discreteHMM.cc

Go to the documentation of this file.
00001 /* MLPACK 0.2
00002  *
00003  * Copyright (c) 2008, 2009 Alexander Gray,
00004  *                          Garry Boyer,
00005  *                          Ryan Riegel,
00006  *                          Nikolaos Vasiloglou,
00007  *                          Dongryeol Lee,
00008  *                          Chip Mappus, 
00009  *                          Nishant Mehta,
00010  *                          Hua Ouyang,
00011  *                          Parikshit Ram,
00012  *                          Long Tran,
00013  *                          Wee Chin Wong
00014  *
00015  * Copyright (c) 2008, 2009 Georgia Institute of Technology
00016  *
00017  * This program is free software; you can redistribute it and/or
00018  * modify it under the terms of the GNU General Public License as
00019  * published by the Free Software Foundation; either version 2 of the
00020  * License, or (at your option) any later version.
00021  *
00022  * This program is distributed in the hope that it will be useful, but
00023  * WITHOUT ANY WARRANTY; without even the implied warranty of
00024  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
00025  * General Public License for more details.
00026  *
00027  * You should have received a copy of the GNU General Public License
00028  * along with this program; if not, write to the Free Software
00029  * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
00030  * 02110-1301, USA.
00031  */
00038 #include "fastlib/fastlib.h"
00039 #include "support.h"
00040 #include "discreteHMM.h"
00041 
00042 using namespace hmm_support;
00043 
00044 void DiscreteHMM::setModel(const Matrix& transmission, const Matrix& emission) {
00045   DEBUG_ASSERT(transmission.n_rows() == transmission.n_cols());
00046   DEBUG_ASSERT(transmission.n_rows() == emission.n_rows());
00047   transmission_.Destruct();
00048   emission_.Destruct();
00049   transmission_.Copy(transmission);
00050   emission_.Copy(emission);
00051 }
00052 
00053 void DiscreteHMM::Init(const Matrix& transmission, const Matrix& emission) {
00054   transmission_.Copy(transmission);
00055   emission_.Copy(emission);
00056   DEBUG_ASSERT(transmission.n_rows() == transmission.n_cols());
00057   DEBUG_ASSERT(transmission.n_rows() == emission.n_rows());
00058 }
00059 
00060 void DiscreteHMM::InitFromFile(const char* profile) {
00061   ArrayList<Matrix> list_mat;
00062   load_matrix_list(profile, &list_mat);
00063   if (list_mat.size() < 2)
00064     FATAL("Number of matrices in the file should be at least 2.");
00065   else if (list_mat.size() > 2)
00066     NONFATAL("Number of matrices in the file should be 2.");
00067   transmission_.Copy(list_mat[0]);
00068   emission_.Copy(list_mat[1]);
00069   DEBUG_ASSERT(transmission_.n_rows() == transmission_.n_cols());
00070   DEBUG_ASSERT(transmission_.n_rows() == emission_.n_rows());
00071 }
00072 
00073 void DiscreteHMM::InitFromData(const ArrayList<Vector>& list_data_seq, int numstate) {
00074   int numsymbol = 0;
00075   int maxseq = 0;
00076   for (int i = 0; i < list_data_seq.size(); i++) 
00077     if (list_data_seq[i].length() > list_data_seq[maxseq].length()) maxseq = i;
00078   for (int i = 0; i < list_data_seq[maxseq].length(); i++)
00079     if (list_data_seq[maxseq][i] > numsymbol) numsymbol = (int) list_data_seq[maxseq][i];
00080   numsymbol++;
00081   Vector states;
00082   int L = list_data_seq[maxseq].length();
00083   states.Init(L);
00084   for (int i = 0; i < L; i++) states[i] = rand() % numstate;
00085   DiscreteHMM::EstimateInit(numsymbol, numstate, list_data_seq[maxseq], states, &transmission_, &emission_);  
00086 }
00087 
00088 void DiscreteHMM::LoadProfile(const char* profile) {
00089   transmission_.Destruct();
00090   emission_.Destruct();
00091   InitFromFile(profile);
00092 }
00093 
00094 void DiscreteHMM::SaveProfile(const char* profile) const {
00095   TextWriter w_pro;
00096   if (!PASSED(w_pro.Open(profile))) {
00097     NONFATAL("Couldn't open '%s' for writing.", profile);
00098     return;
00099   }
00100 
00101   print_matrix(w_pro, transmission_, "%% transmision", "%f,");
00102   print_matrix(w_pro, emission_, "%% emission", "%f,");
00103 }
00104 
00105 void DiscreteHMM::GenerateSequence(int length, Vector* data_seq, Vector* state_seq) const {
00106   DiscreteHMM::GenerateInit(length, transmission_, emission_, data_seq, state_seq);
00107 }
00108 
00109 void DiscreteHMM::EstimateModel(const Vector& data_seq, const Vector& state_seq) {
00110   transmission_.Destruct();
00111   emission_.Destruct();
00112   DiscreteHMM::EstimateInit(data_seq, state_seq, &transmission_, &emission_);
00113 }
00114 
00115 void DiscreteHMM::EstimateModel(int numstate, int numsymbol, const Vector& data_seq, const Vector& state_seq) {
00116   transmission_.Destruct();
00117   emission_.Destruct();
00118   DiscreteHMM::EstimateInit(numsymbol, numstate, data_seq, state_seq, &transmission_, &emission_);
00119 }
00120 
00121 void DiscreteHMM::DecodeOverwrite(const Vector& data_seq, Matrix* state_prob_mat, Matrix* forward_prob_mat, Matrix* backward_prob_mat, Vector* scale_vec) const {
00122   DiscreteHMM::Decode(data_seq, transmission_, emission_, state_prob_mat, forward_prob_mat, backward_prob_mat, scale_vec);
00123 }
00124 
00125 void DiscreteHMM::DecodeInit(const Vector& data_seq, Matrix* state_prob_mat, Matrix* forward_prob_mat, Matrix* backward_prob_mat, Vector* scale_vec) const {
00126   int M = transmission_.n_rows();
00127   int L = data_seq.length();
00128   state_prob_mat->Init(M, L);
00129   forward_prob_mat->Init(M, L);
00130   backward_prob_mat->Init(M, L);
00131   scale_vec->Init(L);
00132   DiscreteHMM::Decode(data_seq, transmission_, emission_, state_prob_mat, forward_prob_mat, backward_prob_mat, scale_vec);
00133 }
00134 
00135 void forward_procedure(const Vector& seq, const Matrix& trans, const Matrix& emis, Vector *scales, Matrix* fs);
00136 
00137 double DiscreteHMM::ComputeLogLikelihood(const Vector& data_seq) const {
00138   int L = data_seq.length();
00139   int M = transmission_.n_rows();
00140   Matrix fs(M, L);
00141   Vector sc;
00142   sc.Init(L);
00143   DiscreteHMM::ForwardProcedure(data_seq, transmission_, emission_, &sc, &fs);
00144   double loglik = 0;
00145   for (int t = 0; t < L; t++)
00146     loglik += log(sc[t]);
00147   return loglik;
00148 }
00149 
00150 void DiscreteHMM::ComputeLogLikelihood(const ArrayList<Vector>& list_data_seq, ArrayList<double>* list_likelihood) const {
00151   int L = 0;
00152   for (int i = 0; i < list_data_seq.size(); i++)
00153     if (list_data_seq[i].length() > L) L = list_data_seq[i].length();
00154   int M = transmission_.n_rows();
00155   Matrix fs(M, L);
00156   Vector sc;
00157   sc.Init(L);
00158   list_likelihood->Init();
00159   for (int i = 0; i < list_data_seq.size(); i++) {
00160     DiscreteHMM::ForwardProcedure(list_data_seq[i], transmission_, emission_, &sc, &fs);
00161     int L = list_data_seq[i].length();
00162     double loglik = 0;
00163     for (int t = 0; t < L; t++)
00164       loglik += log(sc[t]);
00165     list_likelihood->PushBackCopy(loglik);
00166   }
00167 }
00168 
00169 void DiscreteHMM::ComputeViterbiStateSequence(const Vector& data_seq, Vector* state_seq) const {
00170   DiscreteHMM::ViterbiInit(data_seq, transmission_, emission_, state_seq);
00171 }
00172 
00173 void DiscreteHMM::TrainBaumWelch(const ArrayList<Vector>& list_data_seq, int max_iteration, double tolerance) {
00174   DiscreteHMM::Train(list_data_seq, &transmission_, &emission_, max_iteration, tolerance);
00175 }
00176 
00177 void DiscreteHMM::TrainViterbi(const ArrayList<Vector>& list_data_seq, int max_iteration, double tolerance) {
00178   DiscreteHMM::TrainViterbi(list_data_seq, &transmission_, &emission_, max_iteration, tolerance);
00179 }
00180 
00181 void DiscreteHMM::GenerateInit(int L, const Matrix& trans, const Matrix& emis, Vector* seq, Vector* states) {
00182   DEBUG_ASSERT_MSG((trans.n_rows()==trans.n_cols() && trans.n_rows()==emis.n_rows()), "hmm_generateD_init: matrices sizes do not match");
00183   Matrix trsum, esum;
00184   Vector &seq_ = *seq, &states_ = *states;
00185   int M, N;
00186   int cur_state;
00187 
00188   M = trans.n_rows();
00189   N = emis.n_cols();
00190 
00191   trsum.Copy(trans);
00192   esum.Copy(emis);
00193 
00194   for (int i = 0; i < M; i++) {
00195     for (int j = 1; j < M; j++)
00196       trsum.set(i, j, trsum.get(i, j) + trsum.get(i, j-1));
00197     for (int j = 1; j < N; j++) 
00198       esum.set(i, j, esum.get(i, j) + esum.get(i, j-1));
00199   }
00200 
00201   seq_.Init(L);
00202   states_.Init(L);
00203 
00204   cur_state = 0; // starting state is 0
00205   
00206   for (int i = 0; i < L; i++) {
00207     int j;
00208     double r;
00209 
00210     // next state
00211     r = RAND_UNIFORM_01();
00212     for (j = 0; j < M; j++)
00213       if (r <= trsum.get(cur_state, j)) break;
00214     cur_state = j;
00215         
00216     // emission
00217     r = RAND_UNIFORM_01();
00218     for (j = 0; j < N; j++)
00219       if (r <= esum.get(cur_state, j)) break;
00220     seq_[i] = j;
00221     states_[i] = cur_state;
00222   }
00223 }
00224 
00225 void DiscreteHMM::EstimateInit(const Vector& seq, const Vector& states, Matrix* trans, Matrix* emis) {
00226   DEBUG_ASSERT_MSG((seq.length()==states.length()), "hmm_estimateD_init: sequence and states length must be the same");
00227   int M = 0, N=0;
00228   for (int i = 0; i < seq.length(); i++) {
00229     if (seq[i] > N) N = (int) seq[i];
00230     if (states[i] > M) M = (int) states[i];
00231   }
00232   M++;
00233   N++;
00234   DiscreteHMM::EstimateInit(N, M, seq, states, trans, emis);
00235 }
00236 
00237 void DiscreteHMM::EstimateInit(int numSymbols, int numStates, const Vector& seq, const Vector& states, Matrix* trans, Matrix* emis){
00238   DEBUG_ASSERT_MSG((seq.length()==states.length()), "hmm_estimateD_init: sequence and states length must be the same");
00239   int N = numSymbols;
00240   int M = numStates;
00241   int L = seq.length();
00242   
00243   Matrix &trans_ = *trans;
00244   Matrix &emis_ = *emis;
00245   Vector stateSum;
00246 
00247   trans_.Init(M, M);
00248   emis_.Init(M, N);
00249   stateSum.Init(M);
00250 
00251   trans_.SetZero();
00252   emis_.SetZero();
00253 
00254   stateSum.SetZero();
00255   for (int i = 0; i < L-1; i++) {
00256     int state = (int) states[i];
00257     int next_state = (int) states[i+1];
00258     stateSum[state]++;
00259     trans_.ref(state, next_state)++;
00260   }
00261   for (int i = 0; i < M; i++) {
00262     if (stateSum[i] == 0) stateSum[i] = -INFINITY;
00263     for (int j = 0; j < M; j++)
00264       trans_.ref(i, j) /= stateSum[i];
00265   }
00266 
00267   stateSum.SetZero();
00268   for (int i = 0; i < L; i++) {
00269     int state = (int) states[i];
00270     int emission = (int) seq[i];
00271     stateSum[state]++;
00272     emis_.ref(state, emission)++;
00273   }
00274   for (int i = 0; i < M; i++) {
00275     if (stateSum[i] == 0) stateSum[i] = -INFINITY;
00276     for (int j = 0; j < N; j++)
00277       emis_.ref(i, j) /= stateSum[i];
00278   }
00279 }
00280 
00281 void DiscreteHMM::ForwardProcedure(const Vector& seq, const Matrix& trans, const Matrix& emis, Vector *scales, Matrix* fs) {
00282   int L = seq.length();
00283   int M = trans.n_rows();
00284 
00285   Matrix& fs_ = *fs;
00286   Vector& s_ = *scales;
00287 
00288   fs_.SetZero();
00289   s_.SetZero();
00290   // NOTE: start state is 0
00291   // time t = 0
00292   int e = (int) seq[0];
00293   for (int i = 0; i < M; i++) {
00294     fs_.ref(i, 0) = trans.get(0, i) * emis.get(i, e);
00295     s_[0] += fs_.get(i, 0);
00296   }
00297   for (int i = 0; i < M; i++)
00298     fs_.ref(i, 0) /= s_[0];
00299 
00300   // time t = 1 -> L-1
00301   for (int t = 1; t < L; t++) {
00302     e = (int) seq[t];
00303     for (int j = 0; j < M; j++) {
00304       for (int i = 0; i < M; i++)
00305         fs_.ref(j, t) += fs_.get(i, t-1)*trans.get(i, j);
00306       fs_.ref(j, t) *= emis.get(j, e);
00307       s_[t] += fs_.get(j, t);
00308     }
00309     for (int j = 0; j < M; j++)
00310       fs_.ref(j, t) /= s_[t];
00311   }
00312 }
00313 
00314 void DiscreteHMM::BackwardProcedure(const Vector& seq, const Matrix& trans, const Matrix& emis, const Vector& scales, Matrix* bs) {
00315   int L = seq.length();
00316   int M = trans.n_rows();
00317 
00318   Matrix& bs_ = *bs;
00319   bs_.SetZero();
00320   for (int i = 0; i < M; i++)
00321     bs_.ref(i, L-1) = 1.0;
00322 
00323   for (int t = L-2; t >= 0; t--) {
00324     int e = (int) seq[t+1];
00325     for (int i = 0; i < M; i++) {
00326       for (int j = 0; j < M; j++)
00327         bs_.ref(i, t) += trans.get(i, j) * bs_.ref(j, t+1) * emis.get(j, e);
00328       bs_.ref(i, t) /= scales[t+1];
00329     }
00330   }
00331 }
00332 
00333 double DiscreteHMM::Decode(const Vector& seq, const Matrix& trans, const Matrix& emis, Matrix* pstates, Matrix* fs, Matrix* bs, Vector* scales) {
00334   int L = seq.length();
00335   int M = trans.n_rows();
00336 
00337   DEBUG_ASSERT_MSG((L==pstates->n_cols() && L==fs->n_cols() && L == bs->n_cols() && 
00338                     M==trans.n_cols() && M==emis.n_rows()),"hmm_decodeD: sizes do not match");
00339   
00340   Matrix& ps_ = *pstates;
00341   Vector& s_ = *scales;
00342 
00343   DiscreteHMM::ForwardProcedure(seq, trans, emis, &s_, fs);
00344   DiscreteHMM::BackwardProcedure(seq, trans, emis, s_, bs);
00345 
00346   for (int i = 0; i < M; i++)
00347     for (int t = 0; t < L; t++)
00348       ps_.ref(i, t) = fs->get(i,t) * bs->get(i,t);
00349 
00350   double logpseq = 0;
00351   for (int t = 0; t < L; t++) 
00352     logpseq += log(s_[t]);
00353 
00354   return logpseq;
00355 }
00356 
00357 double DiscreteHMM::ViterbiInit(const Vector& seq, const Matrix& trans, const Matrix& emis, Vector* states) {
00358   int L = seq.length();
00359   return DiscreteHMM::ViterbiInit(L, seq, trans, emis, states);
00360 }
00361 
00362 double DiscreteHMM::ViterbiInit(int L, const Vector& seq, const Matrix& trans, const Matrix& emis, Vector* states) {
00363   int M = trans.n_rows();
00364   int N = emis.n_cols();
00365   DEBUG_ASSERT_MSG((M==trans.n_cols() && M==emis.n_rows()),"hmm_viterbiD: sizes do not match");
00366   
00367   Vector& s_ = *states;
00368   s_.Init(L);
00369   
00370   Vector v, vOld;
00371   v.Init(M);
00372   v.SetAll(-INFINITY);
00373   v[0] = 0;
00374   vOld.Copy(v);
00375 
00376   Matrix w;
00377   w.Init(M, L);
00378 
00379   Matrix logtrans, logemis;
00380   logtrans.Init(M, M);
00381   logemis.Init(M, N);
00382 
00383   for (int i = 0; i < M; i++) {
00384     for (int j = 0; j < M; j++) logtrans.ref(i, j) = log(trans.get(i, j));
00385     for (int j = 0; j < N; j++) logemis.ref(i, j) = log(emis.get(i, j));
00386   }
00387 
00388   for (int t = 0; t < L; t++) {
00389     int e = (int) seq[t];
00390     for (int j = 0; j < M; j++) {
00391       double bestVal = -INFINITY;
00392       double bestPtr = -1;      
00393       for (int i = 0; i < M; i++) {
00394         double val = vOld[i] + logtrans.get(i, j);
00395         if (val > bestVal) {
00396           bestVal = val;
00397           bestPtr = i;
00398         }
00399       }
00400       v[j] = bestVal + logemis.get(j, e);
00401       w.ref(j, t) = bestPtr;
00402     }
00403     vOld.CopyValues(v);
00404   }
00405 
00406   double bestVal = -INFINITY;
00407   double bestPtr = -1;
00408   for (int i = 0; i < M; i++)
00409     if (v[i] > bestVal) {
00410       bestVal = v[i];
00411       bestPtr = i;
00412     }
00413   
00414   s_[L-1] = bestPtr;
00415   for (int t = L-2; t >= 0; t--) {
00416     s_[t] = w.get((int)s_[t+1], t+1);
00417   }
00418 
00419   return bestVal;
00420 }
00421 
00422 void DiscreteHMM::Train(const ArrayList<Vector>& seqs, Matrix* guessTR, Matrix* guessEM, int max_iter, double tol) {
00423   int L = -1;
00424   int M = guessTR->n_rows();
00425   int N = guessEM->n_cols();
00426   DEBUG_ASSERT_MSG((M==guessTR->n_cols() && M==guessEM->n_rows()),"hmm_trainD: sizes do not match");
00427   
00428   for (int i = 0; i < seqs.size(); i++)
00429     if (seqs[i].length() > L) L = seqs[i].length();
00430 
00431   Matrix &gTR = *guessTR, &gEM = *guessEM;
00432   Matrix TR, EM; // guess transition and emission matrix
00433   TR.Init(M, M);
00434   EM.Init(M, N);
00435 
00436   Matrix ps, fs, bs;
00437   Vector s;
00438 
00439   ps.Init(M, L);
00440   fs.Init(M, L);
00441   bs.Init(M, L);
00442   s.Init(L);
00443 
00444   double loglik = 0, oldlog;
00445   for (int iter = 0; iter < max_iter; iter++) {
00446     oldlog = loglik;
00447     loglik = 0;
00448 
00449     TR.SetZero();
00450     EM.SetZero();
00451     for (int idx = 0; idx < seqs.size(); idx++) {
00452       L = seqs[idx].length();
00453       loglik += DiscreteHMM::Decode(seqs[idx], gTR, gEM, &ps, &fs, &bs, &s);
00454       
00455       for (int t = 0; t < L-1; t++) {
00456         int e = (int) seqs[idx][t+1];
00457         for (int i = 0; i < M; i++)
00458           for (int j = 0; j < M; j++)
00459             TR.ref(i, j) += fs.get(i, t) * gTR.get(i, j) * gEM.get(j, e) * bs.get(j, t+1) / s[t+1];
00460       }
00461       
00462       for (int t = 0; t < L; t++) {
00463         int e = (int) seqs[idx][t];
00464         for (int i = 0; i < M; i++)
00465           EM.ref(i, e) += ps.get(i, t);
00466       }
00467     }
00468 
00469     double s;
00470     for (int i = 0; i < M; i++) {
00471       s = 0;
00472       for (int j = 0; j < M; j++) s += TR.get(i, j);
00473       if (s == 0) {
00474         for (int j = 0; j < M; j++) gTR.ref(i, j) = 0;
00475         gTR.ref(i, i) = 1;
00476       }
00477       else {
00478         for (int j = 0; j < M; j++) gTR.ref(i, j) = TR.get(i, j) / s;
00479       }
00480       
00481       s = 0;
00482       for (int j = 0; j < N; j++) s += EM.get(i, j);
00483       for (int j = 0; j < N; j++) gEM.ref(i, j) = EM.get(i, j) / s;
00484     }
00485 
00486     printf("Iter = %d Loglik = %8.4f\n", iter, loglik);
00487     if (fabs(oldlog - loglik) < tol) {
00488       printf("\nConverged after %d iterations\n", iter);
00489       break;
00490     }
00491     oldlog = loglik;
00492   }
00493 }
00494 
00495 void DiscreteHMM::TrainViterbi(const ArrayList<Vector>& seqs, Matrix* guessTR, Matrix* guessEM, int max_iter, double tol) {
00496   int L = -1;
00497   int M = guessTR->n_rows();
00498   int N = guessEM->n_cols();
00499   DEBUG_ASSERT_MSG((M==guessTR->n_cols() && M==guessEM->n_rows()),"hmm_trainD: sizes do not match");
00500   
00501   for (int i = 0; i < seqs.size(); i++)
00502     if (seqs[i].length() > L) L = seqs[i].length();
00503 
00504   Matrix &gTR = *guessTR, &gEM = *guessEM;
00505   Matrix TR, EM; // guess transition and emission matrix
00506   TR.Init(M, M);
00507   EM.Init(M, N);
00508 
00509   double loglik = 0, oldlog;
00510   for (int iter = 0; iter < max_iter; iter++) {
00511     oldlog = loglik;
00512     loglik = 0;
00513 
00514     TR.SetAll(1e-4);
00515     EM.SetAll(1e-4);
00516     for (int idx = 0; idx < seqs.size(); idx++) {
00517       Vector states;
00518       L = seqs[idx].length();
00519       loglik += DiscreteHMM::ViterbiInit(L, seqs[idx], gTR, gEM, &states);
00520       
00521       for (int t = 0; t < L-1; t++) {
00522         int i = (int) states[t];
00523         int j = (int) states[t+1];
00524         TR.ref(i, j) ++;
00525       }
00526       
00527       for (int t = 0; t < L; t++) {
00528         int e = (int) seqs[idx][t];
00529         int i = (int) states[t];
00530         EM.ref(i, e) ++;
00531       }
00532     }
00533 
00534     double s;
00535     print_matrix(TR, "TR");
00536     for (int i = 0; i < M; i++) {
00537       s = 0;
00538       for (int j = 0; j < M; j++) s += TR.get(i, j);
00539       if (s == 0) {
00540         for (int j = 0; j < M; j++) gTR.ref(i, j) = 0;
00541         gTR.ref(i, i) = 1;
00542       }
00543       else {
00544         for (int j = 0; j < M; j++) gTR.ref(i, j) = TR.get(i, j) / s;
00545       }
00546       
00547       s = 0;
00548       for (int j = 0; j < N; j++) s += EM.get(i, j);
00549       for (int j = 0; j < N; j++) gEM.ref(i, j) = EM.get(i, j) / s;
00550     }
00551 
00552     printf("Iter = %d Loglik = %8.4f\n", iter, loglik);
00553     if (fabs(oldlog - loglik) < tol) {
00554       printf("\nConverged after %d iterations\n", iter);
00555       break;
00556     }
00557     oldlog = loglik;
00558   }
00559 }
Generated on Mon Jan 24 12:04:38 2011 for FASTlib by  doxygen 1.6.3