import collections
import copy
import gevent
import logging
import random
import time
import json

import google.protobuf.json_format as json_format

import retry
import yt.wrapper as yt

import infra.callisto.controllers.sdk as sdk
import infra.callisto.protos.deploy.http_pb2 as http_pb2
import infra.callisto.protos.deploy.tables_pb2 as tables_pb2

import search.plutonium.deploy.proto.sources_pb2 as sources_pb2

from match_banned import match_banned_group

import config_pb2
import entities
import tables
import unistat

from collections import defaultdict
from datetime import datetime

TShardConfig = config_pb2.TShardConfig
TCoordinatorConfig = config_pb2.TCoordinatorConfig

RSPROXY_READY_THRESHOLD = 0.89
STARTUP_INTERVAL = 20 * 60
DEFAULT_ACTIVE_STATES_LIMIT = 2
DEFAULT_STATES_LIMIT = DEFAULT_ACTIVE_STATES_LIMIT + 1
DEFAULT_STATES_LIMIT_FOR_HAMSTER = 5
CLEANUP_AGE = 60 * 60 * 24
NAMESPACE_BATCH_SIZE = 10
REPLICATE_PROGRESS_TIMEOUT = 60 * 60  # 1h


class Location(collections.namedtuple('Location', ['endpoint_set', 'table'])):
    def get(self, namespace, state_id):
        return self.table.get(namespace, state_id)

    def list(self, namespaces, filter_func=None):
        return filter(filter_func, self.table.list(namespaces))

    def __str__(self):
        return self.table.path


class PlutoniumFS(object):
    def __init__(self, yt_client, path, content_cluster=None,
                 fallback_clusters=None, error_cluster_probability=0, path_prefix='shard/'):
        self._runtime_meta = tables.RuntimeMetaTable(
            yt_client, path + '/runtime_fs/meta'
        )
        self._runtime_content_path = path + '/runtime_fs/content'
        self._content_cluster = content_cluster
        self._fallback_clusters = fallback_clusters
        self.error_cluster_probability = error_cluster_probability
        self.path_prefix = path_prefix

    def get_content_cluster(self):
        content_cluster = self._content_cluster or self._runtime_meta.cluster
        if random.random() < self.error_cluster_probability:
            return 'error_cluster'
        return content_cluster

    def _get_mappings(self, namespace, state_id):
        return [
            file_desc
            for file_desc in self._runtime_meta.read(namespace, state_id)
            if file_desc.path.startswith(self.path_prefix)
        ]

    def _make_source(self, file_desc):
        source_spec = sources_pb2.TDynamicTablesSource(
            Cluster=self.get_content_cluster(),
            Path=self._runtime_content_path,
            RowId=file_desc.file_id,
            Size=file_desc.size,
            Hash=file_desc.hash
        )
        for cluster in self._fallback_clusters:
            source_spec.FallbackLocations.add(
                Cluster=cluster, Path=self._runtime_content_path
            )

        return source_spec

    def get_mappings_resource(self, namespace, state_id):
        mappings = self._get_mappings(namespace, state_id)
        return sources_pb2.TCompoundSource(
            Sources=[
                sources_pb2.TCompoundSource.TInternalResource(
                    Path=file_desc.path,
                    Source=sources_pb2.TSource(
                        DynamicTables=self._make_source(file_desc)
                    )
                )
                for file_desc in mappings
            ]
        )


class ChunkCtrl(object):
    def __init__(self, yt_client, path, tablet_cell_bundle=None, readonly=True):
        self._status_table = tables.ChunkCtrlStatus(
            yt_client, path + '/chunks_status'
        )
        self._target_table = tables.ChunkCtrlTarget(
            yt_client, path + '/chunks_target',
            readonly,
            tablet_cell_bundle or 'cajuper',
        )
        self._topology_table = tables.ChunkCtrlTopology(
            yt_client, path + '/configs'
        )
        self._status = {}

    def update_status(self, namespaces):
        self._status = {
            (state.Namespace, state.StateId): state
            for state in self._status_table.list(namespaces)
        }

    def is_active(self, state):
        active_status = tables_pb2.EGenerationState.ACTIVE_GENERATION_STATE
        if (state.Namespace, state.StateId) in self._status:
            return self._status[state.Namespace, state.StateId].State.State == active_status
        return False

    def set_target(self, states):
        targets = [
            tables_pb2.TGenerationTarget(
                Namespace=state.Namespace,
                StateId=state.StateId,
                State=tables_pb2.TGenerationState(State=tables_pb2.EGenerationState.ACTIVE_GENERATION_STATE)
            )
            for state in states
        ]
        self._target_table.write(targets)

    def get_topology(self, states):
        return self._topology_table.get(states)


class SnapshotReplicator(object):
    def __init__(self, target_table, status_table):
        self._target_table = target_table
        self._status_table = status_table

    def replicate(self, states):
        last_namespace = None
        to_replicate_states = []

        for state in _sort_states(states):
            if last_namespace and last_namespace == state.Namespace:
                raise ValueError('More than 1 state for namespace "{}"'.format(state.Namespace))
            last_namespace = state.Namespace

            to_replicate_states.append({
                'Stream': state.Namespace,
                'SnapshotId': state.StateId,
                'State': 'REPLICATE',  # TODO: Use protobuf in State
            })

        self._target_table.write(to_replicate_states)

    def is_replicated(self, state):
        snapshot_status = self._status_table.get(state.Namespace, state.StateId)
        return snapshot_status and snapshot_status['State'] == 'DONE'  # TODO: Use protobuf in State


