00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032 #include "simple_nbc.h"
00033 #include "fastlib/base/test.h"
00034
00035 const fx_entry_doc test_simple_nbc_main_entries[] = {
00036 {"nbc/classes", FX_RESERVED, FX_INT, NULL,
00037 "Set during testing."},
00038 FX_ENTRY_DOC_DONE
00039 };
00040
00041 const fx_submodule_doc test_simple_nbc_main_submodules[] = {
00042 {"nbc", &parm_nbc_doc,
00043 " Trains on a given set and number of classes and "
00044 "tests them on a given set\n"},
00045 FX_SUBMODULE_DOC_DONE
00046 };
00047
00048 const fx_module_doc test_simple_nbc_main_doc = {
00049 test_simple_nbc_main_entries, test_simple_nbc_main_submodules,
00050 " Tests the simple nbc class.\n"
00051 };
00052
00053 class TestClassSimpleNBC{
00054 private:
00055 SimpleNaiveBayesClassifier *nbc_test_;
00056 const char *filename_train_, *filename_test_;
00057 const char *train_result_, *test_result_;
00058 index_t number_of_classes_;
00059
00060 public:
00061
00062 void Init(const char *filename_train, const char *filename_test,
00063 const char *train_result, const char *test_result,
00064 const int number_of_classes) {
00065 nbc_test_ = new SimpleNaiveBayesClassifier();
00066 filename_train_ = filename_train;
00067 filename_test_ = filename_test;
00068 train_result_ = train_result;
00069 test_result_ = test_result;
00070 number_of_classes_ = number_of_classes;
00071 }
00072
00073 void Destruct() {
00074 delete nbc_test_;
00075 delete filename_train_;
00076 delete filename_test_;
00077 delete train_result_;
00078 delete test_result_;
00079 }
00080
00081 void TestInitTrain(fx_module *root) {
00082 Matrix train_data, train_res, calc_mat;
00083 data::Load(filename_train_, &train_data);
00084 data::Load(train_result_, &train_res);
00085 struct datanode* nbc_module = fx_submodule(root,"nbc");
00086 fx_set_param_int(nbc_module, "classes", 2);
00087 nbc_test_->InitTrain(train_data, nbc_module);
00088 index_t number_of_features = nbc_test_->means_.n_rows();
00089 calc_mat.Init(2*number_of_features + 1, number_of_classes_);
00090 for(index_t i = 0; i < number_of_features; i++) {
00091 for(index_t j = 0; j < number_of_classes_; j++) {
00092 calc_mat.set(i, j, nbc_test_->means_.get(i, j));
00093 calc_mat.set(i + number_of_features, j, nbc_test_->variances_.get(i, j));
00094 }
00095 }
00096 for(index_t i = 0; i < number_of_classes_; i++) {
00097 calc_mat.set(2 * number_of_features, i, nbc_test_->class_probabilities_[i]);
00098 }
00099
00100 for(index_t i = 0; i < calc_mat.n_rows(); i++) {
00101 for(index_t j = 0; j < number_of_classes_; j++) {
00102 TEST_DOUBLE_APPROX(train_res.get(i, j), calc_mat.get(i, j), 0.0001);
00103 }
00104 }
00105 NONFATAL("Test InitTrain passed...\n");
00106
00107 }
00108
00109 void TestClassify() {
00110 Matrix test_data, test_res;
00111 Vector test_res_vec, calc_vec;
00112 data::Load(filename_test_, &test_data);
00113 data::Load(test_result_, &test_res);
00114 nbc_test_->Classify(test_data, &calc_vec);
00115 index_t number_of_datum = test_data.n_cols();
00116 test_res.MakeColumnVector(0, &test_res_vec);
00117 for(index_t i = 0; i < number_of_datum; i++) {
00118 TEST_ASSERT(test_res_vec.get(i) == calc_vec.get(i));
00119 }
00120 NONFATAL("Test Classify passed...\n");
00121 }
00122
00123 void TestAll(fx_module *root) {
00124 TestInitTrain(root);
00125 TestClassify();
00126 }
00127 };
00128
00129 int main(int argc, char *argv[]) {
00130
00131 fx_module *root =
00132 fx_init(argc, argv, &test_simple_nbc_main_doc);
00133
00134 TestClassSimpleNBC test;
00135
00136 const char *train_data = "trainSet.arff";
00137 const char *train_res = "trainRes.arff";
00138 const char *test_data = "testSet.arff";
00139 const char *test_res = "testRes.arff";
00140 const int num_classes = 2;
00141
00142 test.Init(train_data, test_data, train_res, test_res, num_classes);
00143 test.TestAll(root);
00144
00145 fx_done(root);
00146 }