import os
import json
import yaml
from datetime import timedelta, datetime
from typing import Dict, List
from dateutil import parser
import logging
from airflow.models import Variable, Connection
from airflow import settings
from airflow.settings import conf
from zenyatta.common.errors import ZenyattaError
from sqlalchemy import not_, and_
import zenyatta.common


def get_connections() -> Dict:
    """attempt to get puppet generated connections yaml file which contains the necessary info
    to generate the airflow.models.Connection objects
    :return:
    """
    if os.path.isfile('/etc/zenyatta/connections.yaml'):
        return yaml.load(open('/etc/zenyatta/connections.yaml'))

    elif os.path.isfile('./connections.yaml'):
        return yaml.load(open('./connections.yaml'))

    elif os.path.isfile('../../connections.yaml'):
        return yaml.load(open('../../connections.yaml'))
    else:
        raise ZenyattaError("could not find a connections.yaml in {}".format(os.getcwd()))


def get_docker_metadata(name: str) -> Dict[str, str]:
    """
    return {'repository': "1234.dkr.ecr.us-west-2.amazonaws.com/d8a/postgres-pitr"}
    """
    try:
        docker = get_connections()['docker'].get(name)
    except Exception as e:
        raise ZenyattaError("couldn't load conns: {}".format(e))
    return docker


def get_sql_metadata(name: str) -> Dict[str, str]:
    """
    return { 'host': "tmi-postgres.justin.tv"
             'user': "backup"
             'password': "password"
             'schema': chat_depot
             'port: 5432
             'slots': 3
            }
    """
    try:
        sql = get_connections()['sql'].get(name)
    except Exception as e:
        raise ZenyattaError("couldn't load conns: {}".format(e))
    return sql


def get_airflow_connection(conn_id: str) -> Connection:
    """
    return an airflow connection object
    """
    sesh = settings.Session()
    connection = sesh.query(Connection).filter_by(conn_id=conn_id).one()
    return connection


def get_default_conns() -> List[str]:
    return ['airflow_db', 'airflow_ci', 'beeline_default', 'bigquery_default', 'local_mysql',
            'presto_default', 'hive_cli_default', 'hiveserver2_default', 'metastore_default',
            'mysql_default', 'postgres_default', 'sqlite_default', 'http_default', 'mssql_default',
            'vertica_default', 'webhdfs_default', 'ssh_default', 'airflow-logs', 'fs_default',
            'aws_default', 'spark_default', 'emr_default']


def get_zenyatta_connections() -> List[Connection]:
    skip_conns = get_default_conns()
    sesh = settings.Session()
    connections = sesh.query(Connection).filter(and_(Connection.conn_id.notin_(skip_conns),
                                                not_(Connection.conn_id.contains('-etl-')))).all()
    return connections


def get_zenyatta_base_connections() -> List[Connection]:
    skip_conns = get_default_conns()
    sesh = settings.Session()
    connections = sesh.query(Connection).filter(and_(Connection.conn_id.notin_(skip_conns),
                                                not_(Connection.conn_id.op('~')("T[0-9]*$")))).all()
    return connections


def get_work_directory(dag_id: str=None, ts_nodash: str=None, prefix_dir: str=None) -> str:
    if prefix_dir and dag_id and ts_nodash:
        return "/{prefix_dir}/{dag_id}/{ts_nodash}".format(**locals())
    if dag_id and ts_nodash:
        return "/{dag_id}/{ts_nodash}".format(**locals())
    elif conf.has_option('output', 'work_directory'):
        return conf.get('output', 'work_directory')
    else:
        return "/mnt"


def get_dag_defaults(conn_id: str) -> (dict, dict):
    """
    return the default variable value and dag arguments of a given conn_id.
    it looks for conn_id+"_defaults" as the key in Variable and pulls out the value, which is set up
    by init_db.py. it puts basic configuration of this value into dag_args, including schedule,
    retry, formatted start/end date, and etc, and return the vaule and the dag_args
    :param: conn_id
    :return:
    """
    defaults = json.loads(Variable.get(conn_id+"_defaults", deserialize_json=True))
    dag_args = {}
    # pre process
    for key in ['retry_delay', 'schedule_interval']:
        if key in defaults:
            try:
                t = datetime.strptime(defaults[key], "%H:%M:%S")
                defaults[key] = timedelta(hours=t.hour, minutes=t.minute, seconds=t.second)
            except:
                if key != 'schedule_interval':
                    raise
    # start poppin
    if 'start_date' in defaults:
        dag_args['start_date'] = parser.parse(defaults.pop('start_date'))
    if 'schedule_interval' in defaults:
        dag_args['schedule_interval'] = defaults.pop('schedule_interval')
    return defaults, dag_args


def skip_table(table, filters={'killme_'}) -> bool:
    """
    skip table when build a dag
    """
    if any(table.startswith(f) for f in filters):
        logging.info("skipping {}".format(table))
        return True
    else:
        return False


def check_table_priority(conn_id: str, table: str) -> int:
    """higher priority for resource consuming tables"""
    aws_key = 'priority_map'
    priority_map = util.get_aws_resource_variable(aws_key, aws_key)
    key = '{}.{}'.format(conn_id, table)
    return 1 if priority_map is None or key not in priority_map else priority_map[key]


def check_dag_output_format(default_args: dict) -> str:
    """
    parse out defaults Variable to check if this dag is set to produce csv or parquet
    {'output':['csv', 'parquet']} would be stored as key value pair in extra element
    """
    return 'no_output' if 'output' not in default_args \
        else 'csv_parquet' if ['csv', 'parquet'] == sorted(default_args['output']) \
        else 'csv' if ['csv'] == default_args['output'] \
        else 'parquet' if ['parquet'] == default_args['output'] \
        else 'unknown_format'
