/*

	Set up a gene regulation network, train it, test it, print it.

	This file contains all the functions that
		read in data, set up variables and nodes, 
		add them to a network, 
		include the appropriate connections, 
		initialize the CPDs for all nodes,
		assign evidence data to the nodes,
		evaluate the network on a testing set


*/

#include <string>
#include <list>
#include <sstream>
#include <vector>
#include <map>
#include <iostream>
#include <fstream>
using namespace std;

#include "gaussian.h"
#include "BayesNet.h"
#include "CPDTree.h"
#include "datafile.h"
#include "probability.h"
#include "OptionParser.h"
#include "stl.h"


// constants
const int HIDDEN_STATE_ARITY = 2;
const string HIDDEN_STATE_NAMES[] = { "I", "A" };

enum BASELINES { FULL, NOHIDDEN, NOHIDDEN_NOCONDITIONS };

/** Structure represents GRN variable data */
struct GRNVariableData {

	list<Role> roles;
	list<pair<string,list<string> > > conditions;
	map<string,GaussianMixture> states;

	GRNVariableData() { }
	GRNVariableData(string rfile, string cfile, string sfile) {
		read(rfile,cfile,sfile); 
	}

	void read(string rfile, string cfile, string sfile) { 
		roles = read_roles(rfile, cerr);
		conditions = read_conditions(cfile, cerr);
		states = read_states(sfile, cerr);
	}

	void clear() { 
		roles.clear();
		conditions.clear();
		states.clear();
	}
};

/** Structure of maps from node NAMES to pointers to the NODES */
struct NodeMaps {
	map<string,Node*> cmap, tmap, emap, hmap; 
	int size() { return cmap.size() + tmap.size() + emap.size() + hmap.size(); }
}; 

struct GRNAssayData {

	int N, C, G;
	vector<string> arrays;
	vector<string> conditions;
	vector<string> genes;
	vector<vector<string> > values;
	vector<vector<double> > expression;

	map<string,int> amap;	// map array name -> index
	map<string,int> cmap;	// map condition name -> index
	map<string,int> gmap;	// map gene name -> index

	GRNAssayData() { }
	GRNAssayData(string filename) { read(filename); }
	void read(string filename);	// read data and set up maps

	void clear() { 
		arrays.clear(); conditions.clear(); genes.clear();
		values.clear(); expression.clear();
		amap.clear(); cmap.clear(); gmap.clear();
	}
	
};

void GRNAssayData::read(string filename) { 

	read_assay_data(filename, N, C, G, arrays, conditions, genes, 
	              values, expression, cerr);
	
	// set up maps
	for (int a=0; a<N; a++) { amap[arrays[a]] = a; }
	for (int c=0; c<C; c++) { cmap[conditions[c]] = c; }
	for (int g=0; g<G; g++) { gmap[genes[g]] = g; }

}

/** data types for using ROLES */
enum role_type { UNKNOWN, ACTIVATOR, REPRESSOR, DUAL };

/** translate role -> string */
string role_type2string(role_type role) { 
	switch(role) { 
		case(ACTIVATOR):  return "Activator"; 
		case(REPRESSOR):  return "Repressor"; 
		case(DUAL):  return "Dual"; 
	}
	return "Unknown";
}

/** translate string -> role_type */
role_type string2role_type(string s) { 
	if (s[0] == 'A' || s[0] == 'a') { return ACTIVATOR; }
	if (s[0] == 'R' || s[0] == 'r') { return REPRESSOR; }
	if (s[0] == 'D' || s[0] == 'd') { return DUAL; }
	return UNKNOWN;
}


/** an evaluation over test data */
struct Evaluation { 
	
	int N;	// number of arrays
	double error;		// total number of incorrect predictions
	double SSE;			// sum squared error, total squared distance 
	                	//	between prediction and expression
	double log_prob;	// log probability of data given model

	Evaluation() : N(0), error(0.0), SSE(0.0), log_prob(0.0) { }

	// sum up statistics in a human-readable way
	string stats();
};

string Evaluation::stats() { 
	ostringstream oss;
	oss << "N = " << N << ", "
	    << "error = " << error << " "
		<< "(" << (100.0 * (error / N)) << "%), "
	    << "SSE = " << SSE << ", "
		<< "log probability = " << log_prob << " "
		<< "(" << exp(log_prob) << ")";
	return oss.str();
}


