// Written by: Young Wu
// Attribution: Ainur Ainabekova's CS540 P4 Solution 2020

import java.io.File;
import java.io.FileNotFoundException;
import java.util.Arrays;
import java.util.HashSet;
import java.util.Scanner;

public class P4_public {
	private static final int K = 8; // Cluster
	private static final int M = 5; // Number of parameters
	private static final int dataColumn = 5; // Number of columns to omit
	private static final int dataRow = 4; // Number of rows to omit
	private static final int numCountries = 50; // Number of countries
	private static final int numMissing = 3; // Number of missing data allowed
	private static final int maxIter = 10;

	public static void main(String[] args) {
		double[][] data = preprocessData("inflation.csv");
		double[][] param = new double[data.length][M];
		for (int i = 0; i < data.length; i++) {
			param[i][0] = getMu(data[i]);
			param[i][1] = getMu(data[i]);
			param[i][2] = getMu(data[i]);
			param[i][3] = getMu(data[i]);
			param[i][4] = getMu(data[i]);
		}
		param = rescale(param, 0, 1);
		int[] cluster = kMeansCluster(param);
		for (int i = 0; i < param.length; i++)
			System.out.println(cluster[i]);
	}

	public static int[] hierachicalCluster(double[][] x) {
		double[][] d = pairDistance(x);
		double min = 0;
		int argmin1 = 0;
		int argmin2 = 0;
		int c1 = 0;
		int c2 = 0;
		double max = 0;
		for (int i = 0; i < d.length; i++) {
			for (int j = i + 1; j < d.length; j++)
				max = Math.max(max, d[i][j]);
		}
		int[] cluster = new int[x.length];
		int iter = 0;
		for (int i = 0; i < x.length; i++)
			cluster[i] = i;
		while (iter < x.length - K) {
			min = max + 1;
			argmin1 = 0;
			argmin2 = 0;
			for (int i = 0; i < d.length; i++) {
				for (int j = i + 1; j < d.length; j++) {
					if (d[i][j] >= 0 && d[i][j] < min) {
						min = d[i][j];
						argmin1 = i;
						argmin2 = j;
					}
				}
			}
			for (int i = 0; i < d.length; i++) {
				d[argmin1][i] = Math.min(d[argmin1][i], d[argmin2][i]);
				d[argmin2][i] = -1;
				d[i][argmin1] = d[argmin1][i];
				d[i][argmin2] = d[argmin2][i];
				d[argmin1][argmin2] = -1;
				d[i][i] = 0;
			}
			c1 = cluster[argmin1];
			c2 = cluster[argmin2];
			for (int i = 0; i < x.length; i ++) {
				if (cluster[i] == c2) cluster[i] = c1;
			}
			iter++;
		}
		HashSet<Integer> hs = new HashSet<Integer>();
		for (int i = 0; i < cluster.length; i++)
			hs.add(cluster[i]);
		Integer[] order = hs.toArray(new Integer[K]);
		Arrays.sort(order);
		for (int i = 0; i < cluster.length; i++) {
			for (int j = 0; j < K; j++) {
				if (cluster[i] == order[j])
					cluster[i] = j;
			}
		}
		return cluster;
	}

	public static int[] kMeansCluster(double[][] x) {
		double[][] center = new double[K][M];
		double min = 0;
		int argmin = 0;
		double dist = 0;
		int[] cluster = new int[x.length];
		int[] count = new int[K];
		int iter = 0;
		for (int i = 0; i < K; i++) {
			for (int j = 0; j < M; j ++)
				center[i][j] = x[(int) (Math.random() * x.length)][j];
		}
		// You can change this and use another stopping criterion
		while (iter < maxIter) {
			for (int i = 0; i < x.length; i++) {
				min = distance(x[i], center[0]);
				argmin = 0;
				for (int j = 1; j < K; j++) {
					dist = distance(x[i], center[j]);
					if (dist < min) {
						min = dist;
						argmin = j;
					}
				}
				cluster[i] = argmin;
			}
			for (int i = 0; i < K; i++) {
				count[i] = 0;
				for (int j = 0; j < M; j++)
					center[i][j] = 0;
			}
			for (int i = 0; i < x.length; i++) {
				count[cluster[i]]++;
				for (int j = 0; j < M; j++)
					center[cluster[i]][j] += x[i][j];
			}
			for (int i = 0; i < K; i++) {
				for (int j = 0; j < M; j++)
					center[i][j] /= count[i];
			}
			iter++;
		}
		return cluster;
	}

