/*
	INPUT:  "PSSM" file

		CONSENSUS	E-VALUE	[ [site-1-A-probability site-1-C-probability ..G.. ..T..] [site-2-A-prob... ] ... ]

	OUTPUT:	"MIF" file
		
		Sequence-ID	Motif-ID Position Strand log-probability-ratio sequence-match sequence-match-template-strand


*/

#include <stdlib.h>
#include <math.h>
	
#include <vector>
#include <iostream>
#include <fstream>
#include <sstream>
#include <string>
#include <iterator>
#include <list>
using namespace std;

#include "Option.h"
#include "fasta.h"
#include "markov.h"
#include "dna.h"

#ifndef TAB
#define TAB '\t'
#endif

#ifndef INFINITY
const double INFINITY = -log(0.0);
#endif


void tolower(string &s) { 
	const int D = ('a' - 'A');
	for (int i=0; i<s.length(); i++) { 
		if ('A' <= s[i] && s[i] <= 'Z') {
			s[i] += D;
		}
	}
	return;
}



struct PSSM {
	string consensus;
	double e_value;
	vector<vector<double> > pssm; // pssm[position][DNA]
};

PSSM read_pssm(istream &in) {
	
	PSSM ans;
	string e_value;

	in >> ans.consensus >> e_value;	// E value could be non-numeric ("inf")

	ans.e_value = atof(e_value.c_str());

	ans.pssm.resize(ans.consensus.length());
	for (int i=0; i<ans.consensus.length(); i++) { 
		ans.pssm[i].resize(ABLEN);
		double sum = 0.0;
		for (int c=0; c<ABLEN; c++) { 
			in >> ans.pssm[i][c];
			sum += ans.pssm[i][c];
		}
		for (int c=0; c<ABLEN; c++) { 
			ans.pssm[i][c] = log(ans.pssm[i][c] / sum);	// normalize and log PSSM as it is read
		}
		in >> ws;
	}
	return ans;
}





/** calculate the probability of EACH base in a sequence (log scale), 
 *	given a Markov model background distribution 
 */
vector<double> calc_base_probability(const string &sequence, const vector<vector<long double> > &mm) { 

	int N = sequence.length();
	const char *c = sequence.c_str();

	vector<double> ans(N);
	int ORDER = mm.size() - 1;	// maximum order of the MM

	for (int p=0; p<N; p++) { 

		int order = (p < ORDER) ? p : ORDER;	// the order we'll use (usually ORDER)
		int offset = markov_offset(c + p - order, order + 1);	// relevant probability
		ans[p] = log( mm[order][offset] );
	}

	return ans;
}



/** 
 *	Scan for motifs in sequence.
 *	sequence:	LOWER CASE sequence
 *	pssm:		motif consensus[site][base]
 *	bg_prob:	probability of each base in the sequence (background probability)
 *	log_ratio:	log probability ratio that must be exceeded for a motif match
 *
 *	return:  list of (position, log ratio) pairs
 */
list<pair<int, double> > motif_scan(const string &sequence, 
                     const vector<vector<double> > &pssm,
					 const vector<double> bg_prob,
					 double log_ratio) {	

	list<pair<int, double> > ans;

	const int L = sequence.length();
	const int W = pssm.size();

	const int ablen = pssm[0].size();	// assume pssm has >= 1 site, all the same alphabet size

	long double lpB = 0.0;	// log probability of the Background model
	for (int i=0; i<W && i<L; i++) { 
		lpB += bg_prob[i];	
	}

	// for each possible sequence site for the START of the motif...
	for (int i=0; i<(L-W+1); i++) {	// NOTE:  since L,W are UNSIGNED, 
									//	(L-W+1) may be <= 0 (which means motif can't
									//	fit in sequence)

		double lpM = 0.0;	// log probability of a Motif
									
		for (int m=0; m<W && lpM > -INFINITY; m++) {	// for each site in the motif

			// TODO:  this is specific to DNA, and it ignores any non-ACGT
			// characters (e.g., 'N' or 'R' for PURINE).  A more general way
			// would be to ADD probabilities of all qualifying characters
			// (e.g., for 'R', add P('a') + P('g')).
			
			if (sequence[i+m] == alphabet[0]) { lpM += (pssm[m][0]); continue; }
			if (sequence[i+m] == alphabet[1]) { lpM += (pssm[m][1]); continue; }
			if (sequence[i+m] == alphabet[2]) { lpM += (pssm[m][2]); continue; }
			if (sequence[i+m] == alphabet[3]) { lpM += (pssm[m][3]); continue; }
			if (sequence[i+m] == ALPHABET[0]) { lpM += (pssm[m][0]); continue; }
			if (sequence[i+m] == ALPHABET[1]) { lpM += (pssm[m][1]); continue; }
			if (sequence[i+m] == ALPHABET[2]) { lpM += (pssm[m][2]); continue; }
			if (sequence[i+m] == ALPHABET[3]) { lpM += (pssm[m][3]); continue; }

			// illegal base found
			lpM = -INFINITY;
			break;
			
		}

		// prob's calc'ed.  Is this a hit?
		if ( lpM - lpB  >  log_ratio ) { ans.push_back(pair<int,double>(i, lpM-lpB)); }
	
		// update background probability
		if (i+W < L) { 
			lpB -= (bg_prob[i]);
			lpB += (bg_prob[i+W]); 
		}
		
	} // next sequence position

	return ans;
}

	


