# copyright 2013 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
# contact http://www.logilab.fr -- mailto:contact@logilab.fr
#
# This program is free software: you can redistribute it and/or modify it under
# the terms of the GNU Lesser General Public License as published by the Free
# Software Foundation, either version 2.1 of the License, or (at your option)
# any later version.
#
# This program is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
# details.
#
# You should have received a copy of the GNU Lesser General Public License along
# with this program. If not, see <http://www.gnu.org/licenses/>.

"""cubicweb-signedrequest automatic tests for authentication
"""
import hmac
from datetime import datetime, timedelta
from operator import itemgetter
from StringIO import StringIO

from cubicweb.devtools.testlib import CubicWebTC
from cubicweb.web.controller import Controller
from cubicweb.predicates import authenticated_user

HEADERS_TO_SIGN = ('Content-MD5', 'Content-Type', 'Date')

class TestController(Controller):
    __regid__ = 'testauth'
    __select__ = authenticated_user()

    def publish(self, rset):
        if self._cw.user.login == self._cw.form.get('expected', 'admin'):
            return u'VALID'
        else:
            return u'INVALID'


class TrustedAuthTC(CubicWebTC):

    def setup_database(self):
        req = self.request()
        req.execute('INSERT AuthToken T: T token "my precious", '
                    '                    T token_for_user U, '
                    '                    T id "admin", '
                    '                    T enabled True'
                    ' WHERE U login "admin"')

    def _build_string_to_sign(self, headers, method='GET'):
        get_headers = itemgetter(*HEADERS_TO_SIGN)
        url = self.request().build_url(TestController.__regid__)
        return method + url + ''.join(get_headers(headers))

    def _build_signature(self, id, string_to_sign):
        req = self.request()
        rset = req.execute('Any K WHERE T id %(id)s, T token K',
                           {'id': id})
        assert rset
        return hmac.new(str(rset[0][0]), string_to_sign).hexdigest()

    def _test_header_format(self, method, login, signature, http_method='GET',
                            headers=None):
        if headers is None:
            headers = {}
        with self.temporary_appobjects(TestController):
            req = self.requestcls(self.vreg, url=TestController.__regid__,
                                  method=http_method)
            req.form['expected'] = 'admin'
            # Fill an arbitrary body content if POST.
            if http_method == 'POST':
                req.content = StringIO("rql=Any+X+WHERE+X+is+Player")
            self.set_auth_mode('http')
            req.set_request_header('Authorization', '%s %s:%s' % (method, login, signature), raw=True)
            for name, value in headers.items():
                req.set_request_header(name, value, raw=True)
            try:
                # re-enable normal error handling
                fake_error_handler = self.app.error_handler
                del self.app.error_handler
                result = self.app.handle_request(req, 'testauth')
            finally:
                self.app.error_handler = fake_error_handler
        return result, req

    def get_valid_authdata(self, headers=None):
        if headers is None:
            headers = {}
        headers.setdefault('Content-MD5', 'aa3d66a90f73242ef6f679ce26b3691e')
        headers.setdefault('Content-Type', 'application/xhtml+xml')
        headers.setdefault('Date', datetime.utcnow().strftime('%a, %d %b %Y %H:%M:%S GMT'))
        string_to_sign = self._build_string_to_sign(headers)
        signature = self._build_signature('admin', string_to_sign)
        return signature, headers

    def test_login(self):
        signature, headers = self.get_valid_authdata()
        result, req = self._test_header_format(method='Cubicweb',
                                               login='admin',
                                               signature=signature,
                                               headers=headers)
        self.assertEqual(200, req.status_out)
        self.assertEqual('VALID', result)

    def test_bad_date(self):
        for date in ((datetime.utcnow() + timedelta(0, 1000)).strftime('%a, %d %b %Y %H:%M:%S GMT'),
                     (datetime.utcnow() - timedelta(0, 1000)).strftime('%a, %d %b %Y %H:%M:%S GMT'),
                     'toto'):
            headers = {'Date': date}
            signature, headers = self.get_valid_authdata(headers)
            result, req = self._test_header_format(method='Cubicweb',
                                                   login='admin',
                                                   signature=signature,
                                                   headers=headers)
            self.assertEqual(401, req.status_out)

    def test_bad_http_auth_method(self):
        signature = self._build_signature('admin', '')
        result, req = self._test_header_format(method='AWS', login='admin', signature=signature)
        self.assertEqual(401, req.status_out)

    def test_bad_signature(self):
        result, req = self._test_header_format(method='Cubicweb', login='admin', signature='YYY')
        self.assertEqual(401, req.status_out)

    def test_deactivated_token(self):
        req = self.request()
        req.execute('SET T enabled False WHERE T token_for_user U, U login %(l)s',
                    {'l':'admin'})
        self.commit()
        signature, headers = self.get_valid_authdata()
        result, req = self._test_header_format(method='Cubicweb',
                                               login='admin',
                                               signature=signature,
                                               headers=headers)
        self.assertEqual(401, req.status_out)

    def test_bad_signature_url(self):
        def bad_build_string_to_sign(self, headers):
            get_headers = itemgetter(*HEADERS_TO_SIGN)
            return ''.join(get_headers(headers))
        self._build_string_to_sign = bad_build_string_to_sign
        result, req = self._test_header_format(method='Cubicweb', login='admin', signature='YYY')
        self.assertEqual(401, req.status_out)


    def test_post_http_request_signature(self):
        headers = {'Content-MD5': '43115f65c182069f76b56df967e5c3fd',
                   'Content-Type': 'application/x-www-form-urlencoded',
                   'Date': datetime.utcnow().strftime('%a, %d %b %Y %H:%M:%S GMT')}
        string_to_sign = self._build_string_to_sign(headers, method='POST')
        signature = self._build_signature('admin', string_to_sign)
        result, req = self._test_header_format(method='Cubicweb',
                                               login='admin',
                                               signature=signature,
                                               http_method='POST',
                                               headers=headers)
        self.assertEqual(200, req.status_out)
        self.assertEqual('VALID', result)

if __name__ == "__main__":
    from logilab.common.testlib import unittest_main
    unittest_main()
