package storage

import (
	"a.yandex-team.ru/infra/hmserver/pkg/reporter/types"
	"context"
	"database/sql"
	"fmt"
	"sort"
	"time"

	"github.com/golang/protobuf/ptypes"
	_ "github.com/jackc/pgx/v4/stdlib"
	"github.com/jmoiron/sqlx"

	pb "a.yandex-team.ru/infra/hmserver/proto"
)

type DBRecord struct {
	Node           types.Node    `db:"node"`
	Name           types.Unit    `db:"name"`
	Version        types.Version `db:"version"`
	Stage          types.Stage   `db:"stage"`
	Kind           string        `db:"kind"`
	Ready          pb.Status     `db:"ready"`
	Pending        pb.Status     `db:"pending"`
	LastTransition time.Time     `db:"last_transition"`
	ReportTime     time.Time     `db:"report_time"`
}

type Config struct {
	Addr            string
	AdminAddr       string
	DBName          string
	User            string
	Password        string
	MaxIdleConns    int
	MaxConnLifetime time.Duration
}

func (r *DBRecord) Proto() (*pb.Unit, error) {
	transitionTime, err := ptypes.TimestampProto(r.LastTransition)
	if err != nil {
		return nil, err
	}
	return &pb.Unit{
		Node:           string(r.Node),
		Name:           string(r.Name),
		Version:        string(r.Version),
		Stage:          string(r.Stage),
		Kind:           r.Kind,
		Ready:          r.Ready,
		Pending:        r.Pending,
		LastTransition: transitionTime,
	}, nil
}

func reportsProto(dbRecords []*DBRecord) ([]*pb.Unit, error) {
	reportsProto := make([]*pb.Unit, len(dbRecords))
	for i, r := range dbRecords {
		reportProto, err := r.Proto()
		if err != nil {
			return nil, err
		}
		reportsProto[i] = reportProto
	}
	return reportsProto, nil
}

type Units interface {
	GetReports(context.Context, types.Node, types.Unit, types.Stage, types.Version, []pb.Status, []pb.Status, int32, int32) ([]*pb.Unit, error)
	GetUnits(context.Context) (map[string][]string, error)
	RemoveByNodes(context.Context, []string) error
	UpdateHostReports(context.Context, string, []*pb.Unit) error
	GetUnitVersions(context.Context, types.Unit) (map[types.Version]int, error)
	GetUnitReady(context.Context, types.Unit) (map[pb.Status]int, error)
	GetUnitPending(context.Context, types.Unit) (map[pb.Status]int, error)
	GetUnitStages(context.Context, types.Unit) (map[types.Stage]int, error)
}

type units struct {
	db           *sqlx.DB
	unitsQueries *unitsQueries
}

func NewUnits(db *sqlx.DB) (Units, error) {
	_, err := db.Exec(unitsScheme)
	if err != nil {
		return nil, err
	}
	queries, err := prepareUnitsQueries()
	if err != nil {
		return nil, err
	}
	return &units{db, queries}, nil
}

func (s *units) UpdateHostReports(ctx context.Context, node string, incomingUnits []*pb.Unit) error {
	// No need to form empty SQL request, which fails at the moment.
	if len(incomingUnits) == 0 {
		return nil
	}
	tx, err := s.db.Begin()
	defer tx.Commit()
	if err != nil {
		return err
	}
	oldUnits, err := s.getUnitsByNode(ctx, tx, node)
	if err != nil {
		return err
	}
	sort.Slice(oldUnits, func(i, j int) bool {
		return CompareUnit(oldUnits[i], oldUnits[j]) > 0
	})
	sort.Slice(incomingUnits, func(i, j int) bool {
		return CompareUnit(incomingUnits[i], incomingUnits[j]) > 0
	})
	toUpsert, toRemove := getUnitsToExecute(oldUnits, incomingUnits)
	if len(toRemove) > 0 {
		unitsToRemove := make([]string, len(toRemove))
		for i, unit := range toRemove {
			unitsToRemove[i] = unit.Name
		}
		err = s.removeByNodeNames(ctx, tx, node, unitsToRemove)
		if err != nil {
			return fmt.Errorf("faield to remove reports: %w", err)
		}
	}
	if len(toUpsert) > 0 {
		err = s.upsertUnits(ctx, tx, toUpsert)
		if err != nil {
			return fmt.Errorf("faield to upsert reports: %w", err)
		}
	}
	return nil
}

