# coding: utf-8
import copy
import dataclasses
import datetime
import enum
import logging
import json
import operator
import random
import socket
import time
import uuid
from collections import namedtuple

import kazoo
import kazoo.exceptions

from typing import Union, Optional, Set, Iterable, Tuple, List, Dict, Any

from direct.infra.mysql_manager.libs.configs import DtAllDBInstanceConfig, DtMysqlFailoverManagerConfig
from direct.infra.mysql_manager.libs.helpers import DEFAULT_RETRY_CONF, retry_callable, tcp_ping, get_json_value,\
    wait_until, wait_while
from direct.infra.mysql_manager.libs.mysql import mysql_init_connection, MysqlClusterResults, mysql_exec_cmd, \
    mysql_exec_cmd_multi, mysql_cmd_on_result_ready, mysql_cmd_on_retry, MysqlHostCmdResult
from direct.infra.mysql_manager.libs.zk_managers import DtMysqlGuardManager, DtDBConfigManager
from direct.infra.mysql_manager.libs.zookeeper import zk_make_default_acls, ZKClient, zk_jdumps, zk_check_acls,\
    zk_generate_lock_data
from direct.infra.mysql_manager.libs.gtid import GtidSet

_logger = logging.getLogger(__name__)
_logger.addHandler(logging.NullHandler())


class HostsSelector(enum.Enum):
    EMPTY_SET = enum.auto()
    ALL = enum.auto()


HostSlavePos = namedtuple("HostSlavePos",
                          "host gtid_retrieved io_read gtid_executed sql_executed relay_pos sql_thread_state")
HostMasterPos = namedtuple("HostMasterPos", "host gtid_executed")
MasterId = namedtuple("MasterId", ["Master_Host", "Master_Port", "Master_Server_Id", "Master_UUID"])
HostToMasterId = namedtuple("HostToMasterId", "host master_id")


