from __future__ import unicode_literals

import abc
import logging
from collections import defaultdict
from datetime import datetime, timedelta
from typing import Iterable, Sequence, List

from travel.proto.dicts.rasp.carrier_pb2 import TCarrier
from travel.proto.dicts.rasp.timezone_pb2 import TTimeZone
from travel.proto.shared_flights.snapshots.blacklist_pb2 import TBlacklistRule
from travel.proto.shared_flights.snapshots.flight_merge_rule_pb2 import TFlightMergeRule
from travel.proto.shared_flights.snapshots.flight_status_pb2 import (
    TFlightStatus,
    TFlightStatusSource,
    TLastImportedInfo,
)
from travel.proto.shared_flights.snapshots.iata_correction_pb2 import TIataCorrectionRule
from travel.proto.shared_flights.snapshots.overrides_pb2 import TOverride
from travel.proto.shared_flights.snapshots.station_status_source_pb2 import TStationStatusSource
from travel.proto.shared_flights.snapshots.station_with_codes_pb2 import TStationWithCodes
from travel.proto.shared_flights.ssim.flights_pb2 import (
    EFlightBaseSource,
    TCarrierPopularScore,
    TDesignatedCarrier,
    TPopularScore,
)

from travel.avia.library.python.backend_client import BackendClient
from travel.avia.shared_flights.diff_builder.utils import ensure_string, convert_to_latin
from travel.avia.shared_flights.lib.python.settings import NATIONAL_VERSIONS


def _datetime_to_int_date(dtime: datetime):
    return dtime.year * 10000 + dtime.month * 100 + dtime.day


def datetime_to_string(dtime: datetime) -> str:
    return dtime.strftime('%Y-%m-%d %H:%M:%S') if dtime else ''


def datetime_to_date_string(dtime: datetime) -> str:
    return dtime.strftime('%Y-%m-%d') if dtime else ''


DEFAULT_LANGUAGE = 'ru'


class IDataFetcher(abc.ABC):
    @abc.abstractmethod
    def fetch(self) -> str:
        pass


class TimezoneFetcher(IDataFetcher):
    def __init__(self, conn, logger: logging.Logger):
        self._conn = conn
        self.logger = logger

    def fetch(self) -> Iterable[str]:
        self.logger.info('Start fetching timezones')
        cursor = self._conn.cursor()
        cursor.execute(
            '''
            select
                id,
                code
            from
                timezone
            order by
                id;
            '''
        )
        for row in cursor:
            timezone = TTimeZone()
            timezone.Id = row[0]
            timezone.Code = row[1]
            yield timezone.SerializeToString()

        self.logger.info('Done fetching timezones')


class DesignatedCarrierFetcher(IDataFetcher):
    def __init__(self, conn, logger: logging.Logger):
        self._conn = conn
        self.logger = logger

    def fetch(self) -> Iterable[str]:
        self.logger.info('Start fetching designated carriers')
        cursor = self._conn.cursor()
        cursor.execute(
            '''
            select
                id,
                title
            from
                designated_carrier
            order by
                id;
            '''
        )
        for row in cursor:
            designated_carrier = TDesignatedCarrier()
            designated_carrier.Id = row[0]
            designated_carrier.Title = row[1]
            yield designated_carrier.SerializeToString()

        self.logger.info('Done fetching designated carriers')


class CarrierFetcher(IDataFetcher):
    def __init__(self, conn, logger: logging.Logger):
        self._conn = conn
        self.logger = logger

    def fetch(self) -> Iterable[str]:
        self.logger.info('Start fetching carriers')
        cursor = self._conn.cursor()
        # iterate through carriers
        cursor.execute(
            '''
            select
                id,
                iata,
                icao,
                sirena_id,
                icao_ru
            from
                carrier
            order by
                id
            '''
        )
        for row in cursor:
            carrier = TCarrier()
            carrier.Id = row[0]
            carrier.Iata = row[1] or ''
            carrier.Icao = row[2] or ''
            carrier.SirenaId = row[3] or ''
            carrier.IcaoRu = row[4] or ''
            yield carrier.SerializeToString()

        self.logger.info('Done fetching carriers')


