#!/usr/bin/env python
# -*- encoding: utf-8 -*-

"""
Файл с тасками fabric для автоматизации переключения мастер-базы в Директе
Для переключения использует существующий инструментарий:
 zk-db-config - для получения текущего мастера и записи нового
 clus - для получения списка хостов
 lm - вся работа с локальным инстансом mysql
 lfw - локальный файрвол

 штатное переключение:
 fab --disable-known-hosts --timeout=5 --command-timeout=60 --connection-attempts=2 --colorize-errors -f /etc/yandex-direct/fab_switch_master.py switch-master:ppc:1,ppcdata1-01t.ppc.yandex.ru,paranoid=False,parallel=True,ignore_bad_repl=False,wait_slaves=True,force_change_master=False,use_gtid=True

 paranoid - спрашивать подтверждения всех деструктивных действий
 parallel - параллельный режим ssh
 use_gtid - при возможности (на всех репликах правильные настройки), переключаться с master_auto_position=1
 sec_behind - игнорировать отставание реплик в пределах sec_behind

 Менять дефолтные значения с крайней осторожностью:
 ignore_bad_repl - игнорировать проблемы в репликации (почти никогда не должно быть нужно, кроме случаев, когда реплики пытаются реплицироваться с недоступного мастера)
 wait_slaves - ждать, пока слейвы догонят мастера
               --exclude old.master.host частично решает эту проблему, но полезно, если до мастера уже не достучаться по ssh, и gtid-реплики останавились на разных позициях.
               даже при wait_slaves=False скрипт попытается отследить неправильное переключение при разных позициях слейвов (и проверит, что везде включен gtid)
               сейчас слейвы не останавливают репликацию перед переключением, при проблемах в сети все равно можно получить странную конфигурацию =(
 force_change_master - не проверять, что позиция реплики совпадает с позицией старого мастера при переключении, нужно при переключении на наиболее свежую gtid-реплику
"""

import json, sys, socket, time, re

import fabric.state
from fabric.api import *
from fabric.utils import error
from fabric.tasks import Task
import fabric.contrib.console as console
import fabric.colors as colors
from fabric.exceptions import NetworkError

#env.user = 'cloud-user'
env.sudo_user = 'root'
fabric.state.output.running = False


def sleep_gen(sleep_time=0.1, timeout=None):
    """
    генератор пауз - чтобы можно было раз в sleep_time секунд 
    повторять какие-то действия, но не дольше timeout
    """
    start_time = time.time()
    while True:
        if timeout is not None and time.time() - start_time > timeout:
            abort("Timeout")
        yield None
        time.sleep(sleep_time)


class DirectDbInfo(object):
    """
    информация о базе - db_config, список серверов
    """
    CLUS = '/usr/local/bin/clus --config=/etc/clusrc.d/yandex-direct'
    ZK_DB_CONFIG = 'zk-db-config -c /etc/zk-delivery/ppc.cfg -n /direct/db-config.json'
    def __init__(self, dbname):
        self.dbname = dbname
        self.hosts = self._clus_hosts()
        self.db_config = self._db_config()


    def set_config_host(self, host):
        """
        установить новый хост в db_config
        """
        with show('running'):
            local(" ".join(["sudo", self.ZK_DB_CONFIG, self.dbname, host]))


    def _db_config(self):
        full_db_config_text = local(self.ZK_DB_CONFIG, capture=True)
        full_db_config = json.loads(full_db_config_text)
        cur_cfg = full_db_config['db_config']
        ret = cur_cfg.copy()
        for part in self.dbname.split(":"):
            if 'CHILDS' in cur_cfg and part in cur_cfg['CHILDS']:
                cur_cfg = cur_cfg['CHILDS'][part]
                ret.update(cur_cfg)
            else:
                raise "Can't find db config for %s" % self.dbname

        if 'CHILDS' in cur_cfg:
            if '_' in cur_cfg['CHILDS']:
                ret.update(cur_cfg['CHILDS']['_'])
            else:
                raise "Can't find db config for %s (not leaf)" % self.dbname

        ret.pop('CHILDS', None)
        ret['dbname'] = self.dbname
        ret.setdefault('db', self.dbname.split(':')[0])
        ret.setdefault('instance', ret['db'])

        return ret


    def _clus_macro(self):
        m = re.match(r'^(.+):(\d+)$', self.dbname)
        if m:
            db, shard = m.groups()[:]
        else:
            db, shard = self.dbname, None
        if db == 'ppc':
            macro = 'PPCDATA'
        elif db == 'monitor':
            macro = 'PPCMONITOR'
        else:
            macro = db.upper()
        if shard is not None:
            macro += shard
        return macro


    def _clus_hosts(self):
        with hide('everything'):
            ret = local(self.CLUS + " " + self._clus_macro() + ' --hosts', capture=True)
        return [socket.getfqdn(h) for h in ret.strip().split(',')]


