# -*- coding: utf-8 -*-
from __future__ import absolute_import

import logging
from datetime import timedelta

from marshmallow import fields, validate
from more_itertools import unique_everseen

from travel.avia.library.python.avia_data.models.air_traffic_recovery import AirTrafficRecoveryStat, TransportType
from travel.avia.library.python.common.utils.iterrecipes import group_by

from travel.avia.backend.main.api.api_handler import ApiHandler
from travel.avia.backend.main.api.api_schema import TypeSchema
from travel.avia.backend.main.api_types.reference import FormBaseParams
from travel.avia.backend.main.lib.covid_restrictions import get_country_restrictions, get_region_restrictions
from travel.avia.backend.repository import settlement_repository, country_repository


logger = logging.getLogger(__name__)


class AirTrafficRecoveryStatParams(FormBaseParams):
    departure_settlement_id = fields.Integer(required=True)
    date_forward = fields.Date(required=True)
    new_coefficient = fields.Boolean(required=False, default=False, missing=False)
    transport = fields.String(required=False, validate=validate.OneOf(('plane', 'train', 'all')), missing='plane')
    # Todo: default=plane нужен только в момент пеерехода, пока фронт не будет готов поддерживать остальные.


class AirTrafficRecoveryStatParamsSchema(TypeSchema):
    from_city = fields.Dict()
    to_city = fields.Dict()
    departure_settlement_id = fields.Integer()
    arrival_settlement_id = fields.Integer()
    date_forward = fields.Date()
    last_search_date = fields.Date()
    direct_flights = fields.Integer()
    connecting_flights = fields.Integer()
    historical_max_direct_flights = fields.Integer()
    historical_max_connecting_flights = fields.Integer()
    coefficient = fields.Float()
    direct_flights_coefficient = fields.Float()
    connecting_flights_coefficient = fields.Float()
    min_price_changes = fields.Float()
    last_min_price = fields.Integer()
    historical_min_price = fields.Integer()
    city_coefficient = fields.Float()
    restrictions = fields.Dict()
    next_direct_flight_day = fields.Date()
    transport = fields.Function(lambda obj: TransportType(obj.transport).name)


class AirTrafficRecoveryStatHandler(ApiHandler):
    PARAMS_SCHEMA = AirTrafficRecoveryStatParams
    TYPE_SCHEMA = AirTrafficRecoveryStatParamsSchema
    MULTI = True

    def preprocess_fields(self, fields):
        if not fields:
            return (
                'departure_settlement_id', 'arrival_settlement_id', 'date_forward',
                'last_min_price', 'direct_flights', 'connecting_flights', 'last_search_date',
                'coefficient', 'direct_flights_coefficient', 'connecting_flights_coefficient',
                'historical_max_direct_flights', 'historical_max_connecting_flights',
                'city_coefficient', 'restrictions', 'next_direct_flight_day', 'from_city', 'to_city', 'transport'
            )

        return fields

    @staticmethod
    def statistic_with_updated_coefficients(params, filter_query):
        statistic = AirTrafficRecoveryStat.objects.filter(
            departure_settlement_id=params['departure_settlement_id'],
            date_forward__gte=params['date_forward'] - timedelta(days=3),
            date_forward__lte=params['date_forward'] + timedelta(days=3),
            historical_max_direct_flights__gt=0,
            **filter_query
        )
        result_statistic = [s for s in statistic if s.date_forward == params['date_forward']]
        if not result_statistic:
            return result_statistic

        new_coefficient_by_arrival_id = {}
        for key, group in group_by(statistic, key=lambda x: x.arrival_settlement_id):
            direct_flights = 0
            historical_direct_flights = 0
            for s in group:
                direct_flights += s.direct_flights
                historical_direct_flights += s.historical_max_direct_flights

            coefficient = 0.
            if historical_direct_flights:
                coefficient = round((float(direct_flights) / historical_direct_flights) * 5, 1)

            new_coefficient_by_arrival_id[key] = coefficient

        for s in result_statistic:
            s.coefficient = new_coefficient_by_arrival_id.get(s.arrival_settlement_id, 0.)

        return result_statistic

    def process(self, params, fields):
        filter_query = {}
        if params['transport'] != 'all':
            filter_query['transport'] = getattr(TransportType, params['transport']).value

        if params['new_coefficient']:
            statistic = self.statistic_with_updated_coefficients(params, filter_query)
        else:
            statistic = AirTrafficRecoveryStat.objects.filter(
                departure_settlement_id=params['departure_settlement_id'],
                date_forward=params['date_forward'],
                **filter_query
            )

        self.fill_statistic(statistic, params['departure_settlement_id'])

        logger.info(
            'Return %d directions for settlement_id=%d date_forward=%s',
            len(statistic), params['departure_settlement_id'], params['date_forward']
        )

        return list(unique_everseen(statistic, key=lambda k: (k.arrival_settlement_id, k.transport)))

    def fill_statistic(self, statistic, departure_settlement_id):
        """
        Для каждого города назначения проверяем, есть ли ограничения сначала для его региона
         и если есть, то применяем их.
        Если нет ограничений региона, то проверяем, есть ли ограничения для страны в целом, и применяем.
        """
        if not len(statistic):
            return

        departure_settlement = self._serialize_settlement(
            settlement_repository.get(departure_settlement_id)
        )

        for direction_stat in statistic:
            arrival_settlement = settlement_repository.get(direction_stat.arrival_settlement_id)
            direction_stat.from_city = departure_settlement
            direction_stat.to_city = self._serialize_settlement(arrival_settlement)

            restrictions = {}
            if arrival_settlement.region_id:
                restrictions = get_region_restrictions(arrival_settlement.region_id)

            if not restrictions and arrival_settlement.country_id:
                restrictions = get_country_restrictions(arrival_settlement.country_id)

            direction_stat.restrictions = restrictions

    @staticmethod
    def _serialize_settlement(settlement):
        country = country_repository.get(settlement.country_id)
        return {
            'latitude': settlement.latitude,
            'key': settlement.point_key,
            'longitude': settlement.longitude,
            'majorityId': settlement.majority_id,
            'countryCode': country.code if country else None
        }
