dualtree_kde_cv_common.h
00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032 #ifndef DUALTREE_KDE_CV_COMMON_H
00033 #define DUALTREE_KDE_CV_COMMON_H
00034
00035 #include "inverse_normal_cdf.h"
00036
00037 class DualtreeKdeCVCommon {
00038
00039 public:
00040
00041 template<typename TTree, typename TAlgorithm>
00042 static bool MonteCarloPrunable
00043 (TTree *qnode, TTree *rnode, double probability, DRange &dsqd_range,
00044 DRange &first_kernel_value_range, DRange &second_kernel_value_range,
00045 double &first_dl, double &first_de, double &first_du,
00046 double &first_used_error, double &second_dl, double &second_de,
00047 double &second_du, double &second_used_error, double &delta_n_pruned,
00048 TAlgorithm *kde_object) {
00049
00050
00051 if(qnode->count() * rnode->count() < 50) {
00052 return false;
00053 }
00054
00055
00056 double first_kernel_sum = 0, second_kernel_sum = 0;
00057 double first_squared_kernel_sum = 0, second_squared_kernel_sum = 0;
00058
00059
00060 double standard_score =
00061 InverseNormalCDF::Compute(probability + 0.5 * (1 - probability));
00062
00063
00064 int num_samples = 50;
00065 int total_samples = 0;
00066
00067 for(index_t s = 0; s < num_samples; s++) {
00068
00069 index_t random_query_point_index =
00070 math::RandInt(qnode->begin(), qnode->end());
00071 index_t random_reference_point_index =
00072 math::RandInt(rnode->begin(), rnode->end());
00073
00074
00075 const double *query_point =
00076 (kde_object->rset_).GetColumnPtr(random_query_point_index);
00077
00078
00079 const double *reference_point =
00080 (kde_object->rset_).GetColumnPtr(random_reference_point_index);
00081
00082
00083 double squared_distance = la::DistanceSqEuclidean
00084 ((kde_object->rset_).n_rows(), query_point, reference_point);
00085
00086 double first_weighted_kernel_value;
00087 double second_weighted_kernel_value;
00088 kde_object->EvalUnnormOnSq_
00089 (random_reference_point_index, squared_distance,
00090 &first_weighted_kernel_value, &second_weighted_kernel_value);
00091 first_kernel_sum += first_weighted_kernel_value;
00092 second_kernel_sum += second_weighted_kernel_value;
00093 first_squared_kernel_sum += math::Sqr(first_weighted_kernel_value);
00094 second_squared_kernel_sum += math::Sqr(second_weighted_kernel_value);
00095
00096 }
00097
00098
00099 total_samples += num_samples;
00100
00101
00102
00103 double first_sample_mean = first_kernel_sum / ((double) total_samples);
00104 double first_sample_variance =
00105 (first_squared_kernel_sum - total_samples * first_sample_mean *
00106 first_sample_mean) / math::Sqr((double) total_samples - 1);
00107 double second_sample_mean = second_kernel_sum / ((double) total_samples);
00108 double second_sample_variance =
00109 (second_squared_kernel_sum - total_samples * second_sample_mean *
00110 second_sample_mean) / math::Sqr((double) total_samples - 1);
00111
00112
00113 double first_mass_l_change = qnode->count() *
00114 rnode->stat().get_weight_sum() *
00115 (first_sample_mean - standard_score * sqrt(first_sample_variance));
00116 double first_new_mass_l = (kde_object->first_sum_l_) + first_mass_l_change;
00117 double second_mass_l_change = qnode->count() *
00118 rnode->stat().get_weight_sum() *
00119 (second_sample_mean - standard_score * sqrt(second_sample_variance));
00120 double second_new_mass_l = (kde_object-> second_sum_l_) +
00121 second_mass_l_change;
00122
00123
00124 double proportion = 1.0 / (kde_object->rroot_->count() *
00125 kde_object->rroot_->stat().get_weight_sum() -
00126 kde_object->n_pruned_);
00127 double first_allowed_err =
00128 (kde_object->relative_error_ * first_new_mass_l -
00129 kde_object->first_used_error_) * proportion;
00130 double second_allowed_err =
00131 (kde_object->relative_error_ * second_new_mass_l -
00132 kde_object->second_used_error_) * proportion;
00133
00134 if(sqrt(first_sample_variance) * standard_score <= first_allowed_err &&
00135 sqrt(second_sample_variance) * standard_score <= second_allowed_err) {
00136 first_dl = std::max(first_dl, first_mass_l_change);
00137 first_de = qnode->count() * rnode->stat().get_weight_sum() *
00138 first_sample_mean;
00139 first_used_error = qnode->count() * rnode->stat().get_weight_sum() *
00140 standard_score * sqrt(first_sample_variance);
00141 second_dl = std::max(second_dl, second_mass_l_change);
00142 second_de = qnode->count() * rnode->stat().get_weight_sum() *
00143 second_sample_mean;
00144 second_used_error = qnode->count() * rnode->stat().get_weight_sum() *
00145 standard_score * sqrt(second_sample_variance);
00146 return true;
00147 }
00148 else {
00149 return false;
00150 }
00151 }
00152
00153 template<typename TTree, typename TAlgorithm>
00154 static bool Prunable(TTree *qnode, TTree *rnode, double probability,
00155 DRange &dsqd_range, DRange &first_kernel_value_range,
00156 DRange &second_kernel_value_range,
00157 double &first_dl, double &first_de, double &first_du,
00158 double &first_used_error,
00159 double &second_dl, double &second_de, double &second_du,
00160 double &second_used_error, double &delta_n_pruned,
00161 TAlgorithm *kde_object) {
00162
00163
00164 first_dl = first_kernel_value_range.lo * qnode->count() *
00165 rnode->stat().get_weight_sum();
00166 first_de = 0.5 * qnode->count() * rnode->stat().get_weight_sum() *
00167 (first_kernel_value_range.lo + first_kernel_value_range.hi);
00168 first_du = (first_kernel_value_range.hi - 1) * qnode->count() *
00169 rnode->stat().get_weight_sum();
00170 second_dl = second_kernel_value_range.lo * qnode->count() *
00171 rnode->stat().get_weight_sum();
00172 second_de = 0.5 * qnode->count() * rnode->stat().get_weight_sum() *
00173 (second_kernel_value_range.lo + second_kernel_value_range.hi);
00174 second_du = (second_kernel_value_range.hi - 1) * qnode->count() *
00175 rnode->stat().get_weight_sum();
00176
00177
00178 double first_new_mass_l = (kde_object->first_sum_l_) + first_dl;
00179 double second_new_mass_l = (kde_object-> second_sum_l_) + second_dl;
00180
00181
00182 double proportion =
00183 (qnode->count() * rnode->stat().get_weight_sum()) /
00184 (kde_object->rroot_->count() *
00185 kde_object->rroot_->stat().get_weight_sum() - kde_object->n_pruned_);
00186 double first_allowed_err =
00187 (kde_object->relative_error_ * first_new_mass_l -
00188 kde_object->first_used_error_) *
00189 proportion;
00190 double second_allowed_err =
00191 (kde_object->relative_error_ * second_new_mass_l -
00192 kde_object->second_used_error_) *
00193 proportion;
00194
00195
00196 double first_kernel_diff =
00197 0.5 * (first_kernel_value_range.hi - first_kernel_value_range.lo);
00198 double second_kernel_diff =
00199 0.5 * (second_kernel_value_range.hi - second_kernel_value_range.lo);
00200
00201
00202 first_used_error = first_kernel_diff * qnode->count() *
00203 rnode->stat().get_weight_sum();
00204 second_used_error = second_kernel_diff * qnode->count() *
00205 rnode->stat().get_weight_sum();
00206
00207
00208 delta_n_pruned = qnode->count() * rnode->stat().get_weight_sum();
00209
00210
00211
00212 return (!isnan(first_allowed_err)) && (!isnan(second_allowed_err)) &&
00213 (first_used_error <= first_allowed_err) &&
00214 (second_used_error <= second_allowed_err);
00215 }
00216 };
00217
00218 #endif