#include "driver-config.h"

#include <linux/time.h>
#include <linux/kernel.h>
#include <linux/if_ether.h>
#include <linux/ip.h>
#include <linux/tcp.h>
#include "utils.h"
#include "print_string.h"

/* Prints stuff I care about in an sk_buf. */
void
print_skb (struct sk_buff *skb)
{
	struct timeval tv = skb->stamp;
	struct timeval now_tv;
	unsigned int total_us, curr_us; 
	unsigned short protl = ntohs (skb->protocol);
	char buf[256];

#ifdef PACIA_TRACE_DEBUG_PRINT_SKB
	sprintf (buf, "protocol = 0x%04x", protl);
	print_string (buf);
#endif // PACIA_TRACE_DEBUG_PRINT_SKB

	if (protl == ETH_P_IP)
	{
		struct iphdr *iph = skb->nh.iph;
		struct tcphdr * th = skb->h.th;
		__u8 ip_protl = ntohs (iph->protocol);

#ifdef PACIA_TRACE_DEBUG_PRINT_SKB
			sprintf (buf, "&ethhdr = 0x%08x", skb->mac.ethernet);
			print_string (buf);
			sprintf (buf, "&iph = 0x%08x", iph);
			print_string (buf);
#endif // PACIA_TRACE_DEBUG_PRINT_SKB
		
		if ( ( (unsigned long ) th) ==  ( (unsigned long ) iph)
			&& ip_protl == IPPROTO_IP
			&& skb->len >= 14 + 20 + 20)
		{
			th = (struct tcphdr *)
				( ( (unsigned long) iph) + 20);
			// IP header size is 20 for cases we care about
		}

		if (ip_protl == IPPROTO_IP)
		{
			/*
			__u16 src = ntohs (th->source);
			__u16 dst = ntohs (th->dest);
			*//*
			unsigned long ports = ntohl (((unsigned long *)th)[0]);
			__u16 src = ((unsigned short *)&ports)[0]; 
			__u16 dst = ((unsigned short *)&ports)[1];
			*/
			__u32 src_l, dst_l;
			__u16 src, dst;
			int i;

			/* This was figured out through trial and error. */
			src_l = ntohl (( (__u32 *) skb->data)[8]);
			src = ( (__u16 *) &src_l)[0];

			dst_l = ntohl (( (__u32 *) skb->data)[9]);
			dst = ( (__u16 *) &dst_l)[1];

#ifdef PACIA_TRACE_DEBUG_PRINT_SKB
			sprintf (buf, "&tcp header = 0x%08x", th);
			print_string (buf);
#endif // PACIA_TRACE_DEBUG_PRINT_SKB

			do_gettimeofday (&now_tv);
			total_us = 1000000 * tv.tv_sec + tv.tv_usec;
			curr_us = 1000000 * now_tv.tv_sec + now_tv.tv_usec;
			sprintf (buf, "timestamp = %lu, current time = %lu",
				total_us, curr_us);
			print_string (buf);

#ifdef PACIA_TRACE_DEBUG_PRINT_SKB
			total_us = 1000000 * tv.tv_sec + tv.tv_usec;
			sprintf (buf, "timestamp = %lu, source = 0x%04x, dest = 0x%04x",
				total_us, src, dst);
			print_string (buf);
#endif // PACIA_TRACE_DEBUG_PRINT_SKB

#ifdef PACIA_TRACE_DEBUG_PRINT_SKB
			// hexdump the tcp header, assuming it really has one
			sprintf (buf, "tcp header:");
			print_string (buf);

			for (i = 0; i < ( sizeof (struct tcphdr) >> 2); i += 2)
			{
				sprintf (buf, "0x%08x 0x%08x",
						ntohl (( (__u32 *) th)[i]),
						ntohl (( (__u32 *) th)[i + 1]));
				print_string (buf);
			}

			// hexdump the data itself
			sprintf (buf, "&data = 0x%08x", skb->data);
			print_string (buf);
			sprintf (buf, "data len = %u:", skb->len);
			print_string (buf);
			for (i = 0; i < (skb->len >> 2); i += 4)
			{
				sprintf (buf, "0x%08x 0x%08x 0x%08x 0x%08x",
						ntohl (( (__u32 *) skb->data)[i]),
						ntohl (( (__u32 *) skb->data)[i + 1]),
						ntohl (( (__u32 *) skb->data)[i + 2]),
						ntohl (( (__u32 *) skb->data)[i + 3]));
				print_string (buf);
			}
#endif // PACIA_TRACE_DEBUG_PRINT_SKB
		}
	}
}

/* Returns 1 on success, 0 on failure. */
int getTcpPorts (struct sk_buff *skb, __u16 * src, __u16 * dst)
{
	unsigned short protl = ntohs (skb->protocol);
	int retval = 0;

	if (protl == ETH_P_IP)
	{
		struct iphdr *iph = skb->nh.iph;
		struct tcphdr * th = skb->h.th;
		__u8 ip_protl = ntohs (iph->protocol);
		
		/*   14		size of ethernet header
		 * + 20		size of ip header
		 * + 20		size of tcp header
		 */
		if ( ( (unsigned long ) th) ==  ( (unsigned long ) iph)
			&& ip_protl == IPPROTO_IP
			&& skb->len >= 14 + 20 + 20)
		{
			th = (struct tcphdr *)
				( ( (unsigned long) iph) + 20);
			// IP header size is 20 for cases we care about
		}

		if (ip_protl == IPPROTO_IP
			&& skb->len >= 36)
		{
			__u32 src_l, dst_l;

			/* This was figured out through trial and error. */
			src_l = ntohl (( (__u32 *) skb->data)[8]);
			*src = ( (__u16 *) &src_l)[0];

			dst_l = ntohl (( (__u32 *) skb->data)[9]);
			*dst = ( (__u16 *) &dst_l)[1];

			retval = 1;
		}
	}

	return retval;
}
