package edu.wisc.sjm.machlearn.classifiers.ppc;

import edu.wisc.sjm.jutil.matrices.DoubleMatrix;
import edu.wisc.sjm.machlearn.classifiers.Classifier;
import edu.wisc.sjm.machlearn.dataset.DataSet;
import edu.wisc.sjm.machlearn.dataset.Example;
import edu.wisc.sjm.machlearn.dataset.Feature;
import edu.wisc.sjm.machlearn.dataset.featuredataset.FeatureDataSet;
import edu.wisc.sjm.machlearn.exceptions.parameters.ParameterException;
import edu.wisc.sjm.machlearn.parameters.Parameter;
import edu.wisc.sjm.machlearn.parameters.ParameterSupportObject;
import edu.wisc.sjm.machlearn.policy.fdspreprocessor.selection.filter.ShrunkProbFilterAbsolute;
import edu.wisc.sjm.machlearn.util.Util;
import java.util.StringTokenizer;
import weka.classifiers.lazy.kstar.KStarConstants;

/* loaded from: input_file:builds/machlearn_install.jar:machlearn.jar:edu/wisc/sjm/machlearn/classifiers/ppc/NearestShrunkenProbs.class */
public class NearestShrunkenProbs extends Classifier implements ParameterSupportObject {
    protected DoubleMatrix pik;
    protected double delta;
    protected boolean do_tune;
    protected int tune_folds;
    protected double[] tune_values;

    public NearestShrunkenProbs() {
        this(KStarConstants.FLOOR, false, 10);
    }

    public NearestShrunkenProbs(double d, boolean z, int i) {
        this.delta = d;
        this.do_tune = z;
        this.tune_folds = i;
        this.pik = new DoubleMatrix(1, 2);
    }

    public void setDelta(double d) {
        this.delta = d;
    }

    public void setDelta(String str) {
        setDelta(Double.parseDouble(str));
    }

    public void setTuneValues(String str) {
        StringTokenizer stringTokenizer = new StringTokenizer(str, "$");
        int countTokens = stringTokenizer.countTokens();
        this.tune_values = new double[countTokens];
        for (int i = 0; i < countTokens; i++) {
            this.tune_values[i] = Double.parseDouble(stringTokenizer.nextToken());
            System.out.println(i + ")" + this.tune_values[i]);
        }
        this.do_tune = true;
        this.tune_folds = 10;
    }

    public void setTuneFolds(int i) {
        this.tune_folds = i;
    }

    public void setTuneFolds(String str) {
        setTuneFolds(Integer.parseInt(str));
    }

    public void createPik(FeatureDataSet featureDataSet) {
        this.pik.resize(featureDataSet.numFeatures(), 2);
        double[] dArr = new double[2];
        for (int i = 0; i < featureDataSet.numFeatures(); i++) {
            ShrunkProbFilterAbsolute.getProbs(featureDataSet, i, this.delta, dArr);
            this.pik.set(i, 0, dArr[0]);
            this.pik.set(i, 1, dArr[1]);
        }
    }

    @Override // edu.wisc.sjm.machlearn.classifiers.Classifier
    public void train_(FeatureDataSet featureDataSet) {
        if (this.do_tune) {
            System.out.println("Tuning delta with:" + this.tune_folds + " cv-folds");
            tuneDelta(featureDataSet);
        }
        createPik(featureDataSet);
    }

    @Override // edu.wisc.sjm.machlearn.classifiers.Classifier
    public Feature classify_(Example example) {
        Feature feature = (Feature) example.getOutputFeature().clone();
        double d = 0.0d;
        double d2 = 0.0d;
        for (int i = 0; i < example.numFeatures(); i++) {
            try {
                if (i != example.getOutputIndex()) {
                    double dValue = example.get(i).getDValue() - this.pik.get(i, 0);
                    double dValue2 = example.get(i).getDValue() - this.pik.get(i, 1);
                    d += dValue * dValue;
                    d2 += dValue2 * dValue2;
                }
            } catch (Exception e) {
                internalError(e);
            }
        }
        if (d2 < d) {
            feature.setValueId(1);
        } else {
            feature.setValueId(0);
        }
        return feature;
    }

    @Override // edu.wisc.sjm.machlearn.classifiers.Classifier
    public double[] getDistribution(Example example) {
        double d = 0.0d;
        double d2 = 0.0d;
        for (int i = 0; i < example.numFeatures(); i++) {
            if (i != example.getOutputIndex()) {
                double dValue = example.get(i).getDValue() - this.pik.get(i, 0);
                double dValue2 = example.get(i).getDValue() - this.pik.get(i, 1);
                d += dValue * dValue;
                d2 += dValue2 * dValue2;
            }
        }
        double d3 = d + d2;
        return new double[]{1.0d - (d / d3), 1.0d - (d2 / d3)};
    }

    public void tuneDelta(FeatureDataSet featureDataSet) {
        try {
            DataSet[][] splitDataSetFolds = featureDataSet.splitDataSetFolds(this.tune_folds, true);
            double[] dArr = new double[this.tune_values.length];
            for (int i = 0; i < this.tune_values.length; i++) {
                setDelta(this.tune_values[i]);
                for (int i2 = 0; i2 < splitDataSetFolds.length; i2++) {
                    createPik((FeatureDataSet) splitDataSetFolds[i2][0]);
                    int i3 = i;
                    dArr[i3] = dArr[i3] + getAccuracy(splitDataSetFolds[i2][1]);
                }
            }
            System.out.println("best delta score:" + (Util.max(dArr) / splitDataSetFolds.length));
            System.out.println("best delta:" + this.tune_values[Util.argmaxr(dArr)]);
            setDelta(this.tune_values[Util.argmaxr(dArr)]);
        } catch (Exception e) {
            internalError(e);
        }
    }

    @Override // edu.wisc.sjm.machlearn.classifiers.Classifier
    public String printClassifier() {
        StringBuffer stringBuffer = new StringBuffer();
        stringBuffer.append("KNN\n");
        return stringBuffer.toString();
    }

    @Override // edu.wisc.sjm.machlearn.classifiers.Classifier
    public Classifier cloneClassifier() {
        return null;
    }

    @Override // edu.wisc.sjm.machlearn.classifiers.Classifier
    public void setParameter(int i, Object obj) {
    }

    @Override // edu.wisc.sjm.machlearn.parameters.ParameterSupportObject
    public void setParameter(int i, String str) throws ParameterException {
        throw new ParameterException();
    }

    @Override // edu.wisc.sjm.machlearn.parameters.ParameterSupportObject
    public Parameter getParameter(int i) throws ParameterException {
        throw new ParameterException("Unsupported type!");
    }

    @Override // edu.wisc.sjm.machlearn.parameters.ParameterSupportObject
    public String getParameterValue(int i) {
        return null;
    }

    @Override // edu.wisc.sjm.machlearn.parameters.ParameterSupportObject
    public int numParameters() {
        return 1;
    }

    @Override // edu.wisc.sjm.machlearn.parameters.ParameterSupportObject
    public String getPSOName() {
        return "";
    }
}
