package handlers

import (
	"fmt"
	"net/url"
	"strings"

	"github.com/labstack/echo/v4"

	"a.yandex-team.ru/library/go/core/xerrors"
	"a.yandex-team.ru/library/go/yandex/tvm"
	"a.yandex-team.ru/passport/infra/daemons/tvmtool/internal/errs"
	"a.yandex-team.ru/passport/infra/daemons/tvmtool/internal/tvmcontext"
	"a.yandex-team.ru/passport/infra/daemons/tvmtool/internal/tvmtypes"
)

const (
	requiredRolesAny  = "any"
	requiredRolesNone = "none"
)

type checkResponseV2 struct {
	Status  string                `json:"status"`
	Error   string                `json:"error,omitempty"`
	Service *checkResponseService `json:"service"`
	User    *checkResponseUser    `json:"user"`
}

type checkResponseService struct {
	checkResponseCommon
	Src   tvm.ClientID       `json:"src"`
	Roles checkResponseRoles `json:"roles"`
}

type checkResponseUser struct {
	checkResponseCommon
	DefaultUID string                        `json:"default_uid"`
	UIDs       []string                      `json:"uids"`
	Scopes     []string                      `json:"scopes"`
	Roles      map[string]checkResponseRoles `json:"roles"`
}

type checkResponseCommon struct {
	Status        string `json:"status"`
	Error         string `json:"error,omitempty"`
	DebugString   string `json:"debug_string"`
	LoggingString string `json:"logging_string"`
	topLevelError string `json:"-"`
}

type checkResponseRoles map[string][]tvm.Entity

type checkProvider interface {
	GetServiceContext() (*tvmcontext.ServiceContext, error)
	GetUserContext() (*tvmcontext.UserContext, error)
	GetRoles(slug string) (*tvm.Roles, error)
}

func CheckHandlerV2(cfg *tvmtypes.OptimizedConfig, cache checkProvider) echo.HandlerFunc {
	return func(ctx echo.Context) error {
		query := ctx.Request().URL.Query()
		headers := ctx.Request().Header
		var err error

		var service *checkResponseService
		if header := headers.Get(SrvTicketHeader); header != "" {
			service, err = checkServiceTicket(query, header, cfg, cache)
			if err != nil {
				return err
			}
		}

		var user *checkResponseUser
		if header := headers.Get(UsrTicketHeader); header != "" {
			user, err = checkUserTicket(query, header, cfg, cache)
			if err != nil {
				return err
			}
		}

		if service == nil && user == nil {
			return &errs.InvalidParam{
				Message: "nothing to do: there is no Service- or User-Ticket",
			}
		}

		response := mergeCheckResponse(service, user)
		return ctx.JSON(statusToHTTPCode[response.Status], response)
	}
}

func checkServiceTicket(
	query url.Values,
	header string,
	cfg *tvmtypes.OptimizedConfig,
	cache checkProvider,
) (*checkResponseService, error) {
	client, requiredRoles, err := getRequiredArgsForCheck(query, cfg, "required_service_roles")
	if err != nil {
		return nil, err
	}

	var roles *tvm.Roles
	if requiredRoles != "" {
		roles, err = cache.GetRoles(client.IdmSlug)
		if err != nil {
			return nil, &errs.InvalidParam{
				Message: fmt.Sprintf("failed to get roles for alias: '%s': %s", client.Alias, err),
			}
		}
	}

	srvctx, err := cache.GetServiceContext()
	if err != nil {
		return nil, &errs.Temporary{Message: err.Error()}
	} else if srvctx == nil {
		return nil, &errs.Temporary{Message: "internal error: missing srvctx"}
	}

	checked, err := srvctx.CheckTicket(header, client.SelfTvmID)
	if err != nil {
		result := &checkResponseService{}
		setInvalidTicketError(err, "invalid service ticket", &result.checkResponseCommon)
		return result, nil
	}

	result := &checkResponseService{
		Src: checked.SrcID,
		checkResponseCommon: checkResponseCommon{
			DebugString:   checked.DbgInfo,
			LoggingString: checked.LogInfo,
		},
	}

	if requiredRoles != "" {
		if err := checkServiceRoles(requiredRoles, showRoles(query), roles, checked, result); err != nil {
			setNoRolesError(err, "no roles for service", &result.checkResponseCommon)
			return result, nil
		}
	}

	result.Status = StatusOK
	return result, nil
}

