package worker

import (
	"context"
	"fmt"
	"os"
	"strings"
	"sync"
	"time"

	"github.com/aws/aws-sdk-go/service/sqs"

	"a.yandex-team.ru/library/go/core/log"
	"a.yandex-team.ru/library/go/yandex/tvm"
	"a.yandex-team.ru/library/go/yandex/tvm/tvmauth"
	"a.yandex-team.ru/security/libs/go/porto"
	"a.yandex-team.ru/security/libs/go/xtvm"
	"a.yandex-team.ru/security/xray/internal/db"
	"a.yandex-team.ru/security/xray/internal/queue"
	"a.yandex-team.ru/security/xray/internal/servers/worker/config"
	"a.yandex-team.ru/security/xray/internal/servers/worker/inspect"
	"a.yandex-team.ru/security/xray/internal/servers/worker/splunk"
	"a.yandex-team.ru/security/xray/internal/storage/layerstorage"
	"a.yandex-team.ru/security/xray/internal/storage/resstorage"
	"a.yandex-team.ru/security/xray/internal/storage/s3storage"
	"a.yandex-team.ru/security/xray/pkg/checks"
	"a.yandex-team.ru/security/xray/pkg/checks/check"
	"a.yandex-team.ru/security/xray/pkg/collectors"
	"a.yandex-team.ru/security/xray/pkg/collectors/collector"
	"a.yandex-team.ru/yp/go/yp"
)

const (
	backoffTimeout = 10 * time.Second
	tvmWaitTimeout = 30 * time.Second
)

type (
	Worker struct {
		ctx               context.Context
		exitChan          chan struct{}
		shutDownFunc      context.CancelFunc
		cfg               *config.Config
		checks            checks.Checks
		collectors        collectors.Collectors
		queue             *queue.Queue
		porto             *porto.API
		layerStorage      *layerstorage.Storage
		resStorage        *resstorage.Storage
		s3Storage         *s3storage.Storage
		tvm               tvm.Client
		db                *db.DB
		yp                *yp.Client
		splunk            splunk.Sender
		visibilityTimeout int64
		analyzeTimeout    time.Duration
		syncLock          sync.RWMutex
		log               log.Logger
	}
)

func NewWorker(cfg *config.Config, l log.Logger) (*Worker, error) {
	ctx, shutdown := context.WithCancel(context.Background())
	return &Worker{
		ctx:          ctx,
		shutDownFunc: shutdown,
		exitChan:     make(chan struct{}),
		cfg:          cfg,
		log:          l,
	}, nil
}

func (w *Worker) newTVMClient() (tvm.Client, error) {
	if w.cfg.TVM.CacheDir != "" {
		if err := os.MkdirAll(w.cfg.TVM.CacheDir, 0o700); err != nil {
			return nil, fmt.Errorf("unable to create tvm cache dir: %w", err)
		}
	}

	tvmSettings := tvmauth.TvmAPISettings{
		SelfID:                      w.cfg.TVM.ClientID,
		ServiceTicketOptions:        tvmauth.NewAliasesOptions(w.cfg.TVM.ClientSecret, w.cfg.TVM.Destinations),
		DiskCacheDir:                w.cfg.TVM.CacheDir,
		BlackboxEnv:                 &w.cfg.TVM.Env,
		EnableServiceTicketChecking: true,
	}

	if w.cfg.TVM.Port != 0 {
		tvmSettings.TVMPort = w.cfg.TVM.Port
	}

	return tvmauth.NewAPIClient(tvmSettings, w.log)
}

func (w *Worker) ShutdownTTL() time.Duration {
	return w.analyzeTimeout
}

