/**
 *
 * The linux misc device. The job of this routine is to transfer control
 * from the master to the SFI component and vice-versa. Unfortunately,
 * both routines only pass a name to the mischelp device.
   Search code for TODO
**/

//#include "../common/MJR_external_functions.h"
#include "common_h.h"

#include "misc.h"

#include "wrappers_nooks.h"
#include "wrappers_alloc.h"
#include "wrappers_misc.h"
#include "rec_lock.h"

#include <linux/pci.h>
#include <linux/bug.h>
#include <linux/kallsyms.h>
#include <linux/kernel.h>
#include <linux/module.h>
#include <linux/rtnetlink.h>
#include <linux/lockdep.h>

#include <linux/suspend.h>
extern __restore_processor_state(struct saved_context *);
extern __save_processor_state(struct saved_context *);
static struct saved_context orig_context; 

MODULE_AUTHOR("Asim Kadav");
MODULE_LICENSE("GPL");

//#define ODFT_DRIVER_NAME "e1000"
#define ODFT_DRIVER_NAME "ens1371"
//#define ODFT_DRIVER_NAME "psmousebase"
//#define ODFT_DRIVER_NAME "r8169"
//#define ODFT_DRIVER_NAME "r8139too"


static int odft_failure_record = 0;
static spinlock_t odft_failure_lock;
static struct miscdevice misc_help;

// Shared variables (or related functions) between kernel and misc device
// Provided by kernel:
static pDispKern disp_kern_ptr = NULL;      // Kernel dispatch function
pnooks_hash_table_t g_nooks_table; // Pointer to nooks hash table

#define MISC_MAX_LEN 8192
static char **misc_function_id_map; // Defined in the driver kernel module
static void *marsh_buffers[10000];
static int misc_function_id_map_len;

static int sfi_pid; 
static char* sfi_fn_name;



static inline void crash_setup_regs(struct pt_regs *newregs,
                                    struct pt_regs *oldregs)
{
        if (oldregs) {
                memcpy(newregs, oldregs, sizeof(*newregs));
        } else {
                asm volatile("movq %%rbx,%0" : "=m"(newregs->bx));
                asm volatile("movq %%rcx,%0" : "=m"(newregs->cx));
                asm volatile("movq %%rdx,%0" : "=m"(newregs->dx));
                asm volatile("movq %%rsi,%0" : "=m"(newregs->si));
                asm volatile("movq %%rdi,%0" : "=m"(newregs->di));
                asm volatile("movq %%rbp,%0" : "=m"(newregs->bp));
                asm volatile("movq %%rax,%0" : "=m"(newregs->ax));
                asm volatile("movq %%rsp,%0" : "=m"(newregs->sp));
                asm volatile("movq %%r8,%0" : "=m"(newregs->r8));
                asm volatile("movq %%r9,%0" : "=m"(newregs->r9));
                asm volatile("movq %%r10,%0" : "=m"(newregs->r10));
                asm volatile("movq %%r11,%0" : "=m"(newregs->r11));
                asm volatile("movq %%r12,%0" : "=m"(newregs->r12));
                asm volatile("movq %%r13,%0" : "=m"(newregs->r13));
                asm volatile("movq %%r14,%0" : "=m"(newregs->r14));
                asm volatile("movq %%r15,%0" : "=m"(newregs->r15));
                asm volatile("movl %%ss, %%eax;" :"=a"(newregs->ss));
                asm volatile("movl %%cs, %%eax;" :"=a"(newregs->cs));
                asm volatile("pushfq; popq %0" :"=m"(newregs->flags));
                newregs->ip = (unsigned long)current_text_addr();
        }
}


// Don't use the full_slab_verify.h
// That's designed for use elsewhere.
// Here we need this function all the time.
//void full_slab_verify(void);
extern void odft_rollback_klog(void); 
static int odft_pointer_flip_count = 0;
static int odft_stack_corrupt_count = 0; 
static int odft_stack_really_corrupt_count = 0;

int odft_pointer_to_flip(void)  {
   return odft_pointer_flip_count;
}

int odft_stack_to_corrupt(void)  {
   return odft_stack_corrupt_count;
}

int odft_stack_to_really_corrupt(void)  {
   return odft_stack_really_corrupt_count;
}

