import datetime
import logging
import typing
import uuid
import codecs
from dateutil import relativedelta
from tempfile import SpooledTemporaryFile
import yt.wrapper as yt
from yt.wrapper.errors import YtResolveError

from crm.agency_cabinet.rewards.server.src.celery.base import celery_app as celery, RewardsContextBoto3Task
from crm.agency_cabinet.rewards.server.src.db.models import ReportMetaInfo, Contract, S3MdsFile
from crm.agency_cabinet.common.consts.common import START_FIN_YEAR_2021, END_FIN_YEAR_2021, END_FIN_YEAR_2022, START_FIN_YEAR_2022
from crm.agency_cabinet.common.consts.report import ReportsStatuses, ReportsTypes
from crm.agency_cabinet.common.consts.service import name_to_service_id
from crm.agency_cabinet.common.consts import ContractType, ReportContractType
from crm.agency_cabinet.rewards.server.config.clients import YT_CONFIG
from crm.agency_cabinet.rewards.server.config.mds import REPORTS_MDS_SETTINGS
from crm.agency_cabinet.common.celery.base import async_to_sync, locked_task
from crm.agency_cabinet.common.chyt import make_request

LOGGER = logging.getLogger('celery.generate_report')

LOCK_PATH = 'generate_rewards_report/task'
LOCK_KEY = 'report_id'

START_FIN_YEAR_2021_TZ = START_FIN_YEAR_2021.replace(tzinfo=datetime.timezone.utc)
START_FIN_YEAR_2022_TZ = START_FIN_YEAR_2022.replace(tzinfo=datetime.timezone.utc)
END_FIN_YEAR_2021_TZ = END_FIN_YEAR_2021.replace(tzinfo=datetime.timezone.utc)
END_FIN_YEAR_2022_TZ = END_FIN_YEAR_2022.replace(tzinfo=datetime.timezone.utc)

AGENCIES_WITH_SPECIAL_DEAL = [42647729, 70388151, 6825417, 50705, 372455]


class BaseCreateReportException(Exception):
    default_message = 'UNKNOWN'

    def __init__(self, msg=None):
        self.message = msg or self.default_message


def check_params(period_from, period_to, report_type) -> typing.List[str]:
    errors = []
    if not START_FIN_YEAR_2021_TZ <= period_from < END_FIN_YEAR_2022_TZ or not START_FIN_YEAR_2021_TZ <= period_to < END_FIN_YEAR_2022_TZ:
        errors.append('UNSUPPORTED_PERIOD')

    if report_type not in (
        ReportsTypes.month.value, ReportsTypes.quarter_direct.value, ReportsTypes.quarter_video.value
    ):
        errors.append('UNSUPPORTED_TYPE')

    if report_type == ReportsTypes.month.value and period_from >= START_FIN_YEAR_2022_TZ:
        search_path = f'//home/balance/prod/yb-ar/rewards/reports/yandex/2022-base-m/detailed_report/{period_from.strftime("%Y%m")}/@schema'
        yt_client = yt.YtClient(proxy='hahn', token=YT_CONFIG['TOKEN'])
        try:
            yt_client.get(search_path)
        except YtResolveError:
            errors.append('UNSUPPORTED_PERIOD')

    return errors


