import os
import logging
import shutil
import yaml
import json
import sys
from urllib.request import urlopen
from urllib.error import URLError
from time import sleep
from typing import List, Any, Dict
from subprocess import check_output
from sqlalchemy import and_
from airflow.models import Variable, Connection, TaskInstance
from airflow import settings
from zenyatta.aws.rds import RDSResource, SingleInstanceDriver, AuroraDriver
from zenyatta.aws import boto_resource
from zenyatta.aws.s3 import remove_s3_object
from zenyatta.common.errors import ZenyattaError
from zenyatta.common import xcoms


def update_aws_resource_variable(resource_key: str, aws_type: str, resource_id: Any,
                                 is_append: bool) -> Any:
    """
    update airflow Variable
    :param resource_key: Variable key
    :param aws_type: value type
    :param resource_id: value content
    :param is_append: append or update
    :return: updated Variable
    """
    var = Variable.get(resource_key, default_var={}, deserialize_json=True)
    if aws_type not in var:
        var[aws_type] = []
    if is_append:
        var[aws_type].append(resource_id)
    else:
        var[aws_type] = resource_id
    Variable.set(resource_key, var, serialize_json=True)
    return var


def get_aws_resource_variable(resource_key: str, aws_type: str) -> List[Any]:
    var = Variable.get(resource_key, default_var={}, deserialize_json=True)
    return var.get(aws_type, {})


def parse_config_file(file_path: str) -> Dict:
    """
    parse a yaml file
    :param: full file path
    :return:
    """
    if not os.path.isfile(file_path):
        raise ZenyattaError("could not locate hardware config file {file_path}".format(**locals()))
    try:
        with open(file_path) as f:
            cfg = yaml.safe_load(f)
            return cfg
    except Exception as e:
        logging.info("could not load config file {}".format(e))


def get_task_instance(task_id: str, execution_date: str) -> TaskInstance:
    """
    task_id: airflow task id as 'create-ec2-sql-chat_depot'
    execution_date: task execution date as str '2017-06-29 16:00:00'
    return task_instance for xcom.pull and xcom.push
    """
    try:
        sesh = settings.Session()
        task_instance = sesh.query(TaskInstance).filter(and_(
            TaskInstance.task_id == task_id,
            TaskInstance.execution_date == execution_date
        )).one()
    except Exception as e:
        logging.info("error query a task instance {0}".format(str(e)))
    return task_instance


def make_dir(dir_to_make: str) -> str:
    """
    make a directory, if it exists, delete and recrate
    :param: full path of the dir
    :return:
    """
    if os.path.exists(dir_to_make):
        shutil.rmtree(dir_to_make, ignore_errors=True)
    os.makedirs(dir_to_make)
    return dir_to_make


def run_command(command: List[str], env: Dict[Any, Any]=None) -> Any:
    """
    log and run a list of command, if it's running a shell dscript, put #!/bin/bash as header
    :param: a list of command
    :param: env, default to None
    :return:
    """
    logging.info("running command {command} with env: {env}".format(**locals()))
    return check_output(command, env=env)


def check_response(response: Dict) -> bool:
    """
    check HTTP returrn code
    """
    if response.get('ResponseMetadata').get('HTTPStatusCode') == 200:
        return True
    else:
        logging.error(response.get('ResponseMetadata').get('HTTPStatusCode'))
        raise ZenyattaError("failed emr API call HTTPStatusCode is %s"
                            % str(response.get('ResponseMetadata').get('HTTPStatusCode')))


def dump_to_json(json_block: Any, file_name: str) -> str:
    """
    dump any json block to a file
    """
    with open(file_name, 'w') as f:
        json.dump(json_block, f, indent=4)
    return file_name


def dump_to_file(content: str, file_name: str) -> str:
    """
    dump any content to a file
    """
    with open(file_name, 'w') as f:
        f.write(content)
    return file_name


