package apiserver

import (
	"context"
	"encoding/json"
	"net"
	"net/http"
	"reflect"
	"strings"
	"time"

	"github.com/grpc-ecosystem/go-grpc-middleware/logging/zap/ctxzap"
	uberZap "go.uber.org/zap"
	"golang.org/x/exp/slices"
	"google.golang.org/genproto/googleapis/rpc/errdetails"
	"google.golang.org/grpc"
	"google.golang.org/grpc/codes"
	"google.golang.org/grpc/metadata"
	"google.golang.org/grpc/peer"
	"google.golang.org/grpc/status"
	"google.golang.org/protobuf/encoding/protojson"
	"google.golang.org/protobuf/proto"

	"a.yandex-team.ru/library/go/core/log"
	"a.yandex-team.ru/library/go/core/log/ctxlog"
	"a.yandex-team.ru/library/go/core/xerrors"
	"a.yandex-team.ru/library/go/httputil/headers"
	"a.yandex-team.ru/library/go/yandex/blackbox"
	"a.yandex-team.ru/library/go/yandex/blackbox/httpbb"
	"a.yandex-team.ru/library/go/yandex/tvm/tvmtool"
	"a.yandex-team.ru/tasklet/experimental/internal/consts"
	"a.yandex-team.ru/tasklet/experimental/internal/requestctx"
	"a.yandex-team.ru/tasklet/experimental/internal/xgrpc"
	"a.yandex-team.ru/tasklet/experimental/internal/yandex/sandbox"
)

const (
	RequestIDLogField       = "request_id"
	TestDefaultUser         = "fake_default_user"
	GrpcGatewayHeaderPrefix = "grpcgateway-"
	TaskletFeaturePrefix    = "tasklet-feature-"

	oauthPrefix                  = string(consts.OAuthMethod) + " "
	sandboxSessionPrefix         = string(consts.SandboxSessionMethod) + " "
	sandboxExternalSessionPrefix = string(consts.ExternalSandboxSessionMethod) + " "
)

var logOmitMethods = []string{
	"/tasklet.api.v2.SchemaRegistryService/CreateSchema",
	"/tasklet.api.v2.SchemaRegistryService/GetSchema",
	"/tasklet.api.priv.v1.InternalService/GetExternalSession",
}

var (
	errMissingMetadata = status.Errorf(codes.InvalidArgument, "missing metadata")
	stAuthFailed       = status.New(codes.Unauthenticated, "authentication failed")
	stAuthError        = status.New(codes.Unauthenticated, "authentication error")
)

func setStatus(st *status.Status, err error) error {
	rv, _ := st.WithDetails(
		&errdetails.ErrorInfo{
			Reason: err.Error(),
		},
	)
	return rv.Err()
}

type middleware struct {
	conf                  *MiddlewareConf
	logger                log.Logger
	bb                    blackbox.Client
	marshalOptions        protojson.MarshalOptions
	sandboxSessionChecker SandboxSessionChecker
}

func NewMiddleware(c *MiddlewareConf, l log.Logger, sbSessionChecker SandboxSessionChecker) (*middleware, error) {

	mo := protojson.MarshalOptions{}
	if c.LogRequestsFancy {
		mo.Indent = "  "
		mo.Multiline = true
		mo.UseEnumNumbers = false
	}
	mw := &middleware{
		conf:                  c,
		logger:                l,
		bb:                    nil,
		marshalOptions:        mo,
		sandboxSessionChecker: sbSessionChecker,
	}

	if err := mw.initBlackBox(); err != nil {
		return nil, err
	}

	return mw, nil
}

func (mw *middleware) initBlackBox() error {
	if !mw.conf.Auth {
		return nil
	}

	tvmClient, err := tvmtool.NewDeployClient(
		tvmtool.WithCacheEnabled(true),
		tvmtool.WithLogger(mw.logger.WithName("tvmtool").Structured()),
	)
	if err != nil {
		return err
	}

	bb, err := httpbb.NewIntranet(
		httpbb.WithLogger(mw.logger.WithName("blackbox").Structured()),
		httpbb.WithTVM(tvmClient),
	)
	if err != nil {
		return xerrors.Errorf("failed to create blackbox client: %w", err)
	}
	mw.bb = bb
	return nil
}

