package main

import (
	"database/sql"
	"flag"
	"fmt"
	"log"
	"time"

	"code.justin.tv/d8a/pg-stats/statscollection"
	"github.com/cactus/go-statsd-client/statsd"
	_ "github.com/lib/pq"
)

var (
	cluster        string // eg, db17.sfo01 or whatever
	server         string // host:port
	globalInterval int64
	statInterval   int
	lagInterval    int
	host           string
	hostname       string
	port           int
	user           string
	pass           string
	schema         string
	verbose        bool
	nodeRole       string
)

func init() {
	flag.StringVar(&cluster, "cluster", "", "cluster name to log under in statsd")
	flag.StringVar(&server, "server", "graphite.internal.justin.tv:8125", "statsd server to send to with host:port")
	flag.Int64Var(&globalInterval, "global-interval", 10, "number of seconds between fetching global stats")
	flag.IntVar(&statInterval, "stat-interval", 10, "number of seconds between fetching stats")
	flag.IntVar(&lagInterval, "lag-interval", 0, "number of seconds between fetching lag")
	flag.StringVar(&host, "host", "/var/run/postgresql", "host to connect to or domain socket location")
	flag.StringVar(&hostname, "hostname", "", "host name to log under in statsd")
	flag.IntVar(&port, "port", 5432, "postgresql port on the host")
	flag.StringVar(&user, "user", "postgres", "database user name")
	flag.StringVar(&pass, "pass", "", "password for connection")
	flag.StringVar(&schema, "schema", "", "database schema for the connection")
	flag.BoolVar(&verbose, "verbose", false, "log progress as we go")
	flag.StringVar(&nodeRole, "role", "", "master, replica or some other label")
}

func main() {
	flag.Parse()
	if cluster == "" {
		log.Fatal("Must specify a cluster. ie: usher")
	}
	if host == "" {
		log.Fatal("Must specify a host. ie: usher-postgres-1fa026")
	}
	if server == "" {
		log.Fatal("Must specify graphite server")
	}
	if schema == "" {
		log.Fatal("Must specify a schema")
	}
	if globalInterval == 0 && statInterval == 0 && lagInterval == 0 {
		log.Fatal("Must specify at least one valid interval to poll")
	}
	if globalInterval < 0 || statInterval < 0 || lagInterval < 0 {
		log.Fatal("Must specify a positive value for all intervals")
	}

	msg := fmt.Sprintf("Starting pg-stats on %v to %v", host, server)
	log.Println(msg)

	dsn := fmt.Sprintf("sslmode=disable host=%v port=%v user=%v dbname='%v' binary_parameters=yes", host, port, user, schema)
	if pass != "" {
		dsn = fmt.Sprintf("%v password=%v", dsn, pass)
	}

	roleString(dsn, &nodeRole)

	if globalInterval > 0 {
		go pollPostgresGlobals(globalInterval, dsn)
	}
	if statInterval > 0 {
		go pollPostgresStats(statInterval, dsn)
	}
	if lagInterval > 0 && nodeRole != "master" {
		go pollPostgresLag(lagInterval, dsn)
	}
	select {}
}

func vlogf(fmt string, args ...interface{}) {
	if verbose {
		log.Printf(fmt, args...)
	}
}

type query func(db *sql.DB, stats statsd.Statter, now int64) (err error)

func doQuery(dsn string, purpose string, fn query) (err error) {
	log.Printf("Creating statsd NewBufferedClient for %v\n", purpose)
	vlogf("connecting to %v for %v\n", server, purpose)
	stats, err := statsd.NewBufferedClient(server, "", 10*time.Second, 0)
	if err != nil {
		return err
	}

	defer stats.Close()
	vlogf("opening  %v for stats\n", dsn)
	db, err := sql.Open("postgres", dsn)
	if err != nil {
		return err
	}
	defer db.Close()
	now := time.Now().Unix()
	err = fn(db, stats, now)
	if err != nil {
		return err
	}
	return nil
}

func roleString(dsn string, nodeRole *string) {
	//find out if master or replica
	if *nodeRole != "" {
		return
	}

	db, err := sql.Open("postgres", dsn)
	if err != nil {
		log.Fatal(err)
	}
	defer db.Close()

	rows, err := db.Query("select pg_is_in_recovery()")
	if err != nil {
		log.Fatal(err)
	}
	defer rows.Close()

	for rows.Next() {
		isReplica := false
		rows.Scan(&isReplica)
		if isReplica {
			*nodeRole = "replica"
			return
		}
	}
	*nodeRole = "master"
	vlogf("Determined role to be %v", *nodeRole)
}

func pollPostgresGlobals(interval int64, dsn string) {
	log.Printf("Polling globals on %v second interval\n", interval)
	stats := statscollection.NewStatsCollection()
	seconds := time.Duration(interval)
	for _ = range time.Tick(seconds * time.Second) {
		if err := logPostgresGlobals(dsn, stats); err != nil {
			log.Println(err)
			stats.Reset()
		}
	}
}

func logPostgresGlobals(dsn string, stats statscollection.StatsTracker) (err error) {
	return doQuery(dsn, "globals", func(db *sql.DB, statsd statsd.Statter, now int64) (err error) {
		stats.ClearOldHistory(now, globalInterval)
		err = queryTransactions(db, stats)
		if err != nil {
			return err
		}
		err = queryAggregate(db, stats)
		if err != nil {
			return err
		}
		return stats.FormatStats(statsd, now, cluster, nodeRole, hostname, schema, verbose)
	})
}

