package edu.wisc.sjm.machlearn;

import edu.wisc.sjm.jutil.misc.DoubleUtil;
import edu.wisc.sjm.machlearn.confusion.ConfusionMatrix;
import edu.wisc.sjm.machlearn.dataset.DataSet;
import edu.wisc.sjm.machlearn.dataset.xydataset.XYDataSet;
import edu.wisc.sjm.machlearn.policy.TuneParameter;
import edu.wisc.sjm.machlearn.util.Util;

/* loaded from: input_file:builds/machlearn_install.jar:machlearn.jar:edu/wisc/sjm/machlearn/nfold_tune.class */
public class nfold_tune extends nfold {
    String[] s_tune_parameters;
    TuneParameter[] tune_parameters;
    protected int tune_folds;

    public nfold_tune(String str) throws Exception {
        super(str);
    }

    public nfold_tune() throws Exception {
    }

    @Override // edu.wisc.sjm.machlearn.nfold
    public void run() throws Exception {
        int folds = this.parameters.getFolds();
        System.out.println("Running preprocess stuff");
        System.out.println("Full set:");
        System.out.println(Util.printArray(this.data.getExampleNames()));
        this.data = this.policy.preprocess(this.data);
        System.out.println("Processed set:");
        System.out.println(Util.printArray(this.data.getExampleNames()));
        System.out.println("Splitting into folds");
        DataSet[][] splitDataSetFolds = this.data.splitDataSetFolds(folds, false);
        this.foldsets = new XYDataSet[folds][2];
        for (int i = 0; i < folds; i++) {
            this.foldsets[i][0] = (XYDataSet) splitDataSetFolds[i][0];
            this.foldsets[i][1] = (XYDataSet) splitDataSetFolds[i][1];
        }
        ConfusionMatrix confusionMatrix = new ConfusionMatrix(this.foldsets[0][0].getOutputFeatureId());
        ConfusionMatrix confusionMatrix2 = new ConfusionMatrix(this.foldsets[0][1].getOutputFeatureId());
        int i2 = 0;
        System.out.println("Starting run");
        for (int i3 = 0; i3 < folds; i3++) {
            System.out.print("Running fold #" + i3 + "\n");
            System.out.println("Tune/Train...\n");
            this.policy.invalidate();
            this.policy.tune(this.foldsets[i3][0]);
            i2 += this.policy.getFDSDataSet(this.foldsets[i3][1]).numFeatures();
            System.out.println("getting train accuracy...\n");
            confusionMatrix.add(this.policy, this.foldsets[i3][0]);
            System.out.println("getting test accuracy...\n");
            confusionMatrix2.add(this.policy, this.foldsets[i3][1]);
            System.out.print("nf:" + this.policy.getFDSDataSet(this.foldsets[i3][1]).numFeatures());
            System.out.print(" train:");
            System.out.print(DoubleUtil.printDecimal(confusionMatrix.getAccuracy(i3), 3));
            System.out.print(" test:");
            System.out.println(DoubleUtil.printDecimal(confusionMatrix2.getAccuracy(i3), 3));
        }
        double accuracy = confusionMatrix2.getAccuracy();
        double accuracy2 = confusionMatrix.getAccuracy();
        System.out.println("Average # of features:" + (i2 / folds));
        System.out.println("Train Accuracy is:" + accuracy2);
        System.out.println("Test Accuracy is:" + accuracy);
        System.out.println("Train Confusion");
        System.out.println("===============");
        System.out.println(confusionMatrix.toString());
        System.out.println("Test Confusion");
        System.out.println("==============");
        System.out.println(confusionMatrix2.toString());
        System.out.flush();
    }

    public static void main(String[] strArr) throws Exception {
        if (strArr.length != 1) {
            System.out.println("Usage: nfold prop.txt");
            System.exit(-1);
        }
        nfold_tune nfold_tuneVar = new nfold_tune(strArr[0]);
        nfold_tuneVar.setup();
        nfold_tuneVar.run();
    }
}