func (s *units) getUnitsByNode(ctx context.Context, tx *sql.Tx, node string) ([]*pb.Unit, error) {
	query, args, err := s.unitsQueries.getUnitsByNode(node)
	if err != nil {
		return nil, err
	}
	rows, err := tx.QueryContext(ctx, query, args...)
	if err != nil {
		return nil, err
	}
	defer rows.Close()
	reports := make([]*DBRecord, 0)
	for rows.Next() {
		rep := &DBRecord{}
		err = rows.Scan(&rep.Node, &rep.Name, &rep.Version, &rep.Stage, &rep.Kind, &rep.Ready, &rep.Pending, &rep.LastTransition)
		if err != nil {
			return nil, err
		}
		reports = append(reports, rep)
	}
	return reportsProto(reports)
}

func (s *units) removeByNodeNames(ctx context.Context, tx *sql.Tx, node string, names []string) error {
	query, args, err := s.unitsQueries.removeByNodeNames(node, names)
	if err != nil {
		return err
	}
	_, err = tx.ExecContext(ctx, query, args...)
	return err
}

func (s *units) upsertUnits(ctx context.Context, tx *sql.Tx, units []*pb.Unit) error {
	query, args, err := s.unitsQueries.upsertUnits(units)
	if err != nil {
		return err
	}
	_, err = tx.ExecContext(ctx, ReplaceSQL(query, "?"), args...)
	return err
}

func getUnitsToExecute(oldUnits, incomingUnits []*pb.Unit) (toUpsert, toRemove []*pb.Unit) {
	toUpsert = make([]*pb.Unit, 0, len(incomingUnits))
	toRemove = make([]*pb.Unit, 0, len(oldUnits))
	i := 0
	j := 0
	for i < len(oldUnits) && j < len(incomingUnits) {
		old := oldUnits[i]
		inc := incomingUnits[j]
		c := CompareUnit(old, inc)
		if c < 0 {
			toRemove = append(toRemove, old)
			i++
		} else if c > 0 {
			toUpsert = append(toUpsert, inc)
			j++
		} else {
			if UnitChanged(old, inc) {
				toUpsert = append(toUpsert, inc)
			}
			i++
			j++
		}
	}
	for i < len(oldUnits) {
		toRemove = append(toRemove, oldUnits[i])
		i++
	}
	for j < len(incomingUnits) {
		toUpsert = append(toUpsert, incomingUnits[j])
		j++
	}
	return toUpsert, toRemove
}

func CompareUnit(a, b *pb.Unit) int {
	if a.Name == b.Name {
		return 0
	}
	if a.Name < b.Name {
		return -1
	}
	return 1
}

func UnitChanged(a, b *pb.Unit) bool {
	if a.Ready != b.Ready {
		return true
	}
	if a.Pending != b.Pending {
		return true
	}
	if a.Version != b.Version {
		return true
	}
	if a.Stage != b.Stage {
		return true
	}
	if a.Kind != b.Kind {
		return true
	}
	if a.LastTransition != b.LastTransition {
		return true
	}
	return false
}

