from collections import defaultdict
from functools import partial

import luigi
import yt.wrapper as yt

from crypta.graph.v1.python.data_imports.import_logs.app_metrica_day import ImportAppMetrikaDayTask
from crypta.graph.v1.python.data_imports.import_logs.graph_import_fp import ImportFPDayTask
from crypta.graph.v1.python.data_imports.import_logs.webvisor.graph_webvisor import ImportWebvisorDayTask
from crypta.graph.v1.python.lib.luigi import base_luigi_task
from crypta.graph.v1.python.lib.luigi import yt_luigi
from crypta.graph.v1.python.rtcconf import config
from crypta.graph.v1.python.utils import mr_utils as mr
from crypta.graph.v1.python.utils import utils
from crypta.graph.v1.python.utils import yt_clients

HIT_MAX = 10000
REQANS_SOURCE = "rq"
YABS_SOURCE = "yb"
ALL_SOURCE = "al"
MMETRICA_SOURCE = "mm"
FINGERPRINT_SOURCE = "fp"
METRICA_SOURCE = "m"

NAMES = {
    REQANS_SOURCE: "reqans",
    YABS_SOURCE: "yabs",
    ALL_SOURCE: "all",
    MMETRICA_SOURCE: "mmetr",
    FINGERPRINT_SOURCE: "fingerprint",
    METRICA_SOURCE: "metrica",
}


def map_fp(rec):
    # it's really strange to yield a record for every source just to match it back later
    # TODO: use raw table from FP
    yuid = rec["yuid"]
    sources = rec["sources"].split(",")
    fp2type = dict(r=REQANS_SOURCE, m=METRICA_SOURCE)
    if yuid:
        for s in sources:
            source_type = fp2type.get(s)
            if source_type is not None:
                yield {"yuid": yuid, "source_type": source_type, "@table_index": 0}

        yield {"yuid": yuid, "source_type": FINGERPRINT_SOURCE, "@table_index": 0}


def map_yabs(rec):
    value = rec["value"]
    fraud = mr.get_field_value("fraud_fraudbits", value)
    selecttype = mr.get_field_value("selecttype", value)
    if (not fraud or fraud == "0") and selecttype != "5" and selecttype != "17":
        yuid = mr.get_field_value("yandexuid", value)
        if yuid:
            yield {"yuid": yuid, "source_type": YABS_SOURCE, "@table_index": 0}


def reduce_filter_suspicious_split_by_device(yuid_key, recs, sources, date):
    desk_hits = 0
    mob_hits = 0
    desk_ua_profiles = set()
    mob_ua_profiles = set()
    browsers = set()

    is_rus = False
    source_types_set = set()
    source_types_set.add(ALL_SOURCE)

    for r in recs:
        if "ua_profile" in r:
            ua_profile_bro = "|".join((r["ua_profile"], r["browser"], r["browser_version"]))
            if r["ua_profile"].startswith("d|"):
                desk_hits += r["id_count"]
                desk_ua_profiles.add(ua_profile_bro)
            elif r["ua_profile"].startswith("m|"):
                mob_hits += r["id_count"]
                mob_ua_profiles.add(ua_profile_bro)
            browsers.add(r["browser"])
        elif "rus" in r:
            if r["rus"] == 1:
                is_rus = True
        elif "source_type" in r and r["source_type"] in sources:
            source_types_set.add(r["source_type"])

    country_index = 0 if is_rus else 6
    source_types = ",".join(sorted(source_types_set))
    yuid = yuid_key["yuid"]
    all_ua_profiles = desk_ua_profiles | mob_ua_profiles

    def suspicious_hits_index(hits):
        if hits == 1:  # single-clickers
            return 0
        elif hits > HIT_MAX:  # robots?
            return 1
        else:  # good guys
            return 2

    hits = None
    device_type = None

    if not desk_hits and not mob_hits:
        yield {"yuid": yuid, "@table_index": 12}  # no ua
        return

    elif (not mob_ua_profiles and len(all_ua_profiles) > 2 and len(browsers) > 1) or (
        mob_ua_profiles and len(all_ua_profiles) > 5
    ):
        # desktop yuids shouldn't have several ua_profiles per day, but let's allow two just in case
        # mobile yuids can sometime pretend to be several desktops, but let's limit the number of it
        # browser compatibility mode can use several version of the same browser, keep this case too
        yield {
            "key": yuid,
            "source_types": source_types,
            "desk_hits": desk_hits,
            "mob_hits": mob_hits,
            "desk_ua_profile": list(desk_ua_profiles),
            "mob_ua_profile": list(mob_ua_profiles),
            "@table_index": 13,
        }
        return

    elif desk_hits and mob_hits:
        # assume it's mobile just pretending to be desktop
        device_type = "mob"
        hits = desk_hits + mob_hits

    elif mob_hits:
        device_type = "mob"
        hits = mob_hits

    elif desk_hits:
        device_type = "desk"
        hits = desk_hits

    hits_index = suspicious_hits_index(hits)

    yuid_creation_date = utils.get_yuid_creation_date(yuid)
    activity_type = utils.get_yuid_activity_type([date], yuid_creation_date)
    if activity_type == "private":
        hits_index = 0  # single-clickers

    device_index = 0 if device_type == "mob" else 3
    yield {
        "key": yuid,
        "device_type": device_type,
        "hits": hits,
        "source_types": source_types,
        "ua_profiles": list(all_ua_profiles),
        "@table_index": hits_index + device_index + country_index,
    }


