package edu.wisc.sjm.machlearn.apps;

import edu.wisc.sjm.jutil.io.UtilPrintStream;
import edu.wisc.sjm.jutil.io.locking.RemoteLock;
import edu.wisc.sjm.jutil.io.locking.RemoteLockClient;
import edu.wisc.sjm.jutil.io.locking.RemoteLockConnectException;
import edu.wisc.sjm.jutil.misc.DoubleUtil;
import edu.wisc.sjm.jutil.misc.MainClass;
import edu.wisc.sjm.jutil.misc.PropertiesUtil;
import edu.wisc.sjm.machlearn.confusion.ConfusionMatrix;
import edu.wisc.sjm.machlearn.dataset.DataSet;
import edu.wisc.sjm.machlearn.dataset.xydataset.XYDataSet;
import edu.wisc.sjm.machlearn.policy.PolicyClassifier;
import edu.wisc.sjm.machlearn.policy.xypreprocessor.misc.RandomizeClasses;
import edu.wisc.sjm.machlearn.util.Util;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileOutputStream;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.PrintWriter;
import java.util.Vector;

/* loaded from: input_file:builds/machlearn_install.jar:machlearn.jar:edu/wisc/sjm/machlearn/apps/nfoldPermuteLock.class */
public class nfoldPermuteLock extends MainClass {
    String datapath;
    int folds;
    int subsamples;
    XYDataSet data;
    XYDataSet orig_data;
    PolicyClassifier policy;
    String logpath;
    XYDataSet[][] foldsets;
    int permutations;
    String output_file;
    String lock_server_name;
    int lock_server_port;
    File file_obj;
    RemoteLockClient lock_client;
    PrintWriter pw;
    RemoteLock file_lock;

    public nfoldPermuteLock(String str, String str2, String str3, int i) throws Exception {
        if (str != null) {
            PropertiesUtil.load(str);
        }
        String[] split = str2.split(":");
        this.lock_server_name = split[0];
        this.lock_server_port = Integer.parseInt(split[1]);
        this.permutations = i;
        this.datapath = PropertiesUtil.getString("edu.wisc.sjm.machlearn.nfold.datapath", null);
        setLogPath();
        this.folds = PropertiesUtil.getInt("edu.wisc.sjm.machlearn.nfold.folds", 10);
        this.subsamples = PropertiesUtil.getInt("edu.wisc.sjm.machlearn.nfold.subsamples", 1);
        this.policy = new PolicyClassifier();
        this.policy.setSubSamples(this.subsamples);
        PropertiesUtil.save(System.out, "edu.wisc.sjm");
        this.output_file = str3;
        this.file_obj = new File(str3);
    }

    public nfoldPermuteLock() throws Exception {
        this(null, "nova-1.cs.wisc.edu:4100", "permute.out", 1);
    }

    public void setLogPath() throws Exception {
        this.logpath = PropertiesUtil.getString("edu.wisc.sjm.machlearn.nfold.logpath", null);
        if (this.logpath != null) {
            UtilPrintStream utilPrintStream = new UtilPrintStream(new FileOutputStream(this.logpath));
            System.setOut(utilPrintStream);
            System.setErr(utilPrintStream);
        }
    }

    public void printParameters() {
        PropertiesUtil.save(System.out, "edu.wisc.sjm");
    }

    public void setup() throws Exception {
        this.orig_data = XYDataSet.loadXYDataSet(this.datapath);
        this.orig_data = this.policy.preprocess(this.orig_data);
    }

    private void getLock() throws Exception {
        try {
            this.lock_client = new RemoteLockClient(this.lock_server_name, this.lock_server_port);
            this.file_lock = this.lock_client.lock(this.file_obj);
        } catch (RemoteLockConnectException e) {
            log("WARNING: Unable to connect to server");
            log(e.toString());
            log("  Accessing file anyway");
            this.file_lock = null;
            this.lock_client = null;
        }
    }

    private void releaseLock() throws Exception {
        if (this.file_lock != null) {
            this.file_lock.release();
        }
        if (this.lock_client != null) {
            this.lock_client.close();
            this.lock_client = null;
        }
    }