class RsProxy(object):
    def __init__(self, pods_providers, plutonium_fs, target_table, status_table,
                 ready_threshold=RSPROXY_READY_THRESHOLD):
        self._pods_providers = pods_providers
        self._plutonium_fs = plutonium_fs
        self._target_table = target_table
        self._status_table = status_table
        self._ready_threshold = ready_threshold

        self._status = collections.defaultdict(
            lambda: {'active_pods': set(), 'prepared_pods': set()}
        )

    def update_status(self, banned_sets=()):
        ready_statuses = {
            entities.MappingsTarget.namespace: (tables_pb2.EDownloadState.PREPARED, tables_pb2.EDownloadState.ACTIVE),
            entities.TopologyTarget.namespace: (tables_pb2.EDownloadState.ACTIVE,),
        }
        target_keys = [
            {'PodId': status.PodId, 'Namespace': status.Namespace, 'LocalPath': status.LocalPath}
            for status in self._status_table.list(ready_statuses.keys(), self._list_pods_ids())
            if status.ResourceState.Status in ready_statuses[status.Namespace]
        ]

        result = collections.defaultdict(self._status.default_factory)
        for target in self._target_table.lookup(target_keys):
            if target.Namespace == entities.TopologyTarget.namespace:
                key = entities.TopologyTarget.get_snapshot(target)
                result[key]['active_pods'].add(target.PodId)
            elif target.Namespace == entities.MappingsTarget.namespace:
                key = entities.MappingsTarget.get_snapshot(target)
                result[key]['prepared_pods'].add(target.PodId)

        self._status = result

    def is_prepared(self, state):
        return self._count_statuses(state, 'prepared_pods', 'prepared', sample_size=3)

    def is_active(self, state, banned_sets=()):
        assert len(banned_sets) == 0
        return self._count_statuses(state, 'active_pods', 'activated', sample_size=3)

    def _count_statuses(self, state, status_key, status_name, sample_size=3):
        check_pods = self._list_pods_ids()

        ready_pods = self._status[state.Namespace, state.StateId][status_key]
        not_ready_pods_sample = list(check_pods - ready_pods)[:sample_size]
        if not_ready_pods_sample:
            _log.debug('State (%s, %s) is still not %s on %s [ready on %s / %s]',
                       state.Namespace, state.StateId, status_name, not_ready_pods_sample,
                       len(ready_pods), len(check_pods))

        return len(ready_pods) >= len(check_pods) * self._ready_threshold

    def prepare(self, states):
        resources = []
        for state in states:
            compound_resource = self._plutonium_fs.get_mappings_resource(
                state.Namespace, state.StateId
            )

            target = entities.MappingsTarget(
                state.Namespace, state.StateId,
                sources_pb2.TSource(Compound=compound_resource)
            )

            resources.extend(
                target.get_targets(self._list_pods_ids())
            )

        self.deliver(resources)

    def activate(self, topologies):
        resources = []
        for topology in topologies:
            target = entities.TopologyTarget(topology.Namespace, topology.StateId,
                                             topology.ResourceSpec, topology.Revision)
            resources.extend(target.get_targets(self._list_pods_ids()))

        if resources:
            _log.debug('%s resources for deliver', len(resources))
            self.deliver(resources)

    def deliver(self, resources):
        self._target_table.update(resources)

    def _list_providers(self, banned_sets):
        result = []
        for provider in self._pods_providers:
            if match_banned_group(provider, banned_sets):
                _log.debug('skip provider because %s banned', str(provider.group_keys()))
            else:
                result.append(provider)
        return result

    def _list_pods_ids(self, banned_sets=()):
        return {
            pod_id
            for provider in self._list_providers(banned_sets)
            for pod_id in provider.ids
        }


class BadPods(object):
    def __init__(self, max_bad_pods=10, min_mark_lag=10, max_mark_lag=60):
        self._last_ready = defaultdict(int)  # pod id -> last ready timestamp
        self._bad_pods = set()
        self._max_bad_pods = max_bad_pods
        self._min_mark_lag = min_mark_lag
        self._max_mark_lag = max_mark_lag

    def mark_ok(self, pod):
        self._last_ready[pod] = _now()
        self._bad_pods.discard(pod)

    # returns if pod can be skipped
    def try_mark_bad(self, pod):
        now = _now()
        if self._last_ready[pod] + self._min_mark_lag > now:
            return False
        else:
            for pod in list(self._bad_pods):
                if self._last_ready[pod] + self._max_mark_lag < now:
                    self._bad_pods.remove(pod)

            self._bad_pods.add(pod)
            return len(self._bad_pods) <= self._max_bad_pods

    def get_bad_pods(self):
        if len(self._bad_pods) <= self._max_bad_pods:
            return list(self._bad_pods)
        else:
            return []


