# coding=utf-8
from typing import Callable, Optional

import copy
import enum
import os
import pathlib
import logging

import click
import grpc
import sys
import yaml
import inject

from google.protobuf import descriptor_pb2
from google.protobuf import struct_pb2

import tasklet.api.v2.schema_registry_service_pb2 as schema_registry

from tasklet.experimental.cli import interfaces
from tasklet.experimental.cli import consts
from tasklet.experimental.cli import driver
from tasklet.experimental.cli import tasklet_descriptor

log = logging.getLogger(__name__)


class TaskletContext(interfaces.ITaskletContext):
    """
    Tasklet service context

    Handles drivers and RPC
    """

    def __init__(self):
        self.cluster: Optional[consts.ClusterConfig] = None
        self.output: consts.Output = consts.Output.TABLE
        self.quiet: bool = False
        self.config: Optional["tasklet_descriptor.TaskletDescriptor"] = None
        self._token: Optional[str] = None
        self._driver: Optional[driver.Driver] = None
        self._auth: bool = True

    @property
    def token(self) -> str:
        if not self._auth:
            log.debug("Auth disabled")
            return ""

        if self._token:
            return self._token

        if "TASKLET_TOKEN" in os.environ:
            log.debug("Get token from environment")
            self._token = os.environ["TASKLET_TOKEN"]
            return self._token

        user_token_path = os.path.expanduser("~/.tasklet/token")
        if os.path.exists(user_token_path):
            log.debug(f"Get token from {user_token_path}")
            try:
                with open(user_token_path, "r") as f:
                    self._token = f.read().strip()
                    return self._token
            except Exception as err:
                log.debug(f"Failed to load token from {user_token_path}: {err}")

        log.debug("Generating OAuth token via ssh-agent")
        from library.python import oauth

        self._token = oauth.get_token(consts.CLIENT_ID, consts.CLIENT_SECRET)
        return self._token

    @property
    def driver(self) -> driver.Driver:
        if self._driver is None:
            self._driver = driver.Driver(self.cluster)
        return self._driver

    @property
    # TODO: remove, temporary implementation
    def _metadata(self):
        rv = []
        if self._auth:
            log.debug("Auth disabled")
            rv.append(
                (
                    driver.OAuthMetadataPlugin.AUTH_HEADER,
                    f"{driver.OAuthMetadataPlugin.AUTH_METHOD} {self.token}"
                )
            )
        return rv

    @staticmethod
    def _parse_rich_details(err: grpc.Call) -> list:
        from grpc_status import rpc_status
        from google.rpc import status_pb2, error_details_pb2

        status: status_pb2.Status = rpc_status.from_call(err)
        if status is None:
            return []
        for detail in status.details:
            if detail.Is(error_details_pb2.BadRequest.DESCRIPTOR):
                info = error_details_pb2.BadRequest()
                detail.Unpack(info)
                return list(f"{fv.field}: {fv.description}" for fv in info.field_violations)
        return []

    # Schema registry methods
    def register_schema(self, fds: descriptor_pb2.FileDescriptorSet, annotations: dict):
        from google.protobuf import json_format

        pb_annotations = struct_pb2.Struct()

        json_format.ParseDict(annotations, pb_annotations)

        req = schema_registry.CreateSchemaRequest(
            schema=fds,
            annotations=pb_annotations
        )

        client = self.driver.get_schema_registry_client()

        response: schema_registry.CreateSchemaResponse = self.execute_request(client.CreateSchema, req)
        return response

    def get_schema(self, fds_hash: str) -> schema_registry.GetSchemaResponse:
        req = schema_registry.GetSchemaRequest(
            hash=fds_hash,
        )

        client = self.driver.get_schema_registry_client()

        response: schema_registry.GetSchemaResponse = self.execute_request(client.GetSchema, req)
        return response

    def execute_request(self, method, message, **kwargs):
        kwargs = copy.deepcopy(kwargs)
        kwargs["metadata"] = kwargs.get("metadata", []) + self._metadata
        # noinspection PyPep8Naming
        DEV_HEADER_PREFIX = "TASKLET_INTERNAL_"
        for k, v in os.environ.items():
            if not k.startswith(DEV_HEADER_PREFIX):
                continue
            # NB: dev zone, ok to crash
            k_stripped = k[len(DEV_HEADER_PREFIX):].lower()
            assert k_stripped
            kwargs["metadata"].append(("tasklet-feature-" + k_stripped, v))

        test_user = os.environ.get("TASKLET_USER", "")
        if test_user:
            kwargs["metadata"].append(("x-test-user", test_user))

        try:
            return method(message, **kwargs)
        except grpc.RpcError as error:
            if isinstance(error, grpc.Call):
                error_code, error_msg = error.code(), error.details()
                if error_code in (
                    grpc.StatusCode.PERMISSION_DENIED, grpc.StatusCode.ABORTED, grpc.StatusCode.NOT_FOUND
                ):
                    click.echo(
                        f"{click.style('Failed:', fg='red')} Request failed with code {error_code}: {error_msg}",
                        err=True,
                    )
                    sys.exit(1)
                elif error_code == grpc.StatusCode.ALREADY_EXISTS:
                    click.echo(
                        f"{click.style('Warn:', fg='yellow')} Request failed with code {error_code}: {error_msg}",
                        err=True,
                    )
                    sys.exit(0)
                elif error_code == grpc.StatusCode.INVALID_ARGUMENT:
                    error_details = self._parse_rich_details(error)
                    click.echo(
                        f"{click.style('Error:', fg='red')} Request failed with code {error_code}: {error_msg}",
                        err=True,
                    )
                    for detail in error_details:
                        click.echo(" " * 4 + detail)
                    sys.exit(1)
                else:
                    click.echo(
                        f"{click.style('Error:', fg='red')} Request failed with code {error_code}: {error_msg}",
                        err=True,
                    )
                    error_details = self._parse_rich_details(error)
                    for detail in error_details:
                        click.echo(" " * 4 + detail)
                    sys.exit(1)
            raise

    def dump_proto_message(self, message):
        """

        :type message: google.protobuf.message.Message
        """

        from google.protobuf import json_format

        if self.quiet:
            return
        if self.output == consts.Output.TABLE:
            from tasklet.experimental.cli import pretty_dumps

            try:
                result = pretty_dumps.get_table_for_message(message).get_string()
            except Exception:
                log.error(f"Failed to dump message of type {type(message)}", exc_info=True)
                click.echo(
                    f"Failed to dump message of type {type(message)}: {json_format.MessageToDict(message)}",
                    err=True,
                )
                raise
        elif self.output == consts.Output.YAML:
            result = yaml.safe_dump(json_format.MessageToDict(message), default_flow_style=False)
        else:  # self._output == consts.Output.JSON:
            result = json_format.MessageToJson(message)
        click.echo(result)


