package dbcache

import (
	"a.yandex-team.ru/library/go/core/log"
	"a.yandex-team.ru/travel/avia/price_prediction/internal/models"
	"a.yandex-team.ru/travel/avia/price_prediction/internal/pgclient"
	"context"
	"golang.yandex/hasql"
	"gorm.io/gorm"
	"sync/atomic"
	"time"
)

type VariantsPriceStatsRepository struct {
	variantsPriceStats atomic.Value
	pgClient           *pgclient.PGClient
	debug              bool
	cfg                Config
	logger             log.Logger
	lastUpdatedAt      time.Time
}

func NewVariantsPriceStatsRepository(pgClient *pgclient.PGClient, debug bool, cfg Config, logger log.Logger) (
	*VariantsPriceStatsRepository,
	error,
) {
	repository := VariantsPriceStatsRepository{
		pgClient: pgClient,
		debug:    debug,
		cfg:      cfg,
		logger:   logger,
	}
	err := repository.update()
	if err != nil {
		return nil, err
	}
	return &repository, nil
}

func (r *VariantsPriceStatsRepository) update() error {
	ctx, cancel := context.WithTimeout(context.Background(), r.cfg.InitTimeout)
	defer cancel()

	variantsPriceStats := make(map[models.PriceStatKey]models.VariantPriceStat)
	err := r.pgClient.ExecuteInTransaction(
		hasql.Alive, func(db *gorm.DB) error {
			db = db.WithContext(ctx)
			if r.debug {
				db = db.Debug()
			}
			err := db.Exec("DECLARE cur CURSOR FOR SELECT * FROM \"variant_price_stats\";").Error
			if err != nil {
				return err
			}
			for {
				priceStats := make([]models.VariantPriceStat, 0)
				err := db.Raw("FETCH 500 FROM cur;").Scan(&priceStats).Error
				if err != nil {
					return err
				}
				if len(priceStats) == 0 {
					break
				}
				for _, ps := range priceStats {
					key := models.PriceStatKey{
						PointFromType:    ps.PointFromType,
						PointFromID:      ps.PointFromID,
						PointToType:      ps.PointToType,
						PointToID:        ps.PointToID,
						RouteUID:         ps.RouteUID,
						DepartureWeekday: ps.DepartureWeekday,
						DaysToFlight:     ps.DaysToFlight,
					}
					variantsPriceStats[key] = ps
				}
			}
			return db.Exec("CLOSE cur;").Error
		},
	)
	writeUpdateMetric(err == nil)
	if err != nil {
		return err
	}
	r.variantsPriceStats.Store(variantsPriceStats)
	r.lastUpdatedAt = time.Now()
	r.logger.Infof("Loaded %v records", len(variantsPriceStats))
	return nil
}

func (r *VariantsPriceStatsRepository) GetByKey(priceStatKey *models.PriceStatKey) (*models.VariantPriceStat, bool) {
	variantsPriceStats := r.variantsPriceStats.Load().(map[models.PriceStatKey]models.VariantPriceStat)
	stat, exists := variantsPriceStats[*priceStatKey]
	return &stat, exists
}

func (r *VariantsPriceStatsRepository) RunUpdater() {
	t := time.NewTicker(time.Minute)
	defer t.Stop()
	for range t.C {
		if r.needToUpdate(time.Now()) {
			err := r.update()
			if err != nil {
				r.logger.Error("failed to read DB content", log.Error(err))
				continue
			}
		}
	}
}

func (r *VariantsPriceStatsRepository) needToUpdate(now time.Time) bool {
	todayPlannedUpdateAt := time.Date(now.Year(), now.Month(), now.Day(), r.cfg.UpdateAtHour, 0, 0, 0, now.Location())
	yesterdayPlannedUpdateAt := todayPlannedUpdateAt.Add(-24 * time.Hour)
	if now.Before(todayPlannedUpdateAt) {
		return r.lastUpdatedAt.Before(yesterdayPlannedUpdateAt)
	} else {
		return r.lastUpdatedAt.Before(todayPlannedUpdateAt)
	}
}
