from collections import defaultdict
from datetime import date, datetime
from typing import Any, AsyncIterable, Dict, Optional

from aiohttp.client import ClientResponse
from aiohttp.client_exceptions import ContentTypeError
from pytz.tzinfo import BaseTzInfo

from sendr_utils import without_none

from mail.ciao.ciao.interactions.base import BaseInteractionClient
from mail.ciao.ciao.interactions.calendar.entities import Event, EventType
from mail.ciao.ciao.interactions.calendar.exceptions import get_error_by_name


class BaseCalendarClient(BaseInteractionClient[dict]):
    _STRFTIME = '%Y-%m-%dT%H:%M:%S%z'

    @staticmethod
    def map_event(event_data: dict) -> Event:
        return Event(
            event_id=event_data['id'],
            external_id=event_data['externalId'],
            name=event_data['name'],
            description=event_data['description'],
            start_ts=datetime.fromisoformat(event_data['startTs']),
            end_ts=datetime.fromisoformat(event_data['endTs']),
            others_can_view=event_data['othersCanView'],
            sequence=event_data['sequence'],
            all_day=event_data['isAllDay'],
        )

    async def _make_request(self, *args: Any, **kwargs: Any) -> ClientResponse:
        user_ticket = kwargs.pop('user_ticket', None)
        if user_ticket:  # walrus spot
            kwargs.setdefault('headers', {})
            kwargs['headers']['X-Ya-User-Ticket'] = user_ticket
        return await super()._make_request(*args, **kwargs)

    async def _process_response(self, response: ClientResponse, interaction_method: str) -> dict:
        # Calendar always responds with OK 200, need to check for "error" key in response.
        # https://wiki.yandex-team.ru/calendar/api/new-web/#errors
        data = await super()._process_response(response, interaction_method)
        if 'error' in data:
            error_name = data['error'].get('name', 'Missing error name')
            raise get_error_by_name(error_name)(
                status_code=response.status,
                response_status=None,
                service=self.SERVICE,
                method=response.method,
                message=error_name,
            )
        return data

    async def count_events(self,
                           uid: int,
                           user_ticket: Optional[str],
                           from_date: date,
                           to_date: date,
                           timezone: BaseTzInfo,
                           ) -> Dict[date, int]:
        layers_data = await self.get(
            'count_events',
            f'{self.BASE_URL}/internal/count-events',
            user_ticket=user_ticket,
            params={
                'uid': uid,
                'from': from_date.isoformat(),
                'to': to_date.isoformat(),
                'tz': timezone.zone,
            }
        )
        count: Dict[date, int] = defaultdict(int)
        for layer in layers_data['layers']:
            for date_str, date_count in layer['counts'].items():
                count[date.fromisoformat(date_str)] += date_count
        return dict(count)

    async def create_event(self,
                           uid: int,
                           user_ticket: Optional[str],
                           start_datetime: datetime,
                           end_datetime: datetime,
                           external_id: Optional[str] = None,
                           name: Optional[str] = None,
                           description: Optional[str] = None,
                           all_day: bool = False,
                           ) -> int:
        result = await self.post(
            'create_event',
            f'{self.BASE_URL}/internal/create-event',
            user_ticket=user_ticket,
            params={
                'uid': uid,
            },
            json=without_none({
                'externalId': external_id,
                'startTs': start_datetime.strftime(self._STRFTIME),
                'endTs': end_datetime.strftime(self._STRFTIME),
                'name': name,
                'description': description,
                'type': EventType.USER.value,
                'isAllDay': all_day,
            }),
        )

        return result['showEventId']

    async def delete_event(self, uid: int, user_ticket: Optional[str], event_id: int) -> None:
        await self.post(
            'delete_event',
            f'{self.BASE_URL}/internal/delete-event',
            user_ticket=user_ticket,
            params={
                'uid': uid,
                'id': event_id,
            },
        )

    async def get_events(self,
                         uid: int,
                         user_ticket: Optional[str],
                         from_datetime: datetime,
                         to_datetime: datetime,
                         ) -> AsyncIterable[Event]:
        events_data = await self.get(
            'get_events',
            f'{self.BASE_URL}/internal/get-events',
            user_ticket=user_ticket,
            params={
                'uid': uid,
                'from': from_datetime.strftime(self._STRFTIME),
                'to': to_datetime.strftime(self._STRFTIME),
                'dateFormat': 'zoned',
            }
        )

        for event_data in events_data['events']:
            yield self.map_event(event_data)

    async def ping(self) -> None:
        try:
            await self.get('ping', f'{self.BASE_URL}/ping')
        except ContentTypeError:
            pass

    async def update_event(self,
                           uid: int,
                           user_ticket: Optional[str],
                           event_id: int,
                           start_ts: datetime,
                           end_ts: datetime,
                           all_day: bool,
                           ) -> None:
        await self.post(
            'update_event',
            f'{self.BASE_URL}/internal/update-event',
            user_ticket=user_ticket,
            params={
                'uid': uid,
                'id': event_id,
            },
            json={
                'startTs': start_ts.strftime(self._STRFTIME),
                'endTs': end_ts.strftime(self._STRFTIME),
                'isAllDay': all_day,
            },
        )
