package edu.wisc.sjm.machlearn.classifiers.naivebayes.probability;

import edu.wisc.sjm.jutil.misc.DoubleUtil;
import edu.wisc.sjm.jutil.vectors.DoubleVector;
import edu.wisc.sjm.machlearn.dataset.Example;
import edu.wisc.sjm.machlearn.dataset.featuredataset.FeatureDataSet;
import java.awt.Color;
import java.awt.Graphics;

/* loaded from: input_file:builds/machlearn_install.jar:machlearn.jar:edu/wisc/sjm/machlearn/classifiers/naivebayes/probability/ProbabilityHistogram.class */
public class ProbabilityHistogram extends Probability {
    static int numBins = 20;
    double min;
    double max;
    double dx;
    int[][] counts;
    double[][] probs;
    DoubleVector[] temp_values;

    public ProbabilityHistogram(FeatureDataSet featureDataSet, int i) {
        super(featureDataSet, i);
        this.counts = new int[this.nclass][numBins + 2];
        this.probs = new double[this.nclass][numBins + 2];
        this.temp_values = new DoubleVector[this.nclass];
        for (int i2 = 0; i2 < this.nclass; i2++) {
            this.temp_values[i2] = new DoubleVector();
        }
    }

    @Override // edu.wisc.sjm.machlearn.classifiers.naivebayes.probability.Probability
    public void train(FeatureDataSet featureDataSet) {
        for (int i = 0; i < this.nclass; i++) {
            this.temp_values[i].empty();
        }
        for (int i2 = 0; i2 < this.nclass; i2++) {
            for (int i3 = 0; i3 < numBins + 2; i3++) {
                this.probs[i2][i3] = (dm * 1.0d) / (numBins + 2);
            }
        }
        for (int i4 = 0; i4 < featureDataSet.size(); i4++) {
            this.temp_values[featureDataSet.getOutputFeature(i4).getValueId()].add(featureDataSet.get(i4, this.findex).getDValue());
        }
        this.min = Double.POSITIVE_INFINITY;
        this.max = Double.NEGATIVE_INFINITY;
        for (int i5 = 0; i5 < this.nclass; i5++) {
            this.min = Math.min(this.temp_values[i5].min(), this.min);
            this.max = Math.max(this.temp_values[i5].max(), this.max);
        }
        this.dx = (this.max - this.min) / numBins;
        for (int i6 = 0; i6 < featureDataSet.size(); i6++) {
            int valueId = featureDataSet.get(i6, this.output_index).getValueId();
            int findBin = findBin(featureDataSet.get(i6, this.findex).getDValue());
            double[] dArr = this.probs[valueId];
            dArr[findBin] = dArr[findBin] + 1.0d;
            int[] iArr = this.counts[valueId];
            iArr[findBin] = iArr[findBin] + 1;
        }
        for (int i7 = 0; i7 < this.nclass; i7++) {
            double d = 0.0d;
            for (int i8 = 0; i8 < numBins + 2; i8++) {
                d += this.probs[i7][i8];
            }
            for (int i9 = 0; i9 < numBins + 2; i9++) {
                this.probs[i7][i9] = this.probs[i7][i9] / d;
            }
        }
    }

    public double getMin(int i) {
        return this.min;
    }

    public double getMax(int i) {
        return this.max;
    }

    public double getDx(int i) {
        return this.dx;
    }

    public int findBin(double d) {
        return d <= this.min ? 0 : d > this.max ? numBins + 1 : (int) Math.ceil((d - this.min) / this.dx);
    }

    public double getProb(int i, double d) {
        return this.probs[i][findBin(d)];
    }

    @Override // edu.wisc.sjm.machlearn.classifiers.naivebayes.probability.Probability
    public double getProb(int i, Example example) {
        return getProb(i, example.get(this.findex).getDValue());
    }

    @Override // edu.wisc.sjm.machlearn.classifiers.naivebayes.probability.Probability
    public String toXMLString() {
        StringBuffer stringBuffer = new StringBuffer();
        stringBuffer.append("<Probability type=\"Histogram\">\n   ");
        stringBuffer.append(this.fid.toXMLString());
        for (int i = 0; i < this.nclass; i++) {
            for (int i2 = 0; i2 < numBins + 2; i2++) {
                stringBuffer.append("   <prob class_id=\"" + i + "\" bin_id=\"" + i2 + "\" prob=\"" + this.probs[i][i2] + "\"/>");
            }
            stringBuffer.append("\n");
        }
        stringBuffer.append("<\\Probability>\n");
        return stringBuffer.toString();
    }

    public void paint(Graphics graphics, int i, int i2, int i3, int i4) {
        int i5 = i3 - 100;
        int i6 = i4 - 100;
        int i7 = i5 / (numBins + 2);
        for (int i8 = 0; i8 < numBins + 2; i8++) {
            int i9 = (i7 * i8) + i + (100 / 2);
            int max = Math.max(0, (int) (i6 * this.probs[1][i8]));
            int max2 = Math.max(0, (int) (i6 * this.probs[0][i8]));
            graphics.setColor(Color.blue);
            graphics.fillRect(i9, ((i2 + i6) - max2) + (100 / 2), i7 / 2, max2);
            graphics.setColor(Color.red);
            graphics.fillRect(i9 + (i7 / 2), ((i2 + i6) - max) + (100 / 2), i7 / 2, max);
        }
        graphics.setColor(Color.black);
        graphics.drawLine(i + (100 / 2), i6 + (100 / 2), i + (100 / 2) + i5, i6 + (100 / 2));
        for (int i10 = 0; i10 < numBins + 2; i10++) {
            int i11 = (i7 * i10) + i + (100 / 2);
            int i12 = i6 + (100 / 2);
            int i13 = i12 + 10;
            graphics.drawLine(i11, i12, i11, i13);
            if (i10 == 0) {
                graphics.drawString("<=" + DoubleUtil.printDecimal(this.min, 3), i11 + 1, i13 + 1);
            } else if (i10 == numBins + 1) {
                graphics.drawString(">" + DoubleUtil.printDecimal(this.max, 3), i11 + 1, i13 + 1);
            } else {
                graphics.drawString(DoubleUtil.printDecimal((this.min + (this.dx * i10)) - (this.dx / 2.0d), 3), i11 + 1, i13 + 1);
            }
        }
        graphics.drawLine(i + (100 / 2), i6 + (100 / 2), i + (100 / 2), 100 / 2);
        double d = i6 / (5 - 1);
        for (int i14 = 0; i14 < 5; i14++) {
            int i15 = (i6 - ((int) (d * i14))) + (100 / 2);
            int i16 = i + (100 / 2);
            int i17 = i16 - 10;
            graphics.drawLine(i17, i15, i16, i15);
            graphics.drawString(new StringBuilder().append((1.0d / (5 - 1)) * i14).toString(), i17 - 20, i15);
        }
        graphics.drawString(this.fid.printName(), i + (100 / 2) + 10, i2 + (100 / 2));
    }

    public static void main(String[] strArr) throws Exception {
    }
}
