import time
import logging
import datetime

import pymongo.errors

from . import config as config_format


class StorageError(RuntimeError):
    pass


class Mongo(object):
    version = 1

    def __init__(self, collection, timestamps_collection):
        self._coll = collection
        self._timestamps = timestamps_collection
        self._logger = logging.getLogger('MongoStorage({})'.format(collection.name))

    def store(self, configs):
        try:
            self._store(configs)
        except pymongo.errors.PyMongoError:
            logging.exception('Unable to store configs')
            raise StorageError()

    def load(self, timestamp=None):
        try:
            return self._load(timestamp or self._head_timestamp())
        except pymongo.errors.PyMongoError:
            raise StorageError()

    @property
    def head_timestamp(self):
        try:
            return self._head_timestamp()
        except pymongo.errors.PyMongoError:
            raise StorageError()

    def _store(self, configs):
        timestamp = _datetime_to_timestamp(datetime.datetime.now())
        bulk = self._coll.initialize_unordered_bulk_op()

        for (host, port), config in configs.items():
            assert isinstance(port, int)
            bulk.find({
                _Keys.Host: host,
                _Keys.Port: port,
                _Keys.Timestamp: timestamp,
            }).upsert().replace_one({
                _Keys.Host: host,
                _Keys.Port: port,
                _Keys.Timestamp: timestamp,
                _Keys.Content: config,
            })

        bulk.execute({'w': 3, 'wtimeout': 5 * 1000})  # ms
        self._timestamps.insert({
            _TimestampsKeys.Timestamp: timestamp,
            _TimestampsKeys.Version: self.version
        })

    def _load(self, timestamp):
        configs = {}
        for rec in self._coll.find({_Keys.Timestamp: timestamp}):
            host, port = rec[_Keys.Host], rec[_Keys.Port]
            configs[host, port] = rec[_Keys.Content]
        return configs

    def _head_timestamp(self):
        for rec in self._timestamps.find({}, sort=[(_TimestampsKeys.Timestamp, -1)], limit=1):
            return rec[_TimestampsKeys.Timestamp]

    def make_cleaner(self):
        return Cleaner(self._coll, self._timestamps)

    def __str__(self):
        return 'Mongo({})'.format(self._coll.name)


class Cleaner(Mongo):
    def __init__(self, collection, timestamps_collection):
        super(Cleaner, self).__init__(collection, timestamps_collection)
        self._logger = logging.getLogger('StorageCleaner({})'.format(collection.name))

    def list_old_timestamps(self, time_to_skip, timestamps_to_skip):
        lst = []
        some_time_ago = _datetime_to_timestamp(datetime.datetime.now() - time_to_skip)
        for rec in self._timestamps.find(
            {_TimestampsKeys.Timestamp: {'$lt': some_time_ago}},
            sort=[(_TimestampsKeys.Timestamp, -1)],
            skip=timestamps_to_skip,
        ):
            lst.append(rec[_TimestampsKeys.Timestamp])
        return lst

    def remove_timestamp(self, timestamp):
        assert isinstance(timestamp, int)
        assert timestamp < self.head_timestamp
        days_ago = (int(time.time()) - timestamp) / 3600 / 24
        assert days_ago > 3
        self._coll.remove({_Keys.Timestamp: timestamp})
        self._timestamps.remove({_TimestampsKeys.Timestamp: timestamp})
        self._logger.info('removed timestamp %i (%i days old)', timestamp, days_ago)


class MongoOverride(object):
    def __init__(self, collection):
        self._coll = collection

    def update(self, configs):
        try:
            self._update(configs)
        except pymongo.errors.PyMongoError:
            raise StorageError()

    def load(self):
        try:
            return self._load()
        except pymongo.errors.PyMongoError:
            raise StorageError()

    def remove(self, hosts_ports):
        try:
            return self._remove(hosts_ports)
        except pymongo.errors.PyMongoError:
            raise StorageError()

    def _update(self, configs):
        bulk = self._coll.initialize_unordered_bulk_op()
        timestamp = _datetime_to_timestamp(datetime.datetime.now())

        for (host, port), config in configs.items():
            assert isinstance(port, int)
            bulk.find({
                _Keys.Host: host,
                _Keys.Port: port,
            }).upsert().replace_one({
                _Keys.Host: host,
                _Keys.Port: port,
                _Keys.Content: config,
                _Keys.Timestamp: timestamp,
            })
        bulk.execute({'w': 3, 'wtimeout': 5 * 1000})  # ms

    def _load(self):
        configs = {}
        for rec in self._coll.find():
            try:
                host, port = rec[_Keys.Host], rec[_Keys.Port]
                configs[host, port] = config_format.Config(rec[_Keys.Content], rec[_Keys.Timestamp])
            except KeyError:
                pass
        return configs

    def _remove(self, hosts_ports):
        for (host, port) in hosts_ports:
            self._coll.remove({
                _Keys.Host: host,
                _Keys.Port: port,
            })


def make_storage(collection):
    timestamps_collection = collection.timestamps
    collection.ensure_index([
        (_Keys.Timestamp, -1),
        (_Keys.Host, 1),
        (_Keys.Port, 1),
    ], unique=True)
    timestamps_collection.ensure_index([
        (_TimestampsKeys.Timestamp, -1)
    ], unique=True)
    return Mongo(collection, timestamps_collection)


def make_override_storage(collection):
    collection.ensure_index([
        (_Keys.Host, 1),
        (_Keys.Port, 1),
    ], unique=True)
    return MongoOverride(collection)


def _datetime_to_timestamp(dt):
    return int(dt.strftime('%s'))


class _Keys(object):
    Host = 'h'
    Port = 'p'
    Timestamp = 'ts'
    Content = 'c'


class _TimestampsKeys(object):
    Timestamp = 'timestamp'
    Version = 'version'
