package app

import (
	"bytes"
	"context"
	"encoding/json"
	"fmt"
	"net/http"
	"strings"

	"github.com/go-chi/chi/v5"
	"github.com/klauspost/compress/zstd"
	"golang.org/x/sync/errgroup"

	"a.yandex-team.ru/kikimr/public/sdk/go/persqueue"
	"a.yandex-team.ru/kikimr/public/sdk/go/persqueue/log/corelogadapter"
	"a.yandex-team.ru/library/go/core/log"
	"a.yandex-team.ru/library/go/yandex/tvm/tvmtool"
	"a.yandex-team.ru/security/gideon/internal/protoseq"
	"a.yandex-team.ru/security/gideon/speedy-beaver/internal/cleaner"
	"a.yandex-team.ru/security/gideon/speedy-beaver/internal/config"
	"a.yandex-team.ru/security/gideon/speedy-beaver/internal/db"
	"a.yandex-team.ru/security/gideon/speedy-beaver/internal/events"
	"a.yandex-team.ru/security/gideon/speedy-beaver/internal/lbmask"
	"a.yandex-team.ru/security/gideon/speedy-beaver/internal/unistat"
	"a.yandex-team.ru/security/libs/go/chim"
	"a.yandex-team.ru/security/libs/go/ydbtvm"
)

const (
	lbReadSize          = 2 * 1024 * 1024
	lbMessagesCount     = 100
	lbMaxReadPartitions = 2
	chanSize            = 1000
	maxMsgSize          = 20 * 1024 * 1024
)

type App struct {
	cfg       config.Config
	log       log.Logger
	ctx       context.Context
	cancelCtx context.CancelFunc
	dbs       map[events.EventKind]*db.EventSaver
	lbmask    *lbmask.LbmaskSaver
	cleaner   *cleaner.Cleaner
	eventCh   chan events.Event
	unistat   *unistat.Sensor
	lbReaders []persqueue.Reader
	zstd      *zstd.Decoder
}

func NewApp(cfg config.Config, l log.Logger) (*App, error) {
	tvmClient, err := tvmtool.NewDeployClient()
	if err != nil {
		return nil, fmt.Errorf("can't create tvm client: %w", err)
	}

	ctx, cancel := context.WithCancel(context.Background())
	lbCredentials := &ydbtvm.TvmCredentials{
		DstID:     ydbtvm.LbClientID,
		TvmClient: tvmClient,
	}

	lbReaders := make([]persqueue.Reader, 0, len(cfg.Logbroker.Endpoints)*len(cfg.Logbroker.Topics))
	for _, endpoint := range cfg.Logbroker.Endpoints {
		for _, topic := range cfg.Logbroker.Topics {
			lbLogger := corelogadapter.New(log.With(l,
				log.String("source", "lb"),
				log.String("endpoint", endpoint),
				log.String("topic", topic),
			))

			lbReaders = append(lbReaders, persqueue.NewReader(persqueue.ReaderOptions{
				Endpoint:               endpoint,
				Consumer:               cfg.Logbroker.Consumer,
				Logger:                 lbLogger,
				MaxReadSize:            lbReadSize,
				MaxReadMessagesCount:   lbMessagesCount,
				MaxReadPartitionsCount: lbMaxReadPartitions,
				//ForceRebalance:         true,
				DecompressionDisabled: true,
				RetryOnFailure:        true,
				ReadOnlyLocal:         true,
				//CommitsDisabled:        true,
				Topics: []persqueue.TopicInfo{{
					Topic: topic,
				}},
				Credentials: lbCredentials,
			}))
		}
	}

	sensor := &unistat.Sensor{}
	dbs := make(map[events.EventKind]*db.EventSaver)
	for i, kind := range events.AllEventKinds {
		d, err := db.NewEventSaver(cfg.ClickHouse, kind, db.WithLogger(l), db.WithUnistat(sensor))
		if err != nil {
			cancel()
			return nil, fmt.Errorf("failed to create EventSaver: %w", err)
		}

		if i == 0 {
			if err := d.CreateTables(); err != nil {
				cancel()
				return nil, fmt.Errorf("failed initialize tablesr: %w", err)
			}
		}
		dbs[kind] = d
	}

	lbmask, err := lbmask.NewLbmaskSaver(
		cfg.Lbmask,
		lbCredentials,
		lbmask.WithLogger(l),
		lbmask.WithUnistat(sensor),
	)
	if err != nil {
		cancel()
		return nil, fmt.Errorf("failed to cteate lbmask: %w", err)
	}

	cl, err := cleaner.NewCleaner(cfg, cleaner.WithLog(l))
	if err != nil {
		cancel()
		return nil, fmt.Errorf("failed to cteate DB cleaner: %w", err)
	}

	zstdDec, err := zstd.NewReader(
		nil,
		zstd.WithDecoderConcurrency(len(lbReaders)),
	)
	if err != nil {
		cancel()
		return nil, fmt.Errorf("failed to create zstd decoder: %w", err)
	}

	return &App{
		cfg:       cfg,
		log:       l,
		ctx:       ctx,
		cancelCtx: cancel,
		dbs:       dbs,
		lbmask:    lbmask,
		cleaner:   cl,
		eventCh:   make(chan events.Event, chanSize),
		lbReaders: lbReaders,
		zstd:      zstdDec,
		unistat:   sensor,
	}, nil
}

