00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00042
00043
00044
00045 #ifndef PLATONIC_ALLNN_H
00046 #define PLATONIC_ALLNN_H
00047
00048
00049
00050 #include <fastlib/fastlib.h>
00051
00052
00053
00054
00055
00056
00057
00058 const fx_entry_doc allnn_entries[] = {
00059 {"leaf_size", FX_PARAM, FX_INT, NULL,
00060 " The maximum number of points to store at a leaf.\n"},
00061 {"tree_building", FX_TIMER, FX_CUSTOM, NULL,
00062 " Time spent building the kd-tree.\n"},
00063 {"dual_tree_computation", FX_TIMER, FX_CUSTOM, NULL,
00064 " Time spent computing the nearest neighbors.\n"},
00065 {"number_of_prunes", FX_RESULT, FX_INT, NULL,
00066 " Total node-pairs found to be too far to matter.\n"},
00067 FX_ENTRY_DOC_DONE
00068 };
00069
00070 const fx_module_doc allnn_doc = {
00071 allnn_entries, NULL,
00072 "Performs dual-tree all-nearest-neighbors computation.\n"
00073 };
00074
00075 const fx_entry_doc allnn_naive_entries[] = {
00076 {"naive_time", FX_TIMER, FX_CUSTOM, NULL,
00077 " Time spend performing the naive computation.\n"},
00078 FX_ENTRY_DOC_DONE
00079 };
00080
00081 const fx_module_doc allnn_naive_doc = {
00082 allnn_naive_entries, NULL,
00083 "Performs naive all-nearest-neighbors computation.\n"
00084 };
00085
00107 class AllNN {
00108
00110
00111 private:
00119 class QueryStat {
00120
00121 private:
00126 double max_distance_so_far_;
00127
00128
00129
00130
00131
00132
00133
00134
00135 OBJECT_TRAVERSAL_SHALLOW(QueryStat) {
00136
00137
00138 OT_OBJ(max_distance_so_far_);
00139 }
00140
00141 public:
00142 double max_distance_so_far() {
00143 return max_distance_so_far_;
00144 }
00145
00146 void set_max_distance_so_far(double new_dist) {
00147 max_distance_so_far_ = new_dist;
00148 }
00149
00157 void Init(const Matrix& matrix, index_t start, index_t count) {
00158
00159 max_distance_so_far_ = DBL_MAX;
00160 }
00161
00168 void Init(const Matrix& matrix, index_t start, index_t count,
00169 const QueryStat& left, const QueryStat& right) {
00170 Init(matrix, start, count);
00171 }
00172
00173 };
00174
00175
00176
00177
00178
00179
00181 typedef BinarySpaceTree<DHrectBound<2>, Matrix, QueryStat> QueryTree;
00183 typedef BinarySpaceTree<DHrectBound<2>, Matrix> ReferenceTree;
00184
00186
00187 private:
00189 struct datanode* module_;
00190
00192 Matrix queries_;
00194 Matrix references_;
00195
00197 QueryTree* query_tree_;
00199 ReferenceTree* reference_tree_;
00200
00202 index_t leaf_size_;
00203
00205 ArrayList<index_t> old_from_new_queries_;
00207 ArrayList<index_t> old_from_new_references_;
00208
00213 Vector neighbor_distances_;
00218 ArrayList<index_t> neighbor_indices_;
00219
00221 index_t number_of_prunes_;
00222
00224 bool initialized_;
00226 bool already_used_;
00227
00228
00230
00231
00232
00233
00234
00235
00236
00237
00238
00239
00240
00241
00242
00243
00244
00245
00246
00247
00248
00249 FORBID_ACCIDENTAL_COPIES(AllNN);
00250
00251 public:
00252
00253
00254
00255
00256
00257
00258 AllNN() {
00259 query_tree_ = NULL;
00260 reference_tree_ = NULL;
00261
00262 DEBUG_POISON_PTR(module_);
00263 DEBUG_ONLY(leaf_size_ = BIG_BAD_NUMBER);
00264 DEBUG_ONLY(number_of_prunes_ = BIG_BAD_NUMBER);
00265
00266 DEBUG_ONLY(initialized_ = false);
00267 DEBUG_ONLY(already_used_ = false);
00268 }
00269
00270
00271 ~AllNN() {
00272 if (query_tree_ != NULL) {
00273 delete query_tree_;
00274 }
00275 if (reference_tree_ != NULL) {
00276 delete reference_tree_;
00277 }
00278 }
00279
00281
00286 double MinNodeDistSq_(QueryTree* query_node, ReferenceTree* reference_node) {
00287 return query_node->bound().MinDistanceSq(reference_node->bound());
00288 }
00289
00290
00296 void GNPBaseCase_(QueryTree* query_node, ReferenceTree* reference_node) {
00297
00298
00299
00300
00301
00302 DEBUG_ASSERT(query_node != NULL);
00303 DEBUG_ASSERT(reference_node != NULL);
00304
00305
00306 DEBUG_WARN_IF(!query_node->is_leaf());
00307 DEBUG_WARN_IF(!reference_node->is_leaf());
00308
00309
00310 double max_nearest_neighbor_distance = -1.0;
00311
00312
00313
00314
00315
00316
00317 for (index_t query_index = query_node->begin();
00318 query_index < query_node->end(); query_index++) {
00319
00320
00321
00322
00323
00324
00325
00326
00327
00328 Vector query_point;
00329 queries_.MakeColumnVector(query_index, &query_point);
00330
00331
00332
00333
00334
00335
00336
00337
00338 for (index_t reference_index = reference_node->begin();
00339 reference_index < reference_node->end(); reference_index++) {
00340
00341 Vector reference_point;
00342 references_.MakeColumnVector(reference_index, &reference_point);
00343
00344
00345 double distance =
00346 la::DistanceSqEuclidean(query_point, reference_point);
00347
00348
00349 if (distance < neighbor_distances_[query_index]) {
00350 neighbor_distances_[query_index] = distance;
00351 neighbor_indices_[query_index] = reference_index;
00352 }
00353
00354 }
00355
00356
00357 if (neighbor_distances_[query_index] > max_nearest_neighbor_distance) {
00358 max_nearest_neighbor_distance = neighbor_distances_[query_index];
00359 }
00360
00361 }
00362
00363
00364 query_node->stat().set_max_distance_so_far(max_nearest_neighbor_distance);
00365
00366 }
00367
00368
00373 void GNPRecursion_(QueryTree* query_node, ReferenceTree* reference_node,
00374 double lower_bound_distance) {
00375
00376
00377 DEBUG_ASSERT(query_node != NULL);
00378 DEBUG_ASSERT(reference_node != NULL);
00379
00380
00381
00382
00383
00384
00385
00386 DEBUG_SAME_DOUBLE(lower_bound_distance,
00387 MinNodeDistSq_(query_node, reference_node));
00388
00389 if (lower_bound_distance > query_node->stat().max_distance_so_far()) {
00390
00391
00392
00393
00394
00395
00396
00397 number_of_prunes_++;
00398
00399 } else if (query_node->is_leaf() && reference_node->is_leaf()) {
00400
00401
00402 GNPBaseCase_(query_node, reference_node);
00403
00404 } else if (query_node->is_leaf()) {
00405
00406
00407 double left_distance =
00408 MinNodeDistSq_(query_node, reference_node->left());
00409 double right_distance =
00410 MinNodeDistSq_(query_node, reference_node->right());
00411
00412
00413
00414
00415
00416 if (left_distance < right_distance) {
00417 GNPRecursion_(query_node, reference_node->left(), left_distance);
00418 GNPRecursion_(query_node, reference_node->right(), right_distance);
00419 } else {
00420 GNPRecursion_(query_node, reference_node->right(), right_distance);
00421 GNPRecursion_(query_node, reference_node->left(), left_distance);
00422 }
00423
00424 } else if (reference_node->is_leaf()) {
00425
00426
00427 double left_distance =
00428 MinNodeDistSq_(query_node->left(), reference_node);
00429 double right_distance =
00430 MinNodeDistSq_(query_node->right(), reference_node);
00431
00432
00433 GNPRecursion_(query_node->left(), reference_node, left_distance);
00434 GNPRecursion_(query_node->right(), reference_node, right_distance);
00435
00436
00437 query_node->stat().set_max_distance_so_far(
00438 max(query_node->left()->stat().max_distance_so_far(),
00439 query_node->right()->stat().max_distance_so_far()));
00440
00441 } else {
00442
00443
00444
00445
00446
00447
00448
00449
00450
00451 double left_distance =
00452 MinNodeDistSq_(query_node->left(), reference_node->left());
00453 double right_distance =
00454 MinNodeDistSq_(query_node->left(), reference_node->right());
00455
00456 if (left_distance < right_distance) {
00457 GNPRecursion_(query_node->left(),
00458 reference_node->left(), left_distance);
00459 GNPRecursion_(query_node->left(),
00460 reference_node->right(), right_distance);
00461 } else {
00462 GNPRecursion_(query_node->left(),
00463 reference_node->right(), right_distance);
00464 GNPRecursion_(query_node->left(),
00465 reference_node->left(), left_distance);
00466 }
00467
00468 left_distance =
00469 MinNodeDistSq_(query_node->right(), reference_node->left());
00470 right_distance =
00471 MinNodeDistSq_(query_node->right(), reference_node->right());
00472
00473 if (left_distance < right_distance) {
00474 GNPRecursion_(query_node->right(),
00475 reference_node->left(), left_distance);
00476 GNPRecursion_(query_node->right(),
00477 reference_node->right(), right_distance);
00478 } else {
00479 GNPRecursion_(query_node->right(),
00480 reference_node->right(), right_distance);
00481 GNPRecursion_(query_node->right(),
00482 reference_node->left(), left_distance);
00483 }
00484
00485
00486 query_node->stat().set_max_distance_so_far(
00487 max(query_node->left()->stat().max_distance_so_far(),
00488 query_node->right()->stat().max_distance_so_far()));
00489
00490 }
00491
00492 }
00493
00495
00496
00497
00498
00499
00500
00501
00505 void Init(const Matrix& queries_in, const Matrix& references_in,
00506 struct datanode* module_in) {
00507
00508
00509
00510 DEBUG_ASSERT(initialized_ == false);
00511 DEBUG_ONLY(initialized_ = true);
00512
00513 module_ = module_in;
00514
00515
00516 DEBUG_SAME_SIZE(queries_.n_rows(), references_.n_rows());
00517
00518
00519 queries_.Copy(queries_in);
00520 references_.Copy(references_in);
00521
00522 leaf_size_ = fx_param_int(module_, "leaf_size", 20);
00523 DEBUG_ASSERT(leaf_size_ > 0);
00524
00525
00526
00527 fx_timer_start(module_, "tree_building");
00528
00529
00530
00531
00532
00533
00534
00535
00536 query_tree_ = tree::MakeKdTreeMidpoint<QueryTree>(
00537 queries_, leaf_size_, &old_from_new_queries_, NULL);
00538 reference_tree_ = tree::MakeKdTreeMidpoint<ReferenceTree>(
00539 references_, leaf_size_, &old_from_new_references_, NULL);
00540
00541
00542
00543 fx_timer_stop(module_, "tree_building");
00544
00545
00546 neighbor_indices_.Init(queries_.n_cols());
00547
00548
00549 neighbor_distances_.Init(queries_.n_cols());
00550 neighbor_distances_.SetAll(DBL_MAX);
00551
00552 number_of_prunes_ = 0;
00553
00554 }
00555
00556
00562 void InitNaive(const Matrix& queries_in, const Matrix& references_in,
00563 struct datanode* module_in){
00564
00565 DEBUG_ASSERT(initialized_ == false);
00566 DEBUG_ONLY(initialized_ = true);
00567
00568 module_ = module_in;
00569
00570
00571 DEBUG_SAME_SIZE(queries_.n_rows(), references_.n_rows());
00572
00573
00574 queries_.Copy(queries_in);
00575 references_.Copy(references_in);
00576
00577
00578
00579
00580
00581 leaf_size_ = max(queries_.n_cols(), references_.n_cols());
00582
00583
00584 query_tree_ = tree::MakeKdTreeMidpoint<QueryTree>(
00585 queries_, leaf_size_, &old_from_new_queries_, NULL);
00586 reference_tree_ = tree::MakeKdTreeMidpoint<ReferenceTree>(
00587 references_, leaf_size_, &old_from_new_references_, NULL);
00588
00589
00590 neighbor_indices_.Init(queries_.n_cols());
00591
00592
00593 neighbor_distances_.Init(queries_.n_cols());
00594 neighbor_distances_.SetAll(DBL_MAX);
00595
00596 number_of_prunes_ = 0;
00597
00598 }
00599
00600
00605 void ComputeNeighbors(ArrayList<index_t>* results) {
00606
00607
00608
00609 DEBUG_ASSERT(initialized_ == true);
00610 DEBUG_ASSERT(already_used_ == false);
00611 DEBUG_ONLY(already_used_ = true);
00612
00613 fx_timer_start(module_, "dual_tree_computation");
00614
00615
00616 GNPRecursion_(query_tree_, reference_tree_,
00617 MinNodeDistSq_(query_tree_, reference_tree_));
00618
00619 fx_timer_stop(module_, "dual_tree_computation");
00620
00621
00622
00623 fx_result_int(module_, "number_of_prunes", number_of_prunes_);
00624
00625 if (results) {
00626 EmitResults(results);
00627 }
00628
00629 }
00630
00631
00635 void ComputeNaive(ArrayList<index_t>* results) {
00636
00637 DEBUG_ASSERT(initialized_ == true);
00638 DEBUG_ASSERT(already_used_ == false);
00639 DEBUG_ONLY(already_used_ = true);
00640
00641 fx_timer_start(module_, "naive_time");
00642
00643
00644 GNPBaseCase_(query_tree_, reference_tree_);
00645
00646 fx_timer_stop(module_, "naive_time");
00647
00648 if (results) {
00649 EmitResults(results);
00650 }
00651
00652 }
00653
00657 void EmitResults(ArrayList<index_t>* results) {
00658
00659 DEBUG_ASSERT(initialized_ == true);
00660
00661 results->Init(neighbor_indices_.size());
00662
00663
00664 for (index_t i = 0; i < neighbor_indices_.size(); i++) {
00665 (*results)[old_from_new_queries_[i]] =
00666 old_from_new_references_[neighbor_indices_[i]];
00667 }
00668
00669 }
00670
00671 };
00672
00673 #endif