package users

import (
	"container/list"
	"context"
	"encoding/json"
	"fmt"
	"sort"
	"strings"
	"sync"
	"text/template"

	"github.com/spf13/cobra"

	"a.yandex-team.ru/drive/analytics/gobase/models"
	"a.yandex-team.ru/drive/analytics/gotasks"
	"a.yandex-team.ru/drive/analytics/gotasks/exports"
	"a.yandex-team.ru/library/go/core/log"
	"a.yandex-team.ru/yt/go/ypath"
	"a.yandex-team.ru/yt/go/yt"
	"a.yandex-team.ru/zootopia/analytics/drive/api"
)

func init() {
	updateAchievementsCmd := cobra.Command{
		Use: "update-achievements",
		Run: gotasks.WrapMain(updateAchievementsMain),
	}
	updateAchievementsCmd.Flags().Bool("dry-run", false, "Enables dry-run mode")
	updateAchievementsCmd.Flags().String("yt-proxy", "hahn", "YT proxy")
	updateAchievementsCmd.Flags().String("yt-table", "", "Path to YT table")
	updateAchievementsCmd.Flags().String("drive", "", "Name of Drive client")
	updateAchievementsCmd.Flags().Int("workers", 16, "Amount of parallel workers")
	updateAchievementsCmd.Flags().String("achievement", "", "Name of achievement")
	updateAchievementsCmd.Flags().Bool("stateless", false, "Enables stateless mode")
	updateAchievementsCmd.Flags().String("first-push", "", "Name of first push")
	updateAchievementsCmd.Flags().String("push", "", "Name of push")
	UsersCmd.AddCommand(&updateAchievementsCmd)
}

type tableProcessorState struct {
	Finish     json.RawMessage   `json:"finish"`
	InProgress []json.RawMessage `json:"in_progress"`
}

type tableProcessorStateKey struct {
	Key     json.RawMessage `json:"key"`
	Attempt *int            `json:"attempt"`
}

type TableKey interface {
	Less(rhs TableKey) bool
	RawKey() []any
}

type attemptTableKey struct {
	key     TableKey
	attempt int
}

func (k attemptTableKey) MarshalJSON() ([]byte, error) {
	raw, err := json.Marshal(k.key)
	if err != nil {
		return nil, err
	}
	return json.Marshal(tableProcessorStateKey{
		Key:     raw,
		Attempt: &k.attempt,
	})
}

type attemptTableKeySorter []attemptTableKey

func (a attemptTableKeySorter) Less(i, j int) bool {
	return a[i].key.Less(a[j].key)
}

func (a attemptTableKeySorter) Swap(i, j int) {
	a[i], a[j] = a[j], a[i]
}

func (a attemptTableKeySorter) Len() int {
	return len(a)
}

type taskTableRow struct {
	key attemptTableKey
	row TableRow
}

type TableRow interface {
	Key() TableKey
}

type Scanner interface {
	Scan(row any) error
}

type TableImpl interface {
	ParseKey(raw json.RawMessage) (TableKey, error)
	ParseRow(scanner Scanner) (TableRow, error)
	ProcessRow(ctx *gotasks.Context, row TableRow, dryRun bool) error
}

type fakeState struct {
	state models.State
}

func (s *fakeState) ScanState(state interface{}) error {
	return s.state.ScanState(state)
}

func (s *fakeState) SaveState(state interface{}) error {
	return s.state.SetState(state)
}

type Reader interface {
	Next() bool
	Scan(row any) error
	Err() error
	Close() error
}

type Table interface {
	ReadFrom(ctx context.Context, begin TableKey) (Reader, error)
}

type ytTableImpl struct {
	yt   yt.Client
	path ypath.Path
}

func (t *ytTableImpl) ReadFrom(ctx context.Context, begin TableKey) (Reader, error) {
	richPath := t.path.Rich()
	if begin != nil {
		richPath = richPath.AddRange(ypath.StartingFrom(ypath.Key(begin.RawKey()...)))
	}
	return t.yt.ReadTable(ctx, richPath, nil)
}

