package edu.wisc.sjm.machlearn.policy.fdspreprocessor.discretize;

import edu.wisc.sjm.jutil.math.JMath;
import edu.wisc.sjm.jutil.misc.PropertiesUtil;
import edu.wisc.sjm.jutil.vectors.DoubleVector;
import edu.wisc.sjm.jutil.vectors.IntVector;
import edu.wisc.sjm.jutil.xml.XMLUtil;
import edu.wisc.sjm.machlearn.Scorer;
import edu.wisc.sjm.machlearn.classifiers.naivebayes.probability.ProbabilityDiscrete;
import edu.wisc.sjm.machlearn.classifiers.naivebayes.probability.ProbabilityGaussian;
import edu.wisc.sjm.machlearn.dataset.DataSet;
import edu.wisc.sjm.machlearn.dataset.Example;
import edu.wisc.sjm.machlearn.dataset.FeatureId;
import edu.wisc.sjm.machlearn.dataset.featuredataset.FeatureDataSet;
import edu.wisc.sjm.machlearn.policy.FDSPreProcessor;
import edu.wisc.sjm.machlearn.util.APRUtil;
import java.util.Vector;
import org.w3c.dom.Document;
import org.w3c.dom.Element;

/* loaded from: input_file:builds/machlearn_install.jar:machlearn.jar:edu/wisc/sjm/machlearn/policy/fdspreprocessor/discretize/BinDiscretize.class */
public class BinDiscretize extends FDSPreProcessor {
    int numBins;
    Vector<DoubleVector> thresholds;
    public static final int tuneIG = -1;
    public static final int tuneIGR = -2;
    public static final int tunePR = -3;
    protected boolean use_continuous;
    protected boolean doTune;
    protected int mode;
    public static DoubleVector probs = new DoubleVector();
    public static IntVector test_outputs = new IntVector();
    public static IntVector test_values = new IntVector();
    public static IntVector train_outputs = new IntVector();
    public static IntVector train_values = new IntVector();
    public static ProbabilityDiscrete pd = new ProbabilityDiscrete();
    public static DoubleVector temp_thresholds = new DoubleVector();
    public static DoubleVector temp_values = new DoubleVector();
    FeatureId[] fids;

    public BinDiscretize() {
        this.thresholds = new Vector<>();
        readProperties();
    }

    public BinDiscretize(int i) {
        this();
        setNumBins(i);
    }

    protected void readProperties() {
        setUseContinuous(PropertiesUtil.getBoolean("edu.wisc.sjm.machlearn.policy.fdspreprocessor.BinDiscretize.useContinuous", true));
        setNumBins(PropertiesUtil.getInt("edu.wisc.sjm.machlearn.policy.fdspreprocessor.BinDiscretize.numBins", 10));
    }

    @Override // edu.wisc.sjm.machlearn.policy.PreProcessor, edu.wisc.sjm.jutil.xml.XMLObject
    public Element toXML(Document document, Element element) {
        Element createElement = document.createElement("BinDiscretize");
        XMLUtil.setXMLValue(document, createElement, "NumFeatures", this.thresholds.size());
        XMLUtil.setXMLValue(document, createElement, "numBins", this.numBins);
        for (int i = 0; i < this.thresholds.size(); i++) {
            XMLUtil.setXMLValue(document, createElement, "Bin_" + i, this.thresholds.get(i).toArray());
        }
        element.appendChild(createElement);
        return createElement;
    }

    @Override // edu.wisc.sjm.machlearn.policy.PreProcessor, edu.wisc.sjm.jutil.xml.XMLObject
    public void fromXML(Element element) throws Exception {
        int xMLInteger = XMLUtil.getXMLInteger(element, "NumFeatures", 0);
        this.numBins = XMLUtil.getXMLInteger(element, "numBins", 0);
        this.thresholds.clear();
        for (int i = 0; i < xMLInteger; i++) {
            this.thresholds.add(new DoubleVector(XMLUtil.getXMLDoubleArray(element, "Bin_" + i, -1.0d)));
        }
    }

    public void setNumBins(int i) {
        if (i <= 0) {
            this.doTune = true;
        }
        this.numBins = i;
        System.out.println("number of bins is :" + i);
        System.out.println("doTune is:" + this.doTune);
    }

    public void setUseContinuous(boolean z) {
        this.use_continuous = z;
        System.out.println("BinDiscretize.setUseContinuous:" + this.use_continuous);
    }

    public void setUseContinuous(String str) {
        setUseContinuous(Boolean.parseBoolean(str));
    }

    public void setNumBins(String str) {
        setNumBins(Integer.parseInt(str));
    }

    public void setDoTune(boolean z) {
        this.doTune = z;
    }

    public void setDoTune(String str) {
        setDoTune(Boolean.parseBoolean(str));
    }

    public boolean getDoTune() {
        return this.doTune;
    }

    public int getNumBins() {
        return this.numBins;
    }

