#!/usr/bin/python2
# encoding: utf-8
# kate: space-indent on; indent-width 4; replace-tabs on;
#
import sys, argparse, json
import yt.wrapper as ytw
import nirvana.mr_job_context as nv
from urllib import urlopen
from datetime import date
from traceback import format_exception

GET_POOL_FILTER_ITEMS_URL = "https://web.so.yandex-team.ru/ml/get_pool_filter_items?route="
SAVE_MODEL_STATUS_URL = "https://web.so.yandex-team.ru/ml/save_model_status/?workflow_id=%s&workflow_instance_id=%s&status=pool_gathering&route=%s"
DLVLOG_WHITE_RULES = ["MN_DLVR"]
DLVLOG_BLACK_RULES = []
YT_RULES_DICT_PATH = "//home/so_fml/nirvana/rules_dict"

def get_traceback():
    exc_type, exc_value, exc_traceback = sys.exc_info()
    tb = ''
    for step in format_exception(exc_type, exc_value, exc_traceback):
        try:
            tb += "\t" + step.strip() + "\n"
        except:
            pass
    return tb

def doRequest(url, prompt):
    try:
        f = urlopen(url)
        if f.getcode() == 200:
            return f.read()
        else:
            print >>sys.stderr, '{0} response HTTP code: {1}, body: {2}'.format(prompt, f.getcode(), f.info())
    except Exception, e:
        print >>sys.stderr, '%s HTTP request failed: %s' % (prompt, str(e))
    return ""

@ytw.with_context
class LogsReducer:
    def __init__(self, white_rules, black_rules, route):
        self.white_rules = white_rules
        self.black_rules = black_rules
        self.route = route
        self.filter = {}
        content = doRequest(GET_POOL_FILTER_ITEMS_URL + self.route, 'Retrieving pool filter items')
        if content:
            try:
                self.filter.update(json.loads(content))
                print >>sys.stderr, "Obtained filter conditions: %s" % json.dumps(self.filter)
            except Exception, e:
                print >>sys.stderr, "JSON-data parsing while retrieving pool filter items error \"%s\" for string: %s" % (str(e), content)
        self.fields = ['queueid', 'spam', 'rules', 'r_cancel', 'domain', 'rcpt_uid', 'geo', 'ip', 'rdns', 'msgdate', 'mnf', "nnsubj", "nnbody", "nnfromaddr", "nnfromname"]
        if self.route == 'in':
            self.fields += ['sndr', 'sdmn']
    def __call__(self, key, records, context):
        res_rec = {}
        queueid, target, cmpl_target, dlv, gd = "", 0, -1, 0, ""
        for rec in records:
            table_index = context.table_index
            if table_index == 0:
                queueid = rec["queueid"]
                target = 1 if rec.get("spam", "").startswith("yes") else 0
                if self.route == 'in':
                    dlv, gd = rec['dlv'], rec['gd']
                for f in self.fields:
                    res_rec[f] = rec[f] if f in rec else ""
            if table_index == 1:
                cmpl_target = 1 if rec.get("type", "").startswith("foo") else 0
        if queueid:
            if target == cmpl_target:
                cmpl_target = -1
            if dlv:
                target, cmpl_target = 0, -1
            else:
                if gd:
                    target, cmpl_target = 0, -1
                if cmpl_target > -1:
                    target = cmpl_target
            rules = res_rec['rules'].split(';')
            if len(self.white_rules) + len(self.black_rules) > 0:
                for rule in self.white_rules:
                    if rule in rules: target = 0
                for rule in self.black_rules:
                    if rule in rules: target = 1
            is_include = 1
            if self.filter and isinstance(self.filter, dict):
                if 'rules' in self.filter and self.filter['rules'] and isinstance(self.filter['rules'], dict):
                    for rule in self.filter['rules']['spam' if target else 'ham']:
                        if rule in rules: is_include = 0
                if 'ips' in self.filter and self.filter['ips'] and isinstance(self.filter['ips'], dict):
                    for ip in self.filter['ips']['spam' if target else 'ham']:
                        if ip == res_rec['ip']: is_include = 0
                if 'from_domain' in self.filter and self.filter['from_domain'] and isinstance(self.filter['from_domain'], dict):
                    for domain in self.filter['from_domain']['spam' if target else 'ham']:
                        if domain == res_rec['domain']: is_include = 0
                if 'rdns' in self.filter and self.filter['rdns'] and isinstance(self.filter['rdns'], dict):
                    for rdns in self.filter['rdns']['spam' if target else 'ham']:
                        if rdns == res_rec['rdns'] or rdns.endswith('.' + res_rec['rdns']): is_include = 0
            if is_include:
                res_rec['target'] = target
                res_rec['cmpl'] = cmpl_target
                yield res_rec

