# coding: utf8
from __future__ import absolute_import

import time
import logging
import itertools
import threading
import collections
import datetime as dt

import requests
import concurrent.futures

import common.utils
import common.patterns


logger = logging.getLogger(__name__)


class LogEntry(collections.namedtuple(
    'LogEntry',
    ['rid', 'ts', 'login', 'raddr', 'method', 'duration', 'source', 'hp']
)):
    __slots__ = ()

    @staticmethod
    def from_dict(req, time_correction=None):
        return LogEntry(
            rid=req['id'],
            ts=(dt.datetime.strptime(req['received'], "%Y-%m-%dT%H:%M:%S.%fZ") +
                dt.timedelta(seconds=(time_correction or 0))),
            login=req['login'],
            raddr=req['raddr'],
            method=req['remote_method'],
            duration=req['duration'],
            source=req['source'],
            hp=req['high_priority']
        )


Snapshot = collections.namedtuple('Shapshot', ['data'])


class Window(object):
    def __init__(self, period, max_uncheck=1000):
        self._period = None
        self.period = period
        self.samples = {}
        self._max_uncheck = max_uncheck
        self._unchecked = 0
        self._lock = threading.Lock()

    @property
    def period(self):
        return self._period

    @period.setter
    def period(self, period):
        self._period = period if isinstance(period, dt.timedelta) else dt.timedelta(seconds=period)

    def _timemark(self):
        now = dt.datetime.utcnow()
        return now - self.period, now

    def __lshift__(self, value):
        threshold, now = self._timemark()

        if value.rid in self.samples:
            return

        if value.ts <= threshold:
            return

        self.samples[value.rid] = value
        self._unchecked += 1
        if self._unchecked > self._max_uncheck:
            self.actualize()

    def actualize(self):
        threshold, now = self._timemark()
        self.samples = {k: v for k, v in self.samples.items() if v.ts > threshold}
        self._unchecked = 0

    def __enter__(self):
        self._lock.acquire()
        return self

    def __exit__(self, exc_type, exc_value, tb):
        self._lock.release()


class ServerStat(common.patterns.Abstract):
    __slots__ = ('host', 'inprogress', 'window', 'window_size', 'last_update', 'last_orphan', 'future')
    __defs__ = (None, [], None, None, None, None, None)


class REQUESTS(common.utils.Enum):
    INPROGRESS = 1
    LAST = 2


class RefuellREST(object):
    def __init__(self, window_size, hosts, poll_period=None, fetch_timeout=None, reqs=None):
        self._executor = concurrent.futures.ThreadPoolExecutor(max_workers=5)
        self._server_stats = [ServerStat(h) for h in hosts]
        self._thread = None
        self.running = concurrent.futures.Future()
        self._window = Window(window_size)
        self.poll_period = dt.timedelta(seconds=poll_period or 2)
        self.fetch_timeout = fetch_timeout or 5
        self.requests = REQUESTS[reqs] if reqs else REQUESTS.LAST

    @property
    def window_period(self):
        return self._window.period

    @window_period.setter
    def window_period(self, value):
        self._window.period = value

    def snapshot(self):
        if self.requests == REQUESTS.LAST:
            self._window.actualize()
            return Snapshot(
                data=list(self._window.samples.values())
            )
        elif self.requests == REQUESTS.INPROGRESS:
            return Snapshot(
                data=list(itertools.chain.from_iterable(st.inprogress for st in self._server_stats))
            )

    def start(self):
        if self._thread:
            raise ValueError('Already running')

        logger.info('Start rest refueller')

        self.running.set_running_or_notify_cancel()
        self._thread = threading.Thread(target=self.run)
        self._thread.start()

    def stop(self):
        if self._thread:
            self.running.set_result(True)
            self._thread.join()
            self._thread = None

    def run(self):
        def load_url(host, timeout):
            url = 'http://{}/api/v1.0/service/status/server'.format(host)
            r = requests.get(url, params={'inprogress': True, 'last': True}, timeout=timeout)
            r.raise_for_status()
            return r.json()

        try:
            while self.running.running():
                futures = {self.running: None}
                timeout = self.poll_period

                for s in self._server_stats:
                    if not s.future:
                        if dt.datetime.utcnow() - s.last_update > self.poll_period if s.last_update else True:
                            s.future = self._executor.submit(load_url, s.host, self.fetch_timeout)

                    if s.future:
                        futures[s.future] = s
                    else:
                        timeout = min(timeout, self.poll_period - (dt.datetime.utcnow() - s.last_update))

                timeout = max(timeout.total_seconds(), 0)
                if not futures:
                    time.sleep(timeout)

                try:
                    for future in concurrent.futures.as_completed(futures, timeout=timeout):
                        if future is self.running:
                            break

                        s = futures.pop(future)
                        s.future = None

                        if future.exception() is not None:
                            logging.error('%r generated an exception: %s', s.host, future.exception())
                        else:
                            res = future.result()

                            # TODO: use /service/time/current and time_correction = time.time() - res['now']
                            time_correction = None
                            s.inprogress = [LogEntry.from_dict(req, time_correction)
                                            for req in res['requests']['inprogress']]

                            for req in res['requests']['last']:
                                self._window << LogEntry.from_dict(req, time_correction)

                            s.window_size = res['requests']['window_size']

                            threshold = dt.datetime.utcnow() - dt.timedelta(seconds=s.window_size)
                            if s.last_update and s.last_update < threshold:
                                s.last_orphan = s.last_update

                            s.last_update = dt.datetime.utcnow()
                except concurrent.futures.TimeoutError:
                    pass
        except Exception:
            logger.exception('RefuellREST exploded')
        finally:
            logger.info('RefuellREST exit')
