package tirole

import (
	"context"
	"crypto/hmac"
	"crypto/sha256"
	"encoding/hex"
	"fmt"
	"net/http"
	"os"
	"strings"
	"time"

	"github.com/labstack/echo/v4"

	"a.yandex-team.ru/library/go/core/log"
	"a.yandex-team.ru/library/go/core/xerrors"
	"a.yandex-team.ru/library/go/yandex/tvm"
	"a.yandex-team.ru/library/go/yandex/tvm/tvmauth"
	"a.yandex-team.ru/passport/infra/daemons/tirole/internal/errs"
	"a.yandex-team.ru/passport/infra/daemons/tirole/internal/model"
	"a.yandex-team.ru/passport/infra/daemons/tirole/internal/model/unittest"
	"a.yandex-team.ru/passport/infra/daemons/tirole/internal/model/ytc"
	"a.yandex-team.ru/passport/infra/daemons/tirole_internal/keys"
	"a.yandex-team.ru/passport/shared/golibs/httpdaemon"
	"a.yandex-team.ru/passport/shared/golibs/httpdaemon/httpdtvm"
	"a.yandex-team.ru/passport/shared/golibs/httpdaemon/middlewares"
	"a.yandex-team.ru/passport/shared/golibs/logger"
	"a.yandex-team.ru/passport/shared/golibs/unistat"
)

type Tirole struct {
	tvm       tvm.Client
	backend   model.BackendProvider
	cfg       Config
	accessLog log.Logger
	unistat   stats
	keyMap    *keys.KeyMap
}

type CommonConfig struct {
	ForceDownFile string `json:"force_down_file"`
	AccessLog     string `json:"access_log"`
}

type Config struct {
	Common   CommonConfig       `json:"common"`
	Tvm      httpdtvm.TvmConfig `json:"tvm"`
	Yt       ytc.Config         `json:"yt"`
	Unittest *unittest.Config   `json:"unittest"`
	KeyMap   keys.Config        `json:"key_map"`
}

type stats struct {
	errInvalidParams *unistat.SignalDiff
	errNoRoles       *unistat.SignalDiff
	errTmp           *unistat.SignalDiff
	errUnauthorized  *unistat.SignalDiff
	errUnknown       *unistat.SignalDiff
	errBadSign       *unistat.SignalDiff
	rolesFullContent *unistat.SignalDiff

	consumersChunk *unistat.Chunk
	consumers      *unistat.SignalSet
}

type Factory struct{}

func (f *Factory) NewService(config httpdaemon.ServiceConfig) (httpdaemon.Service, error) {
	var cfg Config
	if err := httpdaemon.ParseServiceConfig(config, &cfg); err != nil {
		return nil, err
	}

	access, err := logger.CreateLog(logger.Config{
		FilePath:             cfg.Common.AccessLog,
		DisablePrintingLevel: true,
	})
	if err != nil {
		return nil, err
	}

	keyMap, err := keys.InitKeyMap(cfg.KeyMap)
	if err != nil {
		return nil, err
	}

	consumersChunk := unistat.NewChunk()

	res := &Tirole{
		cfg:       cfg,
		accessLog: access,
		unistat: stats{
			errInvalidParams: unistat.DefaultChunk.CreateSignalDiff("errors.requests.invalid_params"),
			errNoRoles:       unistat.DefaultChunk.CreateSignalDiff("errors.requests.no_roles"),
			errTmp:           unistat.DefaultChunk.CreateSignalDiff("errors.requests.tmp_err"),
			errUnauthorized:  unistat.DefaultChunk.CreateSignalDiff("errors.requests.unauthorized"),
			errUnknown:       unistat.DefaultChunk.CreateSignalDiff("errors.requests.unknown"),
			errBadSign:       unistat.DefaultChunk.CreateSignalDiff("errors.requests.bad_sign"),
			rolesFullContent: unistat.DefaultChunk.CreateSignalDiff("roles.full_content"),
			consumersChunk:   consumersChunk,
			consumers:        consumersChunk.CreateSignalSet(""),
		},
		keyMap: keyMap,
	}

	if cfg.Unittest == nil {
		res.tvm, err = httpdtvm.InitTvm(cfg.Tvm)
		if err != nil {
			return nil, err
		}

		res.backend, err = ytc.InitYt(cfg.Yt)
		if err != nil {
			return nil, err
		}
	} else {
		res.tvm, err = tvmauth.NewUnittestClient(tvmauth.TvmUnittestSettings{
			SelfID: cfg.Tvm.SelfID,
		})
		if err != nil {
			return nil, err
		}

		res.backend, err = unittest.NewProvider(*cfg.Unittest, keyMap)
		if err != nil {
			return nil, err
		}
	}

	return res, nil
}

