import functools
import itertools
import logging
import re
from collections import defaultdict
from datetime import datetime, timedelta
from contextlib import contextmanager

import yt.wrapper as yt
from yt import yson

from crypta.graph.v1.python.utils import utils
from crypta.graph.v1.python.utils import yt_clients

OOM_LIMIT = 100

# 2GB per job to produce less chunks and make better jobs utilization
DATA_SIZE_PER_JOB_2GB_SPEC = {"data_size_per_job": 2048 * 1024 * 1024}
DATA_SIZE_PER_JOB_1GB_SPEC = {"data_size_per_job": 1024 * 1024 * 1024}
DATA_SIZE_PER_JOB_20MB_SPEC = {"data_size_per_job": 20 * 1024 * 1024}

JOIN_REDUCE_HEAVY_JOBS_SPEC = {"max_data_size_per_job": 2 * 1024 * 1024 * 1024 * 1024}

ROWS_PER_PARTITION = 1000000

logger = logging.getLogger(__file__)


def get_field_value(field_name, value, separator="\t", equal_sign="="):
    res = ""
    idx_start = value.find(separator + field_name + equal_sign)
    if idx_start != -1:
        idx_start = idx_start + len(field_name) + len(equal_sign) + len(separator)
    elif value.startswith(field_name + equal_sign):
        idx_start = len(field_name) + len(equal_sign)
    else:
        return res

    if idx_start >= len(value):
        return res
    idx_end = value.find(separator, idx_start)
    if idx_end > 0:
        res = value[idx_start:idx_end]
    else:
        res = value[idx_start:]

    return res


#
# --- frequently used reduces
#


def sort_all(tables, sort_by, spec=None, yt_client=None):
    if yt_client is None:
        yt_client = yt_clients.get_yt_client()

    if spec is None:
        spec = {"combine_chunks": True}
    utils.wait_all([yt_client.run_sort(t, sort_by=sort_by, sync=False, spec=spec) for t in tables])


def avg_column(in_table, out_table, column, sync=True, yt_client=None):
    if yt_client is None:
        yt_client = yt_clients.get_yt_client()

    @yt.aggregator
    def map_avg(recs):
        recs_count = 0
        column_sum = 0
        for rec in recs:
            recs_count += 1
            column_sum += rec[column]
        yield {"stubkey": "1", "avg": column_sum / float(recs_count)}

    def reduce_avg(key, recs):
        recs_count = 0
        column_sum = 0
        for rec in recs:
            recs_count += 1
            column_sum += rec["avg"]
        yield {"avg": column_sum / float(recs_count)}

    yt_client.run_map_reduce(map_avg, reduce_avg, in_table, out_table, reduce_by="stubkey", sync=sync)


def avg_group_by(in_table, out_table, column, sync=True, filter_rec=None, group_by=None, yt_client=None):
    if yt_client is None:
        yt_client = yt_clients.get_yt_client()

    @yt.aggregator
    def map_avg(recs):
        recs_count_in_group = defaultdict(int)
        column_sum_in_group = defaultdict(int)
        for rec in recs:
            if not filter_rec or filter_rec(rec):
                if group_by:
                    reduce_key = rec[group_by]
                else:
                    reduce_key = "stub_key"
                recs_count_in_group[reduce_key] += 1
                column_sum_in_group[reduce_key] += rec[column]

        for reduce_key, recs_count in recs_count_in_group.iteritems():
            column_sum = column_sum_in_group[reduce_key]
            yield {"stub_col": reduce_key, "avg": column_sum / float(recs_count)}

    def reduce_avg(key, recs):
        recs_count = 0
        column_sum = 0
        for rec in recs:
            recs_count += 1
            column_sum += rec["avg"]

        if group_by:
            yield {"avg": column_sum / float(recs_count), group_by: key["stub_col"]}
        else:
            yield {"avg": column_sum / float(recs_count)}

    yt_client.run_map_reduce(map_avg, reduce_avg, in_table, out_table, reduce_by="stub_col", sync=sync)


