# coding=utf-8
from __future__ import unicode_literals

import io
import logging
from zipfile import ZIP_DEFLATED, ZipFile

from travel.avia.library.python.sirena_client import SirenaClient
from travel.avia.shared_flights.tasks.sirena_parser.airlines_importer import AirlinesImporter
from travel.avia.shared_flights.tasks.sirena_parser.routes_importer import RoutesImporter
from travel.avia.shared_flights.tasks.sirena_parser.transport_models_importer import TransportModelsImporter
from travel.library.python.dicts.file_util import write_binary_string

logger = logging.getLogger(__name__)


class SirenaFetcher(SirenaClient):
    """ Fetches Sirena files over socket. """

    def __init__(self, host, port_name, client_id, carrier_codes, oauth_token):
        super(SirenaFetcher, self).__init__(host, port_name, client_id)
        self._carrier_codes = carrier_codes.decode('utf-8').split(',') if carrier_codes else []
        self._oauth_token = oauth_token

        if not self._carrier_codes:
            logger.error('Please specify the carrier codes id')

    def fetch(self):
        logger.info('Started fetching airlines')
        airlines_data = self.get_airlines_reference()
        logger.info('Fetched %d bytes', len(airlines_data))
        airlinesImporter = AirlinesImporter()
        airlines = airlinesImporter.parse(airlines_data)
        logger.info('Done with airlines')

        airlines_dict = {}
        for airline in airlines:
            airlines_dict[airline.SirenaCode] = airline

        logger.info('Started fetching routes')
        routes_data = {}
        fetched_bytes = 0
        for carrier_code in self._carrier_codes:
            routes_xml = self.get_normative_schedule(carrier_code)
            routes_data[carrier_code] = routes_xml
            fetched_bytes += len(routes_xml)

        logger.info('Fetched %d xml bytes for routes', fetched_bytes)
        routesImporter = RoutesImporter()
        routes = routesImporter.parse(routes_data, airlines_dict)
        logger.info('Done with routes')

        # Fetch transport models update
        logger.info('Started fetching transport models')
        transport_models_xml = self.get_transport_models_reference()
        logger.info('Fetched %d bytes', len(transport_models_xml))
        transport_models_importer = TransportModelsImporter()
        transport_models = transport_models_importer.parse(transport_models_xml)
        new_transport_models = transport_models_importer.compare_with_rasp_dict(transport_models, self._oauth_token)
        logger.info('Done with transport models')

        # Create zip file in memory
        mem_airlines_entry = io.BytesIO()
        for airline in airlines:
            write_binary_string(mem_airlines_entry, airline.SerializeToString())

        mem_routes_entry = io.BytesIO()
        for route in routes:
            write_binary_string(mem_routes_entry, route.SerializeToString())

        mem_file = io.BytesIO()
        with ZipFile(mem_file, mode='w', compression=ZIP_DEFLATED) as zip_file:
            zip_file.writestr("airlines.pb2.bin", mem_airlines_entry.getvalue())
            zip_file.writestr("routes.pb2.bin", mem_routes_entry.getvalue())

        return mem_file, new_transport_models
