package com.sjm.machlearn.classifiers.neuralnets;

import com.sjm.machlearn.classifiers.Classifier;
import com.sjm.machlearn.dataset.DataSet;
import com.sjm.machlearn.dataset.Example;
import com.sjm.machlearn.dataset.Feature;
import com.sjm.machlearn.dataset.FeatureIdList;
import com.sjm.machlearn.exceptions.InvalidFeature;
import com.sjm.machlearn.util.DoubleMatrix;
import com.sjm.machlearn.util.Util;
import java.io.PrintWriter;

/* JADX WARN: Classes with same name are omitted:
  
 */
/* loaded from: input_file:com/sjm/machlearn/classifiers/neuralnets/BPNeuralNet.class */
public class BPNeuralNet extends Classifier {
    protected DoubleMatrix whidden;
    protected DoubleMatrix woutput;
    protected BPFeatureVector feature_converter;
    protected int ninputs;
    protected int noutputs;
    protected int nhidden;
    protected double learn_rate;
    protected int trainepoch;
    private DataSet test_set;
    private PrintWriter accfile;
    private FeatureIdList fid;
    private double hidden_percentage;

    public BPNeuralNet() {
        this(50.0d);
    }

    public BPNeuralNet(double d) {
        this.hidden_percentage = d;
        this.learn_rate = 0.1d;
    }

    public double activate(double d) {
        return 1.0d / (1.0d + Math.exp(-d));
    }

    @Override // com.sjm.machlearn.classifiers.Classifier
    public Feature classify(Example example) {
        double[] output = getOutput(example);
        Feature feature = (Feature) example.getOutputFeature().clone();
        try {
            feature.setValue(Util.argmax(output));
        } catch (InvalidFeature e) {
            internalError(e);
        }
        return feature;
    }

    @Override // com.sjm.machlearn.classifiers.Classifier
    public Classifier cloneClassifier() {
        return new BPNeuralNet(this.hidden_percentage);
    }

    public double getAccuracy(double[][] dArr, int[] iArr) {
        double d = 0.0d;
        for (int i = 0; i < iArr.length; i++) {
            if (Util.argmax(getOutput(dArr[i])) == iArr[i]) {
                d += 1.0d;
            }
        }
        return d / iArr.length;
    }

    public double[] getError(double[] dArr, double[] dArr2) {
        double[] output = getOutput(dArr);
        for (int i = 0; i < this.noutputs; i++) {
            output[i] = dArr2[i] - output[i];
        }
        return output;
    }

    @Override // com.sjm.machlearn.classifiers.Classifier
    public double getExampleWeight(Example example) {
        double[] output = getOutput(example);
        return output[0] - output[1];
    }

    protected double[] getOutput(Example example) {
        return getOutput(this.feature_converter.convert(example));
    }

    protected double[] getOutput(double[] dArr) {
        double[] dArr2 = new double[this.noutputs];
        double[] dArr3 = new double[this.nhidden];
        for (int i = 0; i < this.nhidden; i++) {
            double d = this.whidden.get(0, i);
            for (int i2 = 1; i2 <= this.ninputs; i2++) {
                d += this.whidden.get(i2, i) * dArr[i2 - 1];
            }
            dArr3[i] = activate(d);
        }
        for (int i3 = 0; i3 < this.noutputs; i3++) {
            double d2 = this.woutput.get(0, i3);
            for (int i4 = 1; i4 <= this.nhidden; i4++) {
                d2 += this.woutput.get(i4, i3) * dArr3[i4 - 1];
            }
            dArr2[i3] = activate(d2);
        }
        return dArr2;
    }

    public double getTotalError(double[] dArr) {
        double d = 0.0d;
        for (int i = 0; i < this.noutputs; i++) {
            d += dArr[i] * dArr[i];
        }
        return Math.sqrt(d);
    }

    public double getTotalError(double[] dArr, double[] dArr2) {
        return getTotalError(getError(dArr, dArr2));
    }

    public void initializeWeights() {
        initializeWeights(-0.3d, 0.3d);
    }

    public void initializeWeights(double d, double d2) {
        for (int i = 0; i <= this.ninputs; i++) {
            Util.randomizeDbl(this.whidden.get(i), d, d2);
        }
        for (int i2 = 0; i2 <= this.nhidden; i2++) {
            Util.randomizeDbl(this.woutput.get(i2), d, d2);
        }
    }

    @Override // com.sjm.machlearn.classifiers.Classifier
    public String printClassifier() {
        return "BPNeuralNet\n";
    }

    public void reportAccuracies(int i, DataSet dataSet, DataSet dataSet2) {
        double accuracy = getAccuracy(dataSet);
        double accuracy2 = getAccuracy(dataSet2);
        System.out.println("=============================");
        System.out.println(new StringBuffer("EPOCH #").append(i).toString());
        System.out.println("=============================");
        System.out.println(new StringBuffer("Train Accuracy:").append(accuracy).toString());
        System.out.println(new StringBuffer("Tune Accuracy:").append(accuracy2).toString());
        System.out.println(new StringBuffer("Learning Rate:").append(this.learn_rate).toString());
        this.accfile.println(new StringBuffer().append(i).append("\t").append(accuracy).append("\t").append(accuracy2).toString());
        this.accfile.flush();
    }

    public void runEpoch(double[][] dArr, double[][] dArr2) throws Exception {
        int[] randomIntList = Util.randomIntList(0, dArr.length - 1);
        for (int i = 0; i < randomIntList.length; i++) {
            trainExample(dArr[randomIntList[i]], dArr2[randomIntList[i]]);
        }
    }

