from __future__ import print_function, absolute_import, division

import os
import time
import random
import hashlib
import itertools
import collections
import traceback as tb
import bz2
import zlib
import msgpack

import gevent
import gevent.queue
import gevent.socket

from kernel import util
from library.config import detect_hostname
from library.sky.hostresolver import Resolver

from .rpc.client import RPCClientGevent as RPCClient
from .utils.greendeblock import Deblock
from .daemon.config import AppConfig


class ReportSender(object):
    def __init__(self, name, log, cfg, rpc_cfg, db, deblock):
        self.name = name
        self.log = log.getChild('sender.{}'.format(name))
        self._db = db
        self._cfg = cfg

        self._rpc = None
        self._rpc_cfg = rpc_cfg
        self._hostsCache = [0, None]
        self._activeHost = None
        self._deblock = deblock
        self.log.debug('Sender "{}" initialized: hosts = {}, plugins = {}'.format(
            self.name,
            self._cfg.resolved_hosts if self._cfg.resolved_hosts else self._cfg.hosts,
            self._cfg.get('plugins', '')
        ))

    def __repr__(self):
        return '<Sender %r>' % self.name

    @property
    def max_attempts(self):
        return len(self._hostsCache[1]) / 3 if self._hostsCache[1] is not None else None

    def _detect_host_clock_skew(self, host):
        log = self.log.getChild('timesync')
        log.debug('Sending time sync packet to %s', host)

        now = time.time()
        result = self._rpc.call('timesync')
        now2 = time.time()

        assert isinstance(result, float), 'Got invalid result for TIMESYNC packet: %r' % result

        clockSkew = result - (now2 + now) / 2
        if clockSkew > 60:
            func = log.warning
        elif clockSkew > 3:
            func = log.info
        else:
            func = log.debug
        func('Detected clock skew %0.6f seconds with %s', clockSkew, host)

        return clockSkew

    def send_reports(self, reports):
        left = sum(map(lambda x: len(x), reports.itervalues()))
        if not left:
            return left

        self.log.info(
            'Sending %d report(s): %r',
            left, dict([(key, len(values)) for key, values in reports.items()])
        )

        host = None
        now = time.time()

        hostname = self._deblock.apply(detect_hostname)
        self.log.debug('Detected hostname: %s', hostname)

        # Get host list
        cache_time, hosts = self._hostsCache
        if now - cache_time > self._cfg.resolve_cache_time:
            try:
                hosts = self._cfg.resolved_hosts if self._cfg.resolved_hosts else \
                    list(self._deblock.apply(Resolver().resolveHosts, self._cfg.hosts))
                self.log.debug('Resolved servers: %s', hosts)
                if hosts:
                    random.shuffle(hosts)
                    hosts = collections.deque(hosts)
                    self._hostsCache = [now, hosts]
            except BaseException:
                # failed to resolve
                self.log.error('Failed to resolve servers. Use cached ones.')
                hosts = self._hostsCache[1]
                # update cache time to avoid DDOS of resolver
                # next try in 1/4 of cfg.resolve_cache_time
                self._hostsCache = [now - self._cfg.resolve_cache_time*0.75, hosts]
        else:
            hosts = self._hostsCache[1]

        # In case we didn't resolve anything
        if not hosts:
            self.log.warning('Got 0 sync hosts')
            return left

        def on_processing_completed(data):
            self.log.debug('Report #%d of type %r reported as processed. Updating the database.', data[0], data[1])
            self._db.query('UPDATE "report" SET "state" = 2, "sent" = ? WHERE "id" = ?', (time.time(), data[0], ))
            on_processing_completed.processed += 1

        on_processing_completed.processed = 0

        try:
            if self._activeHost is None or self._rpc is None:
                self._activeHost = [hosts[0], now, 0, None]  # hostname, time choosen, queries count, clock skew
                self._rpc = RPCClient(hosts[0], self._cfg.hosts_port, self._rpc_cfg)
            host = self._activeHost[0]
            clock_skew = self._activeHost[3] if self._activeHost[3] is not None else 0

            for try_ in range(2):
                self.log.debug('Request to %r (current clock skew is %d).', host, clock_skew)
                mode = '2' if self._activeHost[3] is None else '2S'
                now = time.time() + clock_skew

                with gevent.Timeout(self._cfg.timeout) as timeout:
                    try:
                        result = self._rpc.call(
                            'report',   # Remote method name
                            mode,       # mode, packet version
                            hostname,   # this hostname
                            now,        # corrected (if needed) timestamp this reports are being sent
                            list(itertools.chain.from_iterable(
                                map(lambda data: (data['id'], name, data, ), lst)
                                for name, lst in reports.iteritems()
                            )),
                            stateCallback=on_processing_completed
                        )
                    except gevent.Timeout as ex:
                        if ex != timeout:
                            raise
                        else:
                            self.log.warning(
                                'Failed to heartbeat to %s:%s: timeout (%d secs)',
                                host, self._cfg.hosts_port, self._cfg.timeout
                            )
                            break

                if try_ == 0 and result == 'TIMESYNC':
                    self.log.debug('Remote side asks us to sync time')
                    self._activeHost[3] = self._detect_host_clock_skew(host)
                    continue

                if result != 'OK':
                    self.log.warn('Server %r respond %r', host, result)

                break

        except gevent.GreenletExit:
            raise
        except Exception as ex:
            self.log.warning('Failed to heartbeat to %s:%s: %s', host, self._cfg.hosts_port, tb.format_exc(ex))

        left -= on_processing_completed.processed
        if left:
            self.log.warn('%d report(s) still not processed. Place them back to the queue.', left)
            self._rpc = self._activeHost = None
            hosts.rotate()
            return left

        self.log.info('All reports are processed successfully by %r' % host)

        if self._activeHost[0] is not None:
            self._activeHost[2] += 1  # increase queries counter
        if now - self._activeHost[1] > self._cfg.rotate_hosts_time:
            self.log.info(
                'Made %d queries to host %s for past %d seconds, forcing rotate',
                self._activeHost[2], self._activeHost[0], now - self._activeHost[1]
            )
            # Rotate hosts deque and force host election next time
            self._rpc = self._activeHost = None
            hosts.rotate()
        return left


