import copy
import traceback
import logging
import uuid
import time
import codecs
import socket
import inspect
import hashlib
import json

import pymongo
import msgpack

from genisys.toiler import config


class Toil(object):
    UPDATE_NOT_MORE_OFTEN_THAN = 20
    TTL_BEFORE_REMOVE = 60 * 60 * 24 * 7

    def __init__(self, database, stats, processor_cls,
                 record, lock_id, lock_ttl, forced=False):
        self.db = database
        self.stats = stats
        self.processor = processor_cls
        if record['value'] is not None:
            record['value'] = _deserialize(record['value'])
        record['source'] = _deserialize(record['source'])
        self.record = record
        self.lock_id = lock_id
        self.process_started = _get_ts()
        self.last_updated = self.process_started
        self.lock_ttl = lock_ttl
        self.logger = logging.getLogger(
            'genisys.toil.{}'.format(record['vtype'])
        )
        self.forced = forced
        self.proclog = [(self.process_started,
                         "started on {}".format(socket.gethostname()))]
        if forced:
            self.proclog.append((self.process_started, "doing forced update"))
        procgen = processor_cls(database, copy.deepcopy(record), forced=forced)
        assert inspect.isgenerator(procgen)
        self.prociter = iter(procgen)
        self._log_info('created, lock_id=%(lock_id)s', lock_id=lock_id)
        self._log_debug('source=%(source)r', source=record['source'])
        self._log_debug('old value=%(value)r', value=self.record['value'])

    def _check_outdated(self):
        if not self.record['ttime'] or not self.record['atime']:
            return
        gap = self.record['ttime'] - self.record['atime']
        msg = "time between last access and last touch: {} = {} - {}".format(
            int(gap), int(self.record['atime']), int(self.record['ttime'])
        )
        self.proclog.append((_get_ts(), msg))
        self._log_info(msg)
        if gap <= self.TTL_BEFORE_REMOVE:
            return False
        msg = "record is outdated and is to be removed ({} > {})".format(
            gap, self.TTL_BEFORE_REMOVE
        )
        self.proclog.append((_get_ts(), msg))
        self._log_info(msg)
        self.stats.incr('deleted')
        self.db.volatile.delete_one({'vtype': self.record['vtype'],
                                     'key': self.record['key']})
        return True

    def _log(self, level, msg, **kwargs):
        if kwargs:
            msg = msg % kwargs
        msg = "key={}: {}".format(self.record['key'], msg)
        self.logger.log(level, msg)

    def _log_info(self, msg, **kwargs):
        self._log(logging.INFO, msg, **kwargs)

    def _log_warning(self, msg, **kwargs):
        self._log(logging.WARNING, msg, **kwargs)

    def _log_debug(self, msg, **kwargs):
        self._log(logging.DEBUG, msg, **kwargs)

    def _log_error(self, msg, **kwargs):
        self._log(logging.ERROR, msg, **kwargs)

    def run(self):
        if self._check_outdated():
            return
        value = None
        status = None
        new_meta = self.record['meta']
        try:
            while True:
                try:
                    logline = next(self.prociter)
                except StopIteration as exc:
                    value, new_meta = exc.value
                    status = "modified" if value is not None else "postponed"
                    break
                else:
                    if not self._handle_proc_iteration(logline):
                        self.prociter.close()
                        return
        except ProcError as exc:
            msg = str(exc)
            if msg:
                self._log_info('processor reports error: {}'.format(msg))
                self.proclog.append((_get_ts(), msg))
            else:
                self._log_info('processor reports error, postponing update')
            status = "error"
            value = None
        except Exception as exc:
            tb = traceback.format_exc()
            self._log_error(tb)
            self._log_info('error in processor occured, postponing update')
            self.proclog.append((_get_ts(), tb))
            status = "error"
            value = None

        self.register_result(value, new_meta, status)

    def _handle_proc_iteration(self, logline):
        now = _get_ts()

        if logline is not None:
            self._log_info("processor: %(logline)s", logline=logline)
            self.proclog.append((now, logline))

        time_since_updated = now - self.last_updated
        if time_since_updated < self.UPDATE_NOT_MORE_OFTEN_THAN:
            return True

        result = self.db.volatile.update_one(
            {'_id': self.record['_id'], 'lock_id': self.lock_id},
            {'$set': {'etime': now + self.lock_ttl}}
        )
        if result.matched_count == 0:
            self._log_warning(
                'missing record after %(time_since_updated).1f seconds '
                'of processing, terminating the processor',
                time_since_updated=time_since_updated
            )
            return False

        now = _get_ts()
        self.last_updated = now

        log_msg = 'updated etime after {:.1f} seconds of processing'.format(
            time_since_updated
        )
        self._log_info(log_msg)
        return True

    def register_result(self, value, new_meta, status):
        now = _get_ts()
        time_spent = now - self.process_started
        value_changed = ((value is not None)
                         and ((self.record['value'] != value) or self.forced))
        assert status is not None
        if status == "modified" and not value_changed:
            status = "same"

        value_serialized = None
        if value is not None and value_changed:
            try:
                value_serialized = _serialize(value)
            except:
                tb = traceback.format_exc()
                self._log_error(tb)
                self._log_info('failed to serialize result, postponing update')
                self.proclog.append((_get_ts(), tb))
                status = "error"
                value = None
                value_changed = False

        if value is not None:
            if hasattr(self.processor, 'get_result_ttl'):
                ttl = self.processor.get_result_ttl(self.record)
            else:
                ttl = self.processor.RESULT_TTL
        else:
            # calculating progressive timeout depending on how many times
            # we have postponed calculation since last successfull attempt
            base = config.POSTPONE_BACKOFF_BASE
            backoff_inc = int(base ** self.record.get('pcount', 0)) - 1
            ttl = min(config.MAX_POSTPONE_TIME,
                      config.MIN_POSTPONE_TIME + backoff_inc)

        self.stats.incr('vtype.{} {}'.format(self.record['vtype'], status))
        log_msg = 'finished processing, spent {time_spent:.1f} seconds ' \
                  'in total, new ttl={ttl}, value {changed}, new ' \
                  'status={status}, new meta={meta}'.format(
            time_spent=time_spent, ttl=ttl, status=status, meta=new_meta,
            changed=('has changed' if value_changed else 'has not changed')
        )
        self._log_info(log_msg)
        self.proclog.append((now, log_msg))
        self._log_debug('value=%(value)r', value=value)

        update = {
            "$set": {
                'etime': now + ttl,
                'ttime': now,
                'locked': False,
                'lock_id': None,
                'proclog': self.proclog,
                'last_status': status,
                'meta': new_meta,
            },
            "$inc": {
                "tcount": 1,
            },
        }

        if value is not None:
            # set "update time" only if processing was not postponed
            update['$set']['utime'] = now
            update['$inc']['ucount'] = 1
            # reseting postponed count and error count
            update['$set']['ecount'] = 0
            update['$set']['pcount'] = 0
            if value_changed:
                update['$set']['value'] = value_serialized
                update['$set']['mtime'] = now
                update['$inc']['mcount'] = 1
        else:
            # processing was postponed
            update['$inc']['pcount'] = 1
            if status == "error":
                # processing failed
                update['$inc']['ecount'] = 1

        update_info = copy.deepcopy(update)
        if 'value' in update_info['$set']:
            update_info['$set']['value'] = '<skipped>'
        update_info['$set']['proclog'] = '<skipped>'
        update_info = json.dumps(update_info, sort_keys=True)
        self.proclog.append((now, "updating record with %s" % (update_info, )))

        result = self.db.volatile.update_one(
            {'_id': self.record['_id'], 'lock_id': self.lock_id}, update
        )
        if result.matched_count == 0:
            self._log_warning('missing record when the result is finally '
                              'ready, sigh')
            return
        if value is not None and hasattr(self.processor, 'postprocess_value'):
            self.processor.postprocess_value(self.record, value)
        self.logger.info('successfully stored the result, exiting')


