# coding: utf8
from __future__ import absolute_import, division, print_function, unicode_literals

import pytz
from dateutil.relativedelta import relativedelta
from iso8601 import parse_date, ParseError
from marshmallow import Schema, fields, validates, ValidationError, post_load
from marshmallow.validate import Range, OneOf
from rest_framework import status

from common.models.geo import StationType
from common.models.transport import TransportType
from travel.rasp.library.python.common23.date import environment
from common.xgettext.i18n import gettext

from travel.rasp.api_public.api_public.v3.core.api_errors import ApiError


class DateField(fields.Field):
    def deserialize(self, value, attr=None, data=None):
        dt_str = value.strip()
        if not dt_str:
            return

        try:
            dt = parse_date(dt_str, default_timezone=None)
        except ParseError:
            try:
                space_count = dt_str.count(" ")
                dt = parse_date(dt_str.replace(" ", "+", space_count))
            except Exception:
                raise ValidationError(gettext(u"Дата должна быть в формате ISO 8601"))

        start_date = (environment.now_aware() - relativedelta(days=30))
        end_date = (environment.now_aware() + relativedelta(months=11))

        if ((dt.tzinfo and not (start_date.astimezone(dt.tzinfo) <= dt <= end_date.astimezone(dt.tzinfo))) or
                (not start_date.date() <= dt.date() <= end_date.date())):
            raise ValidationError(gettext(u"Указана недопустимая дата - {}. Доступен выбор даты на 30 дней назад и "
                                          u"11 месяцев вперед от текущей даты".format(dt.date())), http_code=400)
        return dt


class ResultTimezoneField(fields.Field):
    def deserialize(self, value, attr=None, data=None):
        timezone_str = value.strip()
        if not timezone_str:
            return

        try:
            timezone = pytz.timezone(timezone_str)
        except pytz.exceptions.UnknownTimeZoneError:
            raise ValidationError(gettext(u"Указана неподдерживаемая таймзона."))
        return timezone


class TransportTypesField(fields.Field):
    def deserialize(self, value, attr=None, data=None):
        transport_type_codes = {code for code in map(unicode.strip, value.strip().split(u",")) if code}
        if not transport_type_codes or u"all" in transport_type_codes:
            return

        water_type_codes = {"water", "river", "sea"}
        deprecated_water_types = ["river", "sea"]

        if any(water_type_code in transport_type_codes for water_type_code in water_type_codes):
            transport_type_codes |= water_type_codes

        t_types = []
        for t_type_code in transport_type_codes:
            try:
                t_types.append(TransportType.objects.get(code=t_type_code))
            except TransportType.DoesNotExist:
                if t_type_code in deprecated_water_types:
                    continue
                raise ValidationError(gettext(u"Неподдерживаемый тип транспорта {}").format(t_type_code))
        return t_types


class StationTypesField(fields.Field):
    def deserialize(self, value, attr=None, data=None):
        station_type_names = [s for s in map(unicode.strip, value.strip().split(u",")) if s]
        if not station_type_names or u"all" in station_type_names:
            return

        station_types = []
        for station_type_name in station_type_names:
            try:
                station_types.append(StationType.objects.get(id=StationType.STATION_TYPE_ID_BY_CODE[station_type_name]))
            except (StationType.DoesNotExist, KeyError):
                raise ValidationError(gettext(u"Не поддерживаемый тип станции {}").format(station_type_name))
        return station_types


class ShowSystemsField(fields.Field):
    ALLOWED_SYSTEMS = ("yandex", "esr", "iata", "sirena", "express")

    def deserialize(self, value, attr=None, data=None):
        show_systems = [s for s in map(unicode.strip, value.strip().split(u",")) if s]
        if "all" in show_systems:
            return self.ALLOWED_SYSTEMS
        else:
            not_supported_systems = set(show_systems) - set(self.ALLOWED_SYSTEMS)
            if not_supported_systems:
                raise ValidationError(
                    gettext(u"Неподдерживаемые системы кодирования: {}").format(u", ".join(not_supported_systems))
                )
            return show_systems


class BaseRequestSchema(Schema):
    def handle_error(self, exceptions, data):
        return gen_validate_error(exceptions.messages)

    def get_str_value(self, value):
        return value.strip()

    def get_bool_param(self, value):
        if isinstance(value, basestring):
            return value.lower() == "true"
        return value


