# -*- coding: utf-8 -*-
import threading
import time

from itertools import groupby
from functools import partial
from intranet.yandex_directory.src.yandex_directory.directory_logging.logger import log

from intranet.yandex_directory.src.yandex_directory import app
from intranet.yandex_directory.src.yandex_directory.common.datatools import flatten_lists
from intranet.yandex_directory.src.yandex_directory.common.db import (
    get_shard_numbers,
    get_main_connection,
)
from intranet.yandex_directory.src.yandex_directory.core.task_queue.worker import Worker
from .base import TaskType
from ..utils.ycrid_manager import ycrid_manager


def run_worker(queue, shard=None, theta=0.1, max_interval=30):
    """
    Выбираем задачи из очереди и выполняем
    :param queue: имя очереди задач
    :param shard: шард очереди (опционально)
    :param theta: коэффициент, который будем умножать на степень двойки, чтобы вычислить задержку между попытками поллинга базы
    """
    worker_name = Worker.get_worker_name()

    shards = get_shard_numbers()
    if shard and shard not in shards:
        raise RuntimeError('Unknown shard')

    if shard:
        shards = [shard]

    with log.name_and_fields('task_worker', worker=worker_name, queue=queue):
        with log.fields(tasks=list(TaskType.task_types.keys()), shards=shards):
            log.info('Starting worker')

        threads = [
            threading.Thread(
                target=process_tasks,
                args=(worker_name,
                      shard,
                      queue,
                      theta,
                      max_interval,
                )
            )
            for shard in shards
        ]
        for thread in threads:
            thread.daemon = True
            thread.start()

        # В основном потоке просто спим, потому что если сделать join на
        # потоки из threads, то тогда процесс нельзя будет прервать по Ctrl-C
        while True:
            time.sleep(10)


def process_tasks(worker_name, shard, queue, theta, max_interval):
    polling_theta = theta or app.config['TASK_POLLING_THETA']
    max_interval = max_interval or app.config['TASK_POLLING_MAX_INTERVAL']

    i = 0
    wait_interval = 0
    task = None
    worker = Worker(queue)

    with log.fields(worker=worker_name, queue=queue, shard=shard):
        while True:
            try:
                task = worker.lock_task(shard)
                if task:
                    # Выполняем таск в отдельной транзакции
                    ycrid_manager.set(task.ycrid)

                    with get_main_connection(shard=shard, for_write=True) as main_connection:
                        task.process(main_connection)

                    ycrid_manager.reset()
            except:
                log.trace().error('Error during task processing')
            finally:
                # Если таск только что был обработан, то велика вероятность, что
                # в очереди есть ещё. Поэтому тут мы не ждём, а сразу ломимся в
                # базу за следующей задачей
                if task:
                    wait_interval = 0
                    i = 0
                    log.debug('Going to the next task')
                else:
                    # если нет задач то увеличиваем интервал ожидания по экспоненте
                    if wait_interval < max_interval:
                        wait_interval = polling_theta * 2 ** i
                        wait_interval = min(max_interval, wait_interval)
                        i += 1

                    with log.fields(interval=wait_interval):
                        log.debug('Sleeping')

                    time.sleep(wait_interval)


def print_tasks_tree(main_connection, root_task_id=None):
    """Печатает таск и все подадачи, которых он ждёт.
       Если root_task_id не передан, то выводит все задачи.

       Эта функция может быть полезна при отладке тестов.

       from intranet.yandex_directory.src.yandex_directory.core.task_queue.utils import print_tasks_tree
       print_tasks_tree(self.main_connection)

    """
    from intranet.yandex_directory.src.yandex_directory.core.models.task import TaskModel
    from intranet.yandex_directory.src.yandex_directory.core.models.task_relations import TaskRelationsModel

    all_tasks = list(TaskModel(main_connection).filter().all())
    id_to_task = dict((task['id'], task) for task in all_tasks)

    # Теперь собирём их в дерево, используя информацию о зависимостях
    relations = list(TaskRelationsModel(main_connection).filter().all())

    # В этом словарике будем держать списки id подзадач
    get_task_id = lambda item: item['task_id']
    relations.sort(key=get_task_id)
    grouped =  groupby(relations, key=get_task_id)

    id_to_dependencies = dict(
        (task_id, [dep['dependency_task_id'] for dep in dependencies])
        for task_id, dependencies in grouped
    )

    flattened = flatten_lists(list(id_to_dependencies.values()))
    all_dependencies = set(flattened)

    root_tasks = [
        task_id
        for task_id in id_to_task
        if task_id not in all_dependencies
    ]
    root_tasks.sort()

    def print_task(task_id, parents=tuple()):
        """Рекурсивно печатаем инфу про таск и все его зависимости.
        """
        task = id_to_task[task_id]
        print('{ident} {id} {name} {state}'.format(
            ident='  ' * len(parents),
            id=task['id'],
            name=task['task_name'].split('.')[-1],
            state=task['state'],
        ))
        subtasks = id_to_dependencies.get(task_id, [])
        list(map(
            partial(print_task, parents=parents + (task_id,)),
            subtasks,
        ))

    list(map(print_task, root_tasks))

