import logging
import urllib
import urlparse
from oic.oic import OIDCONF_PATTERN
from oic.oic.message import ProviderConfigurationResponse, AuthorizationResponse
import requests
from oic.utils.keyio import KeyJar
from oic.utils.time_util import utc_time_sans_frac
from oic.oic.provider import secret
from oic.oic.provider import RegistrationEndpoint
from oic.oic.provider import Endpoint

from oic import oauth2
from oic.oauth2 import provider
from oic.oauth2 import VerificationError
from oic.oauth2 import rndstr
from oic.oauth2 import ErrorResponse
from oic.oauth2 import UnSupported
from oic.oauth2 import Message
from oic.oauth2 import message
from oic.oauth2 import SINGLE_REQUIRED_STRING
from oic.oauth2 import OPTIONAL_LIST_OF_SP_SEP_STRINGS
from oic.oauth2 import REQUIRED_LIST_OF_STRINGS
from oic.oauth2 import OPTIONAL_LIST_OF_STRINGS
from oic.oauth2 import SINGLE_OPTIONAL_STRING
from oic.oauth2 import SINGLE_OPTIONAL_INT
from oic.exception import UnknownAssertionType
from oic.exception import PyoidcError
from oic.exception import AuthzError

from oic.utils.authn.client import AuthnFailure
from oic.utils.http_util import Unauthorized, NoContent
from oic.utils.http_util import Response
from oic.utils.http_util import BadRequest
from oic.utils.http_util import Forbidden

logger = logging.getLogger(__name__)

__author__ = 'roland'

# -----------------------------------------------------------------------------


class InvalidRedirectUri(Exception):
    pass


class MissingPage(Exception):
    pass


class ModificationForbidden(Exception):
    pass


class RegistrationRequest(Message):
    c_param = {
        "redirect_uris": REQUIRED_LIST_OF_STRINGS,
        "client_name": SINGLE_OPTIONAL_STRING,
        "client_uri": SINGLE_OPTIONAL_STRING,
        "logo_uri": SINGLE_OPTIONAL_STRING,
        "contacts": OPTIONAL_LIST_OF_STRINGS,
        "tos_uri": SINGLE_OPTIONAL_STRING,
        "policy_uri": SINGLE_OPTIONAL_STRING,
        "token_endpoint_auth_method": SINGLE_OPTIONAL_STRING,
        "scope": OPTIONAL_LIST_OF_SP_SEP_STRINGS,
        "grant_types": OPTIONAL_LIST_OF_STRINGS,
        "response_types": OPTIONAL_LIST_OF_STRINGS,
        "jwks_uri": SINGLE_OPTIONAL_STRING,
        "software_id": SINGLE_OPTIONAL_STRING,
        "software_version": SINGLE_OPTIONAL_STRING,
    }

    def verify(self, **kwargs):
        if "initiate_login_uri" in self:
            assert self["initiate_login_uri"].startswith("https:")

        if "redirect_uris" in self:
            for uri in self["redirect_uris"]:
                if urlparse.urlparse(uri).fragment:
                    raise InvalidRedirectUri(
                        "redirect_uri contains fragment: %s" % uri)

        for uri in ["client_uri", "logo_uri", "tos_uri", "policy_uri"]:
            if uri in self:
                try:
                    resp = requests.request("GET", self[uri],
                                            allow_redirects=True)
                except requests.ConnectionError:
                    raise MissingPage(self[uri])

                if not resp.status_code in [200, 201]:
                    raise MissingPage(self[uri])

        if "grant_types" in self and "response_types" in self:
            for typ in self["grant_types"]:
                if typ == "authorization_code":
                    try:
                        assert "code" in self["response_types"]
                    except AssertionError:
                        self["response_types"].append("code")
                elif typ == "implicit":
                    try:
                        assert "token" in self["response_types"]
                    except AssertionError:
                        self["response_types"].append("token")

        return super(RegistrationRequest, self).verify(**kwargs)


