#!/usr/bin/env python

import sys
import urllib2
import random
import Queue
import time
from threading import Thread
import copy

import cv2
import numpy as np

import argparse
import os
import socket
import string
import cPickle as pkl
from datetime import datetime
#from datetime import date

configfile = os.path.join(os.getenv("HOME"),'.IPS_config.pkl')
print configfile

PORT = 6012

def CV_FOURCC(c1, c2, c3, c4) :
    return (c1 & 255) + ((c2 & 255) << 8) + ((c3 & 255) << 16) + ((c4 & 255) << 24)
    
def get_codec():
  parser = argparse.ArgumentParser()
  parser.add_argument("-dis","--dis", help="Display video frame (set DIS to True or False)",default=True)
  parser.add_argument("-viddir","--viddir",help="Set directory to save videos", default=os.getenv("HOME"))
  parser.add_argument("-viddev","--viddev",help="Select video device number (usually 0 or 1, defailt is 0)",type=int,default = 0) 
  parser.add_argument("-vidint","--vidint",help="recording interval in sec (default=0)",type=int,default=60)
  parser.add_argument("-vidnum","--vidnum",help="number of videos to record default is 1",type=int,default=0)
  parser.add_argument("-codec","--codec",help="Set recording codec MJPG,PIM1,MP42,U263,I263,FLV1 default is DIVX (mpeg4)",default="DIVX")
  parser.add_argument("-vidfps", "--vidfps", help="Set recording fps", type=int, default=30)
  parser.add_argument("-imgfmt","--imgfmt",help="Set video image format [jpg, png, tiff]",default="jpg")
  parser.add_argument("-imgint","--imgint",help="Set interval between images (sec)", type=int,default=60)
  parser.add_argument("-imgdir", "--imgdir", help="Set directory to save images",default=os.getenv("HOME"))
  parser.add_argument("-imgnum", "--imgnum", help="Number of images to take", type=int, default=0)
  args = parser.parse_args()
  return list(args.codec)
    
def explicit(l):
  '''
  '''
  max_val = max(l)
  max_idx = l.index(max_val)
  return max_idx, max_val
    
def process_image(source):
  ''' Find the longest contour from the Thresholded image and apply a fitting ellipse to return
  '''
  image02 = source.copy()
  cont=[0]
  cont , allh= cv2.findContours(image02,cv2.RETR_LIST,cv2.CHAIN_APPROX_NONE)
  DD=[len(cont[x]) for x in range(len(cont))]
  ellipse=[]
  try: 
    if len(cont[explicit(DD)[0]])>10:
      ellipse = cv2.fitEllipse(cont[explicit(DD)[0]])
  except:
    pass
  return ellipse

#class serveronEyeTrackPC(Thread):
  #''' This is a Thread that runs a server TCP onto the EyeTrack PC
  #'''
  #def __init__(self,Eyepydata,Quitb):
    #Thread.__init__(self)
    #self.server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    #self.server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
    #self.server_socket.bind(("", PORT))
    #self.server_socket.listen(5)
    #self.Eyepydata = Eyepydata
    #self.localdict = {}
    #self.Quitb = Quitb
    #self.Quitb
  #def run(self):
    #while self.Quitb.empty():
      #try :
	#self.client_socket, address = self.server_socket.accept()
	#print "serveronEyeTrackPC Got a connection from ", address
      #except :
	#print 'Server has been closed after the client'
	#break;
      #while self.Quitb.empty():
	  #if not(self.Eyepydata.empty()):
	    #data = self.Eyepydata.get()
	    #print data
	    #self.client_socket.send(cPickle.dumps(data))
	    #tmp = self.client_socket.recv(buffsize)
	  #print tmp
    #self.localdict['QuitTCP']=True
    #self.client_socket.send(cPickle.dumps(self.localdict))
    #time.sleep(2)
    #print 'Quit TCP server'
    #self.server_socket.close()
def saveconfig(configfile,listvar):
  f = open(configfile,'w+')
  pkl.dump(listvar,f)
  now = datetime.now()
  print 'Config File %s saved %d/%d/%d %d:%d' % (configfile,now.year,now.month,now.day,now.hour,now.minute) 
  f.close()
def loadconfig(configfile):
    f = open(configfile,'r')
    configvalues = pkl.load(f)
    f.close()
    return configvalues
