import argparse
import datetime
import os

from crypta.graph.rt.sklejka.michurin.proto import state_pb2
from crypta.lib.proto.identifiers.identifiers_pb2 import TGenericID as TGenericIDProto
from crypta.lib.python.identifiers.generic_id import GenericID

from ads.bsyeti.tests.test_lib.data_collector.codec_decompressor import get_decompressed_object
from yweb.antimalware.libs import farmhash

from google.protobuf.json_format import MessageToDict

import yt.wrapper as yt

DEFAULT_STATE_TABLE = "//home/crypta/production/rtsklejka/state/michurin_state"
STATE_TABLE = os.getenv('MICHURIN_STATE_TABLE', DEFAULT_STATE_TABLE)
CID_DEFAULT_STATE_TABLE = "//home/crypta/production/rtsklejka/state/cryptaid_state"
CID_STATE_TABLE = os.getenv('CID_STATE_TABLE', CID_DEFAULT_STATE_TABLE)


def farm_hash(*values):
    result = 0xDEADC0DE
    for value in values:
        result = farmhash.farm_fingerprint(
            (result, farmhash.farm_fingerprint(value))
        )
    return result ^ len(values)


def check_cid_state(client, id_hash, cid_path):
    rows = list(client.select_rows(
        "* FROM [{}] WHERE Hash={}".format(CID_STATE_TABLE, id_hash),
        format="yson",
    ))
    if not rows:
        print("No data")
        return

    row = rows[0]
    gid_bytes = yt.yson.get_bytes(row['Id'])
    gid_proto = TGenericIDProto.FromString(gid_bytes)
    gid = GenericID(proto=gid_proto)
    cid_bytes = yt.yson.get_bytes(row['CryptaId'])
    cid_proto = TGenericIDProto.FromString(cid_bytes)
    cid = GenericID(proto=cid_proto)
    print(
        "{}({}) has {}({})".format(
            gid_proto.WhichOneof('identifier'), gid.value,
            cid_proto.WhichOneof('identifier'), cid.value),
    )


def time_fun(ts):
    return datetime.datetime.fromtimestamp(ts).strftime("%Y-%m-%d %H:%m")


def check_michurin_state(client, cid, show_vertices, show_edges, convert_time, michurin_path, cid_path):
    t = time_fun if convert_time else lambda x: x
    print("Checking '{}'".format(cid))
    rows = client.select_rows(
        "* FROM [{}] WHERE Id={}".format(michurin_path, cid),
        format="yson",
    )
    rows = list(rows)
    if not len(rows):
        print("No state for {}".format(cid))
        return
    state = rows[0]["State"]
    codec = rows[0]["Codec"]
    state_data = get_decompressed_object(yt.yson.get_bytes(state), codec.encode())
    st = state_pb2.TMichurinState().FromString(state_data)
    print("ByteSize: {}".format(st.ByteSize()))
    print("Merged to: {}".format(st.MergedToCryptaId))
    print("LimitReset: {}".format(st.LimitResetCount))
    print("TouchedAt: {}, TouchCount: {}".format(t(st.TouchedAt), st.TouchCount))
    print("Edges: {}, Vertices: {}".format(len(st.Graph.Edges), len(st.Graph.Vertices)))
    print("")
    if not len(st.Graph.Vertices) and not len(st.Graph.Edges):
        print("Graph is empty")
        return st.MergedToCryptaId

    if show_edges:
        print("\nEdges:")
        for edge in st.Graph.Edges:

            e_dict = MessageToDict(edge)
            v1_idx = e_dict.get('Vertex1', 0)
            v1_gid = st.Graph.Vertices[v1_idx]
            v1 = "{}({})".format(v1_gid.WhichOneof('identifier'), GenericID(proto=v1_gid).value)
            v2_idx = e_dict.get('Vertex2', 0)
            v2_gid = st.Graph.Vertices[v2_idx]
            v2 = "{}({})".format(v2_gid.WhichOneof('identifier'), GenericID(proto=v2_gid).value)

            print("{{v1:{},\tv2:{},\tts:{},\tls:{},\tst:{},\tseen:{}}}".format(
                # e_dict.get('Vertex1', 0),
                # e_dict.get('Vertex2', 0),
                v1, v2,
                t(e_dict["TimeStamp"]),
                e_dict["LogSource"],
                e_dict["SourceType"],
                e_dict.get("SeenCount", 0)
            ))

    if show_vertices:
        print("\nVertices:")
        for i, v in enumerate(st.Graph.Vertices):
            print("{},".format(i), end="")
            gid = GenericID(proto=v)
            id_hash = farm_hash(gid.serialize())
            check_cid_state(client, id_hash, cid_path)
    return st.MergedToCryptaId


def main():
    parser = argparse.ArgumentParser(prog='Read michurin state')
    parser.add_argument('cryptaid', help='cryptaid', nargs='+')
    parser.add_argument('--vertices', action="store_true", help='show vertices', default=False)
    parser.add_argument('--no-vertices', action="store_false", dest='vertices', help='do not shot vertices')
    parser.add_argument('--edges', action="store_true", help='show edges', default=True)
    parser.add_argument('--no-edges', action="store_false", dest='edges', help='do not shot edges')
    parser.add_argument('--follow', action="store_true", default=True, help='follow merges')
    parser.add_argument('--no-follow', action="store_true", dest='follow', help='follow merges')
    parser.add_argument('-t', '--convert-time', action="store_true", help='convert time to human-readable')
    parser.add_argument('--michurin-path', help='Path to michurinstate table', default=STATE_TABLE)
    parser.add_argument('--cid-path', help='Path to cid state table', default=CID_STATE_TABLE)

    args = parser.parse_args()

    print("Using MICHURIN_STATE_TABLE: {}\n".format(args.michurin_path))
    print("Using CID_STATE_TABLE: {}\n".format(args.cid_path))

    client = yt.YtClient(proxy=os.getenv('YT_PROXY', 'hahn'))
    cids = args.cryptaid

    for cid in cids:
        while cid:
            merged_to = check_michurin_state(
                client, cid, args.vertices, args.edges, args.convert_time, args.michurin_path, args.cid_path)
            print("\n" + "=" * 20)
            cid = merged_to if args.follow else 0


if __name__ == "__main__":
    main()
