

"""Module Description

Copyright (c) 2014, Ying Jin <yjin@cshl.edu >


This code is free software; you can redistribute it and/or modify it
under the terms of the Artistic License (see the file COPYING included
with the distribution).

@author:  Ying Jin
@contact: yjin@cshl.edu
"""
import sys, time
import logging
from math import ceil,floor
from TEToolkit.Constants import TEindex_BINSIZE

#TEindex_BINSIZE = 200
sys.setrecursionlimit(10000)
class Node :
    def __init__(self,start=-1,end=-1,name=-1,parent=None,left=None,right=None):
        self.__start = start
        self.__end = end
     #   self.binstart = binstart
     #   self.binend = binend
        self.__name = name #idx in nameIDmap
        self.__namelist = {}
        self.left = left
        self.right = right
        self.balanceFactor = 0
        self.parent = parent
        self.isroot = False
        self.add(start,end,name)
        
    def add(self,start,end,name) :
    	if start in self.__namelist :
    		self.__namelist[start].append((name,end))
    	else :
    		self.__namelist[start] = [(name,end)]
    		
    def isRoot(self):
        return self.isroot
    
    def isLeftChild(self):
        return self.parent and self.parent.left == self
    
    def isRightChild(self):
        return  self.parent and self.parent.right == self
    
    def getStart(self):
        bin_startID = self.__start/TEindex_BINSIZE
        if self.__start == bin_startID * TEindex_BINSIZE :
             bin_startID -= 1         
        return bin_startID
    
    def getEnd(self):
        bin_endID = self.__end/TEindex_BINSIZE        
        return bin_endID
    
    def getName(self):
        return self.__name
    
    def overlaps(self,start,end):
    	TEnamelist = []
    	for s in sorted(self.__namelist.keys()) :
    		if s > end :
    			break
    		eles = self.__namelist[s]
    		for name, e in eles :
    			if start <= e and end >= s :
    				TEnamelist.append(name)

    	return TEnamelist
        #if start < self.__end and end >= self.__start :
        #    return True
        #else :
        #    return False

class BinaryTree :
    def __init__(self):
        
        self.root = None
        self.size = 0
        
                

    def print_tree(self,node):
        
        if node is None :
            return
        #print(str(node.getName()))
        #print("left child\t")
        self.print_tree(node.left)
        print(str(node.getStart()))
        self.print_tree(node.right)
    #    if node.left :
    #        self.print_tree(node.left)
    #    print("right child\t")
    #    if node.right :
    #        self.print_tree(node.right)
            
        return

    def children_count(self):
        """
        Return the number of children

        @returns number of children: 0, 1, 2
        """
        cnt = 0
        if self.left:
            cnt += 1
        if self.right:
            cnt += 1
        return cnt
        
    def insert(self,start,end,name):
        if self.root:
            self.__insert(self.root,start,end,name)
        else:
            self.root = Node(start,end,name)
            self.root.isroot  = True
        self.size = self.size + 1
    
    def __insert(self, node, start,end,name):
        """
        Insert new node with data

        @param data node data object to insert
        """
        root = node
        binstart = start/TEindex_BINSIZE
        if start == binstart * TEindex_BINSIZE :
            binstart -= 1

        if binstart == root.getStart() :
        	root.add(start,end,name)
        
        #if start == 8777001 :
           # sys.stderr.write(str(root.getStart())+"\n")
            	    
        if binstart < root.getStart():
            if root.left is None:
                root.left = Node(start=start,end=end,name=name,parent=root)
                #root.left.isLeft = True
                self.updateBalance(root.left)
            else:
                self.__insert(root.left,start,end,name)
        if binstart > root.getStart():
            if root.right is None:
         #       if start == 8828614  :
          #           sys.stderr.write(" parent node is "+str(root.getStart())+"\n")
                     
                root.right = Node(start,end,name,parent=root)
                #self.right.isLeft = False
                self.updateBalance(root.right)
           #     if start == 8777001 or start == 8828614: #> 8776876 and root.getStart() == 17754 :
            #          sys.stderr.write("name id " + str(name)+"\n")
             #         sys.stderr.write("root is "+str(root.getStart())+"\n")
              #        sys.stderr.write("root right is "+str(root.right.getStart())+"\n")
               #       sys.stderr.write("root parent is "+str(root.parent.getStart())+"\n")
                 #     sys.stderr.write("root left is "+str(root.left.getStart())+"\n")

            else:
                self.__insert(root.right,start,end,name)
    
    
