from nose.tools import assert_raises
import sqlalchemy as sa
from sqlalchemy import orm
from formosa import validators
from formosa.validators import ValidationError

class TestSame(object):
    def setup(self):
        self.validator = validators.Same(set(['a', 'b', 'c']))

    def test_sameness(self):
        assert_raises(ValidationError, self.validator.validate,
                      dict(a=1, b=2, c=1))
        assert_raises(ValidationError, self.validator.validate,
                      dict(a=1, b=1, c=3))
        assert_raises(ValidationError, self.validator.validate,
                      dict(a=9, b=1, c=1))
        self.validator.validate(dict(a=1, b=1, c=1))

    def test_missing(self):
        self.validator.validate(dict(a=1, b=9, c=None))
        self.validator.validate(dict(a=None, b=None, c=None))
        self.validator.validate(dict(a=None, b=None, c=None))
        self.validator.validate(dict(a=1, c=8))
        self.validator.validate(dict(a=5, b=3))
        self.validator.validate(dict(b=3, c=10))


class TestDifferent(object):
    def setup(self):
        self.validator = validators.Different(set(['a', 'b', 'c']))

    def test_difference(self):
        assert_raises(ValidationError, self.validator.validate,
                      dict(a=1, b=1, c=1))
        assert_raises(ValidationError, self.validator.validate,
                      dict(a=1, b=2, c=1))
        assert_raises(ValidationError, self.validator.validate,
                      dict(a=1, b=1, c=2))
        assert_raises(ValidationError, self.validator.validate,
                      dict(a=2, b=2, c=2))
        self.validator.validate(dict(a=2, b=1, c=3))

    def test_missing(self):
        self.validator.validate(dict(a=1, b=1, c=None))
        self.validator.validate(dict(a=None, b=1, c=None))
        self.validator.validate(dict(a=None, b=None, c=None))
        self.validator.validate(dict(a=1, c=1))
        self.validator.validate(dict(a=1, b=1))
        self.validator.validate(dict(b=1, c=1))


class TestAtLeast(object):
    def setup(self):
        self.validator = validators.AtLeast(2, set(['a', 'b', 'c']))

    def test_min(self):
        assert_raises(ValidationError, self.validator.validate,
                      dict(a=None, b=None, c=None))
        assert_raises(ValidationError, self.validator.validate,
                      dict(a='Foo', b=None, c=None))
        self.validator.validate(dict(a=None, b=None))
        self.validator.validate(dict(a='foo', b=None, c='bar'))
        self.validator.validate(dict(a='foo', b='baz', c='bar'))


class TestAtMost(object):
    def setup(self):
        self.validator = validators.AtMost(1, set(['a', 'b', 'c']))

    def test_max(self):
        assert_raises(ValidationError, self.validator.validate,
                      dict(a='Foo', b=None, c='Foo'))
        assert_raises(ValidationError, self.validator.validate,
                      dict(a='Foo', b=None, c=''))
        assert_raises(ValidationError, self.validator.validate,
                      dict(a='Foo', b='Asdf', c='Jkl;'))
        self.validator.validate(dict(a=None, b=None, c=None))
        self.validator.validate(dict(a=None, b='Bar', c=None))
        self.validator.validate(dict(a='Foo', b='Bar'))


class TestOnlyIf(object):
    def setup(self):
        self.single_dependency_validator \
            = validators.OnlyIf(['hire_date'], ['interview_date'])
        self.multiple_dependency_validator \
            = validators.OnlyIf(set('xy'), set('abc'))

    def test_single_dependency(self):
        validator = self.single_dependency_validator
        assert_raises(ValidationError, validator.validate,
                      dict(hire_date='20071126', interview_date=None))
        validator.validate(dict(hire_date=None, interview_date=None))
        validator.validate(dict(hire_date=None, interview_date='20071126'))
        validator.validate(dict(hire_date='20071126',
                                interview_date='20071112'))

    def test_multiple_dependencies(self):
        validator = self.multiple_dependency_validator
        assert_raises(ValidationError, validator.validate,
                      dict(x='foo', y='bar', a=None, b=None, c=None))
        assert_raises(ValidationError, validator.validate,
                      dict(x='foo', y='bar', a=1, b=None, c=None))
        assert_raises(ValidationError, validator.validate,
                      dict(x='foo', y='bar', a=1, b=2, c=None))
        assert_raises(ValidationError, validator.validate,
                      dict(x='foo', y='bar', a=1, b=None, c=2))
        assert_raises(ValidationError, validator.validate,
                      dict(x='foo', y='bar', a=None, b=1, c=2))
        validator.validate(dict(x='foo', a=None, b=None, c=None))
        validator.validate(dict(x='foo', y='bar', a=1, b=2, c=3))


class TestAscending(object):
    def setup(self):
        self.validator = validators.Ascending(['a', 'b', 'c'])

    def test_order(self):
        self.validator.validate(dict(a=1, b=2, c=3))
        self.validator.validate(dict(a=5, c=3))
        assert_raises(ValidationError, self.validator.validate,
                      dict(dict(a=1, b=3, c=2)))


