00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00039 #ifndef MLPACK_ALLNN_H_
00040 #define MLPACK_ALLNN_H_
00041
00042 #include <fastlib/fastlib.h>
00043
00044 const fx_entry_doc allnn_entries[] = {
00045 {"leaf_size", FX_PARAM, FX_INT, NULL,
00046 " The maximum number of points to store at a leaf.\n"},
00047 {"tree_building", FX_TIMER, FX_CUSTOM, NULL,
00048 " Time spent building the kd-tree.\n"},
00049 {"dual_tree_computation", FX_TIMER, FX_CUSTOM, NULL,
00050 " Time spent computing the nearest neighbors.\n"},
00051 {"number_of_prunes", FX_RESULT, FX_INT, NULL,
00052 " Total node-pairs found to be too far to matter.\n"},
00053 FX_ENTRY_DOC_DONE
00054 };
00055
00056 const fx_module_doc allnn_doc = {
00057 allnn_entries, NULL,
00058 "Performs dual-tree all-nearest-neighbors computation.\n"
00059 };
00060
00061 const fx_entry_doc allnn_naive_entries[] = {
00062 {"naive_time", FX_TIMER, FX_CUSTOM, NULL,
00063 " Time spend performing the naive computation.\n"},
00064 FX_ENTRY_DOC_DONE
00065 };
00066
00067 const fx_module_doc allnn_naive_doc = {
00068 allnn_naive_entries, NULL,
00069 "Performs naive all-nearest-neighbors computation.\n"
00070 };
00071
00095 class AllNN {
00096
00098
00099 private:
00107 class QueryStat {
00108
00109 private:
00114 double max_distance_so_far_;
00115
00116 public:
00117 QueryStat() {
00118 }
00119 double max_distance_so_far() {
00120 return max_distance_so_far_;
00121 }
00122
00123 void set_max_distance_so_far(double new_dist) {
00124 max_distance_so_far_ = new_dist;
00125 }
00126
00134 void Init(const Matrix& matrix, index_t start, index_t count) {
00135
00136 max_distance_so_far_ = DBL_MAX;
00137 }
00138
00145 void Init(const Matrix& matrix, index_t start, index_t count,
00146 const QueryStat& left, const QueryStat& right) {
00147 Init(matrix, start, count);
00148 }
00149
00150 };
00151
00152
00153
00154
00155
00156
00158 typedef BinarySpaceTree<DHrectBound<2>, Matrix, QueryStat> TreeType;
00159
00161
00162 private:
00164 struct datanode* module_;
00165
00167 Matrix queries_;
00169 Matrix references_;
00170
00172 TreeType* query_tree_;
00174 TreeType* reference_tree_;
00175
00177 index_t leaf_size_;
00178
00180 GenVector<index_t> old_from_new_queries_;
00182 GenVector<index_t> old_from_new_references_;
00183
00188 Vector neighbor_distances_;
00193 GenVector<index_t> neighbor_indices_;
00194
00196 index_t number_of_prunes_;
00197
00199 bool initialized_;
00201 bool already_used_;
00202
00203
00205
00206 FORBID_ACCIDENTAL_COPIES(AllNN);
00207
00208 public:
00209
00210
00211
00212
00213
00214
00215 AllNN() {
00216 query_tree_ = NULL;
00217 reference_tree_ = NULL;
00218
00219 DEBUG_POISON_PTR(module_);
00220 DEBUG_ONLY(leaf_size_ = BIG_BAD_NUMBER);
00221 DEBUG_ONLY(number_of_prunes_ = BIG_BAD_NUMBER);
00222
00223 DEBUG_ONLY(initialized_ = false);
00224 DEBUG_ONLY(already_used_ = false);
00225 }
00226
00227
00228 ~AllNN() {
00229 if (query_tree_ != NULL) {
00230 delete query_tree_;
00231 }
00232 if (reference_tree_ != query_tree_) {
00233 delete reference_tree_;
00234 }
00235
00236 }
00237
00239
00244 double MinNodeDistSq_(TreeType* query_node, TreeType* reference_node) {
00245 return query_node->bound().MinDistanceSq(reference_node->bound());
00246 }
00247
00248
00254 void GNPBaseCase_(TreeType* query_node, TreeType* reference_node) {
00255
00256
00257
00258
00259
00260 DEBUG_ASSERT(query_node != NULL);
00261 DEBUG_ASSERT(reference_node != NULL);
00262
00263
00264 DEBUG_WARN_IF(!query_node->is_leaf());
00265 DEBUG_WARN_IF(!reference_node->is_leaf());
00266
00267
00268 double max_nearest_neighbor_distance = -1.0;
00269
00270
00271
00272
00273
00274
00275 for (index_t query_index = query_node->begin();
00276 query_index < query_node->end(); query_index++) {
00277
00278
00279
00280
00281
00282
00283
00284
00285
00286 Vector query_point;
00287 queries_.MakeColumnVector(query_index, &query_point);
00288
00289
00290
00291
00292
00293
00294
00295
00296 for (index_t reference_index = reference_node->begin();
00297 reference_index < reference_node->end(); reference_index++) {
00298
00299 Vector reference_point;
00300 references_.MakeColumnVector(reference_index, &reference_point);
00301 if (likely(reference_node != query_node ||
00302 reference_index != query_index)) {
00303
00304 double distance =
00305 la::DistanceSqEuclidean(query_point, reference_point);
00306
00307
00308 if (distance < neighbor_distances_[query_index]) {
00309 neighbor_distances_[query_index] = distance;
00310 neighbor_indices_[query_index] = reference_index;
00311 }
00312 }
00313 }
00314
00315
00316 if (neighbor_distances_[query_index] > max_nearest_neighbor_distance) {
00317 max_nearest_neighbor_distance = neighbor_distances_[query_index];
00318 }
00319
00320 }
00321
00322
00323 query_node->stat().set_max_distance_so_far(max_nearest_neighbor_distance);
00324
00325 }
00326
00327
00332 void GNPRecursion_(TreeType* query_node, TreeType* reference_node,
00333 double lower_bound_distance) {
00334
00335
00336 DEBUG_ASSERT(query_node != NULL);
00337 DEBUG_ASSERT(reference_node != NULL);
00338
00339
00340
00341
00342
00343
00344
00345 DEBUG_SAME_DOUBLE(lower_bound_distance,
00346 MinNodeDistSq_(query_node, reference_node));
00347
00348 if (lower_bound_distance > query_node->stat().max_distance_so_far()) {
00349
00350
00351
00352
00353
00354
00355
00356 number_of_prunes_++;
00357
00358 } else if (query_node->is_leaf() && reference_node->is_leaf()) {
00359
00360
00361 GNPBaseCase_(query_node, reference_node);
00362
00363 } else if (query_node->is_leaf()) {
00364
00365
00366 double left_distance =
00367 MinNodeDistSq_(query_node, reference_node->left());
00368 double right_distance =
00369 MinNodeDistSq_(query_node, reference_node->right());
00370
00371
00372
00373
00374
00375 if (left_distance < right_distance) {
00376
00377 GNPRecursion_(query_node, reference_node->left(), left_distance);
00378
00379 GNPRecursion_(query_node, reference_node->right(), right_distance);
00380 } else {
00381
00382 GNPRecursion_(query_node, reference_node->right(), right_distance);
00383
00384
00385 GNPRecursion_(query_node, reference_node->left(), left_distance);
00386 }
00387
00388 } else if (reference_node->is_leaf()) {
00389
00390
00391 double left_distance =
00392 MinNodeDistSq_(query_node->left(), reference_node);
00393 double right_distance =
00394 MinNodeDistSq_(query_node->right(), reference_node);
00395
00396
00397
00398 GNPRecursion_(query_node->left(), reference_node, left_distance);
00399 GNPRecursion_(query_node->right(), reference_node, right_distance);
00400
00401
00402 query_node->stat().set_max_distance_so_far(
00403 max(query_node->left()->stat().max_distance_so_far(),
00404 query_node->right()->stat().max_distance_so_far()));
00405
00406 } else {
00407
00408
00409
00410
00411
00412
00413
00414
00415
00416 double left_distance =
00417 MinNodeDistSq_(query_node->left(), reference_node->left());
00418 double right_distance =
00419 MinNodeDistSq_(query_node->left(), reference_node->right());
00420
00421 if (left_distance < right_distance) {
00422
00423 GNPRecursion_(query_node->left(),
00424 reference_node->left(), left_distance);
00425
00426 GNPRecursion_(query_node->left(),
00427 reference_node->right(), right_distance);
00428 } else {
00429
00430 GNPRecursion_(query_node->left(),
00431 reference_node->right(), right_distance);
00432
00433 GNPRecursion_(query_node->left(),
00434 reference_node->left(), left_distance);
00435 }
00436
00437 left_distance =
00438 MinNodeDistSq_(query_node->right(), reference_node->left());
00439 right_distance =
00440 MinNodeDistSq_(query_node->right(), reference_node->right());
00441
00442 if (left_distance < right_distance) {
00443 GNPRecursion_(query_node->right(),
00444 reference_node->left(), left_distance);
00445 GNPRecursion_(query_node->right(),
00446 reference_node->right(), right_distance);
00447 } else {
00448
00449 GNPRecursion_(query_node->right(),
00450 reference_node->right(), right_distance);
00451
00452 GNPRecursion_(query_node->right(),
00453 reference_node->left(), left_distance);
00454 }
00455
00456 query_node->stat().set_max_distance_so_far(
00457 max(query_node->left()->stat().max_distance_so_far(),
00458 query_node->right()->stat().max_distance_so_far()));
00459
00460 }
00461
00462 }
00463
00465
00466
00467
00468
00469
00470
00471
00479 void Init(const Matrix& queries_in, const Matrix& references_in,
00480 struct datanode* module_in) {
00481 if (queries_in.ptr()==references_in.ptr()) {
00482 FATAL("Data matrices for query tree and reference tree should be different");
00483 }
00484
00485
00486 DEBUG_ASSERT(initialized_ == false);
00487 DEBUG_ONLY(initialized_ = true);
00488
00489 module_ = module_in;
00490
00491
00492 DEBUG_SAME_SIZE(queries_.n_rows(), references_.n_rows());
00493
00494
00495
00496 queries_.Alias(queries_in);
00497 references_.Alias(references_in);
00498
00499 leaf_size_ = fx_param_int(module_, "leaf_size", 20);
00500 DEBUG_ASSERT(leaf_size_ > 0);
00501
00502
00503
00504
00505 fx_timer_start(module_, "tree_building");
00506
00507
00508
00509
00510
00511
00512
00513
00514 query_tree_ = tree::MakeKdTreeMidpoint<TreeType>(
00515 queries_, leaf_size_, &old_from_new_queries_, NULL);
00516 reference_tree_ = tree::MakeKdTreeMidpoint<TreeType>(
00517 references_, leaf_size_, &old_from_new_references_, NULL);
00518
00519
00520
00521 fx_timer_stop(module_, "tree_building");
00522
00523
00524 neighbor_indices_.Init(queries_.n_cols());
00525
00526
00527 neighbor_distances_.Init(queries_.n_cols());
00528 neighbor_distances_.SetAll(DBL_MAX);
00529
00530 number_of_prunes_ = 0;
00531
00532 }
00533
00534 void Init(const Matrix& references_in, fx_module* module_in) {
00535
00536
00537
00538 DEBUG_ASSERT(initialized_ == false);
00539 DEBUG_ONLY(initialized_ = true);
00540
00541 module_ = module_in;
00542
00543
00544
00545 references_.Alias(references_in);
00546 queries_.Alias(references_in);
00547 leaf_size_ = fx_param_int(module_, "leaf_size", 20);
00548 DEBUG_ASSERT(leaf_size_ > 0);
00549
00550
00551
00552
00553 fx_timer_start(module_, "tree_building");
00554
00555
00556
00557
00558
00559
00560
00561
00562 query_tree_ = tree::MakeKdTreeMidpoint<TreeType>(
00563 queries_, leaf_size_, &old_from_new_queries_, NULL);
00564 reference_tree_ = query_tree_;
00565 old_from_new_references_.Alias(old_from_new_queries_);
00566
00567
00568
00569 fx_timer_stop(module_, "tree_building");
00570
00571
00572 neighbor_indices_.Init(queries_.n_cols());
00573
00574
00575 neighbor_distances_.Init(queries_.n_cols());
00576 neighbor_distances_.SetAll(DBL_MAX);
00577
00578 number_of_prunes_ = 0;
00579
00580 }
00581
00582 void Destruct() {
00583 if (query_tree_ != NULL) {
00584 delete query_tree_;
00585 }
00586 if (reference_tree_ != query_tree_) {
00587 delete reference_tree_;
00588 }
00589 queries_.Destruct();
00590 references_.Destruct();
00591 old_from_new_queries_.Destruct();
00592 old_from_new_references_.Destruct();
00593 neighbor_distances_.Destruct();
00594 neighbor_indices_.Destruct();
00595 }
00596
00597
00603 void InitNaive(const Matrix& queries_in, const Matrix& references_in,
00604 fx_module* module_in){
00605
00606 DEBUG_ASSERT(initialized_ == false);
00607 DEBUG_ONLY(initialized_ = true);
00608
00609 module_ = module_in;
00610
00611
00612 DEBUG_SAME_SIZE(queries_.n_rows(), references_.n_rows());
00613
00614
00615 queries_.Alias(queries_in);
00616 references_.Alias(references_in);
00617
00618
00619
00620
00621
00622 leaf_size_ = max(queries_.n_cols(), references_.n_cols());
00623
00624
00625 query_tree_ = tree::MakeKdTreeMidpoint<TreeType>(
00626 queries_, leaf_size_, &old_from_new_queries_, NULL);
00627 reference_tree_ = tree::MakeKdTreeMidpoint<TreeType>(
00628 references_, leaf_size_, &old_from_new_references_, NULL);
00629
00630
00631 neighbor_indices_.Init(queries_.n_cols());
00632
00633
00634 neighbor_distances_.Init(queries_.n_cols());
00635 neighbor_distances_.SetAll(DBL_MAX);
00636
00637 number_of_prunes_ = 0;
00638
00639 }
00640
00646 void InitNaive(const Matrix& references_in,
00647 fx_module* module_in){
00648
00649 DEBUG_ASSERT(initialized_ == false);
00650 DEBUG_ONLY(initialized_ = true);
00651
00652 module_ = module_in;
00653
00654
00655 DEBUG_SAME_SIZE(queries_.n_rows(), references_.n_rows());
00656
00657
00658 queries_.Alias(references_in);
00659 references_.Alias(references_in);
00660
00661
00662
00663
00664
00665 leaf_size_ = max(queries_.n_cols(), references_.n_cols());
00666
00667
00668 query_tree_ = tree::MakeKdTreeMidpoint<TreeType>(
00669 queries_, leaf_size_, &old_from_new_queries_, NULL);
00670 reference_tree_ = query_tree_;
00671 old_from_new_references_.Alias(old_from_new_queries_);
00672
00673 neighbor_indices_.Init(queries_.n_cols());
00674
00675
00676 neighbor_distances_.Init(queries_.n_cols());
00677 neighbor_distances_.SetAll(DBL_MAX);
00678
00679 number_of_prunes_ = 0;
00680
00681 }
00682
00683
00688 void ComputeNeighbors(GenVector<index_t>* results, GenVector<double>* distances) {
00689
00690
00691
00692 DEBUG_ASSERT(initialized_ == true);
00693 DEBUG_ASSERT(already_used_ == false);
00694 DEBUG_ONLY(already_used_ = true);
00695
00696 fx_timer_start(module_, "dual_tree_computation");
00697
00698
00699 GNPRecursion_(query_tree_, reference_tree_,
00700 MinNodeDistSq_(query_tree_, reference_tree_));
00701
00702 fx_timer_stop(module_, "dual_tree_computation");
00703
00704
00705
00706 fx_result_int(module_, "number_of_prunes", number_of_prunes_);
00707
00708 if (results!=NULL) {
00709 EmitResults(results, distances);
00710 }
00711
00712 }
00713
00717 void ComputeNaive(GenVector<index_t>* results, GenVector<double>* distances) {
00718
00719 DEBUG_ASSERT(initialized_ == true);
00720 DEBUG_ASSERT(already_used_ == false);
00721 DEBUG_ONLY(already_used_ = true);
00722
00723 fx_timer_start(module_, "naive_time");
00724
00725
00726 GNPBaseCase_(query_tree_, reference_tree_);
00727
00728 fx_timer_stop(module_, "naive_time");
00729
00730 if (results) {
00731 EmitResults(results, distances);
00732 }
00733
00734 }
00735
00739 void EmitResults(GenVector<index_t>* results, GenVector<double>* distances) {
00740
00741 DEBUG_ASSERT(initialized_ == true);
00742
00743 results->Init(neighbor_indices_.length());
00744 distances->Init(neighbor_distances_.length());
00745
00746
00747 for (index_t i = 0; i < neighbor_indices_.length(); i++) {
00748 (*results)[old_from_new_queries_[i]] =
00749 old_from_new_references_[neighbor_indices_[i]];
00750 (*distances)[
00751 old_from_new_references_[i]] = neighbor_distances_[i];
00752 }
00753
00754 }
00755
00756 };
00757
00758 #endif