    @Override // edu.wisc.sjm.machlearn.policy.FDSPreProcessor
    public void train(FeatureDataSet featureDataSet) {
        System.currentTimeMillis();
        this.thresholds.clear();
        for (int i = 0; i < featureDataSet.numFeatures(); i++) {
            this.thresholds.add(new DoubleVector());
            if (featureDataSet.getFeatureId(i).isContinuous()) {
                if (this.doTune) {
                    try {
                        getThresholdsAPR(featureDataSet, i, this.thresholds.get(i));
                    } catch (Exception e) {
                        internalError(e);
                    }
                } else {
                    getThresholds(featureDataSet, i, this.numBins, this.thresholds.get(i));
                }
            }
        }
        buildFids(featureDataSet);
    }

    public void getThresholdsAPR(FeatureDataSet featureDataSet, int i, DoubleVector doubleVector) throws Exception {
        int tuneAPR = tuneAPR(featureDataSet, i);
        doubleVector.empty();
        if (tuneAPR != -1) {
            getThresholds(featureDataSet, i, tuneAPR, doubleVector);
        }
    }

    public double getContinuousAPR(DataSet[][] dataSetArr, int i) throws Exception {
        return APRUtil.getAPR(new ProbabilityGaussian((FeatureDataSet) dataSetArr[0][0], i), dataSetArr);
    }

    public int tuneAPR(FeatureDataSet featureDataSet, int i) throws Exception {
        int max = Math.max(2, (int) Math.sqrt(2 * JMath.min(featureDataSet.getOutputCounts())));
        System.out.println("max_bins:" + max);
        DataSet[][] splitDataSetFolds = featureDataSet.splitDataSetFolds(10, true, true);
        double continuousAPR = this.use_continuous ? getContinuousAPR(splitDataSetFolds, i) : Double.NEGATIVE_INFINITY;
        int i2 = -1;
        DoubleVector[] doubleVectorArr = new DoubleVector[10];
        IntVector[] intVectorArr = new IntVector[10];
        IntVector intVector = new IntVector();
        for (int i3 = 0; i3 < 10; i3++) {
            doubleVectorArr[i3] = new DoubleVector();
            intVectorArr[i3] = new IntVector();
            FeatureDataSet featureDataSet2 = (FeatureDataSet) splitDataSetFolds[i3][0];
            FeatureDataSet featureDataSet3 = (FeatureDataSet) splitDataSetFolds[i3][1];
            for (int i4 = 0; i4 < featureDataSet2.size(); i4++) {
                doubleVectorArr[i3].add(featureDataSet2.get(i4, i).getDValue());
                intVectorArr[i3].add(featureDataSet2.getOutputFeature(i4).getValueId());
            }
            for (int i5 = 0; i5 < featureDataSet3.size(); i5++) {
                intVector.add(featureDataSet3.getOutputFeature(i5).getValueId());
            }
            doubleVectorArr[i3] = DoubleVector.makeUnique(doubleVectorArr[i3]);
        }
        for (int i6 = 2; i6 <= max; i6++) {
            probs.empty();
            for (int i7 = 0; i7 < splitDataSetFolds.length; i7++) {
                FeatureDataSet featureDataSet4 = (FeatureDataSet) splitDataSetFolds[i7][0];
                FeatureDataSet featureDataSet5 = (FeatureDataSet) splitDataSetFolds[i7][1];
                if (doubleVectorArr[i7].size() > 1) {
                    calcThresholds(doubleVectorArr[i7], i6, temp_thresholds);
                } else {
                    temp_thresholds.empty();
                    temp_thresholds.add(doubleVectorArr[i7].get(0));
                }
                train_values.empty();
                train_outputs.empty();
                for (int i8 = 0; i8 < featureDataSet4.size(); i8++) {
                    int findBin = findBin(temp_thresholds, featureDataSet4.get(i8, i).getDValue());
                    if (findBin > i6) {
                        System.out.println(findBin + ">" + i6);
                    }
                    train_values.add(findBin);
                }
                pd.buildCPT(2, i6, false);
                pd.train(intVectorArr[i7], train_values);
                for (int i9 = 0; i9 < featureDataSet5.size(); i9++) {
                    int findBin2 = findBin(temp_thresholds, featureDataSet5.get(i9, i).getDValue());
                    double prob = pd.getProb(0, findBin2);
                    double prob2 = pd.getProb(1, findBin2);
                    double d = prob + prob2;
                    probs.add(d < 1.0E-10d ? 0.5d : prob2 / d);
                }
            }
            double apr = APRUtil.getAPR(probs, intVector);
            System.out.println("numbins:" + i6 + " apr:" + apr);
            if (apr > continuousAPR) {
                continuousAPR = apr;
                i2 = i6;
            }
        }
        System.out.print("Feature name:" + featureDataSet.getFeatureId(i).printName());
        System.out.print(" best_apr:" + continuousAPR);
        System.out.println(" best_bins:" + i2);
        return i2;
    }

