import logging
import ujson as json
from datetime import datetime

import re
from flask import request
from flask.views import MethodView
from ylog.context import log_context

from jafar import jafar_mongo, mongo_configs, cache, sentry
from jafar.exceptions import FeedException
from jafar.feed import schema
from jafar.feed.backend import FeedMongoBackend
from jafar.feed.base import Feed, FeedPage
from jafar.models import cards
from jafar.utils import get_all_subclasses
from jafar.utils.iter import take
from jafar.web import JsonResponse
from jafar.web.exceptions import NotFound, BadRequest
from jafar_yt.utils.helpers import stemmize, transliterate

logger = logging.getLogger(__name__)


class ExperimentNotFound(BadRequest):
    def __init__(self, name):
        self.details = {"errors": ["No such experiment: {}".format(name)]}


class BaseJafarView(MethodView):
    experiment_config = None
    request_schema = None

    def get_experiment_config(self, name):
        if not self.experiment_config:
            raise NotImplementedError

        @cache.memoize(timeout=5 * 60)
        def get_cached(cls, name):
            # cls is added for correct memoization (name might not be unique)
            config = cls.objects.get(name=name)
            return config.get_params()

        try:
            return get_cached(self.experiment_config, name)
        except self.experiment_config.DoesNotExist:
            raise ExperimentNotFound(name)

    def parse_request_params(self, request_params):
        if not self.request_schema:
            raise NotImplementedError
        request_params, errors = self.request_schema().load(request_params)
        if errors:
            logger.error(json.dumps(errors))
            raise BadRequest(errors)
        return request_params

    @staticmethod
    def get_device_id(request_params):
        return request_params.get('device_id')

    def dispatch_request(self, *args, **kwargs):
        request_params = request.get_json(force=True, silent=True) or request.args
        request_id = request.headers.get('X_REQUESTID')
        device_id = self.get_device_id(request_params)
        # from path '/vanga/rasputin' get 'vanga'
        endpoint = request.path.split("/")[1] if len(request.path.split("/")) > 1 else None
        log_params = dict(
            endpoint=endpoint,
            request_id=request_id,
            device_id=device_id
        )
        with log_context(query_string=args, **log_params):
            sentry.tags_context(log_params)
            logger.info('Request: {}'.format(json.dumps(request_params)))
            request_params = self.parse_request_params(request_params)
            return super(BaseJafarView, self).dispatch_request(request_params, *args, **kwargs)


class RecommendationFeedView(BaseJafarView):
    experiment_config = mongo_configs.FeedExperimentConfig
    request_schema = schema.FeedRequestSchema

    def post(self, request_params, experiment_name):
        config = self.get_experiment_config(experiment_name)
        page_limit = config.pop('page_limit')
        if page_limit > 0 and request_params['page'] >= page_limit + 1:
            return NotFound(schema.FeedPageSchema().dump(FeedPage(blocks=[], expire_at=datetime.utcnow()))[0])
        user = request_params['user_info']
        # Block counts override is used for old 1.x.x launchers that send required groups_count in request
        config['block_count'] = request_params['block_count_override'] or config['block_count']
        feed = Feed(
            user=user, backend_type=FeedMongoBackend, cache_key=request_params['cache_key'],
            categories=request_params['categories'], place=request_params['place'],
            promo_placeholders=request_params['promo_placeholders'], **config
        )
        try:
            page = feed.get_page(page_number=request_params['page'])
        except FeedException as e:
            logger.error(e.message, extra=e.extra)
            raise NotFound({"errors": [e.message]})
        response, errors = schema.FeedPageSchema().dump(page)
        return JsonResponse(response)

    @staticmethod
    def get_device_id(request_params):
        return request_params['user_info']['device_id']


def make_regex(value):
    return {'$regex': '^' + re.escape(value)}


def make_query_expression(field, words):
    if len(words) == 1:
        expression = {field: make_regex(words[0])}
    else:
        comparisons = [{field: word} for word in words[:-1]]
        if len(words[-1]) >= 3:
            comparisons.append({field: make_regex(words[-1])})
        expression = {'$and': comparisons}
    return expression


class SearchView(BaseJafarView):
    experiment_config = mongo_configs.FeedExperimentConfig
    request_schema = schema.SearchSchema

    @staticmethod
    def get_results_from_mongo(count, query):

        stems = stemmize(query)
        if len(stems) == 0:
            return []

        translits = filter(lambda word: len(word), transliterate(query))
        logger.debug(u'    stems: %s', u', '.join(stems))
        logger.debug(u'translits: %s', u', '.join(translits))

        if not translits or (len(translits) == 1 and len(translits[0]) < 3):
            search_query = make_query_expression('stems', stems)
        else:
            search_query = {'$or': [
                make_query_expression('stems', stems),
                make_query_expression('stems', translits),
            ]}

        cursor = jafar_mongo.db.apps_search_index_expanded.find(
            filter=search_query,
            projection={'_id': False, 'package_name': True},
            sort=[('popularity', -1)],
        ).limit(count)
        apps = (app['package_name'] for app in cursor)
        return take(apps, count)

    def get(self, request_params, experiment_name):
        query = request_params['query']
        count = request_params['count']
        package_names = self.get_results_from_mongo(count, query)
        search_results = [{'package_name': package_name} for package_name in package_names]
        return JsonResponse({'search_results': search_results})


class RecConfigView(MethodView):
    def get(self):
        response, errors = schema.RecConfigSchema(many=True).dump(get_all_subclasses(cards.CardConfig))
        return JsonResponse(response)
