# coding=utf-8
from concurrent import futures
import datetime
import logging
import random
import socket
import threading

import grpc
import time

from tasklet.experimental.cli import consts
import tasklet.api.v2.schema_registry_service_pb2_grpc as schema_registry_service_grpc
import tasklet.api.v2.tasklet_service_pb2_grpc as tasklet_service_grpc

log = logging.getLogger(__name__)


class OAuthMetadataPlugin(grpc.AuthMetadataPlugin):
    AUTH_HEADER = "authorization"
    AUTH_METHOD = "OAuth"

    def __init__(self, token: str):
        self._token = token

    def __call__(self, context: grpc.AuthMetadataContext, callback: grpc.AuthMetadataPluginCallback) -> None:
        callback(
            ((self.AUTH_HEADER, f"{self.AUTH_METHOD} {self._token}"),),
            None
        )


@property
def ci_endpoints(self):
    if not self._ci_endpoints:
        self._ci_endpoints = self._get_all_ci_endpoints()
    return self._ci_endpoints


@property
def ci_endpoint_next(self):
    endpoint = self.ci_endpoints.pop()
    return endpoint


class Driver:
    """

    NB: Not thread safe
    """

    def __init__(self, cluster: consts.ClusterConfig):
        self.cluster = cluster
        self._endpoints = []  # type: list[str]
        self._last_resolve = datetime.datetime.min
        self._resolver_tp = futures.ThreadPoolExecutor(max_workers=1)
        self._mtx = threading.Lock()
        self._secure: bool = False

    def _do_resolve_endpoints(self) -> list[str]:
        from infra.yp_service_discovery.api import api_pb2
        from infra.yp_service_discovery.python.resolver import resolver

        client_name = socket.gethostname()
        yp_resolver = resolver.Resolver(client_name=client_name, timeout=5)

        available_grpc_hosts = []
        log.debug(
            "Resolving endpoint list for endpoint set %s in clusters %s",
            self.cluster.endpoint_set_id, self.cluster.clusters,
        )

        for cluster_name in self.cluster.clusters:
            request = api_pb2.TReqResolveEndpoints(
                cluster_name=cluster_name,
                endpoint_set_id=self.cluster.endpoint_set_id,
                client_name=client_name,
            )
            result = yp_resolver.resolve_endpoints(request)  # type: api_pb2.TRspResolveEndpoints
            available_in_cluster = []
            for endpoint in result.endpoint_set.endpoints:  # type: api_pb2.TEndpoint
                if not endpoint.ready:
                    continue
                available_in_cluster.append(f"{endpoint.fqdn}:{endpoint.port}")
            log.debug(
                "Resolved %s alive backends out of a total of %s in %s",
                len(available_in_cluster), len(result.endpoint_set.endpoints), cluster_name,
            )
            available_grpc_hosts.extend(available_in_cluster)

        random.shuffle(available_grpc_hosts)
        log.debug("Resolved %s alive backends in clusters %s", len(available_grpc_hosts), self.cluster.clusters)

        return available_grpc_hosts

    def _resolve_endpoints(self):
        try:
            available_grpc_hosts = self._do_resolve_endpoints()
        except Exception:
            log.warning("Resolve failed", exc_info=True)
            return

        if len(available_grpc_hosts) == 0:
            log.warning("Not updating backend list due to empty result")
            return

        with self._mtx:
            self._endpoints = available_grpc_hosts
            self._last_resolve = datetime.datetime.now()

    def _get_grpc_endpoint(self) -> str:
        if self.cluster.endpoint:
            return self.cluster.endpoint

        with self._mtx:
            for _ in range(10):
                self._endpoints = self._do_resolve_endpoints()
                self._last_resolve = datetime.datetime.now()
                if self._endpoints:
                    break
                # NB: poor man's retry backoff
                time.sleep(1)
            if not self._endpoints:
                raise RuntimeError("Failed to locate Tasklets service")

            if datetime.datetime.now() - self._last_resolve > datetime.timedelta(seconds=30):
                _ = self._resolver_tp.submit(self._resolve_endpoints)
                self._last_resolve = datetime.datetime.now()

            # NB: using old endpoint list
            return random.choice(self._endpoints)

    def _get_channel(self) -> grpc.Channel:
        grpc_host = self._get_grpc_endpoint()
        options = []

        # TODO: add ssl for secure channel
        # composite_credentials = grpc.composite_channel_credentials(
        #     grpc.ssl_channel_credentials(),
        #     grpc.metadata_call_credentials(OAuthMetadataPlugin(token))
        # )
        # channel = grpc.secure_channel(grpc_host, credentials=composite_credentials)
        # TODO: add wrapper for method calls, error handling and auth

        return grpc.insecure_channel(grpc_host, options)

    def get_tasklet_client(self) -> tasklet_service_grpc.TaskletServiceStub:
        channel = self._get_channel()
        return tasklet_service_grpc.TaskletServiceStub(channel)

    def get_schema_registry_client(self) -> schema_registry_service_grpc.SchemaRegistryServiceStub:
        channel = self._get_channel()
        return schema_registry_service_grpc.SchemaRegistryServiceStub(channel)
