import time
import logging

from nanny_repo import repo_api_stub, repo_api_pb2, repo_pb2
from nanny_tickets import tickets_pb2, tickets_api_stub, tickets_api_pb2
from nanny_rpc_client import RetryingRpcClient
from infra.nanny.nanny_services_rest.nanny_services_rest.client import ServiceRepoClient

from infra.nanny_ci.common import common_pb2


log = logging.getLogger(__name__)


def set_snapshot_state(
    nanny_client: ServiceRepoClient,
    service_id: str,
    snapshot_id: str,
    target_state: str,
    ticket: str,
    prepare_recipe: str = '',
    activate_recipe: str = '',
    comment: str = '',
):
    req = {
        'type': 'SET_SNAPSHOT_STATE',
        'content': {
            'comment': comment,
            'set_as_current': True,
            'snapshot_id': snapshot_id,
            'state': target_state,
            'tracked_tickets': {
                'tickets': [],
            },
        }
    }
    if target_state == 'ACTIVE':
        req['content']['tracked_tickets']['tickets'].append({'ticket_id': ticket})
    if prepare_recipe:
        req['content']['prepare_recipe'] = prepare_recipe
    if activate_recipe:
        req['content']['recipe'] = activate_recipe

    nanny_client.create_event(service_id, req)


def wait_for_snapshot_status(
    nanny_client: RetryingRpcClient,
    service_id: str,
    snapshot_id: str,
    state: str,
    timeout: int = 24 * 60 * 60,
) -> bool:
    state_reached = False
    deadline = time.time() + timeout

    stub = repo_api_stub.RepoServiceStub(nanny_client)

    while not state_reached and time.time() <= deadline:
        service = stub.get_service(repo_api_pb2.GetServiceRequest(
            service_id=service_id,
            consistency=repo_api_pb2.STRONG,
        )).service

        spec_snapshot_valid = any(
            snapshot.id == snapshot_id
            and repo_pb2.Snapshot.Target.Name(snapshot.target) == state
            for snapshot in service.spec.snapshot
        )

        if not spec_snapshot_valid:
            raise Exception(
                f"Snapshot {snapshot_id!r} not found in service {service_id!r}"
                f" or its target state != {state}"
            )

        snapshot = next(filter(lambda s: s.id == snapshot_id, service.status.snapshot), None)
        snapshot_state = None if snapshot is None else repo_pb2.SnapshotStatus.Status.Name(snapshot.status)

        log.info("state=%r, required=%r", snapshot_state, state)
        state_reached = snapshot_state == state

        if not state_reached:
            time.sleep(max(0, min(5, deadline - time.time())))

    return state_reached


def activate_snapshot(
    nanny_client: RetryingRpcClient,
    ticket_id: str,
    prepare_recipe: str,
    activate_recipe: str,
    recipe_parameters: dict[str, str] = {},
) -> tuple[str, str]:
    req = tickets_api_pb2.CreateTicketEventRequest()
    req.ticket_id = ticket_id
    req.spec.type = tickets_pb2.EventSpec.ACTIVATE_SERVICE_SNAPSHOT
    req.spec.activate_snapshot.prepare_recipe = prepare_recipe
    req.spec.activate_snapshot.activate_recipe = activate_recipe
    req.spec.activate_snapshot.tracked_tickets.extend([ticket_id])
    req.spec.activate_snapshot.recipe_parameters.update(recipe_parameters)

    stub = tickets_api_stub.TicketServiceStub(nanny_client)
    response = stub.create_ticket_event(req)
    log.info(
        "activated snapshot %r",
        response.ticket.spec.service_deployment.snapshot_id,
    )
    return (
        response.event.id,
        response.ticket.spec.service_deployment.snapshot_id,
    )


def get_activate_recipe(pb) -> tuple[str, dict[str, str]]:
    recipe_kind = pb.WhichOneof('activate_recipe_group')
    if recipe_kind == 'activate_recipe':
        return pb.activate_recipe, {}

    recipe = pb.default_activate_recipe.recipe
    if recipe == common_pb2.DefaultRecipe.RECIPE_UNKNOWN:
        recipe = common_pb2.DefaultRecipe.LOCATIONWISE

    if recipe == common_pb2.DefaultRecipe.LOCATIONWISE:
        return (
            'nanny_default._activate_service_configuration_locationwise_auto',
            {
                'manual': 'true' if pb.default_activate_recipe.manual_confirmation.value else 'false',
            }
        )

    raise Exception("Unsupported default recipe")


def get_activate_prefix(pb) -> str:
    recipe_kind = pb.WhichOneof('activate_recipe_group')
    if recipe_kind == 'activate_recipe':
        return pb.activate_prefix or 'activate_'

    recipe = pb.default_activate_recipe.recipe
    if recipe == common_pb2.DefaultRecipe.RECIPE_UNKNOWN:
        recipe = common_pb2.DefaultRecipe.LOCATIONWISE

    if recipe == common_pb2.DefaultRecipe.LOCATIONWISE:
        return 'activate_'

    raise Exception("Unsupported default recipe")


def collect_locations(
    nanny_client: RetryingRpcClient,
    service_id: str,
    snapshot_id: str,
    activate_prefix: str = 'activate_',
) -> list[tuple[str, str, str]]:
    clusters = set()

    log.info("retrieving taskgroup for service %r snapshot %r", service_id, snapshot_id)
    deadline = time.time() + 15 * 60
    taskgroup_id = None
    while time.time() <= deadline:
        current_state = nanny_client.get_current_state(service_id)
        snapshots = current_state['content']['active_snapshots']
        target_snapshot = next(filter(lambda s: s['snapshot_id'] == snapshot_id, snapshots), None)
        if (
            target_snapshot is not None
            and target_snapshot.get('taskgroup_id')
            and target_snapshot.get('state') in ('ACTIVE', 'ACTIVATING')
        ):
            taskgroup_id = target_snapshot['taskgroup_id']
            log.info("got taskgroup_id %r", taskgroup_id)
            break

        log.info('taskgroup_id is not ready yet')
        time.sleep(2)

    log.info("collecting children for taskgroup %r", taskgroup_id)
    response = nanny_client._request_with_retries('GET', f'/v1/alemate/task_groups/{taskgroup_id}/children/')
    for task in response:
        parts = task['id'].split('/')
        if (
            parts[0] == taskgroup_id
            and parts[-1].startswith(activate_prefix)
            and not task['runtimeOptions']['hasConfirm']
        ):
            cluster = parts[-1][len(activate_prefix):]
            clusters.add((cluster, parts[-1]))

    return [(taskgroup_id, *cluster) for cluster in sorted(clusters)]


def make_build_name(context) -> str:
    if context.target_revision.pull_request_id:
        return (
            f"Pull request {context.target_revision.pull_request_id} "
            f"hash {context.target_revision.hash}"
        )
    else:
        return (
            f"Commit {context.target_revision.number} "
            f"hash {context.target_revision.hash}"
        )
