import copy

import yt.wrapper as yt_wrapper
from yt.wrapper import TablePath

from datacloud.crypta_utils.make_join_tables_v2 import CryptaExtractTablesV2


class YuidTables(object):
    def __init__(self, root, target_cols_schema):
        self.root = root

        self.all_cid = yt_wrapper.TablePath(
            yt_wrapper.ypath_join(self.root, 'all_cid'),
            attributes={
                'schema': [
                    {'name': 'cid', 'type': 'string'},
                    {'name': 'id_type', 'type': 'string'},
                    {'name': 'id_value', 'type': 'string'},
                    {'name': 'external_id', 'type': 'string'},
                    {'name': 'timestamp', 'type': 'int64'},
                ] + target_cols_schema
            }
        )
        self.all_yuid = yt_wrapper.TablePath(
            yt_wrapper.ypath_join(self.root, 'all_yuid'),
            attributes={
                'schema': [
                    {'name': 'external_id', 'type': 'string'},
                    {'name': 'cid', 'type': 'string'},
                    {'name': 'id_type', 'type': 'string'},
                    {'name': 'id_value', 'type': 'string'},
                    {'name': 'yuid', 'type': 'string'},
                    {'name': 'timestamp', 'type': 'int64'},
                ] + target_cols_schema
            }
        )


def get_all_yuid_v2(prefix, local_crypta_folder=None, add_devids=False,
                    add_yuids=False, yt_client=yt_wrapper, tmp_path=None, target_cols_schema=None):
    target_cols_schema = target_cols_schema or []
    result_folder = "/".join(prefix.split("/")[:-1])
    input_table = prefix + "input"

    if local_crypta_folder is None:
        base_root = "//home/x-products/production"
        sub_folder = "/crypta_v2/crypta_db_last"
    else:
        base_root = local_crypta_folder
        sub_folder = ""

    crypta_tables = CryptaExtractTablesV2(
        base_root=base_root,
        sub_folder=sub_folder
    )
    result_tables = YuidTables(result_folder, target_cols_schema=target_cols_schema)

    @yt_wrapper.with_context
    def join_cid(key, recs, context):
        base_recs = []
        for rec in recs:
            if context.table_index == 0:
                base_recs.append(dict(rec))
            elif len(base_recs) == 0:
                break
            else:
                for base_rec in base_recs:
                    base_rec["cid"] = rec["cid"]
                    yield base_rec

    @yt_wrapper.with_context
    def join_yuid(key, recs, context):
        base_recs = []
        for rec in recs:
            if context.table_index == 0:
                base_recs.append(dict(rec))
            elif len(base_recs) == 0:
                break
            else:
                if rec["id_type"] == "yandexuid":
                    for base_rec in base_recs:
                        new_rec = copy.copy(base_rec)
                        new_rec["yuid"] = rec["id_value"]
                        yield new_rec

    with yt_client.Transaction(), \
            yt_client.TempTable(path=tmp_path, prefix='input') as prepared_input:
        yt_client.run_sort(input_table, prepared_input, sort_by=("id_type", "id_value"))
        input_tables = [
            prepared_input,
            TablePath(crypta_tables.id_value_to_cid, lower_key=("phone_md5", ""), upper_key=("phone_md5 ", ""), attributes={'foreign': True}),
            TablePath(crypta_tables.id_value_to_cid, lower_key=("email_md5", ""), upper_key=("email_md5 ", ""), attributes={'foreign': True})
        ]
        if add_yuids:
            input_tables.append(
                TablePath(crypta_tables.id_value_to_cid, lower_key=("yandexuid", ""), upper_key=("yandexuid ", ""), attributes={'foreign': True}),
            )
        # result_tables.all_cid.create_table(force=True)
        yt_client.run_reduce(
            join_cid,
            input_tables,
            result_tables.all_cid,
            reduce_by=("id_type", "id_value"),
            join_by=("id_type", "id_value")
        )
        yt_client.run_sort(result_tables.all_cid, sort_by="cid")
        # result_tables.all_yuid.create_table(force=True)
        yt_client.run_reduce(
            join_yuid,
            [
                result_tables.all_cid,
                TablePath(
                    crypta_tables.cid_to_all,
                    attributes={'foreign': True}
                ),
            ],
            result_tables.all_yuid,
            reduce_by="cid",
            join_by="cid"
        )
        yt_client.run_sort(result_tables.all_yuid, sort_by="yuid")
