"""
Adapters
========

Adapters are used to communicate with external resources. In Mekansm's
case, with AWS's API and the Database.
"""

from abc import ABCMeta, abstractmethod
import json
import re

import boto3.session
import botocore.exceptions
import sqlalchemy
import sqlalchemy.exc

from mekansm.errors import ProfileNameNotFound


class BaseContext(object):
    """Base class that implements the data provider interface."""

    __metaclass__ = ABCMeta

    @abstractmethod
    def get_deployment(self, deployment_id):  # pragma: no cover
        """Return a raw deployment dictionary.

        :param int deployment_id: The deployment ID
        :return: A raw deployment dictionary
        :rtype: dict
        """
        pass

    @abstractmethod
    def get_deployments(self, filter_obj):  # pragma: no cover
        """Return a list of raw deployment dictionaries

        :param dict filter_obj: A filter object
        :return: A list of raw deployment dictionaries
        :rtype: list
        """
        pass

    @abstractmethod
    def get_node(self, node_id):  # pragma: no cover
        """Return a raw node dictionary.

        :param node_id: The ID of the node
        :return: A raw node dictionary
        :rtype: dict
        """
        pass


class DB(BaseContext):
    """Implements an adapter API for PostgreSQL"""

    def __init__(self, dbconf, engineconf):
        self.dbconf = dbconf
        self.engine_conf = engineconf
        self._engine = None

    @property
    def engine(self):
        """Return the sqlalchemy engine for the database."""
        if self._engine is None:
            connstring = (
                "postgresql://{user}:{password}@"
                "{host}:{port}/{database}").format(**self.dbconf)
            self._engine = sqlalchemy.create_engine(
                connstring, **self.engine_conf)
        return self._engine

    def disconnect(self):
        """"Release the connections in the pool"""
        self.engine.dispose()

    @staticmethod
    def _insert_deployment(transaction, deployment_data):
        """Create a deployment in the context of a transaction.

        :param transaction: SQLAlchemy transaction
        :param dict deployment_data: raw data to be inserted
        :return: the full row that was actually inserted
        :rtype: dict
        """
        keys = list(deployment_data.keys())
        names = ", ".join('"{}"'.format(key) for key in keys)
        placeholders = ", ".join("%({})s".format(key) for key in keys)
        query = (
            "INSERT INTO deployments ({}) VALUES ({}) RETURNING *;"
        ).format(names, placeholders)
        row = transaction.execute(query, deployment_data).fetchone()
        return dict(row.items())

    @staticmethod
    def _insert_or_find_node(transaction, node_id, datacenter):
        """Finds a node, or create it if it doesn't exist, transactionally.

        :param transaction: SQLAlchemy transaction
        :param str node_id: the ID of the node
        :param str datacenter: The datacenter where the node is.
        :return: None
        :rtype: NoneType
        """
        query = """
            INSERT INTO nodes (id, datacenter)
            SELECT %(node_id)s, %(datacenter)s
            WHERE NOT EXISTS (SELECT 1 FROM nodes WHERE id = %(node_id)s);
            """
        transaction.execute(query, node_id=node_id, datacenter=datacenter)

    @staticmethod
    def _insert_node_status(transaction, node_id, deployment_id, status):
        """Insert a node status to the database, transactionally.

        :param transaction: SQLAlchemy transaction
        :param str node_id: The ID of the node
        :param str deployment_id: The ID of the deployment
        :param status: The status the node will have.
        :return: None
        :rtype: NoneType
        """
        query = """
        INSERT INTO node_status (node, deployment, status)
        VALUES (%(node_id)s, %(deployment_id)s, %(status)s)
        """
        transaction.execute(
            query, node_id=node_id, deployment_id=deployment_id, status=status)

    @staticmethod
    def _insert_node_history(transaction, node_id, deployment_id, status):
        """Insert a node history entry, in the context of a transaction.

        :param transaction: SQLAlchemy transaction
        :param str node_id: The ID of the node
        :param str deployment_id: The ID of the deployment
        :param str status: The status of the event.
        :return: None
        :rtype: NoneType
        """
        query = """
        INSERT INTO node_history (node, deployment, status)
        VALUES (%(node_id)s, %(deployment_id)s, %(status)s)
        """
        transaction.execute(
            query, node_id=node_id, deployment_id=deployment_id, status=status)

    def _set_node_status(self, transaction, node_id, datacenter, deployment_id,
                         status):
        """Update the status of a node, in the context of a transaction.

        This method will update all the necesary tables for a status change.

        :param transaction: SQLAlchemy transaction
        :param str node_id: The ID of the node
        :param str datacenter: The datacenter the node belongs to.
        :param str deployment_id: The ID of the deployment.
        :param str status: The new status.
        :return: None
        :rtype: NoneType
        """
        self._insert_or_find_node(transaction, node_id, datacenter)
        self._insert_node_status(transaction, node_id, deployment_id, status)
        self._insert_node_history(transaction, node_id, deployment_id, status)

    def create_deployment(self, deployment_data, nodes):
        """Create a deployment, dealing with all the tables involved.

        :param dict deployment_data: Raw data to be inserted.
        :param list nodes: A list of raw nodes to be associated
        :return: The raw deployment row that was actually inserted.
        :rtype: dict
        """
        with self.engine.begin() as transaction:
            dep_row = self._insert_deployment(transaction, deployment_data)
            for node_data in nodes:
                self._set_node_status(
                    transaction, node_data["id"], node_data["datacenter"],
                    deployment_data["id"], "InProgress")
            return dep_row

    def set_deployment_status(self, deployment_id, status):
        """Change the status of a deployment and all related tables.

        :param str deployment_id: The ID of the deployment.
        :param str status: The new status
        :return: The number of affected nodes
        :rtype: int
        """
        with self.engine.begin() as transaction:
            transaction.execute(
                "UPDATE deployments SET status=%(status)s "
                "WHERE id=%(deployment_id)s",
                status=status, deployment_id=deployment_id)
            query = """
                WITH updated AS (
                  UPDATE node_status
                  SET status=%(status)s, update_time=now()
                  WHERE deployment=%(deployment_id)s
                  RETURNING id
                )
                INSERT INTO node_history (node, deployment, status)
                SELECT node, deployment, %(status)s
                  FROM node_status
                  WHERE id IN (SELECT id FROM updated);
                """
            res = transaction.execute(
                query, deployment_id=deployment_id, status=status)
            rowcount = res.rowcount
            return rowcount

    def get_deployment(self, deployment_id):
        """Get a deployment object

        :param str deployment_id: The ID of the deployment to get
        :return: The raw dictionary or None if not found
        :rtype: dict or NoneType
        """
        query = "SELECT * FROM deployments WHERE id=%s"
        row = self.engine.execute(query, deployment_id).fetchone()
        if row:
            return dict(row.items())

    def get_deployments(self, filter_obj):
        """Get matching deployment objects.

        :param dict filter_obj: A dictionary with fields to filter.
        :return: A list of raw deployment objects that can be empty
        :rtype: list
        """
        where = []
        subwhere = []
        need_subjoin = False
        if "owner" in filter_obj:
            where.append("owner=%(owner)s")
        if "repo" in filter_obj:
            where.append("repository=%(repo)s")
        if "environment" in filter_obj:
            where.append("environment=%(environment)s")
        if "node" in filter_obj:
            subwhere.append("node=%(node)s")
        if "status" in filter_obj:
            subwhere.append("status=%(status)s")
        if "datacenter" in filter_obj:
            need_subjoin = True
            subwhere.append("nodes.datacenter=%(datacenter)s")
        if subwhere:
            subq = "SELECT node_status.deployment FROM node_status"
            if need_subjoin:
                subq += " INNER JOIN nodes ON node_status.node = nodes.id"
            subq += " WHERE "
            subq += " AND ".join(subwhere)
            where.append("id IN ({})".format(subq))
        query = "SELECT * FROM deployments"
        if where:
            query += " WHERE " + " AND ".join(where) + ";"
        return [
            dict(row.items()) for row in
            self.engine.execute(query, filter_obj).fetchall()
        ]

    def get_nodes(self, filter_obj):
        """Get matching node objects.

        :param dict filter_obj: A dictionary with fields to filter.
        :return: A list of raw node objects that can be empty
        :rtype: list
        """
        where = []
        extra_join = False
        if "deployment" in filter_obj:
            where.append("node_status.deployment=%(deployment)s")
        if "deployment_status" in filter_obj:
            where.append("node_status.status=%(deployment_status)s")
        if "owner" in filter_obj:
            extra_join = True
            where.append("deployments.owner=%(owner)s")
        if "repo" in filter_obj:
            extra_join = True
            where.append("deployments.repository=%(repo)s")
        if "environment" in filter_obj:
            extra_join = True
            where.append("deployments.environment=%(environment)s")
        query = ("SELECT DISTINCT nodes.* FROM node_status"
                 " INNER JOIN nodes ON node_status.node=nodes.id")
        if extra_join:
            query += (" INNER JOIN deployments "
                      "ON node_status.deployment=deployments.id")
        if where:
            query += " WHERE " + " AND ".join(where)
            query += " ORDER BY nodes.datacenter, nodes.id;"
        return [
            dict(row.items()) for row in
            self.engine.execute(query, filter_obj).fetchall()
        ]

    def get_node(self, node_id):
        """Get a node object

        :param str node_id: The ID of the node to get
        :return: The raw dictionary or None if not found
        :rtype: dict or NoneType
        """
        query = "SELECT * FROM nodes WHERE id=%s"
        row = self.engine.execute(query, node_id).fetchone()
        if row is not None:
            return dict(row.items())

    def get_deployment_instance_status(self, deployment_id, node_id):
        """Get a node status object

        :param str deployment_id: The ID of the deployment to get
        :param str node_id: The ID of the node to get
        :return: The raw dictionary or None if not found
        :rtype: dict or NoneType
        """
        query = (
            "SELECT COUNT(1) FROM node_status "
            "WHERE deployment=%(deployment)s AND node=%(node)s;")
        row = self.engine.execute(query, deployment=deployment_id, node=node_id).fetchone()
        row = dict(row.items())
        assert 1 == row["count"]
        query = (
            "SELECT * FROM node_status "
            "WHERE deployment=%(deployment)s AND node=%(node)s;")
        row = self.engine.execute(
            query, deployment=deployment_id, node=node_id).fetchone()
        if row is not None:
            return dict(row.items())

    def set_deployment_instance_status(self, deployment_id, node_id, status):
        """Get a node status object

        :param str deployment_id: The ID of the deployment to get
        :param str node_id: The ID of the node to get
        :param str status: New status of the node
        :return: None
        :rtype: NoneType
        """
        with self.engine.begin() as transaction:
            query = """
                WITH updated AS (
                  UPDATE node_status
                  SET status=%(status)s, update_time=now()
                  WHERE deployment=%(deployment)s AND node=%(node)s
                  RETURNING id
                )
                INSERT INTO node_status (node, deployment, status)
                SELECT %(node)s, %(deployment)s, %(status)s
                WHERE NOT EXISTS (SELECT * FROM updated)
                """
            res = transaction.execute(
                query, deployment=deployment_id, node=node_id, status=status)
            return res.rowcount

    def health_check(self):
        """Return the status of the database connection

        If the status is OK, return basic counts to verify DB health
        """
        query = """
        SELECT
            (SELECT COUNT(1) FROM deployments) AS deployments,
            (SELECT COUNT(1) FROM nodes) AS nodes;
        """
        try:
            row = self.engine.execute(query).fetchone()
        except sqlalchemy.exc.SQLAlchemyError as exc:
            return "CRITICAL {}".format(exc)
        return "OK Found {} deployments and {} nodes".format(
            row["deployments"], row["nodes"])


