package edu.wisc.sjm.machlearn.confusion;

import edu.wisc.sjm.jutil.matrices.DoubleMatrix;
import edu.wisc.sjm.jutil.vectors.DoubleVector;
import edu.wisc.sjm.machlearn.classifiers.Classifier;
import edu.wisc.sjm.machlearn.dataset.DataSet;
import edu.wisc.sjm.machlearn.dataset.Feature;
import edu.wisc.sjm.machlearn.dataset.FeatureId;
import edu.wisc.sjm.machlearn.dataset.featuredataset.FeatureDataSet;
import edu.wisc.sjm.machlearn.policy.PolicyClassifier;
import java.util.Iterator;
import java.util.Vector;
import weka.classifiers.lazy.kstar.KStarConstants;

/* loaded from: input_file:builds/machlearn_install.jar:machlearn.jar:edu/wisc/sjm/machlearn/confusion/ConfusionMatrix.class */
public class ConfusionMatrix {
    protected FeatureId outputfeature;
    protected Vector<DoubleMatrix> matrices;
    protected int num;
    protected boolean verbose;

    public ConfusionMatrix(FeatureId featureId) {
        this(featureId, false);
    }

    public ConfusionMatrix(FeatureId featureId, boolean z) {
        this.outputfeature = featureId;
        this.matrices = new Vector<>();
        this.verbose = z;
    }

    public void setVerbose(boolean z) {
        this.verbose = z;
    }

    public void clear() {
        this.matrices.clear();
    }

    public void add(Feature[] featureArr, Feature[] featureArr2) {
        DoubleMatrix doubleMatrix = new DoubleMatrix(this.outputfeature.numValues(), this.outputfeature.numValues());
        for (int i = 0; i < featureArr2.length; i++) {
            doubleMatrix.increment(featureArr[i].getValueId(), featureArr2[i].getValueId());
        }
        this.matrices.add(doubleMatrix);
    }

    public void add(Classifier classifier, FeatureDataSet featureDataSet) throws Exception {
        DoubleMatrix doubleMatrix = new DoubleMatrix(this.outputfeature.numValues(), this.outputfeature.numValues());
        Feature[] classify = classifier.classify(featureDataSet);
        double[][] distribution = classifier.getDistribution(featureDataSet);
        if (this.verbose) {
            System.out.println("Name\tActual\tPredicted\t" + featureDataSet.getOutputFeatureId().getValue(0) + "\t" + featureDataSet.getOutputFeatureId().getValue(1));
            System.out.println("====\t======\t=========");
        }
        for (int i = 0; i < featureDataSet.size(); i++) {
            int valueId = classify[i].getValueId();
            int valueId2 = featureDataSet.getOutputFeature(i).getValueId();
            if (this.verbose) {
                System.out.print(String.valueOf(featureDataSet.getName(i)) + "\t" + valueId2 + "\t" + valueId);
                System.out.println("\t" + distribution[i][0] + "\t" + distribution[i][1]);
            }
            doubleMatrix.increment(valueId, valueId2);
        }
        this.matrices.add(doubleMatrix);
    }

    public void add(PolicyClassifier policyClassifier, DataSet dataSet) throws Exception {
        DoubleMatrix doubleMatrix = new DoubleMatrix(this.outputfeature.numValues(), this.outputfeature.numValues());
        FeatureDataSet fDSDataSet = policyClassifier.getFDSDataSet(dataSet);
        Feature[] classify = policyClassifier.classify(fDSDataSet);
        double[][] distribution = policyClassifier.getDistribution(fDSDataSet);
        if (this.verbose) {
            System.out.println("Name\tActual\tPredicted\t" + fDSDataSet.getOutputFeatureId().getValue(0) + "\t" + fDSDataSet.getOutputFeatureId().getValue(1));
            System.out.println("====\t======\t=========");
        }
        for (int i = 0; i < fDSDataSet.size(); i++) {
            int valueId = classify[i].getValueId();
            int valueId2 = fDSDataSet.getOutputFeature(i).getValueId();
            if (this.verbose) {
                System.out.print(String.valueOf(fDSDataSet.getName(i)) + "\t" + valueId2 + "\t" + valueId);
                System.out.println("\t" + distribution[i][0] + "\t" + distribution[i][1]);
            }
            doubleMatrix.increment(valueId, valueId2);
        }
        this.matrices.add(doubleMatrix);
    }