    @Override // com.sjm.machlearn.classifiers.Classifier
    public void setParameter(int i, Object obj) {
        this.hidden_percentage = ((Integer) obj).intValue();
        this.hidden_percentage = Util.max(50.0d, this.hidden_percentage);
    }

    public void setTestSet(DataSet dataSet) {
        this.test_set = dataSet;
    }

    @Override // com.sjm.machlearn.classifiers.Classifier
    public void train(DataSet dataSet) throws Exception {
        this.feature_converter = new BPFeatureVector(dataSet, true);
        this.ninputs = this.feature_converter.size();
        this.noutputs = dataSet.getOutputFeatureId().numValues();
        this.nhidden = this.ninputs + this.noutputs;
        this.nhidden = (int) (this.nhidden * (this.hidden_percentage / 100.0d));
        this.whidden = new DoubleMatrix(this.ninputs + 1, this.nhidden);
        this.woutput = new DoubleMatrix(this.nhidden + 1, this.noutputs);
        DataSet[] splitRandom = dataSet.splitRandom(10.0d);
        train(splitRandom[1], splitRandom[0]);
    }

    public void train(DataSet dataSet, DataSet dataSet2) throws Exception {
        DataSet dataSet3;
        DataSet dataSet4 = dataSet;
        if (dataSet2.size() / dataSet.size() < 0.09d) {
            System.out.println("Warning tune set is < 0.09% of train set");
            System.out.println("Regenerating tune set");
            DataSet[] splitRandom = dataSet.splitRandom(10.0d);
            dataSet4 = splitRandom[1];
            dataSet3 = splitRandom[0];
        } else {
            dataSet3 = dataSet2;
        }
        initializeWeights();
        double accuracy = getAccuracy(dataSet3);
        DoubleMatrix copyMatrix = this.whidden.copyMatrix();
        DoubleMatrix copyMatrix2 = this.woutput.copyMatrix();
        double[][] convert = this.feature_converter.convert(dataSet4);
        double[][] outputValues = this.feature_converter.getOutputValues(dataSet4);
        double[][] convert2 = this.feature_converter.convert(dataSet3);
        int[] outputValueIds = dataSet3.getOutputValueIds();
        for (int i = 1; i <= 100; i++) {
            runEpoch(convert, outputValues);
            double accuracy2 = getAccuracy(convert2, outputValueIds);
            if (accuracy2 > accuracy) {
                accuracy = accuracy2;
                copyMatrix = this.whidden.copyMatrix();
                copyMatrix2 = this.woutput.copyMatrix();
            }
        }
        this.whidden = copyMatrix;
        this.woutput = copyMatrix2;
    }

    public void trainExample(double[] dArr, double[] dArr2) {
        double[] dArr3 = new double[this.nhidden];
        double[] dArr4 = new double[this.noutputs];
        for (int i = 0; i < this.nhidden; i++) {
            double d = this.whidden.get(0, i);
            for (int i2 = 1; i2 <= this.ninputs; i2++) {
                d += this.whidden.get(i2, i) * dArr[i2 - 1];
            }
            dArr3[i] = activate(d);
        }
        for (int i3 = 0; i3 < this.noutputs; i3++) {
            double d2 = this.woutput.get(0, i3);
            for (int i4 = 1; i4 <= this.nhidden; i4++) {
                d2 += this.woutput.get(i4, i3) * dArr3[i4 - 1];
            }
            dArr4[i3] = activate(d2);
        }
        double[] dArr5 = new double[this.noutputs];
        double[] dArr6 = new double[this.nhidden + 1];
        for (int i5 = 0; i5 < this.noutputs; i5++) {
            dArr5[i5] = dArr4[i5] * (1.0d - dArr4[i5]) * (dArr2[i5] - dArr4[i5]);
        }
        double totalError = getTotalError(dArr, dArr2);
        double d3 = 0.0d;
        for (int i6 = 0; i6 < this.noutputs; i6++) {
            d3 += this.woutput.get(0, i6) * dArr5[i6];
        }
        dArr6[0] = d3 * 0.0d;
        for (int i7 = 1; i7 <= this.nhidden; i7++) {
            double d4 = 0.0d;
            for (int i8 = 0; i8 < this.noutputs; i8++) {
                d4 += this.woutput.get(i7, i8) * dArr5[i8];
            }
            dArr6[i7] = d4 * dArr3[i7 - 1] * (1.0d - dArr3[i7 - 1]);
        }
        DoubleMatrix zeroMatrix = this.woutput.zeroMatrix();
        DoubleMatrix zeroMatrix2 = this.whidden.zeroMatrix();
        for (int i9 = 0; i9 < this.noutputs; i9++) {
            zeroMatrix.set(0, i9, this.learn_rate * dArr5[i9]);
            for (int i10 = 1; i10 <= this.nhidden; i10++) {
                zeroMatrix.set(i10, i9, this.learn_rate * dArr5[i9] * dArr3[i10 - 1]);
            }
        }
        for (int i11 = 0; i11 < this.nhidden; i11++) {
            zeroMatrix2.set(0, i11, this.learn_rate * dArr6[i11]);
            for (int i12 = 1; i12 <= this.ninputs; i12++) {
                zeroMatrix2.set(i12, i11, this.learn_rate * dArr6[i11] * dArr[i12 - 1]);
            }
        }
        this.whidden.add(zeroMatrix2);
        this.woutput.add(zeroMatrix);
        if (getTotalError(dArr, dArr2) >= totalError) {
            this.learn_rate *= 0.9d;
        } else {
            this.learn_rate = Util.min(0.5d, this.learn_rate * 1.1d);
        }
    }

    public void tune(DataSet dataSet, Object[] objArr) {
    }
}
