package grpcgateway

import (
	"bytes"
	"context"
	"encoding/json"
	"fmt"
	"net/http"
	"net/textproto"
	"path"
	"time"

	"github.com/go-chi/chi/v5"
	"github.com/go-chi/chi/v5/middleware"
	"github.com/grpc-ecosystem/grpc-gateway/runtime"
	"google.golang.org/grpc"

	"a.yandex-team.ru/library/go/core/resource"
	"a.yandex-team.ru/library/go/httputil/swaggerui"
	"a.yandex-team.ru/library/go/yandex/tvm"
	tvmutil "a.yandex-team.ru/travel/library/go/tvm"
)

type NoEnvConfig struct {
	Enabled       bool
	Address       string
	EnableBinary  bool
	EnableSwagger bool
}

func (cfg *NoEnvConfig) ToConfig() *Config {
	return &Config{
		Enabled:       cfg.Enabled,
		Address:       cfg.Address,
		EnableBinary:  cfg.EnableBinary,
		EnableSwagger: cfg.EnableSwagger,
	}
}

type Config struct {
	Enabled       bool   `config:"grpcgateway-enabled,optional"`
	Address       string `config:"grpcgateway-address,optional"`
	EnableBinary  bool   `config:"grpcgateway-binary,optional"`
	EnableSwagger bool   `config:"grpcgateway-enable-swagger,optional"`
}

var DefaultConfig = Config{
	Enabled:       false,
	Address:       "127.0.0.1:9002",
	EnableBinary:  false,
	EnableSwagger: true,
}

const swaggerPath = "/swagger/"

type registerCallback func(ctx context.Context, mux *runtime.ServeMux, endpoint string, options []grpc.DialOption) error

type Service struct {
	grpcAddr         string
	serviceName      string
	prefix           string
	registerCallback registerCallback
	options          []grpc.DialOption
	extraHeaderSpecs []HeaderSpec
}

type HeaderSpec struct {
	Name        string
	Description string
	Default     string
	Enum        []string
}

func NewService(serviceName string, prefix string, grpcAddr string, registerCallback registerCallback, extraHeaderNames []HeaderSpec) *Service {
	return &Service{
		grpcAddr:         grpcAddr,
		serviceName:      serviceName,
		prefix:           prefix,
		registerCallback: registerCallback,
		options:          []grpc.DialOption{grpc.WithInsecure()},
		extraHeaderSpecs: extraHeaderNames,
	}
}

func (s *Service) WithGrpcDialOptions(options ...grpc.DialOption) *Service {
	s.options = append(s.options, options...)
	return s
}

func (s *Service) WithTvm(tc tvm.Client, tvmID uint32, enabled bool) *Service {
	if enabled {
		s.options = append(
			s.options,
			grpc.WithUnaryInterceptor(tvmutil.ClientInterceptor(tc, tvmutil.ClientWithDst(tvmID))))
	}
	return s
}

func (s *Service) getRoot() string {
	return path.Join(s.prefix, swaggerPath)
}

func (s *Service) getJSONName() string {
	return fmt.Sprintf("%s.swagger.json", s.serviceName)
}

type Gateway struct {
	cfg        *Config
	services   []*Service
	registered map[string]bool
}

func NewGateway(cfg *Config, services ...*Service) *Gateway {
	return &Gateway{
		cfg:        cfg,
		services:   services,
		registered: make(map[string]bool),
	}
}

func (g *Gateway) isPathRegistered(path string) bool {
	if _, ok := g.registered[path]; ok {
		return true
	}
	g.registered[path] = true
	return false
}

