import logging
import re
from concurrent.futures import TimeoutError
from contextlib import contextmanager, asynccontextmanager
from time import monotonic

import asyncpg

LOG_SQL = 25
LOG_SQL_NOTICE = 27
ADDITIONAL_LOG_LEVELS = {
    u'SQL': LOG_SQL,
    u'SQL_NOTICE': LOG_SQL_NOTICE,
}
sql_log = logging.getLogger('sql_log')


@contextmanager
def profiled(f):
    start = monotonic()
    yield
    f(monotonic() - start)


class LoggingConn(asyncpg.connection.Connection):
    NAME_RE = re.compile(r'^\s*--\s*(\w+)\s*$', flags=re.MULTILINE)

    def _log(self, method, query, args):
        def impl(time):
            if 'PERFORM * FROM pg_listening_channels() LIMIT 1' not in query:
                match = self.NAME_RE.match(query)
                name = match.groups(1)[0] if match else query
                return sql_log.log(
                    LOG_SQL,
                    '%s %s: [%s] with args %s took %0.3f seconds',
                    self.get_server_pid(), method, name, args, time
                )
        return impl

    async def fetch(self, query, *args, timeout=None, record_class=None):
        with profiled(self._log('fetch', query, args)):
            return await super(LoggingConn, self).fetch(query, *args, timeout=timeout, record_class=record_class)

    async def fetchrow(self, query, *args, timeout=None, record_class=None):
        with profiled(self._log('fetchrow', query, args)):
            return await super(LoggingConn, self).fetchrow(query, *args, timeout=timeout, record_class=record_class)

    async def fetchval(self, query, *args, column=0, timeout=None):
        with profiled(self._log('fetchval', query, args)):
            return await super(LoggingConn, self).fetchval(query, *args, column=column, timeout=timeout)

    async def execute(self, query: str, *args, timeout: float = None):
        with profiled(self._log('execute', query, args)):
            return await super(LoggingConn, self).execute(query, *args, timeout=timeout)

    async def executemany(self, query: str, args, *, timeout: float = None):
        with profiled(self._log('executemany', query, args)):
            return await super(LoggingConn, self).executemany(query, args, timeout=timeout)

    async def copy_records_to_table(self, table_name, *, records,
                                    columns=None, schema_name=None,
                                    timeout=None):
        with profiled(self._log('copy records', f'{schema_name}.{table_name}({" ,".join(columns)})', records)):
            return await super(LoggingConn, self).copy_records_to_table(
                table_name,
                records=records,
                columns=columns,
                schema_name=schema_name,
                timeout=timeout,
            )

    @asynccontextmanager
    async def transaction(self):
        tx = super(LoggingConn, self).transaction()
        try:
            await tx.start()
            yield tx
            await tx.commit()
        except TimeoutError as e:
            self.terminate()
            raise
        except:
            await tx.rollback()
            raise
