00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00038 #ifndef DATA_CROSSVALIDATION
00039 #define DATA_CROSSVALIDATION
00040
00041 #include "fastlib/data/dataset.h"
00042
00043
00044 #include "fastlib/la/matrix.h"
00045 #include "fastlib/fx/fx.h"
00046
00114 template<class TClassifier>
00115 class SimpleCrossValidator {
00116 FORBID_ACCIDENTAL_COPIES(SimpleCrossValidator);
00117
00118 public:
00120 typedef TClassifier Classifier;
00121
00122 private:
00124 const Dataset *data_;
00126 datanode *root_module_;
00128 datanode *kfold_module_;
00130 int n_folds_;
00132 int n_classes_;
00134 const char *classifier_fx_name_;
00136 index_t n_correct_;
00138 Matrix confusion_matrix_;
00139
00140 public:
00141 SimpleCrossValidator() {}
00142 ~SimpleCrossValidator() {}
00143
00159 void Init(
00160 const Dataset *data_with_labels,
00161 int n_labels,
00162 int default_k,
00163 struct datanode *module_root,
00164 const char *classifier_fx_name,
00165 const char *kfold_fx_name = "kfold");
00166
00173 void Run(bool randomized = false);
00174
00175
00177 index_t n_correct() {
00178 return n_correct_;
00179 }
00180
00182 index_t n_incorrect() {
00183 return data_->n_points() - n_correct_;
00184 }
00185
00187 double portion_correct() {
00188 return n_correct_ * 1.0 / data_->n_points();
00189 }
00190
00197 const Matrix& confusion_matrix() const {
00198 return confusion_matrix_;
00199 }
00200
00202 const Dataset& data() const {
00203 return *data_;
00204 }
00205
00206 private:
00207 void SaveTrainTest_(int i_fold,
00208 const Dataset& train, const Dataset& test) const;
00209 };
00210
00211 template<class TClassifier>
00212 void SimpleCrossValidator<TClassifier>::SaveTrainTest_(
00213 int i_fold,
00214 const Dataset& train, const Dataset& test) const {
00215 String train_name;
00216 String test_name;
00217
00218 train_name.InitSprintf("train_%d.csv", i_fold);
00219 test_name.InitSprintf("test_%d.csv", i_fold);
00220
00221 train.WriteCsv(train_name);
00222 test.WriteCsv(test_name);
00223 }
00224
00225
00226 template<class TClassifier>
00227 void SimpleCrossValidator<TClassifier>::Init(
00228 const Dataset *data_with_labels,
00229 int n_labels,
00230 int default_k,
00231 struct datanode *module_root,
00232 const char *classifier_fx_name,
00233 const char *kfold_fx_name) {
00234 data_ = data_with_labels;
00235
00236 if (n_labels <= 0) {
00237 const DatasetFeature *feature =
00238 &data_->info().feature(data_->n_features() - 1);
00239 DEBUG_ASSERT_MSG(feature->type() == DatasetFeature::NOMINAL,
00240 "Must specify number of classes/labels if the feature is not nominal.");
00241 n_classes_ = feature->n_values();
00242 } else {
00243 n_classes_ = n_labels;
00244 }
00245
00246 root_module_ = module_root;
00247 kfold_module_ = fx_submodule(module_root, kfold_fx_name);
00248 classifier_fx_name_ = classifier_fx_name;
00249
00250 n_folds_ = fx_param_int(kfold_module_, "k", default_k);
00251
00252 DEBUG_ONLY(n_correct_ = BIG_BAD_NUMBER);
00253
00254 confusion_matrix_.Init(n_classes_, n_classes_);
00255 confusion_matrix_.SetZero();
00256 }
00257
00258 template<class TClassifier>
00259 void SimpleCrossValidator<TClassifier>::Run(bool randomized) {
00260 ArrayList<index_t> permutation;
00261
00262 if (randomized) {
00263 math::MakeRandomPermutation(data_->n_points(), &permutation);
00264 } else {
00265 math::MakeIdentityPermutation(data_->n_points(), &permutation);
00266 }
00267
00268 n_correct_ = 0;
00269
00270 fx_timer_start(kfold_module_, "total");
00271
00272 for (int i_fold = 0; i_fold < n_folds_; i_fold++) {
00273 Classifier classifier;
00274 Dataset test;
00275 Dataset train;
00276 index_t local_n_correct = 0;
00277 datanode *classifier_module = fx_copy_module(root_module_,
00278 classifier_fx_name_, "%s/%d/%s",
00279 kfold_module_->key, i_fold, classifier_fx_name_);
00280 datanode *foldmodule = fx_submodule(classifier_module, "..");
00281
00282 data_->SplitTrainTest(n_folds_, i_fold, permutation, &train, &test);
00283
00284 if (fx_param_bool(kfold_module_, "save", 0)) {
00285 SaveTrainTest_(i_fold, train, test);
00286 }
00287
00288 VERBOSE_MSG(1, "cross: Training fold %d", i_fold);
00289 fx_timer_start(foldmodule, "train");
00290 classifier.InitTrain(train, n_classes_, classifier_module);
00291 fx_timer_stop(foldmodule, "train");
00292
00293 fx_timer_start(foldmodule, "test");
00294 VERBOSE_MSG(1, "cross: Testing fold %d", i_fold);
00295 for (index_t i = 0; i < test.n_points(); i++) {
00296 Vector test_vector_with_label;
00297 Vector test_vector;
00298
00299 test.matrix().MakeColumnVector(i, &test_vector_with_label);
00300 test_vector_with_label.MakeSubvector(
00301 0, test.n_features()-1, &test_vector);
00302
00303 int label_predict = classifier.Classify(test_vector);
00304 double label_expect_dbl = test_vector_with_label[test.n_features()-1];
00305 int label_expect = int(label_expect_dbl);
00306
00307 DEBUG_ASSERT(double(label_expect) == label_expect_dbl);
00308 DEBUG_ASSERT(label_expect < n_classes_);
00309 DEBUG_ASSERT(label_expect >= 0);
00310 DEBUG_ASSERT(label_predict < n_classes_);
00311 DEBUG_ASSERT(label_predict >= 0);
00312
00313 if (label_expect == label_predict) {
00314 local_n_correct++;
00315 }
00316
00317 confusion_matrix_.ref(label_expect, label_predict) += 1;
00318 }
00319 fx_timer_stop(foldmodule, "test");
00320
00321 fx_format_result(foldmodule, "n_correct", "%"LI"d",
00322 local_n_correct);
00323 fx_format_result(foldmodule, "n_incorrect", "%"LI"d",
00324 test.n_points() - local_n_correct);
00325 fx_format_result(foldmodule, "p_correct", "%.03f",
00326 local_n_correct * 1.0 / test.n_points());
00327
00328 n_correct_ += local_n_correct;
00329 }
00330 fx_timer_stop(kfold_module_, "total");
00331
00332 fx_format_result(kfold_module_, "n_points", "%"LI"d",
00333 data_->n_points());
00334 fx_format_result(kfold_module_, "n_correct", "%"LI"d",
00335 n_correct());
00336 fx_format_result(kfold_module_, "n_incorrect", "%"LI"d",
00337 n_incorrect());
00338 fx_format_result(kfold_module_, "p_correct", "%.03f",
00339 1.0 * portion_correct());
00340 }
00341
00342
00343
00344
00345
00346
00360 template<class TLearner>
00361 class GeneralCrossValidator {
00362 FORBID_ACCIDENTAL_COPIES(GeneralCrossValidator);
00363
00364 public:
00366 typedef TLearner Learner;
00367
00368 private:
00369
00380 int learner_typeid_;
00382 int n_folds_;
00384 const Dataset *data_;
00386 index_t num_data_points_;
00388 datanode *root_module_;
00390 datanode *kfold_module_;
00392 const char *learner_fx_name_;
00393
00394
00397 int clsf_n_classes_;
00399 index_t clsf_n_correct_;
00401 Matrix clsf_confusion_matrix_;
00402
00405 double msq_err_all_folds_;
00406
00407
00408 public:
00409 GeneralCrossValidator() {}
00410 ~GeneralCrossValidator() {}
00425 void Init(int learner_typeid,
00426 int default_k,
00427 const Dataset *data_input,
00428 struct datanode *module_root,
00429 const char *learner_fx_name,
00430 const char *kfold_fx_name = "kfold");
00431
00433 const Dataset& data() const {
00434 return *data_;
00435 }
00436
00443 void Run(bool randomized);
00444
00447 index_t clsf_n_correct() {
00448 return clsf_n_correct_;
00449 }
00451 index_t clsf_n_incorrect() {
00452 return data_->n_points() - clsf_n_correct_;
00453 }
00455 double clsf_portion_correct() {
00456 return clsf_n_correct_ * 1.0 / data_->n_points();
00457 }
00464 const Matrix& clsf_confusion_matrix() const {
00465 return clsf_confusion_matrix_;
00466 }
00467
00468
00469 private:
00471 void SaveTrainValidationSet_(int i_fold,
00472 const Dataset& train, const Dataset& validation) const;
00473
00477 void StratifiedSplitCVSet_(int i_fold, index_t num_classes, ArrayList<index_t>& cv_labels_ct,
00478 ArrayList<index_t>& cv_labels_startpos, const ArrayList<index_t>& permutation, Dataset *train, Dataset *validation){
00479
00480 index_t n_cv_features = data_->n_features();
00481
00482
00483 index_t n_cv_validation, i_validation, i_train;
00484 n_cv_validation = 0;
00485 for (index_t i_classes=0; i_classes<num_classes; i_classes++) {
00486 i_validation = 0;
00487 for (index_t j=0; j<cv_labels_ct[i_classes]; j++) {
00488 if ((j - i_fold) % n_folds_ == 0) {
00489 i_validation++;
00490 }
00491 }
00492 n_cv_validation = n_cv_validation + i_validation;
00493 }
00494 index_t n_cv_train = num_data_points_ - n_cv_validation;
00495 train->InitBlank();
00496 train->info().InitContinuous(n_cv_features);
00497 train->matrix().Init(n_cv_features, n_cv_train);
00498
00499 validation->InitBlank();
00500 validation->info().InitContinuous(n_cv_features);
00501 validation->matrix().Init(n_cv_features, n_cv_validation);
00502
00503
00504 i_train = 0;
00505 i_validation = 0;
00506 for (index_t i_classes=0; i_classes<num_classes; i_classes++) {
00507 for (index_t j=0; j<cv_labels_ct[i_classes]; j++) {
00508 Vector source, dest;
00509 if ((j - i_fold) % n_folds_ != 0) {
00510 train->matrix().MakeColumnVector(i_train, &dest);
00511 i_train++;
00512 }
00513 else {
00514 validation->matrix().MakeColumnVector(i_validation, &dest);
00515 i_validation++;
00516 }
00517 data_->matrix().MakeColumnVector(cv_labels_startpos[i_classes]+j, &source);
00518 dest.CopyValues(source);
00519 }
00520 }
00521 }
00522
00523 };
00524
00525 template<class TLearner>
00526 void GeneralCrossValidator<TLearner>::SaveTrainValidationSet_(
00527 int i_fold, const Dataset& train, const Dataset& validation) const {
00528 String train_name;
00529 String validation_name;
00530
00531
00532 train_name.InitSprintf("cv_train_%d.csv", i_fold);
00533 validation_name.InitSprintf("cv_validation_%d.csv", i_fold);
00534
00535 train.WriteCsv(train_name);
00536 validation.WriteCsv(validation_name);
00537 }
00538
00539 template<class TLearner>
00540 void GeneralCrossValidator<TLearner>::Init(
00541 int learner_typeid,
00542 int default_k,
00543 const Dataset *data_input,
00544 struct datanode *module_root,
00545 const char *learner_fx_name,
00546 const char *kfold_fx_name) {
00548 learner_typeid_ = learner_typeid;
00549 data_ = data_input;
00550
00551 root_module_ = module_root;
00552 kfold_module_ = fx_submodule(module_root, kfold_fx_name);
00553 n_folds_ = fx_param_int(kfold_module_, "k", default_k);
00554 learner_fx_name_ = learner_fx_name;
00555
00557 if(learner_typeid_ == 0) {
00558
00559 clsf_n_classes_ = data_->n_labels();
00560 clsf_n_correct_ = 0;
00561
00562 clsf_confusion_matrix_.Init(clsf_n_classes_, clsf_n_classes_);
00563 clsf_confusion_matrix_.SetZero();
00564 }
00565 else if (learner_typeid_ == 1 || learner_typeid_ == 2) {
00566 clsf_confusion_matrix_.Init(1,1);
00567
00568 msq_err_all_folds_ = 0.0;
00569 }
00570
00571 }
00572
00573 template<class TLearner>
00574 void GeneralCrossValidator<TLearner>::Run(bool randomized) {
00575 fx_timer_start(kfold_module_, "total");
00576 num_data_points_ = data_->n_points();
00577
00579 if (learner_typeid_ == 0) {
00580
00581
00582 ArrayList<double> cv_labels_list;
00583
00584 ArrayList<index_t> cv_labels_index;
00585
00586 ArrayList<index_t> cv_labels_ct;
00587
00588 ArrayList<index_t> cv_labels_startpos;
00589
00590 index_t num_classes = data_->n_labels();
00591
00592 cv_labels_list.Init();
00593 cv_labels_index.Init();
00594 cv_labels_ct.Init();
00595 cv_labels_startpos.Init();
00596 data_->GetLabels(cv_labels_list, cv_labels_index, cv_labels_ct, cv_labels_startpos);
00597
00598
00599 ArrayList<index_t> permutation;
00600
00601 if (randomized) {
00602 permutation.Init(num_data_points_);
00603 for (index_t i_classes=0; i_classes<num_classes; i_classes++) {
00604 ArrayList<index_t> sub_permutation;
00605 math::MakeRandomPermutation(cv_labels_ct[i_classes], &sub_permutation);
00606
00607 if (i_classes==0){
00608 for (index_t j=0; j<cv_labels_ct[i_classes]; j++)
00609 permutation[cv_labels_startpos[i_classes]+j] = cv_labels_index[ sub_permutation[j] ];
00610 }
00611 else {
00612 for (index_t j=0; j<cv_labels_ct[i_classes]; j++)
00613 permutation[cv_labels_startpos[i_classes]+j] = cv_labels_index[ cv_labels_ct[i_classes-1]+sub_permutation[j] ];
00614 }
00615 sub_permutation.Clear();
00616 }
00617 }
00618 else {
00619 permutation.InitCopy(cv_labels_index, cv_labels_index.size());
00620 }
00621
00622 for (int i_fold = 0; i_fold < n_folds_; i_fold++) {
00623 Learner classifier;
00624 Dataset train;
00625 Dataset validation;
00626
00627 index_t local_n_correct = 0;
00628 datanode *learner_module = fx_copy_module(root_module_,
00629 learner_fx_name_, "%s/%d/%s",
00630 kfold_module_->key, i_fold, learner_fx_name_);
00631 datanode *foldmodule = fx_submodule(learner_module, "..");
00632
00633
00634
00635 StratifiedSplitCVSet_(i_fold, num_classes, cv_labels_ct, cv_labels_startpos, permutation, &train, &validation);
00636 if (fx_param_bool(kfold_module_, "save", 0)) {
00637 SaveTrainValidationSet_(i_fold, train, validation);
00638 }
00639
00640 VERBOSE_MSG(1, "cross: Training fold %d", i_fold);
00641 fx_timer_start(foldmodule, "train");
00642
00643 classifier.InitTrain(learner_typeid_, train, learner_module);
00644 fx_timer_stop(foldmodule, "train");
00645
00646
00647 fx_timer_start(foldmodule, "validation");
00648 VERBOSE_MSG(1, "cross: Validation fold %d", i_fold);
00649
00650 for (index_t i = 0; i < validation.n_points(); i++) {
00651 Vector validation_vector_with_label;
00652 Vector validation_vector;
00653
00654 validation.matrix().MakeColumnVector(i, &validation_vector_with_label);
00655 validation_vector_with_label.MakeSubvector(0, validation.n_features()-1, &validation_vector);
00656
00657 int label_predict = int(classifier.Predict(learner_typeid_, validation_vector));
00658 double label_expect_dbl = validation_vector_with_label[validation.n_features()-1];
00659 int label_expect = int(label_expect_dbl);
00660
00661 DEBUG_ASSERT(double(label_expect) == label_expect_dbl);
00662 DEBUG_ASSERT(label_expect < clsf_n_classes_);
00663 DEBUG_ASSERT(label_expect >= 0);
00664 DEBUG_ASSERT(label_predict < clsf_n_classes_);
00665 DEBUG_ASSERT(label_predict >= 0);
00666
00667 if (label_expect == label_predict) {
00668 local_n_correct++;
00669 }
00670 clsf_confusion_matrix_.ref(label_expect, label_predict) += 1;
00671 }
00672 fx_timer_stop(foldmodule, "validation");
00673
00674 fx_format_result(foldmodule, "local_n_correct", "%"LI"d", local_n_correct);
00675 fx_format_result(foldmodule, "local_n_incorrect", "%"LI"d", validation.n_points() - local_n_correct);
00676 fx_format_result(foldmodule, "local_p_correct", "%.03f", local_n_correct * 1.0 / validation.n_points());
00677
00678 clsf_n_correct_ += local_n_correct;
00679 }
00680 fx_timer_stop(kfold_module_, "total");
00681
00682 fx_format_result(kfold_module_, "n_points", "%"LI"d", num_data_points_);
00683 fx_format_result(kfold_module_, "n_correct", "%"LI"d", clsf_n_correct());
00684 fx_format_result(kfold_module_, "n_incorrect", "%"LI"d", clsf_n_incorrect());
00685 fx_format_result(kfold_module_, "p_correct", "%.03f", 1.0 * clsf_portion_correct());
00686 }
00688 else if (learner_typeid_ == 1 || learner_typeid_ == 2) {
00689 double accu_msq_err_all_folds = 0.0;
00690
00691
00692 ArrayList<index_t> permutation;
00693 if (randomized) {
00694 math::MakeRandomPermutation(num_data_points_, &permutation);
00695 } else {
00696 math::MakeIdentityPermutation(num_data_points_, &permutation);
00697 }
00698
00699 for (int i_fold = 0; i_fold < n_folds_; i_fold++) {
00700 Learner learner;
00701 Dataset train;
00702 Dataset validation;
00703
00704 double msq_err_local = 0.0;
00705 double accu_sq_err_local = 0.0;
00706 datanode *learner_module = fx_copy_module(root_module_,
00707 learner_fx_name_, "%s/%d/%s",
00708 kfold_module_->key, i_fold, learner_fx_name_);
00709 datanode *foldmodule = fx_submodule(learner_module, "..");
00710
00711
00712 data_->SplitTrainTest(n_folds_, i_fold, permutation, &train, &validation);
00713
00714 if (fx_param_bool(kfold_module_, "save", 0)) {
00715 SaveTrainValidationSet_(i_fold, train, validation);
00716 }
00717
00718 VERBOSE_MSG(1, "cross: Training fold %d", i_fold);
00719 fx_timer_start(foldmodule, "train");
00720
00721 learner.InitTrain(learner_typeid_, train, learner_module);
00722 fx_timer_stop(foldmodule, "train");
00723
00724
00725 fx_timer_start(foldmodule, "validation");
00726 VERBOSE_MSG(1, "cross: Validation fold %d", i_fold);
00727 for (index_t i = 0; i < validation.n_points(); i++) {
00728 Vector validation_vector_with_label;
00729 Vector validation_vector;
00730
00731 validation.matrix().MakeColumnVector(i, &validation_vector_with_label);
00732 validation_vector_with_label.MakeSubvector(
00733 0, validation.n_features()-1, &validation_vector);
00734
00735 double value_predict = learner.Predict(learner_typeid_, validation_vector);
00736 double value_true = validation_vector_with_label[validation.n_features()-1];
00737 double value_err = value_predict - value_true;
00738
00739
00740 accu_sq_err_local += pow(value_err, 2);
00741 }
00742 fx_timer_stop(foldmodule, "validation");
00743
00744 msq_err_local = accu_sq_err_local / validation.n_points();
00745 fx_format_result(foldmodule, "local_msq_err", "%f", msq_err_local);
00746
00747 accu_msq_err_all_folds += msq_err_local;
00748 }
00749 fx_timer_stop(kfold_module_, "total");
00750
00751
00752 msq_err_all_folds_ = accu_msq_err_all_folds / n_folds_;
00753 fx_format_result(kfold_module_, "msq_err_all_folds", "%f", msq_err_all_folds_);
00754 }
00755 else {
00756 fprintf(stderr, "Other learner types or Unknown learner type id! Cross validation stops!\n");
00757 return;
00758 }
00759 }
00760
00761
00762 #endif