func (mw *middleware) authWithSessionID(ctx context.Context, sessID *http.Cookie, userIP string, host string) (
	requestctx.AuthSubject,
	error,
) {
	resp, err := mw.bb.SessionID(
		ctx, blackbox.SessionIDRequest{
			SessionID: sessID.Value,
			UserIP:    userIP,
			Host:      host,
		},
	)

	if err != nil {
		if blackbox.IsUnauthorized(err) {
			return requestctx.NewInvalid(), setStatus(stAuthFailed, err)
		}
		ctxlog.Error(ctx, mw.logger, "cannot authorize user", log.Error(err))
		return requestctx.NewInvalid(), setStatus(stAuthError, err)
	}

	return requestctx.NewUser(resp.User.Login), nil
}

func (mw *middleware) authBB(ctx context.Context, oauthToken string, peerAddress string) (
	requestctx.AuthSubject,
	error,
) {
	bbRequest := blackbox.OAuthRequest{
		OAuthToken: oauthToken,
		UserIP:     peerAddress,
	}

	resp, err := mw.bb.OAuth(ctx, bbRequest)
	if err != nil {
		if blackbox.IsUnauthorized(err) {
			return requestctx.NewInvalid(), setStatus(stAuthFailed, err)
		}
		ctxlog.Error(ctx, mw.logger, "cannot authorize user", log.Error(err))
		return requestctx.NewInvalid(), setStatus(stAuthError, err)
	}
	return requestctx.NewUser(resp.User.Login), nil
}

func (mw *middleware) authSandboxTaskSession(ctx context.Context, session string) (
	requestctx.AuthSubject,
	error,
) {

	sessionInfo, err := mw.sandboxSessionChecker.CheckSandboxSession(ctx, sandbox.SandboxSession(session))
	if err != nil {
		if xerrors.Is(err, sandbox.ErrSandboxNotFound) {
			return requestctx.NewInvalid(), setStatus(stAuthFailed, err)
		}
		ctxlog.Error(ctx, mw.logger, "cannot sandbox task session", log.Error(err))
		return requestctx.NewInvalid(), setStatus(stAuthError, err)
	}
	return requestctx.NewSandboxTask(sessionInfo.TaskID), nil
}

func (mw *middleware) authSandboxExternalSession(ctx context.Context, session string) (
	requestctx.AuthSubject,
	error,
) {

	sessionInfo, err := mw.sandboxSessionChecker.CheckExternalSession(ctx, sandbox.SandboxExternalSession(session))
	if err != nil {
		if xerrors.Is(err, sandbox.ErrSandboxNotFound) {
			return requestctx.NewInvalid(), setStatus(stAuthFailed, err)
		}
		ctxlog.Error(ctx, mw.logger, "cannot sandbox task session", log.Error(err))
		return requestctx.NewInvalid(), setStatus(stAuthError, err)
	}
	return requestctx.NewExecutionID(sessionInfo.ExecutionID), nil
}

func (mw *middleware) authWithToken(ctx context.Context, authorization string, peerAddress string) (
	requestctx.AuthSubject,
	error,
) {
	switch {
	case strings.HasPrefix(authorization, oauthPrefix):
		return mw.authBB(ctx, strings.TrimSpace(authorization[len(oauthPrefix):]), peerAddress)
	case strings.HasPrefix(authorization, sandboxSessionPrefix):
		return mw.authSandboxTaskSession(ctx, strings.TrimSpace(authorization[len(sandboxSessionPrefix):]))
	case strings.HasPrefix(authorization, sandboxExternalSessionPrefix):
		return mw.authSandboxExternalSession(ctx, strings.TrimSpace(authorization[len(sandboxExternalSessionPrefix):]))
	default:
		ctxlog.Error(ctx, mw.logger, "invalid authorization header format")
		return requestctx.NewInvalid(), setStatus(
			stAuthError,
			xerrors.Errorf("invalid auth header prefix != %q", oauthPrefix),
		)
	}
}

