import asyncio
import concurrent.futures
import logging

import yt_yson_bindings
from typing import Any, AsyncIterator, Awaitable, List, Optional, Tuple
from yp import data_model
from yp.common import GrpcResourceExhaustedError, YpRowsAlreadyTrimmedError, YpNoSuchObjectError
from yp_proto.yp.client.api.proto import object_service_pb2
from yt.common import YtResponseError

from infra.deploy_notifications_controller.lib.models.event_meta import EventMeta
from infra.deploy_notifications_controller.lib.models.stage import Stage, Meta, Spec, Status
from infra.deploy_notifications_controller.lib.models.stage_history_change import StageHistoryChange, \
    StageHistoryCreate, StageHistoryUpdate, StageHistoryRemove
from infra.deploy_notifications_controller.lib.models.notification_policy import NotificationPolicy


class YpClient:
    def __init__(
        self,
        slave,
        stage_filter: Optional[str] = None,
        thread_pool_size: int = 10,
    ):
        self.log = logging.getLogger('yp_client')

        self.slave = slave
        self.loop = asyncio.get_event_loop()
        self.thread_pool = concurrent.futures.ThreadPoolExecutor(
            max_workers=thread_pool_size,
            thread_name_prefix='yp-client',
        )
        self.stage_filter = stage_filter

    def _apply(self, fn, *args) -> Awaitable[Any]:
        return asyncio.get_event_loop().run_in_executor(self.thread_pool, fn, *args)

    @staticmethod
    def yt_error_not_enough_memory(e):
        return 'Not enough memory to serve allocation' in e.message

    @staticmethod
    def decrease_limit_and_continue(e, req):
        is_grpc_resource_exhausted = isinstance(e, GrpcResourceExhaustedError)
        is_yt_not_enough_memory = e.find_matching_error(predicate=YpClient.yt_error_not_enough_memory) is not None

        decrease_limit_and_continue = is_grpc_resource_exhausted or is_yt_not_enough_memory
        decrease_limit_and_continue &= req.options.limit >= 2

        return decrease_limit_and_continue

    async def _select_objects_values(
        self,
        object_type: int,
        batch_size: int,
        selectors: Optional[List[str]] = None,
        filter: Optional[str] = None,
        timestamp: Optional[int] = None,
    ) -> AsyncIterator[Tuple[int, List[Any]]]:
        req = object_service_pb2.TReqSelectObjects()
        req.object_type = object_type
        if timestamp is not None:
            req.timestamp = timestamp
        if filter is not None:
            req.filter.query = filter
        selectors = selectors or ['']
        req.selector.paths.extend(selectors)
        req.format = object_service_pb2.PF_YSON
        req.options.limit = batch_size

        while True:
            try:
                resp = await self._apply(self.slave.SelectObjects, req)
            except (GrpcResourceExhaustedError, YtResponseError) as e:
                if YpClient.decrease_limit_and_continue(e, req):
                    req.options.limit //= 2
                    continue

                raise

            read = 0

            for r in resp.results:
                read += 1
                yield r.value_payloads

            if read < req.options.limit:
                break
            else:
                req.options.continuation_token = resp.continuation_token

    async def generate_timestamp(self) -> int:
        req = object_service_pb2.TReqGenerateTimestamp()
        resp = await self._apply(self.slave.GenerateTimestamp, req)
        return resp.timestamp

    @staticmethod
    def stage_from_annotation(stage_yson) -> Optional[Stage]:
        stage_id = yt_yson_bindings.loads(stage_yson[0].yson)
        stage_uuid = yt_yson_bindings.loads(stage_yson[1].yson)
        data = yt_yson_bindings.loads(stage_yson[2].yson)
        if data and data.get('meta'):
            return Stage(
                meta=Meta(values=data['meta'], id=stage_id, uuid=stage_uuid),
                spec=Spec(values=data['spec']),
                status=Status(values=data['status'])
            )
        else:
            return None

    @staticmethod
    def stage_full_info(stage_yson) -> Stage:
        last_timestamp = yt_yson_bindings.loads(stage_yson[3].yson)

        return Stage(
            meta=Meta(values=yt_yson_bindings.loads(stage_yson[0].yson)),
            spec=Spec(values=yt_yson_bindings.loads(stage_yson[1].yson)),
            status=Status(values=yt_yson_bindings.loads(stage_yson[2].yson)),
            last_timestamp=last_timestamp or None,  # getting rid of YsonEntity
            infra_service=yt_yson_bindings.loads(stage_yson[4].yson) or None,  # same here
            infra_environment=yt_yson_bindings.loads(stage_yson[5].yson) or None,  # same here
        )

    @staticmethod
    def notification_policy(policy_yson) -> NotificationPolicy:
        return NotificationPolicy(
            stage_id=yt_yson_bindings.loads(policy_yson[0].yson),
            spec=yt_yson_bindings.loads(policy_yson[1].yson),
        )

    async def select_stages(
        self,
        batch_size: int,
        timestamp: Optional[int] = None,
        state_from_annotation: bool = False,
        stage_ids: Optional[List[str]] = None,
    ) -> AsyncIterator[Stage]:
        if state_from_annotation:
            if not stage_ids:
                return

            selectors = [
                '/meta/id',
                '/meta/uuid',
                '/annotations/notifications_last_state',
            ]

            stage_filter = "[/meta/id] IN (%s)" % (",".join(repr(stage_id) for stage_id in stage_ids))

            stage_factory = self.stage_from_annotation
        else:
            selectors = [
                '/meta',
                '/spec',
                '/status',
                '/labels/notifications_last_timestamp',
                '/labels/infra_service',
                '/labels/infra_environment',
            ]

            stage_filter = self.stage_filter

            stage_factory = self.stage_full_info

        stages = self._select_objects_values(
            object_type=data_model.OT_STAGE,
            batch_size=batch_size,
            selectors=selectors,
            filter=stage_filter,
            timestamp=timestamp,
        )

        async for s in stages:
            stage = stage_factory(s)
            if stage is not None:
                yield stage

    @staticmethod
    def fill_req_last_timestamp_label_update(
        req,
        object_id: str,
        last_timestamp: int,
    ):
        req.object_type = data_model.OT_STAGE
        req.object_id = object_id

        update = req.set_updates.add()
        update.path = '/labels/notifications_last_timestamp'
        update.value_payload.yson = yt_yson_bindings.dumps(last_timestamp)

    async def save_stage_state(
        self,
        object_id: str,
        timestamp: int,
        state: Optional[dict],
    ):
        req = object_service_pb2.TReqUpdateObject()
        YpClient.fill_req_last_timestamp_label_update(
            req,
            object_id=object_id,
            last_timestamp=timestamp,
        )

        if state is not None:
            update = req.set_updates.add()
            update.path = '/annotations/notifications_last_state'
            update.value_payload.yson = yt_yson_bindings.dumps(state)

        try:
            await self._apply(self.slave.UpdateObject, req)
        except YpNoSuchObjectError:
            # FIXME we will still need to ensure state save when it's moved into separate NotificationState object
            self.log.error(
                "[%r] timestamp %r was not saved due to YpNoSuchObjectError",
                object_id,
                timestamp,
            )

    async def save_stages_last_timestamp_label(
        self,
        stages: List[Stage],
        batch_size: int,
    ):
        updates = []
        for stage in stages:
            updates.append((stage.id, stage.last_timestamp))

        initial_batch_size = batch_size

        start = 0
        total = len(stages)

        while start < total:
            req = object_service_pb2.TReqUpdateObjects()

            for stage_id, last_timestamp in updates[start:start + batch_size]:
                subreq = req.subrequests.add()
                YpClient.fill_req_last_timestamp_label_update(
                    subreq,
                    object_id=stage_id,
                    last_timestamp=last_timestamp,
                )

            try:
                await self._apply(self.slave.UpdateObjects, req)
            except (GrpcResourceExhaustedError, YpNoSuchObjectError, YtResponseError) as e:
                # we assume that it's always some problem with a single object, and not global error
                # so we can safely spend time bisecting the batch
                if batch_size < 2:
                    stage_id, last_timestamp = updates[start]

                    self.log.error(
                        "[%r] last_timestamp %r was not saved due to error %r",
                        stage_id,
                        last_timestamp,
                        e.message,
                    )

                    start += 1
                    batch_size = initial_batch_size
                    continue
                batch_size //= 2
                continue
            else:
                start += batch_size
                batch_size = initial_batch_size

    async def watch_stages(
        self,
        from_timestamp: int,
        batch_size: int,
    ) -> AsyncIterator[Tuple[int, str]]:
        req = object_service_pb2.TReqWatchObjects()
        req.start_timestamp = from_timestamp
        req.object_type = data_model.OT_STAGE
        # req.event_count_limit = batch_size
        req.event_count_limit = 0

        while True:
            try:
                resp = await self._apply(self.slave.WatchObjects, req)
            except GrpcResourceExhaustedError:
                if req.event_count_limit < 2:
                    raise
                req.event_count_limit //= 2
                continue
            except YpRowsAlreadyTrimmedError:
                break

            read = 0
            req.continuation_token = resp.continuation_token

            for ev in resp.events:
                read += 1
                if ev.event_type == data_model.ET_OBJECT_UPDATED:  # creates and removes will be processed in poll
                    yield (
                        ev.timestamp,
                        # ev.event_type,
                        ev.object_id,
                    )

            # TODO after YP problems fixed (currently read == limit means you're ****ed up)
            # if read < req.event_count_limit:
            #     break
            break

    async def select_stage_last_timestamp(
        self,
        object_id: str,
        object_uuid: str,
    ) -> int:
        req = object_service_pb2.TReqSelectObjectHistory()
        req.object_type = data_model.OT_STAGE
        req.selector.paths.extend(['/meta/id'])
        req.format = object_service_pb2.PF_YSON
        req.options.limit = 1
        req.options.descending_time_order = True

        req.object_id = object_id
        req.options.uuid = object_uuid

        resp = await self._apply(self.slave.SelectObjectHistory, req)
        ev = next(iter(resp.events), None)
        return ev.time.ToNanoseconds()

    async def select_stage_history(
        self,
        # batch_size: int,
        object_id: str,
        object_uuid: str,
        from_timestamp: int,
        to_timestamp: int,
        limit: int = 20,
    ) -> AsyncIterator[Optional[StageHistoryChange]]:
        req = object_service_pb2.TReqSelectObjectHistory()
        req.object_type = data_model.OT_STAGE
        req.object_id = object_id
        req.selector.paths.extend(['/meta/project_id', '/meta/acl', '/spec', '/status'])
        req.format = object_service_pb2.PF_YSON
        req.options.uuid = object_uuid
        # req.options.limit = batch_size
        # req.options.limit = 0
        # FIXME it's temporary solution that can probably result in duplicate results.
        # We have to use it since native pagination in YP is broken and some stages
        # have single snapshot size ≈15MB, so they do not fit into memory nor grpc result
        req.options.limit = limit
        req.options.interval.begin.FromNanoseconds(from_timestamp)
        req.options.interval.end.FromNanoseconds(to_timestamp)
        req.options.descending_time_order = False

        while True:
            try:
                resp = await self._apply(self.slave.SelectObjectHistory, req)
            except (GrpcResourceExhaustedError, YtResponseError) as e:
                if YpClient.decrease_limit_and_continue(e, req):
                    req.options.limit //= 2
                    continue

                raise

            read = 0

            for ev in resp.events:
                read += 1

                event = EventMeta(author=ev.user, timestamp=ev.time.ToNanoseconds())

                change = None

                if ev.event_type == data_model.ET_OBJECT_CREATED:
                    meta = Meta(project_id=yt_yson_bindings.loads(ev.results.value_payloads[0].yson))

                    change = StageHistoryCreate(
                        event=event,
                        meta=meta,
                    )

                if ev.event_type == data_model.ET_OBJECT_REMOVED:
                    change = StageHistoryRemove(event=event)

                if ev.event_type == data_model.ET_OBJECT_UPDATED:
                    meta = Meta(
                        project_id=yt_yson_bindings.loads(ev.results.value_payloads[0].yson),
                        acl=yt_yson_bindings.loads(ev.results.value_payloads[1].yson),
                    )

                    spec = Spec(values=yt_yson_bindings.loads(ev.results.value_payloads[2].yson))
                    status = Status(values=yt_yson_bindings.loads(ev.results.value_payloads[3].yson))

                    change = StageHistoryUpdate(
                        event=event,
                        meta=meta,
                        spec=spec,
                        status=status,
                    )

                yield change

            # TODO after YP problems fixed
            # if read < req.options.limit:
            #     break
            # else:
            #     req.options.continuation_token = resp.continuation_token
            break

    async def select_notification_policies(
        self,
        batch_size: int,
        timestamp: Optional[int] = None,
    ) -> AsyncIterator[NotificationPolicy]:
        selectors = [
            '/meta/stage_id',
            '/spec'
        ]
        policies = self._select_objects_values(
            object_type=data_model.OT_NOTIFICATION_POLICY,
            batch_size=batch_size,
            selectors=selectors,
            timestamp=timestamp,
        )

        async for p in policies:
            policy = self.notification_policy(p)
            if policy is not None:
                yield policy
