import aioboto3
import datetime
import logging
import tempfile
import typing
import xlsxwriter
import uuid
from aiobotocore.config import AioConfig

from crm.agency_cabinet.ord.server.src.db import db, models
from crm.agency_cabinet.ord.server.src.db.queries import build_get_client_rows_query
from crm.agency_cabinet.common.server.common.config import MdsConfig
from crm.agency_cabinet.common.definitions import NamedTemporaryFileProtocol
from crm.agency_cabinet.common.server.common.structs import TaskStatuses

LOGGER = logging.getLogger('celery.export_reports.base')


class BaseReportExporter:
    HEADERS = {}

    def __init__(self, mds_cfg: MdsConfig, report_mds_config, task_id: int, report_id: int):
        self._task_id = task_id
        self._report_id = report_id
        self._bucket = report_mds_config['bucket']
        self._mds_cfg = mds_cfg
        self._prefix = report_mds_config['prefix']

    async def generate(self):
        try:
            await self._update_status(TaskStatuses.in_progress.value)
            report_file = await self._make_report()
            await self._upload_to_mds(report_file)
        except Exception as ex:
            LOGGER.exception('Error during report export: %s', ex)
            await self._update_status(TaskStatuses.error.value)
            raise ex

    async def _make_report(self) -> NamedTemporaryFileProtocol:
        report_file = tempfile.NamedTemporaryFile(suffix=".xlsx")

        with xlsxwriter.Workbook(report_file.name) as workbook:
            worksheet = workbook.add_worksheet()
            COMMON_HEADERS_ROW = 0
            HEADERS_ROW = 1
            c = 0
            for common_header, group_headers in self.HEADERS.items():
                worksheet.write(COMMON_HEADERS_ROW, c, common_header)

                for headers in group_headers:
                    for h in headers:
                        worksheet.write(HEADERS_ROW, c, h)
                        c += 1

            row_index = 3
            async for row in self._get_report_data():
                worksheet.write_row(f'A{row_index}', row)
                row_index += 1

        report_file.seek(0)
        return report_file

    async def _fetch_data_from_db(self) -> typing.List[models.ClientRow]:
        query = build_get_client_rows_query(report_id=self._report_id)
        query = query.order_by(models.ClientRow.client_id, models.ClientRow.campaign_id, models.ClientRow.id)

        async with db.transaction():
            async for row in query.gino.iterate():
                yield row

    async def _update_status(self, status: str):
        await models.ReportExportInfo.update.values(
            status=status
        ).where(
            models.ReportExportInfo.id == self._task_id
        ).gino.status()

    async def _upload_to_mds(self, report_file: NamedTemporaryFileProtocol):
        random_uuid = str(uuid.uuid4())
        display_name = f'Отчет от {datetime.datetime.now().strftime("%d-%m-%Y")}'  # TODO: better naming?
        mds_filename = f'{self._prefix}/{random_uuid}-{display_name}.xlsx'

        # TODO: don't create session each time
        # TODO: add aioboto3.session to BotoSessionBoundedMixin and pass it through task?
        boto3_session = aioboto3.Session(
            aws_access_key_id=self._mds_cfg['access_key_id'],
            aws_secret_access_key=self._mds_cfg['secret_access_key']
        )

        async with boto3_session.resource('s3', endpoint_url=self._mds_cfg['endpoint_url'], config=AioConfig(s3={'addressing_style': 'virtual'})) as s3:
            obj = await s3.Object(self._bucket, mds_filename)
            await obj.put(
                Body=report_file.file,
                ContentType='application/vnd.ms-excel; charset=utf-8',
                ContentDisposition=f'attachment; filename={display_name}.xlsx'
            )

        file: models.S3MdsFile = await models.S3MdsFile.create(
            bucket=self._bucket,
            name=mds_filename,
            display_name=display_name
        )

        await models.ReportExportInfo.update.values(
            status=TaskStatuses.ready.value,
            file_id=file.id,
        ).where(models.ReportExportInfo.id == self._task_id).gino.status()

    async def _get_report_data(self):
        raise NotImplementedError
