from __future__ import unicode_literals
import itertools
import logging

import gevent
import yp.data_model

from infra.swatlib.gevent import geventutil as gutil
from infra.mc_rsc.src import consts
from infra.mc_rsc.src import model
from infra.mc_rsc.src import sync_status
from infra.mc_rsc.src import yputil


CLUSTER_UNAVAILABLE_REASON = "CLUSTER_UNAVAILABLE"

CLUSTER_UNAVAILABLE_MESSAGE_TPL = "Controller failed to reach '{}' cluster. Will retry..."


def split_iterable_into_batches(iterable, batch_size):
    it = iter(iterable)
    batch = list(itertools.islice(it, batch_size))
    while len(batch) > 0:
        yield batch
        batch = list(itertools.islice(it, batch_size))


class Reflector(object):

    YP_CLIENT_FAILED_ATTEMPTS_NUM = 5
    UPDATED_WATCH_EVENT_TYPES = (
        yp.data_model.ET_OBJECT_CREATED,
        yp.data_model.ET_OBJECT_UPDATED
    )
    MAX_WATCH_RETRIES = 3

    def __init__(self, name, cluster,
                 obj_type, obj_class, obj_filter, selectors, watch_selectors,
                 use_watches, fetch_timestamps,
                 storage, client,
                 select_ids_batch_size, get_objects_batch_size,
                 watch_time_limit_secs, event_count_limit,
                 sleep_secs, select_threads_count,
                 metrics_registry,
                 full_sync_storage=None):

        super(Reflector, self).__init__()
        self.name = name
        self.cluster = cluster
        self.log = logging.getLogger('{}_reflector({})'.format(self.name,
                                                               self.cluster))
        self.obj_type = obj_type
        self.obj_class = obj_class
        self.obj_filter = obj_filter
        self.selectors = selectors
        self.watch_selectors = watch_selectors
        self.fetch_timestamps = fetch_timestamps
        self.use_watches = use_watches
        self.storage = storage
        self.full_sync_storage = full_sync_storage
        self.client = client
        self.select_ids_batch_size = select_ids_batch_size
        self.get_objects_batch_size = get_objects_batch_size
        self.watch_time_limit_secs = watch_time_limit_secs
        self.event_count_limit = event_count_limit
        self.sleep_secs = sleep_secs
        self.last_sync_timestamp = None
        self._pool = gevent.pool.Pool(select_threads_count)
        self._metrics_registry = metrics_registry

    def _make_timestamped_object(self, obj, timestamps):
        selector = self.client.loader.get_updated_time_selector_by_object_type(self.obj_type)
        yp_time = timestamps.get(selector)
        if yp_time:
            updated_time = yputil.cast_yp_timestamp_to_seconds(yp_time)
        else:
            updated_time = None
        return self.make_object_to_sync(obj), updated_time

    def _get_objects_by_ids(self, ids, timestamp):
        rv = []
        if len(ids) > self.get_objects_batch_size:
            threads = []
            for b in split_iterable_into_batches(ids, self.get_objects_batch_size):
                t = self._pool.spawn(self.client.get_objects,
                                     object_type=self.obj_type,
                                     object_class=self.obj_class,
                                     ids=b,
                                     selectors=self.selectors,
                                     timestamp=timestamp,
                                     ignore_nonexistent=True,
                                     fetch_timestamps=True)
                threads.append(t)
            gevent.joinall(threads, raise_error=True)
            for t in threads:
                b = t.get()
                for obj, timestamps in b:
                    rv.append(self._make_timestamped_object(obj, timestamps))
                # We iterate over all objects here, so switch context after
                # every batch (t.get() above does not switch because we joined all
                # threads already).
                gevent.idle()
        else:
            objects = self.client.get_objects(object_type=self.obj_type,
                                              object_class=self.obj_class,
                                              ids=ids,
                                              selectors=self.selectors,
                                              timestamp=timestamp,
                                              ignore_nonexistent=True,
                                              fetch_timestamps=True)
            for obj, timestamps in objects:
                rv.append(self._make_timestamped_object(obj, timestamps))
        return rv

    def make_object_to_sync(self, obj):
        return obj

    def sync(self, timestamp):
        if self.full_sync_storage:
            m = 'sync_storage_time_{}_{}'.format(self.cluster, self.name)
            t = self._metrics_registry.get_gauge(m).timer()
            try:
                ids = self.client.select_object_ids(
                    object_type=self.obj_type,
                    object_class=self.obj_class,
                    timestamp=timestamp,
                    query=self.obj_filter,
                    batch_size=self.select_ids_batch_size
                )
                self.sync_storage_full(object_ids=ids,
                                       storage=self.full_sync_storage,
                                       timestamp=timestamp)
            finally:
                t.stop()

        if self.last_sync_timestamp and self.use_watches:
            m = 'watch_storage_time_{}_{}'.format(self.cluster, self.name)
            t = self._metrics_registry.get_gauge(m).timer()
            try:
                events = self.client.watch_objects(
                    object_type=self.obj_type,
                    start_timestamp=self.last_sync_timestamp,
                    timestamp=timestamp,
                    selectors=self.watch_selectors,
                    event_count_limit=self.event_count_limit,
                    query=self.obj_filter,
                    time_limit_seconds=self.watch_time_limit_secs,
                )
            except Exception as e:
                self.log.error('failed to watch, will fallback to select: %s', e)
            else:
                self.sync_storage_delta(watch_events=events, timestamp=timestamp)
                self.last_sync_timestamp = timestamp
                return
            finally:
                t.stop()

        # Fallback if watches not used or all watch retries failed
        ids = self.client.select_object_ids(
            object_type=self.obj_type,
            object_class=self.obj_class,
            query=self.obj_filter,
            timestamp=timestamp,
            batch_size=self.select_ids_batch_size
        )
        self.sync_storage_full(object_ids=ids, storage=self.storage, timestamp=timestamp)
        self.last_sync_timestamp = timestamp

    def sync_storage_delta(self, watch_events, timestamp):
        # NOTE: For events with the same timestamps we must apply remove
        # NOTE: firstly because when we remove and create pods in one
        # NOTE: transaction both remove and create events have the same
        # NOTE: timestamp.
        # NOTE: But events with different timestamps must be processed in
        # NOTE: order of timestamps.
        watch_events.sort(key=lambda e: (e.timestamp, e.event_type != yp.data_model.ET_OBJECT_REMOVED))
        updated_ids = set()
        removed = 0
        for e in gutil.gevent_idle_iter(watch_events):
            if e.event_type in self.UPDATED_WATCH_EVENT_TYPES:
                updated_ids.add(e.object_id)
            elif e.event_type == yp.data_model.ET_OBJECT_REMOVED:
                removed += 1
                removed_time = yputil.cast_yp_timestamp_to_seconds(e.timestamp)
                self.storage.remove(e.object_id, removed_time)
                # If updated_ids includes pod, then pod was updated in one
                # timestamp and remove in another timestamp (later).
                # We must remove pod from updated_ids.
                updated_ids.discard(e.object_id)
        objects = self._get_objects_by_ids(updated_ids, timestamp)
        for obj, updated_time in gutil.gevent_idle_iter(objects):
            self.storage.replace(obj, updated_time)
        self.log.info("storage delta synced: updated %d objects, removed %s, storage size: %d",
                      len(updated_ids), removed, self.storage.size())

    def sync_storage_full(self, object_ids, storage, timestamp):
        objects = self._get_objects_by_ids(object_ids, timestamp)
        storage.sync_with_objects(objs_with_timestamps=objects)
        self.log.info("storage full synced with %d objects: storage size: %d",
                      len(objects), storage.size())

    def sync_retry(self, timestamp):
        yp_client_retry = 1
        while True:
            try:
                self.sync(timestamp)
            except AssertionError as e:
                # YpClient may raise AssertionError because of thread-safety
                # problems. Just retry this error. Other errors are retried
                # by YpClient itself. Details:
                # https://st.yandex-team.ru/DEPLOY-3749
                if yp_client_retry >= self.YP_CLIENT_FAILED_ATTEMPTS_NUM:
                    raise
                self.log.warning("retryable error occured at syncing "
                                 "storage with YP (attempt %d from %d): %s",
                                 yp_client_retry, self.YP_CLIENT_FAILED_ATTEMPTS_NUM, str(e))
                yp_client_retry += 1
            else:
                break

    def start(self, timestamp):
        sync_condition = sync_status.SyncCondition(self.cluster)
        m = 'sync_time_{}_{}'.format(self.cluster, self.name)
        t = self._metrics_registry.get_gauge(m).timer()
        try:
            self.sync_retry(timestamp)
        except Exception:
            self.log.exception("failed to sync storage with YP")
            sync_condition.set_error(
                reason=CLUSTER_UNAVAILABLE_REASON,
                message=CLUSTER_UNAVAILABLE_MESSAGE_TPL.format(self.cluster)
            )
        else:
            sync_condition.set_success()
        finally:
            t.stop()
        return sync_condition


