import logging
import statistics
from asgiref.sync import async_to_sync
from collections import defaultdict
from datetime import datetime
from dateutil.relativedelta import relativedelta
from decimal import Decimal
from sqlalchemy import and_
from yt.wrapper import YtClient
from yql.client.parameter_value_builder import YqlParameterValueBuilder
from crm.agency_cabinet.common.consts.common import START_FIN_YEAR_2021, END_FIN_YEAR_2021
from crm.agency_cabinet.common.consts.reward import RewardsTypes, reward_to_duration
from crm.agency_cabinet.common.consts.service import Services
from crm.agency_cabinet.common.yql.base import YqlModelLoader
from crm.agency_cabinet.common.yt.base import MethodExtractor
from crm.agency_cabinet.rewards.common import structs
from crm.agency_cabinet.rewards.common.schemas import calculator as calculator_schemas
from crm.agency_cabinet.rewards.server.config.clients import YT_CONFIG, YQL_CONFIG
from crm.agency_cabinet.rewards.server.src.celery.base import celery_app as celery
from crm.agency_cabinet.rewards.server.src.db import models
from crm.agency_cabinet.rewards.server.src.db.queries import build_contract_to_services_map
from .base import ServiceDataLoader, EarlyPaymentDataLoader, PredictDataLoader, get_predict_start_from

LOGGER = logging.getLogger('celery.calculator.direct')


class DetailedDataLoaderBase(ServiceDataLoader):
    service = Services.direct.value

    def _get_period_from(self):
        return datetime.strptime(self.table_path.strip('/')[-6:], '%Y%m')

    @staticmethod
    def _get_contract_id(yt_row):
        return yt_row['contract_id'] if 'contract_id' in yt_row else yt_row['orig_contract_id']


class MonthPremiumDataLoader(DetailedDataLoaderBase):
    def _init(self, **kwargs):
        @async_to_sync
        async def _get_all_contracts():
            async with self.db_bind:
                return await models.Contract.query.gino.all()

        self.period_from = self._get_period_from()

        all_contracts = _get_all_contracts()
        services_map = async_to_sync(build_contract_to_services_map)((contract.id for contract in all_contracts))

        self.contract_ids = {
            contract.id for contract in all_contracts if Services.direct.value in (
                services_map[contract.id] | (set(contract.services) if contract.services else set())
            )
        }

    @staticmethod
    def get_grade(amt):
        if amt < 150000:
            return 'D'

        elif amt < 1000000:
            return 'C'

        elif amt < 5000000:
            return 'B'

        return 'A'

    @staticmethod
    def get_rsya_revenue(yt_row):
        if 'rsya_amt' in yt_row:
            return yt_row['rsya_amt']

        if yt_row['rsya_reward'] > 0:
            return yt_row['amt']

        return 0

    @staticmethod
    def get_direct_revenue(yt_row):
        if 'direct_amt' in yt_row:
            return yt_row['direct_amt']

        if yt_row['direct_reward'] > 0:
            return yt_row['amt']

        return 0

    @staticmethod
    def get_domain(yt_row):
        if 'domain' in yt_row:
            return yt_row['domain']
        if 'domainid' in yt_row:
            return yt_row['domainid']

    def _make_table_row(self, contract_id, rows):
        domains = defaultdict(int)
        nodomains = []
        for row in rows:
            domain = MonthPremiumDataLoader.get_domain(row)
            revenue = MonthPremiumDataLoader.get_direct_revenue(row)
            if domain:
                domains[domain] += revenue
            else:
                nodomains.append(revenue)

        grade_amt_map = defaultdict(list)
        for domain, amt in domains.items():
            grade = MonthPremiumDataLoader.get_grade(amt)
            grade_amt_map[grade].append(amt)

        for amt in nodomains:
            grade = MonthPremiumDataLoader.get_grade(amt)
            grade_amt_map[grade].append(amt)

        grades = [structs.CalculatorGradeData(
            grade_id=grade_id,
            domains_count=len(revenues),
            revenue_average=statistics.mean(revenues)
        ) for grade_id, revenues in grade_amt_map.items()]

        revenue_rsya = sum([MonthPremiumDataLoader.get_rsya_revenue(row) for row in rows])

        indexes = [structs.CalculatorIndexData(
            index_id='rsya',
            revenue=revenue_rsya
        )]

        months = [structs.CalculatorMonthData(
            period_from=self.period_from,
            predict=False,
            grades=grades,
            indexes=indexes
        )]

        data = structs.CalculatorData(months=months)

        return {
            'contract_id': contract_id,
            'data': calculator_schemas.CalculatorDataSchema().dump(data)
        }

    def _read_table(self, **kwargs):
        table_path_without_last_slash = self.table_path[:-1]

        contracts = defaultdict(list)
        for row in self.client.read_table(table_path_without_last_slash, **kwargs):
            contract_id = MonthPremiumDataLoader._get_contract_id(row)
            if contract_id in self.contract_ids:
                contracts[contract_id].append(row)

        self._table = []
        for contract_id, rows in contracts.items():
            try:
                self._table.append(self._make_table_row(contract_id, rows))
            except Exception as ex:
                LOGGER.exception('Something went wrong: %s', ex)


