00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038
00039
00040
00041
00042
00043
00044
00045
00046
00047
00048
00049 #include <string>
00050 #include "fastlib/fastlib.h"
00051 #include "ridge_regression.h"
00052 #include "ridge_regression_util.h"
00053
00054 int main(int argc, char *argv[]) {
00055
00057 const fx_entry_doc ridge_main_entries[] = {
00058 {"inversion_method", FX_PARAM, FX_STR, NULL,
00059 " The method chosen for inverting the design matrix: normalsvd\
00060 (SVD on normal equation: default), svd (SVD), quicsvd (QUIC-SVD).\n"},
00061 {"lambda_min", FX_PARAM, FX_DOUBLE, NULL,
00062 " The minimum lambda value used for CV (set to zero by default).\n"},
00063 {"lambda_max", FX_PARAM, FX_DOUBLE, NULL,
00064 " The maximum lambda value used for CV (set to zero by default).\n"},
00065 {"mode", FX_PARAM, FX_STR, NULL,
00066 " The computation mode: regress, cvregress (cross-validated regression),\
00067 fsregress (feature selection then regress).\n"},
00068 {"num_lambdas", FX_PARAM, FX_INT, NULL,
00069 " The number of lambdas to try for CV (set to 1 by default).\n"},
00070 {"predictions", FX_REQUIRED, FX_STR, NULL,
00071 " A file containing the observed predictions.\n"},
00072 {"predictor_indices", FX_PARAM, FX_STR, NULL,
00073 " The file containing the indices of the dimensions that act as the \
00074 predictors for the input dataset.\n"},
00075 {"predictors", FX_REQUIRED, FX_STR, NULL,
00076 " A file containing the predictors.\n"},
00077 {"prune_predictor_indices", FX_PARAM, FX_STR, NULL,
00078 " The file containing the indices of the dimensions that must be \
00079 considered for pruning for the input dataset.\n"},
00080 FX_ENTRY_DOC_DONE
00081 };
00082
00083 const fx_submodule_doc ridge_main_submodules[] = {
00084 FX_SUBMODULE_DOC_DONE
00085 };
00086
00087 const fx_module_doc ridge_main_doc = {
00088 ridge_main_entries, ridge_main_submodules,
00089 "This is the driver for the ridge regression.\n"
00090 };
00091
00092 fx_module *module = fx_init(argc, argv, &ridge_main_doc);
00093 double lambda_min = fx_param_double(module, "lambda_min", 0.0);
00094 double lambda_max = fx_param_double(module, "lambda_max", 0.0);
00095 int num_lambdas_to_cv = fx_param_int(module, "num_lambdas", 1);
00096 const char *mode = fx_param_str(module, "mode", "regress");
00097 if(lambda_min == lambda_max) {
00098 num_lambdas_to_cv = 1;
00099 if(!strcmp(mode, "crossvalidate")) {
00100 fx_set_param_str(module, "mode", "regress");
00101 mode = fx_param_str(module, "mode", "regress");
00102 }
00103 }
00104 else {
00105 fx_set_param_str(module, "mode", "cvregress");
00106 mode = fx_param_str(module, "mode", "cvregress");
00107 }
00108
00109
00110 std::string predictors_file = fx_param_str_req(module, "predictors");
00111 std::string predictions_file = fx_param_str_req(module, "predictions");
00112
00113 Matrix predictors;
00114 if (data::Load(predictors_file.c_str(), &predictors) == SUCCESS_FAIL) {
00115 FATAL("Unable to open file %s", predictors_file.c_str());
00116 }
00117
00118 Matrix predictions;
00119 if (data::Load(predictions_file.c_str(), &predictions) == SUCCESS_FAIL) {
00120 FATAL("Unable to open file %s", predictions_file.c_str());
00121 }
00122
00123 RidgeRegression engine;
00124 NOTIFY("Computing Regression...");
00125
00126 if(!strcmp(mode, "regress")) {
00127
00128 engine.Init(module, predictors, predictions);
00129 engine.SVDRegress(lambda_min);
00130 }
00131 else if(!strcmp(mode, "cvregress")) {
00132 NOTIFY("Crossvalidating for the optimal lambda in [ %g %g ] by trying \
00133 %d values...", lambda_min, lambda_max, num_lambdas_to_cv);
00134 engine.Init(module, predictors, predictions);
00135 engine.CrossValidatedRegression(lambda_min, lambda_max, num_lambdas_to_cv);
00136 }
00137 else if(!strcmp(mode, "fsregress")) {
00138
00139 NOTIFY("Feature selection based regression.\n");
00140
00141 Matrix predictor_indices_intermediate;
00142 Matrix prune_predictor_indices_intermediate;
00143 std::string predictor_indices_file = fx_param_str_req(module,
00144 "predictor_indices");
00145 std::string prune_predictor_indices_file =
00146 fx_param_str_req(module, "prune_predictor_indices");
00147 if(data::Load(predictor_indices_file.c_str(),
00148 &predictor_indices_intermediate) == SUCCESS_FAIL) {
00149 FATAL("Unable to open file %s", predictor_indices_file.c_str());
00150 }
00151 if(data::Load(prune_predictor_indices_file.c_str(),
00152 &prune_predictor_indices_intermediate) == SUCCESS_FAIL) {
00153 FATAL("Unable to open file %s", prune_predictor_indices_file.c_str());
00154 }
00155
00156 GenVector<index_t> predictor_indices;
00157 GenVector<index_t> prune_predictor_indices;
00158 predictor_indices.Init(predictor_indices_intermediate.n_cols());
00159 prune_predictor_indices.Init
00160 (prune_predictor_indices_intermediate.n_cols());
00161
00162
00163
00164
00165 for(index_t i = 0; i < predictor_indices_intermediate.n_cols(); i++) {
00166 predictor_indices[i] =
00167 (index_t) predictor_indices_intermediate.get(0, i);
00168 }
00169 for(index_t i = 0; i < prune_predictor_indices_intermediate.n_cols();
00170 i++) {
00171 prune_predictor_indices[i] = (index_t)
00172 prune_predictor_indices_intermediate.get(0, i);
00173 }
00174
00175
00176 GenVector<index_t> output_predictor_indices;
00177 engine.Init(module, predictors, predictor_indices, predictions);
00178 engine.FeatureSelectedRegression(predictor_indices,
00179 prune_predictor_indices,
00180 predictions,
00181 &output_predictor_indices);
00182 }
00183
00184 NOTIFY("Ridge Regression Model Training Complete!");
00185 double square_error = engine.ComputeSquareError();
00186 NOTIFY("Square Error:%g", square_error);
00187 fx_result_double(module, "square error", square_error);
00188 Matrix factors;
00189 engine.factors(&factors);
00190 std::string factors_file = fx_param_str(module, "factors", "factors.csv");
00191 NOTIFY("Saving factors...");
00192 data::Save(factors_file.c_str(), factors);
00193
00194 fx_done(module);
00195 return 0;
00196 }