package tvm

import (
	"a.yandex-team.ru/library/go/core/log"
	"a.yandex-team.ru/library/go/core/log/nop"
	tvmutil "a.yandex-team.ru/library/go/httputil/middleware/tvm"
	yatvm "a.yandex-team.ru/library/go/yandex/tvm"
	"context"
	"fmt"
	"github.com/grpc-ecosystem/go-grpc-middleware/util/metautils"
	"google.golang.org/grpc"
	"google.golang.org/grpc/codes"
	"google.golang.org/grpc/status"
)

type Option func(*interceptor)

type interceptor struct {
	l         log.Logger
	tvm       yatvm.Client
	whiteList map[yatvm.ClientID]struct{}
}

func CheckServiceTicketInterceptor(tvmClient yatvm.Client, opts ...Option) grpc.UnaryServerInterceptor {
	m := interceptor{
		tvm: tvmClient,
		l:   &nop.Logger{},
	}
	for _, opt := range opts {
		opt(&m)
	}
	return m.wrap
}

func (m *interceptor) wrap(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
	metadata := metautils.ExtractIncoming(ctx)
	serviceTicket := metadata.Get(tvmutil.XYaServiceTicket)
	err := checkServiceTicket(m.tvm, m.whiteList, ctx, serviceTicket)
	if err != nil {
		m.l.Error("tvm error", log.Error(err))
		return nil, status.Error(codes.Unauthenticated, "tvm error")
	}
	resp, err := handler(ctx, req)
	return resp, err
}

func checkServiceTicket(tvmClient yatvm.Client, whiteList map[yatvm.ClientID]struct{}, ctx context.Context, serviceTicket string) error {
	if serviceTicket == "" {
		return fmt.Errorf("missing service ticket")
	}
	ticket, err := tvmClient.CheckServiceTicket(ctx, serviceTicket)
	if err != nil {
		return fmt.Errorf("service ticket check failed: %w", err)
	}
	if _, ok := whiteList[ticket.SrcID]; !ok {
		return fmt.Errorf("client %v authorization failed", ticket.SrcID)
	}
	return nil
}

func WithLogger(l log.Logger) Option {
	return func(m *interceptor) {
		m.l = l
	}
}

func WithAllowedClients(allowedClients []yatvm.ClientID) Option {
	return func(m *interceptor) {
		m.whiteList = map[yatvm.ClientID]struct{}{}
		for _, c := range allowedClients {
			m.whiteList[c] = struct{}{}
		}
	}
}
