bandwidth_lscv.h

Go to the documentation of this file.
00001 /* MLPACK 0.2
00002  *
00003  * Copyright (c) 2008, 2009 Alexander Gray,
00004  *                          Garry Boyer,
00005  *                          Ryan Riegel,
00006  *                          Nikolaos Vasiloglou,
00007  *                          Dongryeol Lee,
00008  *                          Chip Mappus, 
00009  *                          Nishant Mehta,
00010  *                          Hua Ouyang,
00011  *                          Parikshit Ram,
00012  *                          Long Tran,
00013  *                          Wee Chin Wong
00014  *
00015  * Copyright (c) 2008, 2009 Georgia Institute of Technology
00016  *
00017  * This program is free software; you can redistribute it and/or
00018  * modify it under the terms of the GNU General Public License as
00019  * published by the Free Software Foundation; either version 2 of the
00020  * License, or (at your option) any later version.
00021  *
00022  * This program is distributed in the hope that it will be useful, but
00023  * WITHOUT ANY WARRANTY; without even the implied warranty of
00024  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
00025  * General Public License for more details.
00026  *
00027  * You should have received a copy of the GNU General Public License
00028  * along with this program; if not, write to the Free Software
00029  * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
00030  * 02110-1301, USA.
00031  */
00039 #ifndef BANDWIDTH_LSCV_H
00040 #define BANDWIDTH_LSCV_H
00041 
00042 #include "dualtree_kde.h"
00043 #include "dualtree_kde_cv.h"
00044 
00045 class BandwidthLSCV {
00046   
00047  private:
00048 
00049   static double plugin_bandwidth_(const Matrix &references) {
00050     
00051     double avg_sdev = 0;
00052     int num_dims = references.n_rows();
00053     int num_data = references.n_cols();
00054     Vector mean_vector;
00055     mean_vector.Init(references.n_rows());
00056     mean_vector.SetZero();
00057     
00058     // First compute the mean vector.
00059     for(index_t i = 0; i < references.n_cols(); i++) {
00060       for(index_t j = 0; j < references.n_rows(); j++) {
00061         mean_vector[j] += references.get(j, i);
00062       }
00063     }
00064     la::Scale(1.0 / ((double) num_data), &mean_vector);
00065     
00066     // Loop over the dataset again and compute variance along each
00067     // dimension.
00068     for(index_t j = 0; j < num_dims; j++) {
00069       double sdev = 0;
00070       for(index_t i = 0; i < num_data; i++) {
00071         sdev += math::Sqr(references.get(j, i) - mean_vector[j]);
00072       }
00073       sdev /= ((double) num_data - 1);
00074       sdev = sqrt(sdev);
00075       avg_sdev += sdev;
00076     }
00077     avg_sdev /= ((double) num_dims);
00078 
00079     double plugin_bw = 
00080       pow((4.0 / (num_dims + 2.0)), 1.0 / (num_dims + 4.0)) * avg_sdev * 
00081       pow(num_data, -1.0 / (num_dims + 4.0));
00082 
00083     return plugin_bw;
00084   }
00085   
00086  public:
00087 
00088   template<typename TKernelAux>
00089   static void ComputeLSCVScore(const Matrix &references,
00090                                const Matrix &reference_weights,
00091                                double bandwidth) {
00092 
00093     // Get the parameters.
00094     struct datanode *kde_module = fx_submodule(fx_root, "kde");
00095 
00096     // Kernel object.
00097     typename TKernelAux::TKernel kernel;
00098     
00099     // LSCV score.
00100     double lscv_score;
00101 
00102     // Set the bandwidth of the kernel.
00103     kernel.Init(bandwidth);
00104     
00105     printf("Trying the bandwidth value of %g...\n", bandwidth);
00106     
00107     // Need to run density estimates twice: on $h$ and $sqrt(2)
00108     // h$. Free memory after each run to minimize memory usage.
00109     fx_set_param_double(kde_module, "bandwidth", bandwidth);
00110     DualtreeKdeCV<TKernelAux> *fast_kde_on_bandwidth = 
00111       new DualtreeKdeCV<TKernelAux>();
00112     fast_kde_on_bandwidth->Init(references, reference_weights, kde_module);
00113     lscv_score = fast_kde_on_bandwidth->Compute();
00114     delete fast_kde_on_bandwidth;
00115     
00116     printf("Least squares cross-validation score is %g...\n\n", lscv_score); 
00117   }
00118 
00119   template<typename TKernelAux>
00120   static void Optimize(const Matrix &references,
00121                        const Matrix &reference_weights) {
00122     
00123     // Get the parameters.
00124     struct datanode *kde_module = fx_submodule(fx_root, "kde");
00125 
00126     // Minimum LSCV score so far.
00127     double min_lscv_score = DBL_MAX;
00128 
00129     // Kernel object.
00130     typename TKernelAux::TKernel kernel;
00131 
00132     // The current lower and upper search limit.
00133     double plugin_bandwidth = plugin_bandwidth_(references);
00134     double current_lower_search_limit = plugin_bandwidth * 0.00001;
00135     double current_upper_search_limit = plugin_bandwidth;
00136     double min_bandwidth = DBL_MAX;
00137 
00138     printf("Searching the optimal bandwidth in [%g %g]...\n",
00139            current_lower_search_limit, current_upper_search_limit);
00140 
00141     do {
00142       
00143       // Set bandwidth to the middle of the lower and the upper limit
00144       // and initialize the kernel.
00145       double bandwidth = current_upper_search_limit;
00146       kernel.Init(bandwidth);
00147 
00148       printf("Trying the bandwidth value of %g...\n", bandwidth);
00149 
00150       // Need to run density estimates twice: on $h$ and $sqrt(2)
00151       // h$. Free memory after each run to minimize memory usage.
00152       fx_set_param_double(kde_module, "bandwidth", bandwidth);
00153       DualtreeKdeCV<TKernelAux> *fast_kde_on_bandwidth =
00154         new DualtreeKdeCV<TKernelAux>();
00155       fast_kde_on_bandwidth->Init(references, reference_weights, kde_module);
00156       double lscv_score = fast_kde_on_bandwidth->Compute();
00157       delete fast_kde_on_bandwidth;
00158 
00159       printf("Least squares cross-validation score is %g...\n\n", lscv_score);
00160 
00161       if(lscv_score < min_lscv_score) {
00162         min_lscv_score = lscv_score;
00163         min_bandwidth = bandwidth;
00164       }
00165       current_upper_search_limit /= 2.0;
00166       
00167     } while(current_upper_search_limit > current_lower_search_limit);
00168 
00169 
00170     // Output the final density estimates that minimize the least
00171     // squares cross-validation to the file.
00172     printf("Minimum score was %g and achieved at the bandwidth value of %g\n", 
00173            min_lscv_score, min_bandwidth);
00174 
00175   }
00176 
00177 };
00178 
00179 #endif
Generated on Mon Jan 24 12:04:38 2011 for FASTlib by  doxygen 1.6.3