package models

import (
	"database/sql"
	"database/sql/driver"
	"encoding/json"
	"errors"
	"fmt"
	"time"

	"a.yandex-team.ru/drive/library/go/gosql"
	"github.com/jmoiron/sqlx"
	"github.com/jmoiron/sqlx/types"
)

type TaskStatus string

const (
	QueuedTask    TaskStatus = "Queued"
	StartingTask  TaskStatus = "Starting"
	RunningTask   TaskStatus = "Running"
	FailedTask    TaskStatus = "Failed"
	SucceededTask TaskStatus = "Succeeded"
	AbortingTask  TaskStatus = "Aborting"
	AbortedTask   TaskStatus = "Aborted"
)

// TaskOption represents task option.
type TaskOption struct {
	Type  OptionType  `json:""`
	Value interface{} `json:""`
}

// TaskOptions represents task options.
type TaskOptions map[string]TaskOption

func (o TaskOptions) Clone() TaskOptions {
	copy := TaskOptions{}
	for key, value := range o {
		copy[key] = value
	}
	return copy
}

// Task represents task.
type Task struct {
	ID         int64       `db:"id"          json:""`
	ActionID   int         `db:"action_id"   json:""`
	PlannerID  NInt        `db:"planner_id"  json:",omitempty"`
	OwnerID    NInt        `db:"owner_id"    json:",omitempty"`
	NodeID     NInt        `db:"node_id"     json:",omitempty"`
	CreateTime int64       `db:"create_time" json:""`
	UpdateTime int64       `db:"update_time" json:""`
	Status     TaskStatus  `db:"status"      json:""`
	Options    TaskOptions `db:"options"     json:""`
}

// Clone creates copy of task.
func (o Task) Clone() Task {
	o.Options = o.Options.Clone()
	return o
}

type optionGenericValue string

func (s *optionGenericValue) UnmarshalJSON(bytes []byte) error {
	*s = optionGenericValue(bytes)
	return nil
}

func getOptionValue(t OptionType, v optionGenericValue) (interface{}, error) {
	switch t {
	case StringOption:
		var val string
		err := json.Unmarshal([]byte(v), &val)
		return val, err
	case ConfigOption:
		// Config is the same as integer option,
		// so, falltrough integer option
		fallthrough
	case IntegerOption:
		var val int64
		err := json.Unmarshal([]byte(v), &val)
		return val, err
	case FloatOption:
		var val float64
		err := json.Unmarshal([]byte(v), &val)
		return val, err
	case SecretOption:
		// Secrets are migrating to secrets models.
		var intVal int64
		if err := json.Unmarshal([]byte(v), &intVal); err != nil {
			var val SecretValue
			err := json.Unmarshal([]byte(v), &val)
			return val, err
		}
		return intVal, nil
	case JSONOption:
		var val JSON
		err := json.Unmarshal([]byte(v), &val)
		return val, err
	case JSONQuery:
		var val types.JSONText
		err := json.Unmarshal([]byte(v), &val)
		return val, err
	default:
		return nil, fmt.Errorf("unsupported type '%s'", t)
	}
}

func (o *TaskOption) UnmarshalJSON(bytes []byte) error {
	var opt struct {
		Type  OptionType          `json:""`
		Value *optionGenericValue `json:""`
	}
	if err := json.Unmarshal(bytes, &opt); err != nil {
		return err
	}
	if opt.Value != nil {
		value, err := getOptionValue(opt.Type, *opt.Value)
		if err != nil {
			return err
		}
		o.Value = value
	} else {
		o.Value = nil
	}
	o.Type = opt.Type
	return nil
}

func (o *TaskOptions) Scan(src interface{}) error {
	switch t := src.(type) {
	case string:
		return json.Unmarshal([]byte(t), o)
	case []byte:
		return json.Unmarshal(t, o)
	case nil:
		return nil
	default:
		return errors.New("incompatible type for TaskOptions")
	}
}

func (o TaskOptions) Value() (driver.Value, error) {
	bytes, err := json.Marshal(o)
	if err != nil {
		return nil, err
	}
	return driver.Value(string(bytes)), err
}

type TaskStore struct {
	db    *gosql.DB
	dbx   *sqlx.DB
	table string
}

func (s *TaskStore) DB() *gosql.DB {
	return s.db
}

