import os
import traceback
import logging

import sandbox.sandboxsdk.process as sdk_process

from sandbox.sandboxsdk import parameters

from sandbox.projects.common import utils
from sandbox.projects.common import apihelpers
from sandbox.projects.common.search import database as search_database

from sandbox.projects.common.search import settings as media_settings
from sandbox.projects.common.search import config as search_config


class IndexTypeParameter(parameters.SandboxStringParameter):
    name = 'index_type'
    description = 'Index type'
    default_value = media_settings.INDEX_MAIN
    choices = [
        ('Main', media_settings.INDEX_MAIN),
        ('Quick', media_settings.INDEX_QUICK),
        ('CBIR', media_settings.INDEX_CBIR_MAIN),
        ('Middle', media_settings.INDEX_MIDDLE),
        ('Main thumbs', media_settings.INDEX_THUMB_MAIN),
        ('Quick thumbs', media_settings.INDEX_THUMB_QUICK),
        ('Related queries', media_settings.INDEX_RQ),
        ('RIM', media_settings.INDEX_RIM),
        ('InvertedIndex', media_settings.INDEX_INVERTED),
        ('Embedding', media_settings.INDEX_EMBEDDING)
    ]


class IndexShardChecker(parameters.LastReleasedResource):
    name = 'shard_checker'
    description = 'Shard checker binary'
    resource_type = 'IMGSEARCH_SHARD_CHECKER_EXECUTABLE'


FILE_SIZES_PREFIX = 'file_sizes_'


class ImagesLoadBasesearchDatabase(search_database.BaseLoadDatabaseTask):
    """
        Download database for Yandex.Images service
    """

    type = 'IMAGES_LOAD_BASESEARCH_DATABASE'

    input_parameters = (IndexTypeParameter, IndexShardChecker) + \
        search_database.BaseLoadDatabaseTask.input_parameters
    execution_space = 150000

    def _get_expected_files(self):
        index_type = self.ctx[IndexTypeParameter.name]
        if index_type == media_settings.INDEX_MIDDLE:
            return []
        elif index_type == media_settings.INDEX_MAIN:
            return ['indexpanther.key.wad']
        elif index_type == media_settings.INDEX_QUICK:
            return ['indeximg3']
        elif index_type == media_settings.INDEX_CBIR_MAIN:
            return ['indexgeompacked', 'indexmeta', 'indexfeatures']
        elif index_type in (media_settings.INDEX_THUMB_MAIN, media_settings.INDEX_THUMB_QUICK):
            return ['thdb.alive', 'thdb.index', 'thdb.info']
        elif index_type == media_settings.INDEX_RIM:
            return ['rimdb.alive', 'rimdb.index', 'rimdb.info']
        elif index_type == media_settings.INDEX_INVERTED:
            return ['invertedindexstorage.block.wad']
        elif index_type == media_settings.INDEX_EMBEDDING:
            return ['embeddings']
        else:
            return []

    def _get_database_resource_type(self):
        return media_settings.ImagesSettings.basesearch_database_resource(self.ctx[IndexTypeParameter.name])

    def on_execute(self):
        search_database.BaseLoadDatabaseTask.on_execute(self)

        self._compute_statistics()
        self._query_info()

        if self.ctx[IndexTypeParameter.name] == media_settings.INDEX_MAIN:
            self._check_shard()

    def _mmapped_file_names(self):
        try:
            # Middlesearch index type leads to exception
            config_resource_type = media_settings.ImagesSettings.basesearch_config_resource(self.ctx[IndexTypeParameter.name])
        except:
            return ()

        config_resource = apihelpers.get_last_released_resource(config_resource_type)
        config_resource_path = self.sync_resource(config_resource.id)
        config = search_config.SearchConfig.get_config_from_file(config_resource_path)

        # Note -- expects one collection for proper work
        prefetched_files = config.get_parameter("Collection/PrefetchIndexFiles")
        return prefetched_files.split(";") if prefetched_files else ()

    def _compute_statistics(self):
        file_sizes = self._compute_file_sizes(self._get_database_local_path())

        mmapped_files = set(self._mmapped_file_names())
        if mmapped_files:
            logging.info("Memory mapped files: %s", ", ".join(mmapped_files))
        else:
            logging.info("No memory mapped files")
        memory_mapped_file_sizes = {key: value for (key, value) in file_sizes.iteritems()
                                    if key in mmapped_files}
        statistics = {
            media_settings.SHARD_TOTAL_SIZE_ATTRIBUTE_NAME: str(sum(file_sizes.itervalues())),
            media_settings.SHARD_MAPPED_SIZE_ATTRIBUTE_NAME: str(sum(memory_mapped_file_sizes.itervalues()))
        }

        for key, value in file_sizes.iteritems():
            statistics[FILE_SIZES_PREFIX + key] = str(value)

        out_resource_id = self.ctx[search_database.OUT_RESOURCE_KEY]
        utils.set_resource_attributes(out_resource_id, statistics)

    def _query_info(self):
        # Function reads file stamp.TAG from shard directory and extract these attributes:
        #   TierType --> 'tier_type'
        #     expected values: 'unk' (tier0), 'grb' (tier1)
        out_resource_id = self.ctx[search_database.OUT_RESOURCE_KEY]

        try:
            info = {}
            with open(os.path.join(self._get_database_local_path(), "stamp.TAG")) as stamp:
                for line in stamp:
                    key, value = line.partition("=")[::2]
                    info[key.strip()] = value.strip()

            if "TierType" in info.keys():
                utils.set_resource_attributes(out_resource_id, {"tier_type": info["TierType"]})
            else:
                logging.info("Shard does not have TierType entry in stamp.TAG file")
        except:
            logging.info(traceback.format_exc())

    def _compute_file_sizes(self, dir_path):
        result = {}
        for file_name in os.listdir(dir_path):
            file_path = os.path.join(dir_path, file_name)
            if os.path.isfile(file_path):
                result[file_name] = os.stat(file_path).st_size
        return result

    def _check_shard(self):
        shard_checker_path = self.sync_resource(self.ctx[IndexShardChecker.name])
        sdk_process.run_process([shard_checker_path, "--index", self._get_database_local_path()],
                                       log_prefix=IndexShardChecker.name,
                                       timeout=600, wait=True, check=True)


__Task__ = ImagesLoadBasesearchDatabase
