package com.sjm.machlearn.classifiers;

import com.sjm.machlearn.dataset.DataSet;
import com.sjm.machlearn.dataset.Example;
import com.sjm.machlearn.dataset.Feature;
import com.sjm.machlearn.util.MainClass;
import com.sjm.machlearn.util.Util;
import java.io.FileWriter;
import java.io.PrintWriter;
import java.util.Vector;

/* loaded from: input_file:com/sjm/machlearn/classifiers/Classifier.class */
public abstract class Classifier extends MainClass {
    public static final int Fold10Validation = 0;
    public static final int Random10Validation = 1;
    public static final int JackKnifeValidation = 2;

    public abstract Feature classify(Example example);

    public abstract Classifier cloneClassifier();

    public void doTune(DataSet dataSet, Object[] objArr, int i) throws Exception {
        int[] iArr = (int[]) objArr[0];
        double[] dArr = new double[iArr.length];
        switch (i) {
            case 0:
            default:
                return;
            case 1:
                DataSet[] splitRandom = dataSet.splitRandom(10.0d);
                DataSet dataSet2 = splitRandom[1];
                DataSet dataSet3 = splitRandom[0];
                for (int i2 = 0; i2 < iArr.length; i2++) {
                    setParameter(0, new Integer(iArr[i2]));
                    train(dataSet2);
                    dArr[i2] = getAccuracy(dataSet3);
                }
                setParameter(0, new Integer(iArr[Util.argmax(dArr)]));
                train(dataSet);
                return;
        }
    }

    public void generateROCCurve(DataSet dataSet, DataSet dataSet2, String str) throws Exception {
        double[] dArr = new double[dataSet2.size()];
        boolean[] zArr = new boolean[dataSet2.size()];
        train(dataSet);
        for (int i = 0; i < dArr.length; i++) {
            dArr[i] = getExampleWeight(dataSet2.get(i));
            zArr[i] = dataSet2.get(i).getOutputFeature().getValueId() == 0;
        }
        for (int i2 = 0; i2 < dArr.length - 1; i2++) {
            for (int i3 = i2 + 1; i3 < dArr.length; i3++) {
                if (dArr[i2] > dArr[i3]) {
                    Util.swap(dArr, i2, i3);
                    Util.swap(zArr, i2, i3);
                }
            }
        }
        boolean z = zArr[0];
        Vector vector = new Vector();
        for (int i4 = 0; i4 < zArr.length; i4++) {
            if (zArr[i4] != z) {
                vector.add(new Double(i4 - 0.5d));
                z = zArr[i4];
            }
        }
        PrintWriter printWriter = null;
        try {
            printWriter = new PrintWriter(new FileWriter(str));
        } catch (Exception e) {
            internalError(e);
        }
        printWriter.println("threshold\ttprate\tfprate");
        printWriter.println("0\t0\t0");
        for (int i5 = 0; i5 < vector.size(); i5++) {
            int i6 = 0;
            int i7 = 0;
            int i8 = 0;
            int i9 = 0;
            double doubleValue = ((Double) vector.get(i5)).doubleValue();
            for (int i10 = 0; i10 < zArr.length; i10++) {
                if (i10 > doubleValue) {
                    if (zArr[i10]) {
                        i6++;
                    } else {
                        i7++;
                    }
                } else if (zArr[i10]) {
                    i8++;
                } else {
                    i9++;
                }
            }
            printWriter.println(new StringBuffer(String.valueOf(doubleValue)).append("\t").append(i6 / (i6 + i8)).append("\t").append(i7 / (i9 + i7)).toString());
        }
        printWriter.close();
    }

    public double getAccuracy(DataSet dataSet) {
        int i = 0;
        for (int i2 = 0; i2 < dataSet.size(); i2++) {
            Example example = dataSet.get(i2);
            if (example.correctOutput(classify(example))) {
                i++;
            }
        }
        return (i / dataSet.size()) * 100.0d;
    }

    public Example[][] getBothCat(DataSet dataSet) {
        Vector vector = new Vector();
        Vector vector2 = new Vector();
        for (int i = 0; i < dataSet.size(); i++) {
            Example example = dataSet.get(i);
            if (example.correctOutput(classify(example))) {
                vector2.add(example);
            } else {
                vector.add(example);
            }
        }
        Example[][] exampleArr = {new Example[vector.size()], new Example[vector2.size()]};
        vector.copyInto(exampleArr[0]);
        vector2.copyInto(exampleArr[1]);
        return exampleArr;
    }

    public double getExampleWeight(Example example) {
        return 0.0d;
    }

    public Example[] getMisCat(DataSet dataSet) {
        Vector vector = new Vector();
        for (int i = 0; i < dataSet.size(); i++) {
            Example example = dataSet.get(i);
            if (!example.correctOutput(classify(example))) {
                vector.add(example);
            }
        }
        Example[] exampleArr = new Example[vector.size()];
        vector.copyInto(exampleArr);
        return exampleArr;
    }

    public abstract String printClassifier();

    public abstract void setParameter(int i, Object obj);

    public abstract void train(DataSet dataSet) throws Exception;
}
