#include <errno.h>
#include <sys/types.h>
#include <sys/stat.h>
#include <fcntl.h>
#include <stdio.h>
#include <stdlib.h>
#include <unistd.h>
#include <dlfcn.h>
#include "md5.h"
#include "md5.c"
#include <string.h>
#include <stdarg.h>
#define SIZEOF_MD5 (16)
#define META_FILE_SIZE  (SIZEOF_MD5+BLOCKSIZE)
#define MAXLENGTH 8192
#define BLOCKSIZE 4096
#define DEBUG0 (1)

void fs_init() __attribute__((constructor));
void fs_fini() __attribute__((destructor));



static void *handle;
extern int errno;

struct open_file {
	md5_context md5_c;
	uint8 *parity;
	uint8 *digest;
	int digest_len;
	char *file_path;
	int block_count;
	int file_size;
	int last_read;
	long inode;
	int fd;
	struct open_file *next;
	char meta_filepath[MAXLENGTH];
	int dirty;
};

void  print_hex(unsigned char *data, int length, char *lbl);
void generate_parity(struct open_file *open_file, int skip_block);

struct open_file *open_file_head = NULL;

/* pointers to original, real functions  */
int (*o_open_2)(const char*,int) = NULL;
int (*o_open_3)(const char*,int, mode_t) = NULL;
ssize_t (*o_read)(int, void *, size_t) = NULL;
ssize_t (*o_write)(int, const void *, size_t) = NULL;
int (*o_close)(int) = NULL;
off_t (*o_lseek)(int, off_t, int) = NULL;

void error(char *str) {
	fprintf(stderr, "Error: %s\n", str);
	exit(1);
}
void chkerr(int error_code, char *str){
	if(error_code != 0) error(str);
}

static char *home_dir;
/*
void md5_starts( md5_context *ctx );
void md5_update( md5_context *ctx, uint8 *input, uint32 length );
void md5_finish( md5_context *ctx, uint8 digest[16] );
*/
void fs_init()
{
	fprintf(stderr, "fs init\n");

	handle = dlopen("/lib/libc.so.6",  RTLD_LAZY);
	if (handle == NULL) {
		fprintf(stderr, "%s", dlerror());
		exit(1);
	}

	if(!(home_dir = getenv("FSPROTECT_HOME")))
		error("Environment variable FSPROTECT_HOME is not set.");
	/* create some of the "real" functions for later use */
	o_open_2 = (int(*)(const char*, int)) dlsym(handle, "open");
	o_open_3 = (int(*)(const char*, int, mode_t)) dlsym(handle, "open");
	o_read = (ssize_t (*)(int, void *, size_t)) dlsym(handle, "read");
	o_write = (ssize_t (*)(int, const void *, size_t) )
		dlsym(handle, "write");
	o_close = (int(*)(int)) dlsym(handle, "close");
	o_lseek = (off_t (*)(int,off_t,int)) dlsym(handle, "lseek");
}

void fs_fini()
{
	fprintf(stderr, "fs fini\n");
	dlclose(handle);
}
struct open_file* new_open_file(char *pathname){
	struct open_file *open_file = malloc(sizeof(struct open_file));
	open_file->next = open_file_head;
	open_file_head = open_file;
	open_file->parity = malloc(BLOCKSIZE * sizeof(uint8));
	open_file->digest_len = 0;
	open_file->file_path = strdup(pathname);
	return open_file;
}
struct open_file* get_open_file(int fd){
	struct open_file *open_file = open_file_head;
	while(open_file != NULL){
		if(open_file->fd == fd) break;
		open_file = open_file->next; 
	}
	/* just to ensure that they are all set to zero... */
	return open_file;
}
void print_open_files(){
	struct open_file *open_file = open_file_head;
	fprintf(stderr, "======== Currently Open files ============\n");
	while(open_file != NULL){
		fprintf(stderr, "Open file: %016lx,  fd: %04d \n",
				open_file, open_file->fd);
		open_file = open_file->next; 
	}
	fprintf(stderr, "==========================================\n");
	fflush(stderr);
	/* just to ensure that they are all set to zero... */
	return open_file;
}
void del_open_file(struct open_file* open_file){
	struct open_file *prev_open_file = open_file_head;
	if(open_file == open_file_head) open_file_head = open_file->next;
	else {
		/* find the previous open file */
		while(prev_open_file != NULL){
			if(prev_open_file->next == open_file) break;
			prev_open_file = prev_open_file->next; 
		}
		
		/* fix the linked list */
		prev_open_file->next = open_file->next;
	}

	/* TODO  free the various parts of open_file */
	//free(open_file->parity);
	//free(open_file->digest);
	//free(open_file->file_path);
	//free(open_file);
}
      /* void *memcpy(void *dest, const void *src, size_t n); */



