package sessions

import (
	"context"
	"database/sql"
	"encoding/json"
	"fmt"
	"time"

	"github.com/spf13/cobra"

	"a.yandex-team.ru/drive/analytics/gotasks"
	"a.yandex-team.ru/drive/analytics/gotasks/exports"
	"a.yandex-team.ru/drive/analytics/gotasks/users"
	"a.yandex-team.ru/drive/library/go/gosql"
	"a.yandex-team.ru/library/go/core/log"
	"a.yandex-team.ru/zootopia/analytics/drive/api"
	"a.yandex-team.ru/zootopia/analytics/drive/models"
	"a.yandex-team.ru/zootopia/library/go/db"
	"a.yandex-team.ru/zootopia/library/go/db/events"
)

func init() {
	updateSpeedingCmd := cobra.Command{
		Use: "update-speeding",
		Run: gotasks.WrapMain(updateSpeedingMain),
	}
	updateSpeedingCmd.Flags().Bool("dry-run", false, "Enables dry-run mode")
	updateSpeedingCmd.Flags().String("drive", "", "Name of Drive client")
	updateSpeedingCmd.Flags().String("backend-db", "backend", "Backend DB name")
	updateSpeedingCmd.Flags().Int("workers", 8, "Amount of parallel workers")
	updateSpeedingCmd.Flags().Duration("delay", 0, "Delay for shadow sessions")
	updateSpeedingCmd.Flags().Bool("leasing", false, "Leasing backend")
	SessionsCmd.AddCommand(&updateSpeedingCmd)
}

type readerStateData struct {
	BeginEventID int64 `json:"begin_event_id"`
}

func getDialect(conn *gosql.DB) db.Dialect {
	switch conn.Driver {
	case gosql.PostgresDriver:
		return db.Postgres
	case gosql.SQLiteDriver:
		return db.SQLite
	default:
		panic(fmt.Errorf("unsupported %q driver", conn.Driver))
	}
}

func must[T any](v T, err error) T {
	if err != nil {
		panic(err)
	}
	return v
}

type dbEventTable struct {
	db     *gosql.DB
	events events.ROStore
	delay  time.Duration
}

type dbEventReader struct {
	rows []events.Event
	pos  int
}

func (r *dbEventReader) Next() bool {
	if r.pos >= len(r.rows) {
		return false
	}
	r.pos++
	return true
}

func (r *dbEventReader) Scan(v any) error {
	event := v.(*events.Event)
	*event = r.rows[r.pos-1]
	return nil
}

func (r *dbEventReader) Err() error {
	return nil
}

func (r *dbEventReader) Close() error {
	return nil
}

func (t *dbEventTable) ReadFrom(ctx context.Context, key users.TableKey) (users.Reader, error) {
	k, ok := key.(int64Key)
	if !ok {
		k = 1
	}
	consumer := events.NewOrderedConsumer(t.events, int64(k))
	var rows []events.Event
	if err := gosql.WithTxContext(ctx, t.db, nil, func(
		tx *sql.Tx,
	) error {
		return consumer.ConsumeEvents(
			tx, func(event events.Event) error {
				select {
				case <-ctx.Done():
					return ctx.Err()
				default:
				}
				if event.EventTime().Add(t.delay).After(time.Now()) {
					return sql.ErrNoRows
				}
				rows = append(rows, event)
				if len(rows) > 10000 {
					return sql.ErrNoRows
				}
				return nil
			},
		)
	}); err != nil && err != sql.ErrNoRows {
		return nil, err
	}
	return &dbEventReader{rows: rows}, nil
}

type int64Key int64

func (k int64Key) Less(o users.TableKey) bool {
	rhs := o.(int64Key)
	return k < rhs
}

func (k int64Key) RawKey() []any {
	return []any{int64(k)}
}

type rideRow models.CompiledRide

func (r rideRow) Key() users.TableKey {
	return int64Key(r.HistoryEventID)
}

type speedingImpl struct {
	drive   *api.Client
	leasing bool
	delay   time.Duration
}

func (t *speedingImpl) ParseKey(raw json.RawMessage) (users.TableKey, error) {
	var key int64Key
	err := json.Unmarshal(raw, &key)
	return key, err
}

func (t *speedingImpl) ParseRow(scanner users.Scanner) (users.TableRow, error) {
	var event events.Event
	if err := scanner.Scan(&event); err != nil {
		return nil, err
	}
	return rideRow(event.(models.CompiledRide)), nil
}

