package edu.wisc.mgr.auc;

import edu.wisc.sjm.jutil.misc.GenericCache;
import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.Iterator;
import java.util.TreeSet;
import java.util.Vector;
import weka.classifiers.lazy.kstar.KStarConstants;

/* loaded from: input_file:builds/machlearn_install.jar:builds/auc_install.jar:auc.jar:edu/wisc/mgr/auc/Confusion.class */
public class Confusion extends TreeSet<PNPoint> {
    private double totPos;
    private double totNeg;
    private Vector<PNPoint> vector;

    public Confusion() {
        this(1.0d, 1.0d);
    }

    public Confusion(double d, double d2) {
        this.vector = new Vector<>();
        setTotPosNeg(d, d2);
    }

    public void setTotPosNeg(double d, double d2) {
        if (d >= 1.0d && d2 >= 1.0d) {
            this.totPos = d;
            this.totNeg = d2;
        } else {
            this.totPos = 1.0d;
            this.totNeg = 1.0d;
            System.err.println("ERROR: " + d + "," + d2 + " - Defaulting Confusion to 1,1");
        }
    }

    public double getTotPos() {
        return this.totPos;
    }

    public double getTotNeg() {
        return this.totNeg;
    }

    public void addPRPoint(double d, double d2) throws NumberFormatException {
        if (d > 1.0d || d < KStarConstants.FLOOR || d2 > 1.0d || d2 < KStarConstants.FLOOR) {
            throw new NumberFormatException();
        }
        double d3 = d * this.totPos;
        double d4 = (d3 - (d2 * d3)) / d2;
        PNPoint pNPoint = PNPoint.cache.get();
        pNPoint.setPosNeg(d3, d4);
        if (add(pNPoint)) {
            return;
        }
        PNPoint.cache.release((GenericCache<PNPoint>) pNPoint);
    }

    public void empty() {
        PNPoint.cache.release(this.vector);
        clear();
    }

    public void addROCPoint(double d, double d2) throws NumberFormatException {
        if (d > 1.0d || d < KStarConstants.FLOOR || d2 > 1.0d || d2 < KStarConstants.FLOOR) {
            throw new NumberFormatException();
        }
        double d3 = d2 * this.totPos;
        double d4 = d * this.totNeg;
        PNPoint pNPoint = PNPoint.cache.get();
        pNPoint.setPosNeg(d3, d4);
        if (add(pNPoint)) {
            return;
        }
        PNPoint.cache.release((GenericCache<PNPoint>) pNPoint);
    }

    public void addPoint(double d, double d2) throws NumberFormatException {
        if (d < KStarConstants.FLOOR || d > this.totPos || d2 < KStarConstants.FLOOR || d2 > this.totNeg) {
            throw new NumberFormatException("pos:" + d + " totPos:" + this.totPos + " neg:" + d2 + " totNeg:" + this.totNeg);
        }
        if (d > 0.001d) {
            PNPoint pNPoint = PNPoint.cache.get();
            pNPoint.setPosNeg(d, d2);
            if (add(pNPoint)) {
                return;
            }
            PNPoint.cache.release((GenericCache<PNPoint>) pNPoint);
        }
    }

    public void doneAdding() {
        if (size() == 0) {
            System.err.println("ERROR: No data...");
            return;
        }
        PNPoint first = first();
        double neg = first.getNeg() / first.getPos();
        PNPoint pNPoint = PNPoint.cache.get();
        pNPoint.setPosNeg(1.0d, neg);
        if (first.getPos() <= 1.0d || !add(pNPoint)) {
            PNPoint.cache.release((GenericCache<PNPoint>) pNPoint);
        }
        PNPoint pNPoint2 = PNPoint.cache.get();
        pNPoint2.setPosNeg(this.totPos, this.totNeg);
        if (!add(pNPoint2)) {
            PNPoint.cache.release((GenericCache<PNPoint>) pNPoint2);
        }
        PNPoint.cache.release(this.vector);
        Iterator<PNPoint> it = iterator();
        while (it.hasNext()) {
            this.vector.add(it.next());
        }
    }

    public void sort() {
        doneAdding();
    }

    public void interpolate() {
        if (this.vector.size() == 0) {
            System.err.println("ERROR: No data to interpolate....");
            return;
        }
        int i = 0;
        while (i < this.vector.size() - 1) {
            PNPoint pNPoint = this.vector.get(i);
            PNPoint pNPoint2 = this.vector.get(i + 1);
            double neg = (pNPoint2.getNeg() - pNPoint.getNeg()) / (pNPoint2.getPos() - pNPoint.getPos());
            double pos = pNPoint.getPos();
            double neg2 = pNPoint.getNeg();
            while (Math.abs(pNPoint.getPos() - pNPoint2.getPos()) > 1.001d) {
                double pos2 = neg2 + (((pNPoint.getPos() - pos) + 1.0d) * neg);
                PNPoint pNPoint3 = PNPoint.cache.get();
                pNPoint3.setPosNeg(pNPoint.getPos() + 1.0d, pos2);
                i++;
                this.vector.insertElementAt(pNPoint3, i);
                pNPoint = pNPoint3;
            }
            i++;
        }
    }

    public double calculateAUCPR(double d) {
        return calculateAUCPR(d, false);
    }

