import os, sys, json, subprocess, socket, time, argparse, random
from sys import stdout

from twisted.internet import reactor
from twisted.internet.protocol import Protocol, Factory, ClientFactory

timeflt = lambda: time.time()

# ANSI forground green
C_REQ = '\033[34m'
C_GOOD = '\033[32m'
C_BAD = '\033[31m'
C_END = '\033[0m'
verbose = False

def setupArgParse():
    p = argparse.ArgumentParser(description='Perform tests by connecting to the YARN/HADOOP Client')
    p.add_argument('-host', help='Host name', type=str, default='localhost')
    p.add_argument('-p', help='Port', type=int, default=12345)
    p.add_argument('-num', help='Port', type=int, default=1)
    p.add_argument('-t', help='Partition type', type=str, default='random')
    p.add_argument('-a', help='Partition arguments', type=str, default='(1)')
    p.add_argument('-s', help='Table size', type=int, default=10)
    p.add_argument('-v', help='Verbose', action='store_true')
    return p

def output(arg):
    sys.stdout.write(arg)
    sys.stdout.flush()

def getChunk(a, i):
    """Returns a portion of a string, being mindful of its size."""
    if(i < 10):
        start = 0
    else:
        start = i - 10
    if(i + 10 > len(a)):
        end = len(a)
    else:
        end = i + 10
    return a[start:end]

class SenderProtocol(Protocol):
    """
        This class is in charge of keeping the routers up to date with config data.
    """
    def __init__(self, addr, f):
        self.addr = addr
        self.factory = f
    
    def connectionMade(self):
        print('=============== Starting testing ================')
        # Get the first data and send it
        req = self.factory.tester.getCreateStmt()
        print("Query %d: %s" % (self.factory.tester.getStmtNumber(), req))
        
        self.transport.write(req + "\n")

    def dataReceived(self, data):
        # Stop timer
        self.factory.tester.selectToc()
        
        if(self.factory.tester.isDone()):
            self.transport.loseConnection()
            reactor.stop()
            return
            
        # Pass it to the tester for validation
        resp = self.factory.tester.validateResponse(data)
        
        # If ok do next one
        if(resp == True):
            # Keep sending statements until tester returns None for them
            req = self.factory.tester.getNextStmt()
            if(req):
                # Sending insert statement
                print("Query %d: %s" % (self.factory.tester.getStmtNumber(), req))
                self.transport.write(req + "\n")
            
            else:
                # We are done sending everything
                print("===== ALL DONE 2 =====")
                self.factory.tester.writeResults()
                self.transport.write("!exit\n")

        elif(resp == False):
            print("Got error, sending drop statement")
            req = self.factory.tester.getDropStmt()
            self.transport.write(req + "\n")
        else:
            # We are done sending everything
            print("===== ALL DONE 1 =====")
            self.factory.tester.setDone()
            self.transport.write("!exit\n")
            self.factory.tester.writeResults()

    def connectionLost(self, reason):
        print('Connection lost')
        reactor.stop()

class SenderFactory(Factory):
    def __init__(self, host, port, tester):
        self.host = host
        self.port = port
        self.tester = tester
    
    def buildProtocol(self, addr):
        return SenderProtocol(addr, self)
    
    def startedConnecting(self, connector):
        pass

    def clientConnectionLost(self, connector, reason):
        pass

    def clientConnectionFailed(self, connector, reason):
        print('Cannot connect')

