from sandbox.sandboxsdk.channel import channel
from sandbox.sandboxsdk import errors
from sandbox.sandboxsdk import parameters
from sandbox.sandboxsdk import sandboxapi
from sandbox.sandboxsdk import task

import sandbox.common.types.client as ctc
from sandbox.projects import resource_types
from sandbox.projects.cajuper import resources
from sandbox.projects.common import apihelpers
from sandbox.projects.common.search import settings as media_settings
from sandbox.projects.common import utils
from sandbox.projects.common.search import BaseTestSuperMindTask as supermind_task
from sandbox.projects.common.search import database as search_database

from sandbox.projects.images.resources import ImagesGenerateBasesearchRequests as generate_task
from sandbox.projects.images.metasearch import task as metasearch_task
from sandbox.projects.images.rq.resources import IMAGES_RQ_SEARCH_SHARD_MAP

import numpy
import logging


_SHARD_MAPS = 'shard_maps'
_SHARD_CACHE = 'shard_cache'
_SHARD_CATALOG = 'shard_catalog'
_PLAN_CACHE = 'plan_cache'
_PERFORMANCE_CACHE = 'performance_cache'
_PERFORMANCE_STATS = 'performance_stats'

_PLAN_GENERATOR_GROUP = 'Plan generation options'
_DATABASE_SHARDS_GROUP = 'Database shards'


class OldShardmapParameter(parameters.ResourceSelector):
    name = 'old_shardmap_resource_id'
    description = 'Old shardmap'
    resource_type = [
        resource_types.IMAGES_BASE_SHARDMAP,
        resource_types.IMAGES_FAST_BASE_SHARDMAP,
        IMAGES_RQ_SEARCH_SHARD_MAP,
        resource_types.VIDEO_BASE_SHARDMAP
    ]
    group = _DATABASE_SHARDS_GROUP
    required = False


class NewShardmapParameter(parameters.ResourceSelector):
    name = 'new_shardmap_resource_id'
    description = 'New shardmap'
    resource_type = [
        resource_types.IMAGES_BASE_SHARDMAP,
        resource_types.IMAGES_FAST_BASE_SHARDMAP,
        IMAGES_RQ_SEARCH_SHARD_MAP,
        resource_types.VIDEO_BASE_SHARDMAP
    ]
    group = _DATABASE_SHARDS_GROUP
    required = False


class OldProductionIndexState(parameters.ResourceSelector):
    name = 'old_production_index_state_resource_id'
    description = 'Old production index state'
    resource_type = [
        resources.ImagesProductionIndexStateResource
    ]
    group = _DATABASE_SHARDS_GROUP
    required = False


class NewProductionIndexState(parameters.ResourceSelector):
    name = 'new_production_index_state_resource_id'
    description = 'New production index state'
    resource_type = [
        resources.ImagesProductionIndexStateResource
    ]
    group = _DATABASE_SHARDS_GROUP
    required = False


class ShardsCountParameter(parameters.SandboxIntegerParameter):
    name = 'shards_count'
    description = 'Number of shards to test'
    group = _DATABASE_SHARDS_GROUP
    default_value = 1


class IndexTypeParameter(parameters.SandboxStringParameter):
    name = 'index_type'
    description = 'Type of index (main/garbage)'
    group = _DATABASE_SHARDS_GROUP
    default_value = "main"


class ForceShardLoadParameter(parameters.SandboxBoolParameter):
    name = 'force_shard_load'
    description = 'Force shard load'
    group = _DATABASE_SHARDS_GROUP
    default_value = False


class RequestsNumberParameter(parameters.SandboxIntegerParameter):
    group = _PLAN_GENERATOR_GROUP
    name = 'requests_number'
    description = 'Mmeta request count'
    default_value = 100000


