package edu.wisc.sjm.machlearn.classifiers.naivebayes.probability;

import edu.wisc.sjm.jutil.misc.DoubleUtil;
import edu.wisc.sjm.jutil.vectors.IntVector;
import edu.wisc.sjm.jutil.xml.XMLUtil;
import edu.wisc.sjm.machlearn.dataset.Example;
import edu.wisc.sjm.machlearn.dataset.Feature;
import edu.wisc.sjm.machlearn.dataset.FeatureId;
import edu.wisc.sjm.machlearn.dataset.featuredataset.FeatureDataSet;
import edu.wisc.sjm.machlearn.policy.fdspreprocessor.discretize.BinaryDiscretizeIG;
import org.w3c.dom.Document;
import org.w3c.dom.Element;
import weka.classifiers.lazy.kstar.KStarConstants;

/* loaded from: input_file:builds/machlearn_install.jar:machlearn.jar:edu/wisc/sjm/machlearn/classifiers/naivebayes/probability/ProbabilityDiscrete.class */
public class ProbabilityDiscrete extends Probability {
    double[][] probs;
    double[][] priors;
    int[][] counts;
    double threshold;
    boolean discretize;
    int nvalues;
    int nclass;

    public ProbabilityDiscrete() {
    }

    @Override // edu.wisc.sjm.machlearn.classifiers.naivebayes.probability.Probability
    public void fromXML(Element element) throws Exception {
        this.nvalues = XMLUtil.getXMLInteger(element, "NValues", -1);
        System.out.println("ProbabilityDiscrete:nvalues:" + this.nvalues);
        this.probs = XMLUtil.getXMLDouble2DArray(element, "Probs", KStarConstants.FLOOR);
        this.priors = XMLUtil.getXMLDouble2DArray(element, "Priors", KStarConstants.FLOOR);
        this.counts = XMLUtil.getXMLInteger2DArray(element, "Counts", 0);
        System.out.println("ProbabiltiyDiscrete:counts.length:" + this.counts.length);
    }

    @Override // edu.wisc.sjm.machlearn.classifiers.naivebayes.probability.Probability
    public Element toXML(Document document, Element element) {
        Element xml = super.toXML(document, element);
        XMLUtil.setXMLValue(document, xml, "NValues", this.nvalues);
        XMLUtil.setXMLValue(document, xml, "Probs", this.probs);
        XMLUtil.setXMLValue(document, xml, "Priors", this.priors);
        XMLUtil.setXMLValue(document, xml, "Counts", this.counts);
        return xml;
    }

    public void train(IntVector intVector, IntVector intVector2) {
        if (weightExamples && output_counts == null) {
            output_counts = new int[this.nclass];
            for (int i = 0; i < intVector.size(); i++) {
                int[] iArr = output_counts;
                int i2 = intVector.get(i);
                iArr[i2] = iArr[i2] + 1;
            }
        }
        for (int i3 = 0; i3 < this.nclass; i3++) {
            for (int i4 = 0; i4 < this.nvalues; i4++) {
                this.probs[i3][i4] = this.priors[i3][i4] * dm;
            }
        }
        for (int i5 = 0; i5 < intVector.size(); i5++) {
            double[] dArr = this.probs[intVector.get(i5)];
            int i6 = intVector2.get(i5);
            dArr[i6] = dArr[i6] + 1.0d;
            int[] iArr2 = this.counts[intVector.get(i5)];
            int i7 = intVector2.get(i5);
            iArr2[i7] = iArr2[i7] + 1;
        }
        for (int i8 = 0; i8 < this.nvalues; i8++) {
            this.probs[1][i8] = 1.0d * this.probs[1][i8];
        }
        for (int i9 = 0; i9 < this.nclass; i9++) {
            double d = 0.0d;
            for (int i10 = 0; i10 < this.nvalues; i10++) {
                d += this.probs[i9][i10];
            }
            for (int i11 = 0; i11 < this.nvalues; i11++) {
                this.probs[i9][i11] = this.probs[i9][i11] / d;
            }
        }
    }

