import itertools
from collections import defaultdict
from functools import partial

import luigi
import yt.wrapper as yt
from yt.wrapper import ypath

from crypta.graph.v1.python.lib import graphite_sender
from crypta.graph.v1.python.lib.luigi import yt_luigi
from crypta.graph.v1.python.matching.common_stats import graph_stat
from crypta.graph.v1.python.rtcconf import config
from crypta.graph.v1.python.utils import mr_utils as mr
from crypta.graph.v1.python.utils import utils
from crypta.graph.v1.python.utils import yt_clients

STAT_CID_SIZE_OOM = 10000


def reduce_add_cid(key, recs):
    lrecs = list(recs)
    desk = [r for r in lrecs if "device_type" in r and r["device_type"] == "desk"]
    mob = [r for r in lrecs if "device_type" in r and r["device_type"] == "mob"]
    device = [r for r in lrecs if "device_type" in r and r["device_type"] == "app"]
    cid = [r for r in lrecs if "crypta_id" in r]

    if cid:
        crypta_id = cid[0]["crypta_id"]
        if desk:
            desk_rec = desk[0]
            desk_rec["crypta_id"] = crypta_id
            desk_rec["@table_index"] = 0
            yield desk_rec
        if mob:
            mob_rec = mob[0]
            mob_rec["crypta_id"] = crypta_id
            mob_rec["@table_index"] = 1
            yield mob_rec
        if device:
            device_rec = device[0]
            device_rec["crypta_id"] = crypta_id
            device_rec["@table_index"] = 2
            yield device_rec


def reduce_cid_ua(_, recs):
    lrecs = list(recs)
    yi = [r for r in lrecs if "subkey" in r and r["subkey"] == "yi"]
    cid_yuid = [r for r in lrecs if "crypta_id" in r and r["id_type"].startswith("yuid")]
    cid_devid = [r for r in lrecs if "crypta_id" in r and r["id_type"].startswith("deviceid")]

    if yi and cid_yuid:
        ua_profile = mr.get_field_value("ua_profile", yi[0]["value"])
        crypta_id = cid_yuid[0]["crypta_id"]
        yield {"crypta_id": crypta_id, "ua_profile": ua_profile}

    if cid_devid:
        crypta_id = cid_devid[0]["crypta_id"]
        yield {"crypta_id": crypta_id, "ua_profile": "app"}


def reduce_cid_ua_yuid(_, recs, ua_type):
    lrecs = list(itertools.islice(recs, STAT_CID_SIZE_OOM + 1))

    if len(lrecs) > STAT_CID_SIZE_OOM:
        return

    ci_recs = [r for r in lrecs if r["@table_index"] == 0]
    yuid_recs = [r for r in lrecs if r["@table_index"] == 1]

    if ci_recs and yuid_recs:
        for ci_rec in ci_recs:
            if ci_rec["ua_profile"].startswith(ua_type):
                for yuid_rec in yuid_recs:
                    yuid_rec["@table_index"] = 0
                    yield yuid_rec
                return


