package edu.wisc.sjm.machlearn.util;

import edu.wisc.mgr.auc.Confusion;
import edu.wisc.sjm.jutil.misc.MainClass;
import edu.wisc.sjm.jutil.vars.DoubleVar;
import edu.wisc.sjm.jutil.vectors.DoubleVector;
import edu.wisc.sjm.jutil.vectors.IntVector;
import edu.wisc.sjm.machlearn.classifiers.Classifier;
import edu.wisc.sjm.machlearn.classifiers.knn.WeightedKNN;
import edu.wisc.sjm.machlearn.classifiers.misc.JMaxTan;
import edu.wisc.sjm.machlearn.classifiers.naivebayes.NaiveBayesHybrid;
import edu.wisc.sjm.machlearn.classifiers.naivebayes.probability.Probability;
import edu.wisc.sjm.machlearn.classifiers.svm.SMOLinear;
import edu.wisc.sjm.machlearn.classifiers.svm.SMOP;
import edu.wisc.sjm.machlearn.classifiers.svm.SMOR;
import edu.wisc.sjm.machlearn.classifiers.trees.LJ48;
import edu.wisc.sjm.machlearn.dataset.DataSet;
import edu.wisc.sjm.machlearn.dataset.featuredataset.FeatureDataSet;
import edu.wisc.sjm.machlearn.featureselection.FSDataSet;
import java.util.Iterator;
import java.util.Vector;
import weka.classifiers.lazy.kstar.KStarConstants;

/* loaded from: input_file:builds/machlearn_install.jar:machlearn.jar:edu/wisc/sjm/machlearn/util/APRUtil.class */
public final class APRUtil {
    private static IntVector internal_outputs = new IntVector();
    private static DoubleVector internal_probs = new DoubleVector();
    private static IntVector internal_index = new IntVector();
    private static Confusion internal_confusion = new Confusion();

    public static Classifier getBestClassifier(DataSet dataSet, DoubleVar doubleVar) throws Exception {
        Vector vector = new Vector();
        vector.add(new NaiveBayesHybrid());
        vector.add(new JMaxTan());
        vector.add(new WeightedKNN());
        vector.add(new LJ48());
        vector.add(new SMOLinear());
        vector.add(new SMOP());
        vector.add(new SMOR());
        return getBestClassifier(dataSet, (Vector<Classifier>) vector, doubleVar);
    }

    public static Classifier getBestClassifier(DataSet dataSet) throws Exception {
        return getBestClassifier(dataSet, new DoubleVar());
    }

    public static Classifier getBestClassifier(DataSet dataSet, Vector<Classifier> vector) throws Exception {
        return getBestClassifier(dataSet, vector, new DoubleVar());
    }

    public static Classifier getBestClassifier(DataSet dataSet, Vector<Classifier> vector, DoubleVar doubleVar) throws Exception {
        return getBestClassifier(dataSet.splitDataSetFolds(10, true, true), vector, doubleVar);
    }

    public static Classifier getBestClassifier(DataSet[][] dataSetArr, Vector<Classifier> vector, DoubleVar doubleVar) throws Exception {
        doubleVar.value = Double.NEGATIVE_INFINITY;
        Classifier classifier = null;
        Iterator<Classifier> it = vector.iterator();
        while (it.hasNext()) {
            Classifier next = it.next();
            double apr = getAPR(next, dataSetArr);
            if (apr > doubleVar.value) {
                classifier = next;
                doubleVar.value = apr;
            }
        }
        return classifier;
    }

    public static double getAPR(Probability probability, DataSet[][] dataSetArr) {
        DoubleVector doubleVector = internal_probs;
        IntVector intVector = internal_outputs;
        doubleVector.empty();
        intVector.empty();
        for (int i = 0; i < dataSetArr.length; i++) {
            FeatureDataSet featureDataSet = (FeatureDataSet) dataSetArr[i][0];
            FeatureDataSet featureDataSet2 = (FeatureDataSet) dataSetArr[i][1];
            probability.train(featureDataSet);
            for (int i2 = 0; i2 < featureDataSet2.size(); i2++) {
                doubleVector.add(probability.getPositiveProb(featureDataSet2.getExample(i2)));
                intVector.add(featureDataSet2.getOutputFeature(i2).getValueId());
            }
        }
        return getAPR();
    }

