import ipaddress
import typing as tp
import logging

from yandex.cloud.priv.dns.v1.dns_zone_pb2 import RecordSet
from yandex.cloud.priv.dns.v1.dns_zone_service_pb2 import (
    ListDnsZoneRecordSetsRequest,
    UpsertRecordSetsRequest,
    GetDnsZoneRequest,
)
from yandex.cloud.priv.dns.v1.dns_zone_service_pb2_grpc import DnsZoneServiceStub

from infra.walle.server.contrib.yc_python_sdk import yandexcloud
from infra.walle.server.contrib.yc_python_sdk.yandexcloud import operations as yandexcloud_operations
from walle import yc
from walle.constants import DEFAULT_DNS_TTL
from .interface import DnsClientInterface, DnsMultipleEndpoints, DnsError
from .operations import DnsApiOperation, DnsOperationAdd, DnsOperationDelete


logger = logging.getLogger(__name__)


class RurikkDnsClient(DnsClientInterface):
    def __init__(self, key_id, service_account_id, private_key, dns_zone_id, dns_endpoint, iam_endpoint):
        sa_key = {
            "id": key_id,
            "service_account_id": service_account_id,
            "private_key": private_key,
        }
        self._sdk = yandexcloud.SDK(service_account_key=sa_key, dns_endpoint=dns_endpoint, iam_endpoint=iam_endpoint)
        self._dns_zone_service = self._sdk.client(DnsZoneServiceStub)
        self._dns_zone_id = dns_zone_id

    def _get_record(self, name: str, record_type: str) -> tp.Optional[str]:
        list_result = list(
            self._dns_zone_service.ListRecordSets(
                ListDnsZoneRecordSetsRequest(
                    dns_zone_id=self._dns_zone_id,
                    filter=f"type='{record_type}' AND name='{name}'",
                )
            ).record_sets
        )

        if not list_result:
            return
        elif len(list_result) > 1 or len(list_result[0].data) > 1:
            raise DnsMultipleEndpoints(name, list_result, record_type=record_type)
        return list_result[0].data[0]

    def get_aaaa(self, hostname):
        res = self._get_record(hostname, "AAAA")
        if res:
            return ipaddress.ip_address(res).exploded

    def get_a(self, hostname):
        return self._get_record(hostname, "A")

    def get_ptr(self, ip_address):
        # NOTE(rocco66): rurikk_dns has automatic for PTR creation https://st.yandex-team.ru/CLOUD-78777
        return

    def _mk_record(self, operation: DnsApiOperation):
        name = operation.name
        data = [operation.data]
        dns_zone_info = self.get_dns_zone_info()
        zone_host_suffix = dns_zone_info.get_host_suffix()
        if operation.type == "PTR":
            name, data = data[0], [name]
        if name.endswith(zone_host_suffix):
            name = name[: -(len(zone_host_suffix) + 1)]
        return RecordSet(
            name=name,
            type=operation.type,
            ttl=DEFAULT_DNS_TTL,
            data=data,
        )

    def apply_operations(self, operations: list[DnsApiOperation]):
        add_operation = [o for o in operations if isinstance(o, DnsOperationAdd)]
        add_names = {o.name for o in add_operation}
        delete_operation = [o for o in operations if isinstance(o, DnsOperationDelete) and o.name not in add_names]
        replacements = [self._mk_record(o) for o in add_operation]
        deletions = [self._mk_record(o) for o in delete_operation]
        try:
            result_operation = self._dns_zone_service.UpsertRecordSets(
                UpsertRecordSetsRequest(
                    dns_zone_id=self._dns_zone_id,
                    replacements=replacements,
                    deletions=deletions,
                )
            )
        except Exception as exc:
            raise DnsError(str(exc))
        try:
            self._sdk.wait_operation_and_get_result(result_operation, timeout=60)
        except yandexcloud_operations.OperationError as exc:
            logger.error(
                "Rurikk DNS upsert request was failed: zone_id=%s, replacements=%s, deletions=%s",
                self._dns_zone_id,
                replacements,
                deletions,
            )
            raise DnsError(exc.message)

    def get_dns_zone_info(self) -> yc.YcDnsZone:
        dns_zone_response = self._dns_zone_service.Get(GetDnsZoneRequest(dns_zone_id=self._dns_zone_id))
        return yc.YcDnsZone(self._dns_zone_id, dns_zone_response.zone, dns_zone_response.folder_id)
