import aioboto3
import datetime
import logging
import tempfile
import typing
import xlsxwriter
import uuid
from aiobotocore.config import AioConfig
from dataclasses import dataclass
from decimal import Decimal
from babel.dates import get_month_names

from smb.common.pgswim import PoolType, SwimEngine

from crm.agency_cabinet.common.consts.report import ReportsStatuses
from crm.agency_cabinet.common.server.common.config import MdsConfig
from crm.agency_cabinet.client_bonuses.common.structs import ClientType
from crm.agency_cabinet.common.definitions import NamedTemporaryFileProtocol


LOGGER = logging.getLogger('celery.generate_report')


class ReportGeneratorException(Exception):
    pass


class UnknownClientType(ReportGeneratorException):
    pass


class ReportMonthsData:
    @dataclass
    class MonthData:
        clients: dict
        programs: set

    @dataclass
    class ClientMonthData:
        client_id: int
        client_login: str
        programs: dict
        spent_total: Decimal

    def __init__(self):
        self.data = dict()

    def add_client_month_data(
        self,
        month: datetime.datetime,
        client_id: int, client_login: str,
        program_id: typing.Optional[int],
        amount: typing.Optional[Decimal],
        spent_total: Decimal
    ):
        month_data = self._get_or_create_month_data(month)

        if client_id not in month_data.clients:
            month_data.clients[client_id] = ReportMonthsData.ClientMonthData(client_id, client_login, dict(), Decimal(0))

        if program_id:
            month_data.programs.add(program_id)
            month_data.clients[client_id].programs[program_id] = amount

        month_data.clients[client_id].spent_total = spent_total

    def _get_or_create_month_data(self, month: datetime.datetime) -> MonthData:
        if month not in self.data:
            self.data[month] = ReportMonthsData.MonthData(dict(), set())
        return self.data[month]


