#!/usr/bin/python2
# encoding: utf-8
# kate: space-indent on; indent-width 4; replace-tabs on;
#
import sys, argparse, re
from urllib import urlopen
import yt.wrapper as ytw
import nirvana.mr_job_context as nv
from datetime import date
from traceback import format_exception

SAVE_MODEL_STATUS_URL = "https://web.so.yandex-team.ru/ml/save_model_status/?workflow_id=%s&workflow_instance_id=%s&status=pool_gathering&route=%s"
YT_RULES_DICT_PATH = "//home/so_fml/nirvana/rules_dict"

def get_traceback():
    exc_type, exc_value, exc_traceback = sys.exc_info()
    tb = ''
    for step in format_exception(exc_type, exc_value, exc_traceback):
        try:
            tb += "\t" + step.strip() + "\n"
        except:
            pass
    return tb

def doRequest(url, prompt):
    try:
        f = urlopen(url)
        if f.getcode() == 200:
            return f.read()
        else:
            print >>sys.stderr, '{0} response HTTP code: {1}, body: {2}'.format(prompt, f.getcode(), f.info())
    except Exception, e:
        print >>sys.stderr, '%s HTTP request failed: %s' % (prompt, str(e))
    return ""

def sndrReputationC(s, stype = 'sndr'):
    FEATURE_NUMBER = 16
    res = {}
    pref = 'MNF_S' if stype == 'sndr' else 'MNF_D'
    m = [0] + map(lambda x: int(x), re.sub(r"[^0-9 ]+", "", s.replace("_", ' ')).split())
    if len(m) < FEATURE_NUMBER:
        return res
    res.update({pref + '2W_CS' : m[6] / (m[3] + 10.0)})
    res.update({pref + '2W_CSA': m[6] / (m[3] + m[4] + 10.0)})
    res.update({pref + '2W_CH' : m[5] / (m[4] + 10.0)})
    res.update({pref + '2W_CHA': m[5] / (m[3] + m[4] + 10.0)})
    res.update({pref + '2W_CA' : (m[5] + m[6]) / (m[3] + m[4] + 10.0)})
    res.update({pref + '_CS'   : m[13] / (m[10] + 50.0)})
    res.update({pref + '_CSA'  : m[13] / (m[10] + m[11] + 50.0)})
    res.update({pref + '_CH'   : m[12] / (m[11] + 50.0)})
    res.update({pref + '_CHA'  : m[12] / (m[10] + m[11] + 50.0)})
    res.update({pref + '_CA'   : (m[12] + m[13]) / (m[10] + m[11] + 50.0)})
    res.update({pref + '2W_DH' : m[18] / (m[3] + 10.0)})
    res.update({pref + '2W_DA' : m[18] / (m[3] + m[4] + 10.0)})
    res.update({pref + '_DH'   : m[19] / (m[10] + 10.0)})
    res.update({pref + '_DA'   : m[19] / (m[10] + m[11] + 10.0)})
    return res

def getFloat(v, default = 0.0):
    r = default
    try:
        r = float(v)
    except:
        m = re.match(r'^([-+]?\d[\d\.]*)', v)
        r = float(m.group(1)) if m else default
    return r

@ytw.with_context
class DlvLogMapper:
    def __init__(self, rules_dict, rules_count, is_filter = True, route = 'in', model_type = "matrixnet"):
        self.rules = rules_dict
        self.rules_count, self.is_filter, self.route, self.model_type = rules_count, is_filter, route, model_type
    def __call__(self, record, context):
        if 'target' not in record or 'queueid' not in record:
            return
        vector, add_ff = ["0" for i in range(0, self.rules_count + (1 if self.model_type == "matrixnet" else 0))], {}
        if "mnf" in record and record.get("mnf", ""):
            for p in record.get("mnf", "").split(';'):
                if p:
                    try:
                        k, v = p.split(':')
                        add_ff[k] = getFloat(v)
                    except Exception, e:
                        print >>sys.stderr, "Exception (for MNF record: '%s'): '%s'. %s" % (p, str(e), get_traceback())
        #elif self.route == 'in':
        #    add_ff.update(sndrReputationC(record.get("sndr", ""), stype = "sndr"))
        #    add_ff.update(sndrReputationC(record.get("sdmn", ""), stype = "sdmn"))
        #add_ff.update({"VW_SCORE": getFloat(record.get("vw_w", -5.0), -5.0), "VW_L_SCORE": getFloat(record.get("vw_w_log", 0.0))})
        if self.model_type == "matrixnet":
            add_ff["VW_SCORE"] = 0.0
        add_ff["VW_L_SCORE"] = getFloat(record.get("vw_w_log", 0.0))
        for field in ['rules', 'r_cancel']:
            if field in record:
                for r in record.get(field, "").split(";"):
                    rule_index = self.rules.get(r, None)
                    if rule_index is not None:
                        vector[rule_index] = "1"
        for r, v in add_ff.iteritems():
            rule_index = self.rules.get(r, None)
            if rule_index is not None:
                vector[rule_index] = "%.5f" % getFloat(v)
        for i, c in enumerate(record.get("clusters", "").split(',')):
            r = "VW_SCORE_CLUSTER_%s" % i
            rule_index = self.rules.get(r, None)
            if c and rule_index is not None:
                vector[rule_index] = "%.5f" % getFloat(c)
        if self.model_type == "matrixnet":
            vector[0] = "0"
            r = {"key": str(context.row_index)}
            if self.is_filter:
                r.update({"target": record['target'], "queueid": record['queueid'], "vector": "\t".join(vector)})
            else:
                r["value"] = "%i\t%s\t0\t%s" % (record['target'], record['queueid'], "\t".join(vector))
        else:
            if self.is_filter:
                r.update({"target": record['target'], "vector": "\t".join(vector)})
            else:
                r = {'key': str(record['target']), 'value': '\t'.join(vector)}
        yield r

