#include <errno.h>
#include <sys/types.h>
#include <sys/stat.h>
#include <fcntl.h>
#include <stdio.h>
#include <stdlib.h>
#include <unistd.h>
#include <dlfcn.h>
#include "md5.h"
#include "md5.c"
#include <string.h>
#include <stdarg.h>
#define SIZEOF_MD5 (16)
#define MAXLENGTH 8192
#define BLOCKSIZE 4096
#define DEBUG0 (1)

void fs_init() __attribute__((constructor));
void fs_fini() __attribute__((destructor));

static void *handle;
extern int errno;

struct open_file {
	md5_context md5_c;
	uint8 *parity;
	uint8 *digest;
	int digest_len;
	char *file_path;
	int block_count;
	int file_size;
	int last_read;
	long inode;
	int fd;
	struct open_file *next;
	char meta_filepath[MAXLENGTH];
	int dirty;
	int unfixable;
};

void generate_parity_block(unsigned char *parity_buf, int fd, int skip_block);
void delete_meta(struct open_file *open_file);

struct open_file *open_file_head = NULL;

/* pointers to original, real functions  */
int (*o_open_2)(const char*,int) = NULL;
int (*o_open_3)(const char*,int, mode_t) = NULL;
ssize_t (*o_read)(int, void *, size_t) = NULL;
ssize_t (*o_write)(int, const void *, size_t) = NULL;
int (*o_close)(int) = NULL;
off_t (*o_lseek)(int, off_t, int) = NULL;
int (*o_truncate)(const char*, off_t) = NULL;
int (*o_ftruncate)(int, off_t) = NULL;
int (*o_unlink)(const char*) = NULL;


void error(char *str) {
	fprintf(stderr, "Error: %s\n", str);
	exit(1);
}
void chkerr(int error_code, char *str){
	if(error_code < 0) error(str);
}

static char *home_dir;
/*
void md5_starts( md5_context *ctx );
void md5_update( md5_context *ctx, uint8 *input, uint32 length );
void md5_finish( md5_context *ctx, uint8 digest[16] );
*/
void fs_init()
{
	handle = dlopen("/lib/libc.so.6",  RTLD_LAZY);
	if (handle == NULL) {
		fprintf(stderr, "%s", dlerror());
		exit(1);
	}

	if(!(home_dir = getenv("FSPROTECT_HOME")))
		error("Environment variable FSPROTECT_HOME is not set.");
	/* create some of the "real" functions for later use */
	o_open_2 = (int(*)(const char*, int)) dlsym(handle, "open");
	o_open_3 = (int(*)(const char*, int, mode_t)) dlsym(handle, "open");
	o_read = (ssize_t (*)(int, void *, size_t)) dlsym(handle, "read");
	o_write = (ssize_t (*)(int, const void *, size_t) )
		dlsym(handle, "write");
	o_close = (int(*)(int)) dlsym(handle, "close");
	o_lseek = (off_t (*)(int,off_t,int)) dlsym(handle, "lseek");
	o_truncate = (int(*)(const char*, off_t)) dlsym(handle, "truncate");
	o_ftruncate = (int(*)(int, off_t)) dlsym(handle, "ftruncate");
	o_unlink = (int(*)(const char*)) dlsym(handle, "unlink");
}

void fs_fini()
{
	dlclose(handle);
}
struct open_file* new_open_file(char *pathname){
	struct open_file *open_file = malloc(sizeof(struct open_file));
	open_file->next = open_file_head;
	open_file_head = open_file;
	open_file->parity = malloc(BLOCKSIZE * sizeof(uint8));
	open_file->digest_len = 0;
	open_file->file_path = strdup(pathname);
	return open_file;
}
struct open_file* get_open_file(int fd){
	struct open_file *open_file = open_file_head;
	while(open_file != NULL){
		if(open_file->fd == fd) break;
		open_file = open_file->next; 
	}
	/* just to ensure that they are all set to zero... */
	return open_file;
}

