# coding: utf-8

"""
Snooze: a backend-agnostic REST API provider for Flask.

e.g.

    from flask import app, Blueprint
    from flask.ext.sqlalchemy import SQLAlchemy
    from flask.ext.snooze import Snooze, SqlAlchemyEndpoint
    from my_model import sqlalchemy_db, Book

    api = Blueprint('api_v1', __name__)
    apimgr = Snooze(api)
    apimgr.add(SqlAlchemyEndpoint(sqlalchemy_db, Book, ['author', 'title']))

    app.register_blueprint(api, url_prefix='/api_v1')
"""

from flask import request, make_response
import re

try:
    import simplejson as json
except ImportError:
    import json


class NotFoundError(Exception):

    """
    Resource not found.
    """

    def __init__(self, cls, obj_id):
        super(NotFoundError, self).__init__()

        self.cls = cls
        self.obj_id = obj_id
        self.message = 'No %(cls)s exists with an ID of %(obj_id)s' % dict(
            cls=cls.__name__,
            obj_id=obj_id
        )


def error_dict(etype, message, **kwargs):
    d = dict(type=etype, message=message)
    if kwargs:
        d['detail'] = kwargs
    return d


def wrap_verb_call(call, pass_obj_id, endpoint, data_in, data_out):
    """
    Construct a callback that will wrap a given HTTP Verb call, optionally
    passing an object ID.
    """
    def f(obj_id=None):
        data = data_in(request.data) if request.data != '' else dict()
        assert isinstance(data, dict), "Data must be a dict"
        try:
            if pass_obj_id:
                res = call(endpoint, obj_id, data)
            else:
                res = call(endpoint, data)
            try:
                # NB. error_data used because Flask stringifies stuff we put
                #     into res.data, which isn't good for us
                res.data = data_out(res.error_data)
            except AttributeError:
                try:
                    res.data = data_out(res.data)
                except AttributeError:
                    res = data_out(res)
        except NotFoundError, e:
            res = make_response()
            res.status = '404'
            res.data = data_out(error_dict(**{
                'etype': type(e).__name__,
                'message': e.message,
                'class': e.cls.__name__,
                'id': e.obj_id
            }))
        except:
            import sys
            from traceback import extract_tb
            exc_type, exc_value, exc_traceback = sys.exc_info()
            res = data_out(error_dict(exc_type.__name__,
                           exc_value.message,
                           traceback=extract_tb(exc_traceback))), '500'
        return res
    return f


def response_redirect(endpoint, o, code):
    r = make_response()
    r.headers['Location'] = '%(path)s%(id)s' % dict(
        path=re.sub('[^/]*$', '', request.path),
        id=getattr(o, endpoint.id_key)
    )
    r.status = str(code)
    return r


