#!/usr/bin/env python
#
import sys, re, argparse, json
import yt.wrapper as ytw
from datetime import date, datetime, timedelta
from time import localtime, mktime, strftime, strptime, sleep, localtime
from collections import defaultdict
from threading import Thread
from traceback import format_exception

YT_LOG_DIR_1D = "//home/logfeller/logs/axis-all-persfacts-log/1d"
YT_LOG_DIR_5MIN = "//home/logfeller/logs/axis-all-persfacts-log/stream/5min"
YT_LOG_DIR_ML_OUT = "//home/logfeller/logs/mail-so-out-log/1d"
YT_LOG_DIR_ML_OUT_30MIN = "//home/logfeller/logs/mail-so-out-log/30min"
YT_TMP_LOG_PATH = "//home/so_fml/nirvana/tmp/tmp_bounceslog"
USERINFO_URL = "https://web.so.yandex-team.ru/tools/get_user_info?"
MSGID_SIZE_LIMIT = 16382
RECALC_TIME = 300    # in seconds
MILLENNIUM = date(2000, 1, 1)
KEYS = {
    "sender":           "senders",
    "sender_domain":    "senders_domains",
    "recipient":        "recipients",
    "recipient_domain": "recipients_domains"
}

BOUNCE_SPAM_RE = re.compile(r"(?:spam|sender\s+domain\s+must\s+exist|rejected\s+due\s+to\s+(?:DMARC\s+policy|content\s+restrictions)|"
                            + r"Mail\s+contents+denied|Sender\s+denied)", re.S | re.I)
BOUNCE_REGULAR_RE = re.compile(r"(?:quota|blocked|inactive|relay)", re.S | re.I)
#BOUNCE_SMTP_RE = re.compile(r"(?:\nX-Mailer: Yamail|Received:\s+from.*by smtp[^\n]+?\((?:nw)?smtp(?:corp)?\/Yandex\) with E?SMTP)", re.S | re.I)
BOUNCE_SMTP_RE = re.compile(r"(?:\b(?:mx)?smtp(?:corp)?\b)", re.S | re.I)
BOUNCE_UNKNOWN_RE = re.compile(r"((?:no\s+such|unknown|not\s+stored\s+this)\s+user|(?:unknown|bad)\s+recipient|undeliverable\\s+address|"
                               + r"(recipient|addressee|user)\s+(undeliverable|.*?\bunknown\b|.*?\bnot\s+found\b)|invalid\s+(mailbox|recipient)|"
                               + r"unknown\s+or\s+illegal\s+alias|5\d\d.*?unknown|no\s+(?:mailbox\s+here|valid\s+recipients)|"
                               + r"user\s+doesn't\s+have\s+a\s+account|(account|mailbox).*?\b(unavailable|not\s+(?:exists?|found))\b)", re.S | re.I)
BOUNCE_FORWARD_RE = re.compile(r"(?:\bX-Yandex-F(?:orwar|w)d\:)", re.S | re.I)
EMAIL_RE = re.compile(r"([^\n@]+)\@((?:(?:[a-z0-9](?:[a-z0-9-]{0,61}[a-z0-9])?|xn\-\-[a-z0-9-]+)\.)+(?:xn\-\-[a-z0-9-]+|[a-z]+))", re.I)


def get_traceback():
    exc_type, exc_value, exc_traceback = sys.exc_info()
    tb = ''
    for step in format_exception(exc_type, exc_value, exc_traceback):
        try:
            tb += "\t" + step.strip() + "\n"
        except:
            pass
    return tb


def log(s, isTB=False):
    if isTB:
        s += get_traceback()
    print >>sys.stderr, s


def unquote(s):
    return s.replace(r'\\"', r'"').replace(r"\\'", r"'").replace(r"\\\\", r"\\")


