import os
import csv
import json
import logging
import shutil
from datetime import timedelta
from typing import Any, Dict, Callable, List
from time import sleep
import random
from copy import deepcopy
from random import randrange
from pathlib import Path
from botocore.exceptions import ClientError
from sqlalchemy import inspect
from sqlalchemy.orm.exc import NoResultFound
from airflow import DAG
from airflow.models import Connection, TaskInstance
from airflow.settings import conf, Session
from airflow.operators.python_operator import PythonOperator, ShortCircuitOperator
from airflow.operators.dummy_operator import DummyOperator

from zenyatta.aws.ec2 import (EBSVolume, EC2Instance, get_free_device_from_self, create_ec2_instance,
                              get_free_device_from_block_mapings)
from zenyatta.aws.sns import send_sns_message_wrapper
from zenyatta.aws.rds import (create_rds_instance_from_point_in_time, delete_rds_resource,
                              generate_etl_db_instance_identifier, RDSResource, AuroraDriver,
                              SingleInstanceDriver, open_rds_instance_permissions)
from zenyatta.common import (get_airflow_connection, get_work_directory, xcoms, check_table_priority, get_connections)
from zenyatta.common.util import (run_command, make_dir, update_aws_resource_variable,
                                  get_aws_resource_variable, cleanup_aws_from_xcom,
                                  cleanup_connection_from_xcom, cleanup_csv_dump_from_xcom,
                                  cleanup_parquet_dump_from_xcom)
from zenyatta.common.slack import send_slack_message
from zenyatta.common.errors import ZenyattaError
from zenyatta.aws import (boto_resource, get_creds, get_role_from_conn_id)
from zenyatta.aws.s3 import (get_base_s3_path, ensure_s3_bucket_exists, check_s3_object_exist,
                             upload_file_to_s3, download_objects_from_s3, remove_s3_object)
from zenyatta.common.spark import (make_spark_table_config, generate_spark_script,
                                   get_yarn_app_by_name, get_yarn_app_by_id)
from zenyatta.db.sql import SQL, MySQL, PostgresSQL, get_ec2_sql_conn_id, run_table_update
from zenyatta.db.wal import wal_backup_command
from zenyatta.docker import create_docker_postgres, remove_docker_postgres, PostgresContainer
from zenyatta.aws.emr import EMR


def python_task_wrapper(callme: Callable, task_id: str, op_kwargs: Dict[Any, Any], pool: str=None,
                        dag: DAG=None, priority_weight: int=1, trigger_rule: str="all_success",
                        retries: int=11) -> PythonOperator:
    return PythonOperator(
        python_callable=callme,
        task_id=task_id,
        op_kwargs=op_kwargs,
        provide_context=True,
        on_failure_callback=handle_task_failure,
        pool=pool,
        dag=dag,
        priority_weight=priority_weight,
        execution_timeout=timedelta(hours=24),
        on_success_callback=handle_task_succeed,
        trigger_rule=trigger_rule,
        retries=retries
    )


def python_task_wrapper_sc(callme: Callable, task_id: str, op_kwargs: Dict[Any, Any], pool: str=None,
                           dag: DAG=None) -> ShortCircuitOperator:

    return ShortCircuitOperator(
        python_callable=callme,
        task_id=task_id,
        op_kwargs=op_kwargs,
        provide_context=True,
        on_failure_callback=handle_task_failure,
        pool=pool,
        dag=dag
    )


def python_dummy_wrapper(task_id: str, dag: DAG=None) -> DummyOperator:
    return DummyOperator(task_id=task_id, dag=dag)


def get_driver(connection: Connection) -> Any:
    conn_type = connection.conn_type.lower()
    if 'aurora' in conn_type:
        return AuroraDriver(sql_driver=MySQL(connection.conn_id))
    else:
        if 'postgres' in conn_type:
            return SingleInstanceDriver(sql_driver=PostgresSQL(connection.conn_id))
        elif 'mysql' in conn_type:
            return SingleInstanceDriver(sql_driver=MySQL(connection.conn_id))


def create_rds_resource(connection: Connection) -> RDSResource:
    # the identifier is encoded in the host name
    identifier = RDSResource.db_instance_identifier(connection.host)
    role_arn = get_role_from_conn_id(connection.conn_id)
    return RDSResource(identifier=identifier,
                       driver=get_driver(connection),
                       role_arn=role_arn)


def attach_ec2_volume_to_ebs_volume(instance: EC2Instance, volume: EBSVolume) -> Any:
    ec2, _ = boto_resource('ec2', volume.role_arn)
    boto_instance = ec2.Instance(instance.identifier)
    free_device = get_free_device_from_block_mapings(boto_instance.block_device_mappings)
    response = instance.attach_volume(volume, free_device)
    return response


def link_to_task(task_instance: TaskInstance,
                 fqdn="zenyatta.production.twitch-web-aws.us-west2.justin.tv") -> str:
    date = task_instance.execution_date.isoformat()
    return "http://{fqdn}:8080/admin/airflow/task?execution_date={date}&dag_id={task_instance.dag_id}" \
        "&task_id={task_instance.task_id}".format(**locals())


def get_db_tables(connection: Connection) -> List[str]:
    """
    retrieve schema from a connection's pitr
    """
    var_key = connection.conn_id + '-pitr-schema'
    tables = get_aws_resource_variable(var_key, 'db_tables')
    if tables is None:
        logging.error("failed to get schema for {connection.conn_id}".format(**locals()))
        raise ZenyattaError("failed to get schema for {connection.conn_id}".format(**locals()))
    else:
        ts_nodash = get_aws_resource_variable(var_key, 'ts_nodash')
        logging.info("{connection.conn_id} restored from {ts_nodash} : {tables}"
                     .format(**locals()))
    return tables


def handle_task_failure(context: dict) -> None:
    logging.info("in handle task failure with context: {}".format(context))
    ti = context.get('task_instance')
    ts_nodash = context.get('ts_nodash')
    try:
        message = "dag: {} task: {} failed @ {}".format(ti.dag_id, ti.task_id, link_to_task(ti))
        resp = send_slack_message(message)
        logging.info("slack resp is: {}".format(resp))
    except Exception as e:
        logging.info("unexpected error: {}".format(e))

    if 'cleanup_aws' in context:
        logging.info("cleaning up aws resources")
        cleanup_aws_from_xcom(context)

    if 'cleanup_connection' in context:
        logging.info("cleaning up connection object in airflow database")
        cleanup_connection_from_xcom(context)

    if 'cleanup_csv_dump' in context:
        logging.info("cleaning up csv file on ebs")
        cleanup_csv_dump_from_xcom(context)

    if 'cleanup_s3_dump' in context:
        logging.info("cleaning up parqeut files on s3")
        cleanup_parquet_dump_from_xcom(context)


