from pytest import mark
import sqlalchemy as sa
from tests import TestCase
from sqlalchemy import inspect
from sqlalchemy_utils.types import password
from sqlalchemy_utils import Password, PasswordType


@mark.skipif('password.passlib is None')
class TestPasswordType(TestCase):
    def create_models(self):
        class User(self.Base):
            __tablename__ = 'user'
            id = sa.Column(sa.Integer, primary_key=True)
            password = sa.Column(PasswordType(
                schemes=[
                    'pbkdf2_sha512',
                    'pbkdf2_sha256',
                    'md5_crypt'
                ],

                deprecated=['md5_crypt']
            ))

            def __repr__(self):
                return 'User(%r)' % self.id

        self.User = User

    def test_encrypt(self):
        """Should encrypt the password on setting the attribute."""
        obj = self.User()
        obj.password = b'b'

        assert obj.password.hash != 'b'
        assert obj.password.hash.startswith(b'$pbkdf2-sha512$')

    def test_check(self):
        """
        Should be able to compare the plaintext against the
        encrypted form.
        """
        obj = self.User()
        obj.password = 'b'

        assert obj.password == 'b'
        assert obj.password != 'a'

        self.session.add(obj)
        self.session.commit()

        obj = self.session.query(self.User).get(obj.id)

        assert obj.password == b'b'
        assert obj.password != 'a'

    def test_check_and_update(self):
        """
        Should be able to compare the plaintext against a deprecated
        encrypted form and have it auto-update to the preferred version.
        """

        from passlib.hash import md5_crypt

        obj = self.User()
        obj.password = Password(md5_crypt.encrypt('b'))

        assert obj.password.hash.startswith('$1$')
        assert obj.password == 'b'
        assert obj.password.hash.startswith('$pbkdf2-sha512$')

    def test_auto_column_length(self):
        """Should derive the correct column length from the specified schemes.
        """

        from passlib.hash import pbkdf2_sha512

        impl = inspect(self.User).c.password.type.impl

        # name + rounds + salt + hash + ($ * 4) of largest hash
        expected_length = len(pbkdf2_sha512.name)
        expected_length += len(str(pbkdf2_sha512.max_rounds))
        expected_length += pbkdf2_sha512.max_salt_size
        expected_length += pbkdf2_sha512.encoded_checksum_size
        expected_length += 4

        assert impl.length == expected_length

    def test_without_schemes(self):
        assert PasswordType(schemes=[]).length == 1024