class AwsClient:
    """Implements a subset of the API using Code Deploy."""

    instance_re = re.compile(r'^i-([a-z0-9]{8}|[a-z0-9]{17})$')

    def __init__(self, session, profile_name):
        """Code Deploy API initializer.

        :param codedeploy_client: A boto code deploy client instance.
        :param profile_name: The AWS profile used for this session
        """
        self.session = session
        self.cd = session.client("codedeploy")
        self.ec2 = session.client("ec2")
        self.profile_name = profile_name

    def get_deployment(self, deployment_id):
        """Get deployment information from AWS

        Return None if not found

        :param str deployment_id: The Deployment ID
        :return: The raw deployment object, or None if not found
        :rtype: dict or NoneType
        """
        try:
            return self.cd.get_deployment(deploymentId=deployment_id)
        except botocore.exceptions.ClientError as exc:
            if exc.response["Error"]["Code"] in (
                "DeploymentDoesNotExistException",
                "InvalidDeploymentIdException"
            ):
                return
            raise

    @classmethod
    def is_instance_id(cls, instance_id):
        """Is the argument a valid EC2 instance ID"""
        return cls.instance_re.match(instance_id) is not None

    def get_deployment_instance_status(self, deployment_id, node_id):
        """Get deployment instance status data from AWS

        Return None if not found

        :param str deployment_id: The Deployment ID
        :param str node_id: The Node ID
        :return: The raw node status object, or None if not found
        :rtype: dict or NoneType
        """
        if self.is_instance_id(node_id):
            instance_id = node_id
        else:
            try:
                resp = self.ec2.describe_tags(
                    Filters=[{"Name": "value", "Values": [node_id]}])
            except botocore.exceptions.ClientError:
                return
            if not resp["Tags"]:
                return
            assert len(resp["Tags"]) == 1
            instance_id = resp["Tags"][0]["ResourceId"]
        try:
            resp = self.cd.get_deployment_instance(
                deploymentId=deployment_id, instanceId=instance_id)
            return {
                "node": node_id,
                "deployment": deployment_id,
                "update_time": resp["instanceSummary"]["lastUpdatedAt"],
                "status": resp["instanceSummary"]["status"],
                "events": resp["instanceSummary"]["lifecycleEvents"]
            }
        except botocore.exceptions.ClientError as exc:
            if exc.response["Error"]["Code"] in (
                "DeploymentDoesNotExistException",
                "InstanceDoesNotExistException"
            ):
                # return None if not found, let the broker handle it
                return
            raise

    def get_node_ids(self, filter_obj):
        """Return a list of node ids

        :param dict filter_obj: A dictionary with fields to filter.
        :return: A list of instance IDs that can be empty
        :rtype: list
        """
        kwargs = {"deploymentId": filter_obj["deployment"]}
        instances = []
        while True:
            try:
                res = self.cd.list_deployment_instances(**kwargs)
            except botocore.exceptions.ClientError as exc:
                if exc.response["Error"]["Code"] in (
                    "DeploymentDoesNotExistException"
                ):
                    return
                raise
            instances.extend(res["instancesList"])
            if "nextToken" not in res:
                break
            kwargs["nextToken"] = res["nextToken"]
        return instances

    def get_hostnames_from_instance_ids(self, instance_ids):
        return {
            tag["ResourceId"]: tag["Value"].replace("$ID", tag["ResourceId"][2:])
            for tag in self.ec2.describe_tags(
                Filters=[{"Name": "resource-id", "Values": instance_ids}]
            )["Tags"]
            if tag["Key"] == "Name"
        }

    @staticmethod
    def decorate_instance_id(instance_id):
        return instance_id.split("/").pop()

    def get_deployment_nodes_status(self, deployment_id):
        """Return the status of all instances in a deployment.

        :param deployment_id: The deployment ID to analyze.
        :return: A list of nodes statuses.
        """
        node_ids = self.get_node_ids({"deployment": deployment_id})
        if not node_ids:
            return []
        kwargs = {"deploymentId": deployment_id, "instanceIds": node_ids}
        instances = []
        while True:
            res = self.cd.batch_get_deployment_instances(**kwargs)
            instances.extend(res["instancesSummary"])
            if "nextToken" not in res:
                break
            kwargs["nextToken"] = res["nextToken"]
        if not instances:
            return []
        iids = [
            self.decorate_instance_id(status["instanceId"])
            for status in instances
            if self.is_instance_id(
                self.decorate_instance_id(status["instanceId"]))
        ]
        hostnames = self.get_hostnames_from_instance_ids(iids)
        for status in instances:
            instance_id = self.decorate_instance_id(status["instanceId"])
            status["node"] = hostnames.get(instance_id, instance_id)
        return instances

    def create_deployment(self, deployment_data):
        """Create a deployment in Code Deploy.

        :param dict deployment_data: Raw data to be inserted.
        :return: The deployment ID that was created
        :rtype: int
        """
        s3loctoks = deployment_data["s3location"].split("/", 1)
        desc_json_fields = ["sha"]
        return self.cd.create_deployment(
            applicationName=deployment_data["application"],
            deploymentGroupName=deployment_data["group"],
            revision={
                "revisionType": "S3",
                "s3Location": {
                    "bucket": s3loctoks[0],
                    "key": s3loctoks[1],
                    "bundleType": deployment_data["bundletype"]
                }
            },
            deploymentConfigName=deployment_data["config"],
            description=json.dumps({
                field: deployment_data[field] for field in desc_json_fields
            })
        )["deploymentId"]

    @classmethod
    def from_profile_and_region(cls, profile, region):
        """Creates a boto code deploy client to be passed to the initializer

        :param str profile: The AWS profile
        :param str region: The AWS region
        :return: a boto3 code deploy client
        """
        session = boto3.session.Session(region_name=region,
                                        profile_name=profile)
        return cls(session, profile)

    def health_check(self):
        """Return the status of the AWS client connection.

        The format is a tuple with the profile name and the status.
        """
        name = "aws " + self.profile_name
        try:
            self.cd.list_deployments()
            return name, "OK"
        except botocore.exceptions.ClientError as exc:
            return name, "CRITICAL {}: {}".format(
                exc.response["Error"]["Code"],
                exc.response["Error"]["Message"])


