import logging
from datetime import date, datetime
from itertools import product
from typing import Any, Callable

from dateutil.relativedelta import relativedelta

from yt.wrapper import YtClient

from travel.hotels.lib.python3.yt.ytlib import create_table, schema_from_dict, ypath_join
from travel.hotels.tools.affiliate_data_builder.lib.overridable_config import ConfigItem, OverridableConfig


class CommissionConfig:

    __partners_table__ = 'affiliate_partners'
    __partner_commission_table__ = 'affiliate_partner_commission'
    __user_commission_table__ = 'affiliate_user_commission'

    __partners_keys__ = ('AffiliatePartnerName', 'Category', 'OrderType',)
    __partner_commission_keys__ = ('AffiliatePartnerName', 'Category', 'OrderType',)
    __user_commission_keys__ = ('AffiliatePartnerName', 'Category', 'OrderType', 'UserId',)

    __partners_schema__ = schema_from_dict({
        'Date': 'string',
        'AffiliatePartnerName': 'string',
        'Category': 'string',
        'OrderType': 'string',
    })

    __partner_commission_schema__ = schema_from_dict({
        'Date': 'string',
        'AffiliatePartnerName': 'string',
        'Category': 'string',
        'OrderType': 'string',
        'Commission': 'double',
    })

    __user_commission_schema__ = schema_from_dict({
        'Date': 'string',
        'AffiliatePartnerName': 'string',
        'Category': 'string',
        'OrderType': 'string',
        'UserId': 'string',
        'Commission': 'double',
    })

    def __init__(self, yt_client: YtClient, date_from: date, now: datetime):
        self.yt_client = yt_client
        self.date_from = date_from
        self.now = now
        self.current_date = now.date()
        self.partners = list()
        self.partner_commission = list()
        self.user_commission = list()

    def read_config(self, config_path: str) -> None:
        self.partners = self._get_config_data(
            config_path, self.__partners_table__, self._get_partners_flat, self.__partners_keys__
        )
        self.partner_commission = self._get_config_data(
            config_path, self.__partner_commission_table__, list, self.__partner_commission_keys__
        )
        self.user_commission = self._get_config_data(
            config_path, self.__user_commission_table__, list, self.__user_commission_keys__
        )

    def dump_config(self, dump_path: str) -> None:
        latest_path = ypath_join(dump_path, 'latest')
        if not self._is_config_changed(latest_path):
            logging.info('Config not changed. Nothing to dump')
            return
        self._dump_current_state(dump_path, latest_path)

    def _get_config_data(
        self,
        config_path: str,
        table_name: str,
        raw_converter: Callable[[[]], list],
        key_fields: tuple[str, ...],
    ) -> list[dict[str, Any]]:
        logging.info(f'Getting {table_name} config')
        table_path = ypath_join(config_path, table_name)
        raw_items = raw_converter(self._read_table(table_path))
        config = OverridableConfig(key_fields, self.partners)
        data = list()
        for application_date in self._iter_day_from(self.date_from):
            data.extend(self._get_config_for_date(config, raw_items, application_date))
        return sorted(data, key=self._get_config_item_key(key_fields))

    @staticmethod
    def _get_config_item_key(
        key_fields: tuple[str, ...],
    ) -> Callable[[ConfigItem], tuple[str, ...]]:

        def key_getter(item: ConfigItem) -> tuple[str, ...]:
            key = OverridableConfig.get_item_key(key_fields, item)
            return tuple([str(item['Date'])] + list(key))

        return key_getter

    def _iter_day_from(self, date_from):
        date_from = date_from
        while date_from <= self.current_date:
            yield date_from
            date_from += relativedelta(days=+1)

    @staticmethod
    def _get_config_for_date(
        config: OverridableConfig,
        raw_items: [ConfigItem],
        application_date: date,
    ) -> list[ConfigItem]:
        items = list()
        for item in config.apply(raw_items, application_date):
            item.pop('StartAt', None)
            item.pop('EndAt', None)
            item['Date'] = str(application_date)
            items.append(item)
        return items

    def _get_partners_flat(self, partners: [dict[str, any]]) -> [dict[str, any]]:
        partners_flat = list()
        for partner in partners:
            partners_flat.extend(self._get_partner_record_flat(partner))

        return partners_flat

    @staticmethod
    def _get_partner_record_flat(partner_record: dict[str, any]) -> [dict[str, any]]:
        non_flat_keys = list()
        non_flat_values = list()
        for key, value in partner_record.items():
            if isinstance(value, list):
                non_flat_keys.append(key)
                non_flat_values.append(value)

        if not non_flat_keys:
            return [partner_record]

        expanded_items = list()
        for values in product(*non_flat_values):
            partner = {**partner_record, **dict(zip(non_flat_keys, values))}
            expanded_items.append(partner)

        return expanded_items

    def _is_config_changed(self, latest_path: str) -> bool:
        if not self.yt_client.exists(latest_path):
            return True
        if self._is_config_table_changed(latest_path, self.__partners_table__, self.__partners_keys__, self.partners):
            return True
        if self._is_config_table_changed(
            latest_path, self.__partner_commission_table__, self.__partner_commission_keys__, self.partner_commission
        ):
            return True
        if self._is_config_table_changed(
            latest_path, self.__user_commission_table__, self.__user_commission_keys__, self.user_commission
        ):
            return True
        return False

    def _is_config_table_changed(
        self,
        latest_path: str,
        table_name: str,
        key_fields: tuple[str, ...],
        current_data: [dict[str, Any]],
    ):
        table_path = ypath_join(latest_path, table_name)
        logging.info(f'Checking {table_path}')
        latest_data = self._read_table(table_path)
        latest_data = sorted(latest_data, key=lambda x: OverridableConfig.get_item_key(key_fields, x))
        return latest_data != current_data

    def _dump_current_state(self, dump_path: str, latest_path: str) -> None:
        date_path = ypath_join(dump_path, str(self.current_date))
        if self.yt_client.exists(date_path):
            date_path = ypath_join(dump_path, str(self.now))

        if self.yt_client.exists(date_path):
            self.yt_client.remove(date_path, recursive=True)

        self._write_table(date_path, self.__partners_table__, self.__partners_schema__, self.partners)
        self._write_table(
            date_path, self.__partner_commission_table__, self.__partner_commission_schema__, self.partner_commission
        )
        self._write_table(
            date_path, self.__user_commission_table__, self.__user_commission_schema__, self.user_commission
        )

        logging.info(f'Switching latest to {date_path}')
        self.yt_client.link(date_path, latest_path, force=True)

    def _read_table(self, path: str) -> [dict[str, Any]]:
        return self.yt_client.read_table(path)

    def _write_table(
        self,
        date_path: str,
        table_name: str,
        table_schema: [[dict[str, str]]],
        data: [dict[str, Any]],
    ) -> None:
        table_path = ypath_join(date_path, table_name)
        logging.info(f'Writing {table_name} config to {table_path}')
        create_table(table_path, self.yt_client, table_schema)
        self.yt_client.write_table(table_path, data)
