import logging
import ujson
from datetime import datetime
from typing import Any

import time
import zlib
from sqlalchemy.sql import text
from yt.wrapper import YtClient

from travel.avia.library.python.shared_dicts.cache.settlement_cache import SettlementCache
from travel.avia.library.python.urls.route_landing import TravelAviaRouteLanding
from travel.avia.library.python.urls.search import TravelAviaSearch
from travel.avia.price_index.lib.currency_provider import CurrencyModel
from travel.avia.price_index.lib.currency_provider import currency_provider
from travel.avia.price_index.lib.db.storage import Storage
from travel.avia.price_index.lib.db.storage import slave_storage
from travel.avia.price_index.lib.national_version_provider import national_version_provider
from travel.avia.price_index.lib.rates_provider import rates_provider

logger = logging.getLogger(__name__)

SettlementID = int


def get_landing_routes(
    yt_client: YtClient, landing_routes_table: str, nv: str
) -> set[tuple[SettlementID, SettlementID]]:
    return {
        (int(row['from_id']), int(row['to_id']))
        for row in yt_client.read_table(landing_routes_table)
        if row['national_version'] == nv.lower()
    }


def write_to_yt(results: list[dict[str, Any]], destination_table: str, yt_client: YtClient) -> None:
    logger.info('Write to YT')

    with yt_client.Transaction():
        logger.info('Create temp table')
        temp_table = yt_client.create_temp_table(
            attributes={
                'schema': [
                    {'name': 'departure_settlement_id', 'type': 'int64'},
                    {'name': 'departure_settlement_title', 'type': 'string'},
                    {'name': 'departure_settlement_title_from', 'type': 'string'},
                    {'name': 'arrival_settlement_id', 'type': 'int64'},
                    {'name': 'arrival_settlement_title', 'type': 'string'},
                    {'name': 'arrival_settlement_title_to', 'type': 'string'},
                    {'name': 'forward_date', 'type': 'string'},
                    {'name': 'backward_date', 'type': 'string'},
                    {'name': 'price', 'type': 'double'},
                    {'name': 'currency', 'type': 'string'},
                    {'name': 'transfers', 'type': 'uint8'},
                    {'name': 'search_url', 'type': 'string'},
                    {'name': 'search_url_no_date', 'type': 'string'},
                    {'name': 'route_url', 'type': 'string'},
                    {'name': 'updated_at', 'type': 'datetime'},
                ]
            }
        )

        logger.info('Write results to %s', temp_table)
        yt_client.write_table(temp_table, results)

        logger.info('Move "%s" to "%s"', temp_table, destination_table)
        yt_client.move(temp_table, destination_table, force=True)

    logger.info('Results: %s', destination_table)


def prepare_rows(
    result: list[dict[str, Any]],
    landing_routes: set[tuple[SettlementID, SettlementID]],
    route_landing_url: TravelAviaRouteLanding,
    search_url: TravelAviaSearch,
    settlements: SettlementCache,
):
    yt_results = []
    logger.info('Add data from references to results ...')
    for r in result:
        departure_settlement = settlements.get_settlement_by_id(r['departure_settlement_id'])
        if not departure_settlement:
            logger.warning('Unknown departure settlement: %d', r['departure_settlement_id'])
            continue
        arrival_settlement = settlements.get_settlement_by_id(r['arrival_settlement_id'])
        if not arrival_settlement:
            logger.warning('Unknown arrival settlement: %d', r['arrival_settlement_id'])
            continue

        if '1900-01-01' == r['backward_date']:
            r['backward_date'] = None

        r.update(
            {
                # TODO(mikhailche): use national version or add `language` parameter to select translation
                'departure_settlement_title': departure_settlement.TitleDefault,
                'departure_settlement_title_from': departure_settlement.Title.Ru.Genitive
                or departure_settlement.TitleDefault,
                'arrival_settlement_title': arrival_settlement.TitleDefault,
                'arrival_settlement_title_to': arrival_settlement.Title.Ru.Accusative
                or arrival_settlement.TitleDefault,
                # END(TODO)
                'search_url': search_url.url(
                    from_id=departure_settlement.Id,
                    to_id=arrival_settlement.Id,
                    when=r['forward_date'],
                    return_date=r['backward_date'],
                ),
                'search_url_no_date': search_url.url(
                    from_id=departure_settlement.Id,
                    to_id=arrival_settlement.Id,
                    when='',
                ),
                'route_url': None,
            }
        )

        if (departure_settlement.Id, arrival_settlement.Id) in landing_routes:
            r['route_url'] = route_landing_url.url(
                from_slug=settlements.get_slug_by_id(departure_settlement.Id),
                to_slug=settlements.get_slug_by_id(arrival_settlement.Id),
            )

        yt_results.append(r)
    return yt_results