class TestDescending(object):
    def setup(self):
        self.validator = validators.Descending(['a', 'b', 'c'])

    def test_order(self):
        self.validator.validate(dict(a=3, b=2, c=1))
        self.validator.validate(dict(a=3, c=5))
        assert_raises(ValidationError, self.validator.validate,
                      dict(dict(a=5, b=1, c=2)))


class TestEqualTo(object):
    def setup(self):
        self.validator = validators.EqualTo(42, set(['a', 'b', 'c']))

    def test_sameness(self):
        assert_raises(ValidationError, self.validator.validate,
                      dict(a=42, b=43, c=42))
        assert_raises(ValidationError, self.validator.validate,
                      dict(a=42, b=42, c=43))
        assert_raises(ValidationError, self.validator.validate,
                      dict(a=42, b=42, c=43))
        self.validator.validate(dict(a=42, b=42, c=42))

    def test_missing(self):
        self.validator.validate(dict(a=1, b=9, c=None))
        self.validator.validate(dict(a=None, b=None, c=None))
        self.validator.validate(dict(a=None, b=None, c=None))
        self.validator.validate(dict(a=1, c=8))
        self.validator.validate(dict(a=5, b=3))
        self.validator.validate(dict(b=3, c=10))


class Record(object):
    def __init__(self, name, a=None, b=None):
        self.name = name
        self.a = a
        self.b = b


class TestUnique(object):
    def setup(self):
        metadata = sa.MetaData()
        records_table = sa.Table('records', metadata,
                                 sa.Column('name', sa.Text, primary_key=True),
                                 sa.Column('a', sa.Integer),
                                 sa.Column('b', sa.Integer),
                                 sa.UniqueConstraint('a', 'b'))

        engine = sa.create_engine('sqlite://')
        metadata.create_all(engine)

        orm.mapper(Record, records_table)
        records = {'foo': Record('foo', 1, 1),
                   'bar': Record('bar', 1, 2),
                   'baz': Record('baz', 2, 1),
                   'quux': Record('quux', 1, None),
                   'quuux': Record('quuux', None, 1),
                   'quuuux': Record('quuuux', None, None)}

        self.session = orm.create_session(engine, autocommit=False)
        for record in records.itervalues():
            self.session.save(record)
            self.session.commit()

        self.name_validator = validators.Unique(self.session, Record,
                                                {'name': 'record_name'})
        self.name_validator_for_foo \
            = validators.Unique(self.session, Record, {'name': 'record_name'},
                                records['foo'])
        self.a_b_validator = validators.Unique(self.session, Record,
                                               {'a': 'a', 'b': 'b'})
        self.a_b_validator_for_baz \
            = validators.Unique(self.session, Record, {'a': 'a', 'b': 'b'},
                                records['baz'])

    def teardown(self):
        orm.clear_mappers()

    def test_single_prop(self):
        assert_raises(ValidationError, self.name_validator.validate,
                      {'record_name': 'foo'})
        self.name_validator.validate({'record_name': 'xyz'})

    def test_single_prop_with_inst(self):
        assert_raises(ValidationError, self.name_validator_for_foo.validate,
                      {'record_name': 'bar'})
        self.name_validator_for_foo.validate({'record_name': 'foo'})

    def test_multi_prop(self):
        assert_raises(ValidationError, self.a_b_validator.validate,
                      dict(a=1, b=1))
        assert_raises(ValidationError, self.a_b_validator.validate,
                      dict(a=1, b=2))

    def test_multi_prop_with_inst(self):
        assert_raises(ValidationError, self.a_b_validator_for_baz.validate,
                      dict(a=1, b=2))
        self.a_b_validator_for_baz.validate(dict(a=2, b=1))

    def test_missing_params(self):
        self.a_b_validator.validate(dict(a=None, b=None))
        self.a_b_validator.validate(dict(a=1, b=None))
        self.a_b_validator.validate(dict(a=1))
        self.a_b_validator.validate(dict())


class TestIfAnyThen(object):
    def setup(self):
        self.validator = validators.IfAnyThen(['a', 'b', 'c'],
                                              ['x', 'y', 'z'])

    def test_requirements(self):
        assert_raises(ValidationError, self.validator.validate, dict(a=1))
        assert_raises(ValidationError, self.validator.validate,
                      dict(b=1, x=1, y=2))
        assert_raises(ValidationError, self.validator.validate,
                      dict(c=1, z=1, y=2))
        assert_raises(ValidationError, self.validator.validate,
                      dict(a=5, c=1, z=1, y=2))
        self.validator.validate(dict())
        self.validator.validate(dict(x=1, z=3))
        self.validator.validate(dict(a=1, x=1, y=2, z=3))
        self.validator.validate(dict(b=5, x=1, y=2, z=3))
        self.validator.validate(dict(c=9, x=1, y=2, z=3))
        self.validator.validate(dict(a=1, b=7, c=9, x=1, y=2, z=3))