class DevNullReportSender(object):
    def __init__(self, log, db):
        self.log = log.getChild('devnullsender')
        self._db = db

    @property
    def max_attempts(self):
        return 1

    def send_reports(self, data):
        for type, reports in data.iteritems():
            for report in reports:
                self.log.info('Report: %s', report)
                self._db.query('UPDATE "report" SET "state" = 2, "sent" = ? WHERE "id" = ?',
                               (time.time(), report['id'],))
        return 0


class Reporter(object):
    # Initialization {{{
    def __init__(self, ctx, db):
        self._db = db
        self.ctx = ctx
        self.cfg = ctx.cfg.reporter
        self.log = ctx.log.getChild('reporter')
        self.log.debug('Initializing')

        self._workerGrn = None
        self._queue = gevent.queue.Queue()
        self._deblock = Deblock(logger=self.log.getChild('deblock'))
        self._senders = {}

        self._init_senders(ctx.cfg)

    # Initialization }}}

    # Management (start, stop, join) {{{
    def start(self):
        assert self._workerGrn is None
        self._workerGrn = gevent.spawn(self._workerLoop)
        # Reset sending state on the reports queue and queue them.
        self._db.query('UPDATE "report" SET "state" = 0 WHERE "state" = 1')
        self._repair_db(self.log)
        return self

    def stop(self):
        assert self._workerGrn is not None
        self._workerGrn.kill(gevent.GreenletExit)
        return self

    def join(self):
        self._workerGrn.join()
    # Management (start, stop, join) }}}

    def list(self):
        return self._db.query(
            'SELECT "id", "type", "sent", "expires", "attempts", "state" ' +
            'FROM "report" ORDER BY "expires" ASC'
        )

    def details(self, id_):
        res = self._db.query(
            'SELECT ' +
            ' "id", "type", "added", "updated", "sent", "expires", "format", "data", ' +
            ' "compression", "checksum", "attempts", "state" ' +
            'FROM "report" WHERE "id" = ?',
            (id_, )
        )
        return map(lambda x: str(x) if isinstance(x, buffer) else x, res[0]) if res else None

    def plugin_report(self, plugin_type):
        # find a last report for the specified plugin in DB
        report_raw = self._db.query(
            'SELECT "compression", "format", "data" ' +
            'FROM "report" WHERE "type" = ?',
            (plugin_type, ))

        # default report is empty dict
        report = dict()

        if report_raw:
            try:
                # we have to decompress report if required
                dec = {
                    None: lambda x: x,
                    'bz2': lambda x: bz2.decompress(x),
                    'zip': lambda x: zlib.decompress(x),
                }[report_raw[0][0]]

                data = dec(report_raw[0][2])

                # now we have to unpack (currently we support msgpack only)
                if report_raw[0][1] in ['msgpack', 'raw']:
                    report = msgpack.unpackb(data)

            except:
                # something went wrong, so we just return empty report
                pass

        return report

    def schedulerState(self):
        return dict(self._db.query(
            'SELECT "type", MAX(CASE WHEN "updated" IS NOT NULL THEN "updated" ELSE "added" END) ' +
            'FROM "report" GROUP BY "type";'
        ) or [])

    def report(self, data, sendDelay=None):
        now = time.time()
        inc = data['incremental']
        checksum = hashlib.sha1(data['report']).hexdigest()
        with self._db:
            row = self._db.query(
                'SELECT "id", "checksum", "added" FROM "report" WHERE "type" = ? AND "state" <> 1',
                (data['name'], )
            )
            if not inc and row and checksum == row[0][1]:
                self.log.debug(
                    'Similar report #%d of type %r generated %s ago found.',
                    row[0][0], data['name'], util.td2str(now - row[0][2])
                )
                force = data.get('force', None)
                if not force or force + row[0][2] > now:
                    self._db.query('UPDATE "report" SET "updated" = ? WHERE "id" = ?', (now, row[0][0], ))
                    return
                passed = util.td2str(now - row[0][2])
                if force == float('-inf'):
                    self.log.info('Force report %r sending (%s passed from last send).', data['name'], passed)
                else:
                    self.log.info(
                        'Force report %r sending (passed %s out of %s).',
                        data['name'], passed, util.td2str(force)
                    )

            cmd = 'INSERT'
            fields = map(lambda x: '"%s"' % x, 'type added expires data format compression checksum state'.split())
            values = [
                data['name'], data['end'], data['end'] + data['valid'],
                buffer(data['report']), data['format'], data.get('compression', None), checksum, 0
            ]

            if row and not inc:
                cmd = 'REPLACE'
                fields.insert(0, '"id"')
                values.insert(0, row[0][0])

            self._db.query(
                '%s INTO "report" (%s) VALUES (%s)' % (cmd, ', '.join(fields), ', '.join('?' * len(fields))),
                tuple(values)
            )

            discards = data.get('discards', None)
            if discards:
                self.log.debug("The report %r discards following report type(s): %r", data['name'], discards)
                cur = self._db.query('SELECT "id", "type" FROM "report" WHERE "state" == 0 AND "type" IN (??)', (discards, ))
                discards = [(row[0], str(row[1])) for row in cur]
                if discards:
                    self.log.info(
                        'Following reports are going to be discarded by %r: %s',
                        data['name'], ', '.join('#%d(%r)' % row for row in discards)
                    )
                    self._db.query('DELETE FROM "report" WHERE "id" IN (??)', ([row[0] for row in discards], ))

        self._queue.put((data['name'], None if sendDelay is None else data['end'] + sendDelay, ))

    def _init_senders(self, config):
        # initialize default sender
        default_cfg = AppConfig()
        default_cfg.load(config.reporter)
        self._senders['default'] = ReportSender(
            'default',
            self.log,
            default_cfg,
            config.rpc,
            self._db,
            self._deblock)
        self._senders['dummy'] = DevNullReportSender(self.log, self._db)

        for key, value in config.items():
            if key.startswith('reporter:'):
                # merge configs
                merged_cfg = AppConfig()
                merged_cfg.load(default_cfg)
                merged_cfg.load(value)
                sender = ReportSender(
                    key[len('reporter:'):],
                    self.log,
                    merged_cfg,
                    config.rpc,
                    self._db,
                    self._deblock
                )
                try:
                    plugins = merged_cfg.get('plugins', [])
                    # register sender for specified plugins
                    for plugin in plugins:
                        self._senders[plugin] = sender
                except TypeError:
                    # ignore this sender if 'plugins' is not a list
                    pass

        for plugin, sender in self._senders.items():
            self.log.debug('<Plugin %r> -> %r', plugin, sender)

    def _getMoreReports(self, deadline=None):
        log = self.log.getChild('getreports')

        try:
            while 1:
                name, reportDeadline = self._queue.get(
                    timeout=(
                        max(0, deadline - time.time())
                        if deadline is not None else
                        None
                    )
                )
                log.debug('Got new %r report.', name)

                # Change current delivery deadline
                if reportDeadline is not None and (deadline is None or reportDeadline < deadline):
                    deadline = reportDeadline

                now = time.time()
                if deadline - now > 0:
                    log.debug('Will wait for more reports max %s.', util.td2str(deadline - now))
        except gevent.queue.Empty:
            # We cant wait more for other reports, since earliest
            # deadline was met
            return

    def _workerLoop(self):
        log = self.log.getChild('worker')
        log.info('Started')

        wait_period = self.cfg.send_retry_timeout
        while 1:
            try:
                self._getMoreReports(deadline=time.time() + wait_period)

                drop = []
                now = time.time()
                reports = collections.defaultdict(lambda: collections.defaultdict(list))
                with self._db:
                    repaired = False
                    reports_raw = None
                    self._db.query('UPDATE "report" SET "state" = 1 WHERE "state" = 0')
                    for _ in range(2):
                        try:
                            reports_raw = self._db.query(
                                'SELECT "id", "type", "added", "expires", "data", "format", "compression", "attempts" ' +
                                'FROM "report" WHERE "state" = 1 ORDER BY "expires" ASC'
                            )
                            break
                        except UnicodeDecodeError:
                            if repaired:
                                # we have tried to repair but still failed
                                raise Exception('We have tried but failed to repair database')

                            self._repair_db(log)
                            repaired = True

                    if reports_raw is None:
                        raise Exception('Unexpected reports (None)')

                    for row in reports_raw:
                        report = dict(zip(
                            ['id', 'name', 'end', 'expires', 'report', 'format', 'compression', 'attempts'], row
                        ))
                        # select proper sender
                        sender = self._senders.get(report['name'], self._senders['default'])

                        maxAttempts = sender.max_attempts

                        if report['expires'] < now:
                            log.warn('Dropping expired report #%d of type %r', report['id'], report['name'])
                        elif maxAttempts is not None and report['attempts'] > maxAttempts:
                            log.warn(
                                'Dropping report #%d of type %r - max attempts (%d of %d) reached.',
                                report['id'], report['name'], report['attempts'], maxAttempts
                            )
                        else:
                            report['valid'] = report['expires'] - report['end']
                            report['report'] = str(report['report'])
                            reports[sender][report['name']].append(report)
                            continue
                        drop.append(report['id'])

                    if drop:
                        self._db.query('DELETE FROM "report" WHERE "id" IN (??)', (drop, ))

                if not reports:
                    wait_period = self.cfg.collect_period
                    continue

                left = 0
                for sender, data in reports.items():
                    left += sender.send_reports(data)

                wait_period = self.cfg.send_retry_timeout if left else self.cfg.collect_period

                with self._db:
                    self._db.query(
                        'UPDATE "report" SET "state" = 0, "attempts" = "attempts" + 1 WHERE "state" = 1'
                    )
                    rows = self._db.query(
                        'SELECT "type", COUNT("id"), MAX("added") FROM "report" WHERE "state" = 2 GROUP BY "type"'
                    )
                    drop = []
                    for row in rows:
                        if row[1] > 1:
                            log.info('Dropping out %d extra reports of type %r.', row[1] - 1, row[0])
                            drop.append((row[0], row[2], ))

                    if drop:
                        self._db.query(
                            'DELETE FROM "report" WHERE "state" = 2 AND (%s)' %
                            (' OR '.join(['("type" = ? AND "added" < ?)'] * len(drop))),
                            itertools.chain(*drop)
                        )

            except gevent.GreenletExit:
                log.info('Received stop signal')
                break
            except Exception:
                log.error('Unhandled exception: %s', tb.format_exc())
                os._exit(1)

    def _repair_db(self, log):
        ids = self._db.query_col('SELECT "id" FROM "report"')
        for row_id in ids:
            try:
                self._db.query(
                    'SELECT "type", "added", "expires", "data", "format", "compression", "attempts" '
                    'FROM "report" WHERE "id" = ' + str(row_id)
                )
            except UnicodeDecodeError:
                # try to get more info about broken error
                try:
                    row = self._db.query(
                        'SELECT "type", "added", "expires" FROM "report" WHERE "id" = ' + str(row_id)
                    )
                except UnicodeDecodeError:
                    row = None
                log.error('Broken record: ID = {}, info = {} removed'.format(row_id, row[0] if row else 'Unknown'))
                self._db.query('DELETE FROM "report" WHERE "id" = ' + str(row_id))
