package edu.wisc.sjm.machlearn;

import edu.wisc.sjm.jutil.io.UtilPrintStream;
import edu.wisc.sjm.jutil.misc.DoubleUtil;
import edu.wisc.sjm.jutil.misc.MainClass;
import edu.wisc.sjm.jutil.misc.PropertiesUtil;
import edu.wisc.sjm.machlearn.confusion.ConfusionMatrix;
import edu.wisc.sjm.machlearn.dataset.DataSet;
import edu.wisc.sjm.machlearn.dataset.xydataset.XYDataSet;
import edu.wisc.sjm.machlearn.policy.PolicyClassifier;
import edu.wisc.sjm.machlearn.policy.xypreprocessor.misc.RandomizeClasses;
import edu.wisc.sjm.machlearn.util.Util;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileOutputStream;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.PrintWriter;

/* loaded from: input_file:builds/machlearn_install.jar:machlearn.jar:edu/wisc/sjm/machlearn/nfoldPermute.class */
public class nfoldPermute extends MainClass {
    String datapath;
    int folds;
    int subsamples;
    XYDataSet data;
    XYDataSet orig_data;
    PolicyClassifier policy;
    String logpath;
    XYDataSet[][] foldsets;
    int permutations;
    String output_file;
    PrintWriter pw;

    public nfoldPermute(String str, int i, String str2) throws Exception {
        if (str != null) {
            PropertiesUtil.load(str);
        }
        this.permutations = i;
        this.datapath = PropertiesUtil.getString("edu.wisc.sjm.machlearn.nfold.datapath", null);
        setLogPath();
        this.folds = PropertiesUtil.getInt("edu.wisc.sjm.machlearn.nfold.folds", 10);
        this.subsamples = PropertiesUtil.getInt("edu.wisc.sjm.machlearn.nfold.subsamples", 1);
        this.policy = new PolicyClassifier();
        this.policy.setSubSamples(this.subsamples);
        PropertiesUtil.save(System.out, "edu.wisc.sjm");
        this.output_file = str2;
    }

    public nfoldPermute() throws Exception {
        this(null, 1, "permute.out");
    }

    public void setLogPath() throws Exception {
        this.logpath = PropertiesUtil.getString("edu.wisc.sjm.machlearn.nfold.logpath", null);
        if (this.logpath != null) {
            UtilPrintStream utilPrintStream = new UtilPrintStream(new FileOutputStream(this.logpath));
            System.setOut(utilPrintStream);
            System.setErr(utilPrintStream);
        }
    }

    public void printParameters() {
        PropertiesUtil.save(System.out, "edu.wisc.sjm");
    }

    public void setup() throws Exception {
        this.orig_data = XYDataSet.loadXYDataSet(this.datapath);
        this.orig_data = this.policy.preprocess(this.orig_data);
    }

    public void run() throws Exception {
        RandomizeClasses randomizeClasses = new RandomizeClasses();
        randomizeClasses.setSubSamples(this.subsamples);
        int i = -1;
        if (new File(this.output_file).exists()) {
            BufferedReader bufferedReader = new BufferedReader(new FileReader(this.output_file));
            bufferedReader.readLine();
            for (String readLine = bufferedReader.readLine(); readLine != null; readLine = bufferedReader.readLine()) {
                i++;
            }
            System.out.println("Number of permutations already done:" + i);
            bufferedReader.close();
            this.pw = new PrintWriter(new FileWriter(this.output_file, true));
        } else {
            this.pw = new PrintWriter(new FileWriter(this.output_file));
            this.pw.print("PAccuracy\t");
            this.pw.print("Accuracy\tRecall\tPrecision\tFScore\tFPR");
            this.pw.println("\tVAccuracy\tVRecall\tVPrecision\tVFScore\tVFPR");
        }
        if (i == -1) {
            this.data = this.orig_data;
            runPermutation();
            i++;
        }
        for (int i2 = i; i2 < this.permutations; i2++) {
            this.data = randomizeClasses.process(this.orig_data);
            runPermutation();
        }
        this.pw.close();
    }