class ReportBuilder:
    common_fields = '''
            '{0}' as "Начало периода",
            '{1}' as "Конец периода",
            {2} as "ID агентства",
            arrayElement(groupArray(DISTINCT agency_name), 1) as "Агентство",
            arrayElement(groupArray(DISTINCT contract_eids), 1) as "Договоры",
            r.client_id as "ID клиента",
            arrayElement(groupArray(DISTINCT CAST(client_logins as Nullable(String))), 1) as "Логин клиента",
            arrayElement(groupArray(DISTINCT client_name), 1) as "Название клиента",
            r.cid as "Номер РК в Директе",
            arrayElement(groupArray(DISTINCT cid_name), 1) as "Название РК",
            ROUND(SUM(COALESCE(show_amount, 0)), 2) as "Оборот РК всего, руб без НДС"'''

    def __init__(
        self, period_from, period_to, report_type, agency_id, contract_id, report_contract_type, clients_ids, service, yt_client
    ):
        self.period_to = period_to
        self.period_from = period_from
        self.report_type = report_type
        self.agency_id = agency_id
        self.contract_id = contract_id
        self.report_contract_type = report_contract_type
        self.clients_ids = clients_ids
        self.service = service
        self.yt_client = yt_client
        self.base_path = None

    def build_client_ids(self):
        client_ids_present = ''
        if self.clients_ids:
            s_tuple = ', '.join(str(i) for i in self.clients_ids)
            client_ids_present = f'AND r.client_id in ({s_tuple})'

        return client_ids_present

    def build_services(self):
        ids = name_to_service_id(self.service)  # TODO: check contracts services
        ids_present = ''
        if ids is not None:
            s_tuple = ', '.join((str(i) for i in ids)) if len(ids) > 1 else f'{ids[0]}, '
            ids_present = f"\nAND r.discount_type in ({s_tuple})"
        return ids_present

    def build_query(self):
        pass


class MonthReportBuilder(ReportBuilder):
    MONTH_REPORT_COLUMNS = {'orig_contract_eid', 'contract_eid', 'agency_id', 'client_id', 'invoice_eid', 'amt',
                            'invoice_type', 'domain', 'rsya_amt', 'amt_rsya', 'rsya_reward', 'reward_rsya', 'yyyymm',
                            'ar_domain', 'rsya_search_amt'}

    SPECIAL_MONTH_REPORT_COLUMNS = {
        'reward': 'TODO'  # sum???
    }

    MONTH_COLUMNS_MAPPER = {
        'orig_contract_eid': 'Номер договора',
        'contract_eid': 'Номер договора',
        'agency_id': 'Agency ID (номер, идентифицирующий Агентство)',
        'client_id': 'Сlient ID (уникальный номер Клиента Агентства)',
        'invoice_eid': 'Номер счета',
        'invoice_type': 'Тип счета',
        'domain': 'Домен',
        'rsya_amt': 'Стоимость Услуг РСЯ по Приложению № 1 к Договору, за исключением Услуг Медийной рекламы в '
                    'Директе в рамках Client ID по Домену, руб. без учета НДС',
        'rsya_reward': 'Сумма ежемесячной премии по РСЯ по Домену, руб.',
        'reward_rsya': 'Сумма ежемесячной премии по РСЯ по Домену, руб.',
        'amt_rsya': 'Стоимость Услуг РСЯ по Приложению № 1 к Договору, за исключением Услуг Медийной рекламы в '
                    'Директе в рамках Client ID по Домену, руб. без учета НДС',
        'amt': 'Стоимость Услуг по Приложению № 1 к Договору, за исключением Услуг '
               'Медийной рекламы в Директе в рамках Client ID по Домену, руб. без учета НДС',
        'reward': 'Сумма ежемесячной премии по Домену, включая премию по РСЯ, руб.',
        'yyyymm': 'Месяц',
        'rsya_search_amt': 'Стоимость Услуг РСЯ по Приложению № 1 к Договору, за исключением Услуг Медийной рекламы в '
                    'Директе в рамках Client ID по Домену, руб. без учета НДС',
        'ar_domain': 'Домен',
    }

    SUFFIX_CONTRACT_TYPE_MAP_2021 = {
        'aggregator': '2021-base-m-2021_m_base',
        'base': '2021-base-m-2021_m_base',
        'crisis': '2021-prof_20-m-2021_m_prof',
        'prof': '2021-prof_20-m-2021_m_prof',
        'special': '2021-prof_20-m-2021_m_prof',
    }

    SUFFIX_CONTRACT_TYPE_MAP_2022 = {
        'aggregator': '2022-aggregator-m-2022_m_aggregator-m',
        'base': '2022-03-base-m',
        'crisis': '2022-03-premium-m',  # prof
        'prof': '2022-premium-m-2022_m_premium-m',
        'special': '2022-dan-m-m',
    }

    def __init__(self, period_from, period_to, report_type, report_contract_type, clients_ids, service, yt_client, agency_id, contract_id):
        super().__init__(period_from, period_to, report_type, agency_id, contract_id, report_contract_type, clients_ids, service, yt_client)
        self.base_path = '//home/balance/prod/yb-ar/rewards/reports/yandex/'
        self.suffix = self.SUFFIX_CONTRACT_TYPE_MAP_2021[report_contract_type] if period_to <= END_FIN_YEAR_2021_TZ \
            else self.SUFFIX_CONTRACT_TYPE_MAP_2022[report_contract_type]
        self.search_path = f'{self.base_path}{self.suffix}/detailed_report/{self.period_from.strftime("%Y%m")}'

    def build_query(self):
        available_columns = {}

        try:
            available_columns = {e.get('name') for e in self.yt_client.get(self.search_path + '/@schema')}
        except Exception as e:
            LOGGER.error('Something went wrong when getting available columns for month report: %s', e)

        if self.period_from.year == self.period_to.year and self.period_from.month == self.period_to.month:
            self.period_to += relativedelta.relativedelta(months=1)

        client_ids_present = self.build_client_ids()
        services_present = self.build_services()

        select_columns = available_columns & self.MONTH_REPORT_COLUMNS
        select_query = ', '.join(
            (f'{col} as \"{self.MONTH_COLUMNS_MAPPER.get(col, "FIX")}\"' for col in select_columns))
        if 'reward' in available_columns and self.report_contract_type != 'base':
            sum_reward_query = f'reward as \"{self.MONTH_COLUMNS_MAPPER["reward"]}\"'
            full_select_query = f'{select_query}, {sum_reward_query}'
        else:
            full_select_query = f'{select_query}'

        contract_id_col_name = 'orig_contract_id' if 'orig_contract_id' in available_columns else 'contract_id'
        # ТК в таблицах есть только даты с точностью до месяца надо брать все до след. месяца
        query = f"""
        SELECT {full_select_query} FROM `{self.search_path}` as r
        WHERE yyyymm >= '{self.period_from.strftime('%Y%m')}' AND yyyymm < '{self.period_to.strftime('%Y%m')}'
        AND {contract_id_col_name} = {self.contract_id} {client_ids_present} {services_present}
        """
        return query


