nbc_main.cc
Go to the documentation of this file.00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00059 #include "simple_nbc.h"
00060
00061 const fx_entry_doc parm_nbc_main_entries[] = {
00062 {"train", FX_REQUIRED, FX_STR, NULL,
00063 " A file containing the training set\n"},
00064 {"test", FX_REQUIRED, FX_STR, NULL,
00065 " A file containing the test set\n"},
00066 {"output", FX_PARAM, FX_STR, NULL,
00067 " The file in which the output of the test would be "
00068 "written (defaults to 'output.csv')\n"},
00069 FX_ENTRY_DOC_DONE
00070 };
00071
00072 const fx_submodule_doc parm_nbc_main_submodules[] = {
00073 {"nbc", &parm_nbc_doc,
00074 " Trains on a given set and number of classes and "
00075 "tests them on a given set\n"},
00076 FX_SUBMODULE_DOC_DONE
00077 };
00078
00079 const fx_module_doc parm_nbc_main_doc = {
00080 parm_nbc_main_entries, parm_nbc_main_submodules,
00081 "This program test drives the Parametric Naive Bayes \n"
00082 "Classifier assuming that the features are sampled \n"
00083 "from a Gaussian distribution.\n"
00084 };
00085
00086 int main(int argc, char* argv[]) {
00087
00088 fx_module *root = fx_init(argc, argv, &parm_nbc_main_doc);
00089
00091
00092 const char *training_data_filename = fx_param_str_req(root, "train");
00093 Matrix training_data;
00094 data::Load(training_data_filename, &training_data);
00095
00096 const char *testing_data_filename = fx_param_str_req(root, "test");
00097 Matrix testing_data;
00098 data::Load(testing_data_filename, &testing_data);
00099
00101
00103 SimpleNaiveBayesClassifier nbc;
00104
00105 struct datanode* nbc_module = fx_submodule(root, "nbc");
00106
00108 fx_timer_start(nbc_module, "training");
00109
00111 nbc.InitTrain(training_data, nbc_module);
00112
00113 fx_timer_stop(nbc_module, "training");
00114
00117 Vector results;
00118
00119 fx_timer_start(nbc_module, "testing");
00120
00122 nbc.Classify(testing_data, &results);
00123
00124 fx_timer_stop(nbc_module, "testing");
00125
00127
00128 const char *output_filename = fx_param_str(root, "output", "output.csv");
00129
00130 FILE *output_file = fopen(output_filename, "w");
00131
00132 ot::Print(results, output_file);
00133
00134 fclose(output_file);
00135
00136 fx_done(root);
00137
00138 return 1;
00139 }