func TableProcessor(
	ctx *gotasks.Context, tableState exports.TableState, table Table,
	workers int, impl TableImpl, dryRun bool,
) (errRes error) {
	var state tableProcessorState
	if err := tableState.ScanState(&state); err != nil {
		return fmt.Errorf("unable to scan raw state: %w", err)
	}
	// finish represents last key that was already processed.
	var finish TableKey
	if len(state.Finish) != 0 {
		key, err := impl.ParseKey(state.Finish)
		if err != nil {
			return err
		}
		finish = key
	}
	var queue []attemptTableKey
	for _, raw := range state.InProgress {
		if len(raw) == 0 {
			continue
		}
		var rawKey tableProcessorStateKey
		if err := json.Unmarshal(raw, &rawKey); err != nil {
			return err
		}
		if rawKey.Attempt != nil {
			key, err := impl.ParseKey(rawKey.Key)
			if err != nil {
				return err
			}
			queue = append(queue, attemptTableKey{
				key:     key,
				attempt: *rawKey.Attempt,
			})
		} else {
			key, err := impl.ParseKey(raw)
			if err != nil {
				return err
			}
			queue = append(queue, attemptTableKey{
				key:     key,
				attempt: 1,
			})
		}
	}
	sort.Sort(attemptTableKeySorter(queue))
	// start represents first row, that should be started to process.
	start := finish
	if len(queue) > 0 && start != nil && queue[0].key.Less(start) {
		start = queue[0].key
	}
	in, err := table.ReadFrom(ctx.Context, start)
	if err != nil {
		return err
	}
	defer func() {
		_ = in.Close()
	}()
	tasks := list.New()
	var mutex sync.Mutex
	queuePos := 0
	defer func() {
		mutex.Lock()
		defer mutex.Unlock()
		if finish == nil || finish.Less(start) {
			state.Finish, err = json.Marshal(start)
			if err != nil {
				errRes = err
				return
			}
		}
		state.InProgress = nil
		for it := tasks.Front(); it != nil; it = it.Next() {
			task := it.Value.(taskTableRow)
			key := task.key
			// TODO(iudovin@): Replace constant with configurable value.
			if key.attempt++; key.attempt > 32 {
				ctx.Logger.Error(
					"Attempts limit reached",
					log.Any("row", task.row),
					log.Int("attempt", key.attempt),
				)
				continue
			}
			raw, err := json.Marshal(key)
			if err != nil {
				errRes = err
				return
			}
			state.InProgress = append(state.InProgress, raw)
		}
		for _, key := range queue[queuePos:] {
			raw, err := json.Marshal(key)
			if err != nil {
				errRes = err
				return
			}
			state.InProgress = append(state.InProgress, raw)
		}
		if !dryRun {
			if err := tableState.SaveState(state); err != nil {
				errRes = err
				return
			}
		}
	}()
	var waiter sync.WaitGroup
	defer waiter.Wait()
	taskQueue := make(chan *list.Element)
	defer close(taskQueue)
	for i := 0; i < workers; i++ {
		waiter.Add(1)
		go func() {
			defer waiter.Done()
			for it := range taskQueue {
				var task taskTableRow
				func() {
					mutex.Lock()
					defer mutex.Unlock()
					task = it.Value.(taskTableRow)
				}()
				if err := impl.ProcessRow(ctx, task.row, dryRun); err != nil {
					ctx.Logger.Error(
						"Unable to process row",
						log.Error(err),
						log.Any("row", task.row),
						log.Int("attempt", task.key.attempt),
					)
				} else {
					func() {
						mutex.Lock()
						defer mutex.Unlock()
						tasks.Remove(it)
					}()
				}
			}
		}()
	}
	for in.Next() {
		select {
		case <-ctx.Context.Done():
			return ctx.Context.Err()
		default:
		}
		tasksLen := 0
		func() {
			mutex.Lock()
			defer mutex.Unlock()
			tasksLen = tasks.Len()
		}()
		// TODO(iudovin@): Replace constant with configurable value.
		if tasksLen > 512 {
			return fmt.Errorf("too many skipped tasks: %d", tasksLen)
		}
		row, err := impl.ParseRow(in)
		if err != nil {
			return err
		}
		key := row.Key()
		for queuePos < len(queue) && queue[queuePos].key.Less(key) {
			queuePos++
		}
		if finish == nil || finish.Less(key) {
			var it *list.Element
			func() {
				mutex.Lock()
				defer mutex.Unlock()
				task := taskTableRow{
					key: attemptTableKey{key: key, attempt: 1},
					row: row,
				}
				it = tasks.PushBack(task)
				start = key
			}()
			taskQueue <- it
		} else if queuePos < len(queue) && !key.Less(queue[queuePos].key) {
			var it *list.Element
			func() {
				mutex.Lock()
				defer mutex.Unlock()
				task := taskTableRow{
					key: queue[queuePos],
					row: row,
				}
				it = tasks.PushBack(task)
			}()
			taskQueue <- it
		}
	}
	for queuePos < len(queue) && (finish == nil || !finish.Less(queue[queuePos].key)) {
		queuePos++
	}
	for queuePos < len(queue) && (start == nil || !start.Less(queue[queuePos].key)) {
		queuePos++
	}
	return in.Err()
}

