# -*- coding: utf-8 -*-
from __future__ import absolute_import

from logging import getLogger, Logger  # noqa
from marshmallow import Schema, fields, ValidationError, post_load

from travel.avia.backend.repository.direction import direction_repository, DirectionRepository, Direction  # noqa
from travel.avia.backend.repository.settlement import settlement_repository, SettlementRepository  # noqa
from travel.avia.backend.main.rest.helpers import CommonView


class DirectionIndexForm(Schema):
    national_version = fields.String(required=False)
    departure_settlement_id = fields.Integer(required=False)
    arrival_settlement_id = fields.Integer(required=False)
    departure_settlement_geo_id = fields.Integer(required=False)
    arrival_settlement_geo_id = fields.Integer(required=False)
    limit = fields.Integer(required=False, missing=10)

    def __init__(self, repository_settlement, *args, **kwargs):
        # type: (SettlementRepository, list, dict) -> None
        super(DirectionIndexForm, self).__init__(*args, **kwargs)
        self._repository_settlement = repository_settlement

    @post_load
    def process_departure_settlement_geo_id(self, parsed_data):
        if 'departure_settlement_geo_id' in parsed_data:
            geo_id = parsed_data.get('departure_settlement_geo_id')
            s = self._repository_settlement.get_by_geo_id(geo_id)
            if not s:
                raise ValidationError('Unknown settlement with geo_id: {}'.format(geo_id))
            parsed_data['departure_settlement_id'] = s.id
        return parsed_data

    @post_load
    def process_arrival_settlement_geo_id(self, parsed_data):
        if 'arrival_settlement_geo_id' in parsed_data:
            geo_id = parsed_data.get('arrival_settlement_geo_id')
            s = self._repository_settlement.get_by_geo_id(geo_id)
            if not s:
                raise ValidationError('Unknown settlement with geo_id: {}'.format(geo_id))
            parsed_data['arrival_settlement_id'] = s.id
        return parsed_data


class DirectionIndexView(CommonView):
    def __init__(self, form, repository_direction, logger):
        # type: (DirectionIndexForm, DirectionRepository, Logger) -> None
        super(DirectionIndexView, self).__init__(form, logger)
        self._repository_direction = repository_direction

    def _process(self, parsed_data):
        national_version = parsed_data.get('national_version')
        from_id = parsed_data.get('departure_settlement_id')
        to_id = parsed_data.get('arrival_settlement_id')

        if national_version and from_id:
            directions = self._repository_direction.get_from_settlement(national_version, from_id)
        elif national_version and to_id:
            directions = self._repository_direction.get_to_settlement(national_version, to_id)
        else:
            directions = self._repository_direction.get_all(national_version)

        return [
            self._prepare(d) for d in directions[:parsed_data.get('limit')]
        ]

    @staticmethod
    def _prepare(direction):
        # type: (Direction) -> dict
        return {
            'departure_settlement_id': direction.departure_settlement_id,
            'arrival_settlement_id': direction.arrival_settlement_id,
            'direct_flights': direction.direct_flights,
            'connecting_flights': direction.connecting_flights,
            'national_version': direction.national_version,
            'popularity': direction.popularity,
        }


direction_index_view = DirectionIndexView(
    form=DirectionIndexForm(settlement_repository),
    repository_direction=direction_repository,
    logger=getLogger(__name__),
)