class ClientInfoResponse(RegistrationRequest):
    c_param = RegistrationRequest.c_param.copy()
    c_param.update({
        "client_id": SINGLE_REQUIRED_STRING,
        "client_secret": SINGLE_OPTIONAL_STRING,
        "client_id_issued_at": SINGLE_OPTIONAL_INT,
        "client_secret_expires_at": SINGLE_OPTIONAL_INT,
        "registration_access_token": SINGLE_REQUIRED_STRING,
        "registration_client_uri": SINGLE_REQUIRED_STRING
    })


class ClientRegistrationError(ErrorResponse):
    c_param = ErrorResponse.c_param.copy()
    c_param.update({"state": SINGLE_OPTIONAL_STRING})
    c_allowed_values = ErrorResponse.c_allowed_values.copy()
    c_allowed_values.update({"error": ["invalid_redirect_uri",
                                       "invalid_client_metadata",
                                       "invalid_client_id"]})


class ClientUpdateRequest(RegistrationRequest):
    c_param = RegistrationRequest.c_param.copy()
    c_param.update({
        "client_id": SINGLE_REQUIRED_STRING,
        "client_secret": SINGLE_OPTIONAL_STRING,
    })


MSG = {
    "RegistrationRequest": RegistrationRequest,
    "ClientInfoResponse": ClientInfoResponse,
    "ClientRegistrationError": ClientRegistrationError,
    "ClientUpdateRequest": ClientUpdateRequest
}


def factory(msgtype):
    try:
        return MSG[msgtype]
    except KeyError:
        return message.factory(msgtype)


# -----------------------------------------------------------------------------
class ClientInfoEndpoint(Endpoint):
    etype = "clientinfo"