def cross_device_stats_by_vertices(all_f, out_f, vertices):
    # Join cid

    mr.sort_all([vertices], "key")

    yt_client = yt_clients.get_yt_client()
    yt_client.run_reduce(
        reduce_add_cid,
        [all_f + "desk_yuids_rus", all_f + "mob_yuids_rus", all_f + "devices", vertices],
        [out_f + "desk_yuids_rus_cid", out_f + "mob_yuids_rus_cid", out_f + "devices_cid"],
        reduce_by="key",
    )
    mr.sort_all([out_f + "desk_yuids_rus_cid", out_f + "mob_yuids_rus_cid", out_f + "devices_cid"], "crypta_id")

    # Join ua to find cids containing specific UA type (desk/mob/device)
    yt_client.run_reduce(
        reduce_cid_ua, [vertices, config.GRAPH_YT_DICTS_FOLDER + "yuid_ua"], [out_f + "cid_ua"], reduce_by="key"
    )
    yt_client.run_sort(out_f + "cid_ua", sort_by="crypta_id")

    # Find cross-device
    utils.wait_all(
        [
            yt_client.run_reduce(
                partial(reduce_cid_ua_yuid, ua_type="d"),
                [out_f + "cid_ua", out_f + "mob_yuids_rus_cid"],
                [out_f + "mob_yuids_with_desk"],
                reduce_by="crypta_id",
                sync=False,
            ),
            yt_client.run_reduce(
                partial(reduce_cid_ua_yuid, ua_type="m"),
                [out_f + "cid_ua", out_f + "desk_yuids_rus_cid"],
                [out_f + "desk_yuids_with_mob"],
                reduce_by="crypta_id",
                sync=False,
            ),
            yt_client.run_reduce(
                partial(reduce_cid_ua_yuid, ua_type="m"),
                [out_f + "cid_ua", out_f + "devices_cid"],
                [out_f + "devices_with_mob"],
                reduce_by="crypta_id",
                sync=False,
            ),
            yt_client.run_reduce(
                partial(reduce_cid_ua_yuid, ua_type="d"),
                [out_f + "cid_ua", out_f + "devices_cid"],
                [out_f + "devices_with_desk"],
                reduce_by="crypta_id",
                sync=False,
            ),
        ]
    )

    # traffic
    utils.wait_all(
        [
            graph_stat.sum_by_sources(out_f, "mob_yuids_with_desk", ["source_types"], flatten=True),
            graph_stat.sum_by_sources(out_f, "desk_yuids_with_mob", ["source_types"], flatten=True),
            graph_stat.sum_by_sources(out_f, "devices_with_desk", ["source_types"], flatten=True),
            graph_stat.sum_by_sources(out_f, "devices_with_mob", ["source_types"], flatten=True),
            graph_stat.sum_by_sources(out_f, "mob_yuids_with_desk", ["source_types"], flatten=False),
            graph_stat.sum_by_sources(out_f, "desk_yuids_with_mob", ["source_types"], flatten=False),
            graph_stat.sum_by_sources(out_f, "devices_with_desk", ["source_types"], flatten=False),
            graph_stat.sum_by_sources(out_f, "devices_with_mob", ["source_types"], flatten=False),
        ]
    )

    graph_stat.sum_sources_to_total_in_dir(out_f, ["source_types"])


def split_devices_indev(key, recs):
    cid_rec, dev_rec = (None, None)
    for rec in recs:
        if "crypta_id" in rec:
            cid_rec = rec
        else:
            dev_rec = rec
    if dev_rec:
        os = dev_rec.get("os", "")
        out_index = 0 if os == "ios" else (1 if os == "android" else 2)
        if cid_rec:
            if "sources" in cid_rec and ("indev" in cid_rec["sources"]):
                # split indevice-mathes into 3 tables by os
                yield {"key": key["key"], "sources": cid_rec["sources"], "@table_index": out_index}
        # split devices into 3 tables by os
        yield {"key": key["key"], "@table_index": out_index + 3}


def in_device_stats_by_browser_prepare_map(rec):
    maybe_id2_browser = rec.get("id2_browser") or ""
    maybe_id1_ua = rec.get("id1_ua") or ""
    if (
        "indev" in rec.get("match_type", "")
        and rec["pair_type"] == "d_y"
        and bool(maybe_id2_browser)
        and bool(maybe_id1_ua)
    ):

        devid = rec["id1"]
        browser = maybe_id2_browser.split("|")[0]  # mobilesafari|10.0
        platform = maybe_id1_ua.split("|")[3]  # m|phone|apple|ios|10.2.1
        yield dict(devid=devid, browser=browser, platform=platform)


def in_device_stats_by_browser_prepare_reduce(key, recs):
    browsers = set()
    for r in recs:
        browsers.add((r["browser"], r["platform"]))

    for b in browsers:
        yield dict(key=key["devid"], browser=b[0], platform=b[1])


def in_device_stats_by_browser_filter_devices_map(rec):
    rec["jord"] = rec["@table_index"]
    rec["@table_index"] = 0
    yield rec


def in_device_stats_by_browser_filter_devices_reduce(key, recs):
    dev_rec = None
    for r in recs:
        if r["jord"] == 0:
            dev_rec = r
        else:
            if dev_rec is not None:
                yield r


@yt.aggregator
def in_device_stats_by_browser_count_map(recs):
    counts = defaultdict(lambda: defaultdict(int))
    for r in recs:
        counts[r["platform"]][r["browser"]] += 1
    for p in counts:
        for b in counts[p]:
            yield dict(browser=b, platform=p, count=counts[p][b])


