package xydb

import (
	"context"
	"encoding/json"
	"testing"
	"time"

	"github.com/google/go-cmp/cmp"
	"github.com/stretchr/testify/require"
	"github.com/stretchr/testify/suite"
	ydbTable "github.com/ydb-platform/ydb-go-sdk/v3/table"
	ydbOptions "github.com/ydb-platform/ydb-go-sdk/v3/table/options"
	ydbNamed "github.com/ydb-platform/ydb-go-sdk/v3/table/result/named"
	ydbTypes "github.com/ydb-platform/ydb-go-sdk/v3/table/types"

	testutils "a.yandex-team.ru/tasklet/experimental/internal/test_utils"
)

type ClientSuite struct {
	suite.Suite
	suiteClient *Client
}

func (es *ClientSuite) TestReadOneRow() {
	ctx := context.Background()
	es.Run(
		"ok", func() {
			var val int64
			err := es.suiteClient.ReadOneRow(
				ctx,
				`SELECT CAST(1 AS Int64) as val`,
				ReadValues{ydbNamed.Required("val", &val)},
			)
			es.NoError(err)
			es.Equal(int64(1), val)
		},
	)

	es.Run(
		"read error", func() {
			err := es.suiteClient.ReadOneRow(ctx, `asdfad`, ReadValues{})
			es.Error(err)
			es.NotErrorIs(err, ErrNoRows)
			es.NotErrorIs(err, ErrMoreThenOneRow)
		},
	)

	es.Run(
		"result set error", func() {
			var val int64
			err := es.suiteClient.ReadOneRow(
				ctx,
				`SELECT 1 AS val`,
				ReadValues{ydbNamed.OptionalWithDefault("bad_field_name", &val)},
			)
			es.Error(err)
			es.NotErrorIs(err, ErrNoRows)
			es.NotErrorIs(err, ErrMoreThenOneRow)
		},
	)

	es.Run(
		"no rows", func() {
			var val int64
			err := es.suiteClient.ReadOneRow(
				ctx, `
SELECT * FROM (
        SELECT 1 AS val
)
WHERE val = 2
`, ReadValues{ydbNamed.Required("val", &val)},
			)
			es.ErrorIs(err, ErrNoRows)
		},
	)

	es.Run(
		"two rows", func() {
			var val int64
			err := es.suiteClient.ReadOneRow(
				ctx, `
SELECT * FROM (
        SELECT 1 AS val
        UNION ALL
        SELECT 1 AS val
)
`, ReadValues{ydbNamed.Required("val", &val)},
			)
			es.ErrorIs(err, ErrMoreThenOneRow)
		},
	)

	es.Run(
		"scan error", func() {
			var val string
			results := ReadValues{ydbNamed.OptionalWithDefault("val", &val)}
			err := es.suiteClient.ReadOneRow(ctx, `SELECT CAST(1 AS Int64) as val`, results)
			es.Error(err)
			es.NotErrorIs(err, ErrNoRows)
			es.NotErrorIs(err, ErrMoreThenOneRow)
		},
	)
}

func (es *ClientSuite) TestWrite() {
	ctx := context.Background()
	r := es.Require()
	r.NoError(
		es.suiteClient.ResetTable(
			ctx,
			"sample_test_write",
			ydbOptions.WithColumn("id", ydbTypes.Optional(ydbTypes.TypeInt32)),
			ydbOptions.WithColumn("msg", ydbTypes.Optional(ydbTypes.TypeString)),
			ydbOptions.WithPrimaryKeyColumn("id"),
		),
	)
	// language=YQL
	writeQuery := `
		DECLARE $id AS Int32;
		DECLARE $msg AS Optional<String>;
		INSERT INTO sample_test_write (id, msg)
		VALUES ($id, $msg);
	`
	// language=YQL
	readQuery := `
		DECLARE $id AS Int32;
		SELECT id, msg FROM sample_test_write WHERE id == $id;
	`

	es.Run(
		"ok", func() {
			err := es.suiteClient.Write(
				ctx,
				writeQuery,
				ydbTable.ValueParam("$id", ydbTypes.Int32Value(1)),
				ydbTable.ValueParam("$msg", ydbTypes.OptionalValue(ydbTypes.StringValueFromString("boo"))),
			)
			es.NoError(err)
			var v int32
			var msg string
			err = es.suiteClient.ReadOneRow(
				ctx, readQuery,
				ReadValues{ydbNamed.OptionalWithDefault("id", &v), ydbNamed.OptionalWithDefault("msg", &msg)},
				ydbTable.ValueParam("$id", ydbTypes.Int32Value(1)),
			)
			es.NoError(err)
			es.Equal(v, int32(1))
			es.Equal("boo", msg)
		},
	)
}

