package edu.wisc.sjm.machlearn.classifiers;

import edu.wisc.sjm.jutil.misc.DoubleUtil;
import edu.wisc.sjm.jutil.misc.MainClass;
import edu.wisc.sjm.jutil.xml.XMLUtil;
import edu.wisc.sjm.machlearn.MachLearnConstants;
import edu.wisc.sjm.machlearn.Scorer;
import edu.wisc.sjm.machlearn.dataset.DataSet;
import edu.wisc.sjm.machlearn.dataset.Example;
import edu.wisc.sjm.machlearn.dataset.Feature;
import edu.wisc.sjm.machlearn.dataset.featuredataset.FeatureDataSet;
import edu.wisc.sjm.machlearn.featurediscretize.FeatureDiscretize;
import edu.wisc.sjm.machlearn.featureselection.FSDataSet;
import edu.wisc.sjm.machlearn.featureselection.FeatureSelect;
import edu.wisc.sjm.machlearn.util.Util;
import java.io.FileWriter;
import java.io.PrintWriter;
import java.util.Vector;
import org.w3c.dom.Document;
import org.w3c.dom.Element;
import weka.classifiers.lazy.kstar.KStarConstants;
import weka.core.xml.XMLSerialization;

/* loaded from: input_file:builds/machlearn_install.jar:machlearn.jar:edu/wisc/sjm/machlearn/classifiers/Classifier.class */
public abstract class Classifier extends Scorer implements MachLearnConstants {
    FeatureDiscretize feature_discretize;
    FeatureSelect feature_select;
    boolean[] feature_mask;
    static double[] temp_prob = new double[2];

    public abstract void train_(FeatureDataSet featureDataSet) throws Exception;

    public abstract Feature classify_(Example example) throws Exception;

    public abstract String printClassifier();

    public abstract Classifier cloneClassifier();

    public abstract void setParameter(int i, Object obj);

    public void setFeatureDiscretize(FeatureDiscretize featureDiscretize) {
        this.feature_discretize = featureDiscretize;
    }

    public void setFeatureDiscretize(String str) {
        try {
            this.feature_discretize = (FeatureDiscretize) Class.forName(str).newInstance();
        } catch (Exception e) {
            internalError(e);
        }
    }

    public void setFeatureSelect(FeatureSelect featureSelect) {
        System.out.println("Setting feature select");
        this.feature_select = featureSelect;
    }

    public void setFeatureSelect(String str) {
        try {
            this.feature_select = (FeatureSelect) Class.forName(str).newInstance();
        } catch (Exception e) {
            internalError(e);
        }
    }

    public Element createClassifierXMLNode(Document document, Element element) {
        Element createElement = document.createElement("Classifier");
        createElement.setAttribute(XMLSerialization.ATT_CLASS, getClass().getName());
        element.appendChild(createElement);
        return createElement;
    }

    public void toXML(Document document, Element element) {
        createClassifierXMLNode(document, element);
    }

    public static Classifier createFromXML(Element element) {
        Element element2 = element;
        if (element2.getTagName() != "Classifier") {
            element2 = XMLUtil.getChildElement(element, "Classifier");
        }
        Classifier classifier = null;
        try {
            classifier = (Classifier) Class.forName(element2.getAttribute(XMLSerialization.ATT_CLASS)).newInstance();
            classifier.fromXML(element2);
        } catch (Exception e) {
            MainClass._internalError(e);
        }
        return classifier;
    }

    public void fromXML(Element element) {
    }

    public static double[] getClassifierScores(Classifier classifier, DataSet dataSet, int i, boolean z, boolean z2) throws Exception {
        return getClassifierScores(classifier, (FeatureDataSet[][]) dataSet.splitDataSetValidation(i), z, z2);
    }

    public static double[] getClassifierScores(Classifier classifier, DataSet dataSet, Object[] objArr, int i, int i2, boolean z, boolean z2) throws Exception {
        return getClassifierScores(classifier, (FeatureDataSet[][]) dataSet.splitDataSetValidation(i), objArr, i2, z, z2);
    }

    public static double[] getClassifierScores(Classifier classifier, FeatureDataSet[][] featureDataSetArr, boolean z, boolean z2) throws Exception {
        double[] dArr = {KStarConstants.FLOOR, KStarConstants.FLOOR};
        for (int i = 0; i < featureDataSetArr.length; i++) {
            classifier.train(featureDataSetArr[i][0]);
            if (z) {
                dArr[0] = dArr[0] + classifier.getAccuracy(featureDataSetArr[i][0]);
            }
            if (z2) {
                dArr[1] = dArr[1] + classifier.getAccuracy(featureDataSetArr[i][1]);
            }
        }
        dArr[0] = dArr[0] / featureDataSetArr.length;
        dArr[1] = dArr[1] / featureDataSetArr.length;
        return dArr;
    }

