package edu.wisc.sjm.machlearn.regressors.neuralnets;

import edu.wisc.sjm.jutil.misc.MainClass;
import edu.wisc.sjm.machlearn.dataset.DataSet;
import edu.wisc.sjm.machlearn.dataset.Example;
import edu.wisc.sjm.machlearn.dataset.Feature;
import edu.wisc.sjm.machlearn.dataset.FeatureIdList;
import edu.wisc.sjm.machlearn.dataset.featuredataset.FeatureDataSet;
import edu.wisc.sjm.machlearn.exceptions.InvalidFeature;
import edu.wisc.sjm.machlearn.util.DoubleMatrix;
import edu.wisc.sjm.machlearn.util.Util;
import java.io.FileWriter;
import java.io.PrintWriter;
import weka.classifiers.lazy.kstar.KStarConstants;

/* loaded from: input_file:builds/machlearn_install.jar:machlearn.jar:edu/wisc/sjm/machlearn/regressors/neuralnets/BPNeuralNet.class */
public class BPNeuralNet extends MainClass {
    protected DoubleMatrix whidden;
    protected DoubleMatrix woutput;
    protected int ninputs;
    protected int noutputs;
    protected int nhidden;
    protected double learn_rate;
    protected int maxEpoch;
    protected int trainepoch;
    private FeatureDataSet test_set;
    private PrintWriter accfile;
    private FeatureIdList fid;
    private double hidden_percentage;

    public BPNeuralNet() {
        this(50.0d, 100, 0.1d);
    }

    public BPNeuralNet(double d, int i, double d2) {
        this.hidden_percentage = d;
        this.maxEpoch = i;
        this.learn_rate = d2;
    }

    public BPNeuralNet(double d) {
        this(d, 100, 0.1d);
    }

    public void setMaxEpoch(int i) {
        this.maxEpoch = i;
    }

    public void setLearnRate(double d) {
        this.learn_rate = d;
    }

    public void setTestSet(FeatureDataSet featureDataSet) {
        this.test_set = featureDataSet;
    }

    public void initializeWeights() {
        initializeWeights(-0.3d, 0.3d);
    }

    public void initializeWeights(double d, double d2) {
        for (int i = 0; i <= this.ninputs; i++) {
            Util.randomizeDbl(this.whidden.get(i), d, d2);
        }
        for (int i2 = 0; i2 <= this.nhidden; i2++) {
            Util.randomizeDbl(this.woutput.get(i2), d, d2);
        }
    }

    protected double[] getOutput(Example example) {
        int i = 0;
        double[] dArr = new double[this.ninputs];
        for (int i2 = 0; i2 < example.numFeatures(); i2++) {
            if (i2 != example.getOutputIndex()) {
                dArr[i] = example.get(i2).getDValue();
                i++;
            }
        }
        return getOutput(dArr);
    }

    protected double[] getOutput(double[] dArr) {
        double[] dArr2 = new double[this.noutputs];
        double[] dArr3 = new double[this.nhidden];
        for (int i = 0; i < this.nhidden; i++) {
            double d = this.whidden.get(0, i);
            for (int i2 = 1; i2 <= this.ninputs; i2++) {
                d += this.whidden.get(i2, i) * dArr[i2 - 1];
            }
            dArr3[i] = activate(d);
        }
        for (int i3 = 0; i3 < this.noutputs; i3++) {
            double d2 = this.woutput.get(0, i3);
            for (int i4 = 1; i4 <= this.nhidden; i4++) {
                d2 += this.woutput.get(i4, i3) * dArr3[i4 - 1];
            }
            dArr2[i3] = activate(d2);
        }
        return dArr2;
    }

    public double getSE(double[][] dArr, double[][] dArr2) {
        double d = 0.0d;
        for (int i = 0; i < dArr.length; i++) {
            double d2 = getError(dArr[i], dArr2[i])[0];
            d += d2 * d2;
        }
        return d;
    }

    public double getSE(FeatureDataSet featureDataSet) throws InvalidFeature {
        return getSE(makeInput(featureDataSet), makeOutput(featureDataSet));
    }

    public double[][] makeInput(FeatureDataSet featureDataSet) throws InvalidFeature {
        double[][] dArr = new double[featureDataSet.size()][this.ninputs];
        for (int i = 0; i < featureDataSet.size(); i++) {
            int i2 = 0;
            for (int i3 = 0; i3 < featureDataSet.numFeatures(); i3++) {
                if (i3 != featureDataSet.getOutputIndex()) {
                    dArr[i][i2] = featureDataSet.get(i, i3).getDValue();
                    i2++;
                }
            }
        }
        return dArr;
    }