type readTableConsumer struct {
	columns  []string
	rows     []any
	rowMaker func() (any, ReadValues)
}

func (r *readTableConsumer) GetColumnSet() []string {
	return r.columns
}

func (r *readTableConsumer) NextReadValues() ReadValues {
	row, readValues := r.rowMaker()
	r.rows = append(r.rows, row)
	return readValues
}

func (r *readTableConsumer) ConsumeReadValues() error {
	return nil
}

func (r readTableConsumer) Done() {
	// noop
}

func (es *ClientSuite) TestReadTable() {
	ctx := context.Background()
	r := es.Require()
	r.NoError(
		es.suiteClient.CreateTable(
			ctx,
			"sample",
			ydbOptions.WithColumn("id", ydbTypes.Optional(ydbTypes.TypeInt32)),
			ydbOptions.WithColumn("msg", ydbTypes.Optional(ydbTypes.TypeString)),
			ydbOptions.WithColumn("payload", ydbTypes.Optional(ydbTypes.TypeJSONDocument)),
			ydbOptions.WithPrimaryKeyColumn("id"),
		),
	)
	// language=YQL
	query := `
		REPLACE INTO sample (id, msg, payload)
		VALUES
            (1, "foo", CAST(@@{"foo": "bar"}@@ AS JsonDocument)),
            (2, NULL, CAST(@@{"x": 17}@@ AS JsonDocument)),
            (3, "bar", NULL),
            (4, NULL, NULL)
        ;
	`
	{
		_, err := es.suiteClient.Do(ctx, query, es.suiteClient.WriteTxControl)
		r.NoError(err)
	}

	es.Run(
		"ok", func() {
			type rowStruct struct {
				id  int32
				msg string
				js  []byte
			}
			consumer := &readTableConsumer{
				columns: []string{"id", "msg", "payload"},
				rowMaker: func() (any, ReadValues) {
					v := &rowStruct{}
					return v, ReadValues{
						ydbNamed.OptionalWithDefault("id", &v.id),
						ydbNamed.OptionalWithDefault("msg", &v.msg),
						ydbNamed.OptionalWithDefault("payload", &v.js),
					}
				},
			}
			err := es.suiteClient.StreamReadTable(
				ctx,
				"sample",
				consumer,
			)
			r.NoError(err)

			r.Len(consumer.rows, 4)
			for _, item := range consumer.rows {
				row := item.(*rowStruct)
				switch row.id {
				case 1:
					r.Equal("foo", row.msg)
					var v interface{}
					r.NoError(json.Unmarshal(row.js, &v))
					expected := map[string]any{"foo": "bar"}
					if !cmp.Equal(expected, v) {
						es.Fail("not equal", cmp.Diff(expected, v))
					}
				case 2:
					r.Equal("", row.msg)
					var v interface{}
					r.NoError(json.Unmarshal(row.js, &v))
					expected := map[string]any{"x": float64(17)}
					if !cmp.Equal(expected, v) {
						es.Fail("not equal", cmp.Diff(expected, v))
					}
				case 3:
					r.Equal("bar", row.msg)
					r.Equal(([]byte)(nil), row.js)
				case 4:
					r.Equal("", row.msg)
					r.Equal(([]byte)(nil), row.js)
				default:
					es.Failf("unexpected", "id: %v", row.id)

				}
			}
		},
	)

}

func TestClient(t *testing.T) {
	ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
	defer cancel()
	tmpdir := testutils.TwistTmpDir(t)
	logger := testutils.TwistMakeLogger(tmpdir, "client.log")

	client := MustGetYdbClient(ctx, logger, t.Name())
	client.SetLogQueries(true)
	defer func() {
		ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
		defer cancel()
		require.NoError(t, client.Close(ctx))
	}()

	s := &ClientSuite{
		suiteClient: client,
	}
	suite.Run(t, s)
}
