package edu.wisc.sjm.machlearn.classifiers.lvq;

import edu.wisc.sjm.machlearn.classifiers.Classifier;
import edu.wisc.sjm.machlearn.classifiers.neuralnets.BPFeatureVector;
import edu.wisc.sjm.machlearn.dataset.Example;
import edu.wisc.sjm.machlearn.dataset.Feature;
import edu.wisc.sjm.machlearn.dataset.featuredataset.FeatureDataSet;
import edu.wisc.sjm.machlearn.util.DoubleMatrix;
import weka.classifiers.lazy.kstar.KStarConstants;

/* loaded from: input_file:builds/machlearn_install.jar:machlearn.jar:edu/wisc/sjm/machlearn/classifiers/lvq/LVQ.class */
public abstract class LVQ extends Classifier {
    protected int nweights;
    protected int nvectors;
    protected DoubleMatrix code_vectors;
    protected int[] code_class_id;
    protected double alpha;
    protected int max_epochs;
    protected BPFeatureVector converter;

    public abstract void runEpoch(double[][] dArr, int[] iArr);

    public abstract void initializeLVQ();

    public void runEpoch(FeatureDataSet featureDataSet) {
        double[][] convert = this.converter.convert(featureDataSet);
        int[] iArr = new int[featureDataSet.size()];
        for (int i = 0; i < featureDataSet.size(); i++) {
            iArr[i] = featureDataSet.getExample(i).getOutputFeature().getValueId();
        }
        runEpoch(convert, iArr);
    }

    public LVQ() {
        this(6);
    }

    public LVQ(int i) {
        this.alpha = 0.05d;
        this.nvectors = i;
        this.max_epochs = 100;
    }

    public void setAlpha(double d) {
        this.alpha = d;
    }

    public void setMaxEpochs(int i) {
        this.max_epochs = i;
    }

    public void setNumVectors(int i) {
        this.nvectors = i;
        this.code_vectors = null;
        this.code_class_id = null;
    }

    @Override // edu.wisc.sjm.machlearn.classifiers.Classifier
    public Feature classify_(Example example) {
        int nearest = getNearest(example);
        Feature feature = (Feature) example.getOutputFeature().clone();
        try {
            feature.setValue(this.code_class_id[nearest]);
        } catch (Exception e) {
            internalError(e);
        }
        return feature;
    }

    public void initializeVectors(int i) {
        this.nweights = i;
        this.code_vectors = new DoubleMatrix(this.nvectors, this.nweights);
        this.code_class_id = new int[this.nvectors];
        this.code_vectors.randomize(KStarConstants.FLOOR, 1.0d);
        for (int i2 = 0; i2 < this.nvectors / 2; i2++) {
            this.code_class_id[i2] = 0;
        }
        for (int i3 = this.nvectors / 2; i3 < this.nvectors; i3++) {
            this.code_class_id[i3] = 1;
        }
    }

    @Override // edu.wisc.sjm.machlearn.classifiers.Classifier
    public void train_(FeatureDataSet featureDataSet) throws Exception {
        this.converter = new BPFeatureVector(featureDataSet, true);
        FeatureDataSet[] featureDataSetArr = (FeatureDataSet[]) featureDataSet.splitRandom(10.0d);
        FeatureDataSet featureDataSet2 = featureDataSetArr[1];
        FeatureDataSet featureDataSet3 = featureDataSetArr[0];
        double[][] convert = this.converter.convert(featureDataSetArr[1]);
        int[] iArr = new int[featureDataSetArr[1].size()];
        for (int i = 0; i < iArr.length; i++) {
            iArr[i] = featureDataSetArr[1].getExample(i).getOutputFeature().getValueId();
        }
        initializeVectors(this.converter.size());
        initializeLVQ();
        DoubleMatrix copyMatrix = this.code_vectors.copyMatrix();
        double accuracy = getAccuracy(featureDataSet3);
        for (int i2 = 0; i2 < 5; i2++) {
            initializeVectors(this.converter.size());
            initializeLVQ();
            for (int i3 = 1; i3 <= this.max_epochs; i3++) {
                runEpoch(convert, iArr);
                double accuracy2 = getAccuracy(featureDataSet3);
                if (accuracy2 > accuracy) {
                    accuracy = accuracy2;
                    copyMatrix = this.code_vectors.copyMatrix();
                }
            }
        }
        this.code_vectors = copyMatrix;
    }

    public double[] getDiff(double[] dArr, int i) {
        double[] dArr2 = new double[dArr.length];
        for (int i2 = 0; i2 < dArr.length; i2++) {
            dArr2[i2] = dArr[i2] - this.code_vectors.get(i, i2);
        }
        return dArr2;
    }

    public double[] getDiff(Example example, int i) {
        return getDiff(this.converter.convert(example), i);
    }

    public int getNearest(double[] dArr) {
        double dist = getDist(dArr, this.code_vectors.get(0));
        int i = 0;
        for (int i2 = 1; i2 < this.code_vectors.numRow(); i2++) {
            double dist2 = getDist(dArr, this.code_vectors.get(i2));
            if (dist2 < dist) {
                dist = dist2;
                i = i2;
            }
        }
        return i;
    }

    public int getNearest(Example example) {
        return getNearest(this.converter.convert(example));
    }

    public double getDist(Example example, double[] dArr) {
        return getDist(this.converter.convert(example), dArr);
    }

    public double getDist(double[] dArr, double[] dArr2) {
        double d = 0.0d;
        for (int i = 0; i < dArr.length; i++) {
            d += (dArr[i] - dArr2[i]) * (dArr[i] - dArr2[i]);
        }
        return Math.sqrt(d);
    }

    @Override // edu.wisc.sjm.machlearn.classifiers.Classifier
    public void setParameter(int i, Object obj) {
        setNumVectors(((Integer) obj).intValue());
    }
}