class LogsReducer2:
    def __init__(self, do_multiplicate, is_vector):
        self.do_multiplicate = do_multiplicate
        self.is_vector = is_vector
    def __call__(self, key, records):
        target0, target1, cmpl0, cmpl1, r, r0, r1, rc0, rc1 = 0, 0, 0, 0, {}, {}, {}, {}, {}
        for rec in records:
            if rec['target'] == 0:
                if target0 == 0:
                    r0.update(rec)
                target0 += 1
                if cmpl0 == 0:
                    rc0.update(rec)
                if rec['cmpl'] > -1:
                    cmpl0 += 1
            elif rec['target'] == 1:
                if target1 == 0:
                    r1.update(rec)
                target1 += 1
                if cmpl1 == 0:
                    rc1.update(rec)
                if rec['cmpl'] > -1:
                    cmpl1 += 1
        if target0 > target1:
            r.update(r0)
        elif target0 < target1:
            r.update(r1)
        elif cmpl0 > cmpl1:
            r.update(rc0)
        else:
            r.update(rc1)
        if not self.is_vector:
            del r['vector']
        if self.do_multiplicate:
            for i in range(target0 + target1):
                yield r
        else:
            yield r

class DlvLogMapper:
    def __init__(self, rules_dict_path):
        self.rules, self.rcount = {}, 0
        for rec in ytw.read_table(rules_dict_path, format = ytw.JsonFormat(), raw = False):
            if rec['act']:
                self.rules.update({rec['rule']: rec['num']})
            self.rcount += 1
    def __call__(self, record):
        vector = ["0" for i in range(0, self.rcount + 1)]
        for r in record.get("rules", "").split(";"):
            vector[int(self.rules.get(r, 0))] = "1"
        if "r_cancel" in record:
            for r in record.get("r_cancel", "").split(";"):
                vector[int(self.rules.get(r, 0))] = "1"
        vector[0] = "0"
        record['vector'] = "\t".join(vector)
        yield record

