import logging
import click
from typing import Union
from fastapi import FastAPI
import uvicorn
import json

from customer_service.ml.chats.api.lib.inference_model import PassportInferenceModel

from customer_service.ml.chats.api.lib.models import (
    InputRequest,
    ErrorResponse,
    ClassifyResponse
)

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

global inference_config
global inference_model

app = FastAPI()


@app.on_event("startup")
def startup():
    global inference_model
    inference_model = PassportInferenceModel(inference_config)


@app.post("/classify")
async def classify(request: InputRequest) -> Union[ClassifyResponse, ErrorResponse]:
    check = inference_model.is_valid_request(request)
    if not check.passed:
        return check

    request = inference_model.preprocess_request(request)
    response = inference_model.inference(request)

    return response


@click.command()
@click.option('--host', default='::', help='API host')
@click.option('--port', default=7777, help='API port')
@click.option('--config', required=True, help='API config')
def main(host: str, port: int, config: str):
    logger.info('Passport messages classifier API')
    print("Start debug")
    global inference_config
    with open(config) as cfg:
        inference_config = json.load(cfg)

    uvicorn.run(app, host=host, port=port)


if __name__ == '__main__':
    main()