func (t *Tirole) AddHandlers(e *echo.Echo) {
	e.Pre(
		t.middlewareAccessLog(),
		t.middlewareSendReqID(),
	)
	e.Use(
		t.handleErrorMiddleware(),
	)

	e.GET(
		"/ping",
		t.HandlePing(),
	)

	e.GET(
		"/v1/get_actual_roles",
		t.HandleV1GetActualRoles(),
		t.middlewareTvmAuth(),
	)
}

func (t *Tirole) GetOptions() *httpdaemon.Options {
	return &httpdaemon.Options{
		ExtendedUnistatHandler: func(e *echo.Echo) {
			e.GET("/unistat/consumers", t.HandleConsumers())
		},
	}
}

func (t *Tirole) HandlePing() echo.HandlerFunc {
	return func(c echo.Context) error {
		status, err := t.tvm.GetStatus(c.Request().Context())
		if err != nil {
			logger.Log().Errorf("Ping: Failed to get status from TVM client: %s", err)
			return c.String(http.StatusInternalServerError, "Failed to get status from TVM client")
		}
		if status.Status == tvm.ClientError {
			logger.Log().Errorf("Ping: bad TVM client status: %s", status.LastError)
			return c.String(http.StatusInternalServerError, fmt.Sprintf("TVM: %s", status.LastError))
		}

		if _, err := os.Stat(t.cfg.Common.ForceDownFile); err == nil {
			logger.Log().Debugf("Ping: Service is forced down")
			return c.String(http.StatusServiceUnavailable, "Service is forced down")
		}

		err = t.backend.Ping(c.Request().Context())

		if err != nil {
			logger.Log().Warnf("Ping: YT is unavailable: %s", err)
			return c.String(http.StatusServiceUnavailable, "YT is unavailable")
		}

		logger.Log().Debugf("Ping: service is up")
		return c.String(http.StatusOK, "")
	}
}

const (
	contentTypeBlob         = "application/octet-stream"
	headerETag              = "ETag"
	headerRequestID         = "X-Request-Id"
	headerContentLength     = "Content-Length"
	headerIfNoneMatch       = "If-None-Match"
	headerTiroleCompression = "X-Tirole-Compression"
	paramSystemSlug         = "system_slug"
)

func (t *Tirole) CheckSign(r *model.ActualRoles) error {
	s := strings.Split(r.Meta.EncodedHmac, ":")
	if len(s) != 2 {
		return xerrors.Errorf("Invalid encoded_hmac format: '%s'", r.Meta.EncodedHmac)
	}
	keyID, hmacHex := s[0], s[1]
	if keyID == "" || hmacHex == "" {
		return xerrors.Errorf("Invalid encoded_hmac format: '%s'", r.Meta.EncodedHmac)
	}

	key := t.keyMap.GetKey(keyID)
	if key == nil {
		return xerrors.Errorf("Invalid encoded_hmac key id: '%s'", keyID)
	}

	hmacDB, err := hex.DecodeString(hmacHex)
	if err != nil {
		return xerrors.Errorf("Invalid encoded_hmac body format: '%s'", hmacHex)
	}

	hash := hmac.New(sha256.New, key)
	hash.Write(r.Blob)
	actualHmac := hash.Sum(nil)

	if !hmac.Equal(actualHmac, hmacDB) {
		return xerrors.Errorf("encoded_hmac check mismatch")
	}

	return nil
}

func (t *Tirole) HandleV1GetActualRoles() echo.HandlerFunc {
	return func(c echo.Context) error {
		st := tvm.ContextServiceTicket(c.Request().Context())

		systemSlug, err := getRequiredStringParam(c, paramSystemSlug)
		if err != nil {
			return err
		}

		if err := t.checkMapping(c.Request().Context(), systemSlug, st.SrcID); err != nil {
			return err
		}

		revisionToMatch, err := getRevisionFromHeader(headerIfNoneMatch, c.Request().Header.Get(headerIfNoneMatch))
		if err != nil {
			return err
		}
		if revisionToMatch != "" {
			revision, err := t.backend.GetActualRevision(c.Request().Context(), systemSlug)
			if err != nil {
				return err
			}

			if revision == revisionToMatch {
				// Use Blob() to trigger callback in c.Response().After()
				return c.Blob(http.StatusNotModified, contentTypeBlob, nil)
			}
		}

		roles, err := t.backend.GetActualRoles(c.Request().Context(), systemSlug)
		if err != nil {
			return err
		}
		revision := roles.Meta.RevisionExt

		err = t.CheckSign(&roles)
		if err != nil {
			return &errs.BadSignError{
				Message: fmt.Sprintf("Sign check failed, slug '%s', revision %v: %s", systemSlug, revision, err.Error()),
			}
		}

		// For HTTP beauty
		c.Response().Header().Set(headerETag, fmt.Sprintf(`"%s"`, revision))

		// Blob generates 'Transfer-Encoding: chunked' by default.
		// 'Content-Length' should be provided to use efficient data transferring.
		c.Response().Header().Set(headerContentLength, fmt.Sprintf("%d", len(roles.Blob)))

		// Client relies on this format
		c.Response().Header().Set(headerTiroleCompression, prepareTiroleCompressionHeader(&roles.Meta))

		t.unistat.rolesFullContent.Inc()
		return c.Blob(http.StatusOK, contentTypeBlob, roles.Blob)
	}
}

