# Copyright 2014 NYBX Inc.
# All rights reserved.

"""
:module: ledgerx.api.client.test.test_client_protocol_0_0_1
:synopsis: Unit tests for client protocol messages.
:author: Amr Ali <amr@ledgerx.com>

This test module is only concerned with v0.0.1 of client protocol messages.
See <test_client_protocol> module for version-agnostic messages.
"""

import unittest

from ledgerx.api.client.protocol import ClientMessage
from ledgerx.api.client.v0_0_1 import (MessageStatus,
        MessageContractIDMixin, MessageEntryIDMixin, MessageSubmissionMixin,
        MessageLimitOrder, MessageMarketOrder, MessageQuote,
        MessageCancelReplace, MessageCancel, MessageReduceSize,
        MessageActionReport, MessageBookTop, MessageGetContractDetail,
        MessageContractDetail, MessageSCPOrder, MessageExerciseRequest,
        MessageTIMixin, UUID_HEX_NUM_BYTES, MessageActionReportMixin,
        MessageGetBookState, MessageBookState, MessageReplay)

class TestMessageStatus(unittest.TestCase):

    def test_status_codes(self):
        self.assertEqual(MessageStatus.STATUS_CONTRACT_NOT_FOUND, 600)
        self.assertEqual(MessageStatus.STATUS_ENTRY_NOT_FOUND, 601)
        self.assertEqual(MessageStatus.STATUS_ENTRYID_IS_INVALID, 602)
        self.assertEqual(MessageStatus.STATUS_CANCEL_FAILURE, 603)
        self.assertEqual(MessageStatus.STATUS_CANCEL_REPLACE_FAILURE, 604)
        self.assertEqual(MessageStatus.STATUS_REDUCE_SIZE_UNDERFLOW, 605)
        self.assertEqual(MessageStatus.STATUS_REDUCE_SIZE_FAILURE, 606)
        self.assertEqual(MessageStatus.STATUS_MARKET_HOURS_DENIED, 607)
        self.assertEqual(MessageStatus.STATUS_INVALID_MIN_INCREMENT, 608)
        self.assertEqual(MessageStatus.STATUS_CANCEL_SUCCESS, 700)
        self.assertEqual(MessageStatus.STATUS_CANCEL_REPLACE_SUCCESS, 701)
        self.assertEqual(MessageStatus.STATUS_REDUCE_SIZE_SUCCESS, 702)

    def test_default_version(self):
        self.assertEqual(MessageStatus().mversion, '0.0.1')

class TestClientMixins(unittest.TestCase):

    def test_message_contract_id_mixing(self):
        class __TestMsg(MessageContractIDMixin): pass
        msg = __TestMsg()
        self.assertEqual(msg.contract_id, 0)
        with self.assertRaises(ValueError):
            msg.contract_id = b''
        msg.contract_id = True
        self.assertEqual(msg.contract_id, 1)
        self.assertTrue(msg.fullfills(MessageContractIDMixin))

    def test_message_entry_id_mixing(self):
        class __TestMsg(MessageEntryIDMixin): pass
        msg = __TestMsg()
        self.assertEqual(msg.entry_id,
                b''.rjust(UUID_HEX_NUM_BYTES, b'0'))
        msg.entry_id = b'1'
        self.assertEqual(msg.entry_id,
                b'1'.rjust(UUID_HEX_NUM_BYTES, b'0'))
        self.assertTrue(msg.fullfills(MessageEntryIDMixin))

    def test_message_submission_mixing(self):
        class __TestMsg(MessageSubmissionMixin): pass
        msg = __TestMsg()

        self.assertFalse(msg.is_ask)
        msg.is_ask = 1
        self.assertTrue(msg.is_ask)

        self.assertEqual(msg.size, 0)
        with self.assertRaises(ValueError):
            msg.size = ''
        msg.size = 1
        self.assertEqual(msg.size, 1)

        self.assertEqual(msg.price_in_cents, 0)
        with self.assertRaises(ValueError):
            msg.price_in_cents = ''
        msg.price_in_cents = 1
        self.assertEqual(msg.price_in_cents, 1)

        self.assertFalse(msg.mox2)
        with self.assertRaises(AttributeError):
            msg.mox2 = 1 # Read only

        self.assertTrue(msg.fullfills(MessageContractIDMixin))
        self.assertTrue(msg.fullfills(MessageSubmissionMixin))

    def test_message_ti_mixing(self):
        class __TestMsg(MessageTIMixin): pass
        msg = __TestMsg()

        self.assertIsNone(msg.ti)
        with self.assertRaises(ValueError):
            msg.ti = '$$$$'

        msg.ti = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ:0123456789'
        self.assertEqual(msg.ti, 'ABCDEFGHIJKLMNOPQRSTUVWXYZ:0123456789')

        self.assertTrue(msg.fullfills(MessageTIMixin))

    def test_message_action_report_mixing(self):
        class __TestMsg(MessageActionReportMixin): pass
        msg = __TestMsg()

        self.assertEqual(msg.filled_size, 0)
        self.assertEqual(msg.filled_price, 0)
        self.assertEqual(msg.status_type, 0)
        self.assertEqual(msg.status_reason, 0)
        with self.assertRaises(ValueError):
            msg.filled_size = b'test'
        with self.assertRaises(ValueError):
            msg.filled_price = b'test'
        with self.assertRaises(ValueError):
            msg.status_type = b'test'
        with self.assertRaises(ValueError):
            msg.status_reason = b'test'
        msg.filled_size = 12
        msg.filled_price = 1500
        msg.status_type = 34
        msg.status_reason = 45
        self.assertEqual(msg.filled_size, 12)
        self.assertEqual(msg.filled_price, 1500)
        self.assertEqual(msg.status_type, 34)
        self.assertEqual(msg.status_reason, 45)

        self.assertTrue(msg.fullfills(MessageActionReportMixin))
        self.assertTrue(msg.fullfills(MessageContractIDMixin))

