// === CS400 Spring 2024 File Header Information ===
// Name: Tony Zhao
// Email: gzhao46@wisc.edu
// Lecturer: Gary Dahl
// Notes to Grader: N/A


import org.junit.Test;
import org.junit.jupiter.api.Assertions;

import java.util.*;

/**
 * This class extends the BaseGraph data structure with additional methods for
 * computing the total cost and list of node data along the shortest path
 * connecting a provided starting to ending nodes. This class makes use of
 * Dijkstra's shortest path algorithm.
 */
public class DijkstraGraph<NodeType, EdgeType extends Number>
        extends BaseGraph<NodeType, EdgeType>
        implements GraphADT<NodeType, EdgeType> {

    /**
     * While searching for the shortest path between two nodes, a SearchNode
     * contains data about one specific path between the start node and another
     * node in the graph. The final node in this path is stored in its node
     * field. The total cost of this path is stored in its cost field. And the
     * predecessor SearchNode within this path is referened by the predecessor
     * field (this field is null within the SearchNode containing the starting
     * node in its node field).
     *
     * SearchNodes are Comparable and are sorted by cost so that the lowest cost
     * SearchNode has the highest priority within a java.util.PriorityQueue.
     */
    protected class SearchNode implements Comparable<SearchNode> {
        public Node node;
        public double cost;
        public SearchNode predecessor;

        public SearchNode(Node node, double cost, SearchNode predecessor) {
            this.node = node;
            this.cost = cost;
            this.predecessor = predecessor;
        }

        public int compareTo(SearchNode other) {
            if (cost > other.cost)
                return +1;
            if (cost < other.cost)
                return -1;
            return 0;
        }
    }

    /**
     * Constructor that sets the map that the graph uses.
     */
    public DijkstraGraph() {
        super(new PlaceholderMap<>());
    }

    /**
     * This helper method creates a network of SearchNodes while computing the
     * shortest path between the provided start and end locations. The
     * SearchNode that is returned by this method is represents the end of the
     * shortest path that is found: it's cost is the cost of that shortest path,
     * and the nodes linked together through predecessor references represent
     * all of the nodes along that shortest path (ordered from end to start).
     *
     * @param start the data item in the starting node for the path
     * @param end   the data item in the destination node for the path
     * @return SearchNode for the final end node within the shortest path
     * @throws NoSuchElementException when no path from start to end is found
     *                                or when either start or end data do not
     *                                correspond to a graph node
     */
    protected SearchNode computeShortestPath(NodeType start, NodeType end) {
        if (!nodes.containsKey(start) || !nodes.containsKey(end)) {
            throw new NoSuchElementException();
        }
        // prioritize SearchNodes by lowest cost
        PriorityQueue<SearchNode> queue = new PriorityQueue<>();
        // add the start node to the priority queue
        SearchNode startNode = new SearchNode(nodes.get(start), 0, null);
        queue.add(startNode);
        // map to remember the path to each node
        PlaceholderMap<NodeType, SearchNode> path = new PlaceholderMap<>();
        path.put(start, startNode);
        // add nodes to the path
        while (!queue.isEmpty()) {
            // get the node with the lowest cost
            SearchNode current = queue.poll();
            // base case: if the current node is the end node
            if (current.node.data.equals(end)) {
                return current;
            }
            // iterate through the edges leaving the current node
            for (Edge edge : current.node.edgesLeaving) {
                Node successor = edge.successor;
                // total cost from current to successor
                double cost = current.cost + edge.data.doubleValue();
                // get the current shortest path to the successor
                if (path.containsKey(successor.data)) {
                    SearchNode currentPath = path.get(successor.data);
                    if (cost < currentPath.cost) {
                        // remove the old path
                        queue.remove(currentPath);
                        // update the cost and predecessor
                        currentPath.cost = cost;
                        currentPath.predecessor = current;
                        // add the updated path back to the map and the queue
                        queue.add(currentPath);
                    }
                } else {
                    // create a new path to the successor
                    SearchNode currentPath = new SearchNode(successor, cost, current);
                    // add the new path to the map and the queue
                    path.put(successor.data, currentPath);
                    queue.add(currentPath);
                }
            }
        }
        throw new NoSuchElementException();
    }

    /**
     * Returns the list of data values from nodes along the shortest path
     * from the node with the provided start value through the node with the
     * provided end value. This list of data values starts with the start
     * value, ends with the end value, and contains intermediary values in the
     * order they are encountered while traversing this shorteset path. This
     * method uses Dijkstra's shortest path algorithm to find this solution.
     *
     * @param start the data item in the starting node for the path
     * @param end   the data item in the destination node for the path
     * @return list of data item from node along this shortest path
     */
    public List<NodeType> shortestPathData(NodeType start, NodeType end) {
        // compute the shortest path
        SearchNode path = computeShortestPath(start, end);
        List<NodeType> result = new LinkedList<>();
        SearchNode current = path;
        // add the nodes to the result list from back to front
        while (current != null) {
            result.add(current.node.data);
            current = current.predecessor;
        }
        // reverse the result list to get the correct order
        List<NodeType> output = new LinkedList<>();
        for (int i = result.size()-1; i >= 0; i--) {
            output.add(result.get(i));
        }
        return output;
	}

    /**
     * Returns the cost of the path (sum over edge weights) of the shortest
     * path freom the node containing the start data to the node containing the
     * end data. This method uses Dijkstra's shortest path algorithm to find
     * this solution.
     *
     * @param start the data item in the starting node for the path
     * @param end   the data item in the destination node for the path
     * @return the cost of the shortest path between these nodes
     */
    public double shortestPathCost(NodeType start, NodeType end) {
        SearchNode path = computeShortestPath(start, end);
        return path.cost;
    }

    // TODO: implement 3+ tests in step 4.1

    /**
     * Test correct shortest path between two nodes.
     */
    @Test
    public void testShortestPath() {
        DijkstraGraph<String, Integer> graph = new DijkstraGraph<>();
        graph.insertNode("A");
        graph.insertNode("B");
        graph.insertNode("C");
        graph.insertNode("D");
        graph.insertNode("E");
        graph.insertNode("F");
        graph.insertNode("G");
        graph.insertEdge("A", "C", 3);
        graph.insertEdge("A", "D", 5);
        graph.insertEdge("A", "E", 1);
        graph.insertEdge("B", "G", 4);
        graph.insertEdge("C", "D", 5);
        graph.insertEdge("D", "B", 4);
        graph.insertEdge("E", "F", 4);
        graph.insertEdge("F", "C", 3);
        graph.insertEdge("F", "G", 2);
        graph.insertEdge("G", "D", 2);
        List<String> expect = Arrays.asList("A", "E", "F", "G");
        List<String> actual = graph.shortestPathData("A", "G");
        Assertions.assertEquals(expect, actual);
    }

    /**
     * Test correct shortest path cost and sequence between two nodes.
     */
    @Test
    public void testCheckCostAndSequence() {
        DijkstraGraph<String, Integer> graph = new DijkstraGraph<>();
        graph.insertNode("A");
        graph.insertNode("B");
        graph.insertNode("C");
        graph.insertNode("D");
        graph.insertNode("E");
        graph.insertNode("F");
        graph.insertNode("G");
        graph.insertEdge("A", "C", 3);
        graph.insertEdge("A", "D", 5);
        graph.insertEdge("A", "E", 1);
        graph.insertEdge("B", "G", 4);
        graph.insertEdge("C", "D", 5);
        graph.insertEdge("D", "B", 4);
        graph.insertEdge("E", "F", 4);
        graph.insertEdge("F", "C", 3);
        graph.insertEdge("F", "G", 2);
        graph.insertEdge("G", "D", 2);
        double expectCost = 5;
        double actualCost = graph.shortestPathCost("A", "D");
        Assertions.assertEquals(expectCost, actualCost);
        List<String> expectPath = Arrays.asList("A", "D");
        List<String> actualPath = graph.shortestPathData("A", "D");
        Assertions.assertEquals(expectPath, actualPath);
    }

    @Test
    public void testCheckNoPath() {
        DijkstraGraph<String, Integer> graph = new DijkstraGraph<>();
        graph.insertNode("A");
        graph.insertNode("B");
        graph.insertNode("C");
        graph.insertEdge("A", "B", 2);
        graph.insertEdge("B", "C", 3);
        Assertions.assertThrows(NoSuchElementException.class, () -> graph.shortestPathData("B", "A"));
    }
}