class DirectEarlyPaymentDataLoader(EarlyPaymentDataLoader):
    service = Services.direct.value

    def make_month_data(self, predict, period_from, index):
        return structs.CalculatorMonthData(
            period_from=period_from,
            predict=False,
            grades=[],
            indexes=[index]
        )

    def make_data(self, months):
        data = structs.CalculatorData(months=months)
        return calculator_schemas.CalculatorDataSchema().dump(data)


class KPIDataLoader(DetailedDataLoaderBase):
    INDEX_IDS = {
        'conversion_autostrategy',
        'key_goals',
        'metrica',
        'video_cpc',
        'smart_banners',
        'rmp',
        'k50'
    }

    YQL_QUERY = '''
    USE hahn;

    DECLARE $detailed_report_table_name AS String;
    SELECT
        orig_contract_id,
        SUM({conversion_autostrategy}) AS conversion_autostrategy,
        SUM({key_goals}) AS key_goals,
        SUM({metrica}) AS metrica,
        SUM({video_cpc}) AS video_cpc,
        SUM({smart_banners}) AS smart_banners,
        SUM({rmp}) AS rmp,
        SUM({k50}) AS k50
    FROM $detailed_report_table_name
    GROUP BY orig_contract_id;
    '''

    INDEX_ID_TO_YT_COLUMN_MAP_Q1 = {
        'conversion_autostrategy': 'amt_autostrategy',
        'key_goals': 'amt_goals',
        'metrica': 'amt_metrika',
        'video_cpc': 'amt_cpc_video',
        'smart_banners': 'amt_smart_banner',
        'rmp': 'amt_rmp',
        'k50': 'amt_k50'
    }

    INDEX_ID_TO_YT_COLUMN_MAP_Q2 = {
        'conversion_autostrategy': 'autostrategy_amt',
        'key_goals': 'goals_amt',
        'metrica': 'metrika_amt',
        'video_cpc': 'cpc_video_amt',
        'smart_banners': 'smart_banner_amt',
        'rmp': 'rmp_amt',
        'k50': 'k50_amt'
    }

    def _init(self, **kwargs):
        period_from = self._get_period_from()
        period_to = period_from + relativedelta(months=reward_to_duration(RewardsTypes.quarter))

        yql_parameters = {
            '$detailed_report_table_name': YqlParameterValueBuilder.make_string(
                f'//{self.table_path[:-1]}'
            ),
        }
        if period_from == START_FIN_YEAR_2021:
            index_to_column_map = self.INDEX_ID_TO_YT_COLUMN_MAP_Q1
        else:
            index_to_column_map = self.INDEX_ID_TO_YT_COLUMN_MAP_Q2

        client_config = {
            'cluster': 'hahn',
            'token': YT_CONFIG['TOKEN'],
            'config': {
                'pool': 'advagencyportal'
            }
        }

        self.yql_loader = YqlModelLoader(
            table_path=self.table_path,
            model=self.model,
            columns_mapper=self.columns_mapper,
            default_columns=self.default_columns,
            client_config=client_config,
            force_load=self.force_load,
            yql_token=kwargs['yql_token'],
            yql_query=KPIDataLoader.YQL_QUERY.format(**index_to_column_map),
            yql_parameters=yql_parameters
        )

        @async_to_sync
        async def _get_all_data():
            async with self.db_bind:
                return await models.CalculatorData.select('contract_id').gino.all()

        @async_to_sync
        async def _get_revenues(contract_ids, period_from, period_to):
            async with self.db_bind:
                service_rewards = await models.ServiceReward.join(models.Reward).select().where(
                    and_(
                        models.ServiceReward.service == Services.direct.value,
                        models.Reward.contract_id.in_(contract_ids),
                        models.Reward.period_from >= period_from,
                        models.Reward.period_from < period_to,
                        models.Reward.type == RewardsTypes.month.value
                    )
                ).gino.all()

                result = defaultdict(dict)
                for r in service_rewards:
                    result[r.contract_id].update({
                        r.period_from: r.revenue
                    })
                return result

        self.contract_ids = {c.contract_id for c in _get_all_data()}
        self.contract_revenue_map = _get_revenues(self.contract_ids, period_from, period_to)

    def _read_table(self, **kwargs):
        self.yql_loader._read_table()
        self._table = []
        for row in self.yql_loader._table:
            contract_id = self.yql_loader._get_column_value(row, 'orig_contract_id')
            if contract_id in self.contract_ids:
                self._table.append(self._make_table_row(contract_id, row))

    def _make_table_row(self, contract_id, row):
        total_revenue = sum(self.contract_revenue_map[contract_id].values())

        months = []
        for period_from, revenue in self.contract_revenue_map[contract_id].items():
            indexes = []
            for index_id in self.INDEX_IDS:
                coeff = Decimal(self.yql_loader._get_column_value(row, index_id)) / total_revenue

                indexes.append(structs.CalculatorIndexData(
                    index_id=index_id,
                    revenue=revenue * coeff
                ))

            months.append(
                structs.CalculatorMonthData(
                    period_from=period_from.replace(tzinfo=None),
                    predict=False,
                    grades=[],
                    indexes=indexes
                )
            )

        data = structs.CalculatorData(months=months)

        return {
            'contract_id': contract_id,
            'data': calculator_schemas.CalculatorDataSchema().dump(data)
        }


