'''
    mania.types
    ~~~~~~~~~~~

    :copyright: 2010 by Bjoern Schulz <bjoern.schulz.92@gmail.com>
    :license: MIT, see LICENSE for more details
'''

from itertools import izip
from mania.ast import ArgumentList, OptionalArgument, Node
from mania.env import Environment

class Type(object):
    
    def equals(self, other):
        return true if self.value == other.value else false
    
    def not_equals(self, other):
        return false if self.eqaul(other) else true
    
    def greater(self, other):
        return true if self.value > other.value else false
    
    def greater_equals(self, other):
        return true if self.value >= other.value else false
    
    def lower(self, other):
        return true if self.value < other.value else false
    
    def lower_equals(self, other):
        return true if self.value <= other.value else false
    
    def equivalent(self, other):
        return true if self.value is other.value else false
    
    def contains(self, other):
        return true if other.value in self.value else false
    
    def to_boolean(self):
        return true if self.value else false
    
    @classmethod
    def to_mania(cls, value):
        dispatcher = {
            int: Integer,
            float: Real,
            str: String,
            unicode: String,
            bool: lambda o: true if o else false,
            type(None): lambda o: nil,
            list: lambda o: List([Type.to_mania(i) for i in o]),
            dict: lambda o: Struct(
                dict([[k, Type.to_mania(v)] for k, v in value.items()])
            )
        }
        return dispatcher[type(value)](value) if type(value) in dispatcher else value