    public Vector<String> getData(boolean z) throws Exception {
        if (!z) {
            getLock();
        }
        BufferedReader bufferedReader = new BufferedReader(new FileReader(this.file_obj));
        bufferedReader.readLine();
        Vector<String> vector = new Vector<>();
        String readLine = bufferedReader.readLine();
        while (true) {
            String str = readLine;
            if (str == null) {
                break;
            }
            vector.add(str);
            readLine = bufferedReader.readLine();
        }
        if (!z) {
            releaseLock();
        }
        return vector;
    }

    public int countPermutations(boolean z) throws Exception {
        if (!z) {
            getLock();
        }
        int i = -1;
        BufferedReader bufferedReader = new BufferedReader(new FileReader(this.output_file));
        bufferedReader.readLine();
        for (String readLine = bufferedReader.readLine(); readLine != null; readLine = bufferedReader.readLine()) {
            i++;
        }
        System.out.println("Number of permutations already done:" + i);
        bufferedReader.close();
        if (!z) {
            releaseLock();
        }
        return i;
    }

    public int countPermutations() throws Exception {
        return countPermutations(false);
    }

    private String getHeader() {
        return "PAccuracy\tAccuracy\tRecall\tPrecision\tFScore\tFPR\tVAccuracy\tVRecall\tVPrecision\tVFScore\tVFPR";
    }

    public void run() throws Exception {
        log("Number of permutations requested:" + this.permutations);
        RandomizeClasses randomizeClasses = new RandomizeClasses();
        randomizeClasses.setSubSamples(this.subsamples);
        int i = -1;
        if (this.file_obj.exists()) {
            i = countPermutations();
        } else {
            getLock();
            this.pw = new PrintWriter(new FileWriter(this.output_file));
            this.pw.println(getHeader());
            this.pw.println("Calc");
            this.pw.close();
            this.pw = null;
            releaseLock();
        }
        if (i == -1) {
            this.data = this.orig_data;
            System.out.println("Calculating original data set accuracy");
            runPermutation(true);
            i++;
        }
        while (i < this.permutations) {
            this.data = randomizeClasses.process(this.orig_data);
            runPermutation(false);
            i = countPermutations();
            System.out.println("Number of permutations done:" + i);
        }
    }

