# standard library imports
from time import sleep
import logging
from typing import Dict, Any, Tuple
from datetime import datetime
from copy import deepcopy
from botocore.exceptions import ClientError
from sqlalchemy.engine import Engine
from sqlalchemy.exc import DatabaseError
from zenyatta.aws import boto_client
from zenyatta.common import get_airflow_connection, xcoms


class RDSResource:
    def __init__(self,
                 identifier: str=None,  # this is the identifier of the RDS instance itself
                 driver=None,
                 role_arn: str=None
                 ):
        self.identifier = identifier
        self.role_arn = role_arn
        if driver:
            self.driver = driver
            self.driver.set_identifier(identifier)
            self.driver.set_role_arn(role_arn)
        else:
            # this is done such that NotImplemented errors are raised instead of
            # AttributeError: 'NoneType' object has no attribute <name of function not implemented>
            self.driver = RDSDriver(role_arn=role_arn)

    def destroy(self):
        """destroys the RDS resource in AWk
        :return:
        """
        return self.driver.destroy()

    def point_in_time_recovery(self, point_in_time: datetime, target):  # target: RDSResource
        """performs a point in time recovery on an existing RDS resource, creating new resources
        :param point_in_time: point in time for the resource to be restored to
        :param target: target resource for the point in time recovery
        :return:
        """
        return self.driver.point_in_time_recovery(point_in_time, target)

    def get_rds_metadata(self) -> Dict:
        """
        :return: returns results of something like boto3.rds describe_db_instances for the specific resource
        """
        return self.driver.get_rds_metadata()

    def get_host_and_port(self) -> Tuple[str, str]:
        return self.driver.get_host_and_port()

    def does_rds_host_exist(self) -> bool:
        try:
            self.driver.get_rds_metadata()
            # if this doesn't throw an exception, the host exists
            return True
        except IndexError:  # should throw an index error if no instances are returned
            return False
        except ClientError:
            # this exception is thrown if the instance doesn't exist
            return False

    def get_rds_status(self) -> str:
        return self.driver.status()

    def security_groups(self):
        return self.driver.security_groups()

    def subnet(self):
        return self.driver.subnet()

    def create_sql_engine(self, stream_results=True) -> Engine:
        return self.driver.create_sql_engine(stream_results=stream_results)

    def wait_for_rds_status(self, desired_status: str) -> bool:
        return self.driver.wait_for_rds_status(desired_status)

    def reader_instance(self):
        return self.driver.reader_instance()

    def status(self):
        return self.driver.status()

    @staticmethod
    def db_instance_identifier(db_host: str):
        """format for RDS host fqdn is DBInstanceIdentifier.account-id.aws-region.amazonaws.com:port
        example:
            discovery-staging-master-pitr-etl-20161129t200000.cifgffw7w2ar.us-west-2.rds.amazonaws.com:5432
        algorithm:
            split on periods, and return the first item
        """
        return db_host.split('.').pop(0)


# wrapper class for drivers due to common attributes
class RDSDriver:
    def __init__(self, identifier=None, sql_driver=None, role_arn: str=None):
        self.identifier = identifier
        self.sql_driver = sql_driver
        self.role_arn = role_arn

    def set_identifier(self, value: str):
        self.identifier = value

    def set_role_arn(self, value: str):
        self.role_arn = value

    def get_boto(self):
        client, _ = boto_client('rds', self.role_arn)
        return client

    def create_sql_engine(self, stream_results=True) -> Engine:
        # these methods are implemented in zenyatta/db/sql.py
        host, port = self.get_host_and_port()
        return self.sql_driver.create_sql_engine(host=host, port=port, stream_results=stream_results)

    def point_in_time_recovery(self, point_in_time: datetime, target: RDSResource):
        return NotImplementedError()

    def destroy(self):
        return NotImplementedError()

    def get_rds_metadata(self):
        raise NotImplementedError()

    def status(self):
        raise NotImplementedError()

    def subnet(self):
        raise NotImplementedError()

    def get_host_and_port(self) -> Tuple[str, str]:
        raise NotImplementedError()

    def reader_instance(self):
        raise NotImplementedError()

    def wait_for_rds_status(self, desired_status: str) -> bool:
        """waits for rds instance to achieve a specific status
        """
        while True:
            status = self.status()
            if desired_status in status:
                return True
            elif 'failed' in status:
                raise ValueError("{self.identifier} failed, current info: {status} and not "
                                 "{desired_status}".format(**locals()))
            else:
                # get this into the logs in case it doesn't come up for some odd reason
                logging.info("sleeping due to status: {status} waiting for "
                             "{desired_status}".format(**locals()))
                sleep(60)

    def security_groups(self):
        return [sg['VpcSecurityGroupId'] for sg in self.get_rds_metadata()['VpcSecurityGroups']]