class DirectPredictDataLoader(PredictDataLoader):
    service = Services.direct.value

    INDEX_ID_TO_YT_COLUMN_MAP = {
        'early_payment': 'predict_amt',
        'rsya': 'predict_m_amt_rsya',
        'conversion_autostrategy': 'predict_q_amt_auto_Q3',
        'key_goals': 'predict_q_amt_goal_Q3',
        'metrica': 'predict_q_amt_metrika',
        'video_cpc': 'predict_q_amt_video_cpc_Q3',
        'smart_banners': 'predict_q_amt_smart_banner',
        'rmp': 'predict_q_amt_rmp',
        'retargeting': 'predict_q_amt_retargeting',
        'k50': 'predict_q_amt_k50'
    }

    GRADE_TO_YT_COLUMN_MAP = {
        'D': {
            'domains_count': 'predict_m_uniqd_d_150k',
            'revenue_average': 'predict_m_amt_d_150k',
        },

        'C': {
            'domains_count': 'predict_m_uniqd_c_150k_1m',
            'revenue_average': 'predict_m_amt_c_150k_1m',
        },

        'B': {
            'domains_count': 'predict_m_uniqd_b_1m_5m',
            'revenue_average': 'predict_m_amt_b_1m_5m',
        },

        'A': {
            'domains_count': 'predict_m_uniqd_a_5m',
            'revenue_average': 'predict_m_amt_a_5m'
        }
    }

    GRADE_META_INFO = {
        'D': {
            'threshold_start': 0
        },
        'C': {
            'threshold_start': 150000
        },
        'B': {
            'threshold_start': 1000000
        },
        'A': {
            'threshold_start': 5000000
        },
    }

    @staticmethod
    def _make_table_row(row):
        period_from = PredictDataLoader._get_period_from(row)
        contract_id = PredictDataLoader._get_contract_id(row)

        grades = []
        for grade_id, column_map in DirectPredictDataLoader.GRADE_TO_YT_COLUMN_MAP.items():
            domains_count_raw = row[column_map['domains_count']]
            domains_count = round(domains_count_raw)

            if domains_count:
                coef = domains_count_raw / domains_count

                revenue_average = coef * (
                    row[column_map['revenue_average']] / domains_count_raw +
                    DirectPredictDataLoader.GRADE_META_INFO[grade_id]['threshold_start']
                )
                grades.append(structs.CalculatorGradeData(
                    grade_id=grade_id,
                    domains_count=domains_count,
                    revenue_average=revenue_average
                ))

        indexes = []
        for index_id, yt_column in DirectPredictDataLoader.INDEX_ID_TO_YT_COLUMN_MAP.items():
            indexes.append(structs.CalculatorIndexData(
                index_id=index_id,
                revenue=row.get(yt_column, 0)
            ))

        months = [structs.CalculatorMonthData(
            period_from=period_from,
            predict=True,
            grades=grades,
            indexes=indexes
        )]

        data = structs.CalculatorData(months=months)

        return {
            'contract_id': contract_id,
            'data': calculator_schemas.CalculatorDataSchema().dump(data)
        }


