package aggression

import (
	"database/sql"
	"fmt"
	"strings"
	"sync"
	"sync/atomic"
	"time"

	"github.com/dchest/siphash"
	"github.com/gofrs/uuid"
	"github.com/spf13/cobra"

	"a.yandex-team.ru/drive/analytics/gotasks"
	"a.yandex-team.ru/drive/analytics/gotasks/exports"
	"a.yandex-team.ru/drive/analytics/gotasks/models/users"
	"a.yandex-team.ru/drive/library/go/gosql"
	"a.yandex-team.ru/library/go/core/log"
	"a.yandex-team.ru/yt/go/ypath"
	"a.yandex-team.ru/yt/go/yson"
	"a.yandex-team.ru/yt/go/yt"
	"a.yandex-team.ru/zootopia/analytics/drive/api"
)

func init() {
	updateUsersCmd := cobra.Command{
		Use: "update-users",
		Run: gotasks.WrapMain(updateUsersMain),
	}
	updateUsersCmd.Flags().Bool("dry-run", false, "Enables dry-run mode")
	updateUsersCmd.PersistentFlags().String("yt-proxy", "hahn", "YT proxy")
	updateUsersCmd.PersistentFlags().String("db", "analytics", "DB name")
	updateUsersCmd.Flags().Bool("simple", false, "Disable hard logic for tags")
	updateUsersCmd.Flags().Float64("blocked-ratio", 0.0001, "")
	updateUsersCmd.Flags().Float64("price-up-ratio", 0.0001, "")
	updateUsersCmd.Flags().Int64("min-updated-count", 200, "")
	updateUsersCmd.Flags().Bool("scaled-score", false, "Enable scaled score")
	updateUsersCmd.Flags().Bool("percentile", false, "Enable percentile")
	updateUsersCmd.Flags().Bool("mileage", false, "Enable mileage")
	AggressionCmd.AddCommand(&updateUsersCmd)
}

func getRobotStateTx(
	tx gosql.Runner, store *users.RobotStateStore,
	robot string, userID string,
) (users.RobotState, error) {
	state, err := store.GetByRobotUserTx(tx, robot, userID)
	if err != nil {
		if err != sql.ErrNoRows {
			return users.RobotState{}, err
		}
		state.Robot = robot
		state.UserID = userID
		if err := store.CreateTx(tx, &state); err != nil {
			return users.RobotState{}, err
		}
	}
	return state, nil
}

type blockInfo struct {
	Tag       string `json:"tag,omitempty"`
	BeginTime int64  `json:"begin_time"`
	EndTime   int64  `json:"end_time,omitempty"`
}

type scoredSession struct {
	SessionID  string  `json:"session_id"`
	FinishTime int64   `json:"finish_time"`
	Mileage    float64 `json:"mileage"`
}

const (
	ScoredMileageLimit = 200
	TrialMileageLimit  = 100
)

type userStateData struct {
	// HasAggression contains flag that user has aggression.
	HasAggression bool `json:"has_aggression"`
	// AggressionTime contains time of last aggression.
	AggressionTime int64 `json:"aggression_time,omitempty"`
	// Mileage contains mileage of user.
	Mileage *float64 `json:"mileage,omitempty"`
	// TrialMileage contains user mileage in trial period.
	TrialMileage float64 `json:"trial_mileage"`
	// Score contains user score.
	Score       float64  `json:"score"`
	ScaledScore *float64 `json:"scaled_score,omitempty"`
	Stage       string   `json:"stage"`
	// ScoreTime contains time of last user score.
	ScoreTime int64 `json:"score_time"`
	// Blocks contains information about all blocks.
	Blocks []blockInfo `json:"blocks"`
	// TrialPeriods contains information about all trial periods.
	TrialPeriods []blockInfo `json:"trial_periods"`
	// HighPrices contains information about all high prices.
	HighPrices []blockInfo `json:"high_prices"`
	// Block contains information about current block.
	Block *blockInfo `json:"block,omitempty"`
	// TrialPeriod contains information about trial period.
	TrialPeriod *blockInfo `json:"trial_period,omitempty"`
	// HighPrice contains information about high price.
	HighPrice *blockInfo `json:"high_price,omitempty"`
	// LastSessions.
	LastSessions []string `json:"last_sessions"`
	LastTime     int64    `json:"last_time"`
	// ExpGroup contains id of experiment group.
	ExpGroup int `json:"exp_group"` // deprecated
	// ScoredSessions.
	ScoredSessions []scoredSession `json:"scored_sessions"`
}

