import java.io.*;
import java.net.*;
import java.util.*;

import javax.net.ssl.*;
import java.security.cert.*;

import org.w3c.dom.*;

/**
 * Arithmetic server
 */

public class Server {
    public static void main(String[] args) {
	new Server().go();
    }

    SSLServerSocket ss;

    final int PORT = 3346;	// server port to listen on

    void go() {
	try {
	    ss = (SSLServerSocket) SSLServerSocketFactory.getDefault().createServerSocket(PORT);
	    ss.setEnabledCipherSuites(ss.getSupportedCipherSuites());
	    ss.setWantClientAuth(true);

	    // main server loop
	    // waits for a connection, creates a thread to handle the connection
	    // and goes back to waiting
	    while (true) {
		Socket sock = ss.accept();
		System.out.println("Accepted connection: " + sock);
		SSLSession sess = ((SSLSocket) sock).getSession();
		System.out.println("  SSLConnection: " + sess);
		System.out.println("  PeerHost = " + sess.getPeerHost());
		try {

		    Certificate[] certs = sess.getPeerCertificates();
		    for (int i=0; i<certs.length; ++i) {
			System.out.println("  PeerCert = " + certs[i]);
		    }
		} catch (SSLPeerUnverifiedException ex) {
		    System.out.println("Client NOT authenticated");
		}

		XMLSocketThread xml = new XMLSocketThread(sock);
		xml.setDOMProcessor(new Receiver(xml));
		xml.start();
	    }
	} catch (Exception ex) {
	    System.out.println("Shutting down because of exception");
	    ex.printStackTrace(System.out);
	}
    }

    /**
     * This class deals with requests from a client.  When a request is
     * received with an arithmetic expression, the value of the expression
     * is calculated and returned to the client.
     */
    class Receiver implements DOMProcessor {
	XMLSocketThread xml;

	Receiver(XMLSocketThread xml) {
	    this.xml = xml;
	}

	public void process(Document doc) {
	    try {
		System.out.println();
		System.out.println("Received message");
		doc.normalize();

		XML.format(System.out, doc);

		Document ans = eval(doc);
		
		System.out.println();
		System.out.println("Send back");
		XML.format(System.out, ans);

		xml.send(ans);
	    } catch (IOException ex) {
		ex.printStackTrace();
	    }
	}
    }

    /**
     * evaluate the arithmetic expression
     */
    static Document eval(Document doc) {
	Node calc = doc.getLastChild();
	stripText(calc);
	//	System.out.println("Calc: " + calc);
	double n = eval(calc.getFirstChild());

	org.apache.xerces.dom.DOMImplementationImpl DOM = new
	    org.apache.xerces.dom.DOMImplementationImpl();

	DocumentType docType = DOM.createDocumentType("answer", "http://localhost/answer.dtd", "answer.dtd");

	Document ans = new org.apache.xerces.dom.DocumentImpl(docType);
	Node answer = ans.createElement("answer");
	answer.appendChild(ans.createTextNode(String.valueOf(n)));
	ans.appendChild(answer);
	return ans;
    }

    /**
     * recursively evalute the arithmetic expression
     */
    static double eval(Node n) {
	stripText(n);

	//	System.out.println("Eval: " + n);
	String name = n.getNodeName();
	if (name.equals("plus")) {
	    return eval(n.getFirstChild()) + eval(n.getLastChild());
	}
	if (name.equals("minus")) {
	    return eval(n.getFirstChild()) - eval(n.getLastChild());
	}
	if (name.equals("mult")) {
	    return eval(n.getFirstChild()) * eval(n.getLastChild());
	}
	if (name.equals("div")) {
	    return eval(n.getFirstChild()) / eval(n.getLastChild());
	}
	if (name.equals("number")) {
	    return Double.parseDouble(((Element) n).getAttribute("val"));
	}
	throw new AssertionError(name);
    }

    /**
     * helper method to strip bogus TEXT_NODE nodes.  These are
     * caused by whitespace in the XML.  The DTD doesn't allow any real
     * text to appear
     */
    static void stripText(Node n) {
	NodeList l = n.getChildNodes();
	for (int i=0; i<l.getLength(); ++i) {
	    Node nn = l.item(i);
	    if (nn.getNodeType() == nn.TEXT_NODE) {
		n.removeChild(nn);
	    }
	}
    }
}