def handle_task_succeed(context: dict) -> None:
    logging.info("in handle task succeed with context: {}".format(context))
    ti = context.get('task_instance')
    ts_nodash = context.get('ts_nodash')

    if 'succeed_cleanup_csv_dump' in context:
        logging.info("cleaning up csv file on ebs")
        cleanup_csv_dump_from_xcom(context)

    if 'succeed_cleanup_s3_dump' in context:
        logging.info("cleaning up parqeut files on s3")
        cleanup_parquet_dump_from_xcom(context)


def check_table_exists(task_instance: TaskInstance, ts_nodash: str, conn_id: str, table: str) -> bool:
    """
    check if a table exists in the current dag run, since tables are restored from previous pitr,
    there is a chance for schema change
    """
    key = xcoms.XcomTablesKey(task_instance, ts_nodash).get_key()
    task_ids = "update_db_tables-{}".format(conn_id)
    tables = task_instance.xcom_pull(key=key, task_ids=task_ids)
    return True if table in tables else False


def get_pitr_resource_connection(conn_id: str, is_rds: bool,
                                 ts_nodash: str, task_instance: TaskInstance) -> [Any, Connection]:
    if is_rds:
        connection = get_airflow_connection(conn_id)
        source = create_rds_resource(connection)
        resource = RDSResource(
            identifier=generate_etl_db_instance_identifier(source.identifier, ts_nodash),
            role_arn=source.role_arn,
            driver=deepcopy(source.driver))
    else:
        connection = get_airflow_connection(conn_id+'-'+ts_nodash)
        resource = PostgresContainer(get_work_directory(task_instance.dag_id, ts_nodash),
                                     conn_id=connection.conn_id)
    return resource, connection


# code below is task wraper and callable tasks
def create_rds_point_in_time_recover_task(connection: Connection,
                                          dag: DAG) -> PythonOperator:
    return python_task_wrapper(
        create_rds_instance_from_point_in_time,
        "rds-point-in-time-recovery-for-{connection.conn_id}".format(**locals()),
        {'source': create_rds_resource(connection),
         'role_arn': get_role_from_conn_id(connection.conn_id),
         'cleanup_aws': True},
        dag=dag,
        pool=connection.conn_id)


def delete_rds_instance_task(connection: Connection,
                             dag: DAG=None) -> PythonOperator:
    return python_task_wrapper(
        delete_rds_resource_wrapper,
        "delete-rds-instance-{connection.conn_id}-etl".format(**locals()),
        {'source': create_rds_resource(connection),
         'role_arn': get_role_from_conn_id(connection.conn_id),
         'cleanup_aws': True},
        dag=dag,
        trigger_rule="all_done")


def delete_rds_resource_wrapper(source: RDSResource, **kwargs: Dict[Any, Any]) -> None:
    """
    wrapper function so that the actual delete_rds_instance function
    is generic and I can capture ts_nodash here
    :param resource: metadata for rds
    :param kwargs:
    :return:
    """
    target = RDSResource(
        identifier=generate_etl_db_instance_identifier(source.identifier, kwargs.get('ts_nodash')),
        role_arn=source.role_arn,
        driver=deepcopy(source.driver))
    delete_rds_resource(target)


def get_s3_bucket_check_task(connection: Connection, dag: DAG=None) -> PythonOperator:
    s3_meta = dag.default_args['s3_output']
    if dag and dag.has_task('s3-bucket-check'):
        # multiple host sql etl will typically have an existing bucket check
        return dag.task_dict.get('s3-bucket-check')
    else:
        return python_task_wrapper(
            ensure_s3_bucket_exists,
            's3-bucket-check',
            {'s3_meta': s3_meta,
             'conn_id': connection.conn_id},
            dag=dag)


def open_rds_instance_permissions_task(connection: Connection,
                                       role: str,
                                       dag: DAG) -> PythonOperator:
    return python_task_wrapper(
        open_rds_instance_permissions,
        "open-rds-instance-permissions-{connection.conn_id}".format(**locals()),
        {'conn_id': connection.conn_id,
         'source': create_rds_resource(connection),
         'role': role,
         'role_arn': get_role_from_conn_id(connection.conn_id),
         'cleanup_aws': True},
        dag=dag)


def get_table_dump_task(table: str,
                        connection: Connection,
                        is_rds: bool=None,
                        dag: DAG=None) -> PythonOperator:
    """
    :param table: table name
    :param is_rds: whether rds instance or not
    :param connection: postgres host
    :param dag: not required
    :return:
    """
    return python_task_wrapper(
        dump_table_upload_it_delete_it,
        "{connection.conn_id}-{table}-etl".format(**locals()),
        {'table': table,
         'conn_id': connection.conn_id,
         'is_rds': is_rds,
         's3_meta': dag.default_args['s3_output'],
         'cleanup_csv_dump': True,
         'succeed_cleanup_csv_dump': True},
        pool=connection.conn_id,
        dag=dag,
        retries=3)


def dump_table_upload_it_delete_it(table: str, conn_id: str, is_rds: bool,
                                   s3_meta: Dict[str, Any], **kwargs: Dict[Any, Any]) -> None:
    """this is the ETL task. in the future i suspect this will have to actually be broken up into 3 tasks for
    the extract, transformation, and loading

    currently this does:
        1: SELECT * from table,
        2: ensures utf-8
        3: write to csv
        4: upload to s3
        5: remove file

    :param table: table to dump
    :param conn_id: the connection id in the airflow database
    :param is_rds: db is an RDS store or self managed
    :param s3_meta: meta data describing s3 output
    :param kwargs:
    :return:
    """
    ts_nodash = kwargs.get('ts_nodash')
    task_instance = kwargs.get('task_instance')
    if not check_table_exists(task_instance, ts_nodash, conn_id, table):
        logging.info('{} does not exist in this dag. skipping task.'.format(table))
    else:
        # dump table to local dir
        resource, connection = get_pitr_resource_connection(conn_id, is_rds, ts_nodash, task_instance)
        work_dir = get_work_directory()
        table_dump_path = "{work_dir}/{ts_nodash}-{connection.conn_id}-{table}.csv".format(**locals())
        key = xcoms.XcomCSVKey(task_instance, ts_nodash).get_key()
        logging.info("pushing xcom: {}/{} from task: {}"
                     .format(key, table_dump_path, task_instance.task_id))
        task_instance.xcom_push(key=key, value=table_dump_path)

        sql_command = "SELECT * from {table}".format(**locals())

        db_engine = resource.create_sql_engine()

        header = None
        with db_engine.begin() as db_connection:
            results = db_connection.execute(sql_command)
            with open(table_dump_path, 'w', encoding='utf-8') as f:
                writer = csv.writer(f)
                for row in results:
                    if header is None:
                        header = row
                        writer.writerow(header.keys())
                        writer.writerow(header.values())
                    else:
                        writer.writerow(row.values())
        file_size = os.stat(table_dump_path)
        logging.info("table dumped to {table_dump_path} sized: {file_size}".format(**locals()))

        # move file to s3 via boto3
        base_path = get_base_s3_path(s3_meta['s3_key'], ts_nodash, connection.conn_id)
        s3_path = "{base_path}/{table}.csv".format(**locals())
        logging.info("uploading {table_dump_path} to {s3_path}".format(**locals()))
        upload_file_to_s3(s3_meta['bucket'], table_dump_path, s3_path, s3_meta['role_arn'])