@ytw.with_context
def axisLogMapper(record, context):
    if 'type' in record and record['type'] == "bounce" and "data" in record and record["data"]:
        try:
            data = unquote(record["data"])
            data = json.loads(data)
        except Exception, e:
            log("Exception: %s. Data: %s" % (str(e), data), True)
            return
        if isinstance(data["facts"], list) and len(data["facts"]) > 0:
            bounce_types, sender_domain, recipient_domain = [], "", ""
            original_from = unquote(data["facts"][0].get("original_from", "").strip().strip("'").lower())
            original_recipient = unquote(data["facts"][0].get("original_recipient", "").strip().strip("'").lower())
            message_ids = map(lambda s: s.strip()[:MSGID_SIZE_LIMIT], data["facts"][0].get("original_message_id", []))
            if "bounce_traits" in data["facts"][0] and len(data["facts"][0]["bounce_traits"]) > 0:
                bounce_types = map(unicode.lower, data["facts"][0]["bounce_traits"])
            else:
                m = BOUNCE_SPAM_RE.search(data["facts"][0].get("diagnostic_code", ""))
                if m:
                    bounce_types.append("spam")
                m = BOUNCE_REGULAR_RE.search(data["facts"][0].get("diagnostic_code", ""))
                if m:
                    bounce_types.append("regular")
                m = BOUNCE_UNKNOWN_RE.search(data["facts"][0].get("diagnostic_code", ""))
                if m:
                    bounce_types.append("unknown")
                m = BOUNCE_SMTP_RE.search(data["facts"][0].get("reporting_mta", ""))
                if m:
                    bounce_types.append("smtp")
                if len(data["facts"][0].get("original_message_id", "")) > 1:
                    bounce_types.append("forward")
                if "regular" not in bounce_types and "forward" not in bounce_types and (not original_from
                    or not original_recipient or not data.get("received-date", "")):
                        bounce_types.append("invalid")
            t = mktime(strptime(record["iso_eventtime"], "%Y-%m-%d %H:%M:%S"))
            t = strftime("%Y-%m-%d %H:%M", localtime(int(int(t) / RECALC_TIME) * RECALC_TIME))
            m = EMAIL_RE.match(original_from)
            if m:
                sender_domain = m.group(2)
            m = EMAIL_RE.match(original_recipient)
            if m:
                recipient_domain = m.group(2)
            yield {
                "uid":              data.get("uid", ""),
                "date":             t,
                "stid":             data.get("stid", ""),
                "bounce_types":     bounce_types,
                "sender":           original_from,
                "recipient":        original_recipient,
                "sender_domain":    sender_domain,
                "recipient_domain": recipient_domain,
                "message_id":       message_ids
            }


@ytw.with_context
class UniqBouncesReducer:
    def __init__(self, scale):
        self.scale = scale

    def __call__(self, key, records, context):
        record = records.next()
        if self.scale == 'd':
            yield ytw.create_table_switch(1)
            if "invalid" not in record["bounce_types"]:
                for msg_id in record["message_id"]:
                    yield {
                        "stid":          record["stid"],
                        "original_from": record["sender"],
                        "message_id":    msg_id
                    }
            yield ytw.create_table_switch(0)
        yield record


@ytw.with_context
def mlOutLogMapper(record, context):
    if 'msid' in record and record["msid"] and "from" in record and record["from"]:
        yield {"message_id": record["msid"].strip()[:MSGID_SIZE_LIMIT], "from": record["from"]}


@ytw.with_context
def mlOutLogReducer(key, records, context):
    bounces_msgids, mlout_msgids, invalid_bounce, mlout_from, bounce_from = [], [], True, "", ""
    for record in records:
        if "stid" in record and record["stid"]:
            bounces_msgids.append(record)
        elif "from" in record and record["message_id"]:
            mlout_msgids.append(record["message_id"])
            mlout_from = record["from"]
    for record in bounces_msgids:
        if record["message_id"] not in mlout_msgids and (not record["original_from"] or record["original_from"] == mlout_from):
            yield {"message_id": record["message_id"], "stid": record["stid"]}


@ytw.with_context
def bouncesMsgIdsReducer(key, records, context):
    stid, bounce_info, absent_msgids, invalid_cnt = key["stid"], {}, [], 0
    for record in records:
        if "uid" in record and record["uid"] is not None:
            bounce_info.update(record)
        else:
            absent_msgids.append(record["message_id"])
    if "invalid" not in bounce_info["bounce_types"] and len(bounce_info["message_id"]) == len(absent_msgids):
        yield ytw.create_table_switch(1)
        yield bounce_info
        bounce_info["bounce_types"].append("invalid")
        yield ytw.create_table_switch(0)
    yield bounce_info


