import json
import requests
from travel.avia.library.python.shared_dicts.rasp import get_repository, ResourceType
from travel.avia.shared_flights.data_importer.storage.station_codes import StationCodesStorage
from travel.proto.dicts.rasp.station_pb2 import TStation
from travel.proto.dicts.rasp.transport_pb2 import TTransport
from travel.proto.shared_flights.snapshots.station_with_codes_pb2 import TStationWithCodes
from typing import Dict


class StationStorage:
    def __init__(self, logger, sandbox, station_codes: StationCodesStorage, oauth_token=None):
        self._logger = logger
        self.by_id: Dict[int, TStationWithCodes] = self.fetch_stations(sandbox, station_codes, oauth_token)

        # Map stations to their iata codes
        self.by_iata: Dict[str, TStationWithCodes] = {}
        for station in self.by_id.values():
            if not station.IataCode:
                continue
            self.by_iata[station.IataCode] = station

    def fetch_stations(self, sandbox, station_codes: StationCodesStorage, oauth_token):
        station_repository = get_repository(
            ResourceType.TRAVEL_DICT_RASP_STATION_PROD,
            oauth=oauth_token,
        )
        if not station_repository:
            self.logger.error('Unable to fetch the stations data.')
            return None

        stations = {}

        unknown_station = TStationWithCodes()
        unknown_station.Station.Id = -1
        unknown_station.Station.IsHidden = True
        unknown_station.Station.TitleDefault = 'Unknown'
        unknown_station.Station.Type = TStation.EType.TYPE_UNKNOWN
        stations[-1] = unknown_station

        skipped_stations_count = 0
        stations_without_code_count = 0
        for station in station_repository.itervalues():
            if (
                station.TransportType != TTransport.EType.TYPE_PLANE and
                station.TransportType != TTransport.EType.TYPE_HELICOPTER
            ):
                # Keep rail and bus terminals in the list if they have IATA code
                if station.Id not in station_codes.by_id:
                    skipped_stations_count += 1
                    continue
            if not station_codes.by_id.get(station.Id):
                stations_without_code_count += 1
                station_without_codes = TStationWithCodes()
                station_without_codes.Station.MergeFrom(station)
                stations[station.Id] = station_without_codes
            else:
                station_codes_data = station_codes.by_id[station.Id]
                station_with_codes = TStationWithCodes()
                station_with_codes.Station.MergeFrom(station)
                station_with_codes.IataCode = station_codes_data.iata
                station_with_codes.IcaoCode = station_codes_data.icao
                station_with_codes.SirenaCode = station_codes_data.sirena
                stations[station.Id] = station_with_codes

        self._logger.info('Parsed stations: %s', format_number(len(stations)))
        self._logger.info('Skipped stations: %s', format_number(skipped_stations_count))
        self._logger.info('No-code stations: %s', format_number(stations_without_code_count))
        return stations


def format_number(num):
    return '{:,}'.format(num)