    public void getThresholds(FeatureDataSet featureDataSet, int i, int i2, DoubleVector doubleVector) {
        DoubleVector doubleVector2 = new DoubleVector();
        for (int i3 = 0; i3 < featureDataSet.size(); i3++) {
            doubleVector2.add(featureDataSet.get(i3, i).getDValue());
        }
        calcThresholds(DoubleVector.makeUnique(doubleVector2), i2, doubleVector);
    }

    public void calcThresholds(DoubleVector doubleVector, int i, DoubleVector doubleVector2) {
        doubleVector2.empty();
        int size = doubleVector.size();
        int i2 = size / i;
        int i3 = size % i;
        int i4 = i2 - 1;
        if (i3 > 0) {
            i4++;
            i3--;
        }
        int i5 = i4 + 1;
        try {
            doubleVector2.add((doubleVector.get(i4) + doubleVector.get(i5)) / 2.0d);
            int i6 = 1;
            while (true) {
                if (i6 >= i - 1) {
                    break;
                }
                i4 += i2;
                if (i3 > 0) {
                    i4++;
                    i3--;
                }
                int i7 = i4 + 1;
                while (i7 < size && doubleVector.get(i4) == doubleVector.get(i7)) {
                    i4++;
                    i7++;
                }
                if (i7 < size) {
                    doubleVector2.add((doubleVector.get(i4) + doubleVector.get(i7)) / 2.0d);
                    i6++;
                } else if (!doubleVector2.contains(doubleVector.max(), 1.0E-4d)) {
                    doubleVector2.add(doubleVector.max());
                }
            }
            if (doubleVector2.size() >= i) {
                System.out.println("Warning, number of thresholds is too big:" + doubleVector2.size() + ">=" + i);
            }
        } catch (NullPointerException e) {
            System.out.println("nullptrexp:");
            System.out.println("n:" + size);
            System.out.println("nthres:" + i);
            System.out.println("left:" + i4);
            System.out.println("p:" + i2);
            System.out.println("r:" + i3);
            System.out.println("right:" + i5);
            throw e;
        }
    }

    public static int findBin(DoubleVector doubleVector, double d) {
        int search = doubleVector.search(d);
        if (search < doubleVector.size() && doubleVector.get(search) <= d) {
            while (search < doubleVector.size() && doubleVector.get(search) <= d) {
                search++;
            }
        }
        return search;
    }

    public Example process(Example example) {
        Example example2 = new Example(getFids(example));
        setExample(example2, example);
        return example2;
    }

    public FeatureId[] getFids(FeatureDataSet featureDataSet) {
        if (this.fids == null) {
            buildFids(featureDataSet);
        }
        return this.fids;
    }

    public FeatureId[] getFids(Example example) {
        if (this.fids == null) {
            buildFids(example);
        }
        return this.fids;
    }

    public void buildFids(FeatureDataSet featureDataSet) {
        buildFids(featureDataSet.getExample(0));
    }

    public void buildFids(Example example) {
        this.fids = new FeatureId[example.numFeatures()];
        for (int i = 0; i < example.numFeatures(); i++) {
            if (this.thresholds.get(i).size() != 0) {
                String[] strArr = new String[this.thresholds.get(i).size() + 1];
                for (int i2 = 0; i2 < this.thresholds.get(i).size(); i2++) {
                    strArr[i2] = "<" + this.thresholds.get(i).get(i2);
                }
                strArr[this.thresholds.get(i).size()] = ">=" + this.thresholds.get(i).get(this.thresholds.get(i).size() - 1);
                this.fids[i] = FeatureId.createDiscreteFeatureId("d_" + example.getFeatureId(i).printName(), strArr);
            } else {
                this.fids[i] = new FeatureId(example.getFeatureId(i));
            }
        }
    }

    @Override // edu.wisc.sjm.machlearn.policy.FDSPreProcessor
    public FeatureDataSet process(FeatureDataSet featureDataSet) {
        System.currentTimeMillis();
        FeatureDataSet featureDataSet2 = new FeatureDataSet(getFids(featureDataSet));
        featureDataSet2.expand(featureDataSet.size(), false);
        for (int i = 0; i < featureDataSet.size(); i++) {
            setExample(featureDataSet2.addExample(featureDataSet.getName(i)), featureDataSet.getExample(i));
        }
        return featureDataSet2;
    }

    public void setExample(Example example, Example example2) {
        for (int i = 0; i < example2.numFeatures(); i++) {
            try {
                DoubleVector doubleVector = this.thresholds.get(i);
                if (doubleVector.size() != 0) {
                    example.get(i).setValueId(findBin(doubleVector, example2.get(i).getDValue()));
                } else {
                    example.get(i).setValue(example2.get(i));
                }
            } catch (Exception e) {
                internalError(e);
            }
        }
    }

    @Override // edu.wisc.sjm.machlearn.policy.FDSPreProcessor
    public boolean needScorer() {
        return false;
    }

    @Override // edu.wisc.sjm.machlearn.policy.FDSPreProcessor
    public void setScorer(Scorer scorer) {
    }
}