#    def __contains__(self,key):
#        if self._get(key,self.root):
#            return True
#        else:
#            return False

            
    def updateBalance(self,node):
     #   print("updateBalance")
     #   print("node.start" + str(node.getStart()))
        if node.balanceFactor > 1 or node.balanceFactor < -1 :
            self.rebalance(node)
            return
        
        if node.parent != None :
            if node.isLeftChild() :
                node.parent.balanceFactor += 1
            elif node.isRightChild() :
                node.parent.balanceFactor -= 1
            
            if node.parent.balanceFactor != 0 :
                self.updateBalance(node.parent)
            
    def rebalance(self,node):
    #    print("rebalance")
        if node.balanceFactor < 0 :
            if node.right.balanceFactor > 0 :
                self.rotateRight(node.right)
                self.rotateLeft(node)
            else :
                self.rotateLeft(node)
        elif node.balanceFactor > 0 :
            if node.left.balanceFactor < 0 :
                self.rotateLeft(node.left)
                self.rotateRight(node)
            else :
                self.rotateRight(node)

    def rotateRight(self,oldRoot):
        newRoot = oldRoot.left
        oldRoot.left = newRoot.right
        if newRoot.right != None :
            newRoot.right.parent = oldRoot
        newRoot.parent = oldRoot.parent
        
        if oldRoot.isRoot() :
            self.root = newRoot
            newRoot.isroot = True
        else :
            if oldRoot.isLeftChild() :
                oldRoot.parent.left = newRoot
            else :
                oldRoot.parent.right = newRoot
                
        
        newRoot.right = oldRoot
        oldRoot.parent = newRoot
        oldRoot.balanceFactor = oldRoot.balanceFactor - 1 - max(newRoot.balanceFactor,0)
        newRoot.balanceFactor = newRoot.balanceFactor - 1 + min(oldRoot.balanceFactor,0)
                                            
    def rotateLeft(self,oldRoot):
        newRoot = oldRoot.right
        oldRoot.right = newRoot.left
        if newRoot.left != None :
            newRoot.left.parent = oldRoot
        newRoot.parent = oldRoot.parent
        
        if oldRoot.isRoot() :
            self.root = newRoot
            newRoot.isroot = True
        else :
            if oldRoot.isLeftChild() :
                oldRoot.parent.left = newRoot
            else :
                oldRoot.parent.right = newRoot
                
        
        newRoot.left = oldRoot
        oldRoot.parent = newRoot
        oldRoot.balanceFactor = oldRoot.balanceFactor + 1 - min(newRoot.balanceFactor,0)
        newRoot.balanceFactor = newRoot.balanceFactor + 1 - max(oldRoot.balanceFactor,0)
    #range query                        
    def lookup_r(self,start,end,node) :
    	
    	if node is None :
    		return (None,None)
    	
    	node_start = node.getStart()
    	if end < node_start :
    		return self.lookup_r(start,end,node.left)
    	
    	if start > node_start :
    		return self.lookup_r(start,end,node.right)
    	
    	if start == node_start and end == node_start:
    		return (node,None)
    	
    	if start == node_start and end > node_start :
    		return (node,self.lookup_p(end,node.right))
    	
    	if end == node_start and start < node_start :
    		return (self.lookup_p(start,node.left),node)
    	return (None,None) 			
    #point query using start point
    def lookup_p(self, start,node):
        """
        Lookup node containing data

        @param data node data object to look up
        @param parent node's parent
        @returns node and node's parent if found or None, None
        """
        if node is None :
            return None
        if start < node.getStart():
            if node.left is None:
            	return None                
            return self.lookup_p(start, node.left)
        
        elif start > node.getStart():            
            if node.right is None:               
                    return None                
            return self.lookup_p(start, node.right)
        else:
            return node
        
       
