# coding=utf-8
"""
This script makes soup connectivity statistics.
"""

import collections
import json
import os
import sys

import argparse
import yt.wrapper as yt

sys.path.append(
    os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir, os.path.pardir)))

import rtcconf.config as config

from utils import mr_utils

from v2.soup import soup_config
from v2.soup import soup_edge_type


SOURCE_COMBINATIONS = [
    soup_config.__dict__.get(name)
    if isinstance(soup_config.__dict__.get(name), (list, tuple, set))
    else [soup_config.__dict__.get(name)]
    for name in dir(soup_config)
]

SOURCE_COMBINATIONS = filter(
    lambda x: isinstance(x, soup_edge_type.EdgeType),
    [item for l in SOURCE_COMBINATIONS for item in l]
)

MAX_IDS_PER_ROW = 128
MAX_RELATIONS_PER_KEY = 256

sys.path.append(
    os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir)))

YT_CLUSTER = "hahn"
DEFAULT_YT_PROXY = "{}.yt.yandex.net".format(YT_CLUSTER)

TESTING_INFRASTRUCTURE = "testing"
PRODUCTION_INFRASTRUCTURE = "production"
DEFAULT_INFRASTRUCTURE = TESTING_INFRASTRUCTURE
DEFAULT_SOURCE_FOLDER = "//home/crypta/{}/state/graph/v2/soup".format(TESTING_INFRASTRUCTURE)

DIRECTION_FROM = "from"
DIRECTION_TO = "to"

TABLE_TYPE = "table"

TOPOHASHES_PIVOT_FIELDS = ["sourceType", "logSource"]

MINIMUM_BIG_REDUCER_JOB_MEMORY_BYTES = 1 * 1024 * 1024 * 1024


def get_agg_edges_count(reverse=False, out_tables_ids=None, do_split=False):
    # Split sources tables indexes
    #
    # 0 tmp_id_id1type_id2type_edges_count_path,  # id1, id1type, id2type, edgesCount, direction
    # combination.source, combination.id1_type, combination.id2_type, onemany

    id_value_key1 = "id2" if reverse else "id1"
    id_value_key2 = "id1" if reverse else "id2"
    id_type_key1 = "id2Type" if reverse else "id1Type"
    id_type_key2 = "id1Type" if reverse else "id2Type"
    direction = DIRECTION_TO if reverse else DIRECTION_FROM

    def agg_edges_count(keys, recs):
        keys_not_defined_in_config = set()
        id1_to_id2_sources_agg = collections.defaultdict(set)

        value1 = keys[id_value_key1]
        type1 = keys[id_type_key1]
        type2 = keys[id_type_key2]

        if value1 and type1 and type2:
            pivot_values_hashes = collections.Counter()
            ids2 = set()
            too_many_relations = False
            for rec_idx, rec in enumerate(recs):
                value2 = rec[id_value_key2]
                ids2.add(value2)
                pivot_values_hashes[tuple(rec[pivot_field] for pivot_field in TOPOHASHES_PIVOT_FIELDS)] += 1
                if rec_idx > MAX_RELATIONS_PER_KEY:
                    too_many_relations = True
                else:
                    if do_split:
                        id1_to_id2_sources_agg[(value1, value2)].add(rec["sourceType"])

            if too_many_relations:
                yield {
                    "id1Type": type1,
                    "id1": keys[id_value_key1].lower().replace('-', ''),
                    "id2Type": type2,
                    "id2Capped": list(ids2)[:MAX_IDS_PER_ROW],
                    "id2Count": len(ids2),
                    "relationsCountThreshold": MAX_RELATIONS_PER_KEY,
                    "@table_index": out_tables_ids["too_many_relations"]
                }

            for pivot_fields_values, count in pivot_values_hashes.iteritems():
                yield {
                    "id": keys[id_value_key1].lower().replace('-', ''),  # normalize
                    "id1Type": keys[id_type_key1],
                    "id2Type": keys[id_type_key2] + json.dumps(
                        dict(zip(TOPOHASHES_PIVOT_FIELDS, pivot_fields_values)),
                        sort_keys=True
                    ),
                    "direction": direction,
                    "edgesCount": count,
                    "@table_index": out_tables_ids["edges_count"]
                }

                if do_split:
                    for (id_value1, id_value2), sources in id1_to_id2_sources_agg.iteritems():
                        for source in sources:
                            if (
                                    source,
                                    keys[id_type_key1],
                                    keys[id_type_key2],
                                    "one" if len(sources) == 1 else "many"
                            ) in out_tables_ids:
                                yield {
                                    "id1Type": keys[id_type_key1],
                                    "id1": id_value1,
                                    "id2Type": keys[id_type_key2],
                                    "id2": id_value2,
                                    "source": source,
                                    "@table_index": out_tables_ids[(
                                        source,
                                        keys[id_type_key1],
                                        keys[id_type_key2],
                                        "one" if len(sources) == 1 else "many"
                                    )]
                                }
                            elif (
                                    source,
                                    keys[id_type_key2],
                                    keys[id_type_key1],
                                    "one" if len(sources) == 1 else "many"
                            ) in out_tables_ids:
                                yield {
                                    "id1Type": keys[id_type_key2],
                                    "id1": id_value2,
                                    "id2Type": keys[id_type_key1],
                                    "id2": id_value1,
                                    "source": source,
                                    "@table_index": out_tables_ids[(
                                        source,
                                        keys[id_type_key2],
                                        keys[id_type_key1],
                                        "one" if len(sources) == 1 else "many"
                                    )]
                                }
                            else:
                                keys_not_defined_in_config.update((
                                    source,
                                    keys[id_type_key1],
                                    keys[id_type_key2],
                                    "one" if len(sources) == 1 else "many"
                                ))
        if len(keys_not_defined_in_config):
            sys.stderr.write('Bad supometr keys: {}\n'.format(list(keys_not_defined_in_config)))

    return agg_edges_count


