# coding: utf-8
import logging

import pymysql.cursors
import concurrent.futures

from typing import Collection, Optional, Callable, NoReturn, Iterable, Union

_logger = logging.getLogger(__name__)
_logger.addHandler(logging.NullHandler())


class MysqlHostCmdResult:
    """
    Результат выполнения одной команды на одной mysql-реплике.

    Существование экземпляра этого класса означает, что команда либо успешно выполнилась (вернув данные или нет),
    либо выполнить команду пытались, но не удалось (from_failed_connection).
    """

    @classmethod
    def from_failed_connection(cls, command, error, connection):
        return cls(command, data=tuple(), error=error, column_names=tuple(), host=connection.host, port=connection.port)

    def __init__(self, command, data, error, column_names, host, port):
        """
        :param command: выполненная команда ("SELECT ..." |
        :param data: список строк со значениями вида ((value1, value2, ...), (...), ...)
        :param error: None or Exception
        :param column_names: ("col1", "col2", ...)
        """
        self.command = command
        self.data = data
        self.error = error
        self.column_names = column_names if column_names else tuple()
        self.host = host
        self.port = port

        self._column_to_id = {column_names[i]: i for i in range(len(self.column_names))}

    def is_ok(self):
        return self.error is None

    def get_status_string(self):
        status = "OK" if self.is_ok() else f"ERROR: {type(self.error)} {self.error}"
        return status

    def has_data(self):
        # я не знаю, могут ли быть данные и ошибка одновременно
        # стоит ли сделать тут except: return False?
        return len(self.data) > 0 and len(self.data[0]) > 0

    def get_data(self, row: int, columns: Optional[Iterable[Union[int, str]]] = None,
                 as_dict=False) -> Union[dict, tuple]:
        c_ids = range(len(self.column_names))
        c_names = self.column_names

        if isinstance(columns, int):
            return self.data[row][columns]
        elif isinstance(columns, str):
            return self.data[row][self._column_to_id[columns]]
        elif columns is not None:
            c_ids = []
            c_names = []
            for c in columns:
                i = c if isinstance(c, int) else self._column_to_id[c]
                c_ids.append(i)
                c_names.append(self.column_names[i])

        d = (self.data[row][i] for i in c_ids)
        if as_dict:
            return dict(zip(c_names, d))
        return tuple(d)

    def data_iter(self, rows=None, columns=None, as_dict=False):
        if isinstance(rows, int):
            rows = (rows,)
        elif rows is None:
            rows = range(len(self.data))

        for row in rows:
            yield self.get_data(row, columns, as_dict)

    def __iter__(self):
        return self.data_iter()

    def __len__(self):
        return len(self.data)

    def __repr__(self):
        return f"{self.host}:{self.port} '{self.command}' {self.get_status_string()}"

    def __bool__(self):
        return self.is_ok()  # данных при этом может и не быть, это ок


class MysqlClusterResults:
    """
    Результаты выполнения команды на нескольких mysql-репликах
    """

    def __init__(self, command, results: Collection[MysqlHostCmdResult]):
        self.results = tuple(results)
        self.command = command

        self.destination_to_result = {(r.host, r.port): r for r in results}

    def is_ok(self):
        return len(self.results) > 0 and all(r.is_ok() for r in self.results)

    def get_failed(self):
        return [r for r in self.results if not r.is_ok()]

    def get_status_string(self):
        status = "ALL OK"
        if not self.is_ok():
            if len(self.results) > 0:
                status = "ALL HOSTS FAILED" if all(not r.is_ok() for r in self.results) else "SOME HOSTS FAILED"
            else:
                status = "EMPTY HOST SET"

        return status

    def __iter__(self):
        return iter(self.results)

    def __len__(self):
        return len(self.results)

    def __repr__(self):
        children_status = []
        for r in sorted(self.results, key=lambda x: f"{x.host}:{x.port}"):
            children_status.append(f"{r.host}:{r.port} {r.get_status_string()}")
        children_status = "; ".join(children_status)
        return f"'{self.command}': {self.get_status_string()} ({children_status})"

    def __bool__(self):
        return self.is_ok()


class MysqlException(Exception):
    def __init__(self, result: MysqlHostCmdResult):
        self.result = result

    def __repr__(self):
        return f"{self.result}"


