package edu.wisc.sjm.machlearn.policy.fdspreprocessor.selection.wrapper;

import edu.wisc.sjm.jutil.misc.BooleanArray;
import edu.wisc.sjm.machlearn.Scorer;
import edu.wisc.sjm.machlearn.dataset.DataSet;
import edu.wisc.sjm.machlearn.dataset.featuredataset.FeatureDataSet;
import edu.wisc.sjm.machlearn.featureselection.FSDataSet;
import edu.wisc.sjm.machlearn.policy.FDSPreProcessor;
import edu.wisc.sjm.machlearn.util.Util;
import weka.classifiers.lazy.kstar.KStarConstants;

/* loaded from: input_file:builds/machlearn_install.jar:machlearn.jar:edu/wisc/sjm/machlearn/policy/fdspreprocessor/selection/wrapper/WalkSat.class */
public class WalkSat extends FDSPreProcessor {
    protected int max_climbs;
    protected int max_flips;
    protected double rand_prob;
    protected int min_features;
    protected Scorer scorer;
    protected BooleanArray ba;
    protected int validation;

    public WalkSat() {
        this(null);
    }

    public WalkSat(Scorer scorer) {
        this.scorer = scorer;
        this.ba = null;
        this.max_climbs = -1;
        this.max_flips = -1;
        this.rand_prob = -1.0d;
        this.min_features = 3;
        this.validation = -1;
    }

    public void setMinFeatures(int i) {
        this.min_features = i;
    }

    public void setValidation(String str) {
        setValidation(Integer.parseInt(str));
    }

    public void setValidation(int i) {
        this.validation = i;
    }

    protected void initializeConstants(FeatureDataSet featureDataSet) {
        if (this.max_climbs <= 0) {
            System.out.println("Setting max_climbs to 10");
            this.max_climbs = 10;
        }
        if (this.max_flips <= 0) {
            System.out.println("Setting max_flips to 2 * numFeatures()");
            this.max_flips = 2 * featureDataSet.numFeatures();
        }
        if (this.rand_prob < 1.0E-5d) {
            System.out.println("Setting rand_prob to 0.15");
            this.rand_prob = 0.15d;
        }
        System.out.println("WALKSAT:");
        System.out.println("Size of training set:" + featureDataSet.size());
        System.out.println("max_climbs:" + this.max_climbs);
        System.out.println("max_flips:" + this.max_flips);
        System.out.println("rand_prob:" + this.rand_prob);
        System.out.println("min_features:" + this.min_features);
    }

    protected double calcScore(BooleanArray booleanArray, DataSet[][] dataSetArr) throws Exception {
        for (int i = 0; i < dataSetArr.length; i++) {
            FeatureDataSet applyMask = FSDataSet.applyMask((FeatureDataSet) dataSetArr[i][0], booleanArray.getArray());
            FSDataSet.applyMask((FeatureDataSet) dataSetArr[i][1], booleanArray.getArray());
            this.scorer.doTrain(applyMask);
        }
        return KStarConstants.FLOOR / dataSetArr.length;
    }

    @Override // edu.wisc.sjm.machlearn.policy.FDSPreProcessor
    public FeatureDataSet process(FeatureDataSet featureDataSet) {
        try {
            return FSDataSet.applyMask(featureDataSet, this.ba.getArray());
        } catch (Exception e) {
            internalError(e);
            return null;
        }
    }

    @Override // edu.wisc.sjm.machlearn.policy.FDSPreProcessor
    public void train(FeatureDataSet featureDataSet) {
        try {
            wstrain(featureDataSet);
        } catch (Exception e) {
            internalError(e);
        }
    }

    public void wstrain(FeatureDataSet featureDataSet) throws Exception {
        initializeConstants(featureDataSet);
        DataSet[][] splitDataSetValidation = featureDataSet.splitDataSetValidation(this.validation);
        BooleanArray booleanArray = null;
        double d = -1.0d;
        double d2 = 0.0d;
        for (int i = 0; i < this.max_climbs; i++) {
            BooleanArray booleanArray2 = new BooleanArray(featureDataSet.numFeatures());
            booleanArray2.randomize();
            while (booleanArray2.numFalse() < this.min_features) {
                booleanArray2.randomize();
            }
            double calcScore = calcScore(booleanArray2, splitDataSetValidation);
            int i2 = 0;
            while (true) {
                if (i2 >= this.max_flips) {
                    break;
                }
                if (!Util.randomBoolean(this.rand_prob)) {
                    System.out.println("GSAT move");
                    int i3 = -1;
                    double d3 = calcScore;
                    for (int i4 = 0; i4 < featureDataSet.numFeatures(); i4++) {
                        if (i4 != featureDataSet.getOutputIndex()) {
                            booleanArray2.flip(i4);
                            d2 = booleanArray2.numFalse() < this.min_features ? -1.0d : calcScore(booleanArray2, splitDataSetValidation);
                            booleanArray2.flip(i4);
                            if (d2 > d3) {
                                d3 = d2;
                                i3 = i4;
                            }
                        }
                    }
                    if (i3 == -1) {
                        System.out.println("a best flip wasn't found!\n");
                        break;
                    }
                    d2 = d3;
                    int i5 = i3;
                    if (d2 > calcScore) {
                        System.out.println("old max:" + calcScore + " new max:" + d2);
                        calcScore = d2;
                        booleanArray2.flip(i5);
                    }
                } else {
                    int randomInteger = Util.randomInteger(0, featureDataSet.numFeatures() - 1);
                    System.out.println("Random Move");
                    booleanArray2.flip(randomInteger);
                    while (booleanArray2.numFalse() < this.min_features) {
                        booleanArray2.flip(randomInteger);
                        randomInteger = Util.randomInteger(0, featureDataSet.numFeatures() - 1);
                        booleanArray2.flip(randomInteger);
                    }
                    calcScore = calcScore(booleanArray2, splitDataSetValidation);
                }
                System.out.println("Flip #" + i2 + " temp:" + d2 + " max:" + calcScore);
                i2++;
            }
            if (booleanArray == null || calcScore > d) {
                booleanArray = booleanArray2;
                d = calcScore;
            }
            System.out.println("Climb #:" + i + " best score:" + d);
            System.out.println("Features:");
            System.out.println(FSDataSet.printFeatureIdList(featureDataSet, booleanArray.getArray()));
        }
        this.ba = booleanArray;
    }

    public void printFeatureIdList(FeatureDataSet featureDataSet) {
        if (this.ba != null) {
            System.out.println(FSDataSet.printFeatureIdList(featureDataSet, this.ba.getArray()));
        } else {
            System.out.println("Train WalkSat first...\n");
        }
    }

    @Override // edu.wisc.sjm.machlearn.policy.FDSPreProcessor
    public boolean needScorer() {
        return true;
    }

    @Override // edu.wisc.sjm.machlearn.policy.FDSPreProcessor
    public void setScorer(Scorer scorer) {
        System.out.println("WalkSat(): setScorer");
        this.scorer = scorer;
    }
}