static int mischelp_ioctl(struct inode *inode, struct file *fp,
                          unsigned int cmd, unsigned long arg) {
    //
    // Note:  Use to toggle different types of fault inejection.
    //
   
   //    char *alldata = MJR_alloc(12);
   //    struct req_args args;
   //    memset (alldata, 0, sizeof (unsigned char) * 12);
   //    memcpy (&alldata[0], &cmd, sizeof (unsigned int));
   //    memcpy (&alldata[4], &arg, sizeof (unsigned long));
   
   if (cmd == FLIP_POINTERS)    {
      printk ("Will flip pointer %d in next SFI call.\n", arg);
      odft_pointer_flip_count = arg;
   }
    
   if (cmd == CORRUPT_STACK)	{
	printk ("Will corrupt stack %d in next SFI call.\n", arg);
	odft_stack_corrupt_count = arg;
   }
	

   if (cmd == REALLY_CORRUPT_STACK)	{
	printk ("Will really corrupt stack %d in next SFI call.\n", arg);
	odft_stack_really_corrupt_count = arg;
   }
    return 0;
}



struct file_operations misc_fops = {
    .ioctl = mischelp_ioctl,
    .owner = THIS_MODULE
};

extern void odft_init_range_hashtable(void);
extern void odft_init_klog_hashtable(void);
extern int odft_insert_range_hash (const char *, void *, int);
extern int odft_check_range_hash (const char *, void *);
extern int odft_delete_range_hash (const char *, void *);
extern int odft_delete_range_hash_byname (const char *);
extern int odft_truncate_range_hashtable(void);

extern void odft_init_klog_hashtable(void);
extern int odft_insert_klog_hash(const char *, void *, int, void *, int);
extern void rollback_klog(void);
extern void * odft_check_klog_hash (const char*, int);
extern int odft_truncate_klog_hashtable_by_function(const char* fnname);
/*
void test_klog_hash(void)   {
    void * asim=0xabcddead; 
    char a2[100]  = "kmalloc\0";
    char b2[100]  = "netif_device_attach\0";    
    printk ("Inserted hash %d.\n", odft_insert_klog_hash(a2, sizeof(int), 8, asim, 8));
    printk ("Inserted hash %d.\n", odft_insert_klog_hash(b2, sizeof(int), 8, asim, 8));
    printk ("Retreived %p.\n", odft_check_klog_hash("asim", 1));
    rollback_klog();
}

*/
void test_range_hash(void)  {
   odft_init_range_hashtable(); 
   printk ("Inserted hash %d.\n", odft_insert_range_hash("e1000_get_msglevel", 0x10000, 25));
   printk ("Inserted hash %d.\n", odft_insert_range_hash("e1000_get_msglevel", 0x20000, 25));
   printk ("Checking hash %d.\n", odft_check_range_hash("e1000_get_msglevel", 0x10000));
   printk ("Checking hash %d.\n", odft_check_range_hash("e1000_get_msglevel", 0x10005));
   printk ("Checking hash %d.\n", odft_check_range_hash("e1000_get_msglevel", 0x20000));
   printk ("***Starting delete****\n");
   //printk ("Deleted hash %d.\n", odft_delete_range_hash("e1000_get_msglevel", 0x20000));
   printk ("Deleted hash %d.\n", odft_delete_range_hash_byname("e1000_get_msglevel"));
   printk ("Checking hash %d.\n", odft_check_range_hash("e1000_get_msglevel", 0x20005));
   printk ("Checking hash %d.\n", odft_check_range_hash("e1000_get_msglevel", 0x10000));


}