def make_topohash(id_type, edges_count_from, edges_count_to):
    return "{}:{}".format(
        id_type,
        "|".join(
            [
                "{}>{}".format(edges_count_from[id2type], id2type)
                for id2type in sorted(edges_count_from.keys())
            ] + [
                "{}<{}".format(edges_count_to[id2type], id2type)
                for id2type in sorted(edges_count_to.keys())
            ]
        )
    )


def count_topohashes_for_nodes(keys, recs):
    counts_from = collections.Counter()
    counts_to = collections.Counter()

    id1_type = keys["id1Type"]

    for rec in recs:
        id2_type = rec["id2Type"]
        if rec["direction"] == DIRECTION_FROM:
            counts_from[id2_type] += rec["edgesCount"]
        if rec["direction"] == DIRECTION_TO:
            counts_to[id2_type] += rec["edgesCount"]

    yield {
        "id": keys["id"],
        "topohash": make_topohash(id1_type, counts_from, counts_to),
        "@table_index": 0
    }


def reduce_topohashes(keys, recs):
    yield {
        "topohash": keys["topohash"],
        "count": sum(1 for _ in recs)
    }


def merge_mapper(rec):
    rec['@table_index'] = 0
    yield rec


def format_h1(h):
    return "\n===== {} =====".format(h)


def format_h2(h):
    return "\n----- {} -----".format(h)


def format_list(l):
    l = l or []
    return "    " + ("\n    ".join(l))


def format_tables_list(tables_names, rows_counts=None, only_names=False):
    if isinstance(tables_names, str):
        tables_names = [tables_names]
    rows_counts = rows_counts or {}
    max_name_len = max(len(str(n)) for n in tables_names)

    if rows_counts:
        max_rows_count_len = max(
            len(str(c)) for c in rows_counts.values()
        )
    else:
        max_rows_count_len = 0

    def format_rows_counts(table_name):
        rc = rows_counts.get(str(table_name), None)
        if rc is None:
            return ""
        else:
            template = "{} records"
            return template.format(rc).rjust(max_rows_count_len + len(template))

    return format_list((
        "{}    {}".format(
            (str(t).split('/')[-1] if only_names else str(t)).ljust(max_name_len),
            format_rows_counts(t)
        )
        for t in tables_names
    ))


