import os
import logging
import datetime

import infra.callisto.controllers.slots as slot_controller
import infra.callisto.controllers.sdk as sdk
import infra.callisto.controllers.sdk.notify as notify
import infra.callisto.controllers.utils.gencfg_api as gencfg_api


class SourceController(sdk.Controller):
    @property
    def path(self):
        return self._name

    def __init__(self, name, slots_, deployer):
        super(SourceController, self).__init__()
        self._deployer_target_states = set()
        self._searcher_target_state = slot_controller.NullSlotState
        self._deploy_ctrl = deployer
        self._slots = slots_
        self._name = name

        self.register(deployer, *self._slots.values())
        self.add_handler('/deploy_progress', self.deploy_progress)
        self.add_handler('/deploy_percentage', self.deploy_percentage)
        self.add_handler('/searchers_state', self.searchers_state)

    def notifications(self):
        return [notify.TextNotification(
            'observed-state [{}]: {}'.format(self._name, self.get_observed_state()[1]),
            notify.NotifyLevels.IDLE,
        )]

    def execute(self):
        for slot_ctrl in self._slots.values():
            slot_ctrl.set_target_state(self._deployer_target_states.copy(), self._searcher_target_state)

    def set_target_state(self, deployer_target_states, searcher_target_state):
        self._deployer_target_states = deployer_target_states
        self._searcher_target_state = searcher_target_state

    def get_observed_state(self):
        deploy, search = [], set()
        for slot_ctrl in self._slots.values():
            dep, sear = slot_ctrl.get_observed_state()
            deploy.append(dep)
            search.add(sear)
        return set.intersection(*deploy), (search.pop() if len(search) == 1 else slot_controller.NullSlotState)

    def deploy_progress(self, tier_name=None):
        return self._deploy_ctrl.deploy_progress(tier_name)

    def deploy_percentage(self, exactly_timestamp=None):
        tiers_states = {}

        for slot_name, slot_ctrl in self._slots.items():
            for timestamp, state in slot_ctrl.deploy_progress(exactly_timestamp).iteritems():
                if state['total'] or True:
                    if slot_ctrl.slot.tier.name not in tiers_states:
                        tiers_states[slot_ctrl.slot.tier.name] = {}
                    if timestamp not in tiers_states[slot_ctrl.slot.tier.name]:
                        tiers_states[slot_ctrl.slot.tier.name][timestamp] = {'done': [], 'total': [], 'timestamp': timestamp}
                    tier_state = tiers_states[slot_ctrl.slot.tier.name][timestamp]

                    tier_state['done'] += state['done']
                    tier_state['total'] += state['total']

        return {
            tier: [
                {
                    'timestamp': state['timestamp'],
                    'shards': float(len(set(state['done']))) / len(set(state['total'])),
                    'replicas': float(len(state['done'])) / len(state['total']),
                } for state in states.values()
            ]
            for tier, states in tiers_states.iteritems()
        }

    def searchers_state(self, slot_name_to_find=None):
        result = {}
        for slot_name, slot_ctrl in self._slots.items():
            if slot_name_to_find in (slot_name, None):
                result[slot_name] = slot_ctrl.searchers_state()
        return result

    def deploy_namespace_percentage(self):
        return self._deploy_ctrl.deploy_namespace_percentage()

    def json_view(self):
        result = {}
        for slot_id, slot_ctrl in self._slots.items():
            result[slot_id] = slot_ctrl.json_view()
        common_deploy, common_search = self.get_observed_state()
        result['common'] = {
            'deployer': {
                'target': [x.json() for x in self._deployer_target_states],
                'observed': [x.json() for x in common_deploy],
            },
            'searcher': {
                'target': self._searcher_target_state.json(),
                'observed': common_search.json(),
            }
        }
        return result

    def __str__(self):
        return '{}({})'.format(self.__class__.__name__, self._name)


