# coding: utf-8

import time
from contextlib import contextmanager
import logging

import psycopg2.extensions as PGE

from pymail.log_helpers import LOG_SQL, LOG_SQL_NOTICE

log = logging.getLogger(__name__)


class LoggingCursor(PGE.cursor):
    """A cursor that logs queries using its connection logging facilities."""
    # pylint: disable=redefined-builtin, attribute-defined-outside-init

    def execute(self, query, vars=None):
        self.timestamp = time.time()
        try:
            return super(LoggingCursor, self).execute(query, vars)
        finally:
            self.connection.write_log(self.query, self)

    def executemany(self, query, vars=None):
        self.timestamp = time.time()
        try:
            return super(LoggingCursor, self).executemany(query, vars)
        finally:
            self.connection.write_log(self.query, self)

    def callproc(self, procname, vars=None):
        self.timestamp = time.time()
        try:
            return super(LoggingCursor, self).callproc(procname, vars)
        finally:
            self.connection.write_log(self.query, self)

    def copy_from(self, file, table, *args, **kwargs):
        self.timestamp = time.time()
        try:
            return super(LoggingCursor, self).copy_from(file, table, *args, **kwargs)
        finally:
            self.connection.write_log("COPY %s FROM %r " % (table, file), self)

    def copy_to(self, file, table, *args, **kwargs):
        self.timestamp = time.time()
        try:
            return super(LoggingCursor, self).copy_to(file, table, *args, **kwargs)
        finally:
            self.connection.write_log("COPY %s TO %r " % (table, file), self)

    def copy_expert(self, sql, file, *args, **kwargs):
        self.timestamp = time.time()
        try:
            return super(LoggingCursor, self).copy_expert(sql, file, *args, **kwargs)
        finally:
            self.connection.write_log(sql, self)


class LoggingConnection(PGE.connection):
    sql_logger = logging.getLogger('sql')

    SQL_LINE_LIMIT = None

    def _do_log(self, level, message, **extra):
        if hasattr(self, 'skip_logs'):
            if self.skip_logs:
                return
        self.sql_logger.log(
            level,
            message,
            extra=extra
        )

    def _print_notices(self):
        while self.notices:
            self._do_log(
                LOG_SQL_NOTICE,
                self.notices.pop(0).strip()
            )

    def write_log(self, msg, curs):
        self._print_notices()
        if isinstance(msg, bytes):
            msg = unicode(msg, 'utf-8')
        elif not isinstance(msg, str):
            msg = str(msg)  # pylint: disable=redefined-variable-type
        msg = u' '.join(s.strip() for s in msg.split(u'\n'))
        if self.SQL_LINE_LIMIT is not None and len(msg) > self.SQL_LINE_LIMIT:
            msg = msg[:self.SQL_LINE_LIMIT] + u'...'

        backend_pid = None
        if not self.closed:
            backend_pid = self.get_backend_pid()
        self._do_log(
            LOG_SQL,
            msg,
            execution_time_ms=int((time.time() - curs.timestamp) * 1000),
            pg_backend_pid=backend_pid,
        )

    def cursor(self, *args, **kwargs):
        kwargs.setdefault('cursor_factory', LoggingCursor)
        return super(LoggingConnection, self).cursor(*args, **kwargs)

    def close(self):
        self._print_notices()
        return super(LoggingConnection, self).close()


@contextmanager
def unlogged(conn):
    conn.skip_logs = True
    try:
        yield conn
    finally:
        conn.skip_logs = False


DEFAULT_FMT = "%(asctime)s %(levelname)-8s%(filename)s:" \
              "%(lineno)d: %(message)s"

COLORED_FMT = "%(asctime)s %(log_color)s%(levelname)-8s%(filename)s:" \
              "%(lineno)d: %(message)s"