def get_make_ebs_volume_ec2_instance_and_run_wal_backup_task(connection: Connection,
                                                             dag: DAG=None) -> PythonOperator:
    """
    :param connection: postgres host
    :param dag: not required
    :return:
    """
    return python_task_wrapper(
        make_ebs_volume_ec2_instance_and_run_wal_backup,
        "wal-backup-fetch-to-new-ebs-{connection.conn_id}".format(**locals()),
        {'wal_aws_creds': dag.default_args['wal_aws_creds'],
         'role_arn': get_role_from_conn_id(connection.conn_id),
         'cleanup_aws': True},
        pool="{connection.conn_id}-wal-fetch".format(**locals()),
        dag=dag)


def make_ebs_volume_ec2_instance_and_run_wal_backup(wal_aws_creds: Dict[str, str], **kwargs) -> List[str]:

    # get some info on ec2 instance
    myself = EC2Instance.instance_from_self(wal_aws_creds['ec2_arn'])
    # make ebs volume
    volume = EBSVolume(size=conf.get('ec2', 'ebs_size'),
                       volume_type='gp2', availability_zone=myself.availability_zone(),
                       role_arn=wal_aws_creds['ec2_arn'])
    logging.info("creating ebs volume")
    volume.create()
    while not volume.available():
        sleep(randrange(0, 15))

    logging.info("done creating ebs volume")
    # attach it to myself
    logging.info("attaching volume to self")
    device = get_free_device_from_self(EC2Instance.get_block_devices_from_id(
        myself.identifier,
        wal_aws_creds['ec2_arn']))
    try:
        myself.attach_volume(volume, device)
    except:  # this needs to be broad so volumes don't pileup
        volume.destroy()
        raise
    logging.info("formatting volume")
    run_command(["sudo", "mkfs", "-t", "ext4", device])
    # uniquely identify it
    ts_nodash = kwargs.get('ts_nodash')
    dag_id = kwargs['task_instance'].dag_id
    work_directory = Path(get_work_directory(dag_id=dag_id,
                                             ts_nodash=ts_nodash,
                                             prefix_dir=conf.get('ec2', 'worker_mount_point')))
    # make it
    work_directory.mkdir(parents=True, exist_ok=True, mode=0o644)
    # mount it
    logging.info("mounting {device} to {work_directory}".format(**locals()))
    run_command(["sudo", "mount", device, str(work_directory)])
    # allow it
    run_command(["sudo", "chown", "airflow:airflow", str(work_directory)])

    # do the rest of the below
    env = os.environ
    wal_aws_creds.update(get_creds(wal_aws_creds['s3_arn']))
    env.update({"WALE_S3_PREFIX": wal_aws_creds['bucket'],
                "AWS_REGION": wal_aws_creds['region']})

    command = wal_backup_command(wal_aws_creds, str(work_directory), ts_nodash, env)
    logging.info("running postgres setup script: {command}".format(**locals()))
    results = run_command(command, env)

    # detach ebs volume
    logging.info("unmounting {device}".format(**locals()))
    run_command(["sudo", "umount", "-d", device])
    logging.info("detaching volume: {device}".format(**locals()))
    myself.detach_volume(volume, device)
    logging.info("detached volume".format(**locals()))

    # forward VolumeId to next task
    task_instance = kwargs.get('task_instance')
    key = xcoms.XcomVolumeIdKey(task_instance, kwargs.get('ts_nodash')).get_key()
    kwargs['task_instance'].xcom_push(key=key, value=volume.identifier)
    return results


def create_ec2_sql_task(connection: Connection, dag: DAG=None) -> PythonOperator:
    """
    :param connection: postgres host
    :param dag: not required
    :return:
    """
    return python_task_wrapper(
        create_ec2_sql_instance,
        "create-ec2-sql-{connection.conn_id}".format(**locals()),
        {'conn_id': connection.conn_id,
         'role_arn': get_role_from_conn_id(connection.conn_id),
         'cleanup_aws': True},
        dag=dag)