func (w *Worker) onStart() (err error) {
	w.checks = checks.NewChecks(check.Config{
		ContainerDir: w.cfg.Worker.Checks.ContainerDir,
		HostDir:      w.cfg.Worker.Checks.HostDir,
		AuthToken:    w.cfg.Worker.Checks.AuthToken,
		Logger:       w.log,
	})

	w.collectors = collectors.NewCollectors(collector.Config{
		ContainerDir: w.cfg.Worker.Collectors.ContainerDir,
		HostDir:      w.cfg.Worker.Checks.HostDir,
	})

	// initialize checks:
	// - calculate visibility timeout - timeout of all checks + penalty to build/start containers
	// - sync local DB
	w.visibilityTimeout = 5 * 60
	for _, kindChecks := range w.checks {
		for _, ch := range kindChecks {
			err := ch.Sync(w.ctx)
			if err != nil {
				return fmt.Errorf("failed to sync check %q: %w", ch.Name(), err)
			}

			w.visibilityTimeout += int64(ch.Deadline().Seconds())
		}
	}

	for _, kindCollector := range w.collectors {
		for _, col := range kindCollector {
			err := col.Sync(w.ctx)
			if err != nil {
				return fmt.Errorf("failed to sync collector %q: %w", col.Name(), err)
			}

			w.visibilityTimeout += int64(col.Deadline().Seconds())
		}
	}

	w.analyzeTimeout = time.Duration(w.visibilityTimeout-60) * time.Second

	w.yp, err = yp.NewClient(
		"xdc",
		yp.WithAuthToken(w.cfg.YpToken),
	)
	if err != nil {
		return fmt.Errorf("create YP client: %w", err)
	}

	w.layerStorage, err = layerstorage.NewStorage(
		w.cfg.Worker.LayerStorage.Dir,
		layerstorage.WithMaxSize(int64(w.cfg.Worker.LayerStorage.Size.Bytes())),
		layerstorage.WithLogger(w.log.WithName("layer-storage")),
	)
	if err != nil {
		return fmt.Errorf("create layer storage: %w", err)
	}

	w.resStorage, err = resstorage.NewStorage(
		w.cfg.Worker.ResourceStorage.Dir,
		resstorage.WithMaxSize(int64(w.cfg.Worker.ResourceStorage.Size.Bytes())),
		resstorage.WithLogger(w.log.WithName("resource-storage")),
	)
	if err != nil {
		return fmt.Errorf("create resource storage: %w", err)
	}

	w.tvm, err = w.newTVMClient()

	if err != nil {
		return fmt.Errorf("create tvm client: %w", err)
	}

	if err := xtvm.Wait(w.tvm, tvmWaitTimeout); err != nil {
		return fmt.Errorf("wait tvm: %w", err)
	}

	w.s3Storage, err = s3storage.NewS3Storage(w.tvm, w.cfg.S3StorageConfig())
	if err != nil {
		return fmt.Errorf("create s3 storage: %w", err)
	}

	w.queue, err = queue.New(
		w.cfg.SQS.Endpoint,
		queue.WithAuthTVM(w.cfg.SQS.Account, w.tvm),
	)
	if err != nil {
		return fmt.Errorf("create SQS client: %w", err)
	}

	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
	defer cancel()
	w.db, err = db.New(ctx, w.tvm, w.cfg.DBConfig())
	if err != nil {
		return fmt.Errorf("create YDB storage: %w", err)
	}

	w.porto, err = porto.NewAPI(&porto.APIOpts{
		MaxConnections: w.cfg.Worker.Concurrency * 2,
	})
	if err != nil {
		return fmt.Errorf("create Porto client: %w", err)
	}

	if w.cfg.Splunk.Enabled {
		w.splunk = splunk.NewHecSender(
			splunk.WithIndex(w.cfg.Splunk.Index),
			splunk.WithAuthToken(w.cfg.Splunk.AuthToken),
			splunk.WithSourceType(w.cfg.Splunk.SourceType),
			splunk.WithLogger(w.log),
		)
		if err = w.splunk.Start(); err != nil {
			return fmt.Errorf("start Splunk sender: %w", err)
		}
	} else {
		w.splunk = &splunk.NopSender{}
	}

	// cleanup worker data to prevent resource leakage
	w.cleanup()

	return
}

func (w *Worker) onEnd() (err error) {
	if w.db != nil {
		err = w.db.Reset(context.Background())
		if err != nil {
			return
		}
	}

	w.layerStorage.Close()
	w.resStorage.Close()

	if err = w.splunk.Stop(); err != nil {
		return
	}

	return
}

