# -*- coding: utf-8 -*-
import os
import time
import traceback
import threading
import socket
from concurrent.futures import (
    ThreadPoolExecutor,
    TimeoutError,
)

import psycopg2

from intranet.yandex_directory.src.yandex_directory import app
from intranet.yandex_directory.src.yandex_directory.common.db import engines
from intranet.yandex_directory.src.yandex_directory.common import backpressure
from intranet.yandex_directory.src.yandex_directory.core.utils import (
    except_fields,
)
from intranet.yandex_directory.src.yandex_directory.common.utils import (
    get_host_and_port_from_url,
    get_user_data_from_blackbox_by_login,
)
from functools import reduce
from intranet.yandex_directory.src.yandex_directory.directory_logging.logger import log


class SmokeTestsException(Exception):
    pass


class VitalServiceNotConfigured(SmokeTestsException):
    pass


class SmokeTestSkipped(SmokeTestsException):
    pass


# Здесь будут храниться результаты предыдущей проверки
# живости сервисов
_checks_result = {
    'environment': None,
    'default_timeout': None,
    'services': [],
    'errors': [],
    'has_errors_in_vital_services': False,
}
# Для проверки сервисов будет запускаться отдельный daemon поток.
# За запуск потока отвечает функция ensure_thread_is_running
_checker_thread = None


def check_databases(config, vital):
    results = []
    with ThreadPoolExecutor(10) as executor:
        for alias, shards in list(engines.items()):
            for shard, roles in list(shards.items()):
                for engine in list(roles.values()):
                    results.append(
                        executor.submit(
                            psycopg2.connect,
                            **engine.db_info['connection_info']
                        ).result
                    )
    # получаем результаты попыток соединения к БД
    # сами результаты нам не важны
    # важно были-ли ошибки
    # если было исключение то прокинется первое попавшееся
    [func() for func in results]


def check_blackbox_grants(config, vital):
    return get_user_data_from_blackbox_by_login('web-chib')


def check_config_url_tcp_connection(key):
    def func(config, vital):
        url_from_config = _get_deep_key(config, key)
        if url_from_config:
            host, port = get_host_and_port_from_url(url_from_config)
            _check_tcp_connection(host, port)
        elif vital:
            raise VitalServiceNotConfigured('%s has not been set in settings for this environment but it is vital' % key)
        else:
            raise SmokeTestSkipped('%s has not been set in settings for this environment' % key)
    return func


def _check_tcp_connection(host, port):
    socket.create_connection((host, port))


def _get_deep_key(source, key):
    return reduce(lambda val, x: val.get(x, {}), key.split('.'), source)


SERVICES = [
    {
        'name': 'Databases',
        'func': check_databases,
        'vital': True
    },
    {
        'name': 'Staff connection',
        'func': check_config_url_tcp_connection('STAFF_API_URL'),
        'vital': False
    },
    {
        'name': 'Blackbox connection',
        'func': check_config_url_tcp_connection('BLACKBOX.url'),
        'vital': False
    },
    {
        'name': 'Blackbox grants',
        'func': check_blackbox_grants,
        'vital': False
    },
]

def service_name_to_envvar(name):
    """Переводит имя в uppercase и заменяет пробелы на подчеркивания,
       а так же добавляет префикс SMOKE_ и суффикс _TIMEOUT.
    """
    return 'SMOKE_' + name.upper().replace(' ', '_') + '_TIMEOUT'

# установим таймауты из переменных окружения
# переменные окружения для таймаутов на smoke тесты сервисов
# должны формироваться, из названия, примерно так:
# SMOKE_CONDICTOR_CONNECTION_TIMEOUT
# и содержат float значение в секундах.
for service in SERVICES:
    env_var = service_name_to_envvar(service['name'])
    env_value = os.getenv(env_var)
    if env_value:
        with log.fields(service=except_fields(service, 'func'), new_timeout=env_value):
            try:
                env_value = float(env_value)
            except:
                log.trace().error('unable to convert value to float')
            else:
                service['timeout'] = env_value


def _run_smoke_tests(default_timeout=10):
    pool = ThreadPoolExecutor(len(SERVICES) * 2)
    threads = _get_smoke_test_threads(
        SERVICES,
        pool,
        default_timeout=default_timeout,
    )
    results = _get_smoke_test_results(threads)

    error_in_vital_services = any(
        i['status'] == 'error'
        for i in results
        if i['vital']
    )

    response = {
        'environment': app.config['ENVIRONMENT'],
        'default_timeout': default_timeout,
        'has_errors_in_vital_services': error_in_vital_services,
    }

    response['services'] = results
    return response


