package edu.wisc.sjm.machlearn.policy.fdspreprocessor.selection.filter;

import edu.wisc.sjm.jutil.misc.BooleanArray;
import edu.wisc.sjm.jutil.vectors.DoubleVector;
import edu.wisc.sjm.jutil.vectors.IntVector;
import edu.wisc.sjm.machlearn.Scorer;
import edu.wisc.sjm.machlearn.classifiers.knn.KNNComparator;
import edu.wisc.sjm.machlearn.classifiers.knn.KNNIndex;
import edu.wisc.sjm.machlearn.dataset.conversions.ID3Converter;
import edu.wisc.sjm.machlearn.dataset.featuredataset.FeatureDataSet;
import edu.wisc.sjm.machlearn.exceptions.InvalidFeature;
import edu.wisc.sjm.machlearn.exceptions.parameters.ParameterException;
import edu.wisc.sjm.machlearn.featureselection.FSDataSet;
import edu.wisc.sjm.machlearn.parameters.Parameter;
import edu.wisc.sjm.machlearn.parameters.ParameterSupportObject;
import edu.wisc.sjm.machlearn.policy.FDSPreProcessor;
import edu.wisc.sjm.machlearn.util.Util;
import java.util.Iterator;
import java.util.TreeSet;

/* loaded from: input_file:builds/machlearn_install.jar:machlearn.jar:edu/wisc/sjm/machlearn/policy/fdspreprocessor/selection/filter/InfoGainFilterPercent.class */
public class InfoGainFilterPercent extends FDSPreProcessor implements ParameterSupportObject {
    protected double keep_percent;
    protected BooleanArray mask;

    public InfoGainFilterPercent() {
        this.keep_percent = 100.0d;
    }

    public InfoGainFilterPercent(double d) {
        this.keep_percent = Util.min(Util.max(1.0E-4d, d), 100.0d);
    }

    public double getKeepPercent() {
        return this.keep_percent;
    }

    public void setKeepPercent(double d) {
        this.keep_percent = Util.min(Util.max(1.0E-4d, d), 100.0d);
    }

    public void setKeepPercent(String str) {
        setKeepPercent(Double.parseDouble(str));
    }

    @Override // edu.wisc.sjm.machlearn.policy.FDSPreProcessor
    public void train(FeatureDataSet featureDataSet) {
        int max = Util.max(Util.min((int) (((featureDataSet.numFeatures() - 1) * this.keep_percent) / 100.0d), featureDataSet.numFeatures() - 1), 2);
        System.out.println("keeping " + max + " features");
        FeatureDataSet convert = new ID3Converter().convert(featureDataSet);
        this.mask = new BooleanArray(featureDataSet.numFeatures());
        TreeSet treeSet = new TreeSet(new KNNComparator(true));
        for (int i = 0; i < convert.numFeatures(); i++) {
            if (i != convert.getOutputIndex()) {
                try {
                    if (!treeSet.add(new KNNIndex(i, convert.getInfoGain(i)))) {
                        System.out.println("Bug here");
                        System.exit(-1);
                    }
                } catch (InvalidFeature e) {
                    System.out.println("Bug here");
                    System.exit(-1);
                }
            }
        }
        Iterator it = treeSet.iterator();
        this.mask.setTrue();
        IntVector intVector = new IntVector();
        DoubleVector doubleVector = new DoubleVector();
        for (int i2 = 0; i2 < max; i2++) {
            KNNIndex kNNIndex = (KNNIndex) it.next();
            if (kNNIndex.getDist() <= 1.0E-10d) {
                break;
            }
            intVector.add(kNNIndex.getIndex());
            doubleVector.add(kNNIndex.getDist());
            this.mask.setFalse(kNNIndex.getIndex());
        }
        System.out.println("Num features kept:" + intVector.size());
        if (intVector.size() < 50) {
            System.out.println("Kept features");
            for (int i3 = 0; i3 < intVector.size(); i3++) {
                System.out.println(String.valueOf(featureDataSet.getFeatureId(intVector.get(i3)).printName()) + "\t" + doubleVector.get(i3));
            }
        }
    }

    @Override // edu.wisc.sjm.machlearn.policy.FDSPreProcessor
    public FeatureDataSet process(FeatureDataSet featureDataSet) {
        return FSDataSet.applyMask(featureDataSet, this.mask.getArray());
    }

    @Override // edu.wisc.sjm.machlearn.parameters.ParameterSupportObject
    public void setParameter(int i, String str) throws ParameterException {
        switch (i) {
            case 0:
                setKeepPercent(Double.parseDouble(str));
                System.out.println("KeepPercent is now:" + getKeepPercent());
                return;
            default:
                throw new ParameterException("Parameter doesn't exist!");
        }
    }

    @Override // edu.wisc.sjm.machlearn.parameters.ParameterSupportObject
    public Parameter getParameter(int i) throws ParameterException {
        switch (i) {
            case 0:
                return new Parameter(this, "KeepPercent", 0, new StringBuilder().append(getKeepPercent()).toString());
            default:
                throw new ParameterException("Parameter doesn't exist!");
        }
    }

    @Override // edu.wisc.sjm.machlearn.parameters.ParameterSupportObject
    public String getParameterValue(int i) {
        switch (i) {
            case 0:
                return new StringBuilder().append(getKeepPercent()).toString();
            default:
                return null;
        }
    }

    @Override // edu.wisc.sjm.machlearn.parameters.ParameterSupportObject
    public int numParameters() {
        return 1;
    }

    @Override // edu.wisc.sjm.machlearn.parameters.ParameterSupportObject
    public String getPSOName() {
        return "InfoGainFilterPercent";
    }

    @Override // edu.wisc.sjm.machlearn.policy.FDSPreProcessor
    public boolean needScorer() {
        return false;
    }

    @Override // edu.wisc.sjm.machlearn.policy.FDSPreProcessor
    public void setScorer(Scorer scorer) {
    }
}
