import argparse
import io
import os

from crypta.lib.proto.identifiers import id_type_pb2

from crypta.graph.engine.proto.graph_pb2 import TEdgeBetween
from crypta.graph.rt.events.proto import types_pb2
from crypta.graph.rt.events.proto.event_pb2 import TEventMessage
from crypta.graph.rt.events.proto.soup_pb2 import TSoupEvent
from crypta.graph.soup.config.proto import log_source_pb2, source_type_pb2
from crypta.graph.rt.events.proto.michurin_bookkeeping_pb2 import TMichurinBookkeepingEvent as BK_event

from crypta.lib.proto.identifiers.identifiers_pb2 import TGenericID as TGenericIDProto
from crypta.lib.python.identifiers.generic_id import GenericID

from ads.bsyeti.big_rt.py_lib import YtQueue

from library.python.framing import packer as packer_lib
from yweb.antimalware.libs import farmhash
import murmurhash

import yt.wrapper as yt

EVENT_SIZE_LIMIT = 100 * 1024

DEFAULT_QUEUE = '//home/crypta/production/rtsklejka/qyt/sharded_events'
DEFAULT_STATE_TABLE = "//home/crypta/production/rtsklejka/state/cryptaid_state"
CID_STATE_TABLE = os.getenv('CID_STATE_TABLE', DEFAULT_STATE_TABLE)

client = yt.YtClient(proxy=os.getenv('YT_PROXY', 'markov'))


def pack_shard_data(shard_data):
    data = {}
    for shard, events in shard_data.items():
        data[shard] = []
        output = io.BytesIO()
        packer = packer_lib.Packer(output)

        for event in events:
            packer.add_proto(event)
            if output.tell() > EVENT_SIZE_LIMIT:
                packer.flush()
                data[shard].append(output.getvalue())
                output.seek(0)
                output.truncate()

        if 0 != output.tell():
            packer.flush()
            data[shard].append(output.getvalue())
        assert 0 != data[shard]
        output.close()
    return data


def generate_new_cryptaid(gid):
    type_name = id_type_pb2.EIdType.Name(gid.type).lower()
    hash_string = "{}({})".format(gid.value, type_name)
    return murmurhash.hash64(hash_string)


def generate_event_message(
    gid1,
    gid2,
    cid1,
    cid2=None,
    timestamp=100,
    log_source=log_source_pb2.OAUTH_LOG,
    source_type=source_type_pb2.APP_PASSPORT_AUTH,
    counter=0,
    merge=False,
):

    if cid2 is None:
        cid2 = cid1

    message = TEventMessage()

    message.CryptaId = cid1 or cid2
    message.TimeStamp = timestamp
    message.Type = types_pb2.SOUP

    soup = TSoupEvent()
    soup.CryptaId1 = cid1
    soup.CryptaId2 = cid2
    soup.Unixtime = timestamp
    soup.Counter = counter
    soup.Merge = merge

    edge = TEdgeBetween()
    edge.Vertex1.CopyFrom(gid1.to_proto())
    edge.Vertex2.CopyFrom(gid2.to_proto())
    edge.LogSource = log_source
    edge.SourceType = source_type

    soup.Edge.CopyFrom(edge)

    message.Body = soup.SerializeToString()
    return message


def farm_hash(*values):
    result = 0xDEADC0DE
    for value in values:
        result = farmhash.farm_fingerprint((result, farmhash.farm_fingerprint(value)))
    return result ^ len(values)


def resolve(client, gid):
    id_hash = farm_hash(gid.serialize())
    rows = list(
        client.select_rows(
            "* FROM [{}] WHERE Hash={}".format(CID_STATE_TABLE, id_hash),
            format="yson",
        )
    )
    if not rows:
        return 0

    cid_proto = TGenericIDProto.FromString(yt.yson.get_bytes(rows[0]['CryptaId']))
    return cid_proto.CryptaId.Value


