#!/usr/bin/env python
# -*- coding: utf-8 -*-

import random
import os
import itertools
import urllib
import urlparse
import cPickle as pickle
import sys

from Bio             import SeqIO
from Bio.SeqFeature  import SeqFeature
from Bio.SeqFeature  import FeatureLocation

import percache
import appdirs

class saccharomyces_cerevisiae_genome():
    
    data_dir = appdirs.user_data_dir("sgd_genome_data_files")
    
    if not os.path.isdir(data_dir):
        os.mkdir(data_dir)
    
    cache = percache.Cache( os.path.join( data_dir, "sgd-cache" ))

    def __init__(self):
               
        self.data_dir = saccharomyces_cerevisiae_genome.data_dir
        
        self.base_url = "http://downloads.yeastgenome.org/sequence/S288C_reference/NCBI_genome_source/"
        
        self.chromosome_files = {   "A":"chr01.gb", "B":"chr02.gb",
                                    "C":"chr03.gb", "D":"chr04.gb",
                                    "E":"chr05.gb", "F":"chr06.gb",
                                    "G":"chr07.gb", "H":"chr08.gb",
                                    "I":"chr09.gb", "J":"chr10.gb",
                                    "K":"chr11.gb", "L":"chr12.gb",
                                    "M":"chr13.gb", "N":"chr14.gb",
                                    "O":"chr15.gb", "P":"chr16.gb",}
        
        missing_files=[]
        
        for file_ in self.chromosome_files.values():
            if not os.path.exists(os.path.join(self.data_dir, file_)):
                print "data file", file_, "is missing"
                missing_files.append(file_)
        
        if missing_files:
            self.download(missing_files)

        try:
            self.feature_list = pickle.load( open( os.path.join(self.data_dir, "feature_list.p"), "rb" ) )
            self.gene_to_syst = pickle.load( open( os.path.join(self.data_dir, "gene_to_syst.p"), "rb" ) )
        except:
            self._cds = {}
            self.feature_list=[]
            self.gene_to_syst={}
            for f in self.chromosome_files.values():
                krom  =  SeqIO.read(os.path.join(self.data_dir, f), "gb")       
                features = [f for f in krom.features if f.type=="CDS"]
                self.feature_list.extend( [f.qualifiers['locus_tag'][0] for f in features] )
                self.gene_to_syst.update( { f.qualifiers['gene'][0]:f.qualifiers['locus_tag'][0]   for f in features if "gene" in f.qualifiers.keys()} )

            pickle.dump( self.feature_list, open( os.path.join(self.data_dir,"feature_list.p"), "wb" ), -1 )
            pickle.dump( self.gene_to_syst, open( os.path.join(self.data_dir,"gene_to_syst.p"), "wb" ), -1 )
            
            
    def chromosome(self, id):
        try:
            id=int(id)-1
            return SeqIO.read(os.path.join(self.data_dir, self.chromosome_files.values()[id]),"gb")
        except ValueError:
            pass
        if 1 <= (ord(id.lower())-96) <= 16:
            return SeqIO.read(os.path.join(self.data_dir, self.chromosome_files[id.upper()]),"gb")
            
    def chromosomes(self):
        return ( SeqIO.read(os.path.join(self.data_dir, f),"gb")
                 for f in self.chromosome_files.values())

    def download(self, missing_files=None):        
        def reporthook(blocknum, blocksize, totalsize):
            readsofar = blocknum * blocksize
            if totalsize > 0:
                percent = readsofar * 1e2 / totalsize
                s = "\r%5.1f%% %*d / %d" % (
                    percent, len(str(totalsize)), readsofar, totalsize)
                sys.stderr.write(s)
                if readsofar >= totalsize: # near the end
                    sys.stderr.write("\n")
            else: # total size is unknown
                sys.stderr.write("read %d\n" % (readsofar,))
        
        print "do you want to download missing files from", self.base_url
        answer = raw_input("yes/no <return>")

        if  not "y" in answer.lower():
            return
        sys.stderr.write("\n")
        if not missing_files:
            missing_files = self.chromosome_files.values()
        for file_ in sorted(missing_files):
            sys.stderr.write("downloading {}\n".format(file_))
            urllib.urlretrieve( urlparse.urljoin(self.base_url, file_), 
                                os.path.join(self.data_dir ,file_), 
                                reporthook = reporthook)
            sys.stderr.write("{} successfully downloaded\n\n".format(file_))

    def systematic_name(self, gene):
        gene = gene.upper()
        import re
        if re.match("Y[A-P](R|L)\d{3}(W|C)(-.)*", gene[:7]) and gene in self.feature_list:
            return gene
        else:
            try:
                gene = self.gene_to_syst[gene]
            except KeyError:
                raise Exception("gene {} does not exist".format(gene))
        return gene

    def cds(self, gene):
        return self.locus(gene, upstream=0, downstream=0)

    @cache
    def locus(self, gene, upstream=1000, downstream=1000):
        '''
       1000 bp upstream and downstream
       '''
        gene = self.systematic_name(gene)
        if not gene:
            return

        krom = SeqIO.read(os.path.join(self.data_dir,self.chromosome_files[gene[1]]),"gb")

        cds ={f.qualifiers['locus_tag'][0] :  f for f in [f for f in krom.features if f.type=="CDS"]}
        feature = cds[gene]

        color = '#%02x%02x%02x' % (random.uniform(150,255),
                                   random.uniform(150,255),
                                   random.uniform(150,255),)
        feature.qualifiers.update({"ApEinfo_fwdcolor" : color,
                                   "ApEinfo_revcolor" : color,
                                   })

        start, stop = feature.location.start, feature.location.end
        lcs = krom[start-upstream:stop+downstream]

        if gene[6]=="W":
            return lcs
        else:
            return lcs.reverse_complement()


    def upstream_gene(self, gene):
        gene = self.systematic_name(gene)
        if gene[6]=="W":
            return self.feature_list[self.feature_list.index(gene)-1]
        else:
            return self.feature_list[self.feature_list.index(gene)+1]

    def downstream_gene(self, gene):
        gene = self.systematic_name(gene)
        if gene[6]=="C":
            return self.feature_list[self.feature_list.index(gene)-1]
        else:
            return self.feature_list[self.feature_list.index(gene)+1]

    @cache
    def promoter(self, gene):
        gene = self.systematic_name(gene)
        if not gene:
            return
        upstream_gene = self.upstream_gene(gene)
        pr = self.intergenic_sequence(upstream_gene, gene)
        pr.features.append(SeqFeature(FeatureLocation(0, len(pr)),
                                      type = "promoter",
                                      strand = 1,
                                      qualifiers = {"note"              : "tp {} {}".format(upstream_gene,gene),
                                                    "ApEinfo_fwdcolor": "#b1e6cc",
                                                    "ApEinfo_revcolor": "#b1e681",
                                                    }))
        return pr
    
    def tandem(self, gene):
        return self.systematic_name(gene)[6] == self.systematic_name(self.upstream_gene(gene))[6]
           
    def bidirectional(self, gene):
        return not self.tandem(gene)


    @cache
    def terminator(self, gene):
        gene = self.systematic_name(gene)
        if not gene:
            return
        downstream_gene = self.downstream_gene(gene)
        tm = self.intergenic_sequence(gene, downstream_gene)
        tm.features.append(SeqFeature(FeatureLocation(0,len(tm)),
                                      type = "terminator",strand = 1,
                                      qualifiers = {"note": "tp {} {}".format(gene,downstream_gene),
                                                    "ApEinfo_fwdcolor": "#b1e6cc",
                                                    "ApEinfo_revcolor": "#b1e681",
                                                    }))

        return tm


    @cache
    def intergenic_sequence(self, upgene, dngene):
        upgene = self.systematic_name(upgene)
        dngene = self.systematic_name(dngene)
        if not upgene and dngene and upgene[1] == dngene[1]:
            return
        krom = SeqIO.read(os.path.join(self.data_dir, self.chromosome_files[upgene[1]]),"gb")
        cds  = {f.qualifiers['locus_tag'][0] :  f for f in [f for f in krom.features if f.type=="CDS"]}
        upfeature = cds[upgene]
        startup, stopup  = upfeature.location.start,upfeature.location.end
        dnfeature = cds[dngene]
        startdn,stopdn = dnfeature.location.start, dnfeature.location.end

        assert sorted( (startup, stopup, startdn, stopdn) ) == list(itertools.chain.from_iterable(sorted( [sorted((startup,stopup)),sorted((startdn, stopdn))] )))

        length,a,b = min([ (abs(a-b), a, b) for a, b in itertools.product((startup,stopup),(startdn,stopdn))])

        if a<b:
            return krom[a:b]
        else:
            return krom[b:a].reverse_complement()