class ReplicaSetReflector(Reflector):

    def __init__(self, name, cluster,
                 obj_filter, selectors, watch_selectors, use_watches,
                 fetch_timestamps, storage, client,
                 select_ids_batch_size, get_objects_batch_size,
                 watch_time_limit_secs, event_count_limit,
                 sleep_secs, select_threads_count,
                 metrics_registry):
        super(ReplicaSetReflector, self).__init__(
            name=name,
            cluster=cluster,
            obj_type=yp.data_model.OT_REPLICA_SET,
            obj_class=yp.data_model.TReplicaSet,
            obj_filter=obj_filter,
            selectors=selectors,
            watch_selectors=watch_selectors,
            use_watches=use_watches,
            fetch_timestamps=fetch_timestamps,
            storage=storage,
            client=client,
            select_ids_batch_size=select_ids_batch_size,
            get_objects_batch_size=get_objects_batch_size,
            watch_time_limit_secs=watch_time_limit_secs,
            event_count_limit=event_count_limit,
            sleep_secs=sleep_secs,
            select_threads_count=select_threads_count,
            metrics_registry=metrics_registry
        )

    def make_object_to_sync(self, obj):
        return model.ReplicaSet(obj=obj, cluster=self.cluster)


class MultiClusterReplicaSetReflector(Reflector):

    def __init__(self, name, cluster,
                 obj_filter, selectors, watch_selectors, use_watches,
                 fetch_timestamps, storage, client,
                 select_ids_batch_size, get_objects_batch_size,
                 watch_time_limit_secs, event_count_limit,
                 sleep_secs, select_threads_count,
                 metrics_registry):
        super(MultiClusterReplicaSetReflector, self).__init__(
            name=name,
            cluster=cluster,
            obj_type=yp.data_model.OT_MULTI_CLUSTER_REPLICA_SET,
            obj_class=yp.data_model.TMultiClusterReplicaSet,
            obj_filter=obj_filter,
            selectors=selectors,
            watch_selectors=watch_selectors,
            use_watches=use_watches,
            fetch_timestamps=fetch_timestamps,
            storage=storage,
            client=client,
            select_ids_batch_size=select_ids_batch_size,
            get_objects_batch_size=get_objects_batch_size,
            watch_time_limit_secs=watch_time_limit_secs,
            event_count_limit=event_count_limit,
            sleep_secs=sleep_secs,
            select_threads_count=select_threads_count,
            metrics_registry=metrics_registry
        )

    def make_object_to_sync(self, obj):
        return model.MultiClusterReplicaSet(obj=obj, cluster=self.cluster)


