00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00039
00040 #ifndef ALLKNN_H
00041 #define ALLKNN_H
00042
00043
00044
00045
00046 #include <fastlib/fastlib.h>
00047 #include <vector>
00048 #include <string>
00052 class TestAllkNN;
00057 class AllkNN {
00058
00059
00060 friend class TestAllkNN;
00061
00063
00068 class QueryStat {
00069
00070
00071
00072 OT_DEF_BASIC(QueryStat) {
00073
00074
00075 OT_MY_OBJECT(max_distance_so_far_);
00076 }
00077
00078 private:
00079
00083 double max_distance_so_far_;
00084
00085 public:
00086
00087 double max_distance_so_far() {
00088 return max_distance_so_far_;
00089 }
00090
00091
00092 void set_max_distance_so_far(double new_dist) {
00093 max_distance_so_far_ = new_dist;
00094 }
00095
00096
00097
00098
00099
00105 void Init(const Matrix& matrix, index_t start, index_t count) {
00106
00107 max_distance_so_far_ = DBL_MAX;
00108 }
00109
00114 void Init(const Matrix& matrix, index_t start, index_t count,
00115 const QueryStat& left, const QueryStat& right) {
00116
00117 Init(matrix, start, count);
00118 }
00119
00120 };
00121
00122
00123
00124
00125 typedef BinarySpaceTree<DHrectBound<2>, Matrix, QueryStat> TreeType;
00126
00127
00129 private:
00130
00131 Matrix queries_;
00132 Matrix references_;
00133
00134 TreeType* query_tree_;
00135 TreeType* reference_tree_;
00136
00137 index_t number_of_prunes_;
00138
00139 ArrayList<index_t> old_from_new_queries_;
00140 ArrayList<index_t> old_from_new_references_;
00141
00142 index_t leaf_size_;
00143
00144 Vector neighbor_distances_;
00145
00146 ArrayList<index_t> neighbor_indices_;
00147
00148 index_t knns_;
00149
00150 bool k_only_;
00151
00152 std::string mode_;
00153
00154 struct datanode* module_;
00155
00156
00158
00159
00160 FORBID_ACCIDENTAL_COPIES(AllkNN);
00161
00162
00163 public:
00164
00169 AllkNN() {
00170 query_tree_ = NULL;
00171 reference_tree_ = NULL;
00172 }
00173
00177 ~AllkNN() {
00178 if (query_tree_ != NULL) {
00179 delete query_tree_;
00180 }
00181 if (reference_tree_ != NULL) {
00182 delete reference_tree_;
00183 }
00184 }
00185
00186
00188
00192 double MinNodeDistSq_ (TreeType* query_node, TreeType* reference_node) {
00193
00194
00195 return query_node->bound().MinDistanceSq(reference_node->bound());
00196 }
00197
00201 double MinPointNodeDistSq_ (const Vector& query_point, TreeType* reference_node) {
00202
00203
00204 return reference_node->bound().MinDistanceSq(query_point);
00205 }
00206
00207
00211 void ComputeBaseCase_(TreeType* query_node, TreeType* reference_node) {
00212
00213
00214
00215
00216
00217 DEBUG_ASSERT(query_node != NULL);
00218 DEBUG_ASSERT(reference_node != NULL);
00219
00220 DEBUG_WARN_IF(!query_node->is_leaf());
00221 DEBUG_WARN_IF(!reference_node->is_leaf());
00222
00223
00224 double query_max_neighbor_distance = -1.0;
00225 std::vector<std::pair<double, index_t> > neighbors(knns_);
00226
00227
00228 for (index_t query_index = query_node->begin();
00229 query_index < query_node->end(); query_index++) {
00230
00231
00232 Vector query_point;
00233 queries_.MakeColumnVector(query_index, &query_point);
00234
00235 index_t ind = query_index*knns_;
00236 for(index_t i=0; i<knns_; i++) {
00237 neighbors[i]=std::make_pair(neighbor_distances_[ind+i],
00238 neighbor_indices_[ind+i]);
00239 }
00240
00241 double query_to_node_distance =
00242 MinPointNodeDistSq_(query_point, reference_node);
00243 if (query_to_node_distance < neighbor_distances_[ind+knns_-1]) {
00244
00245 for (index_t reference_index = reference_node->begin();
00246 reference_index < reference_node->end(); reference_index++) {
00247
00248
00249
00250 if (likely(reference_node != query_node ||
00251 reference_index != query_index)) {
00252 Vector reference_point;
00253 references_.MakeColumnVector(reference_index, &reference_point);
00254
00255 double distance =
00256 la::DistanceSqEuclidean(query_point, reference_point);
00257
00258
00259 if (distance < neighbor_distances_[ind+knns_-1]) {
00260 neighbors.push_back(std::make_pair(distance, reference_index));
00261 }
00262 }
00263 }
00264
00265 std::sort(neighbors.begin(), neighbors.end());
00266 for(index_t i=0; i<knns_; i++) {
00267 neighbor_distances_[ind+i] = neighbors[i].first;
00268 neighbor_indices_[ind+i] = neighbors[i].second;
00269 }
00270 neighbors.resize(knns_);
00271 }
00272
00273 if (neighbor_distances_[ind+knns_-1] > query_max_neighbor_distance) {
00274 query_max_neighbor_distance = neighbor_distances_[ind+knns_-1];
00275 }
00276
00277 }
00278
00279 query_node->stat().set_max_distance_so_far(query_max_neighbor_distance);
00280
00281 }
00282
00283
00287 void ComputeDualNeighborsRecursion_(TreeType* query_node, TreeType* reference_node,
00288 double lower_bound_distance) {
00289
00290
00291
00292
00293
00294 DEBUG_ASSERT(query_node != NULL);
00295
00296 DEBUG_ASSERT_MSG(reference_node != NULL, "reference node is null");
00297
00298 DEBUG_ASSERT(lower_bound_distance == MinNodeDistSq_(query_node,
00299 reference_node));
00300
00301 if (lower_bound_distance > query_node->stat().max_distance_so_far()) {
00302
00303 number_of_prunes_++;
00304 }
00305
00306 else if (query_node->is_leaf() && reference_node->is_leaf()) {
00307
00308 ComputeBaseCase_(query_node, reference_node);
00309 }
00310 else if (query_node->is_leaf()) {
00311
00312
00313
00314 double left_distance = MinNodeDistSq_(query_node, reference_node->left());
00315 double right_distance = MinNodeDistSq_(query_node, reference_node->right());
00316
00317 if (left_distance < right_distance) {
00318 ComputeDualNeighborsRecursion_(query_node, reference_node->left(),
00319 left_distance);
00320 ComputeDualNeighborsRecursion_(query_node, reference_node->right(),
00321 right_distance);
00322 }
00323 else {
00324 ComputeDualNeighborsRecursion_(query_node, reference_node->right(),
00325 right_distance);
00326 ComputeDualNeighborsRecursion_(query_node, reference_node->left(),
00327 left_distance);
00328 }
00329
00330 }
00331
00332 else if (reference_node->is_leaf()) {
00333
00334
00335 double left_distance = MinNodeDistSq_(query_node->left(), reference_node);
00336 double right_distance = MinNodeDistSq_(query_node->right(), reference_node);
00337
00338 ComputeDualNeighborsRecursion_(query_node->left(), reference_node,
00339 left_distance);
00340 ComputeDualNeighborsRecursion_(query_node->right(), reference_node,
00341 right_distance);
00342
00343
00344
00345 query_node->stat().set_max_distance_so_far(
00346 max(query_node->left()->stat().max_distance_so_far(),
00347 query_node->right()->stat().max_distance_so_far()));
00348 } else {
00349
00350
00351 double left_distance = MinNodeDistSq_(query_node->left(),
00352 reference_node->left());
00353 double right_distance = MinNodeDistSq_(query_node->left(),
00354 reference_node->right());
00355
00356 if (left_distance < right_distance) {
00357 ComputeDualNeighborsRecursion_(query_node->left(), reference_node->left(),
00358 left_distance);
00359 ComputeDualNeighborsRecursion_(query_node->left(), reference_node->right(),
00360 right_distance);
00361 }
00362 else {
00363 ComputeDualNeighborsRecursion_(query_node->left(), reference_node->right(),
00364 right_distance);
00365 ComputeDualNeighborsRecursion_(query_node->left(), reference_node->left(),
00366 left_distance);
00367 }
00368 left_distance = MinNodeDistSq_(query_node->right(), reference_node->left());
00369 right_distance = MinNodeDistSq_(query_node->right(),
00370 reference_node->right());
00371
00372 if (left_distance < right_distance) {
00373 ComputeDualNeighborsRecursion_(query_node->right(), reference_node->left(),
00374 left_distance);
00375 ComputeDualNeighborsRecursion_(query_node->right(), reference_node->right(),
00376 right_distance);
00377 }
00378 else {
00379 ComputeDualNeighborsRecursion_(query_node->right(), reference_node->right(),
00380 right_distance);
00381 ComputeDualNeighborsRecursion_(query_node->right(), reference_node->left(),
00382 left_distance);
00383 }
00384
00385
00386 query_node->stat().set_max_distance_so_far(
00387 max(query_node->left()->stat().max_distance_so_far(),
00388 query_node->right()->stat().max_distance_so_far()));
00389
00390 }
00391
00392 }
00393
00394
00395 void ComputeSingleNeighborsRecursion_(index_t point_id,
00396 Vector &point, TreeType* reference_node,
00397 double *min_dist_so_far) {
00398
00399
00400 DEBUG_ASSERT_MSG(reference_node != NULL, "reference node is null");
00401
00402
00403
00404 if (reference_node->is_leaf()) {
00405
00406 std::vector<std::pair<double, index_t> > neighbors(knns_);
00407 index_t ind = point_id*knns_;
00408 for(index_t i=0; i<knns_; i++) {
00409 neighbors[i]=std::make_pair(neighbor_distances_[ind+i],
00410 neighbor_indices_[ind+i]);
00411 }
00412
00413 for (index_t reference_index = reference_node->begin();
00414 reference_index < reference_node->end(); reference_index++) {
00415
00416
00417 if (likely(!(references_.ptr()==queries_.ptr() &&
00418 reference_index == point_id))) {
00419 Vector reference_point;
00420 references_.MakeColumnVector(reference_index, &reference_point);
00421
00422 double distance =
00423 la::DistanceSqEuclidean(point, reference_point);
00424
00425
00426 if (distance < neighbor_distances_[ind+knns_-1]) {
00427 neighbors.push_back(std::make_pair(distance, reference_index));
00428 }
00429 }
00430 }
00431 std::sort(neighbors.begin(), neighbors.end());
00432 for(index_t i=0; i<knns_; i++) {
00433 neighbor_distances_[ind+i] = neighbors[i].first;
00434 neighbor_indices_[ind+i] = neighbors[i].second;
00435 }
00436 *min_dist_so_far=neighbor_distances_[ind+knns_-1];
00437 } else {
00438
00439 double left_distance = reference_node->left()->bound().MinDistanceSq(point);
00440 double right_distance = reference_node->right()->bound().MinDistanceSq(point);
00441
00442 if (left_distance < right_distance) {
00443 ComputeSingleNeighborsRecursion_(point_id, point, reference_node->left(),
00444 min_dist_so_far);
00445 if (*min_dist_so_far <right_distance){
00446 number_of_prunes_++;
00447 return;
00448 }
00449 ComputeSingleNeighborsRecursion_(point_id, point, reference_node->right(),
00450 min_dist_so_far);
00451 } else {
00452 ComputeSingleNeighborsRecursion_(point_id, point, reference_node->right(),
00453 min_dist_so_far);
00454 if (*min_dist_so_far <left_distance){
00455 number_of_prunes_++;
00456 return;
00457 }
00458 ComputeSingleNeighborsRecursion_(point_id, point, reference_node->left(),
00459 min_dist_so_far);
00460 }
00461 }
00462 }
00464
00470 void Init(const Matrix& queries_in, const Matrix& references_in, struct datanode* module_in) {
00471
00472
00473 module_ = module_in;
00474
00475
00476 number_of_prunes_ = 0;
00477
00478 mode_=fx_param_str(module_, "mode", "dual");
00479
00480 leaf_size_ = fx_param_int(module_, "leaf_size", 20);
00481
00482 DEBUG_ASSERT(leaf_size_ > 0);
00483
00484
00485 queries_.Copy(queries_in);
00486 references_.Copy(references_in);
00487
00488
00489 DEBUG_SAME_SIZE(queries_.n_rows(), references_.n_rows());
00490
00491
00492 knns_ = fx_param_int(module_, "knns", 5);
00493
00494
00495 neighbor_indices_.Init(queries_.n_cols() * knns_);
00496
00497
00498 neighbor_distances_.Init(queries_.n_cols() * knns_);
00499 neighbor_distances_.SetAll(DBL_MAX);
00500
00501
00502 fx_timer_start(module_, "tree_building");
00503
00504
00505
00506
00507 if (mode_=="dual") {
00508 query_tree_ = tree::MakeKdTreeMidpoint<TreeType>(queries_, leaf_size_,
00509 &old_from_new_queries_, NULL);
00510 } else {
00511 query_tree_=NULL;
00512 }
00513 reference_tree_ = tree::MakeKdTreeMidpoint<TreeType>(references_,
00514 leaf_size_, &old_from_new_references_, NULL);
00515
00516
00517 fx_timer_stop(module_, "tree_building");
00518
00519 }
00520
00524 void Init(const Matrix& references_in, struct datanode* module_in) {
00525
00526
00527 module_ = module_in;
00528
00529 mode_=fx_param_str(module_, "mode", "dual");
00530
00531
00532 number_of_prunes_ = 0;
00533
00534
00535 leaf_size_ = fx_param_int(module_, "leaf_size", 20);
00536
00537 DEBUG_ASSERT(leaf_size_ > 0);
00538
00539
00540 references_.Copy(references_in);
00541 queries_.Alias(references_);
00542
00543 knns_ = fx_param_int(module_, "knns", 5);
00544
00545
00546 neighbor_indices_.Init(references_.n_cols() * knns_);
00547
00548
00549 neighbor_distances_.Init(references_.n_cols() * knns_);
00550 neighbor_distances_.SetAll(DBL_MAX);
00551
00552
00553 fx_timer_start(module_, "tree_building");
00554
00555
00556
00557
00558 query_tree_ = NULL;
00559 reference_tree_ = tree::MakeKdTreeMidpoint<TreeType>(references_,
00560 leaf_size_, &old_from_new_references_, NULL);
00561
00562
00563 fx_timer_stop(module_, "tree_building");
00564
00565 }
00566 void Init(const Matrix& queries_in, const Matrix& references_in,
00567 index_t leaf_size, index_t knns, const char *mode="dual") {
00568
00569
00570 module_ = NULL;
00571
00572
00573 number_of_prunes_ = 0;
00574 mode_=mode;
00575
00576
00577 leaf_size_ = leaf_size;
00578 DEBUG_ASSERT(leaf_size_ > 0);
00579
00580
00581 knns_ = knns;
00582 DEBUG_ASSERT(knns_ > 0);
00583
00584 queries_.Copy(queries_in);
00585 references_.Copy(references_in);
00586
00587
00588 DEBUG_SAME_SIZE(queries_.n_rows(), references_.n_rows());
00589
00590
00591
00592 neighbor_indices_.Init(queries_.n_cols() * knns_);
00593
00594
00595 neighbor_distances_.Init(queries_.n_cols() * knns_);
00596 neighbor_distances_.SetAll(DBL_MAX);
00597
00598
00599
00600
00601
00602 if (mode_=="dual") {
00603 query_tree_ = tree::MakeKdTreeMidpoint<TreeType>(queries_, leaf_size_,
00604 &old_from_new_queries_, NULL);
00605 } else {
00606 query_tree_=NULL;
00607 }
00608 reference_tree_ = tree::MakeKdTreeMidpoint<TreeType>(references_,
00609 leaf_size_, &old_from_new_references_, NULL);
00610
00611 }
00612
00613 void Init(const Matrix& references_in, index_t leaf_size,
00614 index_t knns, const char *mode="dual") {
00615
00616 module_ = NULL;
00617
00618
00619 number_of_prunes_ = 0;
00620 mode_=mode;
00621
00622
00623 leaf_size_ = leaf_size;
00624 DEBUG_ASSERT(leaf_size_ > 0);
00625
00626
00627 knns_ = knns;
00628 DEBUG_ASSERT(knns_ > 0);
00629
00630 references_.Copy(references_in);
00631 queries_.Alias(references_);
00632
00633
00634 neighbor_indices_.Init(references_.n_cols() * knns_);
00635
00636
00637 neighbor_distances_.Init(references_.n_cols() * knns_);
00638 neighbor_distances_.SetAll(DBL_MAX);
00639
00640
00641
00642
00643
00644 query_tree_ = NULL;
00645 reference_tree_ = tree::MakeKdTreeMidpoint<TreeType>(references_,
00646 leaf_size_, &old_from_new_references_, NULL);
00647
00648 old_from_new_queries_.Init();
00649 }
00654 void Destruct() {
00655 if (query_tree_ != NULL) {
00656 delete query_tree_;
00657 }
00658 if (reference_tree_ != NULL) {
00659 delete reference_tree_;
00660 }
00661 queries_.Destruct();
00662 references_.Destruct();
00663 old_from_new_queries_.Renew();
00664 old_from_new_references_.Renew();
00665 neighbor_distances_.Destruct();
00666 neighbor_indices_.Renew();
00667 }
00668 void InitNaive(const Matrix& queries_in,
00669 const Matrix& references_in, index_t knns){
00670
00671 queries_.Copy(queries_in);
00672 references_.Copy(references_in);
00673 knns_=knns;
00674
00675 DEBUG_SAME_SIZE(queries_.n_rows(), references_.n_rows());
00676
00677 neighbor_indices_.Init(queries_.n_cols()*knns_);
00678 neighbor_distances_.Init(queries_.n_cols()*knns_);
00679 neighbor_distances_.SetAll(DBL_MAX);
00680
00681
00682
00683 leaf_size_ = max(queries_.n_cols(), references_.n_cols());
00684
00685 query_tree_ = tree::MakeKdTreeMidpoint<TreeType>(queries_,
00686 leaf_size_, &old_from_new_queries_, NULL);
00687 reference_tree_ = tree::MakeKdTreeMidpoint<TreeType>(
00688 references_, leaf_size_, &old_from_new_references_, NULL);
00689
00690 }
00691
00692 void InitNaive(const Matrix& references_in, index_t knns){
00693
00694 references_.Copy(references_in);
00695 queries_.Alias(references_);
00696 knns_=knns;
00697
00698 neighbor_indices_.Init(references_.n_cols()*knns_);
00699 neighbor_distances_.Init(references_.n_cols()*knns_);
00700 neighbor_distances_.SetAll(DBL_MAX);
00701
00702
00703
00704 leaf_size_ = references_.n_cols();
00705
00706 query_tree_ = NULL;
00707 reference_tree_ = tree::MakeKdTreeMidpoint<TreeType>(
00708 references_, leaf_size_, &old_from_new_references_, NULL);
00709
00710 old_from_new_queries_.Init();
00711 }
00712
00716 void ComputeNeighbors(ArrayList<index_t>* resulting_neighbors,
00717 ArrayList<double>* distances) {
00718 fx_timer_start(module_, "computing_neighbors");
00719 if (mode_=="dual") {
00720
00721 if (query_tree_!=NULL) {
00722 ComputeDualNeighborsRecursion_(query_tree_, reference_tree_,
00723 MinNodeDistSq_(query_tree_, reference_tree_));
00724 } else {
00725 ComputeDualNeighborsRecursion_(reference_tree_, reference_tree_,
00726 MinNodeDistSq_(reference_tree_, reference_tree_));
00727 }
00728 } else {
00729 index_t chunk = queries_.n_cols()/10;
00730 printf("Progress:00%%");
00731 fflush(stdout);
00732 for(index_t i=0; i<10; i++) {
00733 for(index_t j=0; j<chunk; j++) {
00734 Vector point;
00735 point.Alias(queries_.GetColumnPtr(i*chunk+j), queries_.n_rows());
00736 double min_dist_so_far=DBL_MAX;
00737 ComputeSingleNeighborsRecursion_(i*chunk+j, point, reference_tree_, &min_dist_so_far);
00738 }
00739 printf("\b\b\b%02"LI"d%%", (i+1)*10);
00740 fflush(stdout);
00741 }
00742 for(index_t i=0; i<queries_.n_cols() % 10; i++) {
00743 index_t ind = (queries_.n_cols()/10)*10+i;
00744 Vector point;
00745 point.Alias(queries_.GetColumnPtr(ind), queries_.n_rows());
00746 double min_dist_so_far=DBL_MAX;
00747 ComputeSingleNeighborsRecursion_(i, point, reference_tree_, &min_dist_so_far);
00748 }
00749 printf("\n");
00750 }
00751 fx_timer_stop(module_, "computing_neighbors");
00752
00753 resulting_neighbors->Init(neighbor_indices_.size());
00754 distances->Init(neighbor_distances_.length());
00755
00756
00757 if (query_tree_ != NULL) {
00758 for (index_t i = 0; i < neighbor_indices_.size(); i++) {
00759 (*resulting_neighbors)[
00760 old_from_new_queries_[i/knns_]*knns_+ i%knns_] =
00761 old_from_new_references_[neighbor_indices_[i]];
00762 (*distances)[
00763 old_from_new_queries_[i/knns_]*knns_+ i%knns_] =
00764 neighbor_distances_[i];
00765 }
00766 } else {
00767 for (index_t i = 0; i < neighbor_indices_.size(); i++) {
00768 (*resulting_neighbors)[
00769 old_from_new_references_[i/knns_]*knns_+ i%knns_] =
00770 old_from_new_references_[neighbor_indices_[i]];
00771 (*distances)[
00772 old_from_new_references_[i/knns_]*knns_+ i%knns_] =
00773 neighbor_distances_[i];
00774 }
00775 }
00776 }
00777
00778
00782 void ComputeNaive(ArrayList<index_t>* resulting_neighbors,
00783 ArrayList<double>* distances) {
00784 if (query_tree_!=NULL) {
00785 ComputeBaseCase_(query_tree_, reference_tree_);
00786 } else {
00787 ComputeBaseCase_(reference_tree_, reference_tree_);
00788 }
00789
00790
00791 resulting_neighbors->Init(neighbor_indices_.size());
00792 distances->Init(neighbor_distances_.length());
00793
00794
00795 for (index_t i = 0; i < neighbor_indices_.size(); i++) {
00796 (*resulting_neighbors)[
00797 old_from_new_references_[i/knns_]*knns_+ i%knns_] =
00798 old_from_new_references_[neighbor_indices_[i]];
00799 (*distances)[
00800 old_from_new_references_[i/knns_]*knns_+ i%knns_] =
00801 neighbor_distances_[i];
00802
00803 }
00804 }
00805
00806 };
00807
00808
00809 #endif
00810