class CarrierRequestSchema(BaseRequestSchema):
    VALID_SYSTEMS = ("iata", "yandex", "sirena", "icao")

    show_systems = ShowSystemsField(missing="")
    code = fields.Method(deserialize="get_str_value", required=True,
                         error_messages={u"required": u"Код перевозчика не указан."})
    system = fields.Method(
        deserialize="get_str_value", missing=u"yandex",
        validate=OneOf(
            VALID_SYSTEMS,
            error=gettext(u"Система кодирования должна быть из списка [{}].".format(", ".join(VALID_SYSTEMS)))
        )
    )
    city = fields.Method(deserialize="get_str_value", missing=u"")

    @validates("city")
    def validate_city(self, value):
        if value:
            int_part_invalid = False
            try:
                int(value[1:])
            except ValueError:
                int_part_invalid = True
            if len(value) < 2 or value[0] != "c" or int_part_invalid:
                raise ValidationError(gettext(u"Город должен указываться в формате cID."))


class NearestPointRequestSchema(BaseRequestSchema):
    show_systems = ShowSystemsField(missing="")
    transport_types = TransportTypesField(missing=u"")
    station_types = StationTypesField(missing=u"")
    lat = fields.Float(
        required=True, error_messages={u"required": u"Не указан параметр lat."},
        validate=Range(-90, 90, gettext(u"Параметр lat должен быть числом в диапазоне от -90 до 90")))
    lng = fields.Float(
        required=True, error_messages={u"required": u"Не указан параметр lng."},
        validate=Range(-180, 180, gettext(u"Параметр lng должен быть числом в диапазоне от -180 до 180")))
    distance = fields.Float(
        missing=10,
        validate=Range(0, 50, gettext(u"Параметр distance должен быть числом от 0 до 50, по умолчанию 10")))


class CodesAndSystemsMixin(object):
    system = fields.Method(deserialize="get_str_value", missing=u"yandex")
    system_from = fields.Method(deserialize="get_str_value", missing=None)
    system_to = fields.Method(deserialize="get_str_value", missing=None)
    code_from = fields.Method(deserialize="get_str_value", missing=u"", load_from="from")
    code_to = fields.Method(deserialize="get_str_value", missing=u"", load_from="to")

    @post_load
    def add_from_to_systems(self, data):
        data["system_from"] = data.get("system_from") or data["system"]
        data["system_to"] = data.get("system_to") or data["system"]


class ThreadInfoRequestSchema(BaseRequestSchema, CodesAndSystemsMixin):
    show_systems = ShowSystemsField(missing="")
    currency = fields.Method(deserialize="get_str_value", missing=u"")
    result_timezone = ResultTimezoneField(missing=u"")
    uid = fields.Str(required=True, error_messages={u"required": u"Не указан uid нитки."})
    dt = DateField(missing=u"", load_from="date")


class ScheduleRequestSchema(BaseRequestSchema):
    VALID_SYSTEMS = ["yandex", "iata", "sirena", "express", "esr"]

    show_systems = ShowSystemsField(missing="")
    page = fields.Integer(required=True, missing=1, error_messages={u"invalid": u"Страница должна быть целым числом."})
    event = fields.Str(required=True, missing=u"")
    dt = DateField(missing=u"", load_from="date")
    system = fields.Method(
        deserialize="get_str_value", missing=u"yandex",
        validate=OneOf(
            VALID_SYSTEMS,
            error=gettext(u"Система кодирования должна быть из списка [{}].".format(", ".join(VALID_SYSTEMS)))
        )
    )
    station = fields.Method(deserialize="get_str_value", missing=u"")
    direction = fields.Str(required=True, missing=None)
    transport_types = TransportTypesField(missing=u"")
    result_pytz = ResultTimezoneField(missing=u"", load_from="result_timezone")
    tablo = fields.Method(deserialize="get_bool_param", missing=False)


class SearchRequestSchema(BaseRequestSchema, CodesAndSystemsMixin):
    show_systems = ShowSystemsField(missing="")
    add_days_mask = fields.Method(deserialize="get_bool_param", missing=False)
    transfers = fields.Method(deserialize="get_bool_param", missing=False)
    currency = fields.Str(missing=u"")
    dt = DateField(missing=u"", load_from="date")
    transport_types = TransportTypesField(missing=u"")
    result_pytz = ResultTimezoneField(missing=u"", load_from="result_timezone")
    code_from = fields.Method(deserialize="get_str_value", load_from="from", required=True,
                              error_messages={u"required": u"Не указан параметр from."})
    code_to = fields.Method(deserialize="get_str_value", load_from="to", required=True,
                            error_messages={u"required": u"Не указан параметр to."})


def gen_validate_error(errors):
    params_errors = [u"{}: {} ".format(param, u" ".join(error for error in param_errors))
                     for param, param_errors in errors.items()]
    raise ApiError(text=gettext(u" ".join(params_errors)), http_code=status.HTTP_400_BAD_REQUEST)
