/*
12345678901234567890123456789012345678901234567890123456789012345678901234567890
*/

/*
	class Node:  represents a Bayesian network node.

	There are two types of nodes, those WITH evidence and those WITHOUT.
	It's assumed that there is a set number of data examples (atomic events),
	which specify a probability distribution over values for each node.
	
	NOTE:  often, evidence consists of KNOWN, SPECIFIC values for each 
	random variable, but this package allows for a PROBABILITY DISTRIBUTION
	over values for evidence.  (Of course, the probability distribution
	could be 1.0 for a certain possible value and 0.0 for all others.)
	
	Evidence is given to each node as a pointer to an instance of an Evidence
	class, which returns a probability for each example (atomic event),
	for each possible variable value (see class Example below).

	NOTE:  If a node HAS evidence, it is assumed that it has evidence
	for ALL events.  If a node DOES NOT HAVE evidence, it is assumed that
	it does NOT have evidence for ALL examples.  
	TODO:  come up with a way to allow nodes to have (or not have)
	evidence for SOME examples, and to not have (or to have) evidence 
	for others.

	class BayesNet:

	Uses class Node to create a collection of Nodes and edges between them.
	Maintains BOTH Node->child and Node->parent links.

	Implements various algorithms for Bayesian networks, such as: 
	query (use other nodes and CPDs to compute),
	em (update ALL CPDs based on some nodes' evidence, others without evidence)

*/

#ifndef BAYES_NET_H
#define BAYES_NET_H 1

#include <iostream>
#include <vector>
#include <list>
#include <sstream>
#include <string>
using namespace std;

#include "CPD.h"
#include "CPT.h"

class Evidence;
class Node;
class BayesNet;
class BayesNetTrainingData; 

/** Functor takes an example and a variable value and returns the
	probability or the extent to which the variable takes on that
	value in that example */
class Evidence {
  public:
  	virtual double operator()(int example, int value) = 0;
};

class Node {
  
  friend class BayesNet;
  friend ostream& operator<<(ostream &out, const Node &node);
	
  private:

	vector<const Node*> parents;		// set via BayesNet::connect(...)
	vector<const Node*> children;		// BayesNet will use these

  public:

  	const Variable *variable;			// Node represents this variable

  	Evidence *evidence;			// Node may have evidence (NULL iff no evidence)
	CPD *cpd;					// CPD not necessarily required 
								//	(e.g. parent w/evidence)
	
	/** create a node */
	Node(const Variable *var) : variable(var) {
		evidence = NULL;
		cpd = NULL;
	}

	/** quick reference */
	inline int arity() const { return variable->arity(); }
	inline int num_parents() const { return parents.size(); }
	inline int num_children() const { return children.size(); }

};

class BayesNet {

  friend class BayesNetTrainingData;

  private:
  
  	list<Node*> nodes; 	// all nodes

	vector<Node*> node_map;	// vector such that if node ID = X is in the list
							//	then node_map[X] -> node X.
	
  public:
	
	/** Add a node (identified by its ID) 
	 If it's already there, it will not be added
	 Internally to BayesNet, this will guarantee that 
	 BayesNet::node_map[node->variable->ID] equals this node pointer */
	void BayesNet::add(Node *node);
	
	/**
	 * Add an edge from parent->child.
	 * If either or both nodes are not already part of the network, 
	 * they are added.
	 */
	void connect(Node *parent, Node *child);

	/**
	 * Create a *new* (empty) CPT for a certain node with the
	 * appropriate variables included based on the network structure.
	 */
	CPT* create_table(int ID) const;

	/**
	 * Create list of variable pointers, the parents of the node
	 */
	vector<const Variable*> parents(int ID) const;

	/**
	 * Get a distribution over possible values for the given node and example
	 */
	distribution query(int ID, int example) const;

	/** 
	 * Normalize ALL Node's CPDs
	 */
	void normalize_all();

	/**
	 * Run EM algorithm to update CPDs for all nodes 
	 * (all nodes that *have* CPDs) using Nodes' evidence where applicable 
	 */
	void em(int num_examples, int num_iterations=1);

	/** print DOT markup representing network (not CPDs) */
	void dot(ostream &out, string title = "") const;

	/**
	 * Delete all evidence, CPDs, Nodes and Variables
	 *	(handy if ALL were dynamically allocated)
	 */
	void delete_all_data();

  private:

	/** create a whole big list of CPTs, potentials for each node
		NOTE:  pointers are new, and must be deleted.
		NOTE:  evidence is applied where available */
	list<CPT*> create_normalized_potentials(const list<const Node*>&, int example) const;

	list<const Node*> relevant_nodes(int ID) const;
	void mark_relevant_parents(int ID, vector<bool> &marked) const;
	void mark_relevant_children(int ID, vector<bool> &marked) const; 

};

/** BayesNetTrainingData uses the evidence supplied for each node */
class BayesNetTrainingData : public TrainingData {
  private:
  	BayesNet *BN;
  public:
  	BayesNetTrainingData(BayesNet *net, int num_examples) 
		: TrainingData(num_examples), BN(net) { }
	double operator()(int ID, int example, int value) { 
		return (*(BN->node_map[ID]->evidence))(example, value);
	}
};

	
#endif	
