/*

	Mathmatical position weight matrix for DNA_ABLEN (4) bases.
	Works with genomic.h


*/ 

#ifndef PWM_H
#define PWM_H

#include "genomic.h"
#include "grid.h"

#include <stdlib.h>	// 2011-11-28, noto added for compile errors in Debian

#include <string>
#include <sstream>
#include <vector>
#include <iostream>
#include <iomanip>
using namespace std;
	
#define PWM_DEBUG 0      // out-of-bounds checking, etc.

#if PWM_DEBUG
#include <assert.h>
#endif

template <typename T> class PWM;

template <typename T> 
class PWM {

  protected:

	vector<T> data;

	#if PWM_DEBUG

	T& entry(unsigned int s, unsigned int b) { 
		unsigned int index = 4*s+b;
		if (index >= data.size()) {
			cerr << "PWM Error: index " << index << " out-of-bounds (width=" 
				 << width() << ", site=" << s << ", base=" << b << ")" << endl; 
			assert(index < data.size());
		}
		return data[index];
	}

	const T& entry(unsigned int s, unsigned int b) const { unsigned int index = 4*s+b; if (index >= data.size()) { cerr << "PWM Error: index " << index << " out-of-bounds (width=" << width() << ", site=" << s << ", base=" << b << ")" << endl; assert(index < data.size()); } return data[index]; } 

	#else 
	inline T& entry(unsigned int s, unsigned int b) { return data[4*s+b]; }
	inline const T& entry(unsigned int s, unsigned int b) const { return data[4*s+b]; }
	#endif
  
  public:
	
	PWM() { }
	inline void resize(unsigned int newsize);	
  	PWM(unsigned int width) { this->resize(width); }
	inline void clear() { data.clear(); }

	// shorthand:
	inline T& operator()(unsigned int s, unsigned int b) { return entry(s,b); }
	inline const T& operator()(unsigned int s, unsigned int b) const { return entry(s,b); }
	
	inline unsigned int width() const { return (data.size()/4); }

	void add(const string &sequence, const T &weight=1.0, bool revcomp=false, unsigned int start=0, unsigned int end=-1); 
	
	void add(const PWM&, const T &weight=1.0);
	
	void pseudocount(const T &pc= (T)1);	// count always linear scale
	void normalize();
	void fill(const T& value= (T)0);	// set everything to zero (or whatever you want)
	
	/**
	 *	Calculate P([sub]sequence|PWM)
	 *
	 *	revcomp = calcualte the likelihood of the sequence on the OTHER STRAND, e.g.
	 *	if the PWM looks like AACGGG, then CCCGTT will score high when revcomp=true.
	 */
	T likelihood(const string &sequence, bool revcomp=false, unsigned int start=0, unsigned int end=-1) const;
	
	string consensus() const; // translate to a consensus
	char sample(unsigned int site) const;	// sample from this site of the PWM
	string sample() const;	// sample from this PWM
	string mle() const;		// print the most-likely DNA instantiation of this PWM

	double entropy(unsigned int site) const;	// calc. entropy for given site
	double entropy() const;	// calc. entropy of PWM (sum of site-wise entropy)

	PWM<T> reverse() const;	// return a reversed version of this PWM
	PWM<T> complement() const; // return a DNA-complemented version of this PWM
	
	void print(ostream&, const string indent="") const;
	ostream& operator<<(ostream &out) const { this->print(out); return out; }

}; // class PWM


/** 
 *	Print a PWM (assuming `operator<<(ostream&, const T&)' is defined)
 *         1    2    3  ...
 *	A | 0.33 0.67 0.01
 *  C | ...
 */
template <typename T>
void PWM<T>::print(ostream &out, const string indent) const { 

	const unsigned int WIDTH = 4;	// width of PWM printed column

	vector<vector<string> > grid(DNA_ABLEN + 1);	// transposed
	for (unsigned int b=0; b<grid.size(); b++) { grid[b].resize(width()+1); }

	for (unsigned int site=0; site<width(); site++) {
		for (unsigned int base=0; base<DNA_ABLEN; base++) { 
			ostringstream cell;
			if (entry(site,base)==0) { 
				// special case to distinguish zero from "0.00"
				cell << " ";
				for (unsigned int space=2; space<WIDTH; space++) { cell << "-"; }
			} else {
				cell << setprecision(WIDTH) << fixed << entry(site,base);
			}
			for (unsigned int space=0; space<WIDTH; space++) { cell << " "; } // pad in case < WIDTH char's so far
			grid[base+1][site+1] = cell.str().substr(0,WIDTH);
		}
		ostringstream oss;  
		for (unsigned int space=0; space<((WIDTH/2)-((unsigned int)log10((float)site+1))); space++) { oss << " "; }	// spaces before site number
		oss << (site+1);
		grid[0][site+1] = oss.str();
	}
	for (unsigned int base=0; base<DNA_ABLEN; base++) {
		grid[base+1][0] = indent + DNA_ALPHABET.substr(base,1) + " | ";
	}
	grid[0][0] = indent + "  | ";

	print_grid(grid, out, " ", "", -1);	// 1 <=> right justify, since all entries are fixed-width

	return; 
}
/** Print a PWM (assuming `operator<<(ostream&, const T&)' is defined) */
template <typename T>
ostream& operator<<(ostream &out, const PWM<T> &pwm) { pwm.print(out); return out; }


