package tables

import (
	"context"
	"fmt"
	"path"
	"time"

	"a.yandex-team.ru/kikimr/public/sdk/go/ydb"
	"a.yandex-team.ru/kikimr/public/sdk/go/ydb/table"
	"a.yandex-team.ru/library/go/core/log"
	"a.yandex-team.ru/travel/avia/personalization/internal/auth"
	"a.yandex-team.ru/travel/avia/personalization/internal/metrics"
)

type EventServiceTypeKey struct {
	Service uint8
	Type    uint8
}

type EventServiceTypeToLimit = map[EventServiceTypeKey]uint8

type UserEventsTable struct {
	logger        log.Logger
	connection    table.SessionProvider
	dbName        string
	tableName     string
	upsertQuery   string
	upsertTimeout time.Duration
	selectTimeout time.Duration
}

func NewUserEventsTable(logger log.Logger, connection table.SessionProvider, dbName, tableName string) *UserEventsTable {
	return &UserEventsTable{
		logger:        logger,
		connection:    connection,
		dbName:        dbName,
		tableName:     tableName,
		upsertQuery:   fmt.Sprintf(upsertQueryTemplate, tableName),
		selectTimeout: 1 * time.Second,
		upsertTimeout: 10 * time.Second,
	}
}

func (t *UserEventsTable) EnsureExists(ctx context.Context) error {
	return table.Retry(ctx, t.connection, table.OperationFunc(func(ctx context.Context, s *table.Session) error {
		return s.CreateTable(ctx, path.Join(t.dbName, t.tableName),
			table.WithColumn("auth_value", ydb.Optional(ydb.TypeUTF8)),
			table.WithColumn("auth_type", ydb.Optional(ydb.TypeUint8)),
			table.WithColumn("service", ydb.Optional(ydb.TypeUint8)),
			table.WithColumn("event_type", ydb.Optional(ydb.TypeUint8)),
			table.WithColumn("event_key", ydb.Optional(ydb.TypeUTF8)),
			table.WithColumn("event_data", ydb.Optional(ydb.TypeUTF8)),
			table.WithColumn("created_at", ydb.Optional(ydb.TypeDatetime)),
			table.WithColumn("expires_at", ydb.Optional(ydb.TypeDatetime)),
			table.WithPrimaryKeyColumn("auth_value", "auth_type", "service", "event_type", "event_key"),
			table.WithTimeToLiveSettings(table.TimeToLiveSettings{
				ColumnName:         "expires_at",
				ExpireAfterSeconds: 0,
				Mode:               table.TimeToLiveModeDateType,
			}),
		)
	}))
}

const (
	selectByServiceAndTypePairsTemplate = `
		--!syntax_v1
		PRAGMA AnsiInForEmptyOrNullableItemsCollections;

		DECLARE $auth_type AS Uint8;
		DECLARE $auth_value AS Utf8;
	`
	selectByServiceAndTypeSubqueryTemplate = `
		SELECT
			auth_type,
			auth_value,
			service,
			event_type,
			event_key,
			event_data,
			created_at,
			expires_at
		FROM %s
		WHERE
			auth_type = $auth_type
			AND auth_value = $auth_value
			AND service = %d
			AND event_type = %d
		ORDER BY created_at DESC
		LIMIT %d;
	`

	upsertQueryTemplate = `
		--!syntax_v1

		DECLARE $auth_type AS Uint8;
		DECLARE $auth_value AS Utf8;
		DECLARE $service AS Uint8;
		DECLARE $event_type AS Uint8;
		DECLARE $event_key AS Utf8;
		DECLARE $event_data AS Utf8;
		DECLARE $created_at AS Datetime;
		DECLARE $expires_at AS Datetime;

		UPSERT INTO %[1]s (auth_type, auth_value, service, event_key, event_type, event_data, created_at, expires_at)
		VALUES ($auth_type, $auth_value, $service, $event_key, $event_type, $event_data, $created_at, $expires_at);
	`
)

