# vim: set foldmethod=marker

from __future__ import absolute_import, print_function, division

import os
import time
import zlib
import random
import functools
import datetime as dt

import msgpack
import gevent
import gevent.event
import gevent.monkey

# Make pymongo compatible with gevent (see http://api.mongodb.org/python/current/examples/gevent.html for details).
gevent.monkey.patch_all()  # noqa
import pymongo
import six

from kernel.util.console import setProcTitle
from kernel.util.errors import formatException

from library.sky.hostresolver.resolver import Resolver
from library.sky.hostresolver.errors import HrSyntaxError, HrHostCheckFailedError
from library.sky.hosts import braceExpansion

from api.config import useGevent

from .rpc.server import RPC
from .rpc.utils import Server
from .web import Web


class CachedValue(object):
    def __init__(self, timeout, oses, hosts):
        self._update_time = time.time()
        self._timeout = timeout
        self.oses = oses
        self.hosts = hosts

    def is_invalid(self):
        return time.time() - self._update_time >= self._timeout


class HeartBeatMaster(object):
    # Initialization {{{
    def __init__(self, ctx):
        self.ctx = ctx
        self.log = ctx.log.getChild('heartbeat-server')
        self.log.debug('Initializing')

        self._stopFlag = gevent.event.Event()
        self._startFlag = gevent.event.Event()

        self._db = None
        self._primary_db = None
        self._server = None
        self._rpc = None
        self._web = None
        self._cache = {}
    # Initialization }}}

    # Management (start, stop, join) {{{
    def start(self):
        assert not self._stopFlag.isSet()

        dbURI = self.ctx.cfg.database.uri
        self.log.info('Connecting mongoDB at following URI: %r', dbURI)
        url = six.moves.urllib_parse.urlparse(dbURI)
        if not url.scheme or url.scheme != 'mongodb' or not url.path:
            raise Exception("Database URI %r passed is invalid." % dbURI)

        if 'replicaSet' in dbURI:
            client = pymongo.MongoReplicaSetClient
        else:
            client = pymongo.MongoClient
        self.log.info('Using %r client', client)

        conn = client(
            dbURI,
            read_preference=pymongo.ReadPreference.SECONDARY_PREFERRED,
        )

        primary_conn = client(
            dbURI,
            read_preference=pymongo.ReadPreference.PRIMARY,
        )

        # Follows a code block to workaround a problem with Python 2.6's `urlparse` bug -
        # it does not strips query parameters out from `path` attribute to appropriate `query` attribute.
        dbname = url.path
        at = dbname.find('?')
        if at > 0:
            dbname = dbname[:at]

        dbname = dbname.strip('/')
        self._db = conn[dbname]
        self._primary_db = primary_conn[dbname]

        self._server = Server(self.ctx).start()
        self._rpc = RPC(self.ctx).start()

        # Register server connection handlers
        self._server.registerConnectionHandler(self._rpc.getConnectionHandler())

        # Register RPC callbacks
        self._rpc.registerHandler('ping', self._ping)
        self._rpc.registerHandler('stop', self._stop)
        self._rpc.registerHandler('shardState', self._shardState)
        self._rpc.registerHandler('shardStateV3', self._shardStateV3)
        self._rpc.registerHandler('instanceStateV3', self._instanceStateV3)
        self._rpc.registerHandler('hostInfo', self._hostInfo)
        self._rpc.registerHandler('hostInfoEx', self._hostInfoEx)
        self._rpc.registerHandler('listHosts', self._listHosts)

        self._web = Web(self.ctx, self._stopFlag.set, self._db, self._primary_db).start()

        self._stopFlag.clear()
        self._startFlag.set()
        return self

    def stop(self):
        assert self._startFlag.isSet()

        self._web.stop().join()
        self._server.stop().join()
        self._rpc.stop().join()
        self._stopFlag.set()
        self._startFlag.clear()
        return self

    def join(self):
        self._stopFlag.wait()
    # Management (start, stop, join) }}}

    @staticmethod
    def _dataChunk(job, compressor, packer, data):
        packed = compressor.compress(packer.pack(data)) if data else compressor.flush()
        if packed:
            job.state(packed)

    # RPC Meths {{{
    @RPC.full
    def _ping(job, self):
        if self._startFlag.wait(timeout=30):
            return self._web.ping()
        job.log.warning('Ping timed out waiting startFlag to become setted (wait 30 seconds)')
        return False

    @RPC.simple
    def _stop(self):
        self._stopFlag.set()

    @staticmethod
    def _shortHostname(host):
        """
        Host name shortener. Cuts off common FQDN suffixes ('yandex,ru', 'search.yandex.net').
        :param host: Host name to be shorted.
        :return: Short host name.
        """
        return host.replace('.yandex.ru', '').replace('.search.yandex.net', '')

    def _genericShardState(self, job, modifiedSince, compress, state):
        from .utils.shardstate import STATE_FIELDS

        assert isinstance(modifiedSince, (int, float, )), \
            'modifiedSince must be numerix (got %r)' % (type(modifiedSince), )
        assert modifiedSince >= 0, 'modifiedSince must be >= 0 (got %r)' % (modifiedSince, )

        packer = msgpack.Packer(use_bin_type=False)
        compressor = zlib.compressobj(compress)
        chunk = functools.partial(self._dataChunk, job, compressor, packer)
        chunk(dict((value[0], key) for key, value in STATE_FIELDS.items()))

        query = {'state': {'$exists': True}}
        if modifiedSince > 0:
            query['last_update'] = {'$gte': dt.datetime.fromtimestamp(modifiedSince)}
            state.ensure_index('last_update')

        for entry in state.find(query):
            state = entry['state']
            chunk((
                self._shortHostname(entry['host']),
                int(time.mktime(entry['last_update'].timetuple())),
                str(state)
            ))

        # Flush the compressor
        chunk(None)
        return True

    @RPC.full
    def _shardState(job, self, modifiedSince, compress=3):
        return self._genericShardState(job, modifiedSince, compress, self._db.shardstate)

    @RPC.full
    def _shardStateV3(job, self, modifiedSince, compress=3):
        return self._genericShardState(job, modifiedSince, compress, self._db.shardstatev3)

    def _genericInstanceState(self, job, modifiedSince, compress, state):
        assert isinstance(modifiedSince, (int, float, )), \
            'modifiedSince must be numeric (got %r)' % (type(modifiedSince), )
        assert modifiedSince >= 0, 'modifiedSince must be >= 0 (got %r)' % (modifiedSince, )

        packer = msgpack.Packer(use_bin_type=False)
        compressor = zlib.compressobj(compress)
        chunk = functools.partial(self._dataChunk, job, compressor, packer)
        chunk('ENDMODELS')

        query = {'state': {'$exists': True}}
        if modifiedSince > 0:
            query['last_update'] = {'$gte': dt.datetime.fromtimestamp(modifiedSince)}
            state.ensure_index('last_update')

        for entry in state.find(query):
            state = entry['state']
            chunk((
                self._shortHostname(entry['host']),
                int(time.mktime(entry['last_update'].timetuple())),
                str(state)
            ))

        # Flush the compressor
        chunk(None)
        return True

    @RPC.full
    def _instanceStateV3(job, self, modifiedSince, compress=3):
        return self._genericInstanceState(job, modifiedSince, compress, self._db.instancestatev3)

    @RPC.full
    def _hostInfo(job, self, compress=3):
        return self._hostInfoInternal(job, compress, None)

    @RPC.full
    def _hostInfoEx(job, self, compress=3, user_hosts=None):
        return self._hostInfoInternal(job, compress, user_hosts)

    def _hostInfoInternal(self, job, compress, user_hosts):
        packer = msgpack.Packer(use_bin_type=False)
        compressor = zlib.compressobj(compress)
        chunk = functools.partial(self._dataChunk, job, compressor, packer)
        resolved_hosts = None

        if user_hosts:
            try:
                command = braceExpansion([user_hosts], True)
                resolved_hosts = Resolver(check_hosts=False).resolveHosts(command)
            except (HrSyntaxError, HrHostCheckFailedError):
                resolved_hosts = 0

            if not resolved_hosts:
                return False
        else:
            user_hosts = 'all'

        # update cache
        for k, v in list(self._cache.items()):
            if v.is_invalid():
                del self._cache[k]

        oses = {}
        hosts = []
        if user_hosts in self._cache:
            oses = self._cache[user_hosts].oses
            hosts = self._cache[user_hosts].hosts
        else:
            if resolved_hosts:
                docs = self._db.hostinfo.find({'os': {'$exists': True}, 'host': {'$in': list(resolved_hosts)}})
            else:
                docs = self._db.hostinfo.find({'os': {'$exists': True}})

            for doc in docs:
                os_info = (doc['os']['name'], doc['os']['arch'], doc['os']['version'], )
                osid = hash(os_info)
                oses.setdefault(osid, os_info)
                hosts.append((doc['host'], osid, ))

            self._cache[user_hosts] = CachedValue(self.ctx.cfg.rpc.calls.host_info.data_cache_timeout, oses, hosts)

        # OS
        for osid, os_info in six.iteritems(oses):
            chunk([osid] + list(os_info))
        chunk('ENDOS')

        # HOSTS
        for hostname, oid in hosts:
            chunk((hostname, oid, ))

        # Flush the compressor
        chunk(None)
        return True

    @RPC.simple
    def _listHosts(self, minReportInterval=None, maxReportInterval=None, fields=('hostname', 'lastContact')):
        # version
        query = {}
        now = time.time()
        hostinfo = self._db.hostinfo

        for field in fields:
            assert field in ('hostname', 'lastContact', 'version'), 'Field %r is not supported' % (field, )
        fieldsMap = {
            'hostname': 'host',
            'lastContact': 'last_update',
            'version': 'skynet.svn.url'
        }
        fields = list(map(lambda x: fieldsMap[x], fields))

        if minReportInterval:
            assert isinstance(minReportInterval, (int, float)), \
                'minReportInterval must be float or int, got %r' % (repr(minReportInterval), )
            lu = query.setdefault('last_update', {})
            lu['$gte'] = dt.datetime.fromtimestamp(now - minReportInterval)

        if maxReportInterval:
            assert isinstance(maxReportInterval, (int, float)), \
                'maxReportInterval must be float or int, got %r' % (repr(maxReportInterval), )
            lu = query.setdefault('last_update', {})
            lu['$lte'] = dt.datetime.fromtimestamp(now - maxReportInterval)

        if 'last_update' in query:
            hostinfo.ensure_index('last_update')

        getters = {
            'host': lambda x: x['host'],
            'last_update': lambda x: time.mktime(x['last_update'].timetuple()),
            'skynet.svn.url': lambda x: x['skynet']['svn']['url'],
        }
        fieldsNum = len(fields) + 1  # Also "_id" field will be returned
        return [
            list(map(lambda x: getters[x](doc), fields))
            for doc in hostinfo.find(query, fields)
            if len(doc) == fieldsNum
        ]
    # RPC Meths }}}


def main(ctx):
    useGevent(True)

    log = ctx.log

    setProcTitle(ctx.cfg.ProgramName)
    log.info('Initializing {0}'.format(ctx.cfg.ProgramName))

    random.seed()

    app = HeartBeatMaster(ctx)
    app.start()

    try:
        app.join()
    except KeyboardInterrupt:
        log.info('Caught SIGINT, stopping daemon')
    except SystemExit as ex:
        if ex.args:
            # We expect exception if this case will have meaningfull message
            log.warning(ex)
        else:
            log.warning('Got SystemExit exception in main loop')
        raise
    except BaseException:
        log.critical('Unhandled exception: %s, exit immidiately!' % formatException())
        os._exit(1)  # pylint: disable=W0212

    app.stop()

    with gevent.Timeout(2) as timeout:
        try:
            app.join()
        except gevent.Timeout as err:
            if err != timeout:
                raise
            log.error('Failed to stop {0!r} gracefully -- timeout occured.'.format(app))
            os._exit(1)  # pylint: disable=W0212

    log.info('Stopped {0}'.format(ctx.cfg.ProgramName))
    os._exit(0)
    return 0
