"""Unit tests for code.justin.tv.dta.rockpaperscissors.event_bus module."""

from __future__ import absolute_import

import base64
import logging
import unittest

import mock

from . import errors
from . import event_bus
from . import large_event_store


logging.basicConfig(level=logging.INFO)

event_bus.MAX_RECORD_BYTES = 50


class EventBusTestCase(unittest.TestCase):

    def testBasic(self):
        mock_large_event_store = mock.create_autospec(
            large_event_store.LargeEventStore, instance=True)
        bus = event_bus.EventBus('stream_name', mock_large_event_store)
        event = event_bus.Event(
          uuid='abcdef',
          timestamp=0.0,
          event_type='type',
          body='body',
          attributes={'key': 'value'})

        with mock.patch.object(bus, 'kinesis', autospec=True) as mock_kinesis:
            bus.PutEvent(event)

            mock_kinesis.put_record.assert_called_once_with(
              StreamName='stream_name',
              Data=event.SerializeToString(),
              PartitionKey=base64.b64encode('abcdef'))

    def testEventBusRecordTooLarge(self):
        mock_large_event_store = mock.create_autospec(
            large_event_store.LargeEventStore, instance=True)
        bus = event_bus.EventBus('stream_name', mock_large_event_store)
        event = event_bus.Event(
          event_type='type',
          body='a'*event_bus.MAX_RECORD_BYTES)

        with mock.patch.object(bus, 'kinesis') as mock_kinesis:
            bus.PutEvent(event)
            mock_large_event_store.PutEvent.assert_called_once_with(event)
            mock_kinesis.put_record.assert_not_called()


class DecodeKinesisRecordDataTestCase(unittest.TestCase):

    def testBasic(self):
        expected_event = event_bus.Event(
          event_type='type',
          body='body')
        data = base64.b64encode(expected_event.SerializeToString())

        decoded_event = event_bus.DecodeKinesisRecordData(data)

        self.assertEqual(expected_event, decoded_event)

    def testBadBase64(self):
        data = 'badbase64'
        self.assertRaises(errors.EventDecodeError,
                          event_bus.DecodeKinesisRecordData, data)

    def testBadProtobuf(self):
        data = base64.b64encode('badprotobuf')
        self.assertRaises(errors.EventDecodeError,
                          event_bus.DecodeKinesisRecordData, data)


class ValidateEventTestCase(unittest.TestCase):

    def testPassed(self):
        event = event_bus.Event('type', 'body')
        event_bus.ValidateEvent(event)

    def testBadUuid(self):
        event = event_bus.Event('type', 'body')
        event.ClearField('uuid')
        self.assertRaises(errors.EventInvalid, event_bus.ValidateEvent, event)

    def testBadTimestamp(self):
        event = event_bus.Event('type', 'body')
        event.ClearField('timestamp')
        self.assertRaises(errors.EventInvalid, event_bus.ValidateEvent, event)

    def testBadType(self):
        event = event_bus.Event('type', 'body')
        event.ClearField('type')
        self.assertRaises(errors.EventInvalid, event_bus.ValidateEvent, event)

    def testBadBody(self):
        event = event_bus.Event('type', 'body')
        event.ClearField('body')
        self.assertRaises(errors.EventInvalid, event_bus.ValidateEvent, event)


if __name__ == '__main__':
    unittest.main()