func (s *userStateData) AddSession(row sessionScoringRow) bool {
	session := scoredSession{
		SessionID:  row.SessionID,
		Mileage:    row.Mileage,
		FinishTime: row.FinishTime,
	}
	pos := len(s.ScoredSessions)
	for i := pos - 1; i >= 0; i-- {
		if session.SessionID == s.ScoredSessions[i].SessionID {
			return false
		}
		if session.FinishTime >= s.ScoredSessions[i].FinishTime {
			break
		}
		pos = i
	}
	for i := pos - 1; i >= 0; i-- {
		if session.FinishTime > s.ScoredSessions[i].FinishTime {
			break
		}
		if session.SessionID == s.ScoredSessions[i].SessionID {
			return false
		}
	}
	added := false
	if pos == len(s.ScoredSessions) {
		s.ScoredSessions = append(s.ScoredSessions, session)
		added = true
	} else {
		s.ScoredSessions = append(
			s.ScoredSessions[:pos+1], s.ScoredSessions[pos:]...,
		)
		s.ScoredSessions[pos] = session
		added = true
	}
	var mileageSum float64
	for _, sess := range s.ScoredSessions {
		mileageSum += sess.Mileage
	}
	for len(s.ScoredSessions) > 1 &&
		mileageSum-s.ScoredSessions[0].Mileage > ScoredMileageLimit {
		if s.ScoredSessions[0].SessionID == session.SessionID {
			added = false
		}
		mileageSum -= s.ScoredSessions[0].Mileage
		s.ScoredSessions = s.ScoredSessions[1:]
	}
	return added
}

func (s *userStateData) GetTrialMileage() float64 {
	if s.TrialPeriod == nil {
		return 0
	}
	var mileageSum float64
	for _, sess := range s.ScoredSessions {
		if sess.FinishTime >= s.TrialPeriod.BeginTime {
			mileageSum += sess.Mileage
		}
	}
	return mileageSum
}

type scoringDetails struct {
	LateralAccelerationCornering *int64 `yson:"lateral_acceleration_cornering" json:"lateral_acceleration_cornering"`
	LateralAccelerationStraight  *int64 `yson:"lateral_acceleration_straight" json:"lateral_acceleration_straight"`
	HighEngineSpeed              *int64 `yson:"high_engine_speed" json:"high_engine_speed"`
	HarshBraking                 *int64 `yson:"harsh_braking" json:"harsh_braking"`
	SharpAcceleration            *int64 `yson:"sharp_acceleration" json:"sharp_acceleration"`
	SectionsWithSpeeding         *int64 `yson:"sections_with_speeding" json:"sections_with_speeding"`
}

// unsafe represents value that can be probably invalid.
type unsafe[T any] struct {
	Value T
	Err   error
}

func (v *unsafe[T]) UnmarshalYSON(bytes []byte) error {
	v.Err = yson.Unmarshal(bytes, &v.Value)
	return nil
}

// Check that unsafe implements yson.Unmarshal.
var _ yson.Unmarshaler = (*unsafe[float64])(nil)

