#! /bin/env/python
'''
A script for mapping positions based on a MOD file.


Created on Sep 8, 2013

@author: Shunping Huang
'''

from __future__ import print_function

import os
import gc
import argparse as ap  # for argument passing

from modtools.mod import Mod
from modtools.alias import Alias
from modtools.utils import *

mod = None
alias = None
DESC = 'Foward (backward) mapping positions based on a MOD file.'


def process(args):
    print("Processing file '%s' ..." % args.posin_fn)
    pairs = []
    for p in args.pairs:
        if p is None or len(p) == 0:
            raise ValueError("Error occured in parsing pairs '%s'" % p)
        pp = p.split(',')
        pp = (int(pp[0]) - 1, int(pp[1]) - 1)  # col_id from 1-based to 0-based
        assert len(pp) == 2, "Wrong pair format '%s'" % p
        pairs.append(pp)

    #print(pairs)
    sep = args.delim
    is_bmap = args.is_bmap
    with open(args.posin_fn, 'r') as in_fp:
        with open(args.posout_fn, 'w') as out_fp:
            n = 0
            try:
                gc.disable()
                for line in in_fp:
                    n += 1
                    line = line.strip()
                    if len(line) == 0:
                        continue
                    fields = line.split(sep)
                    #print(fields)
                    n_cols = len(fields)
                    for chrom_col, pos_col in pairs:
                        if chrom_col >= n_cols or pos_col >= n_cols:
                            raise ValueError('Not enough columns at line:'
                                             ' %s' % line +
                                             ' (%d,%d,%d)' %
                                             (chrom_col, pos_col, n_cols))
                        chrom = fields[chrom_col]
                        #print(fields)
                        #print(pos_col)
                        pos = int(fields[pos_col])
                        posmap = mod.get_posmap(chrom)
                        if is_bmap:
                            fields[pos_col] = str(posmap.bmap((chrom, pos))[1])
                        else:
                            fields[pos_col] = str(posmap.fmap((chrom, pos))[1])

                    out_fp.write(sep.join(fields))
                    out_fp.write("\n")
                    if n % 100000 == 0:
                        gc.enable()
                        print(n)
                        gc.disable()
            except Exception as e:
                print("Error occured at line %d: %s" % (n, e))
                raise e
                #sys.exit(1)

    print("All Done!")


if __name__ == '__main__':
    # Usage:
    # modmap [-a alias.csv] in.mod pos.txt
    p = ap.ArgumentParser(description=DESC,
                          formatter_class=ap.RawTextHelpFormatter)

    p.add_argument("-f", dest='force', action='store_true',
                   help='overwrite existing output')

    p.add_argument("-b", dest='is_bmap', action='store_true',
                   help='activate backward mapping (destination -> source)')

    #p.add_argument('-a', metavar='alias.csv', dest='alias_fn',
    #               default=None,
    #               help='the csv file for alias classes of sequence name'
    #               ' (default: None)')

    p.add_argument('mod_fn', metavar='in.mod',
                   help='the input MOD file')

    p.add_argument('-d', metavar='delimiter', dest='delim',
                   default=',',
                   help='the delimiter in the text input for position mapping'
                   ' (default: \',\')')

    p.add_argument('posin_fn', metavar='pos_in.txt',
                   help='the text input before mapping positions')

    p.add_argument('posout_fn', metavar='pos_out.txt',
                   help='the text ouput after mapping positions')

    p.add_argument('pairs', metavar='chrom_col,pos_col',
                   nargs='+',
                   help='a pair of column ids for chromosome and position'
                   ' in the input')
    args = p.parse_args()

    # Handle tab delimiter
    args.delim = args.delim.replace('\\t', '\t')

    #if args.alias_fn is not None:
    #    is_file_readable(args.alias_fn)

    #alias = Alias()
    #try:
    #    alias.load(args.alias_fn)
    #except:
    #    pass

    is_file_readable(args.posin_fn)
    is_file_writable(args.posout_fn, args.force)

    mod = Mod(args.mod_fn)

    process(args)
