# coding: utf8
from __future__ import unicode_literals, absolute_import, division, print_function

from marshmallow import Schema, fields, post_dump, pre_dump

from common.models.geo import StationCode
from common.models.schedule import RThreadType
from common.models.transport import TransportType
from common.serialization.fields import DictNestedField
from common.utils.railway import get_railway_tz_by_point
from travel.rasp.train_api.serialization.fields import AwareDateTime, FlagField, CallableStringField
from travel.rasp.train_api.serialization.segment_station import SegmentStationSchema
from travel.rasp.train_api.serialization.segment_transport import DeluxeTrainSchema, build_transport

INVALID_VALUE = 'invalid_value'


class DaysSchema(Schema):
    days_text = fields.String(dump_to='text')
    days_text_short = fields.String(dump_to='shortText')
    schedule_plan_appendix = fields.String(dump_to='schedulePlanAppendix')
    except_days_text = fields.String(dump_to='exceptText')
    has_more_days = fields.Boolean(dump_to='hasMoreDays')
    cancel = fields.Boolean(dump_to='canceled')


class CompanySchema(Schema):
    id = fields.Integer()
    title = CallableStringField(attribute='L_title', dump_to='title')
    short_title = CallableStringField(attribute='L_short_title', dump_to='shortTitle')
    url = fields.String()
    hidden = fields.Boolean()


class TrainSchedulePlanSchema(Schema):
    code = fields.String()
    title = fields.String()
    start_date = fields.Date(dump_to='startDate')
    end_date = fields.Date(dump_to='endDate')


class ThreadSchema(Schema):
    uid = fields.String()
    title = CallableStringField(attribute='L_title', dump_to='title')
    number = fields.String()
    is_express = fields.Boolean(dump_to='isExpress', default=False)
    is_aeroexpress = fields.Boolean(dump_to='isAeroExpress', default=False)
    deluxe_train = fields.Nested(DeluxeTrainSchema, dump_to='deluxeTrain')
    begin_time = fields.Time(dump_to='beginTime')
    end_time = fields.Time(dump_to='endTime')
    density = fields.String()
    schedule_plan = fields.Nested(TrainSchedulePlanSchema, only='code', dump_to='schedulePlanCode')
    is_basic = fields.Function(lambda thread: thread.type_id == RThreadType.BASIC_ID, dump_to='isBasic')
    comment = fields.String()
    first_country_code = fields.String(dump_to='firstCountryCode')
    last_country_code = fields.String(dump_to='lastCountryCode')
    displace_yabus = fields.Function(lambda t: (t.supplier and t.supplier.displace_yabus
                                                ) if t.t_type_id == TransportType.BUS_ID else None,
                                     dump_to='displaceYabus')

    @post_dump
    def post_dump(self, data):
        if 'deluxeTrain' in data and data['deluxeTrain'] is None:
            data.pop('deluxeTrain')
        return data


class BaseSegmentSchema(Schema):
    title = CallableStringField(attribute='L_title', dump_to='title')
    arrival = AwareDateTime()
    departure = AwareDateTime()
    duration = fields.TimeDelta()
    number = fields.String()
    station_from = fields.Method('build_station_from', dump_to='stationFrom')
    station_to = fields.Method('build_station_to', dump_to='stationTo')
    days_by_tz = DictNestedField(DaysSchema, dump_to='daysByTimezone')
    run_days = fields.Dict(dump_to='runDays')  # {'2016': {'1': [1, 1, 0, 1,]}}
    start_date = fields.Date(dump_to='startDate')

    sales_limit_in_days = fields.Integer(dump_to='salesLimitInDays')
    transport = fields.Function(build_transport)
    thread = fields.Nested(ThreadSchema)
    company = fields.Nested(CompanySchema)
    is_through_train = fields.Function(lambda obj: obj.thread and obj.thread.type_id == RThreadType.THROUGH_TRAIN_ID,
                                       dump_to='isThroughTrain')
    stops = CallableStringField(attribute='L_stops')
    do_not_sell = FlagField(dump_to='doNotSell')
    old_ufs_order = fields.Boolean(dump_to='oldUfsOrder')

    url = fields.String(dump_to='url')

    def build_station_from(self, segment):
        return self._build_station(segment.station_from, getattr(segment, 'rtstation_from', None), segment.t_type)

    def build_station_to(self, segment):
        return self._build_station(segment.station_to, getattr(segment, 'rtstation_to', None), segment.t_type)

    def _build_station(self, station, rtstation, segment_t_type):
        station_data = SegmentStationSchema(context=self.context).dump(station).data
        if rtstation:
            station_data['platform'] = rtstation.L_platform()

        if segment_t_type.id in {TransportType.TRAIN_ID, TransportType.SUBURBAN_ID}:
            railway_timezone = get_railway_tz_by_point(station)
            if railway_timezone:
                station_data['railwayTimezone'] = railway_timezone.zone
        return station_data

    @pre_dump(pass_many=True)
    def prepare_express_by_station_id_cache(self, segments, many):
        if not many:
            segments = [segments]

        station_ids = set()
        for s in segments:
            station_ids.add(s.station_from.id)
            station_ids.add(s.station_to.id)

        if station_ids:
            self.context['express_by_station_id_cache'] = {
                stid: code
                for stid, code in StationCode.objects.filter(station__id__in=list(station_ids), system__code='express')
                                             .values_list('station__id', 'code')
            }

    @post_dump(pass_many=True)
    def _clear_express_by_station_id_cache(self, _data, many):
        self.context.pop('express_by_station_id_cache', None)
