'''
    mania.parser
    ~~~~~~~~~~~~

    :copyright: 2010 by Bjoern Schulz <bjoern.schulz.92@gmail.com>
    :license: MIT, see LICENSE for more details
'''

import mania.types
from mania.ast import (
    Value, Reference, Assign, Struct, GetAttribute, SetAttribute, Lambda, Let,
    ArgumentList, NamedLet, Call, If, Cond, List, OptionalArgument, Sequence,
    DelAttribute, And, Or, Require
)

class Parser(object):
    
    def __init__(self, scanner):
        self.scanner = scanner
        self.lineno = 1
        self.advance()
        self.methods = {
            'define': self.parse_define,
            'lambda': self.parse_lambda,
            'if': self.parse_if,
            'cond': self.parse_cond,
            'struct': self.parse_struct,
            'set!': self.parse_set_attribute,
            'del!': self.parse_del_attribute,
            'list': self.parse_list,
            'and': self.parse_and,
            'or': self.parse_or,
            'do': self.parse_sequence,
            'require': self.parse_require
        }
    
    def advance(self):
        try:
            self.t, self.v = self.scanner.next()
        except StopIteration:
            self.t, self.v = None, None
        if self.t == '\n':
            self.lineno += 1
            self.advance()
    
    def expect(self, t):
        if self.t != t:
            raise SyntaxError, '%s expected, got %s (line %d)' % (
                t, self.t, self.lineno
            )
        v = self.v
        self.advance()
        return v
    
    def parse_lambda(self):
        args = []
        self.expect('(')
        optional = False
        while self.t != ')':
            name = self.expect('name')
            if name in args:
                raise SyntaxError(
                    "duplicate argument '%s' in function definition" % name
                )
            elif self.t == '...':
                self.advance()
                args.append(ArgumentList(name))
                break
            args.append(OptionalArgument(name) if optional else name)
            if self.t == '.':
                self.advance()
                optional = True
        self.expect(')')
        return Lambda(args, self.parse_any())
    
    def parse_call(self, func):
        args = []
        while self.t != ')':
            value = self.parse_any()
            if self.t == '...':
                self.advance()
                args.append(ArgumentList(value))
                break
            else:
                args.append(value)
        return Call(func, args)
    
    def parse_sequence(self):
        body = []
        while self.t != ')':
            body.append(self.parse_any())
        return Sequence(*body)
    
    def parse_if(self):
        return If(
            self.parse_any(), self.parse_any(),
            self.parse_any() if self.t != ')' else None
        )
    
    def parse_cond(self):
        cases = []
        default = None
        while self.t != ')':
            self.expect('(')
            if self.t == 'name' and self.v == 'else':
                self.advance()
                default = self.parse_any()
            else:
                cases.append((self.parse_any(), self.parse_any()))
            self.expect(')')
        return Cond(cases, default)
    
    def parse_and(self):
        return And(self.parse_any(), self.parse_any())
    
    def parse_or(self):
        return Or(self.parse_any(), self.parse_any())
    
    def parse_define(self):
        if self.t == 'name':
            return Assign(self.expect('name'), self.parse_any())
        self.expect('(')
        func_name = self.expect('name')
        args = []
        optional = False
        while self.t != ')':
            name = self.expect('name')
            if name in args:
                raise SyntaxError(
                    "duplicate argument '%s' in function definition" % name
                )
            elif self.t == '...':
                self.advance()
                args.append(ArgumentList(name))
                break
            args.append(OptionalArgument(name) if optional else name)
            if self.t == '.':
                self.advance()
                optional = True
        self.expect(')')
        return Assign(func_name, Lambda(args, self.parse_any()))
    
    def parse_let(self):
        self.expect('(')
        args = []
        while self.t != ')':
            self.expect('(')
            args.append(
                (self.expect('name'), self.parse_any())
            )
            self.expect(')')
        self.expect(')')
        return Let(args, self.parse_any())
    
    def parse_named_let(self):
        name = self.expect('name')
        self.expect('(')
        args = []
        while self.t != ')':
            self.expect('(')
            args.append(
                (self.expect('name'), self.parse_any())
            )
            self.expect(')')
        self.expect(')')
        return NamedLet(name, args, self.parse_any())
    
    def parse_struct(self):
        dct = []
        while self.t != ')':
            self.expect('(')
            dct.append((self.expect('name'), self.parse_any()))
            self.expect(')')
        return Struct(dct)
    
    def parse_set_attribute(self):
        attr = self.parse_any()
        return SetAttribute(attr.object, attr.name, self.parse_any())
    
    def parse_del_attribute(self):
        attr = self.parse_any()
        return DelAttribute(attr.object, attr.name)
    
    def parse_get_attribute(self, object):
        self.expect('.')
        name = self.expect('name')
        attr = GetAttribute(object, name)
        while self.t == '.':
            self.expect('.')
            attr = GetAttribute(attr, self.expect('name'))
        return attr
    
    def parse_value(self):
        value = self.expect('value')
        if value is None:
            return Value(mania.types.nil)
        elif isinstance(value, bool):
            return Value(mania.types.true if value else mania.types.false)
        elif isinstance(value, basestring):
            return Value(mania.types.String(value))
        elif isinstance(value, float):
            return Value(mania.types.Real(value))
        return Value(mania.types.Integer(value))
    
    def parse_symbol(self):
        if self.t == '(':
            lst = []
            self.expect('(')
            while self.t != ')':
                if self.t == '(':
                    lst.append(self.parse_symbol())
                elif self.t == ':':
                    self.advance()
                else:
                    lst.append(self.parse_symbol())
            self.expect(')')
            return List(lst)
        elif self.t == 'value':
            return self.parse_value()
        else:
            return Value(mania.types.Name(self.expect('name')))
    
    def parse_list(self):
        lst = []
        while self.t != ')':
            item = self.parse_any()
            if self.t == '...':
                self.advance()
                item = ArgumentList(item)
            lst.append(item)
        return List(lst)
    
    def parse_require(self):
        return Require(self.parse_any())
    
    def parse_parentheses(self):
        self.expect('(')
        if self.t == 'name':
            name = self.expect('name')
            try:
                node = self.methods[name]()
            except KeyError:
                if name == 'let':
                    if self.t == 'name':
                        node = self.parse_named_let()
                    else:
                        node = self.parse_let()
                else:
                    if self.t == '.':
                        node = self.parse_call(
                            self.parse_get_attribute(
                                Reference(name)
                            )
                        )
                    else:
                        node = self.parse_call(Reference(name))
        else:
            func = self.parse_parentheses()
            if self.t == '.':
                node = self.parse_call(self.parse_get_attribute(func))
            else:
                node = self.parse_call(func)
        self.expect(')')
        return node
    
    def parse_any(self):
        if self.t == 'name':
            node = Reference(self.expect('name'))
        elif self.t == 'value':
            node = self.parse_value()
        elif self.t == ':':
            self.advance()
            node = self.parse_symbol()
        else:
            node = self.parse_parentheses()
        if self.t == '.':
            return self.parse_get_attribute(node)
        return node
    
    def parse(self):
        nodes = []
        while self.t:
            nodes.append(self.parse_any())
        if len(nodes) == 1:
            return nodes[0]
        return Sequence(*nodes)