class AwsBroker:
    """A broker that iterates calls through code deploy accounts."""

    def __init__(self, config):
        self.config = config
        self.accounts = [
            AwsClient.from_profile_and_region(
                profile_name, value["region"])
            for profile_name, value in config.items()
            if "region" in value
        ]

    def get_deployment(self, deployment_id):
        """Return an AWS deployment object.

        :param str deployment_id: The Deployment ID
        :return: The raw deployment object, or None if not found
        :rtype: dict or NoneType
        """
        for account in self.accounts:
            found = account.get_deployment(deployment_id)
            if found is not None:
                return found

    def get_deployment_instance_status(self, deployment_id, node_id):
        """Return an instance status object.

        :param str deployment_id: The Deployment ID
        :param str node_id: The Node ID
        :return: The raw node status object, or None if not found
        :rtype: dict or NoneType
        """
        for account in self.accounts:
            found = account.get_deployment_instance_status(deployment_id, node_id)
            if found is not None:
                return found

    def create_deployment(self, deployment_data, nodes, db):
        """Create a deployment in Code Deploy.

        :param dict deployment_data: Raw data to be inserted.
        :param list nodes: A list of raw nodes to be associated
        :return: The raw deployment row that was actually inserted.
        :rtype: dict
        """
        deployment_account = deployment_data["account"]
        for account in self.accounts:
            if account.profile_name == deployment_account:
                deployment_id = account.create_deployment(deployment_data)
                deployment_data["id"] = deployment_id
                return db.create_deployment(deployment_data, nodes)
        raise ProfileNameNotFound(deployment_account)

    def get_node_ids(self, filter_obj):
        """Return a list of node ids

        :param dict filter_obj: A dictionary with fields to filter.
        :return: A list of instance IDs that can be empty
        :rtype: list
        """
        assert {"deployment"} >= set(filter_obj), "Unhandled filters"
        for account in self.accounts:
            res = account.get_node_ids(filter_obj)
            if res:
                return res
        return []

    def get_deployment_nodes_status(self, deployment_id):
        """Return the status of all instances in a deployment.

        :param deployment_id: The deployment ID to analyze.
        :return: A list of nodes statuses.
        """
        for account in self.accounts:
            res = account.get_deployment_nodes_status(deployment_id)
            if res:
                return res
        return []

    def health_check(self):
        """Return a dictionary of statuses for AWS connections
        """
        return dict(item.health_check() for item in self.accounts)


