package edu.wisc.sjm.machlearn.classifiers.bayes.tan;

import edu.wisc.sjm.machlearn.classifiers.Classifier;
import edu.wisc.sjm.machlearn.dataset.Example;
import edu.wisc.sjm.machlearn.dataset.Feature;
import edu.wisc.sjm.machlearn.dataset.featuredataset.FeatureDataSet;
import edu.wisc.sjm.machlearn.policy.fdspreprocessor.discretize.BinaryDiscretizeIG;

/* loaded from: input_file:builds/machlearn_install.jar:machlearn.jar:edu/wisc/sjm/machlearn/classifiers/bayes/tan/JMaxTan.class */
public class JMaxTan extends Classifier {
    BinaryDiscretizeIG discretize;
    int[][] TrainMatrix;
    int[] attributeValues;
    double threshold;
    BayesNet bn;
    Datapairs[] attributes;
    CondProb[] netcpts;

    public JMaxTan() {
        this(0.5d);
    }

    public JMaxTan(double d) {
        this.discretize = new BinaryDiscretizeIG();
        setThreshold(d);
    }

    public void setThreshold(double d) {
        this.threshold = d;
    }

    public void setThreshold(String str) {
        setThreshold(Double.parseDouble(str));
    }

    @Override // edu.wisc.sjm.machlearn.classifiers.Classifier
    public void train_(FeatureDataSet featureDataSet) {
        this.discretize.train(featureDataSet);
        FeatureDataSet process = this.discretize.process(featureDataSet);
        this.TrainMatrix = makeDataMatrix(process);
        this.attributeValues = makeAValueArray(process);
        this.bn = new BayesNet();
        trainTan(this.TrainMatrix, process.numFeatures() - 1, this.attributeValues);
    }

    public static int[] makeAValueArray(FeatureDataSet featureDataSet) {
        int[] iArr = new int[featureDataSet.numFeatures()];
        int i = 0;
        for (int i2 = 0; i2 < featureDataSet.numFeatures(); i2++) {
            if (i2 != featureDataSet.getOutputIndex()) {
                iArr[i] = featureDataSet.get(0, i2).getFeatureId().numValues();
                i++;
            }
        }
        iArr[featureDataSet.numFeatures() - 1] = featureDataSet.getOutputFeatureId().numValues();
        return iArr;
    }

    public static int[][] makeDataMatrix(Example example) {
        int[][] iArr = new int[1][example.numFeatures()];
        int i = 0;
        for (int i2 = 0; i2 < example.numFeatures(); i2++) {
            if (i2 != example.getOutputIndex()) {
                iArr[0][i] = example.get(i2).getValueId();
                i++;
            }
        }
        iArr[0][iArr.length - 1] = example.getOutputFeature().getValueId();
        return iArr;
    }

    public static int[][] makeDataMatrix(FeatureDataSet featureDataSet) {
        int[][] iArr = new int[featureDataSet.size()][featureDataSet.numFeatures()];
        for (int i = 0; i < featureDataSet.size(); i++) {
            int i2 = 0;
            for (int i3 = 0; i3 < featureDataSet.numFeatures(); i3++) {
                if (i3 != featureDataSet.getOutputIndex()) {
                    iArr[i][i2] = featureDataSet.get(i, i3).getValueId();
                    i2++;
                }
            }
            iArr[i][featureDataSet.numFeatures() - 1] = featureDataSet.get(i, featureDataSet.getOutputIndex()).getValueId();
        }
        return iArr;
    }

    public void trainTan(int[][] iArr, int i, int[] iArr2) {
        int[] iArr3 = new int[iArr2.length];
        int[] iArr4 = new int[iArr2.length];
        this.attributes = new Datapairs[iArr2.length];
        for (int i2 = 0; i2 < this.attributes.length; i2++) {
            if (i2 == i) {
                this.attributes[i2] = new Datapairs("c" + i2, iArr2[i2], "output");
            } else {
                this.attributes[i2] = new Datapairs("r" + i2, iArr2[i2], "discrete");
            }
        }
        int[][] maxTanGraph = MaxTan.maxTanGraph(iArr, i, this.attributes);
        this.bn.setDegrees(maxTanGraph, iArr3, iArr4);
        this.netcpts = this.bn.buildCPT(maxTanGraph, iArr, iArr3, this.attributes, i);
    }

