# coding: utf8
from __future__ import unicode_literals, absolute_import, division, print_function

import json
import logging
from collections import defaultdict
from functools import partial

from marshmallow import fields, pre_load, ValidationError, post_load

from common.apps.suburban_events.api import get_states_by_segment_keys_plain
from common.models.geo import Station, Point
from travel.rasp.library.python.common23.logging import log_run_time

from route_search.models import RThreadSegment

from travel.rasp.export.export.v3.core.suburban_events import get_segment_state_dict, DYNAMIC_DATA_STATE_NAME_BY_TYPE
from travel.rasp.export.export.v3.core.teasers import get_search_teasers
from travel.rasp.export.export.v3.views.base_view import BaseView
from travel.rasp.export.export.v3.views.request_schemas import BaseRequestSchema

log = logging.getLogger(__name__)
log_run_time = partial(log_run_time, logger=log, log_level=logging.DEBUG)


class StationField(fields.Field):
    def deserialize(self, value, attr=None, data=None):
        if value is None:
            return None

        try:
            station = Station.objects.get(id=value)
        except Exception as ex:
            log.error('Не удалось получить станцию по id {}. Ошибка: {}'.format(value, repr(ex)))
            raise ValidationError('Невалидный id станции {}.'.format(str(value)), http_code=400)

        return station


class PointField(fields.Field):
    def deserialize(self, value, attr=None, data=None):
        if value is None:
            return None

        try:
            point = Point.get_by_key(value)
        except Exception as ex:
            log.error('Не удалось получить точку по ключу {}. Ошибка: {}'.format(value, repr(ex)))
            raise ValidationError('Невалидный код точки {}.'.format(value), http_code=400)

        return point


class SegmentKeyField(BaseRequestSchema):
    arrival = fields.String(missing=None)
    departure = fields.String(missing=None)
    thread_key = fields.String(missing=None)


class SegmentField(BaseRequestSchema):
    station_from = StationField(missing=None)
    station_to = StationField(missing=None)
    thread_uid = fields.String(missing=None)

    @post_load
    def check_stations(self, data):
        if bool(data['station_from']) ^ bool(data['station_to']):
            raise ValidationError('Должны быть указаны обе станции сегмента.', http_code=400)


class SearchContextSchema(BaseRequestSchema):
    point_from = PointField(missing=None)
    point_to = PointField(missing=None)
    segments = fields.Nested(SegmentField, many=True)


class SearchDynamicDataSchema(BaseRequestSchema):
    segments_keys = fields.Nested(SegmentKeyField, many=True, missing=None)
    search_context = fields.Nested(SearchContextSchema, missing=None)

    @pre_load
    def params_to_json(self, data):
        keys, context = data.get('segments_keys'), data.get('search_context')
        if keys:
            data['segments_keys'] = json.loads(keys)
        if context:
            data['search_context'] = json.loads(context)

            segments = data['search_context'].get('segments')
            if segments:
                uniq_segments = []
                stations_set = set()
                for segment in segments:
                    stations_pair = segment.get('station_from'), segment.get('station_to')
                    if stations_pair not in stations_set:
                        uniq_segments.append(segment)
                        stations_set.add(stations_pair)

                data['search_context']['segments'] = uniq_segments

        return data


class SearchDynamicDataView(BaseView):
    """
    https://st.yandex-team.ru/RASPEXPORT-292
    Ручка для автообновления динамических данных ручки поиска.
    Ручка делает 2 вещи:
    1) Возвращает текущие состояния по списку переданных ключей сегментов.
    2) Возвращает список тизеров по переданному контексту поиска.
    """

    http_method_names = ['post']

    def handle(self, request, *args, **kwargs):
        query, errors = SearchDynamicDataSchema(context={'request': request}).load(request.POST.dict())
        national_version, lang = request.national_version, request.language_code
        keys = query['segments_keys']
        context = query['search_context']
        response_dict = {}

        if keys:
            with log_run_time('get_states_by_segment_keys_plain for {} keys'.format(len(keys))):
                states = get_states_by_segment_keys_plain(keys)

            segments_states = defaultdict(dict)
            with log_run_time('create result segments states for {} states'.format(len(states))):
                for state in states:
                    for state_type, state_dict in get_segment_state_dict(state).items():
                        event_type = DYNAMIC_DATA_STATE_NAME_BY_TYPE.get(state_type) or state_type
                        segments_states[event_type][state_dict['key']] = state_dict

            response_dict['segments_states'] = segments_states

        if context:
            with log_run_time('get teasers'):
                point_from, point_to = context.get('point_from'), context.get('point_to')

                segments = []
                for segment_data in context.get('segments', []):
                    if segment_data['station_from'] and segment_data['station_to']:
                        segment = RThreadSegment()
                        segment.station_from = segment_data['station_from']
                        segment.station_to = segment_data['station_to']
                        segments.append(segment)

                teasers = get_search_teasers(
                    segments, point_from, point_to, national_version=national_version, lang=lang)

            response_dict['teasers'] = teasers

        return response_dict
