package edu.wisc.sjm.jutil.math;

import java.util.Arrays;

/* loaded from: input_file:builds/machlearn_install.jar:builds/auc_install.jar:builds/jutil_install.jar:jutil.jar:edu/wisc/sjm/jutil/math/SimplexSolver.class */
public class SimplexSolver implements SolverInterface {
    protected SimplexPoint[] simplex;
    protected SimplexPoint guessPoint;
    protected SimplexFunction function;
    protected int max_iterations;
    double[] averageWeight;
    static SimplexPoint reflected;
    static SimplexPoint expanded;
    static SimplexPoint outside;
    static SimplexPoint inside;
    private static double[][] array1 = null;
    private static double[][] array2 = null;
    static SimplexPoint average = null;

    public SimplexSolver() {
        this.averageWeight = null;
        setMaxIterations(1000);
    }

    public SimplexSolver(SimplexFunction simplexFunction) {
        this(simplexFunction, 1000);
    }

    public SimplexSolver(SimplexFunction simplexFunction, int i) {
        this.averageWeight = null;
        setMaxIterations(i);
        setFunction(simplexFunction);
    }

    public void setMaxIterations(int i) {
        this.max_iterations = i;
    }

    @Override // edu.wisc.sjm.jutil.math.SolverInterface
    public void setFunction(SimplexFunction simplexFunction) {
        this.function = simplexFunction;
    }

    @Override // edu.wisc.sjm.jutil.math.SolverInterface
    public void fit() throws Exception {
        double initialSimplexSize = this.function.getInitialSimplexSize();
        double convergenceThreshold = this.function.getConvergenceThreshold();
        SimplexPoint simplexPoint = new SimplexPoint();
        this.simplex = new SimplexPoint[this.function.getDimension() + 1];
        for (int i = 0; i < this.simplex.length; i++) {
            this.simplex[i] = new SimplexPoint();
        }
        double[] firstGuess = this.function.getFirstGuess();
        this.guessPoint = new SimplexPoint(firstGuess, this.function);
        int i2 = 0;
        do {
            this.simplex[0].copyFrom(this.guessPoint);
            firstGuess = this.simplex[0].getX(firstGuess);
            for (int i3 = 1; i3 < this.simplex.length; i3++) {
                double d = firstGuess[i3 - 1];
                int i4 = i3 - 1;
                firstGuess[i4] = firstGuess[i4] + initialSimplexSize;
                this.simplex[i3].redefine(firstGuess, this.function);
                firstGuess[i3 - 1] = d;
            }
            Arrays.sort(this.simplex);
            double d2 = Double.POSITIVE_INFINITY;
            for (int i5 = 0; d2 > convergenceThreshold && i5 < this.max_iterations; i5++) {
                iterate();
                Arrays.sort(this.simplex);
                d2 = this.simplex[0].getDifference(this.simplex[this.simplex.length - 1]);
            }
            simplexPoint.copyFrom(this.guessPoint);
            this.guessPoint.copyFrom(this.simplex[0]);
            i2++;
            if (this.guessPoint.getDifference(simplexPoint) <= convergenceThreshold) {
                return;
            }
        } while (i2 < 5);
    }

    @Override // edu.wisc.sjm.jutil.math.SolverInterface
    public double[] getSolution() {
        return this.function.makeFinalAnswerNice(this.guessPoint.getX());
    }

    public double getSolutionY() {
        return this.guessPoint.getY();
    }

    private void myAssert(boolean z, String str) {
        if (z) {
            return;
        }
        if (str != null) {
            throw new IllegalStateException("ASSERTION FAILED: " + str);
        }
        throw new IllegalStateException("ASSERTION FAILED");
    }

    protected void iterate() throws Exception {
        int length = this.simplex.length - 1;
        double rho = this.function.getRho();
        double chi = this.function.getChi();
        double gamma = this.function.getGamma();
        double sigma = this.function.getSigma();
        if (length == 0) {
            return;
        }
        if (this.averageWeight == null || this.averageWeight.length != this.simplex.length) {
            this.averageWeight = new double[this.simplex.length];
            for (int i = 0; i < length; i++) {
                this.averageWeight[i] = 1.0d / length;
            }
            this.averageWeight[length] = 0.0d;
        }
        if (average == null) {
            average = new SimplexPoint();
            reflected = new SimplexPoint();
            expanded = new SimplexPoint();
            outside = new SimplexPoint();
            inside = new SimplexPoint();
        }
        average.redefine(this.simplex, this.averageWeight);
        reflected.redefine(average, 1.0d + rho, this.simplex[length], -rho);
        if (this.simplex[0].lte(reflected) && reflected.lt(this.simplex[length - 1])) {
            this.simplex[length].copyFrom(reflected);
            return;
        }
        if (reflected.lt(this.simplex[0])) {
            expanded.redefine(average, 1.0d + (rho * chi), this.simplex[length], (-rho) * chi);
            if (expanded.lt(reflected)) {
                this.simplex[length].copyFrom(expanded);
                return;
            } else {
                this.simplex[length].copyFrom(reflected);
                return;
            }
        }
        myAssert(reflected.gte(this.simplex[length - 1]), "f_r >= f_n");
        if (reflected.lt(this.simplex[length])) {
            outside.redefine(average, 1.0d + (rho * gamma), this.simplex[length], (-rho) * gamma);
            if (outside.lt(reflected)) {
                this.simplex[length].copyFrom(outside);
                return;
            }
        } else {
            inside.redefine(average, 1.0d - gamma, this.simplex[length], gamma);
            if (inside.lt(this.simplex[length])) {
                this.simplex[length].copyFrom(inside);
                return;
            }
        }
        for (int i2 = 1; i2 < this.simplex.length; i2++) {
            this.simplex[i2].redefine(this.simplex[0], 1.0d - sigma, this.simplex[i2], sigma);
        }
    }

    private void printSimplex() {
        for (int i = 0; i < this.simplex.length; i++) {
            System.out.println(String.valueOf(i) + "\t" + this.simplex[i]);
        }
    }
}