class Tester:
    """Setup and deliver SQL statements to send to the cluster."""
    def __init__(self, partition, partArgs, tableSize, numSelects, nodeNumber):
        if(partition.lower() not in ('random', 'hash', 'range', 'roundrobin')):
            raise Exception("PartitionTypeNotSupported")
        self.partType = partition
        self.nodeNumber = nodeNumber
        self.partArgs = partArgs
        self.tableSize = tableSize
        self.curTableSize = 0
        self.stmtNumber = 0
        self.numSelects = numSelects
        self.selectResults = []
        self.TABLENAME = "%stest" % self.partType[0]
        self.FSM = "CREATE"
        self.doneFlag = False

    def getRandomString(self):
        return "".join([chr(random.randint(ord('a'), ord('z'))) for i in range(0, 25)])

    def selectToc(self):
        """Stop timer for this select statement."""
        # Only do this if we are SELECT state
        # the last select statement changes the FSM to drop so we'll be off by one unless we do it 1 more time
        if(self.FSM == "SELECT" or self.FSM == "DROP"):
            now = time.time()
            a = self.selectResults[-1]
            self.selectResults[-1] = now - a

    def writeResults(self):
        for s in self.selectResults:
            sys.stderr.write('%d,%d,%f\n' % (self.tableSize, self.nodeNumber, s))
            sys.stderr.flush()

    def getNextStmt(self):
        """Uses a FSM to determine what statement to return"""
        if(self.FSM == "CREATE"):
            return self.getCreateStmt()
        elif(self.FSM == "INSERT"):
            r = self.getInsertStmt()
            # Special case, if insert returns none switch to select
            if(not r):
                # The insert func changes the FSM so just recall us
                return self.getNextStmt()
            else:
                return r
        elif(self.FSM == "SELECT"):
            return self.getSelectStmt()
        elif(self.FSM == "DROP"):
            return self.getDropStmt()
        elif(self.FSM == "DONE"):
            return None
        else:
            print('!! Error: unknown FSM state: %s' % self.FSM)
            raise Exception("UnknownFsm")

    def getCreateStmt(self):
        """Returns 'create table' statement"""
        self.stmtNumber += 1
        self.FSM = "INSERT"
        return "create table %s(a integer, b char(32)) partition by %s%s" % (self.TABLENAME, self.partType, self.partArgs)

    def getDropStmt(self):
        self.stmtNumber += 1
        self.FSM = "DONE"
        return "drop table %s" % self.TABLENAME

    def getInsertStmt(self, numValues=1):
        """Returns more insert statements, until we hit the tableSize of this test."""
        ret = ""
        valArray = []
        # Add to the valArray
        while(self.curTableSize < self.tableSize and numValues > 0):
            valArray.append("(%d, '%s')" % (self.curTableSize, self.getRandomString()))
            self.curTableSize += 1
            numValues -= 1
        if(len(valArray) == 1):
            self.stmtNumber += 1
            return "insert into %s values %s" % (self.TABLENAME, ",".join(valArray))
        elif(len(valArray) > 1):
            self.stmtNumber += 1
            return "insert into %s values (%s)" % (self.TABLENAME, ",".join(valArray))
        else:
            self.FSM = "SELECT"
            return None

    def getSelectStmt(self):
        self.stmtNumber += 1
        self.selectResults.append(time.time())
        # See if we ran enough select statements
        if(len(self.selectResults) == self.numSelects):
            self.FSM = "DROP"
        return "select * from %s" % (self.TABLENAME)
    
    def getSelectStmtWhereUpper(self):
        """Designed to return upper half of the data."""
        self.stmtNumber += 1
        self.FSM = "DROP"
        return "select * from %s where a > %d" % (self.TABLENAME, self.tableSize / 2)
    
    def getSelectStmtWhereLower(self):
        """Designed to return upper half of the data."""
        self.stmtNumber += 1
        self.FSM = "DROP"
        return "select * from %s where a < %d" % (self.TABLENAME, self.tableSize / 2)

    def getStmtNumber(self):
        return self.stmtNumber

    def isDone(self):
        return self.doneFlag

    def setDone(self):
        self.doneFlag = True

    def validateResponse(self, data):
        """Look at response, attempt to validate it."""
        # check state, if done return None
        if(self.FSM == "DONE"):
            return None
        if("ERROR" in data):
            print('!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!')
            print(data)
            print('!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!')
            self.FSM = "DROP"
            return False
        else:
            return True

###############################################################################
# Main
p = setupArgParse()
args = p.parse_args()

host = args.host
port = args.p
partType = args.t
partArgs = args.a
tableSize = args.s
nodeNumber = args.num
verbose = args.v

tester = Tester(partition=partType, partArgs=partArgs, tableSize=tableSize, numSelects=10, nodeNumber=nodeNumber)

if(False):
    sTime = 0.1
    while(True):
        tester.selectToc()
        req = tester.getNextStmt()
        print(req)
        resp = tester.validateResponse('SUCCESS')
        if(resp is None):
            break
        time.sleep(sTime)
    tester.writeResults()

else:
    f = SenderFactory(host, port, tester)
    reactor.connectTCP(host, port, f)
    reactor.run()
