package httprpc

import (
	"context"
	"io/ioutil"
	"log"
	"net/http"
	"strconv"
	"strings"

	"github.com/go-chi/chi/v5"
	"github.com/rs/cors"
	"google.golang.org/protobuf/encoding/protojson"
	"google.golang.org/protobuf/proto"

	pb "a.yandex-team.ru/infra/fwmanager/proto"
	"a.yandex-team.ru/library/go/httputil/headers"
)

func New(method, path string) *RPC {
	return &RPC{
		l:      log.New(ioutil.Discard, "", 0),
		method: method,
		path:   path,
	}
}

type Handler func(ctx context.Context, reqCtx *pb.RequestCtx, req proto.Message) (proto.Message, error)
type reader func(*http.Request) (proto.Message, *pb.RequestCtx, error)
type writer func(http.ResponseWriter, proto.Message) error

type RPC struct {
	l *log.Logger

	r         reader
	h         Handler
	w         writer
	errWriter writer

	cors     *cors.Cors
	method   string
	path     string
	accepted string
}

func (r *RPC) WithLogger(l *log.Logger) *RPC {
	r.l = l
	return r
}

func (r *RPC) CorsAllowAll() *RPC {
	r.cors = cors.AllowAll()
	return r
}

func (r *RPC) WithHandler(h Handler) *RPC {
	r.h = h
	return r
}

// Unmarshalling json as protobuf message
// If parsing failed returns http error 400
func (r *RPC) WithJSONPbReader(p proto.Message) *RPC {
	r.r = func(req *http.Request) (proto.Message, *pb.RequestCtx, error) {
		content, err := ioutil.ReadAll(req.Body)
		if err != nil {
			return nil, nil, err
		}
		if len(content) == 0 {
			return p, RequestCtxFromRequest(req), err
		}
		err = protojson.Unmarshal(content, p)
		if err != nil {
			return nil, nil, err
		}
		return p, RequestCtxFromRequest(req), nil
	}
	return r
}
func (r *RPC) WithMultiTypeReader(p proto.Message) *RPC {
	r.r = func(req *http.Request) (proto.Message, *pb.RequestCtx, error) {
		content, err := ioutil.ReadAll(req.Body)
		if err != nil {
			return nil, nil, err
		}
		if len(content) == 0 {
			return p, RequestCtxFromRequest(req), err
		}
		if err := readMultiType(content, p, req.Header.Get("Content-Type")); err != nil {
			return nil, nil, err
		}
		r.accepted = req.Header.Get("Accepted")
		return p, RequestCtxFromRequest(req), nil
	}
	return r
}

// Custom request to protobuf transformation function
// If parsing failed returns http error 400
func (r *RPC) WithRequestReader(re reader) *RPC {
	r.r = re
	return r
}

var mrshlr = protojson.MarshalOptions{
	Multiline:       true,
	Indent:          "  ",
	EmitUnpopulated: true,
}

func (r *RPC) WithJSONPbWriter() *RPC {
	r.w = func(w http.ResponseWriter, respPb proto.Message) error {
		w.Header().Set(headers.ContentTypeKey, headers.TypeApplicationJSON.String())
		mrshld, err := mrshlr.Marshal(respPb)
		if err != nil {
			return err
		}
		_, err = w.Write(mrshld)
		return err
	}
	return r
}

func (r *RPC) WithWriterByAccepted() *RPC {
	r.w = func(w http.ResponseWriter, respPb proto.Message) error {
		mrshld, contentType, err := marshalProtoByAccepted(respPb, r.accepted)
		if err != nil {
			return err
		}
		w.Header().Set(headers.ContentTypeKey, contentType)
		_, err = w.Write(mrshld)
		return err
	}
	return r
}

func (r *RPC) Mount(mux *chi.Mux) *RPC {
	httpHandler := http.HandlerFunc(func(w http.ResponseWriter, request *http.Request) {
		if r.cors != nil {
			r.cors.HandlerFunc(w, request)
		}
		req, reqCtx, err := r.r(request)
		// we fail on reading transforming request -> pb
		// often means we can not marshal json -> pb,
		// may be some fields are invalid
		if err != nil {
			fail(err, w, http.StatusBadRequest)
			return
		}
		res, err := r.h(request.Context(), reqCtx, req)
		if err != nil {
			fail(err, w, http.StatusInternalServerError)
			return
		}
		err = r.w(w, res)
		if err != nil {
			fail(err, w, http.StatusInternalServerError)
		}
	})
	mux.Method(r.method, r.path, httpHandler)
	return r
}

func fail(err error, w http.ResponseWriter, statusCode int) {
	w.WriteHeader(statusCode)
	resp := []byte(err.Error())
	w.Header().Set(headers.ContentTypeKey, headers.TypeTextPlain.String())
	w.Header().Set(headers.ContentLength, strconv.Itoa(len(resp)))
	_, _ = w.Write(resp)
}

func RequestCtxFromRequest(r *http.Request) *pb.RequestCtx {
	sessionCookie, _ := r.Cookie("Session_id")
	sessionID := ""
	if sessionCookie != nil {
		sessionID = sessionCookie.Value
	}
	return &pb.RequestCtx{
		SessionId:  sessionID,
		OauthToken: r.Header.Get("OAuth"),
		RemoteAddr: r.RemoteAddr,
	}
}

const (
	headerValueJSON     = "application/json"
	headerValueProtobuf = "application/x-protobuf"
)

func readMultiType(specBytes []byte, msg proto.Message, contentType string) error {
	switch contentType {
	case headerValueProtobuf:
		return proto.Unmarshal(specBytes, msg)
	case headerValueJSON:
		fallthrough
	default:
		return protojson.Unmarshal(specBytes, msg)
	}
}

func marshalProtoByAccepted(m proto.Message, contentType string) ([]byte, string, error) {
	marshaled := make([]byte, 0)
	var err error
	t := ""
	switch strings.ToLower(contentType) {
	case headerValueProtobuf:
		marshaled, err = proto.Marshal(m)
		t = headerValueProtobuf
	case headerValueJSON:
	default:
		marshaled, err = mrshlr.Marshal(m)
		t = headerValueJSON
	}
	return marshaled, t, err
}