class PopularScoreFetcher(IDataFetcher):
    def __init__(self, backend_client: BackendClient, logger: logging.Logger):
        self._backend_client = backend_client
        self.logger = logger

    def fetch(self) -> Iterable[str]:
        self.logger.info('Start fetching popular scores')
        # fetching carrier popular scores from backend
        carrier_popular_scores = defaultdict(dict)
        for national_version in NATIONAL_VERSIONS:
            self.logger.info('Start fetching popular scores from avia_backend (national_version=%s)', national_version)
            for score in self._backend_client.popular_airlines(national_version, DEFAULT_LANGUAGE):
                carrier_popular_scores[score['id']][national_version] = score['popularity']

        for carrier_id, scores in carrier_popular_scores.items():
            popular_scores_for_carrier = TCarrierPopularScore()
            popular_scores_for_carrier.CarrierId = carrier_id

            for national_version, score in scores.items():
                local_popular_score = TPopularScore()
                local_popular_score.NationalVersion = national_version
                local_popular_score.Score = score
                popular_scores_for_carrier.PopularScores.append(local_popular_score)
            yield popular_scores_for_carrier.SerializeToString()
        self.logger.info('Done fetching popular scores')


class IATACorrectionRuleFetcher(IDataFetcher):
    def __init__(self, conn, logger: logging.Logger):
        self._conn = conn
        self.logger = logger

    def fetch(self) -> Iterable[str]:
        self.logger.info('Start fetching iata correction rules')
        cursor = self._conn.cursor()
        # iterate through iata correction rules
        cursor.execute(
            '''
            select
                id,
                marketing_carrier_iata,
                carrier_sirena,
                flying_carrier_iata,
                designated_carrier,
                flight_number_regex,
                carrier_id,
                priority
            from
                iata_correction
            order by
                id;
            '''
        )
        for row in cursor:
            rule = TIataCorrectionRule()
            rule.Id = row[0]
            rule.MarketingCarrierIata = row[1] or ''
            rule.CarrierSirena = row[2] or ''
            rule.FlyingCarrierIata = row[3] or ''
            rule.DesignatedCarrier = row[4] or ''
            rule.FlightNumberRegex = row[5] or ''
            rule.CarrierId = row[6] or 0
            rule.Priority = row[7] or 0
            yield rule.SerializeToString()

        self.logger.info('Done fetching iata correction rules')


def fetch_stations(cursor, logger: logging.Logger) -> List[TStationWithCodes]:
    logger.info('Start fetching stations')
    stations = []
    # iterate through stations
    cursor.execute(
        '''
        select
            id,
            settlement_id,
            country_id,
            time_zone_id,
            iata,
            icao,
            sirena,
            title_default
        from
            station_with_codes
        order by
            id;
        '''
    )
    for row in cursor:
        station = TStationWithCodes()
        station.Station.Id = row[0]
        station.Station.SettlementId = row[1]
        station.Station.CountryId = row[2]
        station.Station.TimeZoneId = row[3]
        station.IataCode = row[4]
        station.IcaoCode = row[5]
        station.SirenaCode = row[6]
        station.Station.TitleDefault = row[7]

        stations.append(station)

    logger.info('Done fetching stations')
    return stations


class StationFetcher(IDataFetcher):
    def __init__(self, conn, logger: logging.Logger):
        self._conn = conn
        self.logger = logger

    def fetch(self) -> Iterable[str]:
        self.logger.info('Start fetching stations')
        cursor = self._conn.cursor()
        for station in fetch_stations(cursor=cursor, logger=self.logger):
            yield station.SerializeToString()


