package main

import (
	"flag"
	"fmt"
	"golang.org/x/net/context"
	"log"
	"os/exec"
	"path/filepath"
	"strconv"
	"strings"
	"time"

	"github.com/go-logfmt/logfmt"
	_ "github.com/lib/pq"

	"code.justin.tv/chat/db"
)

var (
	repmgrConfig string
)

func init() {
	flag.StringVar(&repmgrConfig, "config", "", "-config <config file path>")
}

func main() {
	flag.Parse()

	if repmgrConfig == "" {
		log.Fatalln("-config must be specified!")
	}

	repmgrData, err := LoadConfig(repmgrConfig)
	if err != nil {
		log.Fatalln(err)
	}

	clusterName := repmgrData["cluster"]
	localIDStr := repmgrData["node"]
	connStr := repmgrData["conninfo"]
	pgBinDir := repmgrData["pg_bindir"]

	localID, err := strconv.Atoi(localIDStr)
	if err != nil {
		log.Fatalln(err)
	}

	context := context.Background()

	//Connect to local instance
	log.Println("Verifying local replica...")
	localhost, localConn, err := connectFromConnStr(connStr)
	if err != nil {
		log.Fatalln(err)
	}

	//Verify local instance is replica
	isReplica := getReplica(context, "local", localConn)
	if !isReplica {
		log.Fatalln("nodefollow must be run against a repmgr cluster standby")
	}

	//Get local data directory
	var dataDir string
	row := localConn.QueryRow(context, "local-data-dir", "SELECT setting FROM pg_settings WHERE name='data_directory'")
	err = row.Scan(&dataDir)
	if err != nil {
		log.Fatalln(err)
	}
	if dataDir == "" {
		log.Fatalln("nodefollow could not find the postgres data directory of the local instance")
	}

	//Get local PG version
	replicaVersion := getVersion(context, "local", localConn)

	//Get master data
	var masterID int
	var masterConnStr string
	row = localConn.QueryRow(context, "master-conn-fetch", fmt.Sprintf("SELECT id, conninfo FROM \"repmgr_%s\".repl_nodes WHERE type='master' AND cluster='%s' AND active;", clusterName, clusterName))
	err = row.Scan(&masterID, &masterConnStr)
	if err != nil {
		log.Fatalln(err)
	}

	//Connect to master
	log.Println("Verifying master...")
	masterHost, masterConn, err := connectFromConnStr(masterConnStr)
	if err != nil {
		log.Fatalln(err)
	}

	//Check master is not replica
	isReplica = getReplica(context, "master", masterConn)
	if isReplica {
		log.Fatalln("nodefollow found a repmgr cluster master that was actually a replica")
	}

	//Get master version
	masterVersion := getVersion(context, "master", masterConn)

	//Check version validity
	if masterVersion != replicaVersion {
		log.Fatalf("nodefollow requires the local replica and master to be on the same version of postgres- master was on %s, replica was on %s\n", masterVersion, replicaVersion)
	}

	//Close local connection
	log.Println("Closing local connection...")
	err = localConn.Close()
	if err != nil {
		log.Println(err)
	}

	//update recovery.conf
	log.Println("Updating recovery.conf file...")
	err = writeRecoveryConf(dataDir, masterHost, localhost)
	if err != nil {
		log.Fatalln(err)
	}

	//use pg_ctl to restart postgres
	log.Println("Restarting local postgres...")
	pgCtl := filepath.Join(pgBinDir, "pg_ctl")
	cmd := exec.Command(pgCtl, "-w", "-D", dataDir, "-m", "fast", "restart")
	err = cmd.Run()
	if err != nil {
		log.Fatalln(err)
	}

	//Set this node to active, standby, master ID as upstream where cluster & id match
	log.Println("Updating repmgr node data...")
	res, err := masterConn.Exec(context, "update-local-node", fmt.Sprintf("UPDATE \"repmgr_%s\".repl_nodes SET type=$1, upstream_node_id=$2, active=$3 WHERE cluster=$4 AND id=$5", clusterName), "standby", masterID, true, clusterName, localID)
	if err != nil {
		log.Fatalln(err)
	}

	rowsAffected, err := res.RowsAffected()
	if err != nil {
		log.Fatalln(err)
	}
	if rowsAffected != 1 {
		log.Fatalf("Error updating node: Was hoping to update 1 row in the nodes table, but instead updated %d\n", rowsAffected)
	}

	//Wait up to 60s for replica to come back up & start replicating
	log.Println("Waiting for replica to come up...")
	for i := 0; i < 10; i++ {
		if testConnection(context, clusterName, localID, masterID, connStr) {
			log.Println("nodefollow successful.")
			return
		}
		time.Sleep(6 * time.Second)
	}

	log.Fatalln("Replica didn't come back up!  Something is seriously wrong.")
}

func testConnection(context context.Context, clusterName string, localNode int, upstreamNode int, connStr string) bool {
	_, db, err := connectFromConnStr(connStr)
	if err != nil {
		return false
	}
	defer func() {
		err := db.Close()
		if err != nil {
			log.Println(err)
		}
	}()

	row := db.QueryRow(context, "test-local-conn", fmt.Sprintf("SELECT type, upstream_node_id, active FROM \"repmgr_%s\".repl_nodes WHERE cluster=$1 AND id=$2", clusterName), clusterName, localNode)

	var nodeType string
	var upstreamID int
	var active bool
	err = row.Scan(&nodeType, &upstreamID, &active)
	if err != nil {
		return false
	}

	return (nodeType == "standby" && upstreamID == upstreamNode && active)
}

func connectFromConnStr(connStr string) (outputHost string, outputConn db.DB, err error) {
	outputHost = ""
	outputConn = nil
	err = nil

	decoder := logfmt.NewDecoder(strings.NewReader(connStr))
	decoder.ScanRecord()
	err = decoder.Err()
	if err != nil {
		return
	}

	dbOptions := make([]db.Option, 6)
	dbOptions[0] = db.DriverName("postgres")
	dbOptions[1] = db.MaxOpenConns(5)
	dbOptions[2] = db.MaxIdleConns(5)
	dbOptions[3] = db.RequestTimeout(5 * time.Minute)
	dbOptions[4] = db.MaxConnAge(5 * time.Minute)
	dbOptions[5] = db.ConnAcquireTimeout(1 * time.Second)

	for decoder.ScanKeyval() {
		err = decoder.Err()
		if err != nil {
			return
		}

		key := string(decoder.Key())
		value := string(decoder.Value())

		switch key {
		case "host":
			outputHost = value
			dbOptions = append(dbOptions, db.Host(value))
			break
		case "port":
			port := 0
			port, err = strconv.Atoi(value)
			if err != nil {
				return
			}

			dbOptions = append(dbOptions, db.Port(port))
			break
		case "dbname":
			dbOptions = append(dbOptions, db.DBName(value))
			break
		case "user":
			dbOptions = append(dbOptions, db.User(value))
			break
		}
	}

	outputConn, err = db.Open(dbOptions...)
	return
}

func getReplica(context context.Context, tag string, conn db.DB) bool {
	var isReplica bool
	row := conn.QueryRow(context, tag+"-role-check", "SELECT * FROM pg_is_in_recovery();")
	err := row.Scan(&isReplica)
	if err != nil {
		log.Fatalln(err)
	}
	return isReplica
}

func getVersion(context context.Context, tag string, conn db.DB) string {
	var replicaVersion string
	row := conn.QueryRow(context, tag+"-version-check", "SELECT current_setting('server_version');")
	err := row.Scan(&replicaVersion)
	if err != nil {
		log.Fatalln(err)
	}
	return replicaVersion
}