type userScoringRow struct {
	Time                      int64                   `yson:"timestamp"`
	UserID                    uuid.UUID               `yson:"user_id"`
	Score                     float64                 `yson:"score"`
	Rank                      *float64                `yson:"rank"`
	Threshold                 float64                 `yson:"aggression_threshold"`
	BanThreshold              float64                 `yson:"ban_threshold"`
	ScaledScore               *float64                `yson:"scaled_score"`
	ScaledNormalThreshold     *float64                `yson:"scaled_normal_threshold"`
	ScaledAggressionThreshold *float64                `yson:"scaled_aggression_threshold"`
	ScaledBanThreshold        *float64                `yson:"scaled_ban_threshold"`
	InvertedScale             *bool                   `yson:"inverted_scale"`
	Status                    string                  `yson:"aggression_status"`
	Mileage                   *float64                `yson:"mileage"`
	Percentile                *int64                  `yson:"percentile"`
	AggregatedProfile         unsafe[*scoringDetails] `yson:"aggregated_profile"`
}

func hashUserID(deviceID string) int {
	var key [16]byte
	h := siphash.New(key[:])
	_, _ = h.Write([]byte(deviceID))
	return int(h.Sum64() % 100)
}

func ptr[T any](x T) *T {
	return &x
}

func (m *userScoringImpl) processUserScoringRowRetry(
	row userScoringRow,
	stats *userScoringStats,
) error {
	var err error
	for i := 0; i < 3; i++ {
		if err = m.processUserScoringRow(row, stats); err == nil {
			return nil
		}
		m.Context.Logger.Warn(
			"Unable to process user scoring row",
			log.Any("row", row),
			log.Error(err),
		)
		select {
		case <-m.Context.Context.Done():
			return m.Context.Context.Err()
		case <-time.After(time.Duration(1+i) * time.Second):
		}
	}
	return err
}