def count_field_recs(
    in_table,
    out_table,
    columns,
    count_column="count",
    yamr_format=False,
    expect_large_keys=True,
    add_desc_count=False,
    sync=True,
    yt_client=None,
):
    """
    Produces counts of every value of a column.
    :param yamr_format: if field contains in tab-separated yamr value
    :param expect_large_keys:
    if true, optimize for several large column values
    if false, optimize for a lot of small column values
    :param add_desc_count: yt can't sort descending. If True, desc column is added to result for descending sort
    :return:
    """
    if yt_client is None:
        yt_client = yt_clients.get_yt_client()

    columns = utils.flatten(columns)

    def map_add_count_many_small_keys(rec):
        rec[count_column] = 1
        rec["@table_index"] = 0
        yield rec

    @yt.aggregator
    def map_add_count_several_large_keys(recs):
        counts = defaultdict(int)

        for r in recs:
            if yamr_format:
                aggr_key = tuple(get_field_value(c, r["value"]) for c in columns)
            else:
                aggr_key = tuple(r.get(c) for c in columns)

            counts[aggr_key] += 1

        for aggr_keys, c in counts.iteritems():
            out_rec = {aggr_col: aggr_key for (aggr_col, aggr_key) in zip(columns, aggr_keys)}
            out_rec[count_column] = c
            yield out_rec

    def reduce_sum_count(key, recs):
        c = 0
        for rec in recs:
            c += rec[count_column]

        out_rec = {col: key[col] for col in columns}
        out_rec[count_column] = c

        if add_desc_count:
            out_rec["desc"] = -c

        yield out_rec

    add_count_mapper = map_add_count_several_large_keys if expect_large_keys else map_add_count_many_small_keys

    return yt_client.run_map_reduce(
        add_count_mapper, reduce_sum_count, in_table, out_table, reduce_by=columns, sync=sync
    )


def t2_filterby_t1(in_table1, in_table2, out_table, key_field, yt_client=None):
    if yt_client is None:
        yt_client = yt_clients.get_yt_client()

    sort_all([in_table1, in_table2], key_field, yt_client=yt_client)

    def reduce_filter(key, recs):
        has_1 = False
        for rec in recs:
            if rec["@table_index"] == 0:
                has_1 = True
            elif rec["@table_index"] == 1 and has_1:
                rec["@table_index"] = 0
                yield rec
            else:
                return

    yt_client.run_reduce(reduce_filter, [in_table1, in_table2], out_table, reduce_by=key_field)


# ------ additional methods ------


def mkdir(folder, recursive=True, yt_client=None):
    if yt_client is None:
        yt_client = yt_clients.get_yt_client()

    if folder.endswith("/"):
        folder = folder[:-1]

    yt_client.mkdir(folder, recursive=recursive)


def row_count(table, yt_client=None):
    if yt_client is None:
        yt_client = yt_clients.get_yt_client()

    if exists(table, yt_client=yt_client):
        for attribute in ("row_count", "chunk_row_count"):
            try:
                return yt_client.get("{path}/@{attr}".format(path=table, attr=attribute), 0)
            except Exception:
                pass
    return 0


def set_attribute(path, attribute, value, yt_client=None):
    if yt_client is None:
        yt_client = yt_clients.get_yt_client()

    if path.endswith("/"):
        path = path[:-1]

    if exists(path, yt_client=yt_client):
        yt_client.set_attribute(path, attribute, value)
    else:
        logger.warning("Failed to set attr %s=%s for %s. The latter does not exist", attribute, value, path)


def get_generate_date(table, yt_client=None):
    if yt_client is None:
        yt_client = yt_clients.get_yt_client()

    custom_attributes = yt_client.get(table + "/@")
    return custom_attributes.get("generate_date", "0000-00-00")


def set_generate_date(table, date, yt_client=None):
    if yt_client is None:
        yt_client = yt_clients.get_yt_client()

    set_attribute(table, "generate_date", date, yt_client=yt_client)


def drop(prefix, yt_client=None):
    if yt_client is None:
        yt_client = yt_clients.get_yt_client()

    if prefix.endswith("/"):
        prefix = prefix[:-1]
    if exists(prefix, yt_client=yt_client):
        yt_client.remove(prefix, recursive=True)


def ls(prefix, absolute_path=True, yt_client=None):
    if yt_client is None:
        yt_client = yt_clients.get_yt_client()

    if prefix.endswith("/"):
        prefix = prefix[:-1]
    if absolute_path:
        return [
            prefix + "/" + t for t in yt_client.list(prefix)
        ]  # TODO:yt.list has absolute param, no need to do it manually
    else:
        return yt_client.list(prefix)


