00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00090 #ifndef DUALTREE_KDE_CV_H
00091 #define DUALTREE_KDE_CV_H
00092
00093 #define INSIDE_DUALTREE_KDE_CV_H
00094
00095 #include "fastlib/fastlib.h"
00096 #include "mlpack/series_expansion/farfield_expansion.h"
00097 #include "mlpack/series_expansion/local_expansion.h"
00098 #include "mlpack/series_expansion/mult_farfield_expansion.h"
00099 #include "mlpack/series_expansion/mult_local_expansion.h"
00100 #include "mlpack/series_expansion/kernel_aux.h"
00101 #include "gen_metric_tree.h"
00102 #include "dualtree_kde_cv_common.h"
00103 #include "kde_cv_stat.h"
00104
00127 template<typename TKernelAux>
00128 class DualtreeKdeCV {
00129
00130 friend class DualtreeKdeCommon;
00131
00132 friend class DualtreeKdeCVCommon;
00133
00134 public:
00135
00136
00137 typedef GeneralBinarySpaceTree<DBallBound < LMetric<2>, Vector>, Matrix, KdeCVStat<TKernelAux> > Tree;
00138
00139 private:
00140
00142
00146 static const int num_initial_samples_per_query_ = 25;
00147
00148 static const int sample_multiple_ = 1;
00149
00151
00154 struct datanode *module_;
00155
00159 TKernelAux first_ka_;
00160
00164 TKernelAux second_ka_;
00165
00168 Matrix rset_;
00169
00172 Tree *rroot_;
00173
00176 Vector rset_weights_;
00177
00178 double first_sum_l_;
00179
00184 double first_sum_e_;
00185
00186 double first_sum_u_;
00187
00188 double second_sum_l_;
00189
00194 double second_sum_e_;
00195
00196 double second_sum_u_;
00197
00200 double first_mult_const_;
00201
00204 double second_mult_const_;
00205
00206 double first_used_error_;
00207
00208 double second_used_error_;
00209
00210 double n_pruned_;
00211
00214 double rset_weight_sum_;
00215
00219 double relative_error_;
00220
00225 double threshold_;
00226
00229 int num_farfield_to_local_prunes_;
00230
00233 int num_farfield_prunes_;
00234
00237 int num_local_prunes_;
00238
00241 int num_finite_difference_prunes_;
00242
00245 int num_monte_carlo_prunes_;
00246
00250 ArrayList<index_t> old_from_new_references_;
00251
00253
00256 void DualtreeKdeCVBase_(Tree *qnode, Tree *rnode, double probability);
00257
00261 bool PrunableEnhanced_(Tree *qnode, Tree *rnode, double probability,
00262 DRange &dsqd_range, DRange &first_kernel_value_range,
00263 DRange &second_kernel_value_range, double &first_dl,
00264 double &first_de, double &first_du,
00265 double &first_used_error, int &first_order,
00266 double &second_dl, double &second_de,
00267 double &second_du, double &second_used_error,
00268 int &second_order, double &n_pruned);
00269
00270 void EvalUnnormOnSq_(index_t reference_point_index, double squared_distance,
00271 double *first_kernel_value,
00272 double *second_kernel_value);
00273
00284 bool DualtreeKdeCVCanonical_(Tree *qnode, Tree *rnode, double probability);
00285
00290 void PreProcess(Tree *node);
00291
00292 public:
00293
00295
00298 DualtreeKdeCV() {
00299 rroot_ = NULL;
00300 }
00301
00304 ~DualtreeKdeCV() {
00305 delete rroot_;
00306 }
00307
00309
00310
00312
00313 double Compute() {
00314
00315
00316 first_mult_const_ = 1.0 /
00317 (pow(sqrt(2), rset_.n_rows()) *
00318 second_ka_.kernel_.CalcNormConstant(rset_.n_rows()));
00319 second_mult_const_ = 1.0 /
00320 second_ka_.kernel_.CalcNormConstant(rset_.n_rows());
00321
00322
00323 relative_error_ = fx_param_double(module_, "relative_error", 0.1);
00324 threshold_ = fx_param_double(module_, "threshold", 0) *
00325 first_ka_.kernel_.CalcNormConstant(rset_.n_rows());
00326
00327
00328 num_finite_difference_prunes_ = num_monte_carlo_prunes_ =
00329 num_farfield_to_local_prunes_ = num_farfield_prunes_ =
00330 num_local_prunes_ = 0;
00331
00332 printf("\nStarting fast KDE on bandwidth value of %g...\n",
00333 sqrt(second_ka_.kernel_.bandwidth_sq()));
00334 fx_timer_start(NULL, "fast_kde_compute");
00335
00336
00337 first_sum_l_ = first_sum_e_ = 0;
00338 first_sum_u_ = rset_weight_sum_ * rroot_->count();
00339 second_sum_l_ = second_sum_e_ = 0;
00340 second_sum_u_ = rset_weight_sum_ * rroot_->count();
00341 first_used_error_ = second_used_error_ = 0;
00342 n_pruned_ = 0;
00343
00344
00345 PreProcess(rroot_);
00346
00347
00348
00349 double probability = fx_param_double(module_, "probability", 1);
00350 DualtreeKdeCVCanonical_(rroot_, rroot_, probability);
00351 fx_timer_stop(NULL, "fast_kde_compute");
00352 printf("\nFast KDE completed...\n");
00353 printf("Finite difference prunes: %d\n", num_finite_difference_prunes_);
00354 printf("Monte Carlo prunes: %d\n", num_monte_carlo_prunes_);
00355 printf("F2L prunes: %d\n", num_farfield_to_local_prunes_);
00356 printf("F prunes: %d\n", num_farfield_prunes_);
00357 printf("L prunes: %d\n", num_local_prunes_);
00358
00359
00360 first_sum_e_ *= (first_mult_const_ / rset_weight_sum_);
00361 second_sum_e_ *= (second_mult_const_ / rset_weight_sum_);
00362
00363
00364 double lscv_score =
00365 (first_sum_e_ - 2.0 * second_sum_e_ +
00366 2.0 * second_ka_.kernel_.EvalUnnormOnSq(0.0) /
00367 second_ka_.kernel_.CalcNormConstant(rset_.n_rows())) /
00368 ((double) rset_.n_cols());
00369 return lscv_score;
00370 }
00371
00372 void Init(const Matrix &references, const Matrix &rset_weights,
00373 struct datanode *module_in) {
00374
00375
00376 module_ = module_in;
00377
00378
00379 int leaflen = fx_param_int(module_in, "leaflen", 20);
00380
00381
00382
00383 rset_.Copy(references);
00384 rset_weights_.Init(rset_weights.n_cols());
00385 rset_weight_sum_ = 0;
00386 for(index_t i = 0; i < rset_weights.n_cols(); i++) {
00387 rset_weights_[i] = rset_weights.get(0, i);
00388 rset_weight_sum_ += rset_weights_[i];
00389 }
00390
00391
00392
00393
00394 fx_timer_start(NULL, "tree_d");
00395 rroot_ = proximity::MakeGenMetricTree<Tree>(rset_, leaflen,
00396 &old_from_new_references_,
00397 NULL);
00398 DualtreeKdeCommon::ShuffleAccordingToPermutation
00399 (rset_weights_, old_from_new_references_);
00400 fx_timer_stop(NULL, "tree_d");
00401
00402
00403 double bandwidth = fx_param_double_req(module_, "bandwidth");
00404
00405
00406
00407 if(rset_.n_rows() <= 2) {
00408 first_ka_.Init(sqrt(2) * bandwidth, fx_param_int(module_, "order", 7),
00409 rset_.n_rows());
00410 second_ka_.Init(bandwidth, fx_param_int(module_, "order", 7),
00411 rset_.n_rows());
00412 }
00413 else if(rset_.n_rows() <= 3) {
00414 first_ka_.Init(sqrt(2) * bandwidth, fx_param_int(module_, "order", 5),
00415 rset_.n_rows());
00416 second_ka_.Init(bandwidth, fx_param_int(module_, "order", 5),
00417 rset_.n_rows());
00418 }
00419 else if(rset_.n_rows() <= 5) {
00420 first_ka_.Init(sqrt(2) * bandwidth, fx_param_int(module_, "order", 3),
00421 rset_.n_rows());
00422 second_ka_.Init(bandwidth, fx_param_int(module_, "order", 3),
00423 rset_.n_rows());
00424 }
00425 else if(rset_.n_rows() <= 6) {
00426 first_ka_.Init(sqrt(2) * bandwidth, fx_param_int(module_, "order", 1),
00427 rset_.n_rows());
00428 second_ka_.Init(bandwidth, fx_param_int(module_, "order", 1),
00429 rset_.n_rows());
00430 }
00431 else {
00432 first_ka_.Init(sqrt(2) * bandwidth, fx_param_int(module_, "order", 0),
00433 rset_.n_rows());
00434 second_ka_.Init(bandwidth, fx_param_int(module_, "order", 0),
00435 rset_.n_rows());
00436 }
00437 }
00438 };
00439
00440 #include "dualtree_kde_cv_impl.h"
00441 #undef INSIDE_DUALTREE_KDE_CV_H
00442
00443 #endif