class YtDrivenSourceController(SourceController):
    def __init__(self, name, slots_, deployer, target_table, status_table=None):
        super(YtDrivenSourceController, self).__init__(name, slots_, deployer)
        self._target_table = target_table
        self._status_table = status_table
        self._target_modification_time = datetime.datetime.min

    def update(self, reports):
        head = self._target_table.head()
        target = head.target
        deployer_target_states = {slot_controller.SlotState(ts) for ts in target.deploy}
        searcher_target_state = slot_controller.SlotState(target.search)
        self.set_target_state(deployer_target_states, searcher_target_state)
        self._target_modification_time = head.time

    def save_status(self):
        if not self._status_table:
            return
        deploy, search = self.get_observed_state()
        deploy = {slot_state.timestamp for slot_state in deploy}
        search = search.timestamp if search else None
        status = Status(deploy, search)
        head = self._status_table.head()
        if not head or head.status != status:
            self._status_table.write(status)

    def html_view(self):
        return sdk.blocks.wrap(
            sdk.blocks.Header('YtDrivenSourceController'),
            sdk.blocks.HrefList([sdk.blocks.Href('target-table', self._target_table.gui_url)]),
            *filter(None, (child.html_view() for child in self.children))
        )


class Target(sdk.table.Target):
    schema = [
        {"name": "Deploy", "type": "string"},
        {"name": "Search", "type": "string"},
    ]

    def __init__(self, deploy, search):
        assert isinstance(deploy, (set, list, tuple))
        assert all(isinstance(item, (int, long)) for item in deploy)
        assert search is None or isinstance(search, (int, long))

        self.deploy = set(deploy)
        self.search = search

    @classmethod
    def load_row(cls, row):
        assert row['Deploy']
        deploy_ts = {int(ts) for ts in row['Deploy'].split(' ')}
        search_ts = int(row['Search']) if row['Search'] is not None else None
        return cls(deploy_ts, search_ts)

    def dump_row(self):
        deploy = ' '.join(str(ts) for ts in sorted(self.deploy))
        search = str(self.search) if self.search is not None else None
        return {'Deploy': deploy, 'Search': search}


class Status(sdk.table.Status):
    schema = [
        {"name": "Deploy", "type": "string"},
        {"name": "Search", "type": "string"},
    ]

    def __init__(self, deploy, search):
        self.deploy = set(deploy)
        self.search = search

    @classmethod
    def load_row(cls, row):
        deploy_ts = {int(ts) for ts in row['Deploy'].split(' ')} if row['Deploy'] else set()
        search_ts = int(row['Search']) if row['Search'] is not None else None
        return cls(deploy_ts, search_ts)

    def dump_row(self):
        deploy = ' '.join(str(ts) for ts in sorted(self.deploy))
        search = str(self.search) if self.search is not None else None
        return {'Deploy': deploy, 'Search': search}


class _TargetTable(sdk.table.ControlTable):
    target_class = Target


class _StatusTable(sdk.table.StatusTable):
    status_class = Status


def get_yt_target_table_absolute_path(yt_client, path, readonly):
    table = _TargetTable(yt_client, path, readonly)
    return sdk.table.Readonly(table) if readonly else table


def get_yt_status_table_absolute_path(yt_client, path, readonly):
    table = _StatusTable(yt_client, path, readonly)
    return sdk.table.Readonly(table) if readonly else table


def get_yt_target_table(yt_client, path, readonly):
    path = os.path.join('//home/cajuper/user', path.lstrip('/'))
    return get_yt_target_table_absolute_path(yt_client, path, readonly)


def get_yt_status_table(yt_client, path, readonly):
    path = os.path.join('//home/cajuper/user', path.lstrip('/'))
    return get_yt_status_table_absolute_path(yt_client, path, readonly)


def make_slot_controllers(slots, deploy_ctrl, deploy_only=False, namespace_prefix=None):
    slot_ctrls = {}
    _prefetch_slots(slots)  # hack to fetch all gencfg groups in parallel

    for slot in slots:
        slot_ctrls[slot.name] = slot_controller.make_slot_controller(
            slot,
            deploy_ctrl,
            gencfg_api.agent_shard_number_mapping(slot.group, slot.topology, slot.use_mtn),
            deploy_only=deploy_only,
            namespace_prefix=namespace_prefix,
        )
    return slot_ctrls


def _prefetch_slots(slots):
    import gevent.pool
    slots = sorted(slots)
    for _ in gevent.pool.Pool().imap_unordered(
        gencfg_api.agent_shard_number_mapping,
        [slot.group for slot in slots],
        [slot.topology for slot in slots],
    ):
        pass


_log = logging.getLogger('search_source')
