# -*- coding: utf-8 -*-


import logging
import re
import traceback
import math

import lxml.etree as ET

from at.common import Types
from at.common.utils import (
    log_exception, stopwatch, et2xml, get_connection
)
from at.common import dbswitch

_log = logging.getLogger('aux/items')


def split(s):
    return re.split('(?:\s*[,;]\s*)+', s.strip())


def getTagCountScaleFunc(min, max, low, top):
    range = float(max - min)
    if range == 0.0:
        return lambda x: int(round(max))
    else:
        return lambda x: int(round(((x - min) * top + (max - x) * low) / range))


def scaleTagCount(countIdPairs, low=1, top=5):
    minCount = min(count for count, _ in countIdPairs)
    bias = 1 if minCount < 1 else 0
    countLogs = [math.log(count + bias) for count, _ in countIdPairs]
    scaleFunc = getTagCountScaleFunc(min(countLogs), max(countLogs), low, top)
    return [(scaleFunc(countLogs[i]), countIdPairs[i][1]) for i in
            range(len(countLogs))]


def appendTagCountsToElement(countIdPairs, parentElement):
    for (count, tag_id) in countIdPairs:
        tagElement = ET.SubElement(parentElement, 'tag')
        ET.SubElement(tagElement, 'id').text = str(tag_id)
        ET.SubElement(tagElement, 'count').text = str(count)


@stopwatch
def buildTagListElement(tagRows):
    rootElement = ET.Element('tag-list')
    for feed_id, id, title in tagRows:
        tagElement = ET.SubElement(rootElement, 'tag')
        ET.SubElement(tagElement, 'id').text = str(id)
        ET.SubElement(tagElement, 'title-tag').text = str(title)
    return rootElement


