import csv
import io
from typing import Any, ClassVar, Dict

from typing_extensions import Protocol

from mail.ipa.ipa.conf import settings
from mail.ipa.ipa.core.actions.base import BaseDBAction
from mail.ipa.ipa.core.entities.collector import Collector
from mail.ipa.ipa.core.entities.enums import UserImportError


class AsyncWritable(Protocol):
    async def write(self, chunk: bytes) -> Any:
        ...


class WriteCSVReportAction(BaseDBAction):
    BATCH_SIZE: ClassVar[int] = settings.CSV_REPORT_BATCH_SIZE

    def __init__(self, org_id: int, output: AsyncWritable):
        super().__init__()
        self.org_id: int = org_id
        self.output: AsyncWritable = output
        self.buf = io.StringIO()
        self.writer = csv.DictWriter(self.buf, fieldnames=[
            'login',
            'error',
            'collected',
            'total',
            'errors',
        ])

    async def flush(self) -> None:
        text = self.buf.getvalue()
        self.buf.truncate(0)
        self.buf.seek(0)
        await self.output.write(text.encode('utf-8'))

    async def handle(self) -> None:
        self.writer.writeheader()
        await self.flush()

        async for user in self.storage.user.get_all_with_collectors(
            org_id=self.org_id,
            batch_size=self.BATCH_SIZE,
            only_errors=True,
        ):
            data: Dict[str, Any] = {'login': user.login}
            if isinstance(user.collector, Collector):
                data.update({
                    'collected': user.collector.collected,
                    'total': user.collector.total,
                    'errors': user.collector.errors,
                    'error': UserImportError.get_error_str(
                        user_error=user.error,
                        collector_status=user.collector.status,
                        logger=self.logger,
                    ),
                })
            else:
                data['error'] = UserImportError.get_error_str(
                    user_error=user.error,
                    logger=self.logger,
                )

            self.writer.writerow(data)
            await self.flush()
