package status

import (
	"context"
	"fmt"
	"time"

	"a.yandex-team.ru/library/go/core/log"
	"a.yandex-team.ru/library/go/core/xerrors"
	"a.yandex-team.ru/travel/avia/shared_flights/lib/go/logger"
	appMetrics "a.yandex-team.ru/travel/avia/shared_flights/status_importer/internal/metrics"
	"a.yandex-team.ru/travel/avia/shared_flights/status_importer/internal/objects"
	"a.yandex-team.ru/travel/avia/shared_flights/status_importer/pkg/logging/yt/updatestatuslog"
	"a.yandex-team.ru/travel/avia/wizard/pkg/wizard/helpers"
	"a.yandex-team.ru/travel/library/go/errutil"
	"a.yandex-team.ru/travel/library/go/metrics"
)

type UpdaterConfig struct {
	Parallelism   int16
	BatchSenderFn BatchSenderFn
	Objects       *objects.Objects
	UpdateLogger  log.Logger
}

type BatchSenderFn func() BatchSender

type statusUpdater struct {
	statusChan     chan ProcessingUnit
	batchSenderFn  BatchSenderFn
	objects        *objects.Objects
	ytUpdateLogger log.Logger
}

func NewStatusUpdater(config UpdaterConfig) Updater {
	statusUpdater := &statusUpdater{
		statusChan:     make(chan ProcessingUnit),
		batchSenderFn:  config.BatchSenderFn,
		objects:        config.Objects,
		ytUpdateLogger: config.UpdateLogger,
	}
	statusUpdater.run(config.Parallelism)
	return statusUpdater
}

func (su *statusUpdater) Update(unit ProcessingUnit, blocking bool) error {
	su.statusChan <- unit
	if blocking {
		return <-unit.Finished
	}
	return nil
}

func (su *statusUpdater) maintainSingleWorker(workerID int) {
	for {
		func() {
			defer func() {
				if e := recover(); e != nil {
					logger.Logger().Error(
						"Status worker panicked, restoring",
						log.Int("worker id", workerID),
						log.Reflect("panic", e),
						log.String("traceback", helpers.GetTraceback()),
					)
				}
			}()
			metrics.GlobalAppMetrics().GetOrCreateGauge(appMetrics.Worker, nil, appMetrics.Count).Add(1)
			defer func() {
				metrics.GlobalAppMetrics().GetOrCreateGauge(appMetrics.Worker, nil, appMetrics.Count).Add(-1)
			}()
			su.worker()
		}()
	}
}

func (su *statusUpdater) run(parallelism int16) {
	logger.Logger().Info("Starting status update workers", log.Int16("count", parallelism))
	for i := 0; i < int(parallelism); i++ {
		go func(workerID int) {
			su.maintainSingleWorker(workerID)
		}(i)
	}
}

func (su *statusUpdater) worker() {
	for statusProcessingUnit := range su.statusChan {
		metrics.GlobalAppMetrics().GetOrCreateGauge(appMetrics.Worker, nil, appMetrics.Busy).Add(1)
		statusProcessingUnit.Finished <- su.update(statusProcessingUnit.Statuses)
		close(statusProcessingUnit.Finished)
		metrics.GlobalAppMetrics().GetOrCreateGauge(appMetrics.Worker, nil, appMetrics.Busy).Add(-1)
	}
}

type UpdateResults map[updatestatuslog.RESULT]Statuses

func (ur UpdateResults) appendStatuses(results updatestatuslog.RESULT, statuses Statuses) {
	ur[results] = append(ur[results], statuses...)
}

func (ur UpdateResults) addFrom(ur2 UpdateResults) {
	for k, v := range ur2 {
		ur.appendStatuses(k, v)
	}
}

func (su *statusUpdater) writeLogEntry(status *tStatus, result updatestatuslog.RESULT) {
	json, err := status.GetUpdateLogRecord(result).JSON()
	if err != nil {
		logger.Logger().Error(
			"Cannot log update status log entry",
			log.Error(err),
			log.Any("status", status),
		)
		return
	}
	su.ytUpdateLogger.Info(string(json))
}

func (su *statusUpdater) writeLog(results UpdateResults) {
	for result, statuses := range results {
		for _, s := range statuses {
			su.writeLogEntry(s, result)
		}
	}
}