def create_ec2_sql_instance(conn_id: str=None, role_arn: str=None, **kwargs) -> None:
    """
    1: pull volume_id and instance_id from xcom
    2: call attach_ec2_volume_to_ebs_volume
    3: spin up docker on this host
    4: update airflow connections table with new conneciton info
    :return:
    """
    connection = get_airflow_connection(conn_id)
    task_instance = kwargs.get('task_instance')
    # get volume id from xcom
    key = xcoms.XcomVolumeIdKey(task_instance, kwargs.get('ts_nodash')).get_key()
    task_ids = 'wal-backup-fetch-to-new-ebs-' + connection.schema
    logging.info("pulling {key} from tasks: {task_ids}".format(**locals()))
    volume_id = task_instance.xcom_pull(key=key, task_ids=task_ids)
    volume = EBSVolume.volume_from_id(volume_id, role_arn)
    logging.info("pulled volume_id: {volume_id}".format(**locals()))
    myself = EC2Instance.instance_from_self(role_arn)
    tag_filter = {}
    tag_filter['key'] = 'Name'
    tag_filter['value'] = 'Data - '
    subnet_zones_map = myself.sunbets_zone(tag_filter)
    instance_ids = create_ec2_instance(role_arn=role_arn,
                                       image_id=conf.get('ec2', 'worker_ami'),
                                       instance_type=myself.instance_type,
                                       security_groups=myself.security_groups,
                                       subnet_id=subnet_zones_map[volume.availability_zone],
                                       availability_zone=volume.availability_zone,
                                       name_tag=conn_id,
                                       **kwargs)
    instance_id = instance_ids[0]
    instance = EC2Instance.instance_from_id(instance_id, role_arn)
    attach_ec2_volume_to_ebs_volume(instance, volume)

    sesh = Session()
    new_conn_id = get_ec2_sql_conn_id(connection.conn_id, kwargs.get('ts_nodash'))
    try:
        new_conn = sesh.query(Connection).filter_by(conn_id=new_conn_id).one()
        new_conn.host = instance.private_ip_address
    except NoResultFound:
        new_conn = Connection(conn_id=new_conn_id,
                              conn_type=connection.conn_type,
                              host=instance.private_ip_address,
                              login=connection.login,
                              password=connection.get_password(),
                              schema=connection.schema,
                              port=5432)
        sesh.add(new_conn)
    sesh.commit()
    # alert next task what the connection is
    task_instance = kwargs.get('task_instance')
    key = xcoms.XcomConnIdKey(task_instance, kwargs.get('ts_nodash')).get_key()
    logging.info("pushing xcom: {}/{} from task: {}".format(key, new_conn.conn_id, task_instance.task_id))
    task_instance.xcom_push(key=key, value=new_conn.conn_id)
    key = xcoms.XcomInstanceIdKey(task_instance, kwargs.get('ts_nodash')).get_key()
    logging.info("pushing xcom: {}/{} from task: {}".format(key, instance_id, task_instance.task_id))
    task_instance.xcom_push(key=key, value=instance_id)
    key = xcoms.XcomDBConnIdKey(task_instance, kwargs.get('ts_nodash')).get_key()
    logging.info("pushing xcom: {}/{} from task:{}".format(key, new_conn_id, task_instance.task_id))
    task_instance.xcom_push(key=key, value=new_conn_id)


def remove_ec2_sql_task(connection: Connection, dag: DAG=None):
    return python_task_wrapper(
        remove_ec2_sql,
        "remove-ec2-sql-{connection.conn_id}".format(**locals()),
        {'conn_id': connection.conn_id,
         'role_arn': get_role_from_conn_id(connection.conn_id),
         'cleanup_aws': True},
        pool=connection.conn_id,
        dag=dag,
        trigger_rule="all_done")


def remove_ec2_sql(conn_id: str=None, role_arn: str=None, **kwargs) -> None:
    connection = get_airflow_connection(conn_id)
    task_instance = kwargs.get('task_instance')
    task_ids = 'create-ec2-sql-' + connection.schema
    key = xcoms.XcomInstanceIdKey(task_instance, kwargs.get('ts_nodash')).get_key()
    logging.info("pulling xcom: {} from task: {}".format(key, task_ids))
    instance_id = task_instance.xcom_pull(key=key, task_ids=task_ids)
    key = xcoms.XcomConnIdKey(task_instance, kwargs.get('ts_nodash')).get_key()
    logging.info("pulling xcom: {} from task: {}".format(key, task_ids))
    connection_id = task_instance.xcom_pull(key=key, task_ids=task_ids)
    key = xcoms.XcomVolumeIdKey(task_instance, kwargs.get('ts_nodash')).get_key()
    task_ids = 'wal-backup-fetch-to-new-ebs-' + connection.schema
    logging.info("pulling xcom: {} from task: {}".format(key, task_ids))
    volume_id = task_instance.xcom_pull(key=key, task_ids=task_ids)
    logging.info("about to remove {instance_id} and {volume_id}".format(**locals()))

    ec2, _ = boto_resource('ec2', role_arn)
    instance = ec2.Instance(instance_id)
    try:
        # attempt to play nice and unmount drive first
        work_dir = get_work_directory(task_instance.dag_id, kwargs.get('ts_nodash'))
        run_command(['ssh', instance.private_ip_address, 'sudo', 'docker', 'stop', connection_id])
        run_command(['ssh', instance.private_ip_address, 'sudo', 'unmount', work_dir])
    except:
        logging.info("failed to unmount drive")

    volume = ec2.Volume(volume_id)
    while len(volume.attachments) > 0:
        try:
            response = instance.detach_volume(VolumeId=volume_id, Force=True)
            volume = ec2.Volume(volume_id)
            logging.info("volume detach status attempt: {}. current attachments: {}"
                         .format(response, volume.attachments))
        except ClientError:
            # instance is available if this happens
            break
        sleep(60)

    logging.info("deleting volume: {volume.volume_id}".format(**locals()))
    response = volume.delete()
    logging.info("volume deleted: {response}".format(**locals()))

    logging.info("deleting instance: {instance.instance_id}".format(**locals()))
    response = instance.terminate()
    logging.info("instance deleted: {response}".format(**locals()))

    logging.info("deleting {connection_id}".format(**locals()))
    sesh = Session()
    to_remove = sesh.query(Connection).filter(Connection.conn_id.contains(connection_id)).all()
    for remove in to_remove:
        sesh.delete(remove)
    sesh.commit()
    logging.info("{connection_id} deleted".format(**locals()))


def get_create_container_task(connection: Connection, dag: DAG=None):

    return python_task_wrapper(
        create_docker_postgres,
        "creating-postgres-container-{connection.conn_id}".format(**locals()),
        {'conn_id': connection.conn_id,
         'role_arn': get_role_from_conn_id(connection.conn_id),
         'cleanup_aws': True,
         'cleanup_connection': True},
        pool=connection.conn_id,
        dag=dag)


def remove_create_container_task(connection: Connection, dag: DAG=None):
    return python_task_wrapper(
        remove_docker_postgres,
        "remove-postgres-container-{connection.conn_id}".format(**locals()),
        {'master_conn_id': connection.conn_id},
        pool=connection.conn_id,
        dag=dag)


def send_done_message_task(connection: Connection, dag: DAG):
    return python_task_wrapper(
        send_sns_message_wrapper,
        "{connection.conn_id}-is-done-message".format(**locals()),
        {'conn_id': connection.conn_id,
         's3_meta': dag.default_args['s3_output'],
         'sns_arn': dag.default_args['sns_output']['sns_arn'],
         'role_arn': dag.default_args['sns_output']['role_arn']},
        dag=dag,
        trigger_rule="all_done")


def check_postgres_pg_isready_task(connection: Connection,
                                   dag: DAG) -> PythonOperator:
    return python_task_wrapper(
        check_postgres_pg_isready,
        "check_pg_isready-for-{connection.conn_id}".format(**locals()),
        {'conn_id': connection.conn_id,
         'role_arn': get_role_from_conn_id(connection.conn_id),
         'cleanup_aws': True},
        dag=dag,
        pool=connection.conn_id)


