package com.sjm.machlearn.featureselection;

import com.sjm.machlearn.classifiers.Classifier;
import com.sjm.machlearn.dataset.DataSet;
import com.sjm.machlearn.util.Util;

/* loaded from: input_file:com/sjm/machlearn/featureselection/ForwardSelect.class */
public class ForwardSelect extends FeatureSelect {
    public ForwardSelect(boolean z) {
        super(z);
    }

    @Override // com.sjm.machlearn.featureselection.FeatureSelect
    public boolean[] doFeatureSelection(DataSet dataSet, Classifier classifier) {
        double d;
        DataSet[] splitRandom = dataSet.splitRandom(10.0d);
        FSDataSet fSDataSet = new FSDataSet(splitRandom[0]);
        FSDataSet fSDataSet2 = new FSDataSet(splitRandom[1]);
        if (this.debug) {
            fSDataSet.debugOn();
            fSDataSet2.debugOn();
        }
        fSDataSet.setAll();
        fSDataSet2.setAll();
        int outputIndex = fSDataSet.getOutputIndex();
        fSDataSet.unsetMask(outputIndex);
        fSDataSet2.unsetMask(outputIndex);
        debugMesg(new StringBuffer("shown features(should be 1 (output):").append(fSDataSet2.numShownFeatures()).toString());
        int[] hiddenFeatureIndices = fSDataSet2.getHiddenFeatureIndices();
        double[] dArr = new double[hiddenFeatureIndices.length];
        for (int i = 0; i < hiddenFeatureIndices.length; i++) {
            debugMesg(new StringBuffer("testing feature #:").append(hiddenFeatureIndices[i]).toString());
            fSDataSet2.unsetMask(hiddenFeatureIndices[i]);
            fSDataSet.unsetMask(hiddenFeatureIndices[i]);
            try {
                classifier.train(fSDataSet2.getDataSet());
            } catch (Exception e) {
                internalError(e);
            }
            dArr[i] = classifier.getAccuracy(fSDataSet.getDataSet());
            debugMesg(new StringBuffer("accuracy:").append(dArr[i]).toString());
            fSDataSet2.setMask(hiddenFeatureIndices[i]);
            fSDataSet.setMask(hiddenFeatureIndices[i]);
        }
        debugMesg(Util.printArray(dArr));
        int argmax = Util.argmax(dArr);
        double d2 = 0.0d;
        double d3 = dArr[argmax];
        while (true) {
            d = d3;
            if (d <= d2 || hiddenFeatureIndices.length <= 1) {
                break;
            }
            d2 = d;
            fSDataSet2.unsetMask(hiddenFeatureIndices[argmax]);
            fSDataSet.unsetMask(hiddenFeatureIndices[argmax]);
            hiddenFeatureIndices = fSDataSet2.getHiddenFeatureIndices();
            double[] dArr2 = new double[hiddenFeatureIndices.length];
            for (int i2 = 0; i2 < hiddenFeatureIndices.length; i2++) {
                debugMesg(new StringBuffer("testing feature #:").append(hiddenFeatureIndices[i2]).toString());
                fSDataSet2.unsetMask(hiddenFeatureIndices[i2]);
                fSDataSet.unsetMask(hiddenFeatureIndices[i2]);
                try {
                    classifier.train(fSDataSet2.getDataSet());
                } catch (Exception e2) {
                    internalError(e2);
                }
                dArr2[i2] = classifier.getAccuracy(fSDataSet.getDataSet());
                debugMesg(new StringBuffer("accuracy:").append(dArr2[i2]).toString());
                fSDataSet2.setMask(hiddenFeatureIndices[i2]);
                fSDataSet.setMask(hiddenFeatureIndices[i2]);
            }
            argmax = Util.argmax(dArr2);
            d3 = dArr2[argmax];
        }
        if (d > d2) {
            fSDataSet2.unsetMask(hiddenFeatureIndices[argmax]);
            fSDataSet.unsetMask(hiddenFeatureIndices[argmax]);
        }
        return fSDataSet2.getMasks();
    }
}
