package aggression

import (
	"context"
	"database/sql"
	"fmt"
	"sync"
	"time"

	"github.com/gofrs/uuid"
	"github.com/spf13/cobra"

	"a.yandex-team.ru/drive/analytics/gotasks"
	"a.yandex-team.ru/drive/analytics/gotasks/models/users"
	"a.yandex-team.ru/drive/library/go/gosql"
	"a.yandex-team.ru/library/go/core/log"
	"a.yandex-team.ru/yt/go/ypath"
	"a.yandex-team.ru/yt/go/yt"
	"a.yandex-team.ru/zootopia/analytics/drive/api"
	"a.yandex-team.ru/zootopia/library/go/geom"
)

func init() {
	updateSessionsCmd := cobra.Command{
		Use: "update-sessions",
		Run: gotasks.WrapMain(updateSessionsMain),
	}
	updateSessionsCmd.Flags().Bool("dry-run", false, "Enables dry-run mode")
	updateSessionsCmd.Flags().Bool("shadow", false, "Enables shadow mode")
	updateSessionsCmd.Flags().String("yt-proxy", "hahn", "YT proxy")
	updateSessionsCmd.Flags().String("db", "analytics", "DB name")
	updateSessionsCmd.Flags().String("drive", "", "Name of Drive client")
	updateSessionsCmd.Flags().Int("workers", 16, "Amount of parallel workers")
	updateSessionsCmd.Flags().Bool("no-users", false, "Disable updates of users")
	updateSessionsCmd.Flags().Duration("skip-after", 0, "Skip finished sessions before specified time")
	AggressionCmd.AddCommand(&updateSessionsCmd)
}

type readerStateData struct {
	LastTime int64 `json:"last_time"`
}

type sessionScoringEvent struct {
	Kind     string      `yson:"kind"`
	Position geom.Vec2   `yson:"position"`
	Score    float64     `yson:"score"`
	Time     interface{} `yson:"ts"`
}

type sessionScoringRow struct {
	Time       int64   `yson:"timestamp"`
	SessionID  string  `yson:"session_id"`
	UserID     string  `yson:"user_id"`
	StartTime  int64   `yson:"start_timestamp"`
	FinishTime int64   `yson:"finish_timestamp"`
	Score      float64 `yson:"score"`
	Threshold  float64 `yson:"threshold"`
	Mileage    float64 `yson:"mileage"`
	// Events contains list of aggressive events.
	Events []sessionScoringEvent `yson:"events"`
}

var startExpTime = time.Date(2021, 5, 4, 3, 0, 0, 0, time.UTC).Unix()
var startBanTime = time.Date(2021, 5, 11, 3, 0, 0, 0, time.UTC).Unix()

func compareScore(score, threshold float64) bool {
	if score > threshold {
		return true
	}
	return score >= threshold && threshold != 0
}

var eventKinds = map[string]string{
	"acceleration":                  "acceleration",
	"braking":                       "braking",
	"straight_lateral_acceleration": "straight_lateral_acceleration",
	"turning_lateral_acceleration":  "turning_lateral_acceleration",
}