int init_module(void){
    int retval = 0;
    static char *mischelp_name = "mischelp";

    // Set up some memory
    // 8192 is really big.
    misc_function_id_map = vmalloc (sizeof (char *) * MISC_MAX_LEN);

    // Defined in rec_lock.c/h
    //init_recursive_lock ();
    g_nooks_table = nooks_ot_create_new_table(4096);
    
    // Open the pipe
    //printk ("miscdevice: before opening\n");
    //MJR_open (WE_ARE_KERNEL_MODE);
    //printk ("miscdevice: after opening\n");

    // Initialize linux kernel junk
    misc_help.minor = MISC_MINOR;
    misc_help.name = mischelp_name;
    misc_help.fops = &misc_fops;
    retval = misc_register(&misc_help);
    
    if (retval)
        return retval;

    // Initialize reange hash table
    odft_init_range_hashtable();
    printk ("Range hashtable initialized.\n");

    odft_init_klog_hashtable();
    printk ("klog hashtable initialized.\n");
    //test_range_hash();
    //test_klog_hash();
    sfi_pid = 0;
    sfi_fn_name = (char *) kmalloc (100, GFP_ATOMIC);
    memset(sfi_fn_name, 0, 100);
    odft_pointer_flip_count = 0;
    register_odft_failure_record(&check_and_set_odft_failure_record);
    register_odft_mm_failure_record(&check_and_set_odft_failure_record);

    printk ("Driver failure function registered.\n");
    spin_lock_init(&odft_failure_lock);

    printk ("Misc device initialized successfully.\n");
    return 0;
}

void cleanup_module(void){
    int number;

    struct req_args args;
    args.function_id = FIFO_EXIT;
    args.length = 0;
    args.data = NULL;

    //printk ("Notifying userdaemon to quit.\n");
    //MJR_to_user (&args);

    vfree (misc_function_id_map);
    printk ("miscdevice cleanup\n");
    kfree (sfi_fn_name);
    // Deregister from kernel
    number = misc_deregister(&misc_help);
    if (number < 0) {
        printk ("misc_deregister failed. %d\n", number);
    }
    
    // Close the pipe
    //MJR_close ();

    // Free memory
    nooks_ot_free_table (g_nooks_table);
    //full_slab_verify ();
    printk ("It's over for misc device..\n");
}

/*****************************************************************************
 * Full slab verification
 * ASSUMES we're using SLUB not SLAB
 ******************************************************************************/
/* to be added later 
extern struct kmem_cache kmalloc_caches[PAGE_SHIFT + 1] __cacheline_aligned;
extern struct kmem_cache *skbuff_head_cache;
extern struct kmem_cache *skbuff_fclone_cache;

extern long validate_slab_cache(struct kmem_cache *s);
void full_slab_verify(void) {
    int i;
    
    //printk ("Verifying slabs..\n");
    // Fancy 1-off issue:  kernel uses first cache only with CONFIG_NUMA
    // Verify this code against resiliency_test and other functions
    // defined in slub.c
    for (i = 1; i < PAGE_SHIFT + 1; i++) {
        validate_slab_cache(kmalloc_caches + i);
    }

    validate_slab_cache(skbuff_head_cache);
    validate_slab_cache(skbuff_fclone_cache);
}
*/
/*****************************************************************************
 Centralized printk  
    38 #define KERN_EMERG      "<0>"    system is unusable                  
    39 #define KERN_ALERT      "<1>"    action must be taken immediately    
    40 #define KERN_CRIT       "<2>"    critical conditions                  
    41 #define KERN_ERR        "<3>"    error conditions                    
    42 #define KERN_WARNING    "<4>"    warning conditions                   
    43 #define KERN_NOTICE     "<5>"    normal but significant condition     
    44 #define KERN_INFO       "<6>"    informational                       
    45 #define KERN_DEBUG      "<7>"    debug-level messages                
*****************************************************************************/
/*int uprintk (const char *fmt, ...) {
#ifdef ENABLE_UPRINTK
    va_list args;
    int r;
    char format[UPRINTK_BUFLEN];
    
    strcpy (format + 3, fmt);
    format[0] = '<';
    format[1] = '1'; // 1 = print to console, 6 = don't print.
    format[2] = '>';
    
    va_start(args, fmt);
    r = vprintk(format, args);
    va_end(args);
    
    return r;
#else
    return 0;
#endif
}*/