class TaskBase(Task):
    """
    Базовый клас для тасков с утилитами
    """
    paranoid = False
    def confirm(self, text):
        if self.paranoid:
            with settings(parallel=False):
                if not console.confirm(text):
                    abort("user interruption")


    def print_info(self, text):
        fastprint(colors.red(" -> ") + text)


    def print_ok(self):
        fastprint(colors.green("ok")+"\n")   


class StatusTask(TaskBase):
    """
    вывод статуса lm и lfw для одного или списка хостов
    """
    name = "status"

    def get_printable_status(self, instance):
        return run("lm %s status" % instance), sudo("lfw %s" % instance)
    
    def run(self, instance, hosts):
        self.print_info("current status:\n")
        with hide('commands'):
            ret = execute(self.get_printable_status, instance, hosts=hosts)
        for host, (lm_status, lfw_status) in sorted(ret.items()):
            one_lined_status = ''.join([x for x in lfw_status.splitlines() if "write" in x])
            fastprint("%s\n%s\n\n" % (lm_status, one_lined_status))


class SwitchMasterTask(TaskBase):
    """
    собственно переключение мастера
    """
    name = "switch-master"
    new_master = None

    @runs_once
    def run(self, dbname, new_master, paranoid=False, parallel=True, ignore_bad_repl=False, wait_slaves=True, force_change_master=False, use_gtid=True, sec_behind=60):
        start_time = time.time()

        self.paranoid = str(paranoid) == 'True'
        self.parallel = str(parallel) == 'True'
        self.wait_slaves = str(wait_slaves) == 'True'
        self.ignore_bad_repl = str(ignore_bad_repl) == 'True'
        self.force_change_master = str(force_change_master) == 'True'
        self.use_gtid = str(use_gtid) == 'True'
        self.sec_behind = int(sec_behind)

        self.db_info = DirectDbInfo(dbname)
        self.instance = self.db_info.db_config['instance']

        # хосты, исключенные руками через -x
        self.hosts = [h for h in self.db_info.hosts if h not in env.exclude_hosts]
        # дополнительно исключаем недоступные хосты
        self.exclude_unreach_hosts()
        if env.exclude_hosts:
            self.print_info("Excluded hosts: %s\n" % (env.exclude_hosts,))

        self.new_master = socket.getfqdn(new_master)
        self.old_master = self.db_info.db_config['host']

        self.ignore_old_master = self.old_master in env.exclude_hosts
        self.old_slaves = [h for h in self.hosts if h != self.old_master]
        self.new_slaves = [h for h in self.hosts if h != self.new_master]
        
        if self.new_master not in self.hosts:
            abort("new master %s not in hosts list %s" % (self.new_master, str(self.hosts)))
        elif self.old_master not in self.hosts and not self.ignore_old_master:
            abort("old master %s not in hosts list %s" % (self.old_master, str(self.hosts)))
        elif self.new_master == self.old_master:
            abort("New master and old master are same server: %s" % (self.old_master))

        self.print_info("switch master for " + dbname + ": " + colors.red(self.old_master) + ' -> ' + colors.green(self.new_master) + "\n")
        self.print_info("  hosts: %s\n" % ','.join(self.hosts))
    
        with settings(parallel=self.parallel):
            execute('status', self.instance, self.hosts)
            self.check_rpl_status()
         
            downtime_start_time = time.time()
            self.confirm("close old master?")
            # всегда пытаемся закрыть старого мастера, даже если он в exclude-hosts (но тогда не падаем при этом)
            with settings(warn_only=self.ignore_old_master, skip_bad_hosts=self.ignore_old_master):
                execute(self.close_master, host=self.old_master)
            
            stat = {}
            if self.wait_slaves:
                stat = self.wait_slaves_catch_up(self.old_slaves)
            else:
                stat = execute(self.get_status, hosts=self.hosts)

            new_master_pos = self.get_new_master_pos(stat)
            self.check_new_master_pos_is_most_recent(stat, self.old_slaves)
            execute('status', self.instance, self.hosts)
        
            self.confirm("change master to %s?" % (new_master_pos,))
            self.print_info("change master to %s\n" % new_master_pos) 
            execute(self.change_master, stat, new_master_pos, hosts=self.new_slaves)
            
            self.confirm("open new master?")
            execute(self.open_master, host=self.new_master)
            
            self.confirm("change db-config?")
            self.db_info.set_config_host(self.new_master)
            
            self.print_info("total time: %.3f sec, downtime time %.3f sec" % (time.time() - start_time, time.time() - downtime_start_time))


    def exclude_unreach_hosts(self):
        with settings(parallel=True, warn_only=True, skip_bad_hosts=True), hide('everything'):
            stat = execute(self.get_hostname, hosts=self.hosts)
        unreach = [ h for h in stat if isinstance(stat[h], NetworkError) ]
        err = { h: str(stat[h]) for h in stat if isinstance(stat[h], NetworkError) }

        if unreach:
            if not self.paranoid:
                abort("some hosts are unreachable. You can exclude them via --exclude-hosts=" + ",".join(unreach))

            self.confirm("unreachable hosts: %s\nExclude?" % (json.dumps(err, indent=2, sort_keys=True),))
            env.exclude_hosts.extend(unreach)
            self.hosts = [h for h in self.db_info.hosts if h not in env.exclude_hosts]
 

    def get_new_master_pos(self, stat):
        # stat по self.hosts уже получена выше в wait_slaves_catch_up или execute(...
        ##stat = execute(self.get_status, hosts=self.hosts)
        gtid_on = { h: (stat[h]['vars'].get('gtid_mode', 'off').lower() == 'on') for h in stat }

        if self.use_gtid and all(gtid_on.values()):
            return stat[self.new_master]['master_status']['str_auto']
        elif self.use_gtid and gtid_on[self.new_master] and not all(gtid_on.values()):
            abort("gtid_mode enabled on new master %s but disabled on replicas %s" % (self.new_master, str([ h for h in gtid_on if not gtid_on[h] ])))
        else:
            return stat[self.new_master]['master_status']['str']

     
    def check_rpl_status(self):
        self.print_info("check current replication schema... ")

        stat = execute(self.get_status, hosts=self.hosts)
        #print json.dumps(stat, indent=True)
 
        slaves_stat = [stat[host]['slave_status'] for host in self.old_slaves]
        old_master_stat = stat.get(self.old_master, None)
        new_master_stat = stat[self.new_master]

        with settings(warn_only=self.ignore_bad_repl):
            if not all([s['Master_Host'] == new_master_stat['slave_status']['Master_Host'] for s in slaves_stat]):
                error("not all slaves replicated from same master %s" % self.old_master)
            if not all([s['Slave_SQL_Running'] == 'Yes' and s['Slave_IO_Running'] == 'Yes' for s in slaves_stat]):
                error("not all slaves has running replication")
                for s in slaves_stat:
                    if not s.get('Seconds_Behind_Master') and (s['Slave_SQL_Running'] != 'Yes' or s['Slave_IO_Running'] != 'Yes'):
                        s['Seconds_Behind_Master'] = 0
            if not all([int(s['Seconds_Behind_Master']) < self.sec_behind for s in slaves_stat]):
                error("not all slaves have seconds_behind_master < " + str(self.sec_behind))

        if old_master_stat and old_master_stat['vars']['sql_log_bin'].lower() != 'on':
            abort("old master %s don't write binary logs" % self.old_master)
        if new_master_stat['vars']['sql_log_bin'].lower() != 'on':
            abort("new master %s don't write binary logs" % self.new_master)

        if len(set([s['vars']['server_id'] for s in stat.values()])) != len(stat):
            abort("server_id values is not unique")

        self.print_ok()


    def get_status(self):
        with hide('everything'):
            json_status = run("lm %s status-json" % self.instance)
        ret = json.loads(json_status)
        return ret


    def get_hostname(self):
        return run('hostname')


    def close_master(self):
        self.print_info("close old master %s\n" % env.host)
        with show('running'), hide('stdout'):
            sudo("lfw %s -write" % self.instance)
            run("lm %s killall" % self.instance)
    
        self.print_info("wait for sessions death on old_master %s... " % env.host)
        for _ in sleep_gen(0.2, 10):
            r = execute(self.get_status, host=env.host)
            if r[env.host]['proc_num'] == 0:
                break
        self.print_ok()
    
        execute('status', self.instance, [env.host])
    
    
    def open_master(self):
        self.print_info("open new master %s\n" % env.host)
        with show('running'), hide('stdout'):
            sudo("lfw %s +write" % self.instance)
    
        execute('status', self.instance, self.hosts)
 

    def change_master(self, stat, new_pos_str):
        with show('running'), hide('stdout'):
            if env.host == self.old_master or self.force_change_master:
                run("lm %s change-master --force %s" % (self.instance, new_pos_str))
            elif not self.ignore_old_master:
                run("lm %s change-master-safe %s %s" % (self.instance, stat[self.old_master]['master_status']['str'], new_pos_str))
            else:
                run("lm %s change-master-safe %s %s" % (self.instance, stat[self.new_master]['slave_status']['str'], new_pos_str))


    def get_read_up_to_info(self, slave_status):
        read_up_to_bin = slave_status['Relay_Master_Log_File']
        try:
            read_up_to_bin_num = int(read_up_to_bin.split('.', 2)[1])
            read_up_to_bin_pos = int(slave_status['Exec_Master_Log_Pos'])
        except Exception as e:
            abort("can't convert %s to integer binlog number and binlog pos" % (slave_status['str'],))
        return read_up_to_bin_num, read_up_to_bin_pos


    def check_new_master_pos_is_most_recent(self, stat, slaves):
        self.print_info("check if new master is on most recent pos of old master binlog ...\n")

        for slave in slaves:
            slave_read_up_to = self.get_read_up_to_info(stat[slave]['slave_status'])
            new_master_read_up_to = self.get_read_up_to_info(stat[self.new_master]['slave_status'])

            if slave_read_up_to[0] > new_master_read_up_to[0] or (slave_read_up_to[0] == new_master_read_up_to[0] and slave_read_up_to[1] > new_master_read_up_to[1]):
                abort("slave %s pos (%s) is more recent than new master %s pos (%s)\nyou should restart fab with the most recent slave as new_master, and only with use_gtid=True" % \
                        (slave, stat[slave]['slave_status']['str'], self.new_master, stat[self.new_master]['slave_status']['str']))
            elif ((slave_read_up_to[0] < new_master_read_up_to[0]) or \
                     (slave_read_up_to[0] == new_master_read_up_to[0] and slave_read_up_to[1] < new_master_read_up_to[1])) and not self.use_gtid:
                abort("slave %s will not be able to sync with new master %s - exclude slave or enable gtid" % (slave, self.new_master))

        self.print_ok()
 
     
    def wait_slaves_catch_up(self, slaves):
        self.print_info("wait while slaves catchs up...\n")
        
        print_time = time.time()
        for _ in sleep_gen(0.3, 300):
            stat = execute(self.get_status, hosts=self.hosts)
    
            wrong_slaves = []
            for slave in slaves:
                # ждем, пока все слейвы догонят текущего мастера, или пока выйдут на позицию нового мастера
                if (not self.ignore_old_master and stat[slave]['slave_status']['str'] != stat[self.old_master]['master_status']['str']) or \
                       (self.ignore_old_master and stat[slave]['slave_status']['str'] != stat[self.new_master]['slave_status']['str']):
                    wrong_slaves.append(slave)
    
            if not wrong_slaves:
                break

            if time.time() - print_time > 5:
                wrong_slave_str = ', '.join(["%s - %s sec" % (s, str(stat[s]['slave_status']['Seconds_Behind_Master'])) for s in wrong_slaves])
                self.print_info("  wait for hosts: %s\n" % wrong_slave_str)
                print_time = time.time()

        self.print_ok()
        return stat


smt_inst = SwitchMasterTask()
status_task_inst = StatusTask()

 