func (g *Gateway) GetRouter(ctx context.Context) (*chi.Mux, error) {
	var options []runtime.ServeMuxOption
	options = append(options, runtime.WithMarshalerOption(runtime.MIMEWildcard, &runtime.JSONPb{OrigName: true, EmitDefaults: true}))
	if g.cfg.EnableBinary {
		binaryMarshaller := BinaryProtoMarshaller{}
		options = append(options, runtime.WithMarshalerOption("application/octet-stream", &binaryMarshaller))
	}
	extraHeaderNameSet := make(map[string]bool)
	for _, s := range g.services {
		for _, h := range s.extraHeaderSpecs {
			extraHeaderNameSet[textproto.CanonicalMIMEHeaderKey(h.Name)] = true
		}
	}
	if len(extraHeaderNameSet) > 0 {
		options = append(options, runtime.WithIncomingHeaderMatcher(matchCustomHeaders(extraHeaderNameSet)))
	}

	mux := runtime.NewServeMux(options...)
	r := chi.NewRouter()
	r.Use(middleware.Logger)
	r.Handle("/*", mux)

	for _, s := range g.services {
		err := s.registerCallback(ctx, mux, s.grpcAddr, s.options)
		if err != nil {
			return nil, err
		}
	}

	if g.cfg.EnableSwagger {
		servicesByPrefix := make(map[string][]*Service)
		for _, s := range g.services {
			servicesByPrefix[s.prefix] = append(servicesByPrefix[s.prefix], s)
		}

		for prefix, services := range servicesByPrefix {
			root := path.Join(prefix, swaggerPath)
			var combinedSchema map[string]interface{}
			for _, service := range services {
				schema, err := g.loadSwaggerSchema(service)
				if err != nil {
					return nil, fmt.Errorf("unable to load swagger schema for service %s: %w", service.serviceName, err)
				}
				if len(service.extraHeaderSpecs) > 0 {
					schema, err = addHeadersToSchema(schema, service.extraHeaderSpecs)
					if err != nil {
						return nil, fmt.Errorf("unable to add headers to swagger schema for service %s: %w", service.serviceName, err)
					}
				}
				combinedSchema = combineMaps(combinedSchema, schema)
			}
			joinedSchema, err := json.Marshal(combinedSchema)
			if err != nil {
				return nil, fmt.Errorf("unable to marshal swagger schema: %w", err)
			}
			r.Route(root, func(r chi.Router) {
				fs := http.StripPrefix(
					root,
					http.FileServer(
						swaggerui.NewFileSystem(swaggerui.WithJSONScheme(joinedSchema)),
					),
				)
				r.Get("/*", fs.ServeHTTP)
			})
		}
		r.Get("/oauth2-redirect.html", func(writer http.ResponseWriter, request *http.Request) {
			res := resource.MustGet("swagger_resources/oauth2-redirect.html")
			http.ServeContent(writer, request, "oauth2-redirect.html", time.Now(), bytes.NewReader(res))
		})
	}
	return r, nil
}

func matchCustomHeaders(headerNamesSet map[string]bool) runtime.HeaderMatcherFunc {
	return func(s string) (string, bool) {
		// checking default first, so we do not override standard prefixing of grpc-gateway lib
		h, m := runtime.DefaultHeaderMatcher(s)
		if m {
			return h, m
		}
		header := textproto.CanonicalMIMEHeaderKey(s)
		if _, ok := headerNamesSet[header]; ok {
			return header, true
		} else {
			return "", false
		}
	}
}

func (g *Gateway) Run(ctx context.Context) error {
	if !g.cfg.Enabled {
		return fmt.Errorf("Gateway.Run: will not run due to enabled=false")
	}

	r, err := g.GetRouter(ctx)
	if err != nil {
		return err
	}

	server := &http.Server{Addr: g.cfg.Address, Handler: r}
	go func() {
		doneChannel := ctx.Done()
		if doneChannel != nil {
			<-doneChannel
			_ = server.Shutdown(context.Background())
		}
	}()
	err = server.ListenAndServe()
	if ctx.Err() == context.Canceled && err == http.ErrServerClosed {
		return nil // serverClosed is not an error if context is cancelled, we are just exiting
	} else {
		return err
	}
}