ostream& operator<<(ostream &out, const Evaluation &evaluation) { 

	out << evaluation.N << " " << evaluation.error << " "
		<< evaluation.SSE << " " << evaluation.log_prob;

	return out;
}



	
/** remove regulatees with 1 state only (remove from roles list) */
void remove_single_state_regulatees(list<Role> &roles, 
                                    const map<string,GaussianMixture> &smap) { 

	list<Role>::iterator r = roles.begin();

	while (r != roles.end()) { 

		map<string,GaussianMixture>::const_iterator s = smap.find(r->regulatee);
		if (s == smap.end()) { 
			cerr << *r << " regulatee " << r->regulatee 
				 << ", state not found:  REMOVING REGULATEE." << endl;
			r = roles.erase(r);
			
		} else if (s->second.size() <= 1) { 

			// Optionally, we can print a message, but often there are so many removed
			////cerr << *r << " regulatee " << r->regulatee 
			////	 << " has one state:  REMOVING REGULATEE." << endl;
			r = roles.erase(r);

		} else {
			r++;
		}
	}
}



/** create names for gene states (e.g. "2of3" for state 2 of 3 gene states) */
vector<string> create_states_names(const GaussianMixture &gm) { 

	vector<string> ans(gm.size());

	for (int i=0; i<gm.size(); i++) { 
		ostringstream oss;
		oss << (i+1) << "of" << gm.size();
		ans[i] = oss.str();	// copy
	}
	return ans;
}


/** Simple construct of variable and role 
 *	(used below by 'create_regulatee_cpd') 
 */
struct VarRole { 
	const Variable *variable;
	const Role *role; 
	VarRole(const Variable *v, const Role *r) : variable(v), role(r) { } 
};

/** Ability to sort by role, then priority */
struct SortByRolePriority { 
	bool operator()(const VarRole &LHS, const VarRole &RHS) { 
		
		role_type L = string2role_type(LHS.role->role);
		role_type R = string2role_type(RHS.role->role);
		
		// same role => use priority
		if (L == R) { return (LHS.role->priority < RHS.role->priority); }

		// otherwise, sort by ACTIVATOR < REPRESSOR < DUAL < UNKNOWN
		else {	
				if (L == ACTIVATOR) { return true; }
				if (R == ACTIVATOR) { return false; }
				if (L == REPRESSOR) { return true; }
				if (R == REPRESSOR) { return false; }
				if (L == DUAL) { return true; }
				return false;	// L == Unknown, R == Dual
		}
	}
};

