package storage

import (
	"context"
	"fmt"
	"time"

	"github.com/golang/protobuf/ptypes"
	"github.com/jmoiron/sqlx"

	"a.yandex-team.ru/infra/hmserver/pkg/reporter/types"
	yasaltpb "a.yandex-team.ru/infra/hostctl/proto"
	"a.yandex-team.ru/library/go/x/sql/pgx"
)

type Hosts interface {
	GetHost(ctx context.Context, node string) (*yasaltpb.HostInfo, error)
	GetHosts(ctx context.Context, info *yasaltpb.HostInfo, limit, offset int32) ([]string, error)
	UpdateHost(ctx context.Context, p *yasaltpb.HostInfo) error
	UpsertMeta(ctx context.Context, node string, reportTime time.Time, hostTS types.HostTS, unitsTS types.UnitsTS) error
	RemoveHosts(ctx context.Context, ttl time.Duration, limit int) ([]string, error)
	UpdateReportTime(ctx context.Context, hostname string, reportTime time.Time) (time.Time, time.Time, error)
	StartCursor(ctx context.Context) *HostsCursor
	HostsCount(ctx context.Context) (int, error)
	KernelVersions(ctx context.Context) (map[string]int32, error)
}

type hosts struct {
	db      *sqlx.DB
	queries *hostsQueries
}

func NewHosts(db *sqlx.DB) (Hosts, error) {
	_, err := db.Exec(hostsScheme)
	if err != nil {
		return nil, err
	}
	queries, err := prepareHostQueries()
	if err != nil {
		return nil, err
	}
	return &hosts{
		db:      db,
		queries: queries,
	}, nil
}

func (h *hosts) GetHosts(ctx context.Context, info *yasaltpb.HostInfo, limit, offset int32) ([]string, error) {
	type row struct {
		Node string `db:"hostname"`
	}
	query, args, err := h.queries.getHosts(info, uint64(limit), uint64(offset))
	if err != nil {
		return nil, err
	}
	res := make([]*row, 0)
	err = h.db.Select(&res, query, args...)
	if err != nil {
		return nil, err
	}
	hosts := make([]string, len(res))
	for i, r := range res {
		hosts[i] = r.Node
	}
	return hosts, err
}

type rowScanner interface {
	Scan(...interface{}) error
}

// Columns order is mandatory. Must be in sync with scanHostRow() routine.
var hostsColumns = []string{
	"hostname",
	"num",
	"walle_project",
	"walle_tags",
	"net_switch",
	"gencfg_groups",
	"location",
	"dc",
	"kernel",
	"cpu_model",
	"mem_total_mib",
	"host_ts",
	"dc_queue",
	"os_codename",
	"os_arch",
}

func scanHostRow(scanner rowScanner) (*yasaltpb.HostInfo, error) {
	hostTS := time.Now()
	info := &yasaltpb.HostInfo{}
	err := scanner.Scan(
		&info.Hostname,
		&info.Num,
		&info.WalleProject,
		pgx.Array(&info.WalleTags),
		&info.NetSwitch,
		pgx.Array(&info.GencfgGroups),
		&info.Location,
		&info.Dc,
		&info.KernelRelease,
		&info.CpuModel,
		&info.MemTotalMib,
		&hostTS,
		&info.DcQueue,
		&info.OsCodename,
		&info.OsArch,
	)
	if err != nil {
		return nil, fmt.Errorf("failed to scan row to HostInfo: %w", err)
	}
	hostTSProto, err := ptypes.TimestampProto(hostTS)
	if err != nil {
		return nil, fmt.Errorf("failed to convert TS to ProtoTS: %w", err)
	}
	info.Mtime = hostTSProto
	return info, nil
}

func (h *hosts) GetHost(ctx context.Context, node string) (*yasaltpb.HostInfo, error) {
	query, args, err := h.queries.getHost(node)
	if err != nil {
		return nil, err
	}
	row := h.db.QueryRowContext(ctx, query, args...)
	info, err := scanHostRow(row)
	if err != nil {
		fmt.Println(err)
		return nil, err
	}
	return info, nil
}

