import datetime
import threading
import time
import uuid
from abc import ABC, abstractmethod
from contextlib import contextmanager
from functools import wraps
from typing import Optional

import grpc
from load.projects.cloud.cloud_helper import iam
from load.projects.cloud.loadtesting import config
from load.projects.cloud.loadtesting.db import DB
from load.projects.cloud.loadtesting.logan import Logan
from load.projects.cloud.loadtesting.server.api.common import utils
from load.projects.cloud.loadtesting.server.obfuscator import OBFUSCATOR
from load.projects.cloud.loadtesting.server.api.common.utils import PushMonitoringData

from prometheus_client import CollectorRegistry, Histogram

registry = CollectorRegistry()
APP_METRICS = Histogram('request_latency_ms', 'Description of histogram',
                        ['instance', 'handle', 'rpc_code'],
                        buckets=(5, 10, 25, 50, 75, 100, 250, 500, 750, 1000, 1500, 2000, 2500, 3000, 5000, 7500, 10000),
                        registry=registry)
GATEWAY_METRICS = Histogram('pushgateway_latency_ms', 'Description of histogram',
                            ['instance', 'handle'],
                            buckets=(5, 10, 15, 20, 25, 30, 35, 40, 45, 50, 60, 70, 80, 90, 100, 150, 200, 250, 300, 350, 400, 500, 1000),
                            registry=registry)


def extract_x_request_id(context: grpc.ServicerContext) -> Optional[str]:
    meta = context.invocation_metadata()
    for k, v in meta:
        if k.lower() == 'x-request-id':
            return v
    return


def hide_error_details(method):
    @wraps(method)
    def wrapper(self, request, context):
        try:
            return method(self, request, context)
        except grpc.RpcError as error:
            raise context.abort(error.code(), error.details())
        except Exception as error:
            if not context._state.code:
                if config.ENV_CONFIG.ENV_TYPE == config.EnvType.PROD.value:
                    raise context.abort(
                        grpc.StatusCode.INTERNAL,
                        'Server internal error. '
                        f'Please save request id "{self.request_id}" for support team to identify incident.'
                    )
                raise context.abort(grpc.StatusCode.INTERNAL, str(error))
            # Случай `bool(context._state.code) is True` будет восприниматься grpc как RpcError.
            # Мы оставляем возбуждение исключения на grpc либу.
            # При этом с нашей точки зрения мы дополнительно не обрабатываем и не маскируем такие ошибки
            # потому что они приходят от других облачных сервисов, которые маскируют ошибки по своим правилам.

    return wrapper


def limit_response_time(handler):
    @wraps(handler)
    def time_limited_handler(self, request, context):
        work_finished = threading.Event()
        response_container = {'response': None,
                              'exception': None}

        def handle_in_background():
            try:
                response_container['response'] = handler(self, request, context)
            except Exception as e:
                response_container['exception'] = e
            finally:
                work_finished.set()

        work_thread = threading.Thread(target=handle_in_background)
        work_thread.start()

        work_thread.join(timeout=self._response_time_limit)
        if work_finished.is_set():
            if (exception := response_container['exception']) is not None:
                raise exception
            return response_container['response']
        else:
            raise context.abort(grpc.StatusCode.DEADLINE_EXCEEDED, 'Server-side Timeout!')

    return time_limited_handler