template <typename T>
void PWM<T>::resize(const unsigned int sites) { 

	if (sites == width()) { return; }

	unsigned int old_data_size = data.size();

	data.resize(4*sites);

	std::fill(data.begin() + old_data_size, data.end(), (T)0); 

	return;

}
	

template <typename T>
void PWM<T>::normalize() {
	for (unsigned int site=0; site<width(); site++) { 
		T sum = entry(site,0) + entry(site,1) + entry(site,2) + entry(site,3);
		if (sum > ((T)0)) {
			entry(site,0) /= sum;
			entry(site,1) /= sum;
			entry(site,2) /= sum;
			entry(site,3) /= sum;
		}
	}
}

template <typename T>
void PWM<T>::pseudocount(const T &pc) {
	for (unsigned int site=0; site<width(); site++) { 
		for (unsigned int base=0; base<DNA_ABLEN; base++) { entry(site,base) += pc; }
	}
}

template <typename T>
void PWM<T>::fill(const T &value) {
	std::fill(data.begin(), data.end(), value);
}



template <typename T>
void PWM<T>::add(const string &dna, const T &weight, bool revcomp, unsigned int start, unsigned int end) { 

	if (end > dna.length()) { end = dna.length(); }

	for (unsigned int i=start; i<end; i++) { 

		unsigned int site = (revcomp) 
					? width() - 1 - (i-start)
					: i-start;

		const unsigned int DNA_ID = dna_id(dna[i]);

		if (DNA_ID < DNA_ABLEN) { 
			// normal case:  a,c,g,t

			entry( site, revcomp ? complement_dna_id(dna[i]) : DNA_ID ) += weight; 

		} else if (dna[i] == 'n' || dna[i] == 'N') { 
			
			// next most common case:  unsequenced section
			//	I don't do anything here, with no evidence.
			//	One could interpret an 'n' as 25% all bases, but I 
			//	treat it as missing data.

		} else {
			
			// assume non-DNA IUPAC-alphabet 

			unsigned int dna_map = iupac_id(dna[i]);
			if (dna_map == IUPAC_ABLEN) { 
				cerr << "Cannot add '" << dna[i] << "' to PWM." << endl;
			} else {
				
				if (revcomp) { dna_map = ~dna_map & IUPAC_BITS; }
				
				unsigned int base_map;
				unsigned int letters = 0;	// number of possible DNA nt's for this IUPAC character
									//	i.e. 'y' = 2, 'b' = 3, 'x' = 4
				
				base_map = 1;
				for (unsigned int base=0; base<DNA_ABLEN; base++) { 
					if (dna_map & base_map) { letters++; }
					base_map <<= 1;
				}

				base_map = 1; 
				for (unsigned int base=0; base<DNA_ABLEN; base++) { 
					if (dna_map & base_map) { entry(site,base) += (weight/letters); }
					base_map <<= 1;
				}

			}

		}
				
	} // next i
}

template <typename T>
void PWM<T>::add(const PWM<T> &pwm, const T &weight) { 
	
	for (unsigned int site=0; site<pwm.width(); site++) { 
		for (unsigned int base=0; base<DNA_ABLEN; base++) { 
			entry(site,base) += weight * pwm.entry(site,base);
		}
	}
}


template <typename T>
char PWM<T>::sample(unsigned int site) const { 
	
	char ans = 'T';
	T mass = 1;
	for (unsigned int base=0; base < DNA_ABLEN; base++) { 
		T x = (((T)rand()) / RAND_MAX);
		if ( x*mass < entry(site,base) ) { 
			ans = DNA_ALPHABET[base];
			break;
		} else {
			mass -= entry(site,base);
		}
	}
	return ans;
}

template <typename T>
string PWM<T>::sample() const { 

	string ans = "";
	for (unsigned int site=0; site<width(); site++) { 
		ans += sample(site);
	}
	return ans;

}





template <typename T>
T PWM<T>::likelihood(const string &sequence, bool revcomp, unsigned int start, unsigned int end) const { 
	
	if (end > sequence.length()) { end = sequence.length(); }
	
	T ans = 1;
	
	for (unsigned int i=start; i<end; i++) { 

		unsigned int site = (revcomp) 
					?	width() - 1 - (i-start)
					:	i - start;

		const unsigned int DNA_ID = dna_id(sequence[i]);

		if (DNA_ID < DNA_ABLEN) {

			// normal case, a,c,g, or t

			ans *= entry(site, revcomp ? complement_dna_id(sequence[i]) : DNA_ID );

		} else if (sequence[i] == 'n' || sequence[i] == 'N') { 
		
			// another common case (best to check and skip)
			// Do nothing:  100% probability of *anything*
			
		} else {
			
			// non-ACGT
			unsigned int dna_map = iupac_id(sequence[i]);
			
			if (dna_map == IUPAC_ABLEN) { 
				cerr << " PWM::likelihood:  Ignorning illegal character '" << sequence[i] << "'." << endl;
			} else {
			
				if (revcomp) { dna_map = ~dna_map & IUPAC_BITS; }
				
				unsigned int base_map = 1;
				for (unsigned int base=0; base<DNA_ABLEN; base++) { 

					if (dna_map & base_map) { ans *= entry(site,base); }

					base_map <<= 1;

				}
			}

		}

	} // next sequence character

	return ans;

} // likelihood_iupac

