package models

import (
	"context"
	"encoding/json"
	"errors"
	"fmt"
	"path"
	"time"

	"a.yandex-team.ru/kikimr/public/sdk/go/ydb"
	"a.yandex-team.ru/kikimr/public/sdk/go/ydb/table"
)

const (
	tableName = "virus_total"
)

type (
	DB struct {
		ctx               context.Context
		sp                *table.SessionPool
		selectQuery       string
		replaceQuery      string
		countUpdatedQuery string
	}

	DBConfig struct {
		AuthToken string
		Database  string
		Endpoint  string
	}
)

var (
	ErrNotFound = errors.New("not found")
)

func selectQuery(pathPrefix string) string {
	return fmt.Sprintf(`
PRAGMA TablePathPrefix("%s");

DECLARE $sha256 AS "String";

SELECT sha256, md5, sha1, found, positives, total, created_at, updated_at, scans
  FROM %s
WHERE sha256 = $sha256
LIMIT 1;
`, pathPrefix, tableName)
}

func replaceQuery(pathPrefix string) string {
	return fmt.Sprintf(`
PRAGMA TablePathPrefix("%s");

DECLARE $sha256 AS "String";
DECLARE $md5 AS "String";
DECLARE $sha1 AS "String";
DECLARE $found AS "Bool";
DECLARE $positives AS "Uint32";
DECLARE $total AS "Uint32";
DECLARE $created_at AS "Uint64";
DECLARE $updated_at AS "Uint64";
DECLARE $scans AS "Json?";

REPLACE INTO %s
  (sha256, md5, sha1, found, positives, total, created_at, updated_at, scans)
VALUES
  ($sha256, $md5, $sha1, $found, $positives, $total, $created_at, $updated_at, $scans);
`, pathPrefix, tableName)
}

func countUpdatedQuery(pathPrefix string) string {
	return fmt.Sprintf(`
PRAGMA TablePathPrefix("%s");

DECLARE $after AS "Int64";

SELECT count(1) AS count
  FROM %s
WHERE updated_at >= $after
;
`, pathPrefix, tableName)
}

func NewDB(ctx context.Context, cfg DBConfig) (*DB, error) {
	config := new(ydb.DriverConfig)

	config.Credentials = ydb.AuthTokenCredentials{
		AuthToken: cfg.AuthToken,
	}
	config.Database = cfg.Database

	driver, err := (&ydb.Dialer{
		DriverConfig: config,
	}).Dial(ctx, cfg.Endpoint)

	if err != nil {
		return nil, fmt.Errorf("dial error: %v", err)
	}

	tableClient := table.Client{
		Driver: driver,
	}

	sp := table.SessionPool{
		IdleThreshold: 10 * time.Second,
		Builder:       &tableClient,
	}

	err = createTables(ctx, &sp, cfg.Database)
	if err != nil {
		return nil, fmt.Errorf("create tables error: %v", err)
	}

	return &DB{
		ctx:               ctx,
		sp:                &sp,
		selectQuery:       selectQuery(cfg.Database),
		replaceQuery:      replaceQuery(cfg.Database),
		countUpdatedQuery: countUpdatedQuery(cfg.Database),
	}, nil
}

func (c *DB) Reset() error {
	return c.sp.Close(c.ctx)
}

func (c *DB) LookupRecord(key string) (result *Report, err error) {
	readTx := table.TxControl(
		table.BeginTx(
			table.WithOnlineReadOnly(),
		),
		table.CommitTx(),
	)
	var res *table.Result
	err = table.Retry(c.ctx, c.sp,
		table.OperationFunc(func(ctx context.Context, s *table.Session) (err error) {
			stmt, err := s.Prepare(ctx, c.selectQuery)
			if err != nil {
				return err
			}

			_, res, err = stmt.Execute(ctx, readTx, table.NewQueryParameters(
				table.ValueParam("$sha256", ydb.StringValue([]byte(key))),
			))
			return
		}),
	)

	if err != nil {
		return
	}

	if !res.NextSet() || !res.NextRow() {
		err = ErrNotFound
		return
	}

	// TODO(buglloc): migrate DB types!
	result = &Report{}
	// SELECT sha256, md5, sha1, found, positives, total, created_at, updated_at, scans
	res.SeekItem("sha256")
	result.Sha256 = string(res.OString())

	res.NextItem()
	result.Md5 = string(res.OString())

	res.NextItem()
	result.Sha1 = string(res.OString())

	res.NextItem()
	result.Found = res.OBool()

	res.NextItem()
	result.Positives = int(res.OUint32())

	res.NextItem()
	result.Total = int(res.OUint32())

	res.NextItem()
	result.CreatedAt = time.Unix(int64(res.OUint64()), 0)

	res.NextItem()
	result.UpdatedAt = time.Unix(int64(res.OUint64()), 0)

	res.NextItem()
	scans := res.OJSON()
	if scans != "" {
		err = json.Unmarshal([]byte(scans), &result.Scans)
		if err != nil {
			return
		}
	}

	err = res.Err()
	return
}

