package edu.wisc.sjm.machlearn.policy.fdspreprocessor.discretize;

import edu.wisc.sjm.jutil.math.JMath;
import edu.wisc.sjm.jutil.misc.DoubleUtil;
import edu.wisc.sjm.jutil.vectors.DoubleVector;
import edu.wisc.sjm.jutil.vectors.IntVector;
import edu.wisc.sjm.machlearn.Scorer;
import edu.wisc.sjm.machlearn.dataset.Example;
import edu.wisc.sjm.machlearn.dataset.Feature;
import edu.wisc.sjm.machlearn.dataset.FeatureId;
import edu.wisc.sjm.machlearn.dataset.featuredataset.FeatureDataSet;
import edu.wisc.sjm.machlearn.policy.FDSPreProcessor;
import edu.wisc.sjm.machlearn.util.Util;
import weka.classifiers.lazy.kstar.KStarConstants;

/* loaded from: input_file:builds/machlearn_install.jar:machlearn.jar:edu/wisc/sjm/machlearn/policy/fdspreprocessor/discretize/BinaryDiscretizeIG.class */
public class BinaryDiscretizeIG extends FDSPreProcessor {
    private static IntVector index = new IntVector();
    private static DoubleVector values = new DoubleVector();
    private DoubleVector thresholds = new DoubleVector();
    private double min_threshold = KStarConstants.FLOOR;

    public void setMinThreshold(String str) {
        setMinThreshold(Double.parseDouble(str));
    }

    public void setMinThreshold(double d) {
        this.min_threshold = d;
    }

    @Override // edu.wisc.sjm.machlearn.policy.FDSPreProcessor
    public void train(FeatureDataSet featureDataSet) {
        this.thresholds.empty();
        for (int i = 0; i < featureDataSet.numFeatures(); i++) {
            if (featureDataSet.getFeatureId(i).isContinuous() || featureDataSet.getFeatureId(i).numValues() > 2) {
                this.thresholds.add(calcThreshold(featureDataSet, i, this.min_threshold));
            } else {
                this.thresholds.add(KStarConstants.FLOOR);
            }
        }
    }

    public static double calcThreshold(FeatureDataSet featureDataSet, int i) {
        return calcThreshold(featureDataSet, i, Double.NaN);
    }

    public static double calcThreshold(FeatureDataSet featureDataSet, int i, double d) {
        values.empty();
        for (int i2 = 0; i2 < featureDataSet.size(); i2++) {
            values.add(featureDataSet.get(i2, i).getDValue());
        }
        DoubleVector.QuickSort(values, index);
        DoubleVector doubleVector = new DoubleVector();
        int valueId = featureDataSet.getOutputFeature(index.get(0)).getValueId();
        if (!Double.isNaN(d)) {
            doubleVector.add(d);
        }
        for (int i3 = 1; i3 < featureDataSet.size(); i3++) {
            if (featureDataSet.getOutputFeature(index.get(i3)).getValueId() != valueId) {
                double d2 = (values.get(index.get(i3 - 1)) + values.get(index.get(i3))) / 2.0d;
                if (d2 > d || Double.isNaN(d)) {
                    doubleVector.add(d2);
                }
                valueId = featureDataSet.getOutputFeature(index.get(i3)).getValueId();
            }
        }
        double score = getScore(featureDataSet, i, doubleVector.get(0));
        int i4 = 0;
        for (int i5 = 1; i5 < doubleVector.size(); i5++) {
            double score2 = getScore(featureDataSet, i, doubleVector.get(i5));
            if (score2 > score) {
                i4 = i5;
                score = score2;
            }
        }
        return doubleVector.get(i4);
    }

    public static double getScore(FeatureDataSet featureDataSet, int i, double d) {
        int[] iArr = new int[2];
        int[] iArr2 = new int[2];
        for (int i2 = 0; i2 < featureDataSet.size(); i2++) {
            if (featureDataSet.get(i2, i).getDValue() < d) {
                int valueId = featureDataSet.getOutputFeature(i2).getValueId();
                iArr[valueId] = iArr[valueId] + 1;
            } else {
                int valueId2 = featureDataSet.getOutputFeature(i2).getValueId();
                iArr2[valueId2] = iArr2[valueId2] + 1;
            }
        }
        return calcGain(iArr, iArr2);
    }

