#!/usr/bin/env python2
# -*- encoding: utf-8 -*-

import argparse
import json
import os
import logging
import multiprocessing as mp
import signal
import time

from threading import Timer
from subprocess import Popen, PIPE


DEFAULT_DBCONFIG = "/etc/yandex-direct/db-config-np/db-config.dev7.json"
USAGE = '''Программа для запуска ptkill демона для директовых баз. Вычитываем dbconfig, запускает отдельные процессы ptkill и следит за ними.
При обновлении хостов или смены мастера, происходит перезапуск ptkill у изменнеых шардов. Обновление лога происходит HUP сигналом:
перзапускается вся программа и потомки от неё.

%(prog)s --dbconfig <dbconfig_path> --debug

Например:
%(prog)s --dbconfig /etc/yandex-direct/db-config-np/db-config.devtest.json
'''


def startLoging(level, log_file=None):
    logger = logging.getLogger('steam logs to console')
    logger.setLevel(level=getattr(logging, level))
    # create console handler and set level to LEVEL
    chan = logging.StreamHandler() if log_file is None else logging.FileHandler(
        log_file, mode='a')
    chan.setLevel(level=getattr(logging, level))
    # create formatter
    formatter = logging.Formatter(
        '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    # add formatter to ch
    chan.setFormatter(formatter)
    # add ch to logger
    logger.addHandler(chan)
    return logger, chan


def parseInstances(raw, result=list(), saved=dict()):
    '''Принимает на вход словарь, ищет вхождения searching_key и выводит список с набором хэшей.
       Например:
           [ {u'port': 9451, u'instance': u'ppcdata11', u'host': u'host11.yandex.ru', u'user': u'user1', u'pass': {u'file': u'/etc/tokens/mysql'}},
             {u'port': 9450, u'instance': u'ppcdata10', u'host': u'host10.yandex.ru', u'user': u'user1', u'pass': {u'file': u'/etc/tokens/mysql'}}
           ]
    '''
    saved = saved.copy()
    searching_keys = ['instance', 'host', 'port', 'pass', 'user']
    excluding_keys = ['extra_users']
    dict_list = []
    good_keys = set(raw.keys()) - set(excluding_keys)
    for value in raw:
        if value in searching_keys:
            saved[value] = raw[value]
        elif isinstance(raw[value], dict):
            dict_list.append(value)

    if len(saved) < len(searching_keys):
        for i in dict_list:
            parseInstances(raw[i], result, saved)
    else:
        logger.debug('[parseInstances] {0}'.format(saved))
        result.append(saved)
    return result


def readFile(f1le):
    try:
        passwd = open(f1le, 'r').read().strip()
    except Exception as err:
        passwd = ''
        logger.critical(err)
    return passwd


def readPasswd(config):
    '''Принимает на вход список словарей, находит в ней структуры вида {'pass': {'file': '/etc/secret'}},
       пытается прочитать и перезаписывает ключ 'pass'. Если произошла ошибка чтения файла, подставляется
       пустая строка.
    '''
    pass_files = [value['pass']['file'] for value in config if isinstance(value['pass'], dict) and
                  'file' in value['pass']]
    pass_files = tuple(set(pass_files))
    passwd_list = dict([(i, readFile(i)) for i in pass_files])
    logger.debug('[readPasswd] {0}'.format(passwd_list))
    for i in config:
        if isinstance(i['pass'], dict) and 'file' in i['pass']:
            i['pass'] = passwd_list[i['pass']['file']]
    return


def readConfig(config_mysql, cnf_data, exclude_instances=[]):
    '''Принимает на вход путь до конфига mysql, на выходе список с словарями. В каждом указывается порт,
       инстанс, хост, юзер и пароль. Пример вывода указан в описании parseInstances().
    '''
    mtime_current = int(
        os.stat(
            config_mysql).st_mtime) if os.path.exists(
        config_mysql) else 0
    if config_mysql in cnf_data:
        mtime_last = cnf_data[config_mysql]['mtime']
        if mtime_last == mtime_current:
            logger.debug('config dont changes. Using saving')
            return
    try:
        insts = []
        cnf_data[config_mysql] = {}
        logger.debug('read new config')
        with open(config_mysql, 'rb') as fd:
            raw = json.load(fd)
        parseInstances(raw, insts)
        insts = [i for i in insts if i['instance'] not in exclude_instances]
        readPasswd(insts)
        cnf_data[config_mysql]['data'] = insts
        cnf_data[config_mysql]['mtime'] = mtime_current
    except Exception as err:
        logger.critical('[readConfig] {0}'.format(err))
    return


def execute(**kwargs):
    instance = kwargs.get('instance', None)
    db = 'ppc' if instance.count('ppcdata') else ''
    db = 'ppcdict' if instance.count('ppcdict') else db
    db = 'monitor' if instance.count('monitor') else db
    cmdfmt = "/usr/bin/pt-kill2 --host '{host}' -D '{db}' --port {port} --user {user} --password {pass} --filter /etc/ptkill/ppcdata.pl " + \
        " --busy-time 30 --ignore-command '^Sleep$' --kill-query --print --victims all --shard-name {instance}"
    cmd = cmdfmt.format(db=db, **kwargs)
    logger.debug("starting command {0}".format(cmd))
    return Popen(cmd, shell=True, stdout=chan.stream.fileno(), stderr=chan.stream.fileno())


def ptkiller(args):
    logger.debug("start execute pt-kill with args {0}".format(args))
    return execute(**args)


class TerminateError(Exception):

    def __init__(self, msg):
        self.message = msg

    def __str__(self):
        return self.message


def sighup_handler(signum, frame):
    msg = "recieve {0} signal".format(signal.getsignal(signum))
    raise TerminateError(msg)

def sigterm_handler(signum, frame):
    msg = "recieve {0} signal".format(signal.getsignal(signum))
    raise TerminateError(msg)


def main(cnf_name, cnf_data={}):
    #  хэш таблица {'instance': {"host": <hostname>, "proc": <popen process>}}
    allprocs = {}
    q = mp.JoinableQueue()

    try:
        signal.signal(signal.SIGHUP, sighup_handler)
        signal.signal(signal.SIGTERM, sigterm_handler)
        while True:
            readConfig(cnf_name, cnf_data)
            logger.debug("current config {}".format(cnf_data))
            for i in cnf_data[cnf_name]['data']:
                instance = i.get('instance')
                process_by_instance = allprocs.get(instance, None)
                # если ptkill не запущен для инстанса - запускаем
                if instance not in allprocs:
                    allprocs[instance] = {
                        'host': i.get('host'), 'proc': ptkiller(i)}
                    continue
                # если поменялся у инстанса хост, то ptkill прибиваем и обновляем хэш
                if not i.get('host').count(allprocs[instance]['host']):
                    if allprocs[instance]['proc'].poll() is None:
                        allprocs[instance]['proc'].kill()
                        logger.info(
                            "changed database for instance {0}: {1} --> {2}. Send TERM {3}".format(instance,
                                                                                                   allprocs[instance]['host'], i.get('host'), allprocs[instance]['proc'].pid))
                    allprocs[instance] = {
                        'host': i.get('host'), 'proc': ptkiller(i)}
                    continue
                else:
                    # если процесс помер, то запускаем новый процесс
                    if allprocs[instance]['proc'].poll() is not None:
                        allprocs[instance] = {
                            'host': i.get('host'), 'proc': ptkiller(i)}
                    continue
            time.sleep(1)
    except TerminateError as err:
        logger.info(err)
    except Exception as err:
        logger.error('error main(), {0}'.format(err))
        raise
    finally:
        for shard_name in allprocs:
            try:
                meta = allprocs[shard_name]
                proc = meta.get('proc')
                if proc is None:
                    continue
                proc.terminate()
                logger.info(
                    "{0} process terminated for shard {1}".format(proc, shard_name))
            except Exception as err:
                logger.critical(
                    "error process terminated for shard {1}: {0}".format(err, shard_name))


if __name__ == '__main__':
    parser = argparse.ArgumentParser(usage=USAGE)
    parser.add_argument("-d", "--debug", action='store_true',
                        dest='debug', help="enable debug mode")
    parser.add_argument(
        "-c", "--dbconfig", action='store', default=DEFAULT_DBCONFIG,
        dest="dbconfig", help="direct database config")
    parser.add_argument(
        "-l", "--log-file", action='store', type=str, default=None,
        dest="logfile", help="log file path")
    opts = parser.parse_args()

    logger, chan = startLoging(
        "DEBUG", opts.logfile) if opts.debug else startLoging("INFO", opts.logfile)
    main(opts.dbconfig)