    public double[][] makeOutput(FeatureDataSet featureDataSet) throws InvalidFeature {
        double[][] dArr = new double[featureDataSet.size()][1];
        for (int i = 0; i < featureDataSet.size(); i++) {
            dArr[i][0] = featureDataSet.get(i, featureDataSet.getOutputIndex()).getDValue();
        }
        return dArr;
    }

    public double getMSE(FeatureDataSet featureDataSet) throws InvalidFeature {
        return getSE(featureDataSet) / featureDataSet.size();
    }

    public double getMSE(double[][] dArr, double[][] dArr2) {
        return getSE(dArr, dArr2) / dArr.length;
    }

    public double activate(double d) {
        return 1.0d / (1.0d + Math.exp(-d));
    }

    public double getExampleWeight(Example example) {
        double[] output = getOutput(example);
        return output[0] - output[1];
    }

    public void train(FeatureDataSet featureDataSet, FeatureDataSet featureDataSet2) throws Exception {
        FeatureDataSet featureDataSet3;
        FeatureDataSet featureDataSet4 = featureDataSet;
        if (featureDataSet2.size() / featureDataSet.size() < 0.09d) {
            System.out.println("Warning tune set is < 0.09% of train set");
            System.out.println("Regenerating tune set");
            FeatureDataSet[] featureDataSetArr = (FeatureDataSet[]) featureDataSet.splitRandom(10.0d);
            featureDataSet4 = featureDataSetArr[1];
            featureDataSet3 = featureDataSetArr[0];
        } else {
            featureDataSet3 = featureDataSet2;
        }
        System.out.println("BPNN Sean McIlwain");
        System.out.println("# inputs:" + this.ninputs);
        System.out.println("# outputs:" + this.noutputs);
        System.out.println("# hidden:" + this.nhidden);
        System.out.println("# Examples in train set:" + featureDataSet4.size());
        System.out.println("# Examples in tune  set:" + featureDataSet3.size());
        initializeWeights();
        int i = 0;
        double se = getSE(featureDataSet3);
        DoubleMatrix copyMatrix = this.whidden.copyMatrix();
        DoubleMatrix copyMatrix2 = this.woutput.copyMatrix();
        this.accfile = new PrintWriter(new FileWriter("bp_accuracies.txt"));
        this.accfile.println("epoch 0 : SE " + se);
        double[][] makeInput = makeInput(featureDataSet4);
        double[][] makeInput2 = makeInput(featureDataSet3);
        double[][] makeOutput = makeOutput(featureDataSet4);
        double[][] makeOutput2 = makeOutput(featureDataSet3);
        for (int i2 = 1; i2 <= this.maxEpoch; i2++) {
            runEpoch(makeInput, makeOutput);
            double se2 = getSE(makeInput2, makeOutput2);
            this.accfile.println("epoch " + i2 + " : Tune SE: " + se2 + " Train SE: " + getSE(makeInput, makeOutput));
            if (se2 < se) {
                se = se2;
                i = i2;
                copyMatrix = this.whidden.copyMatrix();
                copyMatrix2 = this.woutput.copyMatrix();
            }
        }
        System.out.println("best epoch:" + i + "Tune SE:" + se + " Train SE:" + getSE(makeInput, makeOutput));
        this.accfile.close();
        this.whidden = copyMatrix;
        this.woutput = copyMatrix2;
    }

    public void train(FeatureDataSet featureDataSet) throws Exception {
        this.ninputs = featureDataSet.numFeatures() - 1;
        this.noutputs = 1;
        this.nhidden = this.ninputs + this.noutputs;
        this.nhidden = (int) (this.nhidden * (this.hidden_percentage / 100.0d));
        this.whidden = new DoubleMatrix(this.ninputs + 1, this.nhidden);
        this.woutput = new DoubleMatrix(this.nhidden + 1, this.noutputs);
        DataSet[] splitRandom = featureDataSet.splitRandom(10.0d);
        train((FeatureDataSet) splitRandom[1], (FeatureDataSet) splitRandom[0]);
    }

    public void reportAccuracies(int i, FeatureDataSet featureDataSet, FeatureDataSet featureDataSet2) throws InvalidFeature {
        double se = getSE(featureDataSet);
        double se2 = getSE(featureDataSet2);
        System.out.println("=============================");
        System.out.println("EPOCH #" + i);
        System.out.println("=============================");
        System.out.println("Train Accuracy:" + se);
        System.out.println("Tune Accuracy:" + se2);
        System.out.println("Learning Rate:" + this.learn_rate);
        this.accfile.println(i + "\t" + se + "\t" + se2);
        this.accfile.flush();
    }

