package edu.wisc.sjm.machlearn.classifiers.naivebayes.probability;

import edu.wisc.sjm.jutil.vectors.DoubleVector;
import edu.wisc.sjm.jutil.xml.XMLUtil;
import edu.wisc.sjm.machlearn.dataset.Example;
import edu.wisc.sjm.machlearn.dataset.featuredataset.FeatureDataSet;
import java.io.FileWriter;
import java.io.PrintWriter;
import org.w3c.dom.Document;
import org.w3c.dom.Element;
import weka.classifiers.lazy.kstar.KStarConstants;

/* loaded from: input_file:builds/machlearn_install.jar:machlearn.jar:edu/wisc/sjm/machlearn/classifiers/naivebayes/probability/ProbabilityGaussian.class */
public class ProbabilityGaussian extends Probability {
    static double k = 1.0d / Math.sqrt(6.283185307179586d);
    double[] mean;
    double[] std;

    public ProbabilityGaussian() {
    }

    public ProbabilityGaussian(FeatureDataSet featureDataSet, int i) {
        super(featureDataSet, i);
        featureDataSet.getOutputIndex();
        int numValues = featureDataSet.getOutputFeatureId().numValues();
        this.mean = new double[numValues];
        this.std = new double[numValues];
    }

    @Override // edu.wisc.sjm.machlearn.classifiers.naivebayes.probability.Probability
    public void fromXML(Element element) throws Exception {
        this.mean = XMLUtil.getXMLDoubleArray(element, "mean", KStarConstants.FLOOR);
        this.std = XMLUtil.getXMLDoubleArray(element, "std", 1.0d);
    }

    @Override // edu.wisc.sjm.machlearn.classifiers.naivebayes.probability.Probability
    public Element toXML(Document document, Element element) {
        Element xml = super.toXML(document, element);
        XMLUtil.setXMLValue(document, xml, "mean", this.mean);
        XMLUtil.setXMLValue(document, xml, "std", this.std);
        return xml;
    }

    @Override // edu.wisc.sjm.machlearn.classifiers.naivebayes.probability.Probability
    public void train(FeatureDataSet featureDataSet) {
        featureDataSet.getOutputIndex();
        int numValues = featureDataSet.getOutputFeatureId().numValues();
        DoubleVector[] doubleVectorArr = new DoubleVector[numValues];
        DoubleVector doubleVector = new DoubleVector();
        for (int i = 0; i < numValues; i++) {
            doubleVectorArr[i] = new DoubleVector();
        }
        for (int i2 = 0; i2 < featureDataSet.size(); i2++) {
            int valueId = featureDataSet.getOutputFeature(i2).getValueId();
            doubleVector.add(featureDataSet.get(i2, this.findex).getDValue());
            doubleVectorArr[valueId].add(featureDataSet.get(i2, this.findex).getDValue());
        }
        for (int i3 = 0; i3 < numValues; i3++) {
            double sum = (doubleVectorArr[i3].sum() + (cm * doubleVector.average())) / (doubleVectorArr[i3].size() + cm);
            if (Double.isNaN(sum)) {
                internalError("tmean is NAN!");
            }
            this.mean[i3] = sum;
            double stddev = ((doubleVectorArr[i3].stddev() * doubleVectorArr[i3].size()) + (cm * doubleVector.stddev())) / (doubleVectorArr[i3].size() + cm);
            if (Double.isNaN(stddev)) {
                internalError("tstd is NAN!");
            }
            this.std[i3] = stddev + 1.0E-5d;
        }
    }

    public static double gaussianValue(double d, double d2, double d3) {
        double d4 = d - d2;
        return (k / d3) * Math.exp((((-0.5d) * d4) * d4) / (d3 * d3));
    }

    public double maxValue(int i) {
        return gaussianValue(this.mean[i], this.mean[i], this.std[i]);
    }

    public double getStd(int i) {
        return this.std[i];
    }

    public double getMean(int i) {
        return this.mean[i];
    }

    @Override // edu.wisc.sjm.machlearn.classifiers.naivebayes.probability.Probability
    public double getProb(int i, Example example) {
        return gaussianValue(example.get(this.findex).getDValue(), this.mean[i], this.std[i]);
    }

    public void setProb(int i, double d, double d2) {
        this.mean[i] = d;
        this.std[i] = d2;
    }

    @Override // edu.wisc.sjm.machlearn.classifiers.naivebayes.probability.Probability
    public String toXMLString() {
        StringBuffer stringBuffer = new StringBuffer();
        stringBuffer.append("<Probability type=\"Gaussian\">\n   ");
        stringBuffer.append(this.fid.toXMLString());
        for (int i = 0; i < this.mean.length; i++) {
            stringBuffer.append("   <prob class_id=\"" + i + "\"");
            stringBuffer.append(" mean=\"" + this.mean[i] + "\"");
            stringBuffer.append(" stddev=\"" + this.std[i] + "\"\\>");
            stringBuffer.append("\n");
        }
        stringBuffer.append("<\\Probability>\n");
        return stringBuffer.toString();
    }

    public static void main(String[] strArr) throws Exception {
        if (strArr.length != 2) {
            System.out.println("usage: ProbabilityGaussian <mean> <std>");
            System.exit(-1);
        }
        double parseDouble = Double.parseDouble(strArr[0]);
        double parseDouble2 = Double.parseDouble(strArr[1]);
        PrintWriter printWriter = new PrintWriter(new FileWriter("gauss.txt"));
        double d = parseDouble - (4.0d * parseDouble2);
        double d2 = ((parseDouble + (4.0d * parseDouble2)) - d) / (1000 - 1);
        double d3 = d;
        double d4 = 0.0d;
        for (int i = 0; i < 1000; i++) {
            printWriter.println(d3 + "\t" + gaussianValue(d3, parseDouble, parseDouble2));
            d3 += d2;
            d4 += d2 * gaussianValue(d3, parseDouble, parseDouble2);
        }
        System.out.println("estimated area:" + d4);
        printWriter.flush();
        printWriter.close();
    }
}
