package edu.wisc.sjm.machlearn.classifiers.naivebayes.probability;

import edu.wisc.sjm.jutil.vectors.DoubleVector;
import edu.wisc.sjm.jutil.vectors.IntVector;
import edu.wisc.sjm.machlearn.dataset.DataSet;
import edu.wisc.sjm.machlearn.dataset.Example;
import edu.wisc.sjm.machlearn.dataset.featuredataset.FeatureDataSet;
import edu.wisc.sjm.machlearn.util.APRUtil;
import edu.wisc.sjm.machlearn.util.Util;

/* loaded from: input_file:builds/machlearn_install.jar:machlearn.jar:edu/wisc/sjm/machlearn/classifiers/naivebayes/probability/ProbabilityMixedGaussian.class */
public class ProbabilityMixedGaussian extends Probability {
    double[][] mean;
    double[][] var;
    double[][] weight;
    DoubleVector[] values;
    int k;
    int max_iter;

    public ProbabilityMixedGaussian(FeatureDataSet featureDataSet, int i) {
        super(featureDataSet, i);
        this.k = 2;
        this.max_iter = 50;
    }

    @Override // edu.wisc.sjm.machlearn.classifiers.naivebayes.probability.Probability
    public void train(FeatureDataSet featureDataSet) {
        tune(featureDataSet);
    }

    public void _train(FeatureDataSet featureDataSet) {
        this.nclass = featureDataSet.getOutputFeature().numValues();
        this.mean = new double[this.nclass][this.k];
        this.var = new double[this.nclass][this.k];
        this.weight = new double[this.nclass][this.k];
        this.values = new DoubleVector[this.nclass];
        for (int i = 0; i < this.nclass; i++) {
            this.values[i] = new DoubleVector(featureDataSet.size());
        }
        for (int i2 = 0; i2 < this.nclass; i2++) {
            this.values[i2].empty();
        }
        for (int i3 = 0; i3 < featureDataSet.size(); i3++) {
            this.values[featureDataSet.getOutputValueId(i3)].add(featureDataSet.get(i3, this.findex).getDValue());
        }
        for (int i4 = 0; i4 < this.nclass; i4++) {
            MixedEM(this.values[i4], this.mean[i4], this.var[i4], this.weight[i4]);
        }
    }

    public void tune(FeatureDataSet featureDataSet) {
        try {
            DataSet[][] splitDataSetFolds = featureDataSet.splitDataSetFolds(10, true);
            DoubleVector doubleVector = new DoubleVector();
            IntVector intVector = new IntVector();
            double d = 0.0d;
            int i = 1;
            for (int i2 = 2; i2 < 10; i2++) {
                this.k = i2;
                doubleVector.empty();
                intVector.empty();
                for (int i3 = 0; i3 < splitDataSetFolds.length; i3++) {
                    FeatureDataSet featureDataSet2 = (FeatureDataSet) splitDataSetFolds[i3][1];
                    _train(featureDataSet);
                    for (int i4 = 0; i4 < featureDataSet2.size(); i4++) {
                        double prob = getProb(0, featureDataSet2.getExample(i4));
                        double prob2 = getProb(1, featureDataSet2.getExample(i4));
                        double d2 = prob + prob2;
                        doubleVector.add(d2 < 1.0E-10d ? 0.5d : prob2 / d2);
                        intVector.add(featureDataSet2.getOutputValueId(i4));
                    }
                }
                double apr = APRUtil.getAPR(doubleVector, intVector);
                System.out.println("k:" + i2);
                System.out.println("apr:" + apr);
                System.out.println("best score:" + d);
                System.out.println("best k:" + i);
                if (apr > d) {
                    d = apr;
                    i = i2;
                }
            }
            System.out.println("best score:" + d);
            System.out.println("best k:" + i);
            this.k = i;
            _train(featureDataSet);
        } catch (Exception e) {
            internalError(e);
        }
    }

    public void initializeValues(DoubleVector doubleVector, double[] dArr, double[] dArr2, double[] dArr3) {
        for (int i = 0; i < dArr.length; i++) {
            dArr[i] = doubleVector.get(Util.randomInteger(0, doubleVector.size() - 1));
            dArr2[i] = doubleVector.variance();
            dArr3[i] = 1.0d / this.k;
        }
    }

    public double getLk(DoubleVector doubleVector, double[] dArr, double[] dArr2, double[] dArr3) {
        double d = 0.0d;
        for (int i = 0; i < doubleVector.size(); i++) {
            d += Math.log(getFk(doubleVector.get(i), dArr, dArr2, dArr3));
        }
        return d;
    }

    public void MixedEM(DoubleVector doubleVector, double[] dArr, double[] dArr2, double[] dArr3) {
        initializeValues(doubleVector, dArr, dArr2, dArr3);
        int i = 0;
        double lk = getLk(doubleVector, dArr, dArr2, dArr3);
        double[] dArr4 = new double[this.k];
        double[] dArr5 = new double[this.k];
        double[] dArr6 = new double[this.k];
        boolean z = true;
        do {
            for (int i2 = 0; i2 < this.k; i2++) {
                double d = 0.0d;
                double d2 = 0.0d;
                double d3 = 0.0d;
                for (int i3 = 0; i3 < doubleVector.size(); i3++) {
                    double p = getP(doubleVector.get(i3), dArr, dArr2, dArr3, i2);
                    if (Double.isNaN(p)) {
                        p = 1.0E-10d;
                    }
                    d += p;
                    d2 += p * doubleVector.get(i3);
                    d3 = p * (doubleVector.get(i3) - dArr[i2]) * (doubleVector.get(i3) - dArr[i2]);
                }
                dArr4[i2] = d / doubleVector.size();
                dArr5[i2] = d2 / d;
                dArr6[i2] = d3 / d;
            }
            double lk2 = getLk(doubleVector, dArr4, dArr5, dArr6);
            boolean z2 = !Double.isInfinite(lk2) && lk2 >= lk;
            for (int i4 = 0; i4 < this.k; i4++) {
                dArr3[i4] = dArr4[i4];
                dArr[i4] = dArr5[i4];
                if (z2) {
                    dArr2[i4] = dArr6[i4];
                }
            }
            double d4 = lk2;
            if (!z2) {
                d4 = getLk(doubleVector, dArr, dArr2, dArr3);
            }
            i++;
            if ((Math.abs(lk - d4) < 0.001d && !z2) || i > this.max_iter) {
                z = false;
            }
            lk = d4;
        } while (z);
        System.out.println("i:" + i + " last lk:" + lk);
    }

    protected static double getFk(double d, double[] dArr, double[] dArr2, double[] dArr3) {
        double d2 = 0.0d;
        for (int i = 0; i < dArr.length; i++) {
            d2 += dArr3[i] * gaussianValue(d, dArr[i], dArr2[i]);
        }
        return d2;
    }

    protected static double gaussianValue(double d, double d2, double d3) {
        return ProbabilityGaussian.gaussianValue(d, d2, Math.sqrt(d3));
    }

    protected static double getP(double d, double[] dArr, double[] dArr2, double[] dArr3, int i) {
        return (dArr3[i] * gaussianValue(d, dArr[i], dArr2[i])) / getFk(d, dArr, dArr2, dArr3);
    }

    public double getProb(int i, double d) {
        return getFk(d, this.mean[i], this.var[i], this.weight[i]);
    }

    @Override // edu.wisc.sjm.machlearn.classifiers.naivebayes.probability.Probability
    public double getProb(int i, Example example) {
        return getProb(i, example.get(this.findex).getDValue());
    }
}
