import infra.callisto.controllers.multibeta.observed as observed
import infra.callisto.controllers.multibeta.slots as slots
import infra.callisto.controllers.sdk.tier as tiers
import infra.callisto.controllers.utils.gencfg_api as gencfg_api
import infra.callisto.controllers.utils.yp_utils as yp_utils

import logging


class WebBaseProvidedCtrl(slots.BaseProvidedCtrl):
    def __init__(self, instance_provider, replication, slots_ids, porto_properties=None):
        super(WebBaseProvidedCtrl, self).__init__(
            instance_provider, replication, slots_ids, porto_properties
        )
        self._observed = observed.BasesearchYpObserved(
            instance_provider.tier, replication=replication,
            required_total_ratio=0.9,
            required_replicas_ratio=0.6,
            allow_lack_shards_ratio=(0 if instance_provider.tier != tiers.WebTier1 else 0.005),
        )


class IntCtrl(slots.ComponentCtrl):
    def __init__(self, slots_agents):
        super(IntCtrl, self).__init__()
        self._slots_agents = slots_agents
        self._agents = {agent for agents in slots_agents.values() for agent in agents}
        self._observed = observed.IntObserved(slots_agents, 0.90)
        self._tags = {'a_ctype_hamster', 'a_prj_web-main'}

    @property
    def tags(self):
        return self._tags

    @property
    def slots_ids(self):
        return self._slots_agents.keys()

    def update(self, reports):
        reports = {agent: report.data for agent, report in reports.items() if agent in self._agents}
        self._observed.update(reports)

    def gencfg(self):
        result = {}

        for slot_id, configurations in self._slots_targets.items():
            if slot_id in self._slots_agents:
                for agent in self._slots_agents[slot_id]:
                    cfg = {
                        slot_id: {
                            'instances': self._generate_slot_config(agent, configurations),
                            'default': self._slots_default_revisions.get(slot_id),
                        }
                    }
                    result[agent] = {'slots': cfg}

        return result

    def _generate_configuration_config(self, agent, configuration):
        return {
            'timestamp': configuration.revision,  # TODO: legacy
            'revision': configuration.revision,
            'conf_hash': hash(configuration),
            'resources': [
                slots.resolve_resource('int.executable', configuration.int.executable),
                self._cfg_resource(agent, configuration.int)
            ],
        }

    @staticmethod
    def _cfg_resource(agent, conf_int):
        resource = slots.resolve_resource('int.cfg', conf_int.config).copy()
        resource['extract_file'] = conf_int.config_path.format(short_host=agent.short_host, port=agent.port)
        return resource


class IntRankingCtrl(slots.ComponentCtrl):
    _default_porto_properties = {
        'cpu_policy': 'idle',
        'respawn_delay': '15s',
    }

    def __init__(self, instance_provider, replication, slots_ids, porto_properties=None):
        super(IntRankingCtrl, self).__init__()
        self._instance_provider = instance_provider
        self._slots_ids = slots_ids

        pod_count = self._instance_provider.count_pods()
        if pod_count % replication != 0:
            raise ValueError('Pod count ({}) should be divisible by replication ({})'.format(pod_count, replication))
        self._partition_count = pod_count // replication

        agents = list(instance_provider.agents)
        self._observed = observed.IntObserved(dict.fromkeys(slots_ids, agents), 0.90)

        self._porto_properties = self._default_porto_properties.copy()
        self._porto_properties.update(porto_properties or {})

    @property
    def tags(self):
        return self._instance_provider.tags

    @property
    def slots_ids(self):
        return self._slots_ids

    def update(self, reports):
        agents = frozenset(self._instance_provider.agents)
        reports = {
            agent: report.data
            for agent, report in reports.iteritems()
            if agent in agents
        }
        self._observed.update(reports)

    def gencfg(self):
        configs = {}
        for agent, instance in self._instance_provider.agents_instances.iteritems():
            slots_config = {}
            for slot_id, configurations in self._slots_targets.items():
                if slot_id in self.slots_ids:
                    slots_config[slot_id] = {
                        'instances': self._generate_slot_config(instance, configurations),
                        'default': self._slots_default_revisions.get(slot_id),
                    }
            configs[agent] = {'slots': slots_config}

        return configs

    def _generate_configuration_config(self, instance, configuration):
        return {
            'revision': configuration.revision,
            'conf_hash': hash(configuration),
            'container': self._porto_properties,
            'resources': [
                slots.resolve_resource('int.executable', configuration.int.executable),
                slots.resolve_resource('int.models', configuration.int.models),
                self._cfg_resource('int.cfg', instance, configuration.int)
            ],
        }

    def _cfg_resource(self, name, instance, config):
        resource = slots.resolve_resource(name, config.config).copy()
        agent = instance.get_agent()
        if config.config_path:
            resource['extract_file'] = config.config_path.format(
                short_host=agent.short_host, port=instance.port, tier=None,
                podset_id=(instance.podset if hasattr(instance, 'podset') else None),
                shard_number=instance.pod_index % self._partition_count,
            )
        return resource

    def __str__(self):
        return '{}({})'.format(self.__class__.__name__, self._instance_provider.tier)


