import collections
import copy
import json
import logging
import pprint
import time

import infra.callisto.controllers.sdk as sdk
import infra.callisto.controllers.utils.yp_utils as yp_utils
import infra.callisto.controllers.utils.funcs as funcs
import infra.callisto.libraries.yt as yt_utils

import config_pb2  # noqa


DEFAULT_LOCK_CLUSTER = 'locke'
DEFAULT_LOCK_LIVETIME = 1.5 * 3600


class ReleaseController(sdk.Controller):
    def __init__(self, services_table, deploy_units, lock_checker=None,
                 yp_master=yp_utils.YpMasters.xdc, readonly=True):
        super(ReleaseController, self).__init__()
        self.services_table = services_table
        self.deploy_units = deploy_units
        self.lock_checker = lock_checker
        self._yp_master = yp_master
        self.readonly = readonly
        self.log = logging.getLogger(__name__)

    @property
    def stages(self):
        return frozenset(
            deploy_unit.StageId for deploy_unit in self.deploy_units
        )

    def get_deploy_units(self, stage):
        return [
            deploy_unit
            for deploy_unit in self.deploy_units
            if deploy_unit.StageId == stage
        ]

    def execute(self):
        target_resources = self.get_targets()

        for stage in self.stages:
            self.sync_stage(stage, target_resources)

        self._update_feedback()

    def sync_stage(self, stage, target_resources):
        check_deploy_units = self.get_deploy_units(stage)

        check_deploy_units_ids = [
            deploy_unit.Id
            for deploy_unit in check_deploy_units
        ]

        with self._yp_client() as client:
            patch_builder = ResourcesPatchBuilder()
            patch_builder.fill(
                self._read_stage(stage, check_deploy_units_ids,
                                 patch_builder.selector_templates)
            )
            for deploy_unit in check_deploy_units:
                self.set_target(
                    patch_builder,
                    deploy_unit,
                    target_resources.get(deploy_unit.SourceService),
                )

            if patch_builder.changed_deploy_units:
                updates = [
                    {'path': update[0], 'value': update[1]}
                    for update in patch_builder.get_updates()
                ]
                self._update_stage(client, stage, updates)
                self._override_deploy_strategy(
                    client,
                    stage,
                    patch_builder.changed_deploy_units
                )
            else:
                self.log.info('No diff in stages')

    def _override_deploy_strategy(self, client, stage, changed_deploy_unit_ids):
        SELECTORS_TEMPLATES = [
            '/spec/deploy_units/{}/revision',
            '/spec/deploy_units/{}/replica_set/per_cluster_settings'
        ]

        override_deploy_unit_ids = [
            du.Id
            for du in self.deploy_units
            if du.Id in changed_deploy_unit_ids and du.SpeedupSwitchUnderLock
        ]

        deploy_units_state = self._read_stage(
            stage,
            override_deploy_unit_ids,
            SELECTORS_TEMPLATES
        )

        for deploy_unit_id, (revision, deploy_policy) in deploy_units_state.items():
            override_clusters = set()

            for cluster, strategy in deploy_policy.items():
                if self.lock_checker.is_locked(cluster):
                    override_clusters.add(cluster)

            if override_clusters:
                self.log.debug(
                    'Override strategy for %s @ %s to %s pods for rev %s',
                    deploy_unit_id,
                    override_clusters,
                    strategy['pod_count'],
                    revision
                )
                if not self.readonly:
                    client.override_stage_max_unavailable(
                        stage,
                        deploy_unit_id,
                        list(override_clusters),
                        strategy['pod_count'],
                        revision
                    )

    def get_targets(self):
        target_resources = {}

        for deploy_unit in self.deploy_units:
            if deploy_unit.SourceService not in target_resources:
                head = self.services_table.head(deploy_unit.SourceService)
                if not head:
                    continue
                resources = parse_nanny_status(
                    self.services_table.head(deploy_unit.SourceService)['status']
                )
                if resources:
                    target_resources[deploy_unit.SourceService] = resources

        return target_resources

    def set_target(self, patch_builder, deploy_unit, resources):
        for resource in resources or []:
            if resource.name in deploy_unit.ResourceSyncList.ResourceIds:
                patch_builder.update_resource(deploy_unit.Id, resource.name, resource.url)

    def _update_stage(self, client, stage_id, updates):
        if not self.readonly:
            client.update_object('stage', stage_id, set_updates=updates)
        else:
            self.log.warning('[readonly] update stage %s\n%s', stage_id, pprint.pformat(updates))

    def _update_feedback(self):
        feedback_services = frozenset(
            deploy_unit.FeedbackService
            for deploy_unit in self.deploy_units
            if deploy_unit.FeedbackService
        )

        common_states = self._collect_resources()

        for feedback_service in feedback_services:
            current_state = {
                'RESOURCES_SYNCHRONIZE': sorted(set([
                    resource
                    for deploy_unit in self.deploy_units
                    for resource in deploy_unit.ResourceSyncList.ResourceIds
                    if deploy_unit.FeedbackService == feedback_service
                ])),
                'RESOURCES': {
                    name: {'skynet_id': value}
                    for name, value in (common_states.get(feedback_service) or {}).items()
                }
            }

            head = self.services_table.head(feedback_service) or {}
            head = json.loads(head.get('status', '{}'))

            if head != current_state:
                self.log.info('Stage %s state differ\n%s\n%s', feedback_service, head, current_state)
                if not self.readonly:
                    self.services_table.write(feedback_service, json.dumps(current_state))

    def _collect_resources(self):
        SELECTORS_TEMPLATES = (
            '/spec/deploy_units/{}/replica_set/per_cluster_settings',
            '/spec/deploy_units/{}/revision',
            '/status/deploy_units/{}/target_revision',
            '/status/deploy_units/{}/ready',
            '/status/deploy_units/{}/progress',
            (
                '/status/deploy_units/{}/current_target/replica_set'
                '/replica_set_template/pod_template_spec/spec'
                '/pod_agent_payload/spec/resources'
            )
        )

        common_resources = {}

        for stage in self.stages:
            check_deploy_units = {
                deploy_unit.Id: deploy_unit
                for deploy_unit in self.deploy_units
                if deploy_unit.StageId == stage
            }

            deploy_units_state = self._read_stage(
                stage,
                check_deploy_units.keys(),
                SELECTORS_TEMPLATES
            )

            for deploy_unit_id, state in deploy_units_state.items():
                deploy_policy, spec_revision, target_revision, ready_status, progress, resources = state
                deploy_unit = check_deploy_units[deploy_unit_id]

                if self.is_ready(deploy_policy, spec_revision, target_revision, ready_status, progress):
                    if common_resources.get(deploy_unit.FeedbackService, True):
                        common_resources[deploy_unit.FeedbackService] = self._get_common(
                            resources,
                            deploy_unit.ResourceSyncList.ResourceIds,
                            common_resources.get(deploy_unit.FeedbackService, {})
                        )
                else:
                    self.log.debug('Deploy unit %s is not ready', deploy_unit_id)
                    common_resources[deploy_unit.FeedbackService] = {}

        return common_resources

    @staticmethod
    def is_ready(deploy_policy, spec_revision, target_revision, ready_status, progress):
        if spec_revision != target_revision:
            return False

        global_max_unavailable = sum(
            location['deployment_strategy']['max_unavailable']
            for location in deploy_policy.values()
        )

        pods_not_ready = progress['pods_total'] - progress.get('pods_ready', 0)

        return (
            ready_status['status'] == 'true'
            and global_max_unavailable >= pods_not_ready
        )

    def _get_common(self, resources, sync_list, common={}):
        new_common = copy.deepcopy(common)

        _resources = {
            resource['id']: resource['url']
            for resource in resources['layers'] + resources['static_resources']
        }
        for name in sync_list:
            if name not in new_common:
                new_common[name] = _resources.get(name)
            elif new_common[name] == _resources.get(name):
                continue
            else:
                return {}

        return new_common

    def _yp_client(self):
        return yp_utils.client(self._yp_master)

    def _read_stage(self, stage, deploy_units, selectors_templates):
        with self._yp_client() as client:
            results = client.get_object(
                'stage',
                stage,
                selectors=[
                    selector.format(deploy_unit)
                    for deploy_unit in deploy_units
                    for selector in selectors_templates
                ]
            )
            selectors_number = len(selectors_templates)

            return dict(zip(
                deploy_units,
                [
                    results[i:i + selectors_number]
                    for i in xrange(0, len(results), selectors_number)
                ]
            ))