def check_postgres_pg_isready(conn_id: str, attempts: int=30, **kwargs):
    """check pg_isready on ec2-sql instance
    :param conn_id: conn_id in airflow metadatabase
    :param attempts: integer for the number of attempts
    :param kwargs: airflow provided context for this task
    :return:
    """
    ts_nodash = kwargs.get('ts_nodash')
    task_instance = kwargs.get('task_instance')
    key = xcoms.XcomConnIdKey(task_instance, ts_nodash).get_key()
    tmp_conn = get_airflow_connection(conn_id)
    task_ids = 'create-ec2-sql-' + tmp_conn.schema
    connection = get_airflow_connection(task_instance.xcom_pull(key=key, task_ids=task_ids))
    key = xcoms.XcomPGReadyKey(task_instance, ts_nodash).get_key()
    try:
        if poll_pg_isready(connection, attempts) is False:
            logging.info("time out and database still refuses connection.")
            raise ZenyattaError("time out and database is in recovery mode")
        else:
            logging.info("pushing xcom: {}/{} from task: {}"
                         .format(key, connection.conn_id, task_instance.task_id))
            task_instance.xcom_push(key=key, value=connection.conn_id)
    except ZenyattaError as ze:
        if 'pg_isready return code is 2 and status is no_response' in str(ze):
            # it's likely the container needs to be restarted.
            container = PostgresContainer(get_work_directory(kwargs['task_instance'].dag_id, ts_nodash),
                                          conn_id=connection.conn_id)
            logging.info("attempting to start container: {}".format(container.name))
            response = container.start_container()
            logging.info("response from container: {}".format(response))
            check_postgres_pg_isready(conn_id, attempts+1, **kwargs)
        else:
            raise  # restart


def poll_pg_isready(connection: Connection, attempts: int, wait_time: int=600) -> bool:
    pg_status = {'0': 'running', '1': 'rejecting', '2': 'no_response', '3': 'unknown'}
    while attempts > 0:
        result = run_command(["ssh",
                              connection.host,
                              "pg_isready",
                              "--quiet",
                              "--host=localhost;",
                              "echo",
                              "$?"]).decode().split('\n')[0]
        if result == '0':
            logging.info("database is accepting connections")
            engine = PostgresSQL(connection.conn_id).create_sql_engine()
            with engine.connect() as db:
                if db.execute("SELECT pg_is_xlog_replay_paused()").fetchone()[0]:
                    return True
                logging.info("database recovery not yet paused. retry in %d seconds...", wait_time)
        elif result == '1':
            logging.info("database is rejecting connections. retry in %d seconds...", wait_time)
        else:
            logging.info("pg_isready return code is %s and status is %s" % (result, pg_status[result]))
            raise ZenyattaError("pg_isready return code is %s status is %s" % (result, pg_status[result]))
        sleep(wait_time)
        attempts -= 1
    return False


def generate_table_spark_script_task(table: str, connection: Connection, is_rds: bool=None,
                                     priority_weight: int=1, dag: DAG=None) -> PythonOperator:
    """
    :param table: table name
    :param is_rds: whether rds instance or not
    :param connection: postgres host
    :param dag: not required
    :return:
    """
    return python_task_wrapper(
        generate_table_spark_script,
        "{connection.conn_id}-{table}-etl-prep".format(**locals()),
        {'table': table,
         'conn_id': connection.conn_id,
         'is_rds': is_rds},
        priority_weight=priority_weight,
        pool=connection.conn_id,
        dag=dag)


def generate_table_spark_script(table: str, conn_id: str, is_rds: bool, **kwargs: Dict[Any, Any]) -> None:
    """this is the first ETL task to save table on parquet.
     this step includes:
        1. make a spark configuration
        2. export the spark script and upload to S3

    :param table: table to dump
    :param conn_id: the connection id in the airflow database
    :param is_rds: db is an RDS store or self managed
    :param s3_meta: meta data describing s3 output
    :param kwargs:
    :return:
    """
    s3_input = get_connections()['aws'].get('s3-input')
    s3_output = get_connections()['aws'].get('s3-output')
    s3_meta = {'bucket': s3_input['bucket'], 's3_key': s3_output['s3_key'], 'role_arn': s3_input['role_arn']}
    ts_nodash = kwargs.get('ts_nodash')
    task_instance = kwargs.get('task_instance')
    if not check_table_exists(task_instance, ts_nodash, conn_id, table):
        logging.info('{} does not exist in this dag. skipping task.'.format(table))
    else:
        # handle diff between rds and twitch managed postgres
        if is_rds:
            connection = get_airflow_connection(conn_id)
            source = create_rds_resource(connection)
            resource = RDSResource(
                identifier=generate_etl_db_instance_identifier(source.identifier, kwargs.get('ts_nodash')),
                role_arn=source.role_arn,
                driver=deepcopy(source.driver))
            [pitr_host, pitr_port] = resource.get_host_and_port()
        else:
            connection = get_airflow_connection(conn_id+'-'+ts_nodash)
            pitr_host = connection.host

        logging.info("making spark app config for {conn_id}.{table} on host {pitr_host}".format(**locals()))
        spark_app_config = make_spark_table_config(table, conn_id, pitr_host, ts_nodash, s3_meta)

        logging.info("creating spark script for {conn_id}.{table}".format(**locals()))
        if generate_spark_script(spark_app_config, s3_meta) is False:
            raise ZenyattaError('failed to create spark script for {conn_id}.{table}'.format(**locals()))
        spark_script_s3_fullpath = 's3://{bucket}/'.format(**s3_meta) + spark_app_config['s3_script_path']
        logging.info("spark script is created on {spark_script_s3_fullpath}".format(**locals()))
        key = xcoms.XcomSparkScriptKey(task_instance, ts_nodash).get_key()
        logging.info("pushing xcom: {}/{} from task: {}"
                     .format(key, spark_script_s3_fullpath, task_instance.task_id))
        task_instance.xcom_push(key=key, value=spark_script_s3_fullpath)
        key = xcoms.XcomSparkConfigKey(task_instance, ts_nodash).get_key()
        logging.info("pushing xcom: {}/{} from task: {}"
                     .format(key, spark_app_config, task_instance.task_id))
        task_instance.xcom_push(key=key, value=spark_app_config)
        key = xcoms.XcomBucketKey(task_instance, ts_nodash).get_key()
        logging.info("pushing xcom: {}/{} from task: {}"
                     .format(key, s3_meta['bucket'], task_instance.task_id))
        task_instance.xcom_push(key=key, value=s3_meta['bucket'])
        key = xcoms.XcomRoleArnKey(task_instance, ts_nodash).get_key()
        logging.info("pushing xcom: {}/{} from task: {}"
                     .format(key, s3_meta['role_arn'], task_instance.task_id))
        task_instance.xcom_push(key=key, value=s3_meta['role_arn'])
        key = xcoms.XcomS3ScriptKey(task_instance, ts_nodash).get_key()
        logging.info("pushing xcom: {}/{} from task: {}"
                     .format(key, spark_app_config['s3_script_path'], task_instance.task_id))
        task_instance.xcom_push(key=key, value=spark_app_config['s3_script_path'])
        key = xcoms.XcomLocalScriptKey(task_instance, ts_nodash).get_key()
        logging.info("pushing xcom: {}/{} from task: {}"
                     .format(key, spark_app_config['local_script_path'], task_instance.task_id))
        task_instance.xcom_push(key=key, value=spark_app_config['local_script_path'])

        os.remove(spark_app_config['local_script_path'])