class Context(BaseContext):
    """Implements the current application context for adapters.

    This is a base interface for the business logic to access the
    underlying resources using adapters.
    """
    def __init__(self, config):
        self.config = config
        self.db = self.get_db(config["database"])
        aws_config = botocore.session.Session().full_config["profiles"]
        self.aws = AwsBroker(aws_config)

    @staticmethod
    def get_db(dbconf):
        """Return a DB object given a database configuration.

        :param dict dbconf: Database configuration as given to the initializer
        :return: a Context DB object
        :rtype: mekansm.adapters.DB
        """
        engine_conf = dbconf.pop("engine", {})
        return DB(dbconf, engine_conf)

    def destroy(self):
        """Destroy associates context instances.

        :return: None
        :rtype: NoneType
        """
        self.db.disconnect()

    def create_deployment(self, deployment_data, nodes):
        """Create a deployment.

        :param dict deployment_data: Raw data to be inserted.
        :param list nodes: A list of raw nodes to be associated
        :return: The raw deployment row that was actually inserted.
        :rtype: dict
        """
        return self.aws.create_deployment(deployment_data, nodes, self.db)

    def get_deployment(self, deployment_id):
        """Return a raw deployment dictionary.

        :param int deployment_id: The deployment ID
        :return: A raw deployment dictionary
        :rtype: dict
        """
        raw = self.aws.get_deployment(deployment_id)
        known = self.db.get_deployment(deployment_id)
        if not known or raw["deploymentInfo"]["status"] != known["status"]:
            self.set_deployment_status(
                deployment_id, raw["deploymentInfo"]["status"])
        return self.db.get_deployment(deployment_id)

    def set_deployment_status(self, deployment_id, status):
        """Change the status of a deployment.

        :param str deployment_id: The ID of the deployment.
        :param str status: The new status
        :return: The number of affected nodes
        :rtype: int
        """
        return self.db.set_deployment_status(deployment_id, status)

    def get_deployments(self, filter_obj):
        """Return a list of raw deployment dictionaries

        :param dict filter_obj: A filter object
        :return: A list of raw deployment dictionaries
        :rtype: list
        """
        return self.db.get_deployments(filter_obj)

    def get_node_ids(self, filter_obj):
        """Return a list of node ids

        :param dict filter_obj: A filter object
        :return: A list of node ids
        :rtype: list
        """
        return self.aws.get_node_ids(filter_obj)

    def get_nodes(self, filter_obj):
        """Return a list of node objects

        :param dict filter_obj: A filter object
        :return: A list of node objects
        :rtype: list
        """
        return self.db.get_nodes(filter_obj)

    def get_node(self, node_id):
        """Return a raw node dictionary.

        :param node_id: The ID of the node
        :return: A raw node dictionary
        :rtype: dict
        """
        return self.db.get_node(node_id)

    def get_deployment_instance_status(self, deployment_id, node_id):
        """Return an instance status object. Update local cache from AWS.

        :param str deployment_id: The Deployment ID
        :param str node_id: The Node ID
        :return: The raw node status object, or None if not found
        :rtype: dict or NoneType
        """
        value = self.aws.get_deployment_instance_status(deployment_id, node_id)
        if value:
            self.db.set_deployment_instance_status(
                deployment_id, node_id, value["status"])
        return value

    def get_deployment_nodes_status(self, deployment_id):
        """Return the status of all instances in a deployment.

        :param deployment_id: The deployment ID to analyze.
        :return: A list of nodes statuses.
        """
        return self.aws.get_deployment_nodes_status(deployment_id)

    def health_check(self):
        """Return a dictionary of health information for db and aws conns.
        """
        res =self.aws.health_check()
        res["db"] = self.db.health_check()
        return res