ResourceSpec = collections.namedtuple('Resource', ['name', 'url', 'meta'])


def parse_nanny_status(data):
    return [ResourceSpec(name, description.get('skynet_id'), meta=None)
            for name, description in json.loads(data).iteritems()]


class ResourcesPatchBuilder(object):
    """Narrow subset of stage spec."""
    SELECTOR_TEMPLATE = '/spec/deploy_units/{}/replica_set/replica_set_template' \
                        '/pod_template_spec/spec/pod_agent_payload/spec/resources'

    def __init__(self):
        self._deploy_resources_specs = {}
        self.changeset = {}

    @property
    def selector_templates(self):
        return [self.SELECTOR_TEMPLATE]

    @property
    def changed_deploy_units(self):
        return self.changeset.keys()

    def fill(self, values):
        self._deploy_resources_specs = {
            deploy_unit: vals.pop()
            for deploy_unit, vals in values.items()
        }

    def update_resource(self, deploy_unit_id, name, value):
        # Deletion is forbidden.
        if not value:
            return

        du_spec = self._deploy_resources_specs.get(deploy_unit_id)
        for origin in ('layers', 'static_resources'):
            for resource in du_spec.get(origin, []):
                if resource['id'] == name:
                    if resource.get('url') != value:
                        resource['url'] = value
                        if deploy_unit_id not in self.changeset:
                            self.changeset[deploy_unit_id] = set()
                        self.changeset[deploy_unit_id].add(name)

    def get_updates(self):
        return list(self._updates())

    def _updates(self):
        for deploy_unit, spec in self._deploy_resources_specs.items():
            if deploy_unit in self.changeset:
                yield self.SELECTOR_TEMPLATE.format(deploy_unit), spec

        if self.changeset:
            yield '/spec/revision_info/description', 'auto update {}'.format(self.changeset)


class LockChecker(object):
    def __init__(self, yt_client, locks, ttl):
        self._yt_client = yt_client
        self._locks = locks
        self._ttl = ttl

    def is_locked(self, dc):
        for lock in self._locks:
            if dc == lock.Dc:
                if self._yt_client.exists(lock.Path):
                    return self._is_actual(lock.Path)

        return False

    def _is_actual(self, node):
        create_time = self._yt_client.get_attribute(node, 'creation_time')
        node_age = time.time() - funcs.iso_time_to_timestamp(create_time)
        return node_age < self._ttl


def make_lock_checher(settings):
    if settings.Locks:
        yt_client = yt_utils.create_yt_client(
            settings.Cluster or DEFAULT_LOCK_CLUSTER
        )
    else:
        yt_client = None

    return LockChecker(yt_client, settings.Locks,
                       settings.LockTTL or DEFAULT_LOCK_LIVETIME)