func (a *App) Start() error {
	g := errgroup.Group{}
	g.Go(func() error {
		return a.cleaner.Start(a.ctx)
	})

	for _, r := range a.lbReaders {
		func(r persqueue.Reader) {
			g.Go(func() error {
				err := a.startLBConsumer(r)
				if err != nil {
					a.log.Error("failed to start consumer", log.Error(err))
				}
				return err
			})
		}(r)
	}

	g.Go(func() error {
		for e := range a.eventCh {
			switch e.Kind {
			case events.EventKindProcExec:
				details := e.GetProcExec()
				if len(details.Args) > 1 {
					lbmask.MaskArgs(details.Args[1:])
				}
			case events.EventKindExecveAt:
				details := e.GetExecveAt()
				if len(details.Args) > 1 {
					lbmask.MaskArgs(details.Args[1:])
				}
			}

			err := a.lbmask.LogEvent(e)
			if err != nil {
				a.log.Error("failed to process with lbmask", log.String("kind", e.Kind.String()), log.Error(err))
			}

			d, ok := a.dbs[e.Kind]
			if !ok {
				a.log.Error("unexpected event - skip it", log.String("kind", e.Kind.String()))
				continue
			}

			err = d.LogEvent(e)
			if err != nil {
				a.log.Error("failed to save events", log.String("kind", e.Kind.String()), log.Error(err))
			}
		}
		return nil
	})

	g.Go(func() error {
		r := chi.NewRouter()

		r.Get("/ping", func(w http.ResponseWriter, _ *http.Request) {
			_, _ = w.Write([]byte("OK"))
		})

		r.Get("/unistat", func(w http.ResponseWriter, _ *http.Request) {
			_ = json.NewEncoder(w).Encode(a.unistat.FlushPull())
		})

		r.Mount("/debug", chim.Profiler(""))

		return http.ListenAndServe(":80", r)
	})

	a.log.Info("beaver started")
	return g.Wait()
}

func (a *App) Shutdown(ctx context.Context) error {
	for _, r := range a.lbReaders {
		r.Shutdown()
	}

	close(a.eventCh)
	a.cancelCtx()

	for kind, d := range a.dbs {
		if err := d.Close(ctx); err != nil {
			a.log.Error("can't close EventSaver", log.String("kind", kind.String()), log.Error(err))
		}
	}

	err := a.lbmask.Close(ctx)
	if err != nil {
		a.log.Error("can't close LbmaskSaver", log.Error(err))
	}

	a.zstd.Close()
	return nil
}

func (a *App) startLBConsumer(r persqueue.Reader) error {
	if _, err := r.Start(a.ctx); err != nil {
		return fmt.Errorf("failed to start lb reader for endpoint: %w", err)
	}

	var pseq protoseq.Decoder
	decBuf := make([]byte, 0, maxMsgSize)
	for msg := range r.C() {
		switch v := msg.(type) {
		case *persqueue.Disconnect:
			a.unistat.LbRestarts(1)
			a.log.Error("lb connection lost", log.Error(v.Err))
		case *persqueue.Data:
			for _, b := range v.Batches() {
				a.unistat.Batch(1)
				source := topicToSource(b.Topic)
				for _, m := range b.Messages {
					host := string(m.SourceID)
					data, err := a.decompressMsg(m, decBuf[:0])
					if err != nil {
						a.log.Error("failed to decompress event",
							log.ByteString("source_id", m.SourceID),
							log.UInt64("seq_no", m.SeqNo),
							log.Error(err))
						continue
					}

					pseq.Reset(bytes.NewReader(data))
					for pseq.More() {
						a.unistat.Event(1)
						var e events.Event
						if err := pseq.Decode(&e); err != nil {
							a.log.Error("failed to parse event",
								log.ByteString("source_id", m.SourceID),
								log.UInt64("seq_no", m.SeqNo),
								log.Error(err))
							continue
						}

						e.Host = host
						e.Source = source

						a.eventCh <- e
					}

					if err := pseq.Err(); err != nil {
						a.log.Error("failed to iterate protoseq",
							log.ByteString("source_id", m.SourceID),
							log.UInt64("seq_no", m.SeqNo),
							log.Error(err))
					}
				}
			}
			v.Commit()

		default:
		}
	}

	return r.Err()
}

func (a *App) decompressMsg(msg persqueue.ReadMessage, buf []byte) ([]byte, error) {
	switch msg.Codec {
	case persqueue.Raw:
		return msg.Data, nil
	case persqueue.Zstd:
		return a.zstd.DecodeAll(msg.Data, buf)
	default:
		return nil, fmt.Errorf("unsupported lb codec: %d", msg.Codec)
	}
}

func topicToSource(topic string) string {
	if idx := strings.LastIndex(topic, "--"); idx > -1 && idx+2 < len(topic) {
		return topic[idx+2:]
	}
	return topic
}
