#! /bin/env python
'''
A script for creating reference fasta file

It takes a meta.in file, downloads the necessary files from given URLs,
and put it into a single reference fasta file with proper description line.

The input file, which is in csv format, requires two columns:
sequence (chromosome) name and url for the sequence.

==== Example of input file ====
chrY,http://hgdownload.cse.ucsc.edu/goldenPath/mm9/chromosomes/chrY.fa.gz
==== End of file ====

Created on Aug 21, 2013

@author: Shunping Huang
'''

from __future__ import print_function

import sys
import os.path
import gzip

try:
    # For Python 3
    from urllib.request import urlopen
except ImportError:
    # Fall back to Python 2's urllib2
    from urllib2 import urlopen

import pysam  # for faidx
import argparse as ap  # for argument passing
from modtools.fareader import FaReader
from modtools.alias import Alias
from modtools.utils import *

DESC = 'A FASTA downloader and generator for reference genome'

pool = []
alias = None
tmp_dir = None


def download_from_url(url, target):
    '''Download a file from a given url'''

    print("Connecting '%s' ..." % url)
    sys.stdout.flush()

    #chrom_fn = url.split('/')[-1]
    u = urlopen(url)
    f = open(target, 'wb')
    try:
        # for Python 2.
        meta = u.info()
        file_size = int(meta.getheaders("Content-Length")[0])
    except AttributeError:
        # for Python 3.
        meta = [s.strip().split(': ')
                for s in str(u.info()).strip().split('\n')]
        headers = dict([(v1.lower(), v2) for v1, v2 in meta])
        #print(headers)
        file_size = int(headers["content-length"])

    print("Downloading: %s Bytes: %s" % (target, file_size))
    sys.stdout.flush()

    prev_progress = -1
    progress      = 0
    file_size_dl  = 0
    #block_sz      = 8192
    block_sz      = 128 * 1024
    while True:
        buffer = u.read(block_sz)
        if not buffer:
            break

        file_size_dl += len(buffer)
        f.write(buffer)
        #status = r"%10d  [%3.2f%%]" % (file_size_dl,
        #                               file_size_dl * 100. / file_size)
        progress = file_size_dl * 100 / file_size
        if prev_progress < progress:
            status = r"%10d  [%3.0f%%]" % (file_size_dl, progress)
            status = status + chr(8) * (len(status) + 1)
            prev_progress = progress
            print(status, end="\r")
            sys.stdout.flush()
    f.close()
    print("")
    print("Download Completed.")
    sys.stdout.flush()


def gz_decompress(in_fn, out_fn):
    '''Decompress a gzipped file'''

    assert in_fn.endswith('.gz'), 'Error: file extension not .gz.'

    print("Decompressing: %s to %s ..." % (in_fn, out_fn))
    sys.stdout.flush()
    in_fp = gzip.open(in_fn, 'rb')
    out_fp = open(out_fn, 'wb')
    out_fp.writelines(in_fp)
    out_fp.close()
    in_fp.close()
    print("Decompression Completed.")
    sys.stdout.flush()


def prepare(list_fn):
    '''Read a list and prepare (download, decompress) necessary files'''

    global pool
    print("Using temporary directory '%s'." % tmp_dir)
    with open(list_fn, 'r') as fp:
        nLines = 0
        for line in fp:
            nLines += 1
            parts = None
            try:
                parts = line.rstrip().split(',')
                if len(parts) != 2:
                    raise ValueError("Number of columns should be two")
            except ValueError as e:
                print(e)
                raise ValueError("Error occured when parsing line %d in '%s'" %
                                 (nLines, list_fn))

            chrom = parts[0]
            url = parts[1]
            if not (url.startswith('http://') or
                    url.startswith('https://') or
                    url.startswith('ftp://')):
                raise ValueError("Error: url field error in Line %d." % nLines)

            if chrom == "":
                raise ValueError("Empty chromosome field in Line %d." %
                                 nLines)

            chrom_fn = tmp_dir + '/' + chrom + '.fa'
            if not os.path.isfile(chrom_fn):
                downloaded_fn = tmp_dir + '/' + url.split('/')[-1]
                print("File '%s' not found. Try to download '%s'." %
                      (chrom_fn, downloaded_fn))
                if os.path.isfile(downloaded_fn):
                    print("File '%s' exists. Skipped." % downloaded_fn)
                else:
                    download_from_url(url, downloaded_fn)
                if downloaded_fn.endswith('.gz'):
                    decompressed_fn = downloaded_fn[:-3]
                    if not os.path.isfile(decompressed_fn):
                        gz_decompress(downloaded_fn, decompressed_fn)
                    chrom_fn = decompressed_fn
                else:
                    chrom_fn = downloaded_fn
            else:
                print("File '%s' exists. Skipped." % chrom_fn)

            pool.append((chrom, chrom_fn))