func checkUserTicket(
	query url.Values,
	header string,
	cfg *tvmtypes.OptimizedConfig,
	cache checkProvider,
) (*checkResponseUser, error) {
	client, requiredRoles, err := getRequiredArgsForCheck(query, cfg, "required_user_roles")
	if err != nil {
		return nil, err
	}

	var roles *tvm.Roles
	if requiredRoles != "" {
		roles, err = cache.GetRoles(client.IdmSlug)
		if err != nil {
			return nil, &errs.InvalidParam{
				Message: fmt.Sprintf("failed to get roles for alias: '%s': %s", client.Alias, err),
			}
		}
	}

	usrctx, err := cache.GetUserContext()
	if err != nil {
		return nil, &errs.Temporary{Message: err.Error()}
	} else if usrctx == nil {
		return nil, &errs.Temporary{Message: "internal error: missing usrctx"}
	}

	env, err := getEnv(usrctx.GetDefaultEnv(), query.Get("override_env"))
	if err != nil {
		return nil, &errs.InvalidParam{Message: err.Error()}
	}

	checked, err := usrctx.CheckTicketOverriddenEnv(header, env)
	if err != nil {
		result := &checkResponseUser{}
		setInvalidTicketError(err, "invalid user ticket", &result.checkResponseCommon)
		return result, nil
	}

	result := &checkResponseUser{
		DefaultUID: fmt.Sprintf("%d", checked.DefaultUID),
		UIDs:       convertUIDs(checked.Uids),
		Scopes:     checked.Scopes,
		checkResponseCommon: checkResponseCommon{
			DebugString:   checked.DebugString,
			LoggingString: checked.LoggingString,
		},
	}

	if requiredRoles != "" {
		if err := checkUserRoles(requiredRoles, showRoles(query), roles, checked, result); err != nil {
			setNoRolesError(err, "no roles for user", &result.checkResponseCommon)
			return result, nil
		}
	}

	if requiredScopes := query.Get("required_user_scopes"); requiredScopes != "" {
		for _, scope := range strings.Split(requiredScopes, ",") {
			if !checked.HasScope(scope) {
				result.Status = StatusNoScopes
				result.Error = fmt.Sprintf("missing scope '%s'", scope)
				result.topLevelError = "no scopes for user"
				return result, nil
			}
		}
	}

	result.Status = StatusOK
	return result, nil
}

func mergeCheckResponse(service *checkResponseService, user *checkResponseUser) *checkResponseV2 {
	res := &checkResponseV2{
		Service: service,
		User:    user,
	}

	checkErrs := make([]string, 0, 2)
	if service != nil && service.Status != StatusOK {
		checkErrs = append(checkErrs, service.topLevelError)
	}

	if user != nil && user.Status != StatusOK {
		checkErrs = append(checkErrs, user.topLevelError)
	}

	if len(checkErrs) == 0 {
		res.Status = StatusOK
	} else {
		res.Status = StatusCheckFailed
		res.Error = strings.Join(checkErrs, ". ")
	}

	return res
}

func checkServiceRoles(
	requiredRoles string,
	showRoles bool,
	roles *tvm.Roles,
	checked *tvmcontext.CheckedServiceTicket,
	result *checkResponseService,
) error {
	convertedCheckedTicket := &tvm.CheckedServiceTicket{SrcID: checked.SrcID}

	var consumerRoles *tvm.ConsumerRoles
	if roles != nil {
		consumerRoles = roles.GetRolesForService(convertedCheckedTicket)
	}
	if showRoles {
		result.Roles = convertRoles(consumerRoles)
	}

	return checkRoles(requiredRoles, consumerRoles, nil)
}