def in_device_stats_by_browser_count_reduce(key, recs):
    counts = defaultdict(lambda: defaultdict(int))
    for r in recs:
        counts[r["platform"]][r["browser"]] += r["count"]
    for p in counts:
        for b in counts[p]:
            yield dict(browser=b, platform=p, count=counts[p][b])


def in_device_stats_by_vertices(all_f, out_f, vertices, edges, date):
    indev_names = ["indev_ios", "indev_android", "indev_other"]
    devices_names = ["dev_ios", "dev_android", "dev_other"]

    mr.sort_all([vertices], "key")
    yt_client = yt_clients.get_yt_client()
    yt_client.run_reduce(
        split_devices_indev,
        [all_f + "devices", vertices],
        [out_f + t for t in (indev_names + devices_names)],
        reduce_by="key",
    )

    yt_client.run_map_reduce(
        in_device_stats_by_browser_prepare_map,
        in_device_stats_by_browser_prepare_reduce,
        edges,
        out_f + "device_browsers",
        reduce_by="devid",
    )

    yt_client.run_map_reduce(
        in_device_stats_by_browser_filter_devices_map,
        in_device_stats_by_browser_filter_devices_reduce,
        [all_f + "devices", out_f + "device_browsers"],
        out_f + "device_browsers_filtered",
        sort_by=["key", "jord"],
        reduce_by="key",
    )

    yt_client.run_map_reduce(
        in_device_stats_by_browser_count_map,
        in_device_stats_by_browser_count_reduce,
        out_f + "device_browsers_filtered",
        out_f + "device_browsers_counts",
        reduce_by=["platform", "browser"],
    )

    counts = dict([(k, yt_client.row_count(out_f + k)) for k in (indev_names + devices_names)])
    counts["indev_all"] = sum(counts[k] for k in indev_names)
    counts["dev_all"] = sum(counts[k] for k in devices_names)

    metrics = [("graph_stat_in_device", k, str(v)) for k, v in counts.iteritems()]
    graphite_sender.to_graphite_sender(metrics, date)

    by_browser = [
        (
            "graph_stat_in_device",
            "%s.%s"
            % (r["platform"].replace(" ", "_").replace(".", "_"), r["browser"].replace(" ", "_").replace(".", "_")),
            r["count"],
        )
        for r in yt_client.read_table(out_f + "device_browsers_counts")
    ]
    graphite_sender.to_graphite_sender(by_browser, date)


class VerticesMatchingCoverageStats(yt_luigi.BaseYtTask):
    """
    Calculates how many real yuids are linked cross-device and in-device yuids and devices
    """

    vertices_config = luigi.Parameter()
    tags = ["v1"]

    def input_folders(self):
        return {"total_stats": config.YT_OUTPUT_FOLDER + self.vertices_config.date + "/stat_new/all/"}

    def output_folders(self):
        return {"mathing_stats": self.vertices_config.get_vertices_folder() + "stat/matching/"}

    def requires(self):
        return [
            graph_stat.PrepareTodayTotalUsageStats(date=self.vertices_config.date),
            self.vertices_config.producing_task,
        ]

    def before_run(self):
        output_folder = self.out_f("mathing_stats")
        mr.mkdir(ypath.ypath_join(output_folder, "cross_device"))
        mr.mkdir(ypath.ypath_join(output_folder, "in_device"))

    def run(self):
        total_stats_folder = self.in_f("total_stats")
        vertices_table = self.vertices_config.get_vertices_table()
        edges_table = self.vertices_config.get_edges_table()
        output_folder = self.out_f("mathing_stats")

        cross_device_stats_by_vertices(total_stats_folder, output_folder + "cross_device/", vertices_table)

        in_device_stats_by_vertices(
            total_stats_folder, output_folder + "in_device/", vertices_table, edges_table, self.vertices_config.date
        )

    def output(self):
        output_folder = self.out_f("mathing_stats")
        # TODO: add real number of tables
        return [
            yt_luigi.YtTarget(output_folder + "cross_device/sum_by_source_types/mob_yuids_with_desk_count_total"),
            yt_luigi.YtTarget(output_folder + "cross_device/sum_by_source_types/desk_yuids_with_mob_count_total"),
        ]