void del_open_file(struct open_file* open_file){
	struct open_file *prev_open_file = open_file_head;
	if(open_file == open_file_head) open_file_head = open_file->next;
	else {
		/* find the previous open file */
		while(prev_open_file != NULL){
			if(prev_open_file->next == open_file) break;
			prev_open_file = prev_open_file->next; 
		}
		
		/* fix the linked list */
		prev_open_file->next = open_file->next;
	}

	/* TODO  free the various parts of open_file */
	//free(open_file->parity);
	//free(open_file->digest);
	//free(open_file->file_path);
	//free(open_file);
}


void resize_open_file(struct open_file* open_file, int new_block_count){
	uint8 * old_digest = open_file->digest;
	open_file->digest = malloc(sizeof(uint8) * SIZEOF_MD5*new_block_count);
	if(open_file->digest_len != 0)
		memcpy(open_file->digest, old_digest, open_file->digest_len);
	open_file->block_count = new_block_count;
	open_file->digest_len = sizeof(uint8) * SIZEOF_MD5 * new_block_count;
}

void update_open_file(struct open_file *open_file, int fd){
	struct stat filestat;
	open_file->fd = fd; 
	chkerr(fstat(fd, &filestat), "open: could not fstat file.");
	open_file->inode = (long) filestat.st_ino;
	open_file->block_count = (int) filestat.st_blocks; 
	open_file->file_size = (int) filestat.st_size;
	/* create filepath to metafile */
	sprintf(open_file->meta_filepath,
			"%s/%016lx", home_dir, open_file->inode);
}
void delete_meta_for_path(char *path){
	struct open_file open_file;
	chkerr(open_file.fd = (*o_open_2)(path, O_RDONLY), "dm:open");
	update_open_file(&open_file, open_file.fd);
	delete_meta(&open_file);
	chkerr((*o_close)(open_file.fd), "dm:close"); 
}

void parse_metafile(unsigned char *text, unsigned char **md5_digest,
		unsigned char **parity ){
	*md5_digest = text;
	*parity = text + SIZEOF_MD5;
}
int find_corrupt_block(struct open_file* open_file, uint8 *digest1){
	int i = 0;
	int corrupt_block = -1;
	for(; i < (open_file->block_count); i++){
		if(memcmp(open_file->digest+i*SIZEOF_MD5, 
				digest1+i*SIZEOF_MD5, SIZEOF_MD5) != 0){
			if(corrupt_block != -1){
				return -1;
			} else {
				corrupt_block = i;
			}
		}
	}
	return corrupt_block;
}

void get_md5(unsigned char* result, unsigned char* buf, int length){
	md5_context ctx;
	md5_starts(&ctx);
	md5_update(&ctx, buf, length);
	md5_finish(&ctx, result);
}

void generate_md5_for_open_file(struct open_file *open_file,
		unsigned char *digest){
	unsigned char buf[BLOCKSIZE]; 
	int block_counter = 0;
	int bytes_read = 0;
	int start_position = (*o_lseek)(open_file->fd, 0, SEEK_CUR);
	if((*o_lseek)(open_file->fd, 0, SEEK_SET)!=0)
		error("generate_md5: could not seek to beginning");
	if(open_file->file_size <= 0) return; /* empty file just return */
	while((bytes_read = (*o_read)(open_file->fd, buf, BLOCKSIZE)) > 0
			&& block_counter < open_file->block_count) 
		get_md5((digest+ ((block_counter++)*SIZEOF_MD5)),
				buf, bytes_read);
	(*o_lseek)(open_file->fd, start_position, SEEK_SET);
	
	if(block_counter >= open_file->block_count)
		error("generate_md5: File size not what is expected.");
}