func processSessionScoringRow(
	ctx *gotasks.Context,
	db *gosql.DB,
	store *users.RobotStateStore,
	drive *api.Client,
	row sessionScoringRow,
	dryRun bool, noUsers bool, skipAfter time.Duration,
) error {
	if skipAfter > 0 {
		if ts := time.Now().Add(-skipAfter).Unix(); row.FinishTime < ts {
			ctx.Logger.Debug("Skip too late row", log.Any("row", row))
			return nil
		}
	}
	newCtx, cancel := context.WithTimeout(ctx.Context, time.Second*30)
	defer cancel()
	if err := gosql.WithTxContext(
		newCtx, db, nil, func(tx *sql.Tx) error {
			userRawState, err := store.GetForUpdateByRobotUserTx(
				tx, "user_aggression", row.UserID,
			)
			if err != nil {
				if err != sql.ErrNoRows {
					return err
				}
				if row.Score < row.Threshold {
					return nil
				}
				userRawState.UserID = row.UserID
				userRawState.Robot = "user_aggression"
				if err := store.CreateTx(tx, &userRawState); err != nil {
					ctx.Signal("aggression.sessions.update_user_state.error_sum", nil).Add(1)
					return err
				}
			}
			userState := userStateData{}
			if err := userRawState.ScanState(&userState); err != nil {
				return err
			}
			if userState.LastTime > row.Time {
				return nil
			}
			if userState.LastTime < row.Time {
				userState.LastTime = row.Time
				userState.LastSessions = nil
			}
			for _, session := range userState.LastSessions {
				if session == row.SessionID {
					return nil
				}
			}
			userState.LastSessions = append(
				userState.LastSessions, row.SessionID,
			)
			isNewSession := userState.AddSession(row)
			if !noUsers {
				userState.TrialMileage += row.Mileage
				if userState.TrialPeriod != nil && isNewSession {
					mileage := userState.GetTrialMileage()
					if mileage > TrialMileageLimit {
						mileage = TrialMileageLimit
					}
					tag := api.UserTag{
						Tag:    "user_aggressive_trial_period",
						Data:   &api.GenericTagData{"value": mileage},
						UserID: row.UserID,
					}
					ctx.Logger.Debug(
						"Update user trial tag",
						log.Any("tag", tag),
					)
					if !dryRun {
						if err := drive.AddUserTag(tag); err != nil {
							ctx.Signal("aggression.sessions.update_user_aggressive_trial_period.error_sum", nil).Add(1)
							return err
						}
						ctx.Signal("aggression.sessions.update_user_aggressive_trial_period.ok_sum", nil).Add(1)
					}
				}
			}
			if compareScore(row.Score, row.Threshold) {
				if row.StartTime >= userState.AggressionTime {
					userState.AggressionTime = row.StartTime
				}
				var tagEvents []interface{}
				for _, event := range row.Events {
					eventData := map[string]interface{}{
						"value":     event.Score,
						"location":  event.Position,
						"timestamp": event.Time,
					}
					if value, ok := eventKinds[event.Kind]; ok {
						eventData["kind"] = value
					}
					tagEvents = append(tagEvents, eventData)
				}
				if len(row.Events) > 0 {
					tag := api.SessionTag{
						Tag: "scoring_trace_tag",
						Data: &api.GenericTagData{
							"value":     row.Score,
							"timestamp": row.Time,
							"events":    tagEvents,
						},
						SessionID: row.SessionID,
					}
					ctx.Logger.Debug(
						"Adding session tag",
						log.Any("tag", tag),
						log.String("user_id", row.UserID),
						log.Float64("threshold", row.Threshold),
						log.Bool("dryrun", dryRun),
					)
					if !dryRun {
						if err := drive.AddSessionTag(tag); err != nil {
							ctx.Signal("aggression.sessions.add_scoring_trace_tag.error_sum", nil).Add(1)
							return err
						}
						ctx.Signal("aggression.sessions.add_scoring_trace_tag.ok_sum", nil).Add(1)
					}
					if !noUsers {
						now := time.Now().Unix()
						if !userState.HasAggression && row.StartTime >= now-60*60*24*7 && isNewSession {
							tag := api.UserTag{
								Tag:    "user_aggressive_first_aggression",
								UserID: row.UserID,
							}
							ctx.Logger.Debug(
								"Adding user tag",
								log.Any("tag", tag),
								log.Float64("threshold", row.Threshold),
								log.Bool("dryrun", dryRun),
							)
							if !dryRun {
								if err := drive.AddUserTag(tag); err != nil {
									ctx.Signal("aggression.sessions.add_first_aggression_tag.error_sum", nil).Add(1)
									return err
								}
								ctx.Signal("aggression.sessions.add_first_aggression_tag.ok_sum", nil).Add(1)
							}
							userState.HasAggression = true
						}
					}
				}
			}
			if err := userRawState.SetState(userState); err != nil {
				return err
			}
			if !dryRun {
				if err := store.UpdateTx(tx, userRawState); err != nil {
					ctx.Signal("aggression.sessions.update_user_state.error_sum", nil).Add(1)
					return err
				}
				ctx.Signal("aggression.sessions.update_user_state.ok_sum", nil).Add(1)
			}
			ctx.Signal("aggression.sessions.update_session.ok_sum", nil).Add(1)
			return nil
		},
	); err != nil {
		ctx.Signal("aggression.sessions.update_session.error_sum", nil).Add(1)
		return err
	}
	return nil
}

