package storage

import (
	"fmt"
	"strings"
	"time"

	sq "github.com/Masterminds/squirrel"
	"github.com/golang/protobuf/ptypes"

	"a.yandex-team.ru/infra/hmserver/pkg/reporter/types"
	"a.yandex-team.ru/infra/hmserver/pkg/squtil"
	pb "a.yandex-team.ru/infra/hmserver/proto"
	yasaltpb "a.yandex-team.ru/infra/hostctl/proto"
	"a.yandex-team.ru/library/go/x/sql/pgx"
)

func replaceSQL(q string, args []interface{}, err error) (string, []interface{}, error) {
	if err != nil {
		return q, args, err
	}
	return ReplaceSQL(q, "?"), args, err
}

type SQLQuery func(args ...interface{}) (string, []interface{}, error)

func getUnitsByNodeQuery() (SQLQuery, error) {
	query, _, err := sq.Select("*").
		From("units").
		Where(sq.Eq{"node": ""}).
		ToSql()
	if err != nil {
		return nil, err
	}
	return func(args ...interface{}) (string, []interface{}, error) {
		if len(args) != 1 {
			return "", nil, fmt.Errorf("len(args)=%d should be 1 [node]", len(args))
		}
		if _, ok := args[0].(string); !ok {
			return "", nil, fmt.Errorf("args[0] node should be string")
		}
		return replaceSQL(query, args, nil)
	}, nil
}

func getUnitsQuery() (SQLQuery, error) {
	query, _, err := sq.Select().
		Distinct().
		Column("name").
		Column("kind").
		From("units").
		ToSql()
	if err != nil {
		return nil, err
	}
	return func(args ...interface{}) (string, []interface{}, error) {
		if len(args) != 0 {
			return "", nil, fmt.Errorf("len(args)=%d should be 0", len(args))
		}
		return replaceSQL(query, args, nil)
	}, nil
}

func getUnitVersionsQuery() (SQLQuery, error) {
	query, _, err := sq.Select("count(*)", "version").
		From("units").
		Where(sq.Eq{"name": ""}).
		GroupBy("version").
		ToSql()
	if err != nil {
		return nil, err
	}
	return func(args ...interface{}) (string, []interface{}, error) {
		if len(args) != 1 {
			return "", nil, fmt.Errorf("len(args)=%d should be 1 [name]", len(args))
		}
		if _, ok := args[0].(string); !ok {
			return "", nil, fmt.Errorf("args[0] name should be string")
		}
		return replaceSQL(query, args, nil)
	}, nil
}

func getUnitReadyQuery() (SQLQuery, error) {
	query, _, err := sq.Select("count(*)", "ready").
		From("units").
		Where(sq.Eq{"name": ""}).
		GroupBy("ready").
		ToSql()
	if err != nil {
		return nil, err
	}
	return func(args ...interface{}) (string, []interface{}, error) {
		if len(args) != 1 {
			return "", nil, fmt.Errorf("len(args)=%d should be 1 [name]", len(args))
		}
		if _, ok := args[0].(string); !ok {
			return "", nil, fmt.Errorf("args[0] name should be string")
		}
		return replaceSQL(query, args, nil)
	}, nil
}

func getUnitPendingQuery() (SQLQuery, error) {
	query, _, err := sq.Select("count(*)", "pending").
		From("units").
		Where(sq.Eq{"name": ""}).
		GroupBy("pending").
		ToSql()
	if err != nil {
		return nil, err
	}
	return func(args ...interface{}) (string, []interface{}, error) {
		if len(args) != 1 {
			return "", nil, fmt.Errorf("len(args)=%d should be 1 [name]", len(args))
		}
		if _, ok := args[0].(string); !ok {
			return "", nil, fmt.Errorf("args[0] name should be string")
		}
		return replaceSQL(query, args, nil)
	}, nil
}

func getUnitStagesQuery() (SQLQuery, error) {
	query, _, err := sq.Select("count(*)", "stage").
		From("units").
		Where(sq.Eq{"name": ""}).
		GroupBy("stage").
		ToSql()
	if err != nil {
		return nil, err
	}
	return func(args ...interface{}) (string, []interface{}, error) {
		if len(args) != 1 {
			return "", nil, fmt.Errorf("len(args)=%d should be 1 [name]", len(args))
		}
		if _, ok := args[0].(string); !ok {
			return "", nil, fmt.Errorf("args[0] name should be string")
		}
		return replaceSQL(query, args, nil)
	}, nil
}