    public Feature[] classify(FeatureDataSet featureDataSet) throws Exception {
        Feature[] featureArr = new Feature[featureDataSet.size()];
        for (int i = 0; i < featureDataSet.size(); i++) {
            featureArr[i] = classify(featureDataSet.getExample(i));
        }
        return featureArr;
    }

    public double[][] getDistribution(FeatureDataSet featureDataSet) throws Exception {
        double[][] dArr = new double[featureDataSet.size()][0];
        for (int i = 0; i < featureDataSet.size(); i++) {
            dArr[i] = getDistribution(featureDataSet.getExample(i));
        }
        return dArr;
    }

    public double[] getDistribution(Example example) throws Exception {
        double[] dArr = new double[example.getOutputFeatureId().numValues()];
        getDistribution(example, dArr);
        return dArr;
    }

    public double getPositiveProb(Example example) throws Exception {
        getDistribution(example, temp_prob);
        return temp_prob[1];
    }

    protected FeatureDataSet process_train(FeatureDataSet featureDataSet) throws Exception {
        FeatureDataSet featureDataSet2 = featureDataSet;
        if (this.feature_discretize != null) {
            this.feature_discretize.train(featureDataSet2);
            featureDataSet2 = this.feature_discretize.discretize(featureDataSet2);
        }
        if (this.feature_select != null) {
            System.out.println("doing feature selection");
            this.feature_mask = this.feature_select.doFeatureSelection(featureDataSet2, (Classifier) getClass().newInstance());
            featureDataSet2 = FSDataSet.applyMask(featureDataSet2, this.feature_mask);
        }
        return featureDataSet2;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public FeatureDataSet process(FeatureDataSet featureDataSet) throws Exception {
        FeatureDataSet featureDataSet2 = featureDataSet;
        if (this.feature_discretize != null) {
            featureDataSet2 = this.feature_discretize.discretize(featureDataSet2);
        }
        if (this.feature_select != null) {
            System.out.println("applying feature mask");
            featureDataSet2 = FSDataSet.applyMask(featureDataSet2, this.feature_mask);
        }
        return featureDataSet2;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public Example process(Example example) throws Exception {
        Example example2 = example;
        if (this.feature_discretize != null) {
            example2 = this.feature_discretize.discretize(example2);
        }
        if (this.feature_select != null) {
            if (this.feature_mask != null) {
                example2 = FSDataSet.applyMask(example2, this.feature_mask);
            } else {
                System.out.println("Null mask????");
            }
        }
        return example2;
    }

    public void getDistribution(Example example, double[] dArr) throws Exception {
        Feature classify = classify(example);
        dArr[classify.getValueId()] = 1.0d;
        for (int i = 0; i < dArr.length; i++) {
            if (i != classify.getValueId()) {
                dArr[i] = 0.0d;
            } else {
                dArr[i] = 1.0d;
            }
        }
    }

    public Feature classifyByDistribution(Example example) throws Exception {
        Feature feature = (Feature) example.getOutputFeature().clone();
        try {
            feature.setValue(Util.argmax(getDistribution(example)));
        } catch (Exception e) {
            internalError(e);
        }
        return feature;
    }

    public static double getAlphaScore(Classifier classifier, FeatureDataSet[][] featureDataSetArr, double d) throws Exception {
        double[] dArr = {KStarConstants.FLOOR, KStarConstants.FLOOR};
        double[] classifierScores = DoubleUtil.eq(KStarConstants.FLOOR, d, 1.0E-4d) ? getClassifierScores(classifier, featureDataSetArr, false, true) : DoubleUtil.eq(1.0d, d, 1.0E-4d) ? getClassifierScores(classifier, featureDataSetArr, true, false) : getClassifierScores(classifier, featureDataSetArr, true, true);
        return ((d * classifierScores[0]) + ((1.0d - d) * classifierScores[1])) / 100.0d;
    }

    public static double getAlphaScore(Classifier classifier, DataSet[][] dataSetArr, Object[] objArr, int i, double d) throws Exception {
        double[] dArr = {KStarConstants.FLOOR, KStarConstants.FLOOR};
        double[] classifierScores = DoubleUtil.eq(KStarConstants.FLOOR, d, 1.0E-4d) ? getClassifierScores(classifier, dataSetArr, objArr, i, false, true) : DoubleUtil.eq(1.0d, d, 1.0E-4d) ? getClassifierScores(classifier, dataSetArr, objArr, i, true, false) : getClassifierScores(classifier, dataSetArr, objArr, i, true, true);
        return ((d * classifierScores[0]) + ((1.0d - d) * classifierScores[1])) / 100.0d;
    }

    public static double[] getClassifierScores(Classifier classifier, DataSet[][] dataSetArr, Object[] objArr, int i, boolean z, boolean z2) throws Exception {
        double[] dArr = {KStarConstants.FLOOR, KStarConstants.FLOOR};
        for (int i2 = 0; i2 < dataSetArr.length; i2++) {
            classifier.doTune(dataSetArr[i2][0], objArr, i);
            if (z) {
                dArr[0] = dArr[0] + classifier.getAccuracy(dataSetArr[i2][0]);
            }
            if (z2) {
                dArr[1] = dArr[1] + classifier.getAccuracy(dataSetArr[i2][1]);
            }
        }
        dArr[0] = dArr[0] / dataSetArr.length;
        dArr[1] = dArr[1] / dataSetArr.length;
        return dArr;
    }

    public void doTune(DataSet dataSet, Object[] objArr, int i) throws Exception {
        double[] dArr = (double[]) objArr[0];
        double[] dArr2 = new double[dArr.length];
        DataSet[][] splitDataSetFolds = i > 0 ? dataSet.splitDataSetFolds(i, true) : i == -2 ? dataSet.splitDataSetFolds(dataSet.size(), true) : new DataSet[][]{dataSet.splitRandom(10.0d)};
        for (int i2 = 0; i2 < dArr.length; i2++) {
            setParameter(0, new Double(dArr[i2]));
            dArr2[i2] = 0.0d;
            for (int i3 = 0; i3 < splitDataSetFolds.length; i3++) {
                DataSet dataSet2 = splitDataSetFolds[i3][0];
                DataSet dataSet3 = splitDataSetFolds[i3][1];
                train((FeatureDataSet) dataSet2);
                int i4 = i2;
                dArr2[i4] = dArr2[i4] + (getAccuracy((FeatureDataSet) dataSet3) / splitDataSetFolds.length);
                debugMesg(1, "tunescore[" + i2 + "]=" + dArr2[i2]);
            }
        }
        setParameter(0, new Double(dArr[Util.argmax(dArr2)]));
        train((FeatureDataSet) dataSet);
    }

    public double getAccuracy(DataSet dataSet) throws Exception {
        FeatureDataSet featureDataSet = (FeatureDataSet) dataSet;
        int i = 0;
        for (int i2 = 0; i2 < dataSet.size(); i2++) {
            Example example = featureDataSet.getExample(i2);
            if (example.correctOutput(classify(example))) {
                i++;
            }
        }
        return (i / dataSet.size()) * 100.0d;
    }

    public double[] getAccuracy(FeatureDataSet featureDataSet, int i) {
        try {
            double[] dArr = new double[2];
            DataSet[][] splitDataSetFolds = featureDataSet.splitDataSetFolds(i, false);
            FeatureDataSet[][] featureDataSetArr = new FeatureDataSet[i][2];
            for (int i2 = 0; i2 < i; i2++) {
                featureDataSetArr[i2][0] = (FeatureDataSet) splitDataSetFolds[i2][0];
                featureDataSetArr[i2][1] = (FeatureDataSet) splitDataSetFolds[i2][1];
            }
            for (int i3 = 0; i3 < i; i3++) {
                train(featureDataSetArr[i3][0]);
                dArr[0] = dArr[0] + (getAccuracy(featureDataSetArr[i3][0]) / i);
                dArr[1] = dArr[1] + (getAccuracy(featureDataSetArr[i3][1]) / i);
            }
            return dArr;
        } catch (Exception e) {
            internalError(e);
            return null;
        }
    }

    public double getRecall(DataSet dataSet) throws Exception {
        int i = 0;
        int i2 = 0;
        FeatureDataSet featureDataSet = (FeatureDataSet) dataSet;
        for (int i3 = 0; i3 < dataSet.size(); i3++) {
            Example example = featureDataSet.getExample(i3);
            if (example.getOutputFeature().getValueId() == 1) {
                if (example.correctOutput(classify(example))) {
                    i++;
                } else {
                    i2++;
                }
            }
        }
        return i / (i + i2);
    }

    public double getPrecision(DataSet dataSet) throws Exception {
        int i = 0;
        int i2 = 0;
        FeatureDataSet featureDataSet = (FeatureDataSet) dataSet;
        for (int i3 = 0; i3 < dataSet.size(); i3++) {
            Example example = featureDataSet.getExample(i3);
            if (example.getOutputFeature().getValueId() == 1) {
                if (example.correctOutput(classify(example))) {
                    i++;
                }
            } else if (!example.correctOutput(classify(example))) {
                i2++;
            }
        }
        return i / (i + i2);
    }

    public Example[][] getBothCat(DataSet dataSet) throws Exception {
        Vector vector = new Vector();
        Vector vector2 = new Vector();
        FeatureDataSet featureDataSet = (FeatureDataSet) dataSet;
        for (int i = 0; i < dataSet.size(); i++) {
            Example example = featureDataSet.getExample(i);
            if (example.correctOutput(classify(example))) {
                vector2.add(example);
            } else {
                vector.add(example);
            }
        }
        Example[][] exampleArr = {new Example[vector.size()], new Example[vector2.size()]};
        vector.copyInto(exampleArr[0]);
        vector2.copyInto(exampleArr[1]);
        return exampleArr;
    }

    public Example[] getMisCat(DataSet dataSet) throws Exception {
        Vector vector = new Vector();
        FeatureDataSet featureDataSet = (FeatureDataSet) dataSet;
        for (int i = 0; i < dataSet.size(); i++) {
            Example example = featureDataSet.getExample(i);
            if (!example.correctOutput(classify(example))) {
                vector.add(example);
            }
        }
        Example[] exampleArr = new Example[vector.size()];
        vector.copyInto(exampleArr);
        return exampleArr;
    }

    public double getExampleWeight(Example example) {
        return KStarConstants.FLOOR;
    }

    public void generateROCCurve(DataSet dataSet, DataSet dataSet2, String str) throws Exception {
        FeatureDataSet featureDataSet = (FeatureDataSet) dataSet2;
        double[] dArr = new double[dataSet2.size()];
        boolean[] zArr = new boolean[dataSet2.size()];
        train((FeatureDataSet) dataSet);
        for (int i = 0; i < dArr.length; i++) {
            dArr[i] = getExampleWeight(featureDataSet.getExample(i));
            zArr[i] = featureDataSet.getExample(i).getOutputFeature().getValueId() == 0;
        }
        for (int i2 = 0; i2 < dArr.length - 1; i2++) {
            for (int i3 = i2 + 1; i3 < dArr.length; i3++) {
                if (dArr[i2] > dArr[i3]) {
                    Util.swap(dArr, i2, i3);
                    Util.swap(zArr, i2, i3);
                }
            }
        }
        boolean z = zArr[0];
        Vector vector = new Vector();
        for (int i4 = 0; i4 < zArr.length; i4++) {
            if (zArr[i4] != z) {
                vector.add(new Double(i4 - 0.5d));
                z = zArr[i4];
            }
        }
        PrintWriter printWriter = null;
        try {
            printWriter = new PrintWriter(new FileWriter(str));
        } catch (Exception e) {
            internalError(e);
        }
        printWriter.println("threshold\ttprate\tfprate");
        printWriter.println("0\t0\t0");
        for (int i5 = 0; i5 < vector.size(); i5++) {
            int i6 = 0;
            int i7 = 0;
            int i8 = 0;
            int i9 = 0;
            double doubleValue = ((Double) vector.get(i5)).doubleValue();
            for (int i10 = 0; i10 < zArr.length; i10++) {
                if (i10 > doubleValue) {
                    if (zArr[i10]) {
                        i6++;
                    } else {
                        i7++;
                    }
                } else if (zArr[i10]) {
                    i8++;
                } else {
                    i9++;
                }
            }
            printWriter.println(String.valueOf(doubleValue) + "\t" + (i6 / (i6 + i8)) + "\t" + (i7 / (i9 + i7)));
        }
        printWriter.close();
    }

    @Override // edu.wisc.sjm.machlearn.Scorer
    public double getScore(Example example) throws Exception {
        return getPositiveProb(example);
    }

    public void train(FeatureDataSet featureDataSet) throws Exception {
        process_train(featureDataSet);
        train_(process(featureDataSet));
    }

    public Feature classify(Example example) throws Exception {
        process(example);
        return classify_(example);
    }

    @Override // edu.wisc.sjm.machlearn.Scorer
    public void doTrain(DataSet dataSet) throws Exception {
        train((FeatureDataSet) dataSet);
    }

    public void setParameters(String str) {
        for (String str2 : Util.splitString(str, ",")) {
            String[] splitString = Util.splitString(str2, "=");
            System.out.println("Setting " + splitString[0] + " = " + splitString[1]);
            try {
                getClass().getMethod("set" + splitString[0], String.class).invoke(this, splitString[1]);
            } catch (Exception e) {
                internalError(e);
            }
        }
    }
}