class Toiler(object):
    LOCK_TTL = 60

    def __init__(self, database, registry, stats,
                 forced_update_from_mtime=None, is_eager=False):
        self.db = database
        self.registry = registry
        self.stats = stats
        self.logger = logging.getLogger('genisys.toiler')
        self.logger.info('ready to toil. supported vtypes are: %(vtypes)r',
                         dict(vtypes=list(registry.keys())))
        self.is_eager = is_eager
        self.forced_update_from_mtime = forced_update_from_mtime
        if self.forced_update_from_mtime is not None:
            self.logger.info('doing forced update of records with mtime < %s',
                             self.forced_update_from_mtime)
        self._running = True

    def _get_lock_id(self):
        return uuid.uuid4()

    def _step(self):
        record, lock_id = self._acquire_record()
        if record is None:
            self.logger.info('no record available to work on')
            time.sleep(config.TOILER_SLEEP_ON_NOTHING_TO_DO)
        else:
            self._process_record(record, lock_id)

    def stop(self):
        self._running = False

    def run(self):
        while self._running:
            self._step()

    def _acquire_record(self):
        lock_id = self._get_lock_id()
        now = _get_ts()
        update = {'$set': {
            'etime': now + self.LOCK_TTL,
            'locked': True,
            'lock_id': lock_id,
        }}
        sort = [('locked', pymongo.ASCENDING), ('etime', pymongo.ASCENDING)]
        filter_ = {'vtype': {'$in': list(sorted(self.registry.keys()))}}

        if self.forced_update_from_mtime is not None:
            filter_['mtime'] = {'$lt': self.forced_update_from_mtime}

        # first try to find stall records (locked and expired)
        record = self.db.volatile.find_one_and_update(
            dict(filter_, locked=True, etime={'$lt': now}),
            update=update, sort=sort,
            return_document=pymongo.ReturnDocument.BEFORE
        )
        if record is not None:
            self.stats.incr('stalled_records')
            return record, lock_id

        # then try to find an unlocked expired record
        record = self.db.volatile.find_one_and_update(
            dict(filter_, locked=False, etime={'$lt': now}),
            update=update, sort=sort,
            return_document=pymongo.ReturnDocument.BEFORE
        )
        if record is not None:
            return record, lock_id

        if not self.is_eager:
            return None, None

        # at least get an unlocked record with soonest expiration time
        record = self.db.volatile.find_one_and_update(
            dict(filter_, locked=False),
            update=update, sort=sort,
            return_document=pymongo.ReturnDocument.BEFORE
        )
        return record, lock_id

    def _process_record(self, record, lock_id):
        self.logger.info(
            'got a record to work on: vtype=%(vtype)r key=%(key)r '
            'locked=%(locked)r lock_id=%(lock_id)r etime=%(etime)r',
            record
        )
        now = _get_ts()
        if record['etime'] < now:
            seconds_behind = now - record['etime']
            self.stats.incr('lag', seconds_behind)
            if seconds_behind > 10:
                self.logger.warning('we are %.1f seconds behind schedule!',
                                    seconds_behind)
        else:
            self.logger.info('we are %.1f seconds ahead of schedule',
                             record['etime'] - now)

        forced = False
        if self.forced_update_from_mtime is not None:
            forced = True

        processor_cls = self.registry[record['vtype']]
        toil = Toil(self.db, self.stats, processor_cls, record=record,
                    lock_id=lock_id, lock_ttl=self.LOCK_TTL,
                    forced=forced)
        with self.stats.timer('vtype.{} process'.format(record['vtype'])):
            toil.run()