func (t *Tirole) checkMapping(ctx context.Context, slug string, tvmid tvm.ClientID) error {
	// TODO: add cache
	mapping, err := t.backend.GetMapping(ctx, slug)
	if err != nil {
		return err
	}
	if _, ok := mapping[tvmid]; !ok {
		return &errs.AccessDenied{
			Message: fmt.Sprintf("system_slug '%s' is not mapped to tvmid=%d", slug, tvmid),
		}
	}

	return nil
}

func prepareTiroleCompressionHeader(m *model.Meta) string {
	return fmt.Sprintf("1:%s:%d:%s", m.Codec, m.DecodedSize, m.DecodedSha256)
}

func (t *Tirole) HandleConsumers() echo.HandlerFunc {
	return func(c echo.Context) error {
		return c.JSONBlob(http.StatusOK, t.unistat.consumersChunk.Serialize())
	}
}

func (t *Tirole) middlewareTvmAuth() echo.MiddlewareFunc {
	return func(next echo.HandlerFunc) echo.HandlerFunc {
		return func(c echo.Context) (err error) {
			st := c.Request().Header.Get("X-Ya-Service-Ticket")
			if st == "" {
				return &errs.UnauthorizedError{
					Message: "missing service ticket",
				}
			}

			ticket, err := t.tvm.CheckServiceTicket(c.Request().Context(), st)
			if err != nil {
				e := &errs.UnauthorizedError{
					Message:      "service ticket is invalid",
					Description:  err.Error(),
					TicketStatus: err.(*tvm.TicketError).Status.String(),
				}
				if ticket != nil {
					e.LoggablePart = ticket.LogInfo
				}
				if err.(*tvm.TicketError).Status == tvm.TicketInvalidDst {
					e.Description = fmt.Sprintf(
						"%s; expected dst is %d",
						e.Description, t.cfg.Tvm.SelfID,
					)
				}

				return e
			}

			c.SetRequest(c.Request().WithContext(
				tvm.WithServiceTicket(c.Request().Context(), ticket),
			))
			t.unistat.consumers.CreateOrIncSignal(
				fmt.Sprintf("tvmid_%d", ticket.SrcID),
			)

			return next(c)
		}
	}
}

const emptyString = "-"

func (t *Tirole) middlewareAccessLog() echo.MiddlewareFunc {
	nullable := func(val string) string {
		if len(val) == 0 {
			return emptyString
		}
		return val
	}

	consumer := func(t *tvm.CheckedServiceTicket) string {
		if t == nil {
			return emptyString
		}
		return fmt.Sprintf("%d", t.SrcID)
	}

	return func(next echo.HandlerFunc) echo.HandlerFunc {
		return func(c echo.Context) (err error) {
			c.Response().After(func() {
				startTime := middlewares.ContextStartInstant(c.Request().Context())
				st := tvm.ContextServiceTicket(c.Request().Context())

				t.accessLog.Debugf(
					"%s\t%s\t%s\t%s\t%.1fms\t%d\t%d\t%s\t%s\t%s",
					middlewares.ContextReqID(c.Request().Context()),
					c.RealIP(),
					consumer(st),
					c.Request().URL.Path,
					float64(time.Since(startTime).Microseconds())/1000,
					c.Response().Status,
					c.Response().Size,
					nullable(c.Request().Header.Get(headerIfNoneMatch)),
					nullable(c.Request().URL.RawQuery),
					nullable(c.Response().Header().Get(headerETag)),
				)
			})

			return next(c)
		}
	}
}

func (t *Tirole) middlewareSendReqID() echo.MiddlewareFunc {
	return func(next echo.HandlerFunc) echo.HandlerFunc {
		return func(c echo.Context) (err error) {
			c.Response().Header().Set(
				headerRequestID,
				middlewares.ContextReqID(c.Request().Context()),
			)

			return next(c)
		}
	}
}