class BkStat(object):
    def __init__(self, pods_providers, target_table, status_table,
                 max_not_ready_pods=2,
                 max_bad_pods=10,
                 min_mark_bad_seconds=10,
                 max_mark_bad_seconds=60,
                 min_ready_threshold=0.75):
        self._pods_providers = pods_providers
        self._target_table = target_table
        self._status_table = status_table
        self._bad_pods = BadPods(max_bad_pods, min_mark_bad_seconds, max_mark_bad_seconds)
        self._min_ready_threshold = min_ready_threshold
        self._ignored_pods = []
        self._dead_pods = []
        self._max_not_ready_pods = max_not_ready_pods
        self._switching_percent = 1.0

        self._status = collections.defaultdict(
            lambda: {'active_pods': set(), 'prepared_pods': self._list_pods_ids()}
        )

    def update_status(self, banned_sets=()):
        ready_status = (tables_pb2.EDownloadState.ACTIVE,)
        target_keys = [
            {'PodId': status.PodId, 'Namespace': status.Namespace, 'LocalPath': status.LocalPath}
            for status in self._status_table.list([entities.TopologyTarget.namespace], self._list_pods_ids(banned_sets))
            if status.ResourceState.Status in ready_status
        ]
        global_timer.stamp_delta('BkStat pods statuses ready')

        result = collections.defaultdict(self._status.default_factory)
        for target in self._target_table.lookup(target_keys):
            if target.Namespace == entities.TopologyTarget.namespace:
                key = entities.TopologyTarget.get_snapshot(target)
                result[key]['active_pods'].add(target.PodId)
        global_timer.stamp_delta('BkStat status ready')

        self._status = result

    def is_prepared(self, state):
        return True

    def prepare(self, states):
        pass

    def is_active(self, state, realtime_config, banned_sets=()):
        check_pods = self._list_pods_ids(banned_sets)
        alive_pods = self._list_alive_pods(banned_sets)
        ready_pods = self._status[state.Namespace, state.StateId]['active_pods']
        waiting_pods = list()
        self._ignored_pods = list(check_pods - alive_pods - ready_pods)
        self._dead_pods = list(check_pods - alive_pods)
        for pod in list(ready_pods):
            self._bad_pods.mark_ok(pod)
        for pod in list(alive_pods - ready_pods):
            if not self._bad_pods.try_mark_bad(pod):
                waiting_pods.append(pod)
        ready_percent = len(ready_pods) / float(len(check_pods))
        if len(waiting_pods) > realtime_config.get("saas2_bkstat_max_not_ready_pods", self._max_not_ready_pods) \
                or ready_percent < realtime_config.get("saas2_bkstat_min_ready_threshold", self._min_ready_threshold):
            sample_size = 3
            _log.debug('State (%s, %s) is still not active. Waiting for %s [ready on %s / %s]',
                       state.Namespace, state.StateId, waiting_pods[:sample_size],
                       len(ready_pods), len(check_pods))
            return False
        else:
            self._switching_percent = ready_percent
            return True

    def get_bad_pods(self):
        return self._bad_pods.get_bad_pods()

    def get_ignored_pods(self):
        return self._ignored_pods

    def get_dead_pods(self):
        return self._dead_pods

    def get_switching_percent(self):
        return self._switching_percent

    def activate(self, topologies):
        resources = []
        pods = self._list_pods_ids()
        for topology in topologies:
            target = entities.TopologyTarget(topology.Namespace, topology.StateId,
                                             topology.ResourceSpec, topology.Revision, topology.State.Timestamp)
            resources.extend(target.get_targets(pods))

        if resources:
            _log.debug('%s resources for deliver', len(resources))
            self.deliver(resources)

    def deliver(self, resources):
        self._target_table.update(resources)

    def _list_providers(self, banned_sets):
        result = []
        for provider in self._pods_providers:
            if match_banned_group(provider, banned_sets):
                _log.debug('skip provider because %s banned', str(provider.group_keys()))
            else:
                result.append(provider)
        return result

    def _list_alive_pods(self, banned_sets=()):
        pods = set()
        for provider in self._list_providers(banned_sets):
            for agent in provider.agents_instances.itervalues():
                if agent.is_alive:
                    pods.add(agent.id)
        return pods

    def _list_pods_ids(self, banned_sets=()):
        return {
            pod_id
            for provider in self._list_providers(banned_sets)
            for pod_id in provider.ids
        }


class ReplicaCtl(object):
    def __init__(self, target_table, status_table):
        self._target_table = target_table
        self._status_table = status_table
        self._ready_states = set()

    def update_status(self, banned_sets=()):
        ready_status = (tables_pb2.EDownloadState.ACTIVE,)
        target_keys = [
            {'PodId': status.PodId, 'Namespace': status.Namespace, 'LocalPath': status.LocalPath}
            for status in self._status_table.list([entities.TopologyTarget.namespace], pods=[""])
            if status.ResourceState.Status in ready_status
        ]
        global_timer.stamp_delta('ReplicaCtl pods statuses ready')

        ready_states = set()
        for target in self._target_table.lookup(target_keys):
            ready_states.add(entities.TopologyTarget.get_snapshot(target))
        global_timer.stamp_delta('ReplicaCtl status ready')
        self._ready_states = ready_states

    def is_prepared(self, state):
        return True

    def prepare(self, states):
        pass

    def is_active(self, state, _realtime_config, _banned_sets=()):
        return (state.Namespace, state.StateId) in self._ready_states

    def get_bad_pods(self):
        return []

    def get_ignored_pods(self):
        return []

    def get_dead_pods(self):
        return []

    def get_switching_percent(self):
        return 1.0

    def activate(self, topologies):
        resources = []
        pods = [""]
        for topology in topologies:
            target = entities.TopologyTarget(topology.Namespace, topology.StateId,
                                             topology.ResourceSpec, topology.Revision, topology.State.Timestamp)
            resources.extend(target.get_targets(pods))

        if resources:
            _log.debug('%s resources for deliver', len(resources))
            self._target_table.update(resources)


