import itertools
import math
import numbers
import threading

from contextlib2 import contextmanager

from sepelib.core.exceptions import LogicalError
from walle.models import timestamp

ABSOLUTE = 'axxx'  # absolute from instance, max between instances and max for time roll-up
ABSOLUTE_MINIMUM = 'annn'  # absolute from instance, min between instances and min for time roll-up
DISTRIBUTED = 'ammx'  # absolute from instance, sum between instances, max for time roll-up
DELTA = 'summ'  # delta from instance, sum between instances and sum for time roll-up

AGE_AGGREGATORS = {ABSOLUTE, ABSOLUTE_MINIMUM}


class Counter:
    def __init__(self, aggregation_method, value=0):
        self.value = value
        self._lock = threading.Lock()
        self.aggregation_method = aggregation_method

    def set(self, value):
        with self._lock:
            self.value = value

    def add(self, value):
        with self._lock:
            self.value += value

    @classmethod
    def load(cls, storage_format):
        aggregation_method, value = storage_format
        counter = cls(aggregation_method)
        if not isinstance(value, numbers.Number):
            raise ValueError("Unknown storage format")
        counter.value = value
        return counter

    def to_yasm_format(self):
        return "count_" + self.aggregation_method, self.value

    def dump(self):
        return self.aggregation_method, self.value

    def merge(self, counter):
        with self._lock:
            if self.aggregation_method != counter.aggregation_method:
                raise ValueError("Different aggregation methods given")
            if self.aggregation_method == DELTA:
                self.value += counter.value
            elif self.aggregation_method == ABSOLUTE:
                self.value = max(self.value, counter.value)
            elif self.aggregation_method == ABSOLUTE_MINIMUM:
                self.value = min(self.value, counter.value)
            else:
                raise ValueError("Unknown aggregation method given")


class Age:
    def __init__(self, aggregation_method=ABSOLUTE, last_timestamp=0):
        self.last_timestamp = last_timestamp
        self._lock = threading.Lock()
        self.aggregation_method = aggregation_method

    def set(self, value):
        with self._lock:
            if value is None:
                self.last_timestamp = timestamp()
            else:
                self.last_timestamp = value

    @classmethod
    def load(cls, storage_format):
        aggregation_method, last_timestamp = storage_format
        age = cls(aggregation_method)
        if not isinstance(last_timestamp, numbers.Number):
            raise ValueError("Unknown storage format")
        age.last_timestamp = last_timestamp
        return age

    def age(self):
        return timestamp() - self.last_timestamp

    def to_yasm_format(self):
        return "age_" + self.aggregation_method, self.age()

    def dump(self):
        return self.aggregation_method, self.last_timestamp

    def merge(self, age):
        with self._lock:
            if self.aggregation_method != age.aggregation_method:
                raise ValueError("Different aggregation methods given")
            if self.aggregation_method == ABSOLUTE:
                self.last_timestamp = max(self.last_timestamp, age.last_timestamp)
            elif self.aggregation_method == ABSOLUTE_MINIMUM:
                self.last_timestamp = min(self.last_timestamp, age.last_timestamp)
            else:
                raise ValueError("Unknown aggregation method given")


class Histogram:
    _BORDERS = []

    def __init__(self):
        self._lock = threading.Lock()
        self._buckets = [0] * len(self._BORDERS)

    # pylint: disable=protected-access
    @classmethod
    def load(cls, storage_format):
        hist = cls()
        if not isinstance(storage_format, (list, tuple)) or len(storage_format) != len(cls._BORDERS):
            raise ValueError("Unknown storage format")
        hist._buckets = list(storage_format)
        return hist

    def dump(self):
        return list(self._buckets)

    def add(self, value):
        raise NotImplementedError

    def add_multiple(self, iterator):
        for value in iterator:
            self.add(value)

    @classmethod
    def from_list(cls, values_list):
        hist = cls()
        hist.add_multiple(values_list)
        return hist

    def get_difference_from(self, other_hist):
        hist = type(self)()
        for idx, (left, right) in enumerate(zip(self._buckets, other_hist._buckets)):
            hist._buckets[idx] = max(left - right, 0)
        return hist

    def get_percentile(self, percentile):
        with self._lock:
            items_count = sum(self._buckets)
            items_required = items_count * percentile
            items_so_far = 0
            pct_bucket = 0
            for idx, value in enumerate(self._buckets):
                pct_bucket = idx
                items_so_far += value
                if items_so_far >= items_required:
                    break
        return (self._BORDERS[pct_bucket] + self._BORDERS[min(pct_bucket + 1, len(self._BORDERS) - 1)]) / 2.0

    def to_yasm_format(self):
        first_nonzero = next((idx for idx, value in enumerate(self._buckets) if value > 0), 0)
        # actually next after the last non-zero
        last_nonzero = len(self._buckets) - next(
            (idx for idx, value in enumerate(reversed(self._buckets)) if value > 0), 0
        )
        borders = self._BORDERS[first_nonzero:last_nonzero]
        values = self._buckets[first_nonzero:last_nonzero]
        return "hgram_dhhh", [[edge, count] for edge, count in zip(borders, values)]

    # pylint: disable=protected-access
    def merge(self, hist):
        with self._lock:
            for idx in range(len(self._BORDERS)):
                self._buckets[idx] += hist._buckets[idx]


