#!/usr/bin/env python
#
import sys
import argparse
import json
import pymongo
import yt.wrapper as ytw
from time import mktime, strptime, strftime, localtime, sleep
from bson.int64 import Int64
from collections import defaultdict
from threading import Thread
from traceback import format_exception


YT_LOG = {
    "compl_5min": "//home/logfeller/logs/mail-so-compl-log/stream/5min/",
    "in_daily":   "//home/logfeller/logs/mail-so-ml-log/1d/",
    "out_daily":  "//home/logfeller/logs/mail-so-out-log/1d/",
    "corp_daily": "//home/logfeller/logs/mail-so-corp-log/1d/",
    "tmp":        "//home/so_fml/nirvana/tmp/tmp_rules_complaints_log_"
}
MONGO = {
    'db':      'rules',
    'hosts':   'vla-hwmeehtmq450wvke.db.yandex.net,sas-agixxin7fbr78u0o.db.yandex.net,vlx-g9q00jkxh959zk6m.db.yandex.net',
    'port':    27018,
    'user':    'solog',
    'timeout': 70000
}
ROUTES = ['in', 'out', 'corp']


def get_traceback():
    exc_type, exc_value, exc_traceback = sys.exc_info()
    tb = ''
    for step in format_exception(exc_type, exc_value, exc_traceback):
        try:
            tb += "\t" + step.strip() + "\n"
        except:
            pass
    return tb


def log(s, isTB=False):
    if isTB:
        s += get_traceback()
    print >>sys.stderr, s
    sys.stderr.flush()


def mongoConnStr(cfg):
    s, hosts = '', cfg.get('hosts', '127.0.0.1').split(',')
    if 'user' in cfg and cfg['user']:
        s = "%s:%s@" % (cfg['user'], cfg['passwd'])
    return "mongodb://%s%s/%s" % (s, ','.join(hosts), cfg['db'])


def getMongoDB(cfg, params={}):
    if not hasattr(getMongoDB, "%s_connection" % cfg['db']):
        timeout = cfg['timeout'] if 'timeout' in cfg and cfg['timeout'] else 10000
        if "connectTimeoutMS" not in params:
            params["connectTimeoutMS"] = timeout
        if "socketTimeoutMS" not in params:
            params["socketTimeoutMS"] = timeout
        setattr(getMongoDB, "%s_connection" % cfg['db'], pymongo.MongoClient(host=mongoConnStr(cfg), port=cfg.get('port', 27017), **params)[cfg['db']])
    return getattr(getMongoDB, "%s_connection" % cfg['db'])


def sendRules2DB(dbClient, route, data):
    try:
        for (t, ruleStats) in data['detailed'].iteritems():
            for (rule, stats) in ruleStats.iteritems():
                dbClient['detailed_Rules_%s' % route].update_one({'rule': rule, 'time': t}, {'$inc': dict(stats)}, upsert=True)
        for (d, rulesStats) in data['daily'].iteritems():
            for (rule, stats) in rulesStats.iteritems():
                dbClient['Rules_%s' % route].update_one({'rule': rule, 'date': d}, {'$inc': dict(stats)}, upsert=True)
        log("Saving of statistics to DB for route=%s done" % route)
    except Exception, e:
        log("Exception while DB operations in sendRules2DB: %s" % str(e), True)


def gatherRulesStatistics(ytClient, route, dt, tablePath, data):
    for r in ytClient.read_table(tablePath, format=ytw.JsonFormat(), raw=False):
        try:
            t = r["cmpl_time"]
            data['detailed'][t][r["rule"]]['cmpl_ham'] += r["ham"]
            data['detailed'][t][r["rule"]]['cmpl_spam'] += r["spam"]
            data['detailed'][t][r["rule"]]['cmpl_ham_nopf'] += r["ham_nopf"]
            data['detailed'][t][r["rule"]]['cmpl_spam_nopf'] += r["spam_nopf"]
            data['daily'][dt][r["rule"]]['cmpl_ham'] += r["ham"]
            data['daily'][dt][r["rule"]]['cmpl_spam'] += r["spam"]
            data['daily'][dt][r["rule"]]['cmpl_ham_nopf'] += r["ham_nopf"]
            data['daily'][dt][r["rule"]]['cmpl_spam_nopf'] += r["spam_nopf"]
        except Exception, e:
            log("gatherRulesStatistics failed to process row: %s" % str(e), True)
    return data