def submit_spark_application_task(table: str, connection: Connection,
                                  priority_weight: int=1, dag: DAG=None) -> PythonOperator:
    return python_task_wrapper(
        submit_spark_application,
        "{connection.conn_id}-{table}-etl-submit".format(**locals()),
        {'table': table,
         'conn_id': connection.conn_id},
        priority_weight=priority_weight,
        pool=connection.conn_id,
        dag=dag,
        retries=3)


def submit_spark_application(table: str, conn_id: str, attempts: int=20,
                             **kwargs: Dict[Any, Any]) -> None:
    """this is the second ETL task to save table on parquet.
    a step is added to EMR cluster which will launch a spark application
    if this task has run but failed once, it will skip the next 10 retries

    :param table: table to dump
    :param conn_id: the connection id in the airflow database
    :param s3_meta: meta data describing s3 output
    :param kwargs:
    :return:
    """
    ts_nodash = kwargs.get('ts_nodash')
    task_instance = kwargs.get('task_instance')
    if not check_table_exists(task_instance, ts_nodash, conn_id, table):
        logging.info('{} does not exist in this dag. skipping task.'.format(table))
    else:
        key = xcoms.XcomSparkSubmitKey(task_instance, ts_nodash).get_key()
        logging.info("xcom_pull key:{key} task_ids:{task_instance.task_id}".format(**locals()))
        submit_flag = task_instance.xcom_pull(key=key, task_ids=task_instance.task_id)
        if submit_flag is None:
            priority_weight = check_table_priority(conn_id, table)

            key = xcoms.XcomSparkScriptKey(task_instance, ts_nodash).get_key()
            task_ids = "{conn_id}-{table}-etl-prep".format(**locals())
            logging.info("xcom_pull key:{key} task_ids:{task_ids}".format(**locals()))
            spark_script_s3_fullpath = task_instance.xcom_pull(key=key, task_ids=task_ids)
            logging.info("spark_script_path: {spark_script_s3_fullpath}".format(**locals()))

            key = xcoms.XcomSparkConfigKey(task_instance, ts_nodash).get_key()
            logging.info("xcom_pull key:{key} task_ids:{task_ids}".format(**locals()))
            spark_app_config = task_instance.xcom_pull(key=key, task_ids=task_ids)
            logging.info("spark_app_config: {spark_app_config}".format(**locals()))

            emr_name = conf.get('spark', 'emr_name')
            emr_prefix = emr_name if len(emr_name) > 0 \
                else 'zenyatta-spark-' + conn_id if conn_id in ['cohesion-following', 'justintv_prod'] \
                else 'zenyatta-spark'
            logging.info("fetch EMR cluster for {emr_prefix}".format(**locals()))
            emr = EMR(emr_prefix)
            emr_cluster = emr.fetch_emr_cluster()
            logging.info("ERM cluster {Id} {Name} is {Status}".format(**emr_cluster))
            cluster_desc = emr.get_cluster_description(emr_cluster['Id'])
            cluster_pub_dns = cluster_desc['MasterPublicDnsName']

            args = ['spark-submit', '--deploy-mode', 'cluster', '--master', 'yarn', '--jars',
                    '/home/hadoop/custom_jars/mysql-connector-java-6.0.6.jar,'
                    '/home/hadoop/custom_jars/postgresql-42.1.0.jre7.jar',
                    '--conf', 'spark.yarn.submit.waitAppCompletion=false',
                    '--conf', 'spark.network.timeout=10000000',
                    '--conf', 'spark.executor.heartbeatInterval=10000000',
                    '--conf', 'spark.dynamicAllocation.enabled=false',
                    '--conf', 'spark.executor.instances={}'.format(priority_weight),
                    spark_script_s3_fullpath]
            steps = [{'Name': '{conn_id}.{table}.{ts_nodash}'.format(**locals()),
                      'ActionOnFailure': 'CONTINUE',
                      'HadoopJarStep': {'Jar': 'command-runner.jar', 'Args': args}}]
            sleep(random.randint(1, 300))
            stepIds = emr.add_step_to_cluster(steps)
            logging.info("step {stepIds} is added to EMR".format(**locals()))

            # wait 5 to 10 mins for application to kick off get applicationID from yarn
            sleep(random.randint(300, 600))
            attempts = 4
            app_detail = get_yarn_app_by_name(cluster_pub_dns, spark_app_config['script_name'])
            while app_detail is None and attempts > 0:
                sleep(60)
                app_detail = get_yarn_app_by_name(cluster_pub_dns, spark_app_config['script_name'])
                attempts -= 1
            if app_detail is None:
                raise ZenyattaError("could not find app {} on cluster"
                                    .format(spark_app_config['script_name']))
            app_id = app_detail['id']

            [app_state, app_final_status] = poll_app_status(cluster_pub_dns, app_id)
            # allow an app to run up to 4 hours
            attempts = 4
            while (
                    (app_state == 'RUNNING') or (app_state == 'ACCEPTED' and app_final_status == 'UNDEFINED')
                   ) and attempts > 0:
                logging.info("app is running or accepted, polling app status again...")
                [app_state, app_final_status] = poll_app_status(cluster_pub_dns, app_id)
                attempts -= 1
            if app_state == 'RUNNING' or attempts == 0:
                logging.info("updating xcom key: {key} value: True".format(**locals()))
                key = xcoms.XcomSparkSubmitKey(task_instance, ts_nodash).get_key()
                task_instance.xcom_push(key=key, value=True)
                raise ZenyattaError("app state is running for 4 hours. failing ")
            if app_state == 'FINISHED' and app_final_status == 'FAILED':
                raise ZenyattaError("app finished but failed. airflow will retry.")
            if app_state == 'FINISHED' and app_final_status == 'SUCCEEDED':
                logging.info("app completed")
        else:
            logging.info("this task has been submitted or set to fail, skipping...")


