test_simple_nbc_main.cc

00001 /* MLPACK 0.2
00002  *
00003  * Copyright (c) 2008, 2009 Alexander Gray,
00004  *                          Garry Boyer,
00005  *                          Ryan Riegel,
00006  *                          Nikolaos Vasiloglou,
00007  *                          Dongryeol Lee,
00008  *                          Chip Mappus, 
00009  *                          Nishant Mehta,
00010  *                          Hua Ouyang,
00011  *                          Parikshit Ram,
00012  *                          Long Tran,
00013  *                          Wee Chin Wong
00014  *
00015  * Copyright (c) 2008, 2009 Georgia Institute of Technology
00016  *
00017  * This program is free software; you can redistribute it and/or
00018  * modify it under the terms of the GNU General Public License as
00019  * published by the Free Software Foundation; either version 2 of the
00020  * License, or (at your option) any later version.
00021  *
00022  * This program is distributed in the hope that it will be useful, but
00023  * WITHOUT ANY WARRANTY; without even the implied warranty of
00024  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
00025  * General Public License for more details.
00026  *
00027  * You should have received a copy of the GNU General Public License
00028  * along with this program; if not, write to the Free Software
00029  * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
00030  * 02110-1301, USA.
00031  */
00032 #include "simple_nbc.h"
00033 #include "fastlib/base/test.h"
00034 
00035 const fx_entry_doc test_simple_nbc_main_entries[] = {
00036   {"nbc/classes", FX_RESERVED, FX_INT, NULL,
00037    "Set during testing."},
00038   FX_ENTRY_DOC_DONE
00039 };
00040 
00041 const fx_submodule_doc test_simple_nbc_main_submodules[] = {
00042   {"nbc", &parm_nbc_doc,
00043    " Trains on a given set and number of classes and "
00044    "tests them on a given set\n"},
00045   FX_SUBMODULE_DOC_DONE
00046 };
00047 
00048 const fx_module_doc test_simple_nbc_main_doc = {
00049   test_simple_nbc_main_entries, test_simple_nbc_main_submodules,
00050   " Tests the simple nbc class.\n"
00051 };
00052 
00053 class TestClassSimpleNBC{
00054  private:
00055   SimpleNaiveBayesClassifier *nbc_test_;
00056   const char *filename_train_, *filename_test_;
00057   const char *train_result_, *test_result_;
00058   index_t number_of_classes_;
00059 
00060  public:
00061 
00062   void Init(const char *filename_train, const char *filename_test,
00063             const char *train_result, const char *test_result,
00064             const int number_of_classes) {
00065     nbc_test_ = new SimpleNaiveBayesClassifier();
00066     filename_train_ = filename_train;
00067     filename_test_ = filename_test;
00068     train_result_ = train_result;
00069     test_result_ = test_result;
00070     number_of_classes_ = number_of_classes;
00071   }
00072 
00073   void Destruct() {
00074     delete nbc_test_;
00075     delete filename_train_;
00076     delete filename_test_;
00077     delete train_result_;
00078     delete test_result_;
00079   }
00080 
00081   void TestInitTrain(fx_module *root) {
00082     Matrix train_data, train_res, calc_mat;
00083     data::Load(filename_train_, &train_data);
00084     data::Load(train_result_, &train_res); 
00085     struct datanode* nbc_module = fx_submodule(root,"nbc");
00086     fx_set_param_int(nbc_module, "classes", 2);
00087     nbc_test_->InitTrain(train_data, nbc_module);
00088     index_t number_of_features = nbc_test_->means_.n_rows();
00089     calc_mat.Init(2*number_of_features + 1, number_of_classes_);
00090     for(index_t i = 0; i < number_of_features; i++) {
00091       for(index_t j = 0; j < number_of_classes_; j++) {
00092         calc_mat.set(i, j, nbc_test_->means_.get(i, j));
00093         calc_mat.set(i + number_of_features, j, nbc_test_->variances_.get(i, j));       
00094       }
00095     }
00096     for(index_t i = 0; i < number_of_classes_; i++) {
00097       calc_mat.set(2 * number_of_features, i, nbc_test_->class_probabilities_[i]);      
00098     }
00099     
00100     for(index_t i = 0; i < calc_mat.n_rows(); i++) {
00101       for(index_t j = 0; j < number_of_classes_; j++) {
00102         TEST_DOUBLE_APPROX(train_res.get(i, j), calc_mat.get(i, j), 0.0001);
00103       }
00104     }
00105     NONFATAL("Test InitTrain passed...\n");
00106     
00107   }
00108 
00109   void TestClassify() {
00110     Matrix test_data, test_res;
00111     Vector test_res_vec, calc_vec;
00112     data::Load(filename_test_, &test_data);
00113     data::Load(test_result_, &test_res); 
00114     nbc_test_->Classify(test_data, &calc_vec);
00115     index_t number_of_datum = test_data.n_cols();
00116     test_res.MakeColumnVector(0, &test_res_vec);
00117     for(index_t i = 0; i < number_of_datum; i++) {
00118       TEST_ASSERT(test_res_vec.get(i) == calc_vec.get(i));
00119     }
00120     NONFATAL("Test Classify passed...\n");
00121   }
00122 
00123   void TestAll(fx_module *root) {
00124     TestInitTrain(root);
00125     TestClassify();
00126   }
00127 };
00128 
00129 int main(int argc, char *argv[]) {
00130 
00131   fx_module *root =
00132     fx_init(argc, argv, &test_simple_nbc_main_doc);
00133 
00134   TestClassSimpleNBC test;
00135 
00136   const char *train_data = "trainSet.arff";
00137   const char *train_res = "trainRes.arff";
00138   const char *test_data = "testSet.arff";
00139   const char *test_res = "testRes.arff";
00140   const int num_classes = 2;
00141 
00142   test.Init(train_data, test_data, train_res, test_res, num_classes);
00143   test.TestAll(root);
00144   
00145   fx_done(root);
00146 }
Generated on Mon Jan 24 12:04:38 2011 for FASTlib by  doxygen 1.6.3