"""Miscellaneous utilities."""

import calendar
import datetime
import functools
import inspect
import itertools
import logging
import random
import time
import typing as tp
import uuid
from collections import deque
from contextlib import contextmanager

import gevent
import pytz
from gevent.greenlet import Greenlet
from gevent.pool import Pool
from gevent.queue import Queue, Empty

from sepelib.core import config
from sepelib.core.exceptions import LogicalError
from walle.util.gevent_tools import gevent_idle_iter

log = logging.getLogger(__name__)


class AllErrors:
    def __contains__(self, status_code):
        return True


ALL_ERRORS = AllErrors()


class InvOrUUIDOrName:
    def __init__(self, value):
        if value is None:
            self._value = None
        elif not self._try_set_uuid(value) and not self._try_set_inv(value) and not self._try_set_name(value):
            raise LogicalError

    def is_inv(self):
        return isinstance(self._value, int)

    def is_uuid(self):
        try:
            uuid.UUID(self._value)
            return True
        except (ValueError, AttributeError):
            return False

    def value(self):
        return self._value

    def key(self, inv_arg="inv", name_arg="name", uuid_arg="uuid"):
        if self.is_uuid():
            return uuid_arg

        return inv_arg if self.is_inv() else name_arg

    def kwargs(self, inv_arg="inv", name_arg="name", uuid_arg="uuid", **kwargs):
        return dict({self.key(inv_arg, name_arg, uuid_arg): self.value()}, **kwargs)

    def __eq__(self, other):
        # NB: no __hash__ method means no support for checking in sets and dicts.
        # We can not perform such checks in case insensitive manner, so we don't support it at all.
        if not isinstance(other, type(self)):
            other = type(self)(other)

        return self.value() == other.value()

    def __ne__(self, other):
        return not (self == other)

    def _try_set_inv(self, value):
        try:
            self._value = int(value)
        except ValueError:
            return False
        else:
            return True

    def _try_set_name(self, value):
        try:
            self._value = value.lower()
        except AttributeError:
            return False
        else:
            return True

    def _try_set_uuid(self, value):
        try:
            self._value = uuid.UUID(value).hex
        except (ValueError, AttributeError):
            return False
        else:
            return True


def iter_chunks(data, size):
    iterator = iter(data)
    while True:
        chunk = list(itertools.islice(iterator, size))
        if len(chunk):
            yield chunk
        else:
            break


def iter_shuffle(target_list, random_func=None):
    """This is an adaptation of the random.shuffle function, which does the very same thing, but returns a generator.
    NB: this is not thread-safe.

    iterable, random=random.random -> shuffle list iterable in place, and yield every processed ("shuffled") value.

    Optional arg random is a 0-argument function returning a random
    float in [0.0, 1.0); by default, the standard random.random.

    """

    if random_func is None:
        random_func = random.random
    _int = int
    for i in reversed(range(0, len(target_list))):
        # pick an element in x[:i+1] with which to exchange x[i]
        j = _int(random_func() * (i + 1))
        target_list[i], target_list[j] = target_list[j], target_list[i]
        yield target_list[i]


def filter_dict_keys(d, allowed_keys):
    return {k: v for k, v in d.items() if k in allowed_keys}


def drop_none(d, none=None):
    return {k: v for k, v in d.items() if v is not none}


def concat_dicts(*dicts):
    return {k: v for d in dicts for k, v in d.items()}


def first(iterable):
    return next(iter(iterable), None)


def find_one(iterable, key, value):
    _UNDEF = object()
    return first(i for i in iterable if i.get(key, _UNDEF) == value)


def unique(iterable, clear=False):
    items = set(iterable)
    if clear:
        items = filter(None, items)

    return list(items)


def values_filtered(d, allowed_keys, filter_func=None):
    return list(filter(filter_func, (v for k, v in d.items() if k in allowed_keys)))


def values_ordered(mapping, keys):
    """Return values from mapping ordered in the same order as keys."""
    return [mapping[key] for key in keys]


def fix_mongo_set_kwargs(**kwargs):
    """
    MongoEngine's syntax for setting field value is too inconvenient with None values because set__field=None saves
    None to database instead of unsetting the field. This function converts update kwargs to proper set__* and unset__*
    MongoEngine kwargs for None values.
    """

    update = {}

    for op, value in kwargs.items():
        if op.startswith("set__") and value is None:
            op, value = "unset__" + op[len("set__") :], True

        update[op] = value

    return update


def fix_mongo_batch_update_dict(update):
    new_update = {}
    for k, v in update.get("$set", {}).items():
        if v is None:
            new_update.setdefault("$unset", {})[k] = True
        else:
            new_update.setdefault("$set", {})[k] = v
    return new_update


def args_as_dict(func, *args, **kwargs):
    positionals, args_name, kwargs_name, defaults = inspect.getargspec(func)

    fetched_args = {}
    if defaults:
        fetched_args.update(zip(positionals[-len(defaults) :], defaults))
    if args:
        fetched_args.update(zip(positionals, args))
        if len(args) > len(positionals):
            fetched_args[args_name] = args[-len(positionals) :]
    fetched_args.update(kwargs)

    return fetched_args


@contextmanager
def closing_ctx(close, *args, **kwargs):
    """Somewhat similar to contextlib.closing
    except this one calls given function with given arguments instead of just `.close()`.
    """
    try:
        yield
    finally:
        close(*args, **kwargs)


@contextmanager
def dummy_context(obj=None, *args, **kwargs):
    yield obj