func (su *statusUpdater) update(statuses Statuses) (err error) {
	defer errutil.Wrap(&err, "statusUpdater.update")
	defer func() {
		if r := recover(); r != nil {
			err = errOrGroup(
				err,
				xerrors.Errorf(
					"panic: %v, panic trace: %+v",
					fmt.Sprint(r),
					helpers.GetTraceback(),
				),
			)
		}
	}()

	if len(statuses) == 0 {
		return xerrors.Errorf("no statuses to write: %w", errNoStatuses)
	}

	startTime := time.Now()

	logger.Logger().Info(
		"Statuses arrived",
		log.Int("len", len(statuses)),
	)

	updateResults := make(UpdateResults)
	defer func() {
		su.writeLog(updateResults)
	}()

	// Validation
	validStatuses, invalidStatuses := statuses.validate()
	if len(invalidStatuses) > 0 {
		updateResults.appendStatuses(updatestatuslog.InvalidStatus, invalidStatuses)
		invalidStatusMetrics(invalidStatuses)
	}

	if len(validStatuses) == 0 {
		return NotImportantError{xerrors.Errorf("all statuses did not pass validation")}
	}

	// Normalization
	normalizedStatuses, normalizeErrors := validStatuses.normalize(su.objects)
	if len(normalizeErrors) > 0 {
		updateResults.appendStatuses(updatestatuslog.AbnormalStatus, normalizeErrors.Statuses())
		abnormalStatusMetrics(normalizeErrors)
	}

	if len(normalizedStatuses) == 0 {
		return NotImportantError{xerrors.Errorf("all statuses did not pass normalization")}
	}

	// Argument substitution
	var statusesWithArgs []tStatusWithArguments
	err = func() error {
		var errors []statusWithError
		statusesWithArgs, errors = prepareStatusArguments(normalizedStatuses, su.objects)
		var unknownStatusSources = map[string]struct{}{}
		var unknownStatusSourceErrorStatuses Statuses
		var otherArgumentErrors error
		var otherArgumentErrorStatuses Statuses
		for _, prepareArgumentError := range errors {
			var unknownStatusSource unknownStatusSourceError
			if xerrors.As(prepareArgumentError.err, &unknownStatusSource) {
				unknownStatusSources[unknownStatusSource.SourceName] = struct{}{}
				unknownStatusSourceErrorStatuses = append(unknownStatusSourceErrorStatuses, prepareArgumentError.status)
				continue
			}
			otherArgumentErrors = errOrGroup(otherArgumentErrors, prepareArgumentError.err)
			otherArgumentErrorStatuses = append(otherArgumentErrorStatuses, prepareArgumentError.status)
		}
		updateResults.appendStatuses(updatestatuslog.UnknownStatusSource, unknownStatusSourceErrorStatuses)
		updateResults.appendStatuses(updatestatuslog.Error, otherArgumentErrorStatuses)
		if len(unknownStatusSources) > 0 {
			unknownStatusSourcesList := make([]string, 0, len(unknownStatusSources))
			for k := range unknownStatusSources {
				unknownStatusSourcesList = append(unknownStatusSourcesList, k)
			}
			logger.Logger().Warn("Unknown status sources", log.Strings("sources", unknownStatusSourcesList))
		}
		if len(statuses) == 0 {
			if otherArgumentErrors != nil {
				return xerrors.Errorf("no valid statuses: %w", otherArgumentErrors)
			}
			return nil
		}
		if otherArgumentErrors != nil {
			logger.Logger().Error("Some statuses will not be written", log.Error(otherArgumentErrors))
		}
		return nil
	}()
	if err != nil {
		return xerrors.Errorf("cannot get arguments for all statuses: %w", err)
	}
	if len(statusesWithArgs) == 0 {
		logger.Logger().Debug("No statuses after getting arguments")
		return nil
	}

	// Write to database
	var updateResultsTransact = make(UpdateResults)
	var statusWriter = tStatusWriter{
		batchSender:      su.batchSenderFn(),
		objects:          su.objects,
		minibatchTimeout: 10 * time.Second,
	}
	statusWriter.Write(context.Background(), statusesWithArgs, updateResultsTransact)

	stopTime := time.Now()
	updateResults.addFrom(updateResultsTransact)

	if v, exists := updateResultsTransact[updatestatuslog.Error]; exists {
		if len(v) > 0 {
			failedStatusWriteMetrics(stopTime.Sub(startTime), v)
		}
	}

	if v, exists := updateResultsTransact[updatestatuslog.Success]; exists {
		if len(v) > 0 {
			successStatusWriteMetrics(stopTime.Sub(startTime), v)
		}
	}

	if v, exists := updateResultsTransact[updatestatuslog.IgnoredByUpdatedAt]; exists {
		if len(v) > 0 {
			ignoredStatusWriteMetrics(stopTime.Sub(startTime), v)
		}
	}

	logger.Logger().Info(
		"Wrote statuses",
		log.Int("arrived", len(statuses)),
		log.Int("affected", len(updateResults[updatestatuslog.Success])),
		log.Int("ignored-by-received-at", len(updateResults[updatestatuslog.IgnoredByUpdatedAt])),
		log.Int("invalid", len(updateResults[updatestatuslog.InvalidStatus])),
		log.Int("abnormal", len(updateResults[updatestatuslog.AbnormalStatus])),
		log.Int("error", len(updateResults[updatestatuslog.Error])),
	)
	return nil
}

func (ss Statuses) validate() (Statuses, Statuses) {
	var valid, invalid Statuses
	for _, status := range ss {
		if status.Validate() {
			valid = append(valid, status)
		} else {
			invalid = append(invalid, status)
		}
	}
	return valid, invalid
}

func (ss Statuses) normalize(objects *objects.Objects) (Statuses, tInvalidStatuses) {
	var valid Statuses
	var invalid tInvalidStatuses
	for _, status := range ss {
		err := status.Normalize(objects)
		if err == nil {
			valid = append(valid, status)
		} else {
			invalid = append(invalid, tInvalidStatus{status, err.Error()})
		}
	}
	return valid, invalid
}