class StationStatusSourceFetcher(IDataFetcher):
    def __init__(self, conn, logger: logging.Logger):
        self._conn = conn
        self.logger = logger

    def fetch(self) -> Iterable[str]:
        self.logger.info('Start fetching carriers')
        cursor = self._conn.cursor()
        # iterate through carriers
        cursor.execute(
            '''
            select
                station_id,
                status_source_id,
                whitelist,
                blacklist,
                trusted
            from
                station_status_source
            order by
                station_id, status_source_id
            '''
        )
        for row in cursor:
            station_status_source = TStationStatusSource()
            station_status_source.StationID = row[0] or 0
            station_status_source.StatusSourceID = row[1] or 0
            station_status_source.Whitelist = row[2] or False
            station_status_source.Blacklist = row[3] or False
            station_status_source.Trusted = row[4] or False

            yield station_status_source.SerializeToString()

        self.logger.info('Done fetching station status source mapping')


class FlightStatusSourceFetcher(IDataFetcher):
    def __init__(self, conn, logger: logging.Logger):
        self._conn = conn
        self.logger = logger

    def fetch(self) -> Iterable[str]:
        self.logger.info('Start fetching flight status sources')
        cursor = self._conn.cursor()
        cursor.execute(
            '''
            select
                id,
                name,
                priority
            from
                status_source
            order by
                id;
            '''
        )
        for row in cursor:
            fs_source = TFlightStatusSource()
            fs_source.Id = row[0]
            fs_source.Name = row[1] or ''
            fs_source.Priority = row[2] or 0
            yield fs_source.SerializeToString()

        self.logger.info('Done fetching flight status sources')


def get_station(stations_map, route_point):
    if route_point:
        return stations_map.get(route_point)

    return None