class Coordinator(sdk.Controller):
    def __init__(self, namespaces, public_states, locations, contour_states, snapshot_replicator=None, banned_groups=tables.BannedGroupsStub()):
        super(Coordinator, self).__init__()
        self._namespaces = frozenset(namespaces)
        self._public_states = public_states
        self._locations = locations
        self._contour_states = contour_states
        self._freeze_states = copy.deepcopy(contour_states)
        self._snapshot_replicator = snapshot_replicator

        self._banned_groups = banned_groups

        self.add_handler('/freeze', self._freeze)
        self.add_handler('/list_frozen_states', self._list_freeze_states)
        self.add_handler('/freeze_last_state', self._freeze_last_state)
        self.add_handler('/unfreeze_state', self._unfreeze_state)
        self.add_handler('/group_state', self._group_state)

    def execute(self):
        global_timer.reset()
        public_states = self._public_states.list(self._namespaces)

        self._add_fresh_states(public_states)
        global_timer.stamp_delta('fresh added')

        if self._snapshot_replicator:
            self._replicate_states()
            global_timer.stamp_delta('replicate state handled')

        self._cleanup()
        global_timer.stamp_delta('cleaned')
        global_timer.stamp_delta('snapshots stats have got')

    @retry.retry(exceptions=yt.errors.YtNoSuchTransaction, tries=3, delay=1)
    def _freeze(self):
        request = sdk.request.current_request()
        freeze_request = json_format.Parse(request.stream.read(), http_pb2.TFreezeRequest())

        with self._freeze_states._transaction():
            changed_states = []
            for state_meta in freeze_request.SnapshotMetas:
                state = self._freeze_states.get(state_meta.Stream, state_meta.Id)
                if state and _is_actual(state):
                    if not _is_force_active(state):
                        _log.debug('Freeze state %s', (state.Namespace, state.StateId))
                        state.Demand.Active = True
                        changed_states.append(state)
                    else:
                        _log.debug('Keep freeze state %s', (state.Namespace, state.StateId))
                else:
                    # TODO: check
                    raise RuntimeError('Unknown state {}'.format(state_meta))

            for state in self._freeze_states.list(self._namespaces):
                if _is_force_active(state):
                    state_meta = http_pb2.TSnapshotMeta(
                        Stream=state.Namespace, Id=state.StateId
                    )
                    if state_meta not in freeze_request.SnapshotMetas:
                        _log.debug('UnFreeze state %s', (state.Namespace, state.StateId))
                        state.Demand.Active = False
                        changed_states.append(state)
            self._freeze_states.update(changed_states)

        return {}

    def _group_state(self):
        # Parse options request
        request = json.loads(sdk.request.current_request().stream.read())
        self._banned_groups.write(request)

    def _get_shardctl_states(self):
        state_filter = lambda x: _is_actual(x) and _is_stable_active(x)
        states = {}
        for loc in self._locations:
            for state in loc.list(self._namespaces, state_filter):
                _key = (state.Namespace, state.StateId)
                if _key not in states:
                    states[_key] = state
        return states.values()

    @retry.retry(exceptions=yt.errors.YtNoSuchTransaction, tries=3, delay=1)
    def _list_freeze_states(self):
        """
        Optional JSON Data:
        {
            "namespaces": [
                "<namespace1>",
                "<namespace2>"
            ]
        }
        """
        freeze_namespaces = self._namespaces

        # Parse options request
        request = sdk.request.current_request().stream.read()
        if request:
            request_json = json.loads(request)
            freeze_namespaces = request_json['namespaces']

        _log.debug('Request to list freeze states in namespaces: %s', ','.join(freeze_namespaces))

        freezed_states = []
        with self._freeze_states._transaction():
            for namespace, namespace_states in _grouped_states(self._freeze_states.list(freeze_namespaces)):
                for state in namespace_states:
                    if _is_force_active(state):
                        freezed_states.append(state)

        return {'states': [{'Namespace': state.Namespace, 'StateId': state.StateId} for state in freezed_states]}

    @retry.retry(exceptions=yt.errors.YtNoSuchTransaction, tries=3, delay=1)
    def _freeze_last_state(self):
        """
        Optional JSON Data:
        {
            "namespaces": [
                "<namespace1>",
                "<namespace2>"
            ]
        }
        """

        # Parse options request
        request = sdk.request.current_request().stream.read()
        if request:
            request_json = json.loads(request)
            freeze_namespaces = set(request_json['namespaces'])
        else:
            freeze_namespaces = set(self._namespaces)

        _log.debug('Request to freeze last state in namespaces: %s', ','.join(freeze_namespaces))

        grouped_records = _grouped_states(
            self._get_shardctl_states(),
            group_key=lambda x: x.Namespace,
            sort_key=lambda x: (x.Namespace, x.StateId),
        )

        changed_states = []
        with self._freeze_states._transaction():
            for stream, stream_records in grouped_records:
                if stream not in freeze_namespaces:
                    continue

                _log.debug('Try freeze state in %s', stream)

                for record in stream_records:
                    if not _is_stable_active(record):
                        continue
                    state = self._freeze_states.get(record.Namespace, record.StateId)
                    if not state:
                        continue
                    elif _is_force_active(state):
                        break
                    elif not _is_actual(state):
                        continue

                    state.Demand.Active = True
                    changed_states.append(state)
                    _log.debug('Freeze state %s', (state.Namespace, state.StateId))
                    break

            self._freeze_states.update(changed_states)

        return {'states': [{'Namespace': s.Namespace, 'StateId': s.StateId} for s in changed_states]}

    @retry.retry(exceptions=yt.errors.YtNoSuchTransaction, tries=3, delay=1)
    def _unfreeze_state(self):
        """
        {
            "states": [
                {
                    "Namespace": "<namespace1>",
                    "StateId": "<state_id1>",
                },
                {
                    "Namespace": "<namespace2>",
                    "StateId": "<state_id2>",
                }
            ]
        }
        """

        # Parse options request
        unfreeze_states = json.loads(sdk.request.current_request().stream.read())['states']
        _log.debug('Request to unfreeze state: %s', unfreeze_states)

        changed_states = []
        with self._freeze_states._transaction():
            for unfreeze_state in unfreeze_states:
                state = self._freeze_states.get(unfreeze_state['Namespace'], unfreeze_state['StateId'])
                if not state or not _is_force_active(state):
                    continue

                state.Demand.Active = False
                changed_states.append(state)
                _log.debug('UnFreeze state %s', (state.Namespace, state.StateId))

            self._freeze_states.update(changed_states)

        return {'states': [{'Namespace': s.Namespace, 'StateId': s.StateId} for s in changed_states]}

    def _add_fresh_states(self, public_states):
        fresh_states = []
        with self._contour_states._transaction():
            known_states = {
                (state.Namespace, state.StateId)
                for state in self._contour_states.list(self._namespaces)
            }
            for state in public_states:
                if (state.Namespace, state.StateId) not in known_states:
                    fresh_states.append(state)
            if fresh_states:
                _log.debug(
                    'Add states to contour: %s',
                    [(state.Namespace, state.StateId) for state in fresh_states]
                )
                self._contour_states.update(fresh_states)

    def _replicate_states(self):
        with self._contour_states._transaction():
            changed_states = []
            to_replicate_states = []

            for namespace, namespace_states in _grouped_states(self._contour_states.list(self._namespaces)):
                gevent.sleep(0)
                candidate_state = None  # Candidate State for be a next replication target
                in_progress_state = None  # State in an incomplete replication process (current target)
                for state in namespace_states:
                    if not _is_actual(state) and not _is_replicate(state):
                        continue
                    if candidate_state is None:
                        candidate_state = state
                    if _is_replicate(state) and not in_progress_state:
                        if candidate_state.StateId <= state.StateId:
                            # Candidate State aleady is replication target
                            candidate_state = None
                        if not _is_replicated(state):
                            # Update target State status
                            if self._snapshot_replicator.is_replicated(state):
                                _log.info('State %s is replicated', (state.Namespace, state.StateId))
                                state.Status.ReplicateTimestamp = _now()
                                state.Demand.KeepInPlutonium = False
                                changed_states.append(state)
                            else:
                                _log.debug('State %s in replication progress', (state.Namespace, state.StateId))
                                in_progress_state = state
                        break

                if in_progress_state:
                    if candidate_state and _is_stuck(
                        in_progress_state.Target.ReplicateTimestamp,
                        in_progress_state.Status.ReplicateTimestamp,
                        REPLICATE_PROGRESS_TIMEOUT
                    ):
                        # Drop stuck target State
                        _log.error(
                            'State %s is stuck in replication - drop (target of %s)',
                            (in_progress_state.Namespace, in_progress_state.StateId),
                            in_progress_state.Target.ReplicateTimestamp,
                        )
                        in_progress_state.Demand.KeepInPlutonium = False
                        changed_states.append(in_progress_state)
                        in_progress_state = None
                    else:
                        # Keep target State no change
                        _log.debug(
                            'State %s keep in replication target (target of %s)',
                            (in_progress_state.Namespace, in_progress_state.StateId),
                            in_progress_state.Target.ReplicateTimestamp,
                        )
                        to_replicate_states.append(in_progress_state)

                if not in_progress_state:
                    if candidate_state:
                        # Add new target State
                        _log.debug('State %s is new replication target', (candidate_state.Namespace, candidate_state.StateId))
                        candidate_state.Target.ReplicateTimestamp = _now()
                        candidate_state.Demand.KeepInPlutonium = True
                        changed_states.append(candidate_state)
                        to_replicate_states.append(candidate_state)
                    else:
                        # Nothing to replicate
                        _log.debug('No states in namespace %s for replication', namespace)

            self._contour_states.update(changed_states)
            self._snapshot_replicator.replicate(to_replicate_states)

    def _cleanup_contour(self):
        global_timer.stamp_delta('begin cleanup contour')
        with self._contour_states._transaction():
            to_delete = []
            ns_checked = set()

            contour_states = [
                state
                for state in self._contour_states.list(self._namespaces)
                if _is_actual(state)
            ]
            global_timer.stamp_delta('contour states read')
            locations_states = {
                (state.Namespace, state.StateId)
                for location in self._locations
                for state in location.list(self._namespaces, _is_actual)
            }
            global_timer.stamp_delta('locations states read')
            for state in _sort_states(contour_states):
                if state.Namespace not in ns_checked:
                    ns_checked.add(state.Namespace)
                    continue
                if _is_force_active(state):
                    continue
                if (state.Namespace, state.StateId) not in locations_states:
                    _log.info('Retire state %s', (state.Namespace, state.StateId))
                    _revoke(state)
                    to_delete.append(state)
            global_timer.stamp_delta('to delete states are collected')

            if to_delete:
                self._contour_states.update(to_delete)
            global_timer.stamp_delta('contour states are updated')

        purge_states(self._contour_states, self._namespaces)

    def _cleanup_public(self):
        with self._public_states._transaction():
            to_delete = []
            contour_states = {
                (state.Namespace, state.StateId): state
                for state in self._contour_states.list(self._namespaces)
                if not _is_actual(state) and not state.Status.DropTimestamp and not state.Demand.KeepInPlutonium
            }
            for state in self._public_states.list(self._namespaces):
                key = (state.Namespace, state.StateId)
                if key in contour_states:
                    contour_states[key].Status.DropTimestamp = _now()
                    to_delete.append(contour_states[key])

            if to_delete:
                _log.info('Remove states %s from public',
                          [(state.Namespace, state.StateId) for state in to_delete])
                self._public_states.delete(to_delete)
                self._contour_states.update(to_delete)

    def _cleanup(self):
        self._cleanup_contour()
        self._cleanup_public()