def exists(src, yt_client=None):
    if yt_client is None:
        yt_client = yt_clients.get_yt_client()

    return yt_client.exists(src)


def copy(src, dst, yt_client=None):
    if yt_client is None:
        yt_client = yt_clients.get_yt_client()

    if src.endswith("/"):
        src = src[:-1]
    if dst.endswith("/"):
        dst = dst[:-1]

    if exists(dst, yt_client=yt_client):
        drop(dst, yt_client)
    yt_client.copy(src, dst)


def merge(table, sync=True, yt_client=None):
    if yt_client is None:
        yt_client = yt_clients.get_yt_client()

    if exists(table, yt_client=yt_client):
        return yt_client.run_merge(table, table, spec={"combine_chunks": True}, sync=sync)


def merge_chunks_all(tables, yt_client=None):
    if yt_client is None:
        yt_client = yt_clients.get_yt_client()

    utils.wait_all([yt_client.run_merge(t, t, sync=False, spec={"combine_chunks": True}) for t in tables])


def filter_left_by_right(key, recs, keep_not_found=False, right_columns_to_join=[]):
    right_found = False
    left_recs = []
    values_to_join = {}
    for rec in recs:
        if rec["@table_index"] == 0:
            left_recs.append(rec)
        else:
            right_found = True
            for col in right_columns_to_join:
                values_to_join[col] = rec[col]

    for rec in left_recs:
        if right_found:
            rec["@table_index"] = 0
            rec.update(values_to_join)
            yield rec
        elif keep_not_found:
            rec["@table_index"] = 1
            yield rec


class OomLimitException(Exception):
    def __init__(self, recs_count):
        self.recs_count = recs_count
        super(OomLimitException, self).__init__()


def count_rest(recs):
    count = 0
    for _ in recs:
        count += 1
    return count


def take_till_oom(recs, oom_limit=OOM_LIMIT):
    l = list(itertools.islice(recs, 0, oom_limit))
    rest_count = sum(1 for _ in recs)
    return l, rest_count


def get_singe_rec(recs, table_index):
    target_recs = [rec for rec in recs if rec["@table_index"] == table_index]
    if target_recs:
        if len(target_recs) > 1:
            raise Exception("Should be single rec")
        else:
            return target_recs[0]
    else:
        return None


def split_left_right(recs, oom_check=True, oom_limit=OOM_LIMIT):
    """
    Helper for two-table reduce methods
    :param recs: all recs of two-table reduce
    :param oom_check: if true, raises OomLimitException in case of too large iterator
    :param oom_limit when to raise OomLimitException
    :return: list of records of the first table + list of recs of the second
    """
    left = []
    right = []
    for r in recs:
        if oom_check and (len(left) > oom_limit or len(right) > oom_limit):
            raise OomLimitException(oom_limit + count_rest(recs))

        if not r["@table_index"] or r["@table_index"] == 0:
            left.append(r)
        elif r["@table_index"] == 1:
            right.append(r)
        else:
            raise ValueError("Unsupported table index: %s" % r["@table_index"])
    return left, right


def join_left_right(key, recs, l_cols, r_cols):
    """
    reduce two tables to single one. Useful for comparing the same table of different dates.
    Keeps in new table:
     - reduce key
     - l_cols from left table with postfx _1
     - r_cols from right table with postfx _2
    :return:
    """

    left, right = split_left_right(recs)
    out = {}
    for k, v in key.iteritems():
        out[k] = v

    if left:
        left_rec = left[0]
        for col in l_cols:
            out[col + "_1"] = left_rec[col]
    if right:
        right_rec = right[0]
        for col in r_cols:
            out[col + "_2"] = right_rec[col]
    yield out


def map_column_to_key(rec, column):
    rec["key"] = rec[column]
    rec["@table_index"] = 0
    yield rec


def map_key_to_column(rec, column):
    rec[column] = rec["key"]
    rec["@table_index"] = 0
    yield rec


def generate_dates_before(date, days):
    dt = datetime.strptime(date, "%Y-%m-%d")
    return [(dt - timedelta(days=i)).strftime("%Y-%m-%d") for i in range(days)]


def list_dates_before(folder, dt, days, yt_client=None):
    if yt_client is None:
        yt_client = yt_clients.get_yt_client()

    if folder.endswith("/"):
        folder = folder[:-1]

    dt_pattern = re.compile(r"(\d{4}-\d{2}-\d{2})")
    dates = yt_client.list(folder)
    dates = sorted(filter(lambda x: dt_pattern.match(x) and x <= dt, dates))

    return dates[-days:]


