# coding=utf-8
from __future__ import unicode_literals

import six
from infra.swatlib import pbutil
from infra.watchdog.src.lib.cleanup_policies import POLICIES_REGISTRY
from infra.watchdog.src.lib.metrics import Counter
from infra.watchdog.src.lib.service_pbutil import get_current_and_target, iter_snapshots
from nanny_repo import repo_pb2

SETTLED_SNAPSHOT_STATES = (
    repo_pb2.SnapshotStatus.ACTIVE,
    repo_pb2.SnapshotStatus.PREPARED,
    repo_pb2.SnapshotStatus.CREATED,
)


def count_uncleaned_snapshots(cleanup_policy, s_pb):
    if not cleanup_policy:
        return None
    policy_class = POLICIES_REGISTRY.get(cleanup_policy.spec.type)
    if not policy_class:
        return None
    cleanup_policy_processor = policy_class(cleanup_policy)
    if not cleanup_policy_processor:
        return None
    return cleanup_policy_processor.count_unprocessed_snapshots(s_pb)


def count_stuck_snapshots(s_pb):
    stuck_snapshots = Counter()
    for sn_id, sn, sn_status in iter_snapshots(s_pb.spec.snapshot, s_pb.status.snapshot):
        current, current_name, target, target_name = get_current_and_target(sn, sn_status)
        is_paused = bool(s_pb.status.is_paused.value)
        has_settled_state = sn_status is None or current in SETTLED_SNAPSHOT_STATES
        if not is_paused and current_name != target_name and has_settled_state:
            stuck_snapshots.inc()
    return stuck_snapshots


def get_snapshot_statuses(s_pb):
    for sn_id, sn, sn_status in iter_snapshots(s_pb.spec.snapshot, s_pb.status.snapshot):
        current, _, _, _ = get_current_and_target(sn, sn_status)
        yield sn_id, current


def check_awaited_snapshot_states(zk_client, expected_snapshot_states):
    count = Counter()
    for s_id, snapshots in six.iteritems(expected_snapshot_states):
        s_pb = zk_client.get_service_state(s_id)
        for sn_id, _, sn_status in iter_snapshots(s_pb.spec.snapshot, s_pb.status.snapshot):
            if sn_id not in snapshots:
                continue
            if sn_status:
                current_name = pbutil.enum_value_to_name(repo_pb2.SnapshotStatus.Status.DESCRIPTOR, sn_status.status)
            else:
                current_name = 'DESTROYED'
            count.inc(value=snapshots[sn_id].get(current_name, 0))
    return count


def collect_snapshots_info(cleanup_policy, s_pb):
    # Check #1: number of snapshots in non-paused services that are stuck in settled state
    stuck_snapshots = count_stuck_snapshots(s_pb)
    # Check #2.1: number of snapshots that should be processed by cleanup policies
    unprocessed_cleanup_snapshots = count_uncleaned_snapshots(cleanup_policy, s_pb)
    return stuck_snapshots, unprocessed_cleanup_snapshots