def gatherRulesStatistics2(ytClient, route, tablePath, data, uniqs):
    for r in ytClient.read_table(tablePath, format=ytw.JsonFormat(), raw=False):
        try:
            t = r["cmpl_date"]
            data[t][route][r["rule"]]['ham'] += r["ham"]
            data[t][route][r["rule"]]['spam'] += r["spam"]
            data[t][route][r["rule"]]['ham_nopf'] += r["ham_nopf"]
            data[t][route][r["rule"]]['spam_nopf'] += r["spam_nopf"]
            data[t][route]["TOTAL"]['ham'] += r["ham"]
            data[t][route]["TOTAL"]['spam'] += r["spam"]
            data[t][route]["TOTAL"]['ham_nopf'] += r["ham_nopf"]
            data[t][route]["TOTAL"]['spam_nopf'] += r["spam_nopf"]
            uniqs[t][route][r["rule"]]['ham'] |= set(r["uniq_ham_uids"].split(","))
            uniqs[t][route][r["rule"]]['spam'] |= set(r["uniq_spam_uids"].split(","))
            uniqs[t][route][r["rule"]]['ham_nopf'] |= set(r["uniq_ham_nopf_uids"].split(","))
            uniqs[t][route][r["rule"]]['spam_nopf'] |= set(r["uniq_spam_nopf_uids"].split(","))
            uniqs[t][route]["TOTAL"]['ham'] |= set(r["uniq_ham_uids"].split(","))
            uniqs[t][route]["TOTAL"]['spam'] |= set(r["uniq_spam_uids"].split(","))
            uniqs[t][route]["TOTAL"]['ham_nopf'] |= set(r["uniq_ham_nopf_uids"].split(","))
            uniqs[t][route]["TOTAL"]['spam_nopf'] |= set(r["uniq_spam_nopf_uids"].split(","))
        except Exception, e:
            log("gatherRulesStatistics failed to process row: %s" % str(e), True)
    return data


@ytw.with_context
def complLogMapper(record, context):
    if record.get('type', "").endswith("foo") and len(record.get("route", "")) > 1 and len(record.get("queueid", "")) > 7 \
            and len(record.get("actdate", "")) > 8 and "skipped" in record and (record["skipped"] == "-" or record["skipped"] == "") \
            and "uid" in record and record["uid"] != "-":
        rulesStr = record.get("rules", "").strip()
        rules = [] if rulesStr == "-" else rulesStr.split(';')
        if record["route"] == "in":
            yield ytw.create_table_switch(0 if len(rules) > 0 else 1)
        elif record["route"] == "out":
            yield ytw.create_table_switch(2 if len(rules) > 0 else 3)
        elif record["route"] == "corp":
            yield ytw.create_table_switch(4 if len(rules) > 0 else 5)
        else:
            return
        msgDate = "" if record.get("msgdate", "") == "-" else record.get("msgdate", "")
        cmplDate = "" if record.get("actdate", "") == "-" else record.get("actdate", "")
        t = 0
        try:
            t = int(mktime(strptime(cmplDate, "%Y-%m-%d %H:%M:%S")))
        except:
            t = int(mktime(strptime(cmplDate, "%d.%m.%Y %H:%M:%S")))
        yield {
            "uid":       record["uid"],
            "queueid":   record["queueid"],
            "type":      record["type"],
            "route":     record["route"],
            "rules":     "" if rulesStr == "-" else rulesStr,
            "cmpl_time": int(t / 300) * 300,
            "date":      msgDate.split()[0]
        }


@ytw.with_context
def dateReducer(sortKey, records, context):
    if sortKey["date"]:
        yield {"date": sortKey["date"]}


def parseDates(ytClient, table):
    dates = []
    for r in ytClient.read_table(table, format=ytw.JsonFormat(), raw=False):
        dates.append(r["date"])
    return sorted(dates)


@ytw.with_context
def logMapper(record, context):
    yield {
        "type":      record["type"] if "type" in record else "dlv",
        "uid":       record["uid"] if "uid" in record else "",
        "rules":     record["rules"] if "rules" in record else record.get("r_sp", ""),
        "cmpl_time": record["cmpl_time"] if "cmpl_time" in record else 0,
        "queueid":   record["queueid"] if "queueid" in record else record.get("x-yandex-queueid", "")
    }