func checkUserRoles(
	requiredRoles string,
	showRoles bool,
	roles *tvm.Roles,
	checked *tvmcontext.CheckedUserTicket,
	result *checkResponseUser,
) error {
	convertedCheckedTicket := &tvm.CheckedUserTicket{
		DefaultUID: checked.DefaultUID,
		UIDs:       checked.Uids,
		Env:        checked.Env,
	}

	var consumerRoles *tvm.ConsumerRoles
	var err error
	if roles != nil {
		consumerRoles, err = roles.GetRolesForUser(convertedCheckedTicket, nil)
	}
	if showRoles {
		result.Roles = convertUserRoles(convertedCheckedTicket, roles)
	}

	return checkRoles(requiredRoles, consumerRoles, err)
}

func checkRoles(requiredRoles string, consumerRoles *tvm.ConsumerRoles, rolesErr error) error {
	if requiredRoles == requiredRolesNone {
		return nil
	}

	if rolesErr != nil {
		return rolesErr
	}

	if requiredRoles == requiredRolesAny {
		if len(consumerRoles.GetRoles()) == 0 {
			return xerrors.Errorf("missing any role")
		}
		return nil
	}

	for _, required := range strings.Split(requiredRoles, ",") {
		if !consumerRoles.HasRole(required) {
			return xerrors.Errorf("missing role '%s'", required)
		}
	}

	return nil
}

func setInvalidTicketError(err error, description string, result *checkResponseCommon) {
	result.Status = StatusInvalidTicket
	result.Error = err.Error()
	result.topLevelError = description

	if ferr, ok := err.(*errs.Forbidden); ok {
		result.DebugString = ferr.DebugString
		result.LoggingString = ferr.LoggingString
	}
}

func setNoRolesError(err error, description string, result *checkResponseCommon) {
	result.Status = StatusNoRoles
	result.Error = err.Error()
	result.topLevelError = description
}

func showRoles(args url.Values) bool {
	return args.Get("show_roles") != "no"
}

func convertRoles(tvmRoles *tvm.ConsumerRoles) checkResponseRoles {
	res := make(checkResponseRoles)
	if tvmRoles == nil {
		return res
	}

	for k, v := range tvmRoles.GetRoles() {
		res[k] = v.GetEntitiesWithAttrs(nil)
	}

	return res
}

func convertUserRoles(ticket *tvm.CheckedUserTicket, roles *tvm.Roles) map[string]checkResponseRoles {
	res := make(map[string]checkResponseRoles)

	for _, uid := range ticket.UIDs {
		var consumerRoles *tvm.ConsumerRoles
		if roles != nil {
			consumerRoles, _ = roles.GetRolesForUser(ticket, &uid)
		}

		res[fmt.Sprintf("%d", uid)] = convertRoles(consumerRoles)
	}

	return res
}

func convertUIDs(nums []tvm.UID) []string {
	res := make([]string, len(nums))

	for idx := range nums {
		res[idx] = fmt.Sprintf("%d", nums[idx])
	}

	return res
}

func getRequiredArgsForCheck(
	query url.Values,
	cfg *tvmtypes.OptimizedConfig,
	requiredParam string,
) (*tvmtypes.Client, string, error) {
	selfAlias, err := getRequiredStringParam(query, "self")
	if err != nil {
		return nil, "", err
	}

	client := cfg.FindClientByAlias(selfAlias)
	if client == nil {
		return nil, "", &errs.InvalidParam{
			Message: fmt.Sprintf("couldn't find client in config by alias: '%s'", selfAlias),
		}
	}

	var requiredRoles string
	if client.IdmSlug != "" {
		requiredRoles, err = getRequiredStringParam(query, requiredParam)
		if err != nil {
			return nil, "", err
		}
	}

	return client, requiredRoles, nil
}
