package com.sjm.machlearn.classifiers.naivebayes;

import com.sjm.machlearn.classifiers.Classifier;
import com.sjm.machlearn.dataset.DataSet;
import com.sjm.machlearn.dataset.Example;
import com.sjm.machlearn.dataset.Feature;
import com.sjm.machlearn.exceptions.InvalidFeature;
import com.sjm.machlearn.util.Util;

/* loaded from: input_file:com/sjm/machlearn/classifiers/naivebayes/NaiveBayes.class */
public class NaiveBayes extends Classifier {
    protected FeatureCounts outputProbs;
    protected FeatureCounts[][] featureProbs;
    protected DataSet trainingData;
    protected int nbins;
    protected Cont2DiscConverter converter;

    public NaiveBayes() {
        this(10);
    }

    public NaiveBayes(int i) {
        setNumBins(i);
    }

    @Override // com.sjm.machlearn.classifiers.Classifier
    public Feature classify(Example example) {
        Example convert = this.converter.convert(example);
        Feature outputFeature = convert.getOutputFeature();
        int outputIndex = this.trainingData.getOutputIndex();
        Feature feature = new Feature(outputFeature);
        double[] dArr = new double[outputFeature.numValues()];
        for (int i = 0; i < dArr.length; i++) {
            try {
                dArr[i] = this.outputProbs.getProb(i);
                for (int i2 = 0; i2 < convert.size(); i2++) {
                    if (i2 != outputIndex) {
                        int i3 = i;
                        dArr[i3] = dArr[i3] * this.featureProbs[i][i2].getProb(convert.get(i2).getValueId());
                    }
                }
            } catch (InvalidFeature e) {
                internalError(e);
            }
        }
        feature.setValue(Util.argmax(dArr));
        debugMesg("Prob array");
        if (this.debug) {
            Util.printArray(dArr);
        }
        debugMesg(new StringBuffer("answer is:").append(feature.getValue()).toString());
        debugMesg(new StringBuffer("correct is:").append(outputFeature.getValue()).toString());
        return feature;
    }

    @Override // com.sjm.machlearn.classifiers.Classifier
    public Classifier cloneClassifier() {
        return new NaiveBayes(this.nbins);
    }

    protected void initializeFeatures() {
        Feature outputFeature = this.trainingData.getOutputFeature();
        int numFeatures = this.trainingData.numFeatures();
        this.outputProbs = new FeatureCounts(outputFeature);
        this.featureProbs = new FeatureCounts[outputFeature.numValues()][numFeatures];
        Example example = this.trainingData.get(0);
        for (int i = 0; i < outputFeature.numValues(); i++) {
            for (int i2 = 0; i2 < numFeatures; i2++) {
                this.featureProbs[i][i2] = new FeatureCounts(example.get(i2));
            }
        }
    }

    @Override // com.sjm.machlearn.classifiers.Classifier
    public String printClassifier() {
        StringBuffer stringBuffer = new StringBuffer();
        stringBuffer.append("Naive Bayes\n");
        stringBuffer.append(new StringBuffer("Number of bins for continuous:").append(this.nbins).toString());
        return stringBuffer.toString();
    }

    public void setNumBins(int i) {
        this.nbins = i;
        this.converter = new Cont2DiscConverter(i);
    }

    @Override // com.sjm.machlearn.classifiers.Classifier
    public void setParameter(int i, Object obj) {
        setNumBins(((Integer) obj).intValue());
    }

    @Override // com.sjm.machlearn.classifiers.Classifier
    public void train(DataSet dataSet) {
        this.trainingData = this.converter.convert(dataSet);
        initializeFeatures();
        this.trainingData.getOutputFeature();
        int numFeatures = dataSet.numFeatures();
        int outputIndex = this.trainingData.getOutputIndex();
        for (int i = 0; i < dataSet.size(); i++) {
            Example example = this.trainingData.get(i);
            for (int i2 = 0; i2 < numFeatures; i2++) {
                Feature feature = example.get(i2);
                if (i2 == outputIndex) {
                    try {
                        this.outputProbs.addCount(feature.getValueId());
                    } catch (InvalidFeature e) {
                        internalError(e);
                    }
                } else {
                    this.featureProbs[example.get(outputIndex).getValueId()][i2].addCount(feature.getValueId());
                }
            }
        }
        debugMesg(new StringBuffer("output:\n").append(this.outputProbs).toString());
    }

    public void tune(DataSet dataSet, Object[] objArr) {
    }
}