void dispatch_user_request (char *function_name, struct req_args *misc_fnargs) {
    int retval;
    //printk ("verifying slabs.\n");
    //full_slab_verify();
    switch (misc_fnargs->function_id){
        // Miscellaneous wrappers
        case FIFO_RETURN:
            // The caller, being in the driver, will demarshal everything.
            printk ("Returning to marsh_stub %s ...\r\n", function_name);
            return;
        case FIFO_WRAPPER_PRINTK:
            handle_printk(misc_fnargs);
            return;
        case FIFO_WRAPPER_JIFFIES:
            handle_jiffies (misc_fnargs);
            return;

        // Nooks OT functions
        case FIFO_NOOKS_XLATE_U2K:
            handle_nooks_xlate_u2k (misc_fnargs);
            return;
        case FIFO_NOOKS_XLATE_K2U:
            handle_nooks_xlate_k2u (misc_fnargs);
            return;
        case FIFO_NOOKS_ADD_TO_HASH:
            handle_nooks_add_to_hash (misc_fnargs);
            return;
        case FIFO_NOOKS_DEL_FROM_HASH:
            handle_nooks_del_from_hash (misc_fnargs);
            return;
        case FIFO_NOOKS_DEL_FROM_HASH_REVERSE:
            handle_nooks_del_from_hash_reverse (misc_fnargs);
            return;
        case FIFO_NOOKS_REGISTER_USERFN:
            handle_nooks_register_userfn (misc_fnargs);
            return;
        case FIFO_NOOKS_MEMORY_ASSOC:
            handle_nooks_memory_assoc (misc_fnargs);
            return;
        case FIFO_NOOKS_GET_ARRAY_NUMELTS_REVERSE:
            handle_nooks_get_array_numelts_reverse (misc_fnargs);
            return;
        case FIFO_NOOKS_STORE_ARRAY_NUMELTS_REVERSE:
            handle_nooks_store_array_numelts_reverse (misc_fnargs);
            return;
        case FIFO_NOOKS_RANGE_UPDATE:
            handle_nooks_range_update (misc_fnargs);
            return;
        case FIFO_NOOKS_RANGE_FREE:
            handle_nooks_range_free (misc_fnargs);
            return;
        case FIFO_NOOKS_RANGE_ADD:
            handle_nooks_range_add (misc_fnargs);
            return;
        case FIFO_NOOKS_DEBUG_TRANSLATE_U2K:
            handle_nooks_debug_translate_u2k (misc_fnargs);
            return;
        // All others fall through.
    }

    //
    // Release recursive lock and check heap
    //
    //MJR_release_recursive_lock(NULL);
    
    //full_slab_verify();
    
    //
    // OK, it must be an allocator then?
    //
    //retval = disp_wrapper (misc_fnargs);
    //if (retval == 0) {
      //  goto release;
    //}

    //
    // No, not an allocator.
    //
    // This just means it's not one of the allocators or one of the nooks functions.
    // TODO:  Why don't we merge all this together using a single more unified approach?
    // Should be straightforward.  Is a giant switch statement the way to go?  Maybe
    // function pointers?  Ugh.
    //
    if (disp_kern_ptr == NULL) {
        panic ("Oops--the kernel driver stubs are screwed up and have not set up disp_kern!\n");
    }
    uprintk ("Calling function %s, from %s\r\n",
             misc_function_id_map[misc_fnargs->function_id], function_name);
    if (disp_kern_ptr ("", misc_fnargs) != 0) {
        panic ("Failed to disp_kern with %s from %s!\n",
               misc_function_id_map[misc_fnargs->function_id],
               function_name);
    }

  release:
    //
    // Check heap and acquire recursive lock
    //
    //full_slab_verify();
    //MJR_acquire_recursive_lock(NULL);
    return;
}