func queryTransactions(db *sql.DB, stats statscollection.StatsTracker) (err error) {
	// This query finds the largest relations and their sizes
	rows, err := db.Query("select xact_commit, xact_rollback " +
		"from pg_stat_database where datname = current_database() /* pg-stats global transactions */")
	if err != nil {
		return err
	}
	defer rows.Close()
	for rows.Next() {
		var xactCommit, xactRollback string
		if err := rows.Scan(&xactCommit, &xactRollback); err != nil {
			return err
		}
		stats.RecordDelta("all", "xact_commit", xactCommit)
		stats.RecordDelta("all", "xact_rollback", xactRollback)
	}
	return nil
}

func queryAggregate(db *sql.DB, stats statscollection.StatsTracker) (err error) {
	query := "select 'all', 0, " +
		"coalesce(sum(seq_tup_read), 0), coalesce(sum(idx_tup_fetch),0), " +
		"coalesce(sum(n_tup_ins), 0), coalesce(sum(n_tup_upd), 0), coalesce(sum(n_tup_del), 0), " +
		"coalesce(sum(n_tup_hot_upd), 0), coalesce(sum(n_live_tup), 0), coalesce(sum(n_dead_tup), 0), " +
		"'0', '0', " +
		"coalesce(sum(vacuum_count), 0), coalesce(sum(autovacuum_count), 0) " +
		"from pg_stat_user_tables " +
		"/* pg-stats global load aggregate */"
	rows, err := db.Query(query)
	if err != nil {
		return err
	}
	defer rows.Close()
	for rows.Next() {
		err = statscollection.RecordLoadRow(rows, stats)
		if err != nil {
			return err
		}
	}
	return nil
}

// gauges for dead & live tuples, and use counters for everything else
const statsQuery = "select C.relname, pg_total_relation_size(C.oid), " +
	"coalesce(seq_tup_read, 0), coalesce(idx_tup_fetch,0), " +
	"coalesce(n_tup_ins, 0), coalesce(n_tup_upd, 0), coalesce(n_tup_del, 0), " +
	"coalesce(n_tup_hot_upd, 0), coalesce(n_live_tup, 0), coalesce(n_dead_tup, 0), " +
	"coalesce(extract(epoch from last_vacuum)::integer, '0'), " +
	"coalesce(extract(epoch from last_autovacuum)::integer, '0'), " +
	"coalesce(vacuum_count, 0), coalesce(autovacuum_count, 0) " +
	"from pg_class C " +
	"join pg_stat_user_tables S ON C.relname = S.relname " +
	"where C.relkind <> 'i' " +
	"order by pg_total_relation_size(C.oid) desc " +
	"limit 25 /* pg-stats logPostgresStats */"

func pollPostgresStats(interval int, dsn string) {
	log.Printf("Polling postgresql stats on %v second interval\n", interval)
	stats := statscollection.NewStatsCollection()
	seconds := time.Duration(interval)
	for _ = range time.Tick(seconds * time.Second) {
		if err := logPostgresStats(dsn, stats); err != nil {
			log.Println(err)
		}
	}
}

func logPostgresStats(dsn string, stats statscollection.StatsTracker) (err error) {
	return doQuery(dsn, "stats", func(db *sql.DB, statsd statsd.Statter, now int64) (err error) {
		// This query finds the largest relations and their sizes
		vlogf(statsQuery)
		rows, err := db.Query(statsQuery)
		if err != nil {
			return err
		}
		defer rows.Close()

		for rows.Next() {
			err = statscollection.RecordLoadRow(rows, stats)
			if err != nil {
				return err
			}
		}
		return stats.FormatStats(statsd, now, cluster, nodeRole, hostname, schema, verbose)
	})
}

func pollPostgresLag(interval int, dsn string) {
	log.Printf("Polling replication lag on %v second interval\n", interval)
	seconds := time.Duration(interval)
	for _ = range time.Tick(seconds * time.Second) {
		if err := logPostgresLag(dsn); err != nil {
			log.Println(err)
		}
	}
}

func logPostgresLag(dsn string) (err error) {
	return doQuery(dsn, "lag", func(db *sql.DB, statsd statsd.Statter, now int64) (err error) {
		log.Println("querying for rep_lag")
		rows, err := db.Query("SELECT (CASE pg_is_in_recovery() WHEN false THEN NULL " +
			"ELSE extract(epoch FROM now() - pg_last_xact_replay_timestamp()) END) AS replication_lag " +
			"/* pg-stats logPostgresLag */")
		if err != nil {
			return err
		}
		defer rows.Close()
		for rows.Next() {
			var replicationLag sql.NullFloat64
			if err := rows.Scan(&replicationLag); err != nil {
				return err
			}
			if !replicationLag.Valid {
				log.Println("replication lag is not valid.")
				continue
			}
			line := fmt.Sprintf("postgres.%v.%v.%v.%v.replication_lag", cluster, nodeRole, hostname, schema)
			if err := statsd.Inc(line, int64(replicationLag.Float64), 1.0); err != nil {
				return nil
			}

		}
		return nil
	})
}