def write_fasta(output_fn):
    '''Write output to a fasta file'''

    print("Writing fasta file '%s' ..." % output_fn)
    with open(output_fn, 'w') as out_fp:
        for chrom, chrom_fn in pool:
            print("Processing '%s' in '%s' ..." % (chrom, chrom_fn))
            length = 0
            fr = FaReader(chrom_fn)

            with open(chrom_fn, 'r') as fp:
                if len(fr.chrom_names) > 1:
                    candidates = alias.getAliases(chrom)
                    print("Using name candidates: %s" % ','.join(candidates))
                else:
                    candidates = fr.chrom_names
                    print("Only one sequence found. Using it as a shortcut.")
                    print("Using name candidates: %s" % ','.join(candidates))

                print("Writing '%s' ..." % (chrom))
                sys.stdout.flush()

                # Only write out the first match among chrom candidates
                for c in candidates:
                    try:
                        fp.seek(fr.chrom_offset(c))
                        out_fp.write('>%s\n' % chrom)
                        for row in fp:
                            if row.startswith('>'):
                                break
                            else:
                                out_fp.write(row.upper())
                                if row.endswith('\r\n'):
                                    length += len(row) - 2
                                elif row.endswith('\n'):
                                    length += len(row) - 1
                                else:
                                    length += len(row)
                        break
                    except ValueError:
                        pass
                else:
                    raise ValueError("Sequence '%s' not found in '%s'." %
                                     (chrom, chrom_fn) +
                                     "You may need to specify an alias.")

            print("Length of '%s': %d" % (chrom, length))
            sys.stdout.flush()

    print("Fasta output Finished.")

    print("Building fasta index ...")
    pysam.faidx(output_fn)
    print("Fasta index Finished.")


if __name__ == '__main__':
    # Usage:
    # refmaker [-f][-t tmp_dir][-a alias.csv][-i meta.in] ref.fa

    p = ap.ArgumentParser(description=DESC,
                          formatter_class=ap.RawTextHelpFormatter)

    p.add_argument("-f", dest='force', action='store_true',
                   help='overwrite existing fasta output')

    p.add_argument('-t', metavar='tmp_dir', dest='tmp_dir',
                   default='./tmp',
                   help='path for storing temparary files'
                   ' (default: ./tmp)')

    p.add_argument('-a', metavar='alias.csv', dest='alias_fn',
                   default=None,
                   help='the csv file for alias classes of sequence name'
                   ' (default: None)')

    p.add_argument('-i', metavar='list.in', dest='list_fn',
                   default="list.in",
                   help='the input list file (default: list.in)')

    p.add_argument('output_fn', metavar='ref.fa',
                   help='the output reference in fasta format')

    args = p.parse_args()

    if args.alias_fn is not None:
        is_file_readable(args.alias_fn)

    alias = Alias()
    try:
        alias.load(args.alias_fn)
    except:
        pass

    is_file_readable(args.list_fn)
    is_file_writable(args.output_fn, args.force)

    tmp_dir = args.tmp_dir
    while tmp_dir.endswith('/'):
        tmp_dir = tmp_dir[:-1]
    is_dir_writable(tmp_dir, True)

    prepare(args.list_fn)
    write_fasta(args.output_fn)