    public int numExamples(int i) {
        return (int) getMatrix(i).sum();
    }

    public int numExamples() {
        int i = 0;
        for (int i2 = 0; i2 < size(); i2++) {
            i += numExamples(i2);
        }
        return i;
    }

    public int size() {
        return this.matrices.size();
    }

    public DoubleMatrix getMatrix() {
        DoubleMatrix matrix = getMatrix(0);
        for (int i = 1; i < size(); i++) {
            matrix.add(getMatrix(i));
        }
        return matrix;
    }

    public DoubleMatrix getMatrix(int i) {
        return this.matrices.get(i);
    }

    public double getAccuracy(DoubleMatrix doubleMatrix) {
        double sum = doubleMatrix.sum();
        double d = 0.0d;
        for (int i = 0; i < this.outputfeature.numValues(); i++) {
            d += doubleMatrix.get(i, i);
        }
        return (d / sum) * 100.0d;
    }

    public double getAccuracy(int i) {
        return getAccuracy(getMatrix(i));
    }

    public double getAccuracy() {
        double truePositives = getTruePositives(1);
        double trueNegatives = getTrueNegatives(1);
        return ((truePositives + trueNegatives) / (((truePositives + trueNegatives) + getFalsePositives(1)) + getFalseNegatives(1))) * 100.0d;
    }

    public double getPrecision() {
        return getPrecision(1);
    }

    public double getF1Score(int i, int i2) {
        double recall = getRecall(i, i2);
        double precision = getPrecision(i, i2);
        return (recall == KStarConstants.FLOOR || precision == KStarConstants.FLOOR) ? KStarConstants.FLOOR : ((2.0d * precision) * recall) / (precision + recall);
    }

    public double getRecall(int i, int i2) {
        DoubleMatrix matrix = getMatrix(i);
        int i3 = i2 == 1 ? 0 : 1;
        double d = matrix.get(i2, i2);
        return d == KStarConstants.FLOOR ? KStarConstants.FLOOR : (d / (d + matrix.get(i3, i2))) * 100.0d;
    }

    public double getPrecision(int i, int i2) {
        DoubleMatrix matrix = getMatrix(i);
        int i3 = i2 == 1 ? 0 : 1;
        double d = matrix.get(i2, i2);
        return d == KStarConstants.FLOOR ? KStarConstants.FLOOR : (d / (d + matrix.get(i2, i3))) * 100.0d;
    }

    public double getPrecision(int i) {
        boolean z = i != 1;
        double truePositives = getTruePositives(i);
        return truePositives == KStarConstants.FLOOR ? KStarConstants.FLOOR : (truePositives / (truePositives + getFalsePositives(i))) * 100.0d;
    }

    public double getFScore() {
        return getFScore(1);
    }

    public double getFScore(int i) {
        double recall = getRecall(i);
        double precision = getPrecision(i);
        return (recall == KStarConstants.FLOOR || precision == KStarConstants.FLOOR) ? KStarConstants.FLOOR : ((2.0d * precision) * recall) / (precision + recall);
    }

    public double getRecall() {
        return getRecall(1);
    }

    public double getRecall(int i) {
        boolean z = i != 1;
        double truePositives = getTruePositives(i);
        return truePositives == KStarConstants.FLOOR ? KStarConstants.FLOOR : (truePositives / (truePositives + getFalseNegatives(i))) * 100.0d;
    }

    public String toString() {
        StringBuffer stringBuffer = new StringBuffer("Confusion Matrix\n");
        DoubleMatrix matrix = getMatrix();
        stringBuffer.append("Output Feature:");
        stringBuffer.append(this.outputfeature.toString());
        stringBuffer.append("\n");
        stringBuffer.append("\t\tActual\n");
        stringBuffer.append("Predicted");
        for (int i = 0; i < this.outputfeature.numValues(); i++) {
            stringBuffer.append("\t");
            stringBuffer.append(this.outputfeature.getValue(i));
        }
        stringBuffer.append("\n");
        for (int i2 = 0; i2 < this.outputfeature.numValues(); i2++) {
            stringBuffer.append(this.outputfeature.getValue(i2));
            stringBuffer.append("\t\t");
            for (int i3 = 0; i3 < this.outputfeature.numValues(); i3++) {
                stringBuffer.append(matrix.get(i2, i3));
                stringBuffer.append("\t");
            }
            stringBuffer.append("\n");
        }
        stringBuffer.append("\n");
        stringBuffer.append("Assuming\n");
        stringBuffer.append(this.outputfeature.getValue(0));
        stringBuffer.append(" is positive and\n");
        stringBuffer.append(this.outputfeature.getValue(1));
        stringBuffer.append(" is negative\n");
        stringBuffer.append("Recall:");
        stringBuffer.append(getRecall(0));
        stringBuffer.append("\nPrecision:");
        stringBuffer.append(getPrecision(0));
        stringBuffer.append("\n------------\n");
        stringBuffer.append("Assuming\n");
        stringBuffer.append(this.outputfeature.getValue(1));
        stringBuffer.append(" is positive and\n");
        stringBuffer.append(this.outputfeature.getValue(0));
        stringBuffer.append(" is negative\n");
        stringBuffer.append("Recall:");
        stringBuffer.append(getRecall(1));
        stringBuffer.append("\nPrecision:");
        stringBuffer.append(getPrecision(1));
        return stringBuffer.toString();
    }