def map_device(rec):
    if rec.get("main_region_country") == 225:
        yield {
            "key": rec["device_id"],
            "device_type": "app",
            "hits": 1,
            "source_types": MMETRICA_SOURCE,
            "os": rec.get("os"),
        }


def flatten_sources(rec, columns, separator=","):
    """
    Unwraps all values that are stored as comma-separated list in specified columns
    and yields original rec with all combinations of unwrapped value
    input example:
        {'key': 'key1', 'a': '1,2', 'b': '4,5', 'c':'other,values'}
    output example:
        {'key': 'key1', 'a': '1', 'b': '4', 'c': 'other,values'}
        {'key': 'key1', 'a': '1', 'b': '5', 'c': 'other,values'}
        {'key': 'key1', 'a': '2', 'b': '4', 'c': 'other,values'}
        {'key': 'key1', 'a': '2', 'b': '5', 'c': 'other,values'}
    """
    combinations = []

    def extend_combination(comb, column, values):
        for value in values:
            new_comb = comb.copy()
            new_comb[column] = value
            yield new_comb

    for col in columns:
        if separator:
            column_values_list = rec[col].split(separator)
        else:
            column_values_list = rec[col]

        if combinations:
            new_combs = []
            for existing_comb in combinations:
                new_combs.extend(extend_combination(existing_comb, col, column_values_list))
            combinations = new_combs
        else:
            combinations = extend_combination({}, col, column_values_list)

    for flatten_rec in combinations:
        for col, value in flatten_rec.iteritems():
            rec[col] = value
        yield rec


def to_sources_group_str(rec, columns):
    for col in columns:
        rec[col] = ",".join(rec[col])
    yield rec


def count_uniques_and_sum_hits_by_sources(in_table, out_table, source_columns, flatten, sync):
    """
    Calculates hit sums and uniques counts grouped by source columns with single op
    :param source_columns: one or more grouping columns
    :param flatten: each source column may contain several sources per yuid.
    - if True, then result is grouped by each source separately
    - if False, then result is grouped by each group of sources groups
    :return: op
    """
    hits_column = "hits"
    uniques_column = "count"

    @yt.aggregator
    def map_pre_aggregate_uniques_and_hits_by_source(recs):
        uniques_by_source = defaultdict(int)
        hits_by_source = defaultdict(int)
        sources = set()

        if flatten:
            extract_sources_func = partial(flatten_sources, columns=source_columns)
        else:
            extract_sources_func = partial(to_sources_group_str, columns=source_columns)

        for rec in recs:
            for rec_with_source in extract_sources_func(rec):
                sources_key = tuple(rec_with_source[sc] for sc in source_columns)
                sources.add(sources_key)

                u = rec_with_source[uniques_column] if uniques_column in rec else 1
                uniques_by_source[sources_key] += u

                h = rec_with_source[hits_column] if hits_column in rec else 0
                hits_by_source[sources_key] += h

        for sources_key in sources:
            out_rec = dict(zip(source_columns, sources_key))
            out_rec[uniques_column] = uniques_by_source[sources_key]
            out_rec[hits_column] = hits_by_source[sources_key]
            yield out_rec

    def reduce_uniques_and_hits_by_source(key, recs):
        uniques = 0
        hits = 0

        for rec in recs:
            uniques += rec[uniques_column]
            hits += rec[hits_column]

        out_rec = dict(key)
        out_rec[uniques_column] = uniques
        out_rec[hits_column] = hits

        yield out_rec

    yt_client = yt_clients.get_yt_client()
    return yt_client.run_map_reduce(
        map_pre_aggregate_uniques_and_hits_by_source,
        reduce_uniques_and_hits_by_source,
        in_table,
        out_table,
        reduce_by=source_columns,
        sync=sync,
    )


