00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00055 #include "mog_l2e.h"
00056 #include "fastlib/optimization/contrib/optimizers.h"
00057
00058
00059 const fx_entry_doc mog_l2e_main_entries[] = {
00060 {"data", FX_REQUIRED, FX_STR, NULL,
00061 " A file containing the data on which the model"
00062 " has to be fit.\n"},
00063 {"output", FX_PARAM, FX_STR, NULL,
00064 " The file into which the output is to be written into.\n"},
00065 FX_ENTRY_DOC_DONE
00066 };
00067
00068 const fx_submodule_doc mog_l2e_main_submodules[] = {
00069 {"mog_l2e", &mog_l2e_doc,
00070 " Responsible for intializing the model and"
00071 " computing the parameters.\n"},
00072 {"opt", &opt_doc,
00073 " Responsible for minimizing the L2 loss function"
00074 " and obtaining the parameter values.\n"},
00075 FX_SUBMODULE_DOC_DONE
00076 };
00077
00078 const fx_module_doc mog_l2e_main_doc = {
00079 mog_l2e_main_entries, mog_l2e_main_submodules,
00080 " This program test drives the parametric estimation "
00081 "of a Gaussian mixture model using L2 loss function.\n"
00082 };
00083
00084 int main(int argc, char* argv[]) {
00085
00086 fx_module *root =
00087 fx_init(argc, argv, &mog_l2e_main_doc);
00088
00090
00091 const char *data_filename = fx_param_str_req(root, "data");
00092
00093 Matrix data_points;
00094 data::Load(data_filename, &data_points);
00095
00097
00098 datanode *mog_l2e_module = fx_submodule(root, "mog_l2e");
00099 index_t number_of_gaussians = fx_param_int(mog_l2e_module, "K", 1);
00100 fx_set_param_int(mog_l2e_module, "D", data_points.n_rows());
00101 index_t dimension = fx_param_int_req(mog_l2e_module, "D");;
00102
00104
00105 datanode *opt_module = fx_submodule(root, "opt");
00106 const char *opt_method = fx_param_str(opt_module, "method", "QuasiNewton");
00107 index_t param_dim = (number_of_gaussians*(dimension+1)*(dimension+2)/2 - 1);
00108 fx_set_param_int(opt_module, "param_space_dim", param_dim);
00109
00110 index_t optim_flag = (strcmp(opt_method, "NelderMead") == 0 ? 1 : 0);
00111 MoGL2E mog;
00112
00113 if (optim_flag == 1) {
00114
00116
00117 NelderMead opt;
00118
00120 fx_timer_start(opt_module, "init_opt");
00121 opt.Init(MoGL2E::L2ErrorForOpt, data_points, opt_module);
00122 fx_timer_stop(opt_module, "init_opt");
00123
00125 double **pts;
00126 pts = (double**)malloc((param_dim+1)*sizeof(double*));
00127 for(index_t i = 0; i < param_dim+1; i++) {
00128 pts[i] = (double*)malloc(param_dim*sizeof(double));
00129 }
00130
00131 fx_timer_start(opt_module, "get_init_pts");
00132 MoGL2E::MultiplePointsGenerator(pts, param_dim+1,
00133 data_points, number_of_gaussians);
00134 fx_timer_stop(opt_module, "get_init_pts");
00135
00137
00138 fx_timer_start(opt_module, "optimizing");
00139 opt.Eval(pts);
00140 fx_timer_stop(opt_module, "optimizing");
00141
00143 mog.MakeModel(mog_l2e_module, pts[0]);
00144
00145 }
00146 else {
00147
00149
00150 QuasiNewton opt;
00151
00153 fx_timer_start(opt_module, "init_opt");
00154 opt.Init(MoGL2E::L2ErrorForOpt, data_points, opt_module);
00155 fx_timer_stop(opt_module, "init_opt");
00156
00158 double *pt;
00159 pt = (double*)malloc(param_dim*sizeof(double));
00160
00161 fx_timer_start(opt_module, "get_init_pt");
00162 MoGL2E::InitialPointGenerator(pt, data_points, number_of_gaussians);
00163 fx_timer_stop(opt_module, "get_init_pt");
00164
00166
00167 fx_timer_start(opt_module, "optimizing");
00168 opt.Eval(pt);
00169 fx_timer_stop(opt_module, "optimizing");
00170
00172 mog.MakeModel(mog_l2e_module, pt);
00173
00174 }
00175
00176 long double error = mog.L2Error(data_points);
00177 NOTIFY("Minimum L2 error achieved: %Lf", error);
00178 mog.Display();
00179
00180 ArrayList<double> results;
00181 mog.OutputResults(&results);
00182
00183
00185
00186 const char *output_filename = fx_param_str(NULL, "output", "output.csv");
00187
00188 FILE *output_file = fopen(output_filename, "w");
00189
00190 ot::Print(results, output_file);
00191 fclose(output_file);
00192 fx_done(root);
00193
00194 return 1;
00195 }