	public static double[][] pairDistance(double[][] x) {
		double[][] d = new double[x.length][x.length];
		for (int i = 0; i < x.length; i++) {
			for (int j = i + 1; j < x.length; j++) {
				d[i][j] = distance(x[i], x[j]);
				d[j][i] = d[i][j];
			}
		}
		return d;
	}

	// The main distance function, you can change between Manhattan and Euclidean
	public static double distance(double[] x, double[] y) {
		double d = 0;
		for (int i = 0; i < Math.min(x.length, y.length); i++)
			d += (x[i] - y[i]) * (x[i] - y[i]);
		return d;
	}

	public static double getMu(double[] list) {
		double mu = 0;
		for (int i = 0; i < list.length; i++)
			mu += list[i];
		return mu / list.length;
	}

	public static double getSigma(double[] list) {
		double mu = getMu(list);
		double sig = 0;
		for (int i = 0; i < list.length; i++)
			sig += (list[i] - mu) * (list[i] - mu);
		return Math.sqrt(sig / list.length);
	}

	public static double getMedian(double[] list) {
		double[] newList = list.clone();
		Arrays.sort(newList);
		return newList[(int) Math.floor(list.length / 2)];
	}

	public static double getBeta(double[] list) {
		double mu = getMu(list);
		double num = 0;
		double den = 0;
		double ht = 0.5 * (list.length + 1) - 1;
		for (int i = 0; i < list.length; i++) {
			num += (list[i] - mu) * (i - ht);
			den += (i - ht) * (i - ht);
		}
		return num / den;
	}

	public static double getRho(double[] list) {
		double mu = getMu(list);
		double num = 0;
		double den = 0;
		for (int i = 0; i < list.length; i++) {
			if (i > 0)
				num += (list[i] - mu) * (list[i - 1] - mu);
			den += (list[i] - mu) * (list[i] - mu);
		}
		return num / den;
	}

	public static double[][] rescale(double[][] list, double min, double max) {
		double[][] scale = new double[list.length][list[0].length];
		double[] mins = new double[list[0].length];
		double[] maxs = new double[list[0].length];
		for (int i = 0; i < scale[0].length; i++) {
			mins[i] = list[0][i];
			maxs[i] = list[0][i];
			for (int j = 0; j < scale.length; j++) {
				mins[i] = Math.min(mins[i], list[j][i]);
				maxs[i] = Math.max(maxs[i], list[j][i]);
			}
		}
		for (int i = 0; i < scale[0].length; i++) {
			for (int j = 0; j < scale.length; j++)
				scale[j][i] = (list[j][i] - mins[i]) / (maxs[i] - mins[i]) * (max - min) + min;
		}
		return scale;
	}

	public static double[][] preprocessData(String file) {
		double[][] globalData = new double[numCountries][1];
		try {
			Scanner sc = new Scanner(new File(file));
			for (int i = 0; i < dataRow; i++) sc.nextLine();
			String[] firstLineColumnNames = sc.nextLine().split(",");
			globalData = new double[numCountries][firstLineColumnNames.length - dataColumn];
			int i = 0;
			while (sc.hasNext() && i < numCountries) {
				String[] line = sc.nextLine().split(",");
				int missing = 0;
				for (int j = 0; j < line.length; j++) {
					if (line[j].equals("")) {
						globalData[i][j - dataColumn] = 0;
						missing++;
					}
					else globalData[i][j - dataColumn] = Double.parseDouble(line[j]);
				}
				if (missing < numMissing) i++;
			}
			sc.close();
		} catch (FileNotFoundException e) {
			System.out.println("The input file cannot be found!");
		}
		return globalData;
	}
}