# -*- coding: utf-8 -*-
import time
import zlib
from datetime import datetime, timedelta

import ujson
from sqlalchemy.sql import text
from typing import Dict, Set, Tuple

from travel.avia.library.python.flying_time.flying_times_cache import FlyingTimesCache
from travel.avia.library.python.shared_dicts.cache.settlement_cache import SettlementCache

from travel.avia.price_index.lib.currency_provider import currency_provider, CurrencyModel
from travel.avia.price_index.lib.db.storage import slave_storage
from travel.avia.price_index.lib.national_version_provider import national_version_provider
from travel.avia.price_index.lib.rates_provider import rates_provider


class Dumper(object):
    def __init__(
        self, national_version, search_url_template, route_url_template, yt_client, landing_routes_table, logger
    ):
        self._national_version = national_version
        self._search_url_template = search_url_template
        self._route_url_template = route_url_template
        self._landing_routes_table = landing_routes_table
        self._yt_client = yt_client
        self._logger = logger

    def _get_landing_routes(self):
        # type: () -> Set[Tuple[int, int]]
        routes = {
            (int(row['from_id']), int(row['to_id']))
            for row in self._yt_client.read_table(self._landing_routes_table)
            if row['national_version'] == self._national_version.lower()
        }
        self._logger.info('Loaded %d landing routes', len(routes))

        return routes

    def _write_to_yt(self, result, destination_table, settlements, flying_times_cache, landing_routes):
        # type: (list, str, SettlementCache, FlyingTimesCache, Set[Tuple[int, int]]) -> None
        self._logger.info('Write to YT')

        yt_results = []

        self._logger.info('Add data from references to results ...')
        for r in result:

            departure_settlement = settlements.get_settlement_by_id(r['departure_settlement_id'])
            if not departure_settlement:
                self._logger.warning('Unknown departure settlement: %d', r['departure_settlement_id'])
                continue
            arrival_settlement = settlements.get_settlement_by_id(r['arrival_settlement_id'])
            if not arrival_settlement:
                self._logger.warning('Unknown arrival settlement: %d', r['arrival_settlement_id'])
                continue

            if (departure_settlement.Id, arrival_settlement.Id) not in landing_routes:
                self._logger.warning(
                    'Unexpected landing route: %s -- %s',
                    departure_settlement.Id,
                    arrival_settlement.Id,
                )
                continue

            r.update(
                {
                    'departure_settlement_geo_id': departure_settlement.GeoId,
                    'arrival_settlement_geo_id': arrival_settlement.GeoId,
                    'search_url': self._search_url_template.format(
                        from_id=departure_settlement.Id,
                        to_id=arrival_settlement.Id,
                        when=r['date'],
                    ),
                    'route_url': self._route_url_template.format(
                        from_slug=settlements.get_slug_by_id(departure_settlement.Id),
                        to_slug=settlements.get_slug_by_id(arrival_settlement.Id),
                    ),
                    'flying_time': flying_times_cache.get_flying_time(
                        departure_settlement.Id,
                        arrival_settlement.Id,
                        r['date'],
                    ),
                }
            )

            yt_results.append(r)

        with self._yt_client.Transaction():
            self._logger.info('Create temp table')
            temp_table = self._yt_client.create_temp_table(
                attributes={
                    'schema': [
                        {'name': 'departure_settlement_id', 'type': 'int64'},
                        {'name': 'departure_settlement_geo_id', 'type': 'int64'},
                        {'name': 'arrival_settlement_id', 'type': 'int64'},
                        {'name': 'arrival_settlement_geo_id', 'type': 'int64'},
                        {'name': 'date', 'type': 'string'},
                        {'name': 'price', 'type': 'double'},
                        {'name': 'currency', 'type': 'string'},
                        {'name': 'transfers', 'type': 'uint8'},
                        {'name': 'search_url', 'type': 'string'},
                        {'name': 'route_url', 'type': 'string'},
                        {'name': 'updated_at', 'type': 'datetime'},
                        {'name': 'flying_time', 'type': 'int32'},
                    ]
                }
            )

            self._logger.info('Write results to %s', temp_table)
            self._yt_client.write_table(temp_table, yt_results)

            self._logger.info('Move "%s" to "%s"', temp_table, destination_table)
            self._yt_client.move(temp_table, destination_table, force=True)

        self._logger.info('Results: %s', destination_table)

    def _get_data(self, national_version_id, from_date, to_date, max_transfers, base_currency, currency_rates):
        # type: (int, datetime, datetime, int, CurrencyModel, Dict[int, float]) -> list
        session = slave_storage.get_session()

        query = text(
            """
            select
                *
            from
                (
                    select
                        from_id,
                        to_id,
                        forward_date,
                        gzip_data,
                        updated_at,
                        row_number() over (partition by from_id, to_id order by base_value) as rn
                    from
                        (
                            select
                                from_id,
                                to_id,
                                forward_date,
                                base_value,
                                gzip_data,
                                updated_at
                            from
                                result
                            where
                                national_version_id = :national_version_id
                                and adults_count = 1
                                and children_count = 0
                                and infants_count = 0
                                and forward_date >= :min_forward_date
                                and forward_date <= :max_forward_date
                                and backward_date = '1900-01-01'
                        ) as tmp
                ) as tmp
            where
                rn=1
        """
        )
        params = {
            'national_version_id': national_version_id,
            'min_forward_date': from_date.strftime('%Y-%m-%d'),
            'max_forward_date': to_date.strftime('%Y-%m-%d'),
        }

        rs = session.execute(query, params)

        result = []
        for row in rs:
            departure_settlement_id = row[0]
            arrival_settlement_id = row[1]
            forward_date = row[2].strftime('%Y-%m-%d')

            min_price = None
            transfers = None
            gzip_data = row[3]
            if not gzip_data:
                continue
            for data in ujson.loads(zlib.decompress(gzip_data)):
                if data['count_transfer'] > max_transfers:
                    continue

                price_value = data['value'] * currency_rates[data['currency_id']]

                if (
                    min_price is None
                    or min_price > price_value
                    or (min_price == price_value and transfers > data['count_transfer'])
                ):
                    min_price = price_value
                    transfers = data['count_transfer']

            if min_price is None:
                continue

            result.append(
                {
                    'departure_settlement_id': departure_settlement_id,
                    'arrival_settlement_id': arrival_settlement_id,
                    'date': forward_date,
                    'price': min_price,
                    'currency': base_currency.code,
                    'transfers': transfers,
                    'updated_at': int(time.mktime(row[4].timetuple())),
                }
            )

            if len(result) % 10000 == 0:
                self._logger.info('Fetched %d directions', len(result))

        self._logger.info('Total directions: %d', len(result))

        return result

    def dump(self, days, max_transfers, yt_table):
        tomorrow = datetime.now() + timedelta(days=1)

        landing_routes = self._get_landing_routes()

        self._logger.info('Fetching settlements')
        settlement_cache = SettlementCache(self._logger)
        settlement_cache.populate()

        self._logger.info('Fetching flying times')
        flying_times_cache = FlyingTimesCache(self._logger)
        flying_times_cache.populate()

        national_version_provider.fetch()
        national_version = national_version_provider.get_by_code(self._national_version)

        currency_provider.fetch()
        rates_provider.fetch()
        base_currency = currency_provider.get_by_id(
            rates_provider.get_base_currency_id(national_version.pk),
            national_version.pk,
        )
        currency_rates = rates_provider.get_rates_for(national_version.pk)

        self._write_to_yt(
            self._get_data(
                national_version.pk,
                tomorrow,
                tomorrow + timedelta(days=days),
                max_transfers,
                base_currency,
                currency_rates,
            ),
            yt_table,
            settlement_cache,
            flying_times_cache,
            landing_routes,
        )
