package clickhouse

import (
	"database/sql/driver"
	"errors"
	"fmt"
	"log"
	"math/rand"
	"net/url"
	"runtime/pprof"
	"strconv"
	"sync"
	"time"

	ch "github.com/ClickHouse/clickhouse-go"

	"a.yandex-team.ru/security/osquery/osquery-sender/util"
)

const (
	defaultPort = 9440
)

type PoolParams struct {
	Hosts     []string
	Port      int
	Params    map[string]string
	Size      int
	ClusterID string

	NumRetries        int
	RetryBackoff      time.Duration
	WaitForConnection time.Duration
}

type ConnProc func(conn ch.Clickhouse) error

// A round-robin connection pool which maintains connections to multiple hosts in one Clickhouse cluster.
type ClickhousePool struct {
	hosts []string
	pools []*clickhouseSimplePool

	database string

	numRetries   int
	retryBackoff time.Duration

	waitForConnection time.Duration
}

func NewPool(params PoolParams) (*ClickhousePool, error) {
	database, ok := params.Params["database"]
	if !ok {
		return nil, errors.New("missing 'database' in Clickhouse parameters")
	}
	ret := &ClickhousePool{
		hosts:             params.Hosts,
		database:          database,
		numRetries:        params.NumRetries,
		retryBackoff:      retryBackoff,
		waitForConnection: params.WaitForConnection,
	}
	usePort := params.Port
	if usePort == 0 {
		usePort = defaultPort
	}
	urlParams := url.Values{}
	for key, value := range params.Params {
		urlParams.Add(key, value)
	}

	for _, host := range params.Hosts {
		churl := fmt.Sprintf("tcp://%s:%d?%s", host, usePort, urlParams.Encode())
		pool := newPool(churl, host, params.Size, retryBackoff)
		ret.pools = append(ret.pools, pool)
	}

	// Validate that we have at least one backend alive.
	totalConnections := 0
	for _, p := range ret.pools {
		totalConnections += p.numAliveConns()
	}
	if totalConnections < params.Size {
		return nil, fmt.Errorf("could not establish connection to Clickhouse")
	}

	return ret, nil
}

func (p *ClickhousePool) Close() {
	for _, p := range p.pools {
		p.close()
	}
}

// Runs proc on a connection with random host (different retries use the same host).
func (p *ClickhousePool) Run(proc ConnProc) error {
	startIdx := rand.Intn(len(p.pools))
	for i := 0; i < p.numRetries; i++ {
		idx := (startIdx + i) % len(p.pools)
		retry, err := p.runIter(p.pools[idx], p.hosts[idx], proc, i)
		if err == nil {
			return nil
		}
		if !retry {
			return err
		}
	}
	return errors.New("ClickhousePool.Run: unreachable")
}

func (p *ClickhousePool) runIter(pool *clickhouseSimplePool, host string, proc ConnProc, retryNum int) (bool, error) {
	isLastRetry := retryNum == p.numRetries-1
	conn := pool.get(p.waitForConnection)
	if conn == nil {
		return !isLastRetry, fmt.Errorf("no connections to %s", pool.host)
	}
	defer pool.put(conn)

	err := proc(conn)
	if err == nil {
		return false, nil
	}
	switch err.(type) {
	case *ch.Exception:
		// We really should check if the error is transient.
		if isLastRetry {
			log.Printf("ERROR: giving up (%s): %v\n", host, err)
			return false, err
		}
		log.Printf("ERROR: retrying in %v (%s): %v\n", p.retryBackoff, host, err)
		time.Sleep(p.retryBackoff)
		return true, err
	default:
		log.Printf("ERROR: non-retryable error (%s): %v\n", host, err)
		return false, err
	}
}

// A simple round-robin connection pool with liveness checks when taking connection.
type clickhouseSimplePool struct {
	url  string
	host string

	retryBackoff time.Duration

	available chan ch.Clickhouse

	// Store all create connections so that we can close them on program exit.
	allConns   []ch.Clickhouse
	allConnsMu sync.Mutex
}

func newPool(churl string, host string, size int, retryBackoff time.Duration) *clickhouseSimplePool {
	pool := &clickhouseSimplePool{
		url:          churl,
		host:         host,
		retryBackoff: retryBackoff,
		available:    make(chan ch.Clickhouse, size),
		allConns:     make([]ch.Clickhouse, 0, size),
	}
	for i := 0; i < size; i++ {
		err := pool.addNewConn()
		if err != nil {
			log.Printf("ERROR: could not connect to %s: %v", host, err)
			// Continue trying to establish a connection.
			go util.RunWithLabels(pprof.Labels("name", "clickhouse-new-pool-try-connect-"+strconv.Itoa(i)), func() {
				pool.addNewConnLoop()
			})
		}
	}

	return pool
}