/** Design a CPD Tree given a sorted list of VarRoles (variables and roles) */
CPDTreeNode* create_regulatee_cpd(const vector<VarRole> &varroles, int arity) {

	// The approach:
	//	
	//	sort roles by repressor, activator (and priority), set up a simple
	//	tree (linked list, really), that says:
	//	1st:  IF REPRESSOR is ACTIVATED, then probability is HIGH,
	//			that regulatee expression is LOW
	//	2nd:  IF ACTIVATOR is ACTIVATED, then probability is HIGH,
	//			that regulatee expression is HIGH
	//	
	//	thus, preference is given first to the HIGH PRIORITY repressors, 
	//	then other repressors, then
	//	HIGH PRIORITY activators, et cetera.  For example, if REPX is a
	//	repressor with priority X,
	//
	//		if (REP1 activated) -> { LOW=0.9, HI=0.1 }
	//		else if (REP2 activated) -> { LOW=0.8, HI=0.2 }
	//		else if (ACT1 inactivated) -> { LOW=0.7, HI=0.3 }
	//		else (repressors inactivated, activators activated)
	//									-> { LOW=0.1, HI=0.9 }
	//

	// PROTOTYPES:
	//
	/** Create a LEAF with a uniform probability distribution */
	// CPDTreeNode(int v) : left(NULL), right(NULL), split_variable(NULL);

	/** Create an INTERIOR node */
	// CPDTreeNode(int id, int v, CPDTreeNode *L, CPDTreeNode *R);

	const double R1_A2[] = { 0.9, 0.1 };
	const double R1_A3[] = { 0.8, 0.1, 0.1 };
	const double R2_A2[] = { 0.8, 0.2 };
	const double R2_A3[] = { 0.7, 0.2, 0.1 };
	const double A1_A2[] = { 0.8, 0.2 };
	const double A1_A3[] = { 0.7, 0.2, 0.1 };
	const double A2_A2[] = { 0.7, 0.3 };
	const double A2_A3[] = { 0.6, 0.2, 0.2 };

	CPDTreeNode *root = new CPDTreeNode(arity);	// uniform distribution

	if (arity==2) { copy(R1_A2, R1_A2+2, root->dist.begin()); }
	if (arity==3) { copy(R1_A3, R1_A3+3, root->dist.begin()); }

	std::reverse(root->dist.begin(), root->dist.end());
	
	for (vector<VarRole>::const_iterator vr=varroles.begin();
	                                     vr!=varroles.end(); vr++) { 

		role_type rt = string2role_type(vr->role->role);
		if (rt == REPRESSOR || rt == ACTIVATOR) { 
			CPDTreeNode *leaf = new CPDTreeNode(arity);
			if (arity==2) { 
				if (rt == REPRESSOR && vr->role->priority == 1) 
					{ copy(R1_A2, R1_A2+2, leaf->dist.begin()); }
				if (rt == REPRESSOR && vr->role->priority >  1) 
					{ copy(R2_A2, R2_A2+2, leaf->dist.begin()); }
				if (rt == ACTIVATOR && vr->role->priority == 1) 
					{ copy(A1_A2, A1_A2+2, leaf->dist.begin()); }
				if (rt == ACTIVATOR && vr->role->priority >  1) 
					{ copy(A2_A2, A2_A2+2, leaf->dist.begin()); }
			} else if (arity==3) { 
				if (rt == REPRESSOR && vr->role->priority == 1) 
					{ copy(R1_A3, R1_A3+3, leaf->dist.begin()); }
				if (rt == REPRESSOR && vr->role->priority >  1) 
					{ copy(R2_A3, R2_A3+3, leaf->dist.begin()); }
				if (rt == ACTIVATOR && vr->role->priority == 1) 
					{ copy(A1_A3, A1_A3+3, leaf->dist.begin()); }
				if (rt == ACTIVATOR && vr->role->priority >  1) 
					{ copy(A2_A3, A2_A3+3, leaf->dist.begin()); }
			} else {
				cerr << "create_regulatee_cpd:  illegal value for arity (" 
					 << arity << ")" << endl;
			}

			int left_value = (rt==REPRESSOR) ? 1 : 0;
			root = new CPDTreeNode(vr->variable, left_value, leaf, root);
		}
		
	} // next parent node
		
	return root;
}


/** Build a CPDTree for the given REGULATEE node.
	assign the tree to the node
	populate list of allocated data (cpds)

	roles are all roles
	hmap maps regulator names to hidden node pointers
	emap maps regulatee names to regulatee node pointers
 */
void create_regulatee_cpds(const list<Role> &roles, 
                            const map<string,Node*> &hmap, 
							const map<string,Node*> &emap) { 

	// for each regulatee
	for (map<string,Node*>::const_iterator e=emap.begin(); e!=emap.end(); e++) {

		// collect roles
		vector<VarRole> varroles;
		for (list<Role>::const_iterator r=roles.begin(); r!=roles.end(); r++) { 
			if (e->first == r->regulatee) { 
				varroles.push_back(VarRole(hmap.find(r->regulator)->second->variable, &(*r)));
			}
		}

		// sort by role (activator < repressor), then priority (lower < higher)
		std::sort(varroles.begin(), varroles.end(), SortByRolePriority());

		// create the CPD Tree structure
		CPDTreeNode *root = create_regulatee_cpd(varroles, e->second->arity());

		// create list of relavent variables
		vector<const Variable*> variables;
		for (vector<VarRole>::const_iterator vr=varroles.begin();
		                                     vr!=varroles.end(); vr++) { 

			variables.push_back(vr->variable);
		}
		variables.push_back(e->second->variable);

		CPDTree *tree = new CPDTree(variables);
		tree->replace(root);
		e->second->cpd = tree;
	
	} // next regulatee

	return;
}