class SingleInstanceDriver(RDSDriver):
    """generally postgres and mysql rds instances behave exactly the same with respect to the boto api.
    """

    def subnet(self):
        return self.get_rds_metadata()["DBSubnetGroup"]["DBSubnetGroupName"]

    def status(self):
        rds_meta = self.get_rds_metadata()
        return rds_meta.get('DBInstanceStatus')

    def destroy(self):
        boto = self.get_boto()
        resp = boto.delete_db_instance(DBInstanceIdentifier=self.identifier, SkipFinalSnapshot=True)
        logging.info("destroyed {self.identifier} with response: {resp}".format(**locals()))
        return resp

    def point_in_time_recovery(self, point_in_time, target):
        boto = self.get_boto()
        recovery_resp = boto.restore_db_instance_to_point_in_time(
            SourceDBInstanceIdentifier=self.identifier,
            TargetDBInstanceIdentifier=target.identifier,
            RestoreTime=point_in_time,
            DBSubnetGroupName=self.subnet(),
            Tags=[{'Key': 'project', 'Value': 'zenyatta'}])
        logging.info("point in time recovery response: {recovery_resp}".format(**locals()))
        target.wait_for_rds_status('available')

        # refresh boto client, it's only good for an hour
        boto = self.get_boto()
        # now change security group and wait, only applies to postgres
        modify_resp = boto.modify_db_instance(DBInstanceIdentifier=target.identifier,
                                              BackupRetentionPeriod=0,  # means no backups for these dbs
                                              ApplyImmediately=True,
                                              VpcSecurityGroupIds=self.security_groups())
        logging.info("modify db response: {modify_resp}".format(**locals()))
        target.wait_for_rds_status('available')
        return recovery_resp, modify_resp

    def get_host_and_port(self) -> Tuple[str, str]:
        instance = self.get_rds_metadata()
        host = instance['Endpoint']['Address']
        port = instance['Endpoint']['Port']
        return host, port

    def get_rds_metadata(self) -> Dict:
        boto = self.get_boto()
        return boto.describe_db_instances(DBInstanceIdentifier=self.identifier)['DBInstances'][0]


class AuroraDriver(RDSDriver):

    def subnet(self):
        return self.get_rds_metadata()["DBSubnetGroup"]

    def status(self):
        rds_meta = self.get_rds_metadata()
        return rds_meta.get('Status')

    def destroy(self):
        boto = self.get_boto()
        instance_resp = None
        try:
            instance_resp = boto.delete_db_instance(DBInstanceIdentifier=self.reader_instance(),
                                                    SkipFinalSnapshot=True)
            logging.info("deleted {} from aurora cluster {} with response: {}"
                         .format(self.reader_instance(), self.identifier, instance_resp))
            sleep(15)  # make sure delete happens
        except ClientError:
            # it's possible the db doesn't exist but the cluster does
            pass

        cluster_resp = boto.delete_db_cluster(DBClusterIdentifier=self.identifier, SkipFinalSnapshot=True)
        logging.info("deleted cluster: {self.identifier} with {cluster_resp}".format(**locals()))
        return instance_resp if instance_resp else {}, cluster_resp

    def point_in_time_recovery(self, point_in_time: datetime, target):  # target is an AuroraDriver
        boto = self.get_boto()
        restore_resp = boto.restore_db_cluster_to_point_in_time(
            SourceDBClusterIdentifier=self.identifier,
            DBClusterIdentifier=target.identifier,
            RestoreToTime=point_in_time,
            VpcSecurityGroupIds=self.security_groups(),
            DBSubnetGroupName=self.subnet(),
            Tags=[{'Key': 'project', 'Value': 'zenyatta'}])
        target.wait_for_rds_status('available')
        # minimum backups
        modify_response = boto.modify_db_cluster(DBClusterIdentifier=target.identifier,
                                                 ApplyImmediately=True,
                                                 BackupRetentionPeriod=1)
        target.wait_for_rds_status('available')
        logging.info("modified cluster: {modify_response}".format(**locals()))
        # setup reader
        logging.info("adding db instance: {} to cluster {}"
                     .format(target.reader_instance(), target.identifier))
        create_resp = boto.create_db_instance(
            DBInstanceIdentifier=target.reader_instance(),
            DBInstanceClass='db.r3.8xlarge',
            Engine='aurora',
            DBSubnetGroupName=target.subnet(),
            DBClusterIdentifier=target.identifier)
        # TODO need to add something here to wait for new instance to be ready
        return restore_resp, modify_response, create_resp

    def get_host_and_port(self) -> Tuple[str, str]:
        instance = self.get_rds_metadata()
        host = instance['Endpoint']
        port = '3306'
        return host, port

    def get_rds_metadata(self, identifier: str=None) -> Dict:
        boto = self.get_boto()
        return boto.describe_db_clusters(DBClusterIdentifier=self.identifier)['DBClusters'][0]

    def reader_instance(self):
        return self.identifier+"-reader"


