package storage

import (
	"context"
	"fmt"
	"time"

	sq "github.com/Masterminds/squirrel"
	"google.golang.org/protobuf/types/known/timestamppb"

	"github.com/jmoiron/sqlx"

	pb "a.yandex-team.ru/infra/hmserver/proto"
)

const (
	heartbeatSelectChunkLimit = 2000
)

type Heartbeats interface {
	GetHeartbeats(context.Context, ...string) ([]*pb.HostHeartbeat, error)
}

type heartbeats struct {
	db *sqlx.DB
}

func NewHeartbeatStorage(db *sqlx.DB) *heartbeats {
	return &heartbeats{db: db}
}

func (h *heartbeats) GetHeartbeats(ctx context.Context, fqdns ...string) ([]*pb.HostHeartbeat, error) {
	chunks := len(fqdns)/heartbeatSelectChunkLimit + 1
	rv := make([]*pb.HostHeartbeat, 0)
	for i := 0; i < chunks; i++ {
		chunkSize := heartbeatSelectChunkLimit
		shift := i * heartbeatSelectChunkLimit
		tail := len(fqdns) - shift
		if tail < heartbeatSelectChunkLimit {
			chunkSize = tail
		}
		chunk, err := getHeartbeats(ctx, h.db, fqdns[shift:shift+chunkSize]...)
		if err != nil {
			return nil, err
		}
		rv = append(rv, chunk...)
	}
	return rv, nil
}

func getHeartbeats(ctx context.Context, db *sqlx.DB, fqdns ...string) ([]*pb.HostHeartbeat, error) {
	query, _, err := sq.Select("hostname", "report_time").From("hosts").
		Where(sq.Eq{"hostname": fqdns}).
		PlaceholderFormat(sq.Dollar).
		ToSql()
	if err != nil {
		return nil, err
	}
	values := asInterface(fqdns)
	rows, err := db.QueryContext(ctx, query, values...)
	if err != nil {
		return nil, fmt.Errorf("cannot execute query: %v", err)
	}
	defer rows.Close()
	rv := make([]*pb.HostHeartbeat, 0)
	for rows.Next() {
		var hostname string
		var reportTime time.Time
		if err := rows.Scan(&hostname, &reportTime); err != nil {
			return nil, fmt.Errorf("cannot fetch results: %v", err)
		}
		rv = append(rv, &pb.HostHeartbeat{Hostname: hostname, LastSeen: timestamppb.New(reportTime)})
	}
	return rv, nil
}

func asInterface(a []string) []interface{} {
	rv := make([]interface{}, len(a))
	for i := range a {
		rv[i] = a[i]
	}
	return rv
}
