dataset.h

00001 /* MLPACK 0.2
00002  *
00003  * Copyright (c) 2008, 2009 Alexander Gray,
00004  *                          Garry Boyer,
00005  *                          Ryan Riegel,
00006  *                          Nikolaos Vasiloglou,
00007  *                          Dongryeol Lee,
00008  *                          Chip Mappus, 
00009  *                          Nishant Mehta,
00010  *                          Hua Ouyang,
00011  *                          Parikshit Ram,
00012  *                          Long Tran,
00013  *                          Wee Chin Wong
00014  *
00015  * Copyright (c) 2008, 2009 Georgia Institute of Technology
00016  *
00017  * This program is free software; you can redistribute it and/or
00018  * modify it under the terms of the GNU General Public License as
00019  * published by the Free Software Foundation; either version 2 of the
00020  * License, or (at your option) any later version.
00021  *
00022  * This program is distributed in the hope that it will be useful, but
00023  * WITHOUT ANY WARRANTY; without even the implied warranty of
00024  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
00025  * General Public License for more details.
00026  *
00027  * You should have received a copy of the GNU General Public License
00028  * along with this program; if not, write to the Free Software
00029  * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
00030  * 02110-1301, USA.
00031  */
00043 #ifndef DATA_DATASET_H
00044 #define DATA_DATASET_H
00045 
00046 #include "fastlib/col/col_string.h"
00047 #include "fastlib/la/matrix.h"
00048 #include "fastlib/math/discrete.h"
00049 #include "fastlib/file/textfile.h"
00050 
00051 class TextLineReader;
00052 class TextWriter;
00053 
00059 class DatasetFeature {
00060  public:
00064   enum Type {
00066       CONTINUOUS,
00068       INTEGER,
00070       NOMINAL
00071   };
00072   
00073  private:
00075   String name_;
00077   Type type_;
00079   ArrayList<String> value_names_;
00080   
00081   OBJECT_TRAVERSAL(DatasetFeature) {
00082     OT_OBJ(name_);
00083     //OT_OBJ(reinterpret_cast<int &>(type_));
00084     OT_ENUM_EXPERT(type_, int,
00085       OT_ENUM_VAL(CONTINUOUS)
00086       OT_ENUM_VAL(INTEGER)
00087       OT_ENUM_VAL(NOMINAL));
00088     OT_OBJ(value_names_);
00089   }
00090 
00096  void InitGeneral(const char *name_in) {
00097     name_.Copy(name_in);
00098     value_names_.Init();
00099  }
00100 
00101  public:
00107   void InitContinuous(const char *name_in) {
00108     InitGeneral(name_in);
00109     type_ = CONTINUOUS;
00110   }
00111 
00117   void InitInteger(const char *name_in) {
00118     InitGeneral(name_in);
00119     type_ = INTEGER;
00120   }
00121 
00131   void InitNominal(const char *name_in) {
00132     InitGeneral(name_in);
00133     type_ = NOMINAL;
00134   }
00135   
00146   void Format(double value, String *result) const;
00147   
00160   success_t Parse(const char *str, double *d) const;
00161   
00167   const String& name() const {
00168     return name_;
00169   }
00170   
00176   Type type() const {
00177     return type_;
00178   }
00179   
00187   const String& value_name(int value) const {
00188     DEBUG_ASSERT(type_ == NOMINAL);
00189     return value_names_[value];
00190   }
00191   
00200   index_t n_values() const {
00201     return value_names_.size();
00202   }
00203   
00211   ArrayList<String>& value_names() {
00212     return value_names_;
00213   }
00214 };
00215 
00219 class DatasetInfo {
00220  private:
00221   String name_;
00222   ArrayList<DatasetFeature> features_;
00223 
00224   OBJECT_TRAVERSAL(DatasetInfo) {
00225     OT_OBJ(name_);
00226     OT_OBJ(features_);
00227   }
00228 
00229  public:
00231   ArrayList<DatasetFeature>& features() {
00232     return features_;
00233   }
00234 
00236   const DatasetFeature& feature(index_t attrib_num) const {
00237     return features_[attrib_num];
00238   }
00239 
00241   index_t n_features() const {
00242     return features_.size();
00243   }
00244   
00246   const char *name() const {
00247     return name_;
00248   }
00249   
00251   void set_name(const char *name_in) {
00252     name_.Destruct();
00253     name_.Copy(name_in);
00254   }
00255 
00259   bool is_all_continuous() const;
00260 
00267   void InitContinuous(index_t n_features,
00268       const char *name_in = "dataset");
00269 
00277   void Init(const char *name_in = "dataset");
00278 
00282   void WriteArffHeader(TextWriter *writer) const;
00283   
00290   void WriteCsvHeader(const char *sep, TextWriter *writer) const;
00291 
00299   void WriteMatrix(const Matrix& matrix, const char *sep,
00300       TextWriter *writer) const;
00301 
00313   success_t InitFromArff(TextLineReader *reader,
00314       const char *filename = "dataset");
00315   
00323   success_t InitFromCsv(TextLineReader *reader,
00324       const char *filename = "dataset");
00325 
00333   success_t InitFromFile(TextLineReader *reader,
00334       const char *filename = "dataset");
00344   success_t ReadMatrix(TextLineReader *reader, Matrix *matrix) const;
00345 
00356   success_t ReadPoint(TextLineReader *reader, double *point,
00357       bool *is_done) const;
00358 
00359  private:
00360   char *SkipSpace_(char *s);
00361 
00362   char *SkipNonspace_(char *s);
00363 
00364   void SkipBlanks_(TextLineReader *reader);
00365 
00366 };
00367 
00368 
00380 class Dataset {
00381  private:
00382   Matrix matrix_;
00383   DatasetInfo info_;
00384   
00385   OBJECT_TRAVERSAL(Dataset) {
00386     OT_OBJ(matrix_);
00387     OT_OBJ(info_);
00388   }
00389   
00390  public:
00400   const DatasetInfo& info() const {
00401     return info_;
00402   }
00403   
00408   DatasetInfo& info() {
00409     return info_;
00410   }
00411   
00419   index_t n_features() const {
00420     return matrix_.n_rows();
00421   }
00422   
00430   index_t n_points() const {
00431     return matrix_.n_cols();
00432   }
00433 
00442   index_t n_labels() const;
00443 
00464   void GetLabels(ArrayList<double> &labels_list,
00465                  ArrayList<index_t> &labels_index,
00466                  ArrayList<index_t> &labels_ct,
00467                  ArrayList<index_t> &labels_startpos) const;
00468  
00475   double get(index_t feature, index_t point) const {
00476     return matrix_.get(feature, point);
00477   }
00478   
00482   int get_int(index_t feature, index_t point) const {
00483     double d = get(feature, point);
00484     int i = int(d);
00485     DEBUG_ASSERT(d == double(i));
00486     return i;
00487   }
00488   
00496   void set(index_t feature, index_t point, double d) {
00497     matrix_.set(feature, point, d);
00498   }
00499   
00505   const double *point(index_t point) const {
00506     return matrix_.GetColumnPtr(point);
00507   }
00513   double *point(index_t point) {
00514     return matrix_.GetColumnPtr(point);
00515   }
00516   
00520   const Matrix& matrix() const {
00521     return matrix_;
00522   }
00527   Matrix& matrix() {
00528     return matrix_;
00529   }
00530   
00538   void Format(index_t feature, index_t point, String *result) const {
00539     info_.feature(feature).Format(get(feature, point), result);
00540   }
00541   
00549   void InitBlank() {
00550   }
00551   
00560   success_t InitFromFile(const char *fname);
00561   
00572   success_t InitFromFile(TextLineReader *reader,
00573       const char *filename = "dataset");
00574   
00582   success_t WriteCsv(const char *fname, bool header = false) const;
00583 
00589   success_t WriteArff(const char *fname) const;
00590 
00597   void CopyMatrix(const Matrix& matrix_in) {
00598     InitBlank();
00599     matrix_.Copy(matrix_in);
00600     info_.InitContinuous(matrix_.n_rows());
00601   }
00602   
00613   void OwnMatrix(Matrix* matrix_in) {
00614     InitBlank();
00615     matrix_.Own(matrix_in);
00616     info_.InitContinuous(matrix_.n_rows());
00617   }
00618   
00629   void AliasMatrix(const Matrix& matrix_in) {
00630     InitBlank();
00631     matrix_.Alias(matrix_in);
00632     info_.InitContinuous(matrix_.n_rows());
00633   }
00634   
00635   //--- Cross-validation features ---
00636 
00652   void SplitTrainTest(int folds, int fold_number,
00653       const ArrayList<index_t>& permutation,
00654       Dataset *train, Dataset *test) const;
00655 };
00656 
00660 namespace data {
00675   success_t Load(const char *fname, Matrix *matrix);
00690   template<typename Precision>
00691   success_t LargeLoad(const char *fname, GenMatrix<Precision> *matrix) {
00692     TextLineReader *reader = new TextLineReader();
00693     if (reader->Open(fname)==SUCCESS_FAIL) {
00694       reader->Error("Couldn't open %s", fname);
00695       return SUCCESS_FAIL;
00696     } 
00697     index_t dimension=0;
00698     String line=reader->Peek();
00699     ArrayList<String> result;
00700     result.Init();
00701     line.Split(",", &result);
00702     dimension=result.size();
00703     while (reader->Gobble()) {
00704     }
00705     matrix->StaticInit(dimension, reader->line_num());
00706     matrix->SetAll(0.0);
00707     delete reader;
00708     reader = new TextLineReader();
00709     reader->Open(fname);
00710     while (true) {
00711       String line=reader->Peek();
00712       ArrayList<String> result;
00713       result.Init();
00714       line.Split(",", &result);
00715       for(index_t i=0; i<result.size(); i++) {
00716         Precision num;
00717         sscanf(result[i].c_str(), "%lf", &num);
00718         matrix->set(i, reader->line_num()-1, (Precision)num);
00719       }
00720       if (reader->Gobble()==false) {
00721         break;
00722       }
00723     }
00724     return SUCCESS_PASS;  
00725   }
00726 
00727 
00742   success_t Save(const char *fname, const Matrix& matrix);
00743 };
00744 
00745 #endif
Generated on Mon Jan 24 12:04:37 2011 for FASTlib by  doxygen 1.6.3