def ensure_table_path(tp):
    fp = '/'.join(tp.split('/')[:-1])
    if not yt.exists(fp):
        sys.stdout.write("Creating folder {} for table {}...\n".format(fp, tp))
        yt.mkdir('/'.join(tp.split('/')[:-1]))
        sys.stdout.write("...done\n")
    return tp


def ensure_folder_path(fp):
    if not yt.exists(fp):
        sys.stdout.write("Creating folder {}...\n".format(fp))
        yt.mkdir(fp)
        sys.stdout.write("...done\n")
    return fp


def norm_path(p):
    return p.strip().rstrip('/')


def get_args():
    parser = argparse.ArgumentParser(prog='PROG')
    parser.add_argument(
        '-d', '--date',
        help="Date to process")
    parser.add_argument(
        '-f',
        '--filter-tables',
        type=str,
        nargs='+',
        default=[],
        help="Filter tables to process by given table names. "
             "Only names should be define, not the full tables paths. "
             "Otherwise paths will be truncated to the name.")
    parser.add_argument(
        '-T',
        '--skip-topohashes',
        action='store_true',
        default=False,
        help="Skip topohashes calculation stage")
    parser.add_argument(
        '-C',
        '--skip-cleanup',
        action='store_true',
        default=False,
        help="Skip cleanup stage")
    parser.add_argument(
        '--hide-rows-counts',
        action='store_true',
        default=False,
        help="Show rows count for tables"
    )
    parser.add_argument(
        '--split-data',
        action='store_true',
        default=True,
        help="Splits data not confirmed by other sources"
    )
    parser.add_argument(
        '-t',
        '--target',
        help="Optional target table path"
    )
    parser.add_argument(
        '--temp-folder',
        help="Temporary folder path"
    )
    parser.add_argument(
        '-s',
        '--source',
        default=DEFAULT_SOURCE_FOLDER,
        help="Root path to the soup folder")

    parser.add_argument(
        '-y',
        '--yt-proxy',
        default=DEFAULT_YT_PROXY,
        help="YT Proxy URL")

    parser.add_argument(
        '-m',
        '--yt-max-memory-bytes',
        default=config.YT_JOB_MAX_MEMORY_BYTES,
        help="Maximum YT job memory in bytes (will be used for split reducers)")

    # processing args
    return parser.parse_args()


