package edu.wisc.sjm.machlearn.classifiers.knn;

import edu.wisc.sjm.jutil.misc.IndexScoreObject;
import edu.wisc.sjm.jutil.misc.MainClass;
import edu.wisc.sjm.jutil.misc.Util;
import edu.wisc.sjm.jutil.vectors.ObjectHeap;
import edu.wisc.sjm.machlearn.classifiers.Classifier;
import edu.wisc.sjm.machlearn.dataset.DataSet;
import edu.wisc.sjm.machlearn.dataset.Example;
import edu.wisc.sjm.machlearn.dataset.Feature;
import edu.wisc.sjm.machlearn.dataset.featuredataset.FeatureDataSet;
import edu.wisc.sjm.machlearn.exceptions.InvalidFeature;
import edu.wisc.sjm.machlearn.util.APRUtil;

/* loaded from: input_file:builds/machlearn_install.jar:machlearn.jar:edu/wisc/sjm/machlearn/classifiers/knn/WeightedKNN.class */
public class WeightedKNN extends Classifier {
    protected KNNScoreInterface score_object;
    protected FeatureDataSet trainingData;
    protected double alpha;
    protected double positive_weight;
    protected boolean doTune;
    protected int nearest_k;
    protected boolean usePositiveWeight;

    public void setK(int i) {
        this.nearest_k = i;
    }

    public void setUsePositiveWeight(boolean z) {
        this.usePositiveWeight = z;
    }

    public void setDoTune(boolean z) {
        this.doTune = z;
    }

    public WeightedKNN() {
        this(1.0d);
    }

    public WeightedKNN(int i) {
        this(1.0d, new HannibisDist(), i);
    }

    public WeightedKNN(double d) {
        this(d, new HannibisDist());
    }

    public WeightedKNN(double d, KNNScoreInterface kNNScoreInterface) {
        this(d, kNNScoreInterface, -1);
    }

    public WeightedKNN(double d, KNNScoreInterface kNNScoreInterface, int i) {
        this.alpha = d;
        this.score_object = kNNScoreInterface;
        this.nearest_k = i;
        setUsePositiveWeight(false);
    }

    @Override // edu.wisc.sjm.machlearn.classifiers.Classifier
    public void train_(FeatureDataSet featureDataSet) {
        this.trainingData = featureDataSet;
        if (this.usePositiveWeight) {
            int[] outputCounts = featureDataSet.getOutputCounts();
            this.positive_weight = (outputCounts[0] + 1) / (outputCounts[1] + 1);
        } else {
            this.positive_weight = 1.0d;
        }
        if (this.doTune) {
            tune(featureDataSet);
        }
    }

    public void tune(FeatureDataSet featureDataSet) {
        if (featureDataSet.size() < 10) {
            this.nearest_k = -1;
            return;
        }
        try {
            DataSet[][] splitDataSetFolds = featureDataSet.splitDataSetFolds(10, true, true);
            int size = splitDataSetFolds[0][0].size();
            this.nearest_k = -1;
            int i = -1;
            double apr = APRUtil.getAPR(this, splitDataSetFolds);
            for (int i2 = 1; i2 < size; i2++) {
                this.nearest_k = i2;
                double apr2 = APRUtil.getAPR(this, splitDataSetFolds);
                if (apr2 > apr) {
                    i = i2;
                    apr = apr2;
                }
                System.out.println("k:" + i2 + " score:" + apr2);
            }
            System.out.println("best k:" + i);
            System.out.println("best score:" + apr);
            this.nearest_k = i;
        } catch (Exception e) {
            internalError(e);
        }
    }

    @Override // edu.wisc.sjm.machlearn.classifiers.Classifier
    public Feature classify_(Example example) throws Exception {
        double[] distribution = getDistribution(example);
        Feature feature = (Feature) example.getOutputFeature().clone();
        try {
            feature.setValueId(Util.argmax(distribution));
        } catch (InvalidFeature e) {
            internalError(e);
        }
        return feature;
    }

    @Override // edu.wisc.sjm.machlearn.classifiers.Classifier
    public void getDistribution(Example example, double[] dArr) throws Exception {
        if (this.nearest_k == -1) {
            getDistributionAll(process(example), dArr);
        } else {
            getDistributionNearest(process(example), dArr);
        }
    }

    public void getDistributionAll(Example example, double[] dArr) {
        example.getOutputFeatureId().numValues();
        dArr[0] = 1.0d;
        dArr[1] = 1.0d;
        for (int i = 0; i < this.trainingData.size(); i++) {
            Example example2 = this.trainingData.getExample(i);
            double scoreEuc = this.alpha / (getScoreEuc(example, example2) + this.alpha);
            int valueId = example2.getOutputFeature().getValueId();
            dArr[valueId] = dArr[valueId] + scoreEuc;
        }
        dArr[1] = dArr[1] * this.positive_weight;
        Util.normalize(dArr);
    }

    protected void getDistributionNearest(Example example, double[] dArr) {
        example.getOutputFeatureId().numValues();
        ObjectHeap objectHeap = new ObjectHeap();
        for (int i = 0; i < this.trainingData.size(); i++) {
            objectHeap.add(new IndexScoreObject(i, getScoreEuc(example, this.trainingData.getExample(i))));
        }
        int min = Math.min(this.nearest_k, objectHeap.size());
        dArr[0] = 1.0d;
        dArr[1] = 1.0d;
        for (int i2 = 0; i2 < min; i2++) {
            IndexScoreObject indexScoreObject = (IndexScoreObject) objectHeap.popMin();
            double score = indexScoreObject.getScore();
            int valueId = this.trainingData.getOutputFeature(indexScoreObject.getIndex()).getValueId();
            dArr[valueId] = dArr[valueId] + (this.alpha / (score + this.alpha));
        }
        Util.normalize(dArr);
    }

    protected double getScoreEuc(Example example, Example example2) {
        double d = 0.0d;
        for (int i = 0; i < example.numFeatures(); i++) {
            if (i != example.getOutputIndex()) {
                double dValue = example.get(i).getDValue() - example2.get(i).getDValue();
                d += dValue * dValue;
            }
        }
        return d;
    }

    protected double getScore(Example example, Example example2) {
        return getScore(example, example2, this.score_object);
    }

    public static double getScore(Example example, Example example2, KNNScoreInterface kNNScoreInterface) {
        double d = 0.0d;
        example.getOutputIndex();
        for (int i = 0; i < example.size(); i++) {
            if (i != example.getOutputIndex()) {
                Feature feature = example.getFeature(i);
                Feature feature2 = example2.getFeature(i);
                try {
                    d = feature.getType() == 0 ? d + kNNScoreInterface.getDiscreteDist(feature, feature2) : d + kNNScoreInterface.getContinuousDist(feature, feature2);
                } catch (InvalidFeature e) {
                    MainClass._internalError(e);
                }
            }
        }
        return d;
    }

    @Override // edu.wisc.sjm.machlearn.classifiers.Classifier
    public String printClassifier() {
        StringBuffer stringBuffer = new StringBuffer();
        stringBuffer.append("WeightedKNN\n");
        return stringBuffer.toString();
    }

    @Override // edu.wisc.sjm.machlearn.classifiers.Classifier
    public Classifier cloneClassifier() {
        WeightedKNN weightedKNN = new WeightedKNN(this.alpha, this.score_object);
        weightedKNN.trainingData = this.trainingData;
        return weightedKNN;
    }

    @Override // edu.wisc.sjm.machlearn.classifiers.Classifier
    public void setParameter(int i, Object obj) {
    }
}
