package main

import (
	"bytes"
	"encoding/json"
	"flag"
	"fmt"
	"io/ioutil"
	"log"
	"net/http"
	"os"
	"strings"
	"time"

	"github.com/darkcrux/gopherduty"
	"github.com/deckarep/golang-set"

	//"github.com/aws/aws-sdk-go/aws"
	//"github.com/aws/aws-sdk-go/aws/credentials"
	//"github.com/aws/aws-sdk-go/aws/credentials/stscreds"
	//"github.com/aws/aws-sdk-go/aws/session"
	//"github.com/aws/aws-sdk-go/service/sts"

	"code.justin.tv/systems/sandstorm/manager"
)

type ZkStatus struct {
	State       uint8  `json:"code"`
	Description string `json:"description"`
	Hostname    string `json:"hostname"`
	IsLeader    bool   `json:"isLeader"`
}

type ZkCluster []ZkStatus

type StatusObj struct {
	NumOfLeader     int8
	FailedServerSet mapset.Set
	UpServerSet     mapset.Set
}

type AlertObj struct {
	IncidentKey string
	Description string
	Client      string
	ClientUrl   string
	Details     string
}

var (
	environment	 string
	servicekey       string
	interval         int
	vpcid            string
	hostname         string
	domainSuffix     string
	timeout          int
	timepassed       int
	triggerExhibitor bool
	triggerZk        bool
	shouldsend       bool
	preTriggerCheck  int
	numBeforeSend    int
)

func sendAlert(alertObj AlertObj) {

	client := gopherduty.NewClient(servicekey)
	client.MaxRetry = 5

	//incidentKey, description, client, clientUrl string, details interface{
	response := client.Trigger(alertObj.IncidentKey, alertObj.Description,
		alertObj.Client, alertObj.ClientUrl, alertObj.Details)

	if response.HasErrors() {
		log.Println("Error response from pagerduty: ", response)
	}

}

func sendAlertZk(statusObj StatusObj) {

	if !triggerZk && shouldsend {

		var buffer bytes.Buffer

		log.Println("Sending Zk alert to pagerduty and Zk error count before send: ", preTriggerCheck)
		// String concatenation
		buffer.WriteString("VPC_ID: ")
		buffer.WriteString(vpcid)

		if statusObj.NumOfLeader != 1 {
			buffer.WriteString(", there is ")
			buffer.WriteString(string(statusObj.NumOfLeader))
			buffer.WriteString(" leader in the cluster.")
		}

		buffer.WriteString(" The problem servers are: ")
		buffer.WriteString(statusObj.FailedServerSet.String())

		buffer.WriteString(" The normal servers are: ")
		buffer.WriteString(statusObj.UpServerSet.String())

		var clientURLBuffer bytes.Buffer

		clientURLBuffer.WriteString("http://")
		clientURLBuffer.WriteString(hostname)
		clientURLBuffer.WriteString(".")
		clientURLBuffer.WriteString(domainSuffix)
		clientURLBuffer.WriteString(":2181")

		alertObj := AlertObj{"zk_incident_key", "zk problem", hostname, clientURLBuffer.String(), buffer.String()}
		sendAlert(alertObj)

		triggerZk = true
	}

}

func sendAlertExhibitor(errorString string) {

	// Check to see if we are going to send alert to pagerduty
	if !triggerExhibitor && shouldsend {

		log.Println("Trigger exhibitor")

		var buffer bytes.Buffer

		// String concatenation
		buffer.WriteString("VPC_ID: ")
		buffer.WriteString(vpcid)
		buffer.WriteString(" hostname: ")
		buffer.WriteString(hostname)
		buffer.WriteString(" port: ")
		buffer.WriteString("8080")

		var clientURLBuffer bytes.Buffer

		clientURLBuffer.WriteString("http://")
		clientURLBuffer.WriteString(hostname)
		clientURLBuffer.WriteString(".")
		clientURLBuffer.WriteString(domainSuffix)
		clientURLBuffer.WriteString(":8080")

		alertObj := AlertObj{"exhibitor_incident_key", errorString, hostname, clientURLBuffer.String(), buffer.String()}
		sendAlert(alertObj)

		triggerExhibitor = true
	}
}

func checkStatus(clusterStats ZkCluster) {
	var numOfUpServer int8 = 0
	var numOfLeader int8 = 0

	failedServerSet := mapset.NewSet()
	upServerSet := mapset.NewSet()

	for i := range clusterStats {
		zkStatus := clusterStats[i]

		if zkStatus.State == 3 {
			numOfUpServer++
			upServerSet.Add(zkStatus.Hostname)
		} else {
			failedServerSet.Add(zkStatus.Hostname)
		}

		if zkStatus.IsLeader {
			numOfLeader++
		}
		//log.Println(i, " stats: ", clusterStats[i])
	}

	alertObj := StatusObj{NumOfLeader: numOfLeader, FailedServerSet: failedServerSet, UpServerSet: upServerSet}

	//log.Println("The normal servers: ", alertObj.UpServerSet.String())

	if numOfLeader != 1 || numOfUpServer != 5 {
		preTriggerCheck++
		if preTriggerCheck >= numBeforeSend {
			sendAlertZk(alertObj)
		}
	}
}