if __name__=="__main__":

    sc = saccharomyces_cerevisiae_genome()

    assert sc.tandem("tpi1") == (not sc.bidirectional("tpi1"))
    assert sc.tandem("gal1") == (not sc.bidirectional("gal1"))


    from pydna_helper import ape

    print sc.locus("TPI1").format("gb")

    print sc.locus("fun26")
    print sc.cds("fun26")
    print sc.upstream_gene("fun26")
    print sc.systematic_name("fun26")
    print sc.downstream_gene("fun26")
    print sc.intergenic_sequence("YAL021C","YAL023C")


    assert str(  sc.promoter("dep1").seq) in str(sc.locus("dep1").seq)
    assert str(sc.terminator("dep1").seq) in str(sc.locus("dep1").seq)
    assert str(  sc.promoter("cys3").seq) in str(sc.locus("cys3").seq)
    assert str(sc.terminator("cys3").seq) in str(sc.locus("cys3").seq)
    assert str(  sc.promoter("swc3").seq) in str(sc.locus("swc3").seq)
    assert str(sc.terminator("swc3").seq) in str(sc.locus("swc3").seq)

    assert str(  sc.promoter("fun26").seq) in str(sc.locus("fun26").seq)
    assert str(  sc.terminator("fun26").seq) in str(sc.locus("fun26").seq)
    assert str(  sc.promoter("pmt2").seq) in str(sc.locus("pmt2").seq)
    assert str(  sc.terminator("lte1").seq) in str(sc.locus("lte1").seq)

    #promoter-promoter
    assert str(sc.promoter("dep1").seq) == str(sc.promoter("syn8").seq.reverse_complement())

    # teminator-terminator
    assert str(sc.terminator("fun14").seq) == str(sc.terminator("erp2").seq.reverse_complement())

    # promoter-promoter
    assert str(sc.promoter("spo7").seq) == str(sc.promoter("mdm10").seq.reverse_complement())

    # terminator-promoter
    assert str(sc.promoter("cys3").seq) == str(sc.terminator("dep1").seq)

    # terminator-promoter
    assert str(sc.promoter("ccr4").seq) == str(sc.terminator("ats1").seq)

    assert str(sc.promoter("CLN3").seq) in str(sc.locus("CLN3",2000,2000).seq)
    assert str(sc.terminator("CLN3").seq) in str(sc.locus("CLN3",2000,2000).seq)
    assert str(sc.terminator("cyc3").seq) in str(sc.locus("cyc3",2000,2000).seq)
    assert str(sc.promoter("cyc3").seq) in str(sc.locus("cyc3",2500,2500).seq)

    assert str(sc.terminator("cyc3").seq) == str(sc.promoter("CLN3").seq)

    assert str(sc.promoter("jen1").seq) == str(sc.promoter("sry1").seq.reverse_complement())
    assert str(sc.promoter("osm1").seq) == str(sc.terminator("isy1").seq)
    assert str(sc.terminator("cyc1").seq) in str(sc.locus("cyc1").seq)
    assert str(sc.terminator("utr1").seq) in str(sc.locus("utr1").seq)

    assert str(  sc.promoter("cdc24").seq) in str(sc.locus("cdc24").seq)
    assert str(sc.terminator("cdc24").seq) in str(sc.locus("cdc24").seq)
