package edu.wisc.sjm.machlearn.featureselection;

import edu.wisc.sjm.jutil.misc.BooleanArray;
import edu.wisc.sjm.jutil.misc.IntegerKey;
import edu.wisc.sjm.machlearn.MachLearnConstants;
import edu.wisc.sjm.machlearn.classifiers.Classifier;
import edu.wisc.sjm.machlearn.dataset.DataSet;
import edu.wisc.sjm.machlearn.dataset.FeatureIdList;
import edu.wisc.sjm.machlearn.dataset.featuredataset.FeatureDataSet;
import edu.wisc.sjm.machlearn.util.APRUtil;
import java.util.Hashtable;

/* loaded from: input_file:builds/machlearn_install.jar:machlearn.jar:edu/wisc/sjm/machlearn/featureselection/RandomClimbSelect.class */
public class RandomClimbSelect extends FeatureSelect implements MachLearnConstants {
    protected int max_features;
    protected int max_iterations;
    protected int max_restarts;
    protected double alpha;
    protected int validation;
    protected FeatureDataSet[][] foldsets;
    protected FeatureIdList fid;
    protected Classifier score_class;
    protected boolean dotune;
    protected Object[] tune_parameters;
    protected int tune_validation;
    private Hashtable<IntegerKey, Double> score_cache;

    public RandomClimbSelect() {
        this(false);
    }

    public RandomClimbSelect(boolean z) {
        super(z);
        this.max_features = 1;
        this.max_iterations = 1;
        this.max_restarts = 1;
        this.alpha = 0.5d;
        this.validation = -1;
        this.score_cache = new Hashtable<>();
        this.validation = 10;
        setMaxIterations(100);
        setMaxRestarts(10);
    }

    public void setMaxFeatures(int i) {
        this.max_features = i;
    }

    public void setMaxIterations(int i) {
        this.max_iterations = i;
    }

    public void setMaxRestarts(int i) {
        this.max_restarts = i;
    }

    public void setAlpha(double d) {
        this.alpha = d;
    }

    public void setValidationType(int i) {
        this.validation = i;
    }

    public void setTuneParameters(Object[] objArr) {
        setTuneParameters(objArr, -1);
    }

    public void setTuneParameters(Object[] objArr, int i) {
        this.tune_parameters = objArr;
        this.dotune = true;
        this.tune_validation = i;
    }

    public void printParameters() {
        System.out.println("Random Hill-Climbing");
    }

    @Override // edu.wisc.sjm.machlearn.featureselection.FeatureSelect
    public boolean[] doFeatureSelection(FeatureDataSet featureDataSet, Classifier classifier) throws Exception {
        DataSet[][] splitDataSetFolds = featureDataSet.splitDataSetFolds(10, true, true);
        this.score_cache.clear();
        BooleanArray booleanArray = new BooleanArray(featureDataSet.numFeatures());
        BooleanArray booleanArray2 = new BooleanArray(featureDataSet.numFeatures());
        double d = Double.NEGATIVE_INFINITY;
        for (int i = 0; i < this.max_restarts; i++) {
            double doFeatureSelectionIteration = doFeatureSelectionIteration(splitDataSetFolds, featureDataSet.getOutputIndex(), classifier, booleanArray);
            if (doFeatureSelectionIteration > d) {
                d = doFeatureSelectionIteration;
                BooleanArray booleanArray3 = booleanArray2;
                booleanArray2 = booleanArray;
                booleanArray = booleanArray3;
            }
        }
        printFeatures(featureDataSet, booleanArray2.getArray());
        return booleanArray2.getArray();
    }

    public double getAPR(Classifier classifier, DataSet[][] dataSetArr, BooleanArray booleanArray) throws Exception {
        double doubleValue;
        Double d = this.score_cache.get(booleanArray.getKey());
        if (d == null) {
            doubleValue = APRUtil.getAPR(classifier, dataSetArr, booleanArray.getArray());
            this.score_cache.put(booleanArray.getKey(), new Double(doubleValue));
        } else {
            doubleValue = d.doubleValue();
        }
        return doubleValue;
    }

    public double doFeatureSelectionIteration(DataSet[][] dataSetArr, int i, Classifier classifier, BooleanArray booleanArray) throws Exception {
        booleanArray.randomize();
        double apr = getAPR(classifier, dataSetArr, booleanArray);
        double d = apr;
        for (int i2 = 0; i2 < this.max_iterations; i2++) {
            int i3 = -1;
            for (int i4 = 0; i4 < booleanArray.size(); i4++) {
                if (i4 != i) {
                    booleanArray.flip(i4);
                    double apr2 = getAPR(classifier, dataSetArr, booleanArray);
                    booleanArray.flip(i4);
                    if (apr2 > d) {
                        i3 = i4;
                        d = apr2;
                    }
                }
            }
            if (i3 == -1 || d <= apr) {
                System.out.println("Maximized at " + i2 + " iterations:" + d);
                break;
            }
            booleanArray.flip(i3);
            apr = d;
        }
        return d;
    }
}