/**
 * Set up a network VARIABLES:
 *	create variables and nodes and store in maps
 *	depends on baseline, some baselines don't contain certain nodes
 */
NodeMaps create_network_variables(const GRNVariableData &data, int BASELINE) { 

	NodeMaps maps;	//return value
	int next_id = 0;

	// CONDITIONS
	if (BASELINE != NOHIDDEN_NOCONDITIONS) { 
		for (list<pair<string,list<string> > >::const_iterator i=data.conditions.begin();
		                                                       i!=data.conditions.end(); i++) {
			Variable *v = new Variable(i->first, next_id++, i->second.begin(), i->second.end());
			Node *n = new Node(v);
			maps.cmap[i->first] = n;
		}
	}

	// REGULATORS
	for (list<Role>::const_iterator i=data.roles.begin(); i!=data.roles.end(); i++) { 
		string name = i->regulator;
		map<string,Node*>::iterator t = maps.tmap.find(name);
		if (t == maps.tmap.end()) { 
			map<string,GaussianMixture>::const_iterator s = data.states.find(name);
			if (s == data.states.end()) { cerr << "No states found for regulator " << name << endl; }
			vector<string> sn = create_states_names(s->second);
			Variable *v = new Variable(name, next_id++, sn.begin(), sn.end());
			Node *n = new Node(v);
			maps.tmap[name] = n;
		}
	}

	// REGULATEES
	for (list<Role>::const_iterator i=data.roles.begin(); i!=data.roles.end(); i++) { 
		string name = i->regulatee;
		map<string,Node*>::iterator e = maps.emap.find(name);
		if (e == maps.emap.end()) { 
			map<string,GaussianMixture>::const_iterator s = data.states.find(name);
			if (s == data.states.end()) { cerr << "No states found for regulatee " << name << endl; }
			vector<string> sn = create_states_names(s->second);
			Variable *v = new Variable(name, next_id++, sn.begin(), sn.end());
			Node *n = new Node(v);
			maps.emap[name] = n;
		}
	}

	// HIDDEN NODES
	if (BASELINE == FULL) { 
		for (list<Role>::const_iterator i=data.roles.begin(); i!=data.roles.end(); i++) { 
			string tname = i->regulator;
			string hname = i->regulator + " state";
			map<string,Node*>::iterator h = maps.hmap.find(tname);
			if (h == maps.hmap.end()) {
				Variable *v = new Variable(hname, next_id++,
										   HIDDEN_STATE_NAMES, HIDDEN_STATE_NAMES+HIDDEN_STATE_ARITY);
				Node *n = new Node(v);
				maps.hmap[tname] = n;
			}
		}
	}

	return maps;
}


BayesNet set_network_structure_full(const list<Role> &roles, const NodeMaps &maps) { 

	BayesNet network;

	// connect all regulator nodes to all hidden nodes
	for (map<string,Node*>::const_iterator t=maps.tmap.begin(); t!=maps.tmap.end(); t++) { 
		network.connect(t->second, maps.hmap.find(t->first)->second);
	}

	// connect all conditions to all hidden nodes
	for (map<string,Node*>::const_iterator c=maps.cmap.begin(); c!=maps.cmap.end(); c++) {
	for (map<string,Node*>::const_iterator h=maps.hmap.begin(); h!=maps.hmap.end(); h++) { 
		network.connect( c->second, h->second);
	}}

	// connect hidden nodes to regulatees
	for (list<Role>::const_iterator r = roles.begin(); r!=roles.end(); r++) { 
		network.connect( maps.hmap.find(r->regulator)->second , 
		                 maps.emap.find(r->regulatee)->second );
	}

	return network;
}



BayesNet set_network_structure_baseline1(const list<Role> &roles, NodeMaps &maps) { 

	// BASELINE 1:  No Hidden Nodes

	BayesNet network;

	// connect all conditions to all REGULATEE nodes
	for (map<string,Node*>::const_iterator c=maps.cmap.begin(); c!=maps.cmap.end(); c++) {
	for (map<string,Node*>::const_iterator e=maps.emap.begin(); e!=maps.emap.end(); e++) { 
		network.connect( c->second, e->second);
	}}

	// connect REGULATOR nodes to regulatees
	for (list<Role>::const_iterator r = roles.begin(); r!=roles.end(); r++) { 
		network.connect( maps.tmap.find(r->regulator)->second , 
		                 maps.emap.find(r->regulatee)->second );
	}

	return network;
}


