# coding: utf-8

import logging

from at.common import dbswitch
from at.common import utils

logger = logging.getLogger(__name__)


class NumericPager(object):

    before = None
    after = None

    def __init__(self, tb, limit, has_more):
        self.has_more = has_more
        if has_more:
            self.before = limit + tb
        if tb:
            self.after = max(0, tb - limit)


class TimestampPager(object):

    before = None

    # не бывает (листалка в одну сторону),
    # но для едонобразия интерфейса добавим
    after = None

    def __init__(self, last_timestamp, has_more):
        self.has_more = has_more
        if self.has_more:
            self.before = last_timestamp


class PostsLoader(object):

    pager_type = None
    _has_more = None
    _posts_data = None

    select_fields = ('Posts.person_id', 'Posts.post_no')
    limit = None
    offset = None

    ACCESS_JOIN = """
        LEFT JOIN FriendGroupMember fgm ON (
            fgm.uid = %(viewer_id)s
            AND fgm.person_id = Posts.person_id
            AND fgm.fgroup_id >= Posts.access_group
            AND fgm.fgroup_id < 100  -- special accesses
        )
    """
    ACCESS_CONDITION = """
        AND Posts.deleted = 0
        AND (
            Posts.author_uid = %(viewer_id)s
            OR Posts.access_group = 0
            OR fgm.fgroup_id is NOT NULL
        )
    """

    defaults = {
        'limit': None,
        'tb': utils.usec,
    }

    def __init__(self, viewer_id, **params):
        self.viewer_id = viewer_id

        for key, value in list(params.items()):
            setattr(self, key, value)

        for key, value in list(self.defaults.items()):
            if not hasattr(self, key):
                setattr(self, key, self._get_default(key))

    @classmethod
    def _get_default(cls, key):
        value = cls.defaults[key]
        if callable(value):
            value = value()
        return value

    @property
    def posts_data(self):
        if self._posts_data is None:
            rows = self.get_post_rows()
            rows = self.post_handle_rows(rows)

            if self.limit and self.pager_type and len(rows) > self.limit:
                self._posts_data = rows[:self.limit]
                self._has_more = True
            else:
                self._posts_data = rows
                self._has_more = False
        return self._posts_data

    def get_post_rows(self):
        return self.select_rows(
            query=self.query,
            query_args=self.get_query_args(),
        )

    @staticmethod
    def select_rows(query, query_args):
        logger.debug('Loading posts with query :\n %s', query % query_args)
        with utils.get_connection() as conn:
            rows = conn.execute(query, query_args).fetchall()
        return list(rows)

    def post_handle_rows(self, rows):
        return rows

    @property
    def pager(self):
        if self.pager_type not in ('numeric', 'timestamp'):
            return

        if self.pager_type == 'numeric':
            return NumericPager(
                tb=self.tb,
                limit=self.limit,
                has_more=self._has_more,
            )

        if self.pager_type == 'timestamp':
            if not self._posts_data:
                return

            timestamp_field = 'Posts.store_time_usec'
            if timestamp_field not in self.select_fields:
                return

            # list потому что Python 2.5 does not support tuple.index() method
            timestamp_index = list(self.select_fields).index(timestamp_field)
            last_post = self._posts_data[-1]
            last_timestamp = last_post[timestamp_index]

            return TimestampPager(
                last_timestamp=last_timestamp,
                has_more=self._has_more,
            )

    # топорно формируем sql из кусочков
    @property
    def query(self):
        return ' '.join(
            self.get_select_lines() +
            self.get_from_lines() +
            self.get_where_lines() +
            self.get_group_by_lines() +
            self.get_order_lines() +
            self.get_limit_lines() +
            self.get_offset_lines()
        )

    def get_select_lines(self, fields=None):
        return [
            'SELECT %s' % ', '.join(fields or self.select_fields)
        ]

    def get_from_lines(self):
        return []

    def get_where_lines(self):
        return []

    def get_order_lines(self):
        if self.order:
            return ['ORDER BY %s DESC' % self.order]
        return []

    def get_query_limit(self):
        """
        Если есть пагинатор, нужно запрашивать с +1, чтобы понять есть
        ли еще записи.
        """
        if self.limit and self.pager_type:
            return self.limit + 1
        return self.limit

    def get_limit_lines(self, limit=None):
        query_limit = limit or self.get_query_limit()
        if query_limit:
            return ['LIMIT %s' % query_limit]
        return []

    def get_offset_lines(self):
        return []

    def get_group_by_lines(self):
        return []

    def get_query_args(self):
        args = {
            'viewer_id': self.viewer_id,
        }
        if self.limit:
            args['limit'] = self.limit
        if self.offset:
            args['offset'] = self.offset
        return args

