// Written by: Young Wu
// Attribution: Ainur Ainabekova's CS540 P2 Solution 2020

import java.io.File;
import java.io.FileNotFoundException;
import java.util.Scanner;

public class P2_public {
	public static int[] featureList = new int[] { 5, 10, 4, 6, 2, 8 };
	public static int[] thresholdList = new int[] {1, 2, 3, 4, 5, 6, 7, 8, 9};
	public static int stumpFeature = 9;
	public static int targetDepth = Integer.MAX_VALUE;

	public static void main(String[] args) {
		int[][] trainData = cleanAndTransformData("breast-cancer-wisconsin.data");
		DecTreeNode root = buildTree(trainData);
		System.out.println(root.maxDepth());
	}

	// Read data
	public static int[][] cleanAndTransformData(String file) {
		int n = 0;
		int m = 0;
		try {
			Scanner scan = new Scanner(new File(file));
			while (scan.hasNextLine()) {
				String line = scan.nextLine();
				if (line.indexOf('?') == -1)
					n++;
			}
			scan.close();
		} catch (FileNotFoundException e) {
			System.out.println("File with the name " + file + " cannot be read!");
		}
		int[][] cleanData = new int[n][m];
		try {
			Scanner scan = new Scanner(new File(file));
			int i = 0;
			while (scan.hasNextLine()) {
				String line = scan.nextLine();
				if (line.indexOf('?') == -1) {
					String[] split = line.split(",");
					int[] instance = new int[split.length];
					for (int j = 0; j < split.length; j++)
						instance[j] = Integer.parseInt(split[j]);
					cleanData[i] = instance;
				}
				i++;
			}
			scan.close();
		} catch (FileNotFoundException e) {
			System.out.println("File with the name " + file + " cannot be read!");
		}
		return cleanData;
	}

	// Compute Entropy
	public static double entropy(double p0) {
		if (p0 == 0 || p0 == 1)
			return 0;
		double p1 = 1 - p0;
		return -(p0 * Math.log(p0) / Math.log(2) + p1 * Math.log(p1) / Math.log(2));
	}

	// Compute Information Gain
	public static double informationGain(int[][] dataSet, int feature, int threshold) {
		int dataSize = dataSet.length;
		// H(Y)
		int count = 0;
		for (int[] data : dataSet) {
			if (data[data.length - 1] == 2)
				count++;
		}
		double Hy = entropy(1.0 * count / dataSize);
		// H(Y|X)
		double Hyx = 0;
		int countLess = 0;
		int countGreater = 0;
		int countLessAndPositive = 0;
		int countGreaterAndPositive = 0;
		for (int[] data : dataSet) {
			if (data[feature] <= threshold) {
				countLess++;
				if (data[data.length - 1] == 2)
					countLessAndPositive++;
			} else {
				countGreater++;
				if (data[data.length - 1] == 2)
					countGreaterAndPositive++;
			}
		}
		double prob1 = 1.0 * countLess / dataSize;
		double prob2 = 1.0 * countGreater / dataSize;
		if (prob1 > 0)
			Hyx = Hyx + prob1 * entropy((1.0 * countLessAndPositive) / countLess);
		if (prob2 > 0)
			Hyx = Hyx + prob2 * entropy((1.0 * countGreaterAndPositive) / countGreater);
		// InfoGain(Y|X)
		return Hy - Hyx;
	}

	// Build Tree
	private static DecTreeNode buildTree(int[][] dataSet) {
		return buildTree(dataSet, 0);
	}

	private static DecTreeNode buildTree(int[][] dataSet, int depth) {
		int numData = dataSet.length;
		int bestAttr = -1;
		int bestThres = Integer.MIN_VALUE;
		double bestScore = Double.NEGATIVE_INFINITY;
		boolean leaf = depth > targetDepth;
		DecTreeNode node = null;
		if (!leaf) {
			for (int j = 0; j < featureList.length; j++) {
				for (int k = 0; k < thresholdList.length; k++) {
					double score = informationGain(dataSet, featureList[j], thresholdList[k]);
					if (score > bestScore) {
						bestScore = score;
						bestAttr = featureList[j];
						bestThres = thresholdList[k];
					}
				}
			}
			if (bestScore <= 0)
				leaf = true;
		}
		if (!leaf) {
			// Split data
			int nLeft = 0;
			int nRight = 0;
			for (int[] data : dataSet) {
				if (data[bestAttr] <= bestThres)
					nLeft++;
				else
					nRight++;
			}
			int[][] leftList = new int[nLeft][0];
			int[][] rightList = new int[nRight][0];
			if (nLeft == 0 || nRight == 0)
				leaf = true;
			else {
				int left = 0;
				int right = 0;
				for (int[] data : dataSet) {
					if (data[bestAttr] <= bestThres) {
						leftList[left] = data;
						left++;
					} else {
						rightList[right] = data;
						right++;
					}
				}
			}
			// Create subtree
			if (!leaf) {
				node = new DecTreeNode(-1, bestAttr, bestThres);
				node.left = buildTree(leftList, depth + 1);
				node.right = buildTree(rightList, depth + 1);
			}
		}
		// Create label
		if (leaf) {
			int count = 0;
			for (int[] data : dataSet) {
				if (data[data.length - 1] == 2)
					count += 1;
			}
			if (count >= numData - count)
				node = new DecTreeNode(2, -1, -1);
			else
				node = new DecTreeNode(4, -1, -1);
		}
		return node;
	}
}

class DecTreeNode {
	public int feature;
	public int threshold;
	public DecTreeNode left = null;
	public DecTreeNode right = null;
	public int classLabel;

	public DecTreeNode(int classLabel, int feature, int threshold) {
		this.classLabel = classLabel;
		this.feature = feature;
		this.threshold = threshold;
	}

	public boolean isLeaf() {
		return this.left == null && this.right == null;
	}

	public int maxDepth() {
		if (this.isLeaf())
			return 1;
		else {
			int leftDepth = this.left.maxDepth();
			int rightDepth = this.right.maxDepth();
			return Math.max(leftDepth, rightDepth) + 1;
		}
	}
	
	public int predict(int[] data) {
		if (this.isLeaf()) return this.classLabel;
		else if (data[this.feature] <= this.threshold) return this.left.predict(data);
		else return this.right.predict(data);
	}

	public void print() {
		if (this.isLeaf())
			System.out.println("return " + this.classLabel);
		else {
			System.out.println("if (x" + feature + ") <= " + threshold);
			this.left.print();
			System.out.println("else");
			this.right.print();
		}
	}
}