void reconstruct_file(struct open_file *open_file, int bad_block){
	int q_fd;
	unsigned char new_block[BLOCKSIZE];
	/* copy over the parity data to start off our parity generation func */
	memcpy(new_block, open_file->parity, BLOCKSIZE);
	generate_parity_block(new_block, open_file->fd, bad_block);

	/* parity is really at this point the fixed corrupt block */
	/* Now we open the file and write to it. */
	chkerr(q_fd=(*o_open_2)(open_file->file_path, O_WRONLY), "rec:open"); 
	chkerr((*o_lseek)(q_fd, bad_block*BLOCKSIZE, SEEK_SET), "rec:lseek");
	chkerr((*o_write)(q_fd, new_block, BLOCKSIZE), "rec:write");
	chkerr((*o_close)(q_fd), "rec:close");
}

void generate_parity_block(unsigned char *parity_buf, int fd, int skip_block){
	unsigned char buf[BLOCKSIZE];
	int block_counter = 0;
	unsigned char *f_byte = buf;
	unsigned char *p_byte = parity_buf;
	unsigned char *end_of_buf = buf + BLOCKSIZE; 
	off_t start_position = (*o_lseek)(fd, 0, SEEK_CUR);
	chkerr((*o_lseek)(fd, 0, SEEK_SET), "gen_parity: couldn't seek");
	while((*o_read)(fd, buf, BLOCKSIZE) == BLOCKSIZE) {
		/* while we are reading whole block increments */
		if(block_counter++ == skip_block) continue; /* skips block */
		f_byte = buf;         /* reset both of     */
		p_byte = parity_buf;  /* our byte pointers */
		/* loop through XORing everything */
		while(f_byte < end_of_buf) *(p_byte++) ^= *(f_byte++);
	}
	(*o_lseek)(fd, start_position, SEEK_SET);
}


void update_parity(int block_number, unsigned char *new_block, struct open_file *open_file){
	unsigned char buf[BLOCKSIZE];
	int i = 0;

	int fd = (*o_open_2)(open_file->file_path, O_RDONLY);
	(*o_lseek)(fd, block_number * BLOCKSIZE, SEEK_SET);
	(*o_read)(fd,buf,BLOCKSIZE);
	(*o_close)(fd); 

	for(; i < BLOCKSIZE; i++){
		buf[i] = buf[i] ^ new_block[i];
	}

	for(i = 0; i < BLOCKSIZE; i++){
		open_file->parity[i] = open_file->parity[i] ^ buf[i];
	}
}


/*
** load_meta loads a metafile as specified by open_file->meta_filepath. it 
** loads it into open_file->digest and open_file->parity. If it is not the 
** size one would expect (as judged by open_file->block_count), then it 
** returns -1. it also returns -1 if it does not exist.
*/

int load_meta(struct open_file *open_file){
	int fd;
	struct stat filestat;
	/* opens, then fstats file. If either fail, return -1 because the file 
	** is probably broken */
	if((fd=(*o_open_2)(open_file->meta_filepath, O_RDONLY))<0) return -1;
	if((fstat(fd, &filestat)<0)) return -1;
	/* checks if the length of the file is the expected length 
	** (parity block + digest length */
	if(filestat.st_size != (BLOCKSIZE + open_file->digest_len)) return -1;

	chkerr((*o_read)(fd, open_file->parity, BLOCKSIZE),
			"load_meta: read parity block");
	chkerr((*o_read)(fd, open_file->digest, open_file->digest_len),
			"load_meta: read digest block");
	chkerr((*o_close)(fd), "load_meta: could not close");
	return 0;
}

/* 
** note: before store_meta is called, the block count property much
*/
void store_meta(struct open_file *open_file){
	int fd;
	
	/* open the metafile, then write parity & digest */
	chkerr((fd=(*o_open_3)(open_file->meta_filepath, 
			O_WRONLY | O_CREAT | O_TRUNC,
			S_IRWXU | S_IRWXG | S_IRWXO)),
			"store_meta: open meta file");
	chkerr((*o_write)(fd, open_file->parity, BLOCKSIZE),
			"store_meta: write parity block");
	chkerr((*o_write)(fd, open_file->digest,
			open_file->block_count * SIZEOF_MD5),
			"store_meta: write entire digest");
	chkerr((*o_close)(fd), "store_meta: could not close");
}