def poll_app_status(cluster_pub_dns: str, app_id: str, attempts: int=60) -> List[str]:
    while attempts > 0:
        app_detail = get_yarn_app_by_id(cluster_pub_dns, app_id)
        logging.info("polling app status. app detail is {app_detail}".format(**locals()))
        if app_detail['state'] == 'FINISHED':
            break
        else:
            sleep(300)
            attempts -= 1
    return [app_detail['state'], app_detail['finalStatus']]


def check_spark_output_complete_task(table: str, connection: Connection,
                                     priority_weight: int=1, dag: DAG=None) -> PythonOperator:
    return python_task_wrapper(
        check_spark_output_status,
        "{connection.conn_id}-{table}-etl-check-complete".format(**locals()),
        {'table': table,
         'conn_id': connection.conn_id},
        priority_weight=priority_weight,
        pool=connection.conn_id,
        dag=dag)


def check_spark_output_status(table: str, conn_id: str, attempts: int=20,
                              **kwargs: Dict[Any, Any]) -> bool:
    """this is the last ETL task to save table on parquet.
    check the spark application output is ready

    :param table: table to dump
    :param conn_id: the connection id in the airflow database
    :param s3_meta: meta data describing s3 output
    :param kwargs:
    :return:
    """
    s3_input = get_connections()['aws'].get('s3-input')
    s3_output = get_connections()['aws'].get('s3-output')
    s3_meta = {'bucket': s3_input['bucket'], 's3_key': s3_output['s3_key'], 'role_arn': s3_input['role_arn']}
    ts_nodash = kwargs.get('ts_nodash')
    task_instance = kwargs.get('task_instance')
    if not check_table_exists(task_instance, ts_nodash, conn_id, table):
        logging.info('{} does not exist in this dag. skipping task.'.format(table))
    else:
        key = xcoms.XcomSparkConfigKey(task_instance, ts_nodash).get_key()
        task_ids = "{conn_id}-{table}-etl-prep".format(**locals())
        logging.info("xcom_pull key:{key} task_ids:{task_ids}".format(**locals()))
        spark_app_config = task_instance.xcom_pull(key=key, task_ids=task_ids)
        logging.info("spark_app_config: {spark_app_config}".format(**locals()))

        while attempts > 0:
            sleep(60)
            logging.info("checking parquet output")
            # check if parquet output is completed or not by checking _SUCCESS file
            if check_s3_object_exist(s3_meta['bucket'],
                                     s3_meta['role_arn'],
                                     spark_app_config['pq_success_path']):
                logging.info("spark job completed. {pq_success_path} is created."
                             .format(**spark_app_config))
                remove_s3_object(s3_meta['bucket'], spark_app_config['s3_script_path'], s3_meta['role_arn'])
                return True
            attempts -= 1
        if attempts == 0:
            raise ZenyattaError("timeout and spark job has not completed.")
        return False


def prep_source_table_task(table: str, connection: Connection, is_rds: bool=None,
                           priority_weight: int=1, dag: DAG=None) -> PythonOperator:
    """This task prepares source table in RDS or PostgresContainer before it is exported.
    It only processes chat_depot.channels_room now. It converts null value to "{}" for ArrayType in JDBC
    so that dataframe operation can be correctly applied.
    :param table: table name
    :param is_rds: whether rds instance or not
    :param connection: postgres host
    :param dag: not required
    :return:
    """

    return python_task_wrapper(
        prep_source_table_if_needed,
        "prep-source-table-{connection.conn_id}-{table}".format(**locals()),
        {'table': table,
         'conn_id': connection.conn_id,
         'is_rds': is_rds},
        priority_weight=priority_weight,
        pool=connection.conn_id,
        dag=dag)


def prep_source_table_if_needed(table: str, conn_id: str, is_rds: bool,
                                **kwargs: Dict[Any, Any]):
    ts_nodash = kwargs.get('ts_nodash')
    task_instance = kwargs.get('task_instance')
    if not check_table_exists(task_instance, ts_nodash, conn_id, table):
        logging.info('{} does not exist in this dag. skipping task.'.format(table))
    else:
        if conn_id == 'chat_depot' and table == 'channel_rooms':
            connection = get_airflow_connection(conn_id+'-'+ts_nodash)
            stmt = "update channel_rooms set chat_rules = '{}' where chat_rules is null"
            rtn = run_table_update(connection.host, connection.schema,
                                   connection.login, connection.get_password(), stmt)
            if not rtn:
                raise ZenyattaError("failed to prep source table {conn_id}.{table}".format(**locals()))
        else:
            logging.info("no prep needed for source table {conn_id}.{table}".format(**locals()))


def upload_db_snapshots_to_s3_task(table: str, connection: Connection, priority_weight: int=1,
                                   dag: DAG=None) -> PythonOperator:
    """This task downloads parquet output from S3 and then uploads it to science bucekt
    :param table: table name
    :param connection: postgres host
    :param dag: not required
    :return:
    """
    return python_task_wrapper(
        upload_db_snapshots_to_s3,
        "{connection.conn_id}-{table}-upload-spark-output".format(**locals()),
        {'table': table,
         'conn_id': connection.conn_id,
         'cleanup_s3_dump': True,
         'succeed_cleanup_s3_dump': True},
        priority_weight=priority_weight,
        pool=connection.conn_id,
        dag=dag,
        trigger_rule="all_done",
        retries=3)


