# coding=utf-8
from datetime import datetime, timedelta
from logging import Logger, getLogger

from sqlalchemy.orm import Session

from travel.avia.price_index.db_models.direction_price import DirectionPrice
from travel.avia.price_index.db_models.dynamic_row import DynamicRow
from travel.avia.price_index.db_models.history import History
from travel.avia.price_index.lib import flag_dependent_settings
from travel.avia.price_index.lib.constants import NULL_DATE
from travel.avia.price_index.lib.currency_provider import currency_provider, CurrencyProvider
from travel.avia.price_index.lib.indexer.index_builder import FareIndexBuilder, fare_index_builder
from travel.avia.price_index.lib.price_converter import price_converter, PriceConverter
from travel.avia.price_index.lib.rates_provider import rates_provider, RatesProvider
from travel.avia.price_index.models.flight import Flight
from travel.avia.price_index.models.query import Query


class Indexer(object):
    def __init__(self, fare_index_builder, price_converter, rates_provider, currency_provider, logger):
        # type: (FareIndexBuilder, PriceConverter, RatesProvider, CurrencyProvider, Logger) -> None
        self._fare_index_builder = fare_index_builder
        self._price_converter = price_converter
        self._rates_provider = rates_provider
        self._currency_provider = currency_provider
        self._logger = logger

    def index(self, session, query, result):
        self._logger.info('start index')
        try:
            self._index(session, query, result)
            self._logger.info('finish index')
        except Exception as e:
            self._logger.exception('Can not index: %r', e)
            raise

    def _fetch_dynamic_rows(self, session, query):
        return list(
            session.query(DynamicRow).filter(
                DynamicRow.from_id == query.from_id,
                DynamicRow.to_id == query.to_id,
                DynamicRow.adults_count == query.adults_count,
                DynamicRow.children_count == query.children_count,
                DynamicRow.infants_count == query.infants_count,
                DynamicRow.national_version_id == query.national_version_id,
                DynamicRow.forward_date == query.forward_date,
                DynamicRow.backward_date == (query.backward_date or NULL_DATE),
            )
        )

    def _has_baggage(self, price):
        for b in price['baggage'][0] + price['baggage'][1]:
            if b is None or b.startswith('0'):
                return False
        return True

    def _get_min_price_from(self, query, fare, base_currency_id, rate_by_currency_id):
        with_baggage = []
        without_baggage = []

        for p in fare['prices']:
            currency_code = p['tariff']['currency']
            value = p['tariff']['value']

            currency = self._currency_provider.get_by_code(currency_code, query.national_version_id)
            if currency is None:
                continue
            currency_id = currency.pk

            base_value = self._price_converter.convert_to_base_currency_id(
                value, currency_id, base_currency_id, rate_by_currency_id
            )

            if base_value is None:
                continue

            price = (base_value, value, currency_id)

            if self._has_baggage(p):
                with_baggage.append(price)
            else:
                without_baggage.append(price)

        return (
            min(without_baggage, key=lambda x: x[0]) if without_baggage else None,
            min(with_baggage, key=lambda x: x[0]) if with_baggage else None,
        )

    def _build_diff(self, new_dynamic_rows, old_dynamic_rows):
        def _make_key(row):
            d = (
                row.has_baggage,
                tuple(sorted(row.airlines)),
                row.forward_departure_airport,
                tuple(sorted(row.forward_transfer_airports)),
                row.forward_arrival_airport,
                row.backward_departure_airport,
                tuple(sorted(row.backward_transfer_airports)),
                row.backward_arrival_airport,
                row.forward_departure_time_type,
                row.forward_arrival_time_type,
                row.backward_departure_time_type,
                row.backward_arrival_time_type,
                row.count_transfer,
                row.duration_transfer,
                row.has_airport_change,
                row.has_night_transfer,
            )

            return d

        key_to_new_dynamic_row = {_make_key(r): r for r in new_dynamic_rows}
        key_to_old_dynamic_rows = {_make_key(r): r for r in old_dynamic_rows}

        new_keys = set(key_to_new_dynamic_row) - set(key_to_old_dynamic_rows)
        updated_keys = [
            k
            for k in set(key_to_new_dynamic_row) & set(key_to_old_dynamic_rows)
            if abs(key_to_old_dynamic_rows[k].base_value - key_to_new_dynamic_row[k].base_value) > 0.01
        ]
        delete_keys = set(key_to_old_dynamic_rows) - set(key_to_new_dynamic_row)

        return {
            'insert_data': [key_to_new_dynamic_row[k] for k in new_keys],
            'update_data': [(key_to_old_dynamic_rows[k], key_to_new_dynamic_row[k]) for k in updated_keys],
            'delete_data': [key_to_old_dynamic_rows[k] for k in delete_keys],
            'min_data': new_dynamic_rows[0] if new_dynamic_rows else None,
        }

    def _apply_insert_changes(self, session, rows):
        if not rows:
            self._logger.info('Nothing insert')
            return
        self._logger.info('Start: Process insert dynamic rows')

        self._logger.info('Finish: Process insert dynamic rows')
        self._logger.info('Start: Flush insert changes')

        session.bulk_save_objects(rows)
        session.flush()

        self._logger.info('Finish: Flush insert changes')

    def _apply_update_changes(
        self,
        session,
        old_and_new_rows,
    ):
        if not old_and_new_rows:
            self._logger.info('Nothing update')
            return
        self._logger.info('Start: Process update dynamic rows')
        for old, new in old_and_new_rows:
            session.query(DynamicRow).filter(DynamicRow.pk == old.pk,).update(
                {'base_value': new.base_value, 'value': new.value, 'currency_id': new.currency_id},
                synchronize_session=False,
            )
        self._logger.info('Finish: Process update dynamic rows')
        self._logger.info('Start: Flush update dynamic rows')
        session.flush()
        self._logger.info('Finish: Flush update dynamic rows')

    def _apply_delete_changes(self, session, rows):
        if not rows:
            self._logger.info('Nothing delete')
            return

        self._logger.info('Start: Process delete dynamic rows')
        session.query(DynamicRow).filter(DynamicRow.pk.in_(tuple(d.pk for d in rows))).delete(synchronize_session=False)
        self._logger.info('Finish: Process delete dynamic rows')
        self._logger.info('Start: Flush delete dynamic rows')
        session.flush()
        self._logger.info('End: Flush delete dynamic rows')

    def _apply_minimal_price(self, session, query, min_data):
        self._logger.info('Start: Calculating direction price')
        has = bool(
            session.query(DirectionPrice)
            .filter(
                DirectionPrice.national_version_id == query.national_version_id,
                DirectionPrice.from_id == query.from_id,
                DirectionPrice.to_id == query.to_id,
                DirectionPrice.forward_date == query.forward_date,
                DirectionPrice.backward_date == (query.backward_date or NULL_DATE),
            )
            .count()
        )

        if min_data is None and has:
            self._logger.info('Delete min price')
            session.query(DirectionPrice).filter(
                DirectionPrice.national_version_id == query.national_version_id,
                DirectionPrice.from_id == query.from_id,
                DirectionPrice.to_id == query.to_id,
                DirectionPrice.forward_date == query.forward_date,
                DirectionPrice.backward_date == (query.backward_date or NULL_DATE),
            ).delete()
        elif min_data and not has:
            self._logger.info('Insert min price')
            direction_price = DirectionPrice(
                national_version_id=query.national_version_id,
                from_id=query.from_id,
                to_id=query.to_id,
                adults_count=query.adults_count,
                children_count=query.children_count,
                infants_count=query.infants_count,
                forward_date=query.forward_date,
                backward_date=(query.backward_date or NULL_DATE),
                base_value=min_data.base_value,
                value=min_data.value,
                currency_id=min_data.currency_id,
                created_at=datetime.now(),
            )
            session.add(direction_price)
        elif min_data and has:
            self._logger.info('Update min price')
            session.query(DirectionPrice).filter(
                DirectionPrice.national_version_id == query.national_version_id,
                DirectionPrice.from_id == query.from_id,
                DirectionPrice.to_id == query.to_id,
                DirectionPrice.forward_date == query.forward_date,
                DirectionPrice.backward_date == (query.backward_date or NULL_DATE),
            ).update({'base_value': min_data.base_value, 'value': min_data.value, 'currency_id': min_data.currency_id})
        self._logger.info('Finish: Calculating direction price')
        self._logger.info('Start: Flush update direction price')
        session.flush()
        self._logger.info('Finish: Flush update direction price')

    def _apply_patch(self, session, query, diff):
        self._apply_insert_changes(session=session, rows=diff['insert_data'])
        self._apply_update_changes(session=session, old_and_new_rows=diff['update_data'])
        self._apply_delete_changes(session=session, rows=diff['delete_data'])

        self._apply_minimal_price(session=session, query=query, min_data=diff['min_data'])

    def build_dynamics_item(self, query, index, has_baggage, price):
        airport_index = index.airport_index
        transfer_index = index.transfer_index
        time_index = index.time_index

        return DynamicRow(
            national_version_id=query.national_version_id,
            from_id=query.from_id,
            to_id=query.to_id,
            adults_count=query.adults_count,
            children_count=query.children_count,
            infants_count=query.infants_count,
            forward_date=query.forward_date,
            backward_date=query.backward_date or NULL_DATE,
            has_baggage=has_baggage,
            base_value=price[0],
            value=price[1],
            currency_id=price[2],
            created_at=datetime.now(),
            airlines=list(index.airline_index),
            forward_departure_airport=airport_index.forward.departure,
            forward_transfer_airports=list(airport_index.forward.transfer),
            forward_arrival_airport=airport_index.forward.arrival,
            backward_departure_airport=airport_index.backward.departure,
            backward_transfer_airports=list(airport_index.backward.transfer),
            backward_arrival_airport=airport_index.backward.arrival,
            forward_departure_time_type=time_index.forward_departure,
            forward_arrival_time_type=time_index.forward_arrival,
            backward_departure_time_type=time_index.backward_departure,
            backward_arrival_time_type=time_index.backward_arrival,
            count_transfer=transfer_index.count,
            duration_transfer=transfer_index.duration,
            has_airport_change=transfer_index.has_airport_change,
            has_night_transfer=transfer_index.has_night_transfer,
        )

    def _build_dynamics_items(self, query, fares, flight_by_key, base_currency_id, rate_by_currency_id):
        for f in fares:
            price_without_baggage, price_with_baggage = self._get_min_price_from(
                query=query, fare=f, base_currency_id=base_currency_id, rate_by_currency_id=rate_by_currency_id
            )

            route = f['route']

            if len(route[0]) == 0:
                continue

            if not flag_dependent_settings.disable_stops_filter():
                if len(route[0]) > 3 or len(route[1]) > 3:
                    continue

            forward_flights = tuple(flight_by_key[f] for f in route[0])

            backward_flights = tuple(flight_by_key[f] for f in route[1])

            fare_index = self._fare_index_builder.index(forward_flights, backward_flights)

            if price_without_baggage:
                yield self.build_dynamics_item(query, fare_index, False, price_without_baggage)
            if price_with_baggage:
                yield self.build_dynamics_item(query, fare_index, True, price_with_baggage)

    def _build_flight(self, raw_flight):
        arrival_local = datetime.strptime(raw_flight['arrival']['local'], '%Y-%m-%dT%H:%M:%S')
        arrival_utc = arrival_local + timedelta(minutes=raw_flight['arrival']['offset'])
        departure_local = datetime.strptime(raw_flight['departure']['local'], '%Y-%m-%dT%H:%M:%S')
        departure_utc = departure_local + timedelta(minutes=raw_flight['departure']['offset'])

        return Flight(
            departure=departure_local,
            arrival=arrival_local,
            departure_utc=departure_utc,
            arrival_utc=arrival_utc,
            airline_id=raw_flight['company'],
            from_id=raw_flight['from'],
            to_id=raw_flight['to'],
            number=raw_flight['number'],
        )

    def _filter_dynamic_items(self, dynamics_rows):
        if not dynamics_rows:
            return

        def _make_keys(dynamic_row):
            airlines = dynamic_row.airlines
            if not dynamic_row.airlines:
                airlines = [None]

            for airline in airlines:
                yield (
                    dynamic_row.has_baggage,
                    dynamic_row.count_transfer,
                    airline,
                    # аэропорты взлета и прибытия
                    dynamic_row.forward_departure_airport,
                    dynamic_row.forward_arrival_airport,
                    dynamic_row.backward_departure_airport,
                    dynamic_row.backward_arrival_airport,
                    # время прибытия
                    dynamic_row.forward_departure_time_type,
                    dynamic_row.forward_arrival_time_type,
                    dynamic_row.backward_departure_time_type,
                    dynamic_row.backward_arrival_time_type,
                )

        has_key = set()
        max_items = 100

        dynamics_rows = sorted(dynamics_rows, key=lambda d: d.base_value)

        without_transfers = []
        with_transfers = []

        for item in dynamics_rows:
            if item.count_transfer == 0:
                without_transfers.append(item)
            else:
                with_transfers.append(item)

        dynamics_rows = [dynamics_rows[0]] + without_transfers + with_transfers

        for item in sorted(dynamics_rows, key=lambda d: d.base_value):
            we_can_add_it = False
            for k in _make_keys(item):
                if k not in has_key:
                    we_can_add_it = True
                    has_key.add(k)

            if we_can_add_it:
                max_items -= 1
                yield item
                if max_items == 0:
                    return

    def _update_history(self, session, query, has_variants):
        # type: (Session, Query, bool) -> None
        backward_date = query.backward_date or NULL_DATE
        has_history = bool(
            session.query(History)
            .filter(
                History.national_version_id == query.national_version_id,
                History.from_id == query.from_id,
                History.to_id == query.to_id,
                History.adults_count == query.adults_count,
                History.children_count == query.children_count,
                History.infants_count == query.infants_count,
                History.forward_date == query.forward_date,
                History.backward_date == backward_date,
            )
            .count()
        )

        if not has_history:
            session.add(
                History(
                    national_version_id=query.national_version_id,
                    from_id=query.from_id,
                    to_id=query.to_id,
                    adults_count=query.adults_count,
                    children_count=query.children_count,
                    infants_count=query.infants_count,
                    forward_date=query.forward_date,
                    backward_date=backward_date,
                    created_at=datetime.now(),
                    has_variants=has_variants,
                    updated_at=datetime.now(),
                )
            )
        else:
            session.query(History).filter(
                History.national_version_id == query.national_version_id,
                History.from_id == query.from_id,
                History.to_id == query.to_id,
                History.adults_count == query.adults_count,
                History.children_count == query.children_count,
                History.infants_count == query.infants_count,
                History.forward_date == query.forward_date,
                History.backward_date == backward_date,
            ).update({'updated_at': datetime.now(), 'has_variants': has_variants}, synchronize_session=False)
        session.flush()

    def _index(self, session, query, result):
        if result['status'] != 'success':
            self._logger.warn('unexpected data status: [%r]', result['status'])
            return

        progress = result['data']['progress']
        if progress["all"] != progress["current"]:
            self._logger.warn(
                'Can not index variants, ' 'because progress status is broken %d/%d',
                progress["all"],
                progress["current"],
            )
            return

        base_currency_id = self._rates_provider.get_base_currency_id(query.national_version_id)
        if base_currency_id is None:
            raise Exception('Can not fetch base currency id for [{}]'.format(query.national_version_id))
        rates_by_currency_id = self._rates_provider.get_rates_for(query.national_version_id)
        if rates_by_currency_id is None:
            raise Exception('Can not rates by currency id for [{}]'.format(query.national_version_id))

        fare_list = result['data']['variants']['fares']
        flight_by_key = {f['key']: self._build_flight(f) for f in result['data']['reference']['flights']}
        self._logger.info('Before build dynamics items')
        dynamics_rows = list(
            self._build_dynamics_items(query, fare_list, flight_by_key, base_currency_id, rates_by_currency_id)
        )
        self._logger.info('Finish build dynamics items')

        self._logger.info('Before filter dynamic items')
        new_dynamic_rows = list(self._filter_dynamic_items(dynamics_rows))
        self._logger.info('Finish filter dynamic items')

        self._update_history(session, query, bool(len(new_dynamic_rows)))

        old_dynamic_rows = self._fetch_dynamic_rows(session=session, query=query)

        self._logger.info('Rows count exists %r', len(fare_list))

        diff = self._build_diff(new_dynamic_rows=new_dynamic_rows, old_dynamic_rows=old_dynamic_rows)

        self._logger.info(
            'Diff[I/U/D]: [%d]/[%d]/[%d]',
            len(diff['insert_data']),
            len(diff['update_data']),
            len(diff['delete_data']),
        )

        self._apply_patch(session=session, query=query, diff=diff)


indexer = Indexer(
    logger=getLogger(__name__),
    fare_index_builder=fare_index_builder,
    price_converter=price_converter,
    rates_provider=rates_provider,
    currency_provider=currency_provider,
)
