dualtree_vkde.h

Go to the documentation of this file.
00001 /* MLPACK 0.2
00002  *
00003  * Copyright (c) 2008, 2009 Alexander Gray,
00004  *                          Garry Boyer,
00005  *                          Ryan Riegel,
00006  *                          Nikolaos Vasiloglou,
00007  *                          Dongryeol Lee,
00008  *                          Chip Mappus, 
00009  *                          Nishant Mehta,
00010  *                          Hua Ouyang,
00011  *                          Parikshit Ram,
00012  *                          Long Tran,
00013  *                          Wee Chin Wong
00014  *
00015  * Copyright (c) 2008, 2009 Georgia Institute of Technology
00016  *
00017  * This program is free software; you can redistribute it and/or
00018  * modify it under the terms of the GNU General Public License as
00019  * published by the Free Software Foundation; either version 2 of the
00020  * License, or (at your option) any later version.
00021  *
00022  * This program is distributed in the hope that it will be useful, but
00023  * WITHOUT ANY WARRANTY; without even the implied warranty of
00024  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
00025  * General Public License for more details.
00026  *
00027  * You should have received a copy of the GNU General Public License
00028  * along with this program; if not, write to the Free Software
00029  * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
00030  * 02110-1301, USA.
00031  */
00089 #ifndef DUALTREE_VKDE_H
00090 #define DUALTREE_VKDE_H
00091 
00092 #define INSIDE_DUALTREE_VKDE_H
00093 
00094 #include "fastlib/fastlib.h"
00095 #include "gen_metric_tree.h"
00096 #include "dualtree_kde_common.h"
00097 #include "kde_stat.h"
00098 #include "mlpack/allknn/allknn.h"
00099 
00123 template<typename TKernel>
00124 class DualtreeVKde {
00125   
00126   friend class DualtreeKdeCommon;
00127 
00128  public:
00129   
00130   // our tree type using the VKdeStat
00131   typedef GeneralBinarySpaceTree<DBallBound < LMetric<2>, Vector>, Matrix, VKdeStat<TKernel> > Tree;
00132     
00133  private:
00134 
00136 
00140   static const int num_initial_samples_per_query_ = 25;
00141 
00142   static const int sample_multiple_ = 10;
00143 
00145 
00148   struct datanode *module_;
00149 
00152   bool leave_one_out_;
00153 
00156   double mult_const_;
00157 
00160   ArrayList<TKernel> kernels_;
00161 
00164   Matrix qset_;
00165 
00168   Tree *qroot_;
00169 
00172   Matrix rset_;
00173   
00176   Tree *rroot_;
00177 
00180   Vector rset_weights_;
00181 
00184   Vector densities_l_;
00185 
00188   Vector densities_e_;
00189 
00192   Vector densities_u_;
00193 
00196   Vector used_error_;
00197 
00201   Vector n_pruned_;
00202 
00205   double rset_weight_sum_;
00206 
00210   double relative_error_;
00211 
00216   double threshold_;
00217   
00220   int num_finite_difference_prunes_;
00221 
00224   int num_monte_carlo_prunes_;
00225 
00229   ArrayList<index_t> old_from_new_queries_;
00230   
00234   ArrayList<index_t> old_from_new_references_;
00235 
00237 
00240   void DualtreeVKdeBase_(Tree *qnode, Tree *rnode, double probability);
00241 
00245   bool PrunableEnhanced_(Tree *qnode, Tree *rnode, double probability,
00246                          DRange &dsqd_range, DRange &kernel_value_range, 
00247                          double &dl, double &du,
00248                          double &used_error, double &n_pruned,
00249                          int &order_farfield_to_local,
00250                          int &order_farfield, int &order_local);
00251 
00252   double EvalUnnormOnSq_(index_t reference_point_index,
00253                          double squared_distance);
00254 
00265   bool DualtreeVKdeCanonical_(Tree *qnode, Tree *rnode, double probability);
00266 
00271   void PreProcess(Tree *node, bool reference_side);
00272 
00275   void PostProcess(Tree *qnode);
00276     
00277  public:
00278 
00280 
00283   DualtreeVKde() {
00284     qroot_ = rroot_ = NULL;
00285   }
00286 
00289   ~DualtreeVKde() { 
00290     
00291     if(qroot_ != rroot_ ) {
00292       delete qroot_; 
00293       delete rroot_; 
00294     } 
00295     else {
00296       delete rroot_;
00297     }
00298 
00299   }
00300 
00302 
00305   void get_density_estimates(Vector *results) { 
00306     results->Init(densities_e_.length());
00307     
00308     for(index_t i = 0; i < densities_e_.length(); i++) {
00309       (*results)[i] = densities_e_[i];
00310     }
00311   }
00312 
00314 
00315   void Compute(Vector *results) {
00316 
00317     // Set accuracy parameters.
00318     relative_error_ = fx_param_double(module_, "relative_error", 0.1);
00319     threshold_ = fx_param_double(module_, "threshold", 0) *
00320       kernels_[0].CalcNormConstant(qset_.n_rows());
00321     
00322     // initialize the lower and upper bound densities
00323     densities_l_.SetZero();
00324     densities_e_.SetZero();
00325     densities_u_.SetAll(rset_weight_sum_);
00326 
00327     // Set zero for error accounting stuff.
00328     used_error_.SetZero();
00329     n_pruned_.SetZero();
00330 
00331     // Reset prune statistics.
00332     num_finite_difference_prunes_ = num_monte_carlo_prunes_ = 0;
00333 
00334     printf("\nStarting variable KDE using %d neighbors...\n",
00335            (int) fx_param_int_req(module_, "knn"));
00336 
00337     fx_timer_start(NULL, "fast_kde_compute");
00338 
00339     // Preprocessing step for initializing series expansion objects
00340     PreProcess(rroot_, true);
00341     if(qroot_ != rroot_) {
00342       PreProcess(qroot_, false);
00343     }
00344     
00345     // Get the required probability guarantee for each query and call
00346     // the main routine.
00347     double probability = fx_param_double(module_, "probability", 1);
00348     DualtreeVKdeCanonical_(qroot_, rroot_, probability);
00349 
00350     // Postprocessing step for finalizing the sums.
00351     PostProcess(qroot_);
00352     fx_timer_stop(NULL, "fast_kde_compute");
00353     printf("\nFast KDE completed...\n");
00354     printf("Finite difference prunes: %d\n", num_finite_difference_prunes_);
00355     printf("Monte Carlo prunes: %d\n", num_monte_carlo_prunes_);
00356 
00357     // Reshuffle the results to account for dataset reshuffling
00358     // resulted from tree constructions.
00359     Vector tmp_q_results;
00360     tmp_q_results.Init(densities_e_.length());
00361     
00362     for(index_t i = 0; i < tmp_q_results.length(); i++) {
00363       tmp_q_results[old_from_new_queries_[i]] =
00364         densities_e_[i];
00365     }
00366     for(index_t i = 0; i < tmp_q_results.length(); i++) {
00367       densities_e_[i] = tmp_q_results[i];
00368     }
00369 
00370     // Retrieve density estimates.
00371     get_density_estimates(results);
00372   }
00373 
00374   void Init(const Matrix &queries, const Matrix &references,
00375             const Matrix &rset_weights, bool queries_equal_references, 
00376             struct datanode *module_in) {
00377 
00378     // point to the incoming module
00379     module_ = module_in;
00380 
00381     // Set the flag for whether to perform leave-one-out computation.
00382     leave_one_out_ = fx_param_exists(module_in, "loo") &&
00383       (queries.ptr() == references.ptr());
00384 
00385     // read in the number of points owned by a leaf
00386     int leaflen = fx_param_int(module_in, "leaflen", 20);
00387 
00388     // Copy reference dataset and reference weights and compute its
00389     // sum. rset_weight_sum_ should be the raw sum of the reference
00390     // weights, ignoring the possibly different normalizing constants
00391     // in the case of variable-bandwidth case.
00392     rset_.Copy(references);
00393     rset_weights_.Init(rset_weights.n_cols());
00394     rset_weight_sum_ = 0;
00395     for(index_t i = 0; i < rset_weights.n_cols(); i++) {
00396       rset_weights_[i] = rset_weights.get(0, i);
00397       rset_weight_sum_ += rset_weights_[i];
00398     }
00399 
00400     // Copy the query dataset.
00401     if(queries_equal_references) {
00402       qset_.Alias(rset_);
00403     }
00404     else {
00405       qset_.Copy(queries);
00406     }
00407 
00408     // Construct query and reference trees. Shuffle the reference
00409     // weights according to the permutation of the reference set in
00410     // the reference tree.
00411     fx_timer_start(NULL, "tree_d");
00412     rroot_ = proximity::MakeGenMetricTree<Tree>(rset_, leaflen,
00413                                                 &old_from_new_references_, 
00414                                                 NULL);
00415     DualtreeKdeCommon::ShuffleAccordingToPermutation
00416       (rset_weights_, old_from_new_references_);
00417 
00418     if(queries_equal_references) {
00419       qroot_ = rroot_;
00420       old_from_new_queries_.InitCopy(old_from_new_references_);
00421     }
00422     else {
00423       qroot_ = proximity::MakeGenMetricTree<Tree>(qset_, leaflen,
00424                                                   &old_from_new_queries_, 
00425                                                   NULL);
00426     }
00427     fx_timer_stop(NULL, "tree_d");
00428     
00429     // Initialize the density lists
00430     densities_l_.Init(qset_.n_cols());
00431     densities_e_.Init(qset_.n_cols());
00432     densities_u_.Init(qset_.n_cols());
00433 
00434     // Initialize the error accounting stuff.
00435     used_error_.Init(qset_.n_cols());
00436     n_pruned_.Init(qset_.n_cols());
00437 
00438     // Initialize the kernels for each reference point.
00439     int knns = fx_param_int_req(module_, "knn");
00440     AllkNN all_knn;
00441     kernels_.Init(rset_.n_cols());
00442     all_knn.Init(rset_, 20, knns);
00443     ArrayList<index_t> resulting_neighbors;
00444     ArrayList<double> squared_distances;    
00445 
00446     fx_timer_start(fx_root, "bandwidth_initialization");
00447     all_knn.ComputeNeighbors(&resulting_neighbors, &squared_distances);
00448 
00449     for(index_t i = 0; i < squared_distances.size(); i += knns) {
00450       kernels_[i / knns].Init(sqrt(squared_distances[i + knns - 1]));
00451     }
00452     fx_timer_stop(fx_root, "bandwidth_initialization");
00453 
00454     // Renormalize the reference weights according to the bandwidths
00455     // that have been chosen.
00456     double min_norm_const = DBL_MAX;
00457     for(index_t i = 0; i < rset_weights_.length(); i++) {
00458       double norm_const = kernels_[i].CalcNormConstant(qset_.n_rows());
00459       min_norm_const = std::min(min_norm_const, norm_const);
00460     }
00461     for(index_t i = 0; i < rset_weights_.length(); i++) {
00462       double norm_const = kernels_[i].CalcNormConstant(qset_.n_rows());
00463       rset_weights_[i] *= (min_norm_const / norm_const);
00464     }
00465 
00466     // Compute normalization constant.
00467     mult_const_ = 1.0 / min_norm_const;
00468   }
00469 
00470   void PrintDebug() {
00471 
00472     FILE *stream = stdout;
00473     const char *fname = NULL;
00474 
00475     if((fname = fx_param_str(module_, "fast_kde_output", 
00476                              "fast_kde_output.txt")) != NULL) {
00477       stream = fopen(fname, "w+");
00478     }
00479     for(index_t q = 0; q < qset_.n_cols(); q++) {
00480       fprintf(stream, "%g\n", densities_e_[q]);
00481     }
00482     
00483     if(stream != stdout) {
00484       fclose(stream);
00485     }
00486   }
00487 
00488 };
00489 
00490 #include "dualtree_vkde_impl.h"
00491 #undef INSIDE_DUALTREE_VKDE_H
00492 
00493 #endif
Generated on Mon Jan 24 12:04:38 2011 for FASTlib by  doxygen 1.6.3