if __name__ == "__main__":
    d = date.today().isoformat()
    YT_DLVLOG_PATH = "//home/logfeller/logs/mail-so-ml-log/1d/" + d
    YT_CMPLLOG_PATH = "//home/logfeller/logs/mail-so-compl-log/1d/" + d
    parser = argparse.ArgumentParser()
    parser.add_argument('-d', '--dlvlog',        type = str, help = "input YT table with mapped delivery-log")
    parser.add_argument('-c', '--cmpllog',       type = str, help = "input YT table with mapped complaints' log")
    parser.add_argument('-m', '--multiplicate',  action = 'store_true', help = "whether it's needed to muctiplicate records with same targets")
    parser.add_argument('-w', '--whiterules',    type = str, help = "(comma separated) list of rules, for which target in pool will be 0")
    parser.add_argument('-b', '--blackrules',    type = str, help = "(comma separated) list of rules, for which target in pool will be 1")
    parser.add_argument('-o', '--output_dlvlog', type = str, help = "path to output dlvlog in YT")
    parser.add_argument('-r', '--route',         type = str, help = "The type of mail for which the model is calculated")
    args = parser.parse_known_args()[0]
    ROUTE = args.route if args.route else 'in'
    if args.dlvlog:
        YT_DLVLOG_PATH = args.dlvlog
    else:
        print >>sys.stderr, "Input YT table with mapped dlvlog must be set. Default path: %s" % YT_DLVLOG_PATH
        YT_DLVLOG_PATH = "//statbox/mail-so-%s-log/%s" % (ROUTE if ROUTE == 'out' else 'ml', d)
    if args.cmpllog:
        YT_CMPLLOG_PATH = args.cmpllog
    else:
        print >>sys.stderr, "Input YT table with mapped complaints log must be set. Default path: %s" % YT_CMPLLOG_PATH
    if args.whiterules:
        DLVLOG_WHITE_RULES = map(str.strip, args.whiterules.split(','))
    if args.blackrules:
        DLVLOG_BLACK_RULES = map(str.strip, args.blackrules.split(','))
    YT_OUTPUT_DLVLOG_PATH = args.output_dlvlog if args.output_dlvlog else "//home/so_fml/nirvana/tmp/dlvlog_reduced_tmp_%s_%s" % (ROUTE, d)
    ctx = nv.context()
    meta = ctx.get_meta()
    doRequest(SAVE_MODEL_STATUS_URL % (meta.get_workflow_uid(), meta.get_workflow_instance_uid(), ROUTE), 'Saving model status')
    if not ytw.is_sorted(YT_DLVLOG_PATH):
        ytw.run_sort(YT_DLVLOG_PATH, sort_by = "queueid")
    if not ytw.is_sorted(YT_CMPLLOG_PATH):
        ytw.run_sort(YT_CMPLLOG_PATH, sort_by = "queueid")
    YT_TMP_DLVLOG_FOLDER = YT_OUTPUT_DLVLOG_PATH[:YT_OUTPUT_DLVLOG_PATH.rfind('/') + 1]
    YT_TMP_DLVLOG_PATH = YT_TMP_DLVLOG_FOLDER + 'dlvlog_reduced_by_vector_' + d
    ytw.run_reduce(LogsReducer(DLVLOG_WHITE_RULES, DLVLOG_BLACK_RULES, ROUTE), [YT_DLVLOG_PATH, YT_CMPLLOG_PATH], YT_OUTPUT_DLVLOG_PATH,
                   reduce_by = "queueid", job_count = 1024, job_io = {"control_attributes": {"enable_table_index": True}})
    print >>sys.stderr, "Delivery-log has %d records, Complaints records count: %d" % (ytw.row_count(YT_DLVLOG_PATH), ytw.row_count(YT_CMPLLOG_PATH))
    ytw.run_map_reduce(DlvLogMapper(YT_RULES_DICT_PATH), LogsReducer2(False, True), YT_OUTPUT_DLVLOG_PATH, YT_TMP_DLVLOG_PATH, reduce_by = ["vector", "rcpt_uid"])
    ytw.run_sort(YT_TMP_DLVLOG_PATH, sort_by = "vector")
    ytw.run_reduce(LogsReducer2(args.multiplicate, False), YT_TMP_DLVLOG_PATH, YT_OUTPUT_DLVLOG_PATH, reduce_by = "vector", job_count = 1024)
    ytw.remove(YT_TMP_DLVLOG_PATH, force = True)
    ytw.set_attribute(YT_OUTPUT_DLVLOG_PATH, 'optimize_for', 'scan')
    ytw.run_merge(YT_OUTPUT_DLVLOG_PATH,
        '<schema = <strict=%false>[{name = queueid; type = string}; {name = cmpl; type = int64}; {name = target; type = int64}; {name = spam; type = string}; {name = rules; type = string}; {name = r_cancel; type = string}; {name = geo; type = string}; {name = domain; type = string}; {name = rcpt_uid; type = string}; {name = ip; type = string}; {name = rdns; type = string}; {name = msgdate; type = string}; {name = mnf; type = string}; {name = sndr; type = string}; {name = sdmn; type = string}; {name = nnsubj; type = string}; {name = nnbody; type = string}; {name = nnfromaddr; type = string}; {name = nnfromname; type = string}]>'
        + YT_OUTPUT_DLVLOG_PATH, mode = 'ordered', spec = {'combine_chunks': 'true', 'data_size_per_job': 1453936477, 'schema_inference_mode': 'from_output', 'force_transform': True})
    ytw.run_sort(YT_OUTPUT_DLVLOG_PATH, sort_by = "domain")