class ShardCtrl(sdk.Controller):
    def __init__(self, namespaces, chunk_ctrl, orchestrated_service,
                 states, contour_states,
                 banned_groups=tables.BannedGroupsStub(),
                 realtime_config_table=tables.RealtimeConfigStub(),
                 enable_freeze=False,
                 states_limit=DEFAULT_STATES_LIMIT,
                 progress_timeout=STARTUP_INTERVAL):
        super(ShardCtrl, self).__init__()
        self._namespaces = frozenset(namespaces)
        self._chunk_ctrl = chunk_ctrl
        self._orchestrated_service = orchestrated_service
        self._states = states
        self._contour_states = contour_states

        self._banned_groups = banned_groups
        self._realtime_config_table = realtime_config_table
        self._realtime_config = {}

        self._freeze_enabled = enable_freeze

        self._states_limit = states_limit
        self._active_states_limit = self._states_limit - 1
        self._progress_timeout = progress_timeout

        self._unistat = unistat.Unistat()
        self._signals = [
            ('iteration_time', 'txxx'),
            ('create_to_collect_time', 'aggregation'),
            ('prepare_time', 'aggregation'),
            ('collect_to_preparing_time', 'aggregation'),
            ('collect_to_prepare_time', 'aggregation'),
            ('activate_time', 'aggregation'),
            ('collect_to_activating_time', 'aggregation'),
            ('collect_to_active_time', 'aggregation'),
            ('prepare_to_activating_time', 'aggregation'),
            ('prepare_to_active_time', 'aggregation'),
            ('active_state_age', 'aggregation'),
            ('force_active_state_age', 'aggregation'),
            ('active_states', 'aggregation'),
            ('activating_states', 'aggregation'),
            ('prepared_states', 'aggregation'),
            ('preparing_states', 'aggregation'),
            ('removed_states', 'aggregation_delta'),
            ('stuck_states', 'aggregation_delta'),
            ('blacklist_size', 'txxx'),
            ('dead_pods_count', 'txxx'),
            ('switching_percent', 'txxx'),
        ]
        self.init_unistat()

        self._iteration_time = None
        self._removed_states = defaultdict(int)
        self._stuck_states = defaultdict(int)

    def execute(self):
        all_states = _sort_states(
            state
            for state in self._states.list(self._namespaces)
            if _is_actual(state)
        )

        self._realtime_config = self._realtime_config_table.get_last_version()

        global_timer.reset()

        self._chunk_ctrl.update_status(self._namespaces)
        global_timer.stamp_delta('chunk controller status updated')
        banned = self._banned_groups.list()

        self._orchestrated_service.update_status(banned)
        global_timer.stamp_delta('orchestrated service status updated')
        for state in all_states:
            self._update_status(state, banned)
            gevent.sleep(0)
        global_timer.stamp_delta('states statuses updated')
        _log.debug('Pods %s marked bad, pods %s ignored (not available in endpointsets)',
                     self._orchestrated_service.get_bad_pods(),
                     self._orchestrated_service.get_ignored_pods())

        self._update_demands(all_states)
        global_timer.stamp_delta('demands updated')

        last_states = self._contour_states.get_last_states(self._namespaces)

        for ns in self._namespaces:
            namespace_states = [
                state
                for state in all_states
                if state.Namespace == ns
            ]

            fresh_state = _get_fresh_state(
                last_states,
                ns,
                namespace_states[0] if namespace_states else None,
                self._freeze_enabled
            )

            _update_target(
                namespace_states,
                self._active_states_limit,
                fresh_state_exists=bool(fresh_state),
                progress_timeout=self._progress_timeout,
            )

            if fresh_state and _is_place_available(namespace_states, self._states_limit):
                _log.debug(
                    'Add fresh state %s',
                    (fresh_state.Namespace, fresh_state.StateId)
                )
                fresh_state.Target.CollectTimestamp = _now()
                all_states.insert(0, fresh_state)

            gevent.sleep(0)
        global_timer.stamp_delta('states targets calculated')

        prepare_states = []
        active_states = []

        for state in all_states:
            if not _is_actual(state):
                continue
            elif state.Target.PrepareTimestamp:
                prepare_states.append(state)
                if _is_active(state):
                    active_states.append(state)
        global_timer.stamp_delta('prepared/activate states is choosen')

        self._states.update(all_states)
        global_timer.stamp_delta('states is stored')

        global_timer.reset()
        self._prepare(prepare_states)
        global_timer.stamp_delta('prepare targets are set')
        self._activate(active_states)
        global_timer.stamp_delta('activate targets are set')

        self._cleanup()
        self._update_unistat(all_states)
        global_timer.stamp_delta('cleanup and update_unistat are called')

    def _update_demands(self, states):
        contour_demands = {
            (state.Namespace, state.StateId): state.Demand
            for state in self._contour_states.lookup(states)
            if _is_actual(state) and state.Demand
        }

        for state in states:
            key = (state.Namespace, state.StateId)
            if key in contour_demands:
                if self._freeze_enabled and contour_demands[key].Active != state.Demand.Active:
                    state.Demand.Active = contour_demands[key].Active
                if state.Demand.Skip != contour_demands[key].Skip:
                    state.Demand.Skip = contour_demands[key].Skip

    def _update_status(self, state, banned_groups):
        if state.Target.ActivateTimestamp and not state.Status.ActivateTimestamp:
            if self._orchestrated_service.is_active(state, self._realtime_config, banned_groups):
                _log.info('State (%s, %s) has been activated',
                          state.Namespace, state.StateId)
                state.Status.ActivateTimestamp = _now()
            else:
                _log.debug('State (%s, %s) not active',
                           state.Namespace, state.StateId)

        if state.Target.PrepareTimestamp and not state.Status.PrepareTimestamp:
            chunks_is_active = self._chunk_ctrl.is_active(state)
            rs_proxy_is_prepared = self._orchestrated_service.is_prepared(state)
            if chunks_is_active and rs_proxy_is_prepared:
                _log.info('State (%s, %s) has been prepared',
                          state.Namespace, state.StateId)
                state.Status.PrepareTimestamp = _now()
            else:
                _log.debug('State (%s, %s) is not ready: rs=%s, rs_proxy=%s',
                           state.Namespace, state.StateId, chunks_is_active,
                           rs_proxy_is_prepared)

        if state.Target.CollectTimestamp and not state.Status.CollectTimestamp:
            state.Status.CollectTimestamp = _now()

    def _prepare(self, states):
        self._chunk_ctrl.set_target(states)
        self._orchestrated_service.prepare(states)

    def _activate(self, states):
        topologies = self._chunk_ctrl.get_topology(states)
        self._orchestrated_service.activate(topologies)

    def _cleanup(self):
        purge_states(self._states, self._namespaces)

    def _update_unistat(self, states):
        if self._iteration_time:
            self._unistat.push('iteration_time_txxx', _now() - self._iteration_time)
        self._iteration_time = _now()
        self._unistat.push('blacklist_size_txxx', len(self._orchestrated_service.get_bad_pods()))
        self._unistat.push('dead_pods_count_txxx', len(self._orchestrated_service.get_dead_pods()))
        self._unistat.push('switching_percent_txxx', self._orchestrated_service.get_switching_percent())

        for namespace, namespace_states in _grouped_states(states):
            last_collected = None
            last_prepared = None
            last_active = None
            last_force_active = None
            active_count = 0
            activating_count = 0
            prepared_count = 0
            preparing_count = 0
            removed_count = 0
            stuck_count = 0

            for state in namespace_states:
                if state.Status.CollectTimestamp and not last_collected:
                    last_collected = state

                if state.Status.PrepareTimestamp:
                    if not last_prepared:
                        last_prepared = state
                    prepared_count += 1
                elif state.Target.PrepareTimestamp:
                    preparing_count += 1

                if _is_stable_active(state):
                    if not last_active and not _is_force_active(state):
                        last_active = state
                    elif not last_force_active and _is_force_active(state):
                        last_force_active = state
                    active_count += 1
                elif _is_active(state):
                    activating_count += 1

                if not _is_actual(state):
                    removed_count += 1
                    if not _is_fresh(state):
                        stuck_count += 1

            if last_collected:
                try:
                    # For debug banner_*
                    created_at = time.mktime(datetime.strptime(last_collected.StateId, '%Y%m%dT%H%M%SZ_%f').timetuple()) + 10800
                    create_to_collect_time = float(last_collected.Status.CollectTimestamp - created_at)
                    self._unistat.push_aggregation('create_to_collect_time', create_to_collect_time)
                    if create_to_collect_time > 10 * 60:
                        _log.debug('Create to collect time ({}, {}): {}s'.format(last_collected.Namespace, last_collected.StateId, create_to_collect_time))
                except:
                    pass

            if last_prepared:
                self._unistat.push_aggregation(
                    'collect_to_preparing_time',
                    float(last_prepared.Target.PrepareTimestamp - last_prepared.Target.CollectTimestamp)
                )
                self._unistat.push_aggregation(
                    'collect_to_prepare_time',
                    float(last_prepared.Status.PrepareTimestamp - last_prepared.Target.CollectTimestamp)
                )
                self._unistat.push_aggregation(
                    'prepare_time',
                    float(last_prepared.Status.PrepareTimestamp - last_prepared.Target.PrepareTimestamp)
                )
            if last_active:
                self._unistat.push_aggregation(
                    'collect_to_activating_time',
                    float(last_active.Target.ActivateTimestamp - last_active.Target.CollectTimestamp)
                )
                self._unistat.push_aggregation(
                    'collect_to_active_time',
                    float(last_active.Status.ActivateTimestamp - last_active.Target.CollectTimestamp)
                )
                self._unistat.push_aggregation(
                    'prepare_to_activating_time',
                    float(last_active.Target.ActivateTimestamp - last_active.Status.PrepareTimestamp)
                )
                self._unistat.push_aggregation(
                    'prepare_to_active_time',
                    float(last_active.Status.ActivateTimestamp - last_active.Status.PrepareTimestamp)
                )
                self._unistat.push_aggregation(
                    'activate_time',
                    float(last_active.Status.ActivateTimestamp - last_active.Target.ActivateTimestamp)
                )
                self._unistat.push_aggregation('active_state_age', _now() - last_active.Target.CollectTimestamp)
            if last_force_active:
                self._unistat.push_aggregation('force_active_state_age', _now() - last_force_active.Target.CollectTimestamp)

            self._unistat.push_aggregation('active_states', active_count)
            self._unistat.push_aggregation('activating_states', activating_count)

            self._unistat.push_aggregation('prepared_states', prepared_count)
            self._unistat.push_aggregation('preparing_states', preparing_count)

            self._removed_states[namespace] += removed_count
            self._unistat.push_aggregation('removed_states', self._removed_states[namespace])

            self._stuck_states[namespace] += stuck_count
            self._unistat.push_aggregation('stuck_states', self._stuck_states[namespace])

        self.flush_unistat()

    def unistat(self):
        return self._unistat.to_list()

    def init_unistat(self):
        for signal_name, signal_type in self._signals:
            if signal_type == 'aggregation':
                self._unistat.drill_float_aggregation(signal_name)
            elif signal_type == 'aggregation_delta':
                self._unistat.drill_float_aggregation(signal_name, delta=True)
            else:
                self._unistat.drill_float(signal_name, signal_type)

    def flush_unistat(self,):
        for signal_name, signal_type in self._signals:
            if signal_type == 'aggregation' or signal_type == 'aggregation_delta':
                self._unistat.flush_aggregation(signal_name)