void delete_meta(struct open_file *open_file){
	if((*o_unlink)(open_file->meta_filepath) != 0 ){
		fprintf(stderr, "the unlink to delete meta didn't work\n");
	}
}


// 
// how to define open()
// 
// ... -> some variable number of arguments beyond pathname, flag
// in this case, open optionally takes one extra mode when a new
// file is created, which specifies the protection bits in that case
int open(const char *pathname, int flags, ...)
{
	int fd; 
	int corrupt_block = -999;
	struct open_file *open_file;
	unsigned char *file_md5_digest;
	//unsigned char *meta_parity;
	open_file = new_open_file( (char*) pathname);

	if( flags & O_TRUNC) delete_meta_for_path( (char*) pathname);
	
	if((fd=(*o_open_2)(pathname, O_RDONLY))<0 && !(flags & O_CREAT))
		return -1; /* doesn't exist */
	
	/* Get stats of file */
	update_open_file(open_file, fd);
	resize_open_file(open_file, open_file->block_count);

	if(load_meta(open_file) < 0){
		/* we have to generate meta data */
		generate_md5_for_open_file(open_file, open_file->digest);
		generate_parity_block(open_file->parity, open_file->fd, -1);
		open_file->dirty = 1;
		store_meta(open_file);
	} else {
		/* check md5 against meta data md5 */
		file_md5_digest = malloc(open_file->block_count 
				* SIZEOF_MD5);
		generate_md5_for_open_file(open_file, file_md5_digest);
		corrupt_block = find_corrupt_block(open_file, file_md5_digest);
		free(file_md5_digest);
		if(corrupt_block >= 0) {
			reconstruct_file(open_file, corrupt_block);
		} else if(corrupt_block == -1){
			// more than one corrupt block was found
			open_file->unfixable = 1;
		}
	}

	if (flags & O_CREAT) { 
		// if O_CREAT is set, look for the mode argument too...
		va_list arg;
		va_start(arg, flags);
		mode_t mode = va_arg(arg, int);
		va_end(arg);

		// do something, perhaps calling the real open
		// with pathname, flags, and mode --
		// e.g., (*openPtr)(pathname, flags, mode);
		fd = (*o_open_3)(pathname, flags, mode);
	} else {
		// do something else, perhaps calling the real open
		// with just pathname and flags --
		// e.g., (*openPtr)(pathname, flags);
		fd = (*o_open_2)(pathname, flags);
	}
	open_file->fd = fd; 

	// take the inode from the stat that we just got, use that as the name of the metadata file
	// need to check to see if there is already metadata
	// and make the metadata if there isn't any yet

	return fd;
}


int close(int fd)
{
	struct open_file *open_file = get_open_file(fd);
	if(open_file->dirty == 1){
		// write out new meta-data for this file
		update_open_file(open_file, open_file->fd);
		store_meta(open_file);
	}
	del_open_file(open_file);
	// now do the close ... with this magic function pointer
	return (*o_close)(fd); 
}

off_t lseek(int filedes, off_t offset, int whence){
	return (*o_lseek)(filedes, offset, whence);
}

ssize_t write(int fd, const void *buf, size_t count){

	struct open_file *open_file = get_open_file(fd);
	int blocks_to_write = count / BLOCKSIZE;

	/* set some sort of dirty bit */

	if(open_file->unfixable == 1){
		return -1;
	}

	open_file->dirty = 1;

        int start_position = (*o_lseek)(open_file->fd, 0, SEEK_CUR);
	int start_block = start_position / BLOCKSIZE;
	unsigned char new_md5[SIZEOF_MD5];
	unsigned char* new_block =  (unsigned char *) buf;

	int i = 0;
	int k = 0;
	for(; i < blocks_to_write; i++){
		new_block = (unsigned char *) buf + (i * BLOCKSIZE);
		get_md5(new_md5, new_block, BLOCKSIZE);
		for(; k < SIZEOF_MD5; k++){
			open_file->digest[SIZEOF_MD5 * (start_block + i) + k] = new_block[k];
		}
		update_parity(start_block + i, new_block, open_file);
	}

	(*o_lseek)(open_file->fd, start_position, SEEK_SET);
	return (*o_write)(fd, buf, count);
}


