package database

import (
	"context"
	"errors"
	"flag"
	"fmt"
	"strings"
	"time"

	"a.yandex-team.ru/infra/skyboned/go/src/auth"

	"github.com/jackc/pgconn"
	"github.com/jackc/pgx/v4"
	"github.com/jackc/pgx/v4/pgxpool"
	"go.uber.org/zap"
)

var (
	DBConfGlobal DBConf
)

type DBConf struct {
	DBhost        string
	DBuser        string
	DBpassword    string
	DBname        string
	DBmasterPool  int
	DBport        int
	DBreplicaPool int
}

type DB struct {
	Master  *pgxpool.Pool
	Replica *pgxpool.Pool
	Hosts   []string
}

//Sets cli args for database connection
func DBFlags() {
	flag.StringVar(&DBConfGlobal.DBhost, "db-host", "", "database host")
	flag.IntVar(&DBConfGlobal.DBport, "db-port", 0, "database port")
	flag.StringVar(&DBConfGlobal.DBuser, "db-user", "", "database user")
	flag.StringVar(&DBConfGlobal.DBpassword, "db-password", "", "database password")
	flag.StringVar(&DBConfGlobal.DBname, "db-name", "", "database name")
	flag.IntVar(&DBConfGlobal.DBmasterPool, "db-master-pool", 20, "database port")
	flag.IntVar(&DBConfGlobal.DBreplicaPool, "db-replica-pool", 100, "database port")
}

func Connect(dbconf *DBConf, master bool) (*pgxpool.Pool, error) {
	var poolsize int
	if master {
		poolsize = dbconf.DBmasterPool
	} else {
		poolsize = dbconf.DBreplicaPool
	}
	connString := fmt.Sprintf("host=%s port=%d user=%s "+
		"password=%s dbname=%s pool_max_conns=%d pool_max_conn_idle_time=60s statement_cache_mode=describe",
		dbconf.DBhost, dbconf.DBport, dbconf.DBuser, dbconf.DBpassword, dbconf.DBname, poolsize)

	dbConfig, err := pgxpool.ParseConfig(connString)
	if err != nil {
		zap.S().Error(err)
		return nil, err
	}
	db, err := pgxpool.ConnectConfig(context.Background(), dbConfig)
	if err != nil {
		zap.S().Error(err)
		return nil, err
	}
	return db, nil
}

func IsMaster(host string) bool {
	dbconf := DBConfGlobal
	dbconf.DBhost = host
	dbconf.DBmasterPool = 1
	db, err := Connect(&dbconf, true)
	if err != nil {
		return false
	}
	defer db.Close()
	var Conntype bool
	err = db.QueryRow(context.Background(), "SELECT pg_is_in_recovery()").Scan(&Conntype)
	if err != nil {
		return false
	}
	return !Conntype
}

func DBSetup(dbconf *DBConf) (*DB, error) {
	var master, slave *pgxpool.Pool
	var err error
	hosts := strings.Split(dbconf.DBhost, ",")
	for i, h := range hosts {
		ok := IsMaster(h)
		if ok {
			dbconf.DBhost = h
			master, err = Connect(dbconf, true)
			if err != nil {
				zap.S().Fatalf("master connection failed: %v", err)
			}
			zap.S().Infof("master chosen: %v pool: %v\n", h, dbconf.DBmasterPool)

			replicaHosts := strings.Join(append(hosts[:i], hosts[i+1:]...), ",")
			if len(replicaHosts) == 0 {
				replicaHosts = h
				zap.S().Warnf("could not separate master and replica - are you running one db host?")
			}
			dbconf.DBhost = replicaHosts
			slave, err = Connect(dbconf, false)
			if err != nil {
				zap.S().Errorf("replica connection failed: %v", err)
			}
			zap.S().Infof("slave(s) chosen: %v pool: %v\n", replicaHosts, dbconf.DBreplicaPool)
		}
	}
	return &DB{
		Master:  master,
		Replica: slave,
		Hosts:   hosts,
	}, nil
}

func (db *DB) TxStart(ctx context.Context) (pgx.Tx, error) {
	tx, err := db.Master.Begin(ctx)
	if err != nil {
		return nil, err
	}
	return tx, err
}

func (db *DB) QueryRow(ctx context.Context, query string, args ...interface{}) pgx.Row {
	t := time.Now()
	defer LogQueryTime(ctx, t, query) //Timers["database_request_sec"].RecordDuration(time.Since(t))
	return db.Replica.QueryRow(context.Background(), query, args...)
}

func (db *DB) Query(ctx context.Context, query string, args ...interface{}) (pgx.Rows, error) {
	t := time.Now()
	defer LogQueryTime(ctx, t, query) //Timers["database_request_sec"].RecordDuration(time.Since(t))
	return db.Replica.Query(context.Background(), query, args...)
}

func (db *DB) Exec(ctx context.Context, query string, args ...interface{}) (ct pgconn.CommandTag, err error) {
	t := time.Now()
	defer LogQueryTime(ctx, t, query) //Timers["database_request_sec"].RecordDuration(time.Since(t))
	defer checkDBError(&err)
	return db.Master.Exec(context.Background(), query, args...)
}

func (db *DB) TxExec(ctx context.Context, tx pgx.Tx, query string, args ...interface{}) (ct pgconn.CommandTag, err error) {
	t := time.Now()
	defer LogQueryTime(ctx, t, query) //Timers["database_request_sec"].RecordDuration(time.Since(t))
	defer checkDBError(&err)
	return tx.Exec(context.Background(), query, args...)
}

func LogQueryTime(ctx context.Context, start time.Time, query string) {
	if rt := time.Since(start); rt > time.Second {
		zap.S().Warnw(ctx.Value(auth.RequestID{}).(string),
			"Query", query,
			"Time", rt)
	}
}

func checkDBError(err *error) {
	var pgErr *pgconn.PgError
	if errors.As(*err, &pgErr) {
		if pgErr.Code == "25006" { //error that is returned if master switched and we try to exec into read-only connection
			zap.S().Fatal(*err)
		}
	}
}
