#!/usr/bin/env python

from jinja2 import Template
import json
import threading
import Queue
import os
import time
import socket

from log import create_logger

BASE_PATH = '/opt/venvs/gen-tns/'
TEMPLATE_FILE = BASE_PATH + 'src/templates/tns_name.j2'
MON_FILE = '/tmp/gen-tns.check'

env_config = {
        'development': BASE_PATH + 'config/tns_map.json.development',
        'testing': BASE_PATH + 'config/tns_map.json.testing',
        'production': BASE_PATH + 'config/tns_map.json.production'
}

TNS_MAP = env_config.get(os.environ.get('QLOUD_ENVIRONMENT', 'development'))

TNS_NAMES_FILE = '/opt/oracle/instantclient_12_1/network/admin/tnsnames.ora'

log = create_logger()


def remove_empty_tns(tns_map, failed_hosts):
    tns_map = remove_failed_hosts(tns_map, failed_hosts)
    result_tns = tns_map.copy()
    for name, data in tns_map.items():
        count = len(data['dcs'].keys())
        log.debug('DC count {} before removing failed hosts'.format(count))
        for dc, dc_data in data['dcs'].items():
            if not dc_data:
                count -= 1
        log.debug('DC count after removing failed hosts: {}'.format(count))
        if count <= 0:
            del result_tns[name]
    return result_tns


def generate_tns_names(tns_map):
    with open(TEMPLATE_FILE, 'r') as f:
        t = Template(f.read())
    return t.render({'tns_names': sorted(tns_map.items())})


def remove_failed_hosts(tns_map, failed_hosts):
    '''
    >>> m = {'a': ['x', 'y', 'z'], 'b': {'c': ['z', 'y']}}
    >>> remove_failed_hosts(m, ['y'])
    {'a': ['x', 'z'], 'b': {'c': ['z']}}
    >>> m
    {'a': ['x', 'y', 'z'], 'b': {'c': ['z', 'y']}}
    '''
    result_map = tns_map.copy()
    for k, v in tns_map.items():
        if isinstance(v, dict):
            result_map[k] = remove_failed_hosts(result_map[k], failed_hosts)
        elif isinstance(v, list):
            hosts = [host for host in v if host not in failed_hosts]
            result_map[k] = hosts

    return result_map


def get_uniq_hosts(tns_map):
    '''
    >>> m = {'a': ['x', 'y', 'z'], 'b': {'c': ['z', 'y']}}
    >>> get_uniq_hosts(m)
    set(['y', 'x', 'z'])
    '''
    hosts = []
    for k, v in tns_map.items():
        if isinstance(v, dict):
            hosts.extend(get_uniq_hosts(v))
        elif isinstance(v, list):
            hosts.extend([host for host in v])
    uniq_hosts = set(hosts)

    return uniq_hosts


class DBOraclePinger(object):
    def __init__(self, hosts, nthreads=5):
        self.queue = Queue.Queue()
        self._hosts = hosts
        log.debug('Hosts to be checked: {}'.format(self._hosts))
        self.alive_hosts = []
        self.failed_hosts_with_errors = []
        self.unknown_status_hosts = []
        self.nthreads = nthreads

    def _fill_queue(self):
        for host in self._hosts:
            self.queue.put(host)

    def _ping(self):
        port = 13
        timeout = 3
        while True:
            try:
                host = self.queue.get(False)  # no block
                log.debug('Processing host: {}'.format(host))
            except Queue.Empty:
                # if queue is empty kill thread
                log.debug('Queue is empty')
                return

            try:
                s = socket.socket(socket.AF_INET6, socket.SOCK_STREAM)
                s.settimeout(timeout)
                s.connect((host, port))

                log.debug('Host is OK: {}'.format(host))
                self.alive_hosts.append(host)
            except (socket.timeout, socket.herror) as e:
                log.exception('Host is Failing: {} {}'.format(host, e.message))
                self.failed_hosts_with_errors.append((host, e.message))
            except BaseException as e:
                log.exception('Pinger failed: {} {}'.format(host, e.strerror))
                self.unknown_status_hosts.append((host, e.strerror))
            finally:
                s.close()

    def run(self):
        self._fill_queue()
        workers = []
        for _ in range(self.nthreads):
            t = threading.Thread(target=self._ping)
            workers.append(t)
            t.start()
        for t in workers:
            t.join()

    @property
    def failed_hosts(self):
        failed_hosts = [host for host, err in self.failed_hosts_with_errors]
        log.debug('Unreachable hosts: {}'.format(failed_hosts))

        return failed_hosts


