package edu.wisc.sjm.machlearn.classifiers.naivebayes;

import edu.wisc.sjm.machlearn.classifiers.Classifier;
import edu.wisc.sjm.machlearn.dataset.Example;
import edu.wisc.sjm.machlearn.dataset.Feature;
import edu.wisc.sjm.machlearn.dataset.featuredataset.FeatureDataSet;
import edu.wisc.sjm.machlearn.exceptions.InvalidFeature;

/* loaded from: input_file:builds/machlearn_install.jar:machlearn.jar:edu/wisc/sjm/machlearn/classifiers/naivebayes/NaiveBayesLog2.class */
public class NaiveBayesLog2 extends Classifier {
    protected FeatureCounts outputProbs;
    protected FeatureCounts[][] featureProbs;
    protected FeatureDataSet trainingData;
    protected int nbins;
    protected Cont2DiscConverter converter;
    protected double prior;
    protected double m;
    protected double p_pos;
    protected double p_neg;
    protected double[][] p_fpos;
    protected double[][] p_fneg;
    protected double threshold;

    public NaiveBayesLog2() {
        this(10);
        this.m = 2.0d;
        this.prior = 0.5d;
        setThreshold(0.5d);
    }

    public NaiveBayesLog2(int i) {
        setNumBins(i);
    }

    public void setNumBins(int i) {
        this.nbins = i;
        this.converter = new Cont2DiscConverter(i);
        this.m = 2.0d;
        this.prior = 0.5d;
    }

    public void setThreshold(double d) {
        this.threshold = d;
    }

    public void setThreshold(String str) {
        setThreshold(Double.parseDouble(str));
    }

    @Override // edu.wisc.sjm.machlearn.classifiers.Classifier
    public void train_(FeatureDataSet featureDataSet) {
        featureDataSet.getOutputIndex();
        int[][] iArr = new int[featureDataSet.numFeatures()][2];
        int[][] iArr2 = new int[featureDataSet.numFeatures()][2];
        this.p_fpos = new double[featureDataSet.numFeatures()][2];
        this.p_fneg = new double[featureDataSet.numFeatures()][2];
        int i = 0;
        int i2 = 0;
        for (int i3 = 0; i3 < featureDataSet.size(); i3++) {
            if (featureDataSet.getOutputFeature(i3).getValueId() == 1) {
                i++;
                for (int i4 = 0; i4 < featureDataSet.numFeatures(); i4++) {
                    int[] iArr3 = iArr[i4];
                    int valueId = featureDataSet.get(i3, i4).getValueId();
                    iArr3[valueId] = iArr3[valueId] + 1;
                }
            } else {
                i2++;
                for (int i5 = 0; i5 < featureDataSet.numFeatures(); i5++) {
                    int[] iArr4 = iArr2[i5];
                    int valueId2 = featureDataSet.get(i3, i5).getValueId();
                    iArr4[valueId2] = iArr4[valueId2] + 1;
                }
            }
        }
        this.p_pos = getLProb(i, featureDataSet.size());
        this.p_neg = getLProb(i2, featureDataSet.size());
        for (int i6 = 0; i6 < featureDataSet.numFeatures(); i6++) {
            for (int i7 = 0; i7 < 2; i7++) {
                this.p_fpos[i6][i7] = getLProb(iArr[i6][i7], i);
                this.p_fneg[i6][i7] = getLProb(iArr2[i6][i7], i2);
            }
        }
    }

    public double getLProb(int i, int i2) {
        return (i + (this.m * this.prior)) / (i2 + this.m);
    }

    @Override // edu.wisc.sjm.machlearn.classifiers.Classifier
    public Feature classify_(Example example) {
        Feature outputFeature = example.getOutputFeature();
        int outputIndex = example.getOutputIndex();
        Feature feature = new Feature(outputFeature);
        double d = this.p_pos;
        double d2 = this.p_neg;
        for (int i = 0; i < example.numFeatures(); i++) {
            try {
                if (i != outputIndex) {
                    double d3 = d * this.p_fpos[i][example.get(i).getValueId()];
                    double d4 = d2 * this.p_fneg[i][example.get(i).getValueId()];
                    double d5 = d3 + d4;
                    d = d3 / d5;
                    d2 = d4 / d5;
                }
            } catch (Exception e) {
                System.out.println("Error during classification");
                e.printStackTrace();
            }
        }
        try {
            if (d >= this.threshold) {
                feature.setValueId(1);
            } else {
                feature.setValueId(0);
            }
        } catch (InvalidFeature e2) {
            internalError(e2);
        }
        return feature;
    }

    @Override // edu.wisc.sjm.machlearn.classifiers.Classifier
    public double[] getDistribution(Example example) {
        double[] dArr = new double[2];
        int outputIndex = example.getOutputIndex();
        double d = this.p_pos;
        double d2 = this.p_neg;
        for (int i = 0; i < example.numFeatures(); i++) {
            try {
                if (i != outputIndex) {
                    double d3 = d * this.p_fpos[i][example.get(i).getValueId()];
                    double d4 = d2 * this.p_fneg[i][example.get(i).getValueId()];
                    double d5 = d3 + d4;
                    d = d3 / d5;
                    d2 = d4 / d5;
                }
            } catch (Exception e) {
                System.out.println("Error during classification");
                e.printStackTrace();
            }
        }
        dArr[1] = d;
        dArr[0] = d2;
        return dArr;
    }

    public void tune(FeatureDataSet featureDataSet, Object[] objArr) {
    }

    @Override // edu.wisc.sjm.machlearn.classifiers.Classifier
    public String printClassifier() {
        StringBuffer stringBuffer = new StringBuffer();
        stringBuffer.append("Naive Bayes\n");
        stringBuffer.append("Number of bins for continuous:" + this.nbins);
        return stringBuffer.toString();
    }

    @Override // edu.wisc.sjm.machlearn.classifiers.Classifier
    public Classifier cloneClassifier() {
        return new NaiveBayes(this.nbins);
    }

    @Override // edu.wisc.sjm.machlearn.classifiers.Classifier
    public void setParameter(int i, Object obj) {
        setNumBins((int) ((Double) obj).doubleValue());
    }
}