# NB: Pass generic options like --cluster and --output to all subcommands
class ContextOption(click.Option):
    pass


class GroupWithContextOptions(click.Group):
    def add_command(self, cmd, name=None):
        super().add_command(cmd, name)
        cmd.params.extend([param for param in self.params if isinstance(param, ContextOption)])


def option_cluster(func: Callable):
    def callback(ctx: click.Context, param: click.Parameter, value: str):
        _, _ = ctx, param
        # noinspection PyTypeChecker
        tasklet_context: TaskletContext = inject.instance(interfaces.ITaskletContext)

        if not value:
            if tasklet_context.cluster is None:
                value = consts.PROD.name  # default
            else:
                return

        if tasklet_context.cluster is not None and tasklet_context.cluster.name == consts.UNITTEST.name:
            # Cluster could be already defined in tests. Do not override it in that case.
            pass
        elif value == consts.ENV.name:
            tasklet_context.cluster = consts.ClusterConfig("env", endpoint=os.environ[consts.TASKLET_ENDPOINT_ENV_VAR])
        else:
            tasklet_context.cluster = consts.CLUSTER_CONFIGS[value]

    return click.option(
        "-c", "--cluster",
        cls=ContextOption,
        type=click.Choice(consts.CLUSTER_CONFIGS),
        expose_value=False,
        callback=callback,
        help="Server instance.",
    )(func)


def option_output(func: Callable):
    def callback(ctx: click.Context, param: click.Parameter, value: consts.Output):
        _, _ = ctx, param
        # noinspection PyTypeChecker
        tasklet_context: TaskletContext = inject.instance(interfaces.ITaskletContext)
        if value:
            tasklet_context.output = value

    return click.option(
        "-o", "--output", cls=ContextOption, type=click.Choice(consts.Output), expose_value=False,
        callback=callback, help="Output format."
    )(func)


