"""Provides various tools for working with limits."""

import dataclasses

from sepelib.core import constants
from sepelib.core.exceptions import Error, LogicalError
from walle.models import timestamp


@dataclasses.dataclass
class TimedLimit:
    period: int  # interval in seconds
    limit: int


def parse_timed_limits(limits_config: list[dict]) -> list[TimedLimit]:
    """Parses a config with timed limits specification."""

    return [parse_timed_limit(limit_config) for limit_config in limits_config]


def parse_period(string):
    """Parses a time period string."""

    period = str(string).strip().lower()

    try:

        if period.isdigit():
            period = int(period)
            unit = "s"
        else:
            period, unit = int(period[:-1]), period[-1:]

        if period <= 0:
            raise ValueError()

        if unit == "s":
            pass
        elif unit == "m":
            period *= constants.MINUTE_SECONDS
        elif unit == "h":
            period *= constants.HOUR_SECONDS
        elif unit == "d":
            period *= constants.DAY_SECONDS
        elif unit == "w":
            period *= constants.WEEK_SECONDS
        else:
            raise ValueError()
    except ValueError:
        raise Error("Invalid time period specification: {}", string)

    return period


def nice_period(period):
    """Return period in human-readable form."""
    ret = ''
    rem = period
    for interval, name, names in [
        (constants.WEEK_SECONDS, 'week', 'weeks'),
        (constants.DAY_SECONDS, 'day', 'days'),
        (constants.HOUR_SECONDS, 'hour', 'hours'),
        (constants.MINUTE_SECONDS, 'minute', 'minutes'),
        (1, 'second', 'seconds'),
    ]:
        q, rem = divmod(rem, interval)
        if q:
            if ret:
                ret += ' '
            ret += '{} {}'.format(q, name if q == 1 else names)
    return ret


class CheckResult:
    def __init__(self, result, info=None):
        self.result = result
        self.info = info

    def __bool__(self):
        return self.result

    def __nonzero__(self):
        return self.__bool__()


class _CheckLimitsInfo:
    def __init__(self, period, limit, current=None, objects=None):
        self.period = period  # in seconds
        self.limit = limit
        self.current = current
        self.objects = objects

    def __str__(self):
        countable = "time" if self.current == 1 else "times"

        limit_msg = ""
        if self.current != self.limit:
            limit_msg = " while limit is {limit}".format(limit=self.limit)

        return '{current} {countable} during the last {period}{limit_msg}'.format(
            current=self.current, countable=countable, period=nice_period(self.period), limit_msg=limit_msg
        )


def check_timed_limits(model, query, time_field, limits: list[TimedLimit], start_time=0, inclusive=True):
    """Checks the specified timed limits on the specified model."""

    if not limits:
        return CheckResult(True)

    time_field_name = time_field.db_field
    max_period = get_max_period(limits)

    if time_field_name in query:
        raise LogicalError()

    query = query.copy()
    query[time_field_name] = {"$gte": max(timestamp() - max_period, start_time)}

    objs = list(model._get_collection().find(query, {"_id": False, time_field_name: True}).sort(time_field_name, -1))

    return check_limits(objs, time_field.db_field, limits, inclusive=inclusive)


def check_limits(objs, time_field_name, limits: list[TimedLimit], inclusive=True):
    cur_time = timestamp()

    for limit in limits:
        min_time, limit_value = cur_time - limit.period, limit.limit
        out_of_limit_value = limit_value + 1 if inclusive else limit_value

        if out_of_limit_value <= 0:
            return CheckResult(False, _CheckLimitsInfo(limit.period, limit.limit))

        count = 0
        for obj in objs:
            if obj[time_field_name] < min_time:
                break

            count += 1
            if count >= out_of_limit_value:
                return CheckResult(False, _CheckLimitsInfo(limit.period, limit.limit, current=count, objects=objs))

    return CheckResult(True)


def get_max_period(limits: list[TimedLimit]):
    max_time = None

    for limit in limits:
        if max_time is None or limit.period > max_time:
            max_time = limit.period

    return max_time


def parse_timed_limit(limit_config) -> TimedLimit:
    try:
        if set(limit_config.keys()) != {"period", "limit"}:
            raise ValueError()
        return TimedLimit(parse_period(limit_config["period"]), int(limit_config["limit"]))
    except Exception:
        raise Error("Invalid timed limit specification: {}", limit_config)
