import cPickle
import logging
import datetime as dt

from sandbox.yasandbox.database import mapping
from sandbox.yasandbox.controller import statistics as statistics_controller


class State(object):
    Model = mapping.State
    SERVER_TRESHOLD = 1  # in days

    @classmethod
    def initialize(cls):
        cls.Model.ensure_indexes()
        cls.logger = logging.getLogger(__name__)

    @classmethod
    def get(cls, id_):
        """
        Return state model

        :param id_: server name
        :return: mapping.State object with id_ or None if there is no client with id_
        """
        return cls.Model.objects.with_id(id_)

    @classmethod
    def update_model(cls, model, data):
        """
        Update model from data

        :param model: state model
        :param data: dict with data about state
        """
        now = dt.datetime.utcnow()
        model.updated = now
        for shard_name, shard in data.get("shards", {}).iteritems():
            if shard_name in model.shards:
                model.shards[shard_name].update_info(shard)
                model.shards[shard_name].updated = now
            else:
                shard_model = mapping.State.Shard(name=shard_name, updated=now, information=cPickle.dumps(shard))
                model.add_shard(shard_model)

    @classmethod
    def create_model(cls, name, data):
        """
        Create new model from data

        :param name: server name
        :param data: dict with data about state
        :return: mapping.State object
        """
        now = dt.datetime.utcnow()
        server = mapping.State(name=name, updated=now, shards_info=[])
        for shard_name, shard in data.get("shards", {}).iteritems():
            shard_model = mapping.State.Shard(name=shard_name, updated=now, information=cPickle.dumps(shard))
            server.add_shard(shard_model)
        return server

    @classmethod
    def remove_unused_servers(cls):
        """ Remove servers and shards with old updated fields """
        now = dt.datetime.utcnow()
        servers = list(mapping.State.objects())
        available_servers = {}

        for server in servers:
            if server.updated < now - dt.timedelta(days=cls.SERVER_TRESHOLD):
                cls.logger.info("Server %s is out of date. Remove it from state.", server.name)
                server.delete()
            else:
                available_servers[server.name] = server
                available_shards = []
                for shard in server.shards_info:
                    if shard.updated >= now - dt.timedelta(days=cls.SERVER_TRESHOLD):
                        available_shards.append(shard)
                    else:
                        cls.logger.info("Shard %s is out of date. Remove it from state.", shard.name)
                if len(available_shards) != len(server.shards_info):
                    server.shards_info = available_shards
        return available_servers

    @classmethod
    def group_mongo_map(cls, data):
        mongo_map = {}
        for shard in data:
            for r in shard["members"]:
                server_name = r["name"].split(".", 1)[0]
                r["replicaset"] = shard["replicaset"]
                mongo_map.setdefault(server_name, {}).setdefault("shards", {})[r["name"]] = r
        return mongo_map

    @classmethod
    def update_state_infromation(cls):
        """ Actualize all objects in state collection """
        try:
            data = statistics_controller.Statistics.db_shards_status()
        except Exception as ex:
            cls.logger.error("Can't get information about shards.", exc_info=ex)
            return

        mongo_map = cls.group_mongo_map(data)

        servers = cls.remove_unused_servers()

        for server_name, server in mongo_map.iteritems():
            if server_name in servers:
                cls.update_model(servers[server_name], server)
            else:
                server = cls.create_model(server_name, server)
                servers[server_name] = server

        for server in servers.itervalues():
            server.save()
