kde_bandwidth_cv_main.cc

Go to the documentation of this file.
00001 /* MLPACK 0.2
00002  *
00003  * Copyright (c) 2008, 2009 Alexander Gray,
00004  *                          Garry Boyer,
00005  *                          Ryan Riegel,
00006  *                          Nikolaos Vasiloglou,
00007  *                          Dongryeol Lee,
00008  *                          Chip Mappus, 
00009  *                          Nishant Mehta,
00010  *                          Hua Ouyang,
00011  *                          Parikshit Ram,
00012  *                          Long Tran,
00013  *                          Wee Chin Wong
00014  *
00015  * Copyright (c) 2008, 2009 Georgia Institute of Technology
00016  *
00017  * This program is free software; you can redistribute it and/or
00018  * modify it under the terms of the GNU General Public License as
00019  * published by the Free Software Foundation; either version 2 of the
00020  * License, or (at your option) any later version.
00021  *
00022  * This program is distributed in the hope that it will be useful, but
00023  * WITHOUT ANY WARRANTY; without even the implied warranty of
00024  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
00025  * General Public License for more details.
00026  *
00027  * You should have received a copy of the GNU General Public License
00028  * along with this program; if not, write to the Free Software
00029  * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
00030  * 02110-1301, USA.
00031  */
00039 #include "fastlib/fastlib.h"
00040 #include "bandwidth_lscv.h"
00041 #include "dataset_scaler.h"
00042 #include "dualtree_kde.h"
00043 #include "naive_kde.h"
00044 
00110 int main(int argc, char *argv[]) {
00111 
00112   // initialize FastExec (parameter handling stuff)
00113   fx_init(argc, argv, &kde_main_doc);
00114 
00116 
00117   // FASTexec organizes parameters and results into submodules.  Think
00118   // of this as creating a new folder named "kde_module" under the
00119   // root directory (NULL) for the Kde object to work inside.  Here,
00120   // we initialize it with all parameters defined "--kde/...=...".
00121   struct datanode* kde_module = fx_submodule(fx_root, "kde");
00122 
00123   // The reference data file is a required parameter.
00124   const char* references_file_name = fx_param_str_req(fx_root, "data");
00125   
00126   // The query data file defaults to the references.
00127   const char* queries_file_name =
00128     fx_param_str(fx_root, "query", references_file_name);
00129 
00130   // Query and reference datasets, reference weight dataset.
00131   Matrix references;
00132   Matrix reference_weights;
00133   Matrix queries;
00134 
00135   // Flag for telling whether references are equal to queries
00136   bool queries_equal_references = 
00137     !strcmp(queries_file_name, references_file_name);
00138 
00139   // data::Load inits a matrix with the contents of a .csv or .arff.
00140   data::Load(references_file_name, &references);  
00141   if(queries_equal_references) {
00142     queries.Alias(references);
00143   }
00144   else {
00145     data::Load(queries_file_name, &queries);
00146   }
00147   
00148   // If the reference weight file name is specified, then read in,
00149   // otherwise, initialize to uniform weights.
00150   if(fx_param_exists(fx_root, "dwgts")) {
00151     data::Load(fx_param_str(fx_root, "dwgts", NULL), &reference_weights);
00152   }
00153   else {
00154     reference_weights.Init(1, queries.n_cols());
00155     reference_weights.SetAll(1);
00156   }
00157 
00158   // Confirm whether the user asked for scaling of the dataset.
00159   if(!strcmp(fx_param_str(kde_module, "scaling", "none"), "range")) {
00160     DatasetScaler::ScaleDataByMinMax(queries, references,
00161                                      queries_equal_references);
00162   }
00163   else if(!strcmp(fx_param_str(kde_module, "scaling", "none"), 
00164                   "standardize")) {
00165     DatasetScaler::StandardizeData(queries, references, 
00166                                    queries_equal_references);
00167   }
00168 
00169   // There are two options: 1) do bandwidth optimization 2) output a
00170   // goodness score of a given bandwidth.
00171   if(!strcmp(fx_param_str(kde_module, "task", "optimize"), "optimize")) {
00172 
00173     // Optimize bandwidth using least squares cross-validation.
00174     if(!strcmp(fx_param_str(kde_module, "kernel", "gaussian"), "gaussian")) {
00175       BandwidthLSCV::Optimize<GaussianKernelAux>(references, 
00176                                                  reference_weights);
00177     }
00178     else if(!strcmp(fx_param_str(kde_module, "kernel", "epan"), "epan")) {
00179       
00180       // Currently, I have not implemented the direct way to
00181       // cross-validate for the optimal bandwidth using the Epanechnikov
00182       // kernel, so I will use cross-validation using the Gaussian
00183       // kernel with the equivalent kernel scaling.
00184       BandwidthLSCV::Optimize<GaussianKernelAux>(references, 
00185                                                  reference_weights);
00186     }
00187   }
00188   else if(!strcmp(fx_param_str(kde_module, "task", "lscvscore"), 
00189                   "lscvscore")) {
00190     
00191     // Get the bandwidth.
00192     double bandwidth = fx_param_double(kde_module, "bandwidth", 0.1);
00193 
00194     // Optimize bandwidth using least squares cross-validation.
00195     if(!strcmp(fx_param_str(kde_module, "kernel", "gaussian"), "gaussian")) {
00196       BandwidthLSCV::ComputeLSCVScore<GaussianKernelAux>
00197         (references, reference_weights, bandwidth);
00198     }
00199     else if(!strcmp(fx_param_str(kde_module, "kernel", "epan"), "epan")) {
00200       
00201       // Currently, I have not implemented the direct way to
00202       // cross-validate for the optimal bandwidth using the Epanechnikov
00203       // kernel, so I will use cross-validation using the Gaussian
00204       // kernel with the equivalent kernel scaling.
00205       BandwidthLSCV::ComputeLSCVScore<GaussianKernelAux>
00206         (references, reference_weights, bandwidth);
00207     }    
00208   }
00209 
00210   fx_done(fx_root);
00211   return 0;
00212 }
Generated on Mon Jan 24 12:04:38 2011 for FASTlib by  doxygen 1.6.3