package edu.wisc.sjm.machlearn.policy.fdspreprocessor.misc;

import edu.wisc.sjm.jutil.misc.DoubleUtil;
import edu.wisc.sjm.jutil.vectors.IntVector;
import edu.wisc.sjm.machlearn.Scorer;
import edu.wisc.sjm.machlearn.dataset.Example;
import edu.wisc.sjm.machlearn.dataset.conversions.ID3Converter;
import edu.wisc.sjm.machlearn.dataset.featuredataset.FeatureDataSet;
import edu.wisc.sjm.machlearn.policy.FDSPreProcessor;
import edu.wisc.sjm.machlearn.util.Util;
import java.io.FileOutputStream;
import java.io.PrintStream;

/* loaded from: input_file:builds/machlearn_install.jar:machlearn.jar:edu/wisc/sjm/machlearn/policy/fdspreprocessor/misc/PrintFeatureStats.class */
public class PrintFeatureStats extends FDSPreProcessor {
    protected String output_file;
    protected int type = 1;
    protected boolean print_credit = false;

    public void setOutput(String str) {
        this.output_file = str;
    }

    @Override // edu.wisc.sjm.machlearn.policy.FDSPreProcessor
    public void train(FeatureDataSet featureDataSet) {
    }

    public void setType(int i) {
        this.type = i;
    }

    public void setType(String str) {
        if (str.equals("max")) {
            setType(0);
            return;
        }
        if (str.equals("pos")) {
            setType(1);
        } else if (str.equals("neg")) {
            setType(2);
        } else {
            setType(Integer.parseInt(str));
        }
    }

    @Override // edu.wisc.sjm.machlearn.policy.FDSPreProcessor
    public FeatureDataSet process(FeatureDataSet featureDataSet) {
        System.out.println("PrintFeatureStats : start().\n");
        try {
            System.out.println("fds.size():" + featureDataSet.size());
            System.out.println("converting features...\n");
            FeatureDataSet convert = new ID3Converter().convert(featureDataSet);
            System.out.println("fdsd.size():" + convert.size());
            IntVector[] intVectorArr = new IntVector[convert.size()];
            for (int i = 0; i < convert.size(); i++) {
                intVectorArr[i] = new IntVector(4);
            }
            for (int i2 = 0; i2 < convert.numFeatures(); i2++) {
                IntVector[] exampleCounts = getExampleCounts(getCounts(convert, i2), convert.size());
                for (int i3 = 0; i3 < convert.size(); i3++) {
                    intVectorArr[i3].sum(exampleCounts[i3]);
                }
            }
            if (this.output_file != null) {
                printStats(convert, this.output_file);
            }
            printStats(convert, null);
            for (int i4 = 0; i4 < convert.size(); i4++) {
                System.out.println(convert.getExample(i4).getName() + ":" + Util.printArrayRow(intVectorArr[i4].getIntValues()));
            }
        } catch (Exception e) {
            e.printStackTrace();
            internalError(e);
        }
        System.out.println("done printfeature stats..\n");
        return featureDataSet;
    }

    public void printStats(FeatureDataSet featureDataSet, String str) throws Exception {
        String[] values = featureDataSet.getOutputFeatureId(0).getValues();
        PrintStream printStream = str == null ? System.out : new PrintStream(new FileOutputStream(str));
        printStream.print("I\tIG\tA\tR\tP\t\t");
        printStream.println(String.valueOf(values[0]) + ",f\t" + values[0] + ",t\t" + values[1] + ",f\t" + values[1] + ",t\tFeature");
        for (int i = 0; i < featureDataSet.numFeatures(); i++) {
            printStream.print(i + "\t");
            printStream.print(DoubleUtil.printDecimal(featureDataSet.getInfoGain(i), 3));
            printStream.print("\t");
            printStream.print(DoubleUtil.printDecimal(DiscreteStats.getAccuracy(featureDataSet, i, this.type), 3));
            printStream.print("\t");
            printStream.print(DoubleUtil.printDecimal(DiscreteStats.getRecall(featureDataSet, i, this.type), 3));
            printStream.print("\t");
            printStream.print(DoubleUtil.printDecimal(DiscreteStats.getPrecision(featureDataSet, i, this.type), 3));
            printStream.print("\t");
            IntVector[] counts = getCounts(featureDataSet, i);
            for (int i2 = 0; i2 < 4; i2++) {
                printStream.print("\t" + counts[i2].size());
                if (this.print_credit) {
                    printStream.print(":");
                    printStream.print(Util.printArrayRow(counts[i2].getIntValues()));
                    printStream.print("\t");
                }
            }
            printStream.print("\t");
            printStream.println(featureDataSet.getFeatureId(i).printName());
        }
        if (str != null) {
            printStream.close();
        }
    }

    public static IntVector[] getExampleCounts(IntVector[] intVectorArr, int i) {
        IntVector[] intVectorArr2 = new IntVector[i];
        for (int i2 = 0; i2 < i; i2++) {
            intVectorArr2[i2] = new IntVector(4);
        }
        for (int i3 = 0; i3 < 4; i3++) {
            IntVector intVector = intVectorArr[i3];
            for (int i4 = 0; i4 < intVector.size(); i4++) {
                intVectorArr2[intVector.get(i4)].increment(i3);
            }
        }
        return intVectorArr2;
    }

    public static IntVector[] getCounts(FeatureDataSet featureDataSet, int i) {
        IntVector[] intVectorArr = new IntVector[4];
        for (int i2 = 0; i2 < 4; i2++) {
            intVectorArr[i2] = new IntVector();
        }
        int outputIndex = featureDataSet.getOutputIndex();
        for (int i3 = 0; i3 < featureDataSet.size(); i3++) {
            Example example = featureDataSet.getExample(i3);
            if (example.get(outputIndex).getValueId() == 0 && example.get(i).getValueId() == 0) {
                intVectorArr[0].add(i3);
            }
            if (example.get(outputIndex).getValueId() == 0 && example.get(i).getValueId() == 1) {
                intVectorArr[1].add(i3);
            }
            if (example.get(outputIndex).getValueId() == 1 && example.get(i).getValueId() == 0) {
                intVectorArr[2].add(i3);
            }
            if (example.get(outputIndex).getValueId() == 1 && example.get(i).getValueId() == 1) {
                intVectorArr[3].add(i3);
            }
        }
        return intVectorArr;
    }

    @Override // edu.wisc.sjm.machlearn.policy.FDSPreProcessor
    public boolean needScorer() {
        return false;
    }

    @Override // edu.wisc.sjm.machlearn.policy.FDSPreProcessor
    public void setScorer(Scorer scorer) {
    }
}
