#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Скрипт проверки состояния redis кластера. Умеет проверять состояние нод, связи слейв-мастер,
количество слейвов, нахождение мастеров на 1 ноде, разброс по портам.
"""

import argparse
import time
import socket
import sys
import re
import prettytable
import json
import redis
from rediscluster import StrictRedisCluster
import rediscluster
import subprocess

DC = {
    "e": "IVA",
    "f": "MYT",
    "g": "FOL",
    "h": "UGR",
    "i": "SAS",
    "k": "MAN"
}

TABLE_LEGEND = "dc\\slots"
TABLE_EMPTY = "Empty"


class BColors:
    HEADER = '\033[95m'
    OKBLUE = '\033[94m'
    OKGREEN = '\033[92m'
    WARNING = '\033[93m'
    FAIL = '\033[91m'
    ENDC = '\033[0m'


def get_hosts_from_dbconfig(filename, section_name):
    """
    Получаем список хостов из конфига с базами
    """
    try:
        with open(filename, 'r') as f:
            return [
                dict(zip(["host", "port"], host_with_port.split(":")))
                for host_with_port in json.load(f)[u'db_config'][u'CHILDS'][section_name][u'host']
            ]

    except (IOError, IndexError, ValueError) as e:
        print "Can't load/parse config\n%s %s" % (e, type(e))
        sys.exit(1)


def get_fqdn_from_id(nodes, node_id):
    """
    Функция, возвращающая кортеж хост,порт по id ноды
    """
    if node_id in nodes:
        return socket.getfqdn(nodes[node_id]["host"]), str(nodes[node_id]["port"])

    print "Can't determine unknown node %s. Exit" % node_id
    sys.exit(1)


def get_slots_summary(nodes, node_id):
    """
    Функция, возвращающая слоты, обслуживаемые нодой
    """
    if node_id in nodes:
        try:
            return str(nodes[node_id].get("slots")[0]) + "-" + str(nodes[node_id].get("slots")[-1])
        except IndexError:
            return []
    return False


def check_diff_masters(nodes):
    """
    Проверка наличия на одном хосту нескольких нод-мастеров
    """
    host2masters = {}
    result = []

    for node_id, node in nodes.items():
        node_addr = get_fqdn_from_id(nodes, node_id)
        if node.get('master', '') and node_addr[0] == get_fqdn_from_id(nodes, node["master"])[0]:
            result.append("%s are on the same host with its master" % ":".join(node_addr))

        if u'master' not in node["flags"]:
            continue

        if node["host"] not in host2masters:
            host2masters[node["host"]] = []
        host2masters[node["host"]].append(node["port"])

    for host in host2masters:
        if len(host2masters[host]) > 1:
            result.append("Masters %s are placed on the same host %s" % (
                ", ".join([":" + str(master_port) for master_port in host2masters[host]]),
                socket.getfqdn(host)
            ))

    return ";\n".join(sorted(result)) if result else "Masters position - OK\nSlaves position - OK", bool(result)


def check_ports(nodes):
    """
    Проверка, что ноды обслуживающие один диапазон находятся на одном порту
    """
    result = []

    for node_id, node in nodes.items():
        if u'master' in node["flags"]:
            continue

        if node.get("port") != nodes.get(node.get("master", ""), {}).get("port"):
            result.append("node %s is from different master, now on %s" % (
                "%s:%s" % (socket.getfqdn(node["host"]), str(node["port"])),
                "%s:%s" % (
                    socket.getfqdn(nodes.get(node.get("master", ""), {}).get("host", "-")),
                    str(nodes.get(node.get("master", ""), {}).get("port", "-"))
                )
            ))

    return "; ".join(result) if result else "Ports - OK", bool(result)


def check_slots(slot_state, min_slaves_count=2):
    """
    Проверка, что диапазон обслуживают необходимое кол-во слейвов
    """
    result = []

    for slots, slots_info in slot_state.items():
        master = socket.getfqdn(slots_info["master"][0]) + ":" + str(slots_info["master"][1])

        if len(slots_info['slaves']) != min_slaves_count:
            result.append("%s master(%s) has %s than %d slaves (%d)" % (
                "-".join([str(i) for i in slots]),
                master,
                "more" if len(slots_info['slaves']) > min_slaves_count else "less",
                min_slaves_count,
                len(slots_info['slaves'])
            ))

    return "\n".join(result) if result else "Slots - OK", bool(result)


def check_failed_nodes(nodes):
    """
    Проверка живости нод
    """
    result = []

    for node_id, node in nodes.items():
        if u'fail' not in node["flags"]:
            continue

        result.append("%s(%s:%s)" % (node["id"].encode("UTF-8"), socket.getfqdn(node["host"]), node["port"]))

    return "\n    ".join(["fail nodes:"] + result) if result else "Nodes - ok", bool(result)


def make_node_matrix(nodes):
    """
    Создание похостовых списков с нодами
    """
    node_matrix = {}
    for node in nodes.values():
        if node["host"] not in node_matrix:
            node_matrix[node["host"]] = []
        node_matrix[node["host"]].append(node)

    return node_matrix


def get_nodes_info(nodes):
    """
    Проверка нод и вывод списка с комментариями
    """
    clnod_info = {}
    fail = False

    for node_id, node in nodes.items():
        info = []
        state = 'ON' if node["link-state"].encode("UTF-8") == "connected" else 'Maybe shutdown'

        if u'fail' in node["flags"]:
            check_fail = 'fail'
            fail = True
        else:
            check_fail = 'good'

        if len(node.get("slots")) and u'master' in node["flags"]:
            info = [
                (socket.getfqdn(node["host"]), node["port"]),
                state,
                get_slots_summary(nodes, node_id),
                "Master",
                check_fail
            ]
        elif u'slave' in node["flags"] and get_slots_summary(nodes, node["master"]) and check_fail == 'good':
            info = [
                (socket.getfqdn(node["host"]), node["port"]),
                state,
                get_slots_summary(nodes, node["master"]),
                " slave of:" + str(get_fqdn_from_id(nodes, node["master"])),
                check_fail
            ]
        elif u'slave' in node["flags"] and get_slots_summary(nodes, node["master"]) and check_fail == 'fail':
            info = [
                (socket.getfqdn(node["host"]), node["port"]),
                state,
                get_slots_summary(nodes, node["master"]),
                " last seen as slave of:" + str(get_fqdn_from_id(nodes, node["master"])),
                check_fail
            ]
        elif u'slave' in node["flags"] and get_slots_summary(nodes, node["master"]) == []:
            info = [
                (socket.getfqdn(node["host"]),  node["port"]),
                state,
                node["slots"],
                "Slave without slots/Dead Master " + str(get_fqdn_from_id(nodes, node["master"])),
                check_fail
            ]
        elif u'master' in node["flags"]:
            info = [
                (socket.getfqdn(node["host"]), node["port"]),
                state,
                node["slots"],
                "Master without slots",
                check_fail
            ]

        node["slots_summary"] = info[2] if info else ""
        clnod_info[node["id"].encode("UTF-8")] = info

    return clnod_info, fail


def remove_dead_nodes_on_startup(startup_nodes):
    """
    Проверяем ноды на доступность и удаляем мертвые
    """
    busy_nodes = {}
    result_nodes = []

    for node in startup_nodes:
        ans = ssend(node, 'PING\n')
        if ans == "+PONG\r\n":
            result_nodes.append(node)
        elif ans == '-LOADING Redis is loading the dataset in memory\r\n':
            busy_nodes[node.get('host') + ":" + node.get('port')] = 'LOAD'
            print node.get('host') + ":" + node.get('port') + ' - LOADING'
        else:
            busy_nodes[node.get('host') + ":" + node.get('port')] = ans

    return startup_nodes, busy_nodes


def ssend(node, command):
    try:
        conn = socket.socket(socket.AF_INET6, socket.SOCK_STREAM, socket.IPPROTO_IP)
        conn.settimeout(1)
        conn.connect((node.get('host'), int(node.get('port'))))
        conn.sendall(command + '\n')
        ans = conn.recv(1024)
        conn.close()
    except (socket.timeout, socket.error) as e:
        return e
    else:
        return ans


def build_table(clnod, clslot, busy_nodes):
    """
    Строим таблицу с данными о нодах
    """
    column = [TABLE_LEGEND]
    column.extend([str(i[0]) + "-" + str(i[1]) for i in sorted(clslot.keys())])
    column.append(TABLE_EMPTY)
    node_matrix = make_node_matrix(clnod)
    table = []

    for host in sorted(node_matrix.keys()):
        host_instances = node_matrix[host]
        slot_spane = {c: [] for c in column}

        for instance in host_instances:
            node_info = {
                'is_master': 'master' in instance["flags"],
                'is_dead': 'fail' in instance["flags"],
                'is_loading': busy_nodes.get(
                     ':'.join([socket.getfqdn(instance.get('host')), str(instance.get('port'))])
                ) == 'LOAD',
                'id': instance["id"],
                'host': instance["host"],
                'port': instance["port"],
                'slots_summary': instance["slots_summary"] if instance["slots_summary"] in column else TABLE_EMPTY,
            }
            slot_spane[node_info['slots_summary']].append(node_info)

        slot_spane[TABLE_LEGEND] = [socket.getfqdn(host_instances[0]["host"].encode("UTF-8"))]
        table.append(slot_spane)

    return {'table': table, 'columns': column}
        

def print_table(table, show_dc=False, show_skull=False):
    """
    Выводим таблицу с текущим состоянием кластера
    """
    if show_skull:
        skull = u"\u2620 "
    else:
        skull = "X_X "

    columns = table['columns']
    table_result = prettytable.PrettyTable(columns)

    for host_row in table['table']:
        row = []

        if show_dc:
            row_name = DC.get(sorted(set(
                re.match(r"\A\w+-*\w+", host_row[TABLE_LEGEND][0]).group(0)[-1]
            ))[0])
        else:
            row_name = host_row[TABLE_LEGEND][0]
        row.append(row_name)

        row.extend([
            ",".join(
                "%s%s%s%s" % (
                    "*" if node['is_master'] else '',
                    skull if node['is_dead'] else '',
                    '(L)' if node['is_loading'] else '',
                    str(node['port'])
                )
                for node in host_row[column_name]
            ) if column_name in host_row else "" for column_name in columns if column_name != TABLE_LEGEND
        ])

        table_result.add_row(row)

    print table_result.get_string(sortby="dc\\slots")

    legend = skull + "- dead\n* - master\n(L) - Dataset loading in memory"
    print legend


def get_slots_list(table):
    return [column for column in table['columns'] if column not in [TABLE_LEGEND, TABLE_EMPTY]]


def find_best_masters_distribution(table):
    slots_list = get_slots_list(table)

    best_distrib = None
    best_power = 0
    best_masters = 0
    best_alive_nodes = 0

    def build_arrangement(idx, arrangement):
        if idx >= len(slots_list):
            yield arrangement
            return

        for i in xrange(len(table['table'])):
            arrangement[idx] = i
            for _ in build_arrangement(idx + 1, arrangement):
                yield arrangement

    for cur_distrib in build_arrangement(0, [0] * len(slots_list)):
        cur_masters = 0
        cur_alive_nodes = 0
        for master_idx, slots in zip(cur_distrib, slots_list):
            has_alive_node = False
            has_alive_master = False
            for node in table['table'][master_idx][slots]:
                if not node['is_dead']:
                    has_alive_node = True
                    if node['is_master']:
                        has_alive_master = True

            cur_masters += bool(has_alive_master)
            cur_alive_nodes += bool(has_alive_node)

        cur_power = len(set(cur_distrib))
        if (cur_alive_nodes > best_alive_nodes or
                (cur_alive_nodes == best_alive_nodes and cur_power > best_power) or
                (cur_alive_nodes == best_alive_nodes and cur_power == best_power and cur_masters > best_masters)):
            best_distrib = list(cur_distrib)
            best_alive_nodes = cur_alive_nodes
            best_power = cur_power
            best_masters = cur_masters

    return best_distrib


def make_failover_commands(table, distrib):
    slots_list = get_slots_list(table)
    cmds = []
    for slots, master_idx in zip(slots_list, distrib):
        master_node = {}
        for node in table['table'][master_idx][slots]:
            if node['is_master'] and not node['is_dead']:
                master_node = node
                break

        prev_master_node = {}
        for host in table['table']:
            for node in host[slots]:
                if node['is_master'] and not node['is_dead']:
                    prev_master_node = node
                    node['is_master'] = False

        for node in table['table'][master_idx][slots]:
            if master_node and node['id'] == master_node['id'] or not master_node and not node['is_dead']:
                node['is_master'] = True
                master_node = node
                break

        if master_node.get('is_moved') and not prev_master_node:
            cmds.append([
                'echo',
                '"need to manually make the node a master because all nodes for slots (%s) are dead: %s:%d"' % (
                    slots, socket.getfqdn(master_node["host"]), master_node['port']
                )
            ])
        elif master_node:
            if master_node.get('is_moved'):
                cmds.append([
                    'redis-cli', '-h', socket.getfqdn(master_node["host"]), '-p', str(master_node['port']),
                    'cluster', 'replicate', prev_master_node['id']
                ])
            if prev_master_node.get('id') != master_node['id']:
                cmds.append([
                    'redis-cli', '-h', socket.getfqdn(master_node["host"]), '-p', str(master_node['port']),
                    'cluster', 'failover'
                ])

        for host in table['table']:
            for node in host[slots]:
                if node.get('is_moved') and master_node and node['id'] != master_node['id']:
                    cmds.append([
                        'redis-cli', '-h', socket.getfqdn(node["host"]), '-p', str(node['port']),
                        'cluster', 'replicate', master_node['id']
                    ])

    return cmds


def move_free_nodes(table):
    slots_list = get_slots_list(table)
    slots_order = {slots: idx for idx, slots in enumerate(slots_list)}
    slots_count = {slots: 0 for slots in slots_list}

    # для начала распределяем ноды, чтобы минимальное количество нод на один слот было максимальным
    for host in table['table']:
        nodes_alive = 0
        for slots in table['columns']:
            if slots == TABLE_LEGEND:
                continue
            for node in host[slots]:
                nodes_alive += int(not node['is_dead'])

        if not nodes_alive:
            continue

        # получаем слоты, у которых на текущий момент меньше всего нод
        min_slots = sorted(
            slots_count.keys(),
            key=lambda slots: (slots_count[slots], slots_order[slots])
        )[:min(len(slots_list), nodes_alive)]
        min_slots_set = set(min_slots)

        free_nodes = []
        for slots in table['columns']:
            if slots == TABLE_LEGEND:
                continue

            good_nodes = []
            # сначала проверяем, нет в данной ячейке живого мастера
            has_alive_node = False
            for node in host[slots]:
                if not node['is_dead'] and node['is_master']:
                    has_alive_node = True
                    break

            # пермещаем лишние ноды (живые не мастера, при условии, что останется хотя бы одна живая нода)
            for node in host[slots]:
                if (not node['is_dead'] and not node['is_master'] and
                        (slots == TABLE_EMPTY or has_alive_node or slots not in min_slots_set)):
                    node['is_moved'] = True
                    free_nodes.append(node)
                else:
                    good_nodes.append(node)

                if not node['is_dead']:
                    has_alive_node = True

            host[slots] = good_nodes

        # раскидываем "свободные ноды" по нужным слотам
        free_idx = 0
        for slots in min_slots:
            if free_idx >= len(free_nodes):
                break

            skip = False
            for node in host[slots]:
                if not node['is_dead']:
                    skip = True
                    break

            if skip:
                continue

            host[slots].append(free_nodes[free_idx])
            free_idx += 1

        slots_idx = 0
        for i in xrange(free_idx, len(free_nodes)):
            host[min_slots[slots_idx]].append(free_nodes[i])
            slots_idx = (slots_idx + 1) % len(min_slots)

        for slots in slots_list:
            for node in host[slots]:
                if not node['is_dead']:
                    slots_count[slots] += 1


def make_failover(table):
    move_free_nodes(table)
    distrib = find_best_masters_distribution(table)
    cmds = make_failover_commands(table, distrib)

    print ""
    if not cmds:
        print 'No failover commands (everything is OK or algorithm can\'t improve the situation)'
    else:
        print "Failover commands:\n%s" % "\n".join(" ".join(cmd) for cmd in cmds)
    return table, cmds


def my_tee(cmd):
    proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, close_fds=True)
    lines = []
    while True:
        line = proc.stdout.readline()
        if not line:
            break
        line = line.rstrip()
        lines.append(line)
        print line
    proc.wait()
    if proc.returncode != 0:
        sys.exit("Command failed, stop")
    return lines


def run_cmds_failover(cmds):
    print "\nStarting failover..."
    delay = 3
    for idx, cmd in enumerate(cmds):
        print "Running", cmd
        output = my_tee(cmd)

        if "\n".join(output).upper().find("error") != -1:
            sys.exit("Command failed, stop")

        if idx + 1 != len(cmds):
            print "waiting for %d seconds (cause redis is too slow...)" % delay
            time.sleep(delay)


def run_checks(clnod, clslot, section_name, skip_masters_check=False, ports_verify=False, verbose=False,
               verbose_ext=False):
    fail_result = False

    clnod_info, fail = get_nodes_info(clnod)
    fail_result |= fail

    slots_status, fail = check_slots(clslot)
    fail_result |= fail

    nodes_status, fail = check_failed_nodes(clnod)
    fail_result |= fail

    if not skip_masters_check:
        masters_status, fail = check_diff_masters(clnod)
        fail_result |= fail
    else:
        masters_status = ''

    if ports_verify:
        ports_status, fail = check_ports(clnod)
        fail_result |= fail
    else:
        ports_status = ''

    if fail_result:
        if verbose or verbose_ext:
            print "\n".join(status for status in [masters_status, slots_status, nodes_status, ports_status] if status)
        else:
            print "2;from:redis-check.py -r %s; cluster state wrong; more info: redis-check.py -r %s -vv -t" % (
                section_name, section_name
            )
    else:
        print "0;OK;from:redis-check.py -r %s; more info: redis-check.py -r %s -vv -t" % (
            section_name, section_name
        )

    if verbose_ext:
        for node_id, node_info in sorted(clnod_info.items(), key=lambda x: x[1][0]):
            if 'fail' in node_info:
                print "%s%s: %s%s" % (BColors.FAIL, node_id, node_info, BColors.ENDC)
            elif 'Master' in node_info:
                print "%s: %s" % (node_id, node_info)
            else:
                print "%s%s: %s%s" % (BColors.OKBLUE, node_id, node_info, BColors.ENDC)


def get_redis_cluster(startup_nodes, section_name):
    num_tries = 2
    clnod = None
    clslot = None
    for i in xrange(num_tries):
        try:
            rc = StrictRedisCluster(startup_nodes=startup_nodes, decode_responses=True, socket_timeout=1)
            clnod = {node["id"]: node for node in rc.cluster_nodes() if node["id"] != "0"}
            clslot = rc.cluster_slots()
            break
        except (rediscluster.exceptions.RedisClusterException, redis.exceptions.ConnectionError):
            if i + 1 == num_tries:
                print "2;from:redis-check.py -r %s; Can't connect to cluster. Use %s for info about nodes" % (
                    section_name,
                    "redis-per-host-check.py"
                )
                sys.exit(1)
            time.sleep(1)

    return clnod, clslot


def parse_options():
    parser = argparse.ArgumentParser(description=__doc__.decode('utf-8'))

    parser.add_argument(
        '-c', dest='db_config', type=str,
        help=u'use different config (default "/etc/yandex-direct/db-config.json")',
        default='/etc/yandex-direct/db-config.json'
    )
    parser.add_argument(
        '-r', dest='section_name', type=str,
        help=u'use different section name (default "redis")',
        default='redis'
    )
    parser.add_argument(
        '-v', dest='verbose', action='store_true',
        help='verbose mode'
    )
    parser.add_argument(
        '-vv', dest='verbose_ext', action='store_true',
        help='more verbose (info about every node)'
    )
    parser.add_argument(
        '-t', dest='table', action='store_true',
        help='show table'
    )
    parser.add_argument(
        '-p', dest='ports_verify', action='store_true',
        help='use check ports verify'
    )
    parser.add_argument(
        '-s', dest='skip_masters_check', action='store_true',
        help='skip masters check (standalone redis)'
    )
    parser.add_argument(
        '-dc', dest='show_dc', action='store_true',
        help='show dc names instead of fqdns in table (-t)'
    )
    parser.add_argument(
        '-sk', action='store_true',
        help=argparse.SUPPRESS
    )
    parser.add_argument(
        '--repair', dest='repair', action='store_true',
        help='print commands to repair cluster (add --apply to run commands)'
    )
    parser.add_argument(
        '--apply', dest='apply', action='store_true',
        help='run repair commands'
    )
    opts, extra = parser.parse_known_args()

    if len(extra) > 0:
        sys.exit("There are unknown parameters")

    return opts


def main():
    opts = parse_options()
    section_name = opts.section_name

    startup_nodes = get_hosts_from_dbconfig(opts.db_config, section_name)
    startup_nodes, busy_nodes = remove_dead_nodes_on_startup(startup_nodes)

    clnod, clslot = get_redis_cluster(startup_nodes, section_name)

    run_checks(
        clnod, clslot, section_name, 
        skip_masters_check=opts.skip_masters_check,
        ports_verify=opts.ports_verify,
        verbose=opts.verbose,
        verbose_ext=opts.verbose_ext
    )

    if opts.table or opts.repair:
        table = build_table(clnod, clslot, busy_nodes)
        if opts.table:
            print_table(table, show_dc=opts.show_dc, show_skull=opts.sk)
        if opts.repair:
            table_after_failover, cmds_failover = make_failover(table)
            if cmds_failover:
                print "After failover table should look like this:"
                print_table(table_after_failover, show_dc=opts.show_dc, show_skull=opts.sk)
                if opts.apply:
                    run_cmds_failover(cmds_failover)


if __name__ == '__main__':
    main()
