import asyncio
from typing import Any, AsyncIterator, Optional, List, Tuple

import yt_yson_bindings
from yp import data_model
from yp.common import GrpcResourceExhaustedError, YtResponseError
from yp_proto.yp.client.api.proto import object_service_pb2

from infra.deploy_queue_controller.lib.models import Stage, Ticket, ReleaseRule, Release
from infra.deploy_queue_controller.lib.yputil import yson_to_proto


class YpClient:
    class NoAcl(Exception):
        pass

    def __init__(self, slave):
        self.slave = slave
        self.loop = asyncio.get_event_loop()
        # self.loop.set_default_executor(concurent.futures.ThreadPoolExecutor(max_workers=10))  # FIXME

    async def _apply(self, fn, *args) -> Any:
        return await asyncio.get_event_loop().run_in_executor(None, fn, *args)

    async def _select_objects_values(
        self,
        object_type: int,
        batch_size: int,
        selectors: Optional[List[str]] = None,
        filter: Optional[str] = None,
        timestamp: Optional[int] = None,
    ) -> AsyncIterator[List[object_service_pb2.TPayload]]:
        req = object_service_pb2.TReqSelectObjects()
        req.object_type = object_type
        if timestamp is not None:
            req.timestamp = timestamp
        if filter is not None:
            req.filter.query = filter
        selectors = selectors or ['']
        req.selector.paths.extend(selectors)
        req.format = object_service_pb2.PF_YSON
        req.options.limit = batch_size

        while True:
            try:
                resp = await self._apply(self.slave.SelectObjects, req)
            except GrpcResourceExhaustedError:
                if req.options.limit < 2:
                    raise
                req.options.limit //= 2
                continue

            read = 0

            for r in resp.results:
                read += 1
                yield r.value_payloads

            if read < req.options.limit:
                break
            else:
                req.options.continuation_token = resp.continuation_token

    async def _get_objects_by_ids(
        self,
        object_type: int,
        object_ids: List[str],
        batch_size: int,
        selectors: Optional[List[str]] = None,
        timestamp: Optional[int] = None,
    ) -> AsyncIterator[List[object_service_pb2.TPayload]]:
        req = object_service_pb2.TReqGetObjects()
        req.object_type = object_type
        if timestamp is not None:
            req.timestamp = timestamp

        selectors = selectors or ['']
        req.selector.paths.extend(selectors)
        req.format = object_service_pb2.PF_YSON
        req.options.ignore_nonexistent = True
        req.options.fetch_values = True

        shift = 0
        while shift < len(object_ids):
            del req.subrequests[:]
            for oid in object_ids[shift:shift + batch_size]:
                subreq = req.subrequests.add()
                subreq.object_id = oid

            try:
                resp = await self._apply(self.slave.GetObjects, req)
            except GrpcResourceExhaustedError:
                if batch_size < 2:
                    raise
                batch_size //= 2
                continue

            shift += batch_size

            for r in resp.subresponses:
                if r.result:
                    yield r.result.value_payloads

    async def generate_timestamp(self) -> int:
        req = object_service_pb2.TReqGenerateTimestamp()
        resp = await self._apply(self.slave.GenerateTimestamp, req)
        return resp.timestamp

    async def select_stages(
        self,
        batch_size: int,
        timestamp: Optional[int] = None,
    ) -> AsyncIterator[Stage]:
        stages = self._select_objects_values(
            object_type=data_model.OT_STAGE,
            batch_size=batch_size,
            selectors=['/meta', '/spec', '/status'],
            timestamp=timestamp,
        )
        async for s in stages:
            yield Stage(
                meta=yson_to_proto(s[0], data_model.TStageMeta),
                spec=yson_to_proto(s[1], data_model.TStageSpec),
                status=yson_to_proto(s[2], data_model.TStageStatus),
            )

    async def select_release_rules(
        self,
        batch_size: int,
        timestamp: Optional[int] = None,
    ) -> AsyncIterator[Tuple[str, ReleaseRule]]:
        rules = self._select_objects_values(
            object_type=data_model.OT_RELEASE_RULE,
            batch_size=batch_size,
            selectors=[
                '/meta/id',
                '/spec/auto_commit_policy/type',
                '/spec/auto_commit_policy/maintain_active_trunk_options/deployment_termination_policy',
            ],
            timestamp=timestamp,
            filter='[/spec/auto_commit_policy/type] != "none" AND [/spec/auto_commit_policy/type] != #',
        )
        async for rule in rules:
            termination_policy = yt_yson_bindings.loads(rule[2].yson) or ReleaseRule.DeploymentTerminationPolicy.WAIT_FOR_COMPLETION
            yield (
                yt_yson_bindings.loads(rule[0].yson),
                ReleaseRule(
                    mode=ReleaseRule.CommitMode(yt_yson_bindings.loads(rule[1].yson)),
                    trunk_deployment_termination_policy=ReleaseRule.DeploymentTerminationPolicy(termination_policy),
                ),
            )

    async def select_releases(
        self,
        release_ids: List[str],
        batch_size: int,
        timestamp: Optional[int] = None,

    ) -> AsyncIterator[Release]:
        releases = self._get_objects_by_ids(
            object_type=data_model.OT_RELEASE,
            object_ids=release_ids,
            batch_size=batch_size,
            selectors=['/meta', '/spec'],
            timestamp=timestamp,
        )
        async for r in releases:
            yield Release(
                meta=yson_to_proto(r[0], data_model.TReleaseMeta),
                spec=yson_to_proto(r[1], data_model.TReleaseSpec),
            )

    async def select_tickets(
        self,
        batch_size: int,
        timestamp: Optional[int] = None,
    ) -> AsyncIterator[Any]:
        tickets = self._select_objects_values(
            object_type=data_model.OT_DEPLOY_TICKET,
            batch_size=batch_size,
            selectors=['/meta', '/spec', '/status'],
            timestamp=timestamp,
            # TODO we also need to use some index on ticket.release_rule.autocommit == true
            filter=(
                '[/status/progress/closed/status] != "true"'
                ' AND [/status/action/type] != "skip"'
                # ' AND [/status/action/type] != "commit"'
            ),
        )
        async for t in tickets:
            yield Ticket(
                meta=yson_to_proto(t[0], data_model.TDeployTicketMeta),
                spec=yson_to_proto(t[1], data_model.TDeployTicketSpec),
                status=yson_to_proto(t[2], data_model.TDeployTicketStatus),
            )

    async def submit_ticket_actions(
        self,
        batch_size: int,
        cancels: List[str],
        commits: List[Tuple[str, str]],
        waits: List[Tuple[str, str]],
        timestamp: int,
    ) -> None:
        updates = []
        initial_batch_size = batch_size

        cancel_action = data_model.TDeployTicketControl.TSkipAction()
        cancel_action.options.message = 'Ticket cancelled because more recent has been applied.'
        cancel_action.options.reason = 'SUPERSEDED'
        cancel_action.options.patch_selector.type = data_model.DTPST_FULL
        cancel_action_value = yt_yson_bindings.dumps_proto(cancel_action)
        cancel_path = "/control/skip"

        commit_action = data_model.TDeployTicketControl.TCommitAction()
        commit_action.options.reason = 'AUTOCOMMIT'
        commit_action.options.patch_selector.type = data_model.DTPST_FULL
        commit_path = "/control/commit"

        # TODO do we have control for waits?
        wait_action = data_model.TDeployPatchAction()
        wait_action.type = data_model.DPAT_WAIT
        wait_action.reason = 'WAITING_IN_QUEUE'
        wait_path = '/status/action'

        for ticket_id in cancels:
            updates.append((ticket_id, cancel_path, cancel_action_value))

        for ticket_id, message in commits:
            commit_action.options.message = f'Ticket applied due to autocommit policy: {message}'
            commit_action_value = yt_yson_bindings.dumps_proto(commit_action)
            updates.append((ticket_id, commit_path, commit_action_value))

        # FIXME we need to think of a way to reduce memory usage here
        wait_actions_cache = {}
        for ticket_id, message in waits:
            if message not in wait_actions_cache:
                wait_action.message = message
                wait_actions_cache[message] = yt_yson_bindings.dumps_proto(wait_action)
            updates.append((ticket_id, wait_path, wait_actions_cache[message]))

        start = 0
        total = len(updates)

        while start < total:
            req = object_service_pb2.TReqUpdateObjects()

            for ticket_id, path, action in updates[start:start + batch_size]:
                subreq = req.subrequests.add()
                subreq.object_type = data_model.OT_DEPLOY_TICKET
                subreq.object_id = ticket_id

                prereq = subreq.attribute_timestamp_prerequisites.add()
                prereq.path = '/spec'
                prereq.timestamp = timestamp

                update = subreq.set_updates.add()
                update.path = path
                update.value_payload.yson = action

            try:
                await self._apply(self.slave.UpdateObjects, req)
            except GrpcResourceExhaustedError:
                if batch_size < 2:
                    raise
                batch_size //= 2
                continue
            except YtResponseError:
                # we assume that it's always some problem with a single object, and not global error
                # so we can safely spend time bisecting the batch
                if batch_size < 2:
                    # TODO log skipped chunk
                    start += 1
                    batch_size = initial_batch_size
                    continue
                batch_size //= 2
                continue
            else:
                start += batch_size
                batch_size = initial_batch_size
