from typing import ClassVar, Dict, Iterable, List, Optional, Tuple, TypeVar, Union

import pytz
import ujson
from aiohttp import web
from marshmallow import Schema
from pytz.tzinfo import BaseTzInfo

from mail.ciao.ciao.api.handlers.base import BaseHandler
from mail.ciao.ciao.conf import settings
from mail.ciao.ciao.core.entities.button import UriButton
from mail.ciao.ciao.core.entities.enums import FrameName
from mail.ciao.ciao.core.entities.missing import MissingType
from mail.ciao.ciao.core.entities.scenario_response import IRRELEVANT_RESPONSE, ScenarioResponse
from mail.ciao.ciao.core.entities.state import State
from mail.ciao.ciao.core.entities.state_stack import StateStack, StateStackItem
from mail.ciao.ciao.core.entities.user import User
from mail.ciao.ciao.core.scenario_runner import ScenarioRunner
from mail.ciao.ciao.core.scenarios.base import BaseScenario
from mail.ciao.ciao.utils.gettext import gettext
from mail.ciao.ciao.utils.logging import LOGGER, SensitiveDataHolder

T = TypeVar('T')


class BaseMegamindHandler(BaseHandler):
    """
    Accepts both protobuf and json. Response has the same content-type as the request.
    """

    PROTOBUF_CONTENT_TYPE = 'application/protobuf'
    JSON_CONTENT_TYPE = 'application/json'

    PROTOBUF_SCHEMAS: ClassVar[Tuple[Schema, Schema]]
    JSON_SCHEMAS: ClassVar[Tuple[Schema, Schema]]

    COMMIT: ClassVar[bool] = False
    REQUEST_TYPE: ClassVar[str]

    @property
    def is_protobuf(self) -> bool:
        return self.request.content_type == self.PROTOBUF_CONTENT_TYPE

    @property
    def request_schema(self) -> Schema:
        return self.PROTOBUF_SCHEMAS[0] if self.is_protobuf else self.JSON_SCHEMAS[0]

    @property
    def response_schema(self) -> Schema:
        return self.PROTOBUF_SCHEMAS[1] if self.is_protobuf else self.JSON_SCHEMAS[1]

    @staticmethod
    def name_items(item_name: str, items: Iterable[T]) -> Dict[str, T]:
        return {
            f'{item_name}_{i}': item
            for i, item in enumerate(items)
        }

    @staticmethod
    def is_voice(data: dict) -> bool:
        return 'voice' in data.get('input', {})

    @staticmethod
    def get_timezone(data: dict) -> BaseTzInfo:
        return pytz.timezone(data['base_request']['client_info']['timezone'])

    @staticmethod
    def get_request_id(data: dict) -> Optional[str]:
        return data.get('base_request', {}).get('request_id')

    @staticmethod
    def get_uuid(data: dict) -> str:
        return data['base_request']['client_info']['uuid']

    @staticmethod
    def get_app_id(data: dict) -> Optional[str]:
        return data.get('base_request', {}).get('client_info', {}).get('app_id')

    @staticmethod
    def get_text(data: dict) -> Optional[str]:
        input_data = data.get('input', {})
        if 'text' in input_data:
            return input_data['text'].get('utterance')
        elif 'voice' in input_data:
            return input_data['voice'].get('utterance')
        return None

    @staticmethod
    def get_frames_data(data: dict) -> List[dict]:
        raise NotImplementedError

    @staticmethod
    def get_frame_name(frame_data: dict) -> FrameName:
        return FrameName.from_string(frame_data.get('name'))

    @classmethod
    def get_state(cls, data: dict) -> State:
        logger = LOGGER.get()
        state_data = data.get('base_request', {}).get('state', {})
        if not state_data:
            logger.info('Failed to parse state. Creating a new one.')
            return State()

        state_stack_data = state_data.get('state_stack', {})
        state = State(
            soft_expire=state_data['soft_expire'],
            hard_expire=state_data['hard_expire'],
            state_stack=StateStack(stack_items=[
                StateStackItem(
                    scenario_name=item_data['scenario_name'],
                    params=item_data['params'],
                    arg_name=item_data.get('arg_name'),
                )
                for item_data in state_stack_data.get('stack_items', [])
            ]),
        )
        assert state.hard_expire is not MissingType.MISSING \
            and state.soft_expire is not MissingType.MISSING
        if state.expired:
            with logger:
                logger.context_push(
                    hard_expire=state.hard_expire.isoformat(),
                    soft_expire=state.soft_expire.isoformat(),
                )
                logger.info('State is expired. Creating a new one.')
            state = State()
        else:
            logger.info('State is restored.')
        return state

    @classmethod
    def get_slots(cls, frame_data: dict) -> dict:
        return {
            slot_data['name']: slot_data['parsed']
            for slot_data in frame_data.get('slots', [])
        }

    @classmethod
    def make_response_data(cls,
                           scenario_response: ScenarioResponse,
                           frame_name: Optional[FrameName] = None,
                           request_data: Optional[dict] = None,
                           ) -> dict:
        """
        Builds response.

        :param scenario_response: response returned by scenrio
        :param frame_name: name of handled frame
        :param request_data: original request data. Used for passing state and frames to commit request.
        """
        if scenario_response.error:
            return {'error': {'message': scenario_response.error}}

        expected_frames = cls.name_items('expected_frame', scenario_response.expected_frames)
        buttons = cls.name_items('button', scenario_response.buttons)
        suggests = cls.name_items('suggest', scenario_response.suggests)

        data: dict = {
            'response_body': {
                'state': BaseScenario.context.state,
                'frame_actions': {**expected_frames, **buttons, **suggests},
            },
        }
        data['response_body']['layout'] = {
            'should_listen': (
                scenario_response.should_listen
                and request_data is not None and cls.is_voice(request_data)
            ),
        }

        if scenario_response.irrelevant:
            data['features'] = {'is_irrelevant': True}
        if scenario_response.text is not None:
            card: dict = {'text': scenario_response.text} if not scenario_response.buttons else {
                'text_with_buttons': {
                    'text': scenario_response.text,
                    'buttons': buttons,
                }
            }
            data['response_body']['layout'].update({
                'output_speech': scenario_response.speech,
                'cards': [card],
                'suggest_buttons': suggests,
            })
        if scenario_response.requested_slot:
            assert frame_name is not None and request_data is not None
            response_frame_name = scenario_response.frame_name or frame_name
            slot_name, slot_type = scenario_response.requested_slot
            data['response_body']['semantic_frame'] = {
                'name': response_frame_name.value,
                'slots': [
                    {'name': slot_name, 'accepted_types': [slot_type], 'is_requested': True},
                ],
            }
            data['response_body'].setdefault('layout', {})

        if scenario_response.analytics is not None:
            data.setdefault('response_body', {})
            data['response_body']['analytics_info'] = scenario_response.analytics

        if scenario_response.commit:
            # https://wiki.yandex-team.ru/mail/swat/ciao/Alice-gotchas/#m-kakrabotaetkommit
            assert request_data is not None
            data = {
                'commit_candidate': {
                    **data,
                    'arguments': {
                        'semantic_frames': cls.get_frames_data(request_data),
                    },
                },
            }

        if scenario_response.contains_sensitive_data:
            data.setdefault('response_body', {})
            data['response_body']['layout']['contains_sensitive_data'] = True

        return data

    async def get_data(self) -> dict:
        data: Union[bytes, dict]
        if self.is_protobuf:
            data = await self.request.read()
        else:
            data = await self.request.json()
        parsed_data, _ = self.request_schema.load(data)
        if settings.DEBUG:
            with self.logger:
                self.logger.context_push(parsed_data=parsed_data)
                self.logger.info('Data parsed.')
        return parsed_data

    def make_response(self,
                      scenario_response: ScenarioResponse,
                      frame_name: Optional[FrameName] = None,
                      request_data: Optional[dict] = None,
                      ) -> web.Response:
        raw_data = self.make_response_data(scenario_response, frame_name, request_data)
        if settings.DEBUG:
            with self.logger:
                self.logger.context_push(raw_response_data=raw_data)
                self.logger.info('Raw response.')
        data, _ = self.response_schema.dump(raw_data)
        if self.is_protobuf:
            return web.Response(
                body=data,
                content_type=self.PROTOBUF_CONTENT_TYPE,
            )
        else:
            return web.Response(
                text=ujson.dumps(data),
                content_type=self.JSON_CONTENT_TYPE,
            )

    async def post(self):
        with self.logger:
            self.logger.context_push(request_type=self.REQUEST_TYPE)

            # Getting user uid
            user_uid: Optional[int] = self.request['tvm'].default_uid
            if user_uid is None:
                self.logger.info('User is not authenticated.')
                return self.make_response(ScenarioResponse(
                    text=gettext('Please, log in and repeat your request.'),
                    buttons=[UriButton(
                        title=gettext('Log in.'),
                        uri='yandex-auth://',
                    )],
                ))
            self.logger.context_push(user_uid=user_uid)

            data = await self.get_data()
            app_id = self.get_app_id(data)

            self.logger.context_push(
                megamind_request_id=self.get_request_id(data),
                text=SensitiveDataHolder(self.get_text(data)),
                app_id=app_id,
                uuid=self.get_uuid(data),
            )

            if app_id not in settings.ALLOWED_APP_IDS:
                self.logger.info('app_id is not allowed.')
                return self.make_response(IRRELEVANT_RESPONSE)

            # Setting up context
            BaseScenario.context.logger = self.logger
            BaseScenario.context.request_id = self.request_id
            BaseScenario.context.user = User(
                uid=user_uid,
                timezone=self.get_timezone(data),
                user_ticket=self.request.headers.get('X-Ya-User-Ticket'),
            )

            # Request might contain multiple frames. Only the first relevant frame is handled.
            for frame_data in self.get_frames_data(data):
                with self.logger:
                    state = BaseScenario.context.state = self.get_state(data)
                    frame_name = self.get_frame_name(frame_data)
                    slots = self.get_slots(frame_data)
                    self.logger.context_push(
                        state=SensitiveDataHolder(state),
                        frame_name=frame_name.value,
                        slots=SensitiveDataHolder(slots),
                    )

                    # Handling request
                    self.logger.debug('State, frame, slots initialized. Calling stack handle.')
                    scenario_response = await ScenarioRunner(
                        frame_name=frame_name,
                        slots=slots,
                        commit=self.COMMIT,
                    ).run()
                    if scenario_response.irrelevant:
                        continue
                    return self.make_response(scenario_response, frame_name, data)
        return self.make_response(IRRELEVANT_RESPONSE)
