#!/usr/bin/env python
# coding: utf-8
"""Chanjo CLI
Central CLI for various Chanjo functionality.
* "annotate": Process all or a subset of genes, get coverage from a BAM-file and commit changes to a SQLite database.
* "read": Quick access getting coverage metrics from an interval in a BAM-file.
* "peak": Peak into a SQLite database to look at coverage for a given gene.

Usage:
  chanjo annotate <store> using <source> [<gene>... | --read=<file>] [--ccds=<file>] [--cutoff=<int>] [--sample=<str>] [--group=<int>] [--json=<file>] [--verbose] [-p | --print] [--force]
  chanjo read <chrom> <start> <end> using <source> [--cutoff=<int>] [--verbose]
  chanjo peak <store> <gene>... [--verbose] [--sample=<str>] [--group=<int>]
  chanjo build <store> using <ccds_path> [--force]
  chanjo <store> import <txt_path>...
  chanjo -h | --help
  chanjo --version

Arguments:
  <store>  Path to new or existing SQLite database.
  <source>  Path to the BAM alignment file.
  <ccds_path> Path to CCDS database dump (text-file).
  <gene>      List of HGNC symbols for genes to annotate or peak at.
  <txt_path>  List of generated text-files to be imported into a datastore. See <store> for further info.
  <chrom>     Chromosome ID.
  <start>     Starting interval position (1-based).
  <end>       Ending interval position (1-based).

Options:
  -h --help          Show this screen.
  --version          Show version.
  -r --read=<file>   Path to txt-file with one HGNC symbol per line.
  --ccds=<file>      Path to CCDS database dump. Also functions as a flag to signal building a new SQLite database.
  -c --cutoff=<int>  Cutoff to use for calculating completeness [default: 10].
  -s --sample=<str>  The sample ID to annotate with or peak at [default: 0-0-0U].
  -g --group=<int>   The sample ID to annotate with or peak at [default: 0].
  -j --json=<file>   Chanjo will write a temporary JSON-file that can be imported later. This is useful when parallelizing Chanjo.
  -f --force         Overwrite existing files without warning.
  -v --verbose       Show more extensive information about transcripts & exons.
  -p --print         Just print the variables to the console (debug).
"""
from __future__ import print_function
from docopt import docopt
import time
import itertools
import json
import sys
from path import path

from elemental.adapters import ccds

import chanjo
from chanjo import core, bam, sql


def fileCheck(file_path, force):
  """
  Checks a path for an existing file before overwrite.
  """
  f = path(file_path)
  if f.exists() and not force:
    sys.exit("\033[93mWARNING: {path} already exists. Use '--force' to overwrite.\033[0m"
             .format(path=f))

def read(bamPath, chrom, start, end, cutoff, verbose=False):
  """
  Reads coverage metrics from a BAM-file across a given interval and prints
  the results as JSON string.
  """
  # Setup adapters and the main Hub
  hub = core.Hub(bam.CoverageAdapter(bamPath),
                 sql.ElementAdapter(":memory:"))

  # How many decimals to print?
  if verbose:
    decimals = 10
  else:
    decimals = 2

  depths = hub.cov.read(chrom, start, end)
  (coverage, completeness, _) = hub.calculate(depths, cutoff)

  output = {
    "interval": {
      "chrom": chrom,
      "start": start,
      "end": end
    },
    "coverage": round(coverage, decimals),
    "completeness": round(completeness, decimals)
  }

  print(json.dumps(output, indent=4))

def peak(storePath, genes, sample=None, group=None, verbose=False):
  """
  Prints out a JSON string of the data you have requested based on genes.
  """
  db = sql.ElementAdapter(storePath)

  # How many decimals to print?
  if verbose:
    decimals = 10
  else:
    decimals = 2

  output = {}

  # Loop over each of the HGNC symbols
  for gene_id in genes:

    output[gene_id] = []

    # Loop over each of the annotations
    for data in db.get("gene", gene_id).data:

      # Print the annotation according to what the user requests (all or one)
      if sample == "0-0-0U" or sample == data.sample_id:
        output[gene_id].append({
          "sample": data.sample_id,
          "coverage": round(data.coverage, decimals),
          "completeness": round(data.completeness, decimals)
        })

  print(json.dumps(output, indent=4))

def build(storePath, ccdsPath, db=None):
  """
  Builds a new database instance with barebones structure and relationships,
  no annotations. This is useful when you plan to run Chanjo in parallel and
  need a reference database.
  """
  if db is None:
    db = sql.ElementAdapter(storePath)

  # Parse the provided CCDS database dump
  parser = ccds.CCDSAdapter()

  # Parse information from the CCDS txt-file
  genes, txs, exons = parser.connect(ccdsPath).parse()

  # 1. Setup the new database with tables etc.
  # 2. Import elements into the database by converting to ORM objects
  # 3. Commit all elements added during the setup session
  db.setup().convert(genes, txs, exons).commit()