class QuarterDirectReportBuilder(ReportBuilder):
    prof_aggr_quarter_fields = None
    auto_strategy = None
    conversion_strategy = None

    def __init__(self, period_from, period_to, report_type, agency_id, contract_id, report_contract_type, clients_ids, service, yt_client):
        super().__init__(period_from, period_to, report_type, agency_id, contract_id, report_contract_type, clients_ids, service, yt_client)
        self.quarter_query_base = f'''
        SELECT
            {self.common_fields},
            ROUND(SUM(COALESCE(rsya_show_amount, 0)), 2) as "Оборот домена в РСЯ",
            intDivOrZero((SUM(COALESCE(show_amount, 0)) + "Оборот домена в РСЯ"),  "Оборот домена в РСЯ") as "Доля РСЯ в обороте по домену",
            arrayElement(groupArray(DISTINCT counters_rate), 1) as "Номера счетчиков Метрики: доля покрытия",
            arrayElement(groupArray(DISTINCT goals_rate), 1) as "Номера ключевых целей: доля целевых визитов",
            arrayElement(groupArray(DISTINCT round(click_visit_rate, -6)), 1) as "Общая доля кликов (Метрика)",
            ROUND(SUM(COALESCE(metrika_and_goals_show_amount, 0)), 2) as "Оборот РК с целями Метрики, прошедшими проверку",
            ROUND(SUM(COALESCE(metrika_show_amount, 0)), 2) as "Оборот РК со счетчиками Метрики, прошедшими проверку",
            ROUND(SUM(COALESCE(autotarget_search_show_amount, 0)), 2) as "Автотаргетинг на поиске, оборот"
            {str(self.prof_aggr_quarter_fields or "")}
            {str(self.auto_strategy or "")}
            {str(self.conversion_strategy or "")}''' + '''
        FROM `{3}` as r
            LEFT JOIN (
        SELECT
            client_id,
            k50_effective_actions_rate,
            SUM(COALESCE(k50_efficient_dayse_show_amount, 0)) as k50_efficient_amount
        FROM `{3}`
        WHERE `date` BETWEEN '{0}' AND '{1}'
            AND agency_id = {2}
        GROUP BY client_id, k50_effective_actions_rate
                    ) as k on k.client_id = r.client_id
                    WHERE `date` BETWEEN '{0}' AND '{1}'
        AND agency_id = {2} and cpc_direct_flag {4}
                    GROUP BY r.client_id, r.cid
                    ORDER BY client_id, cid;
        '''  # noqa

    def build_query(self):
        client_ids_present = self.build_client_ids()
        query = self.quarter_query_base.format(
            self.period_from.strftime("%Y-%m-%d"),
            self.period_to.strftime("%Y-%m-%d"),
            self.agency_id,
            self.base_path,
            client_ids_present
        )
        return query