def add_interval_job(scheduler, func, interval, name=None, delay=None, trigger="interval"):
    kwargs = drop_none(
        dict(
            name=name,
            start_date=datetime.datetime.now() + datetime.timedelta(seconds=(delay or 0) + random.randint(0, interval)),
            # Workaround gevent freezes
            misfire_grace_time=interval,
            coalesce=True,
            timezone=pytz.timezone("Europe/Moscow"),
        )
    )

    return scheduler.add_job(func, trigger, seconds=interval, **kwargs)


def parallelize_execution(*functions, pool: tp.Optional[Pool] = None):
    """Runs the specified functions in parallel."""

    greenlets: tp.List[Greenlet] = []

    try:
        for function in functions:
            greenlets.append(gevent.spawn(function) if pool is None else pool.spawn(function))

        return [greenlet.get() for greenlet in greenlets]
    finally:
        if pool is None:
            gevent.joinall(greenlets)
        else:
            pool.join()


def parallelize_processing(function, args, threads, cpu_intensive=False, pool: tp.Optional[Pool] = None):
    """Runs the specified function in parallel for each argument."""

    if not isinstance(args, (tuple, list)):
        args = list(gevent_idle_iter(args))

    if not args:
        return []

    if len(args) <= threads:

        def get_executor(arg):
            @functools.wraps(function)
            def executor():
                return function(arg)

            return executor

        all_results = parallelize_execution(*(get_executor(arg) for arg in args), pool=pool)
    else:
        queue = Queue(items=args)

        @functools.wraps(function)
        def executor():
            results = deque()

            for _ in gevent_idle_iter(itertools.count(), cpu_intensive=cpu_intensive):
                try:
                    arg = queue.get_nowait()
                except Empty:
                    break

                results.append(function(arg))

            return results

        all_results = list(
            itertools.chain.from_iterable(parallelize_execution(*(executor for _ in range(threads)), pool=pool))
        )

    gevent.idle()

    return all_results


def ellipsis_string(string, max_length):
    if len(string) <= max_length:
        return string

    separator = ">...<"
    length = (max_length - len(separator)) // 2
    if length < 1:
        return string

    return string[:length] + separator + string[-length:]


def format_long_list_for_logging(items, limit=10):
    human_list = ", ".join(itertools.islice(items, limit))

    if len(items) > limit:
        human_list += ", ..."

    return human_list


def format_time(time, format="%Y.%m.%d %H:%M:%S"):
    return datetime.datetime.fromtimestamp(time).strftime(format)


def timetuple_to_time(timetuple, strip_to_epoch=False):
    timestamp = calendar.timegm(timetuple)

    if timestamp < 0:
        if strip_to_epoch:
            timestamp = 0
        else:
            raise ValueError("Time is earlier than 01.01.1970")

    return timestamp


def datetime_to_timestamp(dt):
    return time.mktime(dt.timetuple()) + dt.microsecond / 1000000.0


class StopWatch:
    def __init__(self, log_str=None):
        self._start = time.time()
        self._lap = self._start
        self._log_str = log_str

    def get(self):
        self._lap = time.time()
        return self._lap - self._start

    def split(self):
        this_lap = time.time()
        result = this_lap - self._lap
        self._lap = this_lap
        return result

    def reset(self):
        new_start = time.time()
        result = new_start - self._start
        self._start = new_start
        self._lap = new_start
        return result

    def __enter__(self):
        pass

    def __exit__(self, exc_type, exc_val, exc_tb):
        if self._log_str:
            log.info(f"{self._log_str}. Execution time {time.time() - self._start} seconds")


class LocationSegment:
    COUNTRY = "country"
    CITY = "city"
    DATACENTER = "datacenter"
    QUEUE = "queue"
    RACK = "rack"
    UNIT = "unit"


def get_location_path(location, stop_at=None):
    segments = (
        LocationSegment.COUNTRY,
        LocationSegment.CITY,
        LocationSegment.DATACENTER,
        LocationSegment.QUEUE,
        LocationSegment.RACK,
        LocationSegment.UNIT,
    )

    if stop_at is not None:
        segments = segments[: segments.index(stop_at) + 1]

    return "|".join(getattr(location, placement) or "-" for placement in segments)


def is_not_none_arg_validator(instance, attribute, value):
    if value is None:
        raise AttributeError("'{}' is essential for task, but here a None value was found.".format(attribute.name))


def merge_iterators_by_uuid(a_iterator, b_iterator):
    a = None
    b = None

    while a_iterator or b_iterator:
        if not a and a_iterator:
            try:
                a = next(a_iterator)
            except StopIteration:
                a = None
                a_iterator = None
        if not b and b_iterator:
            try:
                b = next(b_iterator)
            except StopIteration:
                b = None
                b_iterator = None

        if a and b and a.uuid == b.uuid:
            yield a, b
            a = None
            b = None
        elif a and (not b or a.uuid < b.uuid):
            yield a, None
            a = None
        elif b and (not a or b.uuid < a.uuid):
            yield None, b
            b = None


def replace_group_with_its_members(owners, group_to_members):
    from walle.clients.staff import is_group

    owners_expanded = []
    for owner in owners:
        if not is_group(owner):
            owners_expanded.append(owner)
        else:
            owners_expanded.extend(group_to_members.get(owner, []))

    return sorted(set(owners_expanded))


class DummyContextManager:
    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        return False


@functools.cache
def get_existing_tiers():
    return config.get_value("tier_configuration.existing_tiers")


@functools.cache
def get_expert_shards_count(tier):
    if isinstance(tier, float):
        tier = int(tier)
    return config.get_value(f"expert_system.shards_num.{tier}", config.get_value("expert_system.shards_num_default"))
