package tvmcontext

import (
	"encoding/base64"
	"fmt"
	"strings"
	"time"

	"a.yandex-team.ru/library/cpp/tvmauth/src/protos"
	"a.yandex-team.ru/library/go/yandex/tvm"
	"a.yandex-team.ru/passport/infra/daemons/tvmtool/internal/crypto"
	"a.yandex-team.ru/passport/infra/daemons/tvmtool/internal/errs"
)

type UserContext struct {
	keys         publicBbKeys
	keyIDExample keyID
	envtype      tvm.BlackboxEnv
}

type CheckedUserTicket struct {
	DefaultUID    tvm.UID
	Uids          []tvm.UID
	Scopes        []string
	DebugString   string
	LoggingString string
	Env           tvm.BlackboxEnv

	scopesIndex map[string]interface{}
}

type bbPublicKey struct {
	rw  *crypto.RWPublicKey
	env tvm.BlackboxEnv
}

type publicBbKeys map[keyID]bbPublicKey

func NewUserContext(keys string, envtype tvm.BlackboxEnv) (*UserContext, error) {
	parsedKeys, err := ParseKeys(keys)
	if err != nil {
		return nil, err
	}

	var id keyID
	rwEnvCount := make([]int, 5)
	rwKeys := make(publicBbKeys)
	for _, k := range parsedKeys.GetBb() {
		rw, err := parseKey(k.GetGen())
		if err != nil {
			return nil, err
		}

		rwEnvCount[int(k.GetEnv())] += 1

		id = keyID(k.GetGen().GetId())
		rwKeys[id] = bbPublicKey{
			rw:  rw,
			env: tvm.BlackboxEnv(k.GetEnv()),
		}
	}

	if len(rwKeys) == 0 {
		return nil, errorNoPublicKeysUser
	}

	for idx, count := range rwEnvCount {
		if count == 0 {
			return nil, fmt.Errorf("there is no one public key for env: %s", protos.BbEnvType(idx))
		}
	}

	return &UserContext{keys: rwKeys, keyIDExample: id, envtype: envtype}, nil
}

func (u *UserContext) GetDefaultEnv() tvm.BlackboxEnv {
	return u.envtype
}

func (u *UserContext) checkTicketSignatureWithEnv(ticket *parsedTicket, env tvm.BlackboxEnv) error {
	keyid := keyID(ticket.Ticket.GetKeyId())
	key, ok := u.keys[keyid]
	if !ok {
		if err := checkKeyEnvironmentMismatch(keyid, u.keyIDExample); err != nil {
			return err
		}
		return fmt.Errorf("key id for user ticket not found: %d. Maybe keys are too old", ticket.Ticket.GetKeyId())
	}

	if key.env != env {
		return fmt.Errorf("user ticket is accepted from wrong blackbox enviroment. Env expected=%s, got=%s", u.envtype, key.env)
	}

	signature, err := base64.RawURLEncoding.DecodeString(ticket.Signature)
	if err != nil {
		return errorInvalidSignatureBase64
	}

	result, err := key.rw.VerifyMsg(ticket.Msg, signature)
	if err != nil {
		logThrottled(err, ticket.Msg+ticket.Signature)
		return err
	}

	if !result {
		logThrottled(errorIncorrectSignature, ticket.Msg+ticket.Signature)
		return errorIncorrectSignature
	}

	return nil
}

func (u *UserContext) CheckTicket(str string) (*CheckedUserTicket, error) {
	return u.CheckTicketOverriddenEnv(str, u.envtype)
}

func (u *UserContext) CheckTicketOverriddenEnv(str string, overriddenEnv tvm.BlackboxEnv) (*CheckedUserTicket, error) {
	return u.checkTicketOverriddenEnvImpl(str, overriddenEnv, time.Now())
}

func (u *UserContext) checkTicketOverriddenEnvImpl(str string, overriddenEnv tvm.BlackboxEnv, now time.Time) (*CheckedUserTicket, error) {
	parsedtkt, err := parseFromStr(str)
	if err != nil {
		return nil, &errs.Forbidden{Message: err.Error(), LoggingString: str}
	}

	if parsedtkt.Type != TicketTypeUserTicket {
		return nil, &errs.Forbidden{
			Message:       "wrong ticket type, user-ticket is expected",
			LoggingString: parsedtkt.Msg,
		}
	}

	if parsedtkt.Ticket.GetExpirationTime() < now.Unix() {
		return nil, &errs.Forbidden{
			Message:       fmt.Sprintf("expired ticket, exp_time %d, now %d", parsedtkt.Ticket.GetExpirationTime(), now.Unix()),
			DebugString:   makeDebugStringForUserTicket(parsedtkt.Ticket),
			LoggingString: parsedtkt.Msg,
		}
	}

	if err := u.checkTicketSignatureWithEnv(parsedtkt, overriddenEnv); err != nil {
		return nil, &errs.Forbidden{
			Message:       err.Error(),
			DebugString:   makeDebugStringForUserTicket(parsedtkt.Ticket),
			LoggingString: parsedtkt.Msg,
		}
	}

	uids := make([]tvm.UID, len(parsedtkt.Ticket.User.GetUsers()))
	for n := range parsedtkt.Ticket.User.GetUsers() {
		uids[n] = tvm.UID(parsedtkt.Ticket.User.Users[n].GetUid())
	}

	return &CheckedUserTicket{
		DefaultUID:    tvm.UID(parsedtkt.Ticket.User.GetDefaultUid()),
		Uids:          uids,
		Scopes:        parsedtkt.Ticket.User.GetScopes(),
		DebugString:   makeDebugStringForUserTicket(parsedtkt.Ticket),
		LoggingString: parsedtkt.Msg,
		Env:           overriddenEnv,
	}, nil
}

func makeDebugStringForUserTicket(ticket *protos.Ticket) string {
	if ticket == nil {
		return ""
	}

	uids := make([]string, len(ticket.User.GetUsers()))
	for n := range ticket.User.GetUsers() {
		uids[n] = fmt.Sprintf("%d", ticket.User.Users[n].GetUid())
	}

	return fmt.Sprintf(
		"ticket_type=user;expiration_time=%d;scope=%s;default_uid=%d;uid=%s;env=%s;",
		ticket.GetExpirationTime(),
		strings.Join(ticket.User.GetScopes(), ","),
		ticket.User.GetDefaultUid(),
		strings.Join(uids, ","),
		ticket.User.GetEnv().String(),
	)
}

func (t *CheckedUserTicket) HasScope(scope string) bool {
	if t.scopesIndex == nil {
		t.scopesIndex = make(map[string]interface{})
		for _, s := range t.Scopes {
			t.scopesIndex[s] = nil
		}
	}

	_, found := t.scopesIndex[scope]
	return found
}