class BaseHandler(ABC):
    _handler_name = 'handler'
    _response_time_limit = 20  # seconds

    def __init__(self, parent_logger: Logan):
        self._logger: Logan = parent_logger.getChild(self._handler_name)

        self._request_id = None
        self._user_token = None
        self._user_id = None
        self._db = None
        self._request = None
        self._context = None
        self._response = None
        self._error = None
        self._request_time = datetime.datetime.utcnow()
        self._response_time = None
        self._elapsed_time = None
        self._method = None
        self._rpc_code = None

        self._entry_logger = self.logger.entry_logger()
        self._access_logger = self.logger.access_logger()

        self._access_logger.bind('request_time', self._request_time.isoformat())

    @contextmanager
    def _resources(self):
        with DB() as db:
            self._db = db
            yield self

    @property
    def db(self) -> DB:
        return self._db

    @property
    def logger(self):
        return self._logger

    @property
    def user_token(self):
        return self._user_token

    @property
    def request(self):
        return self._request

    @request.setter
    def request(self, value):
        self._request = value
        for l in [self._entry_logger, self._access_logger]:
            l.bind('request_size', self._request.ByteSize())

    @property
    def context(self):
        return self._context

    @context.setter
    def context(self, value):
        self._context = value
        self._method = self._context._rpc_event.call_details.method.decode()
        for l in [self._entry_logger, self._access_logger]:
            l.bind('client_host', self._context._rpc_event.call_details.host.decode())
            l.bind('method', self._method)

    @property
    def request_id(self):
        return self._request_id

    @request_id.setter
    def request_id(self, value):
        self._request_id = value
        for l in [self._logger, self._entry_logger, self._access_logger]:
            l.bind('x_request_id', self._request_id)

    @property
    def user_id(self):
        return self._user_id

    @user_id.setter
    def user_id(self, value):
        self._user_id = value
        for l in [self._logger, self._entry_logger, self._access_logger]:
            l.bind('user_id', self.user_id)

    @property
    def lang(self):
        return 'en' if ('accept-language', 'en') in self.context.invocation_metadata() else 'ru'

    def _rpc_response_summary(self):
        if self._response is not None:
            return 'OK', '', self._response.ByteSize()
        # aborted by context.abort
        if (code := self.context._state.code) is not None:
            return code.name, self.context._state.details or '', 0
        # aborted by exception. details will be in "error" field of logs
        return 'UNKNOWN', '', 0

    def _bind_response_details(self):
        self._response_time = datetime.datetime.utcnow()
        self._elapsed_time = self._response_time - self._request_time
        self._rpc_code, rpc_details, response_size = self._rpc_response_summary()

        for logger in [self._logger, self._access_logger]:
            logger.bind('response_time', self._response_time.isoformat())
            logger.bind('elapsed_time_ms', int(self._elapsed_time.total_seconds() * 1000))
            logger.bind('rpc_code', self._rpc_code)
            logger.bind('rpc_details', rpc_details)
            logger.bind('response_size', response_size)
            logger.bind('error', self._error)

    def observe_metrics(self):
        instance_id = Logan.binded_global()['server_name']
        request_time = int(self._elapsed_time.total_seconds()* 1000)

        push_time = time.time()
        APP_METRICS.labels(instance=instance_id, handle=self._method, rpc_code=self._rpc_code).observe(request_time)
        push_time = time.time() - push_time
        GATEWAY_METRICS.labels(instance=instance_id, handle='/histogram-observe-time').observe(int(push_time * 1000))

        push_time = time.time()
        PushMonitoringData.push()
        push_time = time.time() - push_time
        GATEWAY_METRICS.labels(instance=instance_id, handle='/histogram-push-time').observe(int(push_time * 1000))
        return

    @hide_error_details
    @limit_response_time
    def handle(self, request, context: grpc.ServicerContext):

        self.request = request
        self.context = context
        try:
            self.request_id = extract_x_request_id(context) or f'unset_{str(uuid.uuid4())}'
            try:
                self._user_token = utils.user_iam_token(context.invocation_metadata())
                if not self._user_token:
                    context.abort(grpc.StatusCode.UNAUTHENTICATED, 'Authentication required')
                self.user_id = iam.IAM.authenticate(self._user_token, self.request_id)
            except Exception:
                self.user_id = 'unknown'
                raise
            finally:
                self._entry_logger.send('received')

            obfuscated_request_logger = self.logger.getChild('obfuscated')
            obfuscated_request_logger.bind('request', message_obfuscated_data(self.request))
            obfuscated_request_logger.bind('request_class', _class_info(self.request))
            obfuscated_request_logger.info('received')

            with self._resources():
                self._response = self.proceed()
        except Exception as e:
            self._error = e
            self._bind_response_details()
            self._logger.exception(str(e))
            self._access_logger.send('error')
            raise
        else:
            self._bind_response_details()
            self._access_logger.send('success')
            return self._response
        finally:
            obfuscated_response_logger = self.logger.getChild('obfuscated')
            obfuscated_response_logger.bind('response', message_obfuscated_data(self._response))
            obfuscated_response_logger.bind('response_class', _class_info(self._response))
            obfuscated_response_logger.info('sent')
            self.observe_metrics()

    @abstractmethod
    def proceed(self):
        pass


class BasePublicHandler(BaseHandler, ABC):
    _handler_name = 'private_v2 handler'

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self._logger = self.logger.public_api_logger()


class BasePrivateHandler(BaseHandler, ABC):
    _handler_name = 'private_v1 handler'

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self._logger = self.logger.private_api_logger()


def _class_info(instance) -> str:
    return f'{instance.__class__.__module__}.{instance.__class__.__name__}'


def message_obfuscated_data(grpc_message) -> dict:
    if grpc_message is None:
        return {}
    if (obfuscator := OBFUSCATOR.get(grpc_message.__class__, None)) is None:
        return {'representation failed': 'obfuscator has not been found'}
    try:
        return obfuscator(grpc_message).data()
    except Exception as e:
        return {'representation failed': e}