type achievementKey struct {
	UpdateTime int64  `json:"update_time"`
	UserID     string `json:"user_id"`
}

func (k achievementKey) Less(o TableKey) bool {
	rhs := o.(achievementKey)
	if k.UpdateTime != rhs.UpdateTime {
		return k.UpdateTime < rhs.UpdateTime
	}
	if k.UserID != rhs.UserID {
		return k.UserID < rhs.UserID
	}
	return false
}

func (k achievementKey) RawKey() []any {
	return []any{k.UpdateTime, k.UserID}
}

type achievementRow struct {
	UpdateTime int64    `yson:"update_timestamp"`
	UserID     string   `yson:"user_id"`
	Level      *int     `yson:"level"`
	LevelValue *float64 `yson:"level_value"`
	LevelTime  *int64   `yson:"level_timestamp"`
}

func (r achievementRow) Key() TableKey {
	return achievementKey{
		UpdateTime: r.UpdateTime,
		UserID:     r.UserID,
	}
}

type achievementsImpl struct {
	drive       *api.Client
	achievement string
	firstPush   *template.Template
	push        *template.Template
}

func (t *achievementsImpl) ParseKey(raw json.RawMessage) (TableKey, error) {
	var key achievementKey
	err := json.Unmarshal(raw, &key)
	return key, err
}

func (t *achievementsImpl) ParseRow(scanner Scanner) (TableRow, error) {
	var row achievementRow
	err := scanner.Scan(&row)
	return row, err
}