int main(int argc, char **argv) {

	bool verbose;
	double score;			// probability ratio
	string mfile;			// PSSM motif file
	string bfile, rfile;	// background Markov model files

	ostringstream oss;
	oss << "Find motif instances.  INPUT:  a \"pssm\" file with records as:  "
		<< "'CONSENSUS	E-VALUE	[[site-1-A-probability site-1-C-probability..G.. ..T..] [site-2-A-prob... ] ... ]'  "
		<< "OUTPUT:  a motif instance \"MIF\" file as:  "
		<< "'Sequence-ID(0-indexed) Motif-ID(0-indexed) Position "
		<< "Strand(1=Template,2=Transcribed) log-probability-ratio "
		<< "STRAND(English) sequence-match sequence-match-template-strand'.  "
		<< "Position is negative, measured from RIGHT side of input sequence.  ";
	
	OptionParser parser(oss.str());

	parser.add("motif", 'm', &mfile, "PSSM motif file");
	parser.add("score", 's', &score, 1.0, "Score:  Odds ratio");
	parser.add("bfile", 'b', &bfile, "", "Markov model background distribution file (otherwise uniform)");
	parser.add("rfile", 'r', &rfile, "", "Markov model background distribution file for reverse-complement (otherwise same as forward)");
	
	parser.add("silent", 'x', &verbose, true, "Silence progress/statistical output (to the standard error)");

	vector<string> args = parser.parse(argc, argv, "< FASTA input");

	const double log_ratio = log(score);


	if (verbose) { cerr << endl; }	// just a blank line

	// read Sequence file
	list<pair<string,string> > S;
	int Ssize;
	if (args.size()) { 
		ifstream fin;
		fin.open(args[0].c_str());
		if (!fin) { cerr << "Can't open " << args[0] << endl;  exit(-1); }
		while (fin.good()) { 
			S.push_back(pair<string,string>("",""));
			fasta_read(fin, S.back().first, S.back().second);
		}
		Ssize = S.size();
		if (verbose) { cerr << "Read " << Ssize << " sequences (" << args[0] << ")." << endl; }
		fin.close();
	} else {
		while (cin.good()) { 
			S.push_back(pair<string,string>("",""));
			fasta_read(cin, S.back().first, S.back().second);
		}
		Ssize = S.size();
		if (verbose) { cerr << "Read " << Ssize << " sequences (stdin)." << endl; }
	}

	// convert ALL sequences to LOWER CASE 
	for (list<pair<string,string> >::iterator s=S.begin(); s!=S.end(); s++) {
		tolower(s->second);
	}

	// read PSSM motif file 
	list<PSSM> pssms;
	int num_pssms = 0;
	ifstream mfin(mfile.c_str());
	if (!mfin) { cerr << "Can't open " << mfile << endl; exit(-1); }
	while (mfin.good()) { 
		pssms.push_back(read_pssm(mfin)); 
		num_pssms++;
	}
	if (verbose) { cerr << "Read " << num_pssms << " PSSMs (" << mfile << ")." << endl; }
	mfin.close();

	// read background file 
	vector<vector<long double> > fbgmm;
	if (bfile == "") { 
		fbgmm.resize(1);
		fbgmm[0].resize(ABLEN);
		std::fill(fbgmm[0].begin(), fbgmm[0].end(), 1.0/ABLEN);
		if (verbose) { cerr << "Using uniform background model." << endl; }
	} else {
		ifstream bfin(bfile.c_str());
		if (!bfin) { cerr << "Can't open " << bfile << endl;  exit(-1); }
		fbgmm = markov_read(bfin);
		bfin.close();
		if (verbose) { cerr << "Read " << (fbgmm.size()-1) << "-th order markov model (" << bfile << ")." << endl; }
	}

	// read revcomp background file
	vector<vector<long double> > rbgmm;
	if (rfile == "") { 
		rbgmm = fbgmm;
		if (verbose) { cerr << "Using forward background model as reverse complement." << endl; }
	} else {
		ifstream rfin(rfile.c_str());
		if (!rfin) { cerr << "Can't open " << rfile << endl;  exit(-1); }
		rbgmm = markov_read(rfin);
		rfin.close();
		if (verbose) { cerr << "Read " << (rbgmm.size()-1) << "-th order markov model (" << rfile << ")." << endl; }
	}

	// normalize background distribution over the LAST BASE
	for (int order=0; order<fbgmm.size(); order++) { 
		for (int entry=0; entry<fbgmm[order].size(); entry += ABLEN) { 
			long double sum = 0.0;
			for (int c=0; c<ABLEN; c++) { sum += fbgmm[order][entry+c]; }
			for (int c=0; c<ABLEN; c++) { fbgmm[order][entry+c] /= sum; }
		}
	}
	for (int order=0; order<rbgmm.size(); order++) { 
		for (int entry=0; entry<rbgmm[order].size(); entry += ABLEN) { 
			long double sum = 0.0;
			for (int c=0; c<ABLEN; c++) { sum += rbgmm[order][entry+c]; }
			for (int c=0; c<ABLEN; c++) { rbgmm[order][entry+c] /= sum; }
		}
	}

	int instances = 0;	// count of all found instances.

	// search each sequence
	int sid = 0;
	for (list<pair<string,string> >::const_iterator s=S.begin(); s!=S.end(); s++) {

		int L = s->second.length();	// sequence length
		
		string rc = revcomp(s->second);

		if (verbose) { cerr << "Processed " << sid << " sequences (" 
		                    << ((int)(100.0 * sid / Ssize)) << "%)...    \r"; }

		// calculate probability of each base
		vector<double> base_prob = calc_base_probability(s->second, fbgmm);
		vector<double> rc_base_prob = calc_base_probability(rc, rbgmm);
	
		// for each motif
		int mid = 0;
		for (list<PSSM>::const_iterator m=pssms.begin(); m!=pssms.end(); m++) {

			const int W = m->pssm.size();	// motif width; number of sites

			// scan forward sequence
			list<pair<int, double> > f_pos = motif_scan(s->second, m->pssm, base_prob, log_ratio);

			// write out answers
			for (list<pair<int, double> >::const_iterator i=f_pos.begin(); i!=f_pos.end(); i++) { 
				// write out data for sequence i:
				int p = i->first;
				double score = i->second;
				cout << sid << TAB << mid << TAB 
				     << (p - L) << TAB 
					 << "1" << TAB << score << TAB << "(template)" << TAB
				     << s->second.substr(p, W) << TAB 	// hit
					 << s->second.substr(p, W) << endl; 	// sequence

				instances++;
			}

			// scan reverse sequence
			list<pair<int, double> > r_pos = motif_scan(rc, m->pssm, rc_base_prob, log_ratio);

			// write these out, too
			for (list<pair<int, double> >::const_iterator i=r_pos.begin(); i!=r_pos.end(); i++) { 
				// write out data for sequence i:
				// need to "reverse" the position, first
				int p = i->first;
				double score = i->second;
				cout << sid << TAB << mid << TAB
				     << (-(p+W)) << TAB 
					 << "2" << TAB << score << TAB << "(transcribed)" << TAB
					 << rc.substr(p, W) << TAB 	// hit
					 << s->second.substr(L-W-p, W) << endl;	// sequence

				instances++;
			}

			mid++;
		} // next motif

		sid++;
	
	} // next sequence
	
	if (verbose) {
		cerr << argv[0] << ":  processed " << sid << " sequences and "
	         << "found " << instances << " motif instances." << endl << endl;
	}

	return 0;

}