    public String getStats() {
        StringBuffer stringBuffer = new StringBuffer();
        DoubleVector doubleVector = new DoubleVector();
        DoubleVector doubleVector2 = new DoubleVector();
        DoubleVector doubleVector3 = new DoubleVector();
        DoubleVector doubleVector4 = new DoubleVector();
        int i = 0;
        while (i <= 1) {
            int i2 = i == 0 ? 1 : 0;
            for (int i3 = 0; i3 < this.matrices.size(); i3++) {
                doubleVector.add(getAccuracy(i3));
                doubleVector2.add(getRecall(i3, i));
                doubleVector3.add(getPrecision(i3, i));
                doubleVector4.add(getF1Score(i3, i));
            }
            stringBuffer.append("\nAccuracy:\n");
            stringBuffer.append(String.valueOf(doubleVector.average()) + "+/-" + doubleVector.conf());
            stringBuffer.append("\nAssuming\n");
            stringBuffer.append(this.outputfeature.getValue(i));
            stringBuffer.append(" is positive and\n");
            stringBuffer.append(this.outputfeature.getValue(i2));
            stringBuffer.append(" is negative\n");
            stringBuffer.append("Recall:" + doubleVector2.average() + "+/-" + doubleVector2.conf());
            stringBuffer.append("\nPrecision:" + doubleVector3.average() + "+/-" + doubleVector3.conf());
            stringBuffer.append("\nF1-Score:" + doubleVector4.average() + "+/-" + doubleVector4.conf());
            stringBuffer.append("\n----------------------");
            doubleVector.empty();
            doubleVector2.empty();
            doubleVector3.empty();
            doubleVector4.empty();
            i++;
        }
        return stringBuffer.toString();
    }

    public double getFalsePositives(int i) {
        int i2 = i == 1 ? 0 : 1;
        double d = 0.0d;
        Iterator<DoubleMatrix> it = this.matrices.iterator();
        while (it.hasNext()) {
            d += it.next().get(i, i2);
        }
        return d;
    }

    public double getFalseNegatives(int i) {
        int i2 = i == 1 ? 0 : 1;
        double d = 0.0d;
        Iterator<DoubleMatrix> it = this.matrices.iterator();
        while (it.hasNext()) {
            d += it.next().get(i2, i);
        }
        return d;
    }

    public double getTruePositives(int i) {
        boolean z = i != 1;
        double d = 0.0d;
        Iterator<DoubleMatrix> it = this.matrices.iterator();
        while (it.hasNext()) {
            d += it.next().get(i, i);
        }
        return d;
    }

    public double getTrueNegatives(int i) {
        int i2 = i == 1 ? 0 : 1;
        double d = 0.0d;
        Iterator<DoubleMatrix> it = this.matrices.iterator();
        while (it.hasNext()) {
            d += it.next().get(i2, i2);
        }
        return d;
    }

    public double getFPR() {
        double falsePositives = getFalsePositives(1);
        return falsePositives == KStarConstants.FLOOR ? KStarConstants.FLOOR : (falsePositives / (falsePositives + getTrueNegatives(1))) * 100.0d;
    }

    public static double getF1(int i, int i2, int i3, int i4, double d) {
        double d2 = i / (i + i4);
        double d3 = i / (i + i3);
        double d4 = d * d;
        return (((d4 + 1.0d) * d3) * d2) / (d3 + (d4 * d2));
    }
}