func (m *userScoringImpl) processUserScoringRow(
	row userScoringRow,
	stats *userScoringStats,
) error {
	if err := gosql.WithTxContext(
		m.Context.Context, m.DB, nil, func(tx *sql.Tx) error {
			userRawState, err := m.States.GetForUpdateByRobotUserTx(
				tx, "user_aggression", row.UserID.String(),
			)
			if err != nil {
				if err != sql.ErrNoRows {
					return err
				}
				userRawState.UserID = row.UserID.String()
				userRawState.Robot = "user_aggression"
				if err := m.States.CreateTx(tx, &userRawState); err != nil {
					m.Context.Signal("aggression.users.create_user_state.error_sum", nil).Add(1)
					return err
				}
			}
			userState := userStateData{}
			if err := userRawState.ScanState(&userState); err != nil {
				return err
			}
			if row.Time <= userState.ScoreTime {
				return nil
			}
			if m.ScaledScore && row.ScaledScore == nil {
				return nil
			}
			if m.ScaledScore && row.InvertedScale != nil && !*row.InvertedScale {
				panic("InvertedScale == false")
			}
			scoreData := api.GenericTagData{
				"value":     row.Score,
				"timestamp": row.Time,
			}
			scaledScoreData := api.GenericTagData{
				"value":     row.ScaledScore,
				"timestamp": row.Time,
			}
			if row.AggregatedProfile.Err != nil {
				m.Context.Logger.Warn(
					"Invalid aggregated profile",
					log.String("user_id", row.UserID.String()),
					log.Error(row.AggregatedProfile.Err),
				)
			} else if row.AggregatedProfile.Value != nil {
				scoreData["details"] = *row.AggregatedProfile.Value
				scaledScoreData["details"] = *row.AggregatedProfile.Value
			}
			if row.Rank != nil {
				scoreData["rank"] = row.Rank
			}
			if userState.Score > 0 {
				scoreData["previous_value"] = userState.Score
			}
			if userState.ScaledScore != nil {
				scaledScoreData["previous_value"] = *userState.ScaledScore
			}
			now := time.Now().Unix()
			userState.Score = row.Score
			userState.ScaledScore = row.ScaledScore
			userState.Stage = row.Status
			userState.ScoreTime = row.Time
			scoreTag := api.UserTag{
				Tag:    "scoring_user_tag",
				Data:   &scoreData,
				UserID: row.UserID.String(),
			}
			scaledScoreTag := api.UserTag{
				Tag:    "user_aggression_scoring",
				Data:   &scaledScoreData,
				UserID: row.UserID.String(),
			}
			percentileTag := api.UserTag{
				Tag: "user_scoring_insurance_tag",
				Data: &api.GenericTagData{
					"value":     row.Percentile,
					"timestamp": row.Time,
				},
				UserID: row.UserID.String(),
			}
			m.Context.Logger.Debug(
				"Updating user tag",
				log.Any("tag", scoreTag),
				log.Any("scaled_tag", scaledScoreTag),
				log.Any("percentile_tag", percentileTag),
				log.Float64("threshold", row.Threshold),
			)
			atomic.AddInt64(&stats.UpdatedCount, 1)
			if !m.DryRun {
				if err := m.Context.Drive.AddUserTag(scoreTag); err != nil {
					m.Context.Signal("aggression.users.update_scoring_user_tag.error_sum", nil).Add(1)
					return err
				}
				m.Context.Signal("aggression.users.update_scoring_user_tag.ok_sum", nil).Add(1)
			}
			if !m.DryRun && m.ScaledScore {
				if err := m.Context.Drive.AddUserTag(scaledScoreTag); err != nil {
					m.Context.Signal("aggression.users.update_user_aggression_scoring.error_sum", nil).Add(1)
					return err
				}
				m.Context.Signal("aggression.users.update_user_aggression_scoring.ok_sum", nil).Add(1)
			}
			if !m.DryRun && m.Percentile && row.Percentile != nil {
				if err := m.Context.Drive.AddUserTag(percentileTag); err != nil {
					m.Context.Signal("aggression.users.update_user_scoring_insurance_tag.error_sum", nil).Add(1)
					return err
				}
				m.Context.Signal("aggression.users.update_user_scoring_insurance_tag.ok_sum", nil).Add(1)
			}
			if !m.DryRun && m.Mileage && row.Mileage != nil && (userState.Mileage == nil || *userState.Mileage != *row.Mileage) {
				mileageTag := api.UserTag{
					Tag: "user_aggression_mileage",
					Data: &api.GenericTagData{
						"value":     *row.Mileage,
						"timestamp": row.Time,
					},
					UserID: row.UserID.String(),
				}
				if err := m.Context.Drive.AddUserTag(mileageTag); err != nil {
					m.Context.Signal("aggression.users.update_user_aggression_mileage.error_sum", nil).Add(1)
					return err
				}
				m.Context.Signal("aggression.users.update_user_aggression_mileage.ok_sum", nil).Add(1)
				userState.Mileage = ptr(*row.Mileage)
			}
			if m.Simple {
				if err := userRawState.SetState(userState); err != nil {
					return err
				}
				if !m.DryRun {
					if err := m.States.UpdateTx(tx, userRawState); err != nil {
						m.Context.Signal("aggression.users.update_user_state.error_sum", nil).Add(1)
						return err
					}
					m.Context.Signal("aggression.users.update_user_state.ok_sum", nil).Add(1)
				}
				return nil
			}
			hasTrialPeriod := false
			{
				tags, err := m.Context.Drive.GetUserTags(row.UserID.String())
				if err != nil {
					return err
				}
				for _, tag := range tags {
					if tag.Tag == "user_aggressive_trial_period" {
						hasTrialPeriod = true
					}
				}
			}
			if userState.HighPrice != nil && row.Score < row.Threshold {
				tag := "user_aggressive_price_up"
				m.Context.Logger.Info(
					"Remove user tag",
					log.String("user_id", row.UserID.String()),
					log.String("tag", tag),
					log.Float64("threshold", row.Threshold),
				)
				atomic.AddInt64(&stats.PriceUpCount, 1)
				if !m.DryRun {
					if err := m.Context.Drive.RemoveUserTags(
						row.UserID.String(), tag,
					); err != nil {
						m.Context.Signal("aggression.users.remove_user_aggressive_price_up.error_sum", nil).Add(1)
						return err
					}
					m.Context.Signal("aggression.users.remove_user_aggressive_price_up.ok_sum", nil).Add(1)
				}
				userState.HighPrice.EndTime = now
				userState.HighPrices = append(
					userState.HighPrices, *userState.HighPrice,
				)
				userState.HighPrice = nil
			}
			if userState.Block != nil && userState.Block.EndTime != 0 && now >= userState.Block.EndTime {
				if err := unblockUser(m.Context, row.UserID.String(), &userState, now, m.DryRun); err != nil {
					return err
				}
			}
			if (userState.TrialPeriod != nil && userState.GetTrialMileage() >= TrialMileageLimit) || (hasTrialPeriod && userState.TrialPeriod == nil) {
				m.Context.Logger.Info(
					"Remove user trial tag",
					log.String("user_id", row.UserID.String()),
					log.Float64("threshold", row.BanThreshold),
				)
				if !m.DryRun {
					if err := m.Context.Drive.RemoveUserTags(
						row.UserID.String(), "user_aggressive_trial_period",
					); err != nil {
						m.Context.Signal("aggression.users.remove_user_aggressive_trial_period.error_sum", nil).Add(1)
						return err
					}
					m.Context.Signal("aggression.users.remove_user_aggressive_trial_period.ok_sum", nil).Add(1)
				}
				if userState.TrialPeriod != nil {
					userState.TrialPeriod.EndTime = now
					userState.TrialPeriods = append(
						userState.TrialPeriods, *userState.TrialPeriod,
					)
					userState.TrialPeriod = nil
					userState.TrialMileage = 0
				}
			}
			if userState.HighPrice == nil && row.Score >= row.Threshold &&
				userState.HasAggression &&
				userState.AggressionTime >= now-60*60*24*7 &&
				userState.AggressionTime >= startBanTime &&
				now >= startBanTime {
				tag := api.UserTag{
					Tag: "user_aggressive_price_up",
					Data: &api.GenericTagData{
						"value":     row.Score,
						"timestamp": row.Time,
					},
					UserID: row.UserID.String(),
				}
				m.Context.Logger.Info(
					"Adding user tag",
					log.Any("tag", tag),
					log.Float64("threshold", row.Threshold),
				)
				if !m.DryRun {
					if err := m.Context.Drive.AddUserTag(tag); err != nil {
						m.Context.Signal("aggression.users.add_user_aggressive_price_up.error_sum", nil).Add(1)
						return err
					}
					m.Context.Signal("aggression.users.add_user_aggressive_price_up.ok_sum", nil).Add(1)
				}
				userState.HighPrice = &blockInfo{BeginTime: now}
			}
			if row.Score >= row.BanThreshold && userState.Block == nil &&
				userState.TrialPeriod == nil && userState.HasAggression &&
				userState.AggressionTime >= now-60*60*24*7 &&
				userState.AggressionTime >= startBanTime &&
				now >= startBanTime {
				userState.Block = &blockInfo{BeginTime: now}
				userState.TrialMileage = 0
				days := 0
				tagName := "blocked_user_aggressive_forever"
				switch len(userState.Blocks) {
				case 0:
					days = 3
				case 1:
					days = 7
				case 2:
					days = 10
				}
				if days > 0 {
					userState.Block.EndTime = userState.Block.BeginTime + int64(days)*24*60*60
					tagName = fmt.Sprintf("blocked_user_aggressive_%ddays", days)
				}
				tag := api.UserTag{
					Tag: tagName,
					Data: &api.GenericTagData{
						"comment": fmt.Sprintf(
							"Превышение порога: %f >= %f",
							row.Score, row.BanThreshold,
						),
					},
					UserID: row.UserID.String(),
				}
				m.Context.Logger.Info(
					"Adding user tag",
					log.Any("tag", tag),
					log.Float64("threshold", row.BanThreshold),
				)
				atomic.AddInt64(&stats.BlockedCount, 1)
				if !m.DryRun {
					if err := m.Context.Drive.AddUserTag(tag); err != nil {
						m.Context.Signal(fmt.Sprintf("aggression.users.add_%s.error_sum", tagName), nil).Add(1)
						return err
					}
					m.Context.Signal(fmt.Sprintf("aggression.users.add_%s.ok_sum", tagName), nil).Add(1)
				}
			}
			if err := userRawState.SetState(userState); err != nil {
				return err
			}
			if !m.DryRun {
				if err := m.States.UpdateTx(tx, userRawState); err != nil {
					m.Context.Signal("aggression.users.update_user_state.error_sum", nil).Add(1)
					return err
				}
				m.Context.Signal("aggression.users.update_user_state.ok_sum", nil).Add(1)
			}
			return nil
		}); err != nil {
		return err
	}
	return nil
}