def import_(annotations, db, sample_id=None, group_id=None, cutoff=None):
  """
  Imports annotations stored in text-files for example when running Chanjo in
  parallel.
  """
  # And for each of the annotations
  db.add([db.create("exon_data",

            element_id=anno["element_id"],
            coverage=anno["coverage"],
            completeness=anno["completeness"],
            sample_id=sample_id,
            group_id=group_id

          ) for anno in annotations]).commit()

  # ======================================================
  #   Now we can extend the annotations from exons to
  #   transcripts and genes. Not ready yet.
  #   N.B. We commit so the next query works!
  # ------------------------------------------------------
  db.add([db.create("transcript_data",
            element_id=tx[0],
            sample_id=sample_id,
            group_id=group_id,
            coverage=tx[1],
            completeness=tx[2]
          ) for tx in db.transcriptStats(sample_id)]).commit()

  # Extend annotations to genes
  db.add([db.create("gene_data",
            element_id=gs[0],
            sample_id=sample_id,
            group_id=group_id,
            coverage=gs[1],
            completeness=gs[2]
          ) for gs in db.geneStats(sample_id)]).commit()

def annotate(storePath, bamPath, cutoff, sample_id, group_id, ccdsPath=None,
             geneIDs=None, verbose=False):
  # Setup adapters and the main Hub
  hub = core.Hub(bam.CoverageAdapter(bamPath), sql.ElementAdapter(storePath))

  # We can set up a brand new database if the user wants to
  if ccdsPath:
    # Use the build command function but supply "our" Hub-db
    build(storePath, ccdsPath, db=hub.db)

  # =======================================================
  #   This is the part where we figure out which genes to
  #   annotate in this session.
  # -------------------------------------------------------
  if geneIDs is None:
    # Get all genes in the datastore
    genes = hub.db.find("gene")

  else:
    # The user was nice enough to supply a list of genes to annotate
    genes = hub.db.find("gene", query=geneIDs)

  # =======================================================
  #   This is where we actually annotate the exons based on
  #   the selection of genes above.
  # -------------------------------------------------------
  # This step filters out `None` that come from invalid HGNC symbols
  genes = [gene for gene in genes if gene is not None]
  exonData = [None]*len(genes)
  for i, gene in enumerate(genes):
    # Annotate the gene
    exonData[i] = hub.annotate(gene, cutoff)

    if verbose:
      print("Annotated: {}".format(gene.id), end="\r")

  # Flatten the 2D list
  flatExonData = list(itertools.chain.from_iterable(exonData))

  if args["--json"]:
    # We should save a temp version of the exon annotations to a text-file
    # rather than persisting to the datastore. We use JSON.
    with open(args["--json"], "w") as handle:
      json.dump({
        "cutoff": cutoff,
        "sample_id": sample_id,
        "group_id": group_id,
        "annotations": flatExonData
      }, handle)

  else:
    # Add to session and commit so that the upcoming queries will work
    import_(flatExonData, hub.db, sample_id, group_id, cutoff)

def main(args):
  # First command will determine which path to take
  if args["read"]:

    # Prepare the input arguments
    chrom = args["<chrom>"]
    start = int(args["<start>"])
    end = int(args["<end>"])
    cutoff = int(args["--cutoff"])

    # Read the BAM-file, calculate coverage, and print to console.
    read(args["<source>"], chrom, start, end, cutoff, args["--verbose"])

  elif args["peak"]:

    # Prepare the input arguments
    group = int(args["--group"])

    peak(args["<store>"], args["<gene>"], args["--sample"], group,
         args["--verbose"])

  elif args["build"]:

    # Check path
    fileCheck(args["<store>"], args["--force"])

    # Build up a barebones database structure
    build(args["<store>"], args["<ccds_path>"])

  elif args["import"]:
    # Set up the store
    db = sql.ElementAdapter(args["<store>"])

    # Import one text-file at a time
    for txtPath in args["<txt_path>"]:

      # Open the file
      with open(txtPath, "r") as handle:

        data = json.load(handle)
        import_(data["annotations"], db, data["sample_id"], data["group_id"],
                data["cutoff"])

  else:

    # Start a timer to be able to print out the runtime
    start = time.time()

    if args["--ccds"]:
      # Check path
      fileCheck(args["--ccds"], args["--force"])

    if args["--json"]:
      # Check path
      fileCheck(args["--json"], args["--force"])

    # Lets annotate some genes!
    # Prepare the input arguments
    cutoff = int(args["--cutoff"])
    sample_id = args["--sample"]
    group_id = int(args["--group"])

    # This part pre-processes the list of HGNC symbols/Gene IDs corresponding
    # to the genes the user wants to annotate.
    if args["--read"]:
      # Read HGCN symbols from a txt-file
      with open(args["--read"], "r") as f:
        genes = [hgnc.strip() for hgnc in f.readlines()]

    else:
      # The user either supplied a list of HGNC symbols or wants to annotate
      # all genes.
      genes = args["<gene>"] or None

    # Call the function which does the annotation of the genes
    annotate(args["<store>"], args["<source>"], cutoff, sample_id, group_id, args["--ccds"], genes, verbose=args["--verbose"])

    # End the timer (mohahaha...)
    end = time.time()
    runtime = round((end-start)/float(60), 2)
    print("Runtime: {time} min".format(time=runtime))

if __name__ == "__main__":
  # Parse arguments based on docstring above
  args = docopt(__doc__, version="Chanjo {v}".format(v=chanjo.__version__))

  if args["--print"]:
    print(args)

  else:
    main(args)