def sum_by_sources(workdir, yuids_table, sources_columns, flatten, sync=False):
    """
    Each incoming row contains id[yuid/ip] with multiple sources in each source column and
    number of hits from this sources combination
    Calculates two types of count:
    - Count of hits per sources types combination (unique per id), if flatten = False
    - Count of hits per each source (hit is counted for every source), if flatten = True
    """
    folder_per_group = workdir + "sum_by_" + "_".join(sources_columns) + "/"
    mr.mkdir(folder_per_group)

    in_table_name = workdir + yuids_table
    if flatten:
        out_table_name = folder_per_group + yuids_table + "_flatten_count"
    else:
        out_table_name = folder_per_group + yuids_table + "_count"

    return count_uniques_and_sum_hits_by_sources(in_table_name, out_table_name, sources_columns, flatten, sync)


def sum_sources_to_total_in_dir(workdir, source_types):
    yt_client = yt_clients.get_yt_client()
    sum_dir = workdir + "sum_by_" + "_".join(source_types)
    count_tables = filter(lambda table: table.endswith("_count"), yt_client.list(sum_dir))
    count_tables = filter(lambda table: not table.endswith("_flatten_count"), count_tables)  # doesn't make sense
    count_and_sum_total = partial(mr.sum_and_count_by_column, group_by=["total"])
    ops = []
    for table in count_tables:
        ops.append(
            yt_client.run_map_reduce(
                None,
                count_and_sum_total,
                sum_dir + "/" + table,
                sum_dir + "/" + table + "_total",
                reduce_by="total",
                sync=False,
            )
        )
    utils.wait_all(ops)


def export_todays_statbox_tables(dt, out):
    yt_client = yt_clients.get_yt_client()
    yt_client.run_map(map_fp, config.GRAPH_FOLDER + dt + "/yuids_ua_day", out + "yuids_fp")
    # yt.run_map(map_yabs,
    #            [config.STATBOX_YABS_FOLDER + dt],
    #            [out + 'yuids_yabs'])

    distinct_keys = ["yuid", "source_type"]
    mr.sort_all(
        [
            out + "yuids_fp",
            # out + 'yuids_yabs',
        ],
        distinct_keys,
    )

    utils.wait_all(
        [
            mr.distinct_by(distinct_keys, out + "yuids_fp", out + "yuids_fp", sync=False),
            # mr.distinct_by(distinct_keys, out + 'yuids_yabs', out + 'yuids_yabs', sync=False)
        ]
    )


def prepare_all_yuids(graph, in_f, out, date):
    in_yuid_tables = [
        in_f + "yuids_fp",  # in_f + 'yuids_yabs',
        graph + "yuids_desk_day",
        graph + "yuids_mob_day",
        graph + "yuids_rus_ip_day",
    ]

    # single clickers rus filtering
    out_tables = [
        out + "yuids_single_mob_rus",
        out + "yuids_robot_mob_rus",
        out + "mob_yuids_rus",
        out + "yuids_single_desk_rus",
        out + "yuids_robot_desk_rus",
        out + "desk_yuids_rus",
        out + "yuids_single_mob_not_rus",
        out + "yuids_robot_mob_not_rus",
        out + "yuids_good_mob_not_rus",
        out + "yuids_single_desk_not_rus",
        out + "yuids_robot_desk_not_rus",
        out + "yuids_good_desk_not_rus",
        out + "yuids_no_ua",
        out + "yuids_like_test_robots",
    ]

    mr.sort_all(in_yuid_tables, sort_by="yuid")
    yt_client = yt_clients.get_yt_client()
    yt_client.run_reduce(
        partial(
            reduce_filter_suspicious_split_by_device,
            sources=[REQANS_SOURCE, METRICA_SOURCE, FINGERPRINT_SOURCE],
            date=date,
        ),
        in_yuid_tables,
        out_tables,
        reduce_by="yuid",
    )

    mr.merge_chunks_all(out_tables)

    utils.wait_all(
        [
            sum_by_sources(out, "mob_yuids_rus", ["source_types"], flatten=True),
            sum_by_sources(out, "desk_yuids_rus", ["source_types"], flatten=True),
            sum_by_sources(out, "mob_yuids_rus", ["source_types"], flatten=False),
            sum_by_sources(out, "desk_yuids_rus", ["source_types"], flatten=False),
        ]
    )

    mr.sort_all([out + "desk_yuids_rus", out + "mob_yuids_rus"], "key")

    sum_sources_to_total_in_dir(out, ["source_types"])


