train.cc

Go to the documentation of this file.
00001 /* MLPACK 0.2
00002  *
00003  * Copyright (c) 2008, 2009 Alexander Gray,
00004  *                          Garry Boyer,
00005  *                          Ryan Riegel,
00006  *                          Nikolaos Vasiloglou,
00007  *                          Dongryeol Lee,
00008  *                          Chip Mappus, 
00009  *                          Nishant Mehta,
00010  *                          Hua Ouyang,
00011  *                          Parikshit Ram,
00012  *                          Long Tran,
00013  *                          Wee Chin Wong
00014  *
00015  * Copyright (c) 2008, 2009 Georgia Institute of Technology
00016  *
00017  * This program is free software; you can redistribute it and/or
00018  * modify it under the terms of the GNU General Public License as
00019  * published by the Free Software Foundation; either version 2 of the
00020  * License, or (at your option) any later version.
00021  *
00022  * This program is distributed in the hope that it will be useful, but
00023  * WITHOUT ANY WARRANTY; without even the implied warranty of
00024  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
00025  * General Public License for more details.
00026  *
00027  * You should have received a copy of the GNU General Public License
00028  * along with this program; if not, write to the Free Software
00029  * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
00030  * 02110-1301, USA.
00031  */
00045 #include "fastlib/fastlib.h"
00046 #include "support.h"
00047 #include "discreteHMM.h"
00048 #include "gaussianHMM.h"
00049 #include "mixgaussHMM.h"
00050 #include "mixtureDST.h"
00051 
00052 using namespace hmm_support;
00053 
00054 success_t train_baumwelch();
00055 success_t train_viterbi();
00056 void usage();
00057 
00058 const fx_entry_doc hmm_train_main_entries[] = {
00059   {"type", FX_REQUIRED, FX_STR, NULL,
00060    "  HMM type : discrete | gaussian | mixture.\n"},
00061   {"algorithm", FX_PARAM, FX_STR, NULL,
00062    "  Training algoritm: baumwelch | viterbi.\n"},
00063   {"seqfile", FX_REQUIRED, FX_STR, NULL,
00064    "  Output file for the data sequences.\n"},
00065   {"guess", FX_PARAM, FX_STR, NULL,
00066    "  File containing guessing HMM model profile.\n"},
00067   {"numstate", FX_PARAM, FX_INT, NULL,
00068    "  If no guessing profile specified, at least provide the number of states.\n"},
00069   {"profile", FX_REQUIRED, FX_STR, NULL,
00070    "  Output file containing trained HMM profile.\n"},
00071   {"maxiter", FX_PARAM, FX_INT, NULL,
00072    "  Maximum number of iterations, default = 500.\n"},
00073   {"tolerance", FX_PARAM, FX_DOUBLE, NULL,
00074    "  Error tolerance on log-likelihood as a stopping criteria.\n"},
00075   FX_ENTRY_DOC_DONE
00076 };
00077 
00078 const fx_submodule_doc hmm_train_main_submodules[] = {
00079   FX_SUBMODULE_DOC_DONE
00080 };
00081 
00082 const fx_module_doc hmm_train_main_doc = {
00083   hmm_train_main_entries, hmm_train_main_submodules,
00084   "This is a program training HMM models from data sequences. \n"
00085 };
00086 
00087 void usage() {
00088   printf("\nUsage:\n"
00089          "  train --type=={discrete|gaussian|mixture} OPTION\n"
00090          "[OPTIONS]\n"
00091          "  --algorithm={baumwelch|viterbi} : algorithm used for training, default Baum-Welch\n"
00092          "  --seqfile=file   : file contains input sequences\n"
00093          "  --guess=file     : file contains guess HMM profile\n"
00094          "  --numstate=NUM   : if no guess profile is specified, at least specify the number of state\n"
00095          "  --profile=file   : output file for estimated HMM profile\n"
00096          "  --maxiter=NUM    : maximum number of iteration, default=500\n"
00097          "  --tolerance=NUM  : error tolerance on log-likelihood, default=1e-3\n"
00098          );
00099 }
00100 
00101 int main(int argc, char* argv[]) {
00102   fx_init(argc, argv, &hmm_train_main_doc);
00103   success_t s = SUCCESS_PASS;
00104   if (fx_param_exists(NULL,"type")) {
00105     const char* algorithm = fx_param_str(NULL, "algorithm", "baumwelch");
00106     if (strcmp(algorithm,"baumwelch")==0)
00107       s = train_baumwelch();
00108     else if (strcmp(algorithm,"viterbi")==0)
00109       s = train_viterbi();
00110     else {
00111       printf("Unrecognized algorithm: must be baumwelch or viterbi !!!\n");
00112       s = SUCCESS_FAIL;
00113     }
00114   }
00115   else {
00116     printf("Unrecognized type: must be: discrete | gaussian | mixture  !!!\n");
00117     s = SUCCESS_FAIL;
00118   }
00119   if (!PASSED(s)) usage();
00120   fx_done(NULL);
00121 }
00122 
00123 success_t train_baumwelch_discrete();
00124 success_t train_baumwelch_gaussian();
00125 success_t train_baumwelch_mixture();
00126 
00127 success_t train_baumwelch() {
00128   const char* type = fx_param_str_req(NULL, "type");
00129   if (strcmp(type, "discrete")==0)
00130     return train_baumwelch_discrete();
00131   else if (strcmp(type, "gaussian")==0)
00132     return train_baumwelch_gaussian();
00133   else if (strcmp(type, "mixture")==0)
00134     return train_baumwelch_mixture();
00135   else {
00136     printf("Unrecognized type: must be: discrete | gaussian | mixture !!!\n");
00137     return SUCCESS_FAIL;
00138   }
00139 }
00140 
00141 success_t train_viterbi_discrete();
00142 success_t train_viterbi_gaussian();
00143 success_t train_viterbi_mixture();
00144 
00145 success_t train_viterbi() {
00146   const char* type = fx_param_str_req(NULL, "type");
00147   if (strcmp(type, "discrete")==0)
00148     return train_viterbi_discrete();
00149   else if (strcmp(type, "gaussian")==0)
00150     return train_viterbi_gaussian();
00151   else if (strcmp(type, "mixture")==0)
00152     return train_viterbi_mixture();
00153   else {
00154     printf("Unrecognized type: must be: discrete | gaussian | mixture !!!\n");
00155     return SUCCESS_FAIL;
00156   }
00157 }
00158 
00159 success_t train_baumwelch_mixture() {
00160   if (!fx_param_exists(NULL, "seqfile")) {
00161     printf("--seqfile must be defined.\n");
00162     return SUCCESS_FAIL;
00163   }
00164 
00165   MixtureofGaussianHMM hmm;
00166   ArrayList<Matrix> seqs;
00167 
00168   const char* seqin = fx_param_str_req(NULL, "seqfile");
00169   const char* proout = fx_param_str(NULL, "profile", "pro.mix.out");
00170 
00171   load_matrix_list(seqin, &seqs);
00172 
00173   if (fx_param_exists(NULL, "guess")) { // guessed parameters in a file
00174     const char* guess = fx_param_str_req(NULL, "guess");
00175     printf("Load parameters from file %s\n", guess);
00176     hmm.InitFromFile(guess);
00177   }
00178   else {
00179     hmm.Init();
00180     printf("Automatic initialization not supported !!!");
00181     return SUCCESS_FAIL;
00182   }
00183 
00184   int maxiter = fx_param_int(NULL, "maxiter", 500);
00185   double tol = fx_param_double(NULL, "tolerance", 1e-3);
00186 
00187   hmm.TrainBaumWelch(seqs, maxiter, tol);
00188 
00189   hmm.SaveProfile(proout);
00190 
00191   return SUCCESS_PASS;
00192 }
00193 
00194 success_t train_baumwelch_gaussian() {
00195   if (!fx_param_exists(NULL, "seqfile")) {
00196     printf("--seqfile must be defined.\n");
00197     return SUCCESS_FAIL;
00198   }
00199   GaussianHMM hmm;
00200   ArrayList<Matrix> seqs;
00201 
00202   const char* seqin = fx_param_str_req(NULL, "seqfile");
00203   const char* proout = fx_param_str(NULL, "profile", "pro.gauss.out");
00204 
00205   load_matrix_list(seqin, &seqs);
00206 
00207   if (fx_param_exists(NULL, "guess")) { // guessed parameters in a file
00208     const char* guess = fx_param_str_req(NULL, "guess");
00209     printf("Load parameters from file %s\n", guess);
00210     hmm.InitFromFile(guess);
00211   }
00212   else { // otherwise initialized using information from the data
00213     int numstate = fx_param_int_req(NULL, "numstate");
00214     printf("Generate HMM parameters: NUMSTATE = %d\n", numstate);
00215     hmm.InitFromData(seqs, numstate);
00216     printf("Done.\n");
00217   }
00218 
00219   int maxiter = fx_param_int(NULL, "maxiter", 500);
00220   double tol = fx_param_double(NULL, "tolerance", 1e-3);
00221 
00222   printf("Training ...\n");
00223   hmm.TrainBaumWelch(seqs, maxiter, tol);
00224   printf("Done.\n");
00225 
00226   hmm.SaveProfile(proout);
00227 
00228   return SUCCESS_PASS;
00229 }
00230 
00231 success_t train_baumwelch_discrete() {
00232   if (!fx_param_exists(NULL, "seqfile")) {
00233     printf("--seqfile must be defined.\n");
00234     return SUCCESS_FAIL;
00235   }
00236 
00237   const char* seqin = fx_param_str_req(NULL, "seqfile");
00238   const char* proout = fx_param_str(NULL, "profile", "pro.dis.out");
00239 
00240   ArrayList<Vector> seqs;
00241   load_vector_list(seqin, &seqs);
00242 
00243   DiscreteHMM hmm;
00244 
00245   if (fx_param_exists(NULL, "guess")) { // guessed parameters in a file
00246     const char* guess = fx_param_str_req(NULL, "guess");
00247     printf("Load HMM parameters from file %s\n", guess);
00248     hmm.InitFromFile(guess);
00249   }
00250   else { // otherwise randomly initialized using information from the data
00251     int numstate = fx_param_int_req(NULL, "numstate");
00252     printf("Randomly generate parameters: NUMSTATE = %d\n", numstate);
00253     hmm.InitFromData(seqs, numstate);
00254   }
00255 
00256   int maxiter = fx_param_int(NULL, "maxiter", 500);
00257   double tol = fx_param_double(NULL, "tolerance", 1e-3);
00258 
00259   hmm.TrainBaumWelch(seqs, maxiter, tol);
00260 
00261   hmm.SaveProfile(proout);
00262 
00263   return SUCCESS_PASS;
00264 }
00265 
00266 success_t train_viterbi_mixture() {
00267   if (!fx_param_exists(NULL, "seqfile")) {
00268     printf("--seqfile must be defined.\n");
00269     return SUCCESS_FAIL;
00270   }
00271   
00272   MixtureofGaussianHMM hmm;
00273   ArrayList<Matrix> seqs;
00274 
00275   const char* seqin = fx_param_str_req(NULL, "seqfile");
00276   const char* proout = fx_param_str(NULL, "profile", "pro.mix.out");
00277 
00278   load_matrix_list(seqin, &seqs);
00279 
00280   if (fx_param_exists(NULL, "guess")) { // guessed parameters in a file
00281     const char* guess = fx_param_str_req(NULL, "guess");
00282     printf("Load parameters from file %s\n", guess);
00283     hmm.InitFromFile(guess);
00284   }
00285   else {
00286     hmm.Init();
00287     printf("Automatic initialization not supported !!!");
00288     return SUCCESS_FAIL;
00289   }
00290 
00291   int maxiter = fx_param_int(NULL, "maxiter", 500);
00292   double tol = fx_param_double(NULL, "tolerance", 1e-3);
00293 
00294   hmm.TrainViterbi(seqs, maxiter, tol);
00295 
00296   hmm.SaveProfile(proout);
00297 
00298   return SUCCESS_PASS;
00299 }
00300 
00301 success_t train_viterbi_gaussian() {
00302   if (!fx_param_exists(NULL, "seqfile")) {
00303     printf("--seqfile must be defined.\n");
00304     return SUCCESS_FAIL;
00305   }
00306   
00307   GaussianHMM hmm;
00308   ArrayList<Matrix> seqs;
00309 
00310   const char* seqin = fx_param_str_req(NULL, "seqfile");
00311   const char* proout = fx_param_str(NULL, "profile", "pro.gauss.viterbi.out");
00312 
00313   load_matrix_list(seqin, &seqs);
00314 
00315   if (fx_param_exists(NULL, "guess")) { // guessed parameters in a file
00316     const char* guess = fx_param_str_req(NULL, "guess");
00317     printf("Load parameters from file %s\n", guess);
00318     hmm.InitFromFile(guess);
00319   }
00320   else { // otherwise initialized using information from the data
00321     int numstate = fx_param_int_req(NULL, "numstate");
00322     printf("Generate parameters: NUMSTATE = %d\n", numstate);
00323     hmm.InitFromData(seqs, numstate);
00324   }
00325 
00326   int maxiter = fx_param_int(NULL, "maxiter", 500);
00327   double tol = fx_param_double(NULL, "tolerance", 1e-3);
00328 
00329   hmm.TrainViterbi(seqs, maxiter, tol);
00330 
00331   hmm.SaveProfile(proout);
00332 
00333   return SUCCESS_PASS;
00334 }
00335 
00336 success_t train_viterbi_discrete() {
00337   if (!fx_param_exists(NULL, "seqfile")) {
00338     printf("--seqfile must be defined.\n");
00339     return SUCCESS_FAIL;
00340   }
00341 
00342   DiscreteHMM hmm;
00343   ArrayList<Vector> seqs;
00344 
00345   const char* seqin = fx_param_str_req(NULL, "seqfile");
00346   const char* proout = fx_param_str(NULL, "profile", "pro.dis.viterbi.out");
00347 
00348   load_vector_list(seqin, &seqs);
00349 
00350   if (fx_param_exists(NULL, "guess")) { // guessed parameters in a file
00351     ArrayList<Matrix> matlst;
00352     const char* guess = fx_param_str_req(NULL, "guess");
00353     printf("Load parameters from file %s\n", guess);
00354     hmm.InitFromFile(guess);
00355   }
00356   else { // otherwise randomly initialized using information from the data
00357     int numstate = fx_param_int_req(NULL, "numstate");
00358     printf("Generate parameters with NUMSTATE = %d\n", numstate);
00359     hmm.InitFromData(seqs, numstate);
00360   }
00361 
00362   int maxiter = fx_param_int(NULL, "maxiter", 500);
00363   double tol = fx_param_double(NULL, "tolerance", 1e-3);
00364 
00365   hmm.TrainViterbi(seqs, maxiter, tol);
00366 
00367   hmm.SaveProfile(proout);
00368 
00369   return SUCCESS_PASS;
00370 }
Generated on Mon Jan 24 12:04:38 2011 for FASTlib by  doxygen 1.6.3