func (mw *middleware) auth(ctx context.Context, md metadata.MD, peerAddress string) (requestctx.AuthSubject, error) {
	if !mw.conf.Auth {
		v := md.Get(xgrpc.TestUserMetadataKey)
		if len(v) == 1 {
			return requestctx.NewUser(v[0]), nil
		}
		e := md.Get(xgrpc.TestExecutionMetadataKey)
		if len(e) == 1 {
			return requestctx.NewExecutionID(consts.ExecutionID(e[0])), nil
		}
		return requestctx.NewUser(TestDefaultUser), nil

	}

	// NB: https://github.com/grpc/grpc-go/issues/1174#issuecomment-292923366
	httpRequest := http.Request{Header: http.Header{}}
	for k, values := range md {
		if strings.HasPrefix(strings.ToLower(k), strings.ToLower(GrpcGatewayHeaderPrefix)) {
			restoredHeader := strings.TrimPrefix(strings.ToLower(k), strings.ToLower(GrpcGatewayHeaderPrefix))
			for _, val := range values {
				httpRequest.Header.Add(restoredHeader, val)
			}
		}
	}

	getRemoteAddress := func(md metadata.MD, defaultAddress string) string {
		forwarded := md.Get(consts.ForwardedForHeader)
		forwardedAddress := defaultAddress
		if len(forwarded) > 0 {
			forwardedAddress = forwarded[0]
			// NB: special handling for grpc-gateway. Generated code snippet:
			// 	if addr := req.RemoteAddr; addr != "" {
			//		if remoteIP, _, err := net.SplitHostPort(addr); err == nil {
			//			if fwd := req.Header.Get(xForwardedFor); fwd == "" {
			//				pairs = append(pairs, strings.ToLower(xForwardedFor), remoteIP)
			//			} else {
			//				pairs = append(pairs, strings.ToLower(xForwardedFor), fmt.Sprintf("%s, %s", fwd, remoteIP))
			//			}
			//		}
			//	}
			// TODO: check request came from grpc-gateway?
			if strings.Contains(forwardedAddress, ",") {
				forwardedAddress = strings.TrimSpace(strings.SplitN(forwardedAddress, ",", 2)[0])
			}
		}

		return forwardedAddress
	}
	getRequestedHost := func(md metadata.MD, defaultHost string) string {
		host := defaultHost
		if hosts := md.Get(consts.ForwardedHostHeader); len(hosts) == 1 {
			host = hosts[0]
		}
		// NB: dirty hack for UI failing to provide correct host
		if net.ParseIP(strings.TrimSuffix(strings.TrimPrefix(host, "["), "]")) != nil {
			host = "tasklets.in.yandex-team.ru"
		}
		return host
	}

	if v := md.Get(headers.AuthorizationKey); len(v) > 0 {
		if len(v) != 1 {
			return requestctx.AuthSubject{}, setStatus(
				stAuthError,
				xerrors.Errorf("multiple values for authorization key: %v", len(v)),
			)
		}
		return mw.authWithToken(ctx, v[0], peerAddress)
	} else if v := md.Get(GrpcGatewayHeaderPrefix + headers.AuthorizationKey); len(v) > 0 {
		if len(v) != 1 {
			return requestctx.NewInvalid(), setStatus(
				stAuthError,
				xerrors.Errorf("multiple values for authorization key: %v", len(v)),
			)
		}
		// TODO: check remote host parse correctness
		// https://a.yandex-team.ru/arc/trunk/arcadia/vendor/github.com/grpc-ecosystem/grpc-gateway/runtime/context.go?rev=r7604278#L129
		// TODO: check peerAddress is localhost for proxyfied requests?
		return mw.authWithToken(ctx, v[0], getRemoteAddress(md, peerAddress))
	} else if sessID, err := httpRequest.Cookie("Session_id"); err == nil {
		return mw.authWithSessionID(ctx, sessID, getRemoteAddress(md, peerAddress), getRequestedHost(md, "localhost"))
	} else {
		return requestctx.NewInvalid(), setStatus(stAuthFailed, xerrors.New("no credentials"))
	}
}

func (mw *middleware) logPayload(ctx context.Context, obj interface{}, kind string) {
	if !mw.conf.LogRequests {
		return
	}

	if pb, ok := obj.(proto.Message); !ok {
		ctxlog.Errorf(ctx, mw.logger, "Invalid %s type: %v", kind, reflect.TypeOf(obj))
	} else if respJSON, err := mw.marshalOptions.Marshal(pb); err != nil {
		ctxlog.Errorf(ctx, mw.logger, "Failed to marshal %s to log: %+v", kind, err)
	} else {
		ctxlog.Debugf(ctx, mw.logger, "Full %s: %s", kind, string(respJSON))
	}
}