class Target(Thread):
  def __init__(self,Eyepydata,quit):
    Thread.__init__(self)
    self.Eyepydata = Eyepydata
    self.quit = quit
    self.capture = cv2.VideoCapture(0)
    self.time = time.time()
    self.nameWIN = 'IPS Tracking - O. Joly'
    cv2.namedWindow(self.nameWIN,cv2.WINDOW_NORMAL)
    ret, frame = self.capture.read()
    self.point_rad = int(frame.shape[1]/10)
    self.points = [30,30]
    self.lx = int(frame.shape[1]/2)
    self.ly = int(frame.shape[0]/2)
    self.rx = 0
    self.ry = 0
    self.valTH = 50
    self.DiaMax = 200
    self.DiaMin = 10
    self.framenb = 0
    self.rotation = 0 
    self.center = np.array([0,0])
    self.size = np.array([0,0])
    self.pupil_area = 0
    try:
      [self.point_rad,self.valTH,self.DiaMax,self.DiaMin,self.lx,self.ly,self.rx,self.ry,self.rotation] = loadconfig(configfile)
    except:
      print 'Could not load previous config %s' % configfile
    self.save = False
    cv2.createTrackbar('Threshold', self.nameWIN, self.valTH, 255, self.changeThresholdMax)
    cv2.createTrackbar('DiaMax', self.nameWIN, self.DiaMax, 20000, self.change_DiaMax)
    cv2.createTrackbar('DiaMin',self.nameWIN, self.DiaMin, 20000, self.change_DiaMin)
    cv2.createTrackbar('SizeBox', self.nameWIN, self.point_rad, int(frame.shape[1]/1), self.change_SizeBox)
    self.server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    self.server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
    self.server_socket.bind(('localhost', PORT))
    self.server_socket.listen(5)
    
  def change_SizeBox( self, value ):
    self.point_rad = value
    saveconfig(configfile,[self.point_rad,self.valTH,self.DiaMax,self.DiaMin,self.lx,self.ly,self.rx,self.ry,self.rotation])
  def changeThresholdMax( self, value ):
    self.valTH = value
    saveconfig(configfile,[self.point_rad,self.valTH,self.DiaMax,self.DiaMin,self.lx,self.ly,self.rx,self.ry,self.rotation])
  def change_DiaMax( self, value ):
    self.DiaMax = value  
    if self.DiaMax <self.DiaMin:
      self.DiaMin=self.DiaMax
    saveconfig(configfile,[self.point_rad,self.valTH,self.DiaMax,self.DiaMin,self.lx,self.ly,self.rx,self.ry,self.rotation])
  def change_DiaMin( self, value ):
    self.DiaMin = value
    saveconfig(configfile,[self.point_rad,self.valTH,self.DiaMax,self.DiaMin,self.lx,self.ly,self.rx,self.ry,self.rotation])
  def draw_circle(self, point):
    cv2.circle(self.grey_image, point, self.point_rad,[255, 0, 0],1)
  def mouse_handler(self, event, x, y, flags, param):
    self.x = x
    self.y = y
    self.draw_circle((self.x, self.y))
    if event == cv2.EVENT_LBUTTONDBLCLK:
      print((event, x, y, flags, param))
      self.lx = x
      self.ly = y
      saveconfig(configfile,[self.point_rad,self.valTH,self.DiaMax,self.DiaMin,self.lx,self.ly,self.rx,self.ry,self.rotation])
    elif event == cv2.EVENT_RBUTTONDOWN:
      print((event, x, y, flags, param))
      self.rx = x
      self.ry = y
      saveconfig(configfile,[self.point_rad,self.valTH,self.DiaMax,self.DiaMin,self.lx,self.ly,self.rx,self.ry,self.rotation])
  def run(self):
    time0 = time.time()
    timez = time.time()
    buffsize = 512
    try :
      print "serveronEyeTrackPC Waiting for a connection ..."
      self.client_socket, address = self.server_socket.accept()
      print "serveronEyeTrackPC Got a connection from ", address
    except :
      print 'Server has been closed after the client'
    while self.quit.empty():
      ret,frame = self.capture.read()
      #cv2.flip(frame,2,frame)
      rows,cols,depth = frame.shape
      M = cv2.getRotationMatrix2D((cols/2,rows/2),self.rotation,1)
      clone_imageori = cv2.warpAffine(frame,M,(cols,rows))
      #kernel = np.ones((10,10),np.float32)/25
      #sizekernel = 10
      #self.grey_image2 = cv2.GaussianBlur(clone_imageori,(sizekernel,sizekernel),1.5)
      #clone_imageori = cv2.GaussianBlur(clone_imageori,(sizekernel,sizekernel),0)
      #clone_imageori = cv2.fastNlMeansDenoising(frame,None,10,10,7,10)
      #clone_imageori = cv2.fastNlMeansDenoisingColored(clone_imageori,None,10,10,7,21)
      self.grey_image = cv2.cvtColor(clone_imageori, cv2.COLOR_BGR2GRAY)
      #self.grey_image = cv2.fastNlMeansDenoising(self.grey_image2,10,10,7,10)
      
      ret,clone_image = cv2.threshold(self.grey_image,self.valTH,255,cv2.THRESH_BINARY_INV)
      mask = np.zeros((rows,cols,3),np.uint8)
      mask[int(self.ly)-self.point_rad:int(self.ly)+self.point_rad,int(self.lx)-self.point_rad:int(self.lx)+self.point_rad,:] = 1
      mask[:,:,0] = mask[:,:,0]*clone_image
      mask[:,:,1] = mask[:,:,1]*clone_image
      mask[:,:,2] = mask[:,:,2]*clone_image
      mask2 = np.zeros((rows,cols),np.uint8)
      mask2[int(self.ly)-self.point_rad:int(self.ly)+self.point_rad,int(self.lx)-self.point_rad:int(self.lx)+self.point_rad] = 1
      clone_image = clone_image*mask2
      cloneforw = clone_imageori * mask
      gamma=0
      alpha=0.9
      beta=0.4
      cv2.addWeighted(clone_imageori,alpha,cloneforw,beta,gamma,clone_imageori)
      #cv2.addWeighted(self.grey_image,alpha,clone_image,beta,gamma,self.grey_image)
      re=cv2.rectangle(clone_imageori, (int(self.lx)-self.point_rad,int(self.ly)-self.point_rad),(int(self.lx)+self.point_rad,int(self.ly)+self.point_rad), [255, 255, 255],1)
      cv2.fillPoly(clone_imageori,np.array([[(0,0),(0,40),(self.grey_image.shape[1],40),(self.grey_image.shape[1],0)]]),[180,180,180])

      X=0
      Y=0
      cv2.setMouseCallback(self.nameWIN, self.mouse_handler)
      ellipse = process_image(clone_image)
      time1 = time.time()
      try:
	cv2.ellipse(clone_imageori, ellipse,(50, 0, 255),1)
	self.center = np.array([np.int(np.round(ellipse[0][0])),np.int(np.round(ellipse[0][1]))])
	self.size = np.array([np.int(np.round(ellipse[1][0])),np.int(np.round(ellipse[1][1]))])
	self.pupil_area = np.pi*self.size[0]*self.size[1]
	
        #self.Eyepydata.put([self.center[0],self.center[1],self.pupil_area])
        #strintoprint = 'X: %.3dpx | Y: %.3dpx | S: %.3dpx | FPS: %d' % (self.center[0],self.center[1],self.pupil_area,1/(time1-time0))
        if self.pupil_area>self.DiaMin and self.pupil_area<self.DiaMax:
	  cv2.circle(clone_imageori,(self.center[0],self.center[1]),2,[200, 200, 0],1)
	  strintoprint = '%.3f\t%.3f\t%.3d\t%.3f' % (ellipse[0][0],ellipse[0][1],self.pupil_area,time1-timez)
	else:
	  self.center[0]=0
	  self.center[1]=0
	  self.pupil_area=0
	  #strintoprint = '%.3f\t%.3f\t%.3d\t%.3f' % (float('Nan'),float('Nan'),float('Nan'),time1-timez)
	  strintoprint = '%.3f\t%.3f\t%.3d\t%.3f' % (0,0,0,time1-timez)
        #localdict={'Help':'value'}
        tmp = self.client_socket.recv(512)
        #print tmp
        #self.client_socket.send(cPickle.dumps({[self.center[0],self.center[1],self.pupil_area]}))
        #self.client_socket.send(time.ctime())
        self.client_socket.send(strintoprint)
      except:
	#strintoprint = '%.3f\t%.3f\t%.3d\t%.3f' % (float('Nan'),float('Nan'),float('Nan'),time1-timez)
	strintoprint = '%.3f\t%.3f\t%.3d\t%.3f' % (0,0,0,time1-timez)
	tmp = self.client_socket.recv(512)
        self.client_socket.send(strintoprint)
	#self.Eyepydata.put([0,0,0])
      
      
      
      strintoprint = 'X: %.3dpx | Y: %.3dpx | S: %.4dpx | FPS: %d' % (self.center[0],self.center[1],self.pupil_area,1/(time1-time0))
      #strintoprint = 'X: %.3dpx | Y: %.3dpx | S: %.4dpx ' % (self.center[0],self.center[1],self.pupil_area)
      time0 = copy.deepcopy(time1)
      #font = cv2.FONT_HERSHEY_SIMPLEX
      font = cv2.FONT_HERSHEY_COMPLEX
      #lineType = cv2.LINE_AA
      cv2.putText(clone_imageori, strintoprint, (10,25),font,0.8 , (0, 0, 0),1)
      #cv2.putText(self.grey_image, strintoprint, (10,25),font,0.9 , (0, 255, 0),2)
      cv2.imshow(self.nameWIN, clone_imageori)
      if self.save:
	out.write(clone_imageori)
      if time.time()-timez>20 and self.save:
	print 'release video' 
	print 'stop rec. saved video in %s ' % fileout
	out.release()
	self.save = False
      #cv2.imshow("Eye_Tracking", self.grey_image)
      ### This following might be slowing down -------------------------------------------------
      ###print (self.x,self.y)
      ###Listen for ESC key
      #time.sleep(0.01)
      #self.framenb+=1
      #if self.framenb==1:
      k = cv2.waitKey(10) & 0xff
      self.framenb=0
      if k == 27:
	Quitb.put('Stop')
	break
      elif k == 84: # symbol +
	self.point_rad = self.point_rad-10
	if self.point_rad<1:
	  self.point_rad=1
	  saveconfig(configfile,[self.point_rad,self.valTH,self.DiaMax,self.DiaMin,self.lx,self.ly,self.rx,self.ry,self.rotation])
	  print k
	  print self.point_rad
      elif k == 114: # symbol r
	self.rotation+=5
	saveconfig(configfile,[self.point_rad,self.valTH,self.DiaMax,self.DiaMin,self.lx,self.ly,self.rx,self.ry,self.rotation])
	print k
      elif k == 115: # key s
	print k
	#fourcc = cv2.VideoWriter(*'XVID')
	codecArr = get_codec()
	timez = time.time()
	fourcc = CV_FOURCC(ord(codecArr[0]),ord(codecArr[1]),ord(codecArr[2]),ord(codecArr[3])) 
	print fourcc
	#fourcc=cv2.CV_FOURCC('F', 'M', 'P', '4')
	now = datetime.now()
	fileout = os.getenv("HOME")+'/eye_tracking_%s_%s_%s_%s_%s.avi' % (now.year,now.month,now.day,now.hour,now.minute) 
	print 'start saving video in %s ' % fileout
	out = cv2.VideoWriter(fileout,fourcc, 10.0, (clone_imageori.shape[1],clone_imageori.shape[0]))
	self.save = True
	
      elif k == 233: # symbol +
	#self.rotation+=5
	print k
      elif k<255 : 
	print k

    if self.save:
      out.release()
    self.capture.release()
    cv2.destroyAllWindows()	
      #if self.framenb==1:
	#c = cv2.waitKey(7) % 0x100
	#self.framenb=0
      ###print c
      #if c == 27:
        #break
        ######---------------- 
      #elif c == 43: # symbol +
        #print c
        #self.valTH+=1
        #print self.valTH
      #elif c == 45: # symbol -
        #self.valTH=self.valTH-1
        #print c
        ##print self.valTH    
      #elif c == 82: # symbol +
        #self.point_rad = self.point_rad-10
        #if self.point_rad<1:
          #self.point_rad=1
          #print c
          #print self.point_rad
      ##elif c == 84: # symbol +
        #self.point_rad = self.point_rad+10
        #print c
        #print self.point_rad
      #elif c<255:
	#print c
      # END of This following might be slowing down -------------------------------------------------