    @Override // edu.wisc.sjm.machlearn.classifiers.naivebayes.probability.Probability
    public void train(FeatureDataSet featureDataSet) {
        double d = 1.0d;
        if (weightExamples) {
            if (output_counts == null) {
                output_counts = new int[featureDataSet.getOutputFeatureId().numValues()];
                for (int i = 0; i < featureDataSet.size(); i++) {
                    int[] iArr = output_counts;
                    int valueId = featureDataSet.getOutputFeature(i).getValueId();
                    iArr[valueId] = iArr[valueId] + 1;
                }
            }
            d = output_counts[0] / output_counts[1];
        }
        if (this.discretize) {
            this.threshold = BinaryDiscretizeIG.calcThreshold(featureDataSet, this.findex);
        }
        for (int i2 = 0; i2 < this.nclass; i2++) {
            for (int i3 = 0; i3 < this.nvalues; i3++) {
                this.probs[i2][i3] = this.priors[i2][i3] * dm;
            }
        }
        for (int i4 = 0; i4 < featureDataSet.size(); i4++) {
            if (!this.discretize) {
                double[] dArr = this.probs[featureDataSet.getOutputFeature(i4).getValueId()];
                int valueId2 = featureDataSet.get(i4, this.findex).getValueId();
                dArr[valueId2] = dArr[valueId2] + 1.0d;
                int[] iArr2 = this.counts[featureDataSet.getOutputFeature(i4).getValueId()];
                int valueId3 = featureDataSet.get(i4, this.findex).getValueId();
                iArr2[valueId3] = iArr2[valueId3] + 1;
            } else if (featureDataSet.get(i4, this.findex).getDValue() >= this.threshold) {
                double[] dArr2 = this.probs[featureDataSet.getOutputFeature(i4).getValueId()];
                dArr2[1] = dArr2[1] + 1.0d;
                int[] iArr3 = this.counts[featureDataSet.getOutputFeature(i4).getValueId()];
                iArr3[1] = iArr3[1] + 1;
            } else {
                double[] dArr3 = this.probs[featureDataSet.getOutputFeature(i4).getValueId()];
                dArr3[0] = dArr3[0] + 1.0d;
                int[] iArr4 = this.counts[featureDataSet.getOutputFeature(i4).getValueId()];
                iArr4[0] = iArr4[0] + 1;
            }
        }
        for (int i5 = 0; i5 < this.nvalues; i5++) {
            this.probs[1][i5] = d * this.probs[1][i5];
        }
        for (int i6 = 0; i6 < this.nclass; i6++) {
            double d2 = 0.0d;
            for (int i7 = 0; i7 < this.nvalues; i7++) {
                d2 += this.probs[i6][i7];
            }
            for (int i8 = 0; i8 < this.nvalues; i8++) {
                this.probs[i6][i8] = this.probs[i6][i8] / d2;
            }
        }
    }

    public void buildCPT(int i, int i2, boolean z) {
        this.nvalues = z ? 2 : i2;
        this.discretize = z;
        this.nclass = i;
        this.probs = new double[i][i2];
        this.priors = new double[i][i2];
        this.counts = new int[i][i2];
        double d = 1.0d / i2;
        for (int i3 = 0; i3 < i; i3++) {
            for (int i4 = 0; i4 < i2; i4++) {
                this.priors[i3][i4] = d;
            }
        }
    }

    public ProbabilityDiscrete(FeatureId featureId, FeatureId featureId2) {
        buildCPT(featureId.numValues(), featureId2.numValues(), featureId2.isContinuous());
    }

    public ProbabilityDiscrete(Feature feature, Feature feature2) {
        this(feature.getFeatureId(), feature2.getFeatureId());
    }

    public ProbabilityDiscrete(FeatureDataSet featureDataSet, int i) {
        super(featureDataSet, i);
        buildCPT(featureDataSet.getOutputFeatureId().numValues(), featureDataSet.getFeatureId(i).numValues(), featureDataSet.getOutputFeature().isContinuous());
    }

    public double getProb(int i, int i2) {
        return this.probs[i][i2];
    }

    @Override // edu.wisc.sjm.machlearn.classifiers.naivebayes.probability.Probability
    public double getProb(int i, Example example) {
        return getProb(i, example.get(this.findex));
    }

    public double getProb(int i, Feature feature) {
        if (this.discretize) {
            return getProb(i, feature.getDValue() >= this.threshold ? 1 : 0);
        }
        return getProb(i, feature.getValueId());
    }

    public int getNumValues() {
        return this.nvalues;
    }

    public void setProb(int i, int i2, double d) {
        this.probs[i][i2] = d;
    }

    public void setPrior(int i, int i2, double d) {
        this.priors[i][i2] = d;
    }

    public double getPrior(int i, int i2) {
        return this.priors[i][i2];
    }

    @Override // edu.wisc.sjm.machlearn.classifiers.naivebayes.probability.Probability
    public String toXMLString() {
        StringBuffer stringBuffer = new StringBuffer();
        stringBuffer.append("<Probability type=\"discrete\">\n   ");
        stringBuffer.append(this.fid.toXMLString());
        for (int i = 0; i < this.probs.length; i++) {
            int i2 = 0;
            while (i2 < this.probs[i].length) {
                stringBuffer.append("   <prob class=\"" + this.ofid.getValue(i) + "\"");
                if (this.discretize) {
                    stringBuffer.append(" feature_value > " + this.threshold + "? " + (i2 == 0 ? "No" : "Yes"));
                } else {
                    stringBuffer.append(" feature_value=\"" + this.fid.getValue(i2) + "\"");
                }
                stringBuffer.append(" prob=\"" + DoubleUtil.printDecimal(this.probs[i][i2], 5) + "\" counts=\"" + this.counts[i][i2] + "\"\\>");
                stringBuffer.append(" ");
                stringBuffer.append("\n");
                i2++;
            }
        }
        stringBuffer.append("<\\Probability>\n");
        return stringBuffer.toString();
    }

    @Override // edu.wisc.sjm.machlearn.classifiers.naivebayes.probability.Probability
    public void fromXMLString(String str) {
    }
}
