allnn.h

00001 /* MLPACK 0.2
00002  *
00003  * Copyright (c) 2008, 2009 Alexander Gray,
00004  *                          Garry Boyer,
00005  *                          Ryan Riegel,
00006  *                          Nikolaos Vasiloglou,
00007  *                          Dongryeol Lee,
00008  *                          Chip Mappus, 
00009  *                          Nishant Mehta,
00010  *                          Hua Ouyang,
00011  *                          Parikshit Ram,
00012  *                          Long Tran,
00013  *                          Wee Chin Wong
00014  *
00015  * Copyright (c) 2008, 2009 Georgia Institute of Technology
00016  *
00017  * This program is free software; you can redistribute it and/or
00018  * modify it under the terms of the GNU General Public License as
00019  * published by the Free Software Foundation; either version 2 of the
00020  * License, or (at your option) any later version.
00021  *
00022  * This program is distributed in the hope that it will be useful, but
00023  * WITHOUT ANY WARRANTY; without even the implied warranty of
00024  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
00025  * General Public License for more details.
00026  *
00027  * You should have received a copy of the GNU General Public License
00028  * along with this program; if not, write to the Free Software
00029  * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
00030  * 02110-1301, USA.
00031  */
00042 // Header files should always have inclusion guards.  It's a good idea
00043 // to "sign" these guards with the containing folder or project name,
00044 // in the off chance that someone else has a file with the same name.
00045 #ifndef PLATONIC_ALLNN_H
00046 #define PLATONIC_ALLNN_H
00047 
00048 // You can include all core FASTlib components at once as follows.
00049 // Your "deplibs" entry in build.py should mirror your includes.
00050 #include <fastlib/fastlib.h>
00051 
00052 // TODO: Move these constants into the AllNN class.  I have not
00053 // immediately done this because C++ doesn't like non-integers to be
00054 // defined on declaration within classes.  We would need to do this in
00055 // a .cc file, and in fact should probably move this file's function
00056 // definitions into one as well.
00057 
00058 const fx_entry_doc allnn_entries[] = {
00059   {"leaf_size", FX_PARAM, FX_INT, NULL,
00060    "  The maximum number of points to store at a leaf.\n"},
00061   {"tree_building", FX_TIMER, FX_CUSTOM, NULL,
00062    "  Time spent building the kd-tree.\n"},
00063   {"dual_tree_computation", FX_TIMER, FX_CUSTOM, NULL,
00064    "  Time spent computing the nearest neighbors.\n"},
00065   {"number_of_prunes", FX_RESULT, FX_INT, NULL,
00066    "  Total node-pairs found to be too far to matter.\n"},
00067   FX_ENTRY_DOC_DONE
00068 };
00069 
00070 const fx_module_doc allnn_doc = {
00071   allnn_entries, NULL,
00072   "Performs dual-tree all-nearest-neighbors computation.\n"
00073 };
00074 
00075 const fx_entry_doc allnn_naive_entries[] = {
00076   {"naive_time", FX_TIMER, FX_CUSTOM, NULL,
00077    "  Time spend performing the naive computation.\n"},
00078   FX_ENTRY_DOC_DONE
00079 };
00080 
00081 const fx_module_doc allnn_naive_doc = {
00082   allnn_naive_entries, NULL,
00083   "Performs naive all-nearest-neighbors computation.\n"
00084 };
00085 
00107 class AllNN {
00108 
00110 
00111  private:
00119   class QueryStat {
00120 
00121    private:
00126     double max_distance_so_far_;
00127 
00128     // The object traversal macros establish a FASTlib-complient
00129     // storage class, providing many tools including pretty printing
00130     // and copy construction.  See base/otrav.h for more details.
00131     //
00132     // OT_DEF_BASIC is suitable for when you want pretty printing, but
00133     // don't need a special destructor (your object has no pointers).
00134     // Otherwise, you should use OT_DEF.
00135     OBJECT_TRAVERSAL_SHALLOW(QueryStat) {
00136       // Declare a non-pointer/array member variable to be traversed.
00137       // See base/otrav.h for other kinds of declarations.
00138       OT_OBJ(max_distance_so_far_);
00139     }
00140 
00141    public:
00142     double max_distance_so_far() {
00143       return max_distance_so_far_;
00144     }
00145 
00146     void set_max_distance_so_far(double new_dist) {
00147       max_distance_so_far_ = new_dist;
00148     }
00149 
00157     void Init(const Matrix& matrix, index_t start, index_t count) {
00158       // The bound starts at infinity
00159       max_distance_so_far_ = DBL_MAX;
00160     }
00161 
00168     void Init(const Matrix& matrix, index_t start, index_t count,
00169               const QueryStat& left, const QueryStat& right) {
00170       Init(matrix, start, count);
00171     }
00172 
00173   }; /* class AllNNStat */
00174 
00175   // The tree directory defines several tools for the creation of
00176   // custom tree types, especially for kd-trees.  The DHrectBound<2>
00177   // gives us the normal kind of kd-tree bounding boxes, using the
00178   // 2-norm, and Matrix specifies the storage type of our data.
00179 
00181   typedef BinarySpaceTree<DHrectBound<2>, Matrix, QueryStat> QueryTree;
00183   typedef BinarySpaceTree<DHrectBound<2>, Matrix> ReferenceTree;
00184 
00186 
00187  private:
00189   struct datanode* module_;
00190 
00192   Matrix queries_;
00194   Matrix references_;
00195 
00197   QueryTree* query_tree_;
00199   ReferenceTree* reference_tree_;
00200 
00202   index_t leaf_size_;
00203 
00205   ArrayList<index_t> old_from_new_queries_;
00207   ArrayList<index_t> old_from_new_references_;
00208 
00213   Vector neighbor_distances_;
00218   ArrayList<index_t> neighbor_indices_;
00219 
00221   index_t number_of_prunes_;
00222 
00224   bool initialized_;
00226   bool already_used_;
00227 
00228 
00230 
00231   // It is easy to accidently call copy constructors in C++.  The most
00232   // common mistake is to define functions with object arguments
00233   // passed by value:
00234   //
00235   //   void foo(HugeObject x) {...}
00236   //
00237   // This recursively copies each member variable of the object, which
00238   // would be disasterous, for instance, if stored query and reference
00239   // matrices are huge.  Core FASTlib components usually mitigate this
00240   // by passing objects by const reference or by pointer:
00241   //
00242   //   void bar(const HugeObject& x, HugeObject* y) {...}
00243   //
00244   // Non-const pointers are used when the outside object is modified.
00245   //
00246   // The following disables copy construction and assignment for
00247   // objects of this class, which prevents functions like foo from
00248   // compiling, saving you from poor performance and strange bugs.
00249   FORBID_ACCIDENTAL_COPIES(AllNN);
00250 
00251  public:
00252   // Default constructors should be kept very simple and should never
00253   // allocate memory.  Their two responsibilities are to ensure that
00254   // it's safe to destroy the object without having otherwise used it
00255   // (e.g. to set pointers to NULL) and to poison memory when in debug
00256   // mode with BIG_BAD_NUMBER = 2146666666 = NaN as a double and
00257   // BIG_BAD_POINTER = 0xdeadbeef.
00258   AllNN() {
00259     query_tree_ = NULL;
00260     reference_tree_ = NULL;
00261 
00262     DEBUG_POISON_PTR(module_);
00263     DEBUG_ONLY(leaf_size_ = BIG_BAD_NUMBER);
00264     DEBUG_ONLY(number_of_prunes_ = BIG_BAD_NUMBER);
00265 
00266     DEBUG_ONLY(initialized_ = false);
00267     DEBUG_ONLY(already_used_ = false);
00268   }
00269 
00270   // Note that we don't delete the fx module; it's managed externally.
00271   ~AllNN() {
00272     if (query_tree_ != NULL) {
00273       delete query_tree_;
00274     }
00275     if (reference_tree_ != NULL) {
00276       delete reference_tree_;
00277     }
00278   }
00279 
00281 
00286   double MinNodeDistSq_(QueryTree* query_node, ReferenceTree* reference_node) {
00287     return query_node->bound().MinDistanceSq(reference_node->bound());
00288   }
00289 
00290 
00296   void GNPBaseCase_(QueryTree* query_node, ReferenceTree* reference_node) {
00297 
00298     // Debug checks should be used frequently.  They incur no overhead
00299     // when compiled in --mode=fast and very little otherwise.
00300 
00301     /* Make sure we didn't try to split children */
00302     DEBUG_ASSERT(query_node != NULL);
00303     DEBUG_ASSERT(reference_node != NULL);
00304 
00305     /* Make sure we should be in the base case */
00306     DEBUG_WARN_IF(!query_node->is_leaf());
00307     DEBUG_WARN_IF(!reference_node->is_leaf());
00308 
00309     /* Used to find the query node's new upper bound */
00310     double max_nearest_neighbor_distance = -1.0;
00311 
00312     /* Loop over all query-reference pairs */
00313 
00314     // Trees don't store their points, but instead give index ranges.
00315     // To make this feasible, they have to rearrange their input
00316     // matrices, which is why we were sure to make copies.
00317     for (index_t query_index = query_node->begin();
00318         query_index < query_node->end(); query_index++) {
00319 
00320       // MakeColumnVector aliases (i.e. points to but does not copy) a
00321       // column from the matrix.
00322       //
00323       // A brief aside: BLAS/LAPACK is coded in Fortran and thus
00324       // expects matrices to be column major.  We side with their
00325       // format for compatiblity, and accordingly, it is more cache
00326       // friendly to store data points along columns, as is common in
00327       // statistics, than along rows, as is more conventional.
00328       Vector query_point;
00329       queries_.MakeColumnVector(query_index, &query_point);
00330 
00331       // It's not terrible form to leave TODO statements in code you
00332       // intend to maintain, especially when coding under a deadline.
00333       // These are easy to search for, though for some reason, Garry
00334       // was more partial to "where's WALDO".  More memorable, maybe?
00335 
00336       /* TODO: try pruning query points vs reference node */
00337 
00338       for (index_t reference_index = reference_node->begin();
00339           reference_index < reference_node->end(); reference_index++) {
00340 
00341         Vector reference_point;
00342         references_.MakeColumnVector(reference_index, &reference_point);
00343 
00344         // BLAS can perform many vectors ops more quickly than C/C++.
00345         double distance =
00346             la::DistanceSqEuclidean(query_point, reference_point);
00347 
00348         /* Record points found to be closer than the best so far */
00349         if (distance < neighbor_distances_[query_index]) {
00350           neighbor_distances_[query_index] = distance;
00351           neighbor_indices_[query_index] = reference_index;
00352         }
00353 
00354       } /* for reference_index */
00355 
00356       /* Find the upper bound nn distance for this node */
00357       if (neighbor_distances_[query_index] > max_nearest_neighbor_distance) {
00358         max_nearest_neighbor_distance = neighbor_distances_[query_index];
00359       }
00360 
00361     } /* for query_index */
00362 
00363     /* Update the upper bound nn distance for the node */
00364     query_node->stat().set_max_distance_so_far(max_nearest_neighbor_distance);
00365 
00366   } /* GNPBaseCase_ */
00367 
00368 
00373   void GNPRecursion_(QueryTree* query_node, ReferenceTree* reference_node,
00374                      double lower_bound_distance) {
00375 
00376     /* Make sure we didn't try to split children */
00377     DEBUG_ASSERT(query_node != NULL);
00378     DEBUG_ASSERT(reference_node != NULL);
00379 
00380     // The following asserts equality of two doubles and prints their
00381     // values if it fails.  Note that this *isn't* a particularly fast
00382     // debug check, though; it negates the benefit of passing ahead a
00383     // precomputed distance entirely.  That's why we have --mode=fast.
00384 
00385     /* Make sure the precomputed bounding information is correct */
00386     DEBUG_SAME_DOUBLE(lower_bound_distance,
00387         MinNodeDistSq_(query_node, reference_node));
00388 
00389     if (lower_bound_distance > query_node->stat().max_distance_so_far()) {
00390 
00391       /*
00392        * A reference node with lower-bound distance greater than this
00393        * query node's upper-bound nearest neighbor distance cannot
00394        * contribute a reference closer than any of the queries'
00395        * current neighbors, hence prune
00396        */
00397       number_of_prunes_++;
00398 
00399     } else if (query_node->is_leaf() && reference_node->is_leaf()) {
00400 
00401       /* Cannot further split leaves, so process exhaustively */
00402       GNPBaseCase_(query_node, reference_node);
00403 
00404     } else if (query_node->is_leaf()) {
00405 
00406       /* Query node's a leaf, but we can split references */
00407       double left_distance =
00408           MinNodeDistSq_(query_node, reference_node->left());
00409       double right_distance =
00410           MinNodeDistSq_(query_node, reference_node->right());
00411 
00412       /*
00413        * Nearer reference node more likely to contribute neighbors
00414        * (and thus tighten bounds), so visit it first
00415        */
00416       if (left_distance < right_distance) {
00417         GNPRecursion_(query_node, reference_node->left(), left_distance);
00418         GNPRecursion_(query_node, reference_node->right(), right_distance);
00419       } else {
00420         GNPRecursion_(query_node, reference_node->right(), right_distance);
00421         GNPRecursion_(query_node, reference_node->left(), left_distance);
00422       }
00423 
00424     } else if (reference_node->is_leaf()) {
00425 
00426       /* Reference node's a leaf, but we can split queries */
00427       double left_distance =
00428           MinNodeDistSq_(query_node->left(), reference_node);
00429       double right_distance =
00430           MinNodeDistSq_(query_node->right(), reference_node);
00431 
00432       /* Order of recursion does not matter */
00433       GNPRecursion_(query_node->left(), reference_node, left_distance);
00434       GNPRecursion_(query_node->right(), reference_node, right_distance);
00435 
00436       /* Update upper bound nn distance base new child bounds */
00437       query_node->stat().set_max_distance_so_far(
00438           max(query_node->left()->stat().max_distance_so_far(),
00439               query_node->right()->stat().max_distance_so_far()));
00440 
00441     } else {
00442 
00443       /*
00444        * Neither node is a leaf, so split both
00445        *
00446        * The order we process the query node's children doesn't
00447        * matter, but for each we should visit their nearer reference
00448        * node first.
00449        */
00450 
00451       double left_distance =
00452           MinNodeDistSq_(query_node->left(), reference_node->left());
00453       double right_distance =
00454           MinNodeDistSq_(query_node->left(), reference_node->right());
00455 
00456       if (left_distance < right_distance) {
00457         GNPRecursion_(query_node->left(),
00458             reference_node->left(), left_distance);
00459         GNPRecursion_(query_node->left(),
00460             reference_node->right(), right_distance);
00461       } else {
00462         GNPRecursion_(query_node->left(),
00463             reference_node->right(), right_distance);
00464         GNPRecursion_(query_node->left(),
00465             reference_node->left(), left_distance);
00466       }
00467 
00468       left_distance =
00469           MinNodeDistSq_(query_node->right(), reference_node->left());
00470       right_distance =
00471           MinNodeDistSq_(query_node->right(), reference_node->right());
00472 
00473       if (left_distance < right_distance) {
00474         GNPRecursion_(query_node->right(),
00475             reference_node->left(), left_distance);
00476         GNPRecursion_(query_node->right(),
00477             reference_node->right(), right_distance);
00478       } else {
00479         GNPRecursion_(query_node->right(),
00480             reference_node->right(), right_distance);
00481         GNPRecursion_(query_node->right(),
00482             reference_node->left(), left_distance);
00483       }
00484 
00485       /* Update upper bound nn distance base new child bounds */
00486       query_node->stat().set_max_distance_so_far(
00487           max(query_node->left()->stat().max_distance_so_far(),
00488               query_node->right()->stat().max_distance_so_far()));
00489 
00490     }
00491 
00492   } /* GNPRecursion_ */
00493 
00495 
00496   // Note that we initialize with const references below to keep from
00497   // copying data until we want to.  By the way, which side you put
00498   // the &'s and *'s on is on the level of deep-seated religious
00499   // belief: some people get real angry if you defy them, but you're
00500   // really no worse a person either way.  The compiler is agnostic.
00501 
00505   void Init(const Matrix& queries_in, const Matrix& references_in,
00506             struct datanode* module_in) {
00507 
00508     // It's a good idea to make sure the object isn't initialized a
00509     // second time, as this is almost certainly mistaken.
00510     DEBUG_ASSERT(initialized_ == false);
00511     DEBUG_ONLY(initialized_ = true);
00512 
00513     module_ = module_in;
00514 
00515     /* The data sets need to have the same dimensionality */
00516     DEBUG_SAME_SIZE(queries_.n_rows(), references_.n_rows());
00517 
00518     /* Copy input matrices as they will be rearranged */
00519     queries_.Copy(queries_in);
00520     references_.Copy(references_in);
00521 
00522     leaf_size_ = fx_param_int(module_, "leaf_size", 20);
00523     DEBUG_ASSERT(leaf_size_ > 0);
00524 
00525     // Timers are another handy tool provided by FASTexec.  These are
00526     // emitted automatically once you call fx_done.
00527     fx_timer_start(module_, "tree_building");
00528 
00529     // Input matrices are rearranged to an in-order traversal of
00530     // either tree.  To help in iterpretting results, the third
00531     // argument is Init'd to a mapping from rearranged indices to the
00532     // original order.  The fourth argument, if provided, would
00533     // initialize the reverse of said.
00534 
00535     /* Build the trees */
00536     query_tree_ = tree::MakeKdTreeMidpoint<QueryTree>(
00537         queries_, leaf_size_, &old_from_new_queries_, NULL);
00538     reference_tree_ = tree::MakeKdTreeMidpoint<ReferenceTree>(
00539         references_, leaf_size_, &old_from_new_references_, NULL);
00540 
00541     // While we don't make use of this here, it is possible to start
00542     // timers after stopping them.  They continue where they left off.
00543     fx_timer_stop(module_, "tree_building");
00544 
00545     /* Ready the list of nearest neighbor candidates to be filled. */
00546     neighbor_indices_.Init(queries_.n_cols());
00547 
00548     /* Ready the vector of upper bound nn distances for use. */
00549     neighbor_distances_.Init(queries_.n_cols());
00550     neighbor_distances_.SetAll(DBL_MAX);
00551 
00552     number_of_prunes_ = 0;
00553 
00554   } /* Init */
00555 
00556 
00562   void InitNaive(const Matrix& queries_in, const Matrix& references_in,
00563                  struct datanode* module_in){
00564 
00565     DEBUG_ASSERT(initialized_ == false);
00566     DEBUG_ONLY(initialized_ = true);
00567 
00568     module_ = module_in;
00569 
00570     /* The data sets need to have the same dimensionality */
00571     DEBUG_SAME_SIZE(queries_.n_rows(), references_.n_rows());
00572 
00573     /* Copy input matrices */
00574     queries_.Copy(queries_in);
00575     references_.Copy(references_in);
00576 
00577     /*
00578      * A bit of a trick so we can still use BaseCase_: we'll expand
00579      * the leaf size so that our trees only have one node.
00580      */
00581     leaf_size_ = max(queries_.n_cols(), references_.n_cols());
00582 
00583     /* Build the (single node) trees */
00584     query_tree_ = tree::MakeKdTreeMidpoint<QueryTree>(
00585         queries_, leaf_size_, &old_from_new_queries_, NULL);
00586     reference_tree_ = tree::MakeKdTreeMidpoint<ReferenceTree>(
00587         references_, leaf_size_, &old_from_new_references_, NULL);
00588 
00589     /* Ready the list of nearest neighbor candidates to be filled. */
00590     neighbor_indices_.Init(queries_.n_cols());
00591 
00592     /* Ready the vector of upper bound nn distances for use. */
00593     neighbor_distances_.Init(queries_.n_cols());
00594     neighbor_distances_.SetAll(DBL_MAX);
00595 
00596     number_of_prunes_ = 0;
00597 
00598   } /* InitNaive */
00599 
00600 
00605   void ComputeNeighbors(ArrayList<index_t>* results) {
00606 
00607     // In addition to confirming the object's been initialized, we
00608     // want to make sure we aren't asking it to compute a second time.
00609     DEBUG_ASSERT(initialized_ == true);
00610     DEBUG_ASSERT(already_used_ == false);
00611     DEBUG_ONLY(already_used_ = true);
00612 
00613     fx_timer_start(module_, "dual_tree_computation");
00614 
00615     /* Start recursion on the roots of either tree */
00616     GNPRecursion_(query_tree_, reference_tree_,
00617         MinNodeDistSq_(query_tree_, reference_tree_));
00618 
00619     fx_timer_stop(module_, "dual_tree_computation");
00620 
00621     // Save the total number of prunes to the FASTexec module; this
00622     // will printed after calling fx_done or can be read back later.
00623     fx_result_int(module_, "number_of_prunes", number_of_prunes_);
00624 
00625     if (results) {
00626       EmitResults(results);
00627     }
00628 
00629   } /* ComputeNeighbors */
00630 
00631 
00635   void ComputeNaive(ArrayList<index_t>* results) {
00636 
00637     DEBUG_ASSERT(initialized_ == true);
00638     DEBUG_ASSERT(already_used_ == false);
00639     DEBUG_ONLY(already_used_ = true);
00640 
00641     fx_timer_start(module_, "naive_time");
00642 
00643     /* BaseCase_ on the roots is equivalent to naive */
00644     GNPBaseCase_(query_tree_, reference_tree_);
00645 
00646     fx_timer_stop(module_, "naive_time");
00647 
00648     if (results) {
00649       EmitResults(results);
00650     }
00651 
00652   } /* ComputeNaive */
00653 
00657   void EmitResults(ArrayList<index_t>* results) {
00658 
00659     DEBUG_ASSERT(initialized_ == true);
00660 
00661     results->Init(neighbor_indices_.size());
00662 
00663     /* Map the indices back from how they have been permuted. */
00664     for (index_t i = 0; i < neighbor_indices_.size(); i++) {
00665       (*results)[old_from_new_queries_[i]] =
00666           old_from_new_references_[neighbor_indices_[i]];
00667     }
00668 
00669   } /* EmitResults */
00670 
00671 }; /* class AllNN */
00672 
00673 #endif /* PLATONIC_ALLNN_H */
Generated on Mon Jan 24 12:04:37 2011 for FASTlib by  doxygen 1.6.3