func (w *Worker) Start() error {
	defer close(w.exitChan)

	err := w.onStart()
	if err != nil {
		return fmt.Errorf("failed to start worker: %w", err)
	}

	defer func() {
		err := w.onEnd()
		if err != nil {
			w.log.Error("failed to stop worker", log.Error(err))
		}
	}()

	w.log.Info("worker started")
	queueURL := w.cfg.RequestsQueueURL()
	opts := &queue.ReceiveOptions{
		QueueURL:            queueURL,
		MaxNumberOfMessages: 1,
	}

	messagesQueue := make(chan *sqs.Message, w.cfg.Worker.Concurrency)
	var wg sync.WaitGroup
	wg.Add(w.cfg.Worker.Concurrency)
	for i := 0; i < w.cfg.Worker.Concurrency; i++ {
		go w.messageProcessor(&wg, messagesQueue)
	}

	inShutdown := func() bool {
		return w.ctx.Err() != nil
	}

	nextSync := time.Now().Add(w.cfg.Worker.SyncPeriod).Round(w.cfg.Worker.SyncPeriod)
loop:
	for {
		if inShutdown() {
			break
		}

		if now := time.Now(); now.After(nextSync) {
			w.syncChecks()
			nextSync = now.Add(w.cfg.Worker.SyncPeriod).Round(w.cfg.Worker.SyncPeriod)
		}

		if len(messagesQueue) == w.cfg.Worker.Concurrency {
			// deal with full messagesQueue
			time.Sleep(backoffTimeout)
			continue
		}

		messages, err := w.queue.ReceiveMessage(w.ctx, opts)
		if err != nil {
			if inShutdown() {
				break
			}

			w.log.Error("failed to receive responses", log.Error(err))
			time.Sleep(backoffTimeout)
			continue
		}

		for _, msg := range messages {
			// https://docs.aws.amazon.com/AWSSimpleQueueService/latest/SQSDeveloperGuide/working-with-messages.html#processing-messages-timely-manner
			err := w.queue.ChangeMessageVisibility(w.ctx, &queue.ChangeMessageVisibilityOptions{
				QueueURL:          queueURL,
				ReceiptHandle:     msg.ReceiptHandle,
				VisibilityTimeout: w.visibilityTimeout,
			})

			if err != nil {
				if inShutdown() {
					break loop
				}

				w.log.Error("failed to change visibility timeout", log.Error(err))
			}

			messagesQueue <- msg
		}
	}

	close(messagesQueue)
	wg.Wait()

	w.log.Info("worker stopped")
	return nil
}

func (w *Worker) syncChecks() {
	w.log.Info("syncing checks")
	w.syncLock.Lock()
	defer w.syncLock.Unlock()

	for _, kindChecks := range w.checks {
		for _, ch := range kindChecks {
			err := ch.Sync(w.ctx)
			if err != nil {
				w.log.Error("failed to sync check", log.String("check_name", ch.Name()), log.Error(err))
			}
		}
	}
	w.log.Info("sync finished")
}

func (w *Worker) messageProcessor(wg *sync.WaitGroup, messages <-chan *sqs.Message) {
	defer wg.Done()
	for msg := range messages {
		w.syncLock.RLock()
		w.processMessage(msg)
		w.syncLock.RUnlock()
	}
}

func (w *Worker) Shutdown(ctx context.Context) error {
	w.shutDownFunc()

	// grateful wait processor wg
	select {
	case <-w.exitChan:
		// completed normally
		return nil
	case <-ctx.Done():
		// timed out
		return fmt.Errorf("timed out")
	}
}

func (w *Worker) cleanup() {
	cleanupVolumes := func() {
		volumes, err := w.porto.ListVolumes(porto.ListVolumesOpts{
			Container: "self",
		})
		if err != nil {
			w.log.Error("failed to list porto volumes", log.Error(err))
			return
		}

		for _, v := range volumes {
			if !strings.HasPrefix(v.Path(), w.cfg.Worker.WorkDir) {
				w.log.Info("skip volume cleanup", log.String("path", v.Path()))
				continue
			}

			if err := v.Destroy(); err != nil {
				w.log.Warn("failed to destroy volume", log.String("path", v.Path()), log.Error(err))
			} else {
				w.log.Info("volume destroyed", log.String("path", v.Path()))
			}
		}
	}

	cleanupLayers := func() {
		layers, err := w.porto.ListLayers(porto.ListLayersOpts{
			Mask: fmt.Sprintf("%s-***", inspect.LayerPrefix),
		})
		if err != nil {
			w.log.Error("failed to list porto layers", log.Error(err))
			return
		}

		for _, layer := range layers {
			err := w.porto.RemoveLayer(layer.Name)
			if err != nil {
				w.log.Warn("failed to destroy layer", log.String("id", layer.Name), log.Error(err))
			} else {
				w.log.Info("layer destroyed", log.String("id", layer.Name))
			}
		}
	}

	cleanupVolumes()
	cleanupLayers()
}