func (s *TaskStore) Get(id int) (task Task, err error) {
	err = s.dbx.Get(
		&task,
		fmt.Sprintf(
			`SELECT "id", "action_id",`+
				` COALESCE("planner_id", 0) AS "planner_id",`+
				` COALESCE("owner_id", 0) AS "owner_id",`+
				` COALESCE("node_id", 0) AS "node_id",`+
				` "create_time", "update_time", "status", "options"`+
				` FROM "%s"`+
				` WHERE "id" = $1 LIMIT 1`,
			s.table,
		),
		id,
	)
	return
}

// CreateTx creates new task.
func (s *TaskStore) CreateTx(
	tx gosql.Runner, task *Task,
) error {
	taskCopy := *task
	taskCopy.CreateTime = time.Now().Unix()
	taskCopy.UpdateTime = taskCopy.CreateTime
	taskCopy.Status = QueuedTask
	names, values := gosql.StructNameValues(taskCopy, false, "id")
	query, values := s.db.Insert(s.table).
		Names(names...).Values(values...).Build()
	row := tx.QueryRow(query+` RETURNING "id"`, values...)
	if err := row.Scan(&taskCopy.ID); err != nil {
		return err
	}
	*task = taskCopy
	return nil
}

// Create creates new task.
func (s *TaskStore) Create(task *Task) error {
	return s.CreateTx(s.db, task)
}

// UpdateTx updates task.
func (s *TaskStore) UpdateTx(
	tx gosql.Runner, task *Task, columns ...string,
) error {
	taskCopy := *task
	taskCopy.UpdateTime = time.Now().Unix()
	names, values := gosql.StructNameValues(taskCopy, false, "id")
	if len(columns) > 0 {
		allowed := map[string]struct{}{"update_time": {}}
		for _, column := range columns {
			allowed[column] = struct{}{}
		}
		newLen := 0
		for i := 0; i < len(names); i++ {
			if _, ok := allowed[names[i]]; !ok {
				continue
			}
			names[newLen] = names[i]
			values[newLen] = values[i]
			newLen++
		}
		names, values = names[:newLen], values[:newLen]
	}
	query, values := s.db.Update(s.table).
		Names(names...).Values(values...).
		Where(gosql.Column("id").Equal(task.ID)).Build()
	res, err := tx.Exec(query, values...)
	if err != nil {
		return err
	}
	affected, err := res.RowsAffected()
	if err != nil {
		return err
	}
	if affected != 1 {
		return sql.ErrNoRows
	}
	*task = taskCopy
	return nil
}

// Update updates task.
func (s *TaskStore) Update(task *Task, columns ...string) error {
	return s.UpdateTx(s.db, task, columns...)
}

// FindTx finds tasks with specified query.
func (s *TaskStore) FindTx(
	tx gosql.Runner, where gosql.BoolExpr,
) ([]Task, error) {
	query, values := s.db.Select(s.table).
		Names(gosql.StructNames(Task{})...).Where(where).Build()
	rows, err := tx.Query(query, values...)
	if err != nil {
		return nil, err
	}
	defer func() {
		_ = rows.Close()
	}()
	var tasks []Task
	for rows.Next() {
		var task Task
		if err := rows.Scan(gosql.StructValues(&task, true)...); err != nil {
			return nil, err
		}
		tasks = append(tasks, task)
	}
	return tasks, rows.Err()
}

func (s *TaskStore) FindPageTx(
	tx gosql.Runner, where gosql.BoolExpr, begin NInt, limit int,
) ([]Task, NInt, error) {
	if begin != 0 {
		where = where.And(gosql.Column("id").LessEqual(begin))
	}
	query, values := s.db.Select(s.table).
		Names(gosql.StructNames(Task{})...).
		Where(where).OrderBy(gosql.SortDesc("id")).
		Limit(limit + 1).Build()
	rows, err := tx.Query(query, values...)
	if err != nil {
		return nil, 0, err
	}
	defer func() {
		_ = rows.Close()
	}()
	var tasks []Task
	for rows.Next() {
		var task Task
		if err := rows.Scan(gosql.StructValues(&task, true)...); err != nil {
			return nil, 0, err
		}
		if len(tasks) == limit {
			return tasks, NInt(task.ID), rows.Err()
		}
		tasks = append(tasks, task)
	}
	return tasks, 0, rows.Err()
}

