from unittest.mock import patch
import unittest
from airflow.operators.python_operator import PythonOperator

from zenyatta.aws.ec2 import EC2Instance
from zenyatta.aws.rds import AuroraDriver, SingleInstanceDriver, RDSResource
from zenyatta.db.sql import PostgresSQL, MySQL
from zenyatta.tasks import (get_driver, create_rds_resource, create_rds_point_in_time_recover_task,
                            delete_rds_instance_task, get_s3_bucket_check_task,
                            open_rds_instance_permissions_task, get_table_dump_task,
                            get_make_ebs_volume_ec2_instance_and_run_wal_backup_task,
                            get_create_container_task, remove_create_container_task,
                            send_done_message_task)

from tests.util import get_connection, get_dag

ROLE_ARN = "arn:aws:iam::465369119046:role/airflow"


def test_get_driver():
    conn = get_connection()
    driver = get_driver(conn)
    assert type(driver) is SingleInstanceDriver
    assert type(driver.sql_driver) is PostgresSQL
    conn = get_connection(conn_type='Aurora')
    driver = get_driver(conn)
    assert type(driver) is AuroraDriver
    assert type(driver.sql_driver) is MySQL
    conn = get_connection(conn_type='MySQL')
    driver = get_driver(conn)
    assert type(driver) is SingleInstanceDriver
    assert type(driver.sql_driver) is MySQL


@patch('zenyatta.tasks.get_role_from_conn_id', return_value=ROLE_ARN)
def test_create_rds_resource(mock_yea):
    conn = get_connection()
    rds = create_rds_resource(conn)
    assert rds.identifier == conn.host.split('.').pop(0)
    assert type(rds.driver.sql_driver) is PostgresSQL
    assert type(rds.driver) is SingleInstanceDriver


@patch('zenyatta.tasks.get_role_from_conn_id', return_value=ROLE_ARN)
def test_create_rds_point_in_time_recover_task(mock_yea):
    conn = get_connection()
    task = create_rds_point_in_time_recover_task(conn, get_dag())
    assert type(task) is PythonOperator
    assert task.task_id == "rds-point-in-time-recovery-for-{conn.conn_id}".format(**locals())
    assert 'source' in task.op_kwargs
    assert type(task.op_kwargs.get('source')) is RDSResource
    assert task.pool == conn.conn_id


@patch('zenyatta.tasks.get_role_from_conn_id', return_value=ROLE_ARN)
def test_delete_rds_instance_task(mock_yea):
    conn = get_connection()
    task = delete_rds_instance_task(conn)
    assert type(task) is PythonOperator
    assert task.task_id == "delete-rds-instance-{conn.conn_id}-etl".format(**locals())
    assert 'source' in task.op_kwargs
    assert type(task.op_kwargs.get('source')) is RDSResource
    assert task.pool is None


@unittest.skip("pass and skip")
def test_generate_task_for_sql_table_type_metadata():
    conn = get_connection()
    dag = get_dag()
    task = generate_task_for_sql_table_type_metadata(conn, dag)
    assert type(task) is PythonOperator
    assert task.task_id == "get-table-types-{conn.conn_id}".format(**locals())
    assert 'conn_id' in task.op_kwargs
    assert 's3_meta' in task.op_kwargs
    assert task.pool is None


def test_get_s3_bucket_check_task():
    conn = get_connection()
    dag = get_dag()
    task = get_s3_bucket_check_task(conn, dag)
    assert task.task_id == 's3-bucket-check'
    assert 's3_meta' in task.op_kwargs
    assert 'conn_id' in task.op_kwargs
    assert task.op_kwargs['s3_meta']['bucket'] == dag.default_args['s3_output']['bucket']


@patch('zenyatta.tasks.get_role_from_conn_id', return_value=ROLE_ARN)
def test_open_rds_instance_permissions_task(mock_yea):
    conn = get_connection()
    dag = get_dag()
    role = 'test-role'
    task = open_rds_instance_permissions_task(conn, role, dag)
    assert task.task_id == "open-rds-instance-permissions-{conn.conn_id}".format(**locals())
    assert 'source' in task.op_kwargs
    assert 'conn_id' in task.op_kwargs
    assert 'role' in task.op_kwargs
    assert task.op_kwargs['role'] == role
    assert task.op_kwargs['conn_id'] == conn.conn_id


def test_get_table_dump_task():
    table = 'test'
    conn = get_connection()
    dag = get_dag()
    task = get_table_dump_task(table, conn, True, dag)
    assert task.task_id == "{conn.conn_id}-{table}-etl".format(**locals())
    assert 's3_meta' in task.op_kwargs
    assert 'table' in task.op_kwargs
    assert 'conn_id' in task.op_kwargs
    assert task.pool == conn.conn_id


@patch('zenyatta.tasks.get_role_from_conn_id', return_value='arn:aws:iam::465361111111:role/test')
def test_get_wal_backup_task(conn_mock):
    conn = get_connection()
    dag = get_dag()
    dag.default_args['wal_aws_creds'] = dag.default_args['sns_creds'].copy()
    task = get_make_ebs_volume_ec2_instance_and_run_wal_backup_task(conn, dag)
    assert task.task_id == "wal-backup-fetch-to-new-ebs-{conn.conn_id}".format(**locals())
    assert 'wal_aws_creds' in task.op_kwargs
    assert task.pool == "{conn.conn_id}-wal-fetch".format(**locals())


@patch('zenyatta.tasks.get_role_from_conn_id', return_value='arn:aws:iam::465361111111:role/test')
def test_get_create_container_task(conn_mock):
    conn = get_connection()
    task = get_create_container_task(conn)
    assert task.task_id == "creating-postgres-container-{conn.conn_id}".format(**locals())


def test_remove_create_container_task():
    conn = get_connection()
    task = remove_create_container_task(conn)
    assert task.task_id == "remove-postgres-container-{conn.conn_id}".format(**locals())
    assert 'master_conn_id' in task.op_kwargs
    assert task.pool == conn.conn_id


@patch('zenyatta.tasks.get_role_from_conn_id', return_value=ROLE_ARN)
def test_send_done_message_task(mock_yea):
    conn = get_connection()
    dag = get_dag()
    task = send_done_message_task(conn, dag)
    assert 'conn_id' in task.op_kwargs
    assert task.op_kwargs['conn_id'] == conn.conn_id
    assert 'sns_arn' in task.op_kwargs


@patch('zenyatta.tasks.get_airflow_connection', return_value=get_connection())
def test_get_table_type_metadata(conn_mock):
    # TODO this requires some mocks of the sqlalchemy inspect function
    # mock output files
    # mock s3 files
    pass


@patch('zenyatta.tasks.get_airflow_connection', return_value=get_connection())
def test_dump_table_upload_it_delete_it(conn_mock):
    # dump_table_upload_it_delete_it()
    # TODO this requires some mocks of sqlalchemy db_engine.begin(), db_connection.execute(),
    # opening output path
    # mock output files
    # mock s3 files

    pass
