# coding: utf-8
from __future__ import unicode_literals, absolute_import, division, print_function

from datetime import datetime, time, timedelta

from marshmallow import Schema, ValidationError, fields, post_dump, post_load, validates_schema
from marshmallow_enum import EnumField

from common.serialization.common_schemas import PriceSchema
from common.serialization.fields import DictNestedField
from travel.rasp.library.python.common23.date.environment import now_aware
from common.utils.railway import get_railway_tz_by_point
from travel.rasp.train_api.serialization.experiment import ExperimentQuerySchema
from travel.rasp.train_api.serialization.fields import FlagField
from travel.rasp.train_api.serialization.schema_bases import MultiValueDictSchemaMixin, PointsQuerySchema
from travel.rasp.train_api.tariffs.serialization import SegmentTariffsSchema, TariffSchema, TariffsSegmentSchema
from travel.rasp.train_api.train_purchase.core.enums import TrainOrderUrlOwner
from travel.rasp.train_api.train_purchase.core.models import TrainPartner


class TrainQuerySchema(PointsQuerySchema, MultiValueDictSchemaMixin, ExperimentQuerySchema):
    dates = fields.List(fields.Date(), load_from='date')
    expanded_day = fields.Boolean(missing=False)
    start_time = fields.DateTime(load_from='startTime')
    end_time = fields.DateTime(load_from='endTime')
    national_version = fields.String(required=True)
    include_price_fee = fields.Boolean(load_from='includePriceFee', missing=False)
    use_railway_tz = fields.Boolean(load_from='useRailwayTZ', missing=False)
    experiment = fields.Boolean(missing=None)
    price_exp_id = fields.String(load_from='priceExpId', missing=None)
    service = fields.String(missing=None)
    utm_source = fields.String(load_from='utmSource', missing=None)
    yandex_uid = fields.String(missing=None)
    force_ufs_order = fields.Boolean(load_from='forceUfsOrder', missing=False)
    allow_international_routes = fields.Boolean(load_from='allowInternationalRoutes', missing=False)
    asker = fields.String(load_from='asker', missing=None)
    ytp_referer = fields.String(load_from='ytp_referer', missing=None)
    mock_im_path = fields.String(load_from='mockImPath', missing=None)
    mock_im_auto = fields.Boolean(load_from='mockImAuto', missing=False)

    @validates_schema(skip_on_field_errors=False)
    def validate_dates(self, data):
        start_time = data.get('start_time')
        end_time = data.get('end_time')
        if start_time and end_time:
            if start_time.tzinfo is None or end_time.tzinfo is None:
                raise ValidationError('startTime and endTime should have UTC offset')
            if start_time >= end_time:
                raise ValidationError('startTime should be less than endTime')
        elif not data.get('dates'):
            raise ValidationError('startTime/endTime or date parameters required')

    @post_load
    def process_dates(self, data):
        if data.get('start_time') and data.get('end_time'):
            return data

        tz = get_railway_tz_by_point(data['point_from']) if data['use_railway_tz'] else data['point_from'].pytz
        dates = sorted(data['dates'])
        delta = timedelta(days=1, hours=4) if data['expanded_day'] else timedelta(days=1)
        return dict(data,
                    start_time=tz.localize(datetime.combine(dates[0], time(0))),
                    end_time=tz.localize(datetime.combine(dates[-1], time(0)) + delta))


class TrainEarliestDateQuerySchema(PointsQuerySchema, ExperimentQuerySchema):
    minimum_date = fields.Date(load_from='date', required=False, missing=None)

    @validates_schema(skip_on_field_errors=False)
    def validate_dates(self, data):
        minimum_date = data.get('minimum_date')
        if minimum_date is not None and minimum_date <= datetime.date(now_aware()):
            raise ValidationError('startTime should be greater than today')

    @post_load
    def process_date(self, data):
        minimum_date = data.get('minimum_date')
        if minimum_date is not None:
            return data
        return dict(
            data,
            minimum_date=datetime.date(now_aware()) + timedelta(days=1),
        )


class TrainEarliestDateResponseSchema(Schema):
    date = fields.Date()


