package com.sjm.machlearn.classifiers.knn;

import com.sjm.machlearn.classifiers.Classifier;
import com.sjm.machlearn.dataset.DataSet;
import com.sjm.machlearn.dataset.Example;
import com.sjm.machlearn.dataset.Feature;
import com.sjm.machlearn.exceptions.InvalidFeature;
import com.sjm.machlearn.util.Util;
import java.util.Iterator;
import java.util.TreeSet;

/* loaded from: input_file:com/sjm/machlearn/classifiers/knn/KNN.class */
public class KNN extends Classifier {
    protected KNNScoreInterface score_object;
    protected int k_nearest;
    protected DataSet trainingData;

    public KNN() {
        this(1);
    }

    public KNN(int i) {
        this(i, new HannibisDist());
    }

    public KNN(int i, KNNScoreInterface kNNScoreInterface) {
        this.k_nearest = i;
        this.score_object = kNNScoreInterface;
    }

    @Override // com.sjm.machlearn.classifiers.Classifier
    public Feature classify(Example example) {
        Feature feature = (Feature) example.getOutputFeature().clone();
        getVote(feature, getNearest(example));
        return feature;
    }

    @Override // com.sjm.machlearn.classifiers.Classifier
    public Classifier cloneClassifier() {
        KNN knn = new KNN(this.k_nearest, this.score_object);
        knn.trainingData = this.trainingData;
        return knn;
    }

    public int getK() {
        return this.k_nearest;
    }

    protected KNNIndex[] getNearest(DataSet dataSet, Example example, int i) {
        TreeSet treeSet = new TreeSet(new KNNComparator());
        for (int i2 = 0; i2 < dataSet.size(); i2++) {
            if (!treeSet.add(new KNNIndex(i2, getScore(example, dataSet.get(i2))))) {
                System.out.println("Bug here");
                System.exit(-1);
            }
        }
        KNNIndex[] kNNIndexArr = new KNNIndex[i];
        Iterator it = treeSet.iterator();
        for (int i3 = 0; i3 < i; i3++) {
            kNNIndexArr[i3] = (KNNIndex) it.next();
        }
        return kNNIndexArr;
    }

    protected KNNIndex[] getNearest(Example example) {
        return getNearest(this.trainingData, example, this.k_nearest);
    }

    protected double getScore(Example example, Example example2) {
        double d = 0.0d;
        example.getOutputIndex();
        for (int i = 0; i < example.size(); i++) {
            if (i != example.getOutputIndex()) {
                Feature feature = example.getFeature(i);
                Feature feature2 = example2.getFeature(i);
                try {
                    d = feature.getType() == 0 ? d + this.score_object.getDiscreteDist(feature, feature2) : d + this.score_object.getContinuousDist(feature, feature2);
                } catch (InvalidFeature e) {
                    internalError(e);
                }
            }
        }
        return d;
    }

    protected void getVote(DataSet dataSet, Feature feature, KNNIndex[] kNNIndexArr, int i) {
        int[] iArr = new int[feature.numValues()];
        for (int i2 = 0; i2 < i; i2++) {
            int valueId = dataSet.get(kNNIndexArr[i2].getIndex()).getOutputFeature().getValueId();
            iArr[valueId] = iArr[valueId] + 1;
        }
        try {
            feature.setValue(Util.argmax(iArr));
        } catch (InvalidFeature e) {
            internalError(e);
        }
    }

    protected void getVote(Feature feature, KNNIndex[] kNNIndexArr) {
        int[] iArr = new int[feature.numValues()];
        for (KNNIndex kNNIndex : kNNIndexArr) {
            int valueId = this.trainingData.get(kNNIndex.getIndex()).getOutputFeature().getValueId();
            iArr[valueId] = iArr[valueId] + 1;
        }
        try {
            feature.setValue(Util.argmax(iArr));
        } catch (InvalidFeature e) {
            internalError(e);
        }
    }

    @Override // com.sjm.machlearn.classifiers.Classifier
    public String printClassifier() {
        StringBuffer stringBuffer = new StringBuffer();
        stringBuffer.append("KNN\n");
        stringBuffer.append(new StringBuffer("k=").append(this.k_nearest).toString());
        return stringBuffer.toString();
    }

    public void setK(int i) {
        this.k_nearest = i;
    }

    @Override // com.sjm.machlearn.classifiers.Classifier
    public void setParameter(int i, Object obj) {
        setK(((Integer) obj).intValue());
    }

    @Override // com.sjm.machlearn.classifiers.Classifier
    public void train(DataSet dataSet) {
        this.trainingData = dataSet;
    }

    public void tune(DataSet dataSet, Object[] objArr) {
        int[] iArr = (int[]) objArr[0];
        int min = Util.min(Util.max(iArr), dataSet.size() - 1);
        int[] iArr2 = new int[iArr.length];
        KNNIndex[] kNNIndexArr = new KNNIndex[min];
        Feature feature = (Feature) dataSet.getOutputFeature().clone();
        for (int i = 0; i < dataSet.size(); i++) {
            DataSet[] splitJackKnife = dataSet.splitJackKnife(i);
            DataSet dataSet2 = splitJackKnife[0];
            DataSet dataSet3 = splitJackKnife[1];
            KNNIndex[] nearest = getNearest(dataSet3, dataSet2.get(0), min);
            Feature outputFeature = dataSet2.get(0).getOutputFeature();
            for (int i2 = 0; i2 < iArr.length; i2++) {
                getVote(dataSet3, feature, nearest, Util.min(iArr[i2], min));
                if (feature.getValue().equals(outputFeature.getValue())) {
                    int i3 = i2;
                    iArr2[i3] = iArr2[i3] + 1;
                }
            }
        }
        this.k_nearest = iArr[Util.argmax(iArr2)];
        train(dataSet);
    }
}
