package edu.wisc.sjm.machlearn.classifiers.naivebayes;

import edu.wisc.sjm.jutil.math.JMath;
import edu.wisc.sjm.machlearn.classifiers.Classifier;
import edu.wisc.sjm.machlearn.dataset.Example;
import edu.wisc.sjm.machlearn.dataset.Feature;
import edu.wisc.sjm.machlearn.dataset.featuredataset.FeatureDataSet;
import edu.wisc.sjm.machlearn.exceptions.InvalidFeature;
import edu.wisc.sjm.machlearn.util.Util;

/* loaded from: input_file:builds/machlearn_install.jar:machlearn.jar:edu/wisc/sjm/machlearn/classifiers/naivebayes/NaiveBayesLog.class */
public class NaiveBayesLog extends Classifier {
    protected FeatureCounts outputProbs;
    protected FeatureCounts[][] featureProbs;
    protected FeatureDataSet trainingData;
    protected int nbins;
    protected Cont2DiscConverter converter;

    public NaiveBayesLog() {
        this(10);
    }

    public NaiveBayesLog(int i) {
        setNumBins(i);
    }

    public void setNumBins(int i) {
        this.nbins = i;
        this.converter = new Cont2DiscConverter(i);
    }

    protected void initializeFeatures() {
        Feature outputFeature = this.trainingData.getOutputFeature();
        int numFeatures = this.trainingData.numFeatures();
        this.outputProbs = new FeatureCounts(outputFeature);
        this.featureProbs = new FeatureCounts[outputFeature.numValues()][numFeatures];
        Example example = this.trainingData.getExample(0);
        for (int i = 0; i < outputFeature.numValues(); i++) {
            for (int i2 = 0; i2 < numFeatures; i2++) {
                this.featureProbs[i][i2] = new FeatureCounts(example.get(i2), 1.0d);
            }
        }
    }

    @Override // edu.wisc.sjm.machlearn.classifiers.Classifier
    public void train_(FeatureDataSet featureDataSet) {
        this.trainingData = this.converter.convert(featureDataSet);
        initializeFeatures();
        this.trainingData.getOutputFeature();
        int numFeatures = featureDataSet.numFeatures();
        int outputIndex = this.trainingData.getOutputIndex();
        for (int i = 0; i < featureDataSet.size(); i++) {
            Example example = this.trainingData.getExample(i);
            for (int i2 = 0; i2 < numFeatures; i2++) {
                Feature feature = example.get(i2);
                if (i2 == outputIndex) {
                    try {
                        this.outputProbs.addCount(feature.getValueId());
                    } catch (InvalidFeature e) {
                        internalError(e);
                    }
                } else {
                    this.featureProbs[example.get(outputIndex).getValueId()][i2].addCount(feature.getValueId());
                }
            }
        }
        debugMesg("output:\n" + this.outputProbs);
    }

    @Override // edu.wisc.sjm.machlearn.classifiers.Classifier
    public Feature classify_(Example example) {
        Example convert = this.converter.convert(example);
        Feature outputFeature = convert.getOutputFeature();
        int outputIndex = this.trainingData.getOutputIndex();
        Feature feature = new Feature(outputFeature);
        double[] dArr = new double[outputFeature.numValues()];
        for (int i = 0; i < dArr.length; i++) {
            try {
                dArr[i] = JMath.log2(this.outputProbs.getProb(i));
                for (int i2 = 0; i2 < convert.size(); i2++) {
                    if (i2 != outputIndex) {
                        int i3 = i;
                        dArr[i3] = dArr[i3] + JMath.log2(this.featureProbs[i][i2].getProb(convert.get(i2).getValueId()));
                    }
                }
            } catch (InvalidFeature e) {
                internalError(e);
            }
        }
        feature.setValue(Util.argmax(dArr));
        debugMesg("Prob array");
        debugMesg(Util.printArray(dArr));
        debugMesg("answer is:" + feature.getValue());
        debugMesg("correct is:" + outputFeature.getValue());
        return feature;
    }

    public void tune(FeatureDataSet featureDataSet, Object[] objArr) {
    }

    @Override // edu.wisc.sjm.machlearn.classifiers.Classifier
    public String printClassifier() {
        StringBuffer stringBuffer = new StringBuffer();
        stringBuffer.append("Naive Bayes\n");
        stringBuffer.append("Number of bins for continuous:" + this.nbins);
        return stringBuffer.toString();
    }

    @Override // edu.wisc.sjm.machlearn.classifiers.Classifier
    public Classifier cloneClassifier() {
        return new NaiveBayes(this.nbins);
    }

    @Override // edu.wisc.sjm.machlearn.classifiers.Classifier
    public void setParameter(int i, Object obj) {
        setNumBins((int) ((Double) obj).doubleValue());
    }
}
