# coding: utf-8

import time
from contextlib import contextmanager
import logging

import psycopg2.extensions as PGE
from six import binary_type, text_type

LOG_SQL = 25
LOG_NOTICE = 27
logging.addLevelName(LOG_SQL, "SQL")
logging.addLevelName(LOG_NOTICE, "NOTICE")

log = logging.getLogger(__name__)


class LoggingCursor(PGE.cursor):
    """A cursor that logs queries using its connection logging facilities."""
    # pylint: disable=redefined-builtin, attribute-defined-outside-init

    def execute(self, query, vars=None):
        self.timestamp = time.time()
        try:
            return super(LoggingCursor, self).execute(query, vars)
        finally:
            self.connection.write_log(self.query, self)

    def executemany(self, query, vars=None):
        self.timestamp = time.time()
        try:
            return super(LoggingCursor, self).executemany(query, vars)
        finally:
            self.connection.write_log(self.query, self)

    def callproc(self, procname, vars=None):
        self.timestamp = time.time()
        try:
            return super(LoggingCursor, self).callproc(procname, vars)
        finally:
            self.connection.write_log(self.query, self)

    def copy_from(self, file, table, *args, **kwargs):
        self.timestamp = time.time()
        try:
            return super(LoggingCursor, self).copy_from(file, table, *args, **kwargs)
        finally:
            self.connection.write_log("COPY %s FROM %r " % (table, file), self)

    def copy_to(self, file, table, *args, **kwargs):
        self.timestamp = time.time()
        try:
            return super(LoggingCursor, self).copy_to(file, table, *args, **kwargs)
        finally:
            self.connection.write_log("COPY %s TO %r " % (table, file), self)

    def copy_expert(self, sql, file, *args, **kwargs):
        self.timestamp = time.time()
        try:
            return super(LoggingCursor, self).copy_expert(sql, file, *args, **kwargs)
        finally:
            self.connection.write_log(sql, self)


class LoggingConnection(PGE.connection):

    sql_logger = logging.getLogger('sql')

    SQL_LINE_LIMIT = None
    LOG_FMT = u"{msg} (execution time: {ms} ms)"

    def _do_log(self, level, message):
        if hasattr(self, 'skip_logs'):
            if self.skip_logs:
                return
        self.sql_logger.log(
            level,
            message
        )

    def _print_notices(self):
        while self.notices:
            self._do_log(
                LOG_NOTICE,
                self.notices.pop(0).strip()
            )

    def write_log(self, msg, curs):
        self._print_notices()

        ms = (time.time() - curs.timestamp) * 1000
        if isinstance(msg, binary_type):
            msg = text_type(msg, 'utf-8')
        elif not isinstance(msg, text_type):
            msg = str(msg)  # pylint: disable=redefined-variable-type
        if self.SQL_LINE_LIMIT is not None and len(msg) > self.SQL_LINE_LIMIT:
            msg = msg[:self.SQL_LINE_LIMIT] + u'...'

        backend_pid = None
        if not self.closed:
            backend_pid = self.get_backend_pid()
        self._do_log(
            LOG_SQL,
            self.LOG_FMT.format(
                msg=msg,
                ms=ms,
                pid=backend_pid
            )
        )

    def cursor(self, *args, **kwargs):
        kwargs.setdefault('cursor_factory', LoggingCursor)
        return super(LoggingConnection, self).cursor(*args, **kwargs)

    def close(self):
        self._print_notices()
        return super(LoggingConnection, self).close()


@contextmanager
def unlogged(conn):
    conn.skip_logs = True
    try:
        yield conn
    finally:
        conn.skip_logs = False

DEFAULT_FMT = "%(asctime)s %(levelname)-8s%(filename)s:" \
              "%(lineno)d: %(message)s"

COLORED_FMT = "%(asctime)s %(log_color)s%(levelname)-8s%(filename)s:" \
              "%(lineno)d: %(message)s"


def init_logging(level, fmt=None, colored_fmt=None):
    import sys
    import os

    def make_colored_formatter():
        try:
            from colorlog import ColoredFormatter
        except ImportError:
            return None

        return ColoredFormatter(
            colored_fmt or COLORED_FMT,
            log_colors={
                'SQL':      'cyan',
                'NOTICE':   'blue',
                'INFO':     'green',
                'WARNING':  'yellow',
                'ERROR':    'red',
                'CRITICAL': 'red'
            }
        )

    formatter = None
    if sys.stderr.isatty() or os.environ.get('FORCE_COLORS'):
        formatter = make_colored_formatter()

    if formatter is None:
        formatter = logging.Formatter(
            fmt or DEFAULT_FMT
        )

    handler = logging.StreamHandler(sys.stderr)
    handler.setFormatter(formatter)
    logger = logging.getLogger()
    logger.addHandler(handler)
    logger.setLevel(level)