def get_smoke_tests_results(only_errors=False):
    """Возвращает результат проверок.
       При этом, так же удостоверяется, что поток, в котором выполняются проверки,
       работает.
    """
    ensure_thread_is_running()

    response = _checks_result.copy()
    response['environment'] = app.config['ENVIRONMENT']
    services = response['services']

    def format_message(item):
        message = item['message']
        vital = '!' if item['vital'] else ''

        if message == 'TimeoutError':
            return '{0}{1}(timeout={2})'.format(vital, message, item['timeout'])
        if message.startswith('SmokeTestSkipped'):
            return '{0}{1}'.format(vital, message)
        return '{0}{1} (look for traceback in logs)'.format(vital, message)

    # если нужно показать только ошибки, то переформатируем результаты
    if only_errors:
        errors = dict(
            (
                item['name'].lower().replace(' ', '-'),
                format_message(item)
            )
            for item in services
            if item['status'] != 'ok'
        )
        response['errors'] = errors
        del response['services']

    return response


def _get_smoke_test_threads(services, pool, default_timeout):
    """
    Запускаем smoke-тесты в разных тредах. Каждому треду назначаем timeout=default_timeout
    """
    threads = []
    for service in services:
        # установим таймаут в объект сервиса, если его там нет
        # это нам пригодится при формировании ошибки, если таймаут
        # всё-же случится
        service.setdefault('timeout', default_timeout)

        # Почему тут pool.submit вызывается дважды:
        # pool.submit возвращает объект Future, в метод result которого мы должны передать timeout
        # если результат не получилось получить за указанное время, будет вызвано исключение TimeoutError.
        # Нам нужно запустить N функций с одним и тем же таймаутом.
        # таким образом, мы устанавливаем timeout на получение результата для каждого треда сразу при помещении функции в пул

        threads.append(
            (
                pool.submit(
                    pool.submit(
                        service['func'],
                        config=app.config,
                        vital=service['vital']
                    ).result,
                    timeout=service['timeout'],
                ),
                service
            )
        )
    return threads


def _get_smoke_test_results(threads):
    """Собирает статусы проверок, выполненные в отдельных потоках,
       в результирующий список.
    """
    results = []
    for future, service in threads:
        try:
            future.result()
            response = {
                'status': 'ok',
            }
        except Exception as error:
            fields = service.copy()
            del fields['func']

            with log.fields(service=fields):
                if service['vital']:
                    log.trace().error('Fire. Dear Sir/Madam, I am writing to inform you of a fire...')
                else:
                    log.trace().warning('A little smoke was noticed')

            response = {
                'status': 'error'
            }
            if isinstance(error, SmokeTestSkipped):
                response['status'] = 'skipped'

            # таймаут нужно сохранить в ответе, чтобы потом его можно было добавить к
            # сообщению в отчёте о smoke тесте
            if isinstance(error, TimeoutError):
                response['timeout'] = service['timeout']

            response['message'] = ' '.join(
                # каждый элемент списка, который возвращает
                # format_exception_only, заканчивается на newline
                # и надо его обрезать
                item.strip()
                for item in traceback.format_exception_only(
                    type(error),
                    error
                )
            )
        response.update({
            'name': service['name'],
            'vital': service['vital'],
        })
        results.append(response)
    return results


def _smoke_check():
    """Выполняет проверку и сохраняет результаты в переменной и backpressure.
    """
    global _checks_result

    _checks_result = _run_smoke_tests()
    has_errors = _checks_result['has_errors_in_vital_services']
    backpressure.save_smoke_results(
        has_errors=has_errors,
    )


def _smoke_checker():
    """Бесконечный цикл для проверки живости сервисов, в которые
       ходит Директория.

       В процессе, обновляется словарь со статусом проверок.
       Информацию из этого словаря использует ручка /ping/
       и backpressure механизм.
    """
    with log.name_and_fields('smoke-tests'):
        while True:
            try:
                with app.stats_aggregator.log_work_time('ping_view_smoke_tests'):
                    _smoke_check()
            except Exception:
                log.trace().error('Error during smoke tests execution')

            # Раньше в ручку пинг Qloud приходил каждую секунду, но мне кажется
            # что для проверок это слишком часто, поэтому пусть будет раз в 15s.
            time.sleep(5)


def ensure_thread_is_running():
    """Запускает поток для проверки важных сервисов.
       Если поток уже запущен, то не делает ничего.
    """
    global _checker_thread

    if _checker_thread is None or not _checker_thread.is_alive():
        with log.name_and_fields('smoke-tests'):
            log.info('Starting a thread for smoke tests')
            _checker_thread = threading.Thread(target=_smoke_checker)
            _checker_thread.daemon = True
            _checker_thread.start()
