package main

import (
	"bufio"
	"crypto/tls"
	"crypto/x509"
	"database/sql"
	"encoding/json"
	"fmt"
	"github.com/go-sql-driver/mysql"
	"io"
	"io/ioutil"
	"log"
	"net"
	"os"
	"strings"
	"time"
)

// ==========================================================================================

type ConfigStruct struct {
	Reals             []string
	Host              string
	Port              uint
	ReportPort        uint
	RootCAPath        string
	PasswordFile      string
	User              string
	Database          string
	GrafanaIni        string
	MasterCheckPeriod uint
	ConnectTimeout    uint
	Verbose           bool
}

type DualChan struct {
	Master chan string
	Slave  chan string
}

func NewDualChan() *DualChan {
	return &DualChan{Master: make(chan string), Slave: make(chan string)}
}

// ==========================================================================================

func ReadConfig(path string) ConfigStruct {
	var config ConfigStruct

	data, err := ioutil.ReadFile(path)
	if err != nil {
		log.Fatalf("Bad config file path %s: %s", path, err.Error())
	}
	if json.Unmarshal(data, &config) != nil {
		log.Fatalf("Bad json in config file %s", path)
	}
	return config
}

func ReadGrafanaIni(iniFile string) map[string]string {
	var prefix string
	iniContent := make(map[string]string)

	file, err := os.Open(iniFile)
	if err != nil {
		log.Fatalf("Bad grafana password file path %s: %s", iniFile, err.Error())
	}
	defer file.Close()

	scanner := bufio.NewScanner(file)
	for scanner.Scan() {
		line := scanner.Text()
		commentIdx := strings.IndexAny(line, "#;")
		if commentIdx >= 0 {
			line = line[:commentIdx]
		}
		line = strings.TrimSpace(line)
		if strings.HasPrefix(line, "[") {
			prefix = line[1 : len(line)-1]
		} else if len(line) > 0 {
			fields := strings.SplitN(line, "=", 2)
			if len(fields) == 2 {
				iniContent[fmt.Sprintf("%s.%s", prefix, strings.TrimSpace(fields[0]))] = strings.TrimSpace(fields[1])
			}
		}
	}
	return iniContent
}

func ReadPassword(passwordFile string, user string) string {
	file, err := os.Open(passwordFile)
	if err != nil {
		log.Fatalf("Bad password file path %s: %s", passwordFile, err.Error())
	}
	defer file.Close()

	scanner := bufio.NewScanner(file)
	for scanner.Scan() {
		line := scanner.Text()
		commentIdx := strings.IndexAny(line, "#;")
		if commentIdx >= 0 {
			line = line[:commentIdx]
		}
		commentIdx = strings.Index(line, ":")
		if commentIdx >= 0 {
			first := strings.TrimSpace(line[:commentIdx])
			second := strings.TrimSpace(line[commentIdx+1:])
			if first == user {
				return second
			}
		}
	}
	return ""
}

// ==========================================================================================

func MySQLRegisterCustomTLS(rootCAPath string) {
	rootCertPool := x509.NewCertPool()
	pem, err := ioutil.ReadFile(rootCAPath)
	if err != nil {
		log.Fatalf("Failed to read root CA file %s: %s", rootCAPath, err.Error())
	}
	if ok := rootCertPool.AppendCertsFromPEM(pem); !ok {
		log.Fatalln("Failed to append PEM to root CA pool.")
	}

	// TODO: check error
	_ = mysql.RegisterTLSConfig("custom", &tls.Config{RootCAs: rootCertPool})
}

func MySQL(host string,
	user string,
	password string,
	dbname string,
	timeout time.Duration) *sql.DB {

	dataSourceName := fmt.Sprintf("%s:%s@tcp(%s)/%s?readTimeout=3s&writeTimeout=3s&timeout=%ds&tls=custom",
		user,
		password,
		host,
		dbname,
		uint(timeout/time.Second))

	db, err := sql.Open("mysql", dataSourceName)
	if err != nil {
		log.Printf("Failed to open DB %s@%s/%s: %s", user, host, dbname, err.Error())
	}
	db.SetConnMaxLifetime(timeout)

	return db
}

func MySQLCheckRoutine(realServer,
	user string,
	password string,
	database string,
	channel *DualChan,
	timeout time.Duration) {

	db := MySQL(realServer, user, password, database, timeout)
	defer db.Close()

	for {
		signal := <-channel.Master
		if signal == "QUIT" {
			break
		}
		if err := db.Ping(); err != nil {
			log.Printf("Failed to ping MySQL server at %s: %s\n", realServer, err.Error())
			channel.Slave <- "ERR_NET"
		} else {
			var ro bool
			rows, err := db.Query("SELECT @@GLOBAL.read_only")
			if err != nil {
				log.Printf("Failed to get variables from MySQL server at %s: %s\n", realServer, err.Error())
				channel.Slave <- "ERR_VAR"
			} else {
				rows.Next()
				// TODO: check error
				_ = rows.Scan(&ro)
				if ro {
					channel.Slave <- "RO"
				} else {
					channel.Slave <- "RW"
				}
				// TODO: check error
				_ = rows.Close()
			}
		}
	}
}

