import logging
import pytz
import typing

from datetime import datetime
from dateutil.relativedelta import relativedelta

from sqlalchemy import and_

from yql.client.parameter_value_builder import YqlParameterValueBuilder

from crm.agency_cabinet.common.consts.reward import RewardsTypes, nearest_period_start_month, reward_to_duration
from crm.agency_cabinet.common.consts.service import Services, SERVICE_DISCOUNT_TYPE_MAP
from crm.agency_cabinet.common.yql.base import YqlModelLoader
from crm.agency_cabinet.common.yt.base import async_to_sync, ConstantExtractor, MethodExtractor

from crm.agency_cabinet.rewards.server import config
from crm.agency_cabinet.rewards.server.src.celery.tasks.rewards.common import (
    BaseRewardLoader, BaseServiceRewardLoader
)
from crm.agency_cabinet.rewards.server.src.db import models

LOGGER = logging.getLogger('celery.load_rewards')
BASE_SOURCE_PATH = '//home/balance/prod/bo/v_ar_rewards'

EARLY_PAYMENT_PERCENT = 2
EARLY_PAYMENT_REWARD_TYPES = {311}

MONTH_PAYMENT_REWARD_TYPES = {1, 10, 201, 210, 301, 310, 311}
MONTH_ACCRUAL_REWARD_TYPES = {1, 201, 301, 311}
QUARTER_REWARD_TYPES = {20, 220, 320}
SEMIYEAR_REWARD_TYPES = {2, 202, 302}

MOSCOW_TIMEZONE = pytz.timezone('Europe/Moscow')

YQL_REWARD_QUERY = '''
USE hahn;

DECLARE $table_name AS String;
DECLARE $period_from AS String;
DECLARE $period_to AS String;

$load_rewards = (
SELECT
    contract_id,
    from_dt AS period_from,
    SUM(reward_to_pay) AS payment
FROM $table_name AS rewards
WHERE rewards.from_dt >= $period_from
  AND rewards.from_dt < $period_to
  AND rewards.reward_type IS NOT NULL
  AND rewards.reward_type IN {reward_types}
GROUP BY (rewards.contract_id, rewards.from_dt)
);

SELECT
    contract_id,
    period_from,
    SUM(payment) AS payment
FROM $load_rewards
GROUP BY (contract_id, period_from)
ORDER BY period_from;
'''

YQL_SERVICE_REWARD_QUERY = '''
USE hahn;

DECLARE $table_name AS String;
DECLARE $period_from AS String;
DECLARE $period_to AS String;

$load_rewards = (
SELECT
    contract_id,
    from_dt AS period_from,
    discount_type,
    SUM(reward_to_charge) AS payment,
    SUM(turnover_to_charge) AS revenue,
    CASE
        WHEN rewards.reward_type IN {early_payment_reward_types} THEN {min_early_payment_reward_type}
        ELSE {min_reward_type}
    END AS reward_type
FROM $table_name AS rewards
WHERE rewards.discount_type IN {all_discount_types}
  AND rewards.from_dt >= $period_from
  AND rewards.from_dt < $period_to
  AND rewards.reward_type IS NOT NULL
  AND rewards.reward_type IN {reward_types}
GROUP BY (rewards.contract_id, rewards.from_dt, rewards.discount_type, rewards.reward_type)
);

SELECT
    contract_id,
    period_from,
    discount_type,
    SUM(payment) AS payment,
    SUM(revenue) AS revenue,
    reward_type
FROM $load_rewards
GROUP BY (contract_id, period_from, discount_type, reward_type)
ORDER BY period_from;
'''


def list_reward_types(reward_type: RewardsTypes, is_loading_service_rewards: bool = False) -> set[int]:
    if reward_type == RewardsTypes.month:
        if is_loading_service_rewards:
            return MONTH_ACCRUAL_REWARD_TYPES

        return MONTH_PAYMENT_REWARD_TYPES
    elif reward_type == RewardsTypes.quarter:
        return QUARTER_REWARD_TYPES
    elif reward_type == RewardsTypes.semiyear:
        return SEMIYEAR_REWARD_TYPES


def list_discount_types():
    all_discount_types = set()
    for _, discount_type_values in SERVICE_DISCOUNT_TYPE_MAP.items():
        for discount_type in discount_type_values:
            all_discount_types.add(discount_type)

    return all_discount_types


class CommonPeriodFromPreprocessor(YqlModelLoader):
    def _preprocess_period_from(self, row) -> typing.Tuple[int, int, int]:
        init_datetime = datetime.strptime(self._get_column_value(row, 'period_from'), '%Y-%m-%dT%H:%M:%S%z')
        init_datetime = init_datetime.astimezone(MOSCOW_TIMEZONE)

        return init_datetime.year, init_datetime.month, init_datetime.day


class RewardLoader(CommonPeriodFromPreprocessor, BaseRewardLoader):
    IS_LOADING_PREDICTIONS = False

    def _init(self, period_from: datetime, **kwargs):
        self.period_from = period_from

        super()._init(**kwargs)

    def _before_start(self):
        super()._before_start()

        LOGGER.debug(
            'Predictions clearing started... period_from: \'%s\', reward_type: \'%s\'',
            str(self.period_from), self.reward_type.value
        )

        @async_to_sync
        async def _clear_reward_predictions(reward_type: str, period_from: datetime):
            async with self.db_bind:
                return await models.Reward.delete.where(
                    and_(
                        models.Reward.type == reward_type,
                        models.Reward.period_from == period_from,
                        models.Reward.predict.is_(True)
                    )
                ).gino.status(read_only=False, reuse=False)

        _clear_reward_predictions(
            self.reward_type.value,
            self.period_from
        )

        LOGGER.error('Predictions clearing finished...')


