from ply import yacc
import json
import os

from .lexer import Lexer
from . import ast

try:
    from . import lextab, yacctab
except ImportError:
    lextab, yacctab = 'lextab', 'yacctab'


class Parser(object):
    def __init__(self, lex_optimize=False, lextab=lextab, yacc_optimize=False, yacctab=yacctab, yacc_debug=False):
        self.lexer = Lexer(self._unshift, self._shift)
        self.lexer.build(optimize=lex_optimize, lextab=lextab)
        self.tokens = self.lexer.tokens
        self.parser = yacc.yacc(module=self,
                                tabmodule=yacctab,
                                start="program",
                                outputdir=os.path.dirname(__file__),
                                optimize=yacc_optimize,
                                debug=yacc_debug)
        self._scope_stack = [dict()]
        self.outputs = {}
        self.output_variables = {}
        self.futures = {}

    def parse(self, source, filename="", debug=0):
        if not source:
            source = "\n" 
        try: 
            tree = self.parser.parse(source, lexer=self.lexer, debug=debug)
        except SyntaxError as error: 
            # if debug == 0:
            #     message = ("Syntax error found at line %s\n" % error.lineno) + \
            #               "    " + error.text + "\n    " + (" " * error.offset) + "^"
            #     # todo: add a suggestion for fixing syntax
            #     raise SyntaxError(message)
            # else:
            #     raise
            raise
        return tree

    def _unshift(self):
        self._scope_stack.append(dict())

    def _shift(self):
        self._scope_stack.pop()

    def js(self, scope="_"):
        return "var "+scope+" = new Task(arguments);"+scope+".futures(" + self._dump_js(self.futures, scope).replace(scope+'.get(', scope+'.path(') + ");"+scope+".outputs(" + json.dumps(self.outputs) + ");"

    def _dump_js(self, variables, scope):
        varz = {}
        for k, dct in variables.items():
            assert isinstance(dct, dict), "Internal error. Variable group not valid."
            _vars = []
            for path, var in dct.items():
                _vars.append("'%s':%s" % (path, var.js(scope)))

            varz["_"+str(k)] = "{%s}" % ','.join(_vars)

        return "{%s}" % ",".join(["'%s':%s" % (k, v) for k, v in varz.items()])

    def p_error(self, p):
        raise SyntaxError("Syntax error at '%s' at line %d" % (p.value or p.type, p.lineno))

    # ------------------
    # Tasks
    # ------------------
    def p_program(self, p):
        '''program : placeholders task'''
        p[0] = ast.Program(parser=self,
                           placeholders=p[1],
                           task=p[2])

    def p_task_run(self, p):
        '''task : RUN suite EOF
                | WS ID''' # Just to hide the ply warnings :)
        p[0] = ast.Run(self, lineno=p.lineno(1), suite=p[2])

    def p_varpaths(self, p):
        '''varpaths : VARPATH
                    | varpaths VARPATH'''
        if len(p) == 2:
            p[0] = ast.Varpath(varpath=p[1])
        else:
            p[0] = p[1].add(varpath=p[2])

    def p_task_bfs(self, p):
        '''task : AFTER CHECKOUT suite EOF
                | BEFORE CHECKOUT suite EOF
                | AFTER ADDTOCART suite EOF
                | BEFORE ADDTOCART suite EOF
                | AFTER UNSUBSCRIBE suite EOF
                | AFTER SUBSCRIBE suite EOF
                | AFTER SIGNUP suite EOF
                | AFTER LOGIN suite EOF
                | AFTER LOGOUT suite EOF
                | BEFORE LOGOUT suite EOF
                | AFTER STOREOPENS suite EOF
                | AFTER STORECLOSES suite EOF
                | BEFORE STORECLOSES suite EOF'''
        p[0] = getattr(ast, p[2].capitalize())(parser=self, lineno=p.lineno(1), when=p[1], suite=p[3])

    def p_periodically(self, p):
        '''periodically : HOURLY
                        | DAILY
                        | WEEKLY
                        | MONTHLY
                        | QUARTERLY
                        | YEARLY
                        | '''
        # not sure why, but need to provide 
        # line number here for because below ends
        # up in line number 0
        p[0] = (p[1], p.lineno(1))

    def p_task_periodically(self, p):
        '''task : periodically ATON string suite EOF
                | periodically suite EOF'''
        if len(p) == 4:
            p[0] = ast.Periodically(period=p[1][0], parser=self, lineno=p[1][1], suite=p[2])
        else:
            p[0] = ast.Periodically(period=p[1][0], parser=self, lineno=p[1][1], args=[p[3]], suite=p[4])

    def p_task_every(self, p):
        '''task : EVERY string suite EOF'''
        p[0] = ast.Periodically(period="every", parser=self, lineno=p.lineno(1), args=[p[2]], suite=p[3])

    def p_task_on(self, p):
        '''task : ATON string suite EOF
                | ATON varpaths suite EOF'''
        p[0] = ast.On(parser=self,
                      lineno=p.lineno(1),
                      args=[p[2]],
                      suite=p[3])

    def p_placeholders(self, p):
        '''placeholders : placeholder
                        | placeholders placeholder
                        | '''
        if len(p) == 2:
            p[0] = ast.Placeholders(placeholder=p[1])
        elif len(p) == 3:
            p[0] = p[1].add(placeholder=p[2])

    def p_placeholder(self, p):
        '''placeholder : SET VARPATH TO PLACEHOLDER NEWLINE'''
        p[0] = ast.Placeholder(path=p[2], value=p[4])

    # ---------------------------
    # Statements
    # ---------------------------
    def p_stmts(self, p):
        '''stmts : stmt
                 | stmts stmt'''
        if len(p) == 3:
            p[0] = p[1].add(p[2])
        else:
            p[0] = ast.StmtList(stmt=p[1])

    def p_optional_stmts(self, p):
        '''optional_stmts : stmts
                          |'''
        if len(p) == 2:
            p[0] = p[1]

    # ------------------
    # Operations
    # ------------------
    def p_output(self, p):
        '''output : OUTPUT
                  |'''
        if len(p) == 2:
            self.outputs[p[1]] = "_" + str(p.lineno(1))

    def p_suite(self, p):
        '''suite : NEWLINE INDENT stmts DEDENT'''
        p[0] = p[3]

    # ---------------------------
    # IF ... [ELSE IF ..] ... ELSE
    # ---------------------------
    def p_stmt_if(self, p):
        '''stmt : IF expressions suite
                | IF expressions suite else_if
                | IF expressions suite else_if ELSE suite
                | IF expressions suite ELSE suite'''
        p[0] = ast.StmtList(stmt=ast.If(parser=self, lineno=p.lineno(1), expressions=p[2], suite=p[3]))
        if len(p) == 5:
            p[0].add(*p[4])
        elif len(p) == 6:
            p[0].add(ast.Else(parser=self, lineno=p.lineno(4), suite=p[5]))
        elif len(p) == 7:
            p[0].add(*p[4])
            p[0].add(ast.Else(parser=self, lineno=p.lineno(5), suite=p[6]))

    def p_elseif(self, p):
        '''else_if : ELSEIF expressions suite
                   | else_if ELSEIF expressions suite'''
        if len(p) == 4:
            p[0] = [ast.Elseif(parser=self, lineno=p.lineno(1), expressions=p[2], suite=p[3])]
        else:
            p[1].append(ast.Elseif(parser=self, lineno=p.lineno(2), expressions=p[3], suite=p[4]))
            p[0] = p[1]

    # ------------------
    # Delicious Juice
    # ------------------
    def p_juice(self, p):
        '''juice : args NEWLINE
                 | args NEWLINE INDENT kwargs optional_stmts DEDENT
                 | NEWLINE INDENT kwargs optional_stmts DEDENT
                 | NEWLINE'''
        p[0] = [[], {}, None]
        if len(p) == 3:
            if p[1] is not None:
                p[0][0] = p[1].args
        elif len(p) == 7:
            p[0][0] = p[1].args if p[1] is not None else []
            p[0][1] = p[4].kwargs if p[4] is not None else {}
            p[0][2] = p[5]
        elif len(p) == 6:
            p[0][1] = p[3].kwargs if p[3] is not None else {}
            p[0][2] = p[4]

    # ------------------
    # Events
    # ------------------
    def p_events(self, p):
        '''stmt : PRINT output juice
                | EMAIL output juice
                | DIALOG output juice
                | SMS output juice
                | RESTART output juice
                | TWEET output juice'''
        event = getattr(ast, p[1].capitalize())
        p[0] = event(parser=self,
                     lineno=p.lineno(1),
                     args=p[3][0],
                     kwargs=p[3][1],
                     suite=p[3][2])

    def p_stmt_reward(self, p):
        '''stmt : REWARD VARPATH expressions output NEWLINE'''
        p[0] = ast.Reward(parser=self,
                          lineno=p.lineno(1),
                          args=[ast.Varpath(varpath=p[2]), p[3]])

    # ------------------
    # System Events
    # ------------------
    def p_stmt_wait(self, p):
        '''stmt : WAIT string output suite'''
        p[0] = ast.Wait(parser=self,
                        lineno=p.lineno(1),
                        date=p[2],
                        suite=p[4])

    def p_stmt_wait_until(self, p):
        '''stmt : WAIT UNTIL expression output suite'''
        p[0] = ast.WaitUntil(parser=self,
                             lineno=p.lineno(1),
                             expression=p[3],
                             suite=p[5])

    def p_stmt_log(self, p):
        '''stmt : LOG juice'''
        p[0] = ast.Log(parser=self,
                       lineno=p.lineno(1),
                       args=p[2][0],
                       kwargs=p[2][1],
                       suite=p[2][2])

    def p_stmt_pass(self, p):
        '''stmt : PASS NEWLINE'''
        p[0] = ast.Pass()

    def p_stmt_quit(self, p):
        '''stmt : QUIT NEWLINE'''
        p[0] = ast.Quit()
    
    def p_foreach(self, p):
        '''stmt : FOREACH expression suite
                | FOREACH setter suite'''
        p[0] = ast.Foreach(parser=self,
                           lineno=p.lineno(1),
                           this=p[2],
                           suite=p[3])

    # ----------------------
    # Variables Statements
    # ----------------------
    def p_varexp_randunique(self, p):
        '''randunique : UNIQUE
                      | RANDOM
                      | UNIQUE RANDOM
                      | RANDOM UNIQUE
                      | '''
        if len(p) == 2:
            p[0] = [p[1]]
        elif len(p) == 3:
            p[0] = [p[1], p[2]]
        else:
            p[0] = []

    def p_varexp_offset(self, p):
        '''offset : FIRST DIGITS
                  | LAST DIGITS
                  | DIGITS DIGITS
                  | DIGITS
                  | FIRST
                  | LAST
                  | '''    
        if len(p) == 2:
            if p[1] == 'last':
                p[0] = -1
            elif p[1] == 'first':
                p[0] = 0
            else:
                p[0] = int(p[1])
        elif len(p) == 3:
            if p[1] == 'last':
                p[0] = [int(p[2])*-1, None]
            elif p[1] == 'first':
                p[0] = [None, int(p[2])]
            else:
                p[0] = [int(p[1]), int(p[2])]
        else:
            p[0] = []

    def p_varexp_agg(self, p):
        '''agg : AVERAGE
               | COUNT
               | AVG
               | NUMBER OF
               | MAX
               | SUM
               | MIN
               | SMALLEST
               | NEWEST
               | LOWEST
               | HIGHEST
               | LARGEST
               | OLDEST
               | '''
        if len(p) == 2:
            p[0] = p[1]
        elif len(p) == 3:
            p[0] = 'count'

    def p_varexp_time(self, p):
        '''time : FOR string
                | '''
        if len(p) == 3:
            p[0] = p[2]

    def p_varexp(self, p):
        '''expression : agg offset randunique expressions WHERE expressions SORTBY varpaths ASCDESC time
                      | agg offset randunique expressions SORTBY varpaths ASCDESC time
                      | agg offset randunique expressions WHERE expressions time
                      | agg offset randunique expressions time'''

        if len(p) == 11:
            _ = dict(where=p[6], sortby=(p[8], p[9]), time=p[10])
        elif len(p) == 9:
            _ = dict(sortby=(p[6], p[7]), time=p[8])
        elif len(p) == 8:
            _ = dict(where=p[6], time=p[7])
        elif len(p) == 6:
            _ = dict(time=p[5])

        p[0] = ast.AdvancedExpression(expression=p[4],
                                      randunique=p[3],
                                      offset=p[2],
                                      agg=p[1],
                                      **_)

    def p_varexp_simple(self, p):
        '''setter : varpaths WHERE expressions SORTBY varpaths ASCDESC time
                  | varpaths SORTBY varpaths ASCDESC time
                  | varpaths WHERE expressions time
                  | varpaths time'''

        if len(p) == 8:
            _ = dict(where=p[3], sortby=(p[5], p[6]), time=p[7])
        elif len(p) == 6:
            _ = dict(sortby=(p[3], p[4]), time=p[5])
        elif len(p) == 5:
            _ = dict(where=p[3], time=p[4])
        elif len(p) == 3:
            _ = dict(time=p[2])

        p[0] = ast.AdvancedExpression(expression=ast.Expression(expression=p[1]),
                                      **_)

    def p_stmt_set_varpath(self, p):
        '''stmt : SET VARPATH TO expressions NEWLINE
                | SET VARPATH TO setter NEWLINE'''
        p[0] = ast.Set(parser=self, lineno=p.lineno(1),
                       varpath=p[2], value=p[4])

    def p_stmt_unset_varpath(self, p):
        '''stmt : UNSET VARPATH NEWLINE'''
        p[0] = ast.Unset(parser=self, lineno=p.lineno(1), varpath=p[2])

    def p_stmt_push(self, p):
        '''stmt : PUSH variable INTO VARPATH NEWLINE'''
        p[0] = ast.Push(parser=self, lineno=p.lineno(1),
                        varpath=ast.Varpath(p[4]), value=p[1])

    def p_stmt_with(self, p):
        '''stmt : WITH varpaths suite'''
        p[0] = ast.With(parser=self, lineno=p.lineno(1),
                        varpath=p[2], suite=p[3])

    # ------------------
    # Expressions
    # ------------------
    def p_expression_varpath(self, p):
        '''expression : varpaths'''
        p[0] = ast.Expression(expression=p[1])

    def p_expression_num(self, p):
        '''expression : DIGITS'''
        p[0] = ast.Expression(expression=ast.Number(str(p[1])))

    def p_expression_var(self, p):
        '''expression : variable'''
        p[0] = ast.Expression(p[1])

    def p_expressions(self, p):
        '''expressions : expression
                       | expressions AND expression
                       | expressions OR expression'''
        if len(p) == 4:
            p[0] = p[1].add(method=p[2], expression=p[3])
        else:
            p[0] = p[1]

    # ------------------
    # Expressions > Method
    # ------------------
    def p_expression_has(self, p):
        '''expression : varpaths HAS VARPATH'''
        assert not p[3].endswith(']'), "Index lookup must not be a list"
        p[0] = ast.Expression(expression=ast.Method(left=p[1],
                                                    method="has",
                                                    right=ast.Varpath(p[3])))

    def p_expression_in(self, p):
        '''expression : VARPATH IN varpaths'''
        p[0] = ast.Expression(expression=ast.Method(left=ast.Varpath(varpath=p[1]),
                                                    method="in",
                                                    right=p[3]))

    def p_expression_contains(self, p):
        '''expression : expression CONTAINS variable'''
        p[0] = ast.Expression(expression=ast.Method(left=p[1],
                                                    method="contains",
                                                    right=p[3]))

    def p_expression_likere(self, p):
        '''expression : expression LIKE REGEX'''
        p[0] = ast.Expression(expression=ast.Method(left=p[1],
                                                    method="like",
                                                    right=ast.Regexp(p[3])))

    def p_expression_likestr(self, p):
        '''expression : expression LIKE variable'''
        p[0] = ast.Expression(expression=ast.Method(left=p[1],
                                                    method="like",
                                                    right=p[3]))

    def p_expression_isnotequal(self, p):
        '''expression : expression IS NOT EQUAL TO variable'''
        p[0] = ast.Expression(expression=ast.Method(left=p[1],
                                                    method="isnt",
                                                    right=p[6]))

    def p_expression_isequal(self, p):
        '''expression : expression IS EQUAL TO variable'''
        p[0] = ast.Expression(expression=ast.Method(left=p[1],
                                                    method="is",
                                                    right=p[5]))

    def p_expression_isgtltthen(self, p):
        '''expression : expression IS GREATER THEN expression
                      | expression IS GREATER THAN expression
                      | expression IS LESS THEN expression
                      | expression IS LESS THAN expression'''
        p[0] = ast.Expression(expression=ast.Method(left=p[1],
                                                    method='<' if p[3] == 'less' else '>',
                                                    right=p[5]))

    def p_expression_equals(self, p):
        '''expression : expression EQUALS variable'''
        p[0] = ast.Expression(expression=ast.Method(left=p[1],
                                                    method="is",
                                                    right=p[3]))

    def p_expression_is(self, p):
        '''expression : expression IS expression'''
        p[0] = ast.Expression(expression=ast.Method(left=p[1],
                                                    method="is",
                                                    right=p[3]))

    def p_expression_isnt(self, p):
        '''expression : expression ISNT expression'''
        p[0] = ast.Expression(expression=ast.Method(left=p[1],
                                                    method="isnt",
                                                    right=p[3]))
   
    # ------------------
    # Expressions > Math
    # ------------------
    def p_expression_math(self, p):
        '''expression : expression PLUS expression
                      | expression MINUS expression
                      | expression LT expression
                      | expression LE expression
                      | expression GT expression
                      | expression GE expression
                      | expression EQ expression
                      | expression NE expression
                      | expression TIMES expression
                      | expression DIVIDE expression'''
        p[0] = p[1].add(method=p[2], expression=p[3])

    def p_expression_group(self, p):
        '''expression : LPAREN expressions RPAREN'''
        p[2].expressions.insert(0, ("", "("))
        p[2].expressions.append(("", ")"))
        p[0] = p[2]
        
    # ------------------
    # Strings
    # ------------------
    def p_string_tags(self, p):
        '''string_tags : VARIABLE_TAG
                       | string_tags VARIABLE_TAG'''
        if len(p) == 2:
            p[0] = [p[1]]
        else:
            p[1].append(p[2])
            p[0] = p[1]

    def p_string_vt(self, p):
        '''string_vt : VARIABLE_VARPATH
                     | VARIABLE_VARPATH string_tags'''
        p[0] = ast.Varpath(p[1])
        if len(p) == 3:
            p[0].tags = p[2]

    def p_string_content(self, p):
        '''string_content : string_vt
                          | STRING_CONTINUE'''
        p[0] = p[1]

    def p_string_inner(self, p):
        '''string_inner : string_content
                        | string_inner string_content'''
        if len(p) == 2:
            p[0] = ast.String(data=p[1])
        else:
            p[0] = p[1].add(p[2])

    def p_string(self, p):
        '''string : STRING_START_SINGLE string_inner STRING_END
                  | STRING_START_TRIPLE string_inner STRING_END'''
        p[0] = p[2]

    # ------------------
    # Variable
    # ------------------
    def p_variable_varpath(self, p):
        '''variable : VARPATH'''
        p[0] = ast.Varpath(varpath=p[1])

    def p_variable_string(self, p):
        '''variable : string'''
        p[0] = p[1]

    # ------------------------
    # Arguments
    # ------------------------
    def p_args(self, p):
        '''args : variable
                | args variable
                |'''
        if len(p) == 2:
            p[0] = ast.Args(arg=p[1])
        elif len(p) == 3:
            p[0] = p[1].add(arg=p[2])

    def p_kwarg(self, p):
        '''kwarg : TAG NEWLINE
                 | TAG variable NEWLINE'''
        if len(p) == 3:
            p[0] = {p[1]: None}
        else:
            p[0] = {p[1]: p[2]}

    def p_kwargs(self, p):
        '''kwargs : kwarg 
                  | kwargs kwarg
                  |'''
        if len(p) == 2:
            p[0] = ast.Kwargs(kwarg=p[1])
        elif len(p) == 3:
            p[0] = p[1].add(kwarg=p[2])
