from abc import ABC, abstractmethod
from typing import Union
import aiohttp
import json
import numpy as np
from catboost import CatBoostClassifier
import sys
from .utils import check_template

from .models import (
    InputRequest,
    ErrorResponse,
    ClassifyResponse
)


class ClassifierBase(ABC):
    def __init__(self, cfg):
        self.config = cfg

    @abstractmethod
    def predict(X):
        pass


class PassportClassifier(ClassifierBase):

    def __init__(self, cfg):
        super().__init__(cfg)
        self.model = CatBoostClassifier()
        self.model.load_model(cfg['model_path'])

    def predict(self, X):
        return self.model.predict(X)


class InferenceModelBase(ABC):

    def __init__(self, cfg):
        self.config = cfg
        self.embedder = None
        self.classifier = None
        self._init_embedder()
        self._init_classifier()

    def _init_classifier(self):
        cfg = self.config['classifier']
        classifier_class = getattr(sys.modules[__name__], cfg['name'])
        self.classifier = classifier_class(cfg)

    def _init_embedder(self):
        pass

    @abstractmethod
    def is_valid_request(self, request: InputRequest) -> ErrorResponse:
        pass

    @abstractmethod
    def preprocess_request(self, request: InputRequest) -> InputRequest:
        pass

    @abstractmethod
    async def inference(self, request: InputRequest) -> Union[ClassifyResponse, ErrorResponse]:
        pass


class PassportInferenceModel(InferenceModelBase):

    def _init_embedder(self):
        self.embedder = self.config['embedder_url']

    def is_valid_request(self, request) -> ErrorResponse:
        if not request.text and not request.messages:
            return ErrorResponse(False, 'Error: Provide either text or messages in request')

        if not request.text and request.messages:
            for message in request.messages:
                if message.author not in ['оператор', 'пользователь', 'робот']:
                    return ErrorResponse(False, "Error: Author must be one of 'оператор', 'пользователь', 'робот'")
                if message.author == 'робот' or check_template(message.text):
                    continue

        return ErrorResponse(True, 'OK')

    def preprocess_request(self, request: InputRequest) -> InputRequest:
        if request.text:
            return request

        user_messages = []
        for message in request.messages:
            if message.author == 'пользователь':
                user_messages.append(message)
            else:
                break

        request.text = '\n'.join(user_messages)
        return request

    async def inference(self, request: InputRequest) -> Union[ClassifyResponse, ErrorResponse]:
        is_valid_response = self.is_valid_request(request)
        if not is_valid_response.passed:
            return is_valid_response

        request = self.preprocess_request(request)
        if not request.text:
            return ErrorResponse(False, 'Error: No messages from user')

        async with aiohttp.ClientSession() as session:
            async with session.post(
                self.embedder if request.custom_embedder_api_url is None else request.custom_embedder_api_url,
                json={"Context": [request.text]},
            ) as response:
                data = await response.read()
                embed = np.array(json.loads(data)["Embedding"])[None, :]
                y_pred = self.classifier.predict(embed)

        return ClassifyResponse(request.text, embed, y_pred)