func (t *UserEventsTable) Upsert(
	ctx context.Context,
	entry UserEventEntry,
) (err error) {
	start := time.Now()
	defer metrics.WriteTimings("UserEventsTable", start, map[string]string{"request_type": "Upsert"})

	ctx, cancelFunc := context.WithTimeout(ctx, t.upsertTimeout)
	defer cancelFunc()
	return table.Retry(ctx, t.connection, table.OperationFunc(func(ctx context.Context, s *table.Session) error {
		writeTx := table.TxControl(table.BeginTx(table.WithSerializableReadWrite()), table.CommitTx())
		preparedUpsertStatement, err := s.Prepare(ctx, t.upsertQuery)
		if err != nil {
			return err
		}
		_, _, err = preparedUpsertStatement.Execute(
			ctx, writeTx,
			table.NewQueryParameters(
				table.ValueParam("$auth_type", ydb.Uint8Value(entry.AuthType)),
				table.ValueParam("$auth_value", ydb.UTF8Value(entry.AuthValue)),
				table.ValueParam("$service", ydb.Uint8Value(entry.Service)),
				table.ValueParam("$event_type", ydb.Uint8Value(entry.EventType)),
				table.ValueParam("$event_key", ydb.UTF8Value(entry.EventKey)),
				table.ValueParam("$event_data", ydb.UTF8Value(entry.EventData)),
				table.ValueParam("$created_at", ydb.DatetimeValue(entry.CreatedAt)),
				table.ValueParam("$expires_at", ydb.DatetimeValue(entry.ExpiresAt)),
			),
		)
		return err
	}))
}

func (t *UserEventsTable) SelectByServiceTypePairs(
	ctx context.Context,
	authType auth.Type,
	authValue string,
	serviceTypeKeys EventServiceTypeToLimit,
) (map[EventServiceTypeKey]UserEventEntries, error) {
	start := time.Now()
	defer metrics.WriteTimings("UserEventsTable", start, map[string]string{"request_type": "SelectByServiceTypePairs"})

	query := selectByServiceAndTypePairsTemplate
	serviceTypeKeysList := make([]EventServiceTypeKey, 0, len(serviceTypeKeys))
	for serviceType, limit := range serviceTypeKeys {
		serviceTypeKeysList = append(serviceTypeKeysList, serviceType)
		query += fmt.Sprintf(selectByServiceAndTypeSubqueryTemplate, t.tableName, serviceType.Service, serviceType.Type, limit)
	}
	ctx, cancel := context.WithTimeout(ctx, t.upsertTimeout)
	defer cancel()

	var queryResult *table.Result
	err := table.Retry(
		ctx,
		t.connection,
		table.OperationFunc(func(ctx context.Context, s *table.Session) error {
			readTx := table.TxControl(table.BeginTx(table.WithStaleReadOnly()), table.CommitTx())
			preparedSelectStatement, err := s.Prepare(ctx, query)
			if err != nil {
				return err
			}
			_, queryResult, err = preparedSelectStatement.Execute(
				ctx, readTx,
				table.NewQueryParameters(
					table.ValueParam("$auth_type", ydb.Uint8Value(authType)),
					table.ValueParam("$auth_value", ydb.UTF8Value(authValue)),
				),
			)
			return err
		}),
	)

	if err != nil {
		return nil, err
	}

	if err = queryResult.Err(); err != nil {
		return nil, err
	}

	eventsByServiceType := make(map[EventServiceTypeKey]UserEventEntries)
	for _, serviceType := range serviceTypeKeysList {
		if !queryResult.NextResultSet(ctx) {
			break
		}
		entries := make(UserEventEntries, 0)
		if err := entries.Scan(queryResult); err != nil {
			return nil, err
		}
		eventsByServiceType[serviceType] = entries
	}
	return eventsByServiceType, nil
}