func (t *speedingImpl) ProcessRow(ctx *gotasks.Context, row users.TableRow, dryRun bool) error {
	ride := models.CompiledRide(row.(rideRow))
	if ride.Finish+int64(t.delay.Seconds()) >= time.Now().Unix() {
		return fmt.Errorf("ride is too new")
	}
	rideData, err := ride.ParseData()
	if err != nil {
		return err
	}
	if meta := rideData.Meta; meta != nil && meta.GetRidingDuration() <= 0 {
		ctx.Logger.Debug(
			"Empty session",
			log.String("session_id", ride.SessionID),
		)
		return nil
	}
	var tracks []api.AnalyzedTrack
	if t.leasing {
		tracks, err = t.drive.GetLeasingAnalyzedTrack(ride.SessionID, "")
	} else {
		tracks, err = t.drive.GetAnalyzedTrack(ride.SessionID, ride.HistoryUserID)
	}
	if err, ok := err.(api.NoTrackError); ok {
		ctx.Logger.Warn(
			"Track not found",
			log.String("session_id", ride.SessionID),
			log.Error(err),
		)
		return nil
	}
	if err != nil {
		ctx.Logger.Error(
			"Error",
			log.String("session_id", ride.SessionID),
			log.Error(err),
		)
		return err
	}
	violations := 0
	var diff float64
	for _, track := range tracks {
		if track.SessionID != ride.SessionID {
			ctx.Logger.Warn(
				"Wrong session track",
				log.String("session_id", ride.SessionID),
				log.String("track_session_id", track.SessionID),
			)
			continue
		}
		violations += len(track.Violations)
		for _, violation := range track.Violations {
			vDiff := violation.Peak - violation.Limit
			if vDiff > diff {
				diff = vDiff
			}
		}
	}
	ctx.Logger.Debug(
		"Track info",
		log.String("session_id", ride.SessionID),
		log.Int("violations", violations),
		log.Float64("diff", diff),
	)
	if violations == 0 {
		return nil
	}
	tag := api.SessionTag{
		Tag: "speeding_trace_tag",
		Data: &api.GenericTagData{
			"value":     diff,
			"timestamp": ride.HistoryTimestamp,
		},
		SessionID: ride.SessionID,
	}
	ctx.Logger.Debug(
		"Adding session tag",
		log.Any("tag", tag),
		log.Bool("dryrun", dryRun),
	)
	if !dryRun {
		if err := t.drive.AddSessionTag(tag); err != nil {
			ctx.Signal("sessions", map[string]string{
				"type":   "add_speeding_trace_tag",
				"status": "error",
			}).Add(1)
			return err
		}
		ctx.Signal("sessions", map[string]string{
			"type":   "add_speeding_trace_tag",
			"status": "ok",
		}).Add(1)
	}
	return nil
}

func updateSpeedingMain(ctx *gotasks.Context) error {
	dryRun := must(ctx.Cmd.Flags().GetBool("dry-run"))
	driveName := must(ctx.Cmd.Flags().GetString("drive"))
	workers := must(ctx.Cmd.Flags().GetInt("workers"))
	backendDBName := must(ctx.Cmd.Flags().GetString("backend-db"))
	delay := must(ctx.Cmd.Flags().GetDuration("delay"))
	leasing := must(ctx.Cmd.Flags().GetBool("leasing"))
	drive, ok := ctx.Drives[driveName]
	if !ok {
		return fmt.Errorf("invalid drive name %q", driveName)
	}
	backendDB, ok := ctx.DBs[backendDBName]
	if !ok {
		return fmt.Errorf("invalid DB name %q", backendDBName)
	}
	impl := &speedingImpl{
		drive:   drive,
		leasing: leasing,
		delay:   delay,
	}
	state, err := exports.GetTableState(ctx.States, "speeding/sessions/reader")
	if err != nil {
		return fmt.Errorf("unable to fetch state: %w", err)
	}
	table := &dbEventTable{
		db: backendDB,
		events: events.NewStore(
			models.CompiledRide{},
			"history_event_id",
			"compiled_rides",
			getDialect(backendDB),
		),
		delay: delay,
	}
	return users.TableProcessor(ctx, state, table, workers, impl, dryRun)
}
