import logging

import infra.callisto.libraries.yt as yt_utils

import infra.callisto.controllers.sdk as sdk
import infra.callisto.controllers.slots.state as slot_state

import infra.callisto.protos.deploy.tables_pb2 as tables_pb2


def _get_yt_client(use_rpc=True):
    return yt_utils.create_yt_client('arnold', use_rpc=use_rpc)


def location_target(loc):
    return '//home/cajuper/user/web/prod/chunks/ctl/%s/chunks_target' % (loc,)


def location_status(loc):
    return '//home/cajuper/user/web/prod/chunks/ctl/%s/chunks_status' % (loc,)


class Controller(sdk.Controller):
    def __init__(self, status_table, target_table, tier, yt_client=None, optional=True, readonly=True, threshold=1.0, shard_number_limit=None):
        self._name = '{}ExtChunks'.format(tier.name)
        self.optional = optional
        self.status_table = status_table
        self.target_table = target_table
        self.readonly = readonly
        self.tier = tier
        self.threshold = threshold
        self._targets = None

        self.shard_number_limit = shard_number_limit

        if yt_client is None:
            yt_client = _get_yt_client()
        self.yt_client = yt_client

        self._searcher_target = slot_state.NullSlotState
        self._logger = logging.getLogger('{} chunks-controller'.format(self._name))

        super(Controller, self).__init__()

    def __str__(self):
        return 'TargetController({})'.format(self._name)

    @property
    def path(self):
        return self._name

    def set_targets(self, targets):
        if self._targets != targets:
            self._logger.info('%s: %s -> %s', self, self._targets, targets)
            self._targets = targets

    def _observed_timestamps(self):
        observed = set()
        for target in self._targets:
            if self.optional:
                observed.add(target)
            else:
                streams = [shard.name for shard in self.tier.list_shards(target)
                           if self.shard_number_limit is None or shard.shard_number < self.shard_number_limit]
                rows = self.yt_client.lookup_rows(self.status_table, ({'Namespace': stream, 'StateId': str(target)} for stream in streams))
                active = 0
                for row in rows:
                    state = tables_pb2.TGenerationState()
                    state.ParseFromString(row['State'])
                    if state.State == tables_pb2.ACTIVE_GENERATION_STATE:
                        active += 1

                if len(streams) * self.threshold <= active:
                    observed.add(target)

        return observed

    def _delete_target(self, stream, snapshot):
        self._logger.info('going to delete %s %s from target', stream, snapshot)
        if not self.readonly:
            self.yt_client.delete_rows(self.target_table, [{'Namespace': stream, 'StateId': snapshot}])

    def _add_target(self, stream, snapshot):
        state = tables_pb2.TGenerationState()
        state.State = tables_pb2.ACTIVE_GENERATION_STATE
        state_str = state.SerializeToString()
        self._logger.info('going to deploy %s %s', stream, snapshot)
        if not self.readonly:
            self.yt_client.insert_rows(self.target_table, [{'Namespace': stream, 'StateId': snapshot, 'State': state_str}])

    def execute(self):
        streams = {shard.name for shard in self.tier.list_shards(0)}
        snapshots = {str(target) for target in self._targets}

        with self.yt_client.Transaction(type='tablet'):
            to_insert = {(stream, str(snapshot)) for stream in streams for snapshot in self._targets}
            rows = self.yt_client.select_rows('* from [%s]' % (self.target_table))
            for row in rows:
                stream = row['Namespace']
                snapshot = row['StateId']
                state = tables_pb2.TGenerationState()
                state.ParseFromString(row['State'])

                if stream not in streams:
                    continue
                if snapshot not in snapshots:
                    self._delete_target(stream, snapshot)
                elif state.State == tables_pb2.ACTIVE_GENERATION_STATE:
                    to_insert.remove((stream, snapshot))

            for stream, snapshot in to_insert:
                self._add_target(stream, snapshot)

    @property
    def targets(self):
        return self._targets

    def set_target_state(self, deployer_target_states, searcher_target_state):
        self.set_targets({state.timestamp for state in deployer_target_states})
        self._searcher_target = searcher_target_state

    def get_observed_state(self):
        return {slot_state.SlotState(timestamp) for timestamp in self._observed_timestamps()}, self._searcher_target

    def searchers_state(self, slot_name_to_find=None):
        return {}

    def deploy_progress(self, timestamp=None):
        return {}

    def html_view(self):
        return sdk.blocks.SlotView(self.path, super(Controller, self).html_view(), sdk.blocks.Block())

    def json_view(self):
        deployer_state, searcher_state = self.get_observed_state()
        return {
            'deployer': {
                'observed': [state.json() for state in deployer_state],
                'target': [slot_state.SlotState(target).json() for target in self.targets]
            },
            'searcher': {
                'observed': searcher_state.json(),
                'target': self._searcher_target.json()
            }
        }