class IntegerLinearHistogram(Histogram):
    STEP = 1
    MIN = 0
    MAX = 5000
    _BORDERS = list(range(MIN, MAX, STEP)) + [MAX]

    def add(self, value):
        offset = max(self.MIN, (min(int(value), self.MAX) - self.MIN) // self.STEP)
        with self._lock:
            self._buckets[offset] += 1

    def get_percentile(self, percentile):
        return int(super().get_percentile(percentile))


_LOGARITHMIC_BASE = 1.5


class LogarithmicHistogram(Histogram):
    MIN_LOG = -50
    MAX_LOG = 50
    _BORDERS = [0.0] + [_LOGARITHMIC_BASE**x for x in range(MIN_LOG, MAX_LOG + 1)]

    @classmethod
    def fast_from_list(cls, iterator):
        hist = cls()
        with hist._lock:
            buckets = hist._buckets
            for value in iterator:
                if value > 0:
                    offset = math.floor(
                        max(cls.MIN_LOG - 1, min(cls.MAX_LOG, math.log(value, _LOGARITHMIC_BASE))) - cls.MIN_LOG + 1
                    )
                    buckets[int(offset)] += 1
                else:
                    buckets[0] += 1
        return hist

    def add(self, value):
        if value == 0:
            with self._lock:
                self._buckets[0] += 1
            return
        offset = math.floor(
            max(self.MIN_LOG - 1, min(self.MAX_LOG, math.log(value, _LOGARITHMIC_BASE))) - self.MIN_LOG + 1
        )
        with self._lock:
            self._buckets[int(offset)] += 1


class StatsManager:

    _idx_to_meters = {1: Counter, 2: LogarithmicHistogram, 3: IntegerLinearHistogram, 4: Age}
    _meters_to_idx = {cls: idx for idx, cls in _idx_to_meters.items()}

    def __init__(self):
        self._counters = {}
        self._histograms = {}
        self._ages = {}
        self._lock = threading.Lock()

    def reset(self):
        with self._lock:
            self._counters = {}
            self._histograms = {}
            self._ages = {}
        return self

    def increment_counter(self, key, value=1, aggregation=DELTA):
        self._get_counter(key, aggregation).add(value)
        return self

    def decrement_counter(self, key, value=1, aggregation=DELTA):
        self._get_counter(key, aggregation).add(-value)
        return self

    def set_counter_value(self, key, value, aggregation=DELTA):
        self._get_counter(key, aggregation).set(value)
        return self

    def set_age_timestamp(self, key, value=None, aggregation=ABSOLUTE):
        self._get_age(key, aggregation).set(value)
        return self

    def get_counter_value(self, key):
        if key in self._counters:
            return self._counters[key].value
        return 0

    def get_age_timestamp(self, key):
        if key in self._ages:
            return self._ages[key].last_timestamp
        return 0

    def add_sample(self, key, value, hist_cls=LogarithmicHistogram):
        self._get_histogram(key, hist_cls).add(value)
        return self

    def add_sample_multiple(self, key, values_list, hist_cls=LogarithmicHistogram):
        self._get_histogram(key, hist_cls).add_multiple(values_list)
        return self

    def get_sample(self, key, hist_cls=LogarithmicHistogram):
        return self._get_histogram(key, hist_cls)

    def has_sample(self, key):
        with self._lock:
            return key in self._histograms

    # pylint: disable=protected-access
    def merge(self, manager):
        with self._lock:
            for key, meter in manager._iterate_metrics():
                container = self._get_container(type(meter))
                if key in container:
                    container[key].merge(meter)
                else:
                    container[key] = meter
        return self

    # pylint: disable=protected-access
    @classmethod
    def load(cls, state):
        manager = StatsManager()
        for key, (meter_idx, storage_format) in state.items():
            key = cls._parse_string_key(key)
            meter_cls = cls._idx_to_meters[meter_idx]
            manager._get_container(meter_cls)[key] = meter_cls.load(storage_format)
        return manager

    def dump(self):
        state = {}
        for key, meter in self._iterate_metrics():
            string_key = self._format_string_key(key)
            if string_key in state:
                raise ValueError("Metrics with identical names found: {}".format(string_key))
            state[string_key] = (self._meters_to_idx[type(meter)], meter.dump())
        return state

    def to_yasm_format(self):
        metrics_list = []
        for key, meter in self._iterate_metrics():
            suffix, value = meter.to_yasm_format()
            metrics_list.append(["{}_{}".format(self._format_string_key(key), suffix), value])
        return metrics_list

    @staticmethod
    def _format_string_key(key):
        if isinstance(key, (tuple, list)):
            return ".".join(key)
        return key  # assume string otherwise

    @staticmethod
    def _parse_string_key(key):
        return tuple(key.split("."))

    def _iterate_metrics(self):
        return itertools.chain.from_iterable(
            [
                # iteritems is not thread-safe
                list(self._ages.items()),
                list(self._counters.items()),
                list(self._histograms.items()),
            ]
        )

    def _get_counter(self, key, aggregation) -> Counter:
        with self._lock:
            if key not in self._counters:
                self._counters[key] = Counter(aggregation)
        counter = self._counters[key]
        if counter.aggregation_method != aggregation:
            raise ValueError(
                "Cannot change aggregation method from {} to {}".format(counter.aggregation_method, aggregation)
            )
        return counter

    def _get_age(self, key, aggregation):
        if aggregation not in AGE_AGGREGATORS:
            raise ValueError("Cannot use {} aggregation method on age".format(aggregation))
        with self._lock:
            if key not in self._ages:
                self._ages[key] = Age(aggregation)
        age = self._ages[key]
        if age.aggregation_method != aggregation:
            raise ValueError(
                "Cannot change aggregation method from {} to {}".format(age.aggregation_method, aggregation)
            )
        return age

    def _get_histogram(self, key, cls=LogarithmicHistogram):
        with self._lock:
            if key not in self._histograms:
                self._histograms[key] = cls()
        return self._histograms[key]

    def _get_container(self, meter_cls):
        if issubclass(meter_cls, Counter):
            return self._counters
        elif issubclass(meter_cls, Histogram):
            return self._histograms
        elif issubclass(meter_cls, Age):
            return self._ages
        else:
            raise RuntimeError("Unknown metric class {!r}".format(meter_cls))


class Timing:
    def __init__(self, key, stopwatch, stats=None):
        self.key = self._normalize_key(prime_key=(), key=key)
        self.stopwatch = stopwatch
        self.stats_manager = stats or stats_manager

    def split(self, key):
        self._submit(key, self.stopwatch.split)

    def reset(self, key):
        self._submit(key, self.stopwatch.reset)

    def submit(self, key):
        self._submit(key, self.stopwatch.get)

    def _submit(self, key, getter):
        self.stats_manager.add_sample(self._normalize_key(self.key, key), getter())

    @contextmanager
    def measure(self, name, success=None, error=None):
        """
        Measure time that execution of context have taken.

        :param name: name of resulting metric
        :param success: name of metric for success result
        :param error: name of metric for error result

        :type name: basestring
        :type success: basestring|None
        :type error: basestring|None
        """
        try:
            yield
        except Exception:
            if error is not None:
                self.submit((name, error))
            raise
        else:
            if success is not None:
                self.submit((name, success))
        finally:
            self.split(key=name)

    @staticmethod
    def _normalize_key(prime_key, key):
        if isinstance(key, tuple):
            return prime_key + key
        if isinstance(key, str):
            return prime_key + (key,)
        else:
            raise LogicalError


# global stats manager
stats_manager = StatsManager()
