package edu.wisc.sjm.machlearn;

import edu.wisc.sjm.jutil.misc.MainClass;
import edu.wisc.sjm.jutil.misc.PropertiesUtil;
import edu.wisc.sjm.jutil.vectors.DoubleVector;
import edu.wisc.sjm.jutil.vectors.IntVector;
import edu.wisc.sjm.machlearn.confusion.ConfusionMatrix;
import edu.wisc.sjm.machlearn.dataset.xydataset.XYDataSet;
import edu.wisc.sjm.machlearn.policy.PolicyClassifier;
import java.io.FileWriter;
import java.io.PrintWriter;

/* loaded from: input_file:builds/machlearn_install.jar:machlearn.jar:edu/wisc/sjm/machlearn/traintest_tune.class */
public class traintest_tune extends MainClass {
    String trainpath;
    String testpath;
    int fold_number;
    PolicyClassifier policy;
    String logpath;
    XYDataSet[][] foldsets;

    protected void loadProperties(String str) throws Exception {
        PropertiesUtil.load(str);
        this.trainpath = PropertiesUtil.getString("edu.wisc.sjm.machlearn.traintest.trainpath", "train.xy");
        this.testpath = PropertiesUtil.getString("edu.wisc.sjm.machlearn.traintest.testpath", "test.xy");
        this.fold_number = PropertiesUtil.getInt("edu.wisc.sjm.machlearn.nfold_fold", -1);
        setLogPath(PropertiesUtil.getString("edu.wisc.sjm.machlearn.traintest.logpath", null));
        this.policy = new PolicyClassifier();
    }

    @Override // edu.wisc.sjm.jutil.misc.MainClass
    public void setLogPath(String str) throws Exception {
        super.setLogPath(str);
        this.logpath = str;
        printParameters();
    }

    public traintest_tune() throws Exception {
        this(null);
    }

    public traintest_tune(String str) throws Exception {
        loadProperties(str);
    }

    public void printParameters() {
        System.out.println("trainpath:" + this.trainpath);
        System.out.println("testpath:" + this.testpath);
        System.out.println("logpath:" + this.logpath);
        this.policy.printPolicy();
    }

    public void setup() throws Exception {
        XYDataSet loadXYDataSet = XYDataSet.loadXYDataSet(this.trainpath);
        XYDataSet loadXYDataSet2 = XYDataSet.loadXYDataSet(this.testpath);
        this.foldsets = new XYDataSet[1][2];
        this.foldsets[0][0] = loadXYDataSet;
        this.foldsets[0][1] = loadXYDataSet2;
    }

    public void run() throws Exception {
        ConfusionMatrix confusionMatrix = new ConfusionMatrix(this.foldsets[0][0].getOutputFeatureId());
        ConfusionMatrix confusionMatrix2 = new ConfusionMatrix(this.foldsets[0][1].getOutputFeatureId());
        for (int i = 0; i < this.foldsets.length; i++) {
            System.out.println("Running fold #" + i);
            System.out.println("Preprocessing");
            this.foldsets[i][0] = this.policy.preprocess(this.foldsets[i][0]);
            this.foldsets[i][1] = this.policy.preprocess(this.foldsets[i][1]);
            System.out.println("Tuning");
            this.policy.tune(this.foldsets[i][0]);
            System.out.println("Getting train set confusion");
            confusionMatrix.add(this.policy, this.foldsets[i][0]);
            System.out.println("Getting test set confusion");
            confusionMatrix2.add(this.policy, this.foldsets[i][1]);
        }
        double accuracy = confusionMatrix2.getAccuracy();
        System.out.println("Train Accuracy is:" + confusionMatrix.getAccuracy());
        System.out.println("Test Accuracy is:" + accuracy);
        System.out.println("Train Confusion");
        System.out.println("===============");
        System.out.println(confusionMatrix.toString());
        System.out.println("Test Confusion");
        System.out.println("==============");
        System.out.println(confusionMatrix2.toString());
        saveProbClasses();
        System.out.flush();
    }

    public void saveProbClasses() throws Exception {
        if (this.fold_number != -1) {
            System.out.println("Saving fold info");
            DoubleVector doubleVector = new DoubleVector();
            IntVector intVector = new IntVector();
            double[][] distribution = this.policy.getDistribution(this.foldsets[0][1]);
            for (int i = 0; i < this.foldsets[0][1].size(); i++) {
                doubleVector.add(distribution[i][1]);
                intVector.add(this.foldsets[0][1].getOutputValueId(i));
            }
            IntVector intVector2 = new IntVector();
            DoubleVector.QuickSort(doubleVector, intVector2);
            PrintWriter printWriter = new PrintWriter(new FileWriter(String.valueOf(this.logpath) + ".probs_outputs"));
            for (int i2 = 0; i2 < doubleVector.size(); i2++) {
                printWriter.println(String.valueOf(doubleVector.get(intVector2.get(i2))) + "\t" + intVector.get(intVector2.get(i2)));
            }
            printWriter.close();
        }
    }

    public static void main(String[] strArr) throws Exception {
        if (strArr.length != 1) {
            System.out.println("Usage: traintest_tune prop.txt");
            System.exit(-1);
        }
        traintest_tune traintest_tuneVar = new traintest_tune(strArr[0]);
        traintest_tuneVar.setup();
        traintest_tuneVar.run();
    }
}