    public double calculateAUCPR(double d, boolean z) {
        if (d < KStarConstants.FLOOR || d > 1.0d) {
            System.err.println("ERROR: invalid minRecall, must be between 0 and 1 - returning 0");
            return KStarConstants.FLOOR;
        }
        if (this.vector.size() == 0) {
            System.err.println("ERROR: No data to calculate....");
            return KStarConstants.FLOOR;
        }
        double d2 = d * this.totPos;
        int i = 0;
        PNPoint elementAt = this.vector.elementAt(0);
        PNPoint pNPoint = null;
        while (elementAt.getPos() < d2) {
            try {
                pNPoint = elementAt;
                i++;
                elementAt = this.vector.elementAt(i);
            } catch (ArrayIndexOutOfBoundsException e) {
                System.out.println("ERROR: minRecall out of bounds - exiting...");
                System.exit(-1);
            }
        }
        double pos = (elementAt.getPos() - d2) / this.totPos;
        double pos2 = elementAt.getPos() / (elementAt.getPos() + elementAt.getNeg());
        double d3 = pos * pos2;
        if (pNPoint != null) {
            double pos3 = ((elementAt.getPos() / (elementAt.getPos() + elementAt.getNeg())) - (pNPoint.getPos() / (pNPoint.getPos() + pNPoint.getNeg()))) / ((elementAt.getPos() / this.totPos) - (pNPoint.getPos() / this.totPos));
            System.out.println("slope is " + pos3);
            double pos4 = (pNPoint.getPos() / (pNPoint.getPos() + pNPoint.getNeg())) + ((pos3 * (d2 - pNPoint.getPos())) / this.totPos);
            System.out.println("Tempprec is " + pos4);
            double d4 = 0.5d * pos * (pos4 - pos2);
            System.out.println("Bonus area is " + d4);
            d3 += d4;
        }
        double pos5 = elementAt.getPos() / this.totPos;
        for (int i2 = i + 1; i2 < this.vector.size(); i2++) {
            PNPoint elementAt2 = this.vector.elementAt(i2);
            double pos6 = elementAt2.getPos() / this.totPos;
            double pos7 = elementAt2.getPos() / (elementAt2.getPos() + elementAt2.getNeg());
            d3 += ((pos6 - pos5) * pos7) + (0.5d * (pos6 - pos5) * (pos2 - pos7));
            pos5 = pos6;
            pos2 = pos7;
        }
        if (z) {
            System.out.println("Total pos:" + this.totPos);
            System.out.println("Total neg:" + this.totNeg);
            System.out.println("Area Under the Curve for Precision - Recall is " + d3);
        }
        return d3;
    }

    public double calculateAUCROC() {
        return calculateAUCROC(false);
    }

    public double calculateAUCROC(boolean z) {
        if (this.vector.size() == 0) {
            System.err.println("ERROR: No data to calculate....");
            return KStarConstants.FLOOR;
        }
        PNPoint elementAt = this.vector.elementAt(0);
        double pos = elementAt.getPos() / this.totPos;
        double neg = elementAt.getNeg() / this.totNeg;
        double d = 0.5d * pos * neg;
        for (int i = 1; i < this.vector.size(); i++) {
            PNPoint elementAt2 = this.vector.elementAt(i);
            double pos2 = elementAt2.getPos() / this.totPos;
            double neg2 = elementAt2.getNeg() / this.totNeg;
            d += ((pos2 - pos) * neg2) - ((0.5d * (pos2 - pos)) * (neg - neg2));
            pos = pos2;
            neg = neg2;
        }
        double d2 = 1.0d - d;
        if (z) {
            System.out.println("Total pos:" + this.totPos);
            System.out.println("Total neg:" + this.totNeg);
            System.out.println("Area Under the Curve for ROC is " + d2);
        }
        return d2;
    }

    public void writePRFile(String str) {
        System.out.println("--- Writing file " + str + " ---");
        if (this.vector.size() == 0) {
            System.err.println("ERROR: No data to write....");
            return;
        }
        try {
            PrintWriter printWriter = new PrintWriter(new FileWriter(new File(str)));
            Iterator<PNPoint> it = this.vector.iterator();
            while (it.hasNext()) {
                PNPoint next = it.next();
                printWriter.println(String.valueOf(next.getPos() / this.totPos) + "\t" + (next.getPos() / (next.getPos() + next.getNeg())));
            }
            printWriter.close();
        } catch (IOException e) {
            System.out.println("ERROR: IO Exception in file " + str + " - exiting...");
            System.exit(-1);
        }
    }

    public void writeROCFile(String str) {
        System.out.println("--- Writing file " + str + " ---");
        if (this.vector.size() == 0) {
            System.err.println("ERROR: No data to write....");
            return;
        }
        try {
            PrintWriter printWriter = new PrintWriter(new FileWriter(new File(str)));
            printWriter.println("0\t0");
            Iterator<PNPoint> it = this.vector.iterator();
            while (it.hasNext()) {
                PNPoint next = it.next();
                printWriter.println(String.valueOf(next.getNeg() / this.totNeg) + "\t" + (next.getPos() / this.totPos));
            }
            printWriter.close();
        } catch (IOException e) {
            System.out.println("ERROR: IO Exception in file " + str + " - exiting...");
            System.exit(-1);
        }
    }

    @Override // java.util.AbstractCollection
    public String toString() {
        String str = String.valueOf("") + "TotPos: " + this.totPos + ", TotNeg: " + this.totNeg + "\n";
        for (int i = 0; i < this.vector.size(); i++) {
            str = String.valueOf(str) + this.vector.elementAt(i) + "\n";
        }
        return str;
    }
}