func (m *userScoringImpl) processUserScoringRows(
	state readerStateData,
	table ypath.Path,
	stats *userScoringStats,
) (lastTime int64, err error) {
	m.Context.Logger.Info(
		"Reading users",
		log.Int64("last_time", state.LastTime),
	)
	var timesMutex sync.Mutex
	timesCount := map[int64]int{}
	lastTime = state.LastTime
	defer func() {
		for time, count := range timesCount {
			if count > 0 && time < lastTime {
				lastTime = time
			}
		}
	}()
	var waiter sync.WaitGroup
	defer waiter.Wait()
	rows := make(chan userScoringRow)
	defer close(rows)
	for i := 0; i < 32; i++ {
		waiter.Add(1)
		go func() {
			defer waiter.Done()
			for row := range rows {
				if err := m.processUserScoringRowRetry(row, stats); err != nil {
					m.Context.Logger.Error(
						"Unable to process user scoring row",
						log.Any("row", row),
						log.Error(err),
					)
				} else {
					func() {
						timesMutex.Lock()
						defer timesMutex.Unlock()
						timesCount[row.Time]--
					}()
				}
			}
		}()
	}
	in, err := m.YT.ReadTable(
		m.Context.Context,
		table.Rich().AddRange(
			ypath.StartingFrom(ypath.Key(state.LastTime)),
		),
		nil,
	)
	if err != nil {
		return 0, err
	}
	defer func() {
		_ = in.Close()
	}()
	for in.Next() {
		select {
		case <-m.Context.Context.Done():
			err = m.Context.Context.Err()
			return
		default:
		}
		var row userScoringRow
		if err = in.Scan(&row); err != nil {
			return
		}
		if row.Time < state.LastTime {
			err = fmt.Errorf("invalid update time of record")
			return
		}
		if row.Time > lastTime {
			lastTime = row.Time
		}
		func() {
			timesMutex.Lock()
			defer timesMutex.Unlock()
			timesCount[row.Time]++
		}()
		select {
		case rows <- row:
		case <-m.Context.Context.Done():
			err = m.Context.Context.Err()
			return
		}
	}
	err = in.Err()
	return
}

