# -*- coding: utf-8 -*-

import logging
from string import Template

from at.common.utils import get_connection

_log = logging.getLogger(__name__)


class StrictMapper(object):
    __slots__ = (
        'table_name', 'person_id',
        'data_fields', 'data_template',
        'key_fields', 'key_template', "insert_with"
        )

    def __init__(self, person_id):
        self.table_name = None
        self.insert_with = "INSERT IGNORE"
        self.person_id = person_id

    def _get_data_fields_names(self):
        return ", ".join(f for f in self.data_fields)

    def _get_key_fields_names(self):
        return ", ".join(f for f in self.key_fields)

    def _get_where_key(self, key):
        cond_list = [(name+"="+tmpl) % val for name, tmpl, val in zip(self.key_fields, self.key_template.split(','), key)]
        return "("+(" and ".join(cond_list))+")"

    def _get_where_key2(self, key):
        if key:
            return " AND %s" % self._get_where_key(key)
        return ""

    def _get_data_length(self):
        return len(self.data_fields)

    def _get_key_length(self):
        return len(self.key_fields)

    def __setitem__(self, key, value):
        with get_connection() as conn:
            query = Template("""
            $insert_with
                $table_name (`person_id` $key_fields, $data_fields)
            VALUES
                ( %d $key_template, $data_template)
                """).substitute(
                    key_fields = ", %s" % self._get_key_fields_names() if self._get_key_length() else "",
                    key_template = ", %s" % self.key_template if self._get_key_length() else "",
                    data_fields = self._get_data_fields_names(),
                    data_template = self.data_template,
                    table_name=self.table_name,
                    insert_with = self.insert_with,
                     )
            args = (self.person_id,) + key + value
            conn.execute( query % args)

    def __delitem__(self, key):
        with get_connection() as conn:
            query = Template("""
            DELETE FROM $table_name
            WHERE `person_id`=%d %s
            """).substitute( table_name=self.table_name)
            conn.execute( query % (self.person_id, self._get_where_key2(key)) )

    def delete(self, key):
        return self.__delitem__(key)


    def __getitem__(self, key):
        with get_connection() as conn:
            query = Template("""
            SELECT $data_fields FROM $table_name
            WHERE `person_id`=%d %s
            """).substitute(data_fields = self._get_data_fields_names(), table_name=self.table_name )

            cur=conn.execute(query % (self.person_id, self._get_where_key2(key)) )
            row = cur.fetchone()

            return row[:self._get_data_length()] if row is not None else None

    def _row2kv(self, row):
        return row[:self._get_key_length()], row[self._get_key_length():]

    def bulkget(self, keys):
        if not keys:
            return []
        with get_connection() as conn:
            query = Template("""
            SELECT $key_fields, $data_fields FROM $table_name
            WHERE `person_id`=%d and ( %s )
            """).substitute(
                key_fields=self._get_key_fields_names(),
                data_fields = self._get_data_fields_names(),
                table_name=self.table_name,
            )
            where_keys = " or ".join([self._get_where_key(k) for k in keys if k])
            if not where_keys:
                where_keys = "TRUE"
            cur = conn.execute(query % (self.person_id, where_keys))
            return [self._row2kv(row) for row in cur]

    def get_all(self):
        # LIMITS ?
        with get_connection() as conn:
            query = Template("""
            SELECT $key_fields, $data_fields
            FROM $table_name
            WHERE `person_id` = %d""").substitute(
                key_fields=self._get_key_fields_names(),
                data_fields = self._get_data_fields_names(),
                table_name=self.table_name,
            )
            cur = conn.execute(query % self.person_id)
            return [self._row2kv(row) for row in cur]


class IntKeyMapper(StrictMapper):

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.key_fields = ['`key`']
        self.key_template = '%d'

    def __getitem__(self, key):
        return StrictMapper.__getitem__(self, (key,))

    def __setitem__(self, key, value):
        return StrictMapper.__setitem__(self, (key,), value)

    def __delitem__(self, key):
        return StrictMapper.__delitem__(self, (key,))

    def bulkget(self, keys):
        return [(k[0], v) for k, v in StrictMapper.bulkget(self, [(k,) for k in keys])]



