dualtree_kde_common.h

00001 /* MLPACK 0.2
00002  *
00003  * Copyright (c) 2008, 2009 Alexander Gray,
00004  *                          Garry Boyer,
00005  *                          Ryan Riegel,
00006  *                          Nikolaos Vasiloglou,
00007  *                          Dongryeol Lee,
00008  *                          Chip Mappus, 
00009  *                          Nishant Mehta,
00010  *                          Hua Ouyang,
00011  *                          Parikshit Ram,
00012  *                          Long Tran,
00013  *                          Wee Chin Wong
00014  *
00015  * Copyright (c) 2008, 2009 Georgia Institute of Technology
00016  *
00017  * This program is free software; you can redistribute it and/or
00018  * modify it under the terms of the GNU General Public License as
00019  * published by the Free Software Foundation; either version 2 of the
00020  * License, or (at your option) any later version.
00021  *
00022  * This program is distributed in the hope that it will be useful, but
00023  * WITHOUT ANY WARRANTY; without even the implied warranty of
00024  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
00025  * General Public License for more details.
00026  *
00027  * You should have received a copy of the GNU General Public License
00028  * along with this program; if not, write to the Free Software
00029  * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
00030  * 02110-1301, USA.
00031  */
00032 #ifndef DUALTREE_KDE_COMMON_H
00033 #define DUALTREE_KDE_COMMON_H
00034 
00035 #include "inverse_normal_cdf.h"
00036 
00037 class DualtreeKdeCommon {
00038 
00039  public:
00040 
00043   static int qsort_comparator(const void *a, const void *b) {
00044     double *typecasted_a = (double *) a;
00045     double *typecasted_b = (double *) b;
00046     
00047     if(*typecasted_a < *typecasted_b) {
00048       return -1;
00049     }
00050     else if(*typecasted_a > *typecasted_b) {
00051       return 1;
00052     }
00053     else {
00054       return 0;
00055     }
00056   }
00057  
00066   template<typename TTree, typename TAlgorithm>
00067   static void AddPostponed(TTree *node, index_t destination, 
00068                            TAlgorithm *kde_object) {
00069 
00070     kde_object->densities_l_[destination] += node->stat().postponed_l_;
00071     kde_object->densities_e_[destination] += node->stat().postponed_e_;
00072     kde_object->densities_u_[destination] += node->stat().postponed_u_;
00073     kde_object->used_error_[destination] += node->stat().postponed_used_error_;
00074     kde_object->n_pruned_[destination] += node->stat().postponed_n_pruned_; 
00075   }
00076 
00077   template<typename TTree>
00078   static void BestNodePartners
00079   (TTree *nd, TTree *nd1, TTree *nd2, double probability, 
00080    TTree **partner1, double *probability1, 
00081    TTree **partner2, double *probability2) {
00082   
00083   double d1 = nd->bound().MinDistanceSq(nd1->bound());
00084   double d2 = nd->bound().MinDistanceSq(nd2->bound());
00085 
00086   // Prioritized traversal based on the squared distance bounds.
00087   if(d1 <= d2) {
00088     *partner1 = nd1;
00089     *probability1 = sqrt(probability);
00090     *partner2 = nd2;
00091     *probability2 = sqrt(probability);
00092   }
00093   else {
00094     *partner1 = nd2;
00095     *probability1 = sqrt(probability);
00096     *partner2 = nd1;
00097     *probability2 = sqrt(probability);
00098   }
00099 }
00100 
00110   template<typename TTree, typename TAlgorithm>
00111   static void RefineBoundStatistics(index_t q, TTree *qnode, 
00112                                     TAlgorithm *kde_object) {
00113     
00114     qnode->stat().mass_l_ = std::min(qnode->stat().mass_l_, 
00115                                      kde_object->densities_l_[q]);
00116     qnode->stat().mass_u_ = std::max(qnode->stat().mass_u_, 
00117                                      kde_object->densities_u_[q]);
00118     qnode->stat().used_error_ = std::max(qnode->stat().used_error_,
00119                                          kde_object->used_error_[q]);
00120     qnode->stat().n_pruned_ = std::min(qnode->stat().n_pruned_, 
00121                                        kde_object->n_pruned_[q]);
00122   }
00123 
00129   static void ShuffleAccordingToPermutation
00130   (Vector &v, const ArrayList<index_t> &permutation) {
00131     
00132     Vector v_tmp;
00133     v_tmp.Init(v.length());
00134     for(index_t i = 0; i < v_tmp.length(); i++) {
00135       v_tmp[i] = v[permutation[i]];
00136     }
00137     v.CopyValues(v_tmp);
00138   }
00139 
00140   static double OuterConfidenceInterval
00141   (double population_size, double sample_size,
00142    double sample_order_statistics_min_index,
00143    double population_order_statistics_min_index) {
00144     
00145     double total_probability = 0;
00146     double lower_percentile = population_order_statistics_min_index /
00147       population_size;
00148     
00149     for(double r_star = sample_order_statistics_min_index;
00150         r_star <= std::min(population_order_statistics_min_index, sample_size);
00151         r_star += 1.0) {
00152       
00153       // If any of the arguments to the binomial coefficient is
00154       // invalid, then the contribution is zero.
00155       if(r_star > population_order_statistics_min_index ||
00156          sample_size - r_star < 0 || 
00157          population_size - population_order_statistics_min_index < 0 ||
00158          sample_size - r_star >
00159          population_size - population_order_statistics_min_index) {
00160         continue;
00161       }
00162       
00163       /*
00164         total_probability +=
00165         BinomialCoefficientHelper_
00166         (population_order_statistics_min_index, r_star,
00167         population_size - population_order_statistics_min_index,
00168         sample_size - r_star, population_size, sample_size);
00169       */
00170       total_probability +=
00171         math::BinomialCoefficient((int) sample_size, (int) r_star) *
00172         pow(lower_percentile, r_star) * 
00173         pow(1 - lower_percentile, sample_size - r_star);
00174     }
00175     return std::max(std::min(total_probability, 1.0), 0.0);
00176   }
00177   
00178   static double BinomialCoefficientHelper(double n3, double k3, double n1, 
00179                                           double k1, double n2, double k2) {
00180     
00181     double n_k3 = n3 - k3;
00182     double n_k1 = n1 - k1;
00183     double n_k2 = n2 - k2;
00184     double nchsk = 1;
00185     double i;
00186     
00187     if(k3 > n3 || k3 < 0 || k1 > n1 || k1 < 0 || k2 > n2 || k2 < 0) {
00188       return 0;
00189     }
00190     
00191     if(k3 < n_k3) {
00192       k3 = n_k3;
00193       n_k3 = n3 - k3;
00194     }
00195     if(k1 < n_k1) {
00196       k1 = n_k1;
00197       n_k1 = n1 - k1;
00198     }
00199     if(k2 < n_k2) {
00200       k2 = n_k2;
00201       n_k2 = n2 - k2;
00202     }
00203     
00204     double min_index = std::min(n_k1, n_k2);
00205     double max_index = std::max(n_k1, n_k2);
00206     for(i = 1; i <= min_index; i += 1.0) {
00207       k1 += 1.0;
00208       k2 += 1.0;
00209       nchsk *= k1;
00210       nchsk /= k2;
00211     }
00212     for(i = min_index + 1; i <= max_index; i += 1.0) {
00213       if(n_k1 < n_k2) {
00214         k2 += 1.0;
00215         nchsk *= i;
00216         nchsk /= k2;
00217       }
00218       else {
00219         k1 += 1.0;
00220         nchsk *= k1;
00221         nchsk /= i;
00222       }
00223     }
00224     for(i = 1; i <= n_k3; i += 1.0) {
00225       k3 += 1.0;
00226       nchsk *= k3;
00227       nchsk /= i;
00228     }
00229     
00230     return nchsk;
00231   }
00232 
00233   template<typename TTree, typename TAlgorithm>
00234   static bool MonteCarloPrunable_
00235   (TTree *qnode, TTree *rnode, double probability, DRange &dsqd_range,
00236    DRange &kernel_value_range, double &dl, double &de, double &du, 
00237    double &used_error, double &n_pruned, TAlgorithm *kde_object) {
00238     
00239     // If the reference node contains too few points, then return.
00240     if(qnode->count() * rnode->count() < 50) {
00241       return false;
00242     }
00243     
00244     // Refine the lower bound using the new lower bound info.
00245     double max_used_error = 0;
00246     
00247     // Take random query/reference pair samples and determine how many
00248     // more samples are needed.
00249     bool flag = true;
00250     
00251     // Reset the current position of the scratch space to zero.
00252     double kernel_sums = 0;
00253     double squared_kernel_sums = 0;
00254     
00255     // Commence sampling...
00256     {
00257       double standard_score = 
00258         InverseNormalCDF::Compute(probability + 0.5 * (1 - probability));
00259       
00260       // The initial number of samples is equal to the default.
00261       int num_samples = 25;
00262       int total_samples = 0;
00263       
00264       do {
00265         for(index_t s = 0; s < num_samples; s++) {
00266           
00267           index_t random_query_point_index =
00268             math::RandInt(qnode->begin(), qnode->end());
00269           index_t random_reference_point_index = 
00270             math::RandInt(rnode->begin(), rnode->end());
00271           
00272           // Get the pointer to the current query point.
00273           const double *query_point = 
00274             (kde_object->qset_).GetColumnPtr(random_query_point_index);
00275           
00276           // Get the pointer to the current reference point.
00277           const double *reference_point = 
00278             (kde_object->rset_).GetColumnPtr(random_reference_point_index);
00279           
00280           // Compute the pairwise distance and kernel value.
00281           double squared_distance = la::DistanceSqEuclidean
00282             ((kde_object->rset_).n_rows(), query_point, reference_point);
00283           
00284           double weighted_kernel_value = 
00285             kde_object->EvalUnnormOnSq_(random_reference_point_index,
00286                                         squared_distance);
00287           kernel_sums += weighted_kernel_value;
00288           squared_kernel_sums += weighted_kernel_value * weighted_kernel_value;
00289           
00290         } // end of taking samples for this roune...
00291         
00292         // Increment total number of samples.
00293         total_samples += num_samples;
00294         
00295         // Compute the current estimate of the sample mean and the
00296         // sample variance.
00297         double sample_mean = kernel_sums / ((double) total_samples);
00298         double sample_variance =
00299           (squared_kernel_sums - total_samples * sample_mean * sample_mean) / 
00300           ((double) total_samples - 1);
00301         
00302         // Compute the current threshold for guaranteeing the relative
00303         // error bound.
00304         double new_used_error = qnode->stat().used_error_ +
00305           qnode->stat().postponed_used_error_;
00306         double new_n_pruned = qnode->stat().n_pruned_ + 
00307           qnode->stat().postponed_n_pruned_;
00308         
00309         // The currently proven lower bound.
00310         double new_mass_l = qnode->stat().mass_l_ + 
00311           qnode->stat().postponed_l_ + dl;
00312         double right_hand_side = 
00313           (kde_object->relative_error_ * new_mass_l - new_used_error) /
00314           (kde_object->rroot_->stat().get_weight_sum() - new_n_pruned);
00315         
00316         if(sqrt(sample_variance) * standard_score < right_hand_side) {
00317           kernel_sums = kernel_sums / ((double) total_samples) * 
00318             rnode->stat().get_weight_sum();
00319           max_used_error = rnode->stat().get_weight_sum() * 
00320             standard_score * sqrt(sample_variance);
00321           break;
00322         }
00323         else {
00324           flag = false;
00325           break;
00326         }
00327         
00328       } while(true);
00329       
00330     } // end of sampling...
00331     
00332     // If all queries can be pruned, then add the approximations.
00333     if(flag) {
00334       de = kernel_sums;
00335       used_error = max_used_error;
00336       return true;
00337     }
00338     return false;
00339   }
00340 
00341   template<typename TTree, typename TAlgorithm>
00342   static bool Prunable(TTree *qnode, TTree *rnode, double probability, 
00343                        DRange &dsqd_range, DRange &kernel_value_range, 
00344                        double &dl, double &de, double &du, 
00345                        double &used_error, double &n_pruned, 
00346                        TAlgorithm *kde_object) {
00347     
00348     // the new lower bound after incorporating new info
00349     dl = kernel_value_range.lo * rnode->stat().get_weight_sum();
00350     de = 0.5 * rnode->stat().get_weight_sum() * 
00351       (kernel_value_range.lo + kernel_value_range.hi);
00352     du = (kernel_value_range.hi - 1) * rnode->stat().get_weight_sum();
00353     
00354     // refine the lower bound using the new lower bound info
00355     double new_mass_l = qnode->stat().mass_l_ + 
00356       qnode->stat().postponed_l_ + dl;
00357     double new_used_error = qnode->stat().used_error_ + 
00358       qnode->stat().postponed_used_error_;
00359     double new_n_pruned = qnode->stat().n_pruned_ + 
00360       qnode->stat().postponed_n_pruned_;
00361     
00362     double allowed_err;
00363     
00364     // Compute the allowed error.
00365     allowed_err = (kde_object->relative_error_ * new_mass_l - new_used_error) *
00366       rnode->stat().get_weight_sum() / 
00367       ((double) kde_object->rroot_->stat().get_weight_sum() - new_n_pruned);
00368     
00369     // This is error per each query/reference pair for a fixed query
00370     double kernel_diff = 0.5 * (kernel_value_range.hi - kernel_value_range.lo);
00371     
00372     // this is total error for each query point
00373     used_error = kernel_diff * rnode->stat().get_weight_sum();
00374     
00375     // number of reference points for possible pruning.
00376     n_pruned = rnode->stat().get_weight_sum();
00377     
00378     // If the error bound is satisfied by the hard error bound, it is
00379     // safe to prune.
00380     return (!isnan(allowed_err)) && (used_error <= allowed_err);
00381   }
00382 
00383 };
00384 
00385 #endif
Generated on Mon Jan 24 12:04:38 2011 for FASTlib by  doxygen 1.6.3