@celery.task(bind=True)
def load_calculator_direct_prof_data_task(self):
    client_config = {
        'cluster': 'hahn',
        'token': YT_CONFIG['TOKEN'],
        'config': {
            'pool': 'advagencyportal'
        }
    }

    client = YtClient(proxy='hahn', token=YT_CONFIG['TOKEN'])

    tables = client.search(
        '//home/balance/prod/yb-ar/rewards/reports/yandex/2021-prof_20-m-2021_m_prof/detailed_report',
        node_type=['table'],
    )

    for table in sorted(tables):
        month_data_loader = MonthPremiumDataLoader(
            table_path=table,
            model=models.CalculatorData,
            columns_mapper={
                'contract_id': MethodExtractor('_get_contract_id'),
                'data': 'data'
            },
            default_columns={
                'service': Services.direct.value
            },
            client_config=client_config
        )

        month_data_loader.load()

    early_payment_data_loader = DirectEarlyPaymentDataLoader(
        period_from=START_FIN_YEAR_2021,
        table_path='//home/balance/prod/bo/v_ar_rewards',
        model=models.CalculatorData,
        columns_mapper={
            'contract_id': 'contract_id',
            'data': 'data'
        },
        default_columns={
            'service': Services.direct.value
        },
        client_config=client_config
    )

    early_payment_data_loader.load()

    tables = client.search(
        '//home/balance/prod/yb-ar/rewards/reports/yandex/2021-prof_20-q-2021_q_prof/detailed_report',
        node_type=['table'],
    )

    for table in sorted(tables):
        q_data_loader = KPIDataLoader(
            table_path=table,
            model=models.CalculatorData,
            columns_mapper={
                'contract_id': 'contract_id',
                'data': 'data'
            },
            default_columns={
                'service': Services.direct.value
            },
            client_config=client_config,
            force_load=True,
            yql_token=YQL_CONFIG['TOKEN']
        )

        q_data_loader.load()

    predict_tables = client.search(
        '//home/search-research/ga/agency_rewards/forecast/v1/predict/amt/release',
        node_type=['table'],
    )

    predict_loader = DirectPredictDataLoader(
        period_from=get_predict_start_from(),
        period_to=END_FIN_YEAR_2021,
        table_path=sorted(predict_tables)[-1],
        model=models.CalculatorData,
        columns_mapper={
            'contract_id': MethodExtractor('_get_contract_id'),
            'data': 'data'
        },
        default_columns={
            'service': Services.direct.value
        },
        client_config=client_config,
        force_load=True
    )

    predict_loader.load()