def get_date_table(folder, date, table):
    if folder.endswith("/"):
        folder = folder[:-1]

    return folder + "/" + date + "/" + table


def get_date_tables(folder, table, ndates, before_date=None, yt_client=None):
    if yt_client is None:
        yt_client = yt_clients.get_yt_client()

    dt_pattern = re.compile(r"(\d{4}-\d{2}-\d{2})")
    if folder.endswith("/"):
        folder = folder[:-1]
    dates = filter(lambda x: dt_pattern.match(x), yt_client.list(folder))
    if before_date:
        dates = filter(lambda x: x <= before_date, dates)
    dates = sorted(dates)[-ndates:]
    table_postfix = "/" + table if table else ""  # if table has date format itself
    return map(lambda dt: folder + "/" + dt + table_postfix, dates)


def get_existing_date_tables(folder, table, ndates, before_date=None, yt_client=None):
    if yt_client is None:
        yt_client = yt_clients.get_yt_client()

    tables = get_date_tables(folder, table, ndates, before_date, yt_client=yt_client)
    return filter(lambda table: exists(table, yt_client=yt_client), tables)


def get_prev_table(folder, date, table, raise_exc=True, yt_client=None):
    if yt_client is None:
        yt_client = yt_clients.get_yt_client()

    dt_pattern = re.compile(r"(\d{4}-\d{2}-\d{2})")
    if folder.endswith("/"):
        folder = folder[:-1]
    dates = sorted(filter(lambda x: dt_pattern.match(x) and x < date, yt_client.list(folder)))
    if dates:
        for date in dates[::-1]:
            if table:
                path = folder + "/" + date + "/" + table
            else:
                path = folder + "/" + date

            if exists(path, yt_client=yt_client):
                return path

    if raise_exc:
        raise Exception("No prev table found in folder " + folder + "," + table)
    else:
        return ""


def get_last_table(folder, table="", raise_exc=True, check_size=False, yt_client=None):
    if yt_client is None:
        yt_client = yt_clients.get_yt_client()

    dt_pattern = re.compile(r"(\d{4}-\d{2}-\d{2})")
    if folder.endswith("/"):
        folder = folder[:-1]
    dates = sorted(filter(lambda x: dt_pattern.match(x), yt_client.list(folder)))
    if dates:
        for date in dates[::-1]:
            if table:
                path = folder + "/" + date + "/" + table
            else:
                path = folder + "/" + date
            if exists(path, yt_client=yt_client):
                if check_size:
                    if row_count(path, yt_client=yt_client) > 0:
                        return path
                else:
                    return path
    if raise_exc:
        raise Exception("No last table found in folder " + folder + "," + table)
    else:
        return None


def distinct_by(distinct_key, src_tables, dst_table, sync=True, additional_fields=None, yt_client=None):
    if yt_client is None:
        yt_client = yt_clients.get_yt_client()

    distinct_keys = utils.flatten(distinct_key)
    src_tables = utils.flatten(src_tables)
    if not additional_fields:
        additional_fields = []

    # TODO: if map_reduce works with enough performance, then this can be deleted
    # if sync:  # can't run async two operations, need to sort manually in this case
    #     sort_all(src_tables, distinct_keys)
    # return yt.run_reduce(functools.partial(distinct_mr, distinct_keys=distinct_keys),
    #                      src_tables, dst_table, reduce_by=distinct_keys, sync=sync)

    return yt_client.run_map_reduce(
        None,
        functools.partial(distinct_mr, distinct_keys=distinct_keys, additional_fields=additional_fields),
        src_tables,
        dst_table,
        reduce_by=distinct_keys,
        sync=sync,
    )


def distinct_mr(key, recs, distinct_keys, additional_fields=None):
    output = {distinct_key: key[distinct_key] for distinct_key in distinct_keys}
    if additional_fields:
        for _, out_field, _, _ in additional_fields:
            output[out_field] = []
        for rec in recs:
            for in_field, out_field, is_list, _ in additional_fields:
                output[out_field] += rec.get(in_field, []) if is_list else [rec.get(in_field, None)]
        for _, out_field, _, func in additional_fields:
            if func:
                output[out_field] = func(output[out_field])
    yield output


