# coding: utf-8
import datetime
import json
import logging.handlers
import socket
import time
import direct.infra.mysql_manager.libs.juggler as dj
from typing import Dict, Any

import kazoo
from direct.infra.mysql_manager.libs.configs import DtMysqlFailoverManagerConfig, DtAllDBInstanceConfig
from direct.infra.mysql_manager.libs.failover import DtMysqlFailoverManager, HostsSelector, DtMysqlReplicaSet

from direct.infra.mysql_manager.libs.zookeeper import ZKClient, zk_jdumps, zk_generate_lock_data

_logger = logging.getLogger(__name__)
_logger.addHandler(logging.NullHandler())


class DtMysqlAutoFailover(ZKClient):
    JUGGLER_SERVICE = {
        "working": "auto-failover.working.{instance}",
        "can_failover": "can_failover.{instance}"
    }
    JUGGLER_DESC = {
        "working": "from: dt-mymgr-auto-failover",
        "can_failover":
            "from: dt-mymgr-auto-failover; check zkguard heartbeat, 'dt-mymgr {instance} show' output and mysql status: "
            "https://obs.direct.yandex-team.ru/storages/get_replication_schema_graph?instance={instance}"
    }
    BASE_ZK_STATE: Dict[str, Any] = {
        "started": False,
        "start_ts": 0,
        "start_time": "",
        "finished": False,
        "finish_ts": 0,
        "finish_time": "",
    }

    def __init__(self, mymgr_config: DtMysqlFailoverManagerConfig, instance_config: DtAllDBInstanceConfig,
                 *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.fqdn = socket.getfqdn()

        self.inst_conf = instance_config
        self.mymgr_conf = mymgr_config
        self.lock_data = zk_generate_lock_data()
        for k, v in self.JUGGLER_SERVICE.items():
            self.__setattr__("juggler_service_" + k, v.format(instance=self.inst_conf.instance))
        for k, v in self.JUGGLER_DESC.items():
            self.__setattr__("juggler_description_" + k, v.format(instance=self.inst_conf.instance))

        self.lock = None
        self.mymgr = None
        self.mysql_rs_no_retries = None
        self.is_repl_alive = None
        self.master_fqdn = None

    def run(self):
        self.zk_start()
        self.init_failover_state(force=False)
        self.lock = self.zk.Lock(self.mymgr_conf.failover_zk_lock_path, identifier=self.lock_data)

        self.mymgr = DtMysqlFailoverManager(self.mymgr_conf, self.inst_conf)
        self.mymgr.zk_client.zk_start()

        self.master_fqdn = self.mymgr.guard.get_master()

        self.mysql_rs_no_retries = DtMysqlReplicaSet(
            instance_config=self.inst_conf, logger=self.logger,
            mysql_params=self.mymgr_conf.mysql_params, mysql_command_max_tries=1
        )

        if self.lock.acquire(blocking=False):
            try:
                self.logger.info("failover lock acquired, check if master failed ...")
                self.check_and_failover()

                dj.send_ok(host=self.fqdn, service=self.juggler_service_working,
                           description=self.juggler_description_working)
            finally:
                self.logger.info("release failover lock")
                self.lock.release()
        else:
            self.logger.info("can't acquire failover lock, current contenders: %s", self.lock.contenders())

    def check_and_failover(self):
        self.is_repl_alive = self.get_availability_data()

        # мастер тоже должен быть тут
        dead_repls = [host for host, alive in self.is_repl_alive.items() if not alive]
        self.logger.debug("repl liveliness status: %s; dead repls: %s", self.is_repl_alive, dead_repls)
        self.mymgr.prepare(new_master=self.mymgr.MASTER_AUTO, simple_switchover=False,
                           exclude_hosts=dead_repls)

        can_failover = self.failover_precheck() and self.mymgr.switchover_precheck()
        need_failover = self.is_master_failed()

        self.logger.info("can failover: %s; need failover: %s", can_failover, need_failover)
        dj.send_one(host=self.fqdn, service=self.juggler_service_can_failover,
                    status="OK" if can_failover else "CRIT",
                    description=self.juggler_description_can_failover)

        if need_failover and can_failover:
            self.init_failover_state(force=True)
            self.update_failover_state(started=True)
            if self.mymgr.switchover():
                self.update_failover_state(finished=True)
                return True
            return False
        return True

    def get_availability_data(self):
        self.logger.debug("get availability data")
        pings_window = self.mymgr_conf.failover_pings_window
        host_to_pings_stat = {x: [] for x in self.inst_conf.ready_replica_hosts}

        need_more = list(set(len(x) < pings_window for x in host_to_pings_stat.values()))
        assert len(need_more) == 1
        while need_more[0]:
            start_ts = time.time()
            cluster_results = self.mysql_rs_no_retries.exec("SELECT NOW()", raise_on=HostsSelector.EMPTY_SET,
                                                            require_data_on=HostsSelector.EMPTY_SET)
            ping_duration = time.time() - start_ts

            for (host, _), host_result in cluster_results.destination_to_result.items():
                host_to_pings_stat[host].append(bool(host_result and host_result.has_data()))

            need_more = list(set(len(x) < pings_window for x in host_to_pings_stat.values()))
            assert len(need_more) == 1

            time.sleep(max(self.mymgr_conf.failover_pings_iteration_sleep - ping_duration, 0))

        self.logger.debug("pings stat: %s", host_to_pings_stat)
        repl_to_status = {host: pings.count(False) < self.mymgr_conf.failover_min_failed_pings
                          for host, pings in host_to_pings_stat.items()}

        return repl_to_status

    def is_master_failed(self) -> bool:
        # если брать уже готовую статистику, смену мастера можно не проверять
        cur_master = self.mymgr.guard.get_master()
        if self.master_fqdn != cur_master:
            self.logger.info("master changed from %s to %s, stop failure detection",
                             self.master_fqdn, cur_master)
            return False

        return not self.is_repl_alive[self.master_fqdn]

    def failover_precheck(self) -> bool:
        self.logger.info(f"starting failover prechecks for {self.inst_conf.instance}")
        current_state = self.get_failover_state()

        time_from_last_failover = time.time() - current_state["start_ts"]
        if time_from_last_failover < self.mymgr_conf.failover_cooldown_time:
            self.logger.info("last auto failover initiated recently, wait for some time")
            return False

        alive_repls_dc = set(self.inst_conf.get_dc(host) for host, alive in self.is_repl_alive.items() if alive)
        alive_repls_dc_except_master = alive_repls_dc.difference([self.inst_conf.get_dc(self.master_fqdn)])
        self.logger.debug("datacenters with alive replicas: %s; master dc: %s", alive_repls_dc,
                          self.inst_conf.get_dc(self.master_fqdn))

        if len(alive_repls_dc_except_master) < self.inst_conf.auto_failover_min_alive_dc:
            self.logger.info("only %s datacenters with alive replicas (except master's dc), can't do failover",
                             len(alive_repls_dc_except_master))
            return False

        self.logger.info("failover prechecks passed")
        return True

    def init_failover_state(self, force=False):
        self.logger.info(f"init failover nodes for {self.mymgr_conf.instance}")
        state = self.BASE_ZK_STATE.copy()
        state["started_on"] = None

        self.zk.ensure_path(self.mymgr_conf.failover_zk_lock_path)
        try:
            self.zk.create(self.mymgr_conf.failover_zk_state_path, zk_jdumps(state), makepath=True)
            self.logger.debug(f"new failover state: {state}")
        except kazoo.exceptions.NodeExistsError:
            self.logger.info(f"already initialized at {self.mymgr_conf.failover_zk_state_path}")

        if force:
            self.logger.info(f"forced set failover state for {self.mymgr_conf.instance}")
            self.zk.set(self.mymgr_conf.failover_zk_state_path, zk_jdumps(state))
            self.logger.debug(f"new failover state: {state}")

    def get_failover_state(self):
        return json.loads(self.zk.get(self.mymgr_conf.failover_zk_state_path)[0])

    def update_failover_state(self, started=None, finished=None):
        state, zstat = self.zk.get(self.mymgr_conf.failover_zk_state_path)
        state = json.loads(state)
        self.update_base_state(state, started, finished)
        if started is not None:
            state["started_on"] = self.fqdn
        self.zk.set(self.mymgr_conf.failover_zk_state_path, zk_jdumps(state), zstat.version)

    @staticmethod
    def update_base_state(state, started=None, finished=None):
        if started is not None:
            state["started"] = started
            state["start_ts"] = time.time()
            state["start_time"] = datetime.datetime.now()
        if finished is not None:
            state["finished"] = finished
            state["finish_ts"] = time.time()
            state["finish_time"] = datetime.datetime.now()
