import cPickle
import logging

from sandbox.sandboxsdk.channel import channel
from sandbox.sandboxsdk import errors
from sandbox.sandboxsdk import parameters
from sandbox.sandboxsdk import sandboxapi

from sandbox.projects import resource_types
from sandbox.projects.common import decorators
from sandbox.projects.common import error_handlers as eh
from sandbox.projects.common import utils
from sandbox.projects.common.base_search_quality import basesearch_response_parser
from sandbox.projects.common.base_search_quality import response_saver
from sandbox.projects.common.search.response import state as response_state
from sandbox.projects.common.search import requester_compat as search_requester
from sandbox.projects.common.search import response_patcher
from sandbox.projects.common.base_search_quality.tree.node import WalkPath

RESPONSE_SAVER_PARAMS = response_saver.create_response_saver_params(
    queries_resource=(
        resource_types.PLAIN_TEXT_QUERIES,
    )
)

STATS_VARIANTS = {
    "empty":      ("Empty responses", response_state.is_answer_empty),
    "error":      ("Errors", response_state.is_answer_erroneous),
    "notfetched": ("Not fetched", response_state.is_answer_not_fetched),
    "unanswer":   ("Unanswers", lambda response: not response_state.is_answer_full(response)),
}

_UNSTABLE_SEARCHER_PROPS = [
    "SourceTimestamp"  # Sometimes get diffs on Sandbox
]


class MaxEmptyResponsesRate(parameters.SandboxFloatParameter):
    name = 'max_empty_responses_rate'
    description = 'Maximum allowed empty responses rate'
    group = response_saver.GROUP_NAME
    default_value = 0


class ResponsesTask:
    """Mixin class to get responses from some target"""

    input_parameters = (
        RESPONSE_SAVER_PARAMS.params +
        (MaxEmptyResponsesRate,) +
        search_requester.create_params()
    )

    @staticmethod
    def _get_queries_parameter():
        """Hook, used by database autodetection"""
        return RESPONSE_SAVER_PARAMS.QueriesParameter

    def _create_responses_resource(self, resource_key, file_prefix=''):
        self.ctx[resource_key] = self.create_resource(
            self.descr,
            self._get_responses_resource_path(file_prefix),
            self._get_responses_resource_type(),
            arch=sandboxapi.ARCH_ANY
        ).id

    def _get_responses_resource_type(self):
        if self.ctx[RESPONSE_SAVER_PARAMS.UseBinaryResponses.name]:
            return resource_types.SEARCH_PROTO_RESPONSES
        else:
            return resource_types.BASESEARCH_HR_RESPONSES

    def _get_responses_resource_suffix(self):
        if self.ctx[RESPONSE_SAVER_PARAMS.UseBinaryResponses.name]:
            return "bin"
        else:
            return "txt"

    def _get_responses_resource_path(self, file_prefix=''):
        return '{}responses.{}'.format(file_prefix, self._get_responses_resource_suffix())

    def _get_responses(self, target, resource_key, stats_key=None, prepare_session_callback=None):
        responses_resource = channel.sandbox.get_resource(self.ctx[resource_key])
        queries_path = self.sync_resource(self.ctx[self._get_queries_parameter().name])
        self.__save_responses(queries_path, responses_resource.path, target)
        recheck_n_times = utils.get_or_default(self.ctx, RESPONSE_SAVER_PARAMS.RecheckResultsNTimesParameter)
        self.__recheck_responses(queries_path, responses_resource.path, target, recheck_n_times)

        if not self.ctx[RESPONSE_SAVER_PARAMS.UseBinaryResponses.name] and stats_key is not None:
            stats = self._calc_stats(responses_resource.path)
            self.ctx[stats_key] = stats
            self._verify_stats(stats)

    def _need_test_with_disabled_cache(self):
        return self.ctx.get(RESPONSE_SAVER_PARAMS.TestWithCacheDisabled.name)

    def _verify_stats(self, stats):
        max_empty = utils.get_or_default(self.ctx, MaxEmptyResponsesRate)
        if max_empty and stats["empty"] > max_empty:
            raise errors.SandboxTaskFailureError("Too many empty responses ({} > {})".format(stats["empty"], max_empty))

    @classmethod
    def _calc_stats(cls, responses_path):
        """Calculate various statistics on responses"""

        # rate stats
        rate_stats = {key: 0 for key in STATS_VARIANTS}
        total = 0

        for response in cls.__parse_responses(responses_path):
            response = cPickle.loads(response)

            # rate stats
            for key, data in STATS_VARIANTS.iteritems():
                title, predicate = data
                if predicate(response):
                    rate_stats[key] += 1

            total += 1

        for key in rate_stats:
            rate_stats[key] = float(rate_stats[key]) / total if total else 0

        # join two types of stats
        stats = {}
        stats.update(rate_stats)
        return stats

    @staticmethod
    @decorators.memoize
    def __parse_responses(responses_path):
        return basesearch_response_parser.parse_responses(
            responses_path,
            remove_unstable_props=True,
            response_patchers=get_response_patchers(),
            use_processes=True
        )

    def __save_responses(self, queries_path, responses_path, target):
        target.use_component(lambda: search_requester.save_responses_old(self.ctx, queries_path, responses_path, target))

    def __recheck_responses(self, queries_path, old_responses_path, target, recheck_n_times):
        if self.ctx[RESPONSE_SAVER_PARAMS.UseBinaryResponses.name] or not recheck_n_times:
            return

        for n in xrange(1, recheck_n_times + 1):
            logging.info("Rechecking results. attempt #{}".format(n))
            new_responses_prefix = "{}-".format(n)
            new_responses_path = self._get_responses_resource_path(new_responses_prefix)
            self.__save_responses(queries_path, new_responses_path, target)

            old_responses_data = self.__parse_responses(old_responses_path)
            new_responses_data = self.__parse_responses(new_responses_path)
            diff_indexes = []
            if basesearch_response_parser.compare_responses(old_responses_data, new_responses_data, diff_indexes=diff_indexes):
                logging.info("Additional check finished. No diff found.")
                continue

            self._create_responses_resource("{}_output_resource".format(n), new_responses_prefix)
            unstable_diff = response_saver.write_unstable_html_diff(
                list(search_requester.generate_queries_old(self.ctx, queries_path, target, need_dbgrlv=False)),
                old_responses_data, new_responses_data,
                custom_node_types_dict=None,
                diff_indexes=diff_indexes
            )
            eh.check_failed(
                "Search is not stable, "
                "see resource:{} for unstable diff".format(unstable_diff.id)
            )


def get_response_patchers():
    return [_remove_eventlog_frame, _remove_unstable_searcher_props, _remove_unstable_group_props]


def _remove_eventlog_frame(response):
    for debuginfo in response._nodes.get("DebugInfo", []):
        if "EventLogFrame" in debuginfo._nodes:
            response_patcher.purge(debuginfo, "EventLogFrame")


def _remove_unstable_searcher_props(response):
    response._nodes["SearcherProp"] = [
        p for p in response._nodes.get("SearcherProp", [])
        if p.GetPropValue("Key") not in _UNSTABLE_SEARCHER_PROPS
    ]


def _remove_unstable_group_props(response):
    def process_group(group):
        for property_name in ["InternalPriority"]:
            if property_name in group._props:
                del group._props[property_name]
    WalkPath(response, ["Grouping", "Group"], process_group)

    def process_document(document):
        for property_name in ["InternalPriority", "SInternalPriority"]:
            if property_name in document._props:
                del document._props[property_name]

    WalkPath(response, ["Grouping", "Group", "Document"], process_document)