/*
                  T  G  C  A
----------------+-----------
IUPAC[ 0] = 'X' | 0  0  0  0    ('X' also means 'N', but I'll use it for 0000 and 'N' for 1111)
IUPAC[ 1] = 'A' | 0  0  0  1
IUPAC[ 2] = 'C' | 0  0  1  0
IUPAC[ 3] = 'M' | 0  0  1  1
IUPAC[ 4] = 'G' | 0  1  0  0
IUPAC[ 5] = 'R' | 0  1  0  1
IUPAC[ 6] = 'S' | 0  1  1  0
IUPAC[ 7] = 'V' | 0  1  1  1
IUPAC[ 8] = 'T' | 1  0  0  0
IUPAC[ 9] = 'W' | 1  0  0  1
IUPAC[10] = 'Y' | 1  0  1  0
IUPAC[11] = 'H' | 1  0  1  1
IUPAC[12] = 'K' | 1  1  0  0
IUPAC[13] = 'D' | 1  1  0  1
IUPAC[14] = 'B' | 1  1  1  0
IUPAC[15] = 'N' | 1  1  1  1
*/

/**
 * translate PWM to a consensus sequence
 * base pairs show up if they have at least 'min_ratio' contribution
 */
template <typename T>
string PWM<T>::consensus() const { 

	// Anybody's in who has 50% of the maximum
	const double RATIO = 0.5;

	string ans = "";
	for (unsigned int site=0; site<width(); site++) { 
		unsigned int map = 0;
		unsigned int matches = 0;	
		
		T max = 0;
		for (unsigned int base=0; base<DNA_ABLEN; base++) { 
			if (entry(site,base) > max) { max = entry(site,base); }
		}
		
		const T min_ratio = RATIO * max;
		
		for (unsigned int base=0; base<DNA_ABLEN; base++) { 
			if (entry(site,base) >= min_ratio) { 
				map |= (1 << base);
				matches++;
			}
		}
		// add appropriate letter (capital for ACGT, one match, lower-case otherwise)
		ans = ans + ((matches > 1) ? iupac_alphabet[map] : IUPAC_ALPHABET[map]);
	}
	return ans;
}
			

/**
 * translate PWM to a consensus sequence
 * base pairs show up if they have at least 'min_ratio' contribution
 */
template <typename T>
string PWM<T>::mle() const { 

	string ans = "";
	for (unsigned int site=0; site<width(); site++) { 
		unsigned int map = 0;
		unsigned int matches = 0;	
		T heavy = 0;
		for (unsigned int base=0; base<DNA_ABLEN; base++) { 
			if (entry(site,base) > heavy) { 
				map = (1 << base);
				matches = 1;
				heavy = entry(site,base);
			} else if (entry(site,base) == heavy) { 
				map |= (1 << base);
				matches++;
			}
		}
		// add appropriate letter (capital for ACGT, one match, lower-case otherwise)
		ans = ans + ((matches > 1) ? iupac_alphabet[map] : IUPAC_ALPHABET[map]);
	}
	return ans;
}

template <typename T>
PWM<T> PWM<T>::reverse() const { 

	PWM<T> ans(this->width());
	
	for (unsigned int site=0; site<this->width(); site++) { 
		for (unsigned int base=0; base<DNA_ABLEN; base++) { 
			ans( this->width() - 1 - site, base ) = this->entry(site,base);
		}
	}
	return ans;

}


template <typename T>
PWM<T> PWM<T>::complement() const {

	PWM<T> ans(this->width());
	
	for (unsigned int site=0; site<this->width(); site++) { 
		for (unsigned int base=0; base<DNA_ABLEN; base++) { 
			ans( site, complement_id(base) ) = this->entry(site,base);
		}
	}
	return ans;

}

template <typename T>
double PWM<T>::entropy(unsigned int site) const { 

	double ans = 0;
	for (unsigned int base=0; base<DNA_ABLEN; base++) { 
		const double p = ((double)(this->entry(site,base)));	// convert to double
		if (p > 0.0) { ans += -p * log2(p);	}
	}
	return ans;
}


template <typename T>
double PWM<T>::entropy() const { 

	double ans = 0;
	for (unsigned int site=0; site<this->width(); site++) { 
		ans += this->entropy(site);
	}
	return ans;
}



#endif	// PWM_H
