"""Unit tests for code.justin.tv.dta.rockpaperscissors.ingest_queue module."""

from __future__ import absolute_import

import base64
import logging
import unittest

import mock

from . import errors
from . import ingest_queue


logging.basicConfig(level=logging.INFO)


class IngestQueueTestCase(unittest.TestCase):

    def testAddRequests(self):
        queue = ingest_queue.IngestQueue('queue_name')
        request = ingest_queue.IngestBlueprintRequest('repo', 'path')

        with mock.patch.object(queue, 'sqs', autospec=True) as mock_sqs:
            mock_queue = mock_sqs.get_queue_by_name.return_value
            mock_queue.send_messages.return_value = {}

            queue.AddRequests([request])

            mock_sqs.get_queue_by_name.assert_called_once_with(
              QueueName='queue_name')
            mock_queue.send_messages.assert_called_once_with(
              Entries=[
                {
                  'Id': '0',
                  'MessageBody': base64.b64encode(request.SerializeToString()),
                },
              ])

    def testAddRequestsFailure(self):
        queue = ingest_queue.IngestQueue('queue_name')
        request = ingest_queue.IngestBlueprintRequest('repo', 'path')

        with mock.patch.object(queue, 'sqs', autospec=True) as mock_sqs:
            mock_queue = mock_sqs.get_queue_by_name.return_value
            mock_queue.send_messages.return_value = {
              'Failed': {
                'Id': '0',
                'SenderFault': True,
                'Code': '404',
                'Message': 'An error occurred',
              }
            }

            self.assertRaises(errors.IngestQueueAddRequestsFailures,
                              queue.AddRequests, [request])

    def testIngestQueueRequestTooLarge(self):
        queue = ingest_queue.IngestQueue('queue_name')
        request = ingest_queue.IngestGitHubStatsRequest(
          'dta/rockpaperscissors', 'a' * ingest_queue.MAX_MESSAGE_BYTES)

        with mock.patch.object(queue, 'sqs', autospec=True) as mock_sqs:
            mock_queue = mock_sqs.get_queue_by_name.return_value
            mock_queue.send_messages.return_value = {}

            self.assertRaises(errors.IngestQueueRequestTooLarge,
                              queue.AddRequests, [request])
            mock_queue.send_messages.assert_not_called()

    def testBatchByBytes(self):
        queue = ingest_queue.IngestQueue('queue_name')
        request = ingest_queue.IngestBlueprintRequest('repo', 'path')
        encoded_request = base64.b64encode(request.SerializeToString())

        with mock.patch.object(ingest_queue, 'MAX_BATCH_BYTES',
                               2*len(encoded_request)):
            with mock.patch.object(queue, 'sqs', autospec=True) as mock_sqs:
                mock_queue = mock_sqs.get_queue_by_name.return_value
                mock_queue.send_messages.return_value = {}

                queue.AddRequests([request, request, request])

                mock_queue.send_messages.assert_has_calls([
                  mock.call(
                    Entries=[
                      {
                        'Id': '0',
                        'MessageBody': encoded_request,
                      },
                      {
                        'Id': '1',
                        'MessageBody': encoded_request,
                      },
                    ]),
                  mock.call(
                    Entries=[
                      {
                        'Id': '2',
                        'MessageBody': encoded_request,
                      },
                    ]),
                ])

    def testBatchByMessages(self):
        queue = ingest_queue.IngestQueue('queue_name')
        request = ingest_queue.IngestBlueprintRequest('repo', 'path')
        encoded_request = base64.b64encode(request.SerializeToString())

        with mock.patch.object(ingest_queue, 'MAX_BATCH_MESSAGES', 2):
            with mock.patch.object(queue, 'sqs', autospec=True) as mock_sqs:
                mock_queue = mock_sqs.get_queue_by_name.return_value
                mock_queue.send_messages.return_value = {}

                queue.AddRequests([request, request, request])

                mock_queue.send_messages.assert_has_calls([
                  mock.call(
                    Entries=[
                      {
                        'Id': '0',
                        'MessageBody': encoded_request,
                      },
                      {
                        'Id': '1',
                        'MessageBody': encoded_request,
                      },
                    ]),
                  mock.call(
                    Entries=[
                      {
                        'Id': '2',
                        'MessageBody': encoded_request,
                      },
                    ]),
                ])


if __name__ == '__main__':
    unittest.main()