class TEfeatures:
    """index of TE annotations.
    """
    def __init__ (self):

       # self.build()
       # self.__binsize = 100
        
        self.indexlist = {}
        self._length = []
        self._nameIDmap = []
    
    def getNames(self) :
        names = []
        return self._nameIDmap
    def numTEs(self) :
        return len(self._nameIDmap)

    def getEleName(self,idx) :
        full_name = None
        if idx >= len(self._nameIDmap) or idx < 0 :
            return None
        else :
            full_name =  self._nameIDmap[idx]
        if full_name is not None:
            pos = full_name.find(':')
            val = full_name[pos+1:len(full_name)]
            return val
        else :
            return None
        
    def getFullName(self,idx) :
        if idx >= len(self._nameIDmap) or idx < 0 :
            return None
        else :
            return self._nameIDmap[idx]
        
    def getLength(self,TE_name_idx) :
        if TE_name_idx < len(self._length) :
            return self._length[TE_name_idx]
        else :
            return -1
        
    def getFamilyID(self,chr,start,end):
        binID = start/TEindex_BINSIZE
        endbinID = end/TEindex_BINSIZE + 1
       
        if self.indexlist.has_key(chr) :
            index = self.indexlist[chr]
            (node,RBnode) = index.lookup(binID,index.root,None,None)
            
            if node is not None and node.overlaps(binID,endbinID) :
                full_name = (node.getName()).split(':')
                famid = full_name[2]
                return famid
            else :
                return None
        else :

            return None

    def findOvpTE(self,chrom,start,end):
        startbinID = start/TEindex_BINSIZE
        endbinID = end/TEindex_BINSIZE 
        if start == startbinID * TEindex_BINSIZE :
           startbinID -= 1
        name_idx_list = []
        #sys.stderr.write("binid "+str(startbinID)+"\t"+str(endbinID)+"\n")
        if  self.indexlist.has_key(chrom) :
           	index = self.indexlist[chrom]
        else :
        	return None
        
        (LBnode,RBnode) = index.lookup_r(startbinID,endbinID,index.root)
        
        if LBnode is not None :
            #    sys.stderr.write(str(LBnode.getStart())+"\n")
        	telist = LBnode.overlaps(start,end) 
        	name_idx_list.extend(telist)
        	
        if RBnode is not None :
                #sys.stderr.write(str(RBnode.getStart())+"\n")
        	telist = RBnode.overlaps(start,end) 
        	name_idx_list.extend(telist)
        
        return name_idx_list

    
    def build (self,filename,te_mode):
            self.__srcfile = filename
            #print("build te index....")
            try:
                f = open(self.__srcfile,'r')
            except:
                logging.error("cannot open such file %s !\n" %(self.__srcfile))
                sys.exit(1)
            
           # counts = []
            #te_features = {}
            name_idx = 0
            for line in f :
                line = line.strip()
                items = line.split('\t')
                chrom = items[0]
                start = int(items[3])
                end = int(items[4])
               # strand = items[6]
                items[8] = items[8].replace("; ",";")
                desc = items[8].split(';')
                name = ""
                family_id = ""
                ele_id = ""
                class_id = ""
                tlen = end - start + 1
                
                for i in range(len(desc)) :
                    desc[i] = desc[i].replace("\"","")
                    pos = desc[i].find(" ")
                    tid = desc[i][:pos]
                    val = desc[i][pos+1:len(desc[i])]
                   
                    if tid == "gene_id" :
                        ele_id = val
                    if tid == "transcript_id" :
                        name = val
                    if tid == "family_id" :
                        family_id = val
                    if tid == "class_id" :
                        class_id = val
                        
                if ele_id == "" or name == "" or family_id == "" or class_id == "" :
                    sys.stderr.write(line+"\n")
                    sys.stderr.write("TE GTF format error!\n")
                    raise
               
                full_name = name+':'+ele_id+':'+family_id+':'+class_id
                
                #self._length[full_name] = t.len
                self._length.append(tlen)
                self._nameIDmap.append(full_name)
                
                if self.indexlist.has_key(chrom) :
                
                        index = self.indexlist[chrom]
                        #index.insert(bin_startID, bin_endID, family_id)
                        bin_startID = start/TEindex_BINSIZE
                        bin_endID = end/TEindex_BINSIZE
                        if start == bin_startID * TEindex_BINSIZE :
                            bin_startID -= 1
                        while bin_startID <= bin_endID :
                            end_pos = min(end,(bin_startID+1) * TEindex_BINSIZE )
                            start_pos = max(start,bin_startID * TEindex_BINSIZE+1)
         
                            index.insert(start_pos,end_pos,name_idx)
                            bin_startID += 1
                    
                else :
                        index = BinaryTree()
                        bin_startID = start/TEindex_BINSIZE
                        bin_endID = end/TEindex_BINSIZE
                        if start == bin_startID * TEindex_BINSIZE :
                            bin_startID -= 1
                        while bin_startID <= bin_endID :
                            end_pos = min(end,(bin_startID+1) * TEindex_BINSIZE )
                            start_pos = max(start,bin_startID * TEindex_BINSIZE+1)
                            index.insert(start_pos,end_pos,name_idx)
                            bin_startID += 1

                        self.indexlist[chrom] = index
                        #    print(self.indexlist["chr10"].size)
                name_idx += 1
            
            f.close()
            
          #  return counts    
 
def main():
     
    #a = [1,4,6,7,3,13,10,11,8,14]
    
    a = [14,11,10,7,6,4,3,2,1]
     
    bt = BinaryTree()
    
    te = TEfeatures() 
    te.build("/home/yjin/Workspace/DiffChip/DiffChip2/TEToolkit/te.gtf","multi")
    
    print(time.time())
    for i in range(10000) :
        telist = te.findOvpTE('chr2L',9800,9900)
        telist = te.findOvpTE('chr2L',10800,10900)
        telist = te.findOvpTE('chr2L',130800,130900)
        telist = te.findOvpTE('chr2L',120800,120900)
        telist = te.findOvpTE('chr2L',110800,110900)
    print(time.time())
    
    for t in telist :
    	print(t)
    print("-------------------")
    telist = te.findOvpTE('chr2L',10800,10900)
    
    for t in telist :
    	print(t)    
    
    
 #   for i in range(len(a)) :
        # j = len(a) - i -1
    #     print(str(a[j]))
        # bt.insert(a[j], a[j], str(a[j]))
 #       bt.insert(a[i], a[i], str(a[i]))
    #    print(bt.root.getStart())
    #  print(bt.root.balanceFactor)
           
    #bt.print_tree(bt.root)
    
 #    node  = bt.lookup(12, bt.root,None)
 #    if node != None :
 #        print(node.getStart())     
 
if __name__ == '__main__':
    try:
       
        main()
       
    except KeyboardInterrupt:
        sys.stderr.write("User interrupt !\n")
        sys.exit(0)      