class QuarterVideoReportBuilder(ReportBuilder):
    videostream_field = None

    def __init__(self, period_from, period_to, report_type, agency_id, contract_id, report_contract_type, clients_ids, service, yt_client):
        super().__init__(period_from, period_to, report_type, agency_id, contract_id, report_contract_type, clients_ids, service, yt_client)
        self.quarter_videostream_query_base = f'''
        SELECT
            {self.common_fields}
            {str(self.videostream_field or "")}''' + '''
        FROM `{3}` as r
        WHERE `date` BETWEEN '{0}' AND '{1}'
            AND agency_id = {2} and media_flag {4}
        GROUP BY r.client_id, r.cid
        ORDER BY client_id, cid;
    '''  # noqa

    def build_query(self):
        client_ids_present = self.build_client_ids()

        query = self.quarter_videostream_query_base.format(
            self.period_from.strftime("%Y-%m-%d"),
            self.period_to.strftime("%Y-%m-%d"),
            self.agency_id,
            self.base_path,
            client_ids_present
        )
        return query


class ProfReportBuilder(ReportBuilder):
    def __init__(self, period_from, period_to, report_type, agency_id, contract_id, report_contract_type, clients_ids, service, yt_client):
        super().__init__(period_from, period_to, report_type, agency_id, contract_id, report_contract_type, clients_ids, service, yt_client)
        self.base_path = '//home/search-research/yateika/rewards/ga-247/report_all'
        self.prof_aggr_quarter_fields = ''',
         arrayElement(groupArray(DISTINCT ar_domain), 1) as "Домен:",
         ROUND(SUM(COALESCE(show_amount, 0)), 2) as "Оборот домена, всего:",
         ROUND(SUM(COALESCE(cpc_video_show_amount, 0)), 2) as "Видеообъявления и видеодополнения:",
         ROUND(SUM(COALESCE(mobile_content_show_amount, 0)), 2) as "РМП: оборот:",
         ROUND(SUM(COALESCE(performance_show_amount, 0)), 2) as "Смарт-баннеры: оборот:",
         ROUND(SUM(COALESCE(product_gallery_show_amount, 0)), 2) as "Товарная галерея: оборот:",
         SUM(COALESCE(k50_effective_actions_rate, 0)) as "К50: доля эффективных действий (по агентству)",
         ROUND(SUM(COALESCE(k50_efficient_dayse_show_amount, 0)), 2) as "К50: оборот (по логину):",
         ROUND(SUM(COALESCE(retarget_amount, 0)), 2) as "Ретаргетинг, оборот"'''
        self.auto_strategy = ''', ROUND(SUM(COALESCE(autobudget_avg_show_amount, 0)), 2) as "Автостратегии: оборот:"'''
        self.videostream_field = ''', ROUND(SUM(COALESCE(outstream_show_amount, 0)), 2) AS "Видеореклама OutStream: оборот"'''