# Tasklet argument helpers

class TaskletArgumentType(enum.IntEnum):
    NONE = 0  # Use t.yaml only
    NAMESPACE = 1  # Get namespace from t.yaml or argument
    TASKLET = 2  # As for NAMESPACE + get tasklet name from t.yaml or argument
    LABEL = 3  # As for TASKLET + get label from argument
    BUILD = 4  # As for TASKLET + get build_id from argument
    OPTIONAL_BUILD = 5  # As for TASKLET + get optional build_id from argument

    @property
    def metavar(self) -> str:
        return {
            self.NAMESPACE: "<namespace>",
            self.TASKLET: "<namespace>/<tasklet>",
            self.LABEL: "[<namespace>/<tasklet>:]<label>",
            self.BUILD: "[<namespace>/<tasklet>:]<build_id>",
            self.OPTIONAL_BUILD: "[<namespace>/<tasklet>:<build_id>]",
        }[self]


def tasklet_descriptor_argument(
    func: Optional[Callable] = None, *, argument_type: TaskletArgumentType = TaskletArgumentType.NONE
):
    def tasklet_descriptor_argument_wrapper(target: Callable):
        # Callback evaluation order is user-dependent. Both config and tasklet arguments affect the same
        #   parameters ('namespace' and 'tasklet'), so they use ternary logic to determine status of parameter:
        # 1. if <parameter> not in ctx.params - callback wasn't called yet
        # 2. else if ctx.params[<parameter>] is None - callback was called, but no value was set
        # 3. else - callback was called, value is set
        def config_argument(ctx: click.Context, param: click.Parameter, value: pathlib.Path):
            _ = param
            # noinspection PyTypeChecker
            config: "tasklet_descriptor.TaskletDescriptor" = inject.instance(interfaces.ITaskletDescriptor)

            if (
                argument_type == TaskletArgumentType.NONE or
                argument_type >= TaskletArgumentType.NAMESPACE and not ctx.params.get("namespace") or
                argument_type >= TaskletArgumentType.TASKLET and not ctx.params.get("tasklet")
            ):
                log.debug(f"Loading t.yaml from {value.absolute()}")
                config.load_from_yaml(value.absolute())
                state = config.state
                if argument_type >= TaskletArgumentType.NAMESPACE and ctx.params.get("namespace") is None:
                    ctx.params["namespace"] = state.meta.namespace
                if argument_type >= TaskletArgumentType.TASKLET and ctx.params.get("tasklet") is None:
                    ctx.params["tasklet"] = state.meta.name

        def tasklet_argument(ctx: click.Context, param: click.Parameter, value: Optional[str]):
            _ = param
            # There is no need to check namespace/tasklet values:
            #   it either set be fresh value or (will be) set by config argument
            value = value or ""
            if argument_type > argument_type.TASKLET:
                if argument_type in (argument_type.LABEL, argument_type.BUILD):
                    value, _, label_or_build = value.rpartition(":")
                    if not label_or_build:
                        raise click.BadArgumentUsage(f"Target should be defined in format '{argument_type.metavar}'")
                else:  # argument_type == argument_type.OPTIONAL_BUILD:
                    value, _, label_or_build = value.partition(":")
                parameter_name = "label" if argument_type == TaskletArgumentType.LABEL else "build_id"
                ctx.params[parameter_name] = label_or_build or None
            if argument_type >= TaskletArgumentType.TASKLET:
                value, _, tasklet = value.rpartition("/")
                ctx.params["tasklet"] = tasklet or ctx.params.get("tasklet")
            if argument_type >= TaskletArgumentType.NAMESPACE:
                ctx.params["namespace"] = value or ctx.params.get("namespace")
            # TODO: print coordinates

        if argument_type != TaskletArgumentType.NONE:
            target = click.argument(
                "tasklet_name", metavar=argument_type.metavar, expose_value=False, required=False,
                callback=tasklet_argument
            )(target)

        return click.option(
            "-f", "--file", "config", metavar="<t.yaml>",
            type=click.Path(path_type=pathlib.Path),
            expose_value=False, default=pathlib.Path("./t.yaml"), callback=config_argument,
            help="Path to tasklet descriptor (t.yaml)."
        )(target)

    if func:
        return tasklet_descriptor_argument_wrapper(func)

    return tasklet_descriptor_argument_wrapper
