import json
from datetime import datetime, timedelta

from airflow.models import Connection, Pool, Variable
from airflow.settings import Session

from zenyatta.common import get_connections


AIRFLOW_CONNECTIONS = ["airflow_db",
                       "airflow_ci",
                       "beeline_default",
                       "bigquery_default",
                       "local_mysql",
                       "presto_default",
                       "hive_cli_default",
                       "hiveserver2_default",
                       "metastore_default",
                       "mysql_default",
                       "postgres_default",
                       "sqlite_default",
                       "http_default",
                       "mssql_default",
                       "vertica_default",
                       "webhdfs_default",
                       "ssh_default",
                       "airflow-logs"]


sesh = Session()

# nuke existing data
remove_me = sesh.query(Connection).filter(Connection.conn_id.notin_(AIRFLOW_CONNECTIONS)).all()
for remove in remove_me:
    sesh.delete(remove)

remove_me = sesh.query(Pool).all()
for remove in remove_me:
    sesh.delete(remove)

sesh.commit()

# setup new data
connections = get_connections()

# set up roles for cohesion db to trigger open_rds_instance_permission task
roles = {}
for conn in connections['aws']:
    if 'cohesion' in conn:
        if conn == 'cohesion-friends':
            roles[conn] = 'friends'
        if conn == 'cohesion-following':
            roles[conn] = 'app'
        if conn == 'cohesion-follow-games':
            roles[conn] = 'cohesion'
        if conn == 'cohesion-chat':
            roles[conn] = 'chat'
        else:
            roles[conn] = 'cohesion'

for conn_id, conn_params in connections['sql'].items():
    conn = Connection(conn_id=conn_id,
                      conn_type='Postgres' if 'conn_type' not in conn_params else conn_params['conn_type'],
                      host=conn_params['host'],
                      login=conn_params['user'],
                      password=conn_params['password'],
                      schema=conn_params['schema'],
                      port=conn_params['port'])
    sesh.add(conn)
    # non rds hosts have to perform the wal-e backup fetch which has a separate queue
    if '.rds.' not in conn.host:
        wal_e_pool = Pool(pool=conn_id + "-wal-fetch",
                          slots=8,
                          description="to limit number of wal-e backup fetches on a particular worker host")
        sesh.add(wal_e_pool)

    pool = Pool(pool=conn_id,
                slots=conn_params['slots'],
                description="to limit number of table dumps on a particular postgres host")

    sesh.add(pool)

    start_date = datetime(datetime.now().year, datetime.now().month, datetime.now().day)

    defaul_value = {
        'owner': 'd8a',
        'depends_on_past': False,
        'email': ['d8a@twitch.tv'],
        'email_on_failure': False,
        'email_on_retry': False,
        'retries': 10,
        'retry_delay': str(timedelta(minutes=5)),
        'start_date': str(start_date),
        'max_active_runs': 6,
        'schedule_interval': str(timedelta(hours=12)),
        'output': ['csv', 'parquet'],
        }

    if conn_id in roles:
        defaul_value['role'] = roles[conn_id]

    Variable.set(key=conn_id+"_defaults", value=json.dumps(defaul_value), serialize_json=True)

sesh.commit()

# make s3 connection for logs
log_conn = Connection(conn_id="airflow-logs", conn_type="S3")

existing_log_conns = sesh.query(Connection).filter(Connection.conn_id == log_conn.conn_id).all()

for conn in existing_log_conns:
    print("removing log conn: {conn.conn_id}".format(**locals()))
    sesh.delete(conn)

sesh.commit()

sesh.add(log_conn)
print("adding {log_conn.conn_id}".format(**locals()))
sesh.commit()