@ytw.with_context
class BouncesMapper:
    def __init__(self, scale, key):
        self.scale = scale
        self.key = key

    def __call__(self, record, context):
        yield {
            "uid":         record["uid"],
            "date":        record["date"].split()[0] if self.scale == 'd' else record["date"],
            "bounce_type": "TOTAL",
            self.key:      record[self.key]
        }
        for bounce_type in record["bounce_types"]:
            yield {
                "uid":         record["uid"],
                "date":        record["date"].split()[0] if self.scale == 'd' else record["date"],
                "bounce_type": bounce_type,
                self.key:      record[self.key]
            }


@ytw.with_context
class BouncesDomainsReducer:
    def __init__(self, key):
        self.key = key

    def __call__(self, sortKey, records, context):
        yield {
            "uid":         sortKey["uid"],
            "date":        sortKey["date"],
            "bounce_type": sortKey["bounce_type"],
            self.key:      "TOTAL",
            "total":       len(list(records))
        }
        yield {
            "uid":         sortKey["uid"],
            "date":        sortKey["date"],
            "bounce_type": sortKey["bounce_type"],
            self.key:      sortKey[self.key],
            "total":       len(list(records))
        }


@ytw.with_context
class BouncesReducer:
    def __init__(self, key):
        self.key = key

    def __call__(self, sortKey, records, context):
        yield {
            "date":        sortKey["date"],
            "bounce_type": sortKey["bounce_type"],
            self.key:      sortKey[self.key],
            "total":       len(list(records))
        }


@ytw.with_context
class BouncesUltimateReducer:
    def __init__(self, key):
        self.key = key

    def __call__(self, sortKey, records, context):
        if self.key in sortKey:
            key_total, key_uniq = "total", "uniq"
        else:
            key_total, key_uniq = "{}_total".format(self.key), "{}_uniq".format(self.key)
        counters = {"date": sortKey["date"], "bounce_type": sortKey["bounce_type"], key_total: 0, key_uniq: 0}
        if self.key in sortKey:
            counters[self.key] = sortKey[self.key]
        for record in records:
            counters[key_total] += record["total"]
            counters[key_uniq]  += 1
        yield counters


class CalcUltimately(Thread):
    def __init__(self):
        Thread.__init__(self)

    def __call__(self, key, scale, inputTable):
        self.daemon = True
        self.key = key
        self.scale = scale
        self.inputTable = inputTable
        self.intermidiateTable1 = "{}_{}_auxiliary_tmp".format(inputTable, key)
        self.intermidiateTable = "{}_{}_tmp".format(inputTable, key)
        self.outputTable = "{}_{}".format(inputTable, key)
        self.start()

    def run(self):
        self.ytClient = ytw.YtClient(proxy="hahn")
        if self.key.endswith("domain"):
            self.ytClient.run_map_reduce(BouncesMapper(self.scale, self.key), BouncesDomainsReducer(self.key), self.inputTable, self.intermidiateTable1, reduce_by=["uid", "date", "bounce_type", self.key], spec={
                "max_data_size_per_job": 6442450944,
                "reducer": {"data_size_per_sort_job": 671088640, "memory_limit": 4294967296},
                "owners": ['robot-mailspam']
            })
            self.ytClient.run_sort(self.intermidiateTable1, sort_by=["date", "bounce_type", self.key])
            self.ytClient.run_reduce(BouncesUltimateReducer(self.key), self.intermidiateTable1, self.intermidiateTable, reduce_by=["date", "bounce_type", self.key])
            ytw.remove(self.intermidiateTable1, force=True)
        else:
            self.ytClient.run_map_reduce(BouncesMapper(self.scale, self.key), BouncesReducer(self.key), self.inputTable, self.intermidiateTable, reduce_by=["date", "bounce_type", self.key], spec={
                "max_data_size_per_job": 6442450944,
                "reducer": {"data_size_per_sort_job": 671088640, "memory_limit": 4294967296},
                "owners": ['robot-mailspam']
            })
        self.ytClient.run_sort(self.intermidiateTable, sort_by=["date", "bounce_type"])
        self.ytClient.run_reduce(BouncesUltimateReducer(self.key), self.intermidiateTable, self.outputTable, reduce_by=["date", "bounce_type"])