def _get_fresh_state(last_states, namespace, last_known_state, enable_freeze):
    for state in last_states:
        if state.Namespace == namespace:
            if not last_known_state or state.StateId > last_known_state.StateId:
                if _is_actual(state):
                    fresh_state = tables_pb2.TShardCtrlState(
                        Namespace=state.Namespace,
                        StateId=state.StateId,
                        Demand=state.Demand,
                        State=state.State,
                    )
                    if not enable_freeze:
                        fresh_state.Demand.ClearField('Active')
                    return fresh_state


def purge_states(states_table, namespaces):
    with states_table._transaction():
        states_table.delete(
            state
            for state in states_table.list(namespaces)
            if not _is_actual(state) and _now() - state.Target.DropTimestamp > CLEANUP_AGE
        )


def _sort_states(states, key=lambda s: (s.Namespace, s.StateId)):
    return sorted(states, key=key, reverse=True)


def _grouped_states(states, group_key=lambda s: s.Namespace, sort_key=lambda s: (s.Namespace, s.StateId)):
    last_key = None
    group_states = []
    for state in _sort_states(states, key=sort_key):
        if last_key and last_key != group_key(state):
            yield last_key, group_states
            group_states = []
        last_key = group_key(state)
        group_states.append(state)
    if group_states:
        yield last_key, group_states


