import socket
import os
import logging
from subprocess import CalledProcessError
from time import sleep
from typing import Dict, Any
from airflow.models import Connection
from airflow.settings import Session
from sqlalchemy import and_
from zenyatta.common import get_airflow_connection, get_docker_metadata, get_work_directory
from zenyatta.common.util import run_command
from zenyatta.common.errors import ZenyattaError
from zenyatta.db.sql import PostgresSQL


class DockerContainer:

    def __init__(self, name=None, conn_id=None, host=None, port=None):
        assert host
        self.conn_id = conn_id
        self.host = host
        self.name = name
        self.port = str(port)


class PostgresContainer(DockerContainer):

    def __init__(self, data_directory, version='9.5', memory='8g', repository=None, driver=None,
                 conn_id=None, **kwargs):

        # get kwargs from connection
        connection = get_airflow_connection(conn_id)
        kwargs.update({'host': connection.host,
                       'port': 5432,
                       'name': conn_id,
                       'conn_id': conn_id})
        super().__init__(**kwargs)
        self.schema = connection.schema
        self.data_directory = data_directory
        self.version = version
        self.memory = memory
        self.repository = repository
        self.driver = driver if driver else PostgresSQL(self.conn_id)

    def create_sql_engine(self):
        return self.driver.create_sql_engine(host=self.host, port=self.port)

    def pg_isready(self, attempts: int=0) -> bool:
        """naive version of pg_isready
        :param conn_id: airflow.models.Connection.conn_id
        :param container_port: port of docker container
        :param attempts: integer for the number of attempts
        :return:
        """
        if attempts > 50:
            raise ZenyattaError("postgres never became available")
        logging.info("checking {self.name} for availability".format(**locals()))
        db_engine = self.driver.create_sql_engine(host=self.host, port=self.port)
        try:
            tables = db_engine.table_names()
            logging.info("postgres is ready: {tables}".format(**locals()))
            return True
        except Exception as e:
            logging.info("while waiting for postgres caught {}, sleeping for 120 seconds"
                         .format(str(e)))
            sleep(120)
            self.pg_isready(attempts=attempts+1)

        return False

    def engine_string(self, host: str='localhost', port: str=None) -> str:
        connection = get_airflow_connection(self.conn_id)
        engine_string = "postgresql://{login}:{password}@{host}:{port}/{database}".format(
            login=connection.login,
            password=connection.get_password(),
            host=host,
            port=port,
            database=connection.schema)
        return engine_string

    def run_container(self):
        run_command(["ssh", self.host, "sudo",
                     "docker",
                     "run",
                     "--detach",
                     "--init",  # Use an init process to reap WAL recovery processes
                     "--memory", self.memory,  # set a default here of 8g
                     "--name", self.name,
                     "-e", "PGDATA=/var/lib/postgresql/{}/main".format(self.version),
                     "-v", "{}:/var/lib/postgresql/{}/main".format(self.data_directory, self.version),
                     "-v", "/opt/postgres/:/etc/postgresql/{}/main/".format(self.version),
                     "-p", "5432:5432",
                     self.repository])

    def start_container(self):
        return run_command(["ssh", self.host, "sudo",
                            "docker",
                            "start",
                            self.name])


def docker_login(connection: Connection) -> None:
    """allows an instance to connect to the docker repository in the elastic container registry in AWS
    :return:
    """
    # get login command
    command = run_command(["aws", "ecr", "get-login", "--region", "us-west-2"]).decode().split()
    command = ['ssh', connection.host, "sudo"] + command
    login_results = run_command(command).decode()
    if 'Succeeded' not in login_results:
        raise ZenyattaError("couldn't login to docker repo")


def docker_pull(connection: Connection, repository: str) -> None:
    """Pull the latest version of the Docker container."""
    run_command(['ssh', connection.host, 'sudo', 'docker', 'pull', repository]).decode()


def remove_container(container: DockerContainer, env: Dict[Any, Any]=None) -> None:
    stop = ["ssh", container.host,  "sudo", "docker", "stop", container.name]

    """if we do not force removal, there's an edge case with docker when the host is under load
    and the container's filesystem wont get removed. This will prevent a simple `docker rm container`
    from succeeding and thus prevent new containers with the same name ever being made
    """
    rm = ["ssh", container.host, "sudo", "docker", "rm", "-f", container.name]
    for command in [stop, rm]:
        try:
            run_command(command, env=env)
        except CalledProcessError:
            pass


def remove_docker_postgres(master_conn_id: str, **kwargs) -> None:
    """wrapper function to remove_container such that we can get some metadata from
    the TaskInstance class from airflow and also remove it's PGDATA dir that's on the worker host
    :param master_conn_id: conn_id in airflow metadatabase
    :param kwargs: airflow provided context for this task, primarily so we can get ts_nodash
    :return:
    """
    ts_nodash = kwargs.get('ts_nodash')
    container = PostgresContainer(conn_id=master_conn_id, ts_nodash=ts_nodash)
    remove_container(container=container)
    # cleanup after container
    run_command(["sudo", "rm", "-rf", container.data_directory])


def create_docker_postgres(conn_id: str, **kwargs):
    """creates a postgres container for a specific TaskInstance and waits for it to be ready
    :param conn_id: conn_id in airflow metadatabase
    :param kwargs: airflow provided context for this task, primarily so we can get ts_nodash
    :return:
    """
    ts_nodash = kwargs.get('ts_nodash')
    task_instance = kwargs.get('task_instance')
    key = task_instance.dag_id + '-conn_id-' + ts_nodash
    tmp_conn = get_airflow_connection(conn_id)
    task_ids = 'create-ec2-sql-' + tmp_conn.schema
    connection = get_airflow_connection(task_instance.xcom_pull(key=key, task_ids=task_ids))

    dag_id = kwargs.get('task_instance').dag_id
    work_dir = get_work_directory(dag_id, ts_nodash)
    docker_meta = get_docker_metadata('aws')  # Dict[str, str]
    container = PostgresContainer(work_dir, conn_id=connection.conn_id, **docker_meta)
    # must login to docker registry as part of docker/aws workflow to use containers
    docker_login(connection)
    # pull the latest version of the image
    docker_pull(connection, container.repository)
    # pre-emptively remove container in case previous attempt failed
    remove_container(container)
    # docker run command

    # make sure mount point is there
    run_command(["ssh",
                 connection.host,
                 "sudo",
                 "mkdir",
                 "-p",
                 work_dir])

    # mount ebs volume
    try:
        run_command(["ssh",
                     connection.host,
                     "sudo",
                     "mount",
                     "/dev/xvdc",
                     work_dir])
    except CalledProcessError:
        # this can fail if the drive is already mounted, and that's fine
        pass

    # cleanup replica information if it exists
    run_command(["ssh",
                 connection.host,
                 "sudo",
                 "rm",
                 "-rf",
                 os.path.join(work_dir, "pg_replslot", "*")])

    # copy postgres conf files over
    run_command(["ssh",
                 connection.host,
                 "sudo",
                 "cp",
                 "/opt/postgres/*",
                 work_dir])
    # fire up container
    container.run_container()


def is_port_open(port):
    sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    try:
        sock.bind(('localhost', port))
        # make sure this port isn't taken by an existing connection
        sesh = Session()
        port_conflicts = sesh.query(Connection).filter(and_(Connection.port == port,
                                                            Connection.host == 'localhost')).all()
        if port_conflicts:
            return False

        return True
    except OSError:
        return False