void resize_open_file(struct open_file* open_file, int new_block_count){
	uint8 * old_digest = open_file->digest;
	open_file->digest = malloc(sizeof(uint8) * SIZEOF_MD5*new_block_count);
	if(open_file->digest_len != 0)
		memcpy(open_file->digest, old_digest, open_file->digest_len);
	open_file->block_count = new_block_count;
	open_file->digest_len = sizeof(uint8) * SIZEOF_MD5 * new_block_count;
}

void update_open_file(struct open_file *open_file, int fd){
	struct stat filestat;
	open_file->fd = fd; 
	chkerr(fstat(fd, &filestat), "open: could not fstat file.");
	open_file->inode = (long) filestat.st_ino;
	open_file->block_count = (int) filestat.st_blocks; 
	open_file->file_size = (int) filestat.st_size;
	/* create filepath to metafile */
	sprintf(open_file->meta_filepath,
			"%s/%016lx", home_dir, open_file->inode);
}

void parse_metafile(unsigned char *text, unsigned char **md5_digest,
		unsigned char **parity ){
	*md5_digest = text;
	*parity = text + SIZEOF_MD5;
}
int find_corrupt_block(struct open_file* open_file, uint8 *digest1){
	int i = 0;
	int corrupt_block = -1;
	for(; i < (open_file->block_count); i++){
		if(memcmp(open_file->digest+i*SIZEOF_MD5,digest1+i*SIZEOF_MD5, SIZEOF_MD5) != 0){
			if(corrupt_block != -1){
				fprintf(stderr, "There are more than one block that are corrupt");
				return -1;
			} else {
				corrupt_block = i;
				printf("Hey this is the corrupt block: %d", corrupt_block);
			}
		}
	}
	return corrupt_block;
}

void im_here(char * str) { fprintf(stderr, "I'm here: %s\n", str); 
	fflush(stderr);} 

void get_md5(unsigned char* result, unsigned char* buf, int length){
	md5_context ctx;
	md5_starts(&ctx);
	md5_update(&ctx, buf, length);
	md5_finish(&ctx, result);
}

void generate_md5_for_open_file(struct open_file *open_file){
	unsigned char buf[BLOCKSIZE]; 
	int block_counter = 0;
	int bytes_read = 0;
	int start_position = (*o_lseek)(open_file->fd, 0, SEEK_CUR);
	if((*o_lseek)(open_file->fd, 0, SEEK_SET)!=0)
		error("generate_md5: could not seek to beginning");
	if(open_file->file_size <= 0) return; /* empty file just return */
	if(DEBUG0) fprintf(stderr, "bc:%04d\n", open_file->block_count); 
	while((bytes_read = (*o_read)(open_file->fd, buf, BLOCKSIZE)) > 0
			&& block_counter < open_file->block_count) {
		print_hex(buf, SIZEOF_MD5, "generate_md5: buf");
		get_md5((open_file->digest+ ((block_counter++)*SIZEOF_MD5)),
				buf, bytes_read);
	}
	(*o_lseek)(open_file->fd, start_position, SEEK_SET);
	
	if(block_counter >= open_file->block_count)
		error("generate_md5: File size not what is expected.");
}
void replace_data(char *pathname, off_t location, uint8 *data, int data_length){
	int fd;
	// TODO investigate "O_DIRECT" 
	if((fd = (*o_open_2)(pathname, O_WRONLY))<0)
		error("replace_data: open.");
	if((*o_lseek)(fd, location, SEEK_SET)!=location)
		error("replace_data: Could not lseek to bad block.");
	chkerr((*o_write)(fd, data, data_length)==-1, "replace_data: write.");
	chkerr((*o_close)(fd), "replace_data: close");
}

