# -*- coding: utf-8 -*-
from __future__ import absolute_import, division, print_function, unicode_literals

from functools import partial, wraps
from collections import Iterator, Sized
from logging import getLogger

from gevent import Timeout
from gevent.pool import Pool
from ylog.context import get_log_context, LogContext

from yabus.util.connector_context import connector_context, copy_current_connector_context

logger = getLogger(__name__)


def copy_current_ylog_context(f):
    ylog_ctx = get_log_context()

    @wraps(f.func if isinstance(f, partial) else f)
    def wrapper(*args, **kwargs):
        with LogContext(**ylog_ctx):
            return f(*args, **kwargs)
    return wrapper


def copy_context(f):
    wrapped = copy_current_ylog_context(f)
    if connector_context:
        wrapped = copy_current_connector_context(wrapped)
    return wrapped


class ThreadPool(Pool):

    def imap(self, func, *iterables, **kwargs):
        return super(ThreadPool, self).imap(copy_context(func), *iterables, **kwargs)

    def imap_unordered(self, func, *iterables, **kwargs):
        return super(ThreadPool, self).imap_unordered(copy_context(func), *iterables, **kwargs)

    def map(self, func, iterable):
        return super(ThreadPool, self).map(copy_context(func), iterable)

    def pimap(self, func, iterable, timeout=None, safe=False):
        with Timeout(timeout, False):
            result = self.imap_unordered(func, iterable)
            while True:
                try:
                    yield result.next()
                except StopIteration:
                    break
                except Exception:
                    if not safe:
                        raise
                    else:
                        logger.warning("Ignored exception in pimap", exc_info=1)

    def __enter__(self):
        return self

    def __exit__(self, *args):
        self.kill()


def pmap(func, iterable, **kwargs):
    if isinstance(iterable, Iterator):
        iterable = tuple(iterable)
    if isinstance(iterable, Sized) and len(iterable) == 1:
        return [func(iterable[0])]
    with ThreadPool(**kwargs) as pool:
        return pool.map(func, iterable)


def pimap(func, iterable, timeout=None, safe=False, **kwargs):
    with ThreadPool(**kwargs) as pool:
        for x in pool.pimap(func, iterable, timeout, safe):
            yield x
