# -*- coding: utf-8 -*-
import time
import pymongo
import random

from itertools import chain
from multiprocessing.pool import ThreadPool

import mpfs.engine.process

from mpfs.engine.process import get_default_log
from mpfs.config import settings
from mpfs.common.util import logger, chunks
from mpfs.core.services.mulca_service import Mulca
from mpfs.core.services.bazinga_service import OnetimeTask
from mpfs.core.filesystem.cleaner.models import DeletedStid, DeletedStidRetry, StorageCleanCheckStid, DeletedStidSources
from mpfs.metastorage.postgres.query_executer import PGQueryExecuter
from mpfs.core.filesystem.dao.legacy import CollectionRoutedDatabase, is_collection_uses_dao
from mpfs.core.versioning.dao.version_data import VersionDataDAO

STORAGE_CLEANER_ENABLED = settings.storage_cleaner['enabled']
STORAGE_CLEANER_WORKER_DRY_RUN = settings.storage_cleaner['worker']['dry_run']
STORAGE_CLEANER_WORKER_TSKV_LOGGER = settings.storage_cleaner['worker']['tskv_logger']
STORAGE_CLEANER_WORKER_DB_CHECKER_THREAD_COUNT = settings.storage_cleaner['worker']['db_checker_thread_count']

STORAGE_CLEANER_WORKER_SLEEP_MSECS = settings.storage_cleaner['worker']['sleep_msecs']
STORAGE_CLEANER_WORKER_ENABLE_STID_ACCESS_CHECK = settings.storage_cleaner['worker']['enable_stid_access_check']


class CheckingStid(object):
    """Хранит связанную со stid-ом информацию"""
    def __init__(self, stid, size=None, stid_source=None):
        self.stid = stid
        self.counter_initialized = None
        self.counter_incremented = None
        self.is_stid_in_db = None
        self.storage_status = None
        self.size = size
        self.stid_source = stid_source

    @property
    def can_remove_from_storage(self):
        if self.is_fotki_stid:
            return False
        return self.counter_initialized is True and self.counter_incremented is False and self.is_stid_in_db is False

    @property
    def can_remove_from_deleted_stids(self):
        if self.is_fotki_stid:
            return True
        return self.is_stid_in_db is True or self.storage_status is not None

    @property
    def is_need_retry(self):
        if self.is_fotki_stid:
            return False
        return self.can_remove_from_deleted_stids and not (200 <= self.storage_status < 300)

    @property
    def is_fotki_stid(self):
        return self.stid.startswith('103.')

    def set_size(self, force=False):
        if self.size is not None and not force:
            return

        try:
            self.size = Mulca().get_file_size(self.stid)
        except Mulca.api_error:
            pass

    def remove_from_storage(self):
        if not self.can_remove_from_storage:
            raise RuntimeError("Cant remove stid %s" % self.stid)
        try:
            self.storage_status = Mulca().remove(self.stid)
        except Mulca.api_error:
            self.storage_status = 500

    def __repr__(self):
        return "%s(%r)" % (self.__class__.__name__, self.stid)


