import unittest
from unittest.mock import patch
from zenyatta.common.spark import make_spark_table_config
from moto import mock_emr
import boto3
import os
import pytest

dummys = {
    'aws': {
        'role_arn': 'arn:aws:iam::465369119046:role/airflow',
        'bucket': 'twitch-d8a-test',
        's3_key': 'dbsnapshot',
    },
    'emr': {
        'prefix': 'zenyatta',
        'id': 'j-2GEGTTROZ7QGT',
    },
    'conn_pt': {
        'conn_id': 'vinyl',
        'table': 'audible_magic_responses',
    },
    'conn_nonpt': {
        'conn_id': 'vinyl',
        'table': 'vod_appeals',
    },
    'ts_nodash': '20170715t000000',
    's3_object': {
        'prefix': 'etl/output/cohesion-friends-associations-pq/_SUCCESS',
        'fake_prefix': 'etl/output/fake',
    },
    'conf': {
        'spark': 'all/spark/conf',
    },
    'partition': {
        'col': 'col_one',
        'min': 1,
        'max': 100,
    },
}


class Conn:
    def __init__(self, login, password, host, port, schema, conn_type):
        self.login = login
        self._password = password
        self.host = host
        self.port = port
        self.schema = schema
        self.conn_type = conn_type

    def get_password(self):
        return self._password


def airflow_connection_postgres_mock():
    return Conn(login='psql', password='password', host='localhost', port='5432',
                      schema='psql', conn_type='Postgres')


def airflow_connection_mysql_mock():
    return Conn(login='mysql', password='password', host='localhost', port='3306',
                      schema='mysql', conn_type='MySQL')


class ParquetTestCase(unittest.TestCase):
    """testing for ParquetTestCase"""

    # preparing to test
    def setUp(self):
        """ setup myself """
        print("setting up for test")
        self.meta = dummys['aws']

    # ending the test
    def tearDown(self):
        """tear down myself """

    @patch('zenyatta.common.spark.check_table_partition', return_value=dummys['partition'])
    @patch('zenyatta.common.spark.conf.get', return_value=dummys['conf']['spark'])
    @patch('zenyatta.common.spark.get_airflow_connection', return_value=airflow_connection_postgres_mock())
    def test_make_spark_table_config(self, conn_mock, conf_mock, check_mock):
        config = make_spark_table_config("dummy_table", "psql", "pitr_host",
                                         dummys['ts_nodash'], dummys['aws'])
        assert config['partitionFlag'] is True
        assert config['jdbc_url'] == 'jdbc:postgresql://pitr_host:5432/psql'

    @patch('zenyatta.common.spark.check_table_partition', return_value=dummys['partition'])
    @patch('zenyatta.common.spark.conf.get', return_value=dummys['conf']['spark'])
    @patch('zenyatta.common.spark.get_airflow_connection', return_value=airflow_connection_mysql_mock())
    def test_make_spark_table_mysql(self, conn_mock, conf_mock, check_mock):
        config = make_spark_table_config("dummy_table", "mysql", "pitr_host",
                                         dummys['ts_nodash'], dummys['aws'])
        assert config['jdbc_url'] == 'jdbc:mysql://pitr_host:3306/mysql'

    @pytest.yield_fixture(scope='function')
    def test_emr_config(self):
        mock = mock_emr()
        mock.start()
        os.environ['AWS_DEFAULT_REGION'] = 'us-east-1'
        client = boto3.client('emr')
        clusters = []
        for i in range(2):
            cluster = client.run_job_flow(
                Name='cluster-{:02d}'.format(i),
                Instances={
                    'MasterInstanceType': 'c3.xlarge',
                    'SlaveInstanceType': 'c3.xlarge',
                    'InstanceCount': 3,
                    'Placement': {'AvailabilityZone': 'us-east-1'},
                    'KeepJobFlowAliveWhenNoSteps': True,
                },
                VisibleToAllUsers=True,
            )
            clusters.append(cluster)
        yield {'clusters': clusters}
        mock.stop()
        os.environ['AWS_DEFAULT_REGION'] = 'us-west-2'

if __name__ == '__main__':
    unittest.main()
