package httpreplay

import (
	"bytes"
	"context"
	"crypto/sha1"
	"crypto/tls"
	"encoding/hex"
	"fmt"
	"io"
	"net"
	"net/http"
	"net/http/httptest"
	"net/url"
	"time"

	"a.yandex-team.ru/library/go/certifi"
	"a.yandex-team.ru/library/go/core/log"
	"a.yandex-team.ru/library/go/core/log/nop"
	"a.yandex-team.ru/security/libs/go/boombox/httpmodel"
	"a.yandex-team.ru/security/libs/go/boombox/tape"
)

var _ http.Handler = (*Replay)(nil)

const (
	DefaultNamespace = "default"
)

type Replay struct {
	storage   *tape.Tape
	namespace string
	upstream  *url.URL
	httpc     *http.Client
	https     *http.Server
	httpt     *httptest.Server
	mode      Mode
	keyFunc   KeyFunc
	doFunc    doer
	log       log.Logger
}

type doer func(ctx context.Context, r *httpmodel.Request) (*httpmodel.Response, error)

func NewReplay(tape *tape.Tape, opts ...Option) (*Replay, error) {
	httpc := &http.Client{
		CheckRedirect: func(_ *http.Request, _ []*http.Request) error {
			return http.ErrUseLastResponse
		},
	}

	certPool, err := certifi.NewCertPoolSystem()
	if err == nil {
		httpc.Transport = &http.Transport{
			TLSClientConfig: &tls.Config{
				RootCAs: certPool,
			},
		}
	}

	httpr := &Replay{
		storage:   tape,
		httpc:     httpc,
		mode:      ModeReply,
		namespace: DefaultNamespace,
		keyFunc:   DefaultKeyFunc,
		log:       &nop.Logger{},
	}
	httpr.doFunc = httpr.doReply

	for _, opt := range opts {
		if err := opt(httpr); err != nil {
			return nil, err
		}
	}

	httpr.log = log.With(httpr.log, log.String("boombox_namespace", httpr.namespace))
	return httpr, nil
}

func (h *Replay) ServeHTTP(w http.ResponseWriter, r *http.Request) {
	now := time.Now()

	outReq, err := httpmodel.RequestFromHTTP(r)
	if err != nil {
		http.Error(w, err.Error(), http.StatusInternalServerError)
		return
	}

	rsp, err := h.doFunc(r.Context(), outReq)
	if err != nil {
		http.Error(w, err.Error(), http.StatusInternalServerError)
		return
	}

	err = rsp.WriteTo(w)
	if err != nil {
		http.Error(w, err.Error(), http.StatusInternalServerError)
		return
	}

	h.log.Info("request",
		log.String("uri", r.RequestURI),
		log.Int32("status", rsp.StatusCode),
		log.Int("bytes", len(rsp.Body)),
		log.Duration("elapsed", time.Since(now)),
	)
}

func (h *Replay) ListenAndServe(addr string) error {
	listener, err := net.Listen("tcp", addr)
	if err != nil {
		return fmt.Errorf("can't listen '%s': %w", addr, err)
	}

	h.https = &http.Server{
		Handler: h,
	}

	h.log.Infof("starting boombox: %s", listener.Addr().String())
	return h.https.Serve(listener)
}

func (h *Replay) TestServer() string {
	h.httpt = httptest.NewServer(h)
	return h.httpt.URL
}

func (h *Replay) TestURL() string {
	if h.httpt == nil {
		return h.TestServer()
	}

	return h.httpt.URL
}
func (h *Replay) Shutdown(ctx context.Context) error {
	if h.https != nil {
		return h.https.Shutdown(ctx)
	}

	if h.httpt != nil {
		h.httpt.Close()
	}

	return nil
}

func (h *Replay) doProxy(ctx context.Context, r *httpmodel.Request) (*httpmodel.Response, error) {
	req, err := r.NewHTTPRequest(ctx)
	if err != nil {
		return nil, err
	}

	// RequestURI can't be set in client requests
	req.URL.Scheme = h.upstream.Scheme
	req.URL.Host = h.upstream.Host
	req.URL.Path = singleJoiningSlash(h.upstream.Path, req.URL.Path)
	req.Host = h.upstream.Host
	req.Header.Set("Host", h.upstream.Host)
	if _, ok := req.Header["User-Agent"]; !ok {
		// explicitly disable User-Agent so it's not set to default value
		req.Header.Set("User-Agent", "")
	}

	httpRsp, err := h.httpc.Do(req)
	if err != nil {
		return nil, err
	}

	defer func() {
		_, _ = io.CopyN(io.Discard, httpRsp.Body, 128<<10)
		_ = httpRsp.Body.Close()
	}()

	rsp, err := httpmodel.ResponseFromHTTP(httpRsp)
	if err != nil {
		return nil, err
	}

	err = h.storage.SaveHTTPResponse(h.namespace, h.keyFunc(r), rsp)
	if err != nil {
		return nil, err
	}

	return rsp, nil
}

func (h *Replay) doReply(_ context.Context, r *httpmodel.Request) (*httpmodel.Response, error) {
	rsp, err := h.storage.GetHTTPResponse(h.namespace, h.keyFunc(r))
	if err != nil {
		if err == tape.ErrRecordNotFound {
			return &httpmodel.Response{
				StatusCode: http.StatusNotFound,
			}, nil
		}

		return nil, err
	}
	return rsp, nil
}

func DefaultKeyFunc(r *httpmodel.Request) string {
	return r.Method + "@" + r.Url
}

func BodyKeyFunc(r *httpmodel.Request) string {
	var out bytes.Buffer
	out.WriteString(r.Method)
	out.WriteByte('@')

	bodyHash := sha1.New()
	bodyHash.Write(r.Body)
	out.WriteString(hex.EncodeToString(bodyHash.Sum(nil)))
	out.WriteByte('@')

	out.WriteString(r.Url)
	return out.String()
}
