import logging
import collections
import os
import tempfile
import datetime
from functools import partial
from collections import namedtuple

import google.protobuf.text_format as text_format

import yt.wrapper as yt_wrapper

import infra.callisto.controllers.shard.ctrl as shardctrl
import infra.callisto.controllers.shard.tables as tables
import infra.callisto.controllers.sdk.registry as registry
import infra.callisto.controllers.shard.entities as entities
import infra.callisto.libraries.yt as yt
import infra.callisto.protos.deploy.tables_pb2 as tables_pb2
import infra.callisto.deploy.deployer.plutonium as plutonium

from search.plutonium.deploy.proto.state_schema_pb2 import TStateSchema
from infra.callisto.controllers.shard.utils import make_coordinator, make_instance_providers, make_pod_target, \
    make_pod_status
import adfox_ctrl_config_pb2


TAdfoxShardConfig = adfox_ctrl_config_pb2.TAdfoxShardConfig
EngineContourPodSet = namedtuple('EngineContourPodSet', ['cluster', 'stage'])
Location = namedtuple('Location', ['cluster', 'path'])
COUNTOUR_READY_THRESHOLD = 1.0
DEFAULT_ENGINE_LIVENESS_MINUTES = 40
DEFAULT_ENGINE_LIST_PATH = '//home/adfox/engine/list'


class EnginePodsProvider(object):
    def __init__(self, pod_sets, engine_list_path, engine_liveness_timedelta):
        yt_wrapper.config['proxy']['url'] = 'locke'
        self._pod_sets = [
            EngineContourPodSet(pod_set.Cluster, pod_set.Stage)
            for pod_set in pod_sets
        ]
        self._engine_list_path = engine_list_path
        self._engine_liveness_timedelta = engine_liveness_timedelta

    def _pod_is_alive(self, pod_node, current_ts, stages_list):
        if pod_node['role'] == 'engine' and pod_node['stage'] in stages_list and \
                ((current_ts - datetime.datetime.fromtimestamp(pod_node.get('heartbeat_timestamp', 0))) < self._engine_liveness_timedelta):
            return True

        return False

    def _filter_pods_by_liveness(self, pods_list):
        stages_list = [pod_set.stage for pod_set in self._pod_sets]
        current_ts = datetime.datetime.now()

        alive_pods = filter(partial(self._pod_is_alive, current_ts=current_ts, stages_list=stages_list), pods_list)
        return map(lambda p: p['host'].split('.')[0], alive_pods)

    @property
    def ids(self):
        pods_list = yt_wrapper.get(self._engine_list_path, attributes=['value'])
        return self._filter_pods_by_liveness([engine_pod.attributes['value'] for _, engine_pod in pods_list.items()])


class ChunkCtrlMock(object):
    def __init__(self):
        pass

    def is_active(self, state):
        return True

    def set_target(self, states):
        pass

    def update_status(self, namespaces):
        pass

    def get_topology(self, states):
        return None


class CustomPlutoniumFS(shardctrl.PlutoniumFS):
    def __init__(self, yt_client, path, content_cluster=None,
                 fallback_clusters=None, error_cluster_probability=0):
        super(CustomPlutoniumFS, self).__init__(yt_client, path, content_cluster, fallback_clusters,
                                                error_cluster_probability)

    def _get_mappings_and_chunks(self, namespace, state_id):
        return [
            file_desc
            for file_desc in self._runtime_meta.read(namespace, state_id)
            if file_desc.path.startswith('shard/') or file_desc.path.startswith('chunk/')
        ]

    def get_state_schema_meta(self, namespace, state_id):
        state_schema = [file_desc
                        for file_desc in self._runtime_meta.read(namespace, state_id)
                        if file_desc.path.startswith('kvrs.chunkctrl.schema')]
        assert len(state_schema) == 1
        return state_schema[0]

    def _substitute_file_path(self, path, sources_paths_list):
        for source_path in sources_paths_list:
            if path.startswith(source_path):
                return os.path.relpath(path, source_path)

        return path

    def get_mappings_and_chunks_preprocessed_resource(self, namespace, state_id, sources_paths_list):
        mappings = self._get_mappings_and_chunks(namespace, state_id)
        return shardctrl.sources_pb2.TCompoundSource(
            Sources=[
                shardctrl.sources_pb2.TCompoundSource.TInternalResource(
                    Path=self._substitute_file_path(file_desc.path, sources_paths_list),
                    Source=shardctrl.sources_pb2.TSource(
                        DynamicTables=self._make_source(file_desc)
                    )
                )
                for file_desc in mappings
            ]
        )


