00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00040 #ifndef NAIVE_KDE_H
00041 #define NAIVE_KDE_H
00042
00043 #include "mlpack/allknn/allknn.h"
00044
00063 template<typename TKernel>
00064 class NaiveKde {
00065
00066 FORBID_ACCIDENTAL_COPIES(NaiveKde);
00067
00068 private:
00069
00071
00073 struct datanode *module_;
00074
00076 Matrix qset_;
00077
00079 Matrix rset_;
00080
00082 Vector rset_weights_;
00083
00085 ArrayList<TKernel> kernels_;
00086
00088 Vector densities_;
00089
00091 double norm_const_;
00092
00093 public:
00094
00096
00098 NaiveKde() {
00099 }
00100
00102 ~NaiveKde() {
00103 }
00104
00106
00112 void get_density_estimates(Vector *results) {
00113 results->Init(densities_.length());
00114
00115 for(index_t i = 0; i < densities_.length(); i++) {
00116 (*results)[i] = densities_[i];
00117 }
00118 }
00119
00121
00127 void Compute(Vector *results) {
00128
00129 printf("\nStarting naive KDE...\n");
00130 fx_timer_start(module_, "naive_kde_compute");
00131
00132 for(index_t q = 0; q < qset_.n_cols(); q++) {
00133
00134 const double *q_col = qset_.GetColumnPtr(q);
00135
00136
00137 for(index_t r = 0; r < rset_.n_cols(); r++) {
00138 const double *r_col = rset_.GetColumnPtr(r);
00139 double dsqd = la::DistanceSqEuclidean(qset_.n_rows(), q_col, r_col);
00140
00141 densities_[q] += rset_weights_[r] * kernels_[r].EvalUnnormOnSq(dsqd);
00142 }
00143
00144
00145 densities_[q] /= norm_const_;
00146 }
00147 fx_timer_stop(module_, "naive_kde_compute");
00148 printf("\nNaive KDE completed...\n");
00149
00150
00151 get_density_estimates(results);
00152 }
00153
00156 void Compute() {
00157
00158 printf("\nStarting naive KDE...\n");
00159 fx_timer_start(module_, "naive_kde_compute");
00160
00161 for(index_t q = 0; q < qset_.n_cols(); q++) {
00162
00163 const double *q_col = qset_.GetColumnPtr(q);
00164
00165
00166 for(index_t r = 0; r < rset_.n_cols(); r++) {
00167 const double *r_col = rset_.GetColumnPtr(r);
00168 double dsqd = la::DistanceSqEuclidean(qset_.n_rows(), q_col, r_col);
00169
00170 densities_[q] += rset_weights_[r] * kernels_[r].EvalUnnormOnSq(dsqd);
00171 }
00172
00173 densities_[q] /= norm_const_;
00174 }
00175 fx_timer_stop(module_, "naive_kde_compute");
00176 printf("\nNaive KDE completed...\n");
00177 }
00178
00179 void Init(Matrix &qset, Matrix &rset, struct datanode *module_in) {
00180
00181
00182 Matrix uniform_weights;
00183 uniform_weights.Init(1, rset.n_cols());
00184 uniform_weights.SetAll(1.0);
00185
00186 Init(qset, rset, uniform_weights, module_in);
00187 }
00188
00196 void Init(Matrix &qset, Matrix &rset, Matrix &reference_weights,
00197 struct datanode *module_in) {
00198
00199
00200 module_ = module_in;
00201
00202
00203 qset_.Copy(qset);
00204 rset_.Copy(rset);
00205 rset_weights_.Init(reference_weights.n_cols());
00206 for(index_t i = 0; i < rset_weights_.length(); i++) {
00207 rset_weights_[i] = reference_weights.get(0, i);
00208 }
00209
00210
00211 double weight_sum = 0;
00212 for(index_t i = 0; i < rset_weights_.length(); i++) {
00213 weight_sum += rset_weights_[i];
00214 }
00215
00216
00217 kernels_.Init(rset_.n_cols());
00218 if(!strcmp(fx_param_str(module_, "mode", "variablebw"), "variablebw")) {
00219
00220
00221 int knns = fx_param_int_req(module_, "knn");
00222 AllkNN all_knn;
00223 all_knn.Init(rset_, 20, knns);
00224 ArrayList<index_t> resulting_neighbors;
00225 ArrayList<double> squared_distances;
00226
00227 fx_timer_start(fx_root, "bandwidth_initialization");
00228 all_knn.ComputeNeighbors(&resulting_neighbors, &squared_distances);
00229
00230 for(index_t i = 0; i < squared_distances.size(); i += knns) {
00231 kernels_[i / knns].Init(sqrt(squared_distances[i + knns - 1]));
00232 }
00233 fx_timer_stop(fx_root, "bandwidth_initialization");
00234
00235
00236
00237 double min_norm_const = DBL_MAX;
00238 for(index_t i = 0; i < rset_weights_.length(); i++) {
00239 double norm_const = kernels_[i].CalcNormConstant(qset_.n_rows());
00240 min_norm_const = std::min(min_norm_const, norm_const);
00241 }
00242 for(index_t i = 0; i < rset_weights_.length(); i++) {
00243 double norm_const = kernels_[i].CalcNormConstant(qset_.n_rows());
00244 rset_weights_[i] *= (min_norm_const / norm_const);
00245 }
00246
00247
00248 norm_const_ = weight_sum * min_norm_const;
00249 }
00250 else {
00251 for(index_t i = 0; i < kernels_.size(); i++) {
00252 kernels_[i].Init(fx_param_double_req(module_, "bandwidth"));
00253 }
00254 norm_const_ = kernels_[0].CalcNormConstant(qset_.n_rows()) * weight_sum;
00255 }
00256
00257
00258 densities_.Init(qset.n_cols());
00259 densities_.SetZero();
00260 }
00261
00269 void PrintDebug() {
00270
00271 FILE *stream = stdout;
00272 const char *fname = NULL;
00273
00274 {
00275 fname = fx_param_str(module_, "naive_kde_output",
00276 "naive_kde_output.txt");
00277 stream = fopen(fname, "w+");
00278 }
00279 for(index_t q = 0; q < qset_.n_cols(); q++) {
00280 fprintf(stream, "%g\n", densities_[q]);
00281 }
00282
00283 if(stream != stdout) {
00284 fclose(stream);
00285 }
00286 }
00287
00297 void ComputeMaximumRelativeError(const Vector &density_estimates) {
00298
00299 double max_rel_err = 0;
00300 FILE *stream = fopen("relative_error_output.txt", "w+");
00301 int within_limit = 0;
00302
00303 for(index_t q = 0; q < densities_.length(); q++) {
00304 double rel_err =
00305 (fabs(density_estimates[q] - densities_[q]) < DBL_EPSILON) ?
00306 0 : fabs(density_estimates[q] - densities_[q]) / densities_[q];
00307
00308 if(isnan(density_estimates[q]) || isinf(density_estimates[q]) ||
00309 isnan(densities_[q]) || isinf(densities_[q])) {
00310 printf("Warning: Got infs or nans!\n");
00311 }
00312
00313 if(rel_err > max_rel_err) {
00314 max_rel_err = rel_err;
00315 }
00316 if(rel_err <= fx_param_double(module_, "relative_error", 0.01)) {
00317 within_limit++;
00318 }
00319
00320 fprintf(stream, "%g\n", rel_err);
00321 }
00322
00323 fclose(stream);
00324 fx_format_result(module_, "maximum_relative_error_for_fast_KDE", "%g",
00325 max_rel_err);
00326 fx_format_result(module_, "under_relative_error_limit", "%d",
00327 within_limit);
00328 }
00329
00330 };
00331
00332 #endif