00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00089 #ifndef DUALTREE_VKDE_H
00090 #define DUALTREE_VKDE_H
00091
00092 #define INSIDE_DUALTREE_VKDE_H
00093
00094 #include "fastlib/fastlib.h"
00095 #include "gen_metric_tree.h"
00096 #include "dualtree_kde_common.h"
00097 #include "kde_stat.h"
00098 #include "mlpack/allknn/allknn.h"
00099
00123 template<typename TKernel>
00124 class DualtreeVKde {
00125
00126 friend class DualtreeKdeCommon;
00127
00128 public:
00129
00130
00131 typedef GeneralBinarySpaceTree<DBallBound < LMetric<2>, Vector>, Matrix, VKdeStat<TKernel> > Tree;
00132
00133 private:
00134
00136
00140 static const int num_initial_samples_per_query_ = 25;
00141
00142 static const int sample_multiple_ = 10;
00143
00145
00148 struct datanode *module_;
00149
00152 bool leave_one_out_;
00153
00156 double mult_const_;
00157
00160 ArrayList<TKernel> kernels_;
00161
00164 Matrix qset_;
00165
00168 Tree *qroot_;
00169
00172 Matrix rset_;
00173
00176 Tree *rroot_;
00177
00180 Vector rset_weights_;
00181
00184 Vector densities_l_;
00185
00188 Vector densities_e_;
00189
00192 Vector densities_u_;
00193
00196 Vector used_error_;
00197
00201 Vector n_pruned_;
00202
00205 double rset_weight_sum_;
00206
00210 double relative_error_;
00211
00216 double threshold_;
00217
00220 int num_finite_difference_prunes_;
00221
00224 int num_monte_carlo_prunes_;
00225
00229 ArrayList<index_t> old_from_new_queries_;
00230
00234 ArrayList<index_t> old_from_new_references_;
00235
00237
00240 void DualtreeVKdeBase_(Tree *qnode, Tree *rnode, double probability);
00241
00245 bool PrunableEnhanced_(Tree *qnode, Tree *rnode, double probability,
00246 DRange &dsqd_range, DRange &kernel_value_range,
00247 double &dl, double &du,
00248 double &used_error, double &n_pruned,
00249 int &order_farfield_to_local,
00250 int &order_farfield, int &order_local);
00251
00252 double EvalUnnormOnSq_(index_t reference_point_index,
00253 double squared_distance);
00254
00265 bool DualtreeVKdeCanonical_(Tree *qnode, Tree *rnode, double probability);
00266
00271 void PreProcess(Tree *node, bool reference_side);
00272
00275 void PostProcess(Tree *qnode);
00276
00277 public:
00278
00280
00283 DualtreeVKde() {
00284 qroot_ = rroot_ = NULL;
00285 }
00286
00289 ~DualtreeVKde() {
00290
00291 if(qroot_ != rroot_ ) {
00292 delete qroot_;
00293 delete rroot_;
00294 }
00295 else {
00296 delete rroot_;
00297 }
00298
00299 }
00300
00302
00305 void get_density_estimates(Vector *results) {
00306 results->Init(densities_e_.length());
00307
00308 for(index_t i = 0; i < densities_e_.length(); i++) {
00309 (*results)[i] = densities_e_[i];
00310 }
00311 }
00312
00314
00315 void Compute(Vector *results) {
00316
00317
00318 relative_error_ = fx_param_double(module_, "relative_error", 0.1);
00319 threshold_ = fx_param_double(module_, "threshold", 0) *
00320 kernels_[0].CalcNormConstant(qset_.n_rows());
00321
00322
00323 densities_l_.SetZero();
00324 densities_e_.SetZero();
00325 densities_u_.SetAll(rset_weight_sum_);
00326
00327
00328 used_error_.SetZero();
00329 n_pruned_.SetZero();
00330
00331
00332 num_finite_difference_prunes_ = num_monte_carlo_prunes_ = 0;
00333
00334 printf("\nStarting variable KDE using %d neighbors...\n",
00335 (int) fx_param_int_req(module_, "knn"));
00336
00337 fx_timer_start(NULL, "fast_kde_compute");
00338
00339
00340 PreProcess(rroot_, true);
00341 if(qroot_ != rroot_) {
00342 PreProcess(qroot_, false);
00343 }
00344
00345
00346
00347 double probability = fx_param_double(module_, "probability", 1);
00348 DualtreeVKdeCanonical_(qroot_, rroot_, probability);
00349
00350
00351 PostProcess(qroot_);
00352 fx_timer_stop(NULL, "fast_kde_compute");
00353 printf("\nFast KDE completed...\n");
00354 printf("Finite difference prunes: %d\n", num_finite_difference_prunes_);
00355 printf("Monte Carlo prunes: %d\n", num_monte_carlo_prunes_);
00356
00357
00358
00359 Vector tmp_q_results;
00360 tmp_q_results.Init(densities_e_.length());
00361
00362 for(index_t i = 0; i < tmp_q_results.length(); i++) {
00363 tmp_q_results[old_from_new_queries_[i]] =
00364 densities_e_[i];
00365 }
00366 for(index_t i = 0; i < tmp_q_results.length(); i++) {
00367 densities_e_[i] = tmp_q_results[i];
00368 }
00369
00370
00371 get_density_estimates(results);
00372 }
00373
00374 void Init(const Matrix &queries, const Matrix &references,
00375 const Matrix &rset_weights, bool queries_equal_references,
00376 struct datanode *module_in) {
00377
00378
00379 module_ = module_in;
00380
00381
00382 leave_one_out_ = fx_param_exists(module_in, "loo") &&
00383 (queries.ptr() == references.ptr());
00384
00385
00386 int leaflen = fx_param_int(module_in, "leaflen", 20);
00387
00388
00389
00390
00391
00392 rset_.Copy(references);
00393 rset_weights_.Init(rset_weights.n_cols());
00394 rset_weight_sum_ = 0;
00395 for(index_t i = 0; i < rset_weights.n_cols(); i++) {
00396 rset_weights_[i] = rset_weights.get(0, i);
00397 rset_weight_sum_ += rset_weights_[i];
00398 }
00399
00400
00401 if(queries_equal_references) {
00402 qset_.Alias(rset_);
00403 }
00404 else {
00405 qset_.Copy(queries);
00406 }
00407
00408
00409
00410
00411 fx_timer_start(NULL, "tree_d");
00412 rroot_ = proximity::MakeGenMetricTree<Tree>(rset_, leaflen,
00413 &old_from_new_references_,
00414 NULL);
00415 DualtreeKdeCommon::ShuffleAccordingToPermutation
00416 (rset_weights_, old_from_new_references_);
00417
00418 if(queries_equal_references) {
00419 qroot_ = rroot_;
00420 old_from_new_queries_.InitCopy(old_from_new_references_);
00421 }
00422 else {
00423 qroot_ = proximity::MakeGenMetricTree<Tree>(qset_, leaflen,
00424 &old_from_new_queries_,
00425 NULL);
00426 }
00427 fx_timer_stop(NULL, "tree_d");
00428
00429
00430 densities_l_.Init(qset_.n_cols());
00431 densities_e_.Init(qset_.n_cols());
00432 densities_u_.Init(qset_.n_cols());
00433
00434
00435 used_error_.Init(qset_.n_cols());
00436 n_pruned_.Init(qset_.n_cols());
00437
00438
00439 int knns = fx_param_int_req(module_, "knn");
00440 AllkNN all_knn;
00441 kernels_.Init(rset_.n_cols());
00442 all_knn.Init(rset_, 20, knns);
00443 ArrayList<index_t> resulting_neighbors;
00444 ArrayList<double> squared_distances;
00445
00446 fx_timer_start(fx_root, "bandwidth_initialization");
00447 all_knn.ComputeNeighbors(&resulting_neighbors, &squared_distances);
00448
00449 for(index_t i = 0; i < squared_distances.size(); i += knns) {
00450 kernels_[i / knns].Init(sqrt(squared_distances[i + knns - 1]));
00451 }
00452 fx_timer_stop(fx_root, "bandwidth_initialization");
00453
00454
00455
00456 double min_norm_const = DBL_MAX;
00457 for(index_t i = 0; i < rset_weights_.length(); i++) {
00458 double norm_const = kernels_[i].CalcNormConstant(qset_.n_rows());
00459 min_norm_const = std::min(min_norm_const, norm_const);
00460 }
00461 for(index_t i = 0; i < rset_weights_.length(); i++) {
00462 double norm_const = kernels_[i].CalcNormConstant(qset_.n_rows());
00463 rset_weights_[i] *= (min_norm_const / norm_const);
00464 }
00465
00466
00467 mult_const_ = 1.0 / min_norm_const;
00468 }
00469
00470 void PrintDebug() {
00471
00472 FILE *stream = stdout;
00473 const char *fname = NULL;
00474
00475 if((fname = fx_param_str(module_, "fast_kde_output",
00476 "fast_kde_output.txt")) != NULL) {
00477 stream = fopen(fname, "w+");
00478 }
00479 for(index_t q = 0; q < qset_.n_cols(); q++) {
00480 fprintf(stream, "%g\n", densities_e_[q]);
00481 }
00482
00483 if(stream != stdout) {
00484 fclose(stream);
00485 }
00486 }
00487
00488 };
00489
00490 #include "dualtree_vkde_impl.h"
00491 #undef INSIDE_DUALTREE_VKDE_H
00492
00493 #endif