def set_cid(args_cid, gid, should_resolve, client):
    gid_str = "{}({})".format(gid.to_proto().WhichOneof('identifier'), gid.value)

    if args_cid is not None:
        print("Using cid {} for {}".format(args_cid, gid_str))
        return args_cid
    elif should_resolve:
        cid = resolve(client, gid)
        print("Resolved {} to cid {}".format(gid_str, cid))
        return cid
    print("Using cid 0 for {}".format(gid_str))
    return 0


def parse_args():
    parser = argparse.ArgumentParser(prog='inject')
    parser.add_argument('--dry-run', action='store_true', help='Dry run')
    parser.add_argument('-t', '--timestamp', type=int, help='timestamp', default=100)
    parser.add_argument('-p', '--path', help='path to queue', default=DEFAULT_QUEUE)
    subparsers = parser.add_subparsers(dest='event_type', help='Event type: soup_event or bk_event', required=True)

    parser_soup = subparsers.add_parser('soup_event', help='soup help')
    parser_soup.add_argument('-c', '--cid1', type=int, default=None, help='cryptaid1', required=False)
    parser_soup.add_argument('--cid2', type=int, default=None, help='cryptaid2', required=False)
    parser_soup.add_argument('--counter', type=int, help='counter', default=0)
    parser_soup.add_argument('-m', '--merge', action='store_true', help='merge')
    parser_soup.add_argument('-l', '--log-source', help='log_source', default='OAUTH_LOG')
    parser_soup.add_argument('-s', '--source-type', help='source_type', default='APP_PASSPORT_AUTH')
    parser_soup.add_argument('--no-resolve', action='store_false', dest='resolve', default=True, help='Do not resolve')
    parser_soup.add_argument('gids', help='TYPE VALUE', nargs=4)

    parser_bk = subparsers.add_parser('bk_event', help='bk help')
    parser_bk.add_argument(
        'bookkeeping_type',
        choices=BK_event.EBookkeepingType.keys()[1:],
        help='type of bookkeeping message',
    )
    parser_bk.add_argument('-c', '--cid', type=int, default=None, help='cryptaid', required=True)

    return parser.parse_args()


def generate_soup_event(args):
    gid1 = GenericID(args.gids[0].lower(), args.gids[1])
    gid2 = GenericID(args.gids[2].lower(), args.gids[3])

    cid1 = set_cid(args.cid1, gid1, args.resolve, client)
    cid2 = set_cid(args.cid2, gid2, args.resolve, client)

    if not cid1 and not cid2:
        shard_cid = generate_new_cryptaid(gid1)
    else:
        shard_cid = cid1 or cid2

    ls = getattr(log_source_pb2, args.log_source.upper())
    st = getattr(source_type_pb2, args.source_type.upper())

    message = generate_event_message(
        gid1,
        gid2,
        cid1,
        cid2,
        timestamp=args.timestamp,
        log_source=ls,
        source_type=st,
        counter=args.counter,
        merge=args.merge,
    )
    return message, shard_cid


def generate_bk_event(args):
    message = TEventMessage()

    message.CryptaId = args.cid
    message.TimeStamp = args.timestamp
    message.Type = types_pb2.MICHURIN_BOOKKEEPING

    bk_event = BK_event()
    bk_event.CryptaId = args.cid
    bk_event.Type = BK_event.EBookkeepingType.Value(name=args.bookkeeping_type)

    message.Body = bk_event.SerializeToString()
    return message


def inject(args, event, shard_cid):
    queue = YtQueue({'path': args.path, 'cluster': client.config["proxy"]["url"]})
    shard_count = queue.get_shard_count()
    shard_data = {}
    shard = shard_cid % shard_count
    shard_data[shard] = [event]
    if args.dry_run:
        print(f"Dry run. Would have written bk_event to {shard} shard (cid: {shard_cid})")
    else:
        print(f"Will write bk_event to {shard} shard (cid: {shard_cid}, queue: {args.path})")
        queue.write(pack_shard_data(shard_data), "zstd_6")


def main():
    args = parse_args()
    if args.event_type == 'soup_event':
        event, cid = generate_soup_event(args)
    elif args.event_type == 'bk_event':
        event = generate_bk_event(args)
        cid = args.cid
    inject(args, event, cid)


if __name__ == "__main__":
    main()