class Tags(object):

    def getTags(self, uid, count=None, page=None):
        with get_connection() as con:
            sql = 'SELECT ' + ('SQL_CALC_FOUND_ROWS' if count else '') + '''
                    pc.feed_id, pc.cat_id, pp.title_tag
                from Posts p, PostCategories pc, PostCategory pp
                where p.person_id = %s and p.deleted = 0 and not p.on_moderation and
                    pc.post_no = p.post_no and pc.feed_id = p.person_id and
                    pp.id = pc.cat_id
                    and pp.person_id = p.person_id
                group by pc.cat_id
                order by pp.title_tag
                '''
            params = (uid,)
            if count:
                if page:
                    sql += ' LIMIT %s, %s'
                    params = params + (count * page, count)
                else:
                    sql += ' LIMIT %s'
                    params = params + (count,)
            _log.debug("sql=%r, params=%r" % (sql, params))
            rows = con.execute(sql, params).fetchall()
            total = con.scalar("SELECT FOUND_ROWS()") if count else None
            return rows, total

    def getTagByName(self, uid, tag_name):
        """ Returns tag ID given its text. """
        with get_connection() as con:
            res = con.execute(
                'SELECT id from PostCategory WHERE person_id = %s AND title_tag = %s',
                (uid, tag_name))
            if res:
                one_row = res.fetchone()
                if one_row:
                    return one_row[0]
            return None

    def deleteTag(self, feed_id, tag_name):
        tag_id = self.getTagByName(feed_id, tag_name)

        delete_query = 'DELETE FROM PostCategory WHERE person_id = %s AND id = %s'
        select_query = """SELECT DISTINCT p.post_no
                    FROM PostCategories pcs JOIN Posts p 
                        ON pcs.feed_id = p.person_id AND pcs.post_no = p.post_no 
                    WHERE pcs.feed_id = %s AND pcs.cat_id = %s"""
        with get_connection() as connection:
            connection.execute(delete_query, (feed_id, tag_id))
            item_nos = connection.execute(select_query,
                                          (feed_id, tag_id)).fetchall()
            connection.execute(
                'DELETE FROM PostCategories WHERE feed_id=%s AND cat_id=%s',
                (feed_id, tag_id))

    def renameTag(self, uid, old_name, new_name):
        rename_query = """
            UPDATE PostCategory SET title_tag = %s
            WHERE person_id = %s AND title_tag = %s
        """
        with get_connection() as connection:
            connection.execute(rename_query, (new_name, uid, old_name))

    def resolveUserTags(self, uid, tag_list, as_tuple=False):
        """ Takes a list of strings, returns a list of ints.
            If needed, adds new tags to DB.
        """
        tag_list = set(
            tag_str.strip() for tag_str in tag_list if tag_str.strip())
        tag_list = tuple(tag_list)
        if not tag_list:
            return []

        template = ','.join(['%s' for _ in tag_list])
        select_query = 'SELECT id, title_tag FROM PostCategory WHERE title_tag in (%s) AND person_id=%s' % (
        template, uid)
        with get_connection() as connection:
            rows = [(id, title_tag) for id, title_tag in
                    connection.execute(select_query, tag_list).fetchall()]
            if len(rows) < len(tag_list):
                values = ','.join(['(%s,%%s)' % (uid,) for _ in tag_list])
                insert_query = 'INSERT IGNORE INTO PostCategory (person_id, title_tag) VALUES %s' % (
                values,)
                connection.execute(insert_query, tag_list)
                rows = [(id, title_tag) for id, title_tag in
                        connection.execute(select_query, tag_list).fetchall()]
            return [(id, str(title_tag, 'utf-8')) if as_tuple else id for
                    id, title_tag in rows]

    def resolveUserTagsByNames(self, uid, tag_list):
        """ Takes a list of UTF strings, returns a list of tuples
            (tag_id, tag_name_unicode).
            If needed, adds new tags to DB.
        """
        return self.resolveUserTags(uid, tag_list, True)

    def getFeedItemsTagList(self, feed_items):
        """ Returns list of posts tags (feed_item str, tag-id, title-tag)
            given list of posts [ 'feed_id.item_no' .. ] """
        post_count = len(feed_items)
        if post_count == 0:
            return []

        feed_item_tuples = []
        for feed_item in feed_items:
            feed_id, item_no = feed_item.split('.')
            try:
                feed_item_tuples.append((int(feed_id), int(item_no)))
            except ValueError:
                logging.exception('Invalid post id: %s', feed_item)

        if not feed_item_tuples:
            return []

        query = """SELECT p.person_id, p.post_no, pcs.cat_id, pc.title_tag
                    FROM Posts p JOIN PostCategories pcs JOIN PostCategory pc
                    WHERE ( %s )
                        AND pcs.feed_id = p.person_id AND pcs.post_no = p.post_no
                        AND pcs.deleted = 0
                        AND pc.id = pcs.cat_id
                        AND pc.person_id = p.person_id"""
        condition = " OR ".join(
            "(p.person_id=%d AND p.post_no=%d)" % t for t in feed_item_tuples
        )
        with get_connection() as conn:
            db_rows = conn.execute(query % condition, None)
        rows = [
            ("%d.%d" % (feed_id, item_no), tag_id, title_tag)
            for feed_id, item_no, tag_id, title_tag in db_rows
        ]
        return rows

    @classmethod
    def getSingleFeedTagList(cls, feed_id, tag_ids):
        """Теги tag_ids должны быть из фида feed_id."""
        if not tag_ids:
            return []
        with get_connection() as connection:
            template = """SELECT person_id, id, title_tag FROM PostCategory
                        WHERE id in (%s) and person_id = %s
                        ORDER BY title_tag"""
            condition = ','.join(str(tag_id) for tag_id in tag_ids)
            cursor = connection.execute(template % (condition, feed_id),
                                        tuple())
            return [(person_id, id, title_tag) for person_id, id, title_tag in
                    cursor.fetchall()]

    # corba methods:
    @et2xml
    @log_exception
    @stopwatch
    def GetTagsXMLPaged(self, person_id, count=None, page=None):
        count = count or 0
        page = page or 0
        if page < 0:
            page = 0
        if count < 1:
            count = 0
        rows, total = self.getTags(person_id, count, page)
        rootElement = buildTagListElement(rows)
        if count:
            rootElement.attrib['page-size'] = str(count)
            rootElement.attrib['total'] = str(total)
        return ET.ElementTree(rootElement)

    @et2xml
    @log_exception
    @stopwatch
    def GetTagSuggest(self, person_id, part):
        sql = 'SELECT title_tag FROM PostCategory WHERE title_tag LIKE %s AND person_id = %s ORDER BY title_tag LIMIT 10'
        root = ET.Element('tags')
        with get_connection() as connection:
            for title_tag, in connection.execute(sql, (
                        part + '%', person_id)).fetchall():
                ET.SubElement(root, 'tag').text = title_tag
        return ET.ElementTree(root)

    @et2xml
    @log_exception
    @stopwatch
    def GetTagID(self, person_id, tag_name):
        et = ET.Element('tag-id')
        et.text = str(self.getTagByName(person_id, tag_name) or '')
        return ET.ElementTree(et)

    @et2xml
    @log_exception
    @stopwatch
    def GetTopTags(self, uid, pageLength=5, tag_name='top-tags', scale=True,
                   types=None, pageIndex=0):
        if not uid:
            return '<' + tag_name + '/>'
        types_cond = ""
        if types:
            types_cond = ' AND p.post_type IN (%s) '
            types_cond = types_cond % ', '.join(str(t) for t in types)
        query = 'SELECT ' + ('SQL_CALC_FOUND_ROWS ' if pageLength else '') + """
                    count(*) as c, pc.cat_id
                FROM PostCategories pc, Posts p
                WHERE p.person_id = %s AND p.deleted = 0 AND not p.on_moderation
                    AND pc.feed_id = p.person_id AND pc.post_no = p.post_no
                    AND pc.cat_id <> 0""" + types_cond + """
                GROUP BY pc.cat_id HAVING c > 0
                ORDER BY c DESC, pc.cat_id ASC"""
        params = (uid,)
        if pageLength:
            if pageIndex:
                query += ' LIMIT %s, %s'
                params = params + (pageLength * pageIndex, pageLength)
            else:
                query += ' LIMIT %s'
                params = params + (pageLength,)
        with get_connection() as connection:
            top_tags = connection.execute(query, params).fetchall()
            tag_count = connection.scalar(
                "SELECT FOUND_ROWS()") if pageLength else None

        if scale and top_tags:
            top_tags = scaleTagCount(top_tags)
        rootElement = ET.Element('aux')
        topTagsElement = ET.SubElement(rootElement, tag_name)
        if pageLength:
            if tag_count > pageLength:
                ET.SubElement(topTagsElement, 'more')
            topTagsElement.attrib['page-size'] = str(pageLength)
            topTagsElement.attrib['total'] = str(tag_count)
        appendTagCountsToElement(top_tags, topTagsElement)

        meTagId = self.getTagByName(uid, 'Я')
        if meTagId:
            ET.SubElement(topTagsElement, 'me-tag').text = str(meTagId)

        tag_ids = [tag_id for _, tag_id in top_tags]
        tag_rows = self.getSingleFeedTagList(uid, tag_ids)
        rootElement.append(buildTagListElement(tag_rows))
        return ET.ElementTree(rootElement)

    def GetTagCloud2XML(self, uid, count):
        return self.GetTopTags(uid, pageLength=count, tag_name='tag-cloud')

    def GetTagCloudTyped(self, uid, count=None, posttypes=None):
        count = count or 20
        posttypes = posttypes or ''
        try:
            types = Types.typenames2codes(posttypes.split(','))
        except:
            _log.error('Failed to convert posttypes for GetTagCloudTyped:\n' + \
                       traceback.format_exc())
            types = None
        return self.GetTopTags(uid, pageLength=count, tag_name='tag-cloud',
                               types=types)


t = Tags()
