
#include <sstream>
#include <string>
#include <stdlib.h>
using namespace std;

#include "BayesNet.h"
#include "CPT.h"
#include "CPDTree.h"
#include "stl.h"

#ifndef rnd
#define rnd() ((double)rand() / RAND_MAX)
#endif


template <typename _Iterator>
int choose_probablistically(_Iterator begin, _Iterator end) { 
	int ans = -1;
	double prob = 1.0;
	int index = 0;
	for (_Iterator i=begin; i!=end; i++) { 
		if (rnd() < prob) { ans = index; }
		prob -= *i;
		index++;
	}
	return ans;
}


class AlarmExampleData : public ExampleData { 
  private:
  	int example;
  	vector<vector<int> > *data;	// for each example, node, which value?
  public:
  	AlarmExampleData(int E, vector<vector<int> > *D) : example(E), data(D) { }
	double operator()(int ID, int value) { 
		if ((*data)[example][ID] == value) { return 1.0; } else { return 0.0; }
	}
};

class AlarmEvidence : public Evidence { 
  private:
  	int ID;
	vector<vector<int> > *data;
  public:
  	AlarmEvidence(int id, vector<vector<int> > *D) : ID(id), data(D) { }
	double operator()(int example, int value) {
		if ( (*data)[example][ID] == value ) { return 1.0; }
		else                               { return 0.0; }
	}
};


