// Utility to define our RPC server without resorting to code generation.
package api

import (
	"a.yandex-team.ru/infra/nanny2/pkg/hq/validation"
	"net/http"
	"reflect"
	"strings"
	"time"

	"github.com/go-chi/chi/v5"
	"github.com/golang/protobuf/proto"
	"golang.org/x/net/context"

	"a.yandex-team.ru/infra/nanny2/pkg/concur"
	pb "a.yandex-team.ru/yp/go/proto/hq"
)

// Handler must be implemented by RPC service
type Handler func(ctx context.Context, req interface{}, header http.Header) (proto.Message, *pb.Status)

type Method struct {
	name           string
	input          reflect.Type
	handler        Handler
	limit          *concur.CapacityLimiter
	requestTimeout time.Duration
	desiredLogin   string
}

type rpcService struct {
	Name    string
	Prefix  string
	methods []*Method
}

func (m *Method) WithLimit(concurrentRequests int) *Method {
	m.limit = concur.NewCapacityLimiter(concurrentRequests)
	return m
}

func (m *Method) WithTimeout(duration time.Duration) *Method {
	m.requestTimeout = duration
	return m
}

func (m *Method) Accepts(input interface{}) *Method {
	m.input = reflect.TypeOf(input)
	return m
}

func (m *Method) To(handler Handler) *Method {
	m.handler = handler
	return m
}

func (m *Method) DemandsLogin(login string) *Method {
	m.desiredLogin = login
	return m
}

func NewRPCService(name, prefix string) *rpcService {
	if !strings.HasSuffix(prefix, "/") {
		prefix += "/"
	}
	return &rpcService{Name: name, Prefix: prefix}
}

func (r *rpcService) Route(name string) *Method {
	m := &Method{
		name: name,
	}
	r.methods = append(r.methods, m)
	return m
}

func wrapMethod(m *Method) http.HandlerFunc {
	return func(w http.ResponseWriter, r *http.Request) {
		start := time.Now()

		if m.limit != nil {
			if !m.limit.Add() {
				WriteResponse(w, r, nil, StatusLimitExceeded)
				return
			} else {
				defer m.limit.Done()
			}
		}
		req := reflect.New(m.input).Interface()
		status := ReadRequest(r, req)
		if status != nil {
			WriteResponse(w, r, nil, status)
			_ = time.Since(start) / time.Millisecond
			return
		}
		ctx := context.Background()
		if m.requestTimeout != 0 {
			ctx, _ = context.WithTimeout(ctx, m.requestTimeout)
		}

		err := validation.ValidateUserAccess(ctx, r.Header, m.desiredLogin)
		if err != nil {
			res, status := FailBadRequest(err.Error())
			WriteResponse(w, r, res, status)
			return
		}

		res, status := m.handler(ctx, req, r.Header)
		WriteResponse(w, r, res, status)
		_ = time.Since(start) / time.Millisecond
	}
}

func (r *rpcService) Mount(mux *chi.Mux) {
	for _, method := range r.methods {
		mux.Post(r.Prefix+method.name+"/", wrapMethod(method))
	}
}
