import csv
import logging
import sys
from argparse import ArgumentParser, Namespace
from io import StringIO
from typing import Any, Optional, NamedTuple

from yql.api.v1.client import YqlClient
from yt.wrapper import YtClient

# noinspection PyUnresolvedReferences
import travel.proto.commons_pb2 as commons_pb2
from travel.library.python.tools import replace_args_from_env
from travel.hotels.lib.python3.yql import yqllib
from travel.hotels.lib.python3.yt.ytlib import create_table, schema_from_dict, ypath_join
from travel.library.python.s3_client import S3Client
from travel.library.python.train_tariffs_snapshot_reader import SnapshotReader


# touch here: 2


class DirectionKey(NamedTuple):
    station_id_from: int
    station_id_to: int


class MarketingFeedTrainsBuilder:

    __field_map__ = {
        'settlement_id_from': 'settlement_id_from*',
        'settlement_id_to': 'settlement_id_to*',
    }

    __min_prices_table_schema__ = schema_from_dict({
        'station_id_from': 'int32',
        'station_id_to': 'int32',
        'min_price': 'double',
    })

    __max_value__ = float('+inf')

    def __init__(self, args: Namespace):
        self.args = args
        self.yt_client = YtClient(args.yt_proxy, args.yt_token)
        self.yql_client = YqlClient(db=args.yt_proxy, token=args.yql_token)
        snapshots_s3_client = S3Client(
            endpoint=self.args.s3_endpoint,
            bucket=self.args.snapshots_s3_bucket,
            access_key=self.args.snapshots_s3_access_key,
            access_secret_key=self.args.snapshots_s3_access_secret_key,
        )
        self.snapshot_reader = SnapshotReader(snapshots_s3_client)
        self.result_s3_client = S3Client(
            endpoint=self.args.s3_endpoint,
            bucket=self.args.dst_s3_bucket,
            access_key=self.args.dst_s3_access_key,
            access_secret_key=self.args.dst_s3_access_secret_key,
        )

    def run(self):
        directions_table = ypath_join(self.args.results_path, 'directions')
        dropped_directions_table = ypath_join(self.args.results_path, 'directions_dropped')
        min_prices_table = ypath_join(self.args.results_path, 'min_prices')

        if self.args.skip_min_prices_calc and self.yt_client.exists(min_prices_table):
            logging.info('Skipping min prices calculation')
        else:
            self._write_min_prices(min_prices_table)

        query_args = {
            '$input_popularity_interval_days': self.args.popularity_interval_days,
            '$input_actual_directions_table': self.args.actual_directions_table,
            '$input_min_prices_table': min_prices_table,
            '$output_directions_table': directions_table,
            '$output_dropped_directions_table': dropped_directions_table,
        }
        self._run_query('prepare_marketing_feed_trains.yql', query_args)

        data = self._get_csv_data(directions_table)
        self.result_s3_client.write(self.args.dst_s3_key, data)

    def _write_min_prices(self, min_prices_table):
        tariffs = self.snapshot_reader.get_tariffs(self.args.snapshots_s3_prefix)

        min_prices = dict()
        for index, tariff in enumerate(tariffs):
            if index % 100_000 == 0:
                logging.info(f'tariffs processed: {index}')

            for train in tariff.data:
                direction_key = DirectionKey(train.arrival_station_id, train.departure_station_id)
                direction_min_price = min_prices.get(direction_key, self.__max_value__)
                direction_min_price = min(direction_min_price, self._get_train_min_price(train))
                if direction_min_price != self.__max_value__:
                    min_prices[direction_key] = direction_min_price

        data = (
            {'station_id_from': k.station_id_from, 'station_id_to': k.station_id_to, 'min_price': p}
            for k, p in min_prices.items()
        )
        with self.yt_client.Transaction():
            if self.yt_client.exists(min_prices_table):
                self.yt_client.remove(min_prices_table)
            create_table(min_prices_table, self.yt_client, self.__min_prices_table_schema__)
            self.yt_client.write_table(min_prices_table, data)

    def _get_train_min_price(self, train):
        train_min_price = self.__max_value__
        for place in train.places:
            if place.price.Currency != commons_pb2.ECurrency.C_RUB:
                raise Exception(f'Wrong currency for {place}')
            train_min_price = min(train_min_price, self._get_price_value(place.price))
        return train_min_price

    @staticmethod
    def _get_price_value(price) -> float:
        return price.Amount / 10 ** price.Precision

    def _get_csv_data(self, directions_table) -> str:
        logging.info(f'Reading data from {directions_table}')
        data = StringIO()

        schema = self.yt_client.get_attribute(directions_table, 'schema')
        field_names = [self._replace_name(item['name']) for item in schema]

        writer = csv.DictWriter(
            data,
            fieldnames=field_names,
            restval='',
            dialect='excel',
            delimiter=',',
            quoting=csv.QUOTE_ALL,
            quotechar='"',
        )
        writer.writerow({f: f for f in field_names})
        for row in self.yt_client.read_table(directions_table):
            writer.writerow(self._get_patched_row(row))
        logging.info('Data preparing finished')
        return data.getvalue()

    def _run_query(self, query_name: str, query_args: dict[str, Any], transaction_id: Optional[str] = None) -> None:
        logging.info(f'Running query {query_name}, Args: {query_args}')
        yqllib.run_yql_file(
            client=self.yql_client,
            resource_name=query_name,
            project_name=self.__class__.__name__,
            parameters=query_args,
            transaction_id=transaction_id,
        )
        logging.info('Query finished')

    def _replace_name(self, name: str) -> str:
        return self.__field_map__.get(name, name)

    def _get_patched_row(self, row: dict[str, Any]) -> dict[str, Any]:
        patched_row = dict()
        for key, value in row.items():
            patched_row[self._replace_name(key)] = value
        return patched_row


def main():
    logging.basicConfig(
        level=logging.INFO,
        format="%(asctime)-15s | %(module)s | %(levelname)s | %(message)s",
        stream=sys.stdout,
    )

    parser = ArgumentParser()
    parser.add_argument('--yt-proxy', default='hahn')
    parser.add_argument('--yt-token', required=True)
    parser.add_argument('--actual-directions-table', required=True)
    parser.add_argument('--yql-token', required=True)
    parser.add_argument('--s3-endpoint', default='https://s3.mds.yandex.net')
    parser.add_argument('--snapshots-s3-bucket', required=True)
    parser.add_argument('--snapshots-s3-prefix', required=True)
    parser.add_argument('--snapshots-s3-access-key', required=True)
    parser.add_argument('--snapshots-s3-access-secret-key', required=True)
    parser.add_argument('--dst-s3-bucket', required=True)
    parser.add_argument('--dst-s3-key', required=True)
    parser.add_argument('--dst-s3-access-key', required=True)
    parser.add_argument('--dst-s3-access-secret-key', required=True)
    parser.add_argument('--popularity-interval-days', type=int, default=60)
    parser.add_argument('--results-path', required=True)
    parser.add_argument('--skip-min-prices-calc', action='store_true', help='!!! Debug only feature')

    args = parser.parse_args(args=replace_args_from_env())
    MarketingFeedTrainsBuilder(args).run()


if __name__ == '__main__':
    main()