func removeByNodeNamesQuery() (SQLQuery, error) {
	// we can not prerender sq.Eq with array arg
	return func(args ...interface{}) (string, []interface{}, error) {
		if len(args) != 2 {
			return "", nil, fmt.Errorf("len(args)=%d should be 2 [node, names]", len(args))
		}
		node, ok := args[0].(string)
		if !ok {
			return "", nil, fmt.Errorf("args[0] node should be string")
		}
		names, ok := args[1].([]string)
		if !ok {
			return "", nil, fmt.Errorf("args[1] name should be []string")
		}
		return replaceSQL(sq.
			Delete("units").
			Where(sq.And{sq.Eq{"node": node}, sq.Eq{"name": names}}).
			ToSql())
	}, nil
}

func removeByNodesQuery() (SQLQuery, error) {
	// we can not prerender sq.Eq with array arg
	return func(args ...interface{}) (string, []interface{}, error) {
		if len(args) != 1 {
			return "", nil, fmt.Errorf("len(args)=%d should be 1 [nodes]", len(args))
		}
		nodes, ok := args[0].([]string)
		if !ok {
			return "", nil, fmt.Errorf("args[0] node should be []string")
		}
		return replaceSQL(sq.Delete("units").Where(sq.Eq{"node": nodes}).ToSql())
	}, nil
}

func upsertUnitsQuery() (SQLQuery, error) {
	// we can not prerender query with dynamic count of VALUES
	return func(args ...interface{}) (string, []interface{}, error) {
		if len(args) != 1 {
			return "", nil, fmt.Errorf("len(args)=%d should be 1 [units]", len(args))
		}
		units, ok := args[0].([]*pb.Unit)
		if !ok {
			return "", nil, fmt.Errorf("args[0] units should be []*pb.Unit")
		}
		builder := squtil.Upsert("units")
		for _, row := range units {
			lastTransition, err := ptypes.Timestamp(row.LastTransition)
			if err != nil {
				return "", nil, err
			}
			builder.Values(row.Node, row.Name, row.Version, row.Stage, row.Kind, int(row.Ready), int(row.Pending), lastTransition)
		}
		return replaceSQL(builder.ToSQL())
	}, nil
}

func getUnitsReportsQuery() (SQLQuery, error) {
	// we can prerender query with dynamic where statements
	return func(args ...interface{}) (string, []interface{}, error) {
		if len(args) != 5 {
			return "", nil, fmt.Errorf("len(args)=%d should be 5 [unit, []ready, []pending, limit, offset]", len(args))
		}
		unit, ok := args[0].(*pb.Unit)
		if !ok {
			return "", nil, fmt.Errorf("args[0] unit should be *pb.Unit")
		}
		ready, ok := args[1].([]pb.Status)
		if !ok {
			return "", nil, fmt.Errorf("args[1] ready should be []pb.Status")
		}
		pending, ok := args[2].([]pb.Status)
		if !ok {
			return "", nil, fmt.Errorf("args[2] pending should be []pb.Status")
		}
		limit, ok := args[3].(uint64)
		if !ok {
			return "", nil, fmt.Errorf("args[3] limit should be uint64")
		}
		offset, ok := args[4].(uint64)
		if !ok {
			return "", nil, fmt.Errorf("args[4] offset should be uint64")
		}
		queryBuilder := sq.Select("*").From("units")
		if unit.Node != "" {
			queryBuilder = queryBuilder.Where(sq.Eq{"node": unit.Node})
		}
		if unit.Name != "" {
			queryBuilder = queryBuilder.Where(sq.Eq{"name": unit.Name})
		}
		if unit.Stage != "" {
			queryBuilder = queryBuilder.Where(sq.Eq{"stage": unit.Stage})
		}
		if unit.Version != "" {
			queryBuilder = queryBuilder.Where(sq.Eq{"version": unit.Version})
		}
		queryBuilder = queryBuilder.Where(sq.Eq{"ready": ready})
		queryBuilder = queryBuilder.Where(sq.Eq{"pending": pending})
		queryBuilder = queryBuilder.Limit(limit)
		queryBuilder = queryBuilder.Offset(offset)
		return replaceSQL(queryBuilder.ToSql())
	}, nil
}