BayesNet set_network_structure_baseline2(const list<Role> &roles, NodeMaps &maps) { 

	// BASELINE 2:  No Hidden Nodes, No Condition Nodes

	BayesNet network;

	// connect REGULATOR nodes to regulatees
	for (list<Role>::const_iterator r = roles.begin(); r!=roles.end(); r++) { 
		network.connect( maps.tmap.find(r->regulator)->second , 
		                 maps.emap.find(r->regulatee)->second );
	}

	return network;
}


BayesNet set_network_structure(const list<Role> &roles, NodeMaps &maps, int BASELINE) { 

	if (BASELINE == FULL) { return set_network_structure_full(roles, maps); }
	else if (BASELINE == NOHIDDEN) { return set_network_structure_baseline1(roles, maps); }
	else if (BASELINE == NOHIDDEN_NOCONDITIONS) { return set_network_structure_baseline2(roles, maps); }
}


void initialize_CPDs_full(const BayesNet &network, const list<Role> &roles, const NodeMaps &maps) {

	// create CPDs for hidden nodes
	for (map<string,Node*>::const_iterator h=maps.hmap.begin(); h!=maps.hmap.end(); h++) { 
		vector<const Variable*> hvars = network.parents(h->second->variable->ID);
		hvars.push_back(h->second->variable);
		CPDTree *tree = new CPDTree(hvars);	// creates a default leaf with uniform dist.
		// TODO:  make the probability of active slightly higher if regulator is active?
		h->second->cpd = tree;
	}

	// create CPDs for regulatee nodes
	create_regulatee_cpds(roles, maps.hmap, maps.emap);
}

void initialize_CPDs_baseline1(BayesNet &network, const list<Role> &roles, const NodeMaps &maps) {

	for (map<string,Node*>::const_iterator e=maps.emap.begin(); e!=maps.emap.end(); e++) { 
		
		vector<const Variable*> evars = network.parents(e->second->variable->ID);
		evars.push_back(e->second->variable);
		CPDTree *tree = new CPDTree(evars);
		e->second->cpd = tree;
	}
}

void initialize_CPDs(BayesNet &network, const list<Role> &roles, const NodeMaps &maps, int BASELINE) {

	if (BASELINE == FULL) { initialize_CPDs_full(network, roles, maps); }
	else                  { initialize_CPDs_baseline1(network, roles, maps); }	// 2 same as 1
}



class GRNEvidence : public Evidence { 
  private:
  	vector<vector<double> > *data;	// for each array, value
  public:
  	GRNEvidence(vector<vector<double> > *D) : data(D) { }
	~GRNEvidence() { delete data; }
  	inline double operator()(int array, int value) { return (*data)[array][value]; }
};



/** 
 * Take all data and setup Nodes' evidence
 */