func unblockUser(
	ctx *gotasks.Context, userID string, userState *userStateData, now int64, dryRun bool,
) error {
	if userState.Block == nil {
		return nil
	}
	ctx.Logger.Info(
		"Remove user block tags",
		log.String("user_id", userID),
	)
	if !dryRun {
		if err := ctx.Drive.RemoveUserTagList(
			userID, []string{
				"blocked_user_aggressive_3days",
				"blocked_user_aggressive_7days",
				"blocked_user_aggressive_10days",
			},
		); err != nil {
			ctx.Signal("aggression.users.remove_blocked_user_aggressive.error_sum", nil).Add(1)
			return err
		}
		ctx.Signal("aggression.users.remove_blocked_user_aggressive.ok_sum", nil).Add(1)
	}
	tag := api.UserTag{
		Tag:    "user_aggressive_trial_period",
		Data:   &api.GenericTagData{"value": 0},
		UserID: userID,
	}
	ctx.Logger.Info(
		"Adding user trial tag",
		log.Any("tag", tag),
	)
	if !dryRun {
		if err := ctx.Drive.AddUserTag(tag); err != nil {
			ctx.Signal("aggression.users.add_user_aggressive_trial_period.error_sum", nil).Add(1)
			return err
		}
		ctx.Signal("aggression.users.add_user_aggressive_trial_period.ok_sum", nil).Add(1)
	}
	userState.Blocks = append(userState.Blocks, *userState.Block)
	userState.Block = nil
	userState.TrialPeriod = &blockInfo{BeginTime: now}
	userState.TrialMileage = 0
	return nil
}