def sortTuplesWithTotal(x, y):
    if y[0] == 'TOTAL':
        return 0 if x[0] == y[0] else -1
    else:
        return cmp(x[0], y[0])


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('-d', '--refdate',      type=str, help="Reference date for calculating statistics")
    parser.add_argument('-p', '--period',       type=str, help="Days period for recalculating data for YaStat")
    parser.add_argument('-s', '--scale',        type=str, help="Scale of output data ('m', 'h' or 'd')")
    parser.add_argument('-t', '--total_stats',  type=str, help="Path to output JSON with general statistics to upload to YaStat")
    parser.add_argument('-o', '--domain_stats', type=str, help="Path to output JSON with domains statistics to upload to YaStat")
    args = parser.parse_known_args()[0]
    TOTALS_JSON_FILE = args.total_stats if args.total_stats else './total_stats.json'
    DOMAINS_JSON_FILE = args.domain_stats if args.domain_stats else './domain_stats.json'
    DAYS = int(args.period) if args.period else 1
    top_date = args.refdate.split()[0] if args.refdate else date.today().isoformat()
    try:
        topd = datetime.strptime(top_date, '%Y-%m-%d').date()
    except Exception, e:
        log("Error while parsing date string '%s': %s" % (top_date, str(e)), True)
        sys.exit(1)
    ytw.config["read_parallel"]["enable"] = True
    inputTables, bouncesTable, invalidMsgIds = [], "{}_{}_{}".format(YT_TMP_LOG_PATH, args.scale, top_date if args.scale == 'd' else args.refdate), []
    ML_OUT_MSG_IDS_TABLE = "{}_mlout_msgids".format(bouncesTable)
    BOUNCES_MSG_IDS_TABLE = "{}_bounces_msgids".format(bouncesTable)
    MSG_IDS_TABLE = "{}_msgids".format(bouncesTable)
    NO_MSG_IDS_TABLE = "{}_no_msgids".format(bouncesTable)
    INVALID_BOUNCES_TABLE = "{}_invalid_bounces".format(bouncesTable)
    bouncesTable2 = "{}_tmp".format(bouncesTable)
    yesterday = (topd - timedelta(days=1))
    if args.scale == 'd':
        for i in range(DAYS, 0, -1):
            d = (topd - timedelta(days=i)).isoformat()
            inputTables.append("{}/{}".format(YT_LOG_DIR_1D, d))
    else:
        d = yesterday if datetime.now().hour < 2 else topd
        ytMap, minDateTime = ytw.get(YT_LOG_DIR_5MIN, attributes=['key']), "{}T00:00:00".format(d.isoformat())
        for table in ytMap.keys():
            if ytMap[table].attributes['key'] >= minDateTime:
                inputTables.append("{}/{}".format(YT_LOG_DIR_5MIN, table))
    log("Scale: %s. Date: %s. Period: %s. Input tables: %s." % (args.scale, top_date, DAYS, str(inputTables)))
    if args.scale == 'd':
        mlOutTables = ["{}/{}".format(YT_LOG_DIR_ML_OUT, yesterday.isoformat())]
        ytMap, minDateTime = ytw.get(YT_LOG_DIR_ML_OUT_30MIN, attributes=['key']), "{}T00:00:00".format(topd.isoformat())
        for table in ytMap.keys():
            if ytMap[table].attributes['key'] >= minDateTime:
                mlOutTables.append("{}/{}".format(YT_LOG_DIR_ML_OUT_30MIN, table))
        ytw.run_map(mlOutLogMapper, mlOutTables, ML_OUT_MSG_IDS_TABLE, job_count=1000)
        ytw.run_map_reduce(axisLogMapper, UniqBouncesReducer(args.scale), inputTables, [bouncesTable, BOUNCES_MSG_IDS_TABLE], reduce_by=["stid"])
        log("Total count of bounces: %s" % ytw.row_count(bouncesTable))
        ytw.run_sort([BOUNCES_MSG_IDS_TABLE, ML_OUT_MSG_IDS_TABLE], MSG_IDS_TABLE, sort_by=["message_id"])
        ytw.remove(BOUNCES_MSG_IDS_TABLE)
        ytw.remove(ML_OUT_MSG_IDS_TABLE)
        ytw.run_reduce(mlOutLogReducer, MSG_IDS_TABLE, NO_MSG_IDS_TABLE, reduce_by=["message_id"])
        ytw.remove(MSG_IDS_TABLE)
        ytw.run_sort([bouncesTable, NO_MSG_IDS_TABLE], bouncesTable2, sort_by=["stid"])
        ytw.remove(NO_MSG_IDS_TABLE)
        ytw.run_reduce(bouncesMsgIdsReducer, bouncesTable2, [bouncesTable, INVALID_BOUNCES_TABLE], reduce_by=["stid"])
        ytw.remove(bouncesTable2)
        log("Invalid msg ids count: %s. Total count of bounces2: %s." % (ytw.row_count(INVALID_BOUNCES_TABLE), ytw.row_count(bouncesTable)))
        ytw.remove(INVALID_BOUNCES_TABLE)
    else:
        ytw.run_map_reduce(axisLogMapper, UniqBouncesReducer(args.scale), inputTables, bouncesTable, reduce_by=["stid"])
    threads = []
    for key in KEYS:
        t = CalcUltimately()
        t(key, args.scale, bouncesTable)
        threads.append(t)
    n, i = len(threads), 0
    while n > 0:
        if threads[i].isAlive():
            i += 1
        else:
            del threads[i]
        n = len(threads)
        if i >= n - 1:
            i = 0
            sleep(1)
    ytw.remove(bouncesTable, force=True)
    totals_data, domains_data, totals_stats = [], [], defaultdict(lambda: defaultdict(lambda: defaultdict(int)))
    domains_stats = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(int)))))
    for key in KEYS:
        domain_type = key if key.endswith("domain") else ""
        if domain_type:
            for row in ytw.read_table("{}_{}_tmp".format(bouncesTable, key), format=ytw.JsonFormat(), raw=False):
                dt = row["date"]
                if args.scale != 'd':
                    dt += ":00"
                domain = row[key] if row[key] else '_'
                domains_stats[dt][row["bounce_type"]][domain_type][domain]["total"] = row["total"]
                domains_stats[dt][row["bounce_type"]][domain_type][domain]["uniq"] = row["uniq"]
        ytw.remove("{}_{}_tmp".format(bouncesTable, key), force=True)
        key_total, key_uniq = "{}_total".format(key), "{}_uniq".format(key)
        keys_total, keys_uniq = "{}_total".format(KEYS[key]), "{}_uniq".format(KEYS[key])
        for row in ytw.read_table("{}_{}".format(bouncesTable, key), format=ytw.JsonFormat(), raw=False):
            dt = row["date"]
            if args.scale != 'd':
                dt += ":00"
            totals_stats[dt][row["bounce_type"]][keys_total] = row[key_total]
            totals_stats[dt][row["bounce_type"]][keys_uniq] = row[key_uniq]
        ytw.remove("{}_{}".format(bouncesTable, key), force=True)
    for (d, bounce_stat) in domains_stats.iteritems():
        for (bounce_type, domain_stat) in sorted(bounce_stat.iteritems(), cmp=sortTuplesWithTotal):
            for (domain_type, domain_type_stat) in domain_stat.iteritems():
                for (domain, stat) in sorted(domain_type_stat.iteritems(), cmp=sortTuplesWithTotal):
                    dataRow = {"fielddate": d, "bounce_type": bounce_type, "domain_type": domain_type, "domain": domain}
                    for (k, v) in stat.iteritems():
                        dataRow[k] = v
                    domains_data.append(dataRow)
    for (d, bounce_stat) in totals_stats.iteritems():
        for (bounce_type, stat) in sorted(bounce_stat.iteritems(), cmp=sortTuplesWithTotal):
            dataRow = {"fielddate": d, "bounce_type": bounce_type}
            for (k, v) in stat.iteritems():
                dataRow[k] = v
            totals_data.append(dataRow)
    try:
        f = open(TOTALS_JSON_FILE, 'wt')
        print >>f, json.dumps(totals_data)
        f.close()
    except Exception, e:
        print >>sys.stderr, 'Totals data file saving error: %s.%s' % (str(e), get_traceback())
    try:
        f = open(DOMAINS_JSON_FILE, 'wt')
        print >>f, json.dumps(domains_data)
        f.close()
    except Exception, e:
        print >>sys.stderr, 'Domains data file saving error: %s.%s' % (str(e), get_traceback())