    public void runEpoch(double[][] dArr, double[][] dArr2) throws Exception {
        int[] randomIntList = Util.randomIntList(0, dArr.length - 1);
        for (int i = 0; i < dArr.length; i++) {
            trainExample(dArr[randomIntList[i]], dArr2[randomIntList[i]]);
        }
    }

    public void trainExample(double[] dArr, double[] dArr2) {
        double[] dArr3 = new double[this.nhidden];
        double[] dArr4 = new double[this.noutputs];
        for (int i = 0; i < this.nhidden; i++) {
            double d = this.whidden.get(0, i);
            for (int i2 = 1; i2 <= this.ninputs; i2++) {
                d += this.whidden.get(i2, i) * dArr[i2 - 1];
            }
            dArr3[i] = activate(d);
        }
        for (int i3 = 0; i3 < this.noutputs; i3++) {
            double d2 = this.woutput.get(0, i3);
            for (int i4 = 1; i4 <= this.nhidden; i4++) {
                d2 += this.woutput.get(i4, i3) * dArr3[i4 - 1];
            }
            dArr4[i3] = activate(d2);
        }
        double[] dArr5 = new double[this.noutputs];
        double[] dArr6 = new double[this.nhidden + 1];
        for (int i5 = 0; i5 < this.noutputs; i5++) {
            dArr5[i5] = dArr4[i5] * (1.0d - dArr4[i5]) * (dArr2[i5] - dArr4[i5]);
        }
        double totalError = getTotalError(dArr, dArr2);
        double d3 = 0.0d;
        for (int i6 = 0; i6 < this.noutputs; i6++) {
            d3 += this.woutput.get(0, i6) * dArr5[i6];
        }
        dArr6[0] = d3 * 1.0d * KStarConstants.FLOOR;
        for (int i7 = 1; i7 <= this.nhidden; i7++) {
            double d4 = 0.0d;
            for (int i8 = 0; i8 < this.noutputs; i8++) {
                d4 += this.woutput.get(i7, i8) * dArr5[i8];
            }
            dArr6[i7] = d4 * dArr3[i7 - 1] * (1.0d - dArr3[i7 - 1]);
        }
        DoubleMatrix zeroMatrix = this.woutput.zeroMatrix();
        DoubleMatrix zeroMatrix2 = this.whidden.zeroMatrix();
        for (int i9 = 0; i9 < this.noutputs; i9++) {
            zeroMatrix.set(0, i9, this.learn_rate * dArr5[i9]);
            for (int i10 = 1; i10 <= this.nhidden; i10++) {
                zeroMatrix.set(i10, i9, this.learn_rate * dArr5[i9] * dArr3[i10 - 1]);
            }
        }
        for (int i11 = 0; i11 < this.nhidden; i11++) {
            zeroMatrix2.set(0, i11, this.learn_rate * dArr6[i11]);
            for (int i12 = 1; i12 <= this.ninputs; i12++) {
                zeroMatrix2.set(i12, i11, this.learn_rate * dArr6[i11] * dArr[i12 - 1]);
            }
        }
        this.whidden.add(zeroMatrix2);
        this.woutput.add(zeroMatrix);
        if (getTotalError(dArr, dArr2) >= totalError) {
            this.learn_rate *= 0.9d;
        } else {
            this.learn_rate = Util.min(0.5d, this.learn_rate * 1.1d);
        }
    }

    public double[] getError(double[] dArr, double[] dArr2) {
        double[] output = getOutput(dArr);
        for (int i = 0; i < this.noutputs; i++) {
            output[i] = dArr2[i] - output[i];
        }
        return output;
    }

    public double getTotalError(double[] dArr, double[] dArr2) {
        return getTotalError(getError(dArr, dArr2));
    }

    public double getTotalError(double[] dArr) {
        double d = 0.0d;
        for (int i = 0; i < this.noutputs; i++) {
            d += dArr[i] * dArr[i];
        }
        return Math.sqrt(d);
    }

    public Feature classify(Example example) {
        double[] output = getOutput(example);
        Feature feature = (Feature) example.getOutputFeature().clone();
        try {
            feature.setValue(Util.argmax(output));
        } catch (InvalidFeature e) {
            internalError(e);
        }
        return feature;
    }

    public void tune(FeatureDataSet featureDataSet, Object[] objArr) {
    }

    public String printClassifier() {
        return "BPNeuralNet\n";
    }

    public void setParameter(int i, Object obj) {
        this.hidden_percentage = ((Integer) obj).intValue();
        this.hidden_percentage = Util.max(50.0d, this.hidden_percentage);
    }
}
