# coding: utf-8
from __future__ import absolute_import, division, print_function, unicode_literals

from logging import getLogger

import six

from travel.rasp.bus.scripts.cache_reseter import channels
from travel.rasp.bus.scripts.cache_reseter import keys
from travel.rasp.bus.scripts.cache_reseter.redis_provider import redis_provider


class CacheReseter(object):
    def __init__(self, redis_provider, logger):
        self._redis_provider = redis_provider
        self._logger = logger

    def reset_segments_for(self, suppliers, dry=True):
        self._logger.info('Start resetting the segment cache: %r', suppliers)
        redis = self._redis_provider.get_write_client()

        for raw_segment_key in self._filter_keys(redis, [keys.SegmentsKey.create_universal_key()]):
            self._logger.info('\tFind key: %s', raw_segment_key)

            segment_key = keys.SegmentsKey.parse(raw_segment_key)
            if suppliers and segment_key.connector not in suppliers:
                self._logger.info('\t\tSkipping the key %s', segment_key)
                continue
            self._logger.info('\tKey %s has to be deleted', raw_segment_key)
            if dry:
                continue

            channel = channels.SegmentsChannel(segment_key.connector)
            self._logger.info('\t\t%s has been deleted', segment_key)
            redis.delete(str(segment_key))
            redis.publish(str(channel), 'FORCE')
            self._logger.info('\t\tThe message was published to %s', channel)
        self._logger.info('Finish resetting the segment cache: %r', suppliers)

    def reset_search_data(self, supplier, from_point, to_point, when, dry=True):
        self._logger.info('Start resetting cache by %s:%s:%s for %s', from_point, to_point, when, supplier)
        redis = self._redis_provider.get_write_client()

        direction_filters = [
            keys.SearchResultKey(
                connector=supplier,
                from_point=from_point,
                to_point=to_point,
                when=when,
            ),
        ]
        self._logger.info('Filters: %r', direction_filters)

        search_result_keys = [
            keys.SearchResultKey.parse(raw_search_result_key)
            for raw_search_result_key in self._filter_keys(redis, direction_filters)
        ]

        self._logger.info('\tWe are going to delete %d keys:', len(search_result_keys))
        for k in search_result_keys:
            self._logger.info('\t\t %s', k)

        if not dry and search_result_keys:
            redis.delete(*[str(k) for k in search_result_keys])

        self._logger.info('Finish resetting cache by %s:%s:%s for %s', from_point, to_point, when, supplier)

    def reset_search_data_by_point(self, supplier, point, dry=True):
        self._logger.info('Start resetting cache with %s point for %s connector', point, supplier)
        redis = self._redis_provider.get_write_client()

        filters = [
            keys.SearchResultKey(
                connector=supplier,
                from_point=point,
                to_point='*',
                when='*',
            ),
            keys.SearchResultKey(
                connector=supplier,
                from_point='*',
                to_point=point,
                when='*',
            ),
        ]
        self._logger.info('Filters: %r', filters)

        search_result_keys = [
            keys.SearchResultKey.parse(raw_search_result_key)
            for raw_search_result_key in self._filter_keys(redis, filters)
        ]

        self._logger.info('\tWe are going to delete %d keys:', len(search_result_keys))
        for k in search_result_keys:
            self._logger.info('\t\t %s', k)

        if not dry and search_result_keys:
            redis.delete(*[str(k) for k in search_result_keys])

        self._logger.info('Finish resetting cache with %s point for %s connector', point, supplier)

    def reset_search_data_by_direction(self, supplier, from_point, to_point, dry=True, two_way=False):
        self._logger.info('Start resetting cache with direction %s:%s for %s', from_point, to_point, supplier)
        redis = self._redis_provider.get_write_client()

        direction_filters = [
            keys.SearchResultKey(
                connector=supplier,
                from_point=from_point,
                to_point=to_point,
                when='*',
            ),
        ]
        if two_way:
            direction_filters.append(
                keys.SearchResultKey(
                    connector=supplier,
                    from_point=to_point,
                    to_point=from_point,
                    when='*',
                )
            )
        self._logger.info('Filters: %r', direction_filters)

        search_result_keys = [
            keys.SearchResultKey.parse(raw_search_result_key)
            for raw_search_result_key in self._filter_keys(redis, direction_filters)
        ]

        self._logger.info('\tWe are going to delete %d keys:', len(search_result_keys))
        for k in search_result_keys:
            self._logger.info('\t\t %s', k)

        if not dry and search_result_keys:
            redis.delete(*[str(k) for k in search_result_keys])

        self._logger.info('Finish resetting cache with direction %s:%s for %s', from_point, to_point, supplier)

    def _filter_keys(self, redis, filters, count=100):
        for f in filters:
            for k in redis.scan_iter(match=str(f), count=count):
                yield six.ensure_text(k)


cache_reseter = CacheReseter(
    redis_provider=redis_provider,
    logger=getLogger(__name__),
)
