package runsql

import (
	"context"
	"fmt"
	"io/ioutil"
	"strings"

	"github.com/spf13/cobra"

	"a.yandex-team.ru/drive/analytics/gotasks"
	"a.yandex-team.ru/drive/library/go/gosql"
	"a.yandex-team.ru/library/go/core/log"
)

// SQLCmd represents "sql" command.
var SQLCmd = cobra.Command{Use: "sql"}

func init() {
	// Register "sql" command.
	SQLCmd.PersistentFlags().String("db", "", "Name of database connection")
	gotasks.RootCmd.AddCommand(&SQLCmd)
	// Register subcommands.
	sqlQueryCmd := cobra.Command{
		Use: "query",
		Run: gotasks.WrapMain(sqlQueryMain),
	}
	sqlQueryCmd.Flags().String("file", "", "Path to SQL file")
	sqlQueryCmd.Flags().Duration("timeout", 0, "Timeout of query in seconds")
	SQLCmd.AddCommand(&sqlQueryCmd)
}

func sqlQueryMain(ctx *gotasks.Context) error {
	dbName, err := ctx.Cmd.Flags().GetString("db")
	if err != nil {
		return err
	}
	file, err := ctx.Cmd.Flags().GetString("file")
	if err != nil {
		return err
	}
	timeout, err := ctx.Cmd.Flags().GetDuration("timeout")
	if err != nil {
		return err
	}
	db, ok := ctx.DBs[dbName]
	if !ok {
		return fmt.Errorf("db %q does not exists", dbName)
	}
	query, err := ioutil.ReadFile(file)
	if err != nil {
		return err
	}
	queryCtx := context.Background()
	if timeout != 0 {
		var cancel context.CancelFunc
		queryCtx, cancel = context.WithTimeout(queryCtx, timeout)
		defer cancel()
	}
	queries := []string{string(query)}
	if db.Driver == gosql.ClickHouseDriver {
		queries = strings.Split(string(query), ";")
	}
	for _, query := range queries {
		query = strings.TrimSpace(query)
		if query == "" {
			continue
		}
		res, err := db.ExecContext(queryCtx, query)
		if err != nil {
			return err
		}
		if affected, err := res.RowsAffected(); err == nil {
			ctx.Logger.Info("Affected rows", log.Int64("affected", affected))
		}
	}
	return nil
}