func getHostQuery() (SQLQuery, error) {
	query, _, err := sq.
		Select(hostsColumns...).
		From("hosts").
		Where(sq.Eq{"hostname": ""}).
		ToSql()
	if err != nil {
		return nil, err
	}
	return func(args ...interface{}) (string, []interface{}, error) {
		if len(args) != 1 {
			return "", nil, fmt.Errorf("len(args)=%d should be 1 [node]", len(args))
		}
		if _, ok := args[0].(string); !ok {
			return "", nil, fmt.Errorf("args[0] node should be string")
		}
		return replaceSQL(query, args, nil)
	}, nil
}

func updateHostQuery() (SQLQuery, error) {
	query, _, err := sq.Update("hosts").
		Set("num", 0).
		Set("walle_project", "").
		Set("walle_tags", pgx.Array(make([]string, 0))).
		Set("net_switch", "").
		Set("gencfg_groups", pgx.Array(make([]string, 0))).
		Set("location", "").
		Set("dc", "").
		Set("kernel", "").
		Set("cpu_model", "").
		Set("mem_total_mib", 0).
		Set("dc_queue", "").
		Set("os_codename", "").
		Set("os_arch", "").
		Where(sq.Eq{"hostname": ""}).
		ToSql()
	if err != nil {
		return nil, err
	}
	return func(args ...interface{}) (string, []interface{}, error) {
		if len(args) != 1 {
			return "", nil, fmt.Errorf("len(args)=%d should be 1 [host_info]", len(args))
		}
		info, ok := args[0].(*yasaltpb.HostInfo)
		if !ok {
			return "", nil, fmt.Errorf("args[0] host_info should be *yasaltpb.HostInfo")
		}
		sqlArgs := []interface{}{
			// Query args MUST follow the order of Set()s above
			info.Num, info.WalleProject, info.WalleTags, info.NetSwitch, info.GencfgGroups,
			info.Location, info.Dc, info.KernelRelease, info.CpuModel, info.MemTotalMib, info.DcQueue, info.OsCodename, info.OsArch,
			// Hostname arg MUST be the last one
			info.Hostname}
		return replaceSQL(query, sqlArgs, nil)
	}, nil
}

func upsertMetaQuery() (SQLQuery, error) {
	query, _, err := squtil.
		Upsert("hosts").
		Columns("hostname", "units_ts", "host_ts", "report_time").
		Values("", time.Now(), time.Now(), time.Now()).
		ToSQL()
	if err != nil {
		return nil, err
	}
	return func(args ...interface{}) (string, []interface{}, error) {
		if len(args) != 4 {
			return "", nil, fmt.Errorf("len(args)=%d should be 4 [node, units_ts, host_ts, report_time]", len(args))
		}
		if _, ok := args[0].(string); !ok {
			return "", nil, fmt.Errorf("args[0] node should be string")
		}
		if _, ok := args[1].(types.UnitsTS); !ok {
			return "", nil, fmt.Errorf("args[1] units_ts should be types.UnitsTS")
		}
		if _, ok := args[2].(types.HostTS); !ok {
			return "", nil, fmt.Errorf("args[2] host_ts should be types.HostTS")
		}
		if _, ok := args[3].(time.Time); !ok {
			return "", nil, fmt.Errorf("args[3] report_time should be time.Time")
		}
		return replaceSQL(query, args, nil)
	}, nil
}

func updateReportTimeQuery() (SQLQuery, error) {
	query, _, err := sq.Update("hosts").
		Set("report_time", time.Now()).
		Where(sq.Eq{"hostname": ""}).
		Suffix("RETURNING host_ts, units_ts").
		ToSql()
	if err != nil {
		return nil, err
	}
	return func(args ...interface{}) (string, []interface{}, error) {
		if len(args) != 2 {
			return "", nil, fmt.Errorf("len(args)=%d should be 2 [report_time, node]", len(args))
		}
		if _, ok := args[0].(time.Time); !ok {
			return "", nil, fmt.Errorf("args[0] report_time should be time.Time")
		}
		if _, ok := args[1].(string); !ok {
			return "", nil, fmt.Errorf("args[1] node should be string")
		}
		return replaceSQL(query, args, nil)
	}, nil
}

