#!/usr/bin/env python
# -*- coding: utf-8 -*-

import socket
import os
import sys
from rediscluster import StrictRedisCluster
import rediscluster
import logging
import redis
import json

db_config = '/etc/yandex-direct/db-config.json'

class bcolors:
    HEADER = '\033[95m'
    OKBLUE = '\033[94m'
    OKGREEN = '\033[92m'
    WARNING = '\033[93m'
    FAIL = '\033[91m'
    ENDC = '\033[0m'

def usage():
    print('''
Скрипт для забывания ноды в редис-кластере.
Может принимать ноды в каческтве аргументов или автоматически их определять при подключении к кластеру.
Использование:
    -h этот хелп

    -r Использовать для работы редис(иначе общение происходит по сокету) / менее надежно

    redis-forget.py [-r] [0e8de8f88ab4b6918a13e9aa9f1fc68d0b04717a [7e286ab539ec4c129df68047c8cc8a6355e05a1b]]
    ''')

#Получаем список хостов из конфига с базами
def getHostsFromDBConfig(filename):
    try:
        with open(filename, 'r') as f:
            data = json.load(f)
        host_port = {}
        result = []
        for host in data[u'db_config'][u'CHILDS'][u'redis'][u'host']:
            host = host.encode("UTF-8")
            host_port = [k for k in host.split(":")]
            result.append({"host": host_port[0], "port": host_port[1]})
    except (IOError, IndexError, ValueError) as e:
        print("Can't load/parse config\n"+str(e))
        sys.exit(1)
    print(len(result))
#    print(dir(redis.StrictRedis))
    return result

def send_command(func):
    try:
        func()
    except redis.exceptions.TimeoutError:
        print('Timeout')
    except redis.exceptions.BusyLoadingError:
        print('LOADING')
    except redis.exceptions.ConnectionError:
        print('Can\'t connect')
    except redis.exceptions.ResponseError:
        print(func.__name__ +' - failed')
    else:
        print(func.__name__ + ' - ok')


def ssend(node, command):
    try:
        print(node.get('host'),node.get('port'))
        conn = socket.socket(socket.AF_INET6, socket.SOCK_STREAM, socket.IPPROTO_IP)
        conn.settimeout(1)
        conn.connect((node.get('host'), int(node.get('port'))))
        conn.sendall(command+'\n')
        ans = conn.recv(1024)
        conn.close()
    except socket.timeout:
        print('timeout')
    except socket.error as e:
        print('socket.error', e)
    else:
        return repr(ans)

def remove_node_by_id(nodes, fail_nodes):
    for fail_node in fail_nodes:
        for node in nodes:
            print("try remove %s from %s" % (fail_node, node))
            print(ssend(node, "CLUSTER FORGET " + fail_node))

if __name__ == '__main__':
    args = sys.argv[1:]
    if ('-h' in args) or (args == []):
        usage();
        sys.exit(0)
    use_redis = False
    fail_nodes = args
    start_nodes = getHostsFromDBConfig(db_config)
    if '-r' in fail_nodes:
        use_redis = True
        fail_nodes.remove('-r')

    if fail_nodes == []:
        try:
            rc = StrictRedisCluster(startup_nodes=start_nodes, decode_responses=True, socket_timeout=1)
        except rediscluster.exceptions.ConnectionError as e:
            print('Can\'t connect to cluster')
            exit(1)
        else:
            clnod = rc.cluster_nodes()
        for node in clnod:
            if ('noaddr' in node.get('flags', '') and node.get('link-state', '') == 'disconnected'):
                fail_nodes.append(node.get('id'))

    if use_redis:
        for node in fail_nodes:
            send_command(rc.cluster_forget(node))
    else:
        remove_node_by_id(start_nodes, fail_nodes)
