# -*- coding: utf-8 -*-

import logging
from threading import Lock
from at.common.utils import get_connection

_log = logging.getLogger(__name__)


def MagicTable(table, fields, constructor = lambda x: x, quiet=True):
    """ __iter__() -> список туплов
        all_FIELDs() -> список всех значений поля FIELD
        by_FIELD(value) -> список туплов, где поле FIELD == value
    """
    fields = [f.strip() for f in fields.split(',')]
    class S(object):
        _rows = None
        lock = Lock()
        def __iter__(self):
            return (r[1] for r in self.rows)
        def __init__(self):
            pass
        def _getrows(self):
            if self._rows is not None:
                return self._rows
            with self.lock:
                if self._rows is not None:
                    return self._rows
                # delayed import to avoid recursion
                from . import dbswitch
                # NB we store both plain dict and constructed object
                # it makes some overhead, but MagicTable is for 
                # very small tables anyway. 
                def wrap(row):
                    d = dict(list(zip(fields, row)))
                    return (d, constructor(d))
                sql = 'SELECT ' + ','.join(fields) + ' FROM ' + table
                with get_connection() as conn:
                    self.__class__._rows = [
                        wrap(row) for row in conn.execute(sql)
                    ]
            return self._rows

        rows = property(_getrows)

    for f in fields:
        def _(f):
            def by(s, v):
                try:
                    return [r[1] for r in s.rows if r[0][f] == v][0]
                except IndexError:
                    if quiet:
                        return constructor(dict(list(zip(fields, [None]*len(fields)))))
                    else: raise
            setattr(S, 'by_' + f, by)
            def filter(s, v):
                return [r[1] for r in s.rows if r[0][f] == v]
            setattr(S, 'filter_' + f, filter)
            def all(s):
                return [r[0][f] for r in s.rows]
            setattr(S, 'all_' + f + 's', all)
        _(f)
    return S