func MasterCheckRoutine(realChannels map[string]*DualChan,
	masterChannel chan string,
	realsStateChannel chan string,
	period time.Duration,
	verbose bool) {

	var master string
	var diffCount uint
	maxDiffCount := uint(3)
	realsState := make(map[string]string)

	for {
		selectedReal := ""
		stateSummary := ""
		for _, channel := range realChannels {
			channel.Master <- "CHECK"
		}
		for realServer, channel := range realChannels {
			data := <-channel.Slave
			realsState[realServer] = data
			if data == "RW" || (data == "RO" && len(selectedReal) == 0) {
				selectedReal = realServer
			}
			stateSummary += fmt.Sprintf("[%s]:%s ", realServer, data)
		}
		log.Printf("Reals state: %s", stateSummary)
		realsStateChannel <- stateSummary
		if len(selectedReal) == 0 {
			log.Println("No online servers available")
		} else {
			mState := realsState[master]
			sState := realsState[selectedReal]
			if len(master) == 0 || (mState != "RW" && sState == "RW") {
				diffCount += maxDiffCount
			} else if strings.HasPrefix(mState, "ERR") && !strings.HasPrefix(sState, "ERR") {
				diffCount += 1
			} else {
				diffCount = 0
			}
			if diffCount >= maxDiffCount {
				master = selectedReal
				log.Printf("Sending new master: %s", master)
				masterChannel <- master
				diffCount = 0
			}
		}
		time.Sleep(period)
	}
}

func StateReporter(realsStateChannel chan string,
	host string,
	port uint) {

	var state []byte
	serverHostPort := net.JoinHostPort(host, fmt.Sprint(port))
	tcpConnServer, err := net.Listen("tcp", serverHostPort)
	if err != nil {
		log.Fatalf("Failed to start listening at %s (state reporter): %s", serverHostPort, err.Error())
	}
	connChannel := make(chan net.Conn)
	go func() {
		for {
			tcpConn, err := tcpConnServer.Accept()
			if err != nil {
				log.Printf("Failed to accept connection at %s: %s", serverHostPort, err.Error())
			} else {
				connChannel <- tcpConn
			}
		}
	}()
	for {
		select {
		case s := <-realsStateChannel:
			state = []byte(s)
		case tcpConn := <-connChannel:
			if _, err := tcpConn.Write(state); err != nil {
				log.Printf("Failed to write state: %s", err.Error())
			}
			if err := tcpConn.Close(); err != nil {
				log.Printf("Failed to close state reporter connection: %s", err.Error())
			}
		}
	}
}

// ==========================================================================================

type ProxyConn struct {
	Src         net.Conn
	Dst         net.Conn
	SrcAddr     string
	DstAddr     string
	Master      string
	finChannel  chan<- *ProxyConn
	dialTimeout time.Duration
	verbose     bool
	stopChannel chan bool
	startTime   time.Time
}

func NewProxyConn(src net.Conn,
	master string,
	finChannel chan<- *ProxyConn,
	timeout time.Duration,
	verbose bool) *ProxyConn {

	p := &ProxyConn{
		Src:         src,
		SrcAddr:     src.RemoteAddr().String(),
		Master:      master,
		finChannel:  finChannel,
		dialTimeout: timeout,
		verbose:     verbose,
		stopChannel: make(chan bool),
		startTime:   time.Now(),
	}
	go p.proxyRoutine()
	return p
}

func (p *ProxyConn) Stop() {
	if p.stopChannel != nil {
		close(p.stopChannel)
	}
	p.stopChannel = nil
}

func (p *ProxyConn) oneWay(src net.Conn,
	dst net.Conn,
	fin chan<- int64) {

	bytes, err := io.Copy(dst, src)
	if p.verbose && err != nil {
		log.Printf("Connection error: %s", err.Error())
	}
	fin <- bytes
}