class CleaningLocksManager(object):
    """Управляет локами чистки"""
    ENABLED = STORAGE_CLEANER_WORKER_ENABLE_STID_ACCESS_CHECK

    @classmethod
    def init_counters(cls, stids):
        """Инициализирует  локи-счетчики

        Проходит по списку объектов `CheckingStid` с заполненным полем stid.
        Проставляет counter_initialized = True, если удалось инициализировать счетчик, и False, если он уже был кем-то
        инициализирован.

        :param list[CheckingStid] stids: список объектов `CheckingStid` с заполненным полем stid
        """
        if not stids:
            return
        if not cls.ENABLED:
            for stid in stids:
                stid.counter_initialized = True
            return

        stid_values = [s.stid for s in stids]
        alreary_initialized_stids = list(StorageCleanCheckStid.controller.filter(**{'_id': {'$in': stid_values}}))
        alreary_initialized_stids = [c.stid for c in alreary_initialized_stids]

        for stid in stids:
            stid.counter_initialized = stid.stid not in alreary_initialized_stids

        items_to_create = [StorageCleanCheckStid(stid=s.stid, counter=0) for s in stids if s.counter_initialized]
        StorageCleanCheckStid.controller.bulk_create(items_to_create)

    @classmethod
    def clean_counters(cls, stids):
        """Удаляет локи

        Проходит по списку объектов `CheckingStid` с заполненным полем stid и counter_initialized и чистит счетчики из
        базы для стидов с флагом counter_initialized == True.

        :param list[CheckingStid] stids: список объектов `CheckingStid` с заполненным полем stid и counter_initialized
        """
        if not stids:
            return
        if not cls.ENABLED:
            return

        stid_values = [s.stid for s in stids if s.counter_initialized]
        StorageCleanCheckStid.controller.filter(**{'_id': {'$in': stid_values}}).delete()

    @classmethod
    def check_counters(cls, stids):
        """Проверяет локи

        Проходит по списку объектов `CheckingStid` с заполненным полем stid и counter_initialized. Проверяет, равны ли
        счетчики у инициализированых записей (counter_initialized == True) нулю. Если равны, то проставляет флаг
        can_clean в True.

        :param list[CheckingStid] stids: список объектов `CheckingStid` с заполненным полем stid и counter_initialized
        """
        if not stids:
            return
        if not cls.ENABLED:
            for stid in stids:
                stid.counter_incremented = False
            return

        stid_values = [s.stid for s in stids if s.counter_initialized]
        checked_stids = list(StorageCleanCheckStid.controller.filter(**{'_id': {'$in': stid_values}}))
        zero_counter_stids = [s.stid for s in checked_stids if s.counter == 0]

        for stid in stids:
            stid.counter_incremented = stid.stid not in zero_counter_stids

        for check in checked_stids:
            if check.counter != 0:
                logging_data = {
                    'unixtime': int(time.time()),
                    'stid': check.stid,
                    'counter_value': check.counter,
                }
                get_default_log().info('Found incremented counter stid. '
                                       'STID: "%(stid)s"; counter_value: "%(counter_value)s"; '
                                       'unixtime: "%(unixtime)s";' % logging_data)