void set_network_evidence(BayesNet &network, const GRNAssayData &assay,
							const map<string,GaussianMixture> &states, const NodeMaps &maps) { 

	// create evidence for all conditions
	for (int c=0; c<assay.C; c++) { 
		if (maps.cmap.empty()) { break; }	// skip this step if there are no such thing as condition nodes
		string name = assay.conditions[c];
		Node *node = maps.cmap.find(name)->second;
		if (node->evidence) { delete node->evidence; node->evidence = NULL; }
		vector<vector<double> > *evidence_data = new vector<vector<double> >(assay.N);
		for (int a=0; a<assay.N; a++) { 
			(*evidence_data)[a].resize(node->arity());
			bool found = false;
			for (int v=0; v<node->arity(); v++) { 
				(*evidence_data)[a][v] = (assay.values[a][c] == node->variable->values[v]) ? 1.0 : 0.0;
				if ((*evidence_data)[a][v] == 1.0) { found=true; }
			}
		
			// this is for debugging mostly
			if (!found) { 
				cerr << "condition " << *node << ":  value " << assay.values[a][c] 
					 << " not found in list " << node->variable->values << endl; 
			}
		}
		GRNEvidence *evidence = new GRNEvidence(evidence_data);
		node->evidence = evidence;
	} // next condition

	// create evidence for all regulators
	for (int g=0; g<assay.G; g++) { 
		
		string name = assay.genes[g];
		
		map<string,Node*>::const_iterator t_itr = maps.tmap.find(name);
		if (t_itr != maps.tmap.end()) { 
			// gene is a regulator
			Node *node = t_itr->second;
			if (node->evidence) { delete node->evidence; node->evidence = NULL; }
			const GaussianMixture *gm = &(states.find(name)->second);
			vector<vector<double> > *evidence_data = new vector<vector<double> >(assay.N);
			for (int a=0; a<assay.N; a++) { 
				(*evidence_data)[a].resize(gm->size());
				vector<double> dist = gm->prob_dist(assay.expression[a][g]);
				copy(dist.begin(), dist.end(), (*evidence_data)[a].begin());
			} // next array
			GRNEvidence *evidence = new GRNEvidence(evidence_data);
			node->evidence = evidence;
		}
		
		map<string,Node*>::const_iterator e_itr = maps.emap.find(name);
		if (e_itr != maps.emap.end()) { 
			// gene is existing reguatee
			Node *node = e_itr->second;
			if (node->evidence) { delete node->evidence; node->evidence = NULL; }
			const GaussianMixture *gm = &(states.find(name)->second);
			vector<vector<double> > *evidence_data = new vector<vector<double> >(assay.N);

			for (int a=0; a<assay.N; a++) {
				(*evidence_data)[a].resize(gm->size());
				vector<double> dist = gm->prob_dist(assay.expression[a][g]);
				copy(dist.begin(), dist.end(), (*evidence_data)[a].begin());
			} // next array
			GRNEvidence *evidence = new GRNEvidence(evidence_data);
			node->evidence = evidence;

		}

	} // next gene in data file

	// make sure all regulators, regulatees have evidence
	for (map<string,Node*>::const_iterator t = maps.tmap.begin(); t!=maps.tmap.end(); t++) { 
		if (!(t->second->evidence)) { cerr << "ERROR:  No evidence found for regulator " 
		                                   << t->second->variable->name << "." << endl; }
	}
	for (map<string,Node*>::const_iterator e = maps.emap.begin(); e!=maps.emap.end(); e++) { 
		if (!(e->second->evidence)) { cerr << "ERROR:  No evidence found for regulatee " 
		                                   << e->second->variable->name << "." << endl; }
	}
		
				
	// that's it!

	return;

}

void trim_trees(const map<string,Node*> &hmap, const map<string,Node*> &emap, double ratio) { 

	for (map<string,Node*>::const_iterator h=hmap.begin(); h!=hmap.end(); h++) {
		CPDTree *tree = (CPDTree*)(h->second->cpd);
		tree->prune(ratio);
	}

	for (map<string,Node*>::const_iterator e=emap.begin(); e!=emap.end(); e++) {
		CPDTree *tree = (CPDTree*)(e->second->cpd);
		tree->prune(ratio);
	}
}

	