@ytw.with_context
def logReducer(sortKey, records, context):
    rules = t = cmplDate = uid = ""
    for record in records:
        if record["type"] == "dlv":
            rules = ";".join(map(lambda s: s.split()[0], record["rules"].split(";")))
        else:
            t, cmplDate, uid = record["type"], record["cmpl_time"], record["uid"]
    yield {
        "type":      t,
        "uid":       uid,
        "rules":     rules,
        "cmpl_time": cmplDate
    }


@ytw.with_context
def rulesMapper(record, context):
    rules = record["rules"].split(";")
    pf = "PERSONAL_CORRECT" in rules
    for rule in rules:
        if rule:
            yield {
                "cmpl_time": record["cmpl_time"],
                "rule":      rule,
                "spam":      1 if record["type"] == "foo" else 0,
                "ham":       1 if record["type"] == "antifoo" else 0,
                "pf":        1 if pf else 0,
                "uid":       record["uid"]
            }


@ytw.with_context
def rulesReducer(sortKey, records, context):
    ham = spam = hamNoPF = spamNoPF = 0
    uidsHam = {}
    uidsSpam = {}
    uidsHamNoPF = {}
    uidsSpamNoPF = {}
    for record in records:
        ham += record["ham"]
        spam += record["spam"]
        hamNoPF += 0 if record["pf"] else record["ham"]
        spamNoPF += 0 if record["pf"] else record["spam"]
        if record["uid"]:
            if record["ham"]:
                uidsHam[record["uid"]] = 1
                if not record["pf"]:
                    uidsHamNoPF[record["uid"]] = 1
            if record["spam"]:
                uidsSpam[record["uid"]] = 1
                if not record["pf"]:
                    uidsSpamNoPF[record["uid"]] = 1
    yield {
        "cmpl_time":           int(sortKey["cmpl_time"] / 600) * 600,
        "cmpl_date":           strftime("%Y-%m-%d %H:%M:%S", localtime(sortKey["cmpl_time"])),
        "rule":                sortKey["rule"],
        "ham":                 ham,
        "spam":                spam,
        "ham_nopf":            hamNoPF,
        "spam_nopf":           spamNoPF,
        "uniq_ham_uids":       ",".join(uidsHam.keys()),
        "uniq_spam_uids":      ",".join(uidsSpam.keys()),
        "uniq_ham_nopf_uids":  ",".join(uidsHamNoPF.keys()),
        "uniq_spam_nopf_uids": ",".join(uidsSpamNoPF.keys())
    }


class GatherStatsForRoute(Thread):
    def __init__(self):
        Thread.__init__(self)

    def __call__(self, route, dt, inputTable, dbConfig, withRules=True):
        self.daemon = True
        self.route = route
        self.date = dt
        self.dbConfig = dbConfig
        self.withRules = withRules
        self.inputTable = inputTable
        self.intermidiateTable = "{}_tmp".format(inputTable)
        self.outputTable = "{}_rules_tmp".format(inputTable)
        self.start()

    def run(self):
        self.ytClient = ytw.YtClient(proxy="hahn")
        self.db = getMongoDB(self.dbConfig)
        tablePath = self.inputTable
        if not self.withRules:
            self.ytClient.run_sort(self.inputTable, sort_by=["date"])
            self.ytClient.run_reduce(dateReducer, self.inputTable, self.intermidiateTable, reduce_by=["date"])
            logs = [self.inputTable]
            dates = parseDates(self.ytClient, self.intermidiateTable)
            logs += [YT_LOG["{}_daily".format(self.route)] + d for d in dates]
            log("GatherStatsForRoute: route=%s, dates_for_get_rules=%s" % (self.route, dates))
            self.ytClient.run_map_reduce(logMapper, logReducer, self.inputTable, self.intermidiateTable, reduce_by=["queueid"], spec={
                "max_data_size_per_job": 6442450944,
                "reducer": {"data_size_per_sort_job": 671088640, "memory_limit": 4294967296},
                "owners": ['robot-mailspam']
            })
            tablePath = self.intermidiateTable
        self.ytClient.run_map_reduce(rulesMapper, rulesReducer, tablePath, self.outputTable, reduce_by=["cmpl_time", "rule"], spec={
            "max_data_size_per_job": 6442450944,
            "reducer": {"data_size_per_sort_job": 671088640, "memory_limit": 4294967296},
            "owners": ['robot-mailspam']
        })
        data = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(Int64))))
        gatherRulesStatistics(self.ytClient, self.route, self.date, self.outputTable, data)
        sendRules2DB(self.db, self.route.capitalize(), data)
        if self.ytClient.exists(self.inputTable):
            self.ytClient.remove(self.inputTable)
        if self.ytClient.exists(self.intermidiateTable):
            self.ytClient.remove(self.intermidiateTable)