class ReportGenerator:
    MDS_PREFIX = 'bonuses_reports'

    REPORT_COLUMNS = ['client id', 'client login', 'Накоплено бонусов (₽, с НДС)', 'Потрачено бонусов (₽, с НДС)']

    def __init__(self, db: SwimEngine, mds_cfg: MdsConfig, report_id: int):
        self._db = db
        self._report_id = report_id
        self._bucket = mds_cfg.bucket
        self._mds_cfg = mds_cfg

    async def generate(self):
        try:
            await self._update_status(ReportsStatuses.in_progress.value)
            report_file, report_name = await self._make_report()
            with report_file:
                mds_filename = await self._upload_to_mds(report_file, report_name)
                await self._update_status_and_save_mds_filename(mds_filename, report_name)
        except Exception as ex:
            LOGGER.exception('Error during report generation: %s', ex)
            await self._update_status(ReportsStatuses.requested.value)
            raise ex

    async def _update_status(self, status: str):
        async with self._db.acquire(PoolType.master) as con:
            await con.execute(
                f"""
                UPDATE report_meta_info
                SET STATUS='{status}',
                UPDATED_AT = NOW()
                WHERE ID={self._report_id}
                """
            )

    async def _make_report(self) -> typing.Tuple[NamedTemporaryFileProtocol, str]:
        clients_ids, agency_id, period_from, period_to, meta_report_name = await self._fetch_parameters_from_db()
        cashback_programs = await self._fetch_programs_from_db()

        report_file = tempfile.NamedTemporaryFile()
        with xlsxwriter.Workbook(report_file.name) as workbook:

            general_format = workbook.add_format(
                {'font': 'Arial', 'font_size': 11})
            headers_format = workbook.add_format(
                {'bg_color': '#fdf3d0', 'text_wrap': True, 'font': 'Arial', 'font_size': 11})
            currency_format = workbook.add_format(
                {'font': 'Arial', 'font_size': 11, 'num_format': '#,##0.00'})

            worksheet_total = workbook.add_worksheet('Всего')
            self._format_sheet(worksheet_total, 0, general_format, currency_format)
            worksheet_total.write_row(
                'A1',
                self.REPORT_COLUMNS,
                headers_format
            )

            row_index = 2
            months_data = ReportMonthsData()
            async for row in self._fetch_data_from_db(clients_ids, agency_id, period_from, period_to):
                id, login, gained_total, spent_total, gained_details, spent_details = row

                worksheet_total.write_row(f'A{row_index}', [id, login, gained_total, spent_total])
                row_index += 1

                for data in gained_details:
                    months_data.add_client_month_data(
                        month=datetime.datetime.strptime(data['date'], '%Y-%m-%dT%H:%M:%S%z'),
                        client_id=id,
                        client_login=login,
                        program_id=data['program_id'],
                        amount=data['amount'],
                        spent_total=Decimal(0)
                    )

                for data in spent_details:
                    months_data.add_client_month_data(
                        month=datetime.datetime.strptime(data['date'], '%Y-%m-%dT%H:%M:%S%z'),
                        client_id=id,
                        client_login=login,
                        program_id=None,
                        amount=None,
                        spent_total=data['amount']
                    )

            for date, month_data in months_data.data.items():
                worksheet = workbook.add_worksheet(self._get_month_sheet_title(date))
                self._format_sheet(worksheet, len(month_data.programs), general_format, currency_format)
                worksheet.write_row(
                    'A1',
                    await self._make_month_columns_title(month_data, cashback_programs),
                    headers_format
                )

                row_index = 2
                for _, client_month_data in month_data.clients.items():
                    worksheet.write_row(f'A{row_index}', self._make_month_row(client_month_data, month_data.programs))
                    row_index += 1
        report_file.seek(0)
        return report_file, meta_report_name

    def _format_sheet(self, sheet, programs_num, general_formatter, currency_formatter) -> None:
        sheet.set_column(first_col=0, last_col=0, width=10, cell_format=general_formatter)
        sheet.set_column(first_col=1, last_col=1, width=25, cell_format=general_formatter)
        sheet.set_column(first_col=2, last_col=3 + programs_num, width=15, cell_format=currency_formatter)
        sheet.freeze_panes(1, 0)

    def _get_month_sheet_title(self, date: datetime) -> str:
        month_num = date.month
        month = get_month_names(context='stand-alone', locale='ru')[month_num]
        year = date.year
        return f'{month.capitalize()} {year}'

    async def _make_month_columns_title(self, month_data: ReportMonthsData.MonthData, programs: dict) -> list:
        return self.REPORT_COLUMNS[:2] + [programs[p] for p in month_data.programs] + self.REPORT_COLUMNS[2:]

    def _make_month_row(self, client_month_data: ReportMonthsData.ClientMonthData, programs: set) -> list:
        row = [client_month_data.client_id, client_month_data.client_login]

        gained_total = 0
        for p in programs:
            amount = client_month_data.programs.get(p, 0)
            row.append(amount)
            gained_total += amount

        row.append(gained_total)
        row.append(client_month_data.spent_total)
        return row

    async def _fetch_parameters_from_db(self) -> typing.Tuple[str, int, datetime.datetime, datetime.datetime, str]:
        async with self._db.acquire(PoolType.replica) as con:
            meta_info = await con.fetchrow(
                f"""
                SELECT * FROM report_meta_info
                WHERE id={self._report_id}
                """
            )

            return (
                meta_info['client_type'],
                meta_info['agency_id'],
                meta_info['period_from'],
                meta_info['period_to'],
                meta_info['name']
            )

    async def _fetch_data_from_db(self, client_type: str, agency_id: int, period_from: datetime.datetime,
                                  period_to: datetime.datetime) -> typing.AsyncIterable[typing.Tuple[int, str, Decimal, Decimal, typing.List[typing.Dict],  typing.List[typing.Dict]]]:

        if client_type == ClientType.ALL.value:
            clients_cond = f'agency_id={agency_id}'
        elif client_type == ClientType.ACTIVE.value:
            clients_cond = f'clients.is_active=true AND agency_id={agency_id}'
        elif client_type == ClientType.EXCLUDED.value:
            clients_cond = f'clients.is_active=false AND agency_id={agency_id}'
        else:
            raise UnknownClientType(client_type)

        async with self._db.acquire(PoolType.replica) as con:
            sql = f"""
                WITH gained AS (
                    SELECT client_id,
                        SUM(amount) AS total,
                        array_agg(jsonb_build_object(
                            'date', date_trunc('month', gained_at),
                            'program_id', program_id,
                            'amount', amount)
                        ) AS gained_details
                    FROM gained_client_bonuses
                    WHERE gained_at BETWEEN '{period_from}' AND '{period_to}'
                    GROUP BY client_id
                ),

                spent AS (
                    SELECT client_id,
                        SUM(amount) AS total,
                        array_agg(jsonb_build_object(
                            'date', date_trunc('month', spent_at),
                            'amount', amount)
                        ) as spent_details
                    FROM spent_client_bonuses
                    WHERE spent_at BETWEEN '{period_from}' AND '{period_to}'
                    GROUP BY client_id
                )
                SELECT
                    clients.id,
                    login,
                    gained.total AS gained_total,
                    spent.total AS spent_total,
                    gained_details,
                    spent_details
                FROM clients
                FULL JOIN gained ON clients.id = gained.client_id
                FULL JOIN spent ON clients.id = spent.client_id
                WHERE {clients_cond}
                ORDER BY clients.id
                """

            for row in await con.fetch(sql):
                yield (
                    row['id'],
                    row['login'],
                    row['gained_total'] or 0,
                    row['spent_total'] or 0,
                    row['gained_details'] or [],
                    row['spent_details'] or []
                )

    async def _fetch_programs_from_db(self) -> dict:
        async with self._db.acquire(PoolType.replica) as con:
            programs_info = {}
            programs = await con.fetch(
                """
                SELECT id, name_ru FROM cashback_programs;
                """
            )
            for program in programs:
                programs_info[program['id']] = program['name_ru']

            return programs_info

    async def _upload_to_mds(self, report_file: NamedTemporaryFileProtocol, report_filename: str) -> str:
        mds_filename = f'{self.MDS_PREFIX}/{str(uuid.uuid4())}-{report_filename}.xls'

        boto3_session = aioboto3.Session(
            aws_access_key_id=self._mds_cfg.access_key_id,
            aws_secret_access_key=self._mds_cfg.secret_access_key
        )
        # TODO: don't create session each time
        # TODO: add aioboto3.session to BotoSessionBoundedMixin and pass it through task?

        async with boto3_session.resource('s3', endpoint_url=self._mds_cfg.endpoint_url, config=AioConfig(s3={'addressing_style': 'virtual'})) as s3:
            obj = await s3.Object(self._bucket, mds_filename)
            await obj.put(
                Body=report_file.file,
                ContentType='application/vnd.ms-excel; charset=utf-8',
                ContentDisposition=f'attachment; filename={report_filename}.xls'
            )

        return mds_filename

    async def _update_status_and_save_mds_filename(self, mds_filename: str, display_name: str):
        async with self._db.acquire(PoolType.master) as con:
            row = await con.fetchrow(
                f"""
                INSERT INTO s3_mds_file (
                    bucket, name, display_name
                )
                VALUES (
                    '{self._bucket}', '{mds_filename}', '{display_name}'
                )
                RETURNING *
                """
            )

            await con.execute(
                f"""
                UPDATE report_meta_info
                SET STATUS='{ReportsStatuses.ready.value}', FILE_ID={row['id']}, UPDATED_AT = NOW()
                WHERE ID={self._report_id}
                """
            )
