import datetime
import logging
import random
import re
import sys
import time

from collections import Counter
from functools import partial

import yt.wrapper as yt
from cached_property import cached_property

from crypta.lib.python import time_utils
from crypta.lib.python.bt.tasks import YtTask
from crypta.lib.python.bt.workflow import IndependentTask, Parameter
from crypta.lib.python.lb_pusher.storage import LBPusher, Batch
from crypta.graph.soup.config.proto.bigb_pb2 import TLinksBuffer
from ads.bsyeti.libs.events.proto import cleaner_pb2

from crypta.graph.bochka.lib import packers
from crypta.graph.bochka.lib.collector_mapper import CollectorMapper
import crypta.lib.python.bt.conf.conf as conf

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

kGB = 1 << 10 << 10 << 10
kMB = 1 << 10 << 10


class YtPushTask(YtTask, IndependentTask):

    """Upload yt table to LB"""

    source_path = Parameter()
    batch_class = Batch
    collector_msg_packer = None
    msg_packer = None

    def run(self, **kwargs):
        super(YtPushTask, self).run(**kwargs)

        ts = time_utils.get_current_time()

        if self.msg_packer is not None:
            with self.yt.TempTable() as lb_out, self.yt.TempTable() as internal:
                self._push_to_lb(source=self.source, internal=internal, destination=lb_out, ts=ts)

                status_counter = Counter()

                for lb_record in self.yt.read_table(lb_out):
                    status_counter[lb_record.get("status", "unknown")] += lb_record.get("count", 0)

                oks = status_counter.get("ok", 0)
                errs = status_counter.get("errors", 0)
                skipped = status_counter.get("skipped", 0)
                taked = status_counter.get("taked", 0)

                logger.info("ok=%d\terr=%d\ttake=%d\tskip=%d", oks, errs, taked, skipped)
                if oks < (oks + errs) * 0.99:  # not more then 1%
                    raise Exception("Failed to upload some records: {} vs {}".format(errs, oks + errs))
                else:
                    logger.info("Succesfully pushed %s records of %s", oks, oks + errs)

        if self.collector_msg_packer is not None:
            self._push_to_collector(self.source, ts)

    def _push_to_lb(self, source, internal, destination, ts):
        def select_pg(total_partitions, locked_partitions):
            if total_partitions == 0:
                return None

            while True:
                pg = random.randint(0, total_partitions - 1)
                if pg not in locked_partitions:
                    # logbroker partition groups should index in range [1, N],
                    # in config we take (partitions - 1), e.g. 0
                    # and finally return +1
                    return pg + 1

        pg_func = partial(
            select_pg, conf.proto.LogbrokerPusher.TotalPartitions, list(conf.proto.LogbrokerPusher.LockedPartitions)
        )

        source = self.prepare_source(source, internal)

        self.map(
            mapper=LBPusher(
                logbroker_host=conf.proto.Logbroker.Url,
                logbroker_port=conf.proto.Logbroker.Port,
                logbroker_topic=conf.proto.LbTopic.TopicName,
                tvm_id=int(conf.proto.Tvm.TvmId),
                max_inflight=conf.proto.LogbrokerPusher.MaxInflight,
                batch_size=conf.proto.LogbrokerPusher.BatchSize,
                batch_byte_size=conf.proto.LogbrokerPusher.BatchByteSize,
                msg_packer=lambda row: self.msg_packer(ts, row),
                batch_class=self.batch_class,
                partition_group=pg_func,
            ),
            source=source,
            destination=destination,
            spec=dict(
                data_size_per_job=conf.proto.LbTopic.JobSize,
                resource_limits=dict(user_slots=conf.proto.LbTopic.MaxJobs),
                secure_vault=dict(CRYPTA_LBPUSH_TVM_SECRET=conf.proto.Tvm.TvmSecret),
                max_failed_job_count=conf.proto.LbTopic.MaxFailedJobs,
                mapper=dict(cpu_limit=0.2),
            ),
            memory_limit=512 * kMB,
        )

    def _push_to_collector(self, source, ts):
        destination = yt.ypath_join(
            conf.proto.Collector.FreshDir,
            "{}_{}".format(conf.proto.Collector.TablePrefix, time_utils.get_current_time()),
        )
        self.map(
            mapper=CollectorMapper(msg_packer=lambda row: self.collector_msg_packer(ts, row)),
            source=source,
            destination=destination,
        )

    @cached_property
    def source(self):
        return self.source_path

    def prepare_source(self, source, internal):
        return source


class Yt2BBPushTask(YtPushTask):

    """Upload yt table to BB TSKV"""

    collector_msg_packer = staticmethod(packers.tskv_to_serialized_proto)


class Yt2BBPushValueTask(YtPushTask):

    """Upload yt table to BB VALUE"""

    collector_msg_packer = staticmethod(packers.value_to_serialized_proto)