class KeyInvCtrl(WebBaseProvidedCtrl):
    def _generate_configuration_config(self, instance, configuration):
        return {
            'revision': configuration.revision,
            'conf_hash': hash(configuration),
            'container': self._porto_properties,
            'resources': [
                slots.resolve_resource('keyinv.executable', configuration.keyinv.executable),
                self._cfg_resource('keyinv.configs.server', instance, configuration.keyinv.configs.server),
                self._cfg_resource('keyinv.configs.keyinv', instance, configuration.keyinv.configs.keyinv),
            ],
        }

    def _cfg_resource(self, name, instance, config):  # TODO: make common
        resource = slots.resolve_resource(name, config.resource).copy()
        agent = instance.get_agent()
        if config.path:
            resource['extract_file'] = config.path.format(
                short_host=agent.short_host, port=agent.port, tier=self._instance_provider.tier.name,
                podset_id=(instance.podset if hasattr(instance, 'podset') else None),
                shard_number=instance.pod_index % self._instance_provider.tier.shards_count,  # TODO: fix!
            )
        return resource


class InvIndexCtrl(WebBaseProvidedCtrl):
    def _generate_configuration_config(self, instance, configuration):
        return {
            'revision': configuration.revision,
            'conf_hash': hash(configuration),
            'container': self._porto_properties,
            'resources': [
                slots.resolve_resource('invindex.executable', configuration.invindex.executable),
                self._cfg_resource('invindex.configs.server', instance, configuration.invindex.configs.server),
                self._cfg_resource('invindex.configs.invindex', instance, configuration.invindex.configs.invindex),
            ],
        }

    def _cfg_resource(self, name, instance, config):  # TODO: make common
        resource = slots.resolve_resource(name, config.resource).copy()
        agent = instance.get_agent()
        if config.path:
            resource['extract_file'] = config.path.format(
                short_host=agent.short_host, port=agent.port, tier=self._instance_provider.tier.name,
                podset_id=(instance.podset if hasattr(instance, 'podset') else None),
                shard_number=instance.pod_index % self._instance_provider.tier.shards_count,  # TODO: fix!
            )
        return resource


class EmbeddingCtrl(WebBaseProvidedCtrl):
    def _generate_configuration_config(self, instance, configuration):
        return {
            'revision': configuration.revision,
            'conf_hash': hash(configuration),
            'container': self._porto_properties,
            'resources': [
                slots.resolve_resource('embedding.executable', configuration.embedding.executable),
                slots.resolve_resource('embedding.models', configuration.embedding.models),
                self._cfg_resource('embedding.configs.server', instance, configuration.embedding.configs.server),
                self._cfg_resource('embedding.configs.storage', instance, configuration.embedding.configs.storage),
                self._cfg_resource('embedding.configs.replicas', instance, configuration.embedding.configs.replicas),
            ]
        }

    def _cfg_resource(self, name, instance, config):  # TODO: make common
        resource = slots.resolve_resource(name, config.resource).copy()
        agent = instance.get_agent()
        if config.path:
            resource['extract_file'] = config.path.format(
                short_host=agent.short_host, port=agent.port, tier=self._instance_provider.tier.name,
                podset_id=(instance.podset if hasattr(instance, 'podset') else None),
                shard_number=instance.pod_index % self._instance_provider.tier.shards_count,  # TODO: fix!
            )
        return resource


def make_deploy_base_ctrls(base_slots):
    ctrls = []
    for multibeta_slot_ids, tier, deploy_slot, replication in base_slots:
        ctrls.append(WebBaseProvidedCtrl(
            yp_utils.InstanceProvider(
                [deploy_slot],
                report_tags={'a_itype_base', deploy_slot.deploy_unit},
                tier=tier,
                cache_for=1800
            ),
            replication,
            multibeta_slot_ids,
        ))

    return ctrls


def make_deploy_embedding_ctrls(embedding_slots):
    ctrls = []
    for multibeta_slot_ids, tier, deploy_slot, replication in embedding_slots:
        ctrls.append(EmbeddingCtrl(
            yp_utils.InstanceProvider(
                [deploy_slot],
                report_tags={'a_itype_embedding', deploy_slot.deploy_unit},
                tier=tier,
                cache_for=1800
            ),
            replication,
            multibeta_slot_ids,
        ))

    return ctrls


def make_deploy_invindex_ctrls(invindex_slots):
    ctrls = []
    for multibeta_slot_ids, tier, deploy_slot, replication in invindex_slots:
        ctrls.append(InvIndexCtrl(
            yp_utils.InstanceProvider(
                [deploy_slot],
                report_tags={'a_itype_invindex', deploy_slot.deploy_unit},
                tier=tier,
                cache_for=1800
            ),
            replication,
            multibeta_slot_ids,
        ))

    return ctrls


def make_deploy_keyinv_ctrls(keyinv_slots):
    ctrls = []
    for multibeta_slot_ids, tier, deploy_slot, replication in keyinv_slots:
        ctrls.append(KeyInvCtrl(
            yp_utils.InstanceProvider(
                [deploy_slot],
                report_tags={'a_itype_keyinv', deploy_slot.deploy_unit},
                tier=tier,
                cache_for=1800
            ),
            replication,
            multibeta_slot_ids,
        ))

    return ctrls


def make_deploy_intranking_ctrls(slots):
    ctrls = []
    for multibeta_slot_ids, deploy_slot, replication in slots:
        ctrls.append(IntRankingCtrl(
            yp_utils.InstanceProvider(
                [deploy_slot],
                report_tags={'a_itype_int', deploy_slot.deploy_unit},
                cache_for=1800
            ),
            replication,
            multibeta_slot_ids,
        ))

    return ctrls


def make_int_ctrls(int_slots):
    return [IntCtrl({
        slot_id: gencfg_api.get_agents(groups)
        for slot_id, groups in int_slots.items()
    })]


def make_mmeta_ctrls(mmeta_slots):
    return [
        slots.MmetaCompCtrl({
            slot_id: gencfg_api.get_agents([group])
            for slot_id, group in mmeta_slots.iteritems()
        })
    ]


_log = logging.getLogger(__name__)