class ServiceRewardLoader(CommonPeriodFromPreprocessor, BaseServiceRewardLoader):
    IS_LOADING_PREDICTIONS = False

    def __is_early_payment(self, row) -> bool:
        return int(self._get_column_value(row, 'reward_type')) in EARLY_PAYMENT_REWARD_TYPES

    def _extract_discount_type(self, row) -> int:
        return int(self._get_column_value(row, 'discount_type'))

    def _extract_service_name(self, row) -> str:
        if self.__is_early_payment(row):
            return Services.early_payment.value

        return super()._extract_service_name(row)

    def _extract_revenue(self, row) -> typing.Optional:
        # for early payments revenue is 0 in YT tables, but we need NULLs
        if self.__is_early_payment(row):
            return None

        return self._get_column_value(row, 'revenue')

    def _extract_reward_percent(self, row):
        if self.__is_early_payment(row):
            return EARLY_PAYMENT_PERCENT
        payment = self._get_column_value(row, 'payment')
        revenue = self._get_column_value(row, 'revenue')
        if payment is None or revenue is None:
            return None
        return 100 * payment / revenue if revenue else 0


def substitute_values(reward_type: RewardsTypes, query: str, is_loading_service_rewards: bool) -> str:
    reward_types = list_reward_types(reward_type, is_loading_service_rewards)
    reward_types_without_early_payments = reward_types.difference(EARLY_PAYMENT_REWARD_TYPES)

    query = query.replace(
        '{reward_types}',
        '(' + ', '.join(map(str, reward_types)) + ')'
    )
    query = query.replace(
        '{all_discount_types}',
        '(' + ', '.join(map(str, list_discount_types())) + ')'
    )
    query = query.replace(
        '{early_payment_reward_types}',
        '(' + ', '.join(map(str, EARLY_PAYMENT_REWARD_TYPES)) + ')'
    )
    query = query.replace(
        '{min_reward_type}',
        str(min(reward_types_without_early_payments))
    )
    query = query.replace(
        '{min_early_payment_reward_type}',
        str(min(EARLY_PAYMENT_REWARD_TYPES))
    )

    return query


def common_rewards_info_load(reward_type: RewardsTypes, period_from: str, force_load: bool):
    if period_from:
        period_from = datetime.strptime(period_from, '%Y-%m-%d').replace(tzinfo=MOSCOW_TIMEZONE)
    else:
        current_datetime = datetime.now() - relativedelta(months=reward_to_duration(reward_type))
        start_year, start_month = current_datetime.year, current_datetime.month

        nearest_start_month = nearest_period_start_month(reward_type=reward_type, month=start_month)
        if start_month < nearest_start_month:
            start_year -= 1

        period_from = datetime(year=start_year, month=nearest_start_month, day=1, tzinfo=MOSCOW_TIMEZONE)

    reward_query = substitute_values(
        reward_type=reward_type,
        query=YQL_REWARD_QUERY,
        is_loading_service_rewards=False
    )
    service_reward_query = substitute_values(
        reward_type=reward_type,
        query=YQL_SERVICE_REWARD_QUERY,
        is_loading_service_rewards=True
    )

    period_from_param = period_from.astimezone(pytz.utc).date()

    parameters = {
        '$table_name': YqlParameterValueBuilder.make_string(
            BASE_SOURCE_PATH
        ),
        '$period_from': YqlParameterValueBuilder.make_string(
            period_from_param.isoformat()
        ),
        '$period_to': YqlParameterValueBuilder.make_string(
            (period_from_param + relativedelta(months=1)).isoformat()
        )
    }

    rewards_loader = RewardLoader(
        table_path=BASE_SOURCE_PATH,
        model=models.Reward,
        columns_mapper={
            'contract_id': 'contract_id',
            'type': ConstantExtractor(reward_type.value),
            'is_accrued': ConstantExtractor(True),
            'is_prof': ConstantExtractor(False),
            'payment': 'payment',
            'period_from': MethodExtractor('_extract_period_from'),
            'predict': ConstantExtractor(False),
        },
        default_columns={},
        client_config={
            'cluster': 'hahn',
            'token': config.YT_CONFIG['TOKEN'],
            'config': {}
        },
        force_load=force_load,
        yql_token=config.YQL_CONFIG['TOKEN'],
        yql_query=reward_query,
        yql_parameters=parameters,
        reward_type=reward_type,
        period_from=period_from
    )
    rewards_loader.load()

    service_rewards_loader = ServiceRewardLoader(
        table_path=BASE_SOURCE_PATH,
        model=models.ServiceReward,
        columns_mapper={
            'reward_id': MethodExtractor('_extract_reward_id'),
            'service': MethodExtractor('_extract_service_name'),
            'discount_type': MethodExtractor('_extract_discount_type'),
            'payment': 'payment',
            'revenue': MethodExtractor('_extract_revenue'),
            'reward_percent': MethodExtractor('_extract_reward_percent'),
        },
        default_columns={},
        client_config={
            'cluster': 'hahn',
            'token': config.YT_CONFIG['TOKEN'],
            'config': {}
        },
        force_load=rewards_loader.is_loaded_last_time,
        yql_token=config.YQL_CONFIG['TOKEN'],
        yql_query=service_reward_query,
        yql_parameters=parameters,
        reward_type=reward_type,
    )
    service_rewards_loader.load()
