# coding: utf-8

import logging

from .common import fetch_as_objects

log = logging.getLogger(__name__)


class QueryHandlerContractViolation(Exception):
    pass


class ExpectOneItemError(QueryHandlerContractViolation):
    pass


class QueryRequireDifferentArgumentsError(QueryHandlerContractViolation):
    pass


class FetchAs(object):
    def __init__(self, Result):
        self.Result = Result

    def __call__(self, cur):
        # currently can't use generators,
        # should rewrite tests/steps
        return list(fetch_as_objects(cur, self.Result))


class LazyFetchAs(object):
    def __init__(self, Result):
        self.Result = Result

    def __call__(self, cur):
        # currently can't use generators,
        # should rewrite tests/steps
        return fetch_as_objects(cur, self.Result)


class FetchHeadExpectOneRowAs(object):
    def __init__(self, Result):
        self.Result = Result

    def __call__(self, cur):
        objects = list(fetch_as_objects(cur, self.Result))
        if len(objects) != 1:
            raise ExpectOneItemError(
                'Expect one item while fetching %r, got %d '
                'items: %r' % (self.Result, len(objects), objects)
            )
        return objects[0]


def sync_fetch_as_list(cur):
    if len(cur.description) != 1:
        raise QueryHandlerContractViolation(
            'sync_fetch_as_list expect cursor '
            'with one item in row, got %r' % cur.description)
    return [r[0] for r in cur]


class QueryHandler(object):
    handlers = {}

    def __init__(self, conn):
        self.conn = conn

    @staticmethod
    def _join_qargs(query, qargs):
        qargs = qargs or {}
        if set(qargs) != set(query.args):
            raise QueryRequireDifferentArgumentsError(
                '{0} require different arguments: {0.args}, got: {1} '.format(
                    query, qargs))
        return qargs

    def _execute_query(self, query, qargs, cur):
        cur.execute(query.query, qargs)
        if self.conn.async_:
            self.conn.wait()
        return cur

    def __getattr__(self, name):
        # Non-lazy by default, because lazy requires server-side cursor which itself has more requirements to use
        return self.eager(name)

    def eager(self, name):
        return self._get_handler(name, lazy=False)

    def lazy(self, name):
        return self._get_handler(name, lazy=True)

    def _get_handler(self, name, lazy=False):
        if name not in self.handlers:
            raise AttributeError('unknown handler %s' % name)

        def handler_call(**qargs):
            try:
                query, fetcher = self.handlers[name]
                qargs = self._join_qargs(query, qargs)
                if lazy:
                    uid = qargs.get('uid', 'uid')
                    cur = self.conn.cursor('cursor_{}_{}'.format(name, uid))
                else:
                    cur = self.conn.cursor()
                return fetcher(self._execute_query(query, qargs, cur))
            except Exception as exc:
                prefix_message = '[handler: %s, query %r] ' % (name, query)
                log.warning('Got %s %s', exc, prefix_message)
                if len(exc.args) >= 1:
                    exc.args = (prefix_message + exc.args[0],) + exc.args[1:]
                raise
        return handler_call
