
#include "mixture_constraints.h"

void GRN_GaussianMixtureConstraints::operator()(GaussianMixture &gm) const {

	// Make sure each Gaussian has the minimum variance 
	for (GaussianMixture::iterator i=gm.begin(); i!=gm.end(); i++) { 
		if (i->sigma < min_stdev) { i->sigma = min_stdev; }
		if (i->w < min_weight) { i->w = min_weight; }	// after normalization, this won't make
															//	the weight exactly right.  TODO:  implement
															//	this for real
	}
	gm.normalize();	// for re-adjusted weights

	int middle_index = gm.size() / 2;
	double middle_shift = 0.0;

	// Enforce constraint that each Gaussian has maximum density within K standard distributions of its mean
	std::sort(gm.begin(), gm.end());	// I will be assuming that Gaussians are sorted by mean
	for (int j=0; j<gm.size()-1; j++) { 	// for each pair, j and j+1
		
		if ( gm[j].mu + K*gm[j].sigma > gm[j+1].mu ) { 	// gm[j] overlaps gm[j+1] from the left
			if (j+1==middle_index) { middle_shift += ( (gm[j].mu + K*gm[j].sigma) - gm[j+1].mu ); }
			gm[j+1].mu = gm[j].mu + K*gm[j].sigma;	// push gm[j+1] to the right
		}

		if ( gm[j+1].mu - K*gm[j+1].sigma < gm[j].mu ) { // gm[j+1] overlaps gm[j] from the right
			if (j+1==middle_index) { middle_shift += ( (gm[j].mu + K*gm[j+1].sigma) - gm[j+1].mu ); }
			gm[j+1].mu = gm[j].mu + K*gm[j+1].sigma;	// push gm[j+1] to the right
		}

		double r = gm[j].mu + K*gm[j].sigma;	// r = gm[j] mean + K std dev's

		if ( gm[j].w > 0.0 && gm[j+1].w > 0.0 && gm[j].density(r) < gm[j+1].density(r) ) {
			// gm[j+1] has higher probability at r.

			// move gm[j+1] to the right (by just the appropriate amount)
			double x =     gm[j].mu
						 + K*gm[j].sigma 
						 + sqrt( -log( (exp(-(K*K)/2.0) * gm[j+1].sigma * gm[j].w) / 
									   (gm[j].sigma * gm[j+1].w) )
							  * 2.0 * (gm[j+1].sigma * gm[j+1].sigma) );

			if (j+1==middle_index) { middle_shift += x - gm[j+1].mu; }
			gm[j+1].mu = x;								  
							  
			// move gm[j+1] to the right (by just the appropriate amount)
			gm[j+1].mu =   gm[j].mu
						 + K*gm[j].sigma 
						 + sqrt( -log( (exp(-(K*K)/2.0) * gm[j+1].sigma * gm[j].w) / 
									   (gm[j].sigma * gm[j+1].w) )
							  * 2.0 * (gm[j+1].sigma * gm[j+1].sigma) );
		}

		
		r = gm[j+1].mu - K*gm[j+1].sigma;	// now r = gm[j+1]'s mean - K std. deviations (of gm[j+1])

		if ( gm[j].w > 0.0 && gm[j+1].w > 0.0 && gm[j+1].density(r) < gm[j].density(r) ) { 

			double x =   gm[j+1].mu
					   - K*gm[j+1].sigma 
					   - sqrt(  -log( (exp(-(K*K)/2.0) * gm[j].sigma * gm[j+1].w) / 
									  (gm[j+1].sigma * gm[j].w) )
								  * 2.0 * (gm[j].sigma*gm[j].sigma) );
								  
			if (j+1==middle_index) { middle_shift += gm[j].mu - x; }
			gm[j+1].mu += gm[j].mu - x;	
								
		}
	}

	// shift middle Gaussian back to original position (so that Gaussians "spread out" instead of
	//	just shifting strictly to the right)
	for (int j=0; j<gm.size(); j++) { 
		gm[j].mu -= middle_shift;
	}


	// --------------------------------------------------------------------------
	// TEMPORARY (?) warn about ill-defined parameters (they should never occur)
	for (int j=0; j<gm.size(); j++) {
		if (isnan(gm[j].mu) || isnan(gm[j].sigma) || isnan(gm[j].w)) { 
			cerr << "GRN_GaussianMixture_Constraints::operator():  WARNING:  NaN on Gaussian index " << j << ".  ";
			cerr << "nans:  " << ((isnan(gm[j].mu)) ? "mean " : "")
						  << ((isnan(gm[j].sigma)) ? "stdev " : "")
						  << ((isnan(gm[j].w)) ? "weight " : "");
			cerr << endl;
		}
	}
	// --------------------------------------------------------------------------

} // operator()
	