func shadowProcessSessionScoringRow(
	ctx *gotasks.Context,
	drive *api.Client,
	row sessionScoringRow,
	dryRun bool,
) error {
	fromTime := time.Now().Unix() - 60*60*24*30*4
	if row.StartTime < fromTime {
		return nil
	}
	if compareScore(row.Score, row.Threshold) {
		var tagEvents []interface{}
		for _, event := range row.Events {
			tagEvents = append(tagEvents, map[string]interface{}{
				"value":     event.Score,
				"location":  event.Position,
				"timestamp": event.Time,
			})
		}
		tag := api.SessionTag{
			Tag: "scoring_trace_tag",
			Data: &api.GenericTagData{
				"value":     row.Score,
				"timestamp": row.Time,
				"events":    tagEvents,
			},
			SessionID: row.SessionID,
		}
		ctx.Logger.Debug(
			"Adding session tag",
			log.Any("tag", tag),
			log.String("user_id", row.UserID),
			log.Float64("threshold", row.Threshold),
			log.Bool("dryrun", dryRun),
		)
		if !dryRun {
			if err := drive.AddSessionTag(tag); err != nil {
				ctx.Signal("aggression.sessions_shadow.add_scoring_trace_tag.error_sum", nil).Add(1)
				return err
			}
			ctx.Signal("aggression.sessions_shadow.add_scoring_trace_tag.ok_sum", nil).Add(1)
		}
	}
	return nil
}

func updateReaderState(
	ctx *gotasks.Context,
	db *gosql.DB,
	store *users.RobotStateStore,
	rawState *users.RobotState,
	state *readerStateData,
	lastTime int64,
	dryRun bool,
) error {
	ctx.Logger.Info(
		"Update reader state",
		log.Int64("last_time", lastTime),
		log.Bool("dryrun", dryRun),
	)
	if lastTime < state.LastTime {
		ctx.Signal("aggression.sessions.update_state.error_sum", nil).Add(1)
		return fmt.Errorf(
			"last time is too small: %d < %d",
			lastTime, state.LastTime,
		)
	}
	state.LastTime = lastTime
	if err := rawState.SetState(state); err != nil {
		ctx.Signal("aggression.sessions.update_state.error_sum", nil).Add(1)
		return err
	}
	if !dryRun {
		if err := store.UpdateTx(db, *rawState); err != nil {
			ctx.Signal("aggression.sessions.update_state.error_sum", nil).Add(1)
			return err
		}
		ctx.Signal("aggression.sessions.update_state.ok_sum", nil).Add(1)
	}
	return nil
}

