import airflow
from zenyatta.common import skip_table, check_table_priority, check_dag_output_format
from zenyatta.tasks import (get_s3_bucket_check_task,
                            get_make_ebs_volume_ec2_instance_and_run_wal_backup_task,
                            get_create_container_task, send_done_message_task, get_table_dump_task,
                            create_rds_point_in_time_recover_task, delete_rds_instance_task,
                            open_rds_instance_permissions_task, create_ec2_sql_task, remove_ec2_sql_task,
                            check_postgres_pg_isready_task, generate_table_spark_script_task,
                            submit_spark_application_task, prep_source_table_task,
                            upload_db_snapshots_to_s3_task, update_db_tables_task, get_db_tables,
                            dummy_task)


def initialize_etl_dag(connection: airflow.models.Connection, dag: airflow.DAG, role: str=None) -> None:
    """
    merge initialize_unmanaged_db_etl_dag and initialize_rds_etl_dag to add tasks to
    1: check the existence of an s3 bucket
    2: for unmanaged db, add tasks to dag to:
        1): fetch wal backup to ebs volume
        2): create ec2 sql container
        3): create postgres container
        4): check pg_ready
       for an rds instance, create ETL pipeline tasks for a DAG to:
        1): PITR recovery
        2): open rds permission for cohesion dbs
    3. update db schema for PITR, spark partition and output meta data
    4. ELT job for each table
    5. delete EC2, EBS, and RDS resource and send notification
    :param connection: host or rds instance to ETL
    :param dag: dag to add new tasks to
    :param: role: optional - role to alter to give broader permissions
    :return:
    """
    out_fmt = check_dag_output_format(dag.default_args)

    if out_fmt != 'no_output':
        """note from mylons 7/13/2018: when adding the teams db, get_db_tables fails because there is no existing task
        to pull from. so this is excepted to fail i guess? might be a better way to do this
        """
        try:
            db_tables = get_db_tables(connection)
        except:
            db_tables = []

        s3_bucket_check = get_s3_bucket_check_task(connection, dag)

        if dag.default_args['rds'] is True:
            # rds management pipeline
            create_rds = create_rds_point_in_time_recover_task(connection, dag)
            create_rds.set_upstream(s3_bucket_check)

            if role:
                upstream_task = open_rds_instance_permissions_task(connection, role, dag=dag)
                upstream_task.set_upstream(create_rds)
            else:
                upstream_task = create_rds

            before_update_db_tables = upstream_task

            cleanup_resources = delete_rds_instance_task(connection, dag)

        else:
            # unmanaged db pipeline
            setup_pgdata_dir = get_make_ebs_volume_ec2_instance_and_run_wal_backup_task(connection, dag)
            setup_pgdata_dir.set_upstream(s3_bucket_check)

            ec2_sql = create_ec2_sql_task(connection, dag)
            ec2_sql.set_upstream(setup_pgdata_dir)

            container_task = get_create_container_task(connection, dag)
            container_task.set_upstream(ec2_sql)

            pg_isready_task = check_postgres_pg_isready_task(connection, dag)
            pg_isready_task.set_upstream(container_task)

            before_update_db_tables = pg_isready_task

            cleanup_resources = remove_ec2_sql_task(connection, dag)

        update_db_tables = update_db_tables_task(connection, is_rds=dag.default_args['rds'], dag=dag)
        update_db_tables.set_upstream(before_update_db_tables)

        sns_task = send_done_message_task(connection, dag=dag)

        for table in db_tables:
            if skip_table(table):
                continue

            if 'csv' in out_fmt:
                dump_table = get_table_dump_task(table, connection, is_rds=dag.default_args['rds'], dag=dag)
                dump_table.set_upstream(update_db_tables)
                dump_table.set_downstream(cleanup_resources)
                dump_table.set_downstream(sns_task)

            if 'parquet' in out_fmt:
                weight = check_table_priority(connection.conn_id, table)
                prep_source_table = prep_source_table_task(
                    table, connection, is_rds=dag.default_args['rds'], priority_weight=weight, dag=dag)
                spark_script = generate_table_spark_script_task(
                    table, connection, is_rds=dag.default_args['rds'], priority_weight=weight, dag=dag)
                spark_application = submit_spark_application_task(
                    table, connection, priority_weight=weight, dag=dag)
                upload_db_snapshots_to_s3 = upload_db_snapshots_to_s3_task(
                    table, connection, priority_weight=weight, dag=dag)

                prep_source_table.set_upstream(update_db_tables)
                spark_script.set_upstream(prep_source_table)
                spark_application.set_upstream(spark_script)
                upload_db_snapshots_to_s3.set_upstream(spark_application)

                upload_db_snapshots_to_s3.set_downstream(cleanup_resources)
                upload_db_snapshots_to_s3.set_downstream(sns_task)

    else:
        dummy_task(dag)