def _update_target(states, active_states_limit, fresh_state_exists,
                   progress_timeout=STARTUP_INTERVAL):
    active_states = [state for state in states if _is_force_active(state)]

    first_state_is_checked = False

    for state in states:
        if fresh_state_exists or first_state_is_checked:
            _remove_stuck(state, progress_timeout, not _contains_stable_active(active_states))

        if _is_actual(state):
            if _is_skipped(state):
                _log.warning('Skip state (%s, %s) on demand',
                             state.Namespace, state.StateId)
                _revoke(state)

        if _is_actual(state):
            _promote(state)

            if not _contains_stable_active(active_states) or len(active_states) < active_states_limit:
                if _is_active(state):
                    active_states.append(state)
            elif not _is_force_active(state):
                _log.info('Retire state (%s, %s)', state.Namespace, state.StateId)
                _revoke(state)

        if not first_state_is_checked:
            first_state_is_checked = True


def _is_place_available(namespace_states, states_limit):
    assert len(set(state.Namespace for state in namespace_states)) <= 1, 'Multiple namespaces'
    actual_states = filter(
        _is_actual,
        namespace_states
    )
    return len(actual_states) < states_limit


def _is_fresh(state):
    return not state.Target.PrepareTimestamp


def _is_actual(state):
    return not state.Target.DropTimestamp