class AggregatorReportBuilder(ReportBuilder):
    def __init__(self, period_from, period_to, report_type, agency_id, contract_id, report_contract_type, clients_ids, service, yt_client):
        super().__init__(period_from, period_to, report_type, agency_id, contract_id, report_contract_type, clients_ids, service, yt_client)
        self.base_path = '//home/search-research/yateika/rewards/ga-247/report_aggregate'
        self.videostream_field = None
        self.conversion_strategy = ''', ROUND(SUM(COALESCE(autobudget_avg_show_amount, 0)), 2) as "Конверсионные стратегии: оборот"'''


class BaseReportBuilder(ReportBuilder):
    def __init__(self, period_from, period_to, report_type, agency_id, contract_id, report_contract_type, clients_ids, service, yt_client):
        super().__init__(period_from, period_to, report_type, agency_id, contract_id, report_contract_type, clients_ids, service, yt_client)
        self.base_path = '//home/search-research/yateika/rewards/ga-247/report_base'
        self.videostream_field = ''', ROUND(SUM(COALESCE(outstream_show_amount, 0)), 2) AS "Видеореклама OutStream: оборот"'''
        self.conversion_strategy = ''', ROUND(SUM(COALESCE(autobudget_avg_show_amount, 0)), 2) as "Конверсионные стратегии: оборот"'''


class ProfDirectReportBuilder(QuarterDirectReportBuilder, ProfReportBuilder):
    ...


class ProfVideoReportBuilder(QuarterVideoReportBuilder, ProfReportBuilder):
    ...


class AggregatorDirectReportBuilder(QuarterDirectReportBuilder, AggregatorReportBuilder):
    ...


class AggregatorVideoReportBuilder(QuarterVideoReportBuilder, AggregatorReportBuilder):
    ...


class BaseDirectReportBuilder(QuarterDirectReportBuilder, BaseReportBuilder):
    ...


class BaseVideoReportBuilder(QuarterVideoReportBuilder, BaseReportBuilder):
    ...