class BasePriemkaDatabaseTask(task.SandboxTask):

    input_parameters = (
        OldShardmapParameter,
        NewShardmapParameter,
        OldProductionIndexState,
        NewProductionIndexState,
        IndexTypeParameter,
        ShardsCountParameter,
        ForceShardLoadParameter,
        RequestsNumberParameter,
    )

    _OLD_AGE = 'old'
    _NEW_AGE = 'new'
    _ALL_AGES = (_OLD_AGE, _NEW_AGE)

    # Mapping between old and new stats names
    _STATS = (
        ("median", "shooting.rps_0.5"),
        ("stddev", "shooting.rps_stddev"),
        ("errors", "shooting.errors"),
    )
    client_tags = ctc.Tag.Group.LINUX

    ###
    def _get_tests(self):
        raise NotImplementedError()

    def _get_shard_resource_ids(self, age, index_type):
        return self.ctx[_SHARD_CATALOG][age][index_type].itervalues()

    def _get_index_type_for_tier(self, tier):
        raise NotImplementedError()

    ###
    def _get_database_task(self, index_type):
        raise NotImplementedError()

    def _get_database_args(self, index_type):
        raise NotImplementedError()

    def _get_database_resource(self, index_type):
        raise NotImplementedError()

    ###
    def _get_basesearch_executable(self, index_type):
        raise NotImplementedError()

    def _get_basesearch_config(self, index_type):
        raise NotImplementedError()

    def _get_basesearch_models(self, index_type):
        raise NotImplementedError()

    def _get_basesearch_snippetizer_index_type(self, index_type):
        raise NotImplementedError()

    def _get_basesearch_performance_task(self, index_type):
        raise NotImplementedError()

    def _get_basesearch_performance_args(self, index_type, query_type):
        raise NotImplementedError()

    def _get_basesearch_shooting_parameters(self, index_type):
        raise NotImplementedError()

    ###
    def _get_middlesearch_executable(self, index_type):
        raise NotImplementedError()

    def _get_middlesearch_config(self, index_type):
        raise NotImplementedError()

    def _get_middlesearch_data(self, index_type):
        raise NotImplementedError()

    def _get_middlesearch_models(self):
        raise NotImplementedError()

    def _get_middlesearch_plan(self, meta_index_type, base_index_type):
        raise NotImplementedError()

    def _get_fake_shardmap(self, state):
        raise NotImplementedError()

    def _get_middlesearch_index_type(self):
        return media_settings.INDEX_MIDDLE

    def on_execute(self):
        logging.debug("Parsing shardmap resources...")
        if _SHARD_MAPS not in self.ctx:
            shard_maps = {}
            for age in self._ALL_AGES:
                shard_maps[age] = self.__get_shard_map(age)
            self.ctx[_SHARD_MAPS] = shard_maps

        logging.debug("Getting database resources...")
        if _SHARD_CACHE not in self.ctx:
            for args in self.__database_foreach():
                self.__database_subtask(*args)

        if _PLAN_CACHE not in self.ctx:
            for args in self.__plan_foreach():
                self.__plan_subtask(*args)

        if _PERFORMANCE_CACHE not in self.ctx:
            for test in self._get_tests():
                for args in self.__performance_foreach(*test):
                    self.__performance_subtask(*args)

        utils.check_subtasks_fails(stop_on_broken_children=False, fail_on_first_failure=True)

        stats = self.ctx.setdefault(_PERFORMANCE_STATS, {})
        for test in self._get_tests():
            key = _make_ctx_key(*test)
            results = [self.__performance_stats(*args) for args in self.__performance_foreach(*test)]
            results.append(self.__performance_aggregate(results))
            stats[key] = results

    @property
    def footer(self):
        def _format(fmt, value, default=""):
            return fmt.format(value) if value is not None else default

        def _footer(title, results):
            return {
                "<h4>{}</h4>".format(title): {
                    "header": [
                        {"key": "title",   "title": "Test"},
                        {"key": "status",   "title": "Status"},
                        {"key": "old_shard",   "title": "Old shard"},
                        {"key": "new_shard",   "title": "New shard"},
                        {"key": "median", "title": "Median RPS"},
                        {"key": "stddev", "title": "Standard deviation"},
                        {"key": "errors", "title": "Errors"},
                    ],
                    "body":  {
                        "title": [
                            _format("<a href='/task/{0}/view'>{0}</a>", v.get("task_id"), v.get("title"))
                            for v in results
                        ],
                        "status": [utils.colored_status(v.get("task_status", "")) for v in results],
                        "old_shard": [v.get("old_shard", "") for v in results],
                        "new_shard": [v.get("new_shard", "") for v in results],
                        "median": [_format("{:0.2f}%", v.get("median")) for v in results],
                        "stddev": [_format("{:0.2f}%", v.get("stddev")) for v in results],
                        "errors": [_format("{}%", v.get("errors")) for v in results],
                    },
                }
            }

        items = []

        if _PERFORMANCE_CACHE not in self.ctx:
            items.append({"&nbsp;": "Calculating..."})
        elif _PERFORMANCE_STATS not in self.ctx:
            for test in self._get_tests():
                results = [self.__performance_stats(*args) for args in self.__performance_foreach(*test)]
                items.append(_footer(_make_ctx_key(*test), results))
        else:
            for title, results in self.ctx[_PERFORMANCE_STATS].iteritems():
                items.append(_footer(title, results))

        return [{"content": item} for item in items]

    def _get_index_types(self):
        return set(index_type for index_type, query_type, supermind_mult in self._get_tests())

    def __database_foreach(self):
        """ generates multimedia base shards list
            (including middle, snippets shards for images)"""

        shard_count = self.ctx[ShardsCountParameter.name]
        meta_index_type = self._get_middlesearch_index_type()
        for index_type in self._get_index_types():
            for age in self._ALL_AGES:
                middle_shards = self._get_shard_name(age, meta_index_type)
                if middle_shards:
                    yield age, meta_index_type, middle_shards[0]

                for base_shard in self._get_shard_name(age, index_type)[:shard_count]:
                    yield age, index_type, base_shard

                snippetizer_index_type = self._get_basesearch_snippetizer_index_type(index_type)
                for snip_shard in self._get_shard_name(age, snippetizer_index_type)[:shard_count]:
                    yield age, snippetizer_index_type, snip_shard

    def __database_subtask(self, age, index_type, shard_name):
        shard_cache = self.ctx.setdefault(_SHARD_CACHE, {})
        shard_catalog = self.ctx.setdefault(_SHARD_CATALOG, {})

        if shard_name in shard_cache:
            logging.debug('shard: {} is cached in ctx'.format(shard_name))
            shard_catalog.setdefault(age, {}).setdefault(index_type, {})[shard_name] = shard_cache[shard_name]
            return

        def _update_cache(shard_resource_id):
            shard_cache[shard_name] = shard_resource_id
            shard_catalog.setdefault(age, {}).setdefault(index_type, {})[shard_name] = shard_resource_id

        force = self.ctx[ForceShardLoadParameter.name]
        if not force:
            logging.debug('searching for database resource in sandbox storage...')

            database_resource = apihelpers.get_last_resource_with_attribute(
                self._get_database_resource(index_type),
                media_settings.SHARD_INSTANCE_ATTRIBUTE_NAME,
                shard_name
            )
            if database_resource:
                logging.debug('database resource found. Updating cache and getting resource from cache')
                return _update_cache(database_resource.id)

        logging.debug('don\'t use database resource from cache/sandbox. Creating download subtask instead')
        sub_ctx = self._get_database_args(index_type)
        sub_ctx.update({
            search_database.ShardNameParameter.name: shard_name,
        })
        sub_task = self.create_subtask(
            task_type=self._get_database_task(index_type),
            description='{}, {}, {}'.format(self.descr, index_type, shard_name),
            input_parameters=sub_ctx
        )
        return _update_cache(sub_task.ctx[search_database.OUT_RESOURCE_KEY])

    def _get_plan_descriptions(self):
        plan_data = {}
        for index_type, query_type, supermind_mult in self._get_tests():
            plan_data.setdefault(index_type, set()).add(query_type)
        for index_type, query_types in plan_data.iteritems():
            yield index_type, ",".join(sorted(query_types))

    def __plan_foreach(self):
        shard_count = self.ctx[ShardsCountParameter.name]
        meta_index_type = self._get_middlesearch_index_type()
        for index_type, plan_types in self._get_plan_descriptions():
            snippetizer_index_type = self._get_basesearch_snippetizer_index_type(index_type)
            for age in self._ALL_AGES:
                middle_shard = self._get_shard_name(age, meta_index_type)
                middle_shard = middle_shard[0] if middle_shard else None
                base_shards, snip_shards = [
                    self._get_shard_name(age, shard_type)[:shard_count]
                    for shard_type in (index_type, snippetizer_index_type)
                ]
                for base_shard, snip_shard in zip(base_shards, snip_shards):
                    yield index_type, plan_types, middle_shard, base_shard, snip_shard

    def __plan_subtask(self, base_index_type, plan_types, middle_shard, base_shard, snip_shard):
        plan_cache = self.ctx.setdefault(_PLAN_CACHE, {})
        all_key = _make_ctx_key(base_index_type, "all", base_shard)
        base_key = _make_ctx_key(base_index_type, "search", base_shard)
        factors_key = _make_ctx_key(base_index_type, "factors", base_shard)
        snip_key = _make_ctx_key(base_index_type, "snippets", snip_shard)
        if all_key in plan_cache and base_key in plan_cache and factors_key in plan_cache and snip_key in plan_cache:
            return

        def _update_cache(all_plan, base_plan, factors_plan, snip_plan):
            plan_cache[all_key] = all_plan
            plan_cache[base_key] = base_plan
            plan_cache[factors_key] = factors_plan
            plan_cache[snip_key] = snip_plan

        shard_data = self.ctx[_SHARD_CACHE]
        snippetizer_index_type = self._get_basesearch_snippetizer_index_type(base_index_type)
        meta_index_type = self._get_middlesearch_index_type()

        sub_ctx = self._get_basesearch_shooting_parameters(meta_index_type)
        sub_ctx.update({
            metasearch_task.MIDDLESEARCH_PARAMS.Binary.name: self._get_middlesearch_executable(meta_index_type),
            metasearch_task.MIDDLESEARCH_PARAMS.Config.name: self._get_middlesearch_config(meta_index_type),
            metasearch_task.MIDDLESEARCH_PARAMS.Data.name: self._get_middlesearch_data(meta_index_type),

            metasearch_task.BASESEARCH_PARAMS.Binary.name: self._get_basesearch_executable(base_index_type),
            metasearch_task.BASESEARCH_PARAMS.Config.name: self._get_basesearch_config(base_index_type),
            metasearch_task.BASESEARCH_PARAMS.Database.name: shard_data[base_shard],

            metasearch_task.SNIPPETIZER_PARAMS.Binary.name: self._get_basesearch_executable(snippetizer_index_type),
            metasearch_task.SNIPPETIZER_PARAMS.Config.name: self._get_basesearch_config(snippetizer_index_type),
            metasearch_task.SNIPPETIZER_PARAMS.Database.name: shard_data[snip_shard],

            generate_task.PlanParameter.name: self._get_middlesearch_plan(meta_index_type, base_index_type),
            generate_task.DemandedPlanTypesParameter.name: plan_types,
        })

        if meta_index_type == media_settings.INDEX_MIDDLE:
            sub_ctx.update({
                metasearch_task.MIDDLESEARCH_PARAMS.ArchiveModel.name: self._get_middlesearch_models(),
                metasearch_task.MIDDLESEARCH_PARAMS.Index.name: shard_data[middle_shard],
            })

        basesearch_models = self._get_basesearch_models(base_index_type)
        if basesearch_models:
            sub_ctx[metasearch_task.BASESEARCH_PARAMS.ArchiveModel.name] = basesearch_models
        snippetizer_models = self._get_basesearch_models(snippetizer_index_type)
        if snippetizer_models:
            sub_ctx[metasearch_task.SNIPPETIZER_PARAMS.ArchiveModel.name] = snippetizer_models

        description = "{}, {}, {}, {}, {}".format(
            self.descr,
            base_index_type,
            middle_shard,
            base_shard,
            snip_shard
        )
        sub_task = self.create_subtask(
            task_type=generate_task.ImagesGenerateBasesearchRequests.type,
            description=description,
            input_parameters=sub_ctx,
            arch=sandboxapi.ARCH_LINUX
        )

        return _update_cache(
            sub_task.ctx['all_plan_resource_id'],
            sub_task.ctx['search_plan_resource_id'],
            sub_task.ctx['factors_plan_resource_id'],
            sub_task.ctx['snippets_plan_resource_id'],
        )

    def __performance_foreach(self, index_type, query_type, supermind_mult):
        shard_count = self.ctx[ShardsCountParameter.name]
        old_base_shards, new_base_shards = [
            self._get_shard_name(age, index_type)[:shard_count] for age in self._ALL_AGES
        ]
        for old_base_shard, new_base_shard in zip(old_base_shards, new_base_shards):
            yield index_type, query_type, supermind_mult, old_base_shard, new_base_shard

    def __performance_subtask(self, index_type, query_type, supermind_mult, old_base_shard, new_base_shard):
        performance_cache = self.ctx.setdefault(_PERFORMANCE_CACHE, {})
        performance_key = _make_ctx_key(index_type, query_type, supermind_mult, old_base_shard, new_base_shard)

        if performance_key in performance_cache:
            raise errors.SandboxTaskFailureError("Same test already exists {}".format(performance_key))

        task_type, basesearch_params, plan_params = self._get_basesearch_performance_task(index_type)
        shard_names = (old_base_shard, new_base_shard)

        sub_ctx = self._get_basesearch_performance_args(index_type, query_type)
        if supermind_mult is not None:
            sub_ctx.update({
                supermind_task.EnableSuperMindParameter.name: True,
                supermind_task.SuperMindModeParameter.name: 'mind',
                supermind_task.MultParameter.name: supermind_mult
            })

        shard_data = self.ctx[_SHARD_CACHE]
        plan_data = self.ctx[_PLAN_CACHE]
        for basesearch_param, plan_param, shard_name in zip(basesearch_params, plan_params, shard_names):
            sub_ctx.update({
                basesearch_param.Binary.name: self._get_basesearch_executable(index_type),
                basesearch_param.Config.name: self._get_basesearch_config(index_type),
                basesearch_param.Database.name: shard_data[shard_name],
                basesearch_param.StartTimeout.name: 2400,
                plan_param.name: plan_data[_make_ctx_key(index_type, query_type, shard_name)],
            })
            models = self._get_basesearch_models(index_type)
            if models:
                sub_ctx[basesearch_param.ArchiveModel.name] = models

        description = "{}, {}, mult={}, {} vs {}".format(
            index_type,
            query_type,
            supermind_mult,
            old_base_shard,
            new_base_shard,
        )
        sub_task = self.create_subtask(
            task_type=task_type,
            input_parameters=sub_ctx,
            description=description,
            arch=sandboxapi.ARCH_LINUX,
            model='E5-2650'
        )
        performance_cache[performance_key] = sub_task.id

    def __performance_aggregate(self, results):
        avg_result = {"title": "Median"}
        for stat, _ in self._STATS:
            if all(stat in v for v in results):
                avg_result[stat] = float(numpy.median([v[stat] for v in results]))
        return avg_result

    def __performance_stats(self, index_type, query_type, supermind_mult, old_base_shard, new_base_shard):
        performance_key = _make_ctx_key(index_type, query_type, supermind_mult, old_base_shard, new_base_shard)
        task_id = self.ctx[_PERFORMANCE_CACHE][performance_key]
        task = channel.sandbox.get_task(task_id)
        task_stats = task.ctx.get('new_stats', {}).get('diff', {})

        stats = {
            "old_shard": old_base_shard,
            "new_shard": new_base_shard,
            "task_id": task.id,
            "task_status": task.status,
        }
        for old_name, new_name in self._STATS:
            if new_name in task_stats:
                stats[old_name] = task_stats[new_name]
        return stats

    @staticmethod
    def __add_index_type(results, tabs, index_type):
        results.setdefault(index_type, []).append(tabs[1].split("(", 1)[0])

    def __get_shard_map(self, age):
        if OldProductionIndexState and age == self._OLD_AGE:
            logging.debug('Generating fake shardmap since OldProductionIndexState specified, there is no shardmap in production anymore')
            return self.__get_fake_cajuper_shardmap(age)

        shardmap_parameter = OldShardmapParameter if age == self._OLD_AGE else NewShardmapParameter
        shardmap_path = self.sync_resource(self.ctx[shardmap_parameter.name])
        results = {}
        with open(shardmap_path) as shardmap_file:
            for line in shardmap_file:
                tabs = line.strip().split()
                if len(tabs) <= 2:
                    continue
                index_types = self._get_index_type_for_tier(tabs[2])
                if isinstance(index_types, (tuple, list)):
                    for index_type in index_types:
                        self.__add_index_type(results, tabs, index_type)
                elif index_types is not None:
                    self.__add_index_type(results, tabs, index_types)
        return results

    def __get_fake_cajuper_shardmap(self, age):
        index_state_parameter = OldProductionIndexState
        index_state_path = self.sync_resource(self.ctx[index_state_parameter.name])
        results = {}
        with open(index_state_path) as index_state_file:
            for line in index_state_file:
                tokens = line.strip().split('=')
                if len(tokens) != 2:
                    continue
                _, env_var = tokens[0].split()
                if env_var == 'YT_STATE':
                    state = tokens[1]
                    results = self._get_fake_shardmap(state)
        return results

    def _get_shard_name(self, age, index_type):
        return self.ctx[_SHARD_MAPS][age][index_type]


def _make_ctx_key(*args):
    """Helper method to create a key for context"""

    return ','.join(str(a) for a in args)
