import logging
import sys
from argparse import ArgumentParser, Namespace
from datetime import date, timedelta
from typing import Any, Optional, NamedTuple

# noinspection PyUnresolvedReferences
import travel.proto.commons_pb2 as commons_pb2
from travel.hotels.lib.python3.yql import yqllib
from travel.hotels.lib.python3.yt.ytlib import create_table, schema_from_dict, ypath_join
from yql.api.v1.client import YqlClient
from yt.wrapper import YtClient

from travel.library.python.s3_client import S3Client
from travel.library.python.tools import replace_args_from_env
from travel.library.python.train_tariffs_snapshot_reader import SnapshotReader
from travel.hotels.tools.smart_banner_feed_trains_builder.xml_converter import TrainSmartBannersConverter


# touch here: 1


class TrainKey(NamedTuple):
    station_id_from: int
    station_id_to: int
    number: str


class TrainValue(NamedTuple):
    provider: str
    display_number: int
    departure: int
    arrival: str
    min_price: float


class SmartBannerFeedTrainsBuilder:

    __min_prices_table_schema__ = schema_from_dict({
        'number': 'string',
        'display_number': 'string',
        'provider': 'string',
        'station_id_from': 'int32',
        'station_id_to': 'int32',
        'departure': 'int32',
        'arrival': 'int32',
        'min_price': 'double',
    })

    def __init__(self, args: Namespace):
        self.args = args
        self.yt_client = YtClient(args.yt_proxy, args.yt_token)
        self.yql_client = YqlClient(db=args.yt_proxy, token=args.yql_token)
        snapshots_s3_client = S3Client(
            endpoint=self.args.s3_endpoint,
            bucket=self.args.snapshots_s3_bucket,
            access_key=self.args.snapshots_s3_access_key,
            access_secret_key=self.args.snapshots_s3_access_secret_key,
        )
        self.snapshot_reader = SnapshotReader(snapshots_s3_client)
        self.result_s3_client = S3Client(
            endpoint=self.args.s3_endpoint,
            bucket=self.args.dst_s3_bucket,
            access_key=self.args.dst_s3_access_key,
            access_secret_key=self.args.dst_s3_access_secret_key,
        )
        self.xml = TrainSmartBannersConverter()
        self.target_date = args.target_date
        if not self.target_date:
            self.target_date = date.today() + timedelta(days=1)

    def run(self):
        trains_table = ypath_join(self.args.results_path, 'trains')
        categories_table = ypath_join(self.args.results_path, 'categories')
        min_prices_table = ypath_join(self.args.results_path, 'min_prices')

        if self.args.skip_min_prices_calc:
            logging.info('Skipping min prices calculation')
        else:
            self._write_min_prices(min_prices_table)

        if self.args.skip_run_query:
            logging.info('Skipping run yt query')
        else:
            query_args = {
                '$input_min_prices_table': min_prices_table,
                '$output_trains_table': trains_table,
                '$output_categories_table': categories_table,
            }
            self._run_query('smart_banner_feed_trains.yql', query_args)

        data = self.xml.convert(self.yt_client.read_table(categories_table),
                                self.yt_client.read_table(trains_table))
        self.result_s3_client.write(self.args.dst_s3_key, data)

    def _write_min_prices(self, min_prices_table):
        tariffs = self.snapshot_reader.get_tariffs(self.args.snapshots_s3_prefix)

        trains_dict = dict()
        for index, tariff in enumerate(tariffs):
            if index % 100_000 == 0:
                logging.info(f'tariffs processed: {index}')
            if not (tariff.departure_date.Year == self.target_date.year
                    and tariff.departure_date.Month == self.target_date.month
                    and tariff.departure_date.Day == self.target_date.day):
                continue
            for train in tariff.data:
                key = TrainKey(train.departure_station_id, train.arrival_station_id, train.number)
                new_value = self._get_train_value(train)
                self._add_to_trains_dict(trains_dict, key, new_value)

        data = (
            {
                'number': k.number,
                'display_number': v.display_number,
                'provider': v.provider,
                'station_id_from': k.station_id_from,
                'station_id_to': k.station_id_to,
                'departure': v.departure,
                'arrival': v.arrival,
                'min_price': v.min_price,
            } for k, v in trains_dict.items()
        )
        with self.yt_client.Transaction():
            if self.yt_client.exists(min_prices_table):
                self.yt_client.remove(min_prices_table)
            create_table(min_prices_table, self.yt_client, self.__min_prices_table_schema__)
            self.yt_client.write_table(min_prices_table, data)

    @staticmethod
    def _add_to_trains_dict(trains_dict, train_key, train_value):
        if not train_value:
            return
        cached_value = trains_dict.get(train_key)
        if not cached_value:
            trains_dict[train_key] = train_value
        elif cached_value.min_price > train_value.min_price:
            trains_dict[train_key] = train_value

    def _get_train_value(self, train):
        train_min_price = None
        for place in train.places:
            if place.price.Currency != commons_pb2.ECurrency.C_RUB:
                raise Exception(f'Wrong currency for {place}')
            price = self._get_price_value(place.price)
            train_min_price = min(train_min_price, price) if train_min_price else price
        if not train_min_price:
            return None
        return TrainValue(train.provider, train.display_number, train.departure.seconds, train.arrival.seconds,
                          train_min_price)

    @staticmethod
    def _get_price_value(price) -> float:
        return price.Amount / 10 ** price.Precision

    def _run_query(self, query_name: str, query_args: dict[str, Any], transaction_id: Optional[str] = None) -> None:
        logging.info(f'Running query {query_name}, Args: {query_args}')
        yqllib.run_yql_file(
            client=self.yql_client,
            resource_name=query_name,
            project_name=self.__class__.__name__,
            parameters=query_args,
            transaction_id=transaction_id,
        )
        logging.info('Query finished')


def main():
    logging.basicConfig(
        level=logging.INFO,
        format="%(asctime)-15s | %(module)s | %(levelname)s | %(message)s",
        stream=sys.stdout,
    )

    parser = ArgumentParser()
    parser.add_argument('--yt-proxy', default='hahn')
    parser.add_argument('--yt-token', required=True)
    parser.add_argument('--yql-token', required=True)
    parser.add_argument('--s3-endpoint', default='https://s3.mds.yandex.net')
    parser.add_argument('--snapshots-s3-bucket', required=True)
    parser.add_argument('--snapshots-s3-prefix', required=True)
    parser.add_argument('--snapshots-s3-access-key', required=True)
    parser.add_argument('--snapshots-s3-access-secret-key', required=True)
    parser.add_argument('--dst-s3-bucket', required=True)
    parser.add_argument('--dst-s3-key', required=True)
    parser.add_argument('--dst-s3-access-key', required=True)
    parser.add_argument('--dst-s3-access-secret-key', required=True)
    parser.add_argument('--target-date', type=date, default=None)
    parser.add_argument('--results-path', required=True)
    parser.add_argument('--skip-min-prices-calc', action='store_true', help='!!! Debug only feature')
    parser.add_argument('--skip-run-query', action='store_true', help='!!! Debug only feature')

    args = parser.parse_args(args=replace_args_from_env())
    SmartBannerFeedTrainsBuilder(args).run()


if __name__ == '__main__':
    main()