class DtMysqlReplicaSet:
    DEFAULT_HOSTS_SELECTORS = dict(
        include_hosts=HostsSelector.ALL, exclude_hosts=HostsSelector.EMPTY_SET,
        raise_on=HostsSelector.ALL, ignore_errors_on=HostsSelector.EMPTY_SET,
        require_data_on=HostsSelector.ALL, ignore_nodata_on=HostsSelector.EMPTY_SET,
    )

    def __init__(self, instance_config: DtAllDBInstanceConfig, logger=None, mysql_params=None,
                 mysql_command_max_tries=1, mysql_command_on_result_ready=None, mysql_command_on_retry=None,
                 **hosts_selectors):
        self.logger = _logger if not logger else logger
        self.logger = self.logger.getChild(self.__class__.__name__)

        self.inst_conf = instance_config
        self.mysql_params = {
            "user": self.inst_conf.mysql_user_super,
            "password": self.inst_conf.mysql_pass_super,
        }
        if mysql_params is not None and isinstance(mysql_params, dict):
            self.mysql_params.update(mysql_params)

        self.retry_conf = DEFAULT_RETRY_CONF
        self.mysql_command_on_retry = mysql_command_on_retry
        self.mysql_command_retry = retry_callable(self.retry_conf.upto_retries(mysql_command_max_tries - 1),
                                                  handle_error=mysql_command_on_retry)
        self.mysql_command_on_result_ready = mysql_command_on_result_ready

        self.host_to_connection = {}
        for host in self.inst_conf.ready_replica_hosts:
            self.host_to_connection[host] = mysql_init_connection(host, self.inst_conf.mysql_port,
                                                                  defer_connect=True, **self.mysql_params)

        self.global_hosts_selectors = hosts_selectors

    @property
    def global_hosts_selectors(self):
        return self._global_hosts_selectors

    @global_hosts_selectors.setter
    def global_hosts_selectors(self, value):
        self._global_hosts_selectors = value if value else {}

    def parse_hosts_selector(self, hosts: Optional[Union[HostsSelector, Iterable[str]]]) -> Set[str]:
        if hosts is HostsSelector.EMPTY_SET:
            return set()
        elif hosts is HostsSelector.ALL:
            return set(self.inst_conf.ready_replica_hosts)
        elif isinstance(hosts, Iterable) and not isinstance(hosts, str):
            return set(hosts)
        else:
            raise ValueError(f"Bad hosts selector {hosts}")

    def _prepare_selectors(self, **hosts_selectors):
        selectors = {}
        for selector_name in self.DEFAULT_HOSTS_SELECTORS.keys():
            selectors[selector_name] = self.parse_hosts_selector(
                hosts_selectors.get(selector_name, self.DEFAULT_HOSTS_SELECTORS[selector_name])
            )

            # глобольные селекторы должны быть приоритентее, чем переданные в параметрах
            selectors[selector_name].update(self.parse_hosts_selector(
                self.global_hosts_selectors.get(selector_name, HostsSelector.EMPTY_SET)
            ))

            self.logger.debug(f"prepared selector {selector_name}: {selectors[selector_name]}")

        hosts = self.parse_hosts_selector(HostsSelector.ALL). \
            intersection(selectors["include_hosts"]).difference(selectors["exclude_hosts"])
        raise_on = self.parse_hosts_selector(HostsSelector.ALL). \
            intersection(selectors["raise_on"]).difference(selectors["ignore_errors_on"])
        require_data_on = self.parse_hosts_selector(HostsSelector.ALL). \
            intersection(selectors["require_data_on"]).difference(selectors["ignore_nodata_on"])
        return hosts, raise_on, require_data_on

    def exec(self, command, raise_if_all_excluded=True, **hosts_selectors) -> MysqlClusterResults:
        self.logger.debug(f"exec {command} with local selectors: {hosts_selectors}, "
                          f"global selectors: {self.global_hosts_selectors}")

        hosts, raise_on, require_data_on = self._prepare_selectors(**hosts_selectors)
        self.logger.debug(f"selectors: exec on {hosts}, raise_on {raise_on}, require_data_on {require_data_on}")

        if not hosts and raise_if_all_excluded:
            raise Exception("No hosts for execution")

        connections = [self.host_to_connection[h] for h in hosts]
        cluster_results = mysql_exec_cmd_multi(command, connections, ready_callback=self.mysql_command_on_result_ready,
                                               retry_call=self.mysql_command_retry)

        if not cluster_results and len(cluster_results) > 0:
            failed_hosts = set(r.host for r in cluster_results.get_failed())
            if failed_hosts.intersection(raise_on):
                raise Exception("Some hosts failed", failed_hosts.intersection(raise_on))
            elif failed_hosts:
                self.logger.debug(f"ignoring errors on {failed_hosts} as requested")
        elif cluster_results and not all(r.has_data() for r in cluster_results):
            nodata_hosts = set(r.host for r in cluster_results if not r.has_data())
            if nodata_hosts.intersection(require_data_on):
                raise Exception("Some hosts returned empty data", nodata_hosts.intersection(require_data_on))
            elif nodata_hosts:
                self.logger.debug(f"ignoring nodata on {nodata_hosts} as requested")
        elif len(cluster_results) == 0:
            assert not raise_if_all_excluded
            self.logger.debug("empty host set, not executed")

        return cluster_results

    def exec_on_host(self, host, command, **kwargs) -> Optional[MysqlHostCmdResult]:
        cluster_results = self.exec(command, include_hosts=[host], **kwargs)
        assert len(cluster_results) <= 1
        if len(cluster_results) == 0:
            return None
        return cluster_results.results[0]

    def check_mysql(self, host, port):
        self.logger.info(f"check mysql availability at {host}:{port}")
        connection = mysql_init_connection(host, port, defer_connect=True, **self.mysql_params)

        if mysql_exec_cmd("SELECT NOW()", connection):
            return True
        return False

    def get_hosts_slave_pos_sorted(self, **kwargs) -> List[HostSlavePos]:
        slaves_status = self.get_slave_status(**kwargs)
        slaves_pos = []
        for host, status in slaves_status.items():
            io_read = "%s:%s" % (status["Master_Log_File"], status["Read_Master_Log_Pos"])
            sql_executed = "%s:%s" % (status["Relay_Master_Log_File"], status["Exec_Master_Log_Pos"])
            relay_pos = "%s:%s" % (status["Relay_Log_File"], status["Relay_Log_Pos"])

            gtid_retrieved = GtidSet(status["Retrieved_Gtid_Set"], keep_sorted=True)
            gtid_executed = GtidSet(status["Executed_Gtid_Set"], keep_sorted=True)
            slaves_pos.append(HostSlavePos(host, gtid_retrieved, io_read, gtid_executed, sql_executed, relay_pos,
                                           status["Slave_SQL_Running_State"]))

        # эта сортировка не гарантирует, что в [-1] окажется самый "накачанный" слейв, просто нужен порядок
        slaves_pos.sort(key=operator.itemgetter(1, 2, 3, 4, 5, 6, 0))
        return slaves_pos

    def get_slave_status(self, **kwargs):
        cluster_results = self.exec("SHOW SLAVE STATUS", **kwargs)

        slaves_status = {}
        for (host, _), host_result in cluster_results.destination_to_result.items():
            if not host_result or not host_result.has_data():
                continue
            slaves_status[host] = host_result.get_data(row=0, as_dict=True)

        return slaves_status

    def get_hosts_master_pos_sorted(self, **kwargs) -> Tuple[List[List[HostMasterPos]], List[HostMasterPos]]:
        """
        :return: Первый элемент -
        отсортированные по-возрастанию gtid_executed списки реплик и позиций, например (r1..r4 - fqdn реплик):
        [
            [
                (r1, <GtidSet [<Gtid sid_r2:1-100>, <Gtid sid_r1:1-100>, <Gtid sid_r4:1-100>, <Gtid sid_r3:1-100>]>),
                (r3, <GtidSet [<Gtid sid_r2:1-100>, <Gtid sid_r1:1-100>, <Gtid sid_r4:1-100>, <Gtid sid_r3:1-100>]>),
            ],
            [
                (r2, <GtidSet [<Gtid sid_r2:1-100>, <Gtid sid_r1:1-100>, <Gtid sid_r4:1-107>, <Gtid sid_r3:1-100>]>),
                (r4, <GtidSet [<Gtid sid_r2:1-100>, <Gtid sid_r1:1-100>, <Gtid sid_r4:1-107>, <Gtid sid_r3:1-100>]>),
            ],
        ]
        GtidSet'ы в каждом вложенном списке гарантированно идентичны друг другу, либо содержат только одну реплику.
        Внутри каждого списка с идентичными gtid реплики отсортированы по fqdn.
        Никакой проверки на консистентность тут нет!

        Второй элемент - сортированный (по GtidSet и fqdn) плоский список пар (replica_fqdn, <GtidSet>)
        """
        hosts_master_pos = []
        cluster_results = self.exec("SELECT @@global.gtid_executed", **kwargs)
        for (host, _), host_result in cluster_results.destination_to_result.items():
            if not host_result or not host_result.has_data():
                continue
            # hosts_master_pos:
            # (r1, <GtidSet [<Gtid sid_r2:1-100>, <Gtid sid_r1:1-100>, ...]>),
            # (r2, <GtidSet [<Gtid sid_r2:1-100>, <Gtid sid_r1:1-105>, ...]>),
            hosts_master_pos.append(HostMasterPos(host, GtidSet(host_result.data[0][0], keep_sorted=True)))

        # сортируем по GtidSet, затем по fqdn
        hosts_master_pos.sort(key=operator.itemgetter(1, 0))

        by_executed = []
        for i in range(len(hosts_master_pos)):
            if i > 0 and hosts_master_pos[i - 1].gtid_executed.is_clone_of(hosts_master_pos[i].gtid_executed):
                by_executed[-1].append(hosts_master_pos[i])
            else:
                by_executed.append([hosts_master_pos[i]])

        self.logger.debug("sorted replicas: %s", by_executed)
        return by_executed, hosts_master_pos

    def get_processlist_remote_clients(self, **kwargs):
        # тут ищем по всем replica_hosts, даже исключенным (в maintenance)
        req_id = "%s_%s" % (self.__class__.__name__, uuid.uuid4())
        cmd_template = f"SELECT /* {req_id} */ * FROM INFORMATION_SCHEMA.PROCESSLIST WHERE " + \
                       "user != 'system user' AND info NOT LIKE %s AND command NOT LIKE %s AND " + \
                       "host != 'localhost' AND left(host, length(host) - locate(':', reverse(host))) NOT IN " + \
                       "(" + ",".join(["%s"] * len(self.inst_conf.replica_hosts)) + ")"

        cmd = [cmd_template, (f"%{req_id}%", "Binlog Dump%", *self.inst_conf.replica_hosts)]
        return self.exec(cmd, **kwargs)

    def set_total_readonly(self, **kwargs):
        kwargs.update(require_data_on=HostsSelector.EMPTY_SET)
        # super автоматически выставляет и обычный read_only
        self.exec("SET global super_read_only=1", **kwargs)

    def unset_total_readonly(self, **kwargs):
        kwargs.update(require_data_on=HostsSelector.EMPTY_SET)
        # read_only автоматически снимает super_read_only
        self.exec("SET global read_only=0", **kwargs)


