
#include "CPT.h"
#include <string.h>
using namespace std;

/** print from ostream */		
ostream& operator<<(ostream &out, const CPT &cpt) { return cpt.print(out); }

CPT::CPT(vector<const Variable*> vars) { 
	this->variables = vars;	// copy pointers
	allocate_and_set_offset();
}

void CPT::allocate_and_set_offset() { 


	size = 1;
	for (vector<const Variable*>::const_iterator v=variables.begin(); v!=variables.end(); v++) { 
		size *= (*v)->arity();
	}
	
	// next line is where a CPDTree::as_table would crash if node had too many parents
	
	data = new double[size];	
	
	if (!data) { 
		// warn if call to new fails (although it's probably because 'size' is SO big
		//	that the call actually crashes the executable before it returns)
		cerr << "ERROR (CPT::allocate_data_and_offset):  OUT OF MEMORY; "
			 << "new double[" << size << "];" << endl;
	}
	
	// reset offsets
	unsigned long total = size;
	for (vector<const Variable*>::const_iterator v=variables.begin(); v!=variables.end(); v++) { 
		offset.push_back(total / (*v)->arity());
		total /= (*v)->arity();
	}

}


CPT* CPT::as_table() const { 
	
	CPT *ans = new CPT(*this);	// copy everything 
	
	if (data) { 
		ans->data = new double[size];	
		if (!ans->data) { cerr << "ERROR (CPT::as_table):  OUT OF MEMORY;" << endl; }
		memcpy(ans->data, this->data, size*sizeof(double));
	}
	
	return ans;
}





/** reset all values based on training data */
void CPT::train(TrainingData *training_data) { 

	fill(this->begin(), this->end(), 0.0);

	for (int example=0; example<training_data->N; example++) { 
		for (unsigned long p=0; p<size; p++) { 
	
			vector<int> tuple = coordinates(p);
			double probability = 1.0;
			for (int v=0; v<variables.size(); v++) { 
				probability *= (*training_data)(variables[v]->ID, example, tuple[v]);
			}
			data[p] += probability;
		}
	}

	return;
}

distribution CPT::classify(ExampleData *data) const { 

	CPT *current = this->as_table();

	// for each variable except the last
	for (vector<const Variable*>::const_iterator v=variables.begin(); v!=variables.end()-1; v++) { 

		distribution var_dist((*v)->arity());
		for (int val=0; val<var_dist.size(); val++) { 
			var_dist[val] = (*data)((*v)->ID, val);
		}

		current->multiply((*v)->ID, var_dist.begin());
		CPT *margin = current->marginalize((*v)->ID);
		delete current;		
		current = margin;

	}

	current->normalize();
	distribution ans(variables.back()->arity());
	copy(current->begin(), current->end(), ans.begin());
	delete current;	
	return ans;

} // classify

void CPT::normalize() { 
	
	int V = variables.back()->arity();
	for (unsigned long p=0; p<size; p+=V) {
		// normalize from data+p -> data+p+V
		double sum = 0.0;
		for (double *pointer=data+p; pointer<data+p+V; pointer++) { sum += *pointer; }
		for (double *pointer=data+p; pointer<data+p+V; pointer++) { *pointer /= sum; }
	}
	return;
}


/** multiply two CPTs together and return the result */
CPT* CPT::multiply(const CPT *B) { 

	CPT *A = this;

	// come up with a list of variables representing the union
	vector<const Variable*> ans_variables = A->variables;	// start with those of A
	vector<int> B_vars_index_in_ans(B->variables.size());

	for (int b=0; b<B->variables.size(); b++) { 
		bool new_var = true;
		for (int a=0; a<A->variables.size(); a++) { 
			
			if (A->variables[a]->ID == B->variables[b]->ID) { 
				new_var = false;
				B_vars_index_in_ans[b] = a;
				break;
			}
		}
		if (new_var) { 
			ans_variables.push_back(B->variables[b]);
			B_vars_index_in_ans[b] = ans_variables.size()-1;
		}
	}
	
	CPT *ans = new CPT(ans_variables);	

	for (unsigned long p=0; p<ans->size; p++) { 

		vector<int> coorsA = ans->coordinates(p);
		vector<int> coorsB(B->variables.size());
		for (int b=0; b<B->variables.size(); b++) { 
			coorsB[b] = coorsA[B_vars_index_in_ans[b]];
		}

		ans->data[p] =   A->entry(coorsA.begin())
		               * B->entry(coorsB.begin());

	}
	
	return ans;

}