func unblockUsers(
	ctx *gotasks.Context, db *gosql.DB, yc yt.Client,
	store *users.RobotStateStore, dryRun bool,
) error {
	ctx.Logger.Debug("Unblock users")
	now := time.Now().Unix()
	var row users.RobotState
	names, values := gosql.StructNameValues(&row, true)
	var columns strings.Builder
	for i, name := range names {
		if i > 0 {
			columns.WriteString(", ")
		}
		columns.WriteString(fmt.Sprintf("%q", name))
	}
	rows, err := db.Query(
		fmt.Sprintf(
			`SELECT %s FROM "user_robot_state"`+
				` WHERE robot = 'user_aggression'`+
				` AND state->'block'->'end_time' <= $1`,
			columns.String(),
		),
		now,
	)
	if err != nil {
		return err
	}
	ctx.Logger.Debug("Iterating blocked users")
	for rows.Next() {
		if err := rows.Scan(values...); err != nil {
			return err
		}
		userRawState := row
		userState := userStateData{}
		if err := userRawState.ScanState(&userState); err != nil {
			return err
		}
		if userState.Block != nil && userState.Block.EndTime != 0 && now >= userState.Block.EndTime {
			if err := unblockUser(ctx, row.UserID, &userState, now, dryRun); err != nil {
				return err
			}
		}
		if err := userRawState.SetState(userState); err != nil {
			return err
		}
		if !dryRun {
			if err := store.UpdateTx(db, userRawState); err != nil {
				ctx.Signal("aggression.users.update_user_state.error_sum", nil).Add(1)
				return err
			}
			ctx.Signal("aggression.users.update_user_state.ok_sum", nil).Add(1)
		}
	}
	return rows.Err()
}

type userScoringImpl struct {
	Context     *gotasks.Context
	DB          *gosql.DB
	YT          yt.Client
	Drive       *api.Client
	States      *users.RobotStateStore
	State       exports.TableState
	DryRun      bool
	Simple      bool
	ScaledScore bool
	Percentile  bool
	Mileage     bool
}

type userScoringStats struct {
	UpdatedCount int64
	BlockedCount int64
	PriceUpCount int64
}

func (m *userScoringImpl) ProcessTable(path ypath.Path) (userScoringStats, error) {
	state := readerStateData{}
	m.Context.Logger.Debug("Scan reader state")
	if err := m.State.ScanState(&state); err != nil {
		return userScoringStats{}, err
	}
	stats := userScoringStats{}
	lastTime, err := m.processUserScoringRows(state, path, &stats)
	if err != nil {
		return userScoringStats{}, err
	}
	m.Context.Logger.Info(
		"Update reader state",
		log.Int64("last_time", lastTime),
		log.Bool("dryrun", m.DryRun),
	)
	if lastTime < state.LastTime {
		return userScoringStats{}, fmt.Errorf("last time is too small: %d < %d", lastTime, state.LastTime)
	}
	state.LastTime = lastTime
	if !m.DryRun {
		if err := m.State.SaveState(state); err != nil {
			m.Context.Signal("aggression.users.update_state.error_sum", nil).Add(1)
			return userScoringStats{}, err
		}
		m.Context.Signal("aggression.users.update_lag_last", nil).Set(float64(time.Now().Unix() - lastTime))
		m.Context.Signal("aggression.users.update_state.ok_sum", nil).Add(1)
	}
	return stats, nil
}