class FlightStatusFetcher(IDataFetcher):
    def __init__(self, conn, stations: Sequence[TStationWithCodes], logger: logging.Logger):
        self._conn = conn
        self._stations = stations
        self.logger = logger

    def fetch(self) -> Iterable[str]:
        self.logger.info('Start fetching flight statuses')
        legs_with_unknown_endpoints = 0  # when we are not able to determine at least one endpoint for the leg
        legs_with_contradictory_endpoints = (
            0  # when arrival and departure flight sources do not agree on what airport is
        )
        nodates_dop_flight_count = 0  # when arrival or departure date is not specified
        dop_flights_count = 0

        cursor = self._conn.cursor()
        stations_map = self._fetch_routepoints(cursor)
        for station in self._stations:
            if station.IataCode:
                stations_map[station.IataCode] = station.Station.Id

            if station.IcaoCode:
                stations_map[station.IcaoCode] = station.Station.Id

            if station.SirenaCode:
                stations_map[station.SirenaCode] = station.Station.Id

            if station.Station.TitleDefault:
                stations_map[station.Station.TitleDefault] = station.Station.Id

        last_day_to_fetch = (datetime.now() - timedelta(days=32)).strftime('%Y-%m-%d')
        cursor = self._conn.cursor()
        # iterate through flight statuses
        cursor.execute(
            '''
            select
                airlineid,
                flightnumber,
                legnumber,
                flightdate,
                statussourceid,
                createdatutc,
                updatedatutc,
                departuretimeactual,
                departuretimescheduled,
                departurestatus,
                departuregate,
                departureterminal,
                departurediverted,
                departuredivertedairportcode,
                departurecreatedatutc,
                departurereceivedatutc,
                departureupdatedatutc,
                arrivaltimeactual,
                arrivaltimescheduled,
                arrivalstatus,
                arrivalgate,
                arrivalterminal,
                arrivaldiverted,
                arrivaldivertedairportcode,
                arrivalcreatedatutc,
                arrivalreceivedatutc,
                arrivalupdatedatutc,
                checkindesks,
                baggagecarousels,
                departureroutepointfrom,
                departureroutepointto,
                arrivalroutepointfrom,
                arrivalroutepointto,
                airlinecode,
                departureairport,
                arrivalairport
            from
                flight_status
            where
                flightdate > %(last_day_to_fetch)s
            order by
                updatedatutc;
            ''',
            {'last_day_to_fetch': last_day_to_fetch},
        )
        scheduled_departure_time_column = 8
        scheduled_arrival_time_column = 18
        self.logger.info('Done executing SQL for flight statuses')
        for row in cursor:
            flight_status = TFlightStatus()
            flight_status.AirlineId = row[0]
            flight_status.FlightNumber = row[1]
            flight_status.LegNumber = row[2]
            flight_status.FlightDate = datetime_to_date_string(row[3])
            flight_status.ArrivalSourceId = row[4]
            flight_status.DepartureSourceId = row[4]
            flight_status.CreatedAtUtc = datetime_to_string(row[5])
            flight_status.UpdatedAtUtc = datetime_to_string(row[6])
            flight_status.DepartureTimeActual = datetime_to_string(row[7])
            flight_status.DepartureTimeScheduled = datetime_to_string(row[scheduled_departure_time_column])
            flight_status.DepartureStatus = ensure_string(row[9])
            flight_status.DepartureGate = ensure_string(row[10])
            flight_status.DepartureTerminal = convert_to_latin(ensure_string(row[11]))
            flight_status.DepartureDiverted = bool(row[12])
            flight_status.DepartureDivertedAirportCode = ensure_string(row[13])
            flight_status.DepartureCreatedAtUtc = datetime_to_string(row[14])
            flight_status.DepartureReceivedAtUtc = datetime_to_string(row[15])
            flight_status.DepartureUpdatedAtUtc = datetime_to_string(row[16])
            flight_status.ArrivalTimeActual = datetime_to_string(row[17])
            flight_status.ArrivalTimeScheduled = datetime_to_string(row[scheduled_arrival_time_column])
            flight_status.ArrivalStatus = ensure_string(row[19])
            flight_status.ArrivalGate = ensure_string(row[20])
            flight_status.ArrivalTerminal = convert_to_latin(ensure_string(row[21]))
            flight_status.ArrivalDiverted = bool(row[22])
            flight_status.ArrivalDivertedAirportCode = ensure_string(row[23])
            flight_status.ArrivalCreatedAtUtc = datetime_to_string(row[24])
            flight_status.ArrivalReceivedAtUtc = datetime_to_string(row[25])
            flight_status.ArrivalUpdatedAtUtc = datetime_to_string(row[26])
            flight_status.CheckInDesks = ensure_string(row[27])
            flight_status.BaggageCarousels = ensure_string(row[28])

            departure_route_point_from = ensure_string(row[29])
            departure_route_point_to = ensure_string(row[30])
            arrival_route_point_from = ensure_string(row[31])
            arrival_route_point_to = ensure_string(row[32])
            flight_status.CarrierCode = ensure_string(row[33])

            departure_airport = row[34]
            departure_airport_id = get_station(stations_map, departure_airport)
            arrival_airport = row[35]
            arrival_airport_id = get_station(stations_map, arrival_airport)

            departure_route_point_from_id = get_station(stations_map, departure_route_point_from)
            arrival_route_point_from_id = get_station(stations_map, arrival_route_point_from)
            departure_route_point_to_id = get_station(stations_map, departure_route_point_to)
            arrival_route_point_to_id = get_station(stations_map, arrival_route_point_to)

            departure_date = 0
            if row[scheduled_departure_time_column] is not None:
                departure_date = _datetime_to_int_date(row[scheduled_departure_time_column])
            arrival_date = 0
            if row[scheduled_arrival_time_column] is not None:
                arrival_date = _datetime_to_int_date(row[scheduled_arrival_time_column])

            route_point_from = departure_airport_id or departure_route_point_from_id or arrival_route_point_from_id or 0
            route_point_to = arrival_airport_id or departure_route_point_to_id or arrival_route_point_to_id or 0

            if flight_status.LegNumber < 0:
                if (
                    departure_route_point_from_id
                    and arrival_route_point_from_id
                    and departure_route_point_from_id != arrival_route_point_from_id
                ):
                    legs_with_contradictory_endpoints += 1
                    continue
                if (
                    departure_route_point_to_id
                    and arrival_route_point_to_id
                    and departure_route_point_to_id != arrival_route_point_to_id
                ):
                    legs_with_contradictory_endpoints += 1
                    continue

                if not departure_date and not arrival_date:
                    nodates_dop_flight_count += 1
                    continue

                dop_flights_count += 1

            if route_point_from:
                flight_status.DepartureStation = route_point_from
            if route_point_to:
                flight_status.ArrivalStation = route_point_to

            yield flight_status.SerializeToString()

        self.logger.info('Legs with unknown endpoints count: %d', legs_with_unknown_endpoints)
        self.logger.info('Legs with contradictory endpoints count: %d', legs_with_contradictory_endpoints)
        self.logger.info('No dates dop flights count: %d', nodates_dop_flight_count)
        self.logger.info('Exported dop flights count: %d', dop_flights_count)

        self.logger.info('Done fetching flight statuses')

    def _fetch_routepoints(self, cursor):
        self.logger.info('Start fetching route points')
        # iterate through route points
        stations_map = {}
        cursor.execute(
            '''
            select
                station_code,
                station_id,
                city_code
            from
                stop_point
            where
                station_id >= 0 and station_code is not null
            order by
                id;
            '''
        )
        for row in cursor:
            station_code = row[0]
            station_id = row[1]
            city_code = row[2]
            if station_code:
                stations_map[station_code] = station_id
            if city_code and city_code not in stations_map:
                stations_map[city_code] = station_id

        self.logger.info('Done fetching route points')
        return stations_map


