kernel.h

00001 /* MLPACK 0.2
00002  *
00003  * Copyright (c) 2008, 2009 Alexander Gray,
00004  *                          Garry Boyer,
00005  *                          Ryan Riegel,
00006  *                          Nikolaos Vasiloglou,
00007  *                          Dongryeol Lee,
00008  *                          Chip Mappus, 
00009  *                          Nishant Mehta,
00010  *                          Hua Ouyang,
00011  *                          Parikshit Ram,
00012  *                          Long Tran,
00013  *                          Wee Chin Wong
00014  *
00015  * Copyright (c) 2008, 2009 Georgia Institute of Technology
00016  *
00017  * This program is free software; you can redistribute it and/or
00018  * modify it under the terms of the GNU General Public License as
00019  * published by the Free Software Foundation; either version 2 of the
00020  * License, or (at your option) any later version.
00021  *
00022  * This program is distributed in the hope that it will be useful, but
00023  * WITHOUT ANY WARRANTY; without even the implied warranty of
00024  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
00025  * General Public License for more details.
00026  *
00027  * You should have received a copy of the GNU General Public License
00028  * along with this program; if not, write to the Free Software
00029  * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
00030  * 02110-1301, USA.
00031  */
00038 #ifndef MATH_KERNEL_H
00039 #define MATH_KERNEL_H
00040 
00041 #include "fastlib/base/base.h"
00042 #include "fastlib/math/geometry.h"
00043 #include "fastlib/math/math_lib.h"
00044 
00045 #include <math.h>
00046 
00047 /* More to come soon - Gaussian, Epanechnakov, etc. */
00048 
00053 class GaussianKernel {
00054  private:
00055   double neg_inv_bandwidth_2sq_;
00056   double bandwidth_sq_;
00057 
00058   OBJECT_TRAVERSAL(GaussianKernel) {
00059     OT_OBJ(neg_inv_bandwidth_2sq_);
00060     OT_OBJ(bandwidth_sq_);
00061   }
00062 
00063  public:
00064   static const bool HAS_CUTOFF = false;
00065 
00066  public:
00067   double bandwidth_sq() const {
00068     return bandwidth_sq_;
00069   }
00070 
00071   void Init(double bandwidth_in, index_t dims) {
00072     Init(bandwidth_in);
00073   }
00074 
00080   void Init(double bandwidth_in) {
00081     bandwidth_sq_ = bandwidth_in * bandwidth_in;
00082     neg_inv_bandwidth_2sq_ = -1.0 / (2.0 * bandwidth_sq_);
00083   }
00084 
00089   double EvalUnnorm(double dist) const {
00090     return EvalUnnormOnSq(dist * dist);
00091   }
00092 
00097   double EvalUnnormOnSq(double sqdist) const {
00098     double d = exp(sqdist * neg_inv_bandwidth_2sq_);
00099     return d;
00100   }
00101 
00103   DRange RangeUnnormOnSq(const DRange& range) const {
00104     return DRange(EvalUnnormOnSq(range.hi), EvalUnnormOnSq(range.lo));  
00105   }
00106 
00110   double MaxUnnormValue() {
00111     return 1;
00112   }
00113 
00117   double CalcNormConstant(index_t dims) const {
00118     // Changed because * faster than / and 2 * math::PI opt out.  RR
00119     //return pow((-math::PI/neg_inv_bandwidth_2sq_), dims/2.0);
00120     return pow(2 * math::PI * bandwidth_sq_, dims / 2.0);
00121   }
00122 };
00123 
00128 class GaussianStarKernel {
00129  private:
00130   double neg_inv_bandwidth_2sq_;
00131   double factor_;
00132   double bandwidth_sq_;
00133   double critical_point_sq_;
00134   double critical_point_value_;
00135   
00136   OBJECT_TRAVERSAL(GaussianStarKernel) {
00137     OT_OBJ(neg_inv_bandwidth_2sq_);
00138     OT_OBJ(factor_);
00139     OT_OBJ(bandwidth_sq_);
00140     OT_OBJ(critical_point_sq_);
00141     OT_OBJ(critical_point_value_);
00142   }
00143   
00144  public:
00145   static const bool HAS_CUTOFF = false;
00146  
00147  public:
00148   double bandwidth_sq() const {
00149     return bandwidth_sq_;
00150   }
00151 
00157   void Init(double bandwidth_in, index_t dims) {
00158     bandwidth_sq_ = bandwidth_in * bandwidth_in;
00159     neg_inv_bandwidth_2sq_ = -1.0 / (2.0 * bandwidth_sq_);
00160     factor_ = pow(2.0, -dims / 2.0 - 1);
00161     critical_point_sq_ = 4 * bandwidth_sq_ * (dims / 2.0 + 2) * math::LN_2;
00162     critical_point_value_ = EvalUnnormOnSq(critical_point_sq_);
00163   }
00164   
00169   double EvalUnnorm(double dist) const {
00170     return EvalUnnormOnSq(dist * dist);
00171   }
00172 
00177   double EvalUnnormOnSq(double sqdist) const {
00178     double d =
00179       factor_ * exp(sqdist * neg_inv_bandwidth_2sq_ * 0.5)
00180       - exp(sqdist * neg_inv_bandwidth_2sq_);
00181     return d;
00182   }
00183 
00185   DRange RangeUnnormOnSq(const DRange& range) const {
00186     double eval_lo = EvalUnnormOnSq(range.lo);
00187     double eval_hi = EvalUnnormOnSq(range.hi);
00188     if (range.lo < critical_point_sq_) {
00189       if (range.hi < critical_point_sq_) {
00190         // Strictly under critical point.
00191         return DRange(eval_lo, eval_hi);
00192       } else {
00193         // Critical point is included
00194         return DRange(std::min(eval_lo, eval_hi), critical_point_value_);
00195       }
00196     } else {
00197       return DRange(eval_hi, eval_lo);  
00198     }
00199   }
00200 
00206   double CalcNormConstant(index_t dims) const {
00207     return pow(math::PI_2*bandwidth_sq_, dims / 2) / 2;
00208   }
00209   
00213   double CalcMultiplicativeNormConstant(index_t dims) const {
00214     return 1.0 / CalcNormConstant(dims);
00215   }
00216 };
00217 
00224 class EpanKernel {
00225  private:
00226   double inv_bandwidth_sq_;
00227   double bandwidth_sq_;
00228 
00229   OBJECT_TRAVERSAL(EpanKernel) {
00230     OT_OBJ(inv_bandwidth_sq_);
00231     OT_OBJ(bandwidth_sq_);
00232   }
00233   
00234  public:
00235   static const bool HAS_CUTOFF = true;
00236   
00237  public:
00238   void Init(double bandwidth_in, index_t dims) {
00239     Init(bandwidth_in);
00240   }
00241 
00245   void Init(double bandwidth_in) {
00246     bandwidth_sq_ = (bandwidth_in * bandwidth_in);
00247     inv_bandwidth_sq_ = 1.0 / bandwidth_sq_;
00248   }
00249   
00254   double EvalUnnorm(double dist) const {
00255     return EvalUnnormOnSq(dist * dist);
00256   }
00257   
00262   double EvalUnnormOnSq(double sqdist) const {
00263     // TODO: Try the fabs non-branching version.
00264     if (sqdist < bandwidth_sq_) {
00265       return 1 - sqdist * inv_bandwidth_sq_;
00266     } else {
00267       return 0;
00268     }
00269   }
00270 
00272   DRange RangeUnnormOnSq(const DRange& range) const {
00273     return DRange(EvalUnnormOnSq(range.hi), EvalUnnormOnSq(range.lo));  
00274   }
00275 
00279   double MaxUnnormValue() {
00280     return 1.0;
00281   }
00282   
00286   double CalcNormConstant(index_t dims) const {
00287     return 2.0 * math::SphereVolume(sqrt(bandwidth_sq_), dims)
00288         / (dims + 2.0);
00289   }
00290   
00294   double bandwidth_sq() const {
00295     return bandwidth_sq_;
00296   }
00297   
00301   double inv_bandwidth_sq() const {
00302     return inv_bandwidth_sq_;
00303   }
00304 };
00305 
00306 
00307 #endif
Generated on Mon Jan 24 12:04:37 2011 for FASTlib by  doxygen 1.6.3