# coding: utf8
from __future__ import unicode_literals, absolute_import, division, print_function

import logging
import multiprocessing
import os
import sys
import threading
import traceback
from contextlib import contextmanager

import mongoengine
from django import db
from django.conf import settings

from common.data_api.deploy.instance import deploy_client
from common.db.mongo.mongo import ConnectionProxy
from common.utils.dcutils import ResourceExplorer


log = logging.getLogger(__name__)


def on_worker_start():
    # Не трогая соединения родителя, открываем свои новые.
    # Для mysql это важно, чтобы не ломать транзакции родителя
    db.connections._connections = threading.local()


def close_mongo_connections():
    """
    Закрываем все соединения, иначе воркеры в пуле пошарят соединение при fork.

    Нельзя забывать, что коллекция в pymongo хранит внутри себя инстанс MongoClient,
    и сделать .close на таких клиентах недостаточно - т.к. там может остаться и форкнуться
    залоченный threading.Lock (например, в topology или periodic_executor).
    Лучше всего никогда не хранить коллекции отдельно, а всегда обращаться к ним через databases,
    тогда при форках ссылки на все MongoClient'ы исчезнут, и создадутся новые кленты с новыми локами.

    https://st.yandex-team.ru/RASPFRONT-6810#5d2c0e14701665001c532b50
    """

    # Необходимо делать disconnect до форка, т.к. при его вызове использутся локи,
    # которые точно так же могут повиснуть при форке
    for db_name in settings.MONGO_DATABASES.keys():
        mongoengine.connection.disconnect(db_name)
    ConnectionProxy.close_connections()


@contextmanager
def get_pool(pool_size):
    close_mongo_connections()

    pool = multiprocessing.Pool(processes=pool_size, initializer=on_worker_start)
    try:
        yield pool
    finally:
        pool.close()
        pool.join()


def error_catcher(args):
    """
    Ловим ошибки из процессов, чтобы видеть реальный трейсбек из функции,
    а не просто ошибку в недрах multiprocessing
    """

    func, func_args = args
    try:
        return func(func_args)
    except BaseException:
        raise Exception("".join(traceback.format_exception(*sys.exc_info())))


def run_parallel(func, args_list, pool_size=None):
    """
    Запускаем функцию параллельно в нескольких процессах.

    :param func: функция
    :param args_list: список аргументов, с которыми нужно запускать функцию параллельно
    :param pool_size: количество процессов
    :return: возвращет результаты выполнения функции по мере получения

    def mult(args):
        a, b = args
        print(a, b)
        return a * b

    for result in run_parallel(mult, [(1, 2), (3, 4), (5, 6)]):
        print(result)
    """

    log.info('Start {} on {} processes'.format(func, pool_size))
    with get_pool(pool_size) as pool:
        for result in pool.imap_unordered(error_catcher, ((func, args) for args in args_list)):
            yield result


# global state for forks
_run_instance_method_fork_state = {}


def _run_instance_method(args):
    instance_id, method_name, method_args = args
    instance = _run_instance_method_fork_state[instance_id]

    return getattr(instance, method_name)(*method_args)


def run_instance_method_parallel(method, args_list, pool_size=None):
    """
    Run `method` in parallel on `pool_size` workers, using arguments from args_list,
    with no sending any data into subprocesses, just using same instance after subprocess fork.

    :param method: bound-method of some instance
    :param args_list: list of arguments for each method call: [method_args1, method_args2, ...]
    :param pool_size: max number of subprocesses
    :return: generator of method calls results

    Example:
    class A(object):
        def __init__(self, heavy_value):
            self.heavy_value = heavy_value

        def foo(self, i):
            return self.heavy_value[i] * self.heavy_value[i]

    a = A()
    for res in run_instance_method_parallel(a.foo, [1, 2, 10], pool_size=3):
        print(res)

    The problem of running instance method in subprocess is this:
    When we have a lot of data in memory and want to do some calculations with it in parallel,
    we definitely don't want to send all this data inside each subprocess.
    But if we try to send instance (or its bound method) into subprocess - it should be fully serialized (pickled),
    with all its data. Either it will fail trying to serialize method/data - or data will be serialized and sent,
    both cases are really bad for us.

    To avoid this, instance will be registered in global `fork_state` and only its id and method name
    will be sent to subprocesses. Child process then use id to find instance and call a method by name.
    """

    instance = method.im_self
    instance_id = id(instance)
    _run_instance_method_fork_state[instance_id] = instance

    method_name = method.im_func.func_name
    args_list = [[instance_id, method_name, args] for args in args_list]

    return run_parallel(_run_instance_method, args_list, pool_size)


def get_cpu_count():
    log.info('Running on host with {} cores'.format(multiprocessing.cpu_count()))

    cpu_count_str = None

    if ResourceExplorer.is_run_in_qloud():
        cpu_count_str = os.getenv('QLOUD_CPU_GUARANTEE')
    if ResourceExplorer.is_run_in_sandbox():
        cpu_count_str = os.getenv('SANDBOX_CPU_GUARANTEE')
    elif ResourceExplorer.is_run_in_deploy():
        box_cpu = deploy_client.get_current_box_requirements().get('cpu')
        if box_cpu:
            # https://st.yandex-team.ru/RTCSUPPORT-8679
            cpu_count = box_cpu.get('cpu_guarantee_millicores') or box_cpu.get('cpu_limit_millicores')
            if cpu_count:
                cpu_count_str = cpu_count / 1000

    if cpu_count_str:
        cpu_count = int(float(cpu_count_str))
    else:
        cpu_count = multiprocessing.cpu_count()

    return max(1, cpu_count)
