00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00093 #ifndef DUALTREE_KDE_H
00094 #define DUALTREE_KDE_H
00095
00096 #define INSIDE_DUALTREE_KDE_H
00097
00098 #include "fastlib/fastlib.h"
00099 #include "mlpack/series_expansion/farfield_expansion.h"
00100 #include "mlpack/series_expansion/local_expansion.h"
00101 #include "mlpack/series_expansion/mult_farfield_expansion.h"
00102 #include "mlpack/series_expansion/mult_local_expansion.h"
00103 #include "mlpack/series_expansion/kernel_aux.h"
00104 #include "gen_metric_tree.h"
00105 #include "dualtree_kde_common.h"
00106 #include "kde_stat.h"
00107
00109 const fx_entry_doc kde_main_entries[] = {
00110 {"data", FX_REQUIRED, FX_STR, NULL,
00111 " A file containing reference data.\n"},
00112 {"query", FX_PARAM, FX_STR, NULL,
00113 " A file containing query data (defaults to data).\n"},
00114 FX_ENTRY_DOC_DONE
00115 };
00116
00117 const fx_entry_doc kde_entries[] = {
00118 {"bandwidth", FX_PARAM, FX_DOUBLE, NULL,
00119 " The bandwidth parameter.\n"},
00120 {"do_naive", FX_PARAM, FX_BOOL, NULL,
00121 " Whether to perform naive computation as well.\n"},
00122 {"dwgts", FX_PARAM, FX_STR, NULL,
00123 " A file that contains the weight of each point. If missing, will\
00124 assume uniform weight\n"},
00125 {"fast_kde_output", FX_PARAM, FX_STR, NULL,
00126 " A file to receive the results of computation.\n"},
00127 {"kernel", FX_PARAM, FX_STR, NULL,
00128 " The type of kernel to use.\n"},
00129 {"knn", FX_PARAM, FX_INT, NULL,
00130 " The number of k-nearest neighbor to use for variable bandwidth.\n"},
00131 {"loo", FX_PARAM, FX_BOOL, NULL,
00132 " Whether to output the density estimates using leave-one-out.\n"},
00133 {"mode", FX_PARAM, FX_STR, NULL,
00134 " Fixed bandwidth or variable bandwidth mode.\n"},
00135 {"multiplicative_expansion", FX_PARAM, FX_BOOL, NULL,
00136 " Whether to do O(p^D) kernel expansion instead of O(D^p).\n"},
00137 {"probability", FX_PARAM, FX_DOUBLE, NULL,
00138 " The probability guarantee that the relative error accuracy holds.\n"},
00139 {"relative_error", FX_PARAM, FX_DOUBLE, NULL,
00140 " The required relative error accuracy.\n"},
00141 {"threshold", FX_PARAM, FX_DOUBLE, NULL,
00142 " If less than this value, then absolute error bound.\n"},
00143 {"scaling", FX_PARAM, FX_STR, NULL,
00144 " The scaling option.\n"},
00145 FX_ENTRY_DOC_DONE
00146 };
00147
00148 const fx_module_doc kde_doc = {
00149 kde_entries, NULL,
00150 "Performs dual-tree kernel density estimate computation.\n"
00151 };
00152
00153 const fx_submodule_doc kde_main_submodules[] = {
00154 {"kde", &kde_doc,
00155 " Responsible for dual-tree kernel density estimate computation.\n"},
00156 FX_SUBMODULE_DOC_DONE
00157 };
00158
00159 const fx_module_doc kde_main_doc = {
00160 kde_main_entries, kde_main_submodules,
00161 "This is the driver for the kernel density estimator.\n"
00162 };
00163
00164
00165
00189 template<typename TKernelAux>
00190 class DualtreeKde {
00191
00192 friend class DualtreeKdeCommon;
00193
00194 public:
00195
00196
00197 typedef GeneralBinarySpaceTree<DBallBound < LMetric<2>, Vector>, Matrix, KdeStat<TKernelAux> > Tree;
00198
00199 private:
00200
00202
00206 static const int num_initial_samples_per_query_ = 25;
00207
00208 static const int sample_multiple_ = 1;
00209
00211
00214 struct datanode *module_;
00215
00218 bool leave_one_out_;
00219
00222 double mult_const_;
00223
00226 TKernelAux ka_;
00227
00230 Matrix qset_;
00231
00234 Tree *qroot_;
00235
00238 Matrix rset_;
00239
00242 Tree *rroot_;
00243
00246 Vector rset_weights_;
00247
00250 Vector densities_l_;
00251
00254 Vector densities_e_;
00255
00258 Vector densities_u_;
00259
00262 Vector used_error_;
00263
00267 Vector n_pruned_;
00268
00271 double rset_weight_sum_;
00272
00276 double relative_error_;
00277
00282 double threshold_;
00283
00286 int num_farfield_to_local_prunes_;
00287
00290 int num_farfield_prunes_;
00291
00294 int num_local_prunes_;
00295
00298 int num_finite_difference_prunes_;
00299
00302 int num_monte_carlo_prunes_;
00303
00307 ArrayList<index_t> old_from_new_queries_;
00308
00312 ArrayList<index_t> old_from_new_references_;
00313
00315
00316 void RefineBoundStatistics_(Tree *destination);
00317
00320 void DualtreeKdeBase_(Tree *qnode, Tree *rnode, double probability);
00321
00325 bool PrunableEnhanced_(Tree *qnode, Tree *rnode, double probability,
00326 DRange &dsqd_range, DRange &kernel_value_range,
00327 double &dl, double &du,
00328 double &used_error, double &n_pruned,
00329 int &order_farfield_to_local,
00330 int &order_farfield, int &order_local);
00331
00332 double EvalUnnormOnSq_(index_t reference_point_index,
00333 double squared_distance);
00334
00345 bool DualtreeKdeCanonical_(Tree *qnode, Tree *rnode, double probability);
00346
00351 void PreProcess(Tree *node);
00352
00355 void PostProcess(Tree *qnode);
00356
00357 public:
00358
00360
00363 DualtreeKde() {
00364 qroot_ = rroot_ = NULL;
00365 }
00366
00369 ~DualtreeKde() {
00370
00371 if(qroot_ != rroot_ ) {
00372 delete qroot_;
00373 delete rroot_;
00374 }
00375 else {
00376 delete rroot_;
00377 }
00378
00379 }
00380
00382
00385 void get_density_estimates(Vector *results) {
00386 results->Init(densities_e_.length());
00387
00388 for(index_t i = 0; i < densities_e_.length(); i++) {
00389 (*results)[i] = densities_e_[i];
00390 }
00391 }
00392
00394
00395 void Compute(Vector *results) {
00396
00397
00398 mult_const_ = 1.0 / ka_.kernel_.CalcNormConstant(qset_.n_rows());
00399
00400
00401 relative_error_ = fx_param_double(module_, "relative_error", 0.1);
00402 threshold_ = fx_param_double(module_, "threshold", 0) *
00403 ka_.kernel_.CalcNormConstant(qset_.n_rows());
00404
00405
00406 densities_l_.SetZero();
00407 densities_e_.SetZero();
00408 densities_u_.SetAll(rset_weight_sum_);
00409
00410
00411 used_error_.SetZero();
00412 n_pruned_.SetZero();
00413
00414
00415 num_finite_difference_prunes_ = num_monte_carlo_prunes_ =
00416 num_farfield_to_local_prunes_ = num_farfield_prunes_ =
00417 num_local_prunes_ = 0;
00418
00419 printf("\nStarting fast KDE on bandwidth value of %g...\n",
00420 sqrt(ka_.kernel_.bandwidth_sq()));
00421 fx_timer_start(NULL, "fast_kde_compute");
00422
00423
00424 PreProcess(rroot_);
00425 if(qroot_ != rroot_) {
00426 PreProcess(qroot_);
00427 }
00428
00429
00430
00431 double probability = fx_param_double(module_, "probability", 1);
00432 DualtreeKdeCanonical_(qroot_, rroot_, probability);
00433
00434
00435 PostProcess(qroot_);
00436 fx_timer_stop(NULL, "fast_kde_compute");
00437 printf("\nFast KDE completed...\n");
00438 printf("Finite difference prunes: %d\n", num_finite_difference_prunes_);
00439 printf("Monte Carlo prunes: %d\n", num_monte_carlo_prunes_);
00440 printf("F2L prunes: %d\n", num_farfield_to_local_prunes_);
00441 printf("F prunes: %d\n", num_farfield_prunes_);
00442 printf("L prunes: %d\n", num_local_prunes_);
00443
00444
00445
00446 Vector tmp_q_results;
00447 tmp_q_results.Init(densities_e_.length());
00448
00449 for(index_t i = 0; i < tmp_q_results.length(); i++) {
00450 tmp_q_results[old_from_new_queries_[i]] =
00451 densities_e_[i];
00452 }
00453 for(index_t i = 0; i < tmp_q_results.length(); i++) {
00454 densities_e_[i] = tmp_q_results[i];
00455 }
00456
00457
00458 get_density_estimates(results);
00459 }
00460
00461 void Init(const Matrix &queries, const Matrix &references,
00462 const Matrix &rset_weights, bool queries_equal_references,
00463 struct datanode *module_in) {
00464
00465
00466 module_ = module_in;
00467
00468
00469 leave_one_out_ = fx_param_exists(module_in, "loo") &&
00470 (queries.ptr() == references.ptr());
00471
00472
00473 int leaflen = fx_param_int(module_in, "leaflen", 20);
00474
00475
00476
00477
00478 rset_.Copy(references);
00479 rset_weights_.Init(rset_weights.n_cols());
00480 rset_weight_sum_ = 0;
00481 for(index_t i = 0; i < rset_weights.n_cols(); i++) {
00482 rset_weights_[i] = rset_weights.get(0, i);
00483 rset_weight_sum_ += rset_weights_[i];
00484 }
00485
00486
00487 if(queries_equal_references) {
00488 qset_.Alias(rset_);
00489 }
00490 else {
00491 qset_.Copy(queries);
00492 }
00493
00494
00495
00496
00497 fx_timer_start(NULL, "tree_d");
00498 rroot_ = proximity::MakeGenMetricTree<Tree>(rset_, leaflen,
00499 &old_from_new_references_,
00500 NULL);
00501 DualtreeKdeCommon::ShuffleAccordingToPermutation
00502 (rset_weights_, old_from_new_references_);
00503
00504 if(queries_equal_references) {
00505 qroot_ = rroot_;
00506 old_from_new_queries_.InitCopy(old_from_new_references_);
00507 }
00508 else {
00509 qroot_ = proximity::MakeGenMetricTree<Tree>(qset_, leaflen,
00510 &old_from_new_queries_,
00511 NULL);
00512 }
00513 fx_timer_stop(NULL, "tree_d");
00514
00515
00516 densities_l_.Init(qset_.n_cols());
00517 densities_e_.Init(qset_.n_cols());
00518 densities_u_.Init(qset_.n_cols());
00519
00520
00521 used_error_.Init(qset_.n_cols());
00522 n_pruned_.Init(qset_.n_cols());
00523
00524
00525 double bandwidth = fx_param_double_req(module_, "bandwidth");
00526
00527
00528 if(qset_.n_rows() <= 2) {
00529 ka_.Init(bandwidth, fx_param_int(module_, "order", 7), qset_.n_rows());
00530 }
00531 else if(qset_.n_rows() <= 3) {
00532 ka_.Init(bandwidth, fx_param_int(module_, "order", 5), qset_.n_rows());
00533 }
00534 else if(qset_.n_rows() <= 5) {
00535 ka_.Init(bandwidth, fx_param_int(module_, "order", 3), qset_.n_rows());
00536 }
00537 else if(qset_.n_rows() <= 6) {
00538 ka_.Init(bandwidth, fx_param_int(module_, "order", 1), qset_.n_rows());
00539 }
00540 else {
00541 ka_.Init(bandwidth, fx_param_int(module_, "order", 0), qset_.n_rows());
00542 }
00543 }
00544
00545 void PrintDebug() {
00546
00547 FILE *stream = stdout;
00548 const char *fname = NULL;
00549
00550 if((fname = fx_param_str(module_, "fast_kde_output",
00551 "fast_kde_output.txt")) != NULL) {
00552 stream = fopen(fname, "w+");
00553 }
00554 for(index_t q = 0; q < qset_.n_cols(); q++) {
00555 fprintf(stream, "%g\n", densities_e_[q]);
00556 }
00557
00558 if(stream != stdout) {
00559 fclose(stream);
00560 }
00561 }
00562
00563 };
00564
00565 #include "dualtree_kde_impl.h"
00566 #undef INSIDE_DUALTREE_KDE_H
00567
00568 #endif