dualtree_kde.h

Go to the documentation of this file.
00001 /* MLPACK 0.2
00002  *
00003  * Copyright (c) 2008, 2009 Alexander Gray,
00004  *                          Garry Boyer,
00005  *                          Ryan Riegel,
00006  *                          Nikolaos Vasiloglou,
00007  *                          Dongryeol Lee,
00008  *                          Chip Mappus, 
00009  *                          Nishant Mehta,
00010  *                          Hua Ouyang,
00011  *                          Parikshit Ram,
00012  *                          Long Tran,
00013  *                          Wee Chin Wong
00014  *
00015  * Copyright (c) 2008, 2009 Georgia Institute of Technology
00016  *
00017  * This program is free software; you can redistribute it and/or
00018  * modify it under the terms of the GNU General Public License as
00019  * published by the Free Software Foundation; either version 2 of the
00020  * License, or (at your option) any later version.
00021  *
00022  * This program is distributed in the hope that it will be useful, but
00023  * WITHOUT ANY WARRANTY; without even the implied warranty of
00024  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
00025  * General Public License for more details.
00026  *
00027  * You should have received a copy of the GNU General Public License
00028  * along with this program; if not, write to the Free Software
00029  * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
00030  * 02110-1301, USA.
00031  */
00093 #ifndef DUALTREE_KDE_H
00094 #define DUALTREE_KDE_H
00095 
00096 #define INSIDE_DUALTREE_KDE_H
00097 
00098 #include "fastlib/fastlib.h"
00099 #include "mlpack/series_expansion/farfield_expansion.h"
00100 #include "mlpack/series_expansion/local_expansion.h"
00101 #include "mlpack/series_expansion/mult_farfield_expansion.h"
00102 #include "mlpack/series_expansion/mult_local_expansion.h"
00103 #include "mlpack/series_expansion/kernel_aux.h"
00104 #include "gen_metric_tree.h"
00105 #include "dualtree_kde_common.h"
00106 #include "kde_stat.h"
00107 
00109 const fx_entry_doc kde_main_entries[] = {
00110   {"data", FX_REQUIRED, FX_STR, NULL,
00111    "  A file containing reference data.\n"},
00112   {"query", FX_PARAM, FX_STR, NULL,
00113    "  A file containing query data (defaults to data).\n"},
00114   FX_ENTRY_DOC_DONE
00115 };
00116 
00117 const fx_entry_doc kde_entries[] = {
00118   {"bandwidth", FX_PARAM, FX_DOUBLE, NULL,
00119    "  The bandwidth parameter.\n"},
00120   {"do_naive", FX_PARAM, FX_BOOL, NULL,
00121    "  Whether to perform naive computation as well.\n"},
00122   {"dwgts", FX_PARAM, FX_STR, NULL,
00123    "  A file that contains the weight of each point. If missing, will\
00124  assume uniform weight\n"},
00125   {"fast_kde_output", FX_PARAM, FX_STR, NULL,
00126    "  A file to receive the results of computation.\n"},
00127   {"kernel", FX_PARAM, FX_STR, NULL,
00128    "  The type of kernel to use.\n"},
00129   {"knn", FX_PARAM, FX_INT, NULL,
00130    "  The number of k-nearest neighbor to use for variable bandwidth.\n"},
00131   {"loo", FX_PARAM, FX_BOOL, NULL,
00132    "  Whether to output the density estimates using leave-one-out.\n"},
00133   {"mode", FX_PARAM, FX_STR, NULL,
00134    "  Fixed bandwidth or variable bandwidth mode.\n"},
00135   {"multiplicative_expansion", FX_PARAM, FX_BOOL, NULL,
00136    "  Whether to do O(p^D) kernel expansion instead of O(D^p).\n"},
00137   {"probability", FX_PARAM, FX_DOUBLE, NULL,
00138    "  The probability guarantee that the relative error accuracy holds.\n"},
00139   {"relative_error", FX_PARAM, FX_DOUBLE, NULL,
00140    "  The required relative error accuracy.\n"},
00141   {"threshold", FX_PARAM, FX_DOUBLE, NULL,
00142    "  If less than this value, then absolute error bound.\n"},
00143   {"scaling", FX_PARAM, FX_STR, NULL,
00144    "  The scaling option.\n"},
00145   FX_ENTRY_DOC_DONE
00146 };
00147 
00148 const fx_module_doc kde_doc = {
00149   kde_entries, NULL,
00150   "Performs dual-tree kernel density estimate computation.\n"
00151 };
00152 
00153 const fx_submodule_doc kde_main_submodules[] = {
00154   {"kde", &kde_doc,
00155    "  Responsible for dual-tree kernel density estimate computation.\n"},
00156   FX_SUBMODULE_DOC_DONE
00157 };
00158 
00159 const fx_module_doc kde_main_doc = {
00160   kde_main_entries, kde_main_submodules,
00161   "This is the driver for the kernel density estimator.\n"
00162 };
00163 
00164 
00165 
00189 template<typename TKernelAux>
00190 class DualtreeKde {
00191 
00192   friend class DualtreeKdeCommon;
00193 
00194  public:
00195   
00196   // our tree type using the KdeStat
00197   typedef GeneralBinarySpaceTree<DBallBound < LMetric<2>, Vector>, Matrix, KdeStat<TKernelAux> > Tree;
00198     
00199  private:
00200 
00202 
00206   static const int num_initial_samples_per_query_ = 25;
00207 
00208   static const int sample_multiple_ = 1;
00209 
00211 
00214   struct datanode *module_;
00215 
00218   bool leave_one_out_;
00219 
00222   double mult_const_;
00223 
00226   TKernelAux ka_;
00227 
00230   Matrix qset_;
00231 
00234   Tree *qroot_;
00235 
00238   Matrix rset_;
00239   
00242   Tree *rroot_;
00243 
00246   Vector rset_weights_;
00247 
00250   Vector densities_l_;
00251 
00254   Vector densities_e_;
00255 
00258   Vector densities_u_;
00259 
00262   Vector used_error_;
00263 
00267   Vector n_pruned_;
00268 
00271   double rset_weight_sum_;
00272 
00276   double relative_error_;
00277 
00282   double threshold_;
00283   
00286   int num_farfield_to_local_prunes_;
00287 
00290   int num_farfield_prunes_;
00291   
00294   int num_local_prunes_;
00295   
00298   int num_finite_difference_prunes_;
00299 
00302   int num_monte_carlo_prunes_;
00303 
00307   ArrayList<index_t> old_from_new_queries_;
00308   
00312   ArrayList<index_t> old_from_new_references_;
00313 
00315 
00316   void RefineBoundStatistics_(Tree *destination);
00317 
00320   void DualtreeKdeBase_(Tree *qnode, Tree *rnode, double probability);
00321 
00325   bool PrunableEnhanced_(Tree *qnode, Tree *rnode, double probability,
00326                          DRange &dsqd_range, DRange &kernel_value_range, 
00327                          double &dl, double &du,
00328                          double &used_error, double &n_pruned,
00329                          int &order_farfield_to_local,
00330                          int &order_farfield, int &order_local);
00331   
00332   double EvalUnnormOnSq_(index_t reference_point_index,
00333                          double squared_distance);
00334 
00345   bool DualtreeKdeCanonical_(Tree *qnode, Tree *rnode, double probability);
00346 
00351   void PreProcess(Tree *node);
00352 
00355   void PostProcess(Tree *qnode);
00356 
00357  public:
00358 
00360 
00363   DualtreeKde() {
00364     qroot_ = rroot_ = NULL;
00365   }
00366 
00369   ~DualtreeKde() { 
00370     
00371     if(qroot_ != rroot_ ) {
00372       delete qroot_; 
00373       delete rroot_; 
00374     } 
00375     else {
00376       delete rroot_;
00377     }
00378 
00379   }
00380 
00382 
00385   void get_density_estimates(Vector *results) { 
00386     results->Init(densities_e_.length());
00387     
00388     for(index_t i = 0; i < densities_e_.length(); i++) {
00389       (*results)[i] = densities_e_[i];
00390     }
00391   }
00392 
00394 
00395   void Compute(Vector *results) {
00396 
00397     // compute normalization constant
00398     mult_const_ = 1.0 / ka_.kernel_.CalcNormConstant(qset_.n_rows());
00399 
00400     // Set accuracy parameters.
00401     relative_error_ = fx_param_double(module_, "relative_error", 0.1);
00402     threshold_ = fx_param_double(module_, "threshold", 0) *
00403       ka_.kernel_.CalcNormConstant(qset_.n_rows());
00404     
00405     // initialize the lower and upper bound densities
00406     densities_l_.SetZero();
00407     densities_e_.SetZero();
00408     densities_u_.SetAll(rset_weight_sum_);
00409 
00410     // Set zero for error accounting stuff.
00411     used_error_.SetZero();
00412     n_pruned_.SetZero();
00413 
00414     // Reset prune statistics.
00415     num_finite_difference_prunes_ = num_monte_carlo_prunes_ =
00416       num_farfield_to_local_prunes_ = num_farfield_prunes_ = 
00417       num_local_prunes_ = 0;
00418 
00419     printf("\nStarting fast KDE on bandwidth value of %g...\n",
00420            sqrt(ka_.kernel_.bandwidth_sq()));
00421     fx_timer_start(NULL, "fast_kde_compute");
00422 
00423     // Preprocessing step for initializing series expansion objects
00424     PreProcess(rroot_);
00425     if(qroot_ != rroot_) {
00426       PreProcess(qroot_);
00427     }
00428     
00429     // Get the required probability guarantee for each query and call
00430     // the main routine.
00431     double probability = fx_param_double(module_, "probability", 1);
00432     DualtreeKdeCanonical_(qroot_, rroot_, probability);
00433 
00434     // Postprocessing step for finalizing the sums.
00435     PostProcess(qroot_);
00436     fx_timer_stop(NULL, "fast_kde_compute");
00437     printf("\nFast KDE completed...\n");
00438     printf("Finite difference prunes: %d\n", num_finite_difference_prunes_);
00439     printf("Monte Carlo prunes: %d\n", num_monte_carlo_prunes_);
00440     printf("F2L prunes: %d\n", num_farfield_to_local_prunes_);
00441     printf("F prunes: %d\n", num_farfield_prunes_);
00442     printf("L prunes: %d\n", num_local_prunes_);
00443 
00444     // Reshuffle the results to account for dataset reshuffling
00445     // resulted from tree constructions.
00446     Vector tmp_q_results;
00447     tmp_q_results.Init(densities_e_.length());
00448     
00449     for(index_t i = 0; i < tmp_q_results.length(); i++) {
00450       tmp_q_results[old_from_new_queries_[i]] =
00451         densities_e_[i];
00452     }
00453     for(index_t i = 0; i < tmp_q_results.length(); i++) {
00454       densities_e_[i] = tmp_q_results[i];
00455     }
00456 
00457     // Retrieve density estimates.
00458     get_density_estimates(results);
00459   }
00460 
00461   void Init(const Matrix &queries, const Matrix &references,
00462             const Matrix &rset_weights, bool queries_equal_references, 
00463             struct datanode *module_in) {
00464 
00465     // point to the incoming module
00466     module_ = module_in;
00467 
00468     // Set the flag for whether to perform leave-one-out computation.
00469     leave_one_out_ = fx_param_exists(module_in, "loo") &&
00470       (queries.ptr() == references.ptr());
00471 
00472     // Read in the number of points owned by a leaf.
00473     int leaflen = fx_param_int(module_in, "leaflen", 20);
00474     
00475 
00476     // Copy reference dataset and reference weights and compute its
00477     // sum.
00478     rset_.Copy(references);
00479     rset_weights_.Init(rset_weights.n_cols());
00480     rset_weight_sum_ = 0;
00481     for(index_t i = 0; i < rset_weights.n_cols(); i++) {
00482       rset_weights_[i] = rset_weights.get(0, i);
00483       rset_weight_sum_ += rset_weights_[i];
00484     }
00485 
00486     // Copy query dataset.
00487     if(queries_equal_references) {
00488       qset_.Alias(rset_);
00489     }
00490     else {
00491       qset_.Copy(queries);
00492     }
00493 
00494     // Construct query and reference trees. Shuffle the reference
00495     // weights according to the permutation of the reference set in
00496     // the reference tree.
00497     fx_timer_start(NULL, "tree_d");
00498     rroot_ = proximity::MakeGenMetricTree<Tree>(rset_, leaflen,
00499                                                 &old_from_new_references_, 
00500                                                 NULL);
00501     DualtreeKdeCommon::ShuffleAccordingToPermutation
00502       (rset_weights_, old_from_new_references_);
00503 
00504     if(queries_equal_references) {
00505       qroot_ = rroot_;
00506       old_from_new_queries_.InitCopy(old_from_new_references_);
00507     }
00508     else {
00509       qroot_ = proximity::MakeGenMetricTree<Tree>(qset_, leaflen,
00510                                                   &old_from_new_queries_, 
00511                                                   NULL);
00512     }
00513     fx_timer_stop(NULL, "tree_d");
00514     
00515     // Initialize the density lists
00516     densities_l_.Init(qset_.n_cols());
00517     densities_e_.Init(qset_.n_cols());
00518     densities_u_.Init(qset_.n_cols());
00519 
00520     // Initialize the error accounting stuff.
00521     used_error_.Init(qset_.n_cols());
00522     n_pruned_.Init(qset_.n_cols());
00523 
00524     // Initialize the kernel.
00525     double bandwidth = fx_param_double_req(module_, "bandwidth");
00526 
00527     // initialize the series expansion object
00528     if(qset_.n_rows() <= 2) {
00529       ka_.Init(bandwidth, fx_param_int(module_, "order", 7), qset_.n_rows());
00530     }
00531     else if(qset_.n_rows() <= 3) {
00532       ka_.Init(bandwidth, fx_param_int(module_, "order", 5), qset_.n_rows());
00533     }
00534     else if(qset_.n_rows() <= 5) {
00535       ka_.Init(bandwidth, fx_param_int(module_, "order", 3), qset_.n_rows());
00536     }
00537     else if(qset_.n_rows() <= 6) {
00538       ka_.Init(bandwidth, fx_param_int(module_, "order", 1), qset_.n_rows());
00539     }
00540     else {
00541       ka_.Init(bandwidth, fx_param_int(module_, "order", 0), qset_.n_rows());
00542     }
00543   }
00544 
00545   void PrintDebug() {
00546 
00547     FILE *stream = stdout;
00548     const char *fname = NULL;
00549 
00550     if((fname = fx_param_str(module_, "fast_kde_output", 
00551                              "fast_kde_output.txt")) != NULL) {
00552       stream = fopen(fname, "w+");
00553     }
00554     for(index_t q = 0; q < qset_.n_cols(); q++) {
00555       fprintf(stream, "%g\n", densities_e_[q]);
00556     }
00557     
00558     if(stream != stdout) {
00559       fclose(stream);
00560     }
00561   }
00562 
00563 };
00564 
00565 #include "dualtree_kde_impl.h"
00566 #undef INSIDE_DUALTREE_KDE_H
00567 
00568 #endif
Generated on Mon Jan 24 12:04:38 2011 for FASTlib by  doxygen 1.6.3