def _is_stuck(target_timestamp, status_timestamp, progress_timeout):
    if target_timestamp and not status_timestamp:
        return target_timestamp + progress_timeout < _now()
    return False


def _is_active(state):
    return bool(state.Target.ActivateTimestamp)


def _is_replicate(state):
    return bool(state.Target.ReplicateTimestamp)


def _is_skipped(state):
    return state.Demand.Skip


def _is_force_active(state):
    return state.Demand.Active


def _is_stable_active(state):
    return state.Status.ActivateTimestamp


def _is_replicated(state):
    return bool(state.Status.ReplicateTimestamp)


def _remove_stuck(state, progress_timeout, is_last_active):
    if _is_stuck(state.Target.PrepareTimestamp, state.Status.PrepareTimestamp, progress_timeout):
        _log.warning('Remove state (%s, %s): could not prepare during %s sec',
                     state.Namespace, state.StateId, progress_timeout)
        _revoke(state)

    if not is_last_active and _is_stuck(state.Target.ActivateTimestamp, state.Status.ActivateTimestamp, progress_timeout):
        _log.warning('Remove state (%s, %s): could not activate during %s sec',
                     state.Namespace, state.StateId, progress_timeout)
        _revoke(state)


def _promote(state):
    if not _is_active(state) and state.Status.PrepareTimestamp:
        _log.info('Activate state (%s, %s)', state.Namespace, state.StateId)
        state.Target.ActivateTimestamp = _now()

    if not state.Target.PrepareTimestamp and state.Status.CollectTimestamp:
        _log.info('Prepare state (%s, %s)', state.Namespace, state.StateId)
        state.Target.PrepareTimestamp = _now()


def _revoke(state):
    if _is_actual(state):
        state.Target.DropTimestamp = _now()


def _now():
    return int(time.time())


def _contains_stable_active(states):
    return any([
        _is_stable_active(state)
        for state in states
    ])


class Timer(object):
    def __init__(self):
        self.reset()

    def reset(self):
        self._time = time.time()

    def stamp_delta(self, marker=''):
        t = self._time
        self.reset()
        _log.debug('TIMER %s: +%s', marker, self._time - t)


global_timer = Timer()

_log = logging.getLogger(__name__)