class BaseYt2LBPushInspectTask(YtPushTask):

    """Wait for table changed and re upload"""

    force_upload = Parameter(parse=lambda value: value.lower() in {"+", "yes", "true", "t", "1"}, default="False")
    semaphore = Parameter(parse=int, default=0)
    kDTRE = re.compile(r"\d{4}-\d{2}-\d{2}")
    kMARK = "logbroker_timestamp"

    def run(self, *args, **kwargs):
        if not self.source:
            logger.info("No tables to process")
            return
        super(BaseYt2LBPushInspectTask, self).run(*args, **kwargs)
        self._set_mark()

    @cached_property
    def source(self):
        """Allow multi source path"""
        tables = sum([list(self._inspect_path(item.strip())) for item in self.source_path.split(",")], [])
        if self.semaphore:
            return tables[: self.semaphore]
        return tables

    def _inspect_date_dir(self, path):
        # special case for soup daily directory - recursive inspect
        table_iterator = self.yt.search(
            path.rstrip("/{dt}"), node_type=["map_node"], depth_bound=1, follow_links=False
        )
        for node in table_iterator:
            table_path = str(node)
            if self.kDTRE.search(table_path):
                for sub_path in self._inspect_dir(table_path):
                    yield sub_path

    def _inspect_dir(self, path):
        table_iterator = self.yt.search(
            path, node_type=["table"], attributes=["row_count", self.kMARK], depth_bound=1, follow_links=False
        )
        for node in table_iterator:
            if (node.attributes.get("row_count", 0) > 0) and (
                self.force_upload or (not node.attributes.get(self.kMARK, False))
            ):
                yield str(node)

    def _inspect_path(self, path):
        if not path:
            return
        if path.endswith("/{dt}"):
            for sub_path in self._inspect_date_dir(path):
                yield sub_path
        elif (self.yt.get("{}/@type".format(path)) == "table") and (
            self.force_upload or (not self.yt.has_attribute(path, self.kMARK))
        ):
            # for table just upload
            yield path
        else:
            # otherwise - inspect map_node to sub tables
            for sub_path in self._inspect_dir(path):
                yield sub_path

    def _set_mark(self):
        timestamp = int(time.time())
        for table in self.source:
            try:
                self.yt.set("{path}/@{mark}".format(path=table, mark=self.kMARK), timestamp)
            except Exception:
                # table may be removed
                pass


class YtSoup2LBPushTask(BaseYt2LBPushInspectTask):

    """Upload soup table to logbroker"""

    msg_packer = staticmethod(packers.row_to_soup_edge)


class YtIdsStorage2LBPushTask(YtSoup2LBPushTask):

    """Upload idstorage to logbroker"""

    active_days = Parameter(parse=int, default=0)

    def msg_packer(self, ts, row):
        if self.active_days and (row["date_end"] < self.minimal_day):
            # skip old record
            return None
        return packers.row_to_eternal_info(ts, row)

    @cached_property
    def minimal_day(self):
        return (datetime.datetime.now() - datetime.timedelta(days=self.active_days)).strftime("%Y-%m-%d")


class TLinksBufferBatch(Batch):

    """Batch wich pack to TLinkBuffer format"""

    def add(self, what):
        if what is None:
            return False
        self.chunk.append(what)
        self.byte_size += sys.getsizeof(what)
        return True

    def get(self, with_reset=True):
        buff = TLinksBuffer()
        buff.Links.extend(self.chunk)

        # for item in self.chunk:
        #     link = buff.Links.add()
        #     link.CopyFrom(item)

        if with_reset:
            self.clear()
        return buff.SerializeToString()


class YtSoup2VultureTask(YtSoup2LBPushTask):

    """Upload fresh soup to vulture immediately"""

    # batch_class = TLinksBufferBatch
    flow = Parameter(parse=eval, default="None")
    vulture_prod = Parameter(parse=lambda value: value.lower() in {"+", "yes", "true", "t", "1"}, default="False")
    vulture_exp = Parameter(parse=lambda value: value.lower() in {"+", "yes", "true", "t", "1"}, default="False")
    kMARK = "vulture_timestamp"

    def msg_packer(self, ts, row):
        if self.flow is not None and random.random() > self.flow:
            return
        return packers.row_to_vulture(ts, row, self.vulture_prod, self.vulture_exp)

    def prepare_source(self, source, internal):
        # no reduce
        return source

        def filter_edges(row):
            if self.msg_packer(0, row) is not None:
                yield row

        def deduplicate_edges(key, rows):
            for row in rows:
                yield row
                break  # take only first row

        self.map_reduce(
            mapper=filter_edges,
            reducer=deduplicate_edges,
            reduce_by=["sourceType", "logSource", "id1Type", "id2Type", "id1", "id2"],
            source=source,
            destination=internal,
        )
        return internal


class YtVavilov2LBPushTask(YtPushTask):

    """Upload vavilov table to logbroker"""

    def msg_packer(self, ts, row):
        return packers.row_to_vavilov(ts, self.date, row)

    def run(self, **kwargs):
        self.date = self.yt.get("{}/@generate_date".format(self.source))
        super(YtVavilov2LBPushTask, self).run(**kwargs)


class Yt2BBProfileCleanerTask(BaseYt2LBPushInspectTask):

    reason = Parameter(parse=lambda value: cleaner_pb2.TProfileCleanerConfig.EReason.Value(name=value))

    kMARK = "processed_timestamp"

    """Upload styx table to bb"""

    def run(self, **kwargs):
        self.collector_msg_packer = partial(packers.row_to_profile_cleaner, reason=self.reason)
        super(Yt2BBProfileCleanerTask, self).run(**kwargs)