def dump(
    national_version_id: int,
    from_date: datetime,
    to_date: datetime,
    updated_at: datetime,
    max_transfers: int,
    base_currency: CurrencyModel,
    currency_rates: dict[int, float],
    storage: Storage,
):
    session = storage.get_session()

    query = text(
        """
          SELECT from_id,
                 to_id,
                 forward_date,
                 backward_date,
                 updated_at,
                 gzip_data
          FROM (
                   SELECT from_id,
                          to_id,
                          forward_date,
                          backward_date,
                          base_value,
                          updated_at,
                          gzip_data,
                          row_number() over (partition by from_id, to_id order by base_value) as rn
                   FROM result
                   WHERE adults_count = 1
                     AND children_count = 0
                     AND infants_count = 0
                     AND national_version_id = :national_version_id
                     AND forward_date >= :min_forward_date
                     AND forward_date <= :max_forward_date
                     AND updated_at >= :min_updated_at
                   ORDER BY base_value
               ) tmp
          WHERE rn = 1
          ORDER BY base_value
      """
    )
    params = {
        'national_version_id': national_version_id,
        'min_forward_date': from_date.strftime('%Y-%m-%d'),
        'max_forward_date': to_date.strftime('%Y-%m-%d'),
        'min_updated_at': updated_at.strftime('%Y-%m-%d'),
    }

    rs = session.execute(query, params)
    result = []
    c = {
        name: idx
        for idx, name in enumerate(('from_id', 'to_id', 'forward_date', 'backward_date', 'updated_at', 'gzip_data'))
    }
    for row in rs:
        departure_settlement_id = row[c['from_id']]
        arrival_settlement_id = row[c['to_id']]
        forward_date = row[c['forward_date']].strftime('%Y-%m-%d')
        backward_date = row[c['backward_date']].strftime('%Y-%m-%d')

        min_price = None
        transfers = None
        gzip_data = row[c['gzip_data']]
        if not gzip_data:
            continue
        for data in ujson.loads(zlib.decompress(gzip_data)):
            if data['count_transfer'] > max_transfers:
                continue

            price_value = data['value'] * currency_rates[data['currency_id']]

            if (
                min_price is None
                or min_price > price_value
                or (min_price == price_value and transfers > data['count_transfer'])
            ):
                min_price = price_value
                transfers = data['count_transfer']

        if min_price is None:
            continue

        result.append(
            {
                'departure_settlement_id': departure_settlement_id,
                'arrival_settlement_id': arrival_settlement_id,
                'forward_date': forward_date,
                'backward_date': backward_date,
                'price': min_price,
                'currency': base_currency.code,
                'transfers': transfers,
                'updated_at': int(time.mktime(row[c['updated_at']].timetuple())),
            }
        )

        if len(result) % 10000 == 0:
            logger.info('Fetched %d directions', len(result))

    logger.info('Total directions: %d', len(result))

    return result


class Handler:
    def __init__(self, yt_client: YtClient):
        self._yt_client = yt_client

    def create_feed(
        self,
        from_date: datetime,
        to_date: datetime,
        updated_at: datetime,
        max_transfers: int,
        destination_table: str,
        national_version: str,
        landing_routes_table: str,
        host: str,
    ) -> None:
        search_url = TravelAviaSearch(host)
        route_landing_url = TravelAviaRouteLanding(host)
        landing_routes = get_landing_routes(
            self._yt_client,
            landing_routes_table,
            national_version,
        )

        settlement_cache = SettlementCache(logger)
        settlement_cache.populate()

        national_version_provider.fetch()
        national_version = national_version_provider.get_by_code(national_version)

        currency_provider.fetch()
        rates_provider.fetch()
        base_currency = currency_provider.get_by_id(
            rates_provider.get_base_currency_id(national_version.pk),
            national_version.pk,
        )
        currency_rates = rates_provider.get_rates_for(national_version.pk)

        data = dump(
            national_version.pk,
            from_date,
            to_date,
            updated_at,
            max_transfers,
            base_currency,
            currency_rates,
            storage=slave_storage,
        )
        results = prepare_rows(
            data,
            landing_routes=landing_routes,
            search_url=search_url,
            route_landing_url=route_landing_url,
            settlements=settlement_cache,
        )
        write_to_yt(results, destination_table, yt_client=self._yt_client)