void reconstruct(struct open_file **ptr_of, struct open_file *open_file, int bad_block){
	fprintf(stderr, "Point 0: pointer to the old open file: %08x\n", *ptr_of);
	fflush(stderr);
	printf(stderr, "before GEN PARITYT: open_file:  %08x \n",  open_file);
	fprintf(stderr, "Point 1: pointer to the old open file: %08x\n", *ptr_of);
	fflush(stderr);
	DEBUG_generate_parity(&open_file, open_file, bad_block);
	fprintf(stderr, "Point 2: pointer to the old open file: %08x\n", *ptr_of);
	fflush(stderr);
	/* parity is really at this point the fixed corrupt block */
	replace_data(open_file->file_path, bad_block*BLOCKSIZE,
			open_file->parity, BLOCKSIZE);
	fprintf(stderr, "Point 3: pointer to the old open file: %08x\n", *ptr_of);
	fflush(stderr);
	int THISISUNIQUE = 10;
	// TODO possible logical error: open_file->parity has been overwritten...
	// should we correct this before writing back to disk?
}
void DEBUG_generate_parity(struct open_file **ptr_of, struct open_file *open_file, int skip_block){
	fprintf(stderr, "Point 4: pointer to the old open file: %08x\n", *ptr_of);
	fflush(stderr);
	char buf[BLOCKSIZE];
	int i;
	int block_counter = 0;
	int start_position = (*o_lseek)(open_file->fd, 0, SEEK_CUR);
	fprintf(stderr, "Point 6: pointer to the old open file: %08x\n", *ptr_of);
	fflush(stderr);
	if((*o_lseek)(open_file->fd, 0, SEEK_SET)!=0)
		error("generate_parity: could not seek to beginning");
	fprintf(stderr, "Point 5: pointer to the old open file: %08x\n", *ptr_of);
	fflush(stderr);
	while((*o_read)(open_file->fd, buf, BLOCKSIZE) == BLOCKSIZE){
	fprintf(stderr, "Point 7: pointer to the old open file: %08x\n", *ptr_of);
	fflush(stderr);
		if(block_counter!=skip_block){
			for(i=0; i < BLOCKSIZE; i++){
				open_file->parity[i] = open_file->parity[i] ^ buf[i];
			}
		}
		block_counter++;
	fprintf(stderr, "Point 10: pointer to the old open file: %08x\n", *ptr_of);
	fflush(stderr);
	}
	fprintf(stderr, "Point 11: pointer to the old open file: %08x\n", *ptr_of);
	fflush(stderr);
	(*o_lseek)(open_file->fd, start_position, SEEK_SET);


	fprintf(stderr, "Point 5: pointer to the old open file: %08x\n", *ptr_of);
	fflush(stderr);
	print_open_files();
}

void generate_parity(struct open_file *open_file, int skip_block){
	char buf[BLOCKSIZE];
	int i;
	int block_counter = 0;
	int start_position = (*o_lseek)(open_file->fd, 0, SEEK_CUR);
	if((*o_lseek)(open_file->fd, 0, SEEK_SET)!=0)
		error("generate_parity: could not seek to beginning");
	while((*o_read)(open_file->fd, buf, BLOCKSIZE) >= BLOCKSIZE){
		if(block_counter!=skip_block){
			for(i=0; i < BLOCKSIZE; i++){
				*(open_file->parity + i) =
					*(open_file->parity + i) ^ *(buf +i);
			}
		}
		block_counter++;
	}
	(*o_lseek)(open_file->fd, start_position, SEEK_SET);

}


void update_parity(int block_number, char *new_block, struct open_file *open_file){
	char buf[BLOCKSIZE];
	int i = 0;

	int fd = (*o_open_2)(open_file->file_path, O_RDONLY);
	(*o_lseek)(fd, block_number * BLOCKSIZE, SEEK_SET);
	(*o_read)(fd,buf,BLOCKSIZE);
	(*o_close)(fd); 

	for(; i < BLOCKSIZE; i++){
		buf[i] = buf[i] ^ new_block[i];
	}

	for(i = 0; i < BLOCKSIZE; i++){
		open_file->parity[i] = open_file->parity[i] ^ buf[i];
	}
}

void store_meta(struct open_file *open_file){
	im_here("\nAttempting to store meta data");
	int sizeof_buf = (SIZEOF_MD5 * open_file->block_count) + BLOCKSIZE;
	char buf[sizeof_buf];
	int fd;
	int i = 0;
	for(; i < (SIZEOF_MD5 * open_file->block_count); i++){
		buf[i] = open_file->digest[i];
	}
	for(; i < sizeof_buf; i++){
		buf[i] = open_file->parity[i];
	}
	print_hex(buf, sizeof_buf, open_file->meta_filepath);
	if((fd=(*o_open_3)(open_file->meta_filepath, 
			O_WRONLY | O_CREAT | O_TRUNC,
			S_IRWXU | S_IRWXG | S_IRWXO))<0)
		error("store_meta: could not open");
	fprintf(stderr, "This is the path: %s|%d\n", open_file->meta_filepath,
			fd); 
	if((*o_write)(fd, buf,sizeof_buf)!=sizeof_buf)
		error("store_meta: write");
	im_here("finished writing...");
	fprintf(stderr, "This is the path: %s|%d\n", open_file->meta_filepath,
			fd); 
	chkerr((*o_close)(fd), "store_meta: could not close");
}