def do_supometriya(
        date,
        source=DEFAULT_SOURCE_FOLDER,
        target=None,
        temp_folder=None,
        filter_tables=None,
        skip_topohashes=False,
        skip_cleanup=False,
        hide_rows_counts=False,
        split_data=True,
        yt_proxy=None,
        yt_job_max_memory_bytes=MINIMUM_BIG_REDUCER_JOB_MEMORY_BYTES
):
    if isinstance(yt_job_max_memory_bytes, basestring):
        yt_job_max_memory_bytes = int(yt_job_max_memory_bytes)
        
    filter_tables = filter_tables or []
    sys.stdout.write(format_h1("Initializing\n"))
    sys.stdout.write("Supometriya results will be separates by following fields:\n" +
                     (", ".join('"{}"'.format(d) for d in TOPOHASHES_PIVOT_FIELDS)) + "\n" +
                     "This means if we have k1->k2 rel with sourceType=x and k1->k2 rel with sourceType=y "
                     "we will get *not*:\n"
                     "|| sourceType | topohash || ...\n"
                     "||      ?     | k1:2>k2  || ...\n"
                     "||      ?     | k2:2<k1  || ...\n"
                     "But the separate records:\n"
                     "|| sourceType | topohash || ...\n"
                     "||      x     | k1:1>k2  || ...\n"
                     "||      x     | k2:1<k1  || ...\n"
                     "||      y     | k1:1>k2  || ...\n"
                     "||      y     | k2:1<k1  || ...\n"
                     "You have to aggregate it yourself (see Supometr Web UI code how to do this properly)")
    if yt_proxy:
        sys.stdout.write("Using YT proxy: {}\n".format(yt_proxy))
        yt.update_config({"proxy": {"url": yt_proxy}})
    source_folder = norm_path(source).rstrip('/')
    target_table = (
        target or ensure_table_path(
            "{}/{}.supometriya".format(
                source_folder,
                source_folder.split("/")[-1]
            )
        )
    ).rstrip('/')
    sys.stdout.write("\nSource: {}\nTarget table: {}\n Creating folders...\n".format(source, target_table))
    mr_utils.mkdir(source)
    mr_utils.mkdir('/'.join(target_table.split('/')[:-1]))
    sys.stdout.write("...done!\n".format(source, target_table))

    do_topohashes = not skip_topohashes
    do_cleanup = not skip_cleanup

    filter_tables_names = [
        norm_path(t).split('/')[-1] if t else None for t in set(filter_tables)
    ]

    # Temp folders
    tmp_folder = temp_folder or ensure_folder_path("{}/{}.supometr_tmp".format(
        "/".join(target_table.split('/')[:-1]),
        target_table.split('/')[-1])
    )

    sys.stdout.write("Looking for tables in:\n    {}\n".format(source_folder))
    source_tables_all = [
        source_folder + '/' + f
        for f in yt.list(source_folder)
        if yt.get(source_folder + '/' + f + "/@type") == TABLE_TYPE
    ]

    rows_counts = None
    if not hide_rows_counts:
        rows_counts = {
            str(t): yt.get(t + '/@row_count')
            for t in source_tables_all
        }

    if source_tables_all:
        sys.stdout.write(
            "Found following possible sources tables:\n{}\n"
            "".format(format_tables_list(source_tables_all, rows_counts))
        )
    else:
        sys.stderr.write(
            "No appropriate sources tables found in:\n{}\n"
            "Try select another day instead of {}."
            "".format(format_list([source_folder]), date)
        )
        sys.stderr.write("Exiting\n")
        exit(0)

    # noinspection PyTypeChecker
    source_tables_filtered = [
        f for f in source_tables_all
        if (filter_tables_names and f.split('/')[-1] in filter_tables_names) or (not filter_tables_names)
    ]

    if not source_tables_filtered:
        sys.stderr.write(
            "No appropriate tables found, try to change tables "
            "names filtering settings instead of:"
            "\n{}\n"
            "Possible names for the current date folder are:"
            "\n{}\n"
            "Exiting".format(
                format_tables_list(filter_tables_names),
                format_tables_list(source_tables_all, rows_counts, only_names=True)
            )
        )
        exit(0)

    sys.stdout.write(format_h2("Reduce edges stage\n"))

    # id1, id1type, id2type, edgesCount, direction
    tmp_id_id1type_id2type_edges_count = "/".join([tmp_folder.rstrip('/'), 'id_id1type_id2type_edges_counts_total.unsorted'])
    tmp_id_id1type_id2type_edges_count_appendable = yt.TablePath(tmp_id_id1type_id2type_edges_count, append=True)

    too_many_relations_table = "/".join([tmp_folder.rstrip('/'), 'more_than_{}_relations_keys'.format(MAX_RELATIONS_PER_KEY)])
    too_many_relations_table_appendable = yt.TablePath(too_many_relations_table, append=True)

    if yt.exists(tmp_id_id1type_id2type_edges_count_appendable):
        sys.stdout.write(
            "Removing existing table {}\n".format(format_list([tmp_id_id1type_id2type_edges_count])))
        yt.remove(tmp_id_id1type_id2type_edges_count_appendable, recursive=True, force=True)

    edges_agg_target_tables = [tmp_id_id1type_id2type_edges_count_appendable, too_many_relations_table_appendable]
    edges_agg_out_tables_ids = {
        "edges_count": 0,
        "too_many_relations": 1
    }

    if split_data:
        tables_ids_offset = len(edges_agg_target_tables)
        ensure_folder_path("{}/split".format(tmp_folder))
        # Split sources tables
        split_tables_params = sorted(list(set([
            (combination.source, combination.id1_type, combination.id2_type, onemany)
            for combination in SOURCE_COMBINATIONS
            for onemany in ("one", "many")
        ])))

        edges_agg_target_tables.extend([
            "{}/split/{}".format(
                tmp_folder,
                "-".join(params)
            )
            for params in split_tables_params
        ])

        edges_agg_out_tables_ids.update({
            params: idx + tables_ids_offset
            for idx, params
            in enumerate(split_tables_params)
        })
    sys.stdout.write(
        "Edges aggregation target tables:\n{}\n"
        "Processing edges...\n".format(
            format_tables_list(edges_agg_target_tables))
    )

    yt.run_map_reduce(
        merge_mapper,
        get_agg_edges_count(out_tables_ids=edges_agg_out_tables_ids, reverse=False, do_split=True),
        source_tables_filtered,
        edges_agg_target_tables,
        sort_by=["id1Type", "id2Type", "id1"],
        reduce_by=["id1Type", "id2Type", "id1"],
        spec={
            "reducer": {
                "memory_limit": max(yt_job_max_memory_bytes, MINIMUM_BIG_REDUCER_JOB_MEMORY_BYTES)
            }
        }
    ),
    yt.run_map_reduce(
        merge_mapper,
        get_agg_edges_count(out_tables_ids=edges_agg_out_tables_ids, reverse=True, do_split=False),
        source_tables_filtered,
        edges_agg_target_tables,
        sort_by=["id2Type", "id1Type", "id2"],
        reduce_by=["id2Type", "id1Type", "id2"],
        spec={
            "reducer": {
                "memory_limit": max(yt_job_max_memory_bytes, MINIMUM_BIG_REDUCER_JOB_MEMORY_BYTES) // 2
            }
        }
    )
    sys.stdout.write("...done\n")
    sys.stdout.write("Reduce is done\n")

    # Topohashes
    sys.stdout.write(format_h2("Topohashes calculation stage") + '\n')

    # One side of types bijection is enough.
    if do_topohashes:
        # # This table is for counting "topohashes"
        tmp_id_id1type_id2type_edges_count_path_sorted_id1type_id = yt.TablePath(
            "{}/id1_id1type_id2type.by_id1_id1type".format(tmp_folder))
        tmp_id_topohash = yt.TablePath(
            "{}/id_topohash.by_topohash".format(tmp_folder))
        yt.run_sort(
            # "id" "id1Type" "id2Type" "edgesCount", direction
            tmp_id_id1type_id2type_edges_count_appendable,
            # "id" "id1Type" "id2Type" "edgesCount", direction
            tmp_id_id1type_id2type_edges_count_path_sorted_id1type_id,
            sort_by=["id1Type", "id"]
        )
        # todo: make not temp but result table
        yt.run_reduce(
            count_topohashes_for_nodes,
            # "id" "id1Type" "id2Type" "edgesCount", direction
            tmp_id_id1type_id2type_edges_count_path_sorted_id1type_id,
            # id, topohash
            tmp_id_topohash,
            reduce_by=["id1Type", "id"],
            job_count=256
        )
        yt.run_sort(
            tmp_id_topohash,
            sort_by=["topohash"]
        )
        yt.run_reduce(
            reduce_topohashes,
            tmp_id_topohash,
            target_table,  # topohash, count
            reduce_by=["topohash"],
            job_count=256
        )
        yt.run_sort(tmp_id_topohash, sort_by=["count"])
        sys.stdout.write("...done\n")
    else:
        sys.stdout.write("Ignoring\n")

    sys.stdout.write(format_h2("Cleanup stage") + '\n')
    if do_cleanup:
        sys.stdout.write("Running...\n")
        sys.stdout.write("Removing:\n{}\n".format(format_tables_list(tmp_folder)))
        yt.remove(tmp_folder, recursive=True)
        sys.stdout.write("...done\n")
    else:
        sys.stdout.write("Ignoring cleanup\n")

    sys.stdout.write(
        "Everything is done right!\n"
        "Result:\n"
        "{}\n".format(format_tables_list(target_table))
    )


if __name__ == '__main__':
    args = get_args()

    do_supometriya(**vars(args))