    public static double getAPR(Classifier classifier, DataSet[][] dataSetArr, boolean[] zArr) throws Exception {
        System.currentTimeMillis();
        DataSet[][] dataSetArr2 = new DataSet[dataSetArr.length][dataSetArr[0].length];
        for (int i = 0; i < dataSetArr.length; i++) {
            for (int i2 = 0; i2 < dataSetArr[0].length; i2++) {
                dataSetArr2[i][i2] = FSDataSet.applyMask((FeatureDataSet) dataSetArr[i][i2], zArr);
            }
        }
        System.currentTimeMillis();
        return getAPR(classifier, dataSetArr2);
    }

    public static double getAPR(Classifier classifier, DataSet[][] dataSetArr) throws Exception {
        DoubleVector doubleVector = internal_probs;
        IntVector intVector = internal_outputs;
        doubleVector.empty();
        intVector.empty();
        for (int i = 0; i < dataSetArr.length; i++) {
            FeatureDataSet featureDataSet = (FeatureDataSet) dataSetArr[i][0];
            FeatureDataSet featureDataSet2 = (FeatureDataSet) dataSetArr[i][1];
            try {
                classifier.train(featureDataSet);
            } catch (Exception e) {
                MainClass._internalError(e);
            }
            for (int i2 = 0; i2 < featureDataSet2.size(); i2++) {
                doubleVector.add(classifier.getDistribution(featureDataSet2.getExample(i2))[1]);
                intVector.add(featureDataSet2.getOutputFeature(i2).getValueId());
            }
        }
        return getAPR();
    }

    public static double getAPR() {
        return getAPR(internal_probs, internal_outputs);
    }

    public static Confusion getConfusion(Confusion confusion, DoubleVector doubleVector, IntVector intVector) {
        if (confusion == null) {
            new Confusion();
        }
        return getConfusion(confusion, doubleVector, intVector, true);
    }

    public static Confusion getConfusion(DoubleVector doubleVector, IntVector intVector) {
        return getConfusion(null, doubleVector, intVector, true);
    }

    public static Confusion getConfusion(Confusion confusion, DoubleVector doubleVector, IntVector intVector, boolean z) {
        Confusion confusion2 = confusion;
        if (confusion2 == null) {
            confusion2 = new Confusion();
        } else {
            confusion2.empty();
        }
        System.currentTimeMillis();
        IntVector intVector2 = internal_index;
        int i = 0;
        int i2 = 0;
        DoubleVector.QuickSort(doubleVector, intVector2);
        System.currentTimeMillis();
        double d = 0.0d;
        double d2 = 0.0d;
        for (int i3 = 0; i3 < doubleVector.size(); i3++) {
            if (intVector.get(intVector2.get(i3)) == 0) {
                i2++;
                d += 1.0d;
            } else {
                d2 += 1.0d;
                i++;
            }
        }
        confusion2.setTotPosNeg(i, i2);
        confusion2.addPoint(i, i2);
        for (int i4 = 0; i4 < doubleVector.size(); i4++) {
            int i5 = 0;
            int i6 = 0;
            double d3 = doubleVector.get(i4);
            for (int i7 = 0; i7 < doubleVector.size(); i7++) {
                if (doubleVector.get(i7) >= d3) {
                    if (intVector.get(i7) == 0) {
                        i6++;
                    } else {
                        i5++;
                    }
                }
            }
            confusion2.addPoint(i5, i6);
        }
        confusion2.doneAdding();
        if (z) {
            confusion2.interpolate();
        }
        return confusion2;
    }

    public static double getAPR(DoubleVector doubleVector, IntVector intVector) {
        getConfusion(internal_confusion, doubleVector, intVector);
        return internal_confusion.calculateAUCPR(KStarConstants.FLOOR);
    }

    public static void writePerformanceCurve(DoubleVector doubleVector, IntVector intVector, String str, boolean z, boolean z2) {
        Confusion confusion = getConfusion(internal_confusion, doubleVector, intVector, z2);
        if (z) {
            confusion.calculateAUCPR(KStarConstants.FLOOR, true);
            confusion.writePRFile(str);
        } else {
            confusion.calculateAUCROC(true);
            confusion.writeROCFile(str);
        }
    }

    public static void writePRCurve(DoubleVector doubleVector, IntVector intVector, String str) {
        writePerformanceCurve(doubleVector, intVector, str, true, true);
    }

    public static void writeROCCurve(DoubleVector doubleVector, IntVector intVector, String str) {
        writePerformanceCurve(doubleVector, intVector, str, false, true);
    }

    public static void writeROCCurve(DoubleVector doubleVector, IntVector intVector, String str, boolean z) {
        writePerformanceCurve(doubleVector, intVector, str, false, z);
    }
}