/** sum out one of the values and return the result */
CPT* CPT::marginalize(int ID) const {

	vector<const Variable*> ans_variables = variables;	// copy
	int m_index = 0;	// index of marginalized variable in variable array (updated below)
	vector<const Variable*>::iterator v = ans_variables.begin();

	while (ID != (*v)->ID) { 
		m_index++;
		v++;
		
		if (v==variables.end()) {
			// Marginalization of variable not in table returns exact copy
			return this->as_table();
		}
	}

	ans_variables.erase(v);	// remove the marginalized variable from the list
	CPT *ans = new CPT(ans_variables);	
		// this updates the offsets // of all variables

	fill(ans->data, ans->data+ans->size, 0.0);	// start with zero b/c we'll be adding

	for (unsigned long p=0; p<size; p++) {	// for each value in the original data

		vector<int> coors = coordinates(p);
		coors.erase(coors.begin() + m_index);	// remove the coordinate being marginalized

		ans->entry(coors.begin()) += this->data[p];
	}

	return ans;

}

bool CPT::has(int ID) const { 
	for (vector<const Variable*>::const_iterator v=variables.begin(); v!=variables.end(); v++) { 
		if ((*v)->ID == ID) { return true; }
	}
	return false;
}

/** Pretty-print the CPT
 *	pad each line so that the columns line up.
 *	Example:
 *		Var1  Var2
 *		----- -----
 *      False False | 0.25
 *      False True  | 0.75
 *      True  False | 0.4
 *      True  True  | 0.6
 */
ostream& CPT::print(ostream &out) const {

	// find the longest thing to print for each variable
	vector<int> lengths(variables.size());
	for (int v=0; v<variables.size(); v++) { 
		lengths[v] = variables[v]->name.length();
		for (int val=0; val<variables[v]->values.size(); val++) { 
			if (variables[v]->values[val].length() > lengths[v]) { 
				lengths[v] = variables[v]->values[val].length();
			}
		}
	}

	// print each variable name
	for (int v=0; v<variables.size(); v++) { 
		out << variables[v]->name;
		for (int i=0; i<lengths[v] - variables[v]->name.length() + 1; i++) { 
			out << " ";
		}
	}
	out << endl;

	// print a row of '-' to separate the header row
	for (int v=0; v<variables.size(); v++) { 
		for (int i=0; i<lengths[v]; i++) { 
			out << "-";
		}
		out << " ";
	}
	out << endl;

	// print each possible value vector, then a '|', then the table entry
	for (unsigned long i=0; i<size; i++) { 

		for (int v=0; v<variables.size(); v++) { 

			int t = i/offset[v] % variables[v]->values.size();
			out << variables[v]->values[t];
			for (int i=0; i<lengths[v] - variables[v]->values[t].length() + 1; i++) { 
				out << " ";
			}
		}
		out << "| " << data[i] << endl;
	}
	return out;
}


void CPT::dot(ostream &out) const { 

	// cheat and print a node that just contains the
	//	ASCII version of the CPT printout

	const string FONTNAME = "Courier";
	// const string FONTSIZE = "18";

	out << "digraph {" << endl;
	out << "titlenode [color=\"white\", label=\"CPT for " 
		<< variables.back()->name << "\"];" << endl;
	out << "nodeCPT [shape=\"box\" "
		// use default fontsize << "fontsize=\"" << FONTSIZE << "\" "
		<< "fontname=\"" << FONTNAME << "\" ";
		
	out << "label=\"";

	ostringstream oss;
	this->print(oss);
	string S = oss.str();
	
	// replace newlines with literal \l 
	//	(left justify) and quotes with escaped quotes
	for (int c=0; c<S.length(); c++) { 
		if (S[c] == '\n') { 
			cout << "\\l";
		} else if (S[c] == '\"') { 
			cout << "\\\"";
		} else if (S[c] == '\\') { 
			cout << "\\\\";
		} else { 
			cout << S[c];
		}
	}
	cout << "\"];" << endl << "}" << endl;
}
	
vector<int> CPT::coordinates(int index) const { 
	
	vector<int> ans(variables.size());
	int offset_i = index;

	for (int v=0; v<variables.size(); v++) { 
		int pval = offset_i / offset[v];
		ans[v] = pval;
		offset_i -= pval * offset[v];
	}
	return ans;
}

/** ask the table for an offset for a specific variable (-1 if not found) */
int CPT::var_offset(int ID) const { 
	for (int v=0; v<variables.size(); v++) { 
		if (variables[v]->ID == ID) { return offset[v]; }
	}
	return -1;
}

