import json
import logging
from typing import Iterable, Optional, Type

from ylog.context import LogContext as YLogLogContext, get_log_context
from ylog.format import QloudJsonFormatter, IS_DEPLOY


logger = logging.getLogger(__name__)


class LogFormatter(QloudJsonFormatter):

    def format(self, record: logging.LogRecord) -> str:
        copy_from_context = ['request_id', 'original_request_id']
        record.message = record.getMessage()

        log_data = {
            'message': record.message,
            'level': record.levelname,
        }
        if IS_DEPLOY:
            log_data['levelStr'] = record.levelname
            log_data['loggerName'] = record.name
            log_data['level'] = record.levelno

        if record.exc_info:
            exc = logging.Formatter.formatException(self, record.exc_info)
            log_data['stackTrace'] = exc

        fields = {}

        standard_fields = self._get_standard_fields(record)
        if standard_fields:
            standard_fields['orig_msg'] = record.msg
            fields['std'] = standard_fields

        log_context_fields = get_log_context()
        if log_context_fields:
            fields['context'] = log_context_fields

            for field in copy_from_context:
                if field in log_context_fields:
                    log_data[field] = log_context_fields[field]

        if fields:
            log_data['@fields'] = fields

        return json.dumps(log_data)


class LogContext(YLogLogContext):

    def __init__(self, expected_exceptions: Optional[Iterable[Type[Exception]]] = None, **context):
        super().__init__(**context)
        self._expected_exceptions = expected_exceptions or []

    def __exit__(self, exc_type, exc_val, exc_tb):
        if exc_type is not None:
            if self._is_expected_exceptions(exc_type):
                logger.info('An expected exception %s occurred', exc_type, exc_info=True)
            else:
                logger.exception('Uncaught exception %s leaves log context', exc_type)

        return super().__exit__(exc_type, exc_val, exc_tb)

    def _is_expected_exceptions(self, exc_type: Type[Exception]) -> bool:
        return any(issubclass(exc_type, expected_exception) for expected_exception in self._expected_exceptions)


log_context = LogContext