    public void runPermutation() throws Exception {
        System.out.println("Processed set:");
        System.out.println(Util.printArray(this.data.getExampleNames()));
        System.out.println("Splitting into folds");
        DataSet[][] splitDataSetFolds = this.data.splitDataSetFolds(this.folds, this.subsamples, true, true);
        this.foldsets = new XYDataSet[this.folds][2];
        for (int i = 0; i < this.folds; i++) {
            this.foldsets[i][0] = (XYDataSet) splitDataSetFolds[i][0];
            this.foldsets[i][1] = (XYDataSet) splitDataSetFolds[i][1];
        }
        ConfusionMatrix confusionMatrix = new ConfusionMatrix(this.foldsets[0][0].getOutputFeatureId());
        ConfusionMatrix confusionMatrix2 = new ConfusionMatrix(this.foldsets[0][1].getOutputFeatureId());
        ConfusionMatrix confusionMatrix3 = new ConfusionMatrix(this.foldsets[0][1].getOutputFeatureId());
        int i2 = 0;
        System.out.println("Starting run");
        for (int i3 = 0; i3 < this.folds; i3++) {
            System.out.print("Running fold #" + i3 + "\n");
            System.out.println("Training...\n");
            this.policy.train(this.foldsets[i3][0]);
            int numFeatures = this.policy.getFDSDataSet(this.foldsets[i3][1]).numFeatures();
            i2 += numFeatures;
            confusionMatrix.add(this.policy, this.foldsets[i3][0]);
            confusionMatrix2.add(this.policy, this.foldsets[i3][1]);
            confusionMatrix3.add(this.policy.classifyBySubsample(this.foldsets[i3][1]), this.policy.getSubsampleClassifications(this.foldsets[i3][1]));
            System.out.print("nf:" + numFeatures);
            System.out.print(" train:");
            System.out.print(DoubleUtil.printDecimal(confusionMatrix.getAccuracy(i3), 3));
            System.out.print(" test:");
            System.out.println(DoubleUtil.printDecimal(confusionMatrix2.getAccuracy(i3), 3));
            System.out.println(" Vote test:" + this.policy.getSubsampleThreshold());
            System.out.println(DoubleUtil.printDecimal(confusionMatrix3.getAccuracy(i3), 3));
        }
        double accuracy = confusionMatrix2.getAccuracy();
        double accuracy2 = confusionMatrix.getAccuracy();
        System.out.println("Average # of features:" + (i2 / this.folds));
        System.out.println("Train Accuracy is:" + accuracy2);
        System.out.println("Test Accuracy is:" + accuracy);
        System.out.println("Vote Test Accuracy is:" + confusionMatrix3.getAccuracy());
        System.out.println("Train Confusion");
        System.out.println("===============");
        System.out.println(confusionMatrix.toString());
        System.out.println("Test Confusion");
        System.out.println("==============");
        System.out.println(confusionMatrix2.toString());
        System.out.println("VTest Confusion");
        System.out.println("===============");
        System.out.println(confusionMatrix3.toString());
        System.out.println("Train Stats");
        System.out.println("===============");
        System.out.println(confusionMatrix.getStats());
        System.out.println("Test Stats");
        System.out.println("==============");
        System.out.println(confusionMatrix2.getStats());
        System.out.println("VTest Stats");
        System.out.println("===============");
        System.out.println(confusionMatrix3.getStats());
        System.out.flush();
        int i4 = 0;
        for (int i5 = 0; i5 < this.orig_data.size(); i5++) {
            if (this.orig_data.getOutputValueId(i5) == this.data.getOutputValueId(i5)) {
                i4++;
            }
        }
        this.pw.print(String.valueOf((i4 / this.orig_data.size()) * 100.0d) + "\t");
        this.pw.print(String.valueOf(confusionMatrix2.getAccuracy()) + "\t" + confusionMatrix2.getRecall() + "\t" + confusionMatrix2.getPrecision() + "\t" + confusionMatrix2.getFScore() + "\t" + confusionMatrix2.getFPR());
        this.pw.println("\t" + confusionMatrix3.getAccuracy() + "\t" + confusionMatrix3.getRecall() + "\t" + confusionMatrix3.getPrecision() + "\t" + confusionMatrix3.getFScore() + "\t" + confusionMatrix3.getFPR());
        this.pw.flush();
    }

    public static void main(String[] strArr) throws Exception {
        if (strArr.length < 2) {
            System.out.println("Usage: nfoldPermute prop.txt permutations output_file");
            System.exit(-1);
        }
        nfoldPermute nfoldpermute = new nfoldPermute(strArr[0], Integer.parseInt(strArr[1]), strArr.length >= 3 ? strArr[2] : "permute.out");
        nfoldpermute.setup();
        PropertiesUtil.save("tmp.prop");
        nfoldpermute.run();
    }
}