func (p *ProxyConn) proxyRoutine() {
	var err error
	var bytes [2]int64

	if p.Dst, err = net.DialTimeout("tcp", p.Master, p.dialTimeout); err != nil {
		log.Printf("Failed to connect to master %s (proxying connection for %s, %v): %s", p.Master, p.SrcAddr, time.Since(p.startTime), err.Error())
		// TODO: check error
		_ = p.Src.Close()
	} else {
		p.DstAddr = p.Dst.RemoteAddr().String()
		if p.verbose {
			log.Printf("New proxy connection %s->%s(%s)", p.SrcAddr, p.Master, p.DstAddr)
		}
		fin := make(chan int64, 2)
		go p.oneWay(p.Dst, p.Src, fin)
		go p.oneWay(p.Src, p.Dst, fin)
	loop:
		for {
			select {
			case bytes[0] = <-fin:
				// TODO: check errors
				_ = p.Dst.Close()
				_ = p.Src.Close()
				bytes[1] = <-fin
				break loop
			case <-p.stopChannel:
				// TODO: check errors
				_ = p.Dst.Close()
				_ = p.Src.Close()
			}
		}
		if p.verbose {
			log.Printf("Finished proxy connection %s->%s(%s), %d<->%d bytes (%v)", p.SrcAddr, p.Master, p.DstAddr, bytes[0], bytes[1], time.Since(p.startTime))
		}
	}
	p.finChannel <- p
}

// ==========================================================================================

func proxyConnector(masterChannel chan string,
	connChannel chan net.Conn,
	timeout time.Duration,
	verbose bool) {

	tick := time.Tick(1 * time.Minute)
	proxyConns := make(map[*ProxyConn]struct{})
	finChannel := make(chan *ProxyConn, 100)
	myMaster := <-masterChannel
	for {
		select {
		case myMaster = <-masterChannel:
			log.Printf("Master changed to %s", myMaster)
			for p := range proxyConns {
				if p.Master != myMaster {
					if verbose {
						log.Printf("Stopping: %s->%s(%s)", p.SrcAddr, p.Master, p.DstAddr)
					}
					p.Stop()
				}
			}
		case p := <-finChannel:
			if verbose {
				log.Printf("Done: %s->%s(%s)", p.SrcAddr, p.Master, p.DstAddr)
			}
			delete(proxyConns, p)
		case tcpConn := <-connChannel:
			p := NewProxyConn(tcpConn, myMaster, finChannel, timeout, verbose)
			proxyConns[p] = struct{}{}
		case <-tick:
			log.Printf("Current number of proxy connections: %d", len(proxyConns))
		}
	}
}

// ==========================================================================================

func main() {
	var user string
	var database string
	var password string

	log.SetFlags(log.Ldate | log.Ltime | log.Lmicroseconds)
	if len(os.Args) != 2 {
		log.Fatalf("Usage: %s <config_file>", os.Args[0])
	}
	configFileName := os.Args[1]

	config := ReadConfig(configFileName)
	cfg, _ := json.Marshal(config)
	log.Printf("Using config: %s", string(cfg))

	checkPeriod := time.Duration(config.MasterCheckPeriod) * time.Second
	connectTimeout := time.Duration(config.ConnectTimeout) * time.Second

	// Auth block
	if len(config.GrafanaIni) > 0 {
		if len(config.PasswordFile) > 0 {
			log.Fatalf("Cannot use both GrafanaIni and PasswordFile in config file")
		}
		grafanaIni := ReadGrafanaIni(config.GrafanaIni)
		user = grafanaIni["database.user"]
		database = grafanaIni["database.name"]
		password = grafanaIni["database.password"]
	}
	if len(config.User) > 0 {
		user = config.User
		if len(config.PasswordFile) > 0 {
			password = ReadPassword(config.PasswordFile, user)
		}
	}
	if len(config.Database) > 0 {
		database = config.Database
	}
	if len(password) == 0 || len(user) == 0 || len(database) == 0 {
		log.Fatalf("Failed to set password, user, database from configuration file")
	}
	MySQLRegisterCustomTLS(config.RootCAPath)

	serverHostPort := net.JoinHostPort(config.Host, fmt.Sprint(config.Port))
	tcpConnServer, err := net.Listen("tcp", serverHostPort)
	if err != nil {
		log.Fatalf("Failed to start listening at %s: %s", serverHostPort, err.Error())
	}

	realChannels := make(map[string]*DualChan)
	for _, r := range config.Reals {
		channel := NewDualChan()
		realChannels[r] = channel
		go MySQLCheckRoutine(r, user, password, database, channel, connectTimeout)
	}
	realsStateChannel := make(chan string)
	masterChannel := make(chan string)
	go MasterCheckRoutine(realChannels, masterChannel, realsStateChannel, checkPeriod, config.Verbose)
	go StateReporter(realsStateChannel, config.Host, config.ReportPort)

	connChannel := make(chan net.Conn)
	go proxyConnector(masterChannel, connChannel, connectTimeout, config.Verbose)
	for {
		tcpConn, err := tcpConnServer.Accept()
		if err != nil {
			log.Printf("Failed to accept connection at %s: %s", serverHostPort, err.Error())
		} else {
			connChannel <- tcpConn
		}
	}
}