def logsReducer(key, records):
    r = {"targets_cnt": 0}
    for rec in records:
        if r["targets_cnt"] == 0:
            r.update(rec)
        r["targets_cnt"] += 1
    yield r

class DlvLogMapper2:
    def __init__(self, model_type = "matrixnet"):
        self.model_type = model_type
    def __call__(self, record):
        if record["targets_cnt"] == 1:
            if self.model_type == "matrixnet":
                yield {"key": record['key'], "value": "%i\t%s\t0\t%s" % (record['target'], record['queueid'], record['vector'])}
            else:
                yield {"key": str(record['target']), "value": record['vector']}

def readRules(rules_dict_path):
    rules_all_dict, rules_compact_dict, rules_list, index, rules_count = {}, {}, [], 0, 0
    for rec in ytw.read_table(rules_dict_path, format = ytw.JsonFormat(), raw = False):
        rules_count += 1
        if 'act' in rec and int(rec.get('act', 0)):
            rules_all_dict[rec['rule']] = int(rec['num'])
        else:
            continue
        rules_compact_dict[rec['rule']] = index
        index += 1
        rules_list.append(rec['rule'])

    return rules_all_dict, rules_compact_dict, rules_list, rules_count

def isNumerical(name):
    return name.startswith("MNF_") or name in ['VW_SCORE', 'VW_L_SCORE'] or name.startswith("VW_SCORE_CLUSTER_")

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('-r', '--rulesdict',         type = str, help = "Input YT table with rules dictionary")
    parser.add_argument('-f', '--filter', action = 'store_true', help = "Whether filter delivery-log by targets or not?")
    parser.add_argument('-d', '--dlvlog',            type = str, help = "Input YT table with mapped delivery-log and VW resolutions log", default = "")
    parser.add_argument('-p', '--matrixnet_pool',    type = str, help = "Path to prepared pool in YT for train MatrixNet model")
    parser.add_argument('-c', '--catboost_pool',     type = str, help = "Path to prepared pool in YT for train CatBoost model")
    parser.add_argument('-t', '--rules_tsv',         type = str, help = "Path to tsv-file with features for CatBoost model")
    parser.add_argument('-m', '--route',             type = str, help = "The type of mail for which the model is calculated")
    args = parser.parse_known_args()[0]
    ROUTE = args.route if args.route else 'in'
    if args.rulesdict:
        YT_RULES_DICT_PATH = args.rulesdict
    else:
        print >>sys.stderr, "Input YT table with rules dict path must be set. Default path: %s" % YT_RULES_DICT_PATH
        YT_RULES_DICT_PATH = "//home/so_fml/nirvana/rules_dict_%s" % ROUTE
    if args.dlvlog:
        YT_DLVLOG_PATH = ytw.TablePath(args.dlvlog, columns=['target', 'queueid', 'rules', 'r_cancel', 'vw_w_log', 'mnf', "clusters"])
    else:
        print >>sys.stderr, "Input YT table with mapped dlvlog must be set. Default path: %s" % YT_DLVLOG_PATH
        sys.exit(1)
    YT_POOL_PATH = {"matrixnet": args.matrixnet_pool, "catboost": args.catboost_pool}
    ctx = nv.context()
    meta = ctx.get_meta()
    doRequest(SAVE_MODEL_STATUS_URL % (meta.get_workflow_uid(), meta.get_workflow_instance_uid(), ROUTE), 'Saving model status')

    ytw.config["read_parallel"]["enable"] = True
    rules_dict = {}
    rules_dict["matrixnet"], rules_dict["catboost"], rules_list, rules_count = readRules(YT_RULES_DICT_PATH)
    #if not ytw.is_sorted(YT_DLVLOG_PATH):
    #    ytw.run_sort(YT_DLVLOG_PATH, sort_by = "queueid")

    for model_type in ["matrixnet", "catboost"]:
        rules_cnt = rules_count if model_type == "matrixnet" else len(rules_list)
        if args.filter:
            TMP_YT_POOL_PATH = "%s_%s" % (YT_POOL_PATH[model_type], date.today().isoformat())
            ytw.run_map_reduce(DlvLogMapper(rules_dict[model_type], rules_cnt, True, ROUTE, model_type), logsReducer, YT_DLVLOG_PATH, TMP_YT_POOL_PATH, reduce_by = ["vector"], spec = {"map_job_io": {"control_attributes": {"enable_row_index": True}}, "map_job_count": 1024})
            ytw.run_map(DlvLogMapper2(model_type), TMP_YT_POOL_PATH, YT_POOL_PATH[model_type], job_count = 1024)
            ytw.remove(TMP_YT_POOL_PATH, force = True)
        else:
            ytw.run_map(DlvLogMapper(rules_dict[model_type], rules_cnt, False, ROUTE, model_type), YT_DLVLOG_PATH, YT_POOL_PATH[model_type], job_count = 1024, job_io = {"control_attributes": {"enable_row_index": True}})
        ytw.run_sort(YT_POOL_PATH[model_type], sort_by = "key")

    with open(args.rules_tsv, "w+") as f:
        f.write('0\tLabel\n')
        f.writelines(['{index}\t{type}\t{name}\n'.format(index=index+1, type="Num", name=key) for index, key in enumerate(rules_list)])
