00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00039
00040 #ifndef ALLKFN_H
00041 #define ALLKFN_H
00042
00043
00044
00045
00046 #include <fastlib/fastlib.h>
00047 #include <vector>
00048 #include <functional>
00052 class TestAllkFN;
00057 class AllkFN {
00058
00059
00060 friend class TestAllkNFN;
00061
00063
00068 class QueryStat {
00069
00070
00071
00072 OT_DEF_BASIC(QueryStat) {
00073
00074
00075 OT_MY_OBJECT(min_distance_so_far_);
00076 }
00077
00078 private:
00079
00083 double min_distance_so_far_;
00084
00085 public:
00086
00087 double min_distance_so_far() {
00088 return min_distance_so_far_;
00089 }
00090
00091
00092 void set_min_distance_so_far(double new_dist) {
00093 min_distance_so_far_ = new_dist;
00094 }
00095
00096
00097
00098
00099
00105 void Init(const Matrix& matrix, index_t start, index_t count) {
00106
00107 min_distance_so_far_ = 0;
00108 }
00109
00114 void Init(const Matrix& matrix, index_t start, index_t count,
00115 const QueryStat& left, const QueryStat& right) {
00116
00117 Init(matrix, start, count);
00118 }
00119
00120 };
00121
00122
00123
00124
00125 typedef BinarySpaceTree<DHrectBound<2>, Matrix, QueryStat> TreeType;
00126
00127
00129 private:
00130
00131 Matrix queries_;
00132 Matrix references_;
00133
00134 TreeType* query_tree_;
00135 TreeType* reference_tree_;
00136
00137 index_t number_of_prunes_;
00138
00139 ArrayList<index_t> old_from_new_queries_;
00140 ArrayList<index_t> old_from_new_references_;
00141
00142 index_t leaf_size_;
00143
00144 Vector neighbor_distances_;
00145
00146 ArrayList<index_t> neighbor_indices_;
00147
00148 index_t kfns_;
00149
00150 struct datanode* module_;
00151
00152
00154
00155
00156 FORBID_ACCIDENTAL_COPIES(AllkFN);
00157
00158
00159 public:
00160
00165 AllkFN() {
00166 query_tree_ = NULL;
00167 reference_tree_ = NULL;
00168 }
00169
00173 ~AllkFN() {
00174 if (query_tree_ != NULL) {
00175 delete query_tree_;
00176 }
00177 if (reference_tree_ != NULL) {
00178 delete reference_tree_;
00179 }
00180 }
00181
00182
00184
00188 double MaxNodeDistSq_ (TreeType* query_node, TreeType* reference_node) {
00189
00190
00191 return query_node->bound().MaxDistanceSq(reference_node->bound());
00192 }
00193
00194
00198 void ComputeBaseCase_(TreeType* query_node, TreeType* reference_node) {
00199
00200
00201
00202
00203
00204 DEBUG_ASSERT(query_node != NULL);
00205 DEBUG_ASSERT(reference_node != NULL);
00206
00207 DEBUG_WARN_IF(!query_node->is_leaf());
00208 DEBUG_WARN_IF(!reference_node->is_leaf());
00209
00210
00211 double query_min_neighbor_distance = DBL_MAX;
00212 std::vector<std::pair<double, index_t> > neighbors(kfns_);
00213
00214
00215 for (index_t query_index = query_node->begin();
00216 query_index < query_node->end(); query_index++) {
00217
00218
00219 Vector query_point;
00220 queries_.MakeColumnVector(query_index, &query_point);
00221
00222 index_t ind = query_index*kfns_;
00223 for(index_t i=0; i<kfns_; i++) {
00224 neighbors[i]=std::make_pair(neighbor_distances_[ind+i],
00225 neighbor_indices_[ind+i]);
00226 }
00227
00228 for (index_t reference_index = reference_node->begin();
00229 reference_index < reference_node->end(); reference_index++) {
00230
00231
00232
00233 if (likely(reference_node != query_node ||
00234 reference_index != query_index)) {
00235 Vector reference_point;
00236 references_.MakeColumnVector(reference_index, &reference_point);
00237
00238 double distance =
00239 la::DistanceSqEuclidean(query_point, reference_point);
00240
00241
00242 if (distance > neighbor_distances_[ind+kfns_-1]) {
00243 neighbors.push_back(std::make_pair(distance, reference_index));
00244 }
00245 }
00246 }
00247 std::sort(neighbors.begin(), neighbors.end(),
00248 std::greater<std::pair<double, index_t> >());
00249 for(index_t i=0; i<kfns_; i++) {
00250 neighbor_distances_[ind+i] = neighbors[i].first;
00251 neighbor_indices_[ind+i] = neighbors[i].second;
00252 }
00253 neighbors.resize(kfns_);
00254
00255 if (neighbor_distances_[ind+kfns_-1] < query_min_neighbor_distance) {
00256 query_min_neighbor_distance = neighbor_distances_[ind+kfns_-1];
00257 }
00258 }
00259
00260
00261 query_node->stat().set_min_distance_so_far(query_min_neighbor_distance);
00262 }
00263
00264
00268 void ComputeNeighborsRecursion_ (TreeType* query_node,
00269 TreeType* reference_node,
00270 double higher_bound_distance) {
00271
00272
00273
00274
00275
00276 DEBUG_ASSERT(query_node != NULL);
00277
00278 DEBUG_ASSERT_MSG(reference_node != NULL, "reference node is null");
00279
00280 DEBUG_ASSERT(higher_bound_distance == MaxNodeDistSq_(query_node,
00281 reference_node));
00282
00283 if (higher_bound_distance < query_node->stat().min_distance_so_far()) {
00284
00285 number_of_prunes_++;
00286 }
00287
00288 else if (query_node->is_leaf() && reference_node->is_leaf()) {
00289
00290 ComputeBaseCase_(query_node, reference_node);
00291 }
00292 else if (query_node->is_leaf()) {
00293
00294
00295
00296 double left_distance = MaxNodeDistSq_(query_node, reference_node->left());
00297 double right_distance = MaxNodeDistSq_(query_node, reference_node->right());
00298
00299 if (left_distance > right_distance) {
00300 ComputeNeighborsRecursion_(query_node, reference_node->left(),
00301 left_distance);
00302 ComputeNeighborsRecursion_(query_node, reference_node->right(),
00303 right_distance);
00304 }
00305 else {
00306 ComputeNeighborsRecursion_(query_node, reference_node->right(),
00307 right_distance);
00308 ComputeNeighborsRecursion_(query_node, reference_node->left(),
00309 left_distance);
00310 }
00311
00312 }
00313
00314 else if (reference_node->is_leaf()) {
00315
00316
00317 double left_distance = MaxNodeDistSq_(query_node->left(), reference_node);
00318 double right_distance = MaxNodeDistSq_(query_node->right(), reference_node);
00319
00320 ComputeNeighborsRecursion_(query_node->left(), reference_node,
00321 left_distance);
00322 ComputeNeighborsRecursion_(query_node->right(), reference_node,
00323 right_distance);
00324
00325
00326
00327 query_node->stat().set_min_distance_so_far(
00328 min(query_node->left()->stat().min_distance_so_far(),
00329 query_node->right()->stat().min_distance_so_far()));
00330 } else {
00331
00332
00333 double left_distance = MaxNodeDistSq_(query_node->left(),
00334 reference_node->left());
00335 double right_distance = MaxNodeDistSq_(query_node->left(),
00336 reference_node->right());
00337
00338 if (left_distance > right_distance) {
00339 ComputeNeighborsRecursion_(query_node->left(), reference_node->left(),
00340 left_distance);
00341 ComputeNeighborsRecursion_(query_node->left(), reference_node->right(),
00342 right_distance);
00343 }
00344 else {
00345 ComputeNeighborsRecursion_(query_node->left(), reference_node->right(),
00346 right_distance);
00347 ComputeNeighborsRecursion_(query_node->left(), reference_node->left(),
00348 left_distance);
00349 }
00350 left_distance = MaxNodeDistSq_(query_node->right(), reference_node->left());
00351 right_distance = MaxNodeDistSq_(query_node->right(),
00352 reference_node->right());
00353
00354 if (left_distance > right_distance) {
00355 ComputeNeighborsRecursion_(query_node->right(), reference_node->left(),
00356 left_distance);
00357 ComputeNeighborsRecursion_(query_node->right(), reference_node->right(),
00358 right_distance);
00359 }
00360 else {
00361 ComputeNeighborsRecursion_(query_node->right(), reference_node->right(),
00362 right_distance);
00363 ComputeNeighborsRecursion_(query_node->right(), reference_node->left(),
00364 left_distance);
00365 }
00366
00367
00368 query_node->stat().set_min_distance_so_far(
00369 min(query_node->left()->stat().min_distance_so_far(),
00370 query_node->right()->stat().min_distance_so_far()));
00371
00372 }
00373
00374 }
00375
00376
00377
00378
00380
00386 void Init(const Matrix& queries_in, const Matrix& references_in, struct datanode* module_in) {
00387
00388
00389 module_ = module_in;
00390
00391
00392 number_of_prunes_ = 0;
00393
00394
00395 leaf_size_ = fx_param_int(module_, "leaf_size", 20);
00396
00397 DEBUG_ASSERT(leaf_size_ > 0);
00398
00399
00400 queries_.Copy(queries_in);
00401 references_.Copy(references_in);
00402
00403
00404 DEBUG_SAME_SIZE(queries_.n_rows(), references_.n_rows());
00405
00406
00407 kfns_ = fx_param_int(module_, "kfns", 1);
00408
00409
00410 neighbor_indices_.Init(queries_.n_cols() * kfns_);
00411
00412
00413 neighbor_distances_.Init(queries_.n_cols() * kfns_);
00414 neighbor_distances_.SetAll(0);
00415
00416
00417 fx_timer_start(module_, "tree_building");
00418
00419
00420
00421
00422 query_tree_ = tree::MakeKdTreeMidpoint<TreeType>(queries_, leaf_size_,
00423 &old_from_new_queries_, NULL);
00424 reference_tree_ = tree::MakeKdTreeMidpoint<TreeType>(references_,
00425 leaf_size_, &old_from_new_references_, NULL);
00426
00427
00428 fx_timer_stop(module_, "tree_building");
00429
00430 }
00431
00435 void Init(const Matrix& references_in, struct datanode* module_in) {
00436
00437
00438 module_ = module_in;
00439
00440
00441 number_of_prunes_ = 0;
00442
00443
00444 leaf_size_ = fx_param_int(module_, "leaf_size", 20);
00445
00446 DEBUG_ASSERT(leaf_size_ > 0);
00447
00448
00449 references_.Copy(references_in);
00450 queries_.Alias(references_);
00451
00452 kfns_ = fx_param_int(module_, "kfns", 1);
00453
00454
00455 neighbor_indices_.Init(references_.n_cols() * kfns_);
00456
00457
00458 neighbor_distances_.Init(references_.n_cols() * kfns_);
00459 neighbor_distances_.SetAll(0.0);
00460
00461
00462 fx_timer_start(module_, "tree_building");
00463
00464
00465
00466
00467 query_tree_ = NULL;
00468 reference_tree_ = tree::MakeKdTreeMidpoint<TreeType>(references_,
00469 leaf_size_, &old_from_new_references_, NULL);
00470
00471
00472 fx_timer_stop(module_, "tree_building");
00473
00474 }
00475 void Init(const Matrix& queries_in, const Matrix& references_in,
00476 index_t leaf_size, index_t kfns) {
00477
00478
00479 number_of_prunes_ = 0;
00480
00481
00482 leaf_size_ = leaf_size;
00483 DEBUG_ASSERT(leaf_size_ > 0);
00484
00485
00486 kfns_ = kfns;
00487 DEBUG_ASSERT(kfns_ > 0);
00488
00489 queries_.Copy(queries_in);
00490 references_.Copy(references_in);
00491
00492
00493 DEBUG_SAME_SIZE(queries_.n_rows(), references_.n_rows());
00494
00495
00496
00497 neighbor_indices_.Init(queries_.n_cols() * kfns_);
00498
00499
00500 neighbor_distances_.Init(queries_.n_cols() * kfns_);
00501 neighbor_distances_.SetAll(0);
00502
00503
00504
00505
00506
00507 query_tree_ = tree::MakeKdTreeMidpoint<TreeType>(queries_, leaf_size_,
00508 &old_from_new_queries_, NULL);
00509 reference_tree_ = tree::MakeKdTreeMidpoint<TreeType>(references_,
00510 leaf_size_, &old_from_new_references_, NULL);
00511
00512 }
00513
00514 void Init(const Matrix& references_in, index_t leaf_size, index_t kfns) {
00515
00516 number_of_prunes_ = 0;
00517
00518
00519 leaf_size_ = leaf_size;
00520 DEBUG_ASSERT(leaf_size_ > 0);
00521
00522
00523 kfns_ = kfns;
00524 DEBUG_ASSERT(kfns_ > 0);
00525
00526 references_.Copy(references_in);
00527 queries_.Alias(references_);
00528
00529
00530 neighbor_indices_.Init(references_.n_cols() * kfns_);
00531
00532
00533 neighbor_distances_.Init(references_.n_cols() * kfns_);
00534 neighbor_distances_.SetAll(0.0);
00535
00536
00537
00538
00539
00540 query_tree_ = NULL;
00541 reference_tree_ = tree::MakeKdTreeMidpoint<TreeType>(references_,
00542 leaf_size_, &old_from_new_references_, NULL);
00543
00544 old_from_new_queries_.Init();
00545 }
00546
00547 void Destruct() {
00548 queries_.Destruct();
00549 references_.Destruct();
00550 old_from_new_queries_.Renew();
00551 old_from_new_references_.Renew();
00552 neighbor_distances_.Destruct();
00553 neighbor_indices_.Renew();
00554 if (query_tree_ != NULL) {
00555 delete query_tree_;
00556 query_tree_=NULL;
00557 }
00558 if (reference_tree_ != NULL) {
00559 delete reference_tree_;
00560 reference_tree_=NULL;
00561 }
00562 }
00563
00568 void InitNaive(const Matrix& queries_in,
00569 const Matrix& references_in, index_t kfns){
00570
00571 queries_.Copy(queries_in);
00572 references_.Copy(references_in);
00573 kfns_=kfns;
00574
00575 DEBUG_SAME_SIZE(queries_.n_rows(), references_.n_rows());
00576
00577 neighbor_indices_.Init(queries_.n_cols()*kfns_);
00578 neighbor_distances_.Init(queries_.n_cols()*kfns_);
00579 neighbor_distances_.SetAll(0.0);
00580
00581
00582
00583 leaf_size_ = max(queries_.n_cols(), references_.n_cols());
00584
00585 query_tree_ = tree::MakeKdTreeMidpoint<TreeType>(queries_,
00586 leaf_size_, &old_from_new_queries_, NULL);
00587 reference_tree_ = tree::MakeKdTreeMidpoint<TreeType>(
00588 references_, leaf_size_, &old_from_new_references_, NULL);
00589
00590 }
00591
00592 void InitNaive(const Matrix& references_in, index_t kfns){
00593
00594 references_.Copy(references_in);
00595 queries_.Alias(references_);
00596 kfns_=kfns;
00597
00598 neighbor_indices_.Init(references_.n_cols()*kfns_);
00599 neighbor_distances_.Init(references_.n_cols()*kfns_);
00600 neighbor_distances_.SetAll(0.0);
00601
00602
00603
00604 leaf_size_ = references_.n_cols();
00605
00606 query_tree_ = NULL;
00607 reference_tree_ = tree::MakeKdTreeMidpoint<TreeType>(
00608 references_, leaf_size_, &old_from_new_references_, NULL);
00609
00610 old_from_new_queries_.Init();
00611 }
00612
00616 void ComputeNeighbors(ArrayList<index_t>* resulting_neighbors,
00617 ArrayList<double>* distances) {
00618
00619
00620 if (query_tree_!=NULL) {
00621 ComputeNeighborsRecursion_(query_tree_, reference_tree_,
00622 MaxNodeDistSq_(query_tree_, reference_tree_));
00623 } else {
00624 ComputeNeighborsRecursion_(reference_tree_, reference_tree_,
00625 MaxNodeDistSq_(reference_tree_, reference_tree_));
00626 }
00627
00628
00629 resulting_neighbors->Init(neighbor_indices_.size());
00630 distances->Init(neighbor_distances_.length());
00631
00632
00633 if (query_tree_ != NULL) {
00634 for (index_t i = 0; i < neighbor_indices_.size(); i++) {
00635 (*resulting_neighbors)[
00636 old_from_new_queries_[i/kfns_]*kfns_+ i%kfns_] =
00637 old_from_new_references_[neighbor_indices_[i]];
00638 (*distances)[
00639 old_from_new_queries_[i/kfns_]*kfns_+ i%kfns_] =
00640 neighbor_distances_[i];
00641 }
00642 } else {
00643 for (index_t i = 0; i < neighbor_indices_.size(); i++) {
00644 (*resulting_neighbors)[
00645 old_from_new_references_[i/kfns_]*kfns_+ i%kfns_] =
00646 old_from_new_references_[neighbor_indices_[i]];
00647 (*distances)[
00648 old_from_new_references_[i/kfns_]*kfns_+ i%kfns_] =
00649 neighbor_distances_[i];
00650 }
00651 }
00652 }
00653
00654
00658 void ComputeNaive(ArrayList<index_t>* resulting_neighbors,
00659 ArrayList<double>* distances) {
00660 if (query_tree_!=NULL) {
00661 ComputeBaseCase_(query_tree_, reference_tree_);
00662 } else {
00663 ComputeBaseCase_(reference_tree_, reference_tree_);
00664 }
00665
00666
00667 resulting_neighbors->Init(neighbor_indices_.size());
00668 distances->Init(neighbor_distances_.length());
00669
00670
00671 for (index_t i = 0; i < neighbor_indices_.size(); i++) {
00672 (*resulting_neighbors)[
00673 old_from_new_references_[i/kfns_]*kfns_+ i%kfns_] =
00674 old_from_new_references_[neighbor_indices_[i]];
00675 (*distances)[
00676 old_from_new_references_[i/kfns_]*kfns_+ i%kfns_] =
00677 neighbor_distances_[i];
00678
00679 }
00680 }
00681
00682 };
00683
00684
00685 #endif
00686