class TrainSegmentTariffsSchema(SegmentTariffsSchema):
    classes = DictNestedField(TariffSchema)
    broken_classes = fields.Raw(dump_to='brokenClasses')


class TrainTariffsSegmentSchema(TariffsSegmentSchema):
    can_supply_segments = fields.Boolean(dump_to='canSupplySegments')
    tariffs = fields.Nested(TrainSegmentTariffsSchema)
    raw_train_category = fields.String(dump_to='rawTrainCategory')
    raw_train_name = fields.String(dump_to='rawTrainName')
    has_dynamic_pricing = fields.Boolean(dump_to='hasDynamicPricing')
    two_storey = fields.Boolean(dump_to='twoStorey')
    is_suburban = fields.Boolean(dump_to='isSuburban')
    provider = fields.String()

    @post_dump(pass_many=True, pass_original=True)
    def add_ufs_titles(self, data, many, original_data):
        segments_items = zip(original_data, data) if many else [(original_data, data)]
        for segment, segment_data in segments_items:
            segment_data['ufsTitle'] = getattr(segment, 'ufs_title', '')
            segment_data['stationFrom']['ufsTitle'] = segment.station_from_ufs_title
            segment_data['stationTo']['ufsTitle'] = segment.station_to_ufs_title
            company_data = segment_data.get('company')
            if company_data is None:
                segment_data['company'] = company_data = {}
            company_data['ufsTitle'] = ', '.join(segment.coach_owners)
        return data


class TrainResponseSchema(Schema):
    querying = fields.Boolean()
    segments = fields.Nested(TrainTariffsSegmentSchema, many=True)


class MinTariffsQuerySchema(PointsQuerySchema):
    partner = EnumField(TrainPartner, by_value=True, missing=TrainPartner.IM, required=True)
    national_version = fields.String(required=True)
    experiment = fields.Boolean(missing=None)


class MinTariffSchema(Schema):
    price = fields.Nested(PriceSchema)
    seats = fields.Integer()
    train_order_url = fields.String(dump_to='trainOrderUrl')
    train_order_url_owner = EnumField(TrainOrderUrlOwner, by_value=True, dump_to='trainOrderUrlOwner')
    several_prices = FlagField(dump_to='severalPrices')


class MinTariffsSchema(Schema):
    classes = DictNestedField(MinTariffSchema)
    electronic_ticket = FlagField(dump_to='electronicTicket')


class MinTariffsTrainSchema(Schema):
    display_number = fields.String(dump_to='displayNumber')
    original_number = fields.String(dump_to='originalNumber')
    departure = fields.DateTime()
    tariffs = fields.Nested(MinTariffsSchema)


class MinTariffsResponseSchema(Schema):
    trains = fields.Nested(MinTariffsTrainSchema, many=True)

    @classmethod
    def fast_dump(cls, data):
        trains_dump = []
        for train in data['trains']:
            classes_dump = {}
            for class_name, tariff in train.tariffs['classes'].items():
                classes_dump[class_name] = {
                    'price': {'value': tariff.price.value, 'currency': tariff.price.currency} if tariff.price else None,
                    'seats': tariff.seats,
                    'trainOrderUrl': tariff.train_order_url,
                    'trainOrderUrlOwner': tariff.train_order_url_owner.value if tariff.train_order_url_owner else None,
                    'severalPrices': tariff.several_prices,
                }
            train_dump = {
                'displayNumber': train.display_number,
                'originalNumber': train.original_number,
                'departure': train.departure,
                'tariffs': {
                    'electronicTicket': train.tariffs['electronic_ticket'],
                    'classes': classes_dump
                }
            }
            trains_dump.append(train_dump)
        return {'trains': trains_dump}


class MinTariffsShortSchema(Schema):
    key = fields.String()
    classes = DictNestedField(TariffSchema)
    sales_depth = fields.Integer(dump_to='salesDepth')


class MinTariffsShortQuerySchema(PointsQuerySchema):
    national_version = fields.String(required=True)
    with_sales_depth = fields.Bool(default=False, load_from='withSalesDepth')


class MinTariffsShortResponseSchema(Schema):
    tariffs = fields.Nested(MinTariffsShortSchema, many=True)
