crossvalidation.h

00001 /* MLPACK 0.2
00002  *
00003  * Copyright (c) 2008, 2009 Alexander Gray,
00004  *                          Garry Boyer,
00005  *                          Ryan Riegel,
00006  *                          Nikolaos Vasiloglou,
00007  *                          Dongryeol Lee,
00008  *                          Chip Mappus, 
00009  *                          Nishant Mehta,
00010  *                          Hua Ouyang,
00011  *                          Parikshit Ram,
00012  *                          Long Tran,
00013  *                          Wee Chin Wong
00014  *
00015  * Copyright (c) 2008, 2009 Georgia Institute of Technology
00016  *
00017  * This program is free software; you can redistribute it and/or
00018  * modify it under the terms of the GNU General Public License as
00019  * published by the Free Software Foundation; either version 2 of the
00020  * License, or (at your option) any later version.
00021  *
00022  * This program is distributed in the hope that it will be useful, but
00023  * WITHOUT ANY WARRANTY; without even the implied warranty of
00024  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
00025  * General Public License for more details.
00026  *
00027  * You should have received a copy of the GNU General Public License
00028  * along with this program; if not, write to the Free Software
00029  * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
00030  * 02110-1301, USA.
00031  */
00038 #ifndef DATA_CROSSVALIDATION
00039 #define DATA_CROSSVALIDATION
00040 
00041 #include "fastlib/data/dataset.h"
00042 //#include "dataset.h"
00043 
00044 #include "fastlib/la/matrix.h"
00045 #include "fastlib/fx/fx.h"
00046 
00114 template<class TClassifier>
00115 class SimpleCrossValidator {
00116   FORBID_ACCIDENTAL_COPIES(SimpleCrossValidator);
00117   
00118  public:
00120   typedef TClassifier Classifier;
00121   
00122  private:
00124   const Dataset *data_;
00126   datanode *root_module_;
00128   datanode *kfold_module_;
00130   int n_folds_;
00132   int n_classes_;
00134   const char *classifier_fx_name_;
00136   index_t n_correct_;
00138   Matrix confusion_matrix_;
00139   
00140  public:
00141   SimpleCrossValidator() {}
00142   ~SimpleCrossValidator() {}
00143   
00159   void Init(
00160       const Dataset *data_with_labels,
00161       int n_labels,
00162       int default_k,
00163       struct datanode *module_root,
00164       const char *classifier_fx_name,
00165       const char *kfold_fx_name = "kfold");
00166   
00173   void Run(bool randomized = false);
00174   
00175   
00177   index_t n_correct() {
00178     return n_correct_;
00179   }
00180   
00182   index_t n_incorrect() {
00183     return data_->n_points() - n_correct_;
00184   }
00185   
00187   double portion_correct() {
00188     return n_correct_ * 1.0 / data_->n_points();
00189   }
00190   
00197   const Matrix& confusion_matrix() const {
00198     return confusion_matrix_;
00199   }
00200   
00202   const Dataset& data() const {
00203     return *data_;
00204   }
00205 
00206  private:
00207   void SaveTrainTest_(int i_fold,
00208       const Dataset& train, const Dataset& test) const;
00209 };
00210 
00211 template<class TClassifier>
00212 void SimpleCrossValidator<TClassifier>::SaveTrainTest_(
00213     int i_fold,
00214     const Dataset& train, const Dataset& test) const {
00215   String train_name;
00216   String test_name;
00217   
00218   train_name.InitSprintf("train_%d.csv", i_fold);
00219   test_name.InitSprintf("test_%d.csv", i_fold);
00220   
00221   train.WriteCsv(train_name);
00222   test.WriteCsv(test_name);
00223 }
00224 
00225 
00226 template<class TClassifier>
00227 void SimpleCrossValidator<TClassifier>::Init(
00228     const Dataset *data_with_labels,
00229     int n_labels,
00230     int default_k,
00231     struct datanode *module_root,
00232     const char *classifier_fx_name,
00233     const char *kfold_fx_name) {
00234   data_ = data_with_labels;
00235   
00236   if (n_labels <= 0) {
00237     const DatasetFeature *feature =
00238         &data_->info().feature(data_->n_features() - 1);
00239     DEBUG_ASSERT_MSG(feature->type() == DatasetFeature::NOMINAL,
00240         "Must specify number of classes/labels if the feature is not nominal.");
00241     n_classes_ = feature->n_values();
00242   } else {
00243     n_classes_ = n_labels;
00244   }
00245   
00246   root_module_ = module_root;
00247   kfold_module_ = fx_submodule(module_root, kfold_fx_name);
00248   classifier_fx_name_ = classifier_fx_name;
00249   
00250   n_folds_ = fx_param_int(kfold_module_, "k", default_k);
00251   
00252   DEBUG_ONLY(n_correct_ = BIG_BAD_NUMBER);
00253   
00254   confusion_matrix_.Init(n_classes_, n_classes_);
00255   confusion_matrix_.SetZero();
00256 }
00257 
00258 template<class TClassifier>
00259 void SimpleCrossValidator<TClassifier>::Run(bool randomized) {
00260   ArrayList<index_t> permutation;
00261   
00262   if (randomized) {
00263     math::MakeRandomPermutation(data_->n_points(), &permutation);
00264   } else {
00265     math::MakeIdentityPermutation(data_->n_points(), &permutation);
00266   }
00267   
00268   n_correct_ = 0;
00269   
00270   fx_timer_start(kfold_module_, "total");
00271   
00272   for (int i_fold = 0; i_fold < n_folds_; i_fold++) {
00273     Classifier classifier;
00274     Dataset test;
00275     Dataset train;
00276     index_t local_n_correct = 0;
00277     datanode *classifier_module = fx_copy_module(root_module_,
00278         classifier_fx_name_, "%s/%d/%s",
00279         kfold_module_->key, i_fold, classifier_fx_name_);
00280     datanode *foldmodule = fx_submodule(classifier_module, "..");
00281 
00282     data_->SplitTrainTest(n_folds_, i_fold, permutation, &train, &test);
00283     
00284     if (fx_param_bool(kfold_module_, "save", 0)) {
00285       SaveTrainTest_(i_fold, train, test);
00286     }
00287   
00288     VERBOSE_MSG(1, "cross: Training fold %d", i_fold);
00289     fx_timer_start(foldmodule, "train");
00290     classifier.InitTrain(train, n_classes_, classifier_module);
00291     fx_timer_stop(foldmodule, "train");
00292     
00293     fx_timer_start(foldmodule, "test");
00294     VERBOSE_MSG(1, "cross: Testing fold %d", i_fold);
00295     for (index_t i = 0; i < test.n_points(); i++) {
00296       Vector test_vector_with_label;
00297       Vector test_vector;
00298       
00299       test.matrix().MakeColumnVector(i, &test_vector_with_label);
00300       test_vector_with_label.MakeSubvector(
00301           0, test.n_features()-1, &test_vector);
00302       
00303       int label_predict = classifier.Classify(test_vector);
00304       double label_expect_dbl = test_vector_with_label[test.n_features()-1];
00305       int label_expect = int(label_expect_dbl);
00306       
00307       DEBUG_ASSERT(double(label_expect) == label_expect_dbl);
00308       DEBUG_ASSERT(label_expect < n_classes_);
00309       DEBUG_ASSERT(label_expect >= 0);
00310       DEBUG_ASSERT(label_predict < n_classes_);
00311       DEBUG_ASSERT(label_predict >= 0);
00312       
00313       if (label_expect == label_predict) {
00314         local_n_correct++;
00315       }
00316       
00317       confusion_matrix_.ref(label_expect, label_predict) += 1;
00318     }
00319     fx_timer_stop(foldmodule, "test");
00320     
00321     fx_format_result(foldmodule, "n_correct", "%"LI"d",
00322         local_n_correct);
00323     fx_format_result(foldmodule, "n_incorrect", "%"LI"d",
00324         test.n_points() - local_n_correct);
00325     fx_format_result(foldmodule, "p_correct", "%.03f",
00326         local_n_correct * 1.0 / test.n_points());
00327     
00328     n_correct_ += local_n_correct;
00329   }
00330   fx_timer_stop(kfold_module_, "total");
00331 
00332   fx_format_result(kfold_module_, "n_points", "%"LI"d",
00333       data_->n_points());
00334   fx_format_result(kfold_module_, "n_correct", "%"LI"d",
00335       n_correct());
00336   fx_format_result(kfold_module_, "n_incorrect", "%"LI"d",
00337       n_incorrect());
00338   fx_format_result(kfold_module_, "p_correct", "%.03f",
00339       1.0 * portion_correct());
00340 }
00341 
00342 
00343 
00344 
00345 
00346 
00360 template<class TLearner>
00361 class GeneralCrossValidator {
00362   FORBID_ACCIDENTAL_COPIES(GeneralCrossValidator);
00363   
00364  public:
00366   typedef TLearner Learner;
00367   
00368  private:
00369 
00380   int learner_typeid_;
00382   int n_folds_;
00384   const Dataset *data_;
00386   index_t num_data_points_;
00388   datanode *root_module_;
00390   datanode *kfold_module_;
00392   const char *learner_fx_name_;
00393 
00394   
00397   int clsf_n_classes_;
00399   index_t clsf_n_correct_;
00401   Matrix clsf_confusion_matrix_;
00402 
00405   double msq_err_all_folds_;
00406 
00407 
00408  public:
00409   GeneralCrossValidator() {}
00410   ~GeneralCrossValidator() {}
00425   void Init(int learner_typeid,
00426             int default_k,
00427             const Dataset *data_input,
00428             struct datanode *module_root,
00429             const char *learner_fx_name,
00430             const char *kfold_fx_name = "kfold");
00431 
00433   const Dataset& data() const {
00434     return *data_;
00435   }
00436 
00443   void Run(bool randomized);
00444 
00447   index_t clsf_n_correct() {
00448     return clsf_n_correct_;
00449   }
00451   index_t clsf_n_incorrect() {
00452     return data_->n_points() - clsf_n_correct_;
00453   }
00455   double clsf_portion_correct() {
00456     return clsf_n_correct_ * 1.0 / data_->n_points();
00457   }
00464   const Matrix& clsf_confusion_matrix() const {
00465     return clsf_confusion_matrix_;
00466   }
00467 
00468 
00469  private:
00471   void SaveTrainValidationSet_(int i_fold,
00472       const Dataset& train, const Dataset& validation) const;
00473 
00477   void StratifiedSplitCVSet_(int i_fold, index_t num_classes, ArrayList<index_t>& cv_labels_ct, 
00478                              ArrayList<index_t>& cv_labels_startpos, const ArrayList<index_t>& permutation, Dataset *train, Dataset *validation){
00479     // Begin stratified splitting for the i-th fold stratified CV
00480     index_t n_cv_features = data_->n_features();
00481     
00482     // detemine the number of data samples for training and validation according to i_fold
00483     index_t n_cv_validation, i_validation, i_train;
00484     n_cv_validation = 0;
00485     for (index_t i_classes=0; i_classes<num_classes; i_classes++) {
00486       i_validation = 0;
00487       for (index_t j=0; j<cv_labels_ct[i_classes]; j++) {
00488         if ((j - i_fold) % n_folds_ == 0) { // point for validation
00489           i_validation++;
00490         }
00491       }
00492       n_cv_validation = n_cv_validation + i_validation;
00493     }
00494     index_t n_cv_train = num_data_points_ - n_cv_validation;
00495     train->InitBlank();
00496     train->info().InitContinuous(n_cv_features);
00497     train->matrix().Init(n_cv_features, n_cv_train);
00498 
00499     validation->InitBlank();
00500     validation->info().InitContinuous(n_cv_features);
00501     validation->matrix().Init(n_cv_features, n_cv_validation);
00502 
00503     // make training set and vaidation set by concatenation
00504     i_train = 0;
00505     i_validation = 0;
00506     for (index_t i_classes=0; i_classes<num_classes; i_classes++) {
00507       for (index_t j=0; j<cv_labels_ct[i_classes]; j++) {
00508         Vector source, dest;
00509         if ((j - i_fold) % n_folds_ != 0) { // add to training set
00510           train->matrix().MakeColumnVector(i_train, &dest);
00511           i_train++;
00512         }
00513         else { // add to validation set
00514           validation->matrix().MakeColumnVector(i_validation, &dest);
00515           i_validation++;
00516         }
00517         data_->matrix().MakeColumnVector(cv_labels_startpos[i_classes]+j, &source);
00518         dest.CopyValues(source);
00519       }
00520     }
00521   }
00522 
00523 };
00524   
00525 template<class TLearner>
00526 void GeneralCrossValidator<TLearner>::SaveTrainValidationSet_(
00527     int i_fold, const Dataset& train, const Dataset& validation) const {
00528   String train_name;
00529   String validation_name;
00530   
00531   // save training and validation sets for this fold
00532   train_name.InitSprintf("cv_train_%d.csv", i_fold);
00533   validation_name.InitSprintf("cv_validation_%d.csv", i_fold);
00534   
00535   train.WriteCsv(train_name);
00536   validation.WriteCsv(validation_name);
00537 }
00538 
00539 template<class TLearner>
00540 void GeneralCrossValidator<TLearner>::Init(
00541     int learner_typeid,
00542     int default_k,
00543     const Dataset *data_input,
00544     struct datanode *module_root,
00545     const char *learner_fx_name,
00546     const char *kfold_fx_name) {
00548   learner_typeid_ = learner_typeid;
00549   data_ = data_input;
00550 
00551   root_module_ = module_root;
00552   kfold_module_ = fx_submodule(module_root, kfold_fx_name);
00553   n_folds_ = fx_param_int(kfold_module_, "k", default_k);
00554   learner_fx_name_ = learner_fx_name;
00555 
00557   if(learner_typeid_ == 0) {
00558     // get the number of classes
00559     clsf_n_classes_ = data_->n_labels();
00560     clsf_n_correct_ = 0;
00561     // initialize confusion matrix
00562     clsf_confusion_matrix_.Init(clsf_n_classes_, clsf_n_classes_);
00563     clsf_confusion_matrix_.SetZero();
00564   }
00565   else if (learner_typeid_ == 1 || learner_typeid_ == 2) {
00566     clsf_confusion_matrix_.Init(1,1);
00567     // initialize mean squared error over all folds
00568     msq_err_all_folds_ = 0.0;
00569   }
00570 
00571 }
00572 
00573 template<class TLearner>
00574 void GeneralCrossValidator<TLearner>::Run(bool randomized) {  
00575   fx_timer_start(kfold_module_, "total");
00576   num_data_points_ = data_->n_points();
00577 
00579   if (learner_typeid_ == 0) {
00580     // get label information
00581     /* list of labels, need to be integers. e.g. [0,1,2] for a 3-class dataset */
00582     ArrayList<double> cv_labels_list;
00583     /* array of label indices, after grouping. e.g. [c1[0,5,6,7,10,13,17],c2[1,2,4,8,9],c3[...]]*/
00584     ArrayList<index_t> cv_labels_index;
00585     /* counted number of label for each class. e.g. [7,5,8]*/
00586     ArrayList<index_t> cv_labels_ct;
00587     /* start positions of each classes in the cv label list. e.g. [0,7,12] */
00588     ArrayList<index_t> cv_labels_startpos;
00589     // Get label list and label indices from the cross validation data set
00590     index_t num_classes = data_->n_labels();
00591 
00592     cv_labels_list.Init();
00593     cv_labels_index.Init();
00594     cv_labels_ct.Init();
00595     cv_labels_startpos.Init();
00596     data_->GetLabels(cv_labels_list, cv_labels_index, cv_labels_ct, cv_labels_startpos);
00597 
00598     // randomize the original data set within each class if necessary
00599     ArrayList<index_t> permutation;
00600 
00601     if (randomized) {
00602       permutation.Init(num_data_points_);
00603       for (index_t i_classes=0; i_classes<num_classes; i_classes++) {
00604         ArrayList<index_t> sub_permutation; // within class permut indices
00605         math::MakeRandomPermutation(cv_labels_ct[i_classes], &sub_permutation);
00606         // use sub-permutation indicies to form the whole permutation
00607         if (i_classes==0){
00608           for (index_t j=0; j<cv_labels_ct[i_classes]; j++)
00609             permutation[cv_labels_startpos[i_classes]+j] = cv_labels_index[ sub_permutation[j] ];
00610         }
00611         else {
00612           for (index_t j=0; j<cv_labels_ct[i_classes]; j++)
00613             permutation[cv_labels_startpos[i_classes]+j] = cv_labels_index[ cv_labels_ct[i_classes-1]+sub_permutation[j] ];
00614         }
00615         sub_permutation.Clear();
00616       }
00617     } // e.g. [10,13,5,17,0,6,7,,4,9,8,1,2,,...]
00618     else {
00619       permutation.InitCopy(cv_labels_index, cv_labels_index.size()); // e.g. [0,5,6,7,10,13,17,,1,2,4,8,9,,...]
00620     }
00621     // begin CV
00622     for (int i_fold = 0; i_fold < n_folds_; i_fold++) {
00623       Learner classifier;
00624       Dataset train;
00625       Dataset validation;
00626       
00627       index_t local_n_correct = 0;
00628       datanode *learner_module = fx_copy_module(root_module_,
00629           learner_fx_name_, "%s/%d/%s",
00630           kfold_module_->key, i_fold, learner_fx_name_);
00631       datanode *foldmodule = fx_submodule(learner_module, "..");
00632 
00633       // Split labeled data sets according to i_fold. Use Stratified Cross-Validation to ensure 
00634       // that approximately the same portion of data (training/validation) are used for each class.
00635       StratifiedSplitCVSet_(i_fold, num_classes, cv_labels_ct, cv_labels_startpos, permutation, &train, &validation);
00636       if (fx_param_bool(kfold_module_, "save", 0)) {
00637         SaveTrainValidationSet_(i_fold, train, validation);
00638       }
00639       
00640       VERBOSE_MSG(1, "cross: Training fold %d", i_fold);
00641       fx_timer_start(foldmodule, "train");
00642       // training
00643       classifier.InitTrain(learner_typeid_, train, learner_module);
00644       fx_timer_stop(foldmodule, "train");
00645 
00646       // validation; measure method: percent of correctly classified validation samples
00647       fx_timer_start(foldmodule, "validation");
00648       VERBOSE_MSG(1, "cross: Validation fold %d", i_fold);
00649 
00650       for (index_t i = 0; i < validation.n_points(); i++) {
00651         Vector validation_vector_with_label;
00652         Vector validation_vector;
00653         
00654         validation.matrix().MakeColumnVector(i, &validation_vector_with_label);
00655         validation_vector_with_label.MakeSubvector(0, validation.n_features()-1, &validation_vector);
00656         // testing (classification)
00657         int label_predict = int(classifier.Predict(learner_typeid_, validation_vector));
00658         double label_expect_dbl = validation_vector_with_label[validation.n_features()-1];
00659         int label_expect = int(label_expect_dbl);
00660 
00661         DEBUG_ASSERT(double(label_expect) == label_expect_dbl);
00662         DEBUG_ASSERT(label_expect < clsf_n_classes_);
00663         DEBUG_ASSERT(label_expect >= 0);
00664         DEBUG_ASSERT(label_predict < clsf_n_classes_);
00665         DEBUG_ASSERT(label_predict >= 0);
00666         
00667         if (label_expect == label_predict) {
00668           local_n_correct++;
00669         }
00670         clsf_confusion_matrix_.ref(label_expect, label_predict) += 1;
00671       }
00672       fx_timer_stop(foldmodule, "validation");
00673 
00674       fx_format_result(foldmodule, "local_n_correct", "%"LI"d", local_n_correct);
00675       fx_format_result(foldmodule, "local_n_incorrect", "%"LI"d", validation.n_points() - local_n_correct);
00676       fx_format_result(foldmodule, "local_p_correct", "%.03f", local_n_correct * 1.0 / validation.n_points());
00677 
00678       clsf_n_correct_ += local_n_correct;
00679     }
00680     fx_timer_stop(kfold_module_, "total");
00681     
00682     fx_format_result(kfold_module_, "n_points", "%"LI"d", num_data_points_);
00683     fx_format_result(kfold_module_, "n_correct", "%"LI"d", clsf_n_correct());
00684     fx_format_result(kfold_module_, "n_incorrect", "%"LI"d", clsf_n_incorrect());
00685     fx_format_result(kfold_module_, "p_correct", "%.03f", 1.0 * clsf_portion_correct());
00686   }
00688   else if (learner_typeid_ == 1 || learner_typeid_ == 2) {
00689     double accu_msq_err_all_folds = 0.0;
00690 
00691     // randomize the original data set if necessary
00692     ArrayList<index_t> permutation;  
00693     if (randomized) {
00694       math::MakeRandomPermutation(num_data_points_, &permutation);
00695     } else {
00696       math::MakeIdentityPermutation(num_data_points_, &permutation);
00697     }
00698     // begin CV
00699     for (int i_fold = 0; i_fold < n_folds_; i_fold++) {
00700       Learner learner;
00701       Dataset train;
00702       Dataset validation;
00703       
00704       double msq_err_local = 0.0;
00705       double accu_sq_err_local = 0.0;
00706       datanode *learner_module = fx_copy_module(root_module_,
00707           learner_fx_name_, "%s/%d/%s",
00708           kfold_module_->key, i_fold, learner_fx_name_);
00709       datanode *foldmodule = fx_submodule(learner_module, "..");
00710       
00711       // Split general data sets according to i_fold
00712       data_->SplitTrainTest(n_folds_, i_fold, permutation, &train, &validation);
00713       
00714       if (fx_param_bool(kfold_module_, "save", 0)) {
00715         SaveTrainValidationSet_(i_fold, train, validation);
00716       }
00717       
00718       VERBOSE_MSG(1, "cross: Training fold %d", i_fold);
00719       fx_timer_start(foldmodule, "train");
00720       // training
00721       learner.InitTrain(learner_typeid_, train, learner_module); // 0: dummy number of classes
00722       fx_timer_stop(foldmodule, "train");
00723       
00724       // validation
00725       fx_timer_start(foldmodule, "validation");
00726       VERBOSE_MSG(1, "cross: Validation fold %d", i_fold);
00727       for (index_t i = 0; i < validation.n_points(); i++) {
00728         Vector validation_vector_with_label;
00729         Vector validation_vector;
00730         
00731         validation.matrix().MakeColumnVector(i, &validation_vector_with_label);
00732         validation_vector_with_label.MakeSubvector(
00733                                              0, validation.n_features()-1, &validation_vector);
00734         // testing
00735         double value_predict = learner.Predict(learner_typeid_, validation_vector);
00736         double value_true = validation_vector_with_label[validation.n_features()-1];
00737         double value_err = value_predict - value_true;
00738         
00739         // Calculate squared error: sublevel
00740         accu_sq_err_local  += pow(value_err, 2);
00741       }
00742       fx_timer_stop(foldmodule, "validation");
00743       
00744       msq_err_local = accu_sq_err_local / validation.n_points();
00745       fx_format_result(foldmodule, "local_msq_err", "%f", msq_err_local);
00746 
00747       accu_msq_err_all_folds += msq_err_local;
00748     }
00749     fx_timer_stop(kfold_module_, "total");
00750     
00751     // Calculate mean squared error: over all folds
00752     msq_err_all_folds_ = accu_msq_err_all_folds / n_folds_;
00753     fx_format_result(kfold_module_, "msq_err_all_folds", "%f", msq_err_all_folds_);
00754   }
00755   else {
00756     fprintf(stderr, "Other learner types or Unknown learner type id! Cross validation stops!\n");
00757     return;
00758   }
00759 }
00760 
00761 
00762 #endif
Generated on Mon Jan 24 12:04:37 2011 for FASTlib by  doxygen 1.6.3