class Snooze(object):

    """
    The API context manager,
    The api level means:
        every verb takes in and gives out data in the same ways
    """

    def __init__(self, app, hooks=None):
        self._app = app
        hooks = dict() if hooks is None else hooks
        self._hook_data_in = hooks.get('data_in', json.loads)
        self._hook_data_out = hooks.get('data_out', json.dumps)
        self._routes = {}

    def add(self, endpoint, name=None, methods=(
            'OPTIONS', 'LIST', 'POST',
            'GET', 'PUT', 'PATCH', 'DELETE')):
        """
        Add an endpoint for a class, the name defaults to a lowercase version
        of the class name but can be overriden.

        Methods can be specified, note that HEAD is automatically generated by
        Flask to execute the GET method without returning a body.
        """
        obj_name = endpoint.cls.__name__.lower() if name is None else name
        methods = [m.upper() for m in methods]

        for verb in 'OPTIONS', 'LIST', 'POST':
            if verb not in methods:
                continue

            l = wrap_verb_call(call=getattr(self, '_%s' % verb.lower()),
                               pass_obj_id=False,
                               endpoint=endpoint,
                               data_in=self._hook_data_in,
                               data_out=self._hook_data_out)

            # A bit of a hack, but this use of GET seems like a seperate verb
            # to me, hence LIST
            if verb == 'LIST':
                verb = 'GET'

            self._register(obj_name=obj_name,
                           needs_id=False,
                           verb=verb,
                           func=l)

        for verb in 'GET', 'PUT', 'PATCH', 'DELETE':
            if verb not in methods:
                continue

            l = wrap_verb_call(call=getattr(self, '_%s' % verb.lower()),
                               pass_obj_id=True,
                               endpoint=endpoint,
                               data_in=self._hook_data_in,
                               data_out=self._hook_data_out)

            self._register(obj_name=obj_name,
                           needs_id=True,
                           verb=verb,
                           func=l)

    #
    # Verbs
    #

    def _options(self, endpoint, data):
        """HTTP Verb endpoint"""
        return self._routes

    def _list(self, endpoint, data):
        """HTTP Verb endpoint (GET without an ID)"""
        # NB. to add a filter_by() to this, stick it before .all() as that
        # executes the query
        return endpoint.list_ids()

    def _post(self, endpoint, data):
        """HTTP Verb endpoint"""
        o = endpoint.create()
        if data is not None:
            self._fill(endpoint, o, data)

        return response_redirect(endpoint, o, 201)

    def _get(self, endpoint, obj_id, data):
        """HTTP Verb endpoint"""
        o = endpoint.read(obj_id)

        if isinstance(o, dict):
            return o

        return dict(o)

    def _put(self, endpoint, obj_id, data):
        """HTTP Verb endpoint"""
        created = False
        try:
            o = endpoint.read(obj_id)
        except NotFoundError:
            o = endpoint.create(obj_id)
            created = True

        self._fill(endpoint, o, data)

        if created:
            return response_redirect(endpoint, o, 201)

    def _patch(self, endpoint, obj_id, data):
        """HTTP Verb endpoint"""
        o = endpoint.read(obj_id)
        self._update(endpoint, o, data)

    def _delete(self, endpoint, obj_id, data):
        """HTTP Verb endpoint"""
        endpoint.delete(obj_id)

    #
    # Tools
    #

    def _update(self, endpoint, o, data):
        for k in data:
            assert k in endpoint.writeable_keys, \
                "Cannot update key %s, valid keys for update: %s" % \
                    (k, ', '.join(endpoint.writeable_keys))
            setattr(o, k, data[k])
        endpoint.finalize(o)

    def _fill(self, endpoint, o, data):
        items_set = set(endpoint.writeable_keys)
        keys_set = set(data.keys())
        assert items_set == keys_set, \
            "The provided keys (%s) do not match the expected items (%s)" % \
                (', '.join(keys_set), ', '.join(items_set))

        self._update(endpoint, o, data)

    def _register(self, obj_name, needs_id, verb, func):
        func.provide_automatic_options = False
        route = '/%s/%s' % (obj_name, '<obj_id>' if needs_id else '')
        self._app.route(route,
                        methods=(verb,),
                        endpoint="%s:%s/%s" % (verb,
                                                obj_name,
                                                '<obj_id>' if needs_id \
                                                            else ''))(func)
        verbs = self._routes.get(route, [])
        verbs.append(verb)
        if verb == 'GET':
            # Flask adds 'HEAD' for GET
            verbs.append('HEAD')
        self._routes[route] = verbs


class Endpoint(object):

    """
    Base Endpoint object.
    """

    def __init__(self, cls, id_key, writeable_keys):
        """
        cls:            Class of object being represented by this endpoint
        id_key:         Identifying key of an object
        writeable_keys: A list of keys that may be written to on an object
        """
        self.cls = cls
        self.id_key = id_key
        self.writeable_keys = writeable_keys

    def list_ids(self):
        """List all accessible ids"""
        raise NotImplementedError()

    def create(self):
        """Create a new object"""
        raise NotImplementedError()

    def read(self, obj_id):
        """Load an existing object"""
        raise NotImplementedError()

    def finalize(self, obj):
        """Save an object (if required)"""
        raise NotImplementedError()

    def delete(self, obj_id):
        """Delete the data for the provided ID"""
        raise NotImplementedError()


def row2dict(row):
    """
    Convert a SQLAlchemy row/object to a dict, found on:
    http://stackoverflow.com/questions/
        1958219/convert-sqlalchemy-row-object-to-python-dict
    """
    d = {}
    for col_name in row.__table__.columns.keys():
        d[col_name] = getattr(row, col_name)

    return d


class SqlAlchemyEndpoint(Endpoint):

    def __init__(self, db, cls, items):
        from sqlalchemy.orm import class_mapper
        self.db = db
        self.pk = class_mapper(cls).primary_key[0]
        super(SqlAlchemyEndpoint, self).__init__(cls, self.pk.name, items)

    def list_ids(self):
        return [pk[0] for pk in \
            self.db.session.query(self.pk).all()]

    def create(self, obj_id=None):
        o = self.cls()
        if obj_id is not None:
            setattr(o, self.id_key, obj_id)
        return o

    def read(self, obj_id):
        try:
            return self.cls.query.filter(self.pk == obj_id).all()[0]
        except IndexError:
            raise NotFoundError(self.cls, obj_id)

    def finalize(self, obj):
        self.db.session.add(obj)
        self.db.session.commit()

    def delete(self, obj_id):
        o = self.read(obj_id)
        self.db.session.delete(o)
