import logging
import urllib.parse

from infra.nanny_ci import util
from infra.nanny_ci.common import common_pb2
from infra.nanny_ci.set_snapshot_state.proto import set_snapshot_state_tasklet

from ci.tasklet.common.proto import service_pb2 as ci
from tasklet.services.yav.proto import yav_pb2 as yav

from nanny_rpc_client import RetryingRpcClient
from infra.nanny.nanny_services_rest.nanny_services_rest.client import ServiceRepoClient


log = logging.getLogger(__name__)


class SetSnapshotStateImpl(set_snapshot_state_tasklet.SetSnapshotStateBase):
    steps: list
    nanny_rest_client: ServiceRepoClient
    nanny_repo_client: RetryingRpcClient
    nanny_tickets_client: RetryingRpcClient

    snapshot: common_pb2.Snapshot

    @staticmethod
    def snapshot_link(
        service_id: str,
        snapshot_id: str,
        base_url: str = 'https://nanny.yandex-team.ru',
    ) -> str:
        return urllib.parse.urljoin(
            base_url,
            f'ui/#/services/catalog/{service_id}/runtime_attrs_history/{snapshot_id}/',
        )

    def save_progress(
        self,
        progress_ratio: float,
        message: str,
        step_id: str,
        status=ci.TaskletProgress.Status.RUNNING,
        snapshot_id: str = '',
    ):
        url = self.snapshot_link(
            base_url=self.snapshot.nanny_installation or 'https://nanny.yandex-team.ru',
            service_id=self.snapshot.service_id,
            snapshot_id=self.snapshot.snapshot_id,
        )
        module = 'NANNY'

        progress = ci.TaskletProgress()
        progress.job_instance_id.CopyFrom(self.input.context.job_instance_id)
        progress.id = step_id
        progress.progress = progress_ratio
        progress.text = message
        progress.module = module
        progress.url = url
        progress.status = status
        self.ctx.ci.UpdateProgress(progress)

    def get_token(self):
        secret_uid = self.input.context.secret_uid
        secret_key = 'nanny_ci.token'

        spec = yav.YavSecretSpec(uuid=secret_uid, key=secret_key)
        return self.ctx.yav.get_secret(spec, default_key=secret_key).secret

    def select_snapshot(self):
        if self.input.config.HasField('snapshot'):
            matcher = self.input.config.snapshot
        else:
            matcher = self.input.snapshot

        snapshots = list(filter(
            lambda snp: (
                (not matcher.nanny_installation or snp.nanny_installation == matcher.nanny_installation)
                and (not matcher.service_id or snp.service_id == matcher.service_id)
                and (not matcher.snapshot_id or snp.snapshot_id == matcher.snapshot_id)
            ),
            self.input.snapshots,
        ))
        if len(snapshots) > 1:
            raise Exception(f"Multiple snapshots match filter: {matcher}")
        elif not snapshots:
            raise Exception(f"No matching snapshots found for filter: {matcher}")

        self.snapshot = snapshots[0]

    def prepare_clients(self):
        token = self.get_token()
        self.nanny_rest_client = ServiceRepoClient(
            self.snapshot.nanny_installation,
            token=token,
            timeout=300,
        )
        self.nanny_repo_client = RetryingRpcClient(
            urllib.parse.urljoin(self.snapshot.nanny_installation, '/api/repo'),
            oauth_token=token,
            request_timeout=300,
            retry_5xx=True,
            retry_connection_errors=True,
        )
        self.nanny_tickets_client = RetryingRpcClient(
            urllib.parse.urljoin(self.snapshot.nanny_installation, '/api/tickets'),
            oauth_token=token,
            request_timeout=300,
            retry_5xx=True,
            retry_connection_errors=True,
        )

    def prepare_steps(self):
        self.steps = [
            (self.validate, "Validating config"),
            (self.set_snapshot_state, "Setting snapshot state"),
            (self.activate_snapshot, "Activating snapshot if target_status=ACTIVE"),
            (self.collect_locations, "Getting affected locations"),
            (self.wait_for_state, "Waiting for snapshot state"),
        ]

    def validate(self):
        if self.input.config.target_status == common_pb2.UNKNOWN:
            raise Exception("Tasklet misconfigured: `target_status` must not be UNKNOWN")

        if not self.snapshot.service_id:
            raise Exception("Tasklet misconfigured: `service_id` must be set")

        if not self.snapshot.snapshot_id:
            raise Exception("Tasklet misconfigured: `snapshot_id` must be set")

    def set_snapshot_state(self):
        if self.input.config.target_status == common_pb2.ACTIVE:
            return

        activate_recipe, parameters = util.get_activate_recipe(self.input.config)

        util.set_snapshot_state(
            nanny_client=self.nanny_rest_client,
            service_id=self.snapshot.service_id,
            snapshot_id=self.snapshot.snapshot_id,
            target_state=common_pb2.TargetStatus.Name(self.input.config.target_status),
            prepare_recipe=self.input.config.prepare_recipe,
            activate_recipe=activate_recipe,
            ticket=self.snapshot.ticket_id,
            comment=f"CI Build: {util.make_build_name(self.input.context)}",
        )

    def activate_snapshot(self):
        if self.input.config.target_status != common_pb2.ACTIVE:
            return

        prepare_recipe = self.input.config.prepare_recipe
        activate_recipe, parameters = util.get_activate_recipe(self.input.config)

        util.activate_snapshot(
            self.nanny_tickets_client,
            self.snapshot.ticket_id,
            prepare_recipe,
            activate_recipe,
            recipe_parameters=parameters,
        )
        log.info("activated snapshot for ticket %r", self.snapshot.ticket_id)

    def _get_known_locations(self) -> set[str]:
        return {
            location.cluster
            for location in self.input.approval_locations
            if location.nanny_installation == self.input.config.nanny_installation
            and location.service_id == self.input.config.service_id
        }

    def collect_locations(self):
        if self.input.config.target_status != common_pb2.ACTIVE:
            return

        activate_prefix = util.get_activate_prefix(self.input.config)

        known_locations = self._get_known_locations()
        for taskgroup_id, cluster, task_id in util.collect_locations(
            self.nanny_rest_client,
            self.snapshot.service_id,
            self.snapshot.snapshot_id,
            activate_prefix=activate_prefix,
        ):
            if cluster in known_locations:
                continue

            location = self.output.approval_locations.add()
            location.nanny_installation = self.snapshot.nanny_installation
            location.service_id = self.snapshot.service_id
            location.taskgroup_id = taskgroup_id
            location.cluster = cluster
            location.task_id = task_id

        log.info(
            "collected clusters: %s",
            [location.cluster for location in self.output.approval_locations],
        )

    def wait_for_state(self):
        if len(self.output.approval_locations) and self.input.config.target_status == common_pb2.ACTIVE:
            log.info(
                "Do not waiting for activation: approves required for locations %s",
                [location.cluster for location in self.output.approval_locations],
            )
            return

        timeout = 24 * 60 * 60
        deployed = util.wait_for_snapshot_status(
            self.nanny_repo_client,
            self.snapshot.service_id,
            self.snapshot.snapshot_id,
            state=common_pb2.TargetStatus.Name(self.input.config.target_status),
            timeout=timeout,
        )
        if not deployed:
            raise Exception(f"Wait for deploy failed (not finished in {timeout} seconds)")

    def run(self):
        self.select_snapshot()
        self.prepare_clients()
        self.prepare_steps()

        for idx, (step, message) in enumerate(self.steps):
            log.info("[%2d] %s", idx, message)
            self.save_progress(progress_ratio=idx / len(self.steps), message=message, step_id='progress')
            try:
                step()
            except Exception as e:
                self.save_progress(
                    progress_ratio=1.,
                    message=f"Release failed: {e}",
                    step_id='progress',
                    status=ci.TaskletProgress.Status.FAILED,
                )
                self.output.state.success = False
                self.output.state.message = f"Step {idx} failed: {e}"
                raise

        self.save_progress(
            progress_ratio=1.,
            message="Done.",
            step_id='progress',
            status=ci.TaskletProgress.Status.SUCCESSFUL,
        )
        self.output.state.success = True
        self.output.state.message = "Success"
