package edu.wisc.sjm.machlearn.regressors;

import edu.wisc.sjm.jutil.vars.DoubleVar;
import edu.wisc.sjm.jutil.vectors.DoubleVector;
import edu.wisc.sjm.machlearn.Scorer;
import edu.wisc.sjm.machlearn.dataset.DataSet;
import edu.wisc.sjm.machlearn.dataset.Example;
import edu.wisc.sjm.machlearn.dataset.featuredataset.FeatureDataSet;
import edu.wisc.sjm.machlearn.exceptions.InvalidFeature;
import java.util.Iterator;
import java.util.Vector;

/* loaded from: input_file:builds/machlearn_install.jar:machlearn.jar:edu/wisc/sjm/machlearn/regressors/Regressor.class */
public abstract class Regressor extends Scorer {
    public abstract void train(FeatureDataSet featureDataSet) throws Exception;

    public abstract double regress(Example example) throws Exception;

    @Override // edu.wisc.sjm.machlearn.Scorer
    public double getScore(Example example) throws Exception {
        return regress(example);
    }

    public void regress(FeatureDataSet featureDataSet, double[] dArr) throws Exception {
        for (int i = 0; i < featureDataSet.size(); i++) {
            dArr[i] = regress(featureDataSet.getExample(i));
        }
    }

    @Override // edu.wisc.sjm.machlearn.Scorer
    public void doTrain(DataSet dataSet) throws Exception {
        train((FeatureDataSet) dataSet);
    }

    public double getScore(DataSet dataSet) throws Exception {
        return 1.0d - (getSE((FeatureDataSet) dataSet) / getSSTotal((FeatureDataSet) dataSet));
    }

    public static double getSSTotal(FeatureDataSet featureDataSet) throws InvalidFeature {
        double d = 0.0d;
        double d2 = 0.0d;
        for (int i = 0; i < featureDataSet.size(); i++) {
            d2 += featureDataSet.getOutputFeature(i).getDValue();
        }
        double size = d2 / featureDataSet.size();
        for (int i2 = 0; i2 < featureDataSet.size(); i2++) {
            double dValue = featureDataSet.getOutputFeature(i2).getDValue() - size;
            d += dValue * dValue;
        }
        return d;
    }

    public double[] regress(FeatureDataSet featureDataSet) throws Exception {
        double[] dArr = new double[featureDataSet.size()];
        regress(featureDataSet, dArr);
        return dArr;
    }

    public DoubleVector getAbsError(FeatureDataSet featureDataSet, DoubleVector doubleVector) throws Exception {
        DoubleVector doubleVector2 = doubleVector;
        if (doubleVector2 == null) {
            doubleVector2 = new DoubleVector();
        } else {
            doubleVector2.empty();
        }
        for (int i = 0; i < featureDataSet.size(); i++) {
            doubleVector2.add(getAbsError(featureDataSet.getExample(i)));
        }
        return doubleVector2;
    }

    public double getAbsError(Example example) throws Exception {
        return Math.abs(example.getOutputFeature().getDValue() - regress(example));
    }

    public DoubleVector getSQError(FeatureDataSet featureDataSet, DoubleVector doubleVector) throws Exception {
        DoubleVector doubleVector2 = doubleVector;
        if (doubleVector2 == null) {
            doubleVector2 = new DoubleVector();
        } else {
            doubleVector2.empty();
        }
        for (int i = 0; i < featureDataSet.size(); i++) {
            doubleVector2.add(getSQError(featureDataSet.getExample(i)));
        }
        return doubleVector2;
    }

    public double getSQError(Example example) throws Exception {
        double dValue = example.getOutputFeature().getDValue() - regress(example);
        return dValue * dValue;
    }

    public double getSE(FeatureDataSet featureDataSet) throws Exception {
        double[] regress = regress(featureDataSet);
        double d = 0.0d;
        for (int i = 0; i < featureDataSet.size(); i++) {
            double dValue = featureDataSet.getExample(i).getOutputFeature().getDValue() - regress[i];
            d += dValue * dValue;
        }
        return d;
    }

    public double getMSE(FeatureDataSet featureDataSet) throws Exception {
        return getSE(featureDataSet) / featureDataSet.size();
    }

    public static Regressor getBestRegressor(DataSet dataSet, Vector<Regressor> vector) throws Exception {
        return getBestRegressor(dataSet, vector, new DoubleVar());
    }

    public static Regressor getBestRegressor(DataSet dataSet, Vector<Regressor> vector, DoubleVar doubleVar) throws Exception {
        return getBestRegressor(dataSet.splitDataSetFolds(10, true, false), vector, doubleVar);
    }

    public static Regressor getBestRegressor(DataSet[][] dataSetArr, Vector<Regressor> vector, DoubleVar doubleVar) throws Exception {
        doubleVar.value = Double.POSITIVE_INFINITY;
        Regressor regressor = null;
        Iterator<Regressor> it = vector.iterator();
        while (it.hasNext()) {
            Regressor next = it.next();
            double d = 0.0d;
            for (int i = 0; i < dataSetArr.length; i++) {
                next.train((FeatureDataSet) dataSetArr[i][0]);
                d += next.getMSE((FeatureDataSet) dataSetArr[i][1]);
            }
            double length = d / dataSetArr.length;
            if (length < doubleVar.value) {
                regressor = next;
                doubleVar.value = length;
            }
        }
        return regressor;
    }
}
