simple_nbc.h

Go to the documentation of this file.
00001 /* MLPACK 0.2
00002  *
00003  * Copyright (c) 2008, 2009 Alexander Gray,
00004  *                          Garry Boyer,
00005  *                          Ryan Riegel,
00006  *                          Nikolaos Vasiloglou,
00007  *                          Dongryeol Lee,
00008  *                          Chip Mappus, 
00009  *                          Nishant Mehta,
00010  *                          Hua Ouyang,
00011  *                          Parikshit Ram,
00012  *                          Long Tran,
00013  *                          Wee Chin Wong
00014  *
00015  * Copyright (c) 2008, 2009 Georgia Institute of Technology
00016  *
00017  * This program is free software; you can redistribute it and/or
00018  * modify it under the terms of the GNU General Public License as
00019  * published by the Free Software Foundation; either version 2 of the
00020  * License, or (at your option) any later version.
00021  *
00022  * This program is distributed in the hope that it will be useful, but
00023  * WITHOUT ANY WARRANTY; without even the implied warranty of
00024  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
00025  * General Public License for more details.
00026  *
00027  * You should have received a copy of the GNU General Public License
00028  * along with this program; if not, write to the Free Software
00029  * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
00030  * 02110-1301, USA.
00031  */
00042 #ifndef NBC_H
00043 #define NBC_H
00044 
00045 #include "fastlib/fastlib.h"
00046 #include "phi.h"
00047 #include "math_functions.h"
00048 
00049 const fx_entry_doc parm_nbc_entries[] ={
00050   {"training", FX_TIMER, FX_CUSTOM, NULL,
00051    " The timer to record the training time\n"},
00052   {"testing", FX_TIMER, FX_CUSTOM, NULL,
00053    " The timer to record the testing time\n"},
00054   {"classes", FX_REQUIRED, FX_INT, NULL,
00055    " The number of classes present in the data\n"},
00056   {"features", FX_RESULT, FX_INT, NULL,
00057    " The number of features in the data\n"},
00058   {"examples", FX_RESULT, FX_INT, NULL,
00059    " The number of examples in the training set\n"},
00060   {"tests", FX_RESULT, FX_INT, NULL,
00061    " The number of data points in the test set\n"},
00062   FX_ENTRY_DOC_DONE
00063 };
00064 
00065 const fx_submodule_doc parm_nbc_submodules[] = {
00066   FX_SUBMODULE_DOC_DONE
00067 };
00068 
00069 const fx_module_doc parm_nbc_doc = {
00070   parm_nbc_entries, parm_nbc_submodules,
00071   " Trains the classifier using the training set and "
00072   "outputs the results for the test set\n"
00073 };
00074   
00075 
00106 class SimpleNaiveBayesClassifier {
00107 
00108   // The class for testing this class is made a friend class
00109   friend class TestClassSimpleNBC;
00110 
00111  private:
00112 
00113   // The variables containing the sample mean and variance
00114   // for each of the features with respect to each class
00115   Matrix means_, variances_;
00116 
00117   // The variable containing the class probabilities
00118   ArrayList<double> class_probabilities_;
00119 
00120   // The variable keeping the information about the 
00121   // number of classes present
00122   index_t number_of_classes_;
00123 
00124   datanode *nbc_module_;
00125                    
00126  public:
00127 
00128   SimpleNaiveBayesClassifier(){
00129     means_.Init(0, 0);
00130     variances_.Init(0, 0);
00131     class_probabilities_.Init(0);
00132   }
00133 
00134   ~SimpleNaiveBayesClassifier(){
00135   }
00136 
00149   void InitTrain(const Matrix& data, datanode* nbc_module) {
00150 
00151     ArrayList<double> feature_sum, feature_sum_squared;
00152     index_t number_examples = data.n_cols();
00153     index_t number_features = data.n_rows() - 1;
00154     nbc_module_ = nbc_module;
00155 
00156     // updating the variables, private and local, according to
00157     // the number of features and classes present in the data
00158     number_of_classes_ = fx_param_int_req(nbc_module_,"classes");
00159     class_probabilities_.Resize(number_of_classes_);
00160     means_.Destruct();
00161     means_.Init(number_features, number_of_classes_ );
00162     variances_.Destruct();
00163     variances_.Init(number_features, number_of_classes_);
00164     feature_sum.Init(number_features);
00165     feature_sum_squared.Init(number_features);
00166     for(index_t k = 0; k < number_features; k++) {
00167       feature_sum[k] = 0;
00168       feature_sum_squared[k] = 0;
00169     }
00170     NOTIFY("%"LI"d examples with %"LI"d features each\n",
00171            number_examples, number_features);
00172     fx_result_int(nbc_module_, "features", number_features);
00173     fx_result_int(nbc_module_, "examples", number_examples);
00174 
00175     // calculating the class probabilities as well as the 
00176     // sample mean and variance for each of the features
00177     // with respect to each of the labels
00178     for(index_t i = 0; i < number_of_classes_; i++ ) {
00179       index_t number_of_occurrences = 0;
00180       for (index_t j = 0; j < number_examples; j++) {
00181         index_t flag = (index_t)  data.get(number_features, j);
00182         if(i == flag) {
00183           ++number_of_occurrences;
00184           for(index_t k = 0; k < number_features; k++) {
00185             double tmp = data.get(k, j);
00186             feature_sum[k] += tmp;
00187             feature_sum_squared[k] += tmp*tmp;
00188           }
00189         }
00190       }
00191       class_probabilities_[i] = (double)number_of_occurrences 
00192         / (double)number_examples ;
00193       for(index_t k = 0; k < number_features; k++) {
00194         means_.set(k, i, (feature_sum[k] / number_of_occurrences));
00195         variances_.set(k, i, (feature_sum_squared[k] 
00196                               - (feature_sum[k] * feature_sum[k] / number_of_occurrences))
00197                              /(number_of_occurrences - 1));
00198         feature_sum[k] = 0;
00199         feature_sum_squared[k] = 0;
00200       }
00201     }
00202   }
00203 
00215   void Classify(const Matrix& test_data, Vector *results){
00216 
00217     // Checking that the number of features in the test data is same
00218     // as in the training data
00219     DEBUG_ASSERT(test_data.n_rows() - 1 == means_.n_rows());
00220 
00221     ArrayList<double> tmp_vals;
00222     double *evaluated_result;
00223     index_t number_features = test_data.n_rows() - 1;
00224                         
00225     evaluated_result = (double*)malloc(test_data.n_cols() * sizeof(double));
00226     tmp_vals.Init(number_of_classes_);
00227     
00228     NOTIFY("%"LI"d test cases with %"LI"d features each\n",
00229            test_data.n_cols(), number_features);
00230 
00231     fx_result_int(nbc_module_,"tests", test_data.n_cols());
00232     // Calculating the joint probability for each of the data points
00233     // for each of the classes
00234 
00235     // looping over every test case
00236     for (index_t n = 0; n < test_data.n_cols(); n++) {                  
00237       
00238       //looping over every class
00239       for (index_t i = 0; i < number_of_classes_; i++) {
00240         // Using the log values to prevent floating point underflow
00241         tmp_vals[i] = log(class_probabilities_[i]);
00242         for (index_t j = 0; j < number_features; j++) {
00243           tmp_vals[i] += log(phi(test_data.get(j, n),
00244                                  means_.get(j, i),
00245                                  variances_.get(j, i))
00246                              );   
00247         }
00248       }                 
00249       // Calling a function 'max_element_index' from the file 'math_functions.h
00250       // to obtain the index of the maximum element in an array
00251       evaluated_result[n] = (double) max_element_index(tmp_vals);      
00252     }
00253     // The result is being put in a vector
00254     results->Copy(evaluated_result, test_data.n_cols());
00255     
00256     return;
00257   }
00258 };
00259 #endif
Generated on Mon Jan 24 12:04:38 2011 for FASTlib by  doxygen 1.6.3