func (mw *middleware) setFeatureFlags(ctx context.Context, md metadata.MD) context.Context {

	// NB: To pass feature flags via http proxy use modified headers:
	// curl -v -X GET  -H 'Grpc-Metadata-tasklet-feature-foo: {...} http://...'
	for k, values := range md {
		strippedHeader := strings.ToLower(k)
		if !strings.HasPrefix(strippedHeader, TaskletFeaturePrefix) {
			continue
		}
		featureName := strings.ToLower(strings.TrimPrefix(strippedHeader, TaskletFeaturePrefix))
		if featureName == "" {
			// silently drop empty feature
			continue
		}
		if len(values) == 0 {
			continue
		}
		var parsed interface{}
		err := json.Unmarshal([]byte(values[0]), &parsed)
		if err != nil {
			ctxlog.Infof(ctx, mw.logger, "Failed to parse feature. FeatureName: %s, Error: %v", featureName, err)
			continue
		}
		ctx = requestctx.WithFeature(ctx, featureName, parsed)
		ctxlog.Infof(ctx, mw.logger, "Setup request feature. FeatureName: %s, Payload: %v", featureName, values[0])
	}
	return ctx
}

func (mw *middleware) UnaryInterceptor(
	ctx context.Context,
	req interface{},
	info *grpc.UnaryServerInfo,
	handler grpc.UnaryHandler,
) (
	interface{},
	error,
) {

	ctxlog.Debug(ctx, mw.logger, "Request started")
	if !slices.Contains(logOmitMethods, info.FullMethod) {
		mw.logPayload(ctx, req, "request")
	}

	var md metadata.MD

	// msg='request started' scheme=http proto=HTTP/1.1 method=GET remote-addr='[2a02:6b8:c23:36c9:0:696:48f9:0]:46748' user-agent='Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/96.0.4664.45 Safari/537.36' uri=http://tasklet-test.in.yandex-team.ru/v1/namespaces
	// msg='MD: key=grpcgateway-user-agent, value=[Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/96.0.4664.45 Safari/537.36]'
	// msg='MD: key=x-forwarded-host, value=[tasklet-test.in.yandex-team.ru]'
	// msg='MD: key=grpcgateway-accept, value=[text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.9]'
	// msg='MD: key=content-type, value=[application/grpc]'
	// msg='MD: key=user-agent, value=[grpc-go/1.33.2]'
	// msg='MD: key=grpcgateway-accept-language, value=[en-US,en;q=0.9,ru;q=0.8]'
	// msg='MD: key=grpcgateway-cache-control, value=[max-age=0]'
	// msg='MD: key=x-forwarded-for, value=[2a02:6b8:c23:36c9:0:696:48f9:0]'
	// msg='MD: key=grpcgateway-cookie, value=[<cookie value>]'
	// msg='MD: key=:authority, value=[localhost:8080]'

	if m, ok := metadata.FromIncomingContext(ctx); !ok {
		return nil, errMissingMetadata
	} else {
		md = m
	}

	var peerAddress string
	if peerAddr, ok := peer.FromContext(ctx); !ok {
		ctxlog.Error(ctx, mw.logger, "No peer address in request")
		return nil, xgrpc.ErrGeneralError
	} else {
		host, _, err := net.SplitHostPort(peerAddr.Addr.String())
		if err != nil {
			ctxlog.Errorf(ctx, mw.logger, "Failed to parse host and port. Peer: %q", peerAddr.Addr.String())
			return nil, xgrpc.ErrGeneralError
		}
		peerAddress = host
	}
	if user, err := mw.auth(ctx, md, peerAddress); err != nil {
		return nil, err
	} else {
		ctxlog.Infof(ctx, mw.logger, "Authorized subject: %q", user.ToString())
		ctx = requestctx.WithObject(ctx, user)
		ctx = ctxlog.WithFields(ctx, log.String("subject", user.ToString()))
	}

	ctx = mw.setFeatureFlags(ctx, md)

	start := time.Now()
	var resp interface{}
	var handlerErr error

	defer func() {
		if handlerErr != nil {
			ctxlog.Info(
				ctx, mw.logger, "Request finished with error",
				log.String("method", info.FullMethod),
				log.Duration("elapsed", time.Since(start)),
				log.String("error", handlerErr.Error()),
			)
		} else {
			ctxlog.Debug(
				ctx, mw.logger, "Request completed",
				log.String("method", info.FullMethod),
				log.Duration("elapsed", time.Since(start)),
			)
		}
		if !slices.Contains(logOmitMethods, info.FullMethod) {
			mw.logPayload(ctx, resp, "response")
		}
	}()

	resp, handlerErr = handler(ctx, req)
	return resp, handlerErr
}