func (s *TaskStore) FindPageByPlannerTx(
	tx gosql.Runner, planner int, begin NInt, limit int,
) ([]Task, NInt, error) {
	return s.FindPageTx(
		tx, gosql.Column("planner_id").Equal(planner),
		begin, limit,
	)
}

func (s *TaskStore) FindPageByPlanner(
	planner int, begin NInt, limit int,
) ([]Task, NInt, error) {
	return s.FindPageByPlannerTx(s.db, planner, begin, limit)
}

func (s *TaskStore) FindPageByActionTx(
	tx gosql.Runner, action int, begin NInt, limit int,
) ([]Task, NInt, error) {
	return s.FindPageTx(
		tx, gosql.Column("action_id").Equal(action),
		begin, limit,
	)
}

func (s *TaskStore) FindPageByAction(
	action int, begin NInt, limit int,
) ([]Task, NInt, error) {
	return s.FindPageByActionTx(s.db, action, begin, limit)
}

func (s *TaskStore) FindQueuedByNode(node int) (tasks []Task, err error) {
	err = s.dbx.Select(
		&tasks,
		fmt.Sprintf(
			`SELECT "id", "action_id",`+
				` COALESCE("planner_id", 0) AS "planner_id",`+
				` COALESCE("owner_id", 0) AS "owner_id",`+
				` COALESCE("node_id", 0) AS "node_id",`+
				` "create_time", "update_time", "status", "options"`+
				` FROM "%s"`+
				` WHERE "status" = $1 AND "node_id" = $2`,
			s.table,
		),
		QueuedTask, node,
	)
	return
}

func (s *TaskStore) FindQueuedUnassigned() (tasks []Task, err error) {
	err = s.dbx.Select(
		&tasks,
		fmt.Sprintf(
			`SELECT "id", "action_id",`+
				` COALESCE("planner_id", 0) AS "planner_id",`+
				` COALESCE("owner_id", 0) AS "owner_id",`+
				` COALESCE("node_id", 0) AS "node_id",`+
				` "create_time", "update_time", "status", "options"`+
				` FROM "%s"`+
				` WHERE "status" = $1 AND "node_id" IS NULL`,
			s.table,
		),
		QueuedTask,
	)
	return
}

func (s *TaskStore) GetAssignedByPlanner(
	planner int,
) (tasks []Task, err error) {
	err = s.dbx.Select(
		&tasks,
		fmt.Sprintf(
			`SELECT "id", "action_id",`+
				` COALESCE("planner_id", 0) AS "planner_id",`+
				` COALESCE("owner_id", 0) AS "owner_id",`+
				` COALESCE("node_id", 0) AS "node_id",`+
				` "create_time", "update_time", "status", "options"`+
				` FROM "%s"`+
				` WHERE "planner_id" = $1 AND "node_id" IS NOT NULL`,
			s.table,
		),
		planner,
	)
	return
}

func (s *TaskStore) GetQueuedByPlanner(
	planner int,
) (tasks []Task, err error) {
	err = s.dbx.Select(
		&tasks,
		fmt.Sprintf(
			`SELECT "id", "action_id",`+
				` COALESCE("planner_id", 0) AS "planner_id",`+
				` COALESCE("owner_id", 0) AS "owner_id",`+
				` COALESCE("node_id", 0) AS "node_id",`+
				` "create_time", "update_time", "status", "options"`+
				` FROM "%s"`+
				` WHERE "status" = $1 AND "planner_id" = $2`,
			s.table,
		),
		QueuedTask, planner,
	)
	return
}

func (s *TaskStore) FindFinishedAssigned() (tasks []Task, err error) {
	err = s.dbx.Select(
		&tasks,
		fmt.Sprintf(
			`SELECT "id", "action_id",`+
				` COALESCE("planner_id", 0) AS "planner_id",`+
				` COALESCE("owner_id", 0) AS "owner_id",`+
				` COALESCE("node_id", 0) AS "node_id",`+
				` "create_time", "update_time", "status", "options"`+
				` FROM "%s"`+
				` WHERE "status" IN ($1, $2, $3) AND "node_id" IS NOT NULL`,
			s.table,
		),
		SucceededTask, FailedTask, AbortedTask,
	)
	return
}

