package processor

import (
	"context"
	"runtime/debug"
	"time"

	"github.com/gofrs/uuid"

	"a.yandex-team.ru/library/go/core/log"
	"a.yandex-team.ru/library/go/core/log/ctxlog"
	"a.yandex-team.ru/library/go/core/metrics"
	"a.yandex-team.ru/library/go/core/metrics/solomon"
	"a.yandex-team.ru/library/go/core/xerrors"
	"a.yandex-team.ru/tasklet/api/v2"
	"a.yandex-team.ru/tasklet/experimental/internal/apiclient"
	"a.yandex-team.ru/tasklet/experimental/internal/locks"
	"a.yandex-team.ru/tasklet/experimental/internal/storage"
	"a.yandex-team.ru/tasklet/experimental/internal/yandex/sandbox"
	"a.yandex-team.ru/tasklet/experimental/internal/yandex/ytdriver"
)

const tickIDField = "processor_tick_id"
const executionIDField = "execution_id"
const ytSpawnAnnotation = "yt_spawn"

type Processor struct {
	conf         *Config
	executorConf *apiclient.Config
	logger       log.Logger
	storage      storage.IStorage
	stop         chan struct{}
	sbx          *sandbox.Client
	ytc          *ytdriver.YTDriver
	locker       locks.Locker
	metrics      *processorMetrics
}

type processorMetrics struct {
	iterations       metrics.Counter
	panics           metrics.Counter
	fails            metrics.Counter
	durations        metrics.Timer
	activeExecutions metrics.Gauge
}

func New(
	conf *Config,
	executorConf *apiclient.Config,
	logger log.Logger,
	s storage.IStorage,
	sbx *sandbox.Client,
	ytc *ytdriver.YTDriver,
	locker locks.Locker,
	mr metrics.Registry,
) (*Processor, error) {
	pm := &processorMetrics{
		iterations: mr.Counter("iterations"),
		panics:     mr.Counter("panics"),
		fails:      mr.Counter("fails"),
		durations: mr.DurationHistogram(
			"duration_buckets",
			metrics.NewDurationBuckets(
				100*time.Millisecond,
				500*time.Millisecond,
				1*time.Second,
				2*time.Second,
				5*time.Second,
				10*time.Second,
			),
		),
		activeExecutions: mr.Gauge(mr.ComposeName("executions", "active")),
	}
	solomon.Rated(pm.iterations)
	solomon.Rated(pm.panics)
	solomon.Rated(pm.fails)
	solomon.Rated(pm.durations)

	return &Processor{
		conf:         conf,
		executorConf: executorConf,
		logger:       logger,
		storage:      s,
		stop:         make(chan struct{}),
		sbx:          sbx,
		ytc:          ytc,
		locker:       locker,
		metrics:      pm,
	}, nil
}

func (p *Processor) Serve() error {
	p.locker.Start()
	defer p.locker.Stop()

	timer := time.NewTimer(time.Nanosecond)
MainLoop:
	for {
		select {
		case <-p.stop:
			break MainLoop
		case <-timer.C:
			// respawn limit
			timer.Reset(time.Second * 2)
		}
		if !p.locker.IsLocked() {
			continue
		}
		tickCtx, cancel := context.WithCancel(context.Background())
		tickChan := p.Tick(tickCtx)
		select {
		case <-tickChan:
			cancel()
			// noop
		case <-p.stop:
			cancel()
			break MainLoop
		}
	}
	return nil
}

func (p *Processor) Stop() {
	close(p.stop)
}

func (p *Processor) Tick(ctx context.Context) <-chan struct{} {
	tickCtx := ctxlog.WithFields(
		ctx,
		log.String(tickIDField, uuid.Must(uuid.NewV4()).String()),
	)
	ctxlog.Info(tickCtx, p.logger, "New tick")
	p.metrics.iterations.Inc()
	ts := time.Now()
	rv := make(chan struct{})
	go func() {
		defer close(rv)
		defer func() {
			if r := recover(); r != nil {
				ctxlog.Errorf(tickCtx, p.logger, "Recovered tick panic: %v", r)
				p.metrics.panics.Inc()
			}
		}()
		defer func() {
			p.metrics.durations.RecordDuration(time.Since(ts))
		}()
		err := p.doTick(tickCtx)
		duration := log.Duration("tick_duration", time.Since(ts))
		if err != nil {
			p.metrics.fails.Inc()
			ctxlog.Error(tickCtx, p.logger, "Tick failed", log.Error(err), duration)
		} else {
			ctxlog.Debug(tickCtx, p.logger, "Tick done", duration)
		}
	}()
	return rv
}

func (p *Processor) doTick(ctx context.Context) error {
	if !p.conf.Enabled {
		<-time.After(time.Second * 60)
		return nil
	}

	executions, err := p.storage.ListActiveExecutions(ctx)
	if err != nil {
		return err
	}
	ctxlog.Infof(ctx, p.logger, "Loaded executions. Count: %v", len(executions))
	p.metrics.activeExecutions.Set(float64(len(executions)))

	for _, execution := range executions {
		p.handleExecution(ctx, execution)
	}
	return nil
}

func (p *Processor) handleExecution(ctx context.Context, execution *taskletv2.Execution) {
	newCtx, cancel := context.WithCancel(
		ctxlog.WithFields(
			ctx,
			log.String(executionIDField, execution.Meta.Id),
		),
	)
	defer cancel()

	defer func() {
		if r := recover(); r != nil {
			err := xerrors.Errorf("Handling failed. Panic: %v", r)
			fields := []log.Field{
				log.Error(err),
				log.Any("trace", string(debug.Stack())),
			}
			ctxlog.Error(newCtx, p.logger, "Recovered panic", fields...)
		}
	}()

	ts := time.Now()
	ctxlog.Debug(newCtx, p.logger, "Execution handling started")
	defer func() {
		ctxlog.Debug(
			newCtx,
			p.logger,
			"Execution handling finished", log.Duration("handling_duration", time.Since(ts)),
		)
	}()

	useYTSpawn := execution.Status.GetAnnotations().GetFields()[ytSpawnAnnotation].GetBoolValue()
	if useYTSpawn && p.ytc == nil {
		ctxlog.Error(newCtx, p.logger, "Skipped task handling: YT disabled")
		return
	}

	handler := newExecutionHandler(execution, p)
	handler.stepResolveResources(newCtx)

	if useYTSpawn {
		handler.stepSpawnYT(newCtx)
		handler.stepAwaitYT(newCtx)
	} else {
		handler.stepCreateSandbox(newCtx)
		handler.stepStartSandbox(newCtx)
		handler.stepHandleAbortionSandbox(newCtx)
		handler.stepAwaitSandbox(newCtx)
	}
	handler.stepArchiveExecution(newCtx)

	if handler.err != nil {
		ctxlog.Error(newCtx, p.logger, "Execution handling failed", log.Error(handler.err))
	}
}