ssize_t read(int fd, void *buf, size_t count){
	struct open_file open_file = * get_open_file(fd);
	off_t start_position = (*o_lseek)(open_file.fd, 0, SEEK_CUR);
	int block_counter = start_position / BLOCKSIZE;
	int end_block = block_counter + (count>BLOCKSIZE ? count/BLOCKSIZE : 1);
	/* here we read ahead to check if this block is corrupted or not, and
	** possibly reconstruct it if it is */ 
	unsigned char readahead_buf[BLOCKSIZE];
	unsigned char readahead_digest[SIZEOF_MD5];
	unsigned char *current_digest; 
	char have_found_error = 0;

	while(block_counter < end_block){
		/* read ahead one block */
		if((*o_read)(open_file.fd, readahead_buf, BLOCKSIZE)==0)
			break;
		/* get the md5 of that block */
		get_md5(readahead_digest, readahead_buf, BLOCKSIZE);
		current_digest = open_file.digest+(SIZEOF_MD5*block_counter);
		/* Check for an error */
		if(memcmp(current_digest, readahead_digest, SIZEOF_MD5)!=0){
			/* Check if there are multiple errors and therefore this
			** is unfixable  (note, maybe rewind too?)*/
			if(open_file.unfixable || have_found_error) return -1;
			reconstruct_file(&open_file, block_counter);
			have_found_error = 1; 
		}
		block_counter++;
	}
	/* Rewind back to where we started */
	(*o_lseek)(open_file.fd, start_position, SEEK_SET);
	return (*o_read)(open_file.fd, buf, count);
}

int truncate(const char* pathname, off_t size){

	int success;
	int length = (int)size;
	int block_count = (int)size/BLOCKSIZE;
	struct open_file * open_file;
	char zero[BLOCKSIZE];
	int i;
	for(i = 0; i<BLOCKSIZE; i++) zero[i]=0;

	int fd, corrupt;
	int blk_dif;

	if(size>0){
		fd = open(pathname, O_WRONLY);
		lseek(fd, length, SEEK_SET);
		open_file = get_open_file(fd);
		if(open_file->unfixable && ((size+length)/2)!=BLOCKSIZE*2) return -1;
		blk_dif = (open_file->file_size - length)/ BLOCKSIZE;
		/* write out zeros to correct meta data for truncate */
		while(write(fd, zero, BLOCKSIZE)>0 && blk_dif--);
		close(fd);
	} else {
		delete_meta_for_path(pathname);	
	}
	success = (*o_truncate)(pathname, size);

	return success;	

}

int ftruncate(int fd, off_t size){

	int length = (int)size;
	struct open_file *open_file = get_open_file(fd);
	int return_val;

	return_val = (*o_ftruncate)(fd, size);
	
	if(return_val == 0){
		if(length == 0){
			delete_meta(open_file);
		}
		else{
			delete_meta(open_file);
			int fd2 = open(open_file->file_path, O_RDONLY);
			close(fd2);
		}
	}

	return return_val;
}

int unlink(const char* pathname){

	int fd;
	struct open_file *open_file;
	struct stat filestat;

	fd = open(pathname, O_RDONLY);

	open_file = get_open_file(fd);

	fstat(open_file->fd, &filestat);
	if(filestat.st_nlink == 1){
		delete_meta(open_file);
	}

	close(fd);
	return (*o_unlink)(pathname);
}