func (s *TaskStore) GetLatestByPlanner(planner int) (task Task, err error) {
	err = s.dbx.Get(
		&task,
		fmt.Sprintf(
			`SELECT "id", "action_id",`+
				` COALESCE("planner_id", 0) AS "planner_id",`+
				` COALESCE("owner_id", 0) AS "owner_id",`+
				` COALESCE("node_id", 0) AS "node_id",`+
				` "create_time", "update_time", "status", "options"`+
				` FROM "%s"`+
				` WHERE "planner_id" = $1 ORDER BY "id" DESC LIMIT 1`,
			s.table,
		),
		planner,
	)
	return
}

func (s *TaskStore) FindByAction(action int) (tasks []Task, err error) {
	err = s.dbx.Select(
		&tasks,
		fmt.Sprintf(
			`SELECT "id", "action_id",`+
				` COALESCE("planner_id", 0) AS "planner_id",`+
				` COALESCE("owner_id", 0) AS "owner_id",`+
				` COALESCE("node_id", 0) AS "node_id",`+
				` "create_time", "update_time", "status", "options"`+
				` FROM "%s"`+
				` WHERE "action_id" = $1 ORDER BY "id" DESC LIMIT 100`,
			s.table,
		),
		action,
	)
	return
}

func (s *TaskStore) FindByPlanner(planner int) (tasks []Task, err error) {
	err = s.dbx.Select(
		&tasks,
		fmt.Sprintf(
			`SELECT "id", "action_id",`+
				` COALESCE("planner_id", 0) AS "planner_id",`+
				` COALESCE("owner_id", 0) AS "owner_id",`+
				` COALESCE("node_id", 0) AS "node_id",`+
				` "create_time", "update_time", "status", "options"`+
				` FROM "%s"`+
				` WHERE "planner_id" = $1 ORDER BY "id" DESC LIMIT 100`,
			s.table,
		),
		planner,
	)
	return
}

func (s *TaskStore) Reload(m *Task) error {
	return s.dbx.Get(
		m,
		fmt.Sprintf(
			`SELECT "id", "action_id",`+
				` COALESCE("planner_id", 0) AS "planner_id",`+
				` COALESCE("owner_id", 0) AS "owner_id",`+
				` COALESCE("node_id", 0) AS "node_id",`+
				` "create_time", "update_time", "status", "options"`+
				` FROM "%s"`+
				` WHERE "id" = $1 LIMIT 1`,
			s.table,
		),
		m.ID,
	)
}

func (s *TaskStore) KillByNode(nodeID int) error {
	_, err := s.dbx.Exec(
		fmt.Sprintf(
			`UPDATE %q SET`+
				` "status" = $1, "update_time" = $2`+
				` WHERE "node_id" = $3 AND "status" IN ($4, $5, $6, $7)`,
			s.table,
		),
		FailedTask, time.Now().Unix(), nodeID,
		QueuedTask, StartingTask, RunningTask, AbortingTask,
	)
	return err
}

func (s *TaskStore) ListByUser(
	userID int, page, perPage int,
) (l []Task, n int, err error) {
	err = s.dbx.Select(
		&l,
		fmt.Sprintf(
			`SELECT "id", "action_id",`+
				` COALESCE("planner_id", 0) AS "planner_id",`+
				` COALESCE("owner_id", 0) AS "owner_id",`+
				` COALESCE("node_id", 0) AS "node_id",`+
				` "create_time", "update_time", "status", "options"`+
				` FROM "%s"`+
				` WHERE "owner_id" = $1`+
				` ORDER BY "id" DESC LIMIT $2 OFFSET $3`,
			s.table,
		),
		userID, perPage, page*perPage,
	)
	if err != nil {
		return
	}
	err = s.dbx.Get(
		&n,
		fmt.Sprintf(
			`SELECT COUNT(*) FROM "%s"`+
				` WHERE "owner_id" = $1 LIMIT 1`,
			s.table,
		),
		userID,
	)
	return
}

// NewTaskStore creates store for tasks.
func NewTaskStore(db *gosql.DB, table string) *TaskStore {
	return &TaskStore{db: db, dbx: sqlx.NewDb(db.DB, "pgx"), table: table}
}