int disp_user(char *function_name, struct req_args *misc_fnargs) {
    printk ("In disp_user.\n");
    
    if ((function_name == NULL) || (odft_failure_record == 1)) {
	  printk ("Recovery start\n");
	  dump_stack();	
	  __restore_processor_state(&orig_context);
	printk ("Going to sfi failed.\n");
	goto sfi_failed;
    }
   
    __save_processor_state(&orig_context);  
    crash_setup_regs(&orig_context.regs, NULL);	
 
    char fn_name[100];
    int kill_self = 0;
    char ckpt_fn_name[100];
    struct marshret_struct retval; 
    void (*fn_addr)(void *, struct marshret_struct *);
    int (*ckpt_fn_addr)(struct pci_dev *);
    unsigned long flags;
   

    //printk ("Calling into SFI for %s.\n", function_name);

    //sprintf (fn_name, "e1000_sfi:__MARSH_WRAP__%s", function_name);
    sprintf( fn_name, "%s_sfi:__MARSH_WRAP__%s", ODFT_DRIVER_NAME, function_name); 
    fn_addr = (void (*)) kallsyms_lookup_name(fn_name);
    if (fn_addr == NULL)    {
        printk ("Failed to find the SFI function for %s.\n", fn_name);
        panic("Failed to find %s.\n", function_name);
    }

    sfi_pid = get_current()->pid;
    strcpy (sfi_fn_name, function_name);

    
    printk ("Entering SFI %s with pid %d.\n", sfi_fn_name, sfi_pid);
    // Checkpoint device here.
    //
    // I think it would make more sense to call this inside
    // disp_kern.
    //sprintf (ckpt_fn_name, "e1000_stub:%s", "e1000_checkpoint");
    //sprintf (ckpt_fn_name, (ODFT_DRIVER_NAME, "_stub:%s"), strcat(ODFT_DRIVER_NAME, "_checkpoint"));
    sprintf (ckpt_fn_name, "%s_stub:%s_checkpoint", ODFT_DRIVER_NAME, ODFT_DRIVER_NAME);
    ckpt_fn_addr = (int (*)) kallsyms_lookup_name(ckpt_fn_name);
    if (ckpt_fn_addr == NULL)    {
        printk ("Failed to find the checkpoint function for %s.\n", ckpt_fn_name);
        //panic("Failed to checkpoint device.\n");
    }
   
    ckpt_fn_addr(NULL); 

    // Checkpoint kernel? no-need 
    
    //
    // Generate a log for undo with function name, args
    //
    
    //printk ("In disp_user: attempting to call %s at %p.\n", fn_name, fn_addr);
    
    fn_addr(misc_fnargs->data, &retval);
    misc_fnargs->data = retval.buf;
    misc_fnargs->length = retval.len;
    //printk ("Done with disp_user for %s.\n", fn_name);

sfi_failed: 
   
     printk ("Now checking failure record.\n");

    if (odft_failure_record == 1)  {
      printk ("Failure detected. Performing recovery.\n");  
      dump_stack();
      kill_self = 0;
      int depth = 0;	
      // Rollback kernel
      // Cleanup of locks is critical - else hangs device.
      odft_rollback_klog();
      // Rollback device
      sprintf(ckpt_fn_name, "%s_stub:%s_restore",ODFT_DRIVER_NAME, ODFT_DRIVER_NAME); 
      ckpt_fn_addr = (int (*)) kallsyms_lookup_name(ckpt_fn_name);
      if (ckpt_fn_addr == NULL)    {
          printk ("Failed to find the restore function for %s.\n", ckpt_fn_name);
          printk("Failed to restore device.\n");
      }
      else  {
        printk ("Invoking %s to recover device.\n", ckpt_fn_name); 
        ckpt_fn_addr(NULL);
      }
      //spin lock?
      spin_lock_irqsave(&odft_failure_lock, flags);
      odft_failure_record = 0;
      spin_unlock_irqrestore(&odft_failure_lock, flags);
      sfi_pid = 0;
      
      printk ("Done with recovery from driver.\n"); 
      
      if (kill_self == 1)   {
        if (rtnl_is_locked()) { 
        	rtnl_unlock(); 
	}
		
	/* The code to walk the task structure and clear
	 * up the locks..
	 
	struct task_struct *curr = get_current();
	int i, depth = curr->lockdep_depth;
	for (i = 0; i < depth; i++) {
		printk("\nLock: #%d:\n ", i);
		//print_lock(curr->held_locks + i);
		printk("[<%p>] %pS\n", (void *) (curr->held_locks + i)->acquire_ip, (void *) (curr->held_locks + i)->acquire_ip);	
		//lock_release((curr->held_locks + i)->instance, 1, (curr->held_locks + i)->acquire_ip);
	}

	// If there are locks outside the SFI module such as a proc 
	// or sysfs lock we return and hope after our cleanup everyhting
	// will workkout OK.	
 	*/	
	if ((depth > 0) || 
		(kill_self == 0) || (in_interrupt())) 	{
		return -EINVAL;
	}
	
	/* The final hammer if its Ok to cleanup the thread! */
        printk ("*****Killing %s %d. \n", sfi_fn_name, sfi_pid);
        do_exit(SIGKILL);
	//return EINVAL; 
      }
    }
    // end recovery	
    //obliterate klog
    odft_truncate_klog_hashtable_by_function(sfi_fn_name);	

    return (0);

}