func getHostsQuery() (SQLQuery, error) {
	return func(args ...interface{}) (string, []interface{}, error) {
		if len(args) != 3 {
			return "", nil, fmt.Errorf("len(args)=%d should be 3 [host_info, limit, offset]", len(args))
		}
		info, ok := args[0].(*yasaltpb.HostInfo)
		if !ok {
			return "", nil, fmt.Errorf("args[0] host_info should be *yasaltpb.HostInfo")
		}
		limit, ok := args[1].(uint64)
		if !ok {
			return "", nil, fmt.Errorf("args[1] limit should be uint64")
		}
		offset, ok := args[2].(uint64)
		if !ok {
			return "", nil, fmt.Errorf("args[2] offset should be uint64")
		}
		queryBuilder := sq.Select("hostname").From("hosts")
		if info.Hostname != "" {
			queryBuilder = queryBuilder.Where(sq.Like{"hostname": info.Hostname + "%"})
		}
		if info.Num > 0 {
			queryBuilder = queryBuilder.Where(sq.LtOrEq{"num": info.Num})
		}
		if info.WalleProject != "" {
			queryBuilder = queryBuilder.Where(sq.Eq{"walle_project": info.WalleProject})
		}
		if len(info.WalleTags) > 0 {
			// filtering records where walle_tags intersecting with requested
			queryBuilder = queryBuilder.Where(fmt.Sprintf("walle_tags && array['%s']", strings.Join(info.WalleTags, "','")))
		}
		if info.NetSwitch != "" {
			queryBuilder = queryBuilder.Where(sq.Eq{"net_switch": info.NetSwitch})
		}
		if len(info.GencfgGroups) > 0 {
			// filtering records where gencfg_groups intersecting with requested
			queryBuilder = queryBuilder.Where(fmt.Sprintf("gencfg_groups && array['%s']", strings.Join(info.GencfgGroups, "','")))
		}
		if info.Location != "" {
			queryBuilder = queryBuilder.Where(sq.Eq{"location": info.Location})
		}
		if info.Dc != "" {
			queryBuilder = queryBuilder.Where(sq.Eq{"dc": info.Dc})
		}
		if info.KernelRelease != "" {
			queryBuilder = queryBuilder.Where(sq.Eq{"kernel": info.KernelRelease})
		}
		queryBuilder = queryBuilder.Limit(uint64(limit)).Offset(uint64(offset))
		return replaceSQL(queryBuilder.ToSql())
	}, nil
}

func removeHostsQuery() (SQLQuery, error) {
	return func(args ...interface{}) (string, []interface{}, error) {
		if len(args) != 2 {
			return "", nil, fmt.Errorf("len(args)=%d should be 2 [ttl, limit]", len(args))
		}
		ttl, ok := args[0].(time.Duration)
		if !ok {
			return "", nil, fmt.Errorf("args[0] ttl should be time.Duration")
		}
		limit, ok := args[1].(uint64)
		if !ok {
			return "", nil, fmt.Errorf("args[1] limit should be uint64")
		}
		return replaceSQL(sq.Delete("hosts").
			Where(sq.Lt{"report_time": time.Now().Add(ttl)}).
			Limit(limit).
			Suffix("RETURNING (hostname)").
			ToSql())
	}, nil
}

func hostsCountQuery() (SQLQuery, error) {
	q, _, err := replaceSQL(sq.Select("count(*)").
		From("hosts").
		ToSql())
	if err != nil {
		return nil, err
	}
	return func(args ...interface{}) (string, []interface{}, error) {
		return q, args, nil
	}, nil
}

func kernelVersionsQuery() (SQLQuery, error) {
	q, _, err := sq.Select("kernel", "count(*)").
		From("hosts").
		GroupBy("kernel").
		ToSql()
	if err != nil {
		return nil, err
	}
	return func(args ...interface{}) (string, []interface{}, error) {
		return q, args, nil
	}, nil
}

