#!/usr/bin/python2
# encoding: utf-8
# kate: space-indent on; indent-width 4; replace-tabs on;
#
import sys, argparse, json
import yt.wrapper as ytw
import nirvana.mr_job_context as nv
from urllib import urlopen
from datetime import date
from time import time
from math import ceil
from random import random
from traceback import format_exception

SAVE_MODEL_STATUS_URL = "https://so-web.n.yandex-team.ru/ml/save_model_status/?workflow_id=%s&workflow_instance_id=%s&status=pool_gathering&route=%s"
GET_DOMAINS_FILTER_ITEMS_URL = "https://so-web.n.yandex-team.ru/ml/get_domains_filter_items?route="

def get_traceback():
    exc_type, exc_value, exc_traceback = sys.exc_info()
    tb = ''
    for step in format_exception(exc_type, exc_value, exc_traceback):
        try:
            tb += "\t" + step.strip() + "\n"
        except:
            pass
    return tb

def doRequest(url, prompt):
    try:
        f = urlopen(url)
        if f.getcode() == 200:
            return f.read()
        else:
            print >>sys.stderr, '{0} response HTTP code: {1}, body: {2}'.format(prompt, f.getcode(), f.info())
    except Exception, e:
        print >>sys.stderr, '%s HTTP request failed: %s' % (prompt, str(e))
    return ""

def getDomainsLimits(route):
    domains = {}
    content = doRequest(GET_DOMAINS_FILTER_ITEMS_URL + route, 'Retrieving domains limits')
    if content:
        try:
            domains.update(json.loads(content))
            print >>sys.stderr, "Obtained domains limits: %s" % str(domains)
        except Exception, e:
            print >>sys.stderr, 'Domains limits JSON-data parsing failed: %s.%s' % (str(e), get_traceback())
    return domains

@ytw.with_context
class LogSplitMapper:
    def __init__(self, output_type, route):
        self.output_type = output_type
        self.route = route
        self.domains = getDomainsLimits(route)
    def __call__(self, record, context):
        yield ytw.create_table_switch(1 if record['domain'] in self.domains[self.output_type] else 0)
        yield record

@ytw.with_context
def logComplaintsMapper(record, context):
    yield ytw.create_table_switch(1 if record['cmpl'] > -1 else 0)
    yield record

@ytw.with_context
class LogMapper:
    def __init__(self, step):
        self.step = step
    def __call__(self, record, context):
        if int(context.row_index % self.step) == 0:
            yield record

@ytw.with_context
class LogDomainsReducer:
    def __init__(self, output_type, route):
        self.output_type = output_type
        self.route = route
        self.domains = getDomainsLimits(route)
        if output_type in self.domains:
            self.domains = self.domains[output_type]
        print >>sys.stderr, "Route: %s. Resulted domains limits: %s" % (route, str(self.domains))
    def __call__(self, key, records, context):
        recs = list(records)
        n, domain = len(recs), key['domain']
        if domain in self.domains and self.domains[domain] < n:
            print >>sys.stderr, "Domain: %s, Total records: %s, Limit: %s" % (domain, n, self.domains[domain])
            if self.domains[domain] > 0:
                step = ceil(n * 2.0 / self.domains[domain]) * 1.0 / 2.0
                print >>sys.stderr, "Domain: %s, Total records: %s, Limit: %s, Step: %s" % (domain, n, self.domains[domain], step)
                for i in range(n):
                    if int(i % step) == 0:
                        yield recs[i]
        else:
            for rec in recs:
                yield rec

def safeRemoveTable(yt, table):
    if yt.exists(table):
        yt.remove(table, force = True)

