import collections
from datetime import datetime
import io
import uuid

from google.protobuf.json_format import ParseDict
from library.python.framing import packer as packer_lib

from crypta.graph.engine.proto.graph_pb2 import (
    TEdge,
    TEdgeBetween,
    TGraph,
)
from crypta.graph.rt.events.proto import types_pb2
from crypta.graph.rt.events.proto.event_pb2 import TEventMessage
from crypta.graph.rt.events.proto.michurin_bookkeeping_pb2 import (
    TMichurinBookkeepingEvent,
)
from crypta.graph.rt.events.proto.soup_pb2 import TSoupEvent
from crypta.graph.rt.sklejka.michurin.proto.state_pb2 import TMichurinState
from crypta.graph.soup.config.proto import log_source_pb2, source_type_pb2
from crypta.lib.proto.identifiers.identifiers_pb2 import TGenericID as TGenericIDProto
from crypta.lib.python.identifiers.identifiers import Puid, MmDeviceId
from crypta.lib.python.identifiers.generic_id import GenericID


EVENT_SIZE_LIMIT = 100 * 1024


def assert_cid_map(expected_cid_map, cid_rows):
    cid_map = collections.defaultdict(set)
    for row in cid_rows:
        gid_proto = TGenericIDProto()
        gid_proto.ParseFromString(row[b"Id"])

        cid_proto = TGenericIDProto()
        cid_proto.ParseFromString(row[b"CryptaId"])

        cid_value = cid_proto.CryptaId.Value
        gid = GenericID(proto=gid_proto)
        cid_map[cid_value].add((gid.type, gid.normalize))

    expected_cid_map = {
        cid: {(gid.type, gid.normalize) for gid in gids}
        for cid, gids in expected_cid_map.items()
    }

    assert expected_cid_map == dict(cid_map)


def assert_vults(expected, actual):
    # TODO: mskorokhod add vult check
    assert actual != [], "Vulture states should not be empty"


def pack_shard_data(shard_data):
    data = {}
    for shard, events in shard_data.items():
        data[shard] = []
        output = io.BytesIO()
        packer = packer_lib.Packer(output)

        for event in events:
            packer.add_proto(event)
            if output.tell() > EVENT_SIZE_LIMIT:
                packer.flush()
                data[shard].append(output.getvalue())
                output.seek(0)
                output.truncate()

        if 0 != output.tell():
            packer.flush()
            data[shard].append(output.getvalue())
        assert 0 != data[shard]
        output.close()
    return data


def generate_soup_event(
    gid1,
    gid2,
    cid1,
    cid2=None,
    timestamp=100,
    merge=False,
    log_source=log_source_pb2.OAUTH_LOG,
    source_type=source_type_pb2.APP_PASSPORT_AUTH,
    counter=0,
):

    if cid2 is None:
        cid2 = cid1

    message = TEventMessage()

    message.CryptaId = cid1 or cid2
    message.TimeStamp = timestamp
    message.Type = types_pb2.SOUP

    soup = TSoupEvent()
    soup.CryptaId1 = cid1
    soup.CryptaId2 = cid2
    soup.Unixtime = timestamp
    soup.Counter = counter
    # TODO: test without merge
    soup.Merge = merge

    edge = TEdgeBetween()
    edge.Vertex1.CopyFrom(gid1.to_proto())
    edge.Vertex2.CopyFrom(gid2.to_proto())
    edge.LogSource = log_source
    edge.SourceType = source_type

    soup.Edge.CopyFrom(edge)

    message.Body = soup.SerializeToString()
    return message


def generate_bookkeeping_event(eventType, cid, timestamp):
    message = TEventMessage()

    message.CryptaId = cid
    message.TimeStamp = timestamp
    message.Type = types_pb2.MICHURIN_BOOKKEEPING

    bk_event = TMichurinBookkeepingEvent()
    bk_event.CryptaId = cid
    bk_event.Type = eventType

    message.Body = bk_event.SerializeToString()
    return message


def generate_shards_data_sunflower(shards_count, edge_limit=100, old_edges_count=10):
    shard_data = collections.defaultdict(list)
    expected_cid_map = collections.defaultdict(set)
    vertex_data_len = 0
    for shard_id in range(shards_count):
        identifier = 100 + shard_id
        puid_vertex = Puid(Puid.next())

        # generate edges that would be replaced due to edge_limit
        for _ in range(old_edges_count):
            mm_device_id_vertex = MmDeviceId(str(uuid.uuid4()))

            event = generate_soup_event(
                puid_vertex,
                mm_device_id_vertex,
                cid1=int(identifier),
                cid2=0,
                timestamp=100,
                merge=True,
            )
            shard_data[shard_id].append(event)
            expected_cid_map[0].add(mm_device_id_vertex)
            vertex_data_len += 1

        # generate edges up to edge_limit with larger timestamp
        for _ in range(old_edges_count, edge_limit + old_edges_count):
            mm_device_id_vertex = MmDeviceId(str(uuid.uuid4()))
            event = generate_soup_event(
                puid_vertex,
                mm_device_id_vertex,
                cid1=int(identifier),
                cid2=0,
                timestamp=1000,
                merge=True,
            )
            shard_data[shard_id].append(event)
            expected_cid_map[identifier].add(mm_device_id_vertex)
            vertex_data_len += 1

    # also count resets
    vertex_data_len += old_edges_count * shards_count
    return shard_data, expected_cid_map, vertex_data_len