func UnaryRequestIDGenerator(l log.Logger) grpc.UnaryServerInterceptor {
	return func(
		ctx context.Context,
		req interface{},
		info *grpc.UnaryServerInfo,
		handler grpc.UnaryHandler,
	) (interface{}, error) {
		md, ok := metadata.FromIncomingContext(ctx)
		if !ok {
			return nil, errMissingMetadata
		}

		var requestID consts.RequestID
		{
			values := md.Get(consts.RequestIDHeader)
			if len(values) != 1 {
				l.Infof(
					"Request id not found in request. Generating internal request id",
				)
				requestID = consts.NewRequestID()
			} else if r, err := consts.RequestIDFromString(values[0]); err != nil {
				l.Infof(
					"Request ID is invalid. Generating internal request id",
				)
				requestID = consts.NewRequestID()
			} else {
				requestID = r
			}
		}

		newCtx := requestctx.WithRequestID(ctx, requestID)
		// For GRPC log
		ctxzap.AddFields(
			newCtx,
			uberZap.String(RequestIDLogField, requestID.String()),
		)
		newCtx = ctxlog.WithFields(
			newCtx,
			log.String(RequestIDLogField, requestID.String()),
			log.String("grpc_method", info.FullMethod),
		)

		if err := grpc.SetHeader(newCtx, metadata.Pairs(consts.RequestIDHeader, requestID.String())); err != nil {
			ctxlog.Errorf(newCtx, l, "Error setting %s header: %+v", consts.RequestIDHeader, err)
		}

		return handler(newCtx, req)
	}
}

// GRPC GW metadata & ctx
// msg='request started' scheme=http proto=HTTP/1.1 method=GET remote-addr='[2a02:6b8:c23:36c9:0:696:48f9:0]:46748' user-agent='Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/96.0.4664.45 Safari/537.36' uri=http://tasklet-test.in.yandex-team.ru/v1/namespaces
// msg='MD: key=grpcgateway-user-agent, value=[Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/96.0.4664.45 Safari/537.36]'
// msg='MD: key=x-forwarded-host, value=[tasklet-test.in.yandex-team.ru]'
// msg='MD: key=grpcgateway-accept, value=[text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.9]'
// msg='MD: key=content-type, value=[application/grpc]'
// msg='MD: key=user-agent, value=[grpc-go/1.33.2]'
// msg='MD: key=grpcgateway-accept-language, value=[en-US,en;q=0.9,ru;q=0.8]'
// msg='MD: key=grpcgateway-cache-control, value=[max-age=0]'
// msg='MD: key=x-forwarded-for, value=[2a02:6b8:c23:36c9:0:696:48f9:0]'
// msg='MD: key=grpcgateway-cookie, value=[<cookie value>]'
// msg='MD: key=:authority, value=[localhost:8080]'

// "msg":"MD: key=x-forwarded-for, value=[127.0.0.1]"
// "msg":"MD: key=:authority, value=[localhost:6666]"
// "msg":"MD: key=content-type, value=[application/grpc]"
// "msg":"MD: key=user-agent, value=[grpc-go/1.33.2]"
// "msg":"MD: key=grpcgateway-user-agent, value=[curl/7.74.0]"
// "msg":"MD: key=grpcgateway-accept, value=[*/*]"
// "msg":"MD: key=x-forwarded-host, value=[localhost:8080]"
// "msg":"grpcCtx: key=peer.address, value=127.0.0.1:47116"

// GRPC curl metadata & ctx
// "msg":"MD: key=content-type, value=[application/grpc]"
// "msg":"MD: key=user-agent, value=[grpcurl/dev-build (no version set) grpc-go/1.37.0]"
// "msg":"MD: key=:authority, value=[localhost:6666]"
// "msg":"grpcCtx: key=peer.address, value=127.0.0.1:47128"

// }

func LogMetadata(l log.Logger) grpc.UnaryServerInterceptor {
	return func(
		ctx context.Context,
		req interface{},
		info *grpc.UnaryServerInfo,
		handler grpc.UnaryHandler,
	) (interface{}, error) {
		md, ok := metadata.FromIncomingContext(ctx)
		if !ok {
			ctxlog.Error(ctx, l, "no metadata in context")
		} else {
			for k, v := range md {
				if strings.Contains(strings.ToLower(k), strings.ToLower(consts.ForwardedForHeader)) {
					ctxlog.Debugf(ctx, l, "MD: key=%v, v=%+v", k, v)
				} else {
					ctxlog.Debugf(ctx, l, "MD: key=%v", k)
				}
			}
		}
		return handler(ctx, req)
	}
}