class Provider(provider.Provider):
    def __init__(self, name, sdb, cdb, authn_broker, authz, client_authn,
                 symkey="", urlmap=None, iv=0, default_scope="",
                 ca_bundle=None, seed="", client_authn_methods=None,
                 authn_at_registration="", client_info_url="",
                 secret_lifetime=86400):
        provider.Provider.__init__(self, name, sdb, cdb, authn_broker, authz,
                                   client_authn, symkey, urlmap, iv,
                                   default_scope, ca_bundle)

        self.endp.extend([RegistrationEndpoint, ClientInfoEndpoint])

        # dictionary of client authentication methods
        self.client_authn_methods = client_authn_methods
        if authn_at_registration:
            assert authn_at_registration in client_authn_methods
        self.authn_at_registration = authn_at_registration
        self.seed = seed
        self.client_info_url = client_info_url
        self.secret_lifetime = secret_lifetime

    # @staticmethod
    # def _uris_to_dict(uris):
    #     ruri = {}
    #     for uri in uris:
    #         base, query = urllib.splitquery(uri)
    #         if query:
    #             try:
    #                 ruri[base].append(urlparse.parse_qs(query))
    #             except KeyError:
    #                 ruri[base] = [urlparse.parse_qs(query)]
    #         else:
    #             ruri[base] = [""]
    #     return ruri
    #
    # @staticmethod
    # def _dict_to_uris(spec):
    #     _uri = []
    #     for url, qlist in spec.items():
    #         for query in qlist:
    #             if query:
    #                 _uri.append("%s?%s" % (url, query))
    #             else:
    #                 _uri.append(url)
    #     return _uri

    @staticmethod
    def _uris_to_tuples(uris):
        tup = []
        for uri in uris:
            base, query = urllib.splitquery(uri)
            if query:
                tup.append((base, query))
            else:
                tup.append((base,""))
        return tup

    @staticmethod
    def _tuples_to_uris(items):
        _uri = []
        for url, query in items:
            if query:
                _uri.append("%s?%s" % (url, query))
            else:
                _uri.append(url)
        return _uri

    def create_new_client(self, request):
        """

        :param request: The Client registration request
        :return: The client_id
        """

        _cinfo = request.to_dict()

        # create new id and secret
        _id = rndstr(12)
        while _id in self.cdb:
            _id = rndstr(12)

        _cinfo["client_id"] = _id
        _cinfo["client_secret"] = secret(self.seed, _id)
        _cinfo["client_id_issued_at"] = utc_time_sans_frac()
        _cinfo["client_secret_expires_at"] = utc_time_sans_frac() + \
            self.secret_lifetime

        # If I support client info endpoint
        if ClientInfoEndpoint in self.endp:
            _cinfo["registration_access_token"] = rndstr(32)
            _cinfo["registration_client_uri"] = "%s%s?client_id=%s" % (
                self.client_info_url, ClientInfoEndpoint.etype, _id)

        if "redirect_uris" in request:
            _cinfo["redirect_uris"] = self._uris_to_tuples(
                request["redirect_uris"])

        self.cdb[_id] = _cinfo

        return _id

    def client_info(self, client_id):
        _cinfo = self.cdb[client_id].copy()
        try:
            _cinfo["redirect_uris"] = self._tuples_to_uris(
                _cinfo["redirect_uris"])
        except KeyError:
            pass

        msg = ClientInfoResponse(**_cinfo)
        return Response(msg.to_json(), content="application/json")

    def client_info_update(self, client_id, request):
        _cinfo = self.cdb[client_id].copy()
        try:
            _cinfo["redirect_uris"] = self._tuples_to_uris(
                _cinfo["redirect_uris"])
        except KeyError:
            pass

        for key, value in request.items():
            if key in ["client_secret", "client_id"]:
                # assure it's the same
                try:
                    assert value == _cinfo[key]
                except AssertionError:
                    raise ModificationForbidden("Not allowed to change")
            else:
                _cinfo[key] = value

        for key in _cinfo.keys():
            if key in ["client_id_issued_at", "client_secret_expires_at",
                       "registration_access_token","registration_client_uri"]:
                continue
            if key not in request:
                del _cinfo[key]

        if "redirect_uris" in request:
            _cinfo["redirect_uris"] = self._uris_to_tuples(
                request["redirect_uris"])

        self.cdb[client_id] = _cinfo

    def verify_client(self, environ, areq, authn_method, client_id=""):
        """

        :param environ: WSGI environ
        :param areq: The request
        :param authn_method: client authentication method
        :return:
        """

        if not client_id:
            client_id = self.get_client_id(areq, environ["HTTP_AUTHORIZATION"])

        try:
            method = self.client_authn_methods[authn_method]
        except KeyError:
            raise UnSupported()
        return method(self).verify(environ, client_id=client_id)

    def registration_endpoint(self, request, environ, **kwargs):
        """

        :param request: The request
        :param authn: Client authentication information
        :param kwargs: extra keyword arguments
        :return: A Response instance
        """

        _request = RegistrationRequest().deserialize(request, "json")
        try:
            _request.verify()
        except InvalidRedirectUri, err:
            msg = ClientRegistrationError(error="invalid_redirect_uri",
                                          error_description="%s" % err)
            return BadRequest(msg.to_json(), content="application/json")
        except (MissingPage, VerificationError), err:
            msg = ClientRegistrationError(error="invalid_client_metadata",
                                          error_description="%s" % err)
            return BadRequest(msg.to_json(), content="application/json")

        # authenticated client
        if self.authn_at_registration:
            try:
                _ = self.verify_client(environ, _request,
                                       self.authn_at_registration)
            except (AuthnFailure, UnknownAssertionType):
                return Unauthorized()

        client_id = self.create_new_client(_request)

        return self.client_info(client_id)

    def client_info_endpoint(self, request, environ,
                                           method="GET", query="", **kwargs):
        """
        Operations on this endpoint are switched through the use of different
        HTTP methods

        :param request: The request
        :param authn: Client authentication information
        :param method: HTTP method used for the request
        :param query: The query part of the URL used, this is where the
            client_id is supposed to reside.
        :param kwargs: extra keyword arguments
        :return: A Response instance
        """

        _query = urlparse.parse_qs(query)
        try:
            _id = _query["client_id"][0]
        except KeyError:
            return BadRequest("Missing query component")

        try:
            assert _id in self.cdb
        except AssertionError:
            return Unauthorized()

        # authenticated client
        try:
            _ = self.verify_client(environ, request, "bearer_header",
                                   client_id=_id)
        except (AuthnFailure, UnknownAssertionType):
            return Unauthorized()

        if method == "GET":
            return self.client_info(_id)
        elif method == "PUT":
            try:
                _request = ClientUpdateRequest().from_json(request)
            except ValueError:
                return BadRequest()

            try:
                _request.verify()
            except InvalidRedirectUri, err:
                msg = ClientRegistrationError(error="invalid_redirect_uri",
                                              error_description="%s" % err)
                return BadRequest(msg.to_json(), content="application/json")
            except (MissingPage, VerificationError), err:
                msg = ClientRegistrationError(error="invalid_client_metadata",
                                              error_description="%s" % err)
                return BadRequest(msg.to_json(), content="application/json")

            try:
                self.client_info_update(_id, _request)
                return self.client_info(_id)
            except ModificationForbidden:
                return Forbidden()
        elif method == "DELETE":
            try:
                del self.cdb[_id]
            except KeyError:
                return Unauthorized()
            else:
                return NoContent()

    def providerinfo_endpoint(self):
        pass


