import logging
from argparse import Namespace
from copy import copy
from dataclasses import asdict, dataclass
from datetime import timezone, datetime
from decimal import Decimal, ROUND_UP
from time import sleep
from uuid import uuid4

# noinspection PyUnresolvedReferences
from yt.wrapper import YtClient

from travel.hotels.lib.python3.yql.yqllib import create_client, run_query, wait_results
from travel.hotels.lib.python3.yt.ytlib import ensure_table_exists, ypath_join, schema_from_dict
from travel.marketing.tools.lototron.event import Event
from travel.marketing.tools.lototron.media_billing_client import MediaBillingClient, TopupRequest, TopupStatus
from travel.marketing.tools.lototron.order import CpaOrder, Order


@dataclass
class OrderTopupStatus:
    partner_order_id: str
    transaction_id: str
    processed: bool
    processed_at: str
    topup_points: int


class PrizeBox:

    __points_max__ = 50_000

    __time_to_topup__ = 60

    __topup_status_schema__ = schema_from_dict({
        'partner_order_id': 'string',
        'transaction_id': 'string',
        'processed': 'boolean',
        'processed_at': 'string',
        'topup_points': 'int64',
    })

    __common_fields__ = [
        'category',
        'partner_order_id',
        'status',
        'order_amount',
        'profit_amount',
    ]

    __cpa_order_tables__ = {
        '//home/travel/prod/cpa/avia/orders': [
            'date_forward',
            'date_backward',
        ],
        '//home/travel/prod/cpa/hotels/orders': [
            'check_out',
            'yandex_plus_topup_points',
        ],
        '//home/travel/prod/cpa/train/orders': [
            'arrival',
        ],
    }

    def __init__(self, args: Namespace, tz: timezone):
        self.billing_campaign_id = args.billing_campaign_id
        self.results_path = args.results_path
        self.tz = tz

        self.now = datetime.now(tz)
        self.event = Event(args, self.now, tz)

        self.yql_client = create_client(token=args.yql_token, db=args.yt_proxy)
        self.yt_client = YtClient(args.yt_proxy, args.yt_token)

        self.media_billing_client = MediaBillingClient(
            url=args.media_billing_url,
            tvm_service_id=args.billing_tvm_service_id,
            tvm_client_id=args.billing_tvm_client_id,
            tvm_secret=args.billing_tvm_secret,
        )

    def run(self):
        topup_status_table = ypath_join(self.results_path, 'data', 'topup_status')

        won_orders = self._get_won_orders()
        logging.info(f'{len(won_orders)} won orders')

        actual_orders = self._get_actual_orders(won_orders)
        logging.info(f'{len(actual_orders)} actual orders')

        logging.info(f'topup_finish_time {self._get_topup_finish_time(actual_orders.values())}')

        orders_to_topup = self._get_orders_to_topup(actual_orders, won_orders)
        logging.info(f'{len(orders_to_topup)} orders to topup')

        with self.yt_client.Transaction():
            topup_status = self._get_topup_status(topup_status_table)
            logging.info(f'{len(topup_status)} loaded topup statuses')

            self._update_topup_status(orders_to_topup.values(), topup_status)
            logging.info(f'{len(topup_status)} total topup statuses')

            self._save_topup_status(topup_status.values(), topup_status_table)
            logging.info('Topup status saved')

        with self.yt_client.Transaction():
            self._process_orders(orders_to_topup, topup_status.values())
            logging.info('Order processing finished')

            self._save_topup_status(topup_status.values(), topup_status_table)
            logging.info('Topup status saved')

        logging.info('All done')

    def _get_actual_orders(self, won_orders: [Order]) -> dict[str, CpaOrder]:
        orders = dict()
        for path, additional_fields in self.__cpa_order_tables__.items():
            orders.update(self._get_actual_orders_at_path(won_orders, path, self.__common_fields__ + additional_fields))
        return orders

    def _get_won_orders(self) -> list[Order]:
        won_orders = dict()
        orders = list()
        orders.extend(self._get_won_orders_at_path(ypath_join(self.results_path, 'results')))
        orders.extend(self._get_won_orders_at_path(ypath_join(self.results_path, 'manual_results')))
        for order in orders:
            won_orders[order.partner_order_id] = order
        return list(won_orders.values())

    def _get_won_orders_at_path(self, path: str):
        orders = list()
        for node in self.yt_client.list(path):
            table_path = ypath_join(path, node, 'winners')
            for row in self.yt_client.read_table(table_path):
                row['order_status'] = 'UNKNOWN'
                orders.append(Order(**row))
        return orders

    def _get_orders_to_topup(self, actual_orders: dict[str, CpaOrder], won_orders: [Order]) -> dict[str, CpaOrder]:
        orders = dict()
        for won_order in won_orders:
            actual_order = actual_orders[won_order.partner_order_id]
            merged_order = self._get_merged_order(actual_order, won_order)
            if not merged_order.passport_id:
                logging.info(f'Has no passport_id: {merged_order}')
                continue
            if merged_order.order_status != 'confirmed':
                logging.info(f'Not confirmed: {merged_order}')
                continue
            if merged_order.topup_time > self.now:
                logging.info(f'Waiting for finish: {merged_order}')
                continue
            orders[merged_order.partner_order_id] = merged_order
        return orders

    def _get_actual_orders_at_path(self, orders_to_topup: [Order], path: str, fields: [str]) -> dict[str, CpaOrder]:
        logging.info(f'Getting orders from {path}')
        query_fields = ', '.join(fields)
        order_ids = ', '.join(self._get_quoted_field(o.partner_order_id) for o in orders_to_topup)
        orders = dict()
        temp = self.yt_client.create_temp_table()
        query = f'''
            INSERT INTO `{temp}` WITH TRUNCATE
            SELECT
            {query_fields}
            FROM `{path}`
            WHERE partner_order_id in ({order_ids})
        '''
        request = run_query(self.yql_client, query)
        wait_results(request)
        for row in self.yt_client.read_table(temp):
            order = CpaOrder.from_raw_order(row, self.tz)
            orders[order.partner_order_id] = order
        return orders

    @staticmethod
    def _get_quoted_field(f: str):
        return f'"{f}"'

    def _get_merged_order(self, actual_order: CpaOrder, won_order: Order) -> CpaOrder:
        order = copy(actual_order)
        order.passport_id = won_order.passport_id
        order.order_amount = min(actual_order.order_amount, won_order.order_amount)
        points_to_topup_total = min(won_order.order_amount, actual_order.order_amount, self.__points_max__)
        points_to_topup = points_to_topup_total - order.topup_points
        if points_to_topup > int(points_to_topup):
            points_to_topup += 1
        order.points_to_topup = int(points_to_topup)
        return order

    def _get_topup_status(self, topup_status_table: str) -> dict[str, OrderTopupStatus]:
        status = dict()
        if not self.yt_client.exists(topup_status_table):
            return status
        for row in self.yt_client.read_table(topup_status_table):
            order_topup_status = OrderTopupStatus(**row)
            status[order_topup_status.partner_order_id] = order_topup_status
        return status

    @staticmethod
    def _update_topup_status(
        orders_to_topup: [Order],
        topup_status: dict[str, OrderTopupStatus],
    ) -> None:
        for order in orders_to_topup:
            order_topup_status = topup_status.get(order.partner_order_id)
            if order_topup_status is None:
                order_topup_status = OrderTopupStatus(
                    partner_order_id=order.partner_order_id,
                    transaction_id=str(uuid4()),
                    processed=False,
                    processed_at='',
                    topup_points=order.points_to_topup,
                )
                topup_status[order_topup_status.partner_order_id] = order_topup_status

    def _process_orders(
        self,
        orders_to_topup: dict[str, CpaOrder],
        topup_status: [OrderTopupStatus],
    ) -> None:
        for order_topup_status in topup_status:
            if not self._update_order_topup_status(order_topup_status):
                continue
            topup_request = self._get_topup_request(
                orders_to_topup[order_topup_status.partner_order_id], order_topup_status.transaction_id
            )
            logging.info(f'Requesting topup {topup_request}')
            self.media_billing_client.request_topup(topup_request)
            logging.info('Topup request finished successfully')

        sleep(self.__time_to_topup__)
        for order_topup_status in topup_status:
            self._update_order_topup_status(order_topup_status)

    def _update_order_topup_status(self, order_topup_status: OrderTopupStatus) -> bool:
        if order_topup_status.processed:
            logging.info(f'Already processed {order_topup_status}')
            return False
        billing_topup_status = self.media_billing_client.get_topup_status(order_topup_status.transaction_id)
        if billing_topup_status == TopupStatus.COMPLETED:
            order_topup_status.processed = True
            order_topup_status.processed_at = str(self.now)
            logging.info(f'Processed now {order_topup_status}')
            return False
        elif billing_topup_status == TopupStatus.RUNNING:
            logging.info(f'Still processing {order_topup_status}')
            return False
        elif billing_topup_status == TopupStatus.FAILED:
            raise RuntimeError(f'Got FAILED billing topup status for {order_topup_status}')
        elif billing_topup_status == TopupStatus.NOT_STARTED:
            logging.info(f'Topup not requested yet {order_topup_status}')
        else:
            raise RuntimeError(f'Unknown billing topup status for {order_topup_status}')
        return True

    def _save_topup_status(self, topup_status: [OrderTopupStatus], topup_status_table: str) -> None:
        ensure_table_exists(topup_status_table, self.yt_client, self.__topup_status_schema__)
        data = (asdict(ts) for ts in topup_status)
        self.yt_client.write_table(topup_status_table, data)

    def _get_topup_request(self, order: CpaOrder, transaction_id: str) -> TopupRequest:
        return TopupRequest(
            transaction_id=transaction_id,
            passport_id=order.passport_id,
            campaign_id=self.billing_campaign_id,
            amount=order.points_to_topup,
            payload=self._get_payload(order),
        )

    @staticmethod
    def _get_payload(order: CpaOrder) -> dict[str, str]:
        commission_amount = order.profit_amount
        vat_commission_amount = (Decimal(commission_amount) / Decimal(1.2)).quantize(Decimal('1.00'), ROUND_UP)
        payload = dict()
        payload['base_amount'] = str(order.order_amount)
        payload['commission_amount'] = str(commission_amount)
        payload['vat_commission_amount'] = str(vat_commission_amount)
        return payload

    def _get_topup_finish_time(self, actual_orders: [CpaOrder]) -> datetime:
        topup_finish_time = self.event.event_start
        for order in actual_orders:
            topup_finish_time = max(topup_finish_time, order.topup_time)
        return topup_finish_time