func init() {

	//flag.StringVar(&servicekey, "servicekey", "0fdc99ed525044dfa9d10578ee8a3f67", "pagerduty service key for the incident")
	flag.IntVar(&interval, "interval", 5000, "number of milliseconds between each check.")
	flag.IntVar(&timeout, "timeout", 1800, "number of seconds before sending alert again")
	flag.IntVar(&numBeforeSend, "numBeforeSend", 5, "number of alertable events before sending to pagerduty")
	flag.BoolVar(&shouldsend, "shouldsend", false, "should be sending alert to pagerduty")
	flag.StringVar(&environment, "environment", "staging", "should reflect which enviroment you are")

	triggerZk = false
	triggerExhibitor = false
	preTriggerCheck = 0

	var RoleARN string

	if strings.Compare(environment, "staging") == 0{
		RoleARN = "arn:aws:iam::734326455073:role/sandstorm/production/templated/role/zookeeper-staging"
	}else {
		RoleARN = "arn:aws:iam::734326455073:role/sandstorm/production/templated/role/zookeeper-production"
	}

	//stsclient := sts.New(session.New(&aws.Config{Region: aws.String("us-west-2")}))
	//
	//arp := &stscreds.AssumeRoleProvider{
	//	Duration:     900 * time.Second,
	//	ExpiryWindow: 10 * time.Second,
	//	RoleARN:      RoleARN,
	//	Client:       stsclient,
	//}
	//
	//credentials := credentials.NewCredentials(arp)
	//*aws.Config.WithCredentials(credentials)
	//awsConfig := &aws.Config{
	//	Region: aws.String("us-west-2"),
	//	Credentials: credentials,
	//}

	m := manager.New(manager.Config{
		//AWSConfig: awsConfig,
		TableName: "sandstorm-production", //need to change to sandstorm dyana table
		KeyID: "alias/sandstorm-production",
	})

	key := fmt.Sprint("d8a-eng/zookeeper/", environment, "/servicekey")

	servicekey, err := m.Get(key)
	if err != nil {
		panic(err)
	}
	fmt.Println(string(servicekey.Plaintext))

	//a.SVC = sqs.New(session.New(a.awsConfig))
	//a.manager = manager.New(manager.Config{
	//AWSConfig: a.awsConfig,
	//TableName: a.TableName,
	//KeyID:     a.KeyID,
	//})



	// getting vpc id
	content, err := ioutil.ReadFile("/etc/facter/facts.d/vpcid.txt")
	if err != nil {
		vpcid = "vpc-local"
		log.Println("Can not find /etc/facter/facts.d/vpcid.txt file -- default to vpc-local")
	} else {
		vpcidString := strings.TrimSpace(string(content))
		s := strings.Split(vpcidString, "=")
		if len(s) < 2 {
			vpcid = "vpc-local"
			log.Println("Can not parse content of /etc/facter/facts.d/vpcid.txt -- default to vpc-local")
		} else {
			vpcid = s[1]
		}
	}

	// getting hostname
	name, err := os.Hostname()
	if err != nil {
		// handle error cases
		log.Println("can not get hostname: ", err)
		hostname = "localhost"
	} else {
		hostname = name
	}

	//getting fqdn suffix
	suffix, err := ioutil.ReadFile("/etc/facter/facts.d/domain_suffix.txt")
	if err != nil {
		domainSuffix = "us-west2.justin.tv"
		log.Println("Can not find /etc/facter/facts.d/domain_suffix.txt file -- default to dev.us-west2.justin.tv")
	} else {
		suffixString := strings.TrimSpace(string(suffix))
		s := strings.Split(suffixString, "=")
		if len(s) < 2 {
			domainSuffix = "us-west2.justin.tv"
			log.Println("Can not parse content of /etc/facter/facts.d/domain_suffix.txt -- default to dev.us-west2.justin.tv")
		} else {
			domainSuffix = s[1]
		}
	}
}

func pagerDuty() {

	for {
		resp, err := http.Get("http://localhost:8080/exhibitor/v1/cluster/status")
		if err != nil {
			// handle error
			//log.Println("can not connect to exhibitor")
			sendAlertExhibitor("can not connect to exhibitor")

		} else {

			defer resp.Body.Close()

			body, err := ioutil.ReadAll(resp.Body)

			if err != nil {
				// handle error
				sendAlertExhibitor("can not read return status from exhibitor")
			} else {

				log.Println("The result json is: ", string(body))
				zkstateArray := ZkCluster{}
				json.Unmarshal([]byte(string(body)), &zkstateArray)

				checkStatus(zkstateArray)
			}
		}

		// amount of time to sleep
		time.Sleep(time.Duration(interval) * time.Millisecond)

		timepassed += interval / 1000

		// Reset the flags so they will alert again
		if timepassed >= timeout {
			triggerZk = false
			triggerExhibitor = false
			timepassed = 0
			preTriggerCheck = 0
		}
	}
}

func main() {
	// parse the command-line arguments
	flag.Parse()

	go pagerDuty()

	go consulCheck()

	select {}

}