RESPONSE2ERROR = {
    "ClientInfoResponse": [ClientRegistrationError],
    "ClientUpdateRequest": [ClientRegistrationError]
}


class Client(oauth2.Client):
    def __init__(self, client_id=None, ca_certs=None,
                 client_authn_method=None, keyjar=None, verify_ssl=True):
        oauth2.Client.__init__(self, client_id=client_id, ca_certs=ca_certs,
                               client_authn_method=client_authn_method,
                               keyjar=keyjar, verify_ssl=verify_ssl)
        self.allow = {}
        self.request2endpoint.update({
            "RegistrationRequest": "registration_endpoint",
            "ClientUpdateRequest": "clientinfo_endpoint"
        })
        self.registration_response = None

    def construct_RegistrationRequest(self, request=RegistrationRequest,
                                      request_args=None, extra_args=None,
                                      **kwargs):

        if request_args is None:
            request_args = {}

        return self.construct_request(request, request_args, extra_args)

    def do_client_registration(self, request=RegistrationRequest,
                               body_type="", method="GET",
                               request_args=None, extra_args=None,
                               http_args=None,
                               response_cls=ClientInfoResponse,
                               **kwargs):

        url, body, ht_args, csi = self.request_info(request, method,
                                                    request_args, extra_args,
                                                    **kwargs)

        if http_args is None:
            http_args = ht_args
        else:
            http_args.update(http_args)

        resp = self.request_and_return(url, response_cls, method, body,
                                       body_type, http_args=http_args)

        return resp

    def do_client_read_request(self, request=ClientUpdateRequest,
                               body_type="", method="GET",
                               request_args=None, extra_args=None,
                               http_args=None,
                               response_cls=ClientInfoResponse,
                               **kwargs):

        url, body, ht_args, csi = self.request_info(request, method,
                                                    request_args, extra_args,
                                                    **kwargs)

        if http_args is None:
            http_args = ht_args
        else:
            http_args.update(http_args)

        resp = self.request_and_return(url, response_cls, method, body,
                                       body_type, http_args=http_args)

        return resp

    def do_client_update_request(self, request=ClientUpdateRequest,
                                 body_type="", method="PUT",
                                 request_args=None, extra_args=None,
                                 http_args=None,
                                 response_cls=ClientInfoResponse,
                                 **kwargs):

        url, body, ht_args, csi = self.request_info(request, method,
                                                    request_args, extra_args,
                                                    **kwargs)

        if http_args is None:
            http_args = ht_args
        else:
            http_args.update(http_args)

        resp = self.request_and_return(url, response_cls, method, body,
                                       body_type, http_args=http_args)

        return resp

    def do_client_delete_request(self, request=ClientUpdateRequest,
                                 body_type="", method="DELETE",
                                 request_args=None, extra_args=None,
                                 http_args=None,
                                 response_cls=ClientInfoResponse,
                                 **kwargs):

        url, body, ht_args, csi = self.request_info(request, method,
                                                    request_args, extra_args,
                                                    **kwargs)

        if http_args is None:
            http_args = ht_args
        else:
            http_args.update(http_args)

        resp = self.request_and_return(url, response_cls, method, body,
                                       body_type, http_args=http_args)

        return resp

    def handle_provider_config(self, pcr, issuer, keys=True, endpoints=True):
        """
        Deal with Provider Config Response
        :param pcr: The ProviderConfigResponse instance
        :param issuer: The one I thought should be the issuer of the config
        :param keys: Should I deal with keys
        :param endpoints: Should I deal with endpoints, that is store them
        as attributes in self.
        """

        if "issuer" in pcr:
            _pcr_issuer = pcr["issuer"]
            if pcr["issuer"].endswith("/"):
                if issuer.endswith("/"):
                    _issuer = issuer
                else:
                    _issuer = issuer + "/"
            else:
                if issuer.endswith("/"):
                    _issuer = issuer[:-1]
                else:
                    _issuer = issuer

            try:
                _ = self.allow["issuer_mismatch"]
            except KeyError:
                try:
                    assert _issuer == _pcr_issuer
                except AssertionError:
                    raise PyoidcError(
                        "provider info issuer mismatch '%s' != '%s'" % (
                            _issuer, _pcr_issuer))

            self.provider_info[_pcr_issuer] = pcr
        else:
            _pcr_issuer = issuer

        if endpoints:
            for key, val in pcr.items():
                if key.endswith("_endpoint"):
                    setattr(self, key, val)

        if keys:
            if self.keyjar is None:
                self.keyjar = KeyJar()

            self.keyjar.load_keys(pcr, _pcr_issuer)

    def provider_config(self, issuer, keys=True, endpoints=True,
                        response_cls=ProviderConfigurationResponse,
                        serv_pattern=OIDCONF_PATTERN):
        if issuer.endswith("/"):
            _issuer = issuer[:-1]
        else:
            _issuer = issuer

        url = serv_pattern % _issuer

        pcr = None
        r = self.http_request(url)
        if r.status_code == 200:
            pcr = response_cls().from_json(r.text)
        elif r.status_code == 302:
            while r.status_code == 302:
                r = self.http_request(r.headers["location"])
                if r.status_code == 200:
                    pcr = response_cls().from_json(r.text)
                    break

        if pcr is None:
            raise PyoidcError("Trying '%s', status %s" % (url, r.status_code))

        self.handle_provider_config(pcr, issuer, keys, endpoints)

        return pcr

    def store_registration_info(self, reginfo):
        self.registration_response = reginfo
        self.client_secret = reginfo["client_secret"]
        self.client_id = reginfo["client_id"]
        self.redirect_uris = reginfo["redirect_uris"]

    def handle_registration_info(self, response):
        if response.status_code == 200:
            resp = ClientInfoResponse().deserialize(response.text, "json")
            self.store_registration_info(resp)
        else:
            err = ErrorResponse().deserialize(response.text, "json")
            raise PyoidcError("Registration failed: %s" % err.get_json())

        return resp

    def register(self, url, **kwargs):
        """
        Register the client at an OP

        :param url: The OPs registration endpoint
        :param kwargs: parameters to the registration request
        :return:
        """
        req = self.construct_RegistrationRequest(request_args=kwargs)

        headers = {"content-type": "application/json"}

        rsp = self.http_request(url, "POST", data=req.to_json(),
                                headers=headers)

        return self.handle_registration_info(rsp)

    def parse_authz_response(self, query):
        aresp = self.parse_response(AuthorizationResponse,
                                    info=query,
                                    sformat="urlencoded",
                                    keyjar=self.keyjar)
        if aresp.type() == "ErrorResponse":
            logger.info("ErrorResponse: %s" % aresp)
            raise AuthzError(aresp.error)

        logger.info("Aresp: %s" % aresp)

        return aresp