#!/usr/bin/env python
# -*- coding: utf-8 -*-

import time
import socket
import os
import sys
import copy
import re
import prettytable
import json
from rediscluster import StrictRedisCluster
import rediscluster
import logging
import redis

db_config = '/etc/yandex-direct/db-config.json'

class bcolors:
    HEADER = '\033[95m'
    OKBLUE = '\033[94m'
    OKGREEN = '\033[92m'
    WARNING = '\033[93m'
    FAIL = '\033[91m'
    ENDC = '\033[0m'


def usage():
    print("""Тулза, проверяющая каждую ноду редиса. 
    -r использовать протокол редиса
    -s подключаться напрямую через сокет
    -i вывод полной информации о кластере
    -h этот хелп""")
    sys.exit(0)

def getHostsFromDBConfig(filename):
    try:
        with open(filename, 'r') as f:
            data = json.load(f)
        host_port = {}
        result = []
        for host in data[u'db_config'][u'CHILDS'][u'redis'][u'host']:
            host = host.encode("UTF-8")
            host_port = [k for k in host.split(":")]
            result.append({"host": host_port[0], "port": host_port[1]})
    except (IOError, IndexError, ValueError) as e:
        print("Can't load/parse config\n"+str(e))
        sys.exit(1)
    return result

def check_per_host_use_redis(startup_nodes):
    try:
        rc = redis.StrictRedisCluster(startup_nodes=startup_nodes, socket_timeout=1)
    except rediscluster.exceptions.RedisClusterException as e:
        print(e)
        sys.exit(1)
    try:
        a = rc.ping()
    except redis.exceptions.BusyLoadingError:
        print('LOADING')
    else:
        for node in startup_nodes:
            host_port = node.get('host') + ':' + node.get('port')
#            host_port = node[0].split(":")
#            port = (host_port.pop())
#            host = socket.getfqdn(':'.join(host_port))
            print(host_port)
            if (socket.getaddrinfo(node.get('host'), socket.AF_INET6)[0][4][0] +  ':' + node.get('port')) in a.keys():
                print('ping - ok')
            else:
                print('ping - failed')
            rc = redis.StrictRedis(host = node.get('host'), port = node.get('port'), socket_timeout=1)
            send_command(rc.bgsave)
            print("------")
#            except redis.exceptions.ResponseError as e:
#                print('bgsave failed with error: ' + str(e))


def check_per_host_use_socket(start_nodes):
    for node in start_nodes:
        print(node.get('host') + ':' + node.get('port'))
        print("ping - " + ssend(node, "PING"))
        print("bgsave - " + ssend(node, "BGSAVE"))
        print("------")
        
        
def ssend(node, command):
    try:
        conn = socket.socket(socket.AF_INET6, socket.SOCK_STREAM, socket.IPPROTO_IP)
        conn.settimeout(1)
        conn.connect((node.get('host'), int(node.get('port'))))
        conn.sendall(command+'\n')
        ans = conn.recv(1024)
        conn.close()
    except socket.timeout as e:
        return 'failed. %s' % e
    except socket.error as e:
        return 'failed. %s' % e
    else:
        return ans.strip()[1:]



def send_command(func):
    try:
        func()
    except redis.exceptions.TimeoutError:
        print('Timeout')
    except redis.exceptions.BusyLoadingError:
        print('LOADING')
    except redis.exceptions.ConnectionError:
        print('Can\'t connect')
    except redis.exceptions.ResponseError:
        print(func.__name__ +' - failed')
    else:
        print(func.__name__ + ' - ok')


def clnod(rc):
    clnod = rc.cluster_nodes()
    for node in clnod:
        for key in node.keys():
            if key != 'id':
                continue
            yield node.get(key)


if __name__ == '__main__':
    args = sys.argv[1:]
    if '-h' in args or args == []:
        usage()
    if '-c' in args:
        try:
            db_config = args[args.index('-c')+1]
        except IndexError:
            print("Option -c require filename with json dbconfig")
            sys.exit(1)
    startup_nodes = getHostsFromDBConfig(db_config)
    if '-s' in args:
        check_per_host_use_socket(startup_nodes)
    elif '-r' in args:
        check_per_host_use_redis(startup_nodes)
    rc =  redis.StrictRedisCluster(startup_nodes=startup_nodes, decode_responses=True, socket_timeout=1)
    redis_info = rc.info()
    if '-i' in args:
        for key in redis_info.keys():
            for key2, item in redis_info.get(key).items():
                print(key2, item)