def _deserialize(data):
    return msgpack.loads(codecs.decode(data, 'zip'), encoding='utf-8')

def _serialize(data):
    return codecs.encode(msgpack.dumps(data, encoding='utf-8'), 'zip')

def _get_ts():
    return time.time()


def get_volatiles(db, vtype, keys, with_values=True):
    update_volatiles_atime(db, vtype, keys)
    result = {}
    projection = {'_id': False}
    if not with_values:
        projection['value'] = False
    records = db.volatile.find({"vtype": vtype, "key": {"$in": keys}},
                               projection)
    for record in records:
        if with_values and record['value'] is not None:
            record['value'] = _deserialize(record['value'])
        record['source'] = _deserialize(record['source'])
        result[record['key']] = record
    return result

def update_volatiles_atime(db, vtype, keys):
    db.volatile.update_many({'vtype': vtype, 'key': {'$in': keys}},
                            {'$set': {'atime': _get_ts()}})

def volatile_key_hash(val):
    if isinstance(val, bytes):
        pass
    elif isinstance(val, (int, bool)):
        val = str(val).encode('utf8')
    elif isinstance(val, str):
        val = val.encode('utf8')
    else:
        raise ValueError('can not hash value of type {}'.format(type(val)))
    return hashlib.sha1(val).hexdigest()

class ProcError(Exception):
    pass
