# -*- coding: utf-8 -*-
from __future__ import absolute_import

from marshmallow import Schema, fields

from travel.avia.backend.main.api.api_handler import ApiHandler
from travel.avia.backend.repository.iata_correction import iata_correction_repository
from travel.avia.backend.repository.company import company_repository
from travel.avia.library.python.common.lib.company_finder import CompanyFinder


class FlightNumbersSchema(Schema):
    flight_numbers = fields.List(fields.String, required=True)


class FlightNumbersToCompanyIdsHandler(ApiHandler):
    PARAMS_SCHEMA = FlightNumbersSchema
    IS_RAW_SCHEMA = True

    def __init__(self, *args, **kwargs):
        super(FlightNumbersToCompanyIdsHandler, self).__init__(*args, **kwargs)

        self._company_finder = CompanyFinder(
            companies_by_iata=company_repository.get_companies_by_iata(),
            companies_by_sirena=company_repository.get_companies_by_sirena(),
            companies_by_icao=company_repository.get_companies_by_icao(),
            companies_by_icao_ru=company_repository.get_companies_by_icao_ru(),
            company_by_id=company_repository.get_companies_by_id(),
            corrections_by_iata=iata_correction_repository.get_corrections_by_iata(),
        )

    def process(self, params, fields):
        flight_numbers = frozenset(params['flight_numbers'])
        flight_numbers_to_companies = {
            flight_number: self._company_finder.find_by_flight_number(flight_number)
            for flight_number in flight_numbers if flight_number
        }

        return {
            flight_number: company.id if company else None
            for flight_number, company in flight_numbers_to_companies.iteritems()
        }
