package edu.wisc.sjm.machlearn;

import edu.wisc.sjm.jutil.misc.DoubleUtil;
import edu.wisc.sjm.jutil.misc.MainClass;
import edu.wisc.sjm.jutil.misc.PropertiesUtil;
import edu.wisc.sjm.machlearn.confusion.ConfusionMatrix;
import edu.wisc.sjm.machlearn.dataset.DataSet;
import edu.wisc.sjm.machlearn.dataset.xydataset.XYDataSet;
import edu.wisc.sjm.machlearn.policy.PolicyClassifier;
import edu.wisc.sjm.machlearn.util.Util;

/* loaded from: input_file:builds/machlearn_install.jar:machlearn.jar:edu/wisc/sjm/machlearn/nfold.class */
public class nfold extends MainClass {
    XYDataSet data;
    PolicyClassifier policy;
    XYDataSet[][] foldsets;
    MachlearnParameters parameters;

    public nfold(String str) throws Exception {
        this.parameters = new MachlearnParameters(str);
        this.policy = new PolicyClassifier();
        this.policy.setSubSamples(this.parameters.getSubSamples());
        PropertiesUtil.save(System.out, "edu.wisc.sjm");
    }

    public nfold() throws Exception {
        this(null);
    }

    public void printParameters() {
        PropertiesUtil.save(System.out, "edu.wisc.sjm");
    }

    public void setup() throws Exception {
        this.data = XYDataSet.loadXYDataSet(this.parameters.getDataPath());
    }

    public void run() throws Exception {
        int folds = this.parameters.getFolds();
        int subSamples = this.parameters.getSubSamples();
        System.out.println("Running preprocess stuff");
        System.out.println("Full set:");
        for (int i = 0; i < this.data.size(); i++) {
            System.out.println(String.valueOf(this.data.getExampleName(i)) + "\t" + this.data.getOutputFeature(i).getValueId());
        }
        this.data = this.policy.preprocess(this.data);
        System.out.println("Processed set:");
        System.out.println(Util.printArray(this.data.getExampleNames()));
        System.out.println("Splitting into folds");
        DataSet[][] splitDataSetFolds = this.data.splitDataSetFolds(folds, subSamples, this.parameters.getRandomizeFolds(), this.parameters.getBalanceFolds());
        this.foldsets = new XYDataSet[folds][2];
        for (int i2 = 0; i2 < folds; i2++) {
            this.foldsets[i2][0] = (XYDataSet) splitDataSetFolds[i2][0];
            this.foldsets[i2][1] = (XYDataSet) splitDataSetFolds[i2][1];
        }
        ConfusionMatrix confusionMatrix = new ConfusionMatrix(this.foldsets[0][0].getOutputFeatureId());
        ConfusionMatrix confusionMatrix2 = new ConfusionMatrix(this.foldsets[0][1].getOutputFeatureId());
        ConfusionMatrix confusionMatrix3 = new ConfusionMatrix(this.foldsets[0][1].getOutputFeatureId());
        int i3 = 0;
        System.out.println("Starting run");
        for (int i4 = 0; i4 < folds; i4++) {
            System.out.print("Running fold #" + i4 + "\n");
            System.out.print("Train Names\n");
            for (int i5 = 0; i5 < this.foldsets[i4][0].size(); i5++) {
                System.out.println(i5 + ":" + this.foldsets[i4][0].getName(i5) + " ");
            }
            System.out.print("\n");
            System.out.println("Test Names\n");
            for (int i6 = 0; i6 < this.foldsets[i4][1].size(); i6++) {
                System.out.println(i6 + ":" + this.foldsets[i4][1].getName(i6) + " ");
            }
            System.out.print("\n");
            System.out.println("Training...\n");
            this.policy.train(this.foldsets[i4][0]);
            i3 += this.policy.getFDSDataSet(this.foldsets[i4][1]).numFeatures();
            confusionMatrix.add(this.policy, this.foldsets[i4][0]);
            confusionMatrix2.add(this.policy, this.foldsets[i4][1]);
            if (subSamples > 1) {
                confusionMatrix3.add(this.policy.classifyBySubsample(this.foldsets[i4][1]), this.policy.getSubsampleClassifications(this.foldsets[i4][1]));
            }
            System.out.print("nf:" + this.policy.getFDSDataSet(this.foldsets[i4][1]).numFeatures());
            System.out.print(" train:");
            System.out.print(DoubleUtil.printDecimal(confusionMatrix.getAccuracy(i4), 3));
            System.out.print(" test:");
            System.out.println(DoubleUtil.printDecimal(confusionMatrix2.getAccuracy(i4), 3));
            if (subSamples > 1) {
                System.out.println(" Vote test:" + this.policy.getSubsampleThreshold());
                System.out.println(DoubleUtil.printDecimal(confusionMatrix3.getAccuracy(i4), 3));
            }
            System.out.println("Test set example names:");
            for (int i7 = 0; i7 < this.foldsets[i4][1].size(); i7++) {
                System.out.println(i7 + ")" + this.foldsets[i4][1].getExampleName(i7));
            }
        }
        double accuracy = confusionMatrix2.getAccuracy();
        double accuracy2 = confusionMatrix.getAccuracy();
        System.out.println("Average # of features:" + (i3 / folds));
        System.out.println("Train Accuracy is:" + accuracy2);
        System.out.println("Test Accuracy is:" + accuracy);
        if (subSamples > 1) {
            System.out.println("Vote Test Accuracy is:" + confusionMatrix3.getAccuracy());
        }
        System.out.println("Train Confusion");
        System.out.println("===============");
        System.out.println(confusionMatrix.toString());
        System.out.println("Test Confusion");
        System.out.println("==============");
        System.out.println(confusionMatrix2.toString());
        if (subSamples > 1) {
            System.out.println("VTest Confusion");
            System.out.println("===============");
            System.out.println(confusionMatrix3.toString());
        }
        System.out.println("Train Stats");
        System.out.println("===============");
        System.out.println(confusionMatrix.getStats());
        System.out.println("Test Stats");
        System.out.println("==============");
        System.out.println(confusionMatrix2.getStats());
        if (subSamples > 1) {
            System.out.println("VTest Stats");
            System.out.println("===============");
            System.out.println(confusionMatrix3.getStats());
        }
        System.out.flush();
    }

    public static void main(String[] strArr) throws Exception {
        if (strArr.length != 1) {
            System.out.println("Usage: nfold prop.txt");
            System.exit(-1);
        }
        nfold nfoldVar = new nfold(strArr[0]);
        nfoldVar.setup();
        PropertiesUtil.save("tmp.prop");
        nfoldVar.run();
    }
}