def make_mc_rs_reflector(deploy_engine,
                         cluster, obj_filter, selectors, watch_selectors, use_watches,
                         fetch_timestamps, storage, client,
                         select_ids_batch_size, get_objects_batch_size,
                         watch_time_limit_secs, event_count_limit,
                         sleep_secs, select_threads_count,
                         metrics_registry):
    if deploy_engine == consts.RSC_DEPLOY_ENGINE:
        cls = ReplicaSetReflector
        name = 'rs'
    elif deploy_engine == consts.MCRSC_DEPLOY_ENGINE:
        cls = MultiClusterReplicaSetReflector
        name = 'mcrs'
    else:
        raise ValueError('unknown deploy engine {}'.format(deploy_engine))
    return cls(name=name,
               cluster=cluster,
               obj_filter=obj_filter,
               selectors=selectors,
               watch_selectors=watch_selectors,
               use_watches=use_watches,
               fetch_timestamps=fetch_timestamps,
               storage=storage,
               client=client,
               select_ids_batch_size=select_ids_batch_size,
               get_objects_batch_size=get_objects_batch_size,
               watch_time_limit_secs=watch_time_limit_secs,
               event_count_limit=event_count_limit,
               sleep_secs=sleep_secs,
               select_threads_count=select_threads_count,
               metrics_registry=metrics_registry)