def cleanup_aws_from_xcom(context: dict):
    """
    clean up aws resource from xCom, e.g. ec2, rds, etc
    :param resource_key: Xcom key
    :param context: context to parse
    """
    task_instance = context.get('task_instance')
    ts_nodash = context.get('ts_nodash')
    if 'role_arn' in context:
        ec2, _ = boto_resource('ec2', context.get('role_arn'))
        key = xcoms.XcomInstanceIdKey(task_instance, ts_nodash).get_key()
        ec2_id = task_instance.xcom_pull(key=key, task_ids=task_instance.task_id)
        if ec2_id is not None:
            logging.info("terminating {}".format(ec2_id))
            ec2_instance = ec2.Instance(ec2_id)
            try:
                resp = ec2_instance.terminate()
                logging.info("termination response: {}".format(resp))
                ec2_instance.wait_until_terminated()
                logging.info("terminated: {}".format(ec2_id))
            except:
                logging.info("failed to terminate: {}".format(ec2_id))

        key = xcoms.XcomVolumeIdKey(task_instance, ts_nodash).get_key()
        ebs_id = task_instance.xcom_pull(key=key, task_ids=task_instance.task_id)
        if ebs_id is not None:
            logging.info("terminating {}".format(ebs_id))
            ebs = ec2.Volume(ebs_id)
            try:
                # detaching an attached volume is odd
                # you often have to do it more than once, or wait a period or both
                logging.info("detaching {}".format(ebs_id))
                resp = ebs.detach_from_instance(Force=True)
                logging.info("detached resp: {}".format(resp))
                sleep(30)
                # now make sure that it's detached, this might raise but we don't care
                resp = ebs.detach_from_instance(Force=True)
                logging.info("detached resp: {}".format(resp))
            except Exception as e:
                logging.info("detach failed: {}".format(e))
            try:
                resp = ebs.delete()
                logging.info("deleted ebs: {}".format(resp))
            except Exception as e:
                logging.info("delete instance failed: {}".format(e))
    else:
        logging.warning("no role_arn in context: {}".format(context))

    key = xcoms.XcomRDSIdKey(task_instance, ts_nodash).get_key()
    rds_identifier = task_instance.xcom_pull(key=key, task_ids=task_instance.task_id)
    if rds_identifier is not None:
        resource = RDSResource(identifier=rds_identifier,
                               driver=SingleInstanceDriver(),
                               role_arn=context.get('role_arn'))
        try:
            resource.destroy()
        except Exception as top_e:
            logging.info("couldn't destroy {} as single instance: {}".format(resource.identifier, top_e))
            resource.driver = AuroraDriver()
            try:
                resource.destroy()
            except Exception as nested_e:
                logging.warning("couldn't destroy {} as arora: {}".format(resource.identifier, nested_e))
    else:
        logging.warning("no rds_role_arn in context: {}".format(context))


def cleanup_connection_from_xcom(context: dict):
    """
    clean up connection entry in airflow Connection
    :param context: context to parse
    """
    task_instance = context.get('task_instance')
    ts_nodash = context.get('ts_nodash')
    key = xcoms.XcomDBConnIdKey(task_instance, ts_nodash).get_key()
    conn_id = task_instance.xcom_pull(key=key, task_ids=task_instance.task_id)
    if conn_id is not None:
        logging.info("removing connection objects {}".format(conn_id))
        try:
            sesh = settings.Session()
            to_remove = sesh.query(Connection).filter(Connection.conn_id.in_(conn_id)).all()
            for remove in to_remove:
                sesh.delete(remove)
            sesh.commit()
        except Exception as e:
            logging.info("failed to remove connection: {}".format(e))


def cleanup_csv_dump_from_xcom(context: dict):
    """
    clean up connection entry in airflow Connection
    :param context: context to parse
    """
    task_instance = context.get('task_instance')
    ts_nodash = context.get('ts_nodash')
    key = xcoms.XcomCSVKey(task_instance, ts_nodash).get_key()
    csv_file = task_instance.xcom_pull(key=key, task_ids=task_instance.task_id)
    if os.path.exists(csv_file):
        logging.info("deleting csv file {}".format(csv_file))
        os.remove(csv_file)
    else:
        logging.info("could not find {}".format(csv_file))


