package pgclient

import (
	"context"
	"errors"
	"fmt"
	"time"

	"github.com/jackc/pgx/v4"
	"github.com/jackc/pgx/v4/stdlib"
	"golang.yandex/hasql"
	"gorm.io/gorm"
)

type PGClient struct {
	cluster *hasql.Cluster
}

var (
	ErrPrimaryNodeIsUnavailable = errors.New("postgres: Primary node is unavailable")
	ErrNodeIsUnavailable        = errors.New("postgres: node is unavailable")
)

func NewPGClient(
	hosts []string,
	port int,
	dbName, user, password string,
	initTimeout time.Duration,
	clusterOptions []hasql.ClusterOption,
	poolOptions ...Option,
) (*PGClient, error) {
	var nodes []hasql.Node
	for _, host := range hosts {
		connString := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s", host, port, user, password, dbName)
		connConfig, err := pgx.ParseConfig(connString)
		if err != nil {
			return nil, err
		}

		// workaround for https://github.com/jackc/pgx/issues/602
		connConfig.BuildStatementCache = nil
		connConfig.PreferSimpleProtocol = true

		db := stdlib.OpenDB(*connConfig)
		for _, opt := range append(defaultPoolOptions, poolOptions...) {
			opt(db)
		}

		nodes = append(nodes, hasql.NewNode(host, db))
	}

	client := &PGClient{}
	cluster, err := hasql.NewCluster(nodes, client.nodeChecker, clusterOptions...)
	if err != nil {
		return nil, err
	}
	client.cluster = cluster

	ctx, cancel := context.WithTimeout(context.Background(), initTimeout)
	defer cancel()
	if _, err := cluster.WaitForPrimary(ctx); err != nil {
		return nil, err
	}

	return client, nil
}

func (c *PGClient) GetPrimary() (*gorm.DB, error) {
	if node := c.cluster.Primary(); node != nil {
		return gormWrapper(node)
	}
	return nil, ErrPrimaryNodeIsUnavailable
}

func (c *PGClient) GetReadable() (*gorm.DB, error) {
	if node := c.cluster.StandbyPreferred(); node != nil {
		return gormWrapper(node)
	}
	return nil, ErrNodeIsUnavailable
}

func (c *PGClient) ExecuteInTransaction(
	nodeState hasql.NodeStateCriteria,
	processFunc func(*gorm.DB) error,
) error {
	if node := c.cluster.Node(nodeState); node != nil {
		db, err := gormWrapper(node)
		if err != nil {
			return err
		}
		return db.Transaction(processFunc)
	}

	return ErrNodeIsUnavailable
}
