import cProfile
import io
import json
import logging
import pstats

import six
from django.conf import settings
from django.db import connections
from django.utils.encoding import force_text
from monotonic import monotonic
from ylog.context import log_context

logger = logging.getLogger(__name__)


# @todo вынести этот модуль в общие, -- django_tools_log_context


def wrap_cursor(connection):
    if not hasattr(connection, '_orig_cursor'):
        connection._orig_cursor = connection.cursor

        def cursor():
            return WikiCursorWrapper(connection._orig_cursor(), connection)

        connection.cursor = cursor


def unwrap_cursor(connection):
    if hasattr(connection, '_orig_cursor'):
        del connection._orig_cursor
        del connection.cursor


def enable_instrumentation():
    for connection in connections.all():
        wrap_cursor(connection)


def disable_instrumentation():
    for connection in connections.all():
        unwrap_cursor(connection)


class WikiCursorWrapper(object):
    def __init__(self, cursor, db):
        self.cursor = cursor
        self.db = db
        self.logger = logging.getLogger(__name__)

    def _quote_expr(self, element):
        if isinstance(element, six.string_types):
            return "'%s'" % force_text(element, errors='ignore').replace("'", "''")
        else:
            return repr(element)

    def _quote_params(self, params):
        if not params:
            return params
        if isinstance(params, dict):
            return {key: self._quote_expr(value) for key, value in params.items()}
        return list(map(self._quote_expr, params))

    def _record(self, method, raw_sql, params):
        start_time = monotonic()
        try:
            return method(raw_sql, params)
        finally:
            duration = (monotonic() - start_time) * 1000
            conn = self.db.connection
            vendor = getattr(self.db, 'vendor', 'unknown')

            is_select = raw_sql.lower().strip().startswith('select')

            if is_select or settings.TOOLS_LOG_CONTEXT_ALWAYS_SUBSTITUTE_SQL_PARAMS:
                sql = self.db.ops.last_executed_query(self.cursor, raw_sql, self._quote_params(params))
            else:
                sql = raw_sql
            params = {}
            profiling = {
                'is_select': is_select,
                'is_slow': duration > settings.TOOLS_LOG_CONTEXT_SQL_WARNING_THRESHOLD,
                'vendor': 'database',
                'query_to_analyse': raw_sql,
            }

            if vendor == 'postgresql':
                # If an erroneous query was ran on the connection, it might
                # be in a state where checking isolation_level raises an
                # exception.
                try:
                    iso_level = conn.isolation_level
                except conn.InternalError:
                    iso_level = 'unknown'
                profiling.update(
                    {
                        'encoding': conn.encoding,
                        'iso_level': iso_level,
                        'trans_status': conn.get_transaction_status(),
                    }
                )

            params['profiling'] = profiling
            with log_context(execution_time=int(duration), **params):
                self.logger.info('(%.3f msec) %s', duration, sql)

    def execute(self, sql, params=None):
        return self._record(self.cursor.execute, sql, params)

    def executemany(self, sql, param_list):
        return self._record(self.cursor.executemany, sql, param_list)

    def __getattr__(self, attr):
        return getattr(self.cursor, attr)

    def __iter__(self):
        return iter(self.cursor)

    def __enter__(self):
        return self

    def __exit__(self, type, value, traceback):
        self.close()


class ProfileCtx:
    def __init__(self, request):
        self.active = request.GET.get('__profile_me') is not None
        self.prof = None

    def __enter__(self):
        if self.active:
            self.prof = cProfile.Profile()
            self.prof.enable()
            logger.info('Profiling enabled')
            enable_instrumentation()

    def __exit__(self, exc_type, exc_val, exc_tb):
        if self.active:
            disable_instrumentation()
            self.prof.disable()
            output = io.StringIO()
            stats = pstats.Stats(self.prof, stream=output)
            stats.sort_stats('cumulative')
            stats.print_stats(50)
            output.seek(0)
            data = output.read().split('\n')
            print(json.dumps(data))


class RequestProfileMiddleware:
    def __init__(self, get_response):
        self.get_response = get_response

    def __call__(self, request):
        with ProfileCtx(request):
            response = self.get_response(request)

        return response
