import sys
import argparse
import json
import logging
from datetime import timedelta, datetime
from airflow.models import Connection, Variable, Pool


def create_connection(conn_id: str, conn_type: str, host: str, login: str, password: str,
                      schema: str, port: str):
    if conn_type.lower() == 'postgres':
        conn_type = 'Postgres'
    elif conn_type.lower() == 'mysql':
        conn_type = 'MySQL'
    elif conn_type.lower() == 'aurora':
        conn_type = 'Aurora'
    else:
        raise ValueError("connection type must be: Postgres, MySQL, Aurora")

    return Connection(conn_id=conn_id,
                      conn_type=conn_type,
                      host=host,
                      login=login,
                      password=password,
                      schema=schema,
                      port=port)


def create_pool(conn_id: str, slots: int):
    return Pool(pool=conn_id,
                slots=slots,
                description="to limit number of table dumps on a {conn_id}".format(**locals()))


def create_defaults(conn: Connection, retry_delay: int=5, schedule_interval: int=4, task_retries: int=10,
                    max_active_runs: int=6):
    today = datetime.now()
    today = datetime(today.year, today.month, today.day)
    Variable.set(key=conn.conn_id+"_defaults",
                 value=json.dumps({
                     'owner': 'd8a',
                     'depends_on_past': False,
                     'email': ['d8a@twitch.tv'],
                     'email_on_failure': False,
                     'email_on_retry': False,
                     'retries': task_retries,
                     'retry_delay': str(timedelta(minutes=retry_delay)),
                     'start_date': str(today),
                     'max_active_runs': max_active_runs,
                     'schedule_interval': str(timedelta(hours=schedule_interval)),
                     'output': ['csv', 'parquet'],
                 }),
                 serialize_json=True)


if __name__ == '__main__':
    # parse args
    # create connection
    from airflow.settings import Session
    parser = argparse.ArgumentParser(description="add a an etl dag to zenyatta",
                                     formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    parser.add_argument('--rds_id', required=True, dest='conn_id', action='store', help='RDS identifier')
    parser.add_argument('--type', required=True, dest='conn_type', action='store',
                        help='Postgres, Mysql, or Aurora')
    parser.add_argument('--host', required=True, dest='host', action='store', help='uri to RDS instance')
    parser.add_argument('--user', required=True, dest='login', action='store',
                        help='username for RDS instance, should be a super user')
    parser.add_argument('--password', required=True, dest='password', action='store',
                        help='password for instance')
    parser.add_argument('--schema', required=True, dest='schema', action='store', help='schema to ETL')
    parser.add_argument('--port', required=True, dest='port', action='store', help='port to connect to')
    parser.add_argument('--schedule-interval', required=False, dest='schedule_interval', action='store',
                        help='how often to run the DAG in hours.', default=4)
    parser.add_argument('--max-active-runs', required=False, dest='max_active_runs', action='store',
                        help='how many instances of the DAG allowed to run.', default=6)

    args = parser.parse_args()
    # make sure conneciton doesn't exist
    sesh = Session()
    conns = sesh.query(Connection).filter(Connection.conn_id == args.conn_id).all()
    if len(conns) > 0:
        for conn in conns:
            print("{} already exists, pointed at: {}, and removing it.".format(args.conn_id, conns[0].host))
            sesh.delete(conn)
            # check and see if this conn has a pool too
            pools = sesh.query(Pool).filter(Pool.pool == conn.conn_id).all()
            if len(pools) > 0:
                for pool in pools:
                    sesh.delete(pool)
        sesh.commit()

    conn = create_connection(args.conn_id, args.conn_type, args.host, args.login, args.password,
                             args.schema, args.port)
    sesh.add(conn)
    pool = create_pool(args.conn_id, 25)
    sesh.add(pool)
    logging.info("created DAG: {}-etl".format(conn.conn_id))
    create_defaults(conn, schedule_interval=args.schedule_interval, max_active_runs=args.max_active_runs)
    logging.info("set defaults for {}-etl".format(conn.conn_id))

    sesh.commit()
    sesh.close()
