# coding: utf8
from __future__ import unicode_literals, absolute_import, division, print_function

import logging

from marshmallow import fields

from common.models.geo import Station, Settlement

from travel.rasp.export.export.v3.selling.train_api import poll_train_tariffs, prepare_train_selling_info
from travel.rasp.export.export.v3.views.base_view import BaseView
from travel.rasp.export.export.v3.views.request_schemas import BaseRequestSchema, DateField
from travel.rasp.export.export.v3.views.search_dynamic_data import PointField
from travel.rasp.export.export.v3.views.utils import get_points


log = logging.getLogger(__name__)


class SearchTrainTariffsSchema(BaseRequestSchema):
    point_from = PointField(required=True)
    point_to = PointField(required=True)
    date = DateField(required=True)
    segments_keys = fields.List(fields.List(fields.String()), required=True)


class SearchTrainTariffsView(BaseView):
    http_method_names = ['post']

    def handle(self, request, *args, **kwargs):
        query, errors = SearchTrainTariffsSchema(context={'request': request}).load(request.data)

        points = {'station_from_id': None, 'station_to_id': None, 'city_from_id': None, 'city_to_id': None}

        for point, postfix in zip([query['point_from'], query['point_to']], ['from', 'to']):
            if isinstance(point, Settlement):
                points['city_{}_id'.format(postfix)] = point.id
            elif isinstance(point, Station):
                points['station_{}_id'.format(postfix)] = point.id

        point_from, point_to, point_from_reduced, point_to_reduced = get_points(**points)

        tariffs_by_key, polling_status = poll_train_tariffs({
            'point_from': point_from.point_key,
            'point_to': point_to.point_key,
            'date': query['date']
        })

        if polling_status:
            return {'train_tariffs_polling': True}

        data = {'train_tariffs_polling': False}
        for keys in query['segments_keys']:
            for key in keys:
                tariffs = tariffs_by_key.get(key)
                if tariffs:
                    data[key] = {'selling_info': prepare_train_selling_info(tariffs)}

        return data
