import logging
import typing

from datetime import datetime, date
from dateutil.relativedelta import relativedelta

from yt.wrapper import YtClient
from yql.client.parameter_value_builder import YqlParameterValueBuilder

from crm.agency_cabinet.common.consts.reward import reward_to_duration, nearest_period_start_month, RewardsTypes, \
    get_end_of_current_fin_year, BASE_PREDICTS_SOURCE_PATH
from crm.agency_cabinet.common.yql.base import YqlModelLoader
from crm.agency_cabinet.common.yt.base import async_to_sync, ConstantExtractor, MethodExtractor

from crm.agency_cabinet.rewards.server import config
from crm.agency_cabinet.rewards.server.src.celery.tasks.rewards.common import (
    BaseRewardLoader, BaseServiceRewardLoader
)
from crm.agency_cabinet.rewards.server.src.db import models

TAB = ' ' * 4
LOGGER = logging.getLogger('celery.load_predicted_rewards')

REWARD_TYPE_TO_COLUMN_NAME = {
    RewardsTypes.month: 'predict_reward_M',
    RewardsTypes.quarter: 'predict_reward_Q',
    RewardsTypes.semiyear: 'predict_reward_H',
}

YQL_REWARD_QUERY = '''
USE hahn;

DECLARE $prediction_table_name AS String;
DECLARE $period_from AS String;
DECLARE $period_to AS String;

$reward_predicts = (
SELECT
    MAX(contract_id) AS contract_id,
    CASE dt
        {cases}
        ELSE NULL
    END AS period_from,
    SUM(COALESCE({predict_reward_column}, 0)) AS payment
FROM $prediction_table_name AS predictions
WHERE predictions.dt >= $period_from AND predictions.dt < $period_to AND predictions.discount_types != 'other'
GROUP BY (predictions.contract_id, predictions.dt)
);

SELECT
    contract_id,
    period_from,
    SUM(payment) AS payment
FROM $reward_predicts
GROUP BY (contract_id, period_from)
ORDER BY period_from;
'''

YQL_SERVICE_REWARD_QUERY = '''
USE hahn;

DECLARE $prediction_table_name AS String;
DECLARE $period_from AS String;
DECLARE $period_to AS String;

$reward_predicts = (
SELECT
    MAX(contract_id) AS contract_id,
    CASE dt
        {cases}
        ELSE NULL
    END AS period_from,
    SUM(COALESCE({predict_reward_column}, 0)) AS payment,
    SUM(predict_amt) AS revenue,
    MAX(discount_name) AS service,
    MAX(discount_types) AS discount_type
FROM $prediction_table_name AS predictions
WHERE predictions.dt >= $period_from AND predictions.dt < $period_to AND predictions.discount_types != 'other'
GROUP BY (predictions.contract_id, predictions.dt, predictions.discount_name)
);

SELECT
    contract_id,
    period_from,
    SUM(payment) AS payment,
    SUM(revenue) AS revenue,
    service,
    MAX(discount_type) AS discount_type
FROM $reward_predicts
GROUP BY (contract_id, period_from, service)
ORDER BY period_from;
'''


class CommonPeriodFromPreprocessor(YqlModelLoader):
    def _preprocess_period_from(self, row) -> typing.Tuple[int, int, int]:
        year, month, day = map(int, self._get_column_value(row, 'period_from').strip().split('-'))
        return year, month, day


class PredictRewardLoader(CommonPeriodFromPreprocessor, BaseRewardLoader):
    IS_LOADING_PREDICTIONS = True


class PredictServiceRewardLoader(CommonPeriodFromPreprocessor, BaseServiceRewardLoader):
    IS_LOADING_PREDICTIONS = True

    # converts strings like '1, 2, 13' to integer 1 by extracting the first number
    def _extract_discount_type(self, row) -> int:
        return int(self._get_column_value(row, 'discount_type').split(', ')[0])