class OverrideFetcher(IDataFetcher):
    def __init__(self, conn, logger: logging.Logger):
        self._conn = conn
        self.logger = logger

    def fetch(self) -> Iterable[str]:
        self.logger.info('Start fetching overrides')
        cursor = self._conn.cursor()
        # iterate through overrides
        cursor.execute(
            '''
            select
                id,
                bucket_key,
                marketing_carrier,
                marketing_carrier_iata,
                marketing_flight_number,
                leg_number,
                operating_from,
                operating_until,
                operating_on_days,
                overrides,
                created_at,
                updated_at
            from
                flight_pattern_correction
            order by
                bucket_key,
                operating_from,
                operating_on_days;
            '''
        )
        for row in cursor:
            override = TOverride()
            override.Id = row[0]
            override.BucketKey = ensure_string(row[1])
            override.MarketingCarrier = row[2]
            override.MarketingCarrierIata = ensure_string(row[3])
            override.MarketingFlightNumber = ensure_string(row[4])
            override.LegNumber = row[5]
            override.OperatingFrom = datetime_to_string(row[6])
            override.OperatingUntil = datetime_to_string(row[7])
            override.OperatingOnDays = row[8]
            override.Overrides = ensure_string(row[9])
            override.CreatedAt = datetime_to_string(row[10])
            override.UpdatedAt = datetime_to_string(row[11])
            # TODO(u-jeen): support isDeleted properly
            override.IsDeleted = False
            yield override.SerializeToString()

        self.logger.info('Done fetching overrides')