class DbChecker(object):
    """Реализует логику поиска stid-ов в БД"""
    FAKE_VERSION_DATA_NAME = 'fake_version_data'
    mongo_verifiable_collections = ('user_data', 'trash', 'attach_data', 'misc_data', 'narod_data', 'hidden_data',
                                    'photounlim_data', FAKE_VERSION_DATA_NAME, 'additional_data', 'client_data',
                                    'notes_data')
    postgres_verifiable_collections = ('user_data', 'misc_data', 'storage_duplicates')
    version_data_dao = VersionDataDAO()

    def __init__(self):
        self._is_connections_initialized = False
        self._mapper = None
        self._shards_names = None
        self._routed_db = None

    def is_stids_in_db(self, checking_stids):
        """Проверяет наличие stid-ов в базах MPFS-а

        Проходит по списку объектов `CheckingStid` с заполненным полем stid.
        Проставляет is_stid_in_db = True, если стид найден в базе, и False, если нет

        :param list[CheckingStid] checking_stids: список проверяемых стидов
        :return list[CheckingStid]: возвращает те же объекты с проставленным атибутом `is_stid_in_db`
        """
        if not checking_stids:
            return checking_stids

        self._setup_connections()

        args = []
        # https://st.yandex-team.ru/CHEMODAN-38491
        for chunk in chunks(checking_stids, 200000):
            args += [(s, c, [s.stid for s in chunk]) for s, c in self._shard_collection_generator()]

        pool = ThreadPool(processes=STORAGE_CLEANER_WORKER_DB_CHECKER_THREAD_COUNT)
        results = set(chain(*pool.map(self._find_many, args)))  # объединяем все результаты в один список стидов
        pool.close()
        pool.join()
        for checking_stid in checking_stids:
            checking_stid.is_stid_in_db = checking_stid.stid in results

        return checking_stids

    def _setup_connections(self):
        """Подключаемся ко всем шардам"""
        if self._is_connections_initialized:
            return

        self._mapper = mpfs.engine.process.dbctl().mapper
        self._routed_db = CollectionRoutedDatabase()
        rspool = self._mapper.rspool

        rspool.clean_connection_pool()
        self._shards_names = set()
        for shard_name in rspool.get_all_shards_names():
            rspool.get_connection_for_rs_name(
                shard_name,
                read_preference=pymongo.ReadPreference.SECONDARY_PREFERRED
            )
            self._shards_names.add(shard_name)

        all_shards = PGQueryExecuter().get_all_shard_ids()
        if not all_shards:
            raise RuntimeError('Received empty shard list from sharpei')
        for shard_name in all_shards:
            self._shards_names.add(shard_name)

        self._shards_names = list(self._shards_names)
        random.shuffle(self._shards_names)

        self._is_connections_initialized = True

    def _find_many(self, args):
        shard_name, collection_name, stids = args

        try:
            if collection_name == self.FAKE_VERSION_DATA_NAME:
                # поиск stid-ов в версиях. Должно работать только для mongo
                # потому что в PG бинарники версий лежат в storage_files и
                # проверяются на этапе поиска по user_data.
                find_method = self._find_in_version_data
            else:
                find_method = self._find_in_files
                if not self._routed_db.is_collection_uses_dao(collection_name):
                    raise RuntimeError('Access to non-DAO collection `%s` using CollectionRoutedDatabase' % collection_name)

            found_stids = set()
            prev_candidate_stids = None
            candidate_stids = set(stids)
            # коэффициент дедупликации примерно 30%, поэтому лимит ставим в три раза больше, чем стиды
            stids_per_query_limit = len(candidate_stids) * 3
            while len(candidate_stids) > 0:
                if prev_candidate_stids is not None and prev_candidate_stids == candidate_stids:
                    raise RuntimeError('Candidate stids did not change on last iteration: %s' % candidate_stids)

                found = find_method(shard_name, collection_name, list(candidate_stids), stids_per_query_limit)

                if len(found) == 0:
                    break

                found_stids = found_stids.union(found)
                prev_candidate_stids = set(candidate_stids)
                candidate_stids = candidate_stids.difference(found)

            return found_stids
        except Exception:
            get_default_log().exception(
                'Got exception on db request. Shard "%s"; collection: "%s"; stids: "%s"' % (shard_name, collection_name, stids)
            )
            raise

    def _find_in_version_data(self, shard_name, collection_name, stids, limit):
        """Поиск stid-ов в версиях"""
        found = set()
        dao_items_gen = self.version_data_dao.fetch_by_stids(shard_name, stids, limit)
        for dao_item in dao_items_gen:
            found.add(dao_item.file_stid)
            found.add(dao_item.preview_stid)
            found.add(dao_item.digest_stid)
        return found & set(stids)

    def _find_in_files(self, shard_name, collection_name, stids, limit):
        """Поиск stid-ов коллекциях/таблицах с файлами"""
        docs = self._routed_db[collection_name].find_stids_on_shard(stids, shard_name, limit=limit)
        found = {s['stid'] for doc in docs for s in doc['data']['stids']}
        return found & set(stids)

    def _shard_collection_generator(self):
        self._setup_connections()
        for shard_name in self._shards_names:
            if shard_name.startswith('disk'):
                verifiable_collections = self.mongo_verifiable_collections
            else:
                verifiable_collections = self.postgres_verifiable_collections

            for collection_name in verifiable_collections:
                yield shard_name, collection_name


