dualtree_kde_cv_common.h

00001 /* MLPACK 0.2
00002  *
00003  * Copyright (c) 2008, 2009 Alexander Gray,
00004  *                          Garry Boyer,
00005  *                          Ryan Riegel,
00006  *                          Nikolaos Vasiloglou,
00007  *                          Dongryeol Lee,
00008  *                          Chip Mappus, 
00009  *                          Nishant Mehta,
00010  *                          Hua Ouyang,
00011  *                          Parikshit Ram,
00012  *                          Long Tran,
00013  *                          Wee Chin Wong
00014  *
00015  * Copyright (c) 2008, 2009 Georgia Institute of Technology
00016  *
00017  * This program is free software; you can redistribute it and/or
00018  * modify it under the terms of the GNU General Public License as
00019  * published by the Free Software Foundation; either version 2 of the
00020  * License, or (at your option) any later version.
00021  *
00022  * This program is distributed in the hope that it will be useful, but
00023  * WITHOUT ANY WARRANTY; without even the implied warranty of
00024  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
00025  * General Public License for more details.
00026  *
00027  * You should have received a copy of the GNU General Public License
00028  * along with this program; if not, write to the Free Software
00029  * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
00030  * 02110-1301, USA.
00031  */
00032 #ifndef DUALTREE_KDE_CV_COMMON_H
00033 #define DUALTREE_KDE_CV_COMMON_H
00034 
00035 #include "inverse_normal_cdf.h"
00036 
00037 class DualtreeKdeCVCommon {
00038 
00039  public:
00040   
00041   template<typename TTree, typename TAlgorithm>
00042   static bool MonteCarloPrunable
00043   (TTree *qnode, TTree *rnode, double probability, DRange &dsqd_range,
00044    DRange &first_kernel_value_range, DRange &second_kernel_value_range,
00045    double &first_dl, double &first_de, double &first_du, 
00046    double &first_used_error, double &second_dl, double &second_de, 
00047    double &second_du, double &second_used_error, double &delta_n_pruned,
00048    TAlgorithm *kde_object) {
00049     
00050     // If there are too few pairs, then return.
00051     if(qnode->count() * rnode->count() < 50) {
00052       return false;
00053     }
00054     
00055     // Temporary kernel sums.
00056     double first_kernel_sum = 0, second_kernel_sum = 0;
00057     double first_squared_kernel_sum = 0, second_squared_kernel_sum = 0;
00058     
00059     // Commence sampling...
00060     double standard_score = 
00061       InverseNormalCDF::Compute(probability + 0.5 * (1 - probability));
00062     
00063     // The initial number of samples is equal to the default.
00064     int num_samples = 50;
00065     int total_samples = 0;
00066     
00067     for(index_t s = 0; s < num_samples; s++) {
00068       
00069       index_t random_query_point_index =
00070         math::RandInt(qnode->begin(), qnode->end());
00071       index_t random_reference_point_index = 
00072         math::RandInt(rnode->begin(), rnode->end());
00073       
00074       // Get the pointer to the current query point.
00075       const double *query_point = 
00076         (kde_object->rset_).GetColumnPtr(random_query_point_index);
00077       
00078       // Get the pointer to the current reference point.
00079       const double *reference_point = 
00080         (kde_object->rset_).GetColumnPtr(random_reference_point_index);
00081       
00082       // Compute the pairwise distance and kernel value.
00083       double squared_distance = la::DistanceSqEuclidean
00084         ((kde_object->rset_).n_rows(), query_point, reference_point);
00085       
00086       double first_weighted_kernel_value;
00087       double second_weighted_kernel_value;
00088       kde_object->EvalUnnormOnSq_
00089         (random_reference_point_index, squared_distance, 
00090          &first_weighted_kernel_value, &second_weighted_kernel_value);
00091       first_kernel_sum += first_weighted_kernel_value;
00092       second_kernel_sum += second_weighted_kernel_value;
00093       first_squared_kernel_sum += math::Sqr(first_weighted_kernel_value);
00094       second_squared_kernel_sum += math::Sqr(second_weighted_kernel_value);
00095 
00096     } // end of taking samples...
00097     
00098     // Increment total number of samples.
00099     total_samples += num_samples;
00100     
00101     // Compute the current estimate of the sample mean and the sample
00102     // variance.
00103     double first_sample_mean = first_kernel_sum / ((double) total_samples);
00104     double first_sample_variance =
00105       (first_squared_kernel_sum - total_samples * first_sample_mean * 
00106        first_sample_mean) / math::Sqr((double) total_samples - 1);
00107     double second_sample_mean = second_kernel_sum / ((double) total_samples);
00108     double second_sample_variance =
00109       (second_squared_kernel_sum - total_samples * second_sample_mean *
00110        second_sample_mean) / math::Sqr((double) total_samples - 1);
00111 
00112     // Refine the lower bound using the new lower bound info.
00113     double first_mass_l_change = qnode->count() * 
00114       rnode->stat().get_weight_sum() *
00115       (first_sample_mean - standard_score * sqrt(first_sample_variance));
00116     double first_new_mass_l = (kde_object->first_sum_l_) + first_mass_l_change;
00117     double second_mass_l_change = qnode->count() *
00118       rnode->stat().get_weight_sum() *
00119       (second_sample_mean - standard_score * sqrt(second_sample_variance));
00120     double second_new_mass_l = (kde_object-> second_sum_l_) +
00121       second_mass_l_change;
00122     
00123     // Compute the allowed error.
00124     double proportion = 1.0 / (kde_object->rroot_->count() * 
00125                                kde_object->rroot_->stat().get_weight_sum() -
00126                                kde_object->n_pruned_);
00127     double first_allowed_err = 
00128       (kde_object->relative_error_ * first_new_mass_l - 
00129        kde_object->first_used_error_) * proportion;
00130     double second_allowed_err =
00131       (kde_object->relative_error_ * second_new_mass_l - 
00132        kde_object->second_used_error_) * proportion;
00133         
00134     if(sqrt(first_sample_variance) * standard_score <= first_allowed_err &&
00135        sqrt(second_sample_variance) * standard_score <= second_allowed_err) {
00136       first_dl = std::max(first_dl, first_mass_l_change);
00137       first_de = qnode->count() * rnode->stat().get_weight_sum() * 
00138         first_sample_mean;
00139       first_used_error = qnode->count() * rnode->stat().get_weight_sum() * 
00140         standard_score * sqrt(first_sample_variance);
00141       second_dl = std::max(second_dl, second_mass_l_change);
00142       second_de = qnode->count() * rnode->stat().get_weight_sum() * 
00143         second_sample_mean;
00144       second_used_error = qnode->count() * rnode->stat().get_weight_sum() * 
00145         standard_score * sqrt(second_sample_variance);
00146       return true;
00147     }
00148     else {
00149       return false;
00150     }
00151   }
00152   
00153   template<typename TTree, typename TAlgorithm>
00154   static bool Prunable(TTree *qnode, TTree *rnode, double probability, 
00155                        DRange &dsqd_range, DRange &first_kernel_value_range,
00156                        DRange &second_kernel_value_range,
00157                        double &first_dl, double &first_de, double &first_du, 
00158                        double &first_used_error, 
00159                        double &second_dl, double &second_de, double &second_du,
00160                        double &second_used_error, double &delta_n_pruned,
00161                        TAlgorithm *kde_object) {
00162     
00163     // the new lower bound after incorporating new info
00164     first_dl = first_kernel_value_range.lo * qnode->count() * 
00165       rnode->stat().get_weight_sum();
00166     first_de = 0.5 * qnode->count() * rnode->stat().get_weight_sum() * 
00167       (first_kernel_value_range.lo + first_kernel_value_range.hi);
00168     first_du = (first_kernel_value_range.hi - 1) * qnode->count() *
00169       rnode->stat().get_weight_sum();
00170     second_dl = second_kernel_value_range.lo * qnode->count() *
00171       rnode->stat().get_weight_sum();
00172     second_de = 0.5 * qnode->count() * rnode->stat().get_weight_sum() * 
00173       (second_kernel_value_range.lo + second_kernel_value_range.hi);
00174     second_du = (second_kernel_value_range.hi - 1) * qnode->count() * 
00175       rnode->stat().get_weight_sum();
00176    
00177     // Refine the lower bound using the new lower bound info.
00178     double first_new_mass_l = (kde_object->first_sum_l_) + first_dl;
00179     double second_new_mass_l = (kde_object-> second_sum_l_) + second_dl;
00180     
00181     // Compute the allowed error.
00182     double proportion = 
00183       (qnode->count() * rnode->stat().get_weight_sum()) /
00184       (kde_object->rroot_->count() * 
00185        kde_object->rroot_->stat().get_weight_sum() - kde_object->n_pruned_);
00186     double first_allowed_err = 
00187       (kde_object->relative_error_ * first_new_mass_l - 
00188        kde_object->first_used_error_) *
00189       proportion;
00190     double second_allowed_err =
00191       (kde_object->relative_error_ * second_new_mass_l - 
00192        kde_object->second_used_error_) *
00193       proportion;
00194 
00195     // This is error per each query/reference pair for a fixed query
00196     double first_kernel_diff = 
00197       0.5 * (first_kernel_value_range.hi - first_kernel_value_range.lo);
00198     double second_kernel_diff =
00199       0.5 * (second_kernel_value_range.hi - second_kernel_value_range.lo);
00200     
00201     // this is total error for each query point
00202     first_used_error = first_kernel_diff * qnode->count() * 
00203       rnode->stat().get_weight_sum();
00204     second_used_error = second_kernel_diff * qnode->count() *
00205       rnode->stat().get_weight_sum();
00206     
00207     // number of reference points for possible pruning.
00208     delta_n_pruned = qnode->count() * rnode->stat().get_weight_sum();
00209     
00210     // If the error bound is satisfied by the hard error bound, it is
00211     // safe to prune.
00212     return (!isnan(first_allowed_err)) && (!isnan(second_allowed_err)) &&
00213       (first_used_error <= first_allowed_err) &&
00214       (second_used_error <= second_allowed_err);
00215   }
00216 };
00217 
00218 #endif
Generated on Mon Jan 24 12:04:38 2011 for FASTlib by  doxygen 1.6.3