func (h *hosts) UpdateHost(ctx context.Context, p *yasaltpb.HostInfo) error {
	query, args, err := h.queries.updateHost(p)
	if err != nil {
		return err
	}
	return exec(h.db, ctx, query, args)
}

func (h *hosts) UpsertMeta(ctx context.Context, node string, reportTime time.Time, hostTS types.HostTS, unitsTS types.UnitsTS) error {
	query, args, err := h.queries.upsertMeta(node, unitsTS, hostTS, reportTime)
	if err != nil {
		return err
	}
	return exec(h.db, ctx, query, args)
}

func (h *hosts) UpdateReportTime(ctx context.Context, hostname string, reportTime time.Time) (time.Time, time.Time, error) {
	unitsTS := time.Unix(0, 0)
	hostTS := time.Unix(0, 0)
	query, args, err := h.queries.updateReportTime(reportTime, hostname)
	if err != nil {
		return hostTS, unitsTS, err
	}
	rows, err := h.db.QueryxContext(ctx, query, args...)
	if err != nil {
		return hostTS, unitsTS, err
	}
	defer rows.Close()
	if !rows.Next() {
		return hostTS, unitsTS, nil
	}
	err = rows.Scan(&hostTS, &unitsTS)
	return hostTS, unitsTS, err
}

func (h *hosts) RemoveHosts(ctx context.Context, ttl time.Duration, limit int) ([]string, error) {
	query, args, err := h.queries.removeHosts(-ttl, uint64(limit))
	if err != nil {
		return nil, err
	}
	oldNodes := make([]string, 0)
	err = h.db.SelectContext(ctx, &oldNodes, query, args...)
	return oldNodes, err
}

func (h *hosts) HostsCount(ctx context.Context) (int, error) {
	query, args, err := h.queries.hostsCountQuery()
	if err != nil {
		return 0, err
	}
	res := make([]int, 1)
	err = h.db.SelectContext(ctx, &res, query, args...)
	if err != nil {
		return 0, err
	}
	if len(res) <= 1 {
		return 0, fmt.Errorf("failed to get count")
	}
	return res[1], nil
}

type DBKernelVersions struct {
	Kernel string `db:"kernel"`
	Count  int    `db:"count"`
}

func (h *hosts) KernelVersions(ctx context.Context) (map[string]int32, error) {
	query, args, err := h.queries.kernelVersionsQuery()
	if err != nil {
		return nil, err
	}
	dbResp := make([]*DBKernelVersions, 0)
	err = h.db.SelectContext(ctx, &dbResp, query, args...)
	if err != nil {
		return nil, err
	}
	kernelVersions := make(map[string]int32)
	for _, r := range dbResp {
		kernelVersions[r.Kernel] = int32(r.Count)
	}
	return kernelVersions, nil
}

func (h *hosts) StartCursor(ctx context.Context) *HostsCursor {
	return &HostsCursor{ts: time.Now(), queries: h.queries, db: h.db, offset: 0}
}

type HostsCursor struct {
	db      *sqlx.DB
	ts      time.Time
	queries *hostsQueries
	offset  int
}

func (c *HostsCursor) Next(ctx context.Context, limit int) ([]*yasaltpb.HostInfo, error) {
	query, args, err := c.queries.getHostsCursor(uint64(c.offset), uint64(limit), c.ts)
	if err != nil {
		return nil, err
	}
	infos := make([]*yasaltpb.HostInfo, 0)
	row, err := c.db.QueryContext(ctx, query, args...)
	if err != nil {
		return nil, err
	}
	for row.Next() {
		info, err := scanHostRow(row)
		if err != nil {
			fmt.Println(err)
			return nil, err
		}
		infos = append(infos, info)
	}
	c.offset += limit
	return infos, err
}