// 
// how to define open()
// 
// ... -> some variable number of arguments beyond pathname, flag
// in this case, open optionally takes one extra mode when a new
// file is created, which specifies the protection bits in that case
int open(const char *pathname, int flags, ...)
{
	int fd, meta_fd, corrupt_block;
	struct open_file *open_file;
	unsigned char meta_filecontents[META_FILE_SIZE];
	unsigned char *meta_md5_digest;
	//unsigned char *meta_parity;
	open_file = new_open_file(pathname);

	if((fd=(*o_open_2)(pathname, O_RDONLY))<0 && !(flags & O_CREAT)){
		return -1; /* doesn't exist */
	} else {

		/* TODO add the O_TRUNC check, if it is there, delete the meta data and
		** bypass the rest of this else block */

		// Get stat of file
		update_open_file(open_file, fd);	
		resize_open_file(open_file, open_file->block_count);
		
		im_here("CHECKING OPEN STUFF");
		if((meta_fd=(*o_open_2)(open_file->meta_filepath, O_RDONLY))<0){
			// need to create and generate
			im_here("generatnig md5 etc");
			
			generate_md5_for_open_file(open_file);
			generate_parity(open_file, -1);
			open_file->dirty = 1;
			print_hex(open_file->digest, open_file->digest_len,
					"md5 digest");
		} else {
			if((*o_read)(meta_fd, meta_filecontents, META_FILE_SIZE)
					!= META_FILE_SIZE){
				// corrupt meta file
					im_here("CORRUPT META FILE");
			} else {
				im_here("Checking meta stuff.");
				parse_metafile(meta_filecontents,
						&meta_md5_digest,
						&open_file->parity);
				generate_md5_for_open_file(open_file);

				corrupt_block = find_corrupt_block(open_file,
						meta_md5_digest);
				if(corrupt_block != -1) {
					im_here("There is a corrupt block");
					reconstruct(&open_file, open_file, corrupt_block);
				}
			}

		}
		store_meta(open_file);

	}

	if (flags & O_CREAT) { 
		// if O_CREAT is set, look for the mode argument too...
		va_list arg;
		va_start(arg, flags);
		mode_t mode = va_arg(arg, int);
		va_end(arg);

		// do something, perhaps calling the real open
		// with pathname, flags, and mode --
		// e.g., (*openPtr)(pathname, flags, mode);
		fd = (*o_open_3)(pathname, flags, mode);
	} else {
		// do something else, perhaps calling the real open
		// with just pathname and flags --
		// e.g., (*openPtr)(pathname, flags);
		fd = (*o_open_2)(pathname, flags);
	}
	open_file->fd = fd; 
	fprintf(stderr, "Hey, we just opened %04d\n", fd);

	// take the inode from the stat that we just got, use that as the name of the metadata file
	// need to check to see if there is already metadata
	// and make the metadata if there isn't any yet


	return fd;
}


int close(int fd)
{
	printf("closing fd: %d\n", fd);

	struct open_file *open_file = get_open_file(fd);
	if(open_file->dirty == 1){
		// write out new meta-data for this file
		store_meta(open_file);
		ftruncate(0,0);
	}
	del_open_file(open_file);

	// now do the close ... with this magic function pointer
	return (*o_close)(fd); 
}

off_t lseek(int filedes, off_t offset, int whence){
	return (*o_lseek)(filedes, offset, whence);
}

ssize_t write(int fd, const void *buf, size_t count){

	struct open_file *open_file = get_open_file(fd);
	int blocks_to_write = count / BLOCKSIZE;

	// set some sort of dirty bit

	open_file->dirty = 1;

        int start_position = (*o_lseek)(open_file->fd, 0, SEEK_CUR);
	int start_block = start_position / BLOCKSIZE;
	char new_md5[SIZEOF_MD5];
	char* new_block = buf;

	int i = 0;
	int k = 0;
	for(; i < blocks_to_write; i++){
		new_block = buf + (i * BLOCKSIZE);
		get_md5(new_md5, new_block, BLOCKSIZE);
		for(; k < SIZEOF_MD5; k++){
			open_file->digest[SIZEOF_MD5 * (start_block + i) + k] = new_block[k];
		}
		update_parity(start_block + i, new_block, open_file);
	}

	(*o_lseek)(open_file->fd, start_position, SEEK_SET);
	return (*o_write)(fd, buf, count);
}