def sortTuples(x, y):
    if y[0] == 'TOTAL':
        return 0 if x[0] == y[0] else -1
    else:
        return cmp(x[0], y[0])


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('-d', '--table_path',  type=str, help="Artifact: new path to table, which will be parsed here")
    # parser.add_argument('-s', '--scale',       type=str, help="Scale of output data ('m' or 'd')")
    parser.add_argument('-t', '--total_stats', type=str, help="Path to output JSON with general statistics to upload to YaStat")
    parser.add_argument('-x', '--secret',      type=str, help="Secret for DB access")
    args = parser.parse_known_args()[0]
    TOTALS_JSON_FILE = args.total_stats if args.total_stats else './total_stats.json'
    MONGO["passwd"] = args.secret
    DATETIME = args.table_path[args.table_path.rfind("/")+1:].replace(":", ".")
    ytw.config["read_parallel"]["enable"] = True

    complLogs, totals_data, dt = [], [], DATETIME[:DATETIME.find('T')]
    totals_stats = defaultdict(lambda: defaultdict(lambda: defaultdict(int)))
    uniqs_stats = defaultdict(lambda: defaultdict(lambda: defaultdict(int)))
    for route in ROUTES:
        complLogs.append(YT_LOG["tmp"] + DATETIME + "_" + route + "_with_rules")
        complLogs.append(YT_LOG["tmp"] + DATETIME + "_" + route + "_without_rules")
    ytw.run_map(complLogMapper, args.table_path, complLogs, job_count=1000)
    log("Go to parallel calculations")
    threads = []
    for i, route in enumerate(ROUTES):
        if ytw.row_count(complLogs[2 * i]) > 0:
            t1 = GatherStatsForRoute()
            t1(route, dt, complLogs[2 * i], MONGO, True)
            threads.append(t1)
        else:
            log("Log %s has no rows" % complLogs[2 * i])
            if ytw.exists(complLogs[2 * i]):
                ytw.remove(complLogs[2 * i])
        if ytw.row_count(complLogs[2 * i + 1]) > 0:
            t2 = GatherStatsForRoute()
            t2(route, dt, complLogs[2 * i + 1], MONGO, False)
            threads.append(t2)
        else:
            log("Log %s has no rows" % complLogs[2 * i + 1])
            if ytw.exists(complLogs[2 * i + 1]):
                ytw.remove(complLogs[2 * i + 1])
    n, i = len(threads), 0
    while n > 0:
        if threads[i].isAlive():
            i += 1
        else:
            del threads[i]
        n = len(threads)
        if i >= n - 1:
            i = 0
            sleep(1)
    data = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(Int64))))
    uniqs = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(set))))
    for i, route in enumerate(ROUTES):
        log("Retrieving statistics from tables for route=%s" % route)
        tablePath1, tablePath2 = "{}_rules_tmp".format(complLogs[2 * i]), "{}_rules_tmp".format(complLogs[2 * i + 1])
        if ytw.exists(tablePath1):
            gatherRulesStatistics2(ytw, route, tablePath1, data, uniqs)
            ytw.remove(tablePath1)
        if ytw.exists(tablePath2):
            gatherRulesStatistics2(ytw, route, tablePath2, data, uniqs)
            ytw.remove(tablePath2)
    log("Preparing data for YaStat")
    for (dt, rulesStats) in sorted(data.iteritems(), cmp=sortTuples):
        for (route, routeStats) in rulesStats.iteritems():
            for (rule, ruleStats) in sorted(routeStats.iteritems(), cmp=sortTuples):
                dataRow = {"fielddate": dt, "route": route, "rule": rule}
                for (k, v) in ruleStats.iteritems():
                    dataRow[k] = v
                    dataRow["uniq_{}".format(k)] = len(uniqs[dt][route][rule][k])
                totals_data.append(dataRow)
    try:
        f = open(TOTALS_JSON_FILE, 'wt')
        print >>f, json.dumps(totals_data)
        f.close()
    except Exception, e:
        log('Totals data file saving error: %s.' % str(e), True)
    log("[ DONE ]")
