# coding: utf-8

import logging
import ctypes
import time

from bisect import bisect_right
from itertools import izip
from multiprocessing import Value, Array, Manager
from dateutil import parser


log = logging.getLogger('unistat')


class AccessLogCount(object):
    def __init__(self, name=None, endpoint=None):
        self.__endpoint = '' if endpoint is None else endpoint
        self.__name = join_signal_name('xxx' if name is None else name, endpoint)
        self.__counter = Value(ctypes.c_size_t, 0)

    def update(self, record):
        if ('/%s' % self.__endpoint) in record['request']:
            with self.__counter.get_lock():
                self.__counter.value += 1

    def get(self):
        with self.__counter.get_lock():
            return [with_sigopt_suffix(self.__name, 'summ'), self.__counter.value]


class AccessLogCountByFirstStatusDigit(object):
    def __init__(self, name_prefix=None, endpoint=None):
        self.__endpoint = '' if endpoint is None else endpoint
        self.__name_prefix = join_signal_name('' if name_prefix is None else name_prefix, endpoint)
        self.__counters = Array(ctypes.c_size_t, 10)

    def update(self, record):
        if ('/%s' % self.__endpoint) in record['request']:
            with self.__counters.get_lock():
                self.__counters.get_obj()[int(record['status'][0])] += 1

    def get(self):
        with self.__counters.get_lock():
            return {k: [with_sigopt_suffix(join_signal_name(self.__name_prefix, str(k) + 'xx'), 'summ'), v]
                    for k, v in enumerate(self.__counters.get_obj()) if v}


class CountByStatus(object):
    def __init__(self, endpoint_field, name_prefix=None, endpoint=None):
        self.__endpoint = '' if endpoint is None else endpoint
        self.__name_prefix = join_signal_name('' if name_prefix is None else name_prefix, endpoint)
        self.__size_array = 600
        self.__counters = Array(ctypes.c_size_t, self.__size_array)
        self.__endpoint_field = endpoint_field

    def update(self, record):
        if ('/%s' % self.__endpoint) in record[self.__endpoint_field]:
            status = int(record['status'])
            if status < self.__size_array and status >= 0:
                with self.__counters.get_lock():
                    self.__counters[status] += 1

    def get(self):
        with self.__counters.get_lock():
            return {k: [with_sigopt_suffix(join_signal_name(self.__name_prefix, k), 'summ'), v]
                    for k, v in enumerate(self.__counters) if v}


class AccessLogCountByStatus(CountByStatus):
    def __init__(self, name_prefix=None, endpoint=None):
        super(AccessLogCountByStatus, self).__init__(endpoint_field='request', name_prefix=name_prefix, endpoint=endpoint)


class Hist(object):
    def __init__(self, buckets, name, get_value):
        self.__buckets = buckets
        self.__name = name
        self.__get_value = get_value
        self.__counters = Array(ctypes.c_size_t, len(buckets))

    def update(self, record):
        value = self.__get_value(record)
        index = bisect_right(self.__buckets, value)
        if index <= 0:
            return
        with self.__counters.get_lock():
            self.__counters.get_obj()[index - 1] += 1

    def get(self):
        with self.__counters.get_lock():
            return make_hist_from_counter(name=self.__name, counter=izip(self.__buckets, self.__counters))


class AccessLogRequestTimeHist(object):
    def __init__(self, buckets, name=None, endpoint=None):
        self.__endpoint = '' if endpoint is None else endpoint
        self.__impl = Hist(
            buckets=buckets,
            name=join_signal_name('access_log_request' if name is None else name, endpoint),
            get_value=lambda v: float(v['request_time']),
        )

    def update(self, record):
        if ('/%s' % self.__endpoint) in record['request']:
            self.__impl.update(record)

    def get(self):
        return self.__impl.get()