def mysql_init_connection(host, port, defer_connect=True, **mysql_params):
    timeouts = {
        # pymysql не может с бесконечным connect_timeout, default 10
        "connect_timeout": mysql_params.get("connect_timeout", 10),
        "read_timeout": mysql_params.get("read_timeout", mysql_params.get("timeout", None)),
        "write_timeout": mysql_params.get("write_timeout", mysql_params.get("timeout", None)),
    }
    for k in ("host", "port", "defer_connect", "connect_timeout", "read_timeout", "write_timeout", "timeout"):
        mysql_params.pop(k, None)  # удаляются только из локальной копии mysql_params
    _logger.debug(f"connecting to mysql at {host}:{port}, defer_connect: {defer_connect}")
    conn = pymysql.connect(host=host, port=port, defer_connect=defer_connect, **timeouts, **mysql_params)
    return conn


def mysql_exec_cmd_raw(command, connection):
    _logger.debug(f"going to run command {command} on {connection.host}:{connection.port}")
    connection.ping(reconnect=True)

    column_names = None
    with connection.cursor() as cursor:
        if isinstance(command, str):
            _logger.debug("run sql [%s] on %s:%s", cursor.mogrify(command), connection.host, connection.port)
            cursor.execute(command)
        else:  # команда с параметрами
            _logger.debug("run sql [%s] on %s:%s", cursor.mogrify(*command), connection.host, connection.port)
            cursor.execute(*command)

        data = cursor.fetchall()
        if cursor.description is not None:
            column_names = tuple(x[0] for x in cursor.description)

        _logger.debug(f"command {command} finished on {connection.host}:{connection.port}; fetched data: {data}")
    return data, column_names


def mysql_exec_cmd_retriable(command, connection):
    try:
        return mysql_exec_cmd_raw(command, connection)
    except Exception as e:
        raise MysqlException(MysqlHostCmdResult.from_failed_connection(command, e, connection))


def mysql_exec_cmd(command, connection: pymysql.Connection,
                   retry_call: Optional[Callable] = None) -> MysqlHostCmdResult:
    """
    Всегда должна возвращать MysqlHostCmdResult с Exception (или дочерними классами) в result.error
    """

    # не хочу DictCursor, MysqlHostCmdResult ждет списка строк и колонки отдельно
    if connection.cursorclass is not pymysql.cursors.Cursor and connection.cursorclass is not pymysql.cursors.SSCursor:
        raise TypeError("Only Cursor or SSCursor classes supported for MysqlHostCmdResult")

    func = mysql_exec_cmd_retriable
    args = [command, connection]
    try:
        if callable(retry_call):
            data, column_names = retry_call(func, *args)
        else:
            data, column_names = func(*args)
    except MysqlException as e:
        return e.result

    return MysqlHostCmdResult(command=command, data=data, error=None, column_names=column_names,
                              host=connection.host, port=connection.port)


def mysql_exec_cmd_multi(command, connections: Collection[pymysql.Connection],
                         ready_callback: Optional[Callable[[MysqlHostCmdResult], NoReturn]] = None,
                         retry_call: Optional[Callable] = None) -> MysqlClusterResults:
    results = []
    if not connections:
        return MysqlClusterResults(command, results)

    with concurrent.futures.ThreadPoolExecutor(max_workers=len(connections)) as executor:
        worker_to_connection = {}
        for connection in connections:
            worker = executor.submit(mysql_exec_cmd, command, connection, retry_call)
            worker_to_connection[worker] = connection

        # запускаем без таймаута - он уже должен быть задан в connection
        # если очень хочется общий таймаут - можно, но надо ловить исключения as_completed
        for worker in concurrent.futures.as_completed(worker_to_connection):
            connection = worker_to_connection[worker]
            try:
                res = worker.result()  # таймаут тут не имеет смысла, потому что возвращаются уже completed futures
            except Exception as e:
                # MysqlException быть не должно, exec_cmd перехватывает его и возвращает result
                assert not isinstance(e, MysqlException)
                res = MysqlHostCmdResult.from_failed_connection(command, e, connection)

            if callable(ready_callback):
                ready_callback(res)
            results.append(res)

    return MysqlClusterResults(command, results)


def mysql_cmd_on_result_ready(result, logger=_logger, level=logging.INFO, prefix="", fmt="%(prefix)s%(result)s"):
    logger.log(level, fmt, {"prefix": prefix, "result": result})


def mysql_cmd_on_retry(exception, attempt_number, raised_after, logger=_logger, level=logging.INFO, prefix="",
                       fmt="%(prefix)sretrying %(exception)s: attempt %(attempt_number)s,"
                           "raised_after %(raised_after)s"):
    logger.log(level, fmt, {"prefix": prefix, "exception": exception, "attempt_number": attempt_number,
                            "raised_after": raised_after})
