package edu.umass.cs.mallet.base.classify.tests;

import edu.umass.cs.mallet.base.classify.Classifier;
import edu.umass.cs.mallet.base.classify.ClassifierTrainer;
import edu.umass.cs.mallet.base.classify.DecisionTreeTrainer;
import edu.umass.cs.mallet.base.classify.MaxEntTrainer;
import edu.umass.cs.mallet.base.classify.NaiveBayesTrainer;
import edu.umass.cs.mallet.base.classify.Trial;
import edu.umass.cs.mallet.base.types.Alphabet;
import edu.umass.cs.mallet.base.types.InstanceList;
import edu.umass.cs.mallet.base.util.Random;
import junit.framework.Test;
import junit.framework.TestCase;
import junit.framework.TestSuite;
import junit.textui.TestRunner;

/* loaded from: input_file:edu/umass/cs/mallet/base/classify/tests/TestClassifiers.class */
public class TestClassifiers extends TestCase {
    static Class class$edu$umass$cs$mallet$base$classify$tests$TestClassifiers;

    public TestClassifiers(String str) {
        super(str);
    }

    private static Alphabet dictOfSize(int i) {
        Alphabet alphabet = new Alphabet();
        for (int i2 = 0; i2 < i; i2++) {
            alphabet.lookupIndex(new StringBuffer().append("feature").append(i2).toString());
        }
        return alphabet;
    }

    public void testRandomTrained() {
        ClassifierTrainer[] classifierTrainerArr = {new NaiveBayesTrainer(), new MaxEntTrainer(), new DecisionTreeTrainer()};
        InstanceList[] split = new InstanceList(new Random(1L), dictOfSize(3), new String[]{"class0", "class1", "class2"}, 200).split(new java.util.Random(2L), new double[]{0.5d, 0.5d});
        Classifier[] classifierArr = new Classifier[classifierTrainerArr.length];
        for (int i = 0; i < classifierTrainerArr.length; i++) {
            classifierArr[i] = classifierTrainerArr[i].train(split[0]);
        }
        System.out.println("Accuracy on training set:");
        for (int i2 = 0; i2 < classifierTrainerArr.length; i2++) {
            System.out.println(new StringBuffer().append(classifierArr[i2].getClass().getName()).append(": ").append(new Trial(classifierArr[i2], split[0]).accuracy()).toString());
        }
        System.out.println("Accuracy on testing set:");
        for (int i3 = 0; i3 < classifierTrainerArr.length; i3++) {
            System.out.println(new StringBuffer().append(classifierArr[i3].getClass().getName()).append(": ").append(new Trial(classifierArr[i3], split[1]).accuracy()).toString());
        }
    }

    public static Test suite() {
        Class cls;
        if (class$edu$umass$cs$mallet$base$classify$tests$TestClassifiers == null) {
            cls = class$("edu.umass.cs.mallet.base.classify.tests.TestClassifiers");
            class$edu$umass$cs$mallet$base$classify$tests$TestClassifiers = cls;
        } else {
            cls = class$edu$umass$cs$mallet$base$classify$tests$TestClassifiers;
        }
        return new TestSuite(cls);
    }

    protected void setUp() {
    }

    public static void main(String[] strArr) {
        TestRunner.run(suite());
    }

    static Class class$(String str) {
        try {
            return Class.forName(str);
        } catch (ClassNotFoundException e) {
            throw new NoClassDefFoundError().initCause(e);
        }
    }
}