def count_by_column(key, recs, column, count_column="count", add_desc_count=False):
    count = 0
    for _ in recs:
        count += 1
    if add_desc_count:
        yield {column: key[column], count_column: count, "desc": -count}
    else:
        yield {column: key[column], count_column: count}


def count_by_columns(key, recs, columns):
    count = 0
    for _ in recs:
        count += 1
    out = {"count": count}
    for c in columns:
        out[c] = key[c]
    yield out


def sum_column(key, recs, sum_column):
    sum = 0
    for r in recs:
        sum += r[sum_column]

    out_rec = dict(key)
    out_rec[sum_column] = sum
    yield out_rec


def sum_and_count_unique_by_column(key, recs, group_by, sum_column="hits", count_column="count", id_column="key"):
    sum = 0
    count = 0
    last_id = None

    for rec in recs:
        sum += rec[sum_column] if sum_column in rec else 0
        if rec[id_column] != last_id:
            # assume input is sorted by id, count only on id change
            last_id = rec[id_column]
            count += 1

    sum_rec = {sum_column: sum, count_column: count}
    for group_col in group_by:
        sum_rec[group_col] = key[group_col]

    yield sum_rec


def sum_and_count_by_column(key, recs, group_by, sum_column="hits", count_column="count"):
    sum = 0
    count = 0

    for rec in recs:
        sum += rec[sum_column] if sum_column in rec else 0
        count += rec[count_column] if count_column in rec else 1

    sum_rec = {sum_column: sum, count_column: count}
    for group_col in group_by:
        sum_rec[group_col] = key[group_col]

    yield sum_rec


def fetch_rows_by_id(table, search_keys, search_values, yt_client=None):
    if yt_client is None:
        yt_client = yt_clients.get_yt_client()

    search_predicate = "[" + ",".join('"' + value + '"' for value in utils.flatten(search_values)) + "]"
    table_path = yt_client.TablePath(table + search_predicate)

    from yt.wrapper import table_commands

    sorted_by = utils.flatten(table_commands.get_sorted_by(table_path, default=[], client=yt_client))
    search_keys = utils.flatten(search_keys)
    if not sorted_by or len(search_keys) > len(sorted_by) or sorted_by[: len(search_keys)] != search_keys:
        raise Exception(
            "Table %s should be sorted by %s to allow quick search, instead sorted by %s"
            % (table, search_keys, sorted_by)
        )

    return yt_client.read_table(table_path, raw=False)


def safe_get_attribute(path, attr, default=None, yt_client=None):
    if yt_client is None:
        yt_client = yt_clients.get_yt_client()

    try:
        return yt_client.get_attribute(path, attr, default)
    except Exception:
        return default


def create_table_with_schema(
    table, schema, transaction=None, strict=False, recreate_if_exists=True, sorted_by=None, yt_client=None
):
    if yt_client is None:
        yt_client = yt_clients.get_yt_client()

    @contextmanager
    def fake_transaction():
        yield

    transaction_context = fake_transaction if transaction else yt_client.Transaction

    with transaction_context():
        schema = schema.copy()
        schema_attr = []

        if sorted_by:
            # sorted columns go first in schema
            for col_name in sorted_by:
                col_type = schema.pop(col_name)  # put it to schema once
                schema_attr.append({"name": col_name, "type": col_type, "sort_order": "ascending"})

        for col_name, col_type in schema.items():
            schema_attr.append({"name": col_name, "type": col_type})

        schema_yson = yson.YsonList(schema_attr)
        schema_yson.attributes["strict"] = strict
        # schema_yson.attributes["unique_keys"] = unique_keys  # doesn't work at creation time for now

        if recreate_if_exists and exists(table, yt_client=yt_client):
            yt_client.remove(table, force=True)
        yt_client.create_table(
            table, attributes={"schema": schema_yson, "optimize_for": "scan"}, ignore_existing=True, recursive=True
        )


def calculate_optimized_mr_partition_count(table, rows_per_job=100000, yt_client=None):
    if yt_client is None:
        yt_client = yt_clients.get_yt_client()

    # optimizes map reduce performance
    log_rows_count = yt_client.row_count(table)

    partition_count = int(log_rows_count / rows_per_job)
    if partition_count == 0:
        partition_count = 1

    return partition_count