class StatServerXmlReactorsTrace(object):
    def __init__(self, name_prefix=None):
        self.__name_prefix = 'reactors_' if name_prefix is None else name_prefix
        self.__counters = Manager().dict()

    def update(self, record):
        for reactor in iter(record.find('modules').find('stat_server').find('reactors')):
            self.__counters[reactor.tag] = int(reactor.text)

    def get(self):
        return {k: [with_sigopt_suffix(join_signal_name(self.__name_prefix, k), 'ammm'), v]
                for k, v in self.__counters.copy().iteritems()}


class StatServerJsonReactorsTrace(object):
    def __init__(self, name_prefix=None):
        self.__name_prefix = 'reactors_' if name_prefix is None else name_prefix
        self.__counters = Manager().dict()

    def update(self, record):
        for name, trace in record['stat']['modules']['stat_server']['reactors'].iteritems():
            self.__counters[name] = int(trace)

    def get(self):
        return {k: [with_sigopt_suffix(join_signal_name(self.__name_prefix, k), 'ammm'), v]
                for k, v in self.__counters.copy().iteritems()}


class YplatformLogHttpRequest(object):
    def __init__(self, impl, uri_filter):
        self.__uri_filter = (lambda _: True) if uri_filter is None else uri_filter
        self.__impl = impl
        self.__starts = dict()

    def update(self, record):
        key = self.__get_key(record)
        if record.get('event') == 'start':
            self.__starts[key] = record
            return
        if record.get('event') != 'fin':
            return
        if key not in self.__starts:
            return
        start = self.__starts[key]
        del self.__starts[key]
        if self.__uri_filter(start.get('uri')):
            union = dict()
            union.update(start)
            union.update(record)
            self.__impl.update(union)

    def get(self):
        return self.__impl.get()

    @staticmethod
    def __get_key(record):
        return record['conn'], record['req']


class YplatformLogHttpRequestCountByStatus(object):
    def __init__(self, name_prefix, uri_filter=None):
        self.__impl = YplatformLogHttpRequest(
            uri_filter=uri_filter,
            impl=CountByStatus(endpoint_field='uri', name_prefix=name_prefix),
        )

    def update(self, record):
        self.__impl.update(record)

    def get(self):
        return self.__impl.get()


class YplatformLogHttpRequestHist(object):
    key = None

    def __init__(self, name, buckets, uri_filter=None):
        self.__impl = YplatformLogHttpRequest(
            uri_filter=uri_filter,
            impl=Hist(name=name, buckets=buckets, get_value=lambda v: float(v[self.key])),
        )

    def update(self, record):
        self.__impl.update(record)

    def get(self):
        return self.__impl.get()


class YplatformLogHttpRequestTotalTimeHist(YplatformLogHttpRequestHist):
    key = 'total_time'


class YplatformLogHttpRequestBytesInHist(YplatformLogHttpRequestHist):
    key = 'bytes_in'


class YplatformLogHttpRequestBytesOutHist(YplatformLogHttpRequestHist):
    key = 'bytes_out'


def make_hist_from_counter(name, counter):
    return [with_sigopt_suffix(name, 'hgram'), sorted([k, v] for k, v in counter)]


class CoreDumpCountByNameAndSignal(object):
    def __init__(self, name_prefix=None):
        self.__name_prefix = 'cores_' if name_prefix is None else name_prefix
        self.__counters = Manager().dict()

    def update(self, record):
        name, signal = record.name.split('.')[-2:]
        key = (name, signal)
        if key not in self.__counters:
            self.__counters[key] = 1
        else:
            self.__counters[key] += 1

    def get(self):
        return {k: [with_sigopt_suffix(join_signal_name(self.__name_prefix, k[0], k[1]), 'summ'), v]
                for k, v in self.__counters.copy().iteritems()}


def with_sigopt_suffix(value, suffix):
    return value + suffix if value[-1] == '_' or suffix[0] == '_' else '%s_%s' % (value, suffix)