class StorageCleanerWorker(OnetimeTask):
    """Воркер-чистильщик стораджа

    Общий механизм работы с воркерами такой:
        worker = StorageCleanerWorker([<stid_1>, ..., <stid_N>])
        worker.run()
    """
    log = mpfs.engine.process.get_default_log()
    tskv_log = logger.get(STORAGE_CLEANER_WORKER_TSKV_LOGGER)

    ENABLED = STORAGE_CLEANER_ENABLED
    BAZINGA_TASK_NAME = 'StorageCleanerWorker'
    DRY_RUN = STORAGE_CLEANER_WORKER_DRY_RUN
    SLEEP_MSECS = STORAGE_CLEANER_WORKER_SLEEP_MSECS / 1000.0

    def __init__(self, stids):
        """
        :param list[str] stids: список id файлов в мульке, которые нужно проверить и, при необходимости, удалить
        """
        super(StorageCleanerWorker, self).__init__()

        if not isinstance(stids, list):
            raise ValueError()

        self.stids = []
        self.raw_stids = stids
        self._db_checker = None

    def build_command_parameters(self):
        return self.raw_stids

    def run(self):
        if not self.ENABLED:
            self.log.info('Disabled. Check config."')
            return
        if not self.raw_stids:
            self.log.info('No stids')
            return

        deleted_stids = {s.stid: s for s in DeletedStid.controller.filter(**{'_id': {'$in': self.raw_stids}})}
        self.stids = []
        for raw_stid in self.raw_stids:
            if raw_stid in deleted_stids:
                c_stid = CheckingStid(
                    raw_stid,
                    size=deleted_stids[raw_stid].size,
                    stid_source=deleted_stids[raw_stid].stid_source
                )
            else:
                c_stid = CheckingStid(raw_stid)
            self.stids.append(c_stid)

        self._db_checker = DbChecker()
        try:
            self._run_stids_check()
            self._set_stids_size()
            self._clean_storage()
            self._post_clean()
        except Exception as e:
            self.log.exception('Got exception (%s) on main handler. Stids: "%s"' % (e.__class__.__name__, self.stids))
        else:
            for stid in self.stids:
                self._add_to_log(stid)

    def _add_to_log(self, stid):
        logging_data = {
            'tskv_format': 'disk-mulca-clean-log',
            'unixtime': int(time.time()),
            'stid': stid.stid,
            'status': stid.storage_status,
            'can_clean': stid.can_remove_from_storage,
            'dry_run': self.DRY_RUN,
            'stid_size': stid.size,
            'stid_source': stid.stid_source
        }

        self.tskv_log.info(logger.TSKVMessage(**logging_data))

    def _run_stids_check(self):
        CleaningLocksManager.init_counters(self.stids)
        counter_initialized_stids = [stid for stid in self.stids if stid.counter_initialized]
        try:
            # первый раз проверяем
            self._db_checker.is_stids_in_db(counter_initialized_stids)
            not_in_db_stids = [stid for stid in counter_initialized_stids if not stid.is_stid_in_db]

            time.sleep(self.SLEEP_MSECS)
            # второй раз проверяем
            self._db_checker.is_stids_in_db(not_in_db_stids)
            not_in_db_stids = [stid for stid in not_in_db_stids if not stid.is_stid_in_db]

            CleaningLocksManager.check_counters(not_in_db_stids)
        finally:
            CleaningLocksManager.clean_counters(counter_initialized_stids)

    def _set_stids_size(self):
        for stid in self.stids:
            stid.set_size()

    def _clean_storage(self):
        stids_for_clean = [s for s in self.stids if s.can_remove_from_storage]
        for stid in stids_for_clean:
            if self.DRY_RUN:
                stid.storage_status = 'dry_run'
            else:
                stid.remove_from_storage()
        stids_for_retry = [DeletedStidRetry(stid=s.stid) for s in stids_for_clean if s.is_need_retry]
        if stids_for_retry:
            DeletedStidRetry.controller.bulk_create(stids_for_retry)

    def _post_clean(self):
        if self.DRY_RUN:
            return
        remove_stids = [s.stid for s in self.stids if s.can_remove_from_deleted_stids]
        if remove_stids:
            DeletedStid.controller.filter(**{'_id': {'$in': remove_stids}}).delete()

        fotki_stids = [s.stid for s in self.stids if s.is_fotki_stid]
        for fotki_stid in fotki_stids:
            self.log.info('Ignore fotki stid: %s', fotki_stid)