class BlackListRuleFetcher(IDataFetcher):
    def __init__(self, conn, stations: Sequence[TStationWithCodes], logger: logging.Logger):
        self._conn = conn
        self._stations = stations
        self.logger = logger

    def fetch(self) -> Iterable[str]:
        self.logger.info('Start fetching blacklist rules')
        cursor = self._conn.cursor()
        # iterate through blacklist rules
        cursor.execute(
            '''
            select
                id,
                station_from_id,
                settlement_from_id,
                country_from_id,
                station_to_id,
                settlement_to_id,
                country_to_id,
                carrier_id,
                flight_number,
                flight_date_from,
                flight_date_to,
                national_version,
                language,
                reason,
                force,
                active_since,
                active_until,
                created_at,
                updated_at,
                source
            from
                blacklist
            order by
                force,
                id
            '''
        )
        for row in cursor:
            rule = TBlacklistRule()
            rule.Id = row[0]
            if row[1]:
                rule.StationFromId = row[1]
            if row[2]:
                rule.StationFromSettlement = row[2]
            if row[3]:
                rule.StationFromCountry = row[3]

            if row[4]:
                rule.StationToId = row[4]
            if row[5]:
                rule.StationToSettlement = row[5]
            if row[6]:
                rule.StationToCountry = row[6]

            if row[7]:
                rule.MarketingCarrierId = row[7]
            if row[8]:
                rule.MarketingFlightNumber = row[8]
            rule.FlightDateSince = datetime_to_string(row[9])
            rule.FlightDateUntil = datetime_to_string(row[10])
            if row[11]:
                rule.NationalVersion = row[11]
            if row[12]:
                rule.Language = row[12]
            if row[13]:
                rule.Reason = row[13]
            rule.ForceMode = row[14]
            rule.ActiveSince = datetime_to_string(row[15])
            rule.ActiveUntil = datetime_to_string(row[16])
            rule.CreatedAt = datetime_to_string(row[17])
            rule.UpdatedAt = datetime_to_string(row[18])
            rule.Source = self._string_to_source(row[19])
            yield rule.SerializeToString()

        self.logger.info('Done fetching black list rules')

    def _string_to_source(self, value: str):
        mapping = {
            'INNOVATA': EFlightBaseSource.TYPE_INNOVATA,
            'SIRENA': EFlightBaseSource.TYPE_SIRENA,
            'APM': EFlightBaseSource.TYPE_APM,
            'DOP': EFlightBaseSource.TYPE_DOP,
            'AMADEUS': EFlightBaseSource.TYPE_AMADEUS,
        }
        return mapping.get(value, EFlightBaseSource.TYPE_UNKNOWN)


class FlightMergeRuleFetcher(IDataFetcher):
    def __init__(self, conn, logger: logging.Logger):
        self._conn = conn
        self.logger = logger

    def fetch(self) -> Iterable[str]:
        self.logger.info('Start fetching flight merge rules')
        cursor = self._conn.cursor()
        cursor.execute(
            '''
            select
                id,
                operating_carrier_id,
                operating_flight_number_regex,
                marketing_carrier_id,
                marketing_flight_number_regex,
                excluded_carrier_id,
                should_merge,
                is_rule_active,
                updated_at,
                comment
            from
                flight_merge_rule
            order by
                id;
            '''
        )
        for row in cursor:
            rule = TFlightMergeRule()
            rule.Id = row[0]
            rule.OperatingCarrier = row[1] or 0
            rule.OperatingFlightRegexp = row[2] or ''
            rule.MarketingCarrier = row[3] or 0
            rule.MarketingFlightRegexp = row[4] or ''
            rule.ExcludedCarrier = row[5] or 0
            rule.ShouldMerge = row[6] == 't' or row[6] is True
            rule.IsActive = row[7] == 't' or row[7] is True
            rule.UpdatedAt = datetime_to_string(row[8])
            rule.Comment = row[9] or ''
            yield rule.SerializeToString()

        self.logger.info('Done fetching flight merge rules')


class LastImportedFetcher(IDataFetcher):
    def __init__(self, conn, logger: logging.Logger):
        self._conn = conn
        self.logger = logger

    def fetch(self) -> Iterable[str]:
        self.logger.info('Start fetching last imported dates')
        cursor = self._conn.cursor()
        cursor.execute(
            '''
            select
                id,
                imported_date,
                imported_resource_id,
                updated_at,
                resource_type
            from
                last_imported_info
            where
                imported_date > '2020-02-01' and resource_type like '%RESOURCE'
            order by
                id;
            '''
        )
        for row in cursor:
            last_imported = TLastImportedInfo()
            last_imported.Id = row[0]
            last_imported.ImportedDate = datetime_to_string(row[1]) if row[1] else ''
            last_imported.ImportedResourceId = row[2] or 0
            last_imported.UpdatedAt = datetime_to_string(row[3]) if row[3] else ''
            last_imported.ResourceType = row[4] or ''
            yield last_imported.SerializeToString()

        self.logger.info('Done fetching last imported dates')