func must[T any](val T, err error) T {
	if err != nil {
		panic(err)
	}
	return val
}

func updateUsersMain(ctx *gotasks.Context) error {
	dryRun := must(ctx.Cmd.Flags().GetBool("dry-run"))
	dbName := must(ctx.Cmd.Flags().GetString("db"))
	simple := must(ctx.Cmd.Flags().GetBool("simple"))
	blockedRatio := must(ctx.Cmd.Flags().GetFloat64("blocked-ratio"))
	priceUpRatio := must(ctx.Cmd.Flags().GetFloat64("price-up-ratio"))
	minUpdatedCount := must(ctx.Cmd.Flags().GetInt64("min-updated-count"))
	scaledScore := must(ctx.Cmd.Flags().GetBool("scaled-score"))
	percentile := must(ctx.Cmd.Flags().GetBool("percentile"))
	mileage := must(ctx.Cmd.Flags().GetBool("mileage"))
	yc, err := ctx.GetYT()
	if err != nil {
		return err
	}
	db, ok := ctx.DBs[dbName]
	if !ok {
		return fmt.Errorf("invalid DB name %q", dbName)
	}
	drive := ctx.Drive
	if drive == nil {
		return fmt.Errorf("drive API required")
	}
	stateStore := users.NewRobotStateStore(db, "user_robot_state")
	if !simple {
		if err := unblockUsers(ctx, db, yc, stateStore, dryRun); err != nil {
			return err
		}
	}
	rawReaderState, err := exports.GetTableState(ctx.States, "aggression/users/reader")
	if err != nil {
		return err
	}
	impl := userScoringImpl{
		Context:     ctx,
		DB:          db,
		YT:          yc,
		Drive:       drive,
		States:      stateStore,
		State:       rawReaderState,
		DryRun:      true,
		Simple:      simple,
		ScaledScore: scaledScore,
		Percentile:  percentile,
		Mileage:     mileage,
	}
	if !simple {
		stats, err := impl.ProcessTable(ctx.Config.YTPaths.AggressionUserScoreTable)
		if err != nil {
			return err
		}
		ctx.Logger.Info(
			"Process table stats",
			log.Int64("price_up_count", stats.PriceUpCount),
			log.Int64("blocked_count", stats.BlockedCount),
			log.Int64("updated_count", stats.UpdatedCount),
		)
		if stats.UpdatedCount >= minUpdatedCount {
			if ratio := float64(stats.BlockedCount) / float64(stats.UpdatedCount); ratio > blockedRatio {
				ctx.Signal("aggression.users.blocked_count_ratio_alert_last", nil).Set(1)
				return fmt.Errorf("block ratio %g greater than allowed %g", ratio, blockedRatio)
			} else {
				ctx.Signal("aggression.users.blocked_count_ratio_alert_last", nil).Set(0)
			}
			if ratio := float64(stats.PriceUpCount) / float64(stats.UpdatedCount); ratio > priceUpRatio {
				ctx.Signal("aggression.users.price_up_count_ratio_alert_last", nil).Set(1)
				return fmt.Errorf("price up ratio %g greater than allowed %g", ratio, priceUpRatio)
			} else {
				ctx.Signal("aggression.users.price_up_count_ratio_alert_last", nil).Set(0)
			}
		} else {
			ctx.Signal("aggression.users.blocked_count_ratio_alert_last", nil).Set(0)
			ctx.Signal("aggression.users.price_up_count_ratio_alert_last", nil).Set(0)
		}
	}
	impl.DryRun = dryRun
	_, err = impl.ProcessTable(ctx.Config.YTPaths.AggressionUserScoreTable)
	return err
}
