00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00042 #ifndef NBC_H
00043 #define NBC_H
00044
00045 #include "fastlib/fastlib.h"
00046 #include "phi.h"
00047 #include "math_functions.h"
00048
00049 const fx_entry_doc parm_nbc_entries[] ={
00050 {"training", FX_TIMER, FX_CUSTOM, NULL,
00051 " The timer to record the training time\n"},
00052 {"testing", FX_TIMER, FX_CUSTOM, NULL,
00053 " The timer to record the testing time\n"},
00054 {"classes", FX_REQUIRED, FX_INT, NULL,
00055 " The number of classes present in the data\n"},
00056 {"features", FX_RESULT, FX_INT, NULL,
00057 " The number of features in the data\n"},
00058 {"examples", FX_RESULT, FX_INT, NULL,
00059 " The number of examples in the training set\n"},
00060 {"tests", FX_RESULT, FX_INT, NULL,
00061 " The number of data points in the test set\n"},
00062 FX_ENTRY_DOC_DONE
00063 };
00064
00065 const fx_submodule_doc parm_nbc_submodules[] = {
00066 FX_SUBMODULE_DOC_DONE
00067 };
00068
00069 const fx_module_doc parm_nbc_doc = {
00070 parm_nbc_entries, parm_nbc_submodules,
00071 " Trains the classifier using the training set and "
00072 "outputs the results for the test set\n"
00073 };
00074
00075
00106 class SimpleNaiveBayesClassifier {
00107
00108
00109 friend class TestClassSimpleNBC;
00110
00111 private:
00112
00113
00114
00115 Matrix means_, variances_;
00116
00117
00118 ArrayList<double> class_probabilities_;
00119
00120
00121
00122 index_t number_of_classes_;
00123
00124 datanode *nbc_module_;
00125
00126 public:
00127
00128 SimpleNaiveBayesClassifier(){
00129 means_.Init(0, 0);
00130 variances_.Init(0, 0);
00131 class_probabilities_.Init(0);
00132 }
00133
00134 ~SimpleNaiveBayesClassifier(){
00135 }
00136
00149 void InitTrain(const Matrix& data, datanode* nbc_module) {
00150
00151 ArrayList<double> feature_sum, feature_sum_squared;
00152 index_t number_examples = data.n_cols();
00153 index_t number_features = data.n_rows() - 1;
00154 nbc_module_ = nbc_module;
00155
00156
00157
00158 number_of_classes_ = fx_param_int_req(nbc_module_,"classes");
00159 class_probabilities_.Resize(number_of_classes_);
00160 means_.Destruct();
00161 means_.Init(number_features, number_of_classes_ );
00162 variances_.Destruct();
00163 variances_.Init(number_features, number_of_classes_);
00164 feature_sum.Init(number_features);
00165 feature_sum_squared.Init(number_features);
00166 for(index_t k = 0; k < number_features; k++) {
00167 feature_sum[k] = 0;
00168 feature_sum_squared[k] = 0;
00169 }
00170 NOTIFY("%"LI"d examples with %"LI"d features each\n",
00171 number_examples, number_features);
00172 fx_result_int(nbc_module_, "features", number_features);
00173 fx_result_int(nbc_module_, "examples", number_examples);
00174
00175
00176
00177
00178 for(index_t i = 0; i < number_of_classes_; i++ ) {
00179 index_t number_of_occurrences = 0;
00180 for (index_t j = 0; j < number_examples; j++) {
00181 index_t flag = (index_t) data.get(number_features, j);
00182 if(i == flag) {
00183 ++number_of_occurrences;
00184 for(index_t k = 0; k < number_features; k++) {
00185 double tmp = data.get(k, j);
00186 feature_sum[k] += tmp;
00187 feature_sum_squared[k] += tmp*tmp;
00188 }
00189 }
00190 }
00191 class_probabilities_[i] = (double)number_of_occurrences
00192 / (double)number_examples ;
00193 for(index_t k = 0; k < number_features; k++) {
00194 means_.set(k, i, (feature_sum[k] / number_of_occurrences));
00195 variances_.set(k, i, (feature_sum_squared[k]
00196 - (feature_sum[k] * feature_sum[k] / number_of_occurrences))
00197 /(number_of_occurrences - 1));
00198 feature_sum[k] = 0;
00199 feature_sum_squared[k] = 0;
00200 }
00201 }
00202 }
00203
00215 void Classify(const Matrix& test_data, Vector *results){
00216
00217
00218
00219 DEBUG_ASSERT(test_data.n_rows() - 1 == means_.n_rows());
00220
00221 ArrayList<double> tmp_vals;
00222 double *evaluated_result;
00223 index_t number_features = test_data.n_rows() - 1;
00224
00225 evaluated_result = (double*)malloc(test_data.n_cols() * sizeof(double));
00226 tmp_vals.Init(number_of_classes_);
00227
00228 NOTIFY("%"LI"d test cases with %"LI"d features each\n",
00229 test_data.n_cols(), number_features);
00230
00231 fx_result_int(nbc_module_,"tests", test_data.n_cols());
00232
00233
00234
00235
00236 for (index_t n = 0; n < test_data.n_cols(); n++) {
00237
00238
00239 for (index_t i = 0; i < number_of_classes_; i++) {
00240
00241 tmp_vals[i] = log(class_probabilities_[i]);
00242 for (index_t j = 0; j < number_features; j++) {
00243 tmp_vals[i] += log(phi(test_data.get(j, n),
00244 means_.get(j, i),
00245 variances_.get(j, i))
00246 );
00247 }
00248 }
00249
00250
00251 evaluated_result[n] = (double) max_element_index(tmp_vals);
00252 }
00253
00254 results->Copy(evaluated_result, test_data.n_cols());
00255
00256 return;
00257 }
00258 };
00259 #endif