
#include "CPT.h"
#include "stl.h"
#include <stdlib.h>
#include <iostream>
#include <vector>
#include <list>

using namespace std;

double rr() { return (rand()%10)/10.0; }

const double point1point9[] = { 0.1, 0.9, 0.0 };


class TestExampleData : public ExampleData { 
  public:
  	double operator()(int ID, int value) { 
		if (ID==0 && value==0) { return 0.5; }
		if (ID==0 && value==1) { return 0.5; }
		if (ID==1 && value==0) { return 0.0; }
		if (ID==1 && value==1) { return 1.0; }
		if (ID==1 && value==2) { return 0.0; }
	}
};
	

int main(int argc, char **argv) { 

	enum IDs { A, B, C };

	const string A_VALUES[] = { "False", "True" };	const int A_NUM_VALUES = 2;
	const string B_VALUES[] = { "zero", "one", "two", "three" }; const int B_NUM_VALUES = 4;
	const string C_VALUES[] = { "Red", "Orange", "Blue" }; const int C_NUM_VALUES = 3;

	Variable alpha  ("Binary",  A, A_VALUES, A_VALUES+A_NUM_VALUES);
	Variable bravo  ("Numbers", B, B_VALUES, B_VALUES+B_NUM_VALUES);
	Variable charlie("Color",   C, C_VALUES, C_VALUES+C_NUM_VALUES);

	vector<Variable> ab_vars;
	ab_vars.push_back(alpha);
	ab_vars.push_back(bravo);

	vector<Variable> cb_vars;
	cb_vars.push_back(charlie);
	cb_vars.push_back(bravo);

	CPT *ab = new CPT(ab_vars.begin(), ab_vars.end());
	CPT *cb = new CPT(cb_vars.begin(), cb_vars.end());

	generate(ab->begin(), ab->end(), rr);	// fill with random numbers
	generate(cb->begin(), cb->end(), rr);	

	cout << "ab:" << endl << *ab << endl;
	cout << "cb:" << endl << *cb << endl;

	CPT *abc = ab->multiply(cb);
	cout << "abc = ab * cb:" << endl << *abc << endl;

	CPT *ac = abc->marginalize(B);
	cout << "ac = abc.marginalize(B):" << endl << *ac << endl;

	CPT *c = ac->marginalize(A);
	cout << "c = ac.marginalize(A):" << endl << *c << endl;

	cout << "abc:" << endl << *abc << endl;
	abc->multiply(A, point1point9);
	cout << "abc * 0.1 that A=0, 0.9 that A<0:" << endl << *abc << endl;
	abc->multiply(B, point1point9);
	cout << "abc * 0.1 that B=0, 0.9 that B<0:" << endl << *abc << endl;
	abc->multiply(C, point1point9);
	cout << "abc * 0.1 that C=0, 0.9 that C<0:" << endl << *abc << endl;

	TestExampleData ted;
	distribution dist = abc->classify(&ted);
	cout << "probability distribution over C when A=(0.5@0, 0.5@1, B==1):" << endl << dist << endl;

	delete ab, cb, abc, ac, c;
	return 0;

}
