#!/usr/bin/env python

'''
Quick runs Clojure code quickly.

Commands:

  eval          FORM                             Evals given form.
  repl          [PORT]                           Connects a repl to a running nREPL server.
  run           NAMESPACE[/FUNCTION] [ARGS...]   Runs existing defn.
  lein          [TASK ARGS...]                   Runs a Leiningen task.
  repload-run   NAMESPACE[/FUNCTION] [ARGS...]   Runs existing defn after running repload.
  repload-lein  [TASK ARGS...]                   Runs a Leiningen task after running repload.
  start         [PORT]                           Start a nREPL server.
  kill          [PORT]                           Kill a running nREPL server.
  restart       [PORT]                           Restart a running nREPL server.

Running with no arguments will read code from stdin.
'''

# for python 3 compatibility, in case nrepl author ever ports to python 3
from __future__ import (absolute_import, division, print_function, unicode_literals)
from future import standard_library
from future.builtins import *
from future.utils import native_str

import os, os.path, sys, subprocess, re, readline, warnings
from pprint import pformat
import nrepl

def find_root(cwd=None):
    if cwd is None: cwd = os.getcwd()
    while not os.path.dirname(cwd) == cwd:
        if os.path.isfile(os.path.join(cwd, 'project.clj')):
            return cwd
        cwd = os.path.dirname(cwd)
    return None

def get_port():
    ''' Figure out what port the nREPL server is set to. '''
    # check to see if LEIN_REPL_PORT is set.
    if os.environ.get('LEIN_REPL_PORT') is not None and os.environ.get('LEIN_REPL_PORT').isdigit():
        return int(os.environ.get('LEIN_REPL_PORT'))

    # look for first project.clj in current directory and all parent directories to
    # find the project directory. Check to see if a .nrepl-port file exists in the
    # project directory.
    root = find_root()
    if root is not None:
        nrepl_port = os.path.join(root, '.nrepl-port')
        if os.path.isfile(nrepl_port):
            with open(nrepl_port) as file:
                port = file.read()
                if port is not None and port.isdigit():
                    return int(port)

    # Check to see if ~/.lein/repl-port file exists.
    repl_port = os.path.expanduser('~/.lein/repl-port')
    if os.path.isfile(repl_port):
        with open(repl_port) as file:
            port = file.read()
            if port is not None and port.isdigit():
                return int(port)
    # no running port number could be found
    return None

# TODO: fix race condition
def start_nrepl(port=None):
    proc = subprocess.Popen(
        ['lein','repl',':headless']+([':port',port] if port is not None else []),
        stdout=subprocess.PIPE)
    line = proc.stdout.readline().rstrip()
    if line is not None:
        port = re.match(r"^nREPL server started on port ([0-9]+) on host ([^ ]+)$", line).group(1)
        if port is not None:
            port = int(port)
            #print(line, file=sys.stderr)
    return port

def start(port=None, quiet=True):
    running_port = get_port()
    if running_port:
        if not quiet:
            print('nREPL server is already running on port %s' % (running_port), file=sys.stderr)
        return running_port
    # if no port could be found, spawn a "lein repl :headless" as a background
    # process, then read the port number from stdout.
    return start_nrepl(port)

def kill(port=None, quiet=True):
    running_port = get_port()
    if running_port is None:
        print('No nREPL session is currently running')
        sys.exit(1)
    evaluate('(System/exit 0)', port=running_port)
    if not quiet:
        print('Terminated nREPL instance on port %s' % (running_port))

def restart(port=None):
    kill()
    port = start_nrepl(port)
    print("Restarted nREPL on port %s" % (port))

class EvalError(Exception): pass