func getHostsCursor() (SQLQuery, error) {
	return func(args ...interface{}) (string, []interface{}, error) {
		if len(args) != 3 {
			return "", nil, fmt.Errorf("len(args)=%d should be 3 [offset, limit, ts]", len(args))
		}
		offset, ok := args[0].(uint64)
		if !ok {
			return "", nil, fmt.Errorf("args[1] offset should be uint64")
		}
		limit, ok := args[1].(uint64)
		if !ok {
			return "", nil, fmt.Errorf("args[1] limit should be uint64")
		}
		return replaceSQL(sq.Select(hostsColumns...).
			From("hosts").
			Limit(limit).
			Offset(offset).
			ToSql())
	}, nil
}

type unitsQueries struct {
	getUnitsByNode    SQLQuery
	getUnits          SQLQuery
	getUnitVersions   SQLQuery
	getUnitReady      SQLQuery
	getUnitPending    SQLQuery
	getUnitStages     SQLQuery
	removeByNodeNames SQLQuery
	removeByNodes     SQLQuery
	upsertUnits       SQLQuery
	getUnitsReports   SQLQuery
}

type hostsQueries struct {
	getHost             SQLQuery
	updateHost          SQLQuery
	upsertMeta          SQLQuery
	updateReportTime    SQLQuery
	getHosts            SQLQuery
	removeHosts         SQLQuery
	getHostsCursor      SQLQuery
	hostsCountQuery     SQLQuery
	kernelVersionsQuery SQLQuery
}

func prepareUnitsQueries() (*unitsQueries, error) {
	queries := &unitsQueries{}
	q, err := getUnitsByNodeQuery()
	if err != nil {
		return nil, err
	}
	queries.getUnitsByNode = q
	q, err = getUnitsQuery()
	if err != nil {
		return nil, err
	}
	queries.getUnits = q
	q, err = getUnitVersionsQuery()
	if err != nil {
		return nil, err
	}
	queries.getUnitVersions = q
	q, err = getUnitReadyQuery()
	if err != nil {
		return nil, err
	}
	queries.getUnitReady = q
	q, err = getUnitPendingQuery()
	if err != nil {
		return nil, err
	}
	queries.getUnitPending = q
	q, err = getUnitStagesQuery()
	if err != nil {
		return nil, err
	}
	queries.getUnitStages = q
	q, err = removeByNodeNamesQuery()
	if err != nil {
		return nil, err
	}
	queries.removeByNodeNames = q
	q, err = removeByNodesQuery()
	if err != nil {
		return nil, err
	}
	queries.removeByNodes = q
	q, err = upsertUnitsQuery()
	if err != nil {
		return nil, err
	}
	queries.upsertUnits = q
	q, err = getUnitsReportsQuery()
	if err != nil {
		return nil, err
	}
	queries.getUnitsReports = q
	return queries, nil
}

func prepareHostQueries() (*hostsQueries, error) {
	queries := &hostsQueries{}
	q, err := getHostQuery()
	if err != nil {
		return nil, err
	}
	queries.getHost = q
	q, err = updateHostQuery()
	if err != nil {
		return nil, err
	}
	queries.updateHost = q
	q, err = upsertMetaQuery()
	if err != nil {
		return nil, err
	}
	queries.upsertMeta = q
	q, err = updateReportTimeQuery()
	if err != nil {
		return nil, err
	}
	queries.updateReportTime = q
	q, err = getHostsQuery()
	if err != nil {
		return nil, err
	}
	queries.getHosts = q
	q, err = removeHostsQuery()
	if err != nil {
		return nil, err
	}
	queries.removeHosts = q
	q, err = getHostsCursor()
	if err != nil {
		return nil, err
	}
	queries.getHostsCursor = q
	q, err = hostsCountQuery()
	if err != nil {
		return nil, err
	}
	queries.hostsCountQuery = q
	q, err = kernelVersionsQuery()
	if err != nil {
		return nil, err
	}
	queries.kernelVersionsQuery = q
	return queries, nil
}
