import os
import sys
from threading import Lock
import time

from kikimr.yndx.api.protos.persqueue_pb2 import WriteResponse
from rtmapreduce.protos.data_pb2 import TBulk
import six
import yt.wrapper as yt

from crypta.lib.python.lb_pusher.logbroker import PQClient


kMB = 1 << 10 << 10


class Batch(object):
    def __init__(self, batch_size=1000, batch_byte_size=1 * kMB, splitter=six.b("\n")):
        self.capacity = batch_size
        self.max_byte_size = batch_byte_size
        self.byte_size = 0
        self.splitter = six.ensure_binary(splitter)
        self.chunk = []

    @property
    def size(self):
        return len(self.chunk)

    @property
    def full(self):
        return self.size >= self.capacity or self.byte_size >= self.max_byte_size

    def add(self, what):
        if what is None:
            # baby don't hurt me
            #      don't hurt me
            return False
        self.chunk.append(six.ensure_binary(what))
        self.byte_size += len(self.chunk[-1])
        return True

    def clear(self):
        self.chunk = []
        self.byte_size = 0

    def get(self, with_reset=True):
        result = self.splitter.join(self.chunk)
        if with_reset:
            self.clear()
        return result


class BulkBatch(Batch):

    """ Batch wich pack to rtmr format """

    source_id = None
    table_name = None

    def add(self, what):
        self.chunk.append(what)
        self.byte_size += sys.getsizeof(what)
        return True

    def get(self, with_reset=True):
        bulk = TBulk()
        bulk.SourceId = self.source_id

        for item in self.chunk:
            chunk = bulk.Chunks.add()
            chunk.Key = item["key"]
            chunk.Table = item.get("table", self.table_name)
            entry = chunk.Entries.add()
            entry.SubKey = item["subkey"]
            entry.Value = item["value"]

        if with_reset:
            self.clear()
        return bulk.SerializeToString()


@yt.aggregator
class LBPusher(object):
    def __init__(
        self,
        logbroker_host,
        logbroker_port,
        logbroker_topic,
        tvm_id,
        max_inflight=100,
        batch_size=1000,
        batch_byte_size=1 * kMB,
        msg_packer=lambda msg: msg,
        batch_class=Batch,
        partition_group=None,
    ):
        self.logbroker_host = logbroker_host
        self.logbroker_port = logbroker_port
        self.logbroker_topic = logbroker_topic
        self.tvm_id = tvm_id
        self.max_inflight = max_inflight
        self.batch_size = batch_size
        self.batch_byte_size = batch_byte_size
        self.inflight = 0
        self.lock = Lock()
        self.oks = 0
        self.errors = 0
        self.taked = 0
        self.skipped = 0
        self.msg_packer = msg_packer
        self.batch_class = batch_class
        self.partition_group = partition_group

    def future_callback(self, future):
        with self.lock:
            self.inflight -= 1
        result = future.result()
        if isinstance(result, WriteResponse) and result.HasField("ack"):
            self.oks += 1
        else:
            self.errors += 1

    def __call__(self, records):
        def send(batch):
            write_future = writer.write(batch.get())
            with self.lock:
                self.inflight += 1
            write_future.add_done_callback(self.future_callback)

        tvm_secret = os.environ.get("YT_SECURE_VAULT_CRYPTA_LBPUSH_TVM_SECRET")
        batch = self.batch_class(batch_size=self.batch_size, batch_byte_size=self.batch_byte_size)

        with PQClient(self.logbroker_host, self.logbroker_port, tvm_id=self.tvm_id, tvm_secret=tvm_secret) as client:
            pg = self.partition_group() if callable(self.partition_group) else self.partition_group
            with client.get_writer(self.logbroker_topic, partition_group=pg) as writer:
                for msg in records:
                    while self.inflight >= self.max_inflight:
                        time.sleep(0.05)

                    if batch.add(self.msg_packer(msg)):
                        self.taked += 1
                    else:
                        self.skipped += 1

                    if batch.full:
                        send(batch)

                if batch.size:
                    send(batch)

                while self.inflight > 0:
                    time.sleep(0.05)

        yield dict(status="ok", count=self.oks)
        yield dict(status="error", count=self.errors)
        yield dict(status="skipped", count=self.skipped)
        yield dict(status="taked", count=self.taked)