class BucketsCounters(object):
    def __init__(self, start, bucket_size, buckets_count):
        assert bucket_size > 0
        assert buckets_count > 0
        self.__starts = start
        self.__bucket_size = bucket_size
        self.__offset = Value(ctypes.c_size_t, 0)
        self.__buckets = Array(ctypes.c_size_t, buckets_count)

    def add(self, current_time, value):
        with self.__offset.get_lock():
            if current_time >= self.__starts + self.__offset.value * self.__bucket_size:
                self.__offset.value += 1
                with self.__buckets.get_lock():
                    self.__buckets.get_obj()[self.__offset.value % len(self.__buckets.get_obj())] = value
            else:
                with self.__buckets.get_lock():
                    self.__buckets.get_obj()[self.__offset.value % len(self.__buckets.get_obj())] += value

    def get(self):
        with self.__offset.get_lock():
            offset = self.__offset.value
        result = list()
        with self.__buckets.get_lock():
            buckets_count = len(self.__buckets.get_obj())
            for i in range(buckets_count):
                if offset - i < 0:
                    break
                result.append(self.__buckets.get_obj()[(offset - i) % buckets_count])
        return result


def make_buckets(left, mid, right, timeout):
    """
    Makes continuous integer series from given reference values with len(result) == 50
    :return: [0, ..., left, ..., mid, ..., right, ..., timeout, ..., 2*timeout]
    """
    def stepped_range(_left, _right, count):
        assert _right > _left
        step = (_right - _left) / float(count)
        return [_left + step * v for v in range(count)]

    return (
        stepped_range(0, left, 5)
        + stepped_range(left, mid, 20)
        + stepped_range(mid, right, 15)
        + stepped_range(right, timeout, 5)
        + [(1 + 0.2 * i) * timeout for i in range(5)]
    )


class TskvMonitor(object):
    def __init__(self, name, condition):
        self.__name = name
        self.__condition = condition

    def check_record(self, record):
        return self.__condition(record)

    @property
    def name(self):
        return self.__name


class TskvLogCount(object):
    def __init__(self, name_prefix=None, monitors=None):
        self.__name_prefix = 'errors' if name_prefix is None else name_prefix
        self._monitors = self._checked_monitors(monitors)
        self._counters = self._make_counters(self._monitors)

    def update(self, record):
        for monitor in self._monitors:
            if monitor.check_record(record):
                self._counters[monitor.name] += 1

    def get(self):
        return {k: [with_sigopt_suffix(join_signal_name(self.__name_prefix, k), 'summ'), v]
                for k, v in self._counters.copy().iteritems()}

    @staticmethod
    def _checked_monitors(monitors):
        if not monitors:
            return []
        for m in monitors:
            if not isinstance(m, TskvMonitor):
                raise TypeError('The monitor %s is not TskvMonitor type' % m)
        return monitors

    def _make_counters(self, monitors):
        counters = {m.name: 0 for m in monitors}
        return Manager().dict(counters)


class TskvLogErrorsCount(TskvLogCount):
    DEFAULT_ERR = 'unknown'

    def __init__(self, name_prefix=None, monitors=None):
        super(TskvLogErrorsCount, self).__init__(name_prefix, monitors)

    def update(self, record):
        if record.get('level') == 'error':
            name = None
            for monitor in self._monitors:
                if monitor.check_record(record):
                    name = monitor.name
                    self._counters[name] += 1
            if not name:
                self._counters[self.DEFAULT_ERR] += 1

    def _make_counters(self, monitors):
        counters = {m.name: 0 for m in monitors}
        counters[self.DEFAULT_ERR] = 0
        return Manager().dict(counters)


