package lbmask

import (
	"context"
	"errors"
	"os"
	"regexp"
	"strings"
	"time"

	"a.yandex-team.ru/kikimr/public/sdk/go/persqueue"
	"a.yandex-team.ru/kikimr/public/sdk/go/persqueue/log/corelogadapter"
	"a.yandex-team.ru/kikimr/public/sdk/go/ydb"
	"a.yandex-team.ru/library/go/core/log"
	"a.yandex-team.ru/library/go/core/log/nop"
	"a.yandex-team.ru/security/gideon/speedy-beaver/internal/config"
	"a.yandex-team.ru/security/gideon/speedy-beaver/internal/events"
	"a.yandex-team.ru/security/gideon/speedy-beaver/internal/unistat"
)

const (
	batchSize = uint32(1024 * 1024 * 2)
)

var (
	preSecretRe = regexp.MustCompile(`(?i)pass|token|secret|oauth|access|key`)
	maskTvmRe   = regexp.MustCompile(`(?:(3:(?:user|serv):[\w-]+:)[\w-]{300})`)
	sessRe      = regexp.MustCompile(`([|:]\d+\.\d+\.)[\w-]{27}`)
	aqadRe      = regexp.MustCompile(`(AQAD-[\w-]{7})[\w-]{27}`)

	currentSecret = "MASK_CS"
	nextSecret    = "MASK_NS"
	tvmSecret     = "${1}MASK_TVM"
	sessionSecret = "${1}MASK_SES"
	aqadSecret    = "${1}MASK_AQ"
)

type LbmaskSaver struct {
	ctx       context.Context
	cancelCtx context.CancelFunc
	cfg       config.Lbmask
	writer    persqueue.Writer
	unistat   *unistat.Sensor
	log       log.Logger
	queue     chan events.Event
	closed    chan struct{}
	batchBuf  batchBuffer
}

func NewLbmaskSaver(cfg config.Lbmask, credentials ydb.Credentials, opts ...Option) (*LbmaskSaver, error) {
	ctx, cancelCtx := context.WithCancel(context.Background())

	saver := &LbmaskSaver{
		ctx:       ctx,
		cancelCtx: cancelCtx,
		cfg:       cfg,
		log:       &nop.Logger{},
		queue:     make(chan events.Event, 1000*3),
		closed:    make(chan struct{}),
		batchBuf:  newBatchBuffer(batchSize),
	}

	for _, opt := range opts {
		switch o := opt.(type) {
		case LoggerOption:
			saver.log = log.With(o.Logger)
		case SensorOption:
			saver.unistat = o.Unistat
		}
	}

	writer, err := newLbWriter(cfg, saver.log, credentials)
	if err != nil {
		return nil, err
	}
	saver.writer = writer

	go saver.loop()
	return saver, nil
}

func newLbWriter(cfg config.Lbmask, log log.Logger, credentials ydb.Credentials) (persqueue.Writer, error) {
	ctx := context.Background()

	writer := persqueue.NewWriter(
		persqueue.WriterOptions{
			Endpoint:       cfg.Endpoint,
			Credentials:    credentials,
			Logger:         corelogadapter.New(log),
			Topic:          cfg.Topic,
			SourceID:       []byte(os.Getenv("DEPLOY_POD_ID")),
			Codec:          persqueue.Zstd,
			RetryOnFailure: true,
		},
	)

	_, err := writer.Init(ctx)
	if err != nil {
		return nil, err
	}

	return writer, nil
}

func (s *LbmaskSaver) LogEvent(e events.Event) error {
	select {
	case s.queue <- e:
		return nil
	case <-s.ctx.Done():
		return s.ctx.Err()
	}
}

func (s *LbmaskSaver) loop() {
	defer close(s.closed)

	go func(p persqueue.Writer) {
		for rsp := range p.C() {
			switch m := rsp.(type) {
			case *persqueue.Ack:
			case *persqueue.Issue:
				s.log.Error("lb write issue", log.Error(m.Err))
				s.unistat.MaskError(1)
			}
		}
	}(s.writer)

	save := func(event *events.Event) {
		err := s.writeEvent(event)
		if err != nil {
			s.log.Error("failed to write event", log.Error(err))
			s.unistat.MaskError(1)
		}
	}

	t := time.NewTicker(30 * time.Second)
	defer t.Stop()

	for {

		forceFlush := false
		forceExit := false
		select {
		case <-t.C:
			forceFlush = true
		case <-s.ctx.Done():
			forceExit = true
			forceFlush = true
		case event := <-s.queue:
			save(&event)
		}

		if forceFlush {
			err := s.flush()
			if err != nil {
				s.log.Error("failed to flush events", log.Error(err))
				s.unistat.MaskError(1)
			}
		}

		if forceExit {
			return
		}
	}
}

func (s *LbmaskSaver) writeEvent(event *events.Event) error {
	eventSize := uint32(event.Size())
	ok := s.batchBuf.canPush(eventSize)
	if ok {
		err := s.batchBuf.push(eventSize, event)
		return err
	}

	err := s.flush()
	if err != nil {
		return err
	}

	ok = s.batchBuf.canPush(eventSize)
	if ok {
		err := s.batchBuf.push(eventSize, event)
		return err
	}

	return errors.New("event can't fit buffer")
}

func (s *LbmaskSaver) flush() error {
	data := s.batchBuf.getBuf()
	if len(data) == 0 {
		return nil
	}
	err := s.writer.Write(&persqueue.WriteMessage{Data: data})
	if err != nil {
		return err
	}
	s.unistat.MaskBatch(1)
	s.batchBuf.clear()
	return nil
}

func MaskArgs(args []string) {
	isNextSecret := false
	for i := range args {
		if len(args[i]) < 3 {
			isNextSecret = false
			continue
		}
		if args[i][0] == '/' {
			// skip path
			isNextSecret = false
			continue
		}
		if !isNextSecret {
			if preSecretRe.MatchString(args[i]) {
				if strings.IndexByte(args[i], '=') > -1 {
					args[i] = currentSecret
					continue
				}
				isNextSecret = true
			}
		} else {
			isNextSecret = false
			args[i] = nextSecret
			continue
		}
		if len(args[i]) < 30 {
			continue
		}
		args[i] = maskTvmRe.ReplaceAllString(args[i], tvmSecret)
		args[i] = sessRe.ReplaceAllString(args[i], sessionSecret)
		args[i] = aqadRe.ReplaceAllString(args[i], aqadSecret)
	}
}

func (s *LbmaskSaver) Close(ctx context.Context) error {
	s.cancelCtx()

	select {
	case <-s.closed:
	case <-ctx.Done():
		return ctx.Err()
	}

	return s.writer.Close()
}