def cleanup_parquet_dump_from_xcom(context: dict):
    """
    clean up parquet related tmp files on local ebs and s3 bucket
    :param context: context to parse
    """
    task_instance = context.get('task_instance')
    ts_nodash = context.get('ts_nodash')
    key = xcoms.XcomBucketKey(task_instance, ts_nodash).get_key()
    bucket = task_instance.xcom_pull(key=key, task_ids=task_instance.task_id)
    key = xcoms.XcomRoleArnKey(task_instance, ts_nodash).get_key()
    role_arn = task_instance.xcom_pull(key=key, task_ids=task_instance.task_id)
    key = xcoms.XcomS3ScriptKey(task_instance, ts_nodash).get_key()
    s3_script_path = task_instance.xcom_pull(key=key, task_ids=task_instance.task_id)
    key = xcoms.XcomLocalScriptKey(task_instance, ts_nodash).get_key()
    local_script = task_instance.xcom_pull(key=key, task_ids=task_instance.task_id)
    key = xcoms.XcomLocalPQKey(task_instance, ts_nodash).get_key()
    local_parquet = task_instance.xcom_pull(key=key, task_ids=task_instance.task_id)
    key = xcoms.XcomS3PQKey(task_instance, ts_nodash).get_key()
    s3_parquet = task_instance.xcom_pull(key=key, task_ids=task_instance.task_id)
    if local_script is not None and os.path.exists(local_script):
        logging.info("deleting local saprk script {}".foramt(local_script))
        os.remove(local_script)
    if local_parquet is not None and os.path.exists(local_parquet):
        shutil.rmtree(local_parquet, ignore_errors=True)
        logging.info("deleting lcoal parquet output {}".format(local_parquet))
    if bucket is not None and role_arn is not None:
        if s3_script_path is not None:
            logging.info("deleting spark script on s3 {}".format(s3_script_path))
            try:
                remove_s3_object(bucket, s3_script_path, role_arn)
            except KeyError:
                logging.info("couldn't find {} on s3".format(s3_script_path))
            except Exception as e:
                logging.info("failed to remove {bucket} {s3_script_path} on s3: {e}".format(**locals()))
        if s3_parquet is not None:
            logging.info("deleting parquet output on s3 {}".format(s3_parquet))
            try:
                remove_s3_object(bucket, s3_parquet, role_arn)
            except KeyError:
                logging.info("couldn't find {} on s3".format(s3_parquet))
            except Exception as e:
                logging.info("failed to remove {bucket} {s3_parquet} on s3: {e}".format(**locals()))
    else:
        logging.info("could not get valid variables to clean up")


def retrieve_http_response(url: str) -> Dict:
    try:
        response = urlopen(url)
        data = response.read().decode("utf-8")
        return json.loads(data)
    except URLError as e:
        logging.info("URLlib error {0}".format(e))
        return None


def query_yes_no(question: str, default: str="yes") -> bool:
    """Ask a yes/no question via raw_input() and return their answer.

    "question" is a string that is presented to the user.
    "default" is the presumed answer if the user just hits <Enter>.
        It must be "yes" (the default), "no" or None (meaning
        an answer is required of the user).

    The "answer" return value is True for "yes" or False for "no".
    """
    valid = {"yes": True, "y": True,
             "no": False, "n": False}
    if default is None:
        prompt = " [y/n] "
    elif default == "yes":
        prompt = " [Y/n] "
    elif default == "no":
        prompt = " [y/N] "
    else:
        raise ValueError("invalid default answer: '%s'" % default)

    while True:
        sys.stdout.write(question + prompt)
        choice = input().lower()
        if default is not None and choice == '':
            return valid[default]
        elif choice in valid:
            return valid[choice]
        else:
            sys.stdout.write("Please respond with 'yes' or 'no' "
                             "(or 'y' or 'n').\n")