class AccessLogMonitor(object):
    def __init__(self, bucket_size, buckets_count, threshold):
        self.threshold = threshold
        start = time.time()
        self.__all = BucketsCounters(
            start=start,
            bucket_size=bucket_size,
            buckets_count=buckets_count,
        )
        self.__5xx = BucketsCounters(
            start=start,
            bucket_size=bucket_size,
            buckets_count=buckets_count,
        )
        self.__4xx = BucketsCounters(
            start=start,
            bucket_size=bucket_size,
            buckets_count=buckets_count,
        )

    @staticmethod
    def __parse_local_time(time_local):
        try:
            return time.strptime(time_local, '%d/%b/%Y:%H:%M:%S %Z')
        except ValueError:
            return time.strptime(time_local, '%d/%b/%Y:%H:%M:%S +0300')

    def update(self, record):
        record_time = time.mktime(self.__parse_local_time(record['time_local']))
        self.__all.add(current_time=record_time, value=1)
        self.__5xx.add(current_time=record_time, value=record['status'].startswith('5'))
        self.__4xx.add(current_time=record_time, value=record['status'].startswith('4'))

    def get(self):
        total_buckets = self.__all.get()
        errors_5xx_buckets = self.__5xx.get()
        errors_4xx_buckets = self.__4xx.get()
        return {
            500: [
                with_sigopt_suffix('monitor_5xx', 'ammm'),
                get_buckets_with_errors(
                    errors_buckets=errors_5xx_buckets,
                    total_buckets=total_buckets,
                    threshold=self.threshold,
                ),
            ],
            400: [
                with_sigopt_suffix('monitor_4xx', 'ammm'),
                get_buckets_with_errors(
                    errors_buckets=errors_4xx_buckets,
                    total_buckets=total_buckets,
                    threshold=self.threshold,
                ),
            ]
        }


class AccessTskvMonitor(object):
    def __init__(self, bucket_size, buckets_count, threshold):
        self.threshold = threshold
        start = time.time()
        self.__all = BucketsCounters(
            start=start,
            bucket_size=bucket_size,
            buckets_count=buckets_count,
        )
        self.__5xx = BucketsCounters(
            start=start,
            bucket_size=bucket_size,
            buckets_count=buckets_count,
        )
        self.__4xx = BucketsCounters(
            start=start,
            bucket_size=bucket_size,
            buckets_count=buckets_count,
        )

    def update(self, record):
        record_time = time.mktime(parser.parse(record['timestamp']).timetuple())
        self.__all.add(current_time=record_time, value=1)
        self.__5xx.add(current_time=record_time, value=record['status_code'].startswith('5'))
        self.__4xx.add(current_time=record_time, value=record['status_code'].startswith('4'))

    def get(self):
        total_buckets = self.__all.get()
        errors_5xx_buckets = self.__5xx.get()
        errors_4xx_buckets = self.__4xx.get()
        return {
            500: [
                with_sigopt_suffix('monitor_5xx', 'ammm'),
                get_buckets_with_errors(
                    errors_buckets=errors_5xx_buckets,
                    total_buckets=total_buckets,
                    threshold=self.threshold,
                ),
            ],
            400: [
                with_sigopt_suffix('monitor_4xx', 'ammm'),
                get_buckets_with_errors(
                    errors_buckets=errors_4xx_buckets,
                    total_buckets=total_buckets,
                    threshold=self.threshold,
                ),
            ]
        }


def get_buckets_with_errors(errors_buckets, total_buckets, threshold):
    return sum(
        v >= threshold for v in
        (
            errors / float(total) if total else 0
            for errors, total in izip(errors_buckets, total_buckets)
        )
    )


def join_signal_name(*nodes):
    return '_'.join(str(v) for v in nodes if v)


class AccessTskvCount(object):
    def __init__(self, name=None, endpoint=None):
        self.__endpoint = '/' if endpoint is None else '/%s' % endpoint
        self.__name = join_signal_name('xxx' if name is None else name, endpoint)
        self.__counter = Value(ctypes.c_size_t, 0)

    def update(self, record):
        if 'request' in record and self.__endpoint in record['request']:
            with self.__counter.get_lock():
                self.__counter.value += 1

    def get(self):
        with self.__counter.get_lock():
            return [with_sigopt_suffix(self.__name, 'summ'), self.__counter.value]