#TCPDataQueue = Queue.Queue() # to be sent to stim PC for starting/stoping trial

  
if __name__=="__main__":
  Eyepydata = Queue.Queue()
  Quitb = Queue.Queue()
  #Maindict={}
  t = Target(Eyepydata,Quitb)
  t.start()  
  #try:
    #server = serveronEyeTrackPC(Eyepydata,Quitb)
    #print "init server TCP Process..... seems OK"
  #except:
    #print "Error : Unable to intitialize server TCP"
  #try:
    #server.start()
    #print "Start server TCP Process..... seems OK"
  #except:
    #print "Error : Unable to start server TCP Process"
  #time.sleep(10)
  #print 'Hello again'
  #quit.put('Now')
  
    

 #while self.Quitb.empty():
      #try :
	#self.client_socket, address = self.server_socket.accept()
	#print "serveronEyeTrackPC Got a connection from ", address
      #except :
	#print 'Server has been closed after the client'
	#break;
      #while self.Quitb.empty():
	  #if not(self.Eyepydata.empty()):
	    #data = self.Eyepydata.get()
	    #print data
	    #self.client_socket.send(cPickle.dumps(data))
	    #tmp = self.client_socket.recv(buffsize)
	  #print tmp
    #self.localdict['QuitTCP']=True
    #self.client_socket.send(cPickle.dumps(self.localdict))
    #time.sleep(2)
    #print 'Quit TCP server'
    #self.server_socket.close()