from crypta.graph.rt.sklejka.michurin.proto import state_pb2
from crypta.lib.python.identifiers.generic_id import GenericID

from ads.bsyeti.tests.test_lib.data_collector.codec_decompressor import get_decompressed_object
import yt.wrapper as yt
import yt.yson as yson

import argparse
import time
import os

MIC_TABLE = '//home/crypta/testing/rtsklejka/state/michurin_state'
DST_TABLE = '//home/crypta/testing/rtsklejka/state/cryptaid_state_CHANGE_ME'


def rebuild_cryptaid_state(client, from_table=MIC_TABLE, to_table=DST_TABLE):
    print("Creating table {}".format(to_table))
    print("Reading from {}".format(from_table))
    time.sleep(10)
    client.create(
        "table",
        to_table,
        attributes=dict(
            schema=yson.to_yson_type(
                [
                    dict(name="Hash", type="uint64", expression="farm_hash(Id)", sort_order="ascending"),
                    dict(name="Id", type="string", required=True, sort_order="ascending"),
                    dict(name="CryptaId", type="string", required=True),
                ],
                attributes=dict(unique_keys=True),
            ),
            optimize_for="lookup",
        ),
        recursive=True,
        force=True,
    )

    tmp_unordered_schema = [
        dict(name="Hash", type="uint64", expression="farm_hash(Id)"),
        dict(name="Id", type="string", required=True),
        dict(name="CryptaId", type="string", required=True),
    ]

    def proto_pack_fun(row):
        state = row["State"]
        codec = row["Codec"]
        state_data = get_decompressed_object(yt.yson.get_bytes(state), codec.encode())
        st = state_pb2.TMichurinState().FromString(state_data)
        for vertex in st.Graph.Vertices:
            yield {
                "Id": GenericID(proto=vertex).serialize(),
                "CryptaId": GenericID("cryptaid", str(row["Id"])).serialize(),
            }

    def unique_reducer(_, recs):
        yield next(recs)

    table_writer = {"table_writer": {"block_size": 256 * 2 ** 10, "desired_chunk_size": 100 * 2 ** 20}}
    spec = {
        "job_io": table_writer,
        "map_job_io": table_writer,
        "merge_job_io": table_writer,
        "sort_job_io": table_writer,
    }

    with client.TempTable(attributes={"schema": tmp_unordered_schema}) as tmp_crypta_id:
        print("Running map_reduce from {} to {}".format(from_table, tmp_crypta_id))
        client.run_map_reduce(proto_pack_fun, unique_reducer, from_table, tmp_crypta_id, reduce_by=["Hash", "Id"], spec=spec)
        print("Running sort for {}".format(tmp_crypta_id))
        client.run_sort(tmp_crypta_id, tmp_crypta_id, sort_by=["Hash", "Id"], spec=spec)
        print("Running merge between {} and {}".format(tmp_crypta_id, to_table))
        client.run_merge(tmp_crypta_id, to_table, spec=spec)

    print("Making {} dynamic".format(to_table))
    client.alter_table(to_table, dynamic=True)
    client.set("{path}/@in_memory_mode".format(path=to_table), "uncompressed")
    client.set("{path}/@enable_lookup_hash_table".format(path=to_table), True)
    client.set("{path}/@tablet_cell_bundle".format(path=to_table), "crypta-graph")
    print("Mounting {}".format(to_table))
    client.mount_table(to_table, sync=True)


def main():
    parser = argparse.ArgumentParser("Rebuild cryptaid")
    parser.add_argument('-f', '--from-table', default=MIC_TABLE)
    parser.add_argument('-t', '--to-table', default=DST_TABLE)
    args = parser.parse_args()

    client = yt.YtClient(proxy=os.getenv('YT_PROXY', 'hahn'), token=os.getenv('YT_TOKEN'))
    rebuild_cryptaid_state(client,
                           from_table=args.from_table,
                           to_table=args.to_table)


if __name__ == '__main__':
    main()