func (s *units) GetReports(ctx context.Context, n types.Node, u types.Unit, stage types.Stage, version types.Version, ready []pb.Status, pending []pb.Status, limit, offset int32) ([]*pb.Unit, error) {
	unit := &pb.Unit{
		Node:    string(n),
		Name:    string(u),
		Version: string(version),
		Stage:   string(stage),
	}
	query, args, err := s.unitsQueries.getUnitsReports(unit, ready, pending, uint64(limit), uint64(offset))
	if err != nil {
		return nil, err
	}
	reports := make([]*DBRecord, 0)
	err = s.db.Select(&reports, query, args...)
	if err != nil {
		return nil, err
	}
	return reportsProto(reports)
}

func (s *units) GetUnits(context.Context) (map[string][]string, error) {
	query, args, err := s.unitsQueries.getUnits()
	if err != nil {
		return nil, err
	}
	type resp struct {
		Name string `db:"name"`
		Kind string `db:"kind"`
	}
	res := make([]*resp, 0)
	err = s.db.Select(&res, query, args...)
	if err != nil {
		return nil, err
	}
	units := make(map[string][]string)
	for _, r := range res {
		if units[r.Kind] == nil {
			units[r.Kind] = make([]string, 0)
		}
		units[r.Kind] = append(units[r.Kind], r.Name)
	}
	return units, nil
}

func (s *units) RemoveByNodes(ctx context.Context, nodes []string) error {
	query, args, err := s.unitsQueries.removeByNodes(nodes)
	if err != nil {
		return err
	}
	return exec(s.db, ctx, query, args)
}

type DBVersionCount struct {
	Version types.Version `db:"version"`
	Count   int           `db:"count"`
}

func (s *units) GetUnitVersions(ctx context.Context, name types.Unit) (map[types.Version]int, error) {
	dbResp := make([]*DBVersionCount, 0)
	query, args, err := s.unitsQueries.getUnitVersions(string(name))
	if err != nil {
		return nil, err
	}
	err = s.db.Select(&dbResp, query, args...)
	if err != nil {
		return nil, err
	}
	versionsCount := make(map[types.Version]int)
	for _, r := range dbResp {
		versionsCount[r.Version] = r.Count
	}
	return versionsCount, nil
}

type DBReadyCount struct {
	Ready pb.Status `db:"ready"`
	Count int       `db:"count"`
}

func (s *units) GetUnitReady(ctx context.Context, name types.Unit) (map[pb.Status]int, error) {
	dbResp := make([]*DBReadyCount, 0)
	query, args, err := s.unitsQueries.getUnitReady(string(name))
	if err != nil {
		return nil, err
	}
	err = s.db.Select(&dbResp, query, args...)
	if err != nil {
		return nil, err
	}
	readyCount := make(map[pb.Status]int)
	for _, r := range dbResp {
		readyCount[r.Ready] = r.Count
	}
	return readyCount, nil
}

type DBPendingCount struct {
	Pending pb.Status `db:"pending"`
	Count   int       `db:"count"`
}

func (s *units) GetUnitPending(ctx context.Context, name types.Unit) (map[pb.Status]int, error) {
	dbResp := make([]*DBPendingCount, 0)
	query, args, err := s.unitsQueries.getUnitPending(string(name))
	if err != nil {
		return nil, err
	}
	err = s.db.Select(&dbResp, query, args...)
	if err != nil {
		return nil, err
	}
	pendingCount := make(map[pb.Status]int)
	for _, r := range dbResp {
		pendingCount[r.Pending] = r.Count
	}
	return pendingCount, nil
}

type DBStagesCount struct {
	Stage string `db:"stage"`
	Count int    `db:"count"`
}

func (s *units) GetUnitStages(ctx context.Context, name types.Unit) (map[types.Stage]int, error) {
	dbResp := make([]*DBStagesCount, 0)
	query, args, err := s.unitsQueries.getUnitStages(string(name))
	if err != nil {
		return nil, err
	}
	err = s.db.Select(&dbResp, query, args...)
	if err != nil {
		return nil, err
	}
	stagesCount := make(map[types.Stage]int)
	for _, r := range dbResp {
		stagesCount[types.Stage(r.Stage)] = r.Count
	}
	return stagesCount, nil
}
