// Written by: Young Wu
// Attribution: Ainur Ainabekova's CS540 P3 Solution 2020

import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.util.HashMap;
import java.util.Map;
import java.util.Random;

public class P3_public {
	private static final char[] ALPHABET = " abcdefghijklmnopqrstuvwxyz".toCharArray(); // space + alphabet
	private static final int SENTENCE_LENGTH = 1000; // characters
	private static final int SMOOTHING = 0; // smoothing
	private static final double PRIOR = 0.5; // probability of fake script
	
	public static void main(String[] args) throws IOException {
		String script = new String(Files.readAllBytes(Paths.get("Interstellar.txt")));
		script = process_text(script);
		HashMap<String, Integer> unigram_count = ngram(script, 1);
		HashMap<String, Integer> bigram_count = ngram(script, 2);
		HashMap<String, Integer> trigram_count = ngram(script, 3);
		HashMap<String, Double> unigram = estimateTransitionProbabilities(script, 1, unigram_count, null);
		HashMap<String, Double> bigram = estimateTransitionProbabilities(script, 2, bigram_count, unigram_count);
		HashMap<String, Double> trigram = estimateTransitionProbabilities(script, 3, trigram_count, bigram_count);
		String[] sentences = generateSentences(trigram, bigram);
		HashMap<String, Double> posterior = computePosterior(unigram, unigram);
		int[] prediction = naive_bayes(sentences, posterior);
		for (int i = 0; i < prediction.length; i ++) System.out.print(prediction[i] + ", ");
	}

	public static String process_text(String script) {
		return script.toLowerCase().replaceAll("[^a-z ]", " ").replaceAll(" +", " ");
	}

	public static HashMap<String, Integer> ngram(String script, int n) {
		HashMap<String, Integer> counts = new HashMap<String, Integer>();
		int count = 0;
		if (n == 1) {
			for (int i = 0; i < ALPHABET.length; i++) {
				count = script.length() - script.replace(String.valueOf(ALPHABET[i]), "").length();
				counts.put(String.valueOf(ALPHABET[i]), count);
			}
		} else if (n == 2) {
			for (int i = 0; i < ALPHABET.length; i++) {
				for (int j = 0; j < ALPHABET.length; j++) {
					count = (script.length()
							- script.replace(String.valueOf(ALPHABET[i]) + String.valueOf(ALPHABET[j]), "").length())
							/ 2;
					counts.put(String.valueOf(ALPHABET[i]) + String.valueOf(ALPHABET[j]), count);
				}
			}
		} else if (n == 3) {
			for (int i = 0; i < ALPHABET.length; i++) {
				for (int j = 0; j < ALPHABET.length; j++) {
					for (int k = 0; k < ALPHABET.length; k++) {
						count = (script.length() - script.replace(
								String.valueOf(ALPHABET[i]) + String.valueOf(ALPHABET[j]) + String.valueOf(ALPHABET[k]),
								"").length()) / 3;
						counts.put(
								String.valueOf(ALPHABET[i]) + String.valueOf(ALPHABET[j]) + String.valueOf(ALPHABET[k]),
								count);
					}
				}
			}
		}
		return counts;
	}

	public static HashMap<String, Double> estimateTransitionProbabilities(String script, int n, HashMap<String, Integer> ngram, HashMap<String, Integer> n1gram) {
		HashMap<String, Double> probabilities = new HashMap<String, Double>();
		double probability = 0;
		int count = 0;
		for (String key : ngram.keySet()) {
			count = (n == 1 ? script.length() : n1gram.get(key.substring(0, n - 1)));
			probability = (ngram.get(key) + SMOOTHING) / (double) (count + SMOOTHING * ALPHABET.length);
			probabilities.put(key, probability);
		}
		return probabilities;
	}

	public static String[] generateSentences(HashMap<String, Double> trigram, HashMap<String, Double> bigram) {
		String[] sentences = new String[ALPHABET.length - 1];
		for (int i = 1; i < ALPHABET.length; i++) {
			StringBuilder sb = new StringBuilder();
			sb.append(String.valueOf(ALPHABET[i]));
			while (sb.length() < SENTENCE_LENGTH) {
				double[] cdf;
				if (sb.length() == 1
						|| bigram.get(sb.toString().substring(sb.length() - 2, sb.length())) == 0.0) {
					cdf = computeCDF(String.valueOf(sb.toString().charAt(0)), bigram);
				} else 
					cdf = computeCDF(sb.toString().substring(sb.length() - 2, sb.length()), trigram);
				Random r = new Random();
				double randomValue = cdf[cdf.length - 1] * r.nextDouble();
				char letter = findNextLetter(randomValue, cdf);
				sb.append(letter);
			}
			sentences[i - 1] = sb.toString();
		}
		return sentences;
	}

	public static double[] computeCDF(String c, Map<String, Double> transition) {
		double[] cdf = new double[27];
		double sum = 0;
		for (int i = 0; i < 27; i++) {
			sum = sum + transition.get(c + String.valueOf(ALPHABET[i]));
			cdf[i] = sum;
		}
		return cdf;
	}

	public static char findNextLetter(double u, double[] cdf) {
		for (int i = 0; i < ALPHABET.length; i++) {
			if (u <= cdf[i]) {
				return ALPHABET[i];
			}
		}
		return ALPHABET[ALPHABET.length - 1];
	}
	
	public static HashMap<String, Double> computePosterior(HashMap<String, Double> unigram_real, HashMap<String, Double> unigram_fake) {
		HashMap<String, Double> probabilities = new HashMap<String, Double>();
		double probability = 0;
		for (String key: unigram_real.keySet()) {
			probability = PRIOR * unigram_fake.get(key) / (PRIOR * unigram_fake.get(key) + (1 - PRIOR) * unigram_real.get(key));
			probabilities.put(key, probability);
		}
		return probabilities;
	}
	
	public static int[] naive_bayes(String[] sentences, HashMap<String, Double> posterior) {
		int[] prediction = new int[sentences.length];
		double probability = 0;
		double log_prob_real = 0;
		double log_prob_fake = 0;
		for (int i = 0; i < sentences.length; i ++) {
			log_prob_real = 0;
			log_prob_fake = 0;
			for (int j = 0; j < sentences[i].length(); j ++) {
				probability = posterior.get(String.valueOf(sentences[i].charAt(j)));
				log_prob_real += Math.log(1 - probability);
				log_prob_fake += Math.log(probability);
			}
			if (log_prob_real < log_prob_fake) prediction[i] = 1;
		}
		return prediction;
	}
	
	public static void print_probability(double[] prob) {
		double max = 0;
		double min = 1;
		int imax = 0;
		int imin = 0;
		double p = 0;
		double sum = 0;
		double scale = Math.pow(10, 4);
		for (int i = 0; i < prob.length; i ++) {
			p = Math.round(prob[i] * scale) / scale;
			if (p > max) {
				max = p;
				imax = i;
			}
			if (p < min && p > 0) {
				min = p;
				imin = i;
			}
			sum += p;
			prob[i] = p;
		}
		if (sum > 1) prob[imax] -= (sum - 1);
		if (sum < 1) prob[imin] += (1 - sum);
		System.out.printf(", %.4f", prob[0]);
		for (int i = 1; i < prob.length; i ++)
			System.out.printf(", %.4f", prob[i]);
	}
}