func (p *clickhouseSimplePool) close() {
	p.allConnsMu.Lock()
	defer p.allConnsMu.Unlock()

	for _, conn := range p.allConns {
		err := conn.Close()
		if err != nil {
			log.Printf("ERROR: closing connection failed (%s): %v\n", p.host, err)
		}
	}
}

func (p *clickhouseSimplePool) get(maxWait time.Duration) ch.Clickhouse {
	// Immediately fail if all connections are dead (e.g. the server is offline).
	if p.numAliveConns() == 0 {
		return nil
	}
	// We cannot get the connection from the pool or the connection is dead, do another retry (with another server).
	var conn ch.Clickhouse
	select {
	case conn = <-p.available:
	case <-time.After(maxWait):
		return nil
	}
	if p.connIsOk(conn) {
		return conn
	}
	p.removeFromAllConns(conn)
	_ = conn.Close()

	// The available channel serves as a semaphore: we never spawn more goroutines than the size of the pool.
	go util.RunWithLabels(pprof.Labels("name", "clickhouse-pool-reconnect-on-error"), func() {
		p.addNewConnLoop()
	})
	return nil
}

func (p *clickhouseSimplePool) put(conn ch.Clickhouse) {
	p.available <- conn
}

func (p *clickhouseSimplePool) addNewConnLoop() {
	for {
		err := p.addNewConn()
		if err == nil {
			break
		}
		log.Printf("ERROR: connect to %s failed, sleeping %v: %v", p.host, p.retryBackoff, err)
		time.Sleep(p.retryBackoff)
	}
}

func (p *clickhouseSimplePool) addNewConn() error {
	conn, err := ch.OpenDirect(p.url)
	if err != nil {
		return err
	}

	if !p.connIsOk(conn) {
		_ = conn.Close()
		return fmt.Errorf("new connection is not ok (%s)", p.host)
	}

	p.allConnsMu.Lock()
	defer p.allConnsMu.Unlock()
	p.allConns = append(p.allConns, conn)
	p.available <- conn

	return nil
}

func (p *clickhouseSimplePool) connIsOk(conn ch.Clickhouse) bool {
	_, err := conn.Begin()
	if err != nil {
		log.Printf("ERROR: begin failed (%s): %v", p.host, err)
		return false
	}

	query := "SELECT 1"
	stmt, err := conn.Prepare(query)
	if err != nil {
		log.Printf("ERROR: prepare %s failed (%s): %v\n", query, p.host, err)
		return false
	}
	defer func() {
		err := stmt.Close()
		if err != nil {
			log.Printf("ERROR: closing statement failed (%s): %v\n", p.host, err)
		}
	}()

	//goland:noinspection GoDeprecation
	rows, err := stmt.Query([]driver.Value{})
	if err != nil {
		log.Printf("ERROR: query %s failed (%s): %v\n", query, p.host, err)
		return false
	}
	defer func() {
		err := rows.Close()
		if err != nil {
			log.Printf("ERROR: closing rows failed (%s): %v\n", p.host, err)
		}
	}()

	var values [1]driver.Value
	err = rows.Next(values[:])
	if err != nil {
		log.Printf("ERROR: query %s failed (%s): %v\n", query, p.host, err)
		return false
	}

	// Don't judge me, this is the easiest way.
	strV := fmt.Sprintf("%v", values[0])
	if strV != "1" {
		log.Printf("ERROR: got %v instead of 1 when checking connection (%s)\n", values[0], p.host)
		return false
	}

	err = conn.Commit()
	if err != nil {
		log.Printf("ERROR: commit %s failed (%s): %v\n", query, p.host, err)
		return false
	}
	return true
}

func (p *clickhouseSimplePool) removeFromAllConns(conn ch.Clickhouse) {
	p.allConnsMu.Lock()
	defer p.allConnsMu.Unlock()

	l := len(p.allConns)
	for i := 0; i < l; i++ {
		if p.allConns[i] == conn {
			// Unordered remove.
			p.allConns[i] = p.allConns[l-1]
			p.allConns[l-1] = nil
			p.allConns = p.allConns[:l-1]
			return
		}
	}
	log.Printf("ERROR: connection %s not in allCons?", p.url)
}

func (p *clickhouseSimplePool) numAliveConns() int {
	p.allConnsMu.Lock()
	defer p.allConnsMu.Unlock()
	return len(p.allConns)
}
