from unittest.mock import patch

from zenyatta.aws.rds import RDSResource, SingleInstanceDriver
from zenyatta.aws.sns import send_sns_message_wrapper
from zenyatta.aws.s3 import get_base_s3_path
from zenyatta.db.sql import PostgresSQL
from tests.static.constants import TS_NODASH
from tests.util import AWS_CREDS, get_connection
from tests.static.responses import DESCRIBE_DB_INSTANCES

import pytest


@patch('zenyatta.aws.sns.send_sns_message', return_value=True)
@patch('zenyatta.aws.sns.get_airflow_connection', return_value=get_connection())
def test_send_sns_message_does_not_raise_type_error(msg_mock, conn_mock):
    """
    [2017-01-05 20:30:42,788] {models.py:1286} ERROR - generate_s3_path_for_finished_message() missing 2
        required positional arguments: 'ts_nodash' and 'conn_id'
    Traceback (most recent call last):
        File "/opt/virtualenvs/airflow/lib/python3.5/site-packages/airflow/models.py", line 1245, in run
        result = task_copy.execute(context=context)
        File "/opt/virtualenvs/airflow/lib/python3.5/site-packages/airflow/operators/python_operator.py",
            line 66, in execute
        return_value = self.python_callable(*self.op_args, **self.op_kwargs)
        File "/opt/twitch/zenyatta/releases/1483644932-3b41ae46dfd2c255f9cfe5dfc3f876a006e6bcf4/zenyatta/
            aws/sns.py", line 30, in send_sns_message_wrapper
        message = generate_etl_finished_message(generate_s3_path_for_finished_message(s3_creds),
                                            TypeError: generate_s3_path_for_finished_message() missing 2
                                            required positional arguments: 'ts_nodash' and 'conn_id'
    """
    try:
        send_sns_message_wrapper('test-conn',
                                 {'s3_key': 'key', 'bucket': 'bucket'},
                                 'arn:1234', ts_nodash=TS_NODASH)
    except TypeError as error:
        pytest.fail("we've regressed to {error}".format(**locals()))


TEST_SQL_ENGINE_HOST = 'better.not.be.this'


@patch('zenyatta.aws.rds.RDSDriver.get_host_and_port', return_value=('lawl.test.host', 5432))
@patch('zenyatta.aws.rds.SingleInstanceDriver.get_rds_metadata',
       return_value=DESCRIBE_DB_INSTANCES.copy()['DBInstances'][0])
@patch('zenyatta.db.sql.get_airflow_connection', return_value=get_connection(host=TEST_SQL_ENGINE_HOST))
def test_sql_engine_host_and_port(host_port_mock, boto_mock, conn_mock):
    """discovered that the host and identifier might actually be from the parent connection
       (the production dbs. do not want that...)
    """
    resource = RDSResource(identifier='test',
                           driver=SingleInstanceDriver(sql_driver=PostgresSQL('test')))

    engine = resource.create_sql_engine()
    assert TEST_SQL_ENGINE_HOST not in str(engine)


@patch('zenyatta.aws.rds.RDSDriver.get_host_and_port', return_value=('lawl.test.host', 5432))
@patch('zenyatta.aws.rds.SingleInstanceDriver.get_rds_metadata',
       return_value=DESCRIBE_DB_INSTANCES.copy()['DBInstances'][0])
@patch('zenyatta.db.sql.get_airflow_connection', return_value=get_connection(host=TEST_SQL_ENGINE_HOST))
def test_sql_stream_results(host_port_mock, boto_mock, conn_mock):
    """This bug manifests itself as this sort of error:
        [2017-01-06 01:49:42,389] {models.py:1327} ERROR - (psycopg2.ProgrammingError)
        syntax error at or near "GRANT"
        LINE 1: ...ECLARE "c_7fd56aa02390_5" CURSOR WITHOUT HOLD FOR GRANT 'cha...
                                                                 ^
         [SQL: "GRANT 'chat' TO 'cohesion'"]
    """
    resource = RDSResource(identifier='test',
                           driver=SingleInstanceDriver(sql_driver=PostgresSQL('test')))
    # test default is true
    engine = resource.create_sql_engine()
    assert 'stream_results' in engine._execution_options
    assert engine._execution_options.get('stream_results') is True
    # test false
    engine = resource.create_sql_engine(stream_results=False)
    assert 'stream_results' in engine._execution_options
    assert engine._execution_options.get('stream_results') is False


def test_base_s3_path_replaces_timestamp():
    key = "dbsnapshots"
    conn_id = "justintv_prod-20170531T080000"
    ts_nodash = "20170531T080000"
    base = get_base_s3_path(key, ts_nodash, conn_id)
    assert base == key + '/' + ts_nodash + '/' + "justintv_prod"