def replace_file_contents(file_name, new_contents):
    with open(file_name) as f:
        old_contents = f.read()
    if new_contents == old_contents:
        return False
    with open(file_name, 'w') as f:
        f.write(new_contents)
    return True


def create_path(file_path):
    parent_dir = os.path.dirname(file_path)

    try:
        os.makedirs(parent_dir)
    except OSError as e:
        if e.errno != 17:
            raise
        log.debug('{} already exists. Skiping creation.'.format(parent_dir))

    try:
        os.mknod(file_path)
    except OSError as e:
        if e.errno != 17:
            raise
        log.debug('{} already exists. Skiping creation.'.format(file_path))


def update_monitoring_data(status_code, status_desc):
    mon_data = 'PASSIVE-CHECK:' + ';'.join(['gen_tns_check',
                                            str(status_code),
                                            status_desc])
    log.debug(mon_data)
    with open(MON_FILE, 'w') as f:
        f.write(mon_data)


def sleep(t, interval):
    next_check = t + interval
    time_to_sleep = max(next_check - time.time(), 0)
    log.info('Sleeping for a while... ({:.3f} sec.)'.format(time_to_sleep))
    time.sleep(time_to_sleep)


def _process_status(status, status_code, status_desc):
    status_desc = 'Hosts in {} status: {}'.format(status, status_desc)
    log.debug('{} {}'.format(status_code, status_desc))
    update_monitoring_data(status_code, status_desc)


def main():
    log.debug('#### App {} started ####'.format(__name__))
    create_path(TNS_NAMES_FILE)

    nthreads = 10   # number of threads should be >= the number of hosts
    interval = 5.0  # seconds between checks
    while True:
        t = time.time()
        log.info('Start {} regeneration'.format(TNS_NAMES_FILE))
        with open(TNS_MAP, 'r') as f:
            tns_map = json.loads(f.read())

        log.debug('{}'.format(tns_map))
        log.info('TNS_MAP read successfully: {}'.format(TNS_MAP))

        hosts = get_uniq_hosts(tns_map)
        pinger = DBOraclePinger(hosts, nthreads=nthreads)
        pinger.run()

        status = 'alive'
        status_code = 0
        status_desc = pinger.alive_hosts
        if pinger.unknown_status_hosts:
            status = 'unknown'
            status_code = 2
            status_desc = pinger.unknown_status_hosts
            _process_status(status, status_code, status_desc)
            log.info('Stop regenerate tns_names due to unknown status')
            sleep(t, interval)
            continue
        elif pinger.failed_hosts:
            status = 'failed'
            status_code = 2
            status_desc = pinger.failed_hosts_with_errors

        _process_status(status, status_code, status_desc)

        tns_map = remove_empty_tns(tns_map, pinger.failed_hosts)
        log.debug('tns_map={}'.format(repr(tns_map)))
        new_contents = generate_tns_names(tns_map)
        if replace_file_contents(TNS_NAMES_FILE, new_contents):
            log.info('{} successfully regenerated'.format(TNS_NAMES_FILE))
        else:
            log.info('{} left unchanged'.format(TNS_NAMES_FILE))

        sleep(t, interval)


if __name__ == '__main__':
    main()
