import contextlib
import datetime
import functools
import logging
import random

import yt.yson as yson

import crypta.lib.python.bt.conf.conf as conf  # noqa

from crypta.lib.python.bt.tasks import TransactionalYtTask
from crypta.lib.python.bt.workflow import IndependentTask, Parameter
from crypta.lib.python.identifiers.generic_id import GenericID

logger = logging.getLogger(__name__)


class CreateCryptaIdTableTask(IndependentTask, TransactionalYtTask):

    """ Create crypta id table with active identifiers """

    src = Parameter(default="//home/crypta/production/state/graph/v2/export/ActiveIdentifiers")
    dst = Parameter(default="//home/crypta/testing/rtsklejka/state/crypta_id")
    sampling = Parameter(parse=eval, default="1.0")

    def run(self, *args, **kwargs):
        if not self.src or not self.dst:
            raise ValueError("Parameters required {src!r} -> {dst!r}".format(src=self.src, dst=self.dst))
        self._create_table(self.src, self.dst)

    @contextlib.contextmanager
    def run_context(self):
        with super(CreateCryptaIdTableTask, self).run_context() as ctx:
            yield ctx
        # make dynamic after transaction
        self.yt.alter_table(self.dst, dynamic=True)
        self.yt.set("{path}/@in_memory_mode".format(path=self.dst), "uncompressed")
        self.yt.set("{path}/@primary_medium".format(path=self.dst), "ssd_blobs")
        self.yt.set("{path}/@enable_lookup_hash_table".format(path=self.dst), True)
        self.yt.mount_table(self.dst, sync=True)

    def _create_table(self, src, dst):
        self.yt.create(
            "table",
            dst,
            attributes=dict(
                schema=yson.to_yson_type(
                    [
                        dict(name="Hash", type="uint64", expression="farm_hash(Id)", sort_order="ascending"),
                        dict(name="Id", type="string", required=True, sort_order="ascending"),
                        dict(name="CryptaId", type="string", required=True),
                    ],
                    attributes=dict(unique_keys=True),
                ),
                optimize_for="lookup",
            ),
            recursive=True,
            force=True,
        )

        tmp_unordered_schema = [
            dict(name="Hash", type="uint64", expression="farm_hash(Id)"),
            dict(name="Id", type="string", required=True),
            dict(name="CryptaId", type="string", required=True),
        ]

        proto_pack_fun = functools.partial(self._proto_pack, sampling=self.sampling or 1.0)
        table_writer = {"table_writer": {"block_size": 256 * 2 ** 10, "desired_chunk_size": 100 * 2 ** 20}}
        spec = {
            "job_io": table_writer,
            "map_job_io": table_writer,
            "merge_job_io": table_writer,
            "sort_job_io": table_writer,
        }

        with self.yt.TempTable(attributes={"schema": tmp_unordered_schema}) as tmp_crypta_id:
            self.yt.run_map(proto_pack_fun, src, tmp_crypta_id, spec=spec)
            self.yt.run_sort(tmp_crypta_id, tmp_crypta_id, sort_by=["Hash", "Id"], spec=spec)
            self.yt.run_merge(tmp_crypta_id, dst, spec=spec)

    @staticmethod
    def _proto_pack(row, sampling):
        proto_id = GenericID(row["id_type"], row["id"]).serialize()
        proto_cid = GenericID("cryptaid", str(row["cryptaId"])).serialize()
        if random.random() <= sampling:
            yield {"Id": proto_id, "CryptaId": proto_cid}

    @staticmethod
    def _unique_reduce(key, rows):
        counter = 0
        for row in rows:
            counter += 1
            yield row
        if counter != 1:
            raise RuntimeError("Unique keys fail")


class UpdateCryptaIdTableTask(IndependentTask, TransactionalYtTask):

    """ Insert new values into table """

    src = Parameter(default="//home/crypta/production/state/graph/v2/export/ActiveIdentifiers")
    dst = Parameter(default="//home/crypta/testing/rtsklejka/state/crypta_id")
    sampling = Parameter(parse=eval, default="1.0")

    def run(self, *args, **kwargs):
        if not self.src or not self.dst:
            raise ValueError("Parameters required {src!r} -> {dst!r}".format(src=self.src, dst=self.dst))

    @contextlib.contextmanager
    def run_context(self):
        with super(UpdateCryptaIdTableTask, self).run_context() as ctx:
            yield ctx
        self.yt.config.update(allow_http_requests_to_yt_from_job=True)
        self.yt.mount_table(self.dst, sync=True)
        self.yt.set("{path}/@max_data_ttl".format(path=self.dst), datetime.timedelta(days=1, hours=12).seconds * 1000)
        # run without transaction
        self._update_table(self.src, self.dst)

    def _update_table(self, src, dst):
        rpc_client = self._init_yt()
        rpc_client.config.update(backend="rpc", allow_http_requests_to_yt_from_job=True)
        proto_pack_fun = functools.partial(
            self._proto_pack, client=rpc_client, dst=self.dst, sampling=self.sampling or 1.0
        )

        tmp_unordered_schema = [
            dict(name="Hash", type="uint64", expression="farm_hash(Id)"),
            dict(name="Id", type="string", required=True),
            dict(name="CryptaId", type="string", required=True),
        ]

        with self.yt.TempTable(attributes={"schema": tmp_unordered_schema}) as tmp_crypta_id:
            self.yt.run_reduce(
                proto_pack_fun,
                src,
                tmp_crypta_id,
                spec=dict(auto_merge=dict(mode="relaxed"), resource_limits=dict(user_slots=50)),
                reduce_by=["cryptaId"],
            )

    @staticmethod
    def _proto_pack(key, rows, client, dst, sampling):
        def _inner(row):
            proto_id = GenericID(row["id_type"], row["id"]).serialize()
            proto_cid = GenericID("cryptaid", str(row["cryptaId"])).serialize()
            if random.random() <= sampling:
                return {"Id": proto_id, "CryptaId": proto_cid}

        data = []
        for row in rows:
            proto = _inner(row)
            if proto is None:
                continue
            data.append(proto)

            if len(data) > 100:
                client.insert_rows(dst, data)
                data = []

        if len(data):
            client.insert_rows(dst, data)
