#!/usr/bin/env python3
# coding: utf-8
"""
Демон для слежения за доступностью zk-кластера с mysql-машин для
предотвращения записи в мастер при сетевой изоляции.

В зависимости от статуса (реплика/мастер) и доступности кластера
выполняет различные действия (закрывает/открывает фаервол и тд).

Предполагается, что эта программа ничего не должна знать про
репликацию, переключения и тд.
"""

import argparse
import json
import logging
import os
import queue
import signal
import socket
import time
import direct.infra.mysql_manager.libs.juggler as dj

from direct.infra.mysql_manager.libs.helpers import init_logger, tcp_ping
from direct.infra.mysql_manager.libs.zookeeper import zk_init
from direct.infra.mysql_manager.libs.configs import DtMysqlFailoverManagerConfig
from subprocess import check_output, STDOUT, CalledProcessError
from kazoo.protocol.states import ZnodeStat, KazooState, KeeperState
from kazoo.exceptions import KazooException


_logger = logging.getLogger(__name__)
_logger.addHandler(logging.NullHandler())


class DtMysqlGuard:
    JUGGLER_SERVICE_WORKING = "zkguard.working"
    JUGGLER_DESC_PREFIX = "from: dt-mysql-zkguard; see logs in /var/log/mysql/zkguard.log"
    INVALID_ZSTAT_VERSION = -1
    DEFAULT_STATE = {
        "is_master": None,  # None - особенное состояние "не знаю, кто мастер". Когда невозможно сказать наверняка
        "connected": False,
        "zk_state_ver": INVALID_ZSTAT_VERSION,
        "zkguard_enabled": False,
    }
    STORABLE_STATE_KEYS = ["zkguard_enabled"]
    ZK_STATE_KEYS = ["is_master", "zkguard_enabled"]
    HEARTBEAT_INTERVAL = 30

    def __init__(self, mymgr_config: DtMysqlFailoverManagerConfig, logger=None):
        self.logger = _logger if not logger else logger
        self.logger = self.logger.getChild(self.__class__.__name__)

        self.mymgr_conf = mymgr_config

        self.instance = self.mymgr_conf.instance

        self.zk_state_path = self.mymgr_conf.guard_zk_state_path
        self.zk_heartbeat_path = None
        self.ext_tools = self.mymgr_conf.ext_tools

        self.fqdn = None
        self.event_queue = None
        self.state = None
        self.last_zk_state_update_event_time = 0
        self.last_close_master_time = 0
        self.last_heartbeat_timestamp = None
        self.zk = None

    def prepare(self):
        self.fqdn = socket.getfqdn()
        self.event_queue = queue.Queue()
        self.state = self.DEFAULT_STATE.copy()
        self.state.update(self.load_state_from_status_file())
        # это время последнего успешного получения сырых данных о состоянии из zk - оно никак не связанно
        # с валидностью этого стейта, просто хотим детектить невозможность чтения из zk
        self.last_zk_state_update_event_time = time.time()
        # для мониторинга случайных закрытий
        self.last_close_master_time = 0

        self.zk = zk_init(zk_token=self.mymgr_conf.zk_token, disable_kazoo_retries=True, **self.mymgr_conf.kazoo_params)
        self.zk.add_listener(self.connection_listener)
        self.zk.DataWatch(self.zk_state_path, self.state_watch)

        self.zk_heartbeat_path = self.mymgr_conf.guard_zk_heartbeat_path + '/' + self.fqdn

    def run(self):
        self.logger.info("run")
        signal.signal(signal.SIGINT, self.exit_handler)  # потоки watcher'ов могут отправлять нам sigint, runit - term
        signal.signal(signal.SIGTERM, self.exit_handler)  # в обоих случаях ничего не закрываем, просто выходим

        self.prepare()
        self.put_event("connect")

        seconds_since_last_heartbeat = None
        while True:
            if self.last_heartbeat_timestamp is not None:
                seconds_since_last_heartbeat = time.time() - self.last_heartbeat_timestamp
            if self.state['connected'] and (seconds_since_last_heartbeat is None or seconds_since_last_heartbeat > self.HEARTBEAT_INTERVAL):
                self.zk.ensure_path(self.zk_heartbeat_path)
                self.logger.debug(f"last heartbeat sent {seconds_since_last_heartbeat} seconds ago, sending heartbeat")
                self.send_heartbeat()

            try:
                self.logger.info("waiting for next event ...")
                item = self.event_queue.get(timeout=self.mymgr_conf.guard_iteration_sleep)
            except queue.Empty:
                self.logger.info("queue is empty")
                self.fetch_zk_state()
                continue

            try:
                if not self.process_one_event(**item):
                    break
            finally:
                self.event_queue.task_done()

    def fetch_zk_state(self):
        self.logger.info(f"fetch new state from zk; current state: {self.state}; "
                         f"zk state: {self.zk.state}, {self.zk.client_state}")
        try:
            if self.zk.exists(self.zk_state_path):
                zdata, zstat = self.zk.get(self.zk_state_path)
                self.put_event("zk_state_update", zdata=zdata, zstat=zstat, source="fetcher")
        except (KazooException, self.zk.handler.timeout_exception) as e:
            # если вдруг zk залипнет - периодическое обновление стейта + выход по таймауту помогут
            self.logger.warning("can't fetch new state from zk, ignoring: %s %s" % (type(e), e))

    def put_event(self, event, **kwargs):
        evt = {"event": event, **kwargs}
        self.logger.info(f"put {event} event: {kwargs}")
        self.event_queue.put(evt)

    def send_heartbeat(self):
        heartbeat_time = time.time()
        self.zk.set(self.zk_heartbeat_path, str(heartbeat_time).encode())
        self.last_heartbeat_timestamp = heartbeat_time

    def connect_zk(self):
        self.logger.info("start new iteration ...")

        try:
            self.zk.start(self.mymgr_conf.zk_start_timeout)  # если уже подключен к zk, второй раз не будет
        except (KazooException, self.zk.handler.timeout_exception) as e:
            # используем свои "ретраи", чтобы проще встроить все в схему изменяющихся состояний
            self.put_event("connect_failed")
            self.logger.warning("can't connect to zk cluster: %s %s" % (type(e), e))

    def process_one_event(self, **kwargs):
        self.logger.info(f"process event: {kwargs}")
        event = kwargs.pop("event")

        stop_and_exit = False
        new_state = self.state.copy()  # нужна именно копия!

        # event'ы из watcher'ов приходят в том порядке, в котором произошли в zk
        # т.к. watcher'ы запускаются в одном потоке, последовательно. Иначе нужно смотреть на "source"

        # from fetch_zk_state and state_watch
        if event == "zk_state_update":
            zstat_ver = self.INVALID_ZSTAT_VERSION
            if isinstance(kwargs.get("zstat"), ZnodeStat):
                zstat_ver = kwargs["zstat"].version

            if kwargs["source"] == "fetcher" and zstat_ver < self.state["zk_state_ver"]:
                self.logger.warning("new state version %s is older than current %s" %
                                    (kwargs["zstat"].version, self.state["zk_state_ver"]))
                # такое может быть, если более старые данные от fetch_zk_state попали в очередь после state_watch
                return True  # и не надо такое обрабатывать (ничего не делаем со стейтом)

            new_state.update(self.parse_raw_state(kwargs["zdata"]))
            new_state["zk_state_ver"] = zstat_ver
            self.last_zk_state_update_event_time = time.time()
        # from zk connection_listener
        elif event == "connection_established":
            new_state["connected"] = True
        # suspended не обрабатываем, считаем, что бесконечно в нем не будем (max_tries=0) + есть last_zk_state_update
        elif event == "connection_lost":
            new_state["is_master"] = None  # не знаем, где сейчас мастер
            new_state["connected"] = False
            # боюсь, в очередь прилетят неожиданные event'ы (connection listener и watcher'ы вроде как в разных тредах)
            # поэтому надежнее обработать стейт и выйти сразу, а не посылать отдельный exit event
            stop_and_exit = True
        # from connect_zk
        elif event == "connect_failed":
            # после - обрабатываем стейт, если слишком долго не могли подключиться, можно закрыть
            self.put_event("connect")
        elif event == "connect":
            self.connect_zk()
        # from exit_handler
        elif event == "exit":
            stop_and_exit = True

        self.change_state(new_state)
        self.process_current_state()

        if stop_and_exit:
            self.zk.stop()
            self.zk.close()
            return False

        return True

    def change_state(self, new_state):
        self.logger.info(f"change state to {new_state}")

        if new_state == self.state:
            self.logger.info("state not changed, proceeding")
            return
        self.logger.info("state changed from %s to %s" % (self.state, new_state))

        cur_state_master_is_valid = self.state["connected"] and self.state["is_master"] is not None \
                                    and self.state["is_master"]
        new_state_master_was_switched = new_state["connected"] and new_state["is_master"] is not None \
                                        and not new_state["is_master"]

        # были мастером с коннектом до zk, коннект пропал и кворум zk недоступен - закрываемся (с killall)
        if cur_state_master_is_valid and not new_state["connected"]:
            self.logger.info("looks like a fenced master, check it")
            if not self.is_zk_quorum_reachable():
                self.logger.info("fenced master, killall and close")
                self.close_master()
            else:
                self.logger.info("zk hosts are reachable, seems like zk cluster failure, nothing to do")
        # были мастером и перестали (по указанию из zk)
        elif cur_state_master_is_valid and new_state_master_was_switched:
            self.logger.info("not a master any more (switched in zk), killall and close")
            self.close_master()

        # остальные ситуации обработаем в process_current_state
        self.state = new_state

    def process_current_state(self):
        self.logger.info(f"processing current state {self.state} ...")
        # тут нигде не делаем killall, т.к. не знаем, были ли мы мастером раньше (см. change_state)

        if time.time() - self.last_zk_state_update_event_time > self.mymgr_conf.zk_start_timeout:
            self.logger.info("current state is too old - zk connection may be lost, check it")
            if not self.is_zk_quorum_reachable():
                self.logger.info("fenced replica (not sure slave or master) - close lfw")
                self.close_lfw()
            else:
                self.logger.info("zk hosts are reachable, seems like zk cluster failure, nothing to do")
            return

        if not self.state["connected"]:
            self.logger.info("can't process current state, zk is not connected")
            return

        if self.state["is_master"] is None:
            self.logger.info("can't process current state, master is unknown yet")
            # если очень долго не можем узнать мастер (невалидный json в zk-стейте, или не можем подключиться)
            return
        elif self.state["is_master"]:
            self.logger.info("I'm master, open lfw")
            self.open_lfw()
        else:
            self.logger.info("I'm slave, close lfw")
            self.close_lfw()

        self.update_status_file()
        # при нормальной работе стейт регулярно обрабатывается, шлем ОК
        dj.send_ok(host=self.fqdn, service=self.JUGGLER_SERVICE_WORKING + f".{self.instance}",
                   description=self.JUGGLER_DESC_PREFIX)

    def parse_raw_state(self, raw_state):
        """
        :param raw_state: состояние из zk (bytes, '{"master_fqdn": "none", "zkguard_enabled": true}')
        :return: dict, {"is_master": False, "zkguard_enabled": True}
        """
        self.logger.info(f"parse raw state: {raw_state}")
        # дефолтный стейт - будет при невалидном конфиге или отсутствущей ноде, в process state все обработаем
        state = {k: v for k, v in self.DEFAULT_STATE.items() if k in self.ZK_STATE_KEYS}

        if raw_state is not None:  # state-нода существует
            try:
                raw_state = json.loads(raw_state)  # loads понимает как str, так и bytes из zk
                self.logger.info(f"current master fqdn: {raw_state['master_fqdn']}, my fqdn: {self.fqdn}")

                state["is_master"] = raw_state["master_fqdn"] == self.fqdn
                state["zkguard_enabled"] = raw_state["zkguard_enabled"]
            except (json.JSONDecodeError, KeyError) as e:
                self.logger.error(f"can't decode state: {type(e)} {e}")

        self.logger.info(f"parsed state: {state}")
        return state

    def state_watch(self, *args):
        try:
            self.logger.info(f"state watch: {args}")

            zdata, zstat = args[:2]
            self.put_event("zk_state_update", zdata=zdata, zstat=zstat, source="watcher")
        except Exception:
            # watcher'ы в отдельных тредах, если один из них вываливается по exception, никто об этом не узнает
            # восстановить тред можно только пересоздав zkclient - проще отключиться совсем
            # все равно с багами в watcher'ах работать нельзя
            harakiri()
            raise

    def connection_listener(self, new_state):
        try:
            self.logger.info(f"connection listener: {new_state}")
            self.logger.info("current connection state: %s, %s" % (self.zk.state, self.zk.client_state))
            if new_state == KazooState.CONNECTED and self.zk.client_state == KeeperState.CONNECTED:
                self.put_event("connection_established")
            elif new_state == KazooState.LOST:
                self.put_event("connection_lost")
            # еще может быть suspended или readonly - их не обрабатываем, активный fetch должен помочь
        except Exception:
            harakiri()
            raise

    def is_zk_quorum_reachable(self):
        # проверять через icmp не очень честно (и нет либы, придется дергать ping)
        # компромисс - tcp check на zk порт, если кластер развалился, коннекты все равно должны проходить

        check_results = {(h, p): [] for h, p in self.mymgr_conf.zk_hosts}
        for host, port in self.mymgr_conf.zk_hosts:
            # TODO: параллельно + несколько пингов
            check_results[(host, port)] = tcp_ping(host, port, timeout=self.mymgr_conf.zk_ping_timeout)

        ok_hosts = sum(check_results.values())
        quorum = int(len(self.mymgr_conf.zk_hosts) / 2) + 1

        result = False
        if ok_hosts >= quorum:
            result = True

        self.logger.info(f"zk quorum reachable: {result}; check results: {check_results}; "
                         f"ok hosts: {ok_hosts}; quorum: {quorum};")
        return result

    def exit_handler(self, signum, _):
        self.logger.info(f"got signal {signum}, exiting")
        self.put_event("exit")

    def load_state_from_status_file(self):
        self.logger.info("load state from status file")
        # дефолтный стейт - будет при невалидном конфиге или отсутствущей ноде
        state = {k: v for k, v in self.DEFAULT_STATE.items() if k in self.STORABLE_STATE_KEYS}
        if not os.path.isfile(self.mymgr_conf.guard_status_file_path):
            self.logger.info("no status file found")
            return state

        try:
            with open(self.mymgr_conf.guard_status_file_path, 'r') as f:
                status = json.load(f)
        except (json.JSONDecodeError, KeyError) as e:
            self.logger.error(f"can't load state: {type(e)} {e}")

        # стейт и статус-файл - не одно и то же, в стейт идут только определенные ключи
        state = {k: v for k, v in status.items() if k in self.STORABLE_STATE_KEYS}
        self.logger.info(f"loaded state: {state}")
        return state

    def update_status_file(self):
        self.logger.info("update status file")
        status = {k: v for k, v in self.state.items() if k in self.STORABLE_STATE_KEYS}
        status["last_close_master_time"] = self.last_close_master_time
        with open(self.mymgr_conf.guard_status_file_path, 'w') as f:
            # честную атомарную запись не пытаемся делать, ничего критичного при отложенной записи не будет
            json.dump(status, f)

    def close_master(self):
        self.logger.info("close master from applicatons")
        self.last_close_master_time = time.time()
        self.update_status_file()
        # до закрытия фаервола могут пройти секунды, если iptables тормозит
        self.close_lfw()
        self.mysql_killall()

    def close_lfw(self):
        cmd = [self.ext_tools["lfw"], self.instance, "-write"]
        self.run_cmd(cmd, sudo=True)

    def open_lfw(self):
        cmd = [self.ext_tools["lfw"], self.instance, "+write"]
        self.run_cmd(cmd, sudo=True)

    def mysql_killall(self):
        cmd = [self.ext_tools["lm"], self.instance, "killall"]
        self.run_cmd(cmd, sudo=True)

    def run_cmd(self, cmd, sudo=False):
        if sudo:
            cmd = [self.ext_tools["sudo"], *cmd]
        if not self.state["zkguard_enabled"]:
            self.logger.info(f"dry-run cmd: {cmd}")
            return

        self.logger.info(f"run cmd: {cmd}")
        try:
            out = check_output(cmd, stderr=STDOUT, shell=False, timeout=self.mymgr_conf.ext_tools_timeout)
        except CalledProcessError as e:
            self.logger.error(f"cmd failed, output: {e.output}")
            raise
        self.logger.info(f"cmd finished, output: {out}")


def harakiri():
    os.kill(os.getpid(), signal.SIGINT)


def main():
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter, description=__doc__)
    parser.add_argument("-c", "--config", default=DtMysqlFailoverManagerConfig.default_path,
                        help=DtMysqlFailoverManagerConfig.description)
    parser.add_argument("-i", "--instance", required=True, help="local mysql instance to guard")
    args = parser.parse_args()

    logger = init_logger()
    logger.info(f"running with args: {args}")

    g = DtMysqlGuard(DtMysqlFailoverManagerConfig(args.instance, args.config), logger=logger)
    g.run()
    logger.info("finished")


if __name__ == "__main__":
    main()