int disp_kern(char *function_name, struct req_args *misc_fnargs)    {
    char fn_name[100];
    struct marshret_struct retval; 
    void (*fn_addr)(void *, struct marshret_struct *);


    //sprintf (fn_name, "e1000_stub:__MARSH_WRAP__%s", function_name);
    sprintf (fn_name, "%s_stub:__MARSH_WRAP__%s", ODFT_DRIVER_NAME, function_name);
    printk ("Sending a call into the kernel for %s.\n", fn_name);
    fn_addr = (void (*))kallsyms_lookup_name(fn_name);
    if (fn_addr == NULL)    {
        printk ("Failed to find kern function for %s.\n", fn_name);
        return (-1);
    }

    
    printk ("In disp_kern: attempting to call %s at %p.\n", fn_name, fn_addr);
    fn_addr(misc_fnargs->data, &retval);
    
    odft_insert_klog_hash(function_name, misc_fnargs->data, misc_fnargs->length,
            retval.buf, retval.len);
    misc_fnargs->data = retval.buf;
    misc_fnargs->length = retval.len;

    return (0);

}



/*****************************************************************************
 * Get a pointer to the disp_kern function from the k-driver
 *****************************************************************************/
void register_miscfn(void *fnptr, char fn_id_map[][128], int fn_id_map_len) {
    int i;
    disp_kern_ptr = (pDispKern) fnptr;
    char fn_name[100];
    void (*fn_addr)(void);

    misc_function_id_map_len = fn_id_map_len;
    if (misc_function_id_map_len >= MISC_MAX_LEN) {
        panic ("Failed to init misc_function_id_map because the table is really big!\n");
    }

    for (i = 0; i < misc_function_id_map_len; i++) {
        misc_function_id_map[i] = fn_id_map[i];
    }

    //sprintf( fn_name, "%s_sfi:register_globals", ODFT_DRIVER_NAME);
    //fn_addr = (void (*)) kallsyms_lookup_name(fn_name);
    //fn_addr();
}

/*****************************************************************************
 * Record that we've executed this function
 *****************************************************************************/
void record_function(const char *fn) {
    //full_slab_verify();
    //printk ("%s %s\n", __func__, fn);

}

void odft_record_failure(void)   {
  unsigned long flags = 0;  
  printk ("SFI failed. making note.\n");
  spin_lock_irqsave(&odft_failure_lock, flags);
  odft_failure_record = 1;
  spin_unlock_irqrestore(&odft_failure_lock, flags);
  // Reset pointer flip when recovering?
  //odft_pointer_flip_count = 0;
}                         




void * check_and_set_odft_failure_record(int pid)  {
    if (sfi_pid == 0)   {
        return NULL;
    }
    
    if (pid != sfi_pid)     {
       return NULL;
    }

    odft_record_failure();

    //return disp_user;
    //orig_context.regs.ip = disp_user;
    return &orig_context.regs;
}



/*****************************************************************************
 * Test driver function.  Used to see if marshaling code/replay works
 *****************************************************************************/
int test_driver_function (int *x) {
    static int previous_value = 0;
    previous_value++;
    
    *x = 100 + previous_value;
    return 200 + previous_value;
}

void test_driver_function2 (unsigned char x) {
    printk ("test_driver_function2: %d\n", x);
}


EXPORT_SYMBOL(disp_user);
EXPORT_SYMBOL(disp_kern);
EXPORT_SYMBOL(dispatch_user_request);
EXPORT_SYMBOL(odft_pointer_to_flip);
EXPORT_SYMBOL(odft_stack_to_corrupt);
EXPORT_SYMBOL(odft_stack_to_really_corrupt);
//EXPORT_SYMBOL(full_slab_verify);
EXPORT_SYMBOL(g_nooks_table);
EXPORT_SYMBOL(register_miscfn);
EXPORT_SYMBOL(record_function);
EXPORT_SYMBOL(odft_record_failure);
EXPORT_SYMBOL(test_driver_function);
EXPORT_SYMBOL(test_driver_function2);