class Integer(Type):

    def __init__(self, value):
        self.value = value

    def add(self, other):
        if isinstance(other, Real):
            return Real(self.value + other.value)
        return Integer(self.value + other.value)

    def sub(self, other):
        if isinstance(other, Real):
            return Real(self.value - other.value)
        return Integer(self.value - other.value)

    def mul(self, other):
        if isinstance(other, String):
            return String(self.value * other.value)
        elif isinstance(other, Real):
            return Real(self.value * other.value)
        return Integer(self.value * other.value)

    def div(self, other):
        if isinstance(other, Real):
            return Real(self.value // other.value)
        return Integer(self.value // other.value)

    def truediv(self, other):
        if isinstance(other, Real):
            return Real(self.value / other.value)
        return Integer(float(self.value) / float(other.value)) 

    def mod(self, other):
        return Type.to_mania(self.value % other.value)
    
    def pow(self, other):
        if isinstance(other, Real):
            return Real(self.value ** other.value)
        return Integer(self.value ** other.value)
    
    def pos(self):
        return Integer(+self.value)

    def neg(self):
        return Integer(-self.value)
    
    def bitwise_or(self, other):
        return Integer(self.value | other.value)
    
    def bitwise_xor(self, other):
        return Integer(self.value ^ other.value)
    
    def bitwise_and(self, other):
        return Integer(self.value & other.value)
    
    def bitwise_not(self):
        return Integer(~self.value)
    
    def lshift(self, other):
        return Integer(self.value << other.value)
    
    def rshift(self, other):
        return Integer(self.value >> other.value)
    
    def to_integer(self):
        return self

    def to_real(self):
        return Real(float(self.value))

    def to_string(self):
        return String(str(self.value))
    
    def to_native(self):
        return self.value

class Real(Type):

    def __init__(self, value):
        self.value = value

    def add(self, other):
        return Real(self.value + other.value)

    def sub(self, other):
        return Real(self.value - other.value)

    def mul(self, other):
        return Real(self.value * other.value)

    def div(self, other):
        return Real(self.value // other.value)

    def truediv(self, other):
        return Real(self.value / other.value)

    def mod(self, other):
        return Real(self.value % other.value)
    
    def pow(self, other):
        return Real(self.value ** other.value)
    
    def pos(self):
        return Real(+self.value)

    def neg(self):
        return Real(-self.value)

    def to_integer(self):
        return Integer(int(self.value))

    def to_real(self):
        return self

    def to_string(self):
        return String(str(self.value))
    
    def to_native(self):
        return self.value

class String(Type):
    
    def __init__(self, value):
        self.value = value
    
    def join(self, *args):
        return String(self.value.join(a.to_string() for a in args))
    
    def format(self, *args):
        return String(self.value.format(*args))
    
    def add(self, other):
        return String(self.value + other.value)
    
    def mul(self, other):
        return String(self.value * other.value)
    
    def first(self):
        return String(self.value[0] if self.value else '')
    
    def rest(self):
        return String(self.value[1:])
    
    def length(self):
        return Integer(len(self.value))
    
    def to_integer(self):
        return Integer(int(self.value))
    
    def to_real(self):
        return Real(float(self.value))
    
    def to_name(self):
        return Name(self.value)

    def to_list(self):
        return List([c.to_string() for c in self.value])
    
    def to_string(self):
        return self
        
    def to_native(self):
        return self.value

class Name(Type):
    
    def __init__(self, value):
        self.value = value
    
    def to_string(self):
        return String(self.value)
    
    def to_boolean(self):
        return true
    
    def to_native(self):
        return self.value

class Nil(Type):
    
    def __init__(self):
        self.value = None
    
    def to_boolean(self):
        return false
    
    def to_integer(self):
        return 0
    
    def to_real(self):
        return 0.0
    
    def to_string(self):
        return String('#n')
    
    def to_native(self):
        return None

nil = Nil()

class Boolean(Type):
    
    def __init__(self, value):
        self.value = value
    
    def __nonzero__(self):
        return self.value
    
    def to_boolean(self):
        return self
    
    def to_integer(self):
        return 1 if self.value else 0
    
    def to_real(self):
        return 1.0 if self.value else 0.0
    
    def to_string(self):
        return String('#t' if self.value else '#f')
    
    def to_native(self):
        return self.value

true = Boolean(True)
false = Boolean(False)

class List(Type):
    
    def __init__(self, value):
        self.value = value
    
    def __iter__(self):
        return iter(self.value)
    
    def index(self, i):
        return self.value[i.value] if i.value < len(self.value) else nil
    
    def push(self, value):
        self.value.append(value)
        return self
    
    def pop(self, value):
        self.value.pop(value)
        return self
    
    def first(self):
        return self.value[0] if self.value else nil
    
    def rest(self):
        return List(self.value[1:])
    
    def length(self):
        return Integer(len(self.value))
    
    def concat(self, other):
        self.value += other.value
        return self
    
    def to_struct(self):
        return Struct(dict((k.value, v) for k, v in izip(*self.value)))
    
    def to_string(self):
        return String('(%s)' % ' '.join(
            v.to_string().value for v in self.value)
        )
    
    def to_native(self):
        return [i.to_native() for i in self.value]

class Struct(Type):
    
    def __init__(self, value):
        self.value = value
    
    def get_attribute(self, name):
        return self.value.get(name, nil)
    
    def del_attribute(self, name):
        del self.value[name]
        return self
    
    def set_attribute(self, name, value):
        self.value[name] = value
        return self
    
    def length(self):
        return Integer(len(self.value))
    
    def to_list(self):
        return List([List([Name(k), v]) for k, v in self.value.items()])
    
    def to_string(self):
        return String('(%s)' % (' '.join(
            '(%s %s)' % (k, v.to_string().value)
            for k, v in self.value.items()
        )))
    
    def to_native(self):
        return dict(
            [[k, v.to_native()] for k, v in self.value.items()]
        )

class Module(Type):
    
    def __init__(self, name, value):
        self.value, self.name = value, name
    
    def get_attribute(self, name):
        return self.value.get(name, nil)

    def del_attribute(self, name):
        del self.value[name]
        return self

    def set_attribute(self, name, value):
        self.value[name] = value
        return self

    def length(self):
        return Integer(len(self.value))

    def to_list(self):
        return List([List([Name(k), v]) for k, v in self.value.items()])

    def to_string(self):
        return String('<module %s 0x%x>' % (self.name, id(self)))

    def to_native(self):
        return dict(
            [[k, v.to_native()] for k, v in self.value.items()]
        )

class Function(Type):
    
    error_messages = {
        'exactly': 'function takes exactly %d argument%s (%d given)',
        'no args': 'function takes no arguments (%d given)',
        'at least': 'function takes at least %d argument%s (%d given)'
    }
    
    def __init__(self, args, body, env):
        self.args, self.body, self.env = args, body, env
    
    def call(self, *args):
        is_list = isinstance(self.args[-1], ArgumentList) if self.args else False
        len_args, len_self_args = len(args), len(self.args)
        if len_args != len_self_args:
            if len_args > len_self_args:
                if self.args:
                    if not is_list:
                        raise TypeError(self.error_messages['exactly'] % (
                            len_self_args, 's' if len_self_args > 1 else '',
                            len_args
                        ))
                else:
                    raise TypeError(self.error_messages['no args'] % len_args)
            else:
                args_len = len(filter(
                    lambda a: not isinstance(a, OptionalArgument), self.args
                )) - (1 if is_list else 0)
                lst = self.args[len_args:-1 if is_list else len_self_args]
                if not all(isinstance(a, OptionalArgument) for a in lst):
                    raise TypeError(self.error_messages['at least'] % (
                        args_len, 's' if args_len > 1 else '', len_args
                    ))
        env = Environment(self.env)
        if self.args:
            for key, value in izip(self.args[:-1 if is_list else len_self_args], args):
                if isinstance(key, OptionalArgument):
                    key = key.name
                env.define(key, value)
            if len_args < len_self_args:
                for key in self.args[len_args:-1 if is_list else len_self_args]:
                    env.define(key, nil)
            if is_list:
                key = self.args[-1].eval(env)
                env.define(key, List(args[len_self_args - 1:]))
        return self.body.eval(env)
    
    def to_boolean(self):
        return true
    
    def to_string(self):
        return String('<function 0x%x>' % id(self))
    
    def to_native(self):
        def wrapper(*args):
            args = [Type.to_mania(a) for a in args]
            return self.call(*args).to_native()
        return wrapper

class FunctionWrapper(Type):
    
    def __init__(self, func):
        self.func = func
    
    def call(self, *args):
        return self.func(*args)
    
    def to_boolean(self):
        return true
    
    def to_string(self):
        return String(repr(self))
    
    def to_native(self):
        return self.func