class DtMysqlFailoverManager:
    MASTER_AUTO = "auto"
    BASE_ZK_STATE: Dict[str, Any] = {
        "started": False,
        "start_ts": 0,
        "start_time": "",
        "finished": False,
        "finish_ts": 0,
        "finish_time": "",
    }
    ZKGUARD_HEARTBEAT_TIMEOUT_SECONDS = 90  # = 3 * HEARTBEAT_INTERVAL в zkguard/dt-mysql-zkguard.py

    @dataclasses.dataclass
    class SwitchoverParams:
        old_master: Optional[str] = None
        new_master: Optional[str] = None
        simple_switchover: Optional[bool] = None
        old_master_uuid: Optional[str] = None
        old_master_pos: Optional[GtidSet] = None
        sorted_replicas: Optional[tuple] = None

    def __init__(self, mymgr_config: DtMysqlFailoverManagerConfig, instance_config: DtAllDBInstanceConfig,
                 callbacks=None, logger=None):
        self.logger = _logger if not logger else logger
        self.logger = self.logger.getChild(self.__class__.__name__)

        self.inst_conf = instance_config
        self.mymgr_conf = mymgr_config

        if callbacks is None:
            callbacks = {}

        self.zk_acls = zk_make_default_acls(self.mymgr_conf.get_zk_params()["zk_token"])
        # почему бы и тут не отнаследоваться от ZKClient?
        self.zk_client = ZKClient(start_zk_now=False, logger=self.logger, disable_kazoo_retries=False,
                                  default_acl=self.zk_acls, **self.mymgr_conf.get_zk_params())
        self.zk = self.zk_client.zk
        self.lock = self.zk.Lock(self.mymgr_conf.switchover_zk_lock_path, identifier=zk_generate_lock_data())

        self.guard = DtMysqlGuardManager(mymgr_config=self.mymgr_conf,
                                         logger=self.logger,
                                         zk_client=self.zk, **self.mymgr_conf.get_zk_params())
        self.db_config = DtDBConfigManager(db_config_zk_path=self.inst_conf.db_config_zk,
                                           logger=self.logger,
                                           zk_client=self.zk, **self.mymgr_conf.get_zk_params())

        self.mysql_rs = DtMysqlReplicaSet(
            instance_config=self.inst_conf, logger=self.logger,
            mysql_params=self.mymgr_conf.mysql_params,
            mysql_command_max_tries=self.mymgr_conf.mysql_command_max_tries,
            mysql_command_on_result_ready=callbacks.get("mysql_command_on_result_ready", mysql_cmd_on_result_ready),
            mysql_command_on_retry=callbacks.get("mysql_command_on_retry", mysql_cmd_on_retry),
        )

        self.waiting_until = wait_until(timeout=self.mymgr_conf.switchover_waiting_steps_timeout,
                                        iteration_sleep=self.mymgr_conf.switchover_waiting_iteration_sleep)

        self.waiting_while = wait_while(timeout=self.mymgr_conf.switchover_waiting_steps_timeout,
                                        iteration_sleep=self.mymgr_conf.switchover_waiting_iteration_sleep)

        # могут менять состояние и зависеть друг от друга
        self.precheck_steps = self.fill_step_indexes([
            {"name": "mymgr_healthcheck", "func": self.mymgr_healthcheck},
            {"name": "check_switchover_parameters", "func": self.check_switchover_parameters},
            {"name": "check_interrupted_switchover", "func": self.check_interrupted_switchover},
            {"name": "mysql_cluster_healthcheck", "func": self.mysql_cluster_healthcheck},
            {"name": "check_zkguard_on_old_master", "func": self.check_zkguard_on_old_master},
            {"name": "infer_old_master_uuid", "func": self.infer_old_master_uuid},
            {"name": "check_slaves_consistency", "func": self.check_slaves_consistency},
        ])

        self.switchover_steps = self.fill_step_indexes([
            # всегда
            {"name": "close_lfw_and_wait", "func": self.close_old_master_lfw},

            # только при штатном переключении (fenced + get_pos)
            {"name": "wait_old_master_fenced", "func": self.waiting_until(self.old_master_fenced)},
            {"name": "wait_and_get_old_master_pos", "func": self.waiting_until(self.get_old_master_pos)},

            # только при failover (при штатном тоже ничего страшного не будет выполнить это, но зачем)
            {"name": "wait_slaves_relay_logs_stalled", "func": self.waiting_until(self.slaves_relay_logs_stalled)},
            {"name": "stop_io_thread", "func": self.stop_io_thread},
            {"name": "wait_slaves_applied_all_relay_logs",
             "func": self.waiting_until(self.slaves_applied_all_relay_logs)},

            # только при штатном переключении (потому что знаем позицию мастера)
            {"name": "wait_for_slaves", "func": self.waiting_until(self.slaves_in_sync_with_master)},

            # всегда
            {"name": "infer_old_master_uuid", "func": self.infer_old_master_uuid},
            {"name": "check_slaves_consistency", "func": self.check_slaves_consistency},
            {"name": "select_best_slave", "func": self.select_best_slave},
            {"name": "change_master", "func": self.change_mysql_master},
            {"name": "set_semisync_master", "func": self.set_semisync_master},
            {"name": "set_semisync_slave", "func": self.set_semisync_slave},
            {"name": "start_slaves", "func": self.start_slaves},
            {"name": "open_lfw_and_wait", "func": self.open_new_master_lfw},
            {"name": "change_db_config", "func": self.change_db_config},
        ])
        self.params = self.SwitchoverParams()
        self._backup_params = self.SwitchoverParams()

    # ==внешние методы

    def prepare(self, new_master, old_master=None, simple_switchover=True, **hosts_selectors):
        self.params.old_master = old_master if old_master is not None else self.guard.get_master()
        self.params.new_master = new_master
        self.params.simple_switchover = simple_switchover

        self.store_params()

        # глобальный exclude_hosts
        self.mysql_rs.global_hosts_selectors = hosts_selectors

    def switchover_precheck(self) -> bool:
        self.restore_params()

        self.logger.info(f"starting switchover prechecks for {self.inst_conf.instance}")
        return self.run_steps(self.precheck_steps, on_start_and_finish=None)

    def switchover(self) -> bool:
        self.restore_params()

        # не выполняем предварительные проверки тут
        if self.lock.acquire(blocking=False):
            try:
                self.logger.debug("switchover lock acquired")
                self.init_switchover_state(force=True)
                self.update_switchover_state(started=True)

                self.logger.info(f"switching {self.inst_conf.instance} master from '{self.params.old_master}' "
                                 f"to '{self.params.new_master}'")
                if self.run_steps(self.switchover_steps, on_start_and_finish=self.update_step_state):
                    self.update_switchover_state(finished=True)
                    return True
            finally:
                self.logger.debug("release switchover lock")
                self.lock.release()
        else:
            self.logger.info("can't acquire switchover lock, current contenders: %s", self.lock.contenders())

        return False

    # ==внутренние методы
    def is_zkguard_alive(self, host):
        heartbeat_node = self.mymgr_conf.guard_zk_heartbeat_path + '/' + host
        heartbeat_timestamp = float(self.zk.get(heartbeat_node)[0].decode())
        return time.time() - heartbeat_timestamp < self.ZKGUARD_HEARTBEAT_TIMEOUT_SECONDS

    # ===switchover prechecks
    def mymgr_healthcheck(self):
        check_acls_path = [
            self.mymgr_conf.zk_prefix,
            self.mymgr_conf.guard_zk_state_path,
            self.mymgr_conf.guard_zk_lock_path,
            self.mymgr_conf.switchover_zk_state_path,
            self.mymgr_conf.switchover_zk_lock_path,
            self.inst_conf.db_config_zk
        ]
        for path in check_acls_path:
            if not zk_check_acls(self.zk, path, self.zk_acls):
                raise Exception(f"Bad acls in path {path}, see logs")

        node_data, _ = self.zk.get(self.inst_conf.db_config_zk)
        node_data = json.loads(node_data)
        master = get_json_value(node_data, self.inst_conf.db_config_master_node)
        self.logger.debug(f"current master in dbconfig: {master}")

    def check_interrupted_switchover(self):
        last_state = self.get_switchover_state()
        for step in last_state["switchover_steps"]:
            if step["state"]["started"] and not step["state"]["finished"]:
                raise Exception(f"Last switchover failed on step '{step['name']}', fix state (and mysql) manually")

        if last_state["started"] and not last_state["finished"]:
            raise Exception("Last switchover failed, fix state (and replicas) manually")

    def check_switchover_parameters(self):
        if self.params.old_master not in self.inst_conf.ready_replica_hosts:
            raise Exception(f"Old master '{self.params.old_master}' not found in {self.inst_conf.instance} replicas",
                            self.inst_conf.ready_replica_hosts)

        if self.params.new_master == self.params.old_master:
            raise Exception("New master '%s' is same as current master '%s'" %
                            (self.params.new_master, self.params.old_master))

        if self.params.new_master != self.MASTER_AUTO and self.params.new_master not in \
                self.inst_conf.ready_replica_hosts:
            raise Exception(f"New master '{self.params.new_master}' not found in {self.inst_conf.instance} replicas",
                            self.inst_conf.ready_replica_hosts)

        # это не то же самое, что nonzero_weight_slaves в mysql_cluster_healthcheck
        if self.params.new_master != self.MASTER_AUTO and\
                self.inst_conf.get_switchover_weight(self.params.new_master) == 0:
            raise Exception(f"New master '{self.params.new_master}' switchover weight is zero")

    def mysql_cluster_healthcheck(self):
        slaves_status = self.mysql_rs.get_slave_status(exclude_hosts=[self.params.old_master])
        slave_to_master = []
        nonzero_weight_slaves = []
        for host, status in slaves_status.items():
            if self.inst_conf.get_switchover_weight(host) > 0:
                nonzero_weight_slaves.append(host)

            if status["Slave_SQL_Running"] != "Yes":
                raise Exception("Bad SQL thread status '%s' on slave %s" % (status["Slave_SQL_Running"], host))
            if status["Seconds_Behind_Master"] is not None and \
                    int(status["Seconds_Behind_Master"]) > self.mymgr_conf.switchover_max_seconds_behind_master:
                raise Exception(f"Slave {host} is too far from master, "
                                f"seconds_behind: {status['Seconds_Behind_Master']}")
            if status["Auto_Position"] is None or int(status["Auto_Position"]) != 1:
                raise Exception(f"Auto position disabled on slave {host}")

            if not self.is_zkguard_alive(host):
                raise Exception(f"zkguard hasn't sent heartbeat for more than {self.ZKGUARD_HEARTBEAT_TIMEOUT_SECONDS} seconds on slave {host}")

            slave_to_master.append(HostToMasterId(host, MasterId(*[status[k] for k in MasterId._fields])))

        if not nonzero_weight_slaves:
            raise Exception("Need at least one slave with nonzero weight")

        if len(set(x.master_id for x in slave_to_master)) != 1:
            raise Exception("Not all slaves replicate same master",
                            [x.host for x in slave_to_master if x.master_id != slave_to_master[0].master_id])

    def check_zkguard_on_old_master(self):
        if not self.params.simple_switchover:
            return True
        old_master_host = self.params.old_master
        if not self.is_zkguard_alive(old_master_host):
            raise Exception(f"zkguard hasn't sent heartbeat for more than {self.ZKGUARD_HEARTBEAT_TIMEOUT_SECONDS} seconds on old master {old_master_host}")

    # ===switchover steps
    def close_old_master_lfw(self):
        self.guard.close_master_lfw()
        self.waiting_while(tcp_ping)(self.params.old_master, self.inst_conf.lfw_port,
                                     timeout=self.mymgr_conf.mysql_params["connect_timeout"])

        if not self.params.simple_switchover:
            # TODO: set localhost:0 in dbconfig?
            pass

    def old_master_fenced(self):
        if not self.params.simple_switchover:
            self.logger.info("failover mode active, skip step")
            return True

        trx = self.mysql_rs.exec("SELECT * FROM INFORMATION_SCHEMA.INNODB_TRX", include_hosts=[self.params.old_master],
                                 require_data_on=HostsSelector.EMPTY_SET)

        clients = self.mysql_rs.get_processlist_remote_clients(include_hosts=[self.params.old_master],
                                                               require_data_on=HostsSelector.EMPTY_SET)
        active_clients = (trx and trx.results[0].has_data()) or (clients and clients.results[0].has_data())

        if not active_clients:
            self.mysql_rs.set_total_readonly(include_hosts=[self.params.old_master])

        return not active_clients

    def get_old_master_pos(self):
        if not self.params.simple_switchover:
            self.logger.info("failover mode active, skip step")
            return True

        pos_before = GtidSet(self.mysql_rs.exec_on_host(self.params.old_master,
                                                        "SELECT @@global.gtid_executed").data[0][0])
        self.logger.info("wait for %d sec and make sure master has stopped sql execution",
                         self.mymgr_conf.switchover_min_trx_silence_time)
        time.sleep(self.mymgr_conf.switchover_min_trx_silence_time)

        pos_after = GtidSet(self.mysql_rs.exec_on_host(self.params.old_master,
                                                       "SELECT @@global.gtid_executed").data[0][0])

        self.logger.debug("master gtid_executed before: %s; after: %s", pos_before, pos_after)

        nothing_changed = pos_before.is_clone_of(pos_after)
        if nothing_changed:
            self.params.old_master_pos = pos_after
            self.logger.info("master stopped at %s", self.params.old_master_pos)

        return nothing_changed

    def slaves_in_sync_with_master(self):
        if not self.params.simple_switchover:
            self.logger.info("failover mode active, skip step")
            return True

        _, hosts_master_pos = self.mysql_rs.get_hosts_master_pos_sorted()
        self.logger.debug("check slaves in sync with master, hosts_master_pos: %s", hosts_master_pos)

        synced = all(x.gtid_executed.is_clone_of(self.params.old_master_pos) for x in hosts_master_pos)
        if not synced:
            lagging_repls = [x.host for x in hosts_master_pos if x.gtid_executed < self.params.old_master_pos]
            errant_repls = [x.host for x in hosts_master_pos if x.gtid_executed not in self.params.old_master_pos]
            if errant_repls:
                raise Exception("slave %s transactions not fully included in master's" % (errant_repls,))
            self.logger.info("waiting for slaves %s", lagging_repls)

        return synced

    def slaves_relay_logs_stalled(self):
        if self.params.simple_switchover:
            self.logger.info("switchover mode active, skip step")
            return True

        pos_before = tuple(map(operator.attrgetter("host", "gtid_retrieved", "io_read"),
                               self.mysql_rs.get_hosts_slave_pos_sorted(exclude_hosts=[self.params.old_master])))

        self.logger.info("wait for %d sec and make sure all slaves have stopped relay logs download",
                         self.mymgr_conf.switchover_min_trx_silence_time)
        time.sleep(self.mymgr_conf.switchover_min_trx_silence_time)

        pos_after = tuple(map(operator.attrgetter("host", "gtid_retrieved", "io_read"),
                              self.mysql_rs.get_hosts_slave_pos_sorted(exclude_hosts=[self.params.old_master])))

        self.logger.debug("slaves (fqd, gtid_retrieved, io_thread) pos before: %s; after: %s", pos_before, pos_after)

        # тут достаточно сравнить GtidSet'ы как множества (==)
        nothing_changed = pos_before == pos_after
        if nothing_changed:
            self.logger.info("slaves stopped at %s", pos_after)

        return nothing_changed

    def stop_io_thread(self):
        if self.params.simple_switchover:
            self.logger.info("switchover mode active, skip step")
            return

        self.mysql_rs.exec("STOP SLAVE IO_THREAD", require_data_on=HostsSelector.EMPTY_SET)

    def slaves_applied_all_relay_logs(self):
        if self.params.simple_switchover:
            self.logger.info("switchover mode active, skip step")
            return True

        slaves_pos = self.mysql_rs.get_hosts_slave_pos_sorted(exclude_hosts=[self.params.old_master])
        for pos in slaves_pos:
            self.logger.debug("slave pos: %s", pos)

            relay_log_applied =\
                pos.io_read == pos.sql_executed and \
                pos.gtid_retrieved in pos.gtid_executed and \
                pos.sql_thread_state == "Slave has read all relay log; waiting for more updates"
            if not relay_log_applied:
                self.logger.info("wait while %s applies relay log (%s from %s applied)",
                                 pos.host, pos.sql_executed, pos.io_read)
                return False
        return True

    def infer_old_master_uuid(self):
        slave_to_master_uuid = {}

        if self.params.simple_switchover:
            master_uuid_res = self.mysql_rs.exec_on_host(self.params.old_master, "SELECT @@global.server_uuid")
            slave_to_master_uuid["old_master:" + self.params.old_master] = master_uuid_res.data[0][0]
            self.logger.debug("fetched server_uuid from old master: %s", master_uuid_res.data[0][0])

        slaves_status = self.mysql_rs.get_slave_status(exclude_hosts=[self.params.old_master])
        for k, v in slaves_status.items():
            slave_to_master_uuid[k] = v["Master_UUID"]

        master_uuids = list(set(slave_to_master_uuid.values()))
        if len(master_uuids) != 1:
            raise Exception(f"Not all slaves have same Master_UUID: {slave_to_master_uuid}")

        self.params.old_master_uuid = master_uuids[0]
        self.logger.info(f"use {self.params.old_master_uuid} as old master uuid")

    def check_slaves_consistency(self):
        self.params.sorted_replicas, hosts_master_pos = self.mysql_rs.get_hosts_master_pos_sorted()
        if not self.params.sorted_replicas:
            raise Exception("Empty gtid-sorted replicas list")

        self.logger.debug("use base gtids from %s", hosts_master_pos[-1])
        base_host = hosts_master_pos[-1].host
        base_gtids = hosts_master_pos[-1].gtid_executed

        for host, gtids in hosts_master_pos:
            logging.debug(f"checking gtids on {host}: {gtids}")
            if len([x.uuid for x in gtids]) != len(set(x.uuid for x in gtids)):
                raise Exception(f"Duplicate uuids on replica {host}")

            differed_gtids = set(base_gtids).symmetric_difference(gtids)
            errant_gtids = sorted(x for x in differed_gtids if x.uuid != self.params.old_master_uuid)
            master_gtids = sorted(x for x in differed_gtids if x.uuid == self.params.old_master_uuid)
            logging.debug(f"master_gtids: {master_gtids}; errant_gtids: {errant_gtids}")

            if errant_gtids:
                raise Exception(f"Errant transactions found on {host} and {base_host}: {errant_gtids}")
            if not (len(master_gtids) == 0 or (len(master_gtids) == 2 and master_gtids[0] in master_gtids[1])):
                raise Exception("Bad master_gtids, must be 0 or 2 sets where one is superset of another: " +
                                str(master_gtids))

    def select_best_slave(self):
        self.logger.info("selecting best master as requested")

        most_recent_slaves = [x.host for x in self.params.sorted_replicas[-1] if x.host != self.params.old_master]
        self.logger.debug(f"most recent slaves: {most_recent_slaves}")

        if not most_recent_slaves:
            raise Exception("All slaves behind old master. Can't select best new master")

        if self.params.new_master != self.MASTER_AUTO:
            if self.params.new_master in most_recent_slaves:
                self.logger.info(f"manually specified new master: {self.params.new_master}")
                return self.params.new_master
            else:
                raise Exception("Manually specified master %s not found in most_recent_slaves %s" %
                                (self.params.new_master, most_recent_slaves))

        top_priority_slaves = []
        for host in most_recent_slaves:
            if self.inst_conf.get_switchover_weight(host) == 0:
                continue
            elif not top_priority_slaves or \
                    self.inst_conf.get_switchover_weight(host) == \
                    self.inst_conf.get_switchover_weight(top_priority_slaves[-1]):
                top_priority_slaves.append(host)
            elif self.inst_conf.get_switchover_weight(host) > \
                    self.inst_conf.get_switchover_weight(top_priority_slaves[-1]):
                top_priority_slaves = [host]

        if not top_priority_slaves:
            raise Exception("All recent slaves have zero switchover weight. Can't select best new master")
        self.logger.debug(f"top priority slaves: {top_priority_slaves}")

        self.params.new_master = random.choice(top_priority_slaves)
        self.logger.info(f"random choice from {top_priority_slaves}: {self.params.new_master}")

    def change_mysql_master(self):
        change_master_cmd = [
            "CHANGE MASTER TO master_host=%s, master_port=%s, "
            "master_user=%s, master_password=%s, master_auto_position=1",
            (self.params.new_master, self.inst_conf.mysql_port,
             self.inst_conf.mysql_user_rplcat, self.inst_conf.mysql_pass_rplcat)
        ]
        self.mysql_rs.exec("STOP SLAVE", require_data_on=HostsSelector.EMPTY_SET)
        self.mysql_rs.exec(change_master_cmd, exclude_hosts=[self.params.new_master],
                           require_data_on=HostsSelector.EMPTY_SET)
        self.mysql_rs.exec_on_host(self.params.new_master, "RESET SLAVE ALL",
                                   require_data_on=HostsSelector.EMPTY_SET)

    def set_semisync_master(self):
        if self.inst_conf.semisync_enable:
            self.mysql_rs.exec("SET GLOBAL rpl_semi_sync_slave_enabled=OFF",
                                include_hosts=[self.params.new_master],
                                require_data_on=HostsSelector.EMPTY_SET)
            self.mysql_rs.exec("SET GLOBAL rpl_semi_sync_master_enabled=ON",
                                include_hosts=[self.params.new_master],
                                require_data_on=HostsSelector.EMPTY_SET)

    def set_semisync_slave(self):
        if self.inst_conf.semisync_enable:
            self.mysql_rs.exec("SET GLOBAL rpl_semi_sync_slave_enabled=ON",
                                exclude_hosts=[self.params.new_master],
                                require_data_on=HostsSelector.EMPTY_SET)
            self.mysql_rs.exec("SET GLOBAL rpl_semi_sync_master_enabled=OFF",
                                exclude_hosts=[self.params.new_master],
                                require_data_on=HostsSelector.EMPTY_SET)

    def start_slaves(self):
        if self.params.simple_switchover:
            self.mysql_rs.unset_total_readonly(include_hosts=[self.params.old_master])

        self.mysql_rs.exec("START SLAVE", exclude_hosts=[self.params.new_master],
                           require_data_on=HostsSelector.EMPTY_SET)

    def open_new_master_lfw(self):
        self.guard.set_master(self.params.new_master)
        self.waiting_until(self.mysql_rs.check_mysql)(self.params.new_master, self.inst_conf.lfw_port)

    def change_db_config(self):
        self.db_config.set_master(self.inst_conf, self.params.new_master)

    # ===helper functions
    def check_master(self):
        host = self.guard.get_master()
        if host == DtMysqlGuardManager.NO_MASTER:  # кто-то уже переключает руками? Игнорируем
            return True

        return self.mysql_rs.check_mysql(host, port=self.inst_conf.lfw_port)

    def init_switchover_state(self, force=False):
        self.logger.info(f"init switchover nodes for {self.mymgr_conf.instance}")

        switchover_state = self.BASE_ZK_STATE.copy()
        switchover_state["started_on"] = None
        switchover_state.update(dataclasses.asdict(self.params))
        switchover_state["switchover_steps"] = []
        for step in self.switchover_steps:
            switchover_state["switchover_steps"].append({"name": step["name"], "state": self.BASE_ZK_STATE.copy()})

        self.zk.ensure_path(self.mymgr_conf.switchover_zk_lock_path)
        try:
            self.zk.create(self.mymgr_conf.switchover_zk_state_path, zk_jdumps(switchover_state), makepath=True)
            self.logger.debug(f"new switchover state: {switchover_state}")
        except kazoo.exceptions.NodeExistsError:
            self.logger.info(f"already initialized at {self.mymgr_conf.switchover_zk_state_path}")

        if force:
            self.logger.info(f"forced set switchover state for {self.mymgr_conf.instance}")
            self.zk.set(self.mymgr_conf.switchover_zk_state_path, zk_jdumps(switchover_state))
            self.logger.debug(f"new switchover state: {switchover_state}")

    def get_switchover_state(self):
        return json.loads(self.zk.get(self.mymgr_conf.switchover_zk_state_path)[0])

    def update_switchover_state(self, started=None, finished=None):
        state, zstat = self.zk.get(self.mymgr_conf.switchover_zk_state_path)
        state = json.loads(state)
        self.update_base_state(state, started, finished)
        state.update(dataclasses.asdict(self.params))
        if started is not None:
            state["started_on"] = socket.getfqdn()
        self.zk.set(self.mymgr_conf.switchover_zk_state_path, zk_jdumps(state), zstat.version)

    def update_step_state(self, step_index, started=None, finished=None):
        state, zstat = self.zk.get(self.mymgr_conf.switchover_zk_state_path)
        state = json.loads(state)
        step_state = state["switchover_steps"][step_index]["state"]
        self.update_base_state(step_state, started, finished)
        self.zk.set(self.mymgr_conf.switchover_zk_state_path, zk_jdumps(state), zstat.version)

    def run_steps(self, steps, on_start_and_finish=None) -> bool:
        for step in steps:
            step_name, func = step["name"], step["func"]
            self.logger.info(f"{step['index']}: starting step '{step_name}'")

            try:
                if callable(on_start_and_finish):
                    on_start_and_finish(step["index"], started=True)

                func()

                if callable(on_start_and_finish):
                    on_start_and_finish(step["index"], finished=True)
            except Exception as e:
                self.logger.info(f"step '{step_name}' failed: %s %s", type(e), e)
                # debug обычно попадает в файл с логом
                self.logger.debug("step %s exception info", step_name, exc_info=True)
                return False

            self.logger.info(f"step '{step_name}' done")
        return True

    def restore_params(self):
        self.params = copy.deepcopy(self._backup_params)

    def store_params(self):
        self._backup_params = copy.deepcopy(self.params)

    @staticmethod
    def update_base_state(state, started=None, finished=None):
        if started is not None:
            state["started"] = started
            state["start_ts"] = time.time()
            state["start_time"] = datetime.datetime.now()
        if finished is not None:
            state["finished"] = finished
            state["finish_ts"] = time.time()
            state["finish_time"] = datetime.datetime.now()

    @staticmethod
    def fill_step_indexes(steps):
        idx = 0
        for step in steps:
            step["index"] = idx
            idx += 1
        return steps