    public static double calcGain(int[] iArr, int[] iArr2) {
        int[] iArr3 = {iArr[0] + iArr2[0], iArr[1] + iArr2[1]};
        int i = iArr3[0] + iArr3[1];
        double calcEntropy = calcEntropy(iArr3);
        if (i > 0) {
            calcEntropy -= ((Util.sum(iArr) / i) * calcEntropy(iArr)) + ((Util.sum(iArr2) / i) * calcEntropy(iArr2));
        }
        return calcEntropy;
    }

    public static double calcEntropy(int[] iArr) {
        int sum = Util.sum(iArr);
        double d = 0.0d;
        for (int i : iArr) {
            d += -JMath.xlog2(i / sum);
        }
        return d;
    }

    public Example process(Example example) {
        Feature[] featureArr = new Feature[example.numFeatures()];
        String[] strArr = {"f", "t"};
        Example example2 = null;
        for (int i = 0; i < example.numFeatures(); i++) {
            try {
                if (i == example.getOutputIndex()) {
                    featureArr[i] = (Feature) example.get(i).clone();
                } else {
                    Feature feature = example.get(i);
                    FeatureId featureId = feature.getFeatureId();
                    if (featureId.isContinuous() || featureId.numValues() > 2) {
                        String str = String.valueOf(featureId.printName()) + ".gt." + DoubleUtil.printDecimal(this.thresholds.get(i), 3);
                        featureArr[i] = feature.getDValue() > this.thresholds.get(i) ? Feature.createDiscreteFeature(str, strArr, "t") : Feature.createDiscreteFeature(str, strArr, "f");
                    } else {
                        featureArr[i] = (Feature) example.get(i).clone();
                    }
                }
            } catch (Exception e) {
                internalError(e);
            }
        }
        example2 = new Example(featureArr, example.getName());
        return example2;
    }

    @Override // edu.wisc.sjm.machlearn.policy.FDSPreProcessor
    public FeatureDataSet process(FeatureDataSet featureDataSet) {
        Feature[][] featureArr = new Feature[featureDataSet.size()][featureDataSet.numFeatures()];
        String[] strArr = {"f", "t"};
        String[] strArr2 = new String[featureDataSet.size()];
        FeatureDataSet featureDataSet2 = null;
        for (int i = 0; i < featureDataSet.size(); i++) {
            try {
                for (int i2 = 0; i2 < featureDataSet.numFeatures(); i2++) {
                    if (i2 == featureDataSet.getOutputIndex()) {
                        featureArr[i][i2] = (Feature) featureDataSet.get(i, i2).clone();
                    } else {
                        Feature feature = featureDataSet.get(i, i2);
                        FeatureId featureId = feature.getFeatureId();
                        if (featureId.isContinuous()) {
                            featureArr[i][i2] = feature.getDValue() > this.thresholds.get(i2) ? Feature.createDiscreteFeature(String.valueOf(featureId.printName()) + ".gt." + DoubleUtil.printDecimal(this.thresholds.get(i2), 3), strArr, "t") : Feature.createDiscreteFeature(String.valueOf(featureId.printName()) + ".gt." + DoubleUtil.printDecimal(this.thresholds.get(i2), 3), strArr, "f");
                        } else {
                            featureArr[i][i2] = (Feature) featureDataSet.get(i, i2).clone();
                        }
                    }
                }
                strArr2[i] = featureDataSet.getName(i);
            } catch (Exception e) {
                internalError(e);
            }
        }
        featureDataSet2 = new FeatureDataSet(featureArr, strArr2);
        return featureDataSet2;
    }

    public static double maxAccuracyThreshold(int[] iArr, double[] dArr) {
        IntVector intVector = new IntVector();
        DoubleVector.QuickSort(new DoubleVector(dArr), intVector);
        int[] iArr2 = new int[dArr.length - 2];
        for (int i = 1; i < iArr.length - 1; i++) {
            int i2 = 0;
            for (int i3 = 0; i3 < i; i3++) {
                if (iArr[intVector.get(i3)] == 0) {
                    i2++;
                }
            }
            for (int i4 = i; i4 < iArr.length; i4++) {
                if (iArr[intVector.get(i4)] == 1) {
                    i2++;
                }
            }
            iArr2[i - 1] = i2;
        }
        int argmax = Util.argmax(iArr2);
        return (dArr[intVector.get(argmax)] + dArr[intVector.get(argmax + 1)]) / 2.0d;
    }

    @Override // edu.wisc.sjm.machlearn.policy.FDSPreProcessor
    public boolean needScorer() {
        return false;
    }

    @Override // edu.wisc.sjm.machlearn.policy.FDSPreProcessor
    public void setScorer(Scorer scorer) {
    }
}
