package main

import (
	"flag"
	"fmt"
	"log"
	"net/http"
	"os"
	"strconv"
	"time"

	"github.com/cactus/go-statsd-client/statsd"
	"github.com/gorilla/mux"

	_ "expvar"
	_ "net/http/pprof"

	"code.justin.tv/d8a/pg-healthcheck/common"
	"code.justin.tv/d8a/pg-healthcheck/forcedstatus"
	//"code.justin.tv/d8a/pg-healthcheck/healthtiming"
	healthcheckstats "code.justin.tv/d8a/pg-healthcheck/stats"
)

// type interface used for flag.Var()
type intslice []int

func (p *intslice) String() string {
	return fmt.Sprintf("%d", *p)
}

func (p *intslice) Set(value string) error {
	tmp, err := strconv.Atoi(value)
	if err != nil {
		log.Fatal("Unable to parse port")
	}
	*p = append(*p, tmp)
	return nil
}

var (
	interval_s int
	local_port int
	stale_s    float64
	cluster    string
	role       string

	host     string
	ports    intslice
	user     string
	pass     string
	database string

	absoluteThreshold int
	relativeThreshold float64

	statsdHostPort string
	graphiteURL    string
)

func init() {
	flag.IntVar(&interval_s, "interval", 5, "number of seconds between checking status.")
	flag.IntVar(&local_port, "local-port", 6545, "listen port of the http healthcheck server.")
	flag.Float64Var(&stale_s, "stale-timeout", 9.5, "seconds since check has run to consider last status stale.")

	flag.StringVar(&host, "host", "localhost", "host to connect to and check status.")
	flag.Var(&ports, "port", "port on the host. one port can be specified per flag.")
	flag.StringVar(&user, "user", "", "database user name.")
	flag.StringVar(&pass, "pass", "", "password for connection.")
	flag.StringVar(&database, "database", "", "database name for the connection.")
	flag.StringVar(&statsdHostPort, "statsd-host-port", "graphite.internal.justin.tv:8125", "statsd host to log health statistics to")
	flag.StringVar(&cluster, "cluster", "", "cluster name to log under in statsd")
	flag.StringVar(&role, "role", "", "master, replica, or some other label")
	flag.IntVar(&absoluteThreshold, "absolute-threshold", 50, "Time in ms the health check can take before the node is eligible to be marked unhealthy")
	flag.Float64Var(&relativeThreshold, "relative-threshold", 2.0, "Factor of the average health check time before the node is marked unhealthy, if it is above the absolute threshold")
	flag.StringVar(&graphiteURL, "graphite-url", "graphite.internal.justin.tv", "Graphite URL to query for average health check timing statistics")
}

func main() {
	flag.Parse()

	if cluster == "" {
		log.Fatal("Must specify a cluster name.")
	}
	if database == "" {
		log.Fatal("Must specify a database name.")
	}
	if len(ports) == 0 {
		log.Fatal("Must specify at least one port to check.")
	}

	msg := fmt.Sprintf("Starting pg-healthcheck: %v second interval on %v %v.", interval_s, host, ports)
	log.Println(msg)

	st := NewStatusTracker(ports)
	go listenAndServe(st)

	reasonService := forcedstatus.NewFileReasonService()
	statusWatcher, err := forcedstatus.NewStatusWatcher(ports, reasonService)
	if err != nil {
		log.Fatalln(err)
	}
	defer common.CloseResource(statusWatcher)

	roleContainer := healthcheckstats.NewRoleContainer(role)

	for _, port := range ports {
		be := NewLogProxy(NewBackend(host, port, database, user, pass))

		//query := healthtiming.NewGraphiteQuery(graphiteURL, cluster, os.Getenv("ENVIRONMENT"), roleContainer, getHostname(), port)
		//be = healthtiming.NewBackendTimingChecker(be, time.Duration(absoluteThreshold), relativeThreshold, query)

		stats, err := statsd.NewBufferedClient(statsdHostPort, "", time.Second, 0)
		if err != nil {
			log.Fatal(err)
		}
		be = healthcheckstats.NewStatsWriter(be, cluster, port, roleContainer, stats)
		be = NewLogProxy(be)

		be = forcedstatus.NewForcedStatusChecker(be, statusWatcher, port)

		go pollHealth(interval_s, port, be, st)
	}

	select {}
}

func handler(st common.StatusTracker) func(w http.ResponseWriter, r *http.Request) {
	return func(w http.ResponseWriter, r *http.Request) {
		if r.Method != "GET" {
			w.Header().Set("Allow", "GET")
			http.Error(w, "Unsupported method", http.StatusMethodNotAllowed)
			return
		}

		vars := mux.Vars(r)
		portStr := vars["port"]

		msg := ""
		statusCode := 0
		if portStr == "all" {
			for _, port := range ports {
				portStatus, portMsg := queryPort(st, port)

				if portStatus > statusCode {
					statusCode = portStatus
				}
				msg += fmt.Sprintf("Port %d - %s\n", port, portMsg)
			}
		} else {
			port, err := strconv.Atoi(vars["port"])
			if err != nil {
				http.Error(w, "Unable to parse port.", http.StatusInternalServerError)
				return
			}

			statusCode, msg = queryPort(st, port)
		}

		if statusCode != http.StatusOK {
			http.Error(w, msg, statusCode)
		} else {
			_, err := w.Write([]byte(msg))
			if err != nil {
				log.Println(err)
			}
		}
	}
}

func queryPort(st common.StatusTracker, port int) (int, string) {
	status, err := st.Status(port)
	if err != nil {
		return http.StatusNotFound, err.Error()
	}

	if stale_s > 0.0 {
		// If the stale check fails, it means that this service is
		// not checking frequently enough and is therefore a
		// failure here rather necessarily a failure in
		// postgres. So, we return internal service error to make
		// it clear the problem is in this service.
		age := time.Since(status.At())
		if age.Seconds() > stale_s {
			msg := fmt.Sprintf("Stale check last recorded at %v.", status.At())
			return http.StatusInternalServerError, msg
		}
	}

	// XXX AGB: Consider setting a cache header based on status.at and
	// interval instead printing the information in the
	// payload. We can make that determination once the HAProxy
	// interface matures. 2015-01-06
	msg := fmt.Sprintf("%v. Last checked at %v.", status.Message(), status.At())
	if !status.Ok() {
		return http.StatusServiceUnavailable, msg
	}

	return http.StatusOK, msg
}

func listenAndServe(st common.StatusTracker) {
	handler := handler(st)
	r := mux.NewRouter()
	r.HandleFunc("/port/{port:(all|[0-9]+)}", handler)
	http.Handle("/", r)
	addr := fmt.Sprintf(":%d", local_port)
	log.Fatal(http.ListenAndServe(addr, nil))
}

func pollBackend(port int, be common.BackendChecker, st common.StatusTracker) (err error) {
	healthyMsg, err := be.Check()
	if err != nil {
		return st.MarkDown(port, err.Error())
	}

	return st.MarkUp(port, healthyMsg)
}

func pollHealth(interval int, port int, be common.BackendChecker, st common.StatusTracker) {
	seconds := time.Duration(interval)
	for _ = range time.Tick(seconds * time.Second) {
		if err := pollBackend(port, be, st); err != nil {
			log.Println(err)
		}
	}
}

func getHostname() string {
	hostname, err := os.Hostname()
	if err != nil {
		hostname = "unknown"
	}
	return hostname
}