Evaluation test(const BayesNet &network, const GRNAssayData &assay, 
                const map<string,GaussianMixture> &states, map<string,Node*> &emap,
				bool verbose = false) { 

	Evaluation ans;	// (constructor zeros out parameters)
	
	// remove evidence from regulatee nodes
	map<int,Evidence*> evidence_map;
	for (map<string,Node*>::iterator e=emap.begin(); e!=emap.end(); e++) { 
		evidence_map[e->second->variable->ID] = e->second->evidence;
		e->second->evidence = NULL;
	}

	// for each array, regulatee, query and compare to evidence
	int ecnt = 0;
	for (map<string,Node*>::iterator e=emap.begin(); e!=emap.end(); e++) {

		if (verbose) { 
			cerr << e->second->variable->name << " " 
				 << ((int)(100.0 * ((double)(ecnt++)/emap.size()))) 
				 << " %            \r";
		}

		int gene_index_in_assay = assay.gmap.find(e->second->variable->name)->second;
		
		int ID = e->second->variable->ID;
		const GaussianMixture *gm = &(states.find(e->first)->second);
		Evidence *evidence = evidence_map[ID];
		
		for (int a=0; a<assay.N; a++) {

			// actual expression
			double expression = assay.expression[a][gene_index_in_assay];

			// network prediction based on evidence still there
			distribution prediction = network.query(ID, a);

			// actual value for this node's evidence
			distribution actual(e->second->arity());
			for (int v=0; v<e->second->arity(); v++) { 
				actual[v] = (*evidence)(a, v);
			}

			// measure of similarity between prediction and evidence
			double dot_product = 0.0;
			for (int v=0; v<e->second->arity(); v++) { 
				dot_product += (prediction[v] * actual[v]);
			}

			// distance between mean of gaussian and actual expression (weight by prediction)
			double discrepency = 0.0;
			for (int v=0; v<e->second->arity(); v++) { 
				discrepency += prediction[v] * abs(expression - (*gm)[v].mu);
			}

			ans.log_prob += log(dot_product);	// multiply probability
			ans.error += (1.0 - dot_product);
			ans.SSE += (discrepency * discrepency);
			ans.N++;

		} // next array
	} // next regulatee

	if (verbose) { cerr << "100%                " << endl; }
			
	// replace evidence
	for (map<string,Node*>::iterator e=emap.begin(); e!=emap.end(); e++) { 
		e->second->evidence = evidence_map[e->second->variable->ID];
	}

	return ans;
}




	

