00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00044 #ifndef FASTLIB_MIXTURE_GAUSSIAN_H
00045 #define FASTLIB_MIXTURE_GAUSSIAN_H
00046 #include "fastlib/fastlib.h"
00047 #include "support.h"
00048 class MixtureGauss {
00050 private:
00052 ArrayList<Vector> means;
00053
00055 ArrayList<Matrix> covs;
00056
00058 Vector prior;
00059
00061 ArrayList<Matrix> inv_covs;
00062
00064 Vector det_covs;
00065
00067 ArrayList<Vector> ACC_means;
00068
00070 ArrayList<Matrix> ACC_covs;
00071
00073 Vector ACC_prior;
00074
00076 double total;
00077
00078 OT_DEF(MixtureGauss) {
00079 OT_MY_OBJECT(means);
00080 OT_MY_OBJECT(covs);
00081 OT_MY_OBJECT(prior);
00082 OT_MY_OBJECT(inv_covs);
00083 OT_MY_OBJECT(det_covs);
00084 OT_MY_OBJECT(ACC_means);
00085 OT_MY_OBJECT(ACC_covs);
00086 OT_MY_OBJECT(ACC_prior);
00087 OT_MY_OBJECT(total);
00088 }
00089 public:
00091 void InitFromFile(const char* mean_fn, const char* covs_fn = NULL, const char* prior_fn = NULL);
00092
00098 void InitFromProfile(const ArrayList<Matrix>& matlst, int start, int N);
00099
00101 void Init(int K, int N);
00102
00104 void Init(int K, const Matrix& data, const ArrayList<int>& labels);
00105
00107 void print_mixture(const char* s) const;
00108
00110 void generate(Vector* v) const;
00111
00113 double getPDF(const Vector& v) const;
00114
00116 double getPDF(int cluster, const Vector& v) const;
00117
00119 const Vector& get_prior() const { return prior; }
00120
00122 const Vector& get_mean(int k) const { return means[k]; }
00123
00125 const Matrix& get_cov(int k) const { return covs[k]; }
00126
00128 int n_clusters() const { return means.size(); }
00129
00131 int v_length() const { return means[0].length(); }
00132
00133
00135 void start_accumulate() {
00136 total = 0;
00137 for (int i = 0; i < means.size(); i++) {
00138 ACC_means[i].SetZero();
00139 ACC_covs[i].SetZero();
00140 ACC_prior.SetZero();
00141 }
00142 }
00143
00145 void accumulate(const Vector& v) {
00146 double s = getPDF(v);
00147 for (int i = 0; i < means.size(); i++) {
00148 double p = getPDF(i, v) / s;
00149 ACC_prior[i] += p;
00150 la::AddExpert(p, v, &ACC_means[i]);
00151 Vector d;
00152 la::SubInit(v, means[i], &d);
00153 Matrix D;
00154 D.AliasColVector(d);
00155 la::MulExpert(p, false, D, true, D, 1.0, &ACC_covs[i]);
00156 }
00157 total ++;
00158 }
00159
00161 void accumulate_cluster(int i, const Vector& v) {
00162 la::AddTo(v, &ACC_means[i]);
00163 Matrix V;
00164 V.AliasColVector(v);
00165 la::MulExpert(1.0, false, V, true, V, 1.0, &ACC_covs[i]);
00166 ACC_prior[i]++;
00167 total++;
00168 }
00169
00171 void accumulate(double p, int i, const Vector& v) {
00172 la::AddExpert(p, v, &ACC_means[i]);
00173 Matrix V;
00174 V.AliasColVector(v);
00175 la::MulExpert(p, false, V, true, V, 1.0, &ACC_covs[i]);
00176 ACC_prior[i] += p;
00177 total += p;
00178 }
00179
00181 void end_accumulate_cluster() {
00182 for (int i = 0; i < means.size(); i++)
00183 if (ACC_prior[i] != 0) {
00184 la::ScaleOverwrite(1.0/ACC_prior[i], ACC_means[i], &means[i]);
00185 Matrix M;
00186 M.AliasColVector(means[i]);
00187 la::MulExpert(-1.0, false, M, true, M, 1.0/ACC_prior[i], &ACC_covs[i]);
00188 covs[i].CopyValues(ACC_covs[i]);
00189 prior[i] = ACC_prior[i]/total;
00190
00191 double det = la::Determinant(covs[i]);
00192 la::InverseOverwrite(covs[i], &inv_covs[i]);
00193 det_covs[i] = pow(2.0*math::PI, -means[i].length()/2.0) * pow(det, -0.5);
00194 }
00195 }
00196
00198 void end_accumulate() {
00199 for (int i = 0; i < means.size(); i++) {
00200 if (ACC_prior[i] != 0) {
00201 la::ScaleOverwrite(1.0/ACC_prior[i], ACC_means[i], &means[i]);
00202 Matrix M;
00203 M.AliasColVector(means[i]);
00204 la::MulExpert(-1.0, false, M, true, M, 1.0/ACC_prior[i], &ACC_covs[i]);
00205 covs[i].CopyValues(ACC_covs[i]);
00206 prior[i] = ACC_prior[i]/total;
00207
00208 double det = la::Determinant(covs[i]);
00209 la::InverseOverwrite(covs[i], &inv_covs[i]);
00210 det_covs[i] = pow(2.0*math::PI, -means[i].length()/2.0) * pow(det, -0.5);
00211 }
00212 }
00213 }
00214 };
00215 #endif