#!/usr/bin/env python2
# coding: utf-8

from contextlib import closing
from collections import namedtuple
import opster
import psycopg2
import sys

MessageInfo = namedtuple('MessageInfo', ['id', 'object_id', 'description'])

def chunked(items, chunk_size):
    for i in xrange(0, len(items), chunk_size):
        yield items[i:i + chunk_size]


def fetch_broken_messages(res_conn):
    with closing(res_conn.cursor()) as cursor:
        cursor.execute(
            "SELECT id, object_id, description" +
                " FROM diffalert.messages" +
                    " JOIN diffalert.message_attributes" +
                    " USING (attributes_id)" +
                " WHERE category_id = ''")
        return [MessageInfo(*r) for r in cursor]


def fetch_category_ids(tds_conn, oids):
    oid_category_ids = {}
    for oid_chunk in chunked(oids, 1000):
        with closing(tds_conn.cursor()) as cursor:
            cursor.execute(
                "SELECT DISTINCT ON (1) rev.object_id," +
                        " (SELECT key FROM each(att.contents) WHERE key LIKE 'cat:%')" +
                    " FROM revision.object_revision rev" +
                        " JOIN revision.attributes att" +
                        " ON rev.attributes_id = att.id" +
                    " WHERE object_id IN (" + ','.join(map(str, oids)) + ")")
            for oid, category_id in cursor:
                if len(category_id) < 5:
                    print >>sys.stderr, "bad category for oid %s: %s" % (oid, category_id)
                oid_category_ids[oid] = category_id[4:]
    return oid_category_ids


def update_category_ids(res_conn, msgs, oid_category_ids):
    for msg_chunk in chunked(msgs, 1000):
        vals = []
        for m in msg_chunk:
            vals.append("(%s, '%s', '%s')" %
                (m.id, oid_category_ids[m.object_id], m.description))
        with closing(res_conn.cursor()) as cursor:
            cursor.execute((
                "UPDATE diffalert.messages m" +
                " SET attributes_id =" +
                    " diffalert.insert_message_attributes(v.cat, v.dsc)" +
                " FROM (VALUES %s) AS v (id, cat, dsc)" +
                " WHERE m.id = v.id") % ",".join(vals))


@opster.command()
def main(res_conn_str=('', '', 'validation connection string'),
         tds_conn_str=('', '', 'revision connection string')):
    """Fix missing category ids in messages"""

    with closing(psycopg2.connect(res_conn_str)) as res_conn:
        print >>sys.stderr, "Looking for broken messages...",
        msgs = fetch_broken_messages(res_conn)
        print >>sys.stderr, len(msgs), "found"
        if not msgs:
            return

        oids = list(set(m.object_id for m in msgs))
        print >>sys.stderr, "Fetching category ids...",
        with closing(psycopg2.connect(tds_conn_str)) as tds_conn:
            oid_category_ids = fetch_category_ids(tds_conn, oids)
        print >>sys.stderr, "done"

        print >>sys.stderr, "Performing fix...",
        update_category_ids(res_conn, msgs, oid_category_ids)
        res_conn.commit()
        print >>sys.stderr, "done"


if __name__ == '__main__':
    main()