def construct_cases(reward_type: RewardsTypes, date_from: date, date_to: date, tabs: int = 2) -> str:
    period_months = reward_to_duration(reward_type)

    period_start = date_from
    current_date = date_from
    counter = 0

    cases = ''

    while current_date < date_to:
        cases += f'WHEN \'{current_date.isoformat()}\' THEN \'{period_start.isoformat()}\'\n{TAB * tabs}'
        current_date += relativedelta(months=1)
        counter += 1

        if counter % period_months == 0:
            period_start += relativedelta(months=period_months)

    return cases.rstrip()


def substitute_values(reward_type: RewardsTypes, query: str, date_from: date, date_to: date) -> str:
    query = query.replace('{predict_reward_column}', REWARD_TYPE_TO_COLUMN_NAME[reward_type])
    query = query.replace('{cases}', construct_cases(reward_type, date_from, date_to))
    return query


def common_predicted_rewards_info_load(reward_type: RewardsTypes):
    client = YtClient(proxy='hahn', token=config.YT_CONFIG['TOKEN'])
    predict_tables = client.search(
        BASE_PREDICTS_SOURCE_PATH,
        node_type=['table'],
    )

    # the last updated predictions table is the most fresh one
    last_predict_table = sorted(predict_tables)[-1]

    @async_to_sync
    async def __get_start_period():
        max_reward_period_from = await models.Reward.query.where(
            models.Reward.predict.is_(False)
        ).order_by(
            models.Reward.period_from.desc()
        ).gino.first()

        if max_reward_period_from is None:
            current_datetime = datetime.now()
            return current_datetime.year, current_datetime.month

        start_datetime = max_reward_period_from.period_from + relativedelta(months=1)
        return start_datetime.year, start_datetime.month

    start_year, start_month = __get_start_period()
    nearest_start_month = nearest_period_start_month(reward_type=reward_type, month=start_month)
    if start_month < nearest_start_month:
        start_year -= 1

    date_from = date(year=start_year, month=nearest_start_month, day=1)
    date_to = get_end_of_current_fin_year(date_from)

    reward_query = substitute_values(
        reward_type=reward_type,
        query=YQL_REWARD_QUERY,
        date_from=date_from,
        date_to=date_to,
    )
    service_reward_query = substitute_values(
        reward_type=reward_type,
        query=YQL_SERVICE_REWARD_QUERY,
        date_from=date_from,
        date_to=date_to,
    )

    parameters = {
        '$prediction_table_name': YqlParameterValueBuilder.make_string(
            f'{last_predict_table}'
        ),
        '$period_from': YqlParameterValueBuilder.make_string(
            date_from.isoformat()
        ),
        '$period_to': YqlParameterValueBuilder.make_string(
            date_to.isoformat()
        ),
    }

    predicted_rewards_loader = PredictRewardLoader(
        table_path=last_predict_table,
        model=models.Reward,
        columns_mapper={
            'contract_id': 'contract_id',
            'type': ConstantExtractor(reward_type.value),
            'is_prof': ConstantExtractor(True),
            'payment': 'payment',
            'period_from': MethodExtractor('_extract_period_from'),
            'predict': ConstantExtractor(True),
        },
        default_columns={},
        client_config={
            'cluster': 'hahn',
            'token': config.YT_CONFIG['TOKEN'],
            'config': {}
        },
        yql_token=config.YQL_CONFIG['TOKEN'],
        yql_query=reward_query,
        yql_parameters=parameters,
        reward_type=reward_type,
    )
    predicted_rewards_loader.load()

    predicted_service_rewards_loader = PredictServiceRewardLoader(
        table_path=last_predict_table,
        model=models.ServiceReward,
        columns_mapper={
            'reward_id': MethodExtractor('_extract_reward_id'),
            'service': MethodExtractor('_extract_service_name'),
            'discount_type': MethodExtractor('_extract_discount_type'),
            'payment': 'payment',
            'revenue': 'revenue',
        },
        default_columns={},
        client_config={
            'cluster': 'hahn',
            'token': config.YT_CONFIG['TOKEN'],
            'config': {}
        },
        force_load=predicted_rewards_loader.is_loaded_last_time,
        yql_token=config.YQL_CONFIG['TOKEN'],
        yql_query=service_reward_query,
        yql_parameters=parameters,
        reward_type=reward_type,
    )
    predicted_service_rewards_loader.load()
