package edu.wisc.sjm.machlearn.regressors.knn;

import edu.wisc.sjm.jutil.misc.JParameters;
import edu.wisc.sjm.machlearn.classifiers.knn.HannibisDist;
import edu.wisc.sjm.machlearn.classifiers.knn.KNNComparator;
import edu.wisc.sjm.machlearn.classifiers.knn.KNNIndex;
import edu.wisc.sjm.machlearn.classifiers.knn.KNNScoreInterface;
import edu.wisc.sjm.machlearn.classifiers.knn.KNNUtil;
import edu.wisc.sjm.machlearn.dataset.Example;
import edu.wisc.sjm.machlearn.dataset.Feature;
import edu.wisc.sjm.machlearn.dataset.featuredataset.FeatureDataSet;
import edu.wisc.sjm.machlearn.exceptions.InvalidFeature;
import edu.wisc.sjm.machlearn.regressors.Regressor;
import java.util.Vector;
import weka.classifiers.lazy.kstar.KStarConstants;

/* loaded from: input_file:builds/machlearn_install.jar:machlearn.jar:edu/wisc/sjm/machlearn/regressors/knn/WKNNRegressor.class */
public class WKNNRegressor extends Regressor {
    FeatureDataSet data;
    int K;
    protected KNNScoreInterface score_object;
    protected int k_nearest;
    protected FeatureDataSet trainingData;
    protected KNNComparator knn_compare;
    protected KNNIndex[] scored_examples;
    protected Vector<KNNIndex> index_cache;
    protected int[] k_tune;

    public WKNNRegressor() {
        this(1);
    }

    public WKNNRegressor(int i) {
        this(i, new HannibisDist());
    }

    public WKNNRegressor(int i, KNNScoreInterface kNNScoreInterface) {
        this.K = 1;
        setK(i);
        this.score_object = kNNScoreInterface;
        this.knn_compare = new KNNComparator();
        this.index_cache = new Vector<>();
        this.k_tune = JParameters.getIntArray("edu.wisc.sjm.machlearn.classifiers.knn.KNN.KTune", new int[]{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20});
        if (this.k_tune != null) {
            System.out.println("KNN:tune values:" + this.k_tune.length);
        }
        setProperties();
    }

    public void setK(int i) {
        this.k_nearest = i;
        System.out.println("KNNRegressor: K is now " + this.k_nearest);
    }

    public void setK(String str) {
        setK(Integer.parseInt(str));
    }

    protected void scoreExamples(Example example) {
        for (int i = 0; i < this.data.size(); i++) {
            while (this.index_cache.size() - 1 < i) {
                this.index_cache.add(new KNNIndex(this.index_cache.size(), KStarConstants.FLOOR));
            }
            double score = getScore(example, this.data.getExample(i));
            this.scored_examples[i] = this.index_cache.get(i);
            this.scored_examples[i].init(i, score);
        }
    }

    protected void getNearest(Example example, int i) {
        if (i != 1) {
            scoreExamples(example);
            KNNUtil.findFirstK(this.scored_examples, 0, this.scored_examples.length - 1, i);
            return;
        }
        while (this.index_cache.size() - 1 < 0) {
            this.index_cache.add(new KNNIndex(this.index_cache.size(), KStarConstants.FLOOR));
        }
        KNNIndex kNNIndex = this.index_cache.get(0);
        double score = getScore(example, this.data.getExample(0));
        this.scored_examples[0] = kNNIndex;
        kNNIndex.init(0, score);
        for (int i2 = 1; i2 < this.data.size(); i2++) {
            if (getScore(example, this.data.getExample(i2)) < score) {
                kNNIndex.init(i2, score);
            }
        }
    }

    protected double getScore(Example example, Example example2) {
        double d = 0.0d;
        for (int i = 0; i < example.size(); i++) {
            if (i != example.getOutputIndex()) {
                Feature feature = example.getFeature(i);
                Feature feature2 = example2.getFeature(i);
                try {
                    d = feature.getType() == 0 ? d + this.score_object.getDiscreteDist(feature, feature2) : d + this.score_object.getContinuousDist(feature, feature2);
                } catch (InvalidFeature e) {
                    internalError(e);
                }
            }
        }
        return d;
    }

    @Override // edu.wisc.sjm.machlearn.regressors.Regressor
    public double regress(Example example) throws Exception {
        int min;
        if (this.k_nearest == -1) {
            min = this.scored_examples.length;
            scoreExamples(example);
        } else {
            min = Math.min(this.k_nearest, this.scored_examples.length);
            getNearest(example, min);
        }
        double d = 0.0d;
        double d2 = 0.0d;
        for (int i = 0; i < min; i++) {
            KNNIndex kNNIndex = this.scored_examples[i];
            double dist = 1.0d / (kNNIndex.getDist() + 1.0d);
            d += dist;
            d2 += dist * this.data.getExample(kNNIndex.getIndex()).getOutputFeature().getDValue();
        }
        return d2 / d;
    }

    @Override // edu.wisc.sjm.machlearn.regressors.Regressor
    public void train(FeatureDataSet featureDataSet) throws Exception {
        this.data = featureDataSet;
        this.scored_examples = new KNNIndex[this.data.size()];
    }
}