func processSessionScoringRows(
	ctx *gotasks.Context,
	db *gosql.DB,
	yc yt.Client,
	store *users.RobotStateStore,
	drive *api.Client,
	state readerStateData,
	table ypath.Path,
	workers int,
	shadow bool,
	dryRun bool,
	noUsers bool,
	skipAfter time.Duration,
) (lastTime int64, err error) {
	ctx.Logger.Info(
		"Reading sessions",
		log.Int64("last_time", state.LastTime),
	)
	var timesMutex sync.Mutex
	timesCount := map[int64]int{}
	lastTime = state.LastTime
	defer func() {
		for time, count := range timesCount {
			if count > 0 && time < lastTime {
				lastTime = time
			}
		}
	}()
	var waiter sync.WaitGroup
	defer waiter.Wait()
	rows := make(chan sessionScoringRow)
	defer close(rows)
	for i := 0; i < workers; i++ {
		waiter.Add(1)
		go func() {
			defer waiter.Done()
			for row := range rows {
				if !shadow {
					if err := processSessionScoringRow(
						ctx, db, store, drive, row, dryRun, noUsers, skipAfter,
					); err != nil {
						ctx.Logger.Error(
							"Unable to process session scoring row",
							log.Any("row", row),
							log.Error(err),
						)
					} else {
						func() {
							timesMutex.Lock()
							defer timesMutex.Unlock()
							timesCount[row.Time]--
						}()
					}
				} else {
					if err := shadowProcessSessionScoringRow(
						ctx, drive, row, dryRun,
					); err != nil {
						ctx.Logger.Error(
							"Unable to process session scoring row",
							log.Any("row", row),
							log.Error(err),
						)
					}
					func() {
						timesMutex.Lock()
						defer timesMutex.Unlock()
						timesCount[row.Time]--
					}()
				}
			}
		}()
	}
	in, err := yc.ReadTable(
		ctx.Context,
		table.Rich().AddRange(
			ypath.StartingFrom(ypath.Key(state.LastTime)),
		),
		nil,
	)
	if err != nil {
		return 0, err
	}
	defer func() {
		_ = in.Close()
	}()
	for in.Next() {
		select {
		case <-ctx.Context.Done():
			err = ctx.Context.Err()
			return
		default:
		}
		var row sessionScoringRow
		if err = in.Scan(&row); err != nil {
			return
		}
		if row.Time < state.LastTime {
			err = fmt.Errorf("invalid update time of record")
			return
		}
		if row.Time > lastTime {
			lastTime = row.Time
		}
		func() {
			timesMutex.Lock()
			defer timesMutex.Unlock()
			timesCount[row.Time]++
		}()
		select {
		case rows <- row:
		case <-ctx.Context.Done():
			err = ctx.Context.Err()
			return
		}
	}
	err = in.Err()
	return
}

func updateSessionsMain(ctx *gotasks.Context) error {
	dryRun, err := ctx.Cmd.Flags().GetBool("dry-run")
	if err != nil {
		return err
	}
	shadow, err := ctx.Cmd.Flags().GetBool("shadow")
	if err != nil {
		return err
	}
	yc, err := ctx.GetYT()
	if err != nil {
		return err
	}
	dbName, err := ctx.Cmd.Flags().GetString("db")
	if err != nil {
		return err
	}
	db, ok := ctx.DBs[dbName]
	if !ok {
		return fmt.Errorf("invalid DB name %q", dbName)
	}
	workers, err := ctx.Cmd.Flags().GetInt("workers")
	if err != nil {
		return err
	}
	driveName, err := ctx.Cmd.Flags().GetString("drive")
	if err != nil {
		return err
	}
	drive, ok := ctx.Drives[driveName]
	if !ok {
		return fmt.Errorf("invalid Drive name %q", driveName)
	}
	noUsers, err := ctx.Cmd.Flags().GetBool("no-users")
	if err != nil {
		return err
	}
	skipAfter, err := ctx.Cmd.Flags().GetDuration("skip-after")
	if err != nil {
		return err
	}
	stateStore := users.NewRobotStateStore(db, "user_robot_state")
	stateName := "aggressive_session_reader"
	if shadow {
		ctx.Logger.Info("Running in shadow mode")
		stateName = "aggressive_session_shadow_reader"
	}
	rawReaderState, err := getRobotStateTx(
		db, stateStore, stateName, uuid.Nil.String(),
	)
	if err != nil {
		return err
	}
	readerState := readerStateData{}
	if err := rawReaderState.ScanState(&readerState); err != nil {
		return err
	}
	lastTime, err := processSessionScoringRows(
		ctx, db, yc, stateStore, drive, readerState,
		ctx.Config.YTPaths.AggressionSessionScoreTable,
		workers, shadow, dryRun, noUsers, skipAfter,
	)
	if err != nil {
		return err
	}
	if !shadow {
		ctx.Signal("aggression.sessions.update_lag_last", nil).
			Set(float64(time.Now().Unix() - lastTime))
	}
	return updateReaderState(
		ctx, db, stateStore, &rawReaderState, &readerState, lastTime, dryRun,
	)
}
