dualtree_kde_cv.h

Go to the documentation of this file.
00001 /* MLPACK 0.2
00002  *
00003  * Copyright (c) 2008, 2009 Alexander Gray,
00004  *                          Garry Boyer,
00005  *                          Ryan Riegel,
00006  *                          Nikolaos Vasiloglou,
00007  *                          Dongryeol Lee,
00008  *                          Chip Mappus, 
00009  *                          Nishant Mehta,
00010  *                          Hua Ouyang,
00011  *                          Parikshit Ram,
00012  *                          Long Tran,
00013  *                          Wee Chin Wong
00014  *
00015  * Copyright (c) 2008, 2009 Georgia Institute of Technology
00016  *
00017  * This program is free software; you can redistribute it and/or
00018  * modify it under the terms of the GNU General Public License as
00019  * published by the Free Software Foundation; either version 2 of the
00020  * License, or (at your option) any later version.
00021  *
00022  * This program is distributed in the hope that it will be useful, but
00023  * WITHOUT ANY WARRANTY; without even the implied warranty of
00024  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
00025  * General Public License for more details.
00026  *
00027  * You should have received a copy of the GNU General Public License
00028  * along with this program; if not, write to the Free Software
00029  * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
00030  * 02110-1301, USA.
00031  */
00090 #ifndef DUALTREE_KDE_CV_H
00091 #define DUALTREE_KDE_CV_H
00092 
00093 #define INSIDE_DUALTREE_KDE_CV_H
00094 
00095 #include "fastlib/fastlib.h"
00096 #include "mlpack/series_expansion/farfield_expansion.h"
00097 #include "mlpack/series_expansion/local_expansion.h"
00098 #include "mlpack/series_expansion/mult_farfield_expansion.h"
00099 #include "mlpack/series_expansion/mult_local_expansion.h"
00100 #include "mlpack/series_expansion/kernel_aux.h"
00101 #include "gen_metric_tree.h"
00102 #include "dualtree_kde_cv_common.h"
00103 #include "kde_cv_stat.h"
00104 
00127 template<typename TKernelAux>
00128 class DualtreeKdeCV {
00129 
00130   friend class DualtreeKdeCommon;
00131   
00132   friend class DualtreeKdeCVCommon;
00133 
00134  public:
00135   
00136   // our tree type using the KdeStat
00137   typedef GeneralBinarySpaceTree<DBallBound < LMetric<2>, Vector>, Matrix, KdeCVStat<TKernelAux> > Tree;
00138     
00139  private:
00140 
00142 
00146   static const int num_initial_samples_per_query_ = 25;
00147 
00148   static const int sample_multiple_ = 1;
00149 
00151 
00154   struct datanode *module_;
00155 
00159   TKernelAux first_ka_;
00160 
00164   TKernelAux second_ka_;
00165 
00168   Matrix rset_;
00169   
00172   Tree *rroot_;
00173 
00176   Vector rset_weights_;
00177 
00178   double first_sum_l_;
00179   
00184   double first_sum_e_;
00185 
00186   double first_sum_u_;
00187 
00188   double second_sum_l_;
00189 
00194   double second_sum_e_;
00195 
00196   double second_sum_u_;
00197 
00200   double first_mult_const_;
00201   
00204   double second_mult_const_;
00205 
00206   double first_used_error_;
00207   
00208   double second_used_error_;
00209   
00210   double n_pruned_;
00211 
00214   double rset_weight_sum_;
00215 
00219   double relative_error_;
00220 
00225   double threshold_;
00226   
00229   int num_farfield_to_local_prunes_;
00230 
00233   int num_farfield_prunes_;
00234   
00237   int num_local_prunes_;
00238   
00241   int num_finite_difference_prunes_;
00242 
00245   int num_monte_carlo_prunes_;
00246   
00250   ArrayList<index_t> old_from_new_references_;
00251 
00253 
00256   void DualtreeKdeCVBase_(Tree *qnode, Tree *rnode, double probability);
00257 
00261   bool PrunableEnhanced_(Tree *qnode, Tree *rnode, double probability, 
00262                          DRange &dsqd_range, DRange &first_kernel_value_range,
00263                          DRange &second_kernel_value_range, double &first_dl,
00264                          double &first_de, double &first_du, 
00265                          double &first_used_error, int &first_order,
00266                          double &second_dl, double &second_de,
00267                          double &second_du, double &second_used_error, 
00268                          int &second_order, double &n_pruned);
00269   
00270   void EvalUnnormOnSq_(index_t reference_point_index, double squared_distance,
00271                        double *first_kernel_value,
00272                        double *second_kernel_value);
00273 
00284   bool DualtreeKdeCVCanonical_(Tree *qnode, Tree *rnode, double probability);
00285 
00290   void PreProcess(Tree *node);
00291 
00292  public:
00293 
00295 
00298   DualtreeKdeCV() {
00299     rroot_ = NULL;
00300   }
00301 
00304   ~DualtreeKdeCV() {    
00305     delete rroot_;
00306   }
00307 
00309 
00310 
00312 
00313   double Compute() {
00314 
00315     // Compute normalization constant.
00316     first_mult_const_ = 1.0 / 
00317       (pow(sqrt(2), rset_.n_rows()) * 
00318        second_ka_.kernel_.CalcNormConstant(rset_.n_rows()));
00319     second_mult_const_ = 1.0 /
00320       second_ka_.kernel_.CalcNormConstant(rset_.n_rows());
00321 
00322     // Set accuracy parameters.
00323     relative_error_ = fx_param_double(module_, "relative_error", 0.1);
00324     threshold_ = fx_param_double(module_, "threshold", 0) *
00325       first_ka_.kernel_.CalcNormConstant(rset_.n_rows());
00326 
00327     // Reset prune statistics.
00328     num_finite_difference_prunes_ = num_monte_carlo_prunes_ =
00329       num_farfield_to_local_prunes_ = num_farfield_prunes_ = 
00330       num_local_prunes_ = 0;
00331 
00332     printf("\nStarting fast KDE on bandwidth value of %g...\n",
00333            sqrt(second_ka_.kernel_.bandwidth_sq()));
00334     fx_timer_start(NULL, "fast_kde_compute");
00335 
00336     // Reset the accumulated sum...
00337     first_sum_l_ = first_sum_e_ = 0;
00338     first_sum_u_ = rset_weight_sum_ * rroot_->count();
00339     second_sum_l_ = second_sum_e_ = 0;
00340     second_sum_u_ = rset_weight_sum_ * rroot_->count();
00341     first_used_error_ = second_used_error_ = 0;
00342     n_pruned_ = 0;
00343 
00344     // Preprocessing step for initializing series expansion objects
00345     PreProcess(rroot_);
00346         
00347     // Get the required probability guarantee for each query and call
00348     // the main routine.
00349     double probability = fx_param_double(module_, "probability", 1);
00350     DualtreeKdeCVCanonical_(rroot_, rroot_, probability);
00351     fx_timer_stop(NULL, "fast_kde_compute");
00352     printf("\nFast KDE completed...\n");
00353     printf("Finite difference prunes: %d\n", num_finite_difference_prunes_);
00354     printf("Monte Carlo prunes: %d\n", num_monte_carlo_prunes_);
00355     printf("F2L prunes: %d\n", num_farfield_to_local_prunes_);
00356     printf("F prunes: %d\n", num_farfield_prunes_);
00357     printf("L prunes: %d\n", num_local_prunes_);
00358 
00359     // Normalize accordingly.
00360     first_sum_e_ *= (first_mult_const_ / rset_weight_sum_);
00361     second_sum_e_ *= (second_mult_const_ / rset_weight_sum_);
00362 
00363     // Return the sum of the two sums.
00364     double lscv_score = 
00365       (first_sum_e_ - 2.0 * second_sum_e_ +
00366        2.0 * second_ka_.kernel_.EvalUnnormOnSq(0.0) / 
00367        second_ka_.kernel_.CalcNormConstant(rset_.n_rows())) /
00368       ((double) rset_.n_cols());
00369     return lscv_score;
00370   }
00371 
00372   void Init(const Matrix &references, const Matrix &rset_weights,
00373             struct datanode *module_in) {
00374 
00375     // point to the incoming module
00376     module_ = module_in;
00377 
00378     // Read in the number of points owned by a leaf.
00379     int leaflen = fx_param_int(module_in, "leaflen", 20);
00380     
00381     // Copy reference dataset and reference weights and compute its
00382     // sum.
00383     rset_.Copy(references);
00384     rset_weights_.Init(rset_weights.n_cols());
00385     rset_weight_sum_ = 0;
00386     for(index_t i = 0; i < rset_weights.n_cols(); i++) {
00387       rset_weights_[i] = rset_weights.get(0, i);
00388       rset_weight_sum_ += rset_weights_[i];
00389     }
00390 
00391     // Construct query and reference trees. Shuffle the reference
00392     // weights according to the permutation of the reference set in
00393     // the reference tree.
00394     fx_timer_start(NULL, "tree_d");
00395     rroot_ = proximity::MakeGenMetricTree<Tree>(rset_, leaflen,
00396                                                 &old_from_new_references_, 
00397                                                 NULL);
00398     DualtreeKdeCommon::ShuffleAccordingToPermutation
00399       (rset_weights_, old_from_new_references_);
00400     fx_timer_stop(NULL, "tree_d");
00401 
00402     // Initialize the kernel.
00403     double bandwidth = fx_param_double_req(module_, "bandwidth");
00404 
00405     // Initialize the series expansion object. I should think about
00406     // whether this is true for kernels other than Gaussian.
00407     if(rset_.n_rows() <= 2) {
00408       first_ka_.Init(sqrt(2) * bandwidth, fx_param_int(module_, "order", 7), 
00409                      rset_.n_rows());
00410       second_ka_.Init(bandwidth, fx_param_int(module_, "order", 7), 
00411                       rset_.n_rows());
00412     }
00413     else if(rset_.n_rows() <= 3) {
00414       first_ka_.Init(sqrt(2) * bandwidth, fx_param_int(module_, "order", 5), 
00415                      rset_.n_rows());
00416       second_ka_.Init(bandwidth, fx_param_int(module_, "order", 5), 
00417                       rset_.n_rows());
00418     }
00419     else if(rset_.n_rows() <= 5) {
00420       first_ka_.Init(sqrt(2) * bandwidth, fx_param_int(module_, "order", 3), 
00421                      rset_.n_rows());
00422       second_ka_.Init(bandwidth, fx_param_int(module_, "order", 3), 
00423                       rset_.n_rows());
00424     }
00425     else if(rset_.n_rows() <= 6) {
00426       first_ka_.Init(sqrt(2) * bandwidth, fx_param_int(module_, "order", 1), 
00427                      rset_.n_rows());
00428       second_ka_.Init(bandwidth, fx_param_int(module_, "order", 1), 
00429                       rset_.n_rows());
00430     }
00431     else {
00432       first_ka_.Init(sqrt(2) * bandwidth, fx_param_int(module_, "order", 0), 
00433                      rset_.n_rows());
00434       second_ka_.Init(bandwidth, fx_param_int(module_, "order", 0), 
00435                       rset_.n_rows()); 
00436     }
00437   }
00438 };
00439 
00440 #include "dualtree_kde_cv_impl.h"
00441 #undef INSIDE_DUALTREE_KDE_CV_H
00442 
00443 #endif
Generated on Mon Jan 24 12:04:38 2011 for FASTlib by  doxygen 1.6.3