from __future__ import unicode_literals

import time

import aniso8601
from infra.watchdog.src.lib.metrics import Counter
from nanny_repo import repo_pb2


class ServiceStatus(object):
    ONLINE = 'ONLINE'
    OFFLINE = 'OFFLINE'


def by_transition_time(sn):
    return sn.last_transition_time.seconds


def build_status_index(ss):
    rv = {}
    for sn in ss.snapshot:
        rv[sn.id] = sn
    return rv


class SimpleCountLimitPolicy(object):
    SETTLED_SUMMARIES = [
        ServiceStatus.ONLINE,
        ServiceStatus.OFFLINE,
    ]
    INTERESTING_STATUSES = [
        repo_pb2.Snapshot.PREPARED,
        repo_pb2.Snapshot.CREATED,
    ]
    STALLED_TARGETS = [
        repo_pb2.Snapshot.PREPARED,
        repo_pb2.Snapshot.CREATED,
    ]
    STALLED_STATUSES = [
        repo_pb2.SnapshotStatus.GENERATING,
        repo_pb2.SnapshotStatus.PREPARING,
    ]
    TYPE = repo_pb2.CleanupPolicySpec.SIMPLE_COUNT_LIMIT

    def __init__(self, policy_pb):
        spec = policy_pb.spec.simple_count_limit
        self.max_count = spec.snapshots_count
        self.disposable_count = spec.disposable_count
        if spec.stalled_ttl:
            self.stalled_ttl_seconds = aniso8601.parse_duration(spec.stalled_ttl).total_seconds()
        else:
            self.stalled_ttl_seconds = 0

    def is_stalled(self, sn, sn_status, now_seconds):
        """
        :type sn: repo_pb2.Snapshot
        :type sn_status: repo_pb2.SnapshotStatus|None
        :type now_seconds: int
        """
        if not self.stalled_ttl_seconds:
            return False
        if sn.target not in self.STALLED_TARGETS:
            return False
        if sn_status is None or sn_status.status not in self.STALLED_STATUSES:
            return False
        return (now_seconds - sn_status.last_transition_time.ToSeconds()) > self.stalled_ttl_seconds

    def _collect_removable_snapshots(self, snapshots, max_count, condition):
        removable = set()
        can_remove = max_count == 0
        c = 0
        # Iterate from fresh to old and reap all extra snapshots
        for sn in snapshots:
            if sn.id in removable:
                continue
            if sn.target in self.INTERESTING_STATUSES and condition(sn):
                if can_remove:
                    removable.add(sn.id)
                else:
                    c += 1
                    if c >= max_count:
                        can_remove = True
        return removable

    def count_unprocessed_snapshots(self, s_pb):
        """
        :type s_pb: repo_pb2.Service
        """
        counter = Counter()
        if s_pb.status.is_paused.value:
            return counter
        # Sort snapshots first (target state)
        all_snapshots = sorted(s_pb.spec.snapshot, key=by_transition_time, reverse=True)
        # Run stalled snapshots cleanup first
        si = build_status_index(s_pb.status)
        now = int(time.time())
        snapshots = []  # Snapshots after stalled being removed
        for sn in all_snapshots:
            if self.is_stalled(sn, si.get(sn.id), now):
                counter.inc()
            else:
                snapshots.append(sn)
        # Now check if we're settled and cleanup according to policy settings
        if s_pb.status.summary.value not in self.SETTLED_SUMMARIES:
            return counter
        removable = self._collect_removable_snapshots(snapshots,
                                                      max_count=self.max_count,
                                                      condition=lambda _sn: not _sn.is_disposable)
        removable = removable.union(self._collect_removable_snapshots(snapshots,
                                                                      max_count=self.disposable_count,
                                                                      condition=lambda _sn: _sn.is_disposable))
        counter.inc(value=len(removable))
        return counter


CLEANUP_POLICIES = (SimpleCountLimitPolicy,)
POLICIES_REGISTRY = {
    policy.TYPE: policy for policy in CLEANUP_POLICIES
}