ssize_t read(int fd, void *buf, size_t count){

	/* when reading, we need to make sure that everything checks out before 
	** allowing the libc version to actually give it back to them */

	struct open_file *open_file = get_open_file(fd);
	unsigned char f_buf[BLOCKSIZE];
	unsigned char read_md5[BLOCKSIZE];
	int start_position = (*o_lseek)(open_file->fd, 0, SEEK_CUR);
	int start_block = start_position / BLOCKSIZE;
	int blocks_to_read = count / BLOCKSIZE;
	int i = 0;
	int reconstructions = 0;
	int size_read;

	for(; i < blocks_to_read; i++){
	  im_here("A35");
	  fprintf(stderr, "open_file:  %08x ",  open_file);
	  fflush(stderr);
	  im_here("A36");
		(*o_read)(open_file->fd, f_buf, BLOCKSIZE);
	  im_here("A34");

	im_here("STUFF HERE");
		// generate a md5 for the block just read in
		get_md5(read_md5 ,f_buf, BLOCKSIZE);
	  im_here("A33");
	im_here("MORE STUFF");
	  im_here("A31");
		print_hex(read_md5, SIZEOF_MD5, "read(): read_md5"); 
	  im_here("A32");
		print_hex(open_file->digest + 
				(SIZEOF_MD5 * (start_block + i)),
				 SIZEOF_MD5, "read(): actual digest"); 
	  im_here("A36");
		print_hex(f_buf, SIZEOF_MD5, "read(): actual read"); 
	  im_here("A37");
		if(reconstructions > 1){
			//there is more than one bad block and that is a problem
			errno = EIO;
			im_here("reconstruct > 1");
			return -1;
		}

		/* compare the md5 of open_file->digest[start_block + i] with 
		** the md5 of buf if they are the same...all good otherwise 
		** need to reconstruct this block */
		if(memcmp(open_file->digest + (SIZEOF_MD5 * (start_block + i)),
					 read_md5, SIZEOF_MD5) != 0){
	  im_here("A38");
			// this block is corrupt and needs to be reconstructed
	  fprintf(stderr, "before RECONSTRUCT: open_file:  %08x \n",  open_file);
	  fflush(stderr);
			 reconstruct( &open_file,	open_file, start_block + i);
	  fprintf(stderr, "after RECONSTRUCT: open_file:  %08x \n",  open_file);
	  fflush(stderr);
	  im_here("A39");
			reconstructions++;
			im_here("reconstructions ++");
		}
	}

	// rewind back to where the libc version of read() needs to start
	(*o_lseek)(open_file->fd, start_position, SEEK_SET);
	size_read = (*o_read)(fd, buf, count);
	fprintf(stderr, "size read: %04d", size_read);
	return size_read;
}


void  print_hex(unsigned char *data, int length, char *lbl){
	int i=0;
	fprintf(stderr, "DATA ----- %08d -------- %s\n", length, lbl);
	for(i=0; i<length; i++) fprintf(stderr, "%02x", data[i]);
	fprintf(stderr, "\nEND DATA - %08x --------\n", length);
}

int ftruncate(int fdij, off_t length){ 
/*	struct open_file *open_file;
	int sizeof_buf = 10;
	char *buf = "This is a test blah";
	int fd;

	print_hex(buf, sizeof_buf, open_file->meta_filepath);
	if((fd=(*o_open_3)(open_file->meta_filepath, 
			O_WRONLY | O_CREAT | O_TRUNC,
			S_IRWXU | S_IRWXG | S_IRWXO))<0)
		error("store_meta: could not open");
	fprintf(stderr, "This is the path: %s|%d\n", open_file->meta_filepath,
			fd); 
	if((*o_write)(fd, buf,sizeof_buf)!=sizeof_buf)
		error("store_meta: write");
	im_here("finished writing...");
	fprintf(stderr, "This is the path: %s|%d\n", open_file->meta_filepath,
			fd); 
	chkerr((*o_close)(fd), "store_meta: could not close");
*/
	return fdij + (int) length;
}
