bandwidth_lscv.h
Go to the documentation of this file.00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00039 #ifndef BANDWIDTH_LSCV_H
00040 #define BANDWIDTH_LSCV_H
00041
00042 #include "dualtree_kde.h"
00043 #include "dualtree_kde_cv.h"
00044
00045 class BandwidthLSCV {
00046
00047 private:
00048
00049 static double plugin_bandwidth_(const Matrix &references) {
00050
00051 double avg_sdev = 0;
00052 int num_dims = references.n_rows();
00053 int num_data = references.n_cols();
00054 Vector mean_vector;
00055 mean_vector.Init(references.n_rows());
00056 mean_vector.SetZero();
00057
00058
00059 for(index_t i = 0; i < references.n_cols(); i++) {
00060 for(index_t j = 0; j < references.n_rows(); j++) {
00061 mean_vector[j] += references.get(j, i);
00062 }
00063 }
00064 la::Scale(1.0 / ((double) num_data), &mean_vector);
00065
00066
00067
00068 for(index_t j = 0; j < num_dims; j++) {
00069 double sdev = 0;
00070 for(index_t i = 0; i < num_data; i++) {
00071 sdev += math::Sqr(references.get(j, i) - mean_vector[j]);
00072 }
00073 sdev /= ((double) num_data - 1);
00074 sdev = sqrt(sdev);
00075 avg_sdev += sdev;
00076 }
00077 avg_sdev /= ((double) num_dims);
00078
00079 double plugin_bw =
00080 pow((4.0 / (num_dims + 2.0)), 1.0 / (num_dims + 4.0)) * avg_sdev *
00081 pow(num_data, -1.0 / (num_dims + 4.0));
00082
00083 return plugin_bw;
00084 }
00085
00086 public:
00087
00088 template<typename TKernelAux>
00089 static void ComputeLSCVScore(const Matrix &references,
00090 const Matrix &reference_weights,
00091 double bandwidth) {
00092
00093
00094 struct datanode *kde_module = fx_submodule(fx_root, "kde");
00095
00096
00097 typename TKernelAux::TKernel kernel;
00098
00099
00100 double lscv_score;
00101
00102
00103 kernel.Init(bandwidth);
00104
00105 printf("Trying the bandwidth value of %g...\n", bandwidth);
00106
00107
00108
00109 fx_set_param_double(kde_module, "bandwidth", bandwidth);
00110 DualtreeKdeCV<TKernelAux> *fast_kde_on_bandwidth =
00111 new DualtreeKdeCV<TKernelAux>();
00112 fast_kde_on_bandwidth->Init(references, reference_weights, kde_module);
00113 lscv_score = fast_kde_on_bandwidth->Compute();
00114 delete fast_kde_on_bandwidth;
00115
00116 printf("Least squares cross-validation score is %g...\n\n", lscv_score);
00117 }
00118
00119 template<typename TKernelAux>
00120 static void Optimize(const Matrix &references,
00121 const Matrix &reference_weights) {
00122
00123
00124 struct datanode *kde_module = fx_submodule(fx_root, "kde");
00125
00126
00127 double min_lscv_score = DBL_MAX;
00128
00129
00130 typename TKernelAux::TKernel kernel;
00131
00132
00133 double plugin_bandwidth = plugin_bandwidth_(references);
00134 double current_lower_search_limit = plugin_bandwidth * 0.00001;
00135 double current_upper_search_limit = plugin_bandwidth;
00136 double min_bandwidth = DBL_MAX;
00137
00138 printf("Searching the optimal bandwidth in [%g %g]...\n",
00139 current_lower_search_limit, current_upper_search_limit);
00140
00141 do {
00142
00143
00144
00145 double bandwidth = current_upper_search_limit;
00146 kernel.Init(bandwidth);
00147
00148 printf("Trying the bandwidth value of %g...\n", bandwidth);
00149
00150
00151
00152 fx_set_param_double(kde_module, "bandwidth", bandwidth);
00153 DualtreeKdeCV<TKernelAux> *fast_kde_on_bandwidth =
00154 new DualtreeKdeCV<TKernelAux>();
00155 fast_kde_on_bandwidth->Init(references, reference_weights, kde_module);
00156 double lscv_score = fast_kde_on_bandwidth->Compute();
00157 delete fast_kde_on_bandwidth;
00158
00159 printf("Least squares cross-validation score is %g...\n\n", lscv_score);
00160
00161 if(lscv_score < min_lscv_score) {
00162 min_lscv_score = lscv_score;
00163 min_bandwidth = bandwidth;
00164 }
00165 current_upper_search_limit /= 2.0;
00166
00167 } while(current_upper_search_limit > current_lower_search_limit);
00168
00169
00170
00171
00172 printf("Minimum score was %g and achieved at the bandwidth value of %g\n",
00173 min_lscv_score, min_bandwidth);
00174
00175 }
00176
00177 };
00178
00179 #endif