package edu.wisc.sjm.machlearn.apps;

import edu.wisc.sjm.jutil.io.UtilPrintStream;
import edu.wisc.sjm.jutil.misc.DoubleUtil;
import edu.wisc.sjm.jutil.misc.MainClass;
import edu.wisc.sjm.jutil.misc.PropertiesUtil;
import edu.wisc.sjm.machlearn.confusion.ConfusionMatrix;
import edu.wisc.sjm.machlearn.dataset.DataSet;
import edu.wisc.sjm.machlearn.dataset.Feature;
import edu.wisc.sjm.machlearn.dataset.featuredataset.FeatureDataSet;
import edu.wisc.sjm.machlearn.dataset.xydataset.XYDataSet;
import edu.wisc.sjm.machlearn.policy.PolicyClassifier;
import edu.wisc.sjm.machlearn.util.Util;
import java.io.FileOutputStream;

/* loaded from: input_file:builds/machlearn_install.jar:machlearn.jar:edu/wisc/sjm/machlearn/apps/nfoldProbCompressTune.class */
public class nfoldProbCompressTune extends MainClass {
    String datapath;
    int folds;
    int subsamples;
    XYDataSet data;
    PolicyClassifier policy;
    String logpath;
    XYDataSet[][] foldsets;
    boolean use_subsamples;

    public nfoldProbCompressTune(String str) throws Exception {
        if (str != null) {
            PropertiesUtil.load(str);
        }
        this.use_subsamples = PropertiesUtil.getBoolean("edu.wisc.sjm.machlearn.nfold.useSubsamples", false);
        this.datapath = PropertiesUtil.getString("edu.wisc.sjm.machlearn.nfold.datapath", null);
        setLogPath();
        this.folds = PropertiesUtil.getInt("edu.wisc.sjm.machlearn.nfold.folds", 10);
        this.subsamples = PropertiesUtil.getInt("edu.wisc.sjm.machlearn.nfold.subsamples", 1);
        this.policy = new PolicyClassifier();
        this.policy.setSubSamples(this.subsamples);
        PropertiesUtil.save(System.out, "edu.wisc.sjm");
    }

    public nfoldProbCompressTune() throws Exception {
        this(null);
    }

    public void setLogPath() throws Exception {
        this.logpath = PropertiesUtil.getString("edu.wisc.sjm.machlearn.nfold.logpath", null);
        if (this.logpath != null) {
            UtilPrintStream utilPrintStream = new UtilPrintStream(new FileOutputStream(this.logpath));
            System.setOut(utilPrintStream);
            System.setErr(utilPrintStream);
        }
    }

    public void printParameters() {
        PropertiesUtil.save(System.out, "edu.wisc.sjm");
    }

    public void setup() throws Exception {
        this.data = XYDataSet.loadXYDataSet(this.datapath);
    }

    public void run() throws Exception {
        System.out.println("Running preprocess stuff");
        System.out.println("Full set:");
        for (int i = 0; i < this.data.size(); i++) {
            System.out.println(String.valueOf(this.data.getExampleName(i)) + "\t" + this.data.getOutputFeature(i).getValueId());
        }
        this.data = this.policy.preprocess(this.data);
        System.out.println("Processed set:");
        System.out.println(Util.printArray(this.data.getExampleNames()));
        System.out.println("Splitting into folds");
        DataSet[][] splitDataSetFolds = this.data.splitDataSetFolds(this.folds, this.subsamples, false, true);
        this.foldsets = new XYDataSet[this.folds][2];
        for (int i2 = 0; i2 < this.folds; i2++) {
            this.foldsets[i2][0] = (XYDataSet) splitDataSetFolds[i2][0];
            this.foldsets[i2][1] = (XYDataSet) splitDataSetFolds[i2][1];
        }
        ConfusionMatrix confusionMatrix = new ConfusionMatrix(this.foldsets[0][0].getOutputFeatureId());
        ConfusionMatrix confusionMatrix2 = new ConfusionMatrix(this.foldsets[0][1].getOutputFeatureId());
        new ConfusionMatrix(this.foldsets[0][1].getOutputFeatureId());
        int i3 = 0;
        System.out.println("Starting run");
        for (int i4 = 0; i4 < this.folds; i4++) {
            System.out.print("Running fold #" + i4 + "\n");
            System.out.println("Training...\n");
            this.policy.invalidate();
            this.policy.tune(this.foldsets[i4][0]);
            i3 += this.policy.getFDSDataSet(this.foldsets[i4][1]).numFeatures();
            confusionMatrix.add(this.policy.getClassifier(), this.policy.getFDSDataSet(this.foldsets[i4][0]));
            FeatureDataSet fDSDataSet = this.policy.getFDSDataSet(this.foldsets[i4][1]);
            Feature[] classify = this.policy.classify(this.foldsets[i4][1]);
            Feature[] featureArr = new Feature[classify.length];
            for (int i5 = 0; i5 < featureArr.length; i5++) {
                featureArr[i5] = fDSDataSet.getOutputFeature(i5);
            }
            confusionMatrix2.add(classify, featureArr);
            System.out.print("nf:" + this.policy.getFDSDataSet(this.foldsets[i4][1]).numFeatures());
            System.out.print(" train:");
            System.out.print(DoubleUtil.printDecimal(confusionMatrix.getAccuracy(i4), 3));
            System.out.print(" test:");
            System.out.println(DoubleUtil.printDecimal(confusionMatrix2.getAccuracy(i4), 3));
            System.out.println("Test set example names:");
            for (int i6 = 0; i6 < fDSDataSet.size(); i6++) {
                System.out.println(i6 + ")" + fDSDataSet.getExampleName(i6) + "\t" + fDSDataSet.getOutputValueId(i6) + "\t" + classify[i6].getValueId());
            }
        }
        double accuracy = confusionMatrix2.getAccuracy();
        double accuracy2 = confusionMatrix.getAccuracy();
        System.out.println("Average # of features:" + (i3 / this.folds));
        System.out.println("Train Accuracy is:" + accuracy2);
        System.out.println("Test Accuracy is:" + accuracy);
        System.out.println("Train Confusion");
        System.out.println("===============");
        System.out.println(confusionMatrix.toString());
        System.out.println("Test Confusion");
        System.out.println("==============");
        System.out.println(confusionMatrix2.toString());
        System.out.println("Train Stats");
        System.out.println("===============");
        System.out.println(confusionMatrix.getStats());
        System.out.println("Test Stats");
        System.out.println("==============");
        System.out.println(confusionMatrix2.getStats());
        System.out.flush();
    }

    public static void main(String[] strArr) throws Exception {
        if (strArr.length != 1) {
            System.out.println("Usage: nfold prop.txt");
            System.exit(-1);
        }
        nfoldProbCompressTune nfoldprobcompresstune = new nfoldProbCompressTune(strArr[0]);
        nfoldprobcompresstune.setup();
        PropertiesUtil.save("tmp.prop");
        nfoldprobcompresstune.run();
    }
}
