allknn.h

Go to the documentation of this file.
00001 /* MLPACK 0.2
00002  *
00003  * Copyright (c) 2008, 2009 Alexander Gray,
00004  *                          Garry Boyer,
00005  *                          Ryan Riegel,
00006  *                          Nikolaos Vasiloglou,
00007  *                          Dongryeol Lee,
00008  *                          Chip Mappus, 
00009  *                          Nishant Mehta,
00010  *                          Hua Ouyang,
00011  *                          Parikshit Ram,
00012  *                          Long Tran,
00013  *                          Wee Chin Wong
00014  *
00015  * Copyright (c) 2008, 2009 Georgia Institute of Technology
00016  *
00017  * This program is free software; you can redistribute it and/or
00018  * modify it under the terms of the GNU General Public License as
00019  * published by the Free Software Foundation; either version 2 of the
00020  * License, or (at your option) any later version.
00021  *
00022  * This program is distributed in the hope that it will be useful, but
00023  * WITHOUT ANY WARRANTY; without even the implied warranty of
00024  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
00025  * General Public License for more details.
00026  *
00027  * You should have received a copy of the GNU General Public License
00028  * along with this program; if not, write to the Free Software
00029  * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
00030  * 02110-1301, USA.
00031  */
00039 // inclusion guards, please add them to your .h files
00040 #ifndef ALLKNN_H
00041 #define ALLKNN_H
00042 
00043 // We need to include fastlib.  If you want to use fastlib, 
00044 // you need to have this line in addition to
00045 // the deplibs section of your build.py
00046 #include <fastlib/fastlib.h>
00047 #include <vector>
00048 #include <string>
00052 class TestAllkNN;
00057 class AllkNN {
00058   // Declare the tester class as a friend class so that it has access
00059   // to the private members of the class
00060   friend class TestAllkNN;
00061   
00063 
00068   class QueryStat {
00069     
00070     // Defines many useful things for a class, including a pretty 
00071     // printer and copy constructor
00072     OT_DEF_BASIC(QueryStat) {
00073       // Include this line for all non-pointer members
00074       // There are other versions for arrays and pointers, see base/otrav.h
00075       OT_MY_OBJECT(max_distance_so_far_); 
00076     } // OT_DEF_BASIC
00077     
00078    private:
00079     
00083     double max_distance_so_far_;
00084     
00085    public:
00086     
00087     double max_distance_so_far() {
00088       return max_distance_so_far_; 
00089     } 
00090     
00091     
00092     void set_max_distance_so_far(double new_dist) {
00093       max_distance_so_far_ = new_dist; 
00094     } 
00095     
00096     // In addition to any member variables for the statistic, all stat 
00097     // classes need two Init 
00098     // functions, one for leaves and one for non-leaves. 
00099     
00105     void Init(const Matrix& matrix, index_t start, index_t count) {
00106       // The bound starts at infinity
00107       max_distance_so_far_ = DBL_MAX;
00108     } 
00109     
00114     void Init(const Matrix& matrix, index_t start, index_t count, 
00115         const QueryStat& left, const QueryStat& right) {
00116       // For allnn, non-leaves can be initialized in the same way as leaves
00117       Init(matrix, start, count);
00118     } 
00119     
00120   }; //class AllNNStat  
00121   
00122   // TreeType are BinarySpaceTrees where the data are bounded by 
00123   // Euclidean bounding boxes, the data are stored in a Matrix, 
00124   // and each node has a QueryStat for its bound.
00125   typedef BinarySpaceTree<DHrectBound<2>, Matrix, QueryStat> TreeType;
00126    
00127   
00129  private:
00130   // These will store our data sets.
00131   Matrix queries_;
00132   Matrix references_;
00133   // Pointers to the roots of the two trees.
00134   TreeType* query_tree_;
00135   TreeType* reference_tree_;
00136   // The total number of prunes.
00137   index_t number_of_prunes_;
00138   // A permutation of the indices for tree building.
00139   ArrayList<index_t> old_from_new_queries_;
00140   ArrayList<index_t> old_from_new_references_;
00141   // The number of points in a leaf
00142   index_t leaf_size_;
00143   // The distance to the candidate nearest neighbor for each query
00144   Vector neighbor_distances_;
00145   // The indices of the candidate nearest neighbor for each query
00146   ArrayList<index_t> neighbor_indices_;
00147   // number of nearest neighbrs
00148   index_t knns_; 
00149   // if this flag is true then only the k-neighbor and distance are computed
00150    bool k_only_;
00151    // This can be either "single" or "dual" referring to dual tree and single tree algorithm
00152    std::string mode_;
00153   // The module containing the parameters for this computation. 
00154   struct datanode* module_;
00155   
00156   
00158   
00159   // Add this at the beginning of a class to prevent accidentally calling the copy constructor
00160   FORBID_ACCIDENTAL_COPIES(AllkNN);
00161   
00162 
00163  public:
00164 
00169   AllkNN() {
00170     query_tree_ = NULL;
00171     reference_tree_ = NULL;
00172   } 
00173   
00177   ~AllkNN() {
00178     if (query_tree_ != NULL) {
00179       delete query_tree_;
00180     }
00181     if (reference_tree_ != NULL) {
00182       delete reference_tree_;
00183     }
00184   } 
00185     
00186       
00188   
00192   double MinNodeDistSq_ (TreeType* query_node, TreeType* reference_node) {
00193     // node->bound() gives us the DHrectBound class for the node
00194     // It has a function MinDistanceSq which takes another DHrectBound
00195     return query_node->bound().MinDistanceSq(reference_node->bound());
00196   } 
00197 
00201   double MinPointNodeDistSq_ (const Vector& query_point, TreeType* reference_node) {
00202     // node->bound() gives us the DHrectBound class for the node
00203     // It has a function MinDistanceSq which takes another DHrectBound
00204     return reference_node->bound().MinDistanceSq(query_point);
00205   } 
00206   
00207   
00211   void ComputeBaseCase_(TreeType* query_node, TreeType* reference_node) {
00212    
00213     // DEBUG statements should be used frequently, since they incur no overhead
00214     // when compiled in fast mode
00215     
00216     // Check that the pointers are not NULL
00217     DEBUG_ASSERT(query_node != NULL);
00218     DEBUG_ASSERT(reference_node != NULL);
00219     // Check that we really should be in the base case
00220     DEBUG_WARN_IF(!query_node->is_leaf());
00221     DEBUG_WARN_IF(!reference_node->is_leaf());
00222     
00223     // Used to find the query node's new upper bound
00224     double query_max_neighbor_distance = -1.0;
00225     std::vector<std::pair<double, index_t> > neighbors(knns_);
00226     // node->begin() is the index of the first point in the node, 
00227     // node->end is one past the last index
00228     for (index_t query_index = query_node->begin(); 
00229          query_index < query_node->end(); query_index++) {
00230        
00231       // Get the query point from the matrix
00232       Vector query_point;
00233       queries_.MakeColumnVector(query_index, &query_point);
00234       
00235       index_t ind = query_index*knns_;
00236       for(index_t i=0; i<knns_; i++) {
00237         neighbors[i]=std::make_pair(neighbor_distances_[ind+i],
00238                                     neighbor_indices_[ind+i]);
00239       }
00240 
00241       double query_to_node_distance =
00242         MinPointNodeDistSq_(query_point, reference_node);
00243       if (query_to_node_distance < neighbor_distances_[ind+knns_-1]) {
00244         // We'll do the same for the references
00245         for (index_t reference_index = reference_node->begin(); 
00246              reference_index < reference_node->end(); reference_index++) {
00247         
00248           // Confirm that points do not identify themselves as neighbors
00249           // in the monochromatic case
00250           if (likely(reference_node != query_node ||
00251                      reference_index != query_index)) {
00252             Vector reference_point;
00253             references_.MakeColumnVector(reference_index, &reference_point);
00254             // We'll use lapack to find the distance between the two vectors
00255             double distance =
00256               la::DistanceSqEuclidean(query_point, reference_point);
00257             // If the reference point is closer than the current candidate, 
00258             // we'll update the candidate
00259             if (distance < neighbor_distances_[ind+knns_-1]) {
00260               neighbors.push_back(std::make_pair(distance, reference_index));
00261             }
00262           }
00263         } // for reference_index
00264         // if ((index_t)neighbors.size()>knns_) {
00265         std::sort(neighbors.begin(), neighbors.end());
00266         for(index_t i=0; i<knns_; i++) {
00267           neighbor_distances_[ind+i] = neighbors[i].first;
00268           neighbor_indices_[ind+i]  = neighbors[i].second;
00269         }
00270         neighbors.resize(knns_);
00271       }
00272       // We need to find the upper bound distance for this query node
00273       if (neighbor_distances_[ind+knns_-1] > query_max_neighbor_distance) {
00274         query_max_neighbor_distance = neighbor_distances_[ind+knns_-1]; 
00275       }
00276       
00277     } // for query_index 
00278     // Update the upper bound for the query_node
00279     query_node->stat().set_max_distance_so_far(query_max_neighbor_distance);
00280          
00281   } // ComputeBaseCase_
00282   
00283   
00287   void ComputeDualNeighborsRecursion_(TreeType* query_node, TreeType* reference_node, 
00288       double lower_bound_distance) {
00289    
00290     // DEBUG statements should be used frequently, 
00291     // either with or without messages 
00292     
00293     // A DEBUG statement with no predefined message
00294     DEBUG_ASSERT(query_node != NULL);
00295     // A DEBUG statement with a predefined message
00296     DEBUG_ASSERT_MSG(reference_node != NULL, "reference node is null");
00297     // Make sure the bounding information is correct
00298     DEBUG_ASSERT(lower_bound_distance == MinNodeDistSq_(query_node, 
00299         reference_node));
00300     
00301     if (lower_bound_distance > query_node->stat().max_distance_so_far()) {
00302       // Pruned by distance
00303       number_of_prunes_++;
00304     }
00305     // node->is_leaf() works as one would expect
00306     else if (query_node->is_leaf() && reference_node->is_leaf()) {
00307       // Base Case
00308       ComputeBaseCase_(query_node, reference_node);
00309     }
00310     else if (query_node->is_leaf()) {
00311       // Only query is a leaf
00312       
00313       // We'll order the computation by distance 
00314       double left_distance = MinNodeDistSq_(query_node, reference_node->left());
00315       double right_distance = MinNodeDistSq_(query_node, reference_node->right());
00316       
00317       if (left_distance < right_distance) {
00318         ComputeDualNeighborsRecursion_(query_node, reference_node->left(), 
00319             left_distance);
00320         ComputeDualNeighborsRecursion_(query_node, reference_node->right(), 
00321             right_distance);
00322       }
00323       else {
00324         ComputeDualNeighborsRecursion_(query_node, reference_node->right(), 
00325             right_distance);
00326         ComputeDualNeighborsRecursion_(query_node, reference_node->left(), 
00327             left_distance);
00328       }
00329       
00330     }
00331     
00332     else if (reference_node->is_leaf()) {
00333       // Only reference is a leaf 
00334       
00335       double left_distance = MinNodeDistSq_(query_node->left(), reference_node);
00336       double right_distance = MinNodeDistSq_(query_node->right(), reference_node);
00337       
00338       ComputeDualNeighborsRecursion_(query_node->left(), reference_node, 
00339           left_distance);
00340       ComputeDualNeighborsRecursion_(query_node->right(), reference_node, 
00341           right_distance);
00342       
00343       // We need to update the upper bound based on the new upper bounds of 
00344       // the children
00345       query_node->stat().set_max_distance_so_far(
00346           max(query_node->left()->stat().max_distance_so_far(),
00347               query_node->right()->stat().max_distance_so_far()));
00348     } else {
00349       // Recurse on both as above
00350       
00351       double left_distance = MinNodeDistSq_(query_node->left(), 
00352           reference_node->left());
00353       double right_distance = MinNodeDistSq_(query_node->left(), 
00354           reference_node->right());
00355       
00356       if (left_distance < right_distance) {
00357         ComputeDualNeighborsRecursion_(query_node->left(), reference_node->left(), 
00358             left_distance);
00359         ComputeDualNeighborsRecursion_(query_node->left(), reference_node->right(), 
00360             right_distance);
00361       }
00362       else {
00363         ComputeDualNeighborsRecursion_(query_node->left(), reference_node->right(), 
00364             right_distance);
00365         ComputeDualNeighborsRecursion_(query_node->left(), reference_node->left(), 
00366             left_distance);
00367       }
00368       left_distance = MinNodeDistSq_(query_node->right(), reference_node->left());
00369       right_distance = MinNodeDistSq_(query_node->right(), 
00370           reference_node->right());
00371       
00372       if (left_distance < right_distance) {
00373         ComputeDualNeighborsRecursion_(query_node->right(), reference_node->left(), 
00374             left_distance);
00375         ComputeDualNeighborsRecursion_(query_node->right(), reference_node->right(), 
00376             right_distance);
00377       }
00378       else {
00379         ComputeDualNeighborsRecursion_(query_node->right(), reference_node->right(), 
00380             right_distance);
00381         ComputeDualNeighborsRecursion_(query_node->right(), reference_node->left(), 
00382             left_distance);
00383       }
00384       
00385       // Update the upper bound as above
00386       query_node->stat().set_max_distance_so_far(
00387           max(query_node->left()->stat().max_distance_so_far(),
00388               query_node->right()->stat().max_distance_so_far()));
00389       
00390     }
00391     
00392   } // ComputeDualNeighborsRecursion_
00393   
00394   
00395   void ComputeSingleNeighborsRecursion_(index_t point_id, 
00396       Vector &point, TreeType* reference_node, 
00397       double *min_dist_so_far) {
00398      
00399     // A DEBUG statement with a predefined message
00400     DEBUG_ASSERT_MSG(reference_node != NULL, "reference node is null");
00401     // Make sure the bounding information is correct
00402     
00403     // node->is_leaf() works as one would expect
00404     if (reference_node->is_leaf()) {
00405       // Base Case
00406       std::vector<std::pair<double, index_t> > neighbors(knns_);
00407       index_t ind = point_id*knns_;
00408       for(index_t i=0; i<knns_; i++) {
00409         neighbors[i]=std::make_pair(neighbor_distances_[ind+i],
00410                                     neighbor_indices_[ind+i]);
00411       }
00412       // We'll do the same for the references
00413       for (index_t reference_index = reference_node->begin(); 
00414            reference_index < reference_node->end(); reference_index++) {
00415               // Confirm that points do not identify themselves as neighbors
00416               // in the monochromatic case
00417         if (likely(!(references_.ptr()==queries_.ptr() &&
00418                       reference_index == point_id))) {
00419                 Vector reference_point;
00420                 references_.MakeColumnVector(reference_index, &reference_point);
00421                 // We'll use lapack to find the distance between the two vectors
00422                 double distance =
00423                 la::DistanceSqEuclidean(point, reference_point);
00424                 // If the reference point is closer than the current candidate, 
00425                 // we'll update the candidate
00426                 if (distance < neighbor_distances_[ind+knns_-1]) {
00427                   neighbors.push_back(std::make_pair(distance, reference_index));
00428                 }
00429               }
00430       } // for reference_index
00431       std::sort(neighbors.begin(), neighbors.end());
00432       for(index_t i=0; i<knns_; i++) {
00433         neighbor_distances_[ind+i] = neighbors[i].first;
00434         neighbor_indices_[ind+i]  = neighbors[i].second;
00435       }
00436       *min_dist_so_far=neighbor_distances_[ind+knns_-1];
00437     } else {
00438       // We'll order the computation by distance 
00439       double left_distance = reference_node->left()->bound().MinDistanceSq(point);
00440       double right_distance = reference_node->right()->bound().MinDistanceSq(point);
00441       
00442       if (left_distance < right_distance) {
00443         ComputeSingleNeighborsRecursion_(point_id, point, reference_node->left(), 
00444             min_dist_so_far);
00445         if (*min_dist_so_far <right_distance){
00446           number_of_prunes_++;
00447           return;
00448         }
00449         ComputeSingleNeighborsRecursion_(point_id, point, reference_node->right(), 
00450             min_dist_so_far);
00451       } else {
00452         ComputeSingleNeighborsRecursion_(point_id, point, reference_node->right(), 
00453             min_dist_so_far);
00454         if (*min_dist_so_far <left_distance){
00455           number_of_prunes_++;
00456           return;
00457         }
00458         ComputeSingleNeighborsRecursion_(point_id, point, reference_node->left(), 
00459             min_dist_so_far);
00460       }
00461     }    
00462   }  
00464   
00470    void Init(const Matrix& queries_in, const Matrix& references_in, struct datanode* module_in) {
00471     
00472     // set the module
00473     module_ = module_in;
00474     
00475     // track the number of prunes
00476     number_of_prunes_ = 0;
00477     
00478     mode_=fx_param_str(module_, "mode", "dual"); 
00479     // Get the leaf size from the module
00480     leaf_size_ = fx_param_int(module_, "leaf_size", 20);
00481     // Make sure the leaf size is valid
00482     DEBUG_ASSERT(leaf_size_ > 0);
00483     
00484     // Copy the matrices to the class members since they will be rearranged.  
00485     queries_.Copy(queries_in);
00486     references_.Copy(references_in);
00487     
00488     // The data sets need to have the same number of points
00489     DEBUG_SAME_SIZE(queries_.n_rows(), references_.n_rows());
00490     
00491                 // K-nearest neighbors initialization
00492                 knns_ = fx_param_int(module_, "knns", 5);
00493   
00494     // Initialize the list of nearest neighbor candidates
00495     neighbor_indices_.Init(queries_.n_cols() * knns_);
00496     
00497                 // Initialize the vector of upper bounds for each point.  
00498     neighbor_distances_.Init(queries_.n_cols() * knns_);
00499     neighbor_distances_.SetAll(DBL_MAX);
00500 
00501     // We'll time tree building
00502     fx_timer_start(module_, "tree_building");
00503 
00504     // This call makes each tree from a matrix, leaf size, and two arrays 
00505                 // that record the permutation of the data points
00506     // Instead of NULL, it is possible to specify an array new_from_old_
00507     if (mode_=="dual") {
00508       query_tree_ = tree::MakeKdTreeMidpoint<TreeType>(queries_, leaf_size_, 
00509                                 &old_from_new_queries_, NULL);
00510     } else {
00511       query_tree_=NULL;
00512     }
00513     reference_tree_ = tree::MakeKdTreeMidpoint<TreeType>(references_, 
00514                                 leaf_size_, &old_from_new_references_, NULL);
00515     
00516     // Stop the timer we started above
00517     fx_timer_stop(module_, "tree_building");
00518 
00519   } // Init
00520 
00524   void Init(const Matrix& references_in, struct datanode* module_in) {
00525      
00526     // set the module
00527     module_ = module_in;
00528   
00529     mode_=fx_param_str(module_, "mode", "dual"); 
00530    
00531     // track the number of prunes
00532     number_of_prunes_ = 0;
00533     
00534     // Get the leaf size from the module
00535     leaf_size_ = fx_param_int(module_, "leaf_size", 20);
00536     // Make sure the leaf size is valid
00537     DEBUG_ASSERT(leaf_size_ > 0);
00538     
00539     // Copy the matrices to the class members since they will be rearranged.  
00540     references_.Copy(references_in);
00541     queries_.Alias(references_);    
00542                 // K-nearest neighbors initialization
00543                 knns_ = fx_param_int(module_, "knns", 5);
00544   
00545     // Initialize the list of nearest neighbor candidates
00546     neighbor_indices_.Init(references_.n_cols() * knns_);
00547     
00548                 // Initialize the vector of upper bounds for each point.  
00549     neighbor_distances_.Init(references_.n_cols() * knns_);
00550     neighbor_distances_.SetAll(DBL_MAX);
00551 
00552     // We'll time tree building
00553     fx_timer_start(module_, "tree_building");
00554 
00555     // This call makes each tree from a matrix, leaf size, and two arrays 
00556                 // that record the permutation of the data points
00557     // Instead of NULL, it is possible to specify an array new_from_old_
00558     query_tree_ = NULL;
00559     reference_tree_ = tree::MakeKdTreeMidpoint<TreeType>(references_, 
00560                                 leaf_size_, &old_from_new_references_, NULL);
00561     
00562     // Stop the timer we started above
00563     fx_timer_stop(module_, "tree_building");
00564 
00565   }
00566   void Init(const Matrix& queries_in, const Matrix& references_in, 
00567       index_t leaf_size, index_t knns, const char *mode="dual") {
00568     
00569     // set the module even though we're getting the params elsewhere
00570     module_ = NULL;
00571 
00572     // track the number of prunes
00573     number_of_prunes_ = 0;
00574     mode_=mode;
00575      
00576     // Make sure the leaf size is valid
00577     leaf_size_ = leaf_size;
00578     DEBUG_ASSERT(leaf_size_ > 0);
00579     
00580     // Make sure the knns is valid
00581     knns_ = knns;
00582     DEBUG_ASSERT(knns_ > 0);
00583     // Copy the matrices to the class members since they will be rearranged.  
00584     queries_.Copy(queries_in);
00585     references_.Copy(references_in);
00586     
00587     // The data sets need to have the same number of points
00588     DEBUG_SAME_SIZE(queries_.n_rows(), references_.n_rows());
00589     
00590   
00591     // Initialize the list of nearest neighbor candidates
00592     neighbor_indices_.Init(queries_.n_cols() * knns_);
00593     
00594     // Initialize the vector of upper bounds for each point.  
00595     neighbor_distances_.Init(queries_.n_cols() * knns_);
00596     neighbor_distances_.SetAll(DBL_MAX);
00597 
00598 
00599     // This call makes each tree from a matrix, leaf size, and two arrays 
00600     // that record the permutation of the data points
00601     // Instead of NULL, it is possible to specify an array new_from_old_
00602     if (mode_=="dual") {
00603       query_tree_ = tree::MakeKdTreeMidpoint<TreeType>(queries_, leaf_size_, 
00604           &old_from_new_queries_, NULL);
00605     } else {
00606       query_tree_=NULL;
00607     }
00608     reference_tree_ = tree::MakeKdTreeMidpoint<TreeType>(references_, 
00609         leaf_size_, &old_from_new_references_, NULL);
00610 
00611   } // Init
00612 
00613   void Init(const Matrix& references_in, index_t leaf_size, 
00614             index_t knns, const char *mode="dual") {
00615     // set the module even though we're getting the params elsewhere
00616     module_ = NULL;
00617 
00618     // track the number of prunes
00619     number_of_prunes_ = 0;
00620     mode_=mode; 
00621      
00622     // Make sure the leaf size is valid
00623     leaf_size_ = leaf_size;
00624     DEBUG_ASSERT(leaf_size_ > 0);
00625     
00626     // Make sure the knns is valid
00627     knns_ = knns;
00628     DEBUG_ASSERT(knns_ > 0);
00629     // Copy the matrices to the class members since they will be rearranged.  
00630     references_.Copy(references_in);
00631     queries_.Alias(references_); 
00632   
00633     // Initialize the list of nearest neighbor candidates
00634     neighbor_indices_.Init(references_.n_cols() * knns_);
00635     
00636     // Initialize the vector of upper bounds for each point.  
00637     neighbor_distances_.Init(references_.n_cols() * knns_);
00638     neighbor_distances_.SetAll(DBL_MAX);
00639 
00640 
00641     // This call makes each tree from a matrix, leaf size, and two arrays 
00642     // that record the permutation of the data points
00643     // Instead of NULL, it is possible to specify an array new_from_old_
00644     query_tree_ = NULL;
00645     reference_tree_ = tree::MakeKdTreeMidpoint<TreeType>(references_, 
00646         leaf_size_, &old_from_new_references_, NULL);
00647    // This is an annoying feature of fastlib
00648     old_from_new_queries_.Init();
00649   }
00654   void Destruct() {
00655     if (query_tree_ != NULL) {
00656       delete query_tree_;
00657     }
00658     if (reference_tree_ != NULL) {
00659       delete reference_tree_;
00660     }
00661     queries_.Destruct();
00662     references_.Destruct();
00663     old_from_new_queries_.Renew();
00664     old_from_new_references_.Renew();
00665     neighbor_distances_.Destruct();
00666     neighbor_indices_.Renew();
00667   }
00668   void InitNaive(const Matrix& queries_in, 
00669       const Matrix& references_in, index_t knns){
00670     
00671     queries_.Copy(queries_in);
00672     references_.Copy(references_in);
00673     knns_=knns;
00674     
00675     DEBUG_SAME_SIZE(queries_.n_rows(), references_.n_rows());
00676     
00677     neighbor_indices_.Init(queries_.n_cols()*knns_);
00678     neighbor_distances_.Init(queries_.n_cols()*knns_);
00679     neighbor_distances_.SetAll(DBL_MAX);
00680     
00681     // The only difference is that we set leaf_size_ to be large enough 
00682     // that each tree has only one node
00683     leaf_size_ = max(queries_.n_cols(), references_.n_cols());
00684     
00685     query_tree_ = tree::MakeKdTreeMidpoint<TreeType>(queries_, 
00686         leaf_size_, &old_from_new_queries_, NULL);
00687     reference_tree_ = tree::MakeKdTreeMidpoint<TreeType>(
00688         references_, leaf_size_, &old_from_new_references_, NULL);
00689         
00690   } // InitNaive
00691   
00692    void InitNaive(const Matrix& references_in, index_t knns){
00693     
00694     references_.Copy(references_in);
00695     queries_.Alias(references_);
00696     knns_=knns;
00697     
00698     neighbor_indices_.Init(references_.n_cols()*knns_);
00699     neighbor_distances_.Init(references_.n_cols()*knns_);
00700     neighbor_distances_.SetAll(DBL_MAX);
00701     
00702     // The only difference is that we set leaf_size_ to be large enough 
00703     // that each tree has only one node
00704     leaf_size_ = references_.n_cols();
00705     
00706     query_tree_ = NULL;
00707     reference_tree_ = tree::MakeKdTreeMidpoint<TreeType>(
00708         references_, leaf_size_, &old_from_new_references_, NULL);
00709     // This is an annoying feature of fastlib
00710     old_from_new_queries_.Init();
00711   } // InitNaive
00712   
00716   void ComputeNeighbors(ArrayList<index_t>* resulting_neighbors,
00717                         ArrayList<double>* distances) {
00718     fx_timer_start(module_, "computing_neighbors");
00719     if (mode_=="dual") {
00720       // Start on the root of each tree
00721       if (query_tree_!=NULL) {
00722         ComputeDualNeighborsRecursion_(query_tree_, reference_tree_, 
00723             MinNodeDistSq_(query_tree_, reference_tree_));
00724       } else {
00725          ComputeDualNeighborsRecursion_(reference_tree_, reference_tree_, 
00726             MinNodeDistSq_(reference_tree_, reference_tree_));
00727       }
00728     } else {
00729       index_t chunk = queries_.n_cols()/10;
00730       printf("Progress:00%%");
00731       fflush(stdout);
00732       for(index_t i=0; i<10; i++) {
00733         for(index_t j=0; j<chunk; j++) {
00734           Vector point;
00735           point.Alias(queries_.GetColumnPtr(i*chunk+j), queries_.n_rows());
00736           double min_dist_so_far=DBL_MAX;
00737           ComputeSingleNeighborsRecursion_(i*chunk+j, point, reference_tree_, &min_dist_so_far);
00738         }
00739         printf("\b\b\b%02"LI"d%%", (i+1)*10);  
00740         fflush(stdout);
00741       }
00742       for(index_t i=0; i<queries_.n_cols() % 10; i++) {
00743         index_t ind = (queries_.n_cols()/10)*10+i;
00744         Vector point;
00745         point.Alias(queries_.GetColumnPtr(ind), queries_.n_rows());
00746         double min_dist_so_far=DBL_MAX;
00747         ComputeSingleNeighborsRecursion_(i, point, reference_tree_, &min_dist_so_far);
00748       }
00749       printf("\n");
00750     }
00751     fx_timer_stop(module_, "computing_neighbors");
00752     // We need to initialize the results list before filling it
00753     resulting_neighbors->Init(neighbor_indices_.size());
00754     distances->Init(neighbor_distances_.length());
00755     // We need to map the indices back from how they have 
00756     // been permuted
00757     if (query_tree_ != NULL) {
00758       for (index_t i = 0; i < neighbor_indices_.size(); i++) {
00759         (*resulting_neighbors)[
00760           old_from_new_queries_[i/knns_]*knns_+ i%knns_] = 
00761           old_from_new_references_[neighbor_indices_[i]];
00762         (*distances)[
00763           old_from_new_queries_[i/knns_]*knns_+ i%knns_] = 
00764           neighbor_distances_[i];
00765       }
00766     } else {
00767       for (index_t i = 0; i < neighbor_indices_.size(); i++) {
00768         (*resulting_neighbors)[
00769           old_from_new_references_[i/knns_]*knns_+ i%knns_] = 
00770           old_from_new_references_[neighbor_indices_[i]];
00771         (*distances)[
00772           old_from_new_references_[i/knns_]*knns_+ i%knns_] = 
00773           neighbor_distances_[i];
00774       }
00775     }
00776   } // ComputeNeighbors
00777   
00778   
00782   void ComputeNaive(ArrayList<index_t>* resulting_neighbors,
00783                     ArrayList<double>*  distances) {
00784     if (query_tree_!=NULL) {
00785       ComputeBaseCase_(query_tree_, reference_tree_);
00786     } else {
00787        ComputeBaseCase_(reference_tree_, reference_tree_);
00788     }
00789 
00790     // The same code as above
00791     resulting_neighbors->Init(neighbor_indices_.size());
00792     distances->Init(neighbor_distances_.length());
00793     // We need to map the indices back from how they have 
00794     // been permuted
00795     for (index_t i = 0; i < neighbor_indices_.size(); i++) {
00796       (*resulting_neighbors)[
00797         old_from_new_references_[i/knns_]*knns_+ i%knns_] = 
00798         old_from_new_references_[neighbor_indices_[i]];
00799       (*distances)[
00800         old_from_new_references_[i/knns_]*knns_+ i%knns_] = 
00801         neighbor_distances_[i];
00802 
00803     }
00804   }
00805    
00806 }; //class AllNN
00807 
00808 
00809 #endif
00810 // end inclusion guards
Generated on Mon Jan 24 12:04:37 2011 for FASTlib by  doxygen 1.6.3