func (c *DB) CountUpdatedAfter(after int64) (count uint64, err error) {
	readTx := table.TxControl(
		table.BeginTx(
			table.WithOnlineReadOnly(),
		),
		table.CommitTx(),
	)
	var res *table.Result
	err = table.Retry(c.ctx, c.sp,
		table.OperationFunc(func(ctx context.Context, s *table.Session) (err error) {
			stmt, err := s.Prepare(ctx, c.countUpdatedQuery)
			if err != nil {
				return err
			}

			_, res, err = stmt.Execute(ctx, readTx, table.NewQueryParameters(
				table.ValueParam("$after", ydb.Int64Value(after)),
			))
			return
		}),
	)

	if err != nil {
		return
	}

	if !res.NextSet() || !res.NextRow() {
		err = ErrNotFound
		return
	}

	res.SeekItem("count")
	count = res.Uint64()
	err = res.Err()
	return
}

func (c *DB) InsertRecord(report *Report) (err error) {
	// Prepare write transaction.
	writeTx := table.TxControl(
		table.BeginTx(
			table.WithSerializableReadWrite(),
		),
		table.CommitTx(),
	)

	return table.Retry(c.ctx, c.sp,
		table.OperationFunc(func(ctx context.Context, s *table.Session) (err error) {
			stmt, err := s.Prepare(ctx, c.replaceQuery)
			if err != nil {
				return err
			}

			scans, jErr := json.Marshal(report.Scans)
			if jErr != nil {
				return jErr
			}

			_, _, err = stmt.Execute(ctx, writeTx, table.NewQueryParameters(
				table.ValueParam("$sha256", ydb.StringValue([]byte(report.Sha256))),
				table.ValueParam("$md5", ydb.StringValue([]byte(report.Md5))),
				table.ValueParam("$sha1", ydb.StringValue([]byte(report.Sha1))),
				table.ValueParam("$found", ydb.BoolValue(report.Found)),
				table.ValueParam("$positives", ydb.Uint32Value(uint32(report.Positives))),
				table.ValueParam("$total", ydb.Uint32Value(uint32(report.Total))),
				table.ValueParam("$created_at", ydb.Uint64Value(uint64(report.CreatedAt.Unix()))),
				table.ValueParam("$updated_at", ydb.Uint64Value(uint64(report.UpdatedAt.Unix()))),
				table.ValueParam("$scans", ydb.OptionalValue(ydb.JSONValue(string(scans)))),
			))
			return err
		}),
	)
}

func createTables(ctx context.Context, sp *table.SessionPool, prefix string) (err error) {
	err = table.Retry(ctx, sp,
		table.OperationFunc(func(ctx context.Context, s *table.Session) error {
			return s.CreateTable(ctx, path.Join(prefix, tableName),
				table.WithColumn("sha256", ydb.Optional(ydb.TypeString)),
				table.WithColumn("md5", ydb.Optional(ydb.TypeString)),
				table.WithColumn("sha1", ydb.Optional(ydb.TypeString)),
				table.WithColumn("found", ydb.Optional(ydb.TypeBool)),
				table.WithColumn("positives", ydb.Optional(ydb.TypeUint32)),
				table.WithColumn("total", ydb.Optional(ydb.TypeUint32)),
				table.WithColumn("created_at", ydb.Optional(ydb.TypeUint64)),
				table.WithColumn("updated_at", ydb.Optional(ydb.TypeUint64)),
				table.WithColumn("scans", ydb.Optional(ydb.TypeJSON)),
				table.WithPrimaryKeyColumn("sha256"),
			)
		}),
	)
	return err
}
