package edu.wisc.sjm.machlearn.policy.fdspreprocessor.selection.filter;

import edu.wisc.sjm.jutil.math.JMath;
import edu.wisc.sjm.jutil.vectors.IntVector;
import edu.wisc.sjm.machlearn.Scorer;
import edu.wisc.sjm.machlearn.dataset.conversions.ID3Converter;
import edu.wisc.sjm.machlearn.dataset.featuredataset.FeatureDataSet;
import edu.wisc.sjm.machlearn.featureselection.FSDataSet;
import edu.wisc.sjm.machlearn.policy.FDSPreProcessor;
import java.io.OutputStream;

/* loaded from: input_file:builds/machlearn_install.jar:machlearn.jar:edu/wisc/sjm/machlearn/policy/fdspreprocessor/selection/filter/CMIM.class */
public class CMIM extends FDSPreProcessor {
    protected int keep = 50;
    protected IntVector mask;

    public void setKeep(String str) {
        setKeep(Integer.parseInt(str));
    }

    public void setKeep(int i) {
        this.keep = i;
    }

    @Override // edu.wisc.sjm.machlearn.policy.FDSPreProcessor
    public void train(FeatureDataSet featureDataSet) {
        FeatureDataSet convert = new ID3Converter().convert(featureDataSet);
        int[][] iArr = new int[convert.size()][convert.numFeatures()];
        for (int i = 0; i < iArr.length; i++) {
            for (int i2 = 0; i2 < iArr[0].length; i2++) {
                iArr[i][i2] = convert.get(i, i2).getValueId();
            }
        }
        this.mask = sCMIM(iArr, convert.getOutputIndex(), this.keep);
        System.out.println("CMIM features selected:" + this.mask.size());
        for (int i3 = 0; i3 < this.mask.size(); i3++) {
            System.out.println(i3 + ")" + convert.getFeatureId(this.mask.get(i3)).printName());
        }
    }

    public void printSelectedFeatures(FeatureDataSet featureDataSet, IntVector intVector, OutputStream outputStream) throws Exception {
    }

    @Override // edu.wisc.sjm.machlearn.policy.FDSPreProcessor
    public FeatureDataSet process(FeatureDataSet featureDataSet) {
        return FSDataSet.applyMask(featureDataSet, this.mask.getIntValues());
    }

    public static double entropy(int[][] iArr, int i) {
        int length = iArr.length;
        int[] iArr2 = new int[2];
        for (int[] iArr3 : iArr) {
            int i2 = iArr3[i];
            iArr2[i2] = iArr2[i2] + 1;
        }
        double log2 = JMath.log2(length);
        for (int i3 = 0; i3 < 2; i3++) {
            log2 -= JMath.xlog2(iArr2[i3]) / length;
        }
        return log2;
    }

    public static double entropy(int[][] iArr, int i, int i2) {
        int length = iArr.length;
        int[][] iArr2 = new int[2][2];
        double log2 = JMath.log2(length);
        for (int i3 = 0; i3 < length; i3++) {
            int[] iArr3 = iArr2[iArr[i3][i]];
            int i4 = iArr[i3][i2];
            iArr3[i4] = iArr3[i4] + 1;
        }
        for (int i5 = 0; i5 < 2; i5++) {
            for (int i6 = 0; i6 < 2; i6++) {
                log2 -= JMath.xlog2(iArr2[i5][i6]) / length;
            }
        }
        return log2;
    }

    public static double entropy(int[][] iArr, int i, int i2, int i3) {
        int length = iArr.length;
        int[][][] iArr2 = new int[2][2][2];
        double log2 = JMath.log2(length);
        for (int i4 = 0; i4 < length; i4++) {
            int[] iArr3 = iArr2[iArr[i4][i]][iArr[i4][i2]];
            int i5 = iArr[i4][i3];
            iArr3[i5] = iArr3[i5] + 1;
        }
        for (int i6 = 0; i6 < 2; i6++) {
            for (int i7 = 0; i7 < 2; i7++) {
                for (int i8 = 0; i8 < 2; i8++) {
                    log2 -= JMath.xlog2(iArr2[i6][i7][i8]) / length;
                }
            }
        }
        return log2;
    }

    public static double mut_inf(int[][] iArr, int i, int i2) {
        return (entropy(iArr, i) + entropy(iArr, i2)) - entropy(iArr, i, i2);
    }

    public static double cond_mut_inf(int[][] iArr, int i, int i2, int i3) {
        return ((entropy(iArr, i, i3) - entropy(iArr, i3)) - entropy(iArr, i, i2, i3)) + entropy(iArr, i2, i3);
    }

    public static IntVector sCMIM(int[][] iArr, int i, int i2) {
        int length = iArr[0].length;
        int i3 = i2 == -1 ? length : i2;
        double[] dArr = new double[length];
        for (int i4 = 0; i4 < length; i4++) {
            dArr[i4] = mut_inf(iArr, i, i4);
        }
        IntVector intVector = new IntVector();
        intVector.add(i);
        dArr[i] = Double.NEGATIVE_INFINITY;
        int i5 = 0;
        while (true) {
            if (i5 >= Math.min(i3, length - 1)) {
                break;
            }
            int amax = JMath.amax(dArr);
            if (amax == -1) {
                System.out.println("no best score.");
                break;
            }
            dArr[amax] = Double.NEGATIVE_INFINITY;
            intVector.add(amax);
            for (int i6 = 0; i6 < length; i6++) {
                if (!Double.isInfinite(dArr[i6])) {
                    dArr[i6] = Math.min(dArr[i6], cond_mut_inf(iArr, i, i6, amax));
                }
            }
            i5++;
        }
        return intVector;
    }

    public static int[] fCMIM(int[][] iArr, int i, int i2) {
        int length = iArr[0].length;
        double[] dArr = new double[length - 1];
        int[] iArr2 = new int[length - 1];
        int i3 = 0;
        for (int i4 = 0; i4 < length; i4++) {
            if (i4 != i) {
                dArr[i3] = mut_inf(iArr, i, i3);
                iArr2[i3] = 0;
                i3++;
            }
        }
        int[] iArr3 = new int[i2];
        for (int i5 = 0; i5 < i2; i5++) {
            double d = 0.0d;
            int i6 = 0;
            for (int i7 = 0; i7 < length; i7++) {
                if (i7 != i) {
                    while (dArr[i6] > d && iArr2[i6] < i5) {
                        iArr2[i6] = iArr2[i6] + 1;
                        dArr[i6] = Math.min(dArr[i6], cond_mut_inf(iArr, i, i7, iArr3[iArr2[i6]]));
                    }
                    if (dArr[i6] > d) {
                        d = dArr[i6];
                        iArr3[i5] = i6;
                    }
                    i6++;
                }
            }
        }
        for (int i8 = 0; i8 < i2; i8++) {
            if (iArr3[i8] >= i) {
                int i9 = i8;
                iArr3[i9] = iArr3[i9] + 1;
            }
        }
        return iArr3;
    }

    @Override // edu.wisc.sjm.machlearn.policy.FDSPreProcessor
    public boolean needScorer() {
        return false;
    }

    @Override // edu.wisc.sjm.machlearn.policy.FDSPreProcessor
    public void setScorer(Scorer scorer) {
    }
}