class TestClientMessages(unittest.TestCase):

    def test_message_limit_order(self):
        msg = MessageLimitOrder()
        self.assertEqual(msg.mversion, '0.0.1')
        self.assertTrue(msg.fullfills(MessageLimitOrder))
        self.assertTrue(msg.fullfills(ClientMessage))

    def test_message_scp_order(self):
        msg = MessageSCPOrder()

        self.assertEqual(msg.mversion, '0.0.1')
        self.assertEqual(msg.scp_sweeper_mpid, None)
        self.assertEqual(msg.scp_sweeper_cid, None)
        self.assertEqual(msg.scp_conversation_id, None)
        with self.assertRaises(ValueError):
            msg.scp_sweeper_mpid = b'test'
        with self.assertRaises(ValueError):
            msg.scp_sweeper_cid = b'test'
        with self.assertRaises(ValueError):
            msg.scp_conversation_id = b'test'
        msg.scp_sweeper_mpid = 12
        msg.scp_sweeper_cid = 34
        msg.scp_conversation_id = 79
        self.assertEqual(msg.scp_sweeper_mpid, 12)
        self.assertEqual(msg.scp_sweeper_cid, 34)
        self.assertEqual(msg.scp_conversation_id, 79)

        self.assertTrue(msg.fullfills(MessageSubmissionMixin))
        self.assertTrue(msg.fullfills(ClientMessage))

    def test_message_market_order(self):
        msg = MessageMarketOrder()
        self.assertEqual(msg.mversion, '0.0.1')
        self.assertTrue(msg.fullfills(MessageSubmissionMixin))
        self.assertTrue(msg.fullfills(ClientMessage))

    def test_message_quote(self):
        msg = MessageQuote()

        self.assertEqual(msg.mversion, '0.0.1')
        self.assertFalse(msg.mox2)
        msg.mox2 = b''
        self.assertFalse(msg.mox2)
        msg.mox2 = 1
        self.assertTrue(msg.mox2)

        self.assertTrue(msg.fullfills(MessageSubmissionMixin))
        self.assertTrue(msg.fullfills(ClientMessage))

    def test_message_cancel_replace(self):
        msg = MessageCancelReplace()
        self.assertEqual(msg.mversion, '0.0.1')
        self.assertTrue(msg.fullfills(MessageSubmissionMixin))
        self.assertTrue(msg.fullfills(MessageEntryIDMixin))
        self.assertTrue(msg.fullfills(ClientMessage))

    def test_message_cancel(self):
        msg = MessageCancel()
        self.assertEqual(msg.mversion, '0.0.1')
        self.assertTrue(msg.fullfills(MessageEntryIDMixin))
        self.assertTrue(msg.fullfills(MessageContractIDMixin))
        self.assertTrue(msg.fullfills(ClientMessage))

    def test_message_reduce_size(self):
        msg = MessageReduceSize()

        self.assertEqual(msg.mversion, '0.0.1')
        self.assertEqual(msg.size_decrement, 0)
        with self.assertRaises(ValueError):
            msg.size_decrement = b'test'
        with self.assertRaises(ValueError):
            msg.size_decrement = 0
        msg.size_decrement = 33
        self.assertEqual(msg.size_decrement, 33)

        self.assertTrue(msg.fullfills(MessageContractIDMixin))
        self.assertTrue(msg.fullfills(MessageEntryIDMixin))
        self.assertTrue(msg.fullfills(ClientMessage))

    def test_message_action_report(self):
        msg = MessageActionReport()
        self.assertEqual(msg.mversion, '0.0.1')
        self.assertTrue(msg.fullfills(MessageActionReportMixin))
        self.assertTrue(msg.fullfills(ClientMessage))

    def test_message_get_book_state(self):
        msg = MessageGetBookState()
        self.assertEqual(msg.mversion, '0.0.1')
        self.assertTrue(msg.fullfills(MessageContractIDMixin))
        self.assertTrue(msg.fullfills(ClientMessage))

    def test_message_book_state(self):
        msg = MessageBookState()

        self.assertEqual(msg.mversion, '0.0.1')
        self.assertEqual(msg.price, 0)
        self.assertEqual(msg.size, 0)
        self.assertEqual(msg.is_ask, False)
        self.assertEqual(msg.is_end, False)
        with self.assertRaises(ValueError):
            msg.price = b'test'
        with self.assertRaises(ValueError):
            msg.size = b'test'
        msg.is_end = msg.is_ask = b''
        self.assertEqual(msg.is_ask, False)
        self.assertEqual(msg.is_end, False)
        msg.is_end = msg.is_ask = 1
        self.assertEqual(msg.is_ask, True)
        self.assertEqual(msg.is_end, True)
        msg.price = 12
        msg.size = 55
        self.assertEqual(msg.price, 12)
        self.assertEqual(msg.size, 55)

        self.assertTrue(msg.fullfills(MessageContractIDMixin))
        self.assertTrue(msg.fullfills(ClientMessage))

    def test_message_book_top(self):
        msg = MessageBookTop()

        self.assertEqual(msg.mversion, '0.0.1')
        self.assertEqual(msg.ask, 0)
        self.assertEqual(msg.bid, 0)
        with self.assertRaises(ValueError):
            msg.ask = b'test'
        with self.assertRaises(ValueError):
            msg.bid = b'test'
        msg.ask = 12
        msg.bid = 23
        self.assertEqual(msg.ask, 12)
        self.assertEqual(msg.bid, 23)

        self.assertTrue(msg.fullfills(MessageContractIDMixin))
        self.assertTrue(msg.fullfills(ClientMessage))

    def test_message_get_contract_detail(self):
        msg = MessageGetContractDetail()

        self.assertEqual(msg.mversion, '0.0.1')
        self.assertEqual(msg.all_contracts, False)
        msg.all_contracts = b'test'
        self.assertEqual(msg.all_contracts, True)

        cd = msg.contract_detail()
        self.assertEqual(cd.mid, msg.mid)

        self.assertTrue(msg.fullfills(MessageContractIDMixin))
        self.assertTrue(msg.fullfills(ClientMessage))

    def test_message_contract_detail(self):
        msg = MessageContractDetail()

        self.assertEqual(msg.mversion, '0.0.1')
        self.assertEqual(msg.expiration, 0)
        self.assertEqual(msg.strike_price, 0)
        self.assertIsNone(msg.contract_type)
        with self.assertRaises(ValueError):
            msg.contract_type = 0x31337
        with self.assertRaises(ValueError):
            msg.expiration = b'test'
        with self.assertRaises(ValueError):
            msg.strike_price = b'test'
        msg.expiration = 12
        msg.strike_price = 23
        msg.contract_type = msg.CONTRACT_TYPE_CALL
        self.assertEqual(msg.expiration, 12)
        self.assertEqual(msg.strike_price, 23)
        self.assertEqual(msg.contract_type, msg.CONTRACT_TYPE_CALL)

        self.assertTrue(msg.fullfills(MessageContractIDMixin))
        self.assertTrue(msg.fullfills(ClientMessage))

    def test_message_exercise_request(self):
        msg = MessageExerciseRequest()
        self.assertEqual(msg.mversion, '0.0.1')
        self.assertTrue(msg.fullfills(MessageTIMixin))
        self.assertTrue(msg.fullfills(ClientMessage))

    def test_message_replay(self):
        msg = MessageReplay()

        self.assertEqual(msg.mversion, '0.0.1')
        self.assertIsNone(msg.begin_timestamp)
        self.assertIsNone(msg.end_timestamp)
        with self.assertRaises(ValueError):
            msg.begin_timestamp = 'test'
        with self.assertRaises(ValueError):
            msg.end_timestamp = 'test'
        msg.begin_timestamp = 1
        msg.end_timestamp = 2
        self.assertEqual(msg.begin_timestamp, 1)
        self.assertEqual(msg.end_timestamp, 2)

        self.assertTrue(msg.fullfills(ClientMessage))