    public void classifyTan(int[][] iArr, int i, int[] iArr2, double[] dArr) {
        BayesNet.accuracyScore(iArr, this.netcpts, dArr, iArr2, this.attributes, i);
    }

    @Override // edu.wisc.sjm.machlearn.classifiers.Classifier
    public Feature[] classify(FeatureDataSet featureDataSet) {
        Feature[] featureArr = new Feature[featureDataSet.size()];
        int[][] makeDataMatrix = makeDataMatrix(featureDataSet);
        double[] dArr = new double[featureDataSet.size()];
        int[] iArr = new int[featureDataSet.size()];
        for (int i = 0; i < featureDataSet.size(); i++) {
            featureArr[i] = (Feature) featureDataSet.getOutputFeature(i).clone();
            iArr[i] = featureArr[i].getValueId();
        }
        classifyTan(makeDataMatrix, featureDataSet.numFeatures() - 1, iArr, dArr);
        for (int i2 = 0; i2 < featureDataSet.size(); i2++) {
            try {
                if (dArr[i2] > this.threshold) {
                    featureArr[i2].setValueId(1);
                } else {
                    featureArr[i2].setValueId(0);
                }
            } catch (Exception e) {
                internalError(e);
            }
        }
        return featureArr;
    }

    @Override // edu.wisc.sjm.machlearn.classifiers.Classifier
    public double[][] getDistribution(FeatureDataSet featureDataSet) {
        FeatureDataSet process = this.discretize.process(featureDataSet);
        int[][] makeDataMatrix = makeDataMatrix(process);
        double[] dArr = new double[process.size()];
        double[][] dArr2 = new double[process.size()][featureDataSet.getOutputFeatureId().numValues()];
        int[] iArr = new int[process.size()];
        for (int i = 0; i < process.size(); i++) {
            iArr[i] = process.getOutputFeature(i).getValueId();
        }
        classifyTan(makeDataMatrix, process.numFeatures() - 1, iArr, dArr);
        for (int i2 = 0; i2 < process.size(); i2++) {
            dArr2[i2][1] = dArr[i2];
            dArr2[i2][0] = 1.0d - dArr[i2];
        }
        return dArr2;
    }

    @Override // edu.wisc.sjm.machlearn.classifiers.Classifier
    public double[] getDistribution(Example example) {
        Example process = this.discretize.process(example);
        int[][] makeDataMatrix = makeDataMatrix(process);
        double[] dArr = new double[1];
        classifyTan(makeDataMatrix, process.numFeatures() - 1, new int[]{process.getOutputFeature().getValueId()}, dArr);
        double[] dArr2 = {1.0d - dArr[1], dArr[0]};
        return dArr;
    }

    @Override // edu.wisc.sjm.machlearn.classifiers.Classifier
    public Feature classify_(Example example) {
        Example process = this.discretize.process(example);
        Feature feature = (Feature) process.getOutputFeature().clone();
        int[][] makeDataMatrix = makeDataMatrix(process);
        int[] iArr = {feature.getValueId()};
        double[] dArr = new double[1];
        classifyTan(makeDataMatrix, process.numFeatures() - 1, iArr, dArr);
        try {
            if (dArr[0] > this.threshold) {
                feature.setValueId(1);
            } else {
                feature.setValueId(0);
            }
        } catch (Exception e) {
            internalError(e);
        }
        return feature;
    }

    @Override // edu.wisc.sjm.machlearn.classifiers.Classifier
    public Classifier cloneClassifier() {
        return null;
    }

    @Override // edu.wisc.sjm.machlearn.classifiers.Classifier
    public String printClassifier() {
        return "TAN";
    }

    @Override // edu.wisc.sjm.machlearn.classifiers.Classifier
    public void setParameter(int i, Object obj) {
    }
}