class AccessTskvCountByFirstStatusDigit(object):
    def __init__(self, name_prefix=None, endpoint=None):
        self.__endpoint = '/' if endpoint is None else '/%s' % endpoint
        self.__name_prefix = join_signal_name('' if name_prefix is None else name_prefix, endpoint)
        self.__counters = Array(ctypes.c_size_t, 10)

    def update(self, record):
        if 'request' in record and 'status_code' in record and self.__endpoint in record['request']:
            with self.__counters.get_lock():
                self.__counters.get_obj()[int(record['status_code'][0])] += 1

    def get(self):
        with self.__counters.get_lock():
            return {k: [with_sigopt_suffix(join_signal_name(self.__name_prefix, str(k) + 'xx'), 'summ'), v]
                    for k, v in enumerate(self.__counters.get_obj()) if v}


class AccessTskvCountByStatus(object):
    def __init__(self, name_prefix=None, endpoint=None):
        self.__endpoint = '/' if endpoint is None else '/%s' % endpoint
        self.__name_prefix = join_signal_name('' if name_prefix is None else name_prefix, endpoint)
        self.__size_array = 600
        self.__counters = Array(ctypes.c_size_t, self.__size_array)

    def update(self, record):
        if 'request' in record and 'status_code' in record and self.__endpoint in record['request']:
            error = int(record['status_code'])
            if error < self.__size_array and error >= 0:
                with self.__counters.get_lock():
                    self.__counters[error] += 1

    def get(self):
        with self.__counters.get_lock():
            return {k: [with_sigopt_suffix(join_signal_name(self.__name_prefix, k), 'summ'), v]
                    for k, v in enumerate(self.__counters) if v}


class AccessTskvRequestTimeHist(object):
    def __init__(self, buckets, name=None, endpoint=None):
        self.__endpoint = '/' if endpoint is None else '/%s' % endpoint
        self.__impl = Hist(
            buckets=buckets,
            name=join_signal_name('access_log_request' if name is None else name, endpoint),
            get_value=lambda v: float(v['profiler_total']),
        )

    def update(self, record):
        if 'request' in record and 'profiler_total' in record and self.__endpoint in record['request']:
            self.__impl.update(record)

    def get(self):
        return self.__impl.get()


class HttpClientTskvRequestCountByStatus(object):
    def __init__(self, name_prefix, uri_filter=None):
        self.__impl = CountByStatus(endpoint_field='uri', name_prefix=name_prefix)
        self.__uri_filter = (lambda _: True) if uri_filter is None else uri_filter

    def update(self, record):
        if 'uri' in record and self.__uri_filter(record.get('uri')):
            self.__impl.update(record)

    def get(self):
        return self.__impl.get()


class HttpClientTskvRequestHist(object):
    def __init__(self, name, buckets, uri_filter=None):
        self.__impl = Hist(name=name, buckets=buckets, get_value=lambda v: float(v['total_time']))
        self.__uri_filter = (lambda _: True) if uri_filter is None else uri_filter

    def update(self, record):
        if 'uri' in record and 'total_time' in record and self.__uri_filter(record.get('uri')):
            self.__impl.update(record)

    def get(self):
        return self.__impl.get()


class HttpClientTskvRequestBytesInHist(object):
    def __init__(self, name, buckets, uri_filter=None):
        self.__impl = Hist(name=name, buckets=buckets, get_value=lambda v: float(v['bytes_in']))
        self.__uri_filter = (lambda _: True) if uri_filter is None else uri_filter

    def update(self, record):
        if 'uri' in record and 'bytes_in' in record and self.__uri_filter(record.get('uri')):
            self.__impl.update(record)

    def get(self):
        return self.__impl.get()