def prepare_all_devices(graph, out):
    # TODO: dev info
    yt_client = yt_clients.get_yt_client()
    yt_client.run_map(map_device, graph + "mobile/dev_info_yt", out + "devices")
    yt_client.run_sort(out + "devices", sort_by="key")
    utils.wait_all(
        [
            sum_by_sources(out, "devices", ["source_types"], flatten=True),
            sum_by_sources(out, "devices", ["source_types"], flatten=False),
        ]
    )


def reduce_count_hits_per_acc(key, recs, key_type):
    acc_count = defaultdict(lambda: 0)
    acc_geo_sources = defaultdict(set)
    for rec in recs:
        acc = rec["acc"]
        acc_count[acc] += 1
        geo_source_types = rec["geo_source_types"].split(",")
        acc_geo_sources[acc].update(geo_source_types)

    for acc, count in acc_count.iteritems():
        geo_sources = ",".join(acc_geo_sources[acc])
        yield {
            key_type: key[key_type],
            "key_type": key_type,
            "acc": acc,
            "hits": count,
            "geo_source_types": geo_sources,
        }


def group_by_and_count_hits(key, recs, group_columns):
    count = 0
    for _ in recs:
        count += 1

    out = {"hits": count}
    for c in group_columns:
        out[c] = key[c]
    yield out


class PrepareTodayTotalUsageStats(base_luigi_task.BaseTask):
    """
    Calculates the number of mobile and desktop cookies and devices and their activity
    that have been active for today in different logs. Used as base for other stats
    """

    date = luigi.Parameter()
    tags = ["v1"]

    def requires(self):
        return [
            ImportFPDayTask(date=self.date, run_date=self.date),
            ImportWebvisorDayTask(date=self.date, run_date=self.date),
            ImportAppMetrikaDayTask(date=self.date, run_date=self.date),
        ]

    def run(self):
        graph_folder = config.YT_OUTPUT_FOLDER + self.date + "/"
        output_folder = config.YT_OUTPUT_FOLDER + self.date + "/stat_new/"

        mr.mkdir(output_folder)
        mr.mkdir(output_folder + "all/")

        export_todays_statbox_tables(self.date, output_folder)
        prepare_all_yuids(graph_folder, output_folder, output_folder + "all/", self.date)
        prepare_all_devices(graph_folder, output_folder + "all/")
        for target in self.sorted_output():
            target.complete()

    def sorted_output(self):
        output_folder = config.YT_OUTPUT_FOLDER + self.date + "/stat_new/all/"
        return [
            yt_luigi.YtAttrTarget(output_folder + "desk_yuids_rus", "sorted_by_key"),
            yt_luigi.YtAttrTarget(output_folder + "mob_yuids_rus", "sorted_by_key"),
            yt_luigi.YtAttrTarget(output_folder + "devices", "sorted_by_key"),
        ]

    def output(self):
        output_folder = config.YT_OUTPUT_FOLDER + self.date + "/stat_new/"
        return [
            yt_luigi.YtTarget(output_folder + "all/sum_by_source_types/devices_count"),
            yt_luigi.YtTarget(output_folder + "all/sum_by_source_types/desk_yuids_rus_count_total"),
            yt_luigi.YtTarget(output_folder + "all/sum_by_source_types/mob_yuids_rus_count_total"),
        ] + self.sorted_output()