int main(int argc, char **argv) {

	int baseline;
	
	string roles_filename, conditions_filename, states_filename;
	string trainset_filename, testset_filename;
	string dot_file;
	
	int em_iterations;

	bool prune;
	double ratio;
	
	bool verbose;

	OptionParser parser("Noto and Craven Bayesian Gene Regulation Network");
	
	parser.add(Option("roles", 'r', &roles_filename, "",
		"Roles file:  records are regulator-name regulatee-name Activator|Repressor priority", true));
	parser.add(Option("conditions", 'c', &conditions_filename, "",
		"Conditions file:  records are condition-name possible-values-separated-by-whitespace", true));
	parser.add(Option("states", 's', &states_filename, "",
		"Gene states as GaussianMixture:  records are gene-name number-of-states mean \
		standard-deviation weight mean standard-deviation weight ...", true));
	parser.add(Option("trainset", 'x', &trainset_filename, "",
		"Data file:  first line has number-of-arrays number-of-conditions number-of-genes, second \
		line is conditions-names and genes-names (column header), rest of lines are array-name \
		followed by variable values (conditions values or genes expression)", true));
	parser.add(Option("testset", 't', &testset_filename, "", 
		"Data file as trainset, but used for testing"));
	parser.add(Option("dot", 'd', &dot_file, "", 
		"If provided, print network dot markup to this file"));
	parser.add(Option("noprune", 'p', &prune, true, "After training, do not prune all CPD trees"));
	parser.add(Option("ratio", 'm', &ratio, 1.0, "During CPD tree pruning, maximum LDL/IDL ratio to prune"));
	parser.add(Option("baseline", 'B', &baseline, 0, "Version to run, 0=Full Model, 1=No hidden \
		states, 2=No hidden states and no conditions"));
	parser.add(Option("emitr", 'i', &em_iterations, 10, 
		"Number of EM iterations to run"));
	parser.add(Option("verbose", 'v', &verbose, false, "Verbose output to stderr"));

	parser.parse(argc, argv);

	// check
	if (baseline < 0 || baseline > 2) { 
		cerr << "Error:  There is no baseline " << baseline << "." << endl;
		parser.usage(cerr);
		exit(-1);
	}

	if (baseline == 1 || baseline == 2) { em_iterations = 1; }	// no hidden nodes

	// print out program options
	cerr << endl << "Noto and Craven Bayesian Gene Regulation Network" << endl << endl;
	cerr << "Roles:  " << roles_filename  << endl;
	cerr << "Conditions:  " << conditions_filename << endl;
	cerr << "States:  " << states_filename << endl;
	cerr << "Trainset:  " << trainset_filename << endl;
	if (testset_filename != "") { cerr << "Testset:  " << testset_filename << endl; }
	cerr << endl;

	cerr << "Baseline:  " << (baseline==0?"Full":(baseline==1?"No hidden":"No hidden, no conditions")) << endl;
	cerr << "EM iterations:  " << em_iterations << endl;
	cerr << "Prune after training:  " << (prune?"YES":"NO") << endl;
	if (prune) { cerr << "Prune LDL/MDL ratio:  " << ratio << endl; }
	if (dot_file != "") { cerr << "Dot output:  " << dot_file << endl; }
	cerr << endl;

	// read in data
	GRNVariableData vardata(roles_filename, conditions_filename, states_filename);

	// remove roles with regulatees without >1 state (print to cerr)
	remove_single_state_regulatees(vardata.roles, vardata.states);

	// create all network variables (store in maps)
	NodeMaps maps = create_network_variables(vardata, baseline);

	cerr << maps.cmap.size() << " conditions." << endl;
	cerr << maps.tmap.size() << " regulators." << endl;
	cerr << maps.emap.size() << " regulatees." << endl;
	//cerr << maps.hmap.size() << " hidden states." << endl;	// may put this back in if
																//	1-state-regulators are removed
	cerr << endl;

	// set up the network structure
	BayesNet network = set_network_structure(vardata.roles, maps, baseline);

	// read trainset data
	GRNAssayData trainset(trainset_filename);
	if (verbose) { cerr << "Training on data from " << trainset_filename << ":  "
						<<  trainset.N << " arrays, " << trainset.C 
						<< " conditions, " << trainset.G << " genes." << endl; }

	// set up data as evidence for varaibles
	set_network_evidence(network, trainset, vardata.states, maps);

	initialize_CPDs(network, vardata.roles, maps, baseline);

	if (verbose) { 
		cerr << "Train network; run " << em_iterations << " EM iterations on " 
			 << trainset.N << " arrays..." << endl;
		for (int i=0; i<em_iterations; i++) {
			cerr << (int)((100.0 * i) / em_iterations) << "%   \r";
			network.em(trainset.N, 1);


			/*** This code checks to see how testset results improve with each iteration ***
			network.normalize_all();
			if (testset_filename != "") { 
				GRNAssayData testset(testset_filename);
				if (verbose) { cerr << "Testing on data from " << testset_filename << ":  "
									<<  testset.N << " arrays, " << testset.C 
									<< " conditions, " << testset.G << " genes." << endl; }
				set_network_evidence(network, testset, vardata.states, maps);	
				Evaluation evaluation = test(network, testset, vardata.states, maps.emap, verbose);
				if (verbose) { cerr << "Results are:  " << evaluation.stats() << endl; }
			}
			
			// Must reset to trainset for next EM-itr.
			set_network_evidence(network, trainset, vardata.states, maps);
			
			*******************************************************************************/
			
		}
		cerr << "100% Done training network." << endl;
	} else {
		network.em(trainset.N, em_iterations);
	}
	network.normalize_all();

	if (prune) { trim_trees(maps.hmap, maps.emap, ratio); }

	// if user says so, print network
	if (dot_file != "") { 
		ofstream dot_out(dot_file.c_str());
		if (!dot_out) {
			cerr << "Cannot open " << dot_file << " for output." << endl; 
		} else {
			network.dot(dot_out, roles_filename);
			dot_out.close();
		}
	}

	trim_trees(maps.hmap, maps.emap, ratio);	
	
	// if testset given, run tests and print output
	if (testset_filename != "") { 
		GRNAssayData testset(testset_filename);
		if (verbose) { cerr << "Testing on data from " << testset_filename << ":  "
							<<  testset.N << " arrays, " << testset.C 
							<< " conditions, " << testset.G << " genes." << endl; }
		set_network_evidence(network, testset, vardata.states, maps);	
		Evaluation evaluation = test(network, testset, vardata.states, maps.emap, verbose);
		if (verbose) { cerr << "Results are:  " << evaluation.stats() << endl; }
		cout << evaluation << endl;
	}

	network.delete_all_data();	// the network was given all *new* Variables, Nodes,
								//	CPDs, evidence (this function deletes Variable
								//	structs and evidence as well as nodes and CPDs)

	return 0;

} // main