def generate_shards_data_merged_to(shards_count, reset_after_count, msg_count=10):
    shard_data = collections.defaultdict(list)
    expected_cid_map = collections.defaultdict(set)

    for shard_id in range(shards_count):
        identifier = 100 + shard_id
        puid_vertex = Puid(Puid.next())

        for _ in range(msg_count):
            mm_device_id_vertex = MmDeviceId(str(uuid.uuid4()))
            event = generate_soup_event(
                puid_vertex,
                mm_device_id_vertex,
                cid1=int(identifier),
                merge=True,
            )
            shard_data[shard_id].append(event)

        for _ in range(msg_count, msg_count * 2):
            mm_device_id_vertex = MmDeviceId(str(uuid.uuid4()))
            event = generate_soup_event(
                puid_vertex,
                mm_device_id_vertex,
                cid1=int(identifier),
                counter=reset_after_count + 1,
                merge=True,
            )
            shard_data[shard_id].append(event)

            expected_cid_map[identifier * 10].add(mm_device_id_vertex)
        expected_cid_map[identifier * 10].add(puid_vertex)

    vertex_data_len = sum((len(x) for x in shard_data.values()))
    return shard_data, expected_cid_map, vertex_data_len


def generate_shards_for_merge(shards_count, msg_count=10):
    shard_data = collections.defaultdict(list)
    merge_shard_data = collections.defaultdict(list)
    expected_cid_map = collections.defaultdict(set)
    merged_expected_cid_map = collections.defaultdict(set)

    prev_puid = prev_id = None
    expected_cid_update_requests = merge_expected_cid_update_requests = 0
    for shard_id in range(shards_count):
        identifier = 100 + shard_id
        puid_vertex = Puid(Puid.next())

        mm_device_id_vertex = None
        for _ in range(msg_count):
            mm_device_id_vertex = MmDeviceId(str(uuid.uuid4()))
            event = generate_soup_event(
                puid_vertex, mm_device_id_vertex, cid1=identifier, cid2=0, merge=True
            )
            shard_data[shard_id].append(event)
            expected_cid_map[identifier].add(mm_device_id_vertex)
            merged_expected_cid_map[identifier // 2 * 2 + 1].add(mm_device_id_vertex)
            expected_cid_update_requests += 1

        # merge even to odd
        if shard_id % 2:
            merge_shard_data[shard_id - 1].append(
                generate_soup_event(
                    prev_puid,
                    mm_device_id_vertex,
                    cid1=prev_id,
                    cid2=identifier,
                    timestamp=1000,
                    merge=True,
                )
            )
            # michurin will make two setCidRequest for edge between graphs
            # and one more for each vertex in fromGraph
            merge_expected_cid_update_requests += 2 + msg_count
            # Something has to be in a shard
            merge_shard_data[shard_id].append(
                generate_soup_event(
                    puid_vertex,
                    mm_device_id_vertex,
                    cid1=identifier,
                    timestamp=1000,
                    merge=True,
                )
            )
        else:
            merged_expected_cid_map[identifier + 1].add(puid_vertex)

        prev_puid, prev_id = puid_vertex, identifier

    return (
        shard_data,
        expected_cid_map,
        merge_shard_data,
        expected_cid_update_requests,
        merge_expected_cid_update_requests,
        merged_expected_cid_map,
    )


def generate_michurin_state(
    cid,
    number_of_vertices_per_component,
    number_of_components=1,
    **kwargs,
):
    # TODO add options make graph with different components sizes

    log_source = log_source_pb2.OAUTH_LOG
    source_type = source_type_pb2.APP_PASSPORT_AUTH

    michurin_state = ParseDict(kwargs, TMichurinState())

    graph = TGraph()
    graph.Id = cid
    for idx in range(number_of_components):
        puid_vertex = GenericID("puid", str(cid + idx)).to_proto()
        puid_vertex_index = idx * number_of_vertices_per_component

        graph.Vertices.append(puid_vertex)

        for mm_device_vertex_index in range(
            puid_vertex_index + 1, puid_vertex_index + number_of_vertices_per_component
        ):
            mm_device_id_vertex = GenericID(
                "mm_device_id", str(uuid.uuid4())
            ).to_proto()

            graph.Vertices.append(mm_device_id_vertex)

            edge = TEdge()
            edge.Vertex1 = puid_vertex_index
            edge.Vertex2 = mm_device_vertex_index
            edge.LogSource = log_source
            edge.SourceType = source_type

            graph.Edges.append(edge)

    michurin_state.Graph.CopyFrom(graph)

    return michurin_state


def generate_data_for_cid_update(shards_count):
    shard_data = collections.defaultdict(list)
    expected_cid_map = collections.defaultdict(set)
    michurin_states = {}
    number_of_vertices = 3
    timestamp = 100

    for shard_id in range(shards_count):
        identifier = 100 + shard_id * 1000
        michurin_state = generate_michurin_state(identifier, number_of_vertices)
        michurin_states[identifier] = michurin_state.SerializeToString()

        for vertex in michurin_state.Graph.Vertices:
            expected_cid_map[identifier].add(GenericID(proto=vertex))
        event = generate_bookkeeping_event(
            TMichurinBookkeepingEvent.CID_UPDATE, identifier, timestamp
        )
        shard_data[shard_id].append(event)

    expected_offsets = number_of_vertices * shards_count
    return shard_data, expected_cid_map, michurin_states, expected_offsets


def generate_data_for_tombstone_delete(shards_count, merged_state_TTL):
    shard_data = collections.defaultdict(list)
    michurin_states = {}
    number_of_vertices = 2
    current_timestamp = int(datetime.timestamp(datetime.now()))

    graphs_number = 0
    graphs_to_delete_number = 0
    for shard_id in range(shards_count):
        # This state is old and empty. Shoule be deleted
        identifier_to_delete = 100 + shard_id * 10
        michurin_state_to_delete = generate_michurin_state(
            identifier_to_delete,
            0,
            number_of_components=0,
            MergedToCryptaId=12345,
            TouchedAt=(current_timestamp - merged_state_TTL - 100),
        )
        michurin_states[
            identifier_to_delete
        ] = michurin_state_to_delete.SerializeToString()
        graphs_to_delete_number += 1
        graphs_number += 1
        event = generate_bookkeeping_event(
            TMichurinBookkeepingEvent.TOMBSTONE_DELETE,
            identifier_to_delete,
            current_timestamp,
        )
        shard_data[shard_id].append(event)

        # This state is empty, but young
        identifier_not_old = 100 + shard_id * 10 + 1
        michurin_state_not_old = generate_michurin_state(
            identifier_not_old,
            0,
            number_of_components=0,
            MergedToCryptaId=12345,
            TouchedAt=current_timestamp,
        )
        michurin_states[identifier_not_old] = michurin_state_not_old.SerializeToString()
        graphs_number += 1
        event = generate_bookkeeping_event(
            TMichurinBookkeepingEvent.TOMBSTONE_DELETE,
            identifier_not_old,
            current_timestamp,
        )
        shard_data[shard_id].append(event)

        # This state is non-empty, but old. Should be kept
        identifier_not_tombstone = 100 + shard_id * 10 + 2
        michurin_state_not_tombstone = generate_michurin_state(
            identifier_not_tombstone,
            number_of_vertices,
            MergedToCryptaId=0,
            TouchedAt=(current_timestamp - merged_state_TTL - 100),
        )
        michurin_states[
            identifier_not_tombstone
        ] = michurin_state_not_tombstone.SerializeToString()
        graphs_number += 1
        event = generate_bookkeeping_event(
            TMichurinBookkeepingEvent.TOMBSTONE_DELETE,
            identifier_not_tombstone,
            current_timestamp,
        )
        shard_data[shard_id].append(event)

    return shard_data, michurin_states, graphs_number, graphs_to_delete_number


def generate_data_for_split(shards_count):

    shard_data = collections.defaultdict(list)
    michurin_states = {}
    number_of_vertices_per_component = 2
    current_timestamp = int(datetime.timestamp(datetime.now()))
    graphs_number = 0
    expected_graphs_number = 0
    expected_vertices_number = 0

    for shard_id in range(shards_count):
        # here we build graph that should be splited into two graphs
        identifier_to_split = 100 + shard_id * 10
        michurin_state_to_split = generate_michurin_state(
            identifier_to_split,
            number_of_vertices_per_component,
            number_of_components=2,
        )
        michurin_states[
            identifier_to_split
        ] = michurin_state_to_split.SerializeToString()
        expected_graphs_number += 2
        graphs_number += 1
        expected_vertices_number += number_of_vertices_per_component * 2
        event = generate_bookkeeping_event(
            TMichurinBookkeepingEvent.SPLIT, identifier_to_split, current_timestamp
        )
        shard_data[shard_id].append(event)

        # here we build graph that should not be splitted
        identifier_not_to_split = 100 + shard_id * 10 + 1
        michurin_state_not_to_split = generate_michurin_state(
            identifier_not_to_split,
            number_of_vertices_per_component,
            number_of_components=1,
        )
        michurin_states[
            identifier_not_to_split
        ] = michurin_state_not_to_split.SerializeToString()
        expected_graphs_number += 1
        graphs_number += 1
        expected_vertices_number += number_of_vertices_per_component
        event = generate_bookkeeping_event(
            TMichurinBookkeepingEvent.SPLIT, identifier_not_to_split, current_timestamp
        )
        shard_data[shard_id].append(event)

    return (
        shard_data,
        michurin_states,
        graphs_number,
        expected_graphs_number,
        expected_vertices_number,
    )
