package db

import (
	"context"
	"crypto/tls"
	"database/sql/driver"
	"errors"
	"fmt"
	"net/url"
	"time"

	"github.com/ClickHouse/clickhouse-go"
	"github.com/cenkalti/backoff/v4"

	"a.yandex-team.ru/library/go/certifi"
	"a.yandex-team.ru/library/go/core/log"
	"a.yandex-team.ru/library/go/core/log/nop"
	"a.yandex-team.ru/security/gideon/speedy-beaver/internal/config"
)

const (
	retryTimeout = 100 * time.Millisecond
	pingTimeout  = 200 * time.Millisecond
)

type chFn func(conn clickhouse.Clickhouse) error

type chDB struct {
	cfg    config.ClickHouse
	chURI  string
	chConn clickhouse.Clickhouse
	log    log.Logger
}

func newChDB(cfg config.ClickHouse) (*chDB, error) {
	caCertPool, err := certifi.NewCertPoolInternal()
	if err != nil {
		return nil, fmt.Errorf("failed to setup internal CA cert pool: %w", err)
	}

	_ = clickhouse.RegisterTLSConfig("mdb", &tls.Config{
		RootCAs: caCertPool,
	})

	out := &chDB{
		cfg: cfg,
		log: &nop.Logger{},
		chURI: cfg.URI(url.Values{
			"tls_config":               {"mdb"},
			"connection_open_strategy": {"time_random"},
		}),
	}

	// Validate that we have at least one backend alive.
	pingCtx, pingCancel := context.WithTimeout(context.Background(), pingTimeout)
	defer pingCancel()
	if err := out.Ping(pingCtx); err != nil {
		return nil, fmt.Errorf("ping failed: %w", err)
	}

	return out, nil
}

func (p *chDB) Close() error {
	return p.chConn.Close()
}

func (p *chDB) Ping(ctx context.Context) error {
	return p.Run(ctx, func(conn clickhouse.Clickhouse) error {
		return conn.(driver.Pinger).Ping(ctx)
	})
}

func (p *chDB) Run(ctx context.Context, fn chFn) error {
	return backoff.RetryNotify(
		func() error {
			conn, err := p.conn()
			if err != nil {
				return err
			}

			err = fn(conn)

			if err != nil {
				if errors.Is(err, driver.ErrBadConn) {
					// ch drops connection
					_ = p.chConn.Close()
					p.chConn = nil
				}
			}
			return err
		},
		backoff.WithContext(
			backoff.WithMaxRetries(
				backoff.NewConstantBackOff(retryTimeout),
				p.cfg.Retries,
			),
			ctx,
		),
		func(err error, duration time.Duration) {
			if errors.Is(err, driver.ErrBadConn) {
				return
			}

			p.log.Error("query error", log.Error(err), log.Duration("sleep", duration))
		},
	)
}

func (p *chDB) conn() (clickhouse.Clickhouse, error) {
	if p.chConn != nil {
		return p.chConn, nil
	}

	conn, err := clickhouse.OpenDirect(p.chURI)
	if err != nil {
		return nil, err
	}

	if err := pingConn(conn); err != nil {
		return nil, fmt.Errorf("new connection is not ok: %w", err)
	}

	p.chConn = conn
	return p.chConn, nil
}

func pingConn(conn clickhouse.Clickhouse) error {
	pingCtx, pingCancel := context.WithTimeout(context.Background(), pingTimeout)
	defer pingCancel()
	return conn.(driver.Pinger).Ping(pingCtx)
}
