package models

import (
	"context"
	"database/sql"
	"fmt"
	"strings"
	"time"

	"github.com/jackc/pgx/v4"
	"github.com/jackc/pgx/v4/stdlib"
	"golang.yandex/hasql"
	"golang.yandex/hasql/checkers"
)

type DB struct {
	*hasql.Cluster
}

type DBConfig struct {
	User                 string
	Password             string
	DBName               string
	Hosts                string
	Port                 string
	SSLMode              string
	PreferSimpleProtocol bool
}

func NewDB(config DBConfig) (*DB, error) {
	var nodes []hasql.Node
	for _, host := range strings.Split(config.Hosts, ",") {
		connString := fmt.Sprintf("host=%s port=%s user=%s password='%s' dbname=%s sslmode=%s",
			host, config.Port, config.User, config.Password, config.DBName, config.SSLMode)
		connConfig, err := pgx.ParseConfig(connString)
		if err != nil {
			err = fmt.Errorf("NewDB(): parse connection string: %w", err)
			return nil, err
		}
		connConfig.PreferSimpleProtocol = true
		db, err := sql.Open("pgx", stdlib.RegisterConnConfig(connConfig))
		if err != nil {
			err = fmt.Errorf("NewDB(): failed to connect to %s: %w", host, err)
			return nil, err
		}
		db.SetMaxOpenConns(20)

		nodes = append(nodes, hasql.NewNode(host, db))
	}

	c, err := hasql.NewCluster(nodes, checkers.PostgreSQL)
	if err != nil {
		err = fmt.Errorf("NewDB(): failed to create cluster: %w", err)
		return nil, err
	}

	ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
	defer cancel()

	_, err = c.WaitForPrimary(ctx)
	if err != nil {
		err = fmt.Errorf("NewDB(): wait for master: %w", err)
		return nil, err
	}

	_, err = c.WaitForStandby(ctx)
	if err != nil {
		err = fmt.Errorf("NewDB(): wait for replica: %w", err)
		return nil, err
	}

	return &DB{c}, nil
}
