#!/usr/bin/python2
# encoding: utf-8
# kate: space-indent on; indent-width 4; replace-tabs on;
#
import sys, re, argparse, subprocess, math, json
from urllib import urlopen, quote
import yt.wrapper as ytw
import nirvana.mr_job_context as nv
from traceback import format_exception

YT_FEATURES_DIRECTORY = "//home/so_fml/nirvana/tmp/fml"
GET_FORMULA_FILE_FROM_FML_URL = "https://fml.yandex-team.ru/download/computed/formula?id=%s&file=%s"
SAVE_FORMULA_INFO_URL = "https://web.so.yandex-team.ru/ml/save_formula_info/?"
SAVE_WORKFLOW_INTERMIDIATES_INFO_URL = "https://web.so.yandex-team.ru/ml/save_intermidiate_tables_info/?formula_id=%s&workflow_id=%s&workflow_instance_id=%s&source_op=%s"
SAVE_MODEL_STATUS_URL = "https://web.so.yandex-team.ru/ml/save_model_status/?workflow_id=%s&workflow_instance_id=%s&formula_id=%s&status=models_params_calc&route=%s"
OUTPUT_JSON_FILE = './model_params.json'
MX_OPS_EXECUTABLE = './mx_ops'

def get_traceback():
    exc_type, exc_value, exc_traceback = sys.exc_info()
    tb = ''
    for step in format_exception(exc_type, exc_value, exc_traceback):
        try:
            tb += "\t" + step.strip() + "\n"
        except:
            pass
    return tb

def fw_score(tp, fp, tn, fn, alpha):
    aa = alpha * alpha
    tmp  = (aa + 1.0) * tp
    r = tmp / (tmp + fn + aa * fp)
    return r

def matrixnet(mx_ops, matrixnetfile, src, dst):
    p = subprocess.Popen("%s info %s" % (mx_ops, matrixnetfile), shell = True, stdin = subprocess.PIPE, stdout = subprocess.PIPE)
    m = re.search(r"Slices:\t(.*?)\n", p.stdout.read())
    slice_info = ''
    if m and m.group(1):
        slice_info = m.group(1)
        f = open("./slicesinfo", "w")
        f.write(slice_info)
        f.close()
        slice_info = "--slices-info ./slicesinfo"
    cmd = '%s calc -s 4 --mr-server hahn.yt.yandex.net --mr-user so_fml --mr-src "%s" --mr-dst "%s" %s %s' % (mx_ops, src, dst, slice_info, matrixnetfile)
    p = subprocess.Popen(cmd, shell = True, stdin = subprocess.PIPE, stdout = subprocess.PIPE, stderr = subprocess.PIPE)
    p.wait()
    mx_ops_stdout = p.stdout.read()
    if mx_ops_stdout:
        print "MX_OPS stdout: %s\nMX_OPS output's end" % mx_ops_stdout
    mx_ops_stderr = p.stderr.read()
    if mx_ops_stderr:
        print "MX_OPS stderr: %s\nMX_OPS errors output's end" % mx_ops_stderr