class EngineContour(shardctrl.RsProxy):
    def __init__(self, pods_providers_new, pods_providers_old, plutonium_fs, target_table, status_table,
                 ready_threshold=COUNTOUR_READY_THRESHOLD):
        super(EngineContour, self).__init__(pods_providers_old, plutonium_fs, target_table, status_table, ready_threshold)
        self._pods_providers_new = pods_providers_new

    def activate(self, topologies):
        pass

    def _get_state_schema_content(self, state_schema_meta):
        state_schema = TStateSchema()

        locations = [Location(self._plutonium_fs.get_content_cluster(), self._plutonium_fs._runtime_content_path)] +\
                    [Location(fallback_cluster, self._plutonium_fs._runtime_content_path)
                     for fallback_cluster in self._plutonium_fs._fallback_clusters]

        for location in locations:
            try:
                with tempfile.NamedTemporaryFile(dir='./') as temporary_destination:
                    plutonium.download(location.cluster,
                                       yt_wrapper._get_token(),
                                       location.path,
                                       state_schema_meta.file_id,
                                       temporary_destination.name)
                    text_format.Parse(temporary_destination.read(), state_schema)

                return state_schema
            except Exception as e:
                _log.debug('Plutonium state schema download failed: {}'.format(e))

        raise RuntimeError('Error getting "kvrs.chunkctrl.schema" from runtime content for all locations.')

    def prepare(self, states):
        resources = []
        for state in states:
            state_schema_meta = self._plutonium_fs.get_state_schema_meta(state.Namespace, state.StateId)
            for resource_mapping in self._get_state_schema_content(state_schema_meta).ResourceMapping:
                item_type = resource_mapping.DataKey.ItemType

                sources_paths_list = [resource.SourcePath for resource in resource_mapping.Resources]
                compound_resource = self._plutonium_fs.get_mappings_and_chunks_preprocessed_resource(
                    state.Namespace, state.StateId, sources_paths_list
                )

                target = entities.ItemTypeTarget(
                    state.Namespace, state.StateId,
                    shardctrl.sources_pb2.TSource(Compound=compound_resource),
                    item_type
                )

                resources.extend(
                    target.get_targets(self._list_pods_ids())
                )

        self.deliver(resources)

    def update_status(self, banned_groups=None):
        ready_statuses = {
            entities.ItemTypeTarget.namespace: (tables_pb2.EDownloadState.PREPARED, tables_pb2.EDownloadState.ACTIVE),
        }
        all_target_keys = [
            {'PodId': status.PodId, 'Namespace': status.Namespace, 'LocalPath': status.LocalPath,
             'Status': status.ResourceState.Status}
            for status in self._status_table.list(ready_statuses.keys(), self._list_pods_ids())
            if status.ResourceState.Status in ready_statuses[status.Namespace]
        ]

        prepared_target_keys = [{'PodId': target_key['PodId'], 'Namespace': target_key['Namespace'],
                                 'LocalPath': target_key['LocalPath']} for target_key in all_target_keys
                                if target_key['Status'] == tables_pb2.EDownloadState.PREPARED]
        active_target_keys = [{'PodId': target_key['PodId'], 'Namespace': target_key['Namespace'],
                               'LocalPath': target_key['LocalPath']} for target_key in all_target_keys
                              if target_key['Status'] == tables_pb2.EDownloadState.ACTIVE]
        result = collections.defaultdict(self._status.default_factory)

        for target in self._target_table.lookup(prepared_target_keys):
            key = entities.ItemTypeTarget.get_snapshot(target)
            result[key]['prepared_pods'].add(target.PodId)

        for target in self._target_table.lookup(active_target_keys):
            key = entities.ItemTypeTarget.get_snapshot(target)
            result[key]['prepared_pods'].add(target.PodId)  # Active resource is already prepared.
            result[key]['active_pods'].add(target.PodId)

        self._status = result

    def _list_pods_ids(self):
        pods_set_old = {
            pod_id
            for provider in self._pods_providers
            for pod_id in provider.ids
        }
        pods_set_new = {
            pod_id
            for provider in self._pods_providers_new
            for pod_id in provider.ids
        }

        if pods_set_old != pods_set_new:
            _log.debug("Pods sets differs!")
            _log.debug("SD pods: {}".format(pods_set_old))
            _log.debug("Locke pods: {}".format(pods_set_new))

        return pods_set_new


def make_custom_plutonium_fs(plutonium_fs, plutonium_fs_error_cluster_probability, use_rpc=True):
    return CustomPlutoniumFS(
        yt.create_yt_client(plutonium_fs.MetaCluster, use_rpc=use_rpc),
        plutonium_fs.Path,
        plutonium_fs.ContentCluster,
        plutonium_fs.FallbackClusters or [],
        plutonium_fs_error_cluster_probability or 0
    )


def make_controller(readonly, config):
    local_states = tables.ShardCtrlState(
        yt.create_yt_client(config.ShardConfig.ShardCtrlState.Cluster, use_rpc=True),
        config.ShardConfig.ShardCtrlState.Path,
        readonly,
        config.ShardConfig.ShardCtrlState.TabletCellBundle or 'cajuper'
    )

    chunk_ctrl = ChunkCtrlMock()

    orchestrated_service = EngineContour(
        [
            EnginePodsProvider(config.EngineStage, config.EngineListPath or DEFAULT_ENGINE_LIST_PATH,
                               datetime.timedelta(minutes=(config.EngineLivenessMinutes or DEFAULT_ENGINE_LIVENESS_MINUTES)))
        ],
        make_instance_providers(config.ShardConfig.RSProxy),
        make_custom_plutonium_fs(config.ShardConfig.PlutoniumFS, config.ShardConfig.PlutoniumFSErrorClusterProbability),
        make_pod_target(config.ShardConfig.DeployersTarget, readonly),
        make_pod_status(config.ShardConfig.DeployersStatus),
    )

    return shardctrl.ShardCtrl(
        namespaces=config.ShardConfig.Namespaces,
        chunk_ctrl=chunk_ctrl,
        orchestrated_service=orchestrated_service,
        states=local_states,
        contour_states=tables.ShardCtrlState(
            yt.create_yt_client(config.ShardConfig.CoordinatorStates.Cluster, use_rpc=True),
            config.ShardConfig.CoordinatorStates.Path,
            readonly=True
        ),
        enable_freeze=config.ShardConfig.EnableFreezing or False,
        states_limit=config.ShardConfig.StatesLimit or shardctrl.DEFAULT_STATES_LIMIT,
        progress_timeout=config.ShardConfig.StateProgressTimeout or shardctrl.STARTUP_INTERVAL,
    )


_log = logging.getLogger(__name__)

registry.register2('adfox/saas2/shard', make_controller, sleep_time=4, config_type=TAdfoxShardConfig)
registry.register2('adfox/saas2/coordinator', make_coordinator, sleep_time=4, config_type=shardctrl.TCoordinatorConfig)