def upload_db_snapshots_to_s3(table: str, conn_id: str, **kwargs: Dict[Any, Any]) -> None:
    """
    we use s3_meta_source to temporary staging output, download it to local,
    then upload it to science bucket
    """
    s3_input = get_connections()['aws'].get('s3-input')
    s3_output = get_connections()['aws'].get('s3-output')
    s3_meta = {'bucket': s3_input['bucket'], 'role_arn': s3_input['role_arn'],
               'science_bucket': s3_output['bucket'], 'science_role': s3_output['role_arn']}
    ts_nodash = kwargs.get('ts_nodash')
    task_instance = kwargs.get('task_instance')
    if not check_table_exists(task_instance, ts_nodash, conn_id, table):
        logging.info('{} does not exist in this dag. skipping task.'.format(table))
    else:
        key = xcoms.XcomSparkConfigKey(task_instance, ts_nodash).get_key()
        task_ids = "{conn_id}-{table}-etl-prep".format(**locals())
        logging.info("xcom_pull key:{key} task_ids:{task_ids}".format(**locals()))
        spark_app_config = task_instance.xcom_pull(key=key, task_ids=task_ids)
        logging.info("get spark_app_config: {spark_app_config}".format(**locals()))

        s3_source_path = spark_app_config['s3_pq_path']
        object_prefix = spark_app_config['object_prefix']
        local_staging_prefix = "/mnt/"
        local_staging_dir = local_staging_prefix + object_prefix
        key = xcoms.XcomLocalPQKey(task_instance, ts_nodash).get_key()
        logging.info("pushing xcom: {}/{} from task: {}"
                     .format(key, local_staging_dir, task_instance.task_id))
        task_instance.xcom_push(key=key, value=local_staging_dir)
        key = xcoms.XcomS3PQKey(task_instance, ts_nodash).get_key()
        logging.info("pushing xcom: {}/{} from task: {}".format(key, object_prefix, task_instance.task_id))
        task_instance.xcom_push(key=key, value=object_prefix)

        make_dir(local_staging_dir)

        logging.info("downloading parquet output from s3 {s3_source_path} to "
                     "local staging dir {local_staging_dir}".format(**locals()))
        download_objects_from_s3(s3_meta['bucket'], object_prefix,
                                 local_staging_prefix, s3_meta['role_arn'], is_dir=True)

        logging.info("uploading parquet output from local staging dir to s3 {science_bucket}".format(**s3_meta))
        if not upload_file_to_s3(s3_meta['science_bucket'], local_staging_dir,
                                 object_prefix, s3_meta['science_role'], is_dir=True):
            raise Exception("Could not upload snapshot.")


def update_db_tables_task(connection: Connection, is_rds: bool=None, dag: DAG=None) -> PythonOperator:
    return python_task_wrapper(
        update_db_tables,
        "update_db_tables-{connection.conn_id}".format(**locals()),
        {'conn_id': connection.conn_id,
         'is_rds': is_rds,
         's3_meta': dag.default_args['s3_output'],
         'role_arn': get_role_from_conn_id(connection.conn_id),
         'cleanup_aws': True},
        dag=dag,
        pool=connection.conn_id)


def update_db_tables(conn_id: str, is_rds: bool, s3_meta: Dict[str, str], **kwargs) -> bool:
    """
    update db tables as Variable for next cycle run and output meta data
    :param connection
    :pram is_rds or not
    :return: list of tables
    """
    ts_nodash = kwargs.get('ts_nodash')
    task_instance = kwargs.get('task_instance')
    try:
        resource, connection = get_pitr_resource_connection(conn_id, is_rds, ts_nodash, task_instance)
        db_engine = resource.create_sql_engine()

        var_key = conn_id + '-pitr-schema'
        if 'justintv' not in conn_id:
            tables = db_engine.table_names()
            update_aws_resource_variable(var_key, 'db_tables', tables, is_append=False)
        else:
            tables = get_aws_resource_variable(var_key, 'db_tables')

        update_aws_resource_variable(var_key, 'ts_nodash', ts_nodash, is_append=False)
        logging.info("updated {var_key} schema as {tables}".format(**locals()))
        key = xcoms.XcomTablesKey(task_instance, ts_nodash).get_key()
        logging.info("pushing xcom: {}/{} from task: {}".format(key, tables, task_instance.task_id))
        task_instance.xcom_push(key=key, value=tables)

        var_key = conn_id + '-pitr-table-partition'
        cur_partitions = get_aws_resource_variable(var_key, 'partitions')
        inspector = inspect(db_engine)
        for table in db_engine.table_names():
            if table in cur_partitions.keys():
                continue
            pk_constraint = inspector.get_pk_constraint(table)
            if pk_constraint['name'] is not None:
                pk_col = pk_constraint['constrained_columns'][0]
                rtn = [str(col['type']) for col in inspector.get_columns(table)
                       if col['name'] == pk_col and str(col['type']) in ('INTEGER', 'BIGINT')]
                if len(rtn) > 0:
                    sql_stmt = "select min({pk_col}), max({pk_col}) from {table};".format(**locals())
                    with db_engine.begin() as db_connection:
                        results = db_connection.execute(sql_stmt)
                        for row in results:
                            [min, max] = row.values()
                            if max is not None and max > 1000000:
                                cur_partitions[table] = {'col': pk_col, 'min': min, 'max': max}
        update_aws_resource_variable(var_key, 'partitions', cur_partitions, is_append=False)
        update_aws_resource_variable(var_key, 'ts_nodash', ts_nodash, is_append=False)
        logging.info("updated {var_key} partitions as {cur_partitions}".format(**locals()))

        meta = {}
        tables = []
        for table in db_engine.table_names():
            table_obj = {}
            type_items = inspector.get_columns(table)
            for type_item in type_items:
                type_item['type'] = str(type_item['type'])
            table_obj['columns'] = type_items
            indexes = inspector.get_indexes(table)
            table_obj['indexes'] = [dict(
                {ind['name']: dict(
                    {'columns': ind['column_names'],
                     'unique': ind['unique']
                     })
                 }) for ind in indexes if len(indexes) > 0]
            pk_constraint = inspector.get_pk_constraint(table)
            table_obj['primary_key'] = dict({'columns': pk_constraint['constrained_columns']})
            fk_constraints = inspector.get_foreign_keys(table)
            table_obj['foreign_keys'] = [dict(
                {fk['name']: dict(
                    {'constrained_columns': fk['constrained_columns'],
                     'referred_table': fk['referred_table'],
                     'referred_columns': fk['referred_columns']})
                 }) for fk in fk_constraints if len(fk_constraints) > 0]
            tables.append({table: table_obj})
        meta['tables'] = tables

        work_dir = get_work_directory()
        output_path = "{work_dir}/{connection.conn_id}-{connection.schema}-{ts_nodash}.json".format(
            **locals())

        # now move to s3
        base_path = get_base_s3_path(s3_meta['s3_key'], ts_nodash, connection.conn_id)

        # new schema file
        json.dump({connection.schema: meta}, open(output_path, 'w'))
        s3_path = "{base_path}/metadata.v1.json".format(**locals())
        upload_file_to_s3(s3_meta['bucket'], output_path, s3_path, s3_meta['role_arn'])

        os.remove(output_path)
        return True
    except Exception as e:
        logging.info("unexpected error: {}".format(e))
        raise


def dummy_task(dag: DAG=None) -> DummyOperator:
    return python_dummy_wrapper('dummy_task', dag)
