package pgclient

import (
	"context"
	"database/sql"
	"fmt"
	"time"

	"github.com/jackc/pgx/v4"
	"github.com/jackc/pgx/v4/stdlib"
	"golang.yandex/hasql"
	"gorm.io/gorm"
	"gorm.io/gorm/logger"

	"a.yandex-team.ru/library/go/core/xerrors"
)

type PGClient struct {
	cluster       *hasql.Cluster
	onCheckedNode func(db *sql.DB, latency time.Duration)
	logLevel      logger.LogLevel
}

var (
	ErrPrimaryNodeIsUnavailable = xerrors.New("postgres: Primary node is unavailable")
	ErrNodeIsUnavailable        = xerrors.New("postgres: node is unavailable")
)

func NewPGClient(
	hosts []string,
	port int,
	dbName, user, password string,
	initTimeout time.Duration,
	clientOptions []ClientOption,
	clusterOptions []hasql.ClusterOption,
	poolOptions ...Option,
) (*PGClient, error) {
	var nodes []hasql.Node
	for _, host := range hosts {
		connString := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s", host, port, user, password, dbName)
		connConfig, err := pgx.ParseConfig(connString)
		if err != nil {
			return nil, err
		}

		// workaround for https://github.com/jackc/pgx/issues/602
		connConfig.BuildStatementCache = nil
		connConfig.PreferSimpleProtocol = true

		db := stdlib.OpenDB(*connConfig)
		for _, opt := range append(defaultPoolOptions, poolOptions...) {
			opt(db)
		}

		nodes = append(nodes, hasql.NewNode(host, db))
	}
	client := &PGClient{onCheckedNode: func(*sql.DB, time.Duration) {}}
	for _, opt := range clientOptions {
		opt(client)
	}
	cluster, err := hasql.NewCluster(nodes, client.nodeChecker, clusterOptions...)
	if err != nil {
		return nil, err
	}
	client.cluster = cluster

	ctx, cancel := context.WithTimeout(context.Background(), initTimeout)
	defer cancel()
	if _, err := cluster.WaitForPrimary(ctx); err != nil {
		return nil, err
	}

	return client, nil
}

func (c *PGClient) GetPrimary() (*gorm.DB, error) {
	if node := c.cluster.Primary(); node != nil {
		return gormWrapper(node, c.logLevel)
	}
	return nil, ErrPrimaryNodeIsUnavailable
}

func (c *PGClient) ExecuteInTransaction(
	nodeState hasql.NodeStateCriteria,
	processFunc func(*gorm.DB) error,
) error {
	if node := c.cluster.Node(nodeState); node != nil {
		db, err := gormWrapper(node, c.logLevel)
		if err != nil {
			return err
		}
		return db.Transaction(processFunc)
	}

	return ErrNodeIsUnavailable
}
