package aggression

import (
	"math"

	"github.com/spf13/cobra"

	"a.yandex-team.ru/drive/analytics/goback/models"
	"a.yandex-team.ru/drive/analytics/goback/models/tags"
	"a.yandex-team.ru/drive/analytics/gotasks"
	"a.yandex-team.ru/yt/go/mapreduce"
	"a.yandex-team.ru/yt/go/mapreduce/spec"
	"a.yandex-team.ru/yt/go/schema"
	"a.yandex-team.ru/yt/go/ypath"
)

func init() {
	mapreduce.Register(mapper{})
	mapreduce.Register(reducer{})
	buildSessionsStateDiffCmd := cobra.Command{
		Use: "build-sessions-state-diff",
		Run: gotasks.WrapMain(buildSessionsStateDiffMain),
	}
	buildSessionsStateDiffCmd.Flags().String("yt-proxy", "hahn", "YT proxy")
	buildSessionsStateDiffCmd.Flags().String("tags-path", "//home/carsharing/production/data/exports/trace_tags", "")
	buildSessionsStateDiffCmd.Flags().String("scores-path", "//home/carsharing/ml/agressive_driving/v2/applier/session_scores", "")
	buildSessionsStateDiffCmd.Flags().String("diff-path", "//home/carsharing/iudovin/aggressive_sessions_state_diff", "")
	AggressionCmd.AddCommand(&buildSessionsStateDiffCmd)
}

type mapper struct {
	mapreduce.Untyped
}

type mapperRow struct {
	SessionID string             `yson:"session_id"`
	TagRow    *models.Tag        `yson:"tag_row"`
	StateRow  *sessionScoringRow `yson:"state_row"`
}

func (mapper) Do(
	ctx mapreduce.JobContext, in mapreduce.Reader, out []mapreduce.Writer,
) error {
	for in.Next() {
		var row mapperRow
		switch in.TableIndex() {
		case 0:
			if err := in.Scan(&row.TagRow); err != nil {
				return err
			}
			if row.TagRow == nil {
				continue
			}
			if row.TagRow.Tag != "scoring_trace_tag" {
				continue
			}
			row.SessionID = row.TagRow.ObjectID
		case 1:
			if err := in.Scan(&row.StateRow); err != nil {
				return err
			}
			if row.StateRow == nil {
				continue
			}
			if row.StateRow.Score < row.StateRow.Threshold {
				continue
			}
			row.SessionID = row.StateRow.SessionID
		}
		if err := out[0].Write(row); err != nil {
			return err
		}
	}
	return nil
}

type diffRow struct {
	Time       int64    `yson:"time"`
	SessionID  string   `yson:"session_id"`
	TagScore   *float64 `yson:"tag_score"`
	StateScore *float64 `yson:"state_score"`
	// State.
	State *sessionScoringRow `yson:"state"`
}

type reducer struct {
	mapreduce.Untyped
}

func (r reducer) Do(
	ctx mapreduce.JobContext, in mapreduce.Reader, out []mapreduce.Writer,
) error {
	return mapreduce.GroupKeys(in, func(in mapreduce.Reader) error {
		return r.reduceGroup(in, out)
	})
}

func (r reducer) reduceGroup(
	in mapreduce.Reader, out []mapreduce.Writer,
) error {
	var outRow diffRow
	for in.Next() {
		var row mapperRow
		if err := in.Scan(&row); err != nil {
			return err
		}
		if tag := row.TagRow; tag != nil {
			var data tags.ScoringTagData
			if err := tag.ScanData(&data); err != nil {
				return err
			}
			if data.Value != nil {
				score := *data.Value
				outRow.TagScore = &score
			}
		}
		if state := row.StateRow; state != nil {
			score := state.Score
			outRow.StateScore = &score
			outRow.Time = state.Time
			outRow.State = state
		}
		outRow.SessionID = row.SessionID
	}
	if outRow.TagScore == nil && outRow.StateScore == nil {
		return nil
	}
	if outRow.TagScore == nil {
		return out[0].Write(outRow)
	}
	if outRow.StateScore == nil {
		return out[0].Write(outRow)
	}
	if math.Abs(*outRow.TagScore-*outRow.StateScore) > 1e-9 {
		return out[0].Write(outRow)
	}
	return nil
}

func buildSessionsStateDiffMain(ctx *gotasks.Context) error {
	tagsPath, err := ctx.Cmd.Flags().GetString("tags-path")
	if err != nil {
		return err
	}
	scoresPath, err := ctx.Cmd.Flags().GetString("scores-path")
	if err != nil {
		return err
	}
	diffPath, err := ctx.Cmd.Flags().GetString("diff-path")
	if err != nil {
		return err
	}
	yc, err := ctx.GetYT()
	if err != nil {
		return err
	}
	mr := mapreduce.New(yc)
	outSchema, err := schema.Infer(diffRow{})
	if err != nil {
		return err
	}
	mrSpec := spec.Spec{
		InputTablePaths: []ypath.YPath{
			ypath.Path(tagsPath),
			ypath.Path(scoresPath),
		},
		OutputTablePaths: []ypath.YPath{
			ypath.Path(diffPath).Rich().SetSchema(outSchema),
		},
		ReduceBy: []string{"session_id"},
		SortBy:   []string{"session_id"},
	}
	mrOp, err := mr.MapReduce(mapper{}, reducer{}, mrSpec.MapReduce())
	if err != nil {
		return err
	}
	return mrOp.Wait()
}