def doRequest(url, prompt):
    try:
        f = urlopen(url)
        if f.getcode() == 200:
            return f.read()
        else:
            print >>sys.stderr, '{0} response HTTP code: {1}, body: {2}'.format(prompt, f.getcode(), f.info())
    except Exception, e:
        print >>sys.stderr, '%s HTTP request failed: %s' % (prompt, str(e))
    return ""

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('-f', '--formula',        type = str, help = "Output formula after training")
    parser.add_argument('-d', '--directory',      type = str, help = "Features' directory path in YT")
    parser.add_argument('-o', '--output_params',  type = str, help = "Output parameters of the model in JSON-format: threshold, model type, F-measure and so on")
    parser.add_argument('-x', '--mx_ops',         type = str, help = "Path to mx_ops executable")
    parser.add_argument('-v', '--vw_resource_id', type = str, help = "Resource ID of VW model")
    parser.add_argument('-r', '--route',          type = str, help = "The type of mail for which the model is calculated")
    args, formula_id, vw_resource_id = parser.parse_known_args()[0], '', ''
    ROUTE = args.route if args.route else 'in'
    if not args.formula:
        print >>sys.stderr, "Must be formula's file name specified!"
        sys.exit(1)
    if args.mx_ops:
        MX_OPS_EXECUTABLE = args.mx_ops
    else:
        print >>sys.stderr, "Path to mx_ops executable must be specified!"
        sys.exit(1)
    if args.directory:
        YT_FEATURES_DIRECTORY = args.directory
    if args.output_params:
        OUTPUT_JSON_FILE = args.output_params
    try:
        f = open(args.formula)
        formula_id = json.loads(f.read())['id']
        f.close()
    except Exception, e:
        print >>sys.stderr, "Load formula ID from local file error: %s.%s" % (str(e), get_traceback())
    try:
        f = open(args.vw_resource_id)
        vw_resource_id = f.read().strip()
        f.close()
    except Exception, e:
        print >>sys.stderr, "Load VW model resource ID from local file error: %s.%s" % (str(e), get_traceback())

    MATRIXNET_FILE = './matrixnet.info'
    YT_FEATURES = YT_FEATURES_DIRECTORY + '/features'
    YT_FEATURES_TEST = YT_FEATURES_DIRECTORY + '/features_test_' + str(formula_id)

    ctx = nv.context()
    meta = ctx.get_meta()
    f = open(MATRIXNET_FILE, 'w')
    print >>f, doRequest(GET_FORMULA_FILE_FROM_FML_URL % (formula_id, "matrixnet.info"), "Retrieving formula's file from FML")
    f.close()
    doRequest(SAVE_MODEL_STATUS_URL % (meta.get_workflow_uid(), meta.get_workflow_instance_uid(), formula_id, ROUTE), 'Saving model status')
    if ytw.exists(YT_FEATURES_TEST):
        ytw.remove(YT_FEATURES_TEST, force = True)
    matrixnet(MX_OPS_EXECUTABLE, MATRIXNET_FILE, YT_FEATURES, YT_FEATURES_TEST)
    if not (ytw.exists(YT_FEATURES_TEST) and ytw.row_count(YT_FEATURES_TEST) == ytw.row_count(YT_FEATURES)):
        print "MX_OPS error: output YT table does not exist or is empty!"
        sys.exit(1)
    matrixnet_result_list, res_count, res_count_i = [], {1: 0, 0: 0}, {1: 0, 0: 0}
    for record in ytw.read_table(YT_FEATURES_TEST, format = ytw.JsonFormat(), raw = False):
        rec_list = str(record["value"]).split("\t")
        target, score = int(rec_list[0]), float(rec_list[-1])
        matrixnet_result_list.append([target, score])
        res_count[target] += 1
    matrixnet_result_list = sorted(matrixnet_result_list, key = lambda l: l[-1])
    param_max, pre, rec, f_score_max = [], 0, 0, -1.0
    for elem in matrixnet_result_list:
        res_count_i[elem[0]] += 1.0
        tn, fp = res_count_i[0], res_count_i[1]
        fn, tp = res_count[0] - tn, res_count[1] - fp
        pre = tp * 1.0 / (tp + fp) if tp + fp > 0 else 0
        rec = tp * 1.0 / (tp + fn) if tp + fn > 0 else 0
        fw = fw_score(tp, fp, tn, fn, 1.0)
        if fw > f_score_max:
            f_score_max = fw
            f1 = 2.0 * pre * rec / (pre + rec)
            param_max = [fw, elem[-1], tp, tn, fp, fn, pre, rec, f1]
    workflow_url = meta.get_workflow_url()
    if workflow_url.endswith(meta.get_workflow_uid()):
        workflow_url += '/' + meta.get_workflow_instance_uid()
    output_params = {
        'formula_id':           formula_id,
        'route':                ROUTE,
        'r_alpha':              param_max[0],
        'threshold':            param_max[1],
        'tp':                   int(param_max[2]),
        'tn':                   int(param_max[3]),
        'fp':                   int(param_max[4]),
        'fn':                   int(param_max[5]),
        'precision':            param_max[6],
        'recall':               param_max[7],
        'f1_score':             param_max[8],
        'rules_dict':           YT_FEATURES_DIRECTORY + '/rules_full_dict',
        'workflow_id':          meta.get_workflow_uid(),
        'workflow_instance_id': meta.get_workflow_instance_uid(),
        'workflow_url':         workflow_url,
        'workflow_description': meta.get_description(),
        'vw_model_resource_id': vw_resource_id if vw_resource_id else '',
        'model_type':           'matrixnet'
    }
    doRequest(SAVE_FORMULA_INFO_URL + '&'.join(map(lambda it: "%s=%s" % (it[0], quote(str(it[1]))), output_params.items())), 'Saving model\'s info')
    doRequest(SAVE_WORKFLOW_INTERMIDIATES_INFO_URL % (formula_id, meta.get_workflow_uid(), meta.get_workflow_instance_uid(), 'calc_threshold'), 'Saving workflow intermidiate tables info')
    try:
        f = open(OUTPUT_JSON_FILE, 'wt')
        print >>f, json.dumps(output_params)
        f.close()
    except Exception, e:
        print >>sys.stderr, 'Saving result file error: %s.%s' % (str(e), get_traceback())