# TODO: do we need to pass an ID to get out messages after a need-input?
def evaluate(code, port=None, **kwargs):
    if port is None: port = start()
    url = 'nrepl://localhost:%i' % (port)
    con = nrepl.connect(native_str(url))
    opts = {'op': 'eval', 'code': code}
    opts.update(kwargs)
    con.write(opts)

    value = None
    while True:
        message = con.read()
        print(pformat(message), file=sys.stderr)
        if message is None:
            warnings.warn('Got None message')
        elif 'status' in message and isinstance(message['status'],list) and len(message['status'])>0:
            if message['status'][0] == 'done':
                break
            elif message['status'][0] == 'eval-error':
                opts = {'op': 'eval', 'code': '(.printStackTrace *e)', 'session': message['session']}
                con.write(opts)
                stacktrace = con.read()
                stacktrace_status = con.read() if 'value' in message else message
                raise EvalError(stacktrace['err'])
            elif message['status'][0] == 'need-input':
                opts = {'op': 'stdin',
                        'session': message['session'],
                        'stdin': sys.stdin.read(1) if sys.stdin.isatty() else sys.stdin.read()}
                con.write(opts)
            else:
                raise Exception('Eval returned bad exit status: %s' % (message['status'][0]))
        elif 'value' in message:
            value = message
        elif 'out' in message:
            print(message['out'], end='')
        elif 'err' in message:
            print(message['err'], file=sys.stderr, end='')
        else: 
            warnings.warn("Don't know how to handle message: %s" % (message))
    return value
    
def repl(port=None):
    # send a dummy form to get the current namespace
    ns = evaluate('nil', port)['ns']
    # start the repl
    while (True):
        try:
            line = input('%s=>' % (ns))
            result = evaluate(line, port)
            if result is not None:
                if 'ns' in result:
                    ns = result['ns']
                print(result['value'])
        except (EOFError, KeyboardInterrupt):
            print()
            break
        except (EvalError) as e:
            print(str(e), file=sys.stderr)
        except:
            print('Unexpected error:', sys.exc_info()[0])
            raise

def escape(*symbols):
    return ' '.join(map(lambda x: '"'+x.replace(r'["\\]',r'\\\1')+'"', symbols))

def run(name, repload=False, *args):
    code = """%s
              (do
                (require '[clojure.stacktrace :refer [print-cause-trace]])
                (let [raw (symbol %s)
                      ns (symbol (or (namespace raw) raw))
                      m-sym (if (namespace raw) (symbol (name raw)) '-main)]
                  (require ns)
                  (try ((ns-resolve ns m-sym) %s)
                  (catch Exception e
                    (let [c (:exit-code (ex-data e))]
                      (when-not (and (number? c) (zero? c))
                        (print-cause-trace e)))))))""" %
        ('(repload)' if repload else '',
         escape(name), escape(*args))
    result = evaluate(code, ns='user')
    return None if result is None else result['value']

def lein(repload=False, *args):
    cwd = os.getcwd()
    root = find_root()
    code = """%s
              (binding [*cwd* %s, *exit-process?* false]
                       (System/setProperty "leiningen.original.pwd" %s)
                       (defmethod leiningen.core.eval/eval-in :default
                         [project form]
                         (leiningen.core.eval/eval-in
                           (assoc project :eval-in :nrepl) form))
                       (defmethod leiningen.core.eval/eval-in :trampoline
                         [& _] (throw (Exception. "trampoline disabled")))
                       (try (-main %s)
                         (catch clojure.lang.ExceptionInfo e
                           (let [c (:exit-code (ex-data e))]
                             (when-not (and (number? c) (zero? c))
                               (throw e))))))""" %
        ('(repload)' if repload else '',
         escape(cwd if root is None else root),
         escape(cwd), escape(*args))
    result = evaluate(code, ns='leiningen.core.main')
    return None if result is None else result['value']


# parse command-line arguments
if len(sys.argv) > 2 and sys.argv[1] == 'eval': print(evaluate(sys.argv[2])['value'])
elif len(sys.argv) > 1 and sys.argv[1] == 'start': start(*sys.argv[2:3], quiet=False)
elif len(sys.argv) > 1 and sys.argv[1] == 'kill': kill(*sys.argv[2:3], quiet=False)
elif len(sys.argv) > 1 and sys.argv[1] == 'restart': restart(*sys.argv[2:3])
elif len(sys.argv) > 1 and sys.argv[1] == 'repl': repl(*sys.argv[2:3])
elif len(sys.argv) > 2 and sys.argv[1] == 'run': run(*sys.argv[2:])
elif len(sys.argv) > 1 and sys.argv[1] == 'lein': lein(*sys.argv[2:])
elif len(sys.argv) > 2 and sys.argv[1] == 'repload-run': run(repload=True, *sys.argv[2:])
elif len(sys.argv) > 1 and sys.argv[1] == 'repload-lein': lein(repload=True, *sys.argv[2:])
elif not sys.stdin.isatty(): print(evaluate(sys.stdin.read())['value'])
else:
    print(__doc__, file=sys.stderr)
    sys.exit(1)

sys.exit(0)