int main() { 

	srand(6141977);

	// values of a binary variable
	const string BV[] = { "F", "T" };

	// create variables
	cerr << "create variables...\n";
	const int NUM_VARS = 5;            //    4  5  6  7  8
	//enum VIDS { DOES, THIS, BREAK, ANYTHING, B, E, A, J, M };
	enum VIDS { B, E, A, J, M };
	vector<Variable> variables;	// must be same order as VIDS
	variables.push_back(Variable("burglary",   B, BV, BV+2));
	variables.push_back(Variable("earthquake", E, BV, BV+2));
	variables.push_back(Variable("alarm",      A, BV, BV+2));
	variables.push_back(Variable("johncalls",  J, BV, BV+2));
	variables.push_back(Variable("marycalls",  M, BV, BV+2));

	// create nodes for each variable
	cerr << "create nodes for each variable...\n";
	vector<Node> nodes;
	for (vector<Variable>::const_iterator v=variables.begin(); v!=variables.end(); v++) { 
		nodes.push_back(Node(&(*v)));
	}

	// create the network
	cerr << "create the network...\n";

	BayesNet BN;
	BN.connect(&(nodes[B-B]), &(nodes[A-B]));
	BN.connect(&(nodes[E-B]), &(nodes[A-B]));
	BN.connect(&(nodes[A-B]), &(nodes[J-B]));
	BN.connect(&(nodes[A-B]), &(nodes[M-B]));

	// Create CPDs:
	cerr << "Create CPDs...\n";
	const double B_CPT[] = { 0.1 , 0.9 };	// burglarys more likely for training
	const double E_CPT[] = { 0.2 , 0.8 };	// earthquakes happen all the time
	                    //   !B!E!A  !B!E,A !B,E!A, ...
	const double A_CPT[] = { 0.999 , 0.001, 0.71, 0.29, 0.06, 0.94, 0.05, 0.95 };
	const double J_CPT[] = { 0.95, 0.05, 0.10, 0.90 };
	const double M_CPT[] = { 0.99, 0.01, 0.30, 0.70 };

	// take advantage of node ordering in vector variables
	CPT *bcpt = new CPT(variables.begin() + B-B, variables.begin() + E-B);	// B
	CPT *ecpt = new CPT(variables.begin() + E-B, variables.begin() + A-B);	// E
	CPT *acpt = new CPT(variables.begin() + B-B, variables.begin() + J-B);	// B,E,A
	CPT *jcpt = new CPT(variables.begin() + A-B, variables.begin() + M-B);	// A,J
	vector<Variable> marys;  marys.push_back(variables[A-B]);  marys.push_back(variables[M-B]);
	CPT *mcpt = new CPT(marys.begin(), marys.end());	// A, M

	// copy in data
	cerr << "copy in data...\n";
	copy(B_CPT, B_CPT + 2, bcpt->begin());
	copy(E_CPT, E_CPT + 2, ecpt->begin());
	copy(A_CPT, A_CPT + 8, acpt->begin());
	copy(J_CPT, J_CPT + 4, jcpt->begin());
	copy(M_CPT, M_CPT + 4, mcpt->begin());

	// assign cpts to nodes
	cerr << "assign cpts to nodes...\n";
	nodes[B-B].cpd = bcpt; nodes[E-B].cpd = ecpt; nodes[A-B].cpd = acpt; 
	nodes[J-B].cpd = jcpt; nodes[M-B].cpd = mcpt;

	// generate some data
	cerr << "generate some data...\n";
	const int N = 1234;	// number of examples
	vector<vector<int> > data(N);	// for each example, variable, which value?
	for (int x=0; x<N; x++) { 
		data[x].resize(NUM_VARS + 10);
		data[x][B] = choose_probablistically(bcpt->begin(), bcpt->end());
		data[x][E] = choose_probablistically(ecpt->begin(), ecpt->end());
		AlarmExampleData aed(x, &data);
		distribution adist = acpt->classify(&aed);
		data[x][A] = choose_probablistically(adist.begin(), adist.end());
		distribution jdist = jcpt->classify(&aed);
		data[x][J] = choose_probablistically(jdist.begin(), jdist.end());
		distribution mdist = mcpt->classify(&aed);
		data[x][M] = choose_probablistically(mdist.begin(), mdist.end());
	}

	// set up evidence for each node except alarm
	cerr << "set up evidence for each node except alarm...\n";
	AlarmEvidence b_evidence(B, &data);	nodes[B-B].evidence = &b_evidence;
	AlarmEvidence e_evidence(E, &data);	nodes[E-B].evidence = &e_evidence;
	                                    nodes[A-B].evidence = NULL;
	AlarmEvidence j_evidence(J, &data);	nodes[J-B].evidence = &j_evidence;
	AlarmEvidence m_evidence(M, &data);	nodes[M-B].evidence = &m_evidence;

	// replace alarm's CPD with a tree.
	cerr << "replace alarm's CPD with a tree...\n";
	
	CPDTreeNode *L00 = new CPDTreeNode(2);	copy(acpt->begin() + 0, acpt->begin() + 0 + 2, L00->dist.begin());
	CPDTreeNode *L01 = new CPDTreeNode(2);	copy(acpt->begin() + 2, acpt->begin() + 2 + 2, L01->dist.begin());
	CPDTreeNode *L10 = new CPDTreeNode(2);	copy(acpt->begin() + 4, acpt->begin() + 4 + 2, L10->dist.begin());
	CPDTreeNode *L11 = new CPDTreeNode(2);	copy(acpt->begin() + 6, acpt->begin() + 6 + 2, L11->dist.begin());
	CPDTreeNode *EL = new CPDTreeNode(&(variables[E-B]), 0, L00, L01);
	CPDTreeNode *ER = new CPDTreeNode(&(variables[E-B]), 0, L10, L11);
	CPDTreeNode *aroot = new CPDTreeNode(&(variables[B-B]), 0, EL, ER);
	CPDTree atree(variables.begin() + B-B, variables.begin() + J-B);	// B,E,A
		// NOTE:  constructor creates uniform dist. by default
	atree.replace(aroot);
	nodes[A-B].cpd = &atree;

	// replace earthquake's CPD with a tree, too
	CPDTreeNode *eroot = new CPDTreeNode(2);
	copy(ecpt->begin(), ecpt->end(), eroot->dist.begin());
	CPDTree etree(variables.begin() + E-B, variables.begin() + A-B);
	etree.replace(eroot);
	nodes[E-B].cpd = &etree;
	
	// use EM to learn tree from data
	cerr << "use EM to learn tree from data...\n";
	BN.em(N, 42);	// BN.em(num_examples, num_iterations)

	// normalize ALL network CPDs
	cerr << "normalize ALL network CPDs...\n";
	BN.normalize_all();

	// print out network
	cerr << "print out network...\n";
	BN.dot(cout, "Alarm Network\\n(test of BayesNet class)");

	// clean up
	cerr << "clean up" << endl;
	delete bcpt;
	delete ecpt;
	delete acpt;
	delete jcpt;
	delete mcpt;
	
	// unit test is done, man
	return 0;
}