if __name__ == "__main__":
    d = date.today().isoformat()
    FML_DLVLOG_RECORDS_LIMIT = 10000000
    MAX_COMPLAINTS_PART = 30.0
    OUTPUT_DLVLOG_TYPE = "mn"
    YT_OUTPUT_DLVLOG_PATH = "//home/so_fml/nirvana/tmp/dlvlog_truncated_tmp_%s_%s" % (d, int(1000 * random()))
    YT_TABLE_SCHEMA = '<schema = <strict=%false>[{name = queueid; type = string}; {name = cmpl; type = int64}; {name = target; type = int64}; {name = spam; type = string}; {name = rules; type = string}; {name = r_cancel; type = string}; {name = geo; type = string}; {name = domain; type = string}; {name = ip; type = string}; {name = rdns; type = string}]>'
    parser = argparse.ArgumentParser()
    parser.add_argument('-d', '--dlvlog',          type = str,   help = "input YT table with mapped delivery-log")
    parser.add_argument('-c', '--cmpl_maxpercent', type = float, help = "input max percent of complaints in output YT table")
    parser.add_argument('-n', '--max_total',       type = int,   help = "input max records count in output YT table")
    parser.add_argument('-t', '--output_type',     type = str,   help = "output delivery-log type: 'vw' or 'mn'")
    parser.add_argument('-o', '--output_dlvlog',   type = str,   help = "path to output dlvlog in YT")
    parser.add_argument('-r', '--route',           type = str,   help = "The type of mail for which the model is calculated")
    args = parser.parse_known_args()[0]
    if args.dlvlog:
        YT_DLVLOG_PATH = args.dlvlog
    else:
        print >>sys.stderr, "Input YT table with mapped dlvlog must be set. Default path: %s" % YT_DLVLOG_PATH
        sys.exit(1)
    if args.cmpl_maxpercent:
        MAX_COMPLAINTS_PART = args.cmpl_maxpercent
    if args.max_total:
        FML_DLVLOG_RECORDS_LIMIT = args.max_total
    if args.output_type:
        OUTPUT_DLVLOG_TYPE = args.output_type
    if args.output_dlvlog:
        YT_OUTPUT_DLVLOG_PATH = args.output_dlvlog
    ROUTE = args.route if args.route else 'in'

    YT_DLVLOG_FOLDER = YT_DLVLOG_PATH[:YT_DLVLOG_PATH.rfind('/')]
    YT_DLVLOG_NODMNS_PATH = YT_DLVLOG_FOLDER + '/dlvlog_without_domains_%s_%s' % (time(), int(100 * random()))
    YT_DLVLOG_DOMAINS_PATH = YT_DLVLOG_FOLDER + '/dlvlog_with_domains_%s_%s' % (time(), int(100 * random()))
    YT_DLVLOG_TRUNCATED_DOMAINS_PATH = YT_DLVLOG_FOLDER + '/dlvlog_truncated_domains_%s_%s' % (time(), int(100 * random()))
    YT_DLVLOG_NOCMPL_PATH = YT_DLVLOG_FOLDER + '/dlvlog_without_complaints_%s_%s' % (time(), int(100 * random()))
    YT_DLVLOG_CMPL_PATH = YT_DLVLOG_FOLDER + '/dlvlog_with_complaints_%s_%s' % (time(), int(100 * random()))
    ctx = nv.context()
    meta = ctx.get_meta()
    doRequest(SAVE_MODEL_STATUS_URL % (meta.get_workflow_uid(), meta.get_workflow_instance_uid(), ROUTE), 'Saving model status')
    ytw.run_map(LogSplitMapper(OUTPUT_DLVLOG_TYPE, ROUTE), YT_DLVLOG_PATH, [YT_DLVLOG_NODMNS_PATH, YT_DLVLOG_DOMAINS_PATH], job_count = 1000, spec = {"job_io": {"control_attributes": {"enable_table_index": True}}})
    if not ytw.is_sorted(YT_DLVLOG_DOMAINS_PATH):
        ytw.run_sort(YT_DLVLOG_DOMAINS_PATH, sort_by = "domain")
    ytw.run_reduce(LogDomainsReducer(OUTPUT_DLVLOG_TYPE, ROUTE), YT_DLVLOG_DOMAINS_PATH, YT_DLVLOG_TRUNCATED_DOMAINS_PATH,
                   reduce_by = "domain", job_count = 1000, spec = {"job_io": {"control_attributes": {"enable_table_index": True}}, 'reducer': {'memory_limit': 34359738368}})
    safeRemoveTable(ytw, YT_DLVLOG_DOMAINS_PATH)
    truncated_dlvlog_cnt = ytw.row_count(YT_DLVLOG_TRUNCATED_DOMAINS_PATH) + ytw.row_count(YT_DLVLOG_NODMNS_PATH)
    print >>sys.stderr, "Truncated delivery-log has %d records, original delivery-log has %d records" % (truncated_dlvlog_cnt, ytw.row_count(YT_DLVLOG_PATH))
    if OUTPUT_DLVLOG_TYPE == 'mn':
        ytw.run_map(logComplaintsMapper, [YT_DLVLOG_TRUNCATED_DOMAINS_PATH, YT_DLVLOG_NODMNS_PATH], [YT_DLVLOG_NOCMPL_PATH, YT_DLVLOG_CMPL_PATH], job_count = 1000, spec = {"job_io": {"control_attributes": {"enable_table_index": True}}})
        no_cmpl_count, cmpl_count = ytw.row_count(YT_DLVLOG_NOCMPL_PATH), ytw.row_count(YT_DLVLOG_CMPL_PATH)
        if no_cmpl_count + cmpl_count > FML_DLVLOG_RECORDS_LIMIT or cmpl_count * 100.0 / (no_cmpl_count + cmpl_count) > MAX_COMPLAINTS_PART:
            r = min(MAX_COMPLAINTS_PART * FML_DLVLOG_RECORDS_LIMIT / 100.0, MAX_COMPLAINTS_PART * no_cmpl_count / (100.0 - MAX_COMPLAINTS_PART))
            if cmpl_count > r:
                step = ceil(cmpl_count * 2.0 / r) * 1.0 / 2.0
                print >>sys.stderr, "Complaints count for precast delivery-log is %s. So we'll thin it by step %.2f" % (cmpl_count, step)
                ytw.run_map(LogMapper(step), YT_DLVLOG_CMPL_PATH, YT_DLVLOG_CMPL_PATH, job_count = 1000, spec = {"job_io": {"control_attributes": {"enable_row_index": True}}})
                cmpl_count = ytw.row_count(YT_DLVLOG_CMPL_PATH)
            if no_cmpl_count + cmpl_count > FML_DLVLOG_RECORDS_LIMIT:
                step = ceil(no_cmpl_count * 2.0 / (FML_DLVLOG_RECORDS_LIMIT - cmpl_count)) * 1.0 / 2.0
                print >>sys.stderr, "Non-complaints rows count for precast delivery-log is %s. So we'll thin it by step %.2f" % (no_cmpl_count, step)
                ytw.run_map(LogMapper(step), YT_DLVLOG_NOCMPL_PATH, YT_DLVLOG_NOCMPL_PATH, job_count = 1000, spec = {"job_io": {"control_attributes": {"enable_row_index": True}}})
        ytw.run_merge([YT_DLVLOG_NOCMPL_PATH, YT_DLVLOG_CMPL_PATH], YT_TABLE_SCHEMA + YT_OUTPUT_DLVLOG_PATH, spec = {'combine_chunks': 'true', 'data_size_per_job': 1453936477})
        safeRemoveTable(ytw, YT_DLVLOG_NOCMPL_PATH)
        safeRemoveTable(ytw, YT_DLVLOG_CMPL_PATH)
    else:
        if truncated_dlvlog_cnt > FML_DLVLOG_RECORDS_LIMIT:
            step = ceil(truncated_dlvlog_cnt * 2.0 / FML_DLVLOG_RECORDS_LIMIT) * 1.0 / 2.0
            ytw.run_map(LogMapper(step), [YT_DLVLOG_TRUNCATED_DOMAINS_PATH, YT_DLVLOG_NODMNS_PATH], YT_TABLE_SCHEMA + YT_OUTPUT_DLVLOG_PATH, job_count = 1000, spec = {"job_io": {"control_attributes": {"enable_row_index": True}}})
        else:
            ytw.run_merge([YT_DLVLOG_TRUNCATED_DOMAINS_PATH, YT_DLVLOG_NODMNS_PATH], YT_TABLE_SCHEMA + YT_OUTPUT_DLVLOG_PATH, spec = {'combine_chunks': 'true', 'data_size_per_job': 1453936477})
        ytw.run_sort(YT_OUTPUT_DLVLOG_PATH, sort_by = "queueid")    # shuffle VW train pool: SODEV-1297
    safeRemoveTable(ytw, YT_DLVLOG_NODMNS_PATH)
    safeRemoveTable(ytw, YT_DLVLOG_TRUNCATED_DOMAINS_PATH)