func (t *achievementsImpl) ProcessRow(ctx *gotasks.Context, rawRow TableRow, dryRun bool) error {
	labels := map[string]string{"achievement": t.achievement}
	row := rawRow.(achievementRow)
	tags, err := t.drive.GetUserTags(row.UserID)
	if err != nil {
		return fmt.Errorf("unable to fetch user tags: %w", err)
	}
	var prevTagData api.AchievementTagData
	var prevTag api.UserTag
	for _, tag := range tags {
		if !strings.HasPrefix(tag.Tag, "user_achievement_") {
			continue
		}
		var tagData api.AchievementTagData
		if err := tag.Data.(*api.RawTagData).Scan(&tagData); err != nil {
			return err
		}
		if tag.Tag == t.achievement {
			prevTag = tag
			prevTagData = tagData
		}
	}
	achievementTagData := api.AchievementTagData{
		Level:      row.Level,
		LevelValue: row.LevelValue,
		LevelTime:  row.LevelTime,
	}
	achievementTag := api.UserTag{
		ID:     prevTag.ID,
		Tag:    t.achievement,
		Data:   &achievementTagData,
		UserID: row.UserID,
	}
	ctx.Logger.Debug(
		"Updating user tag",
		log.Any("tag", achievementTag),
		log.Bool("dry_run", dryRun),
	)
	if !dryRun {
		if achievementTag.ID == "" {
			if err := t.drive.AddUserTag(achievementTag); err != nil {
				ctx.Signal("users.achievements.update_user_tag.error_sum", labels).Add(1)
				return fmt.Errorf("unable to create achievement tag: %w", err)
			}
		} else {
			if err := t.drive.UpdateUserTag(achievementTag); err != nil {
				ctx.Signal("users.achievements.update_user_tag.error_sum", labels).Add(1)
				return fmt.Errorf("unable to update achievement tag: %w", err)
			}
		}
		ctx.Signal("users.achievements.update_user_tag.ok_sum", labels).Add(1)
	}
	if achievementTagData.Level == nil || *achievementTagData.Level < 1 {
		return nil
	}
	pushData := api.GenericTagData{
		"attachments":   []interface{}{},
		"template_args": map[string]interface{}{},
	}
	templateData := map[string]interface{}{"level": nil}
	if achievementTagData.Level != nil {
		templateData["level"] = *achievementTagData.Level
	}
	if prevTagData.Level == nil || *prevTagData.Level < 1 {
		var push strings.Builder
		if err := t.firstPush.Execute(&push, templateData); err != nil {
			return fmt.Errorf("unable to build first push tag: %w", err)
		}
		if push.Len() > 0 {
			pushTag := api.UserTag{
				Tag:    push.String(),
				Data:   &pushData,
				UserID: row.UserID,
			}
			ctx.Logger.Debug("First achievement push", log.Any("tag", pushTag), log.Bool("dry_run", dryRun))
			if !dryRun {
				if err := t.drive.AddUserTag(pushTag); err != nil {
					ctx.Signal("users.achievements.first_achievement_push.error_sum", labels).Add(1)
					return fmt.Errorf("unable to add first achievement push tag: %w", err)
				} else {
					ctx.Signal("users.achievements.first_achievement_push.ok_sum", labels).Add(1)
				}
			}
		}
	} else if *prevTagData.Level < *achievementTagData.Level {
		var push strings.Builder
		if err := t.push.Execute(&push, templateData); err != nil {
			return fmt.Errorf("unable to build push tag: %w", err)
		}
		if push.Len() > 0 {
			pushTag := api.UserTag{
				Tag:    push.String(),
				Data:   &pushData,
				UserID: row.UserID,
			}
			ctx.Logger.Debug("New achievement push", log.Any("tag", pushTag), log.Bool("dry_run", dryRun))
			if !dryRun {
				if err := t.drive.AddUserTag(pushTag); err != nil {
					ctx.Signal("users.achievements.new_achievement_push.error_sum", labels).Add(1)
					return fmt.Errorf("unable to add new achievement push tag: %w", err)
				} else {
					ctx.Signal("users.achievements.new_achievement_push.ok_sum", labels).Add(1)
				}
			}
		}
	}
	return nil
}

func updateAchievementsMain(ctx *gotasks.Context) (errOut error) {
	dryRun, err := ctx.Cmd.Flags().GetBool("dry-run")
	if err != nil {
		return err
	}
	yc, err := ctx.GetYT()
	if err != nil {
		return err
	}
	driveName, err := ctx.Cmd.Flags().GetString("drive")
	if err != nil {
		return err
	}
	drive, ok := ctx.Drives[driveName]
	if !ok {
		return fmt.Errorf("invalid drive name %q", driveName)
	}
	ytTable, err := ctx.Cmd.Flags().GetString("yt-table")
	if err != nil {
		return err
	}
	achievement, err := ctx.Cmd.Flags().GetString("achievement")
	if err != nil {
		return err
	}
	firstPush, err := ctx.Cmd.Flags().GetString("first-push")
	if err != nil {
		return err
	}
	push, err := ctx.Cmd.Flags().GetString("push")
	if err != nil {
		return err
	}
	workers, err := ctx.Cmd.Flags().GetInt("workers")
	if err != nil {
		return err
	}
	stateless, err := ctx.Cmd.Flags().GetBool("stateless")
	if err != nil {
		return err
	}
	if !strings.HasPrefix(achievement, "user_achievement_") {
		return fmt.Errorf("achievement is empty")
	}
	impl := &achievementsImpl{
		drive:       drive,
		achievement: achievement,
		firstPush:   template.Must(template.New("first_push").Parse(firstPush)),
		push:        template.Must(template.New("push").Parse(push)),
	}
	table := &ytTableImpl{yt: yc, path: ypath.Path(ytTable)}
	if stateless {
		ctx.Logger.Warn("Running in stateless mode")
		return TableProcessor(ctx, &fakeState{state: models.State{State: models.JSON("null")}}, table, workers, impl, dryRun)
	} else {
		state, err := exports.GetTableState(ctx.States, "achievements/reader/"+achievement)
		if err != nil {
			return fmt.Errorf("unable to fetch state: %w", err)
		}
		return TableProcessor(ctx, state, table, workers, impl, dryRun)
	}
}