def create_rds_instance_from_point_in_time(source: RDSResource, **kwargs: Dict[Any, Any]):
    """does a PITR on an rds instance. This function is a callback inside a PythonOperator inside airflow
    :param source: source of point in time recovery
    :param kwargs: kwargs provided by TaskInstance from airflow
    """

    target = RDSResource(
        identifier=generate_etl_db_instance_identifier(source.identifier, kwargs.get('ts_nodash')),
        role_arn=source.role_arn,
        driver=deepcopy(source.driver))

    task_instance = kwargs.get('task_instance')
    key = xcoms.XcomRDSIdKey(task_instance, kwargs.get('ts_nodash')).get_key()
    logging.info("pushing xcom: {}/{} from task: {}".format(key, target.identifier, task_instance.task_id))
    task_instance.xcom_push(key=key, value=target.identifier)
    if target.does_rds_host_exist():
        # it's much easier to just delete it and recreate it than attempt to figure out
        # what state it's in, etc
        try:
            delete_rds_resource(target)
        except ClientError as e:
            # cannot delete the host for some reason
            logging.warning("could not delete host for some reason: {e}".format(**locals()))
            pass

    cleaned_ts = kwargs.get('ts').split('.')[0]  # in case it's like this 2016-08-25T00:21:58.262914
    point_in_time = datetime.strptime(cleaned_ts, "%Y-%m-%dT%H:%M:%S")
    logging.info("point_in_time={point_in_time} source_db={source.identifier}"
                 " destination_db={target.identifier}".format(**locals()))

    recovery, modify = source.point_in_time_recovery(point_in_time, target)
    return target


def delete_rds_resource(resource: RDSResource):
    """deletes a specific rds instance and skips the final snapshot because these are ephemeral instances
    :param resource: metadata for rds resource
    :return:
    """
    logging.info("deleting rds instance: {resource.identifier}".format(**locals()))
    resource.destroy()

    max_attempts = 20  # wait for 40 minutes, max
    attempts = 0
    while resource.does_rds_host_exist() and attempts < max_attempts:
        # ensure rds host is deleted
        sleep(120)
        attempts += 1
    if attempts == max_attempts:
        return False
    else:
        return True


def open_rds_instance_permissions(source: RDSResource=None, role: str=None,
                                  conn_id: str=None, **kwargs: Dict[Any, Any]):
    """originally created for cohesion due to roles being not established for the rds_superuser

    does two things:
        1: grants the rds_superuser access to the public tables
        2: sets the statement_timeout for the rds_superuser to 0
    """
    connection = get_airflow_connection(conn_id)
    target = RDSResource(
        identifier=generate_etl_db_instance_identifier(source.identifier, kwargs.get('ts_nodash')),
        role_arn=source.role_arn,
        driver=deepcopy(source.driver))

    ts_nodash = kwargs.get('ts_nodash')
    task_instance = kwargs.get('task_instance')
    key = xcoms.XcomRDSIdKey(task_instance, ts_nodash).get_key()
    logging.info("pushing xcom: {}/{} from task: {}".format(key, target.identifier, task_instance.task_id))
    task_instance.xcom_push(key=key, value=target.identifier)
    host, port = target.get_host_and_port()

    # this is extremely important because there really are no results to stream
    # psycopg will attempt to make a cursor which is an error for a grant/alter statement
    db_engine = target.create_sql_engine(stream_results=False)

    # set permissions
    grant_sql = "GRANT {role} TO {connection.login}".format(**locals())

    # TODO break this out into it's own task
    with db_engine.begin() as db_connection:
        try:
            logging.info("{db_engine}: modifying {host} with: {grant_sql}".format(**locals()))
            results = db_connection.execute(grant_sql)
            logging.info("results of modification: {results}".format(**locals()))
        except DatabaseError as de:
            logging.info("failed to modify {host} with {grant_sql} due to {de} -- this is likely ok"
                         .format(**locals()))
        except:
            logging.info("failed to modify {host} with {grant_sql}".format(**locals()))
            # i don't know what the exception being thrown here is
            # calling raise with no arguments re-raises the same exception
            raise

    # set timeout
    # TODO break this out into it's own task
    statement_timeout_sql = "ALTER ROLE {connection.login} SET statement_timeout=0".format(**locals())
    with db_engine.begin() as db_connection:
        try:
            logging.info("{db_engine}: modifying {host} with: {statement_timeout_sql}".format(**locals()))
            results = db_connection.execute(statement_timeout_sql)
            logging.info("results of modification: {results}".format(**locals()))
        except:
            logging.info("failed to modify {host} with {statement_timeout_sql}".format(**locals()))
            # i don't know what the exception being thrown here is
            # calling raise with no arguments re-raises the same exception
            raise


def generate_etl_db_instance_identifier(identifier: str, ts_nodash: str) -> str:
    """logic for naming of etl instances such that tasks can find their database
    :param identifier: string id of name of rds instance
    :param ts_nodash: a string date in the format 20160909T1200
    :return: string
    """
    if len(identifier) + len(ts_nodash) >= 63:  # not legal identifier via AWS
        splits = identifier.split('-')
        # take the first 3 elements, because it will get the longest cohesion instance uniquely
        # and get the last element becuase it's likely a timestasmp
        identifier = "-".join(splits[:3])
    if '.' in ts_nodash:
        # cannot have periods in the identifier name
        ts_nodash = ts_nodash.replace('.', '')
    return "{identifier}-etl-{ts_nodash}".format(**locals()).lower()