    public void runPermutation(boolean z) throws Exception {
        System.out.println("Processed set:");
        System.out.println(Util.printArray(this.data.getExampleNames()));
        System.out.println("Splitting into folds");
        DataSet[][] splitDataSetFolds = this.data.splitDataSetFolds(this.folds, this.subsamples, !z, true);
        this.foldsets = new XYDataSet[this.folds][2];
        for (int i = 0; i < this.folds; i++) {
            this.foldsets[i][0] = (XYDataSet) splitDataSetFolds[i][0];
            this.foldsets[i][1] = (XYDataSet) splitDataSetFolds[i][1];
        }
        ConfusionMatrix confusionMatrix = new ConfusionMatrix(this.foldsets[0][0].getOutputFeatureId());
        ConfusionMatrix confusionMatrix2 = new ConfusionMatrix(this.foldsets[0][1].getOutputFeatureId());
        ConfusionMatrix confusionMatrix3 = new ConfusionMatrix(this.foldsets[0][1].getOutputFeatureId());
        int i2 = 0;
        System.out.println("Starting run");
        for (int i3 = 0; i3 < this.folds; i3++) {
            System.out.print("Running fold #" + i3 + "\n");
            System.out.println("Training...\n");
            this.policy.train(this.foldsets[i3][0]);
            int numFeatures = this.policy.getFDSDataSet(this.foldsets[i3][1]).numFeatures();
            i2 += numFeatures;
            confusionMatrix.add(this.policy, this.foldsets[i3][0]);
            confusionMatrix2.add(this.policy, this.foldsets[i3][1]);
            confusionMatrix3.add(this.policy.classifyBySubsample(this.foldsets[i3][1]), this.policy.getSubsampleClassifications(this.foldsets[i3][1]));
            System.out.print("nf:" + numFeatures);
            System.out.print(" train:");
            System.out.print(DoubleUtil.printDecimal(confusionMatrix.getAccuracy(i3), 3));
            System.out.print(" test:");
            System.out.println(DoubleUtil.printDecimal(confusionMatrix2.getAccuracy(i3), 3));
            System.out.println(" Vote test:" + this.policy.getSubsampleThreshold());
            System.out.println(DoubleUtil.printDecimal(confusionMatrix3.getAccuracy(i3), 3));
        }
        double accuracy = confusionMatrix2.getAccuracy();
        double accuracy2 = confusionMatrix.getAccuracy();
        System.out.println("Average # of features:" + (i2 / this.folds));
        System.out.println("Train Accuracy is:" + accuracy2);
        System.out.println("Test Accuracy is:" + accuracy);
        System.out.println("Vote Test Accuracy is:" + confusionMatrix3.getAccuracy());
        System.out.println("Train Confusion");
        System.out.println("===============");
        System.out.println(confusionMatrix.toString());
        System.out.println("Test Confusion");
        System.out.println("==============");
        System.out.println(confusionMatrix2.toString());
        System.out.println("VTest Confusion");
        System.out.println("===============");
        System.out.println(confusionMatrix3.toString());
        System.out.println("Train Stats");
        System.out.println("===============");
        System.out.println(confusionMatrix.getStats());
        System.out.println("Test Stats");
        System.out.println("==============");
        System.out.println(confusionMatrix2.getStats());
        System.out.println("VTest Stats");
        System.out.println("===============");
        System.out.println(confusionMatrix3.getStats());
        System.out.flush();
        int i4 = 0;
        for (int i5 = 0; i5 < this.orig_data.size(); i5++) {
            if (this.orig_data.getOutputValueId(i5) == this.data.getOutputValueId(i5)) {
                i4++;
            }
        }
        double size = (i4 / this.orig_data.size()) * 100.0d;
        getLock();
        if (z) {
            Vector<String> data = getData(true);
            this.pw = new PrintWriter(new FileWriter(this.file_obj, false));
            this.pw.println(getHeader());
            this.pw.print(String.valueOf(size) + "\t");
            this.pw.print(String.valueOf(confusionMatrix2.getAccuracy()) + "\t" + confusionMatrix2.getRecall() + "\t" + confusionMatrix2.getPrecision() + "\t" + confusionMatrix2.getFScore() + "\t" + confusionMatrix2.getFPR());
            this.pw.println("\t" + confusionMatrix3.getAccuracy() + "\t" + confusionMatrix3.getRecall() + "\t" + confusionMatrix3.getPrecision() + "\t" + confusionMatrix3.getFScore() + "\t" + confusionMatrix3.getFPR());
            for (int i6 = 1; i6 < data.size(); i6++) {
                this.pw.println(data.get(i6));
            }
        } else {
            this.pw = new PrintWriter(new FileWriter(this.file_obj, true));
            this.pw.print(String.valueOf(size) + "\t");
            this.pw.print(String.valueOf(confusionMatrix2.getAccuracy()) + "\t" + confusionMatrix2.getRecall() + "\t" + confusionMatrix2.getPrecision() + "\t" + confusionMatrix2.getFScore() + "\t" + confusionMatrix2.getFPR());
            this.pw.println("\t" + confusionMatrix3.getAccuracy() + "\t" + confusionMatrix3.getRecall() + "\t" + confusionMatrix3.getPrecision() + "\t" + confusionMatrix3.getFScore() + "\t" + confusionMatrix3.getFPR());
        }
        this.pw.flush();
        this.pw.close();
        releaseLock();
    }

    public static void main(String[] strArr) throws Exception {
        if (strArr.length < 3) {
            System.out.println("Usage: nfoldPermuteLock prop.txt server:port permutations <output_file>");
            System.exit(-1);
        }
        nfoldPermuteLock nfoldpermutelock = new nfoldPermuteLock(strArr[0], strArr[1], strArr.length >= 4 ? strArr[3] : "permute.out", Integer.parseInt(strArr[2]));
        nfoldpermutelock.setup();
        PropertiesUtil.save("tmp.prop");
        nfoldpermutelock.run();
    }
}