@celery.task(bind=True, base=RewardsContextBoto3Task, time_limit=3 * 60 * 60 + 1)
@locked_task(lock_path=LOCK_PATH, key=LOCK_KEY, block_timeout=1, block=True, timeout=3 * 60 * 60)
def generate_report_task(self: RewardsContextBoto3Task, report_id):
    # TODO: retry, add lock by report_id, reset requested on error
    @async_to_sync
    async def _get_report(report_id) -> ReportMetaInfo:
        return await ReportMetaInfo.get(report_id)

    @async_to_sync
    async def _get_contract(report: ReportMetaInfo) -> ContractType:
        contract = await Contract.get(report.contract_id)
        return contract

    def _get_report_contract_type(contract: Contract, report_type: ReportsTypes):
        t = contract.type
        if report_type == ReportsTypes.month.value:
            if contract.agency_id in AGENCIES_WITH_SPECIAL_DEAL:
                t = ReportContractType.special.value
            elif contract.is_crisis is True and contract.type == ContractType.prof.value:
                t = ReportContractType.crisis.value
        return t

    @async_to_sync
    async def _update_report_status(status):
        await report.update(status=status).apply()

    @async_to_sync
    async def _create_mds_file_model_and_update_report_status() -> S3MdsFile:
        file: S3MdsFile = await S3MdsFile.create(bucket=bucket, name=filename, display_name=report.name)
        await report.update(status=ReportsStatuses.ready.value, file_id=file.id).apply()
        return file

    try:
        LOGGER.info('Start processing report (id: %s)', report_id)
        report = _get_report(report_id)

        if report is None:
            LOGGER.warning('Can\'t find report (id: %s)', report_id)
            return
        if report.status != ReportsStatuses.requested.value:
            LOGGER.warning('Request to generate task for incorrect report status: %s (id: %s)', report.status, report.id)
            return
        client = yt.YtClient(proxy='hahn', token=YT_CONFIG['TOKEN'])

        agency_contract = _get_contract(report)
        agency_id = agency_contract.agency_id
        report_contract_type = _get_report_contract_type(agency_contract, report.type)

        params = {
            'period_from': report.period_from,
            'period_to': report.period_to,
            'agency_id': agency_id,
            'contract_id': report.contract_id,
            'report_type': report.type,
            'clients_ids': report.clients_ids,
            'service': report.service,
            'yt_client': client,
            'report_contract_type': report_contract_type,
        }

        builders_map = {
            ReportsTypes.month.value: MonthReportBuilder,
            ReportsTypes.quarter_direct.value: {
                ContractType.aggregator.value: AggregatorDirectReportBuilder,
                ContractType.prof.value: ProfDirectReportBuilder,
                ContractType.base.value: BaseDirectReportBuilder,
            },
            ReportsTypes.quarter_video.value: {
                ContractType.aggregator.value: AggregatorVideoReportBuilder,
                ContractType.prof.value: ProfVideoReportBuilder,
                ContractType.base.value: BaseVideoReportBuilder,
            }
        }

        report_type_to_builder = builders_map.get(report.type)

        if report_type_to_builder is None:
            raise NotImplementedError

        if isinstance(report_type_to_builder, dict):
            contract_type_to_builder = report_type_to_builder.get(report_contract_type)

            if contract_type_to_builder is None:
                raise NotImplementedError

            builder = contract_type_to_builder(**params)

        else:
            builder = report_type_to_builder(**params)

        errors = check_params(
            report.period_from,
            report.period_from,
            report.type
        )

        if errors:
            LOGGER.warning('Bad parameters for report: %s (id: %s)', ';'.join(errors), report.id)
            return

        _update_report_status(ReportsStatuses.in_progress.value)

        query = builder.build_query()

        LOGGER.debug('Build query for report: %s', query)

        # TODO: retry requests
        response = make_request(
            query,
            alias="*chyt_advagencyportal",
            client=client,
            format='CSVWithNames',
            settings={'format_csv_delimiter': ';'}
        )

        random_uuid = str(uuid.uuid4())
        bucket = REPORTS_MDS_SETTINGS['bucket']
        prefix = REPORTS_MDS_SETTINGS['prefix']
        filename = f'{prefix}/{random_uuid}-{report.name}.csv'
        with SpooledTemporaryFile(max_size=65536 * 4) as f:
            iterator = response.iter_content(chunk_size=65536)
            first_line = next(iterator)
            if not first_line.startswith(codecs.BOM_UTF8):
                first_line = codecs.BOM_UTF8 + first_line
            f.write(first_line)
            for line in iterator:
                f.write(line)
            f.seek(0)
            self.s3_resource.meta.client.upload_fileobj(
                f,
                bucket,
                filename,
                ExtraArgs={
                    'ContentType': 'text/csv; charset=utf-8',
                    'ContentDisposition': f'attachment; filename={report.name}.csv'
                }
            )

        _create_mds_file_model_and_update_report_status()
    except Exception as ex:
        LOGGER.exception('Error during report generation: %s; retries %s/%s', ex, self.request.retries, self.max_retries)
        if self.request.retries == self.max_retries:
            _update_report_status(ReportsStatuses.requested.value)
        else:
            self.retry(exc=ex)


@celery.task(bind=True)
def check_requested_reports_task(self):
    @async_to_sync
    async def _get_reports() -> typing.List[ReportMetaInfo]:
        return await ReportMetaInfo.query.where(ReportMetaInfo.status == ReportsStatuses.requested.value).gino.all()

    reports = _get_reports()
    for report in reports:
        generate_report_task.delay(report_id=report.id)
