# import time
import logging

import flask
import psycopg2.extras
from sqlalchemy import text

from .abstract import Subscriptions


def count(iterator):
    return sum(1 for _ in iterator)


def parse_settings_triples(records: list):
    settings = {}
    for item in records:
        settings.setdefault(item.kind, {})[item.setting] = item.value
    return settings


def remove_negation(tag_name):
    return tag_name.lstrip("!")


def remove_tags_negation(tags):
    return [remove_negation(tag) for tag in tags]


def has_negation(tag):
    return tag.startswith("!")


class SubscriptionsCache(Subscriptions):
    def __init__(self, db, log=None):
        self.log = log or logging.getLogger('subs')
        self.db = db

    @classmethod
    def create(cls, options, log):
        assert isinstance(options, dict), "options should be dict but %r" % (type(options),)
        db = flask.current_app.database_pool
        return cls(db, log.getChild('subs'))

    def key(self):
        return "subscriptions_cache"

    def name(self):
        return "Database subscriptions cache"

    def subscribers_for(self, tags, fetch_rule=False, exact=False):
        """Find subscribers for tags"""
        assert tags, "list of tags mustn't be empty"
        # t = time.time()

        rule_field = 'g.rule as rule,' if fetch_rule else ''
        args = {}
        query = f"""SELECT g.subscription_id AS sub_id,
                          g.name AS name,
                          g.is_group AS is_group,
                          {rule_field}
                          s.kind AS kind,
                          s.setting AS setting,
                          s.value AS value
                   FROM subscriptions AS g """
        for tag_n, tag in enumerate(tags):
            join_type = 'JOIN' if exact or has_negation(tag) else 'LEFT JOIN'
            query += ' %s tags AS tag_%d ON (tag_%d.name = %%(tag_%d)s)' % (join_type, tag_n, tag_n, tag_n)
            args['tag_%d' % tag_n] = remove_negation(tag)
            query += """ %s subscriptions_tags AS st_%d
                          ON (g.subscription_id = st_%d.subscription_id AND st_%d.tag_id = tag_%d.tag_id)""" % (
                join_type, tag_n, tag_n, tag_n, tag_n)

        query += """ LEFT JOIN settings AS s
                     ON (g.name = s.name AND g.is_group = s.is_group AND s.subscription_id is NULL)"""

        query += ' WHERE'
        positives = []
        clauses = []
        for tag_n, tag in enumerate(tags):
            is_neg = int(has_negation(tag))
            if exact or is_neg:
                clauses.append(f' (st_{tag_n}.invert = {is_neg})')
            else:
                clauses.append(f' (st_{tag_n}.invert = 0 OR st_{tag_n}.invert IS NULL)')
                positives.append(f'CAST(st_{tag_n}.invert IS NOT NULL AS INTEGER)')

        query += ' AND '.join(clauses)
        if exact:
            negatives_total = len(list(filter(has_negation, tags)))
            positives_total = len(tags) - negatives_total
            query += f' AND g.positive_clauses = {positives_total}'
            query += f' AND g.negative_clauses = {negatives_total}'
        else:
            query += ' AND (' + ' + '.join(positives) + ') = g.positive_clauses'
        query += ' ORDER BY g.positive_clauses ASC'

        sub_ids = set()
        users = {}
        groups = {}

        with self.db.connection() as conn, conn.cursor() as cur:
            options = {}
            cur.execute(query, args)
            for row in cur.fetchall():
                sub_ids.add(row.sub_id)
                target = groups if row.is_group else users

                if fetch_rule:
                    target.setdefault(row.name, {}).setdefault(row.sub_id, {'tags': row.rule.split('&'), 'options': {}})
                    options[row.sub_id] = target[row.name][row.sub_id]['options']
                    if row.kind:
                        options[row.sub_id].setdefault(row.kind, {}).setdefault(row.setting, row.value)
                else:
                    target[row.name] = {}
                    if row.kind:
                        target[row.name].setdefault((row.kind, row.setting), row.value)

            if not sub_ids:
                return users, groups

            query = "SELECT * FROM settings WHERE subscription_id IN %s"
            cur.execute(query, (tuple(sub_ids),))
            for row in cur.fetchall():
                if fetch_rule:
                    options[row.subscription_id].setdefault(row.kind, {})[row.setting] = row.value
                else:
                    target = groups if row.is_group else users
                    target[row.name][(row.kind, row.setting)] = row.value

        return users, groups

        # self.log.debug("select %.3f users for %d tags", time.time() - t, len(tags))

    def find_subscriptions_by_tags(self, tags, exact):
        """Find all subscribptions for all users by tags"""

        return self.subscribers_for(tags, fetch_rule=True, exact=exact)

    def settings_for(self, name, is_group):
        query = "SELECT * FROM settings WHERE name = %s AND is_group = %s AND subscription_id IS NULL"
        with self.db.connection() as conn, conn.cursor() as cur:
            cur.execute(query, (name, bool(is_group)))
            return parse_settings_triples(cur.fetchall())

    def update_settings_for(self, name, is_group, settings, replace_all):
        with self.db.connection() as conn, conn.cursor() as cur:
            self._add_settings(cur, name, is_group, None, settings, replace_all=replace_all)
            conn.commit()

    def _add_settings(self, cursor, name, is_group, subscription_id, settings, replace_all=False):
        if replace_all:
            query = "DELETE FROM settings WHERE name = %s AND is_group = %s AND subscription_id IS NULL"
            cursor.execute(query, (name, bool(is_group)))

        for kind, options in settings.items():
            for key, value in options.items():
                if value is not None:
                    query = (
                        "INSERT INTO settings (subscription_id, name, is_group, kind, setting, value) "
                        "VALUES (%s, %s, %s, %s, %s, %s) ON CONFLICT "
                    ) + (
                        "(subscription_id, name, is_group, kind, setting) WHERE subscription_id IS NOT NULL "
                        if subscription_id is not None else
                        "(name, is_group, kind, setting) WHERE subscription_id IS NULL "
                    ) + "DO UPDATE SET value = excluded.value"
                    cursor.execute(query, (subscription_id, name, bool(is_group), kind, key, str(value)))
                else:
                    query = (
                        "DELETE FROM settings WHERE "
                        "name = %s AND kind = %s AND setting = %s AND is_group = %s AND subscription_id "
                    ) + ("= %s" if subscription_id is not None else "IS NULL")
                    args = (name, kind, key, bool(is_group)) + ((subscription_id,) if subscription_id is not None else ())
                    cursor.execute(query, args)

    def _list_subscriptions(self, name, is_group):
        self.log.debug("=> _list_subscriptions(name=%r, is_group=%r)", name, is_group)
        settings = {}
        rules = {}
        query = (
            "SELECT s.subscription_id, s.rule, t.kind, t.setting, t.value FROM subscriptions s "
            "LEFT OUTER JOIN settings t USING (subscription_id) "
            "WHERE s.name = %s AND s.is_group = %s"
        )
        with self.db.connection() as conn, conn.cursor() as cur:
            cur.execute(query, (name, bool(is_group)))
            for row in cur.fetchall():
                if row.subscription_id not in rules:
                    rules[row.subscription_id] = row.rule.split('&')
                if row.setting is not None:
                    settings.setdefault(
                        row.subscription_id, {}
                    ).setdefault(
                        row.kind, {}
                    )[row.setting] = row.value

        for subscription_id in rules:
            yield rules[subscription_id], settings.get(subscription_id, {})

    def subscriptions_for_user(self, login):
        return list(self._list_subscriptions(login, False))

    def subscriptions_for_group(self, name):
        return list(self._list_subscriptions(name, True))

    def _add_tags(self, tags):
        query = "INSERT INTO tags (name) VALUES %s ON CONFLICT (name) DO NOTHING"
        with self.db.connection() as conn, conn.cursor() as cur:
            psycopg2.extras.execute_values(cur, query, [(tag,) for tag in tags], template=None, page_size=500)
            conn.commit()

        query = "SELECT name, tag_id FROM tags WHERE name IN %s"
        with self.db.connection() as conn, conn.cursor() as cur:
            cur.execute(query, (tuple(tags),))
            return {row.name: row.tag_id for row in cur.fetchall()}

    def _find_subscription(self, cursor, name, is_group, tags, add_missing_tags=False):
        if add_missing_tags:
            self.log.debug("name=%s, is_group=%r: adding tags: %s", name, is_group, tags)
            self._add_tags(remove_tags_negation(tags))
        else:
            missing = set(remove_tags_negation(tags))
            query = "SELECT name FROM tags WHERE name in %s"
            cursor.execute(query, (tuple(missing),))
            for row in cursor.fetchall():
                missing.discard(row.name)

            if missing:
                raise KeyError("tags not found: %s" % (', '.join(repr(m) for m in sorted(missing)),))

        negative = count(filter(has_negation, tags))
        positive = len(tags) - negative

        query = "SELECT s.subscription_id FROM subscriptions s "
        args = []
        filters = []

        for idx, tag in enumerate(tags):
            tag_alias = f't{idx}'
            st_alias = f'st{idx}'

            query += f'JOIN tags {tag_alias} ON ({tag_alias}.name = %s) '
            args.append(remove_negation(tag))

            query += (
                f'JOIN subscriptions_tags {st_alias} ON '
                f'(s.subscription_id = {st_alias}.subscription_id AND {st_alias}.tag_id = {tag_alias}.tag_id)'
            )
            # to hell with arg substitution here
            filters.append(f'{st_alias}.invert = {int(has_negation(tag))}')

        query += "WHERE s.is_group = %s AND s.name = %s AND s.positive_clauses = %s AND s.negative_clauses = %s AND "
        args.extend((bool(is_group), name, positive, negative))
        query += ' AND '.join(filters)

        cursor.execute(query, args)
        row = cursor.fetchone()
        if row is not None:
            return row.subscription_id

        return None

    def _add_subscription(self, name, is_group, settings, tags):
        rule = '&'.join(tags)

        with self.db.connection() as conn, conn.cursor() as cur:
            subscription_id = self._find_subscription(cur, name, is_group, tags, add_missing_tags=True)
            if subscription_id is not None:
                self.log.debug("[%r, %r] subscription already exists: %r", name, is_group, subscription_id)
                return subscription_id

            positive = count(filter(lambda x: not has_negation(x), tags))
            negative = len(tags) - positive

            query = (
                "INSERT INTO subscriptions (rule, is_group, name, positive_clauses, negative_clauses) "
                "VALUES (%s, %s, %s, %s, %s) RETURNING subscription_id"
            )
            cur.execute(query, (rule, bool(is_group), name, positive, negative))
            sub_id = cur.fetchone().subscription_id

            tag_inversions = {
                remove_negation(tag): int(has_negation(tag))
                for tag in tags
            }

            query = "SELECT * FROM tags WHERE name IN %s"
            cur.execute(query, (tuple(tag_inversions.keys()),))
            tag_ids = {
                row.name: row.tag_id
                for row in cur.fetchall()
            }

            self._add_settings(
                cursor=cur,
                name=name,
                is_group=is_group,
                subscription_id=sub_id,
                settings=settings,
                replace_all=True,
            )

            query = "INSERT INTO subscriptions_tags (tag_id, subscription_id, invert) VALUES %s"
            psycopg2.extras.execute_values(cur, query,
                                           [(tag_ids[tag], sub_id, tag_inversions[tag]) for tag in tag_inversions],
                                           template=None,
                                           page_size=500)

            conn.commit()

        return sub_id

    def add_subscription(self, name, is_group, tags, settings):
        return self._add_subscription(name=name, is_group=is_group, tags=tags, settings=settings)

    def _remove_subscription(self, name, is_group, tags):
        with self.db.connection() as conn, conn.cursor() as cur:
            subscription_id = self._find_subscription(cur, name, is_group, tags)
            if subscription_id is None:
                raise KeyError("subscription not found")

            query = "DELETE FROM subscriptions_tags WHERE subscription_id = %s"
            cur.execute(query, (subscription_id,))

            query = "DELETE FROM subscriptions WHERE subscription_id = %s"
            cur.execute(query, (subscription_id,))

            query = "DELETE FROM settings WHERE subscription_id = %s"
            cur.execute(query, (subscription_id,))

            conn.commit()

    def remove_subscription(self, name, is_group, tags):
        return self._remove_subscription(name=name, is_group=is_group, tags=tags)

    def update_subscription(self, name, is_group, tags, settings):
        with self.db.connection() as conn, conn.cursor() as cur:
            subscription_id = self._find_subscription(cur, name, is_group=is_group, tags=tags)
            if subscription_id is None:
                raise KeyError("subscription not found: name=%r, is_group=%r, tags=%r" % (name, is_group, tags))

            self._add_settings(
                cursor=cur,
                name=name,
                is_group=is_group,
                settings=settings,
                subscription_id=subscription_id,
                replace_all=False,
            )

            conn.commit()
