package main

import (
	"encoding/json"
	"flag"
	"fmt"
	"log"
	"net"
	"strconv"
	"strings"
	"sync"
	"time"

	"github.com/Shopify/sarama"
	"github.com/cactus/go-statsd-client/statsd"
	"github.com/samuel/go-zookeeper/zk"
)

var (
	statsdAddr  = flag.String("statsd", "localhost:8125", "hostport of a statsd server to receive metrics")
	statsPrefix = flag.String("statsd-prefix", "trace.kafka.consumer-monitoring", "prefix in front of statsd metrics")
	zkAddr      = flag.String("zk", "localhost:2181", "comma-separated hostports of zookeeper servers with kafka consumer data")
	interval    = flag.Duration("interval", time.Second*5, "polling interval to look up consumer positions")

	noop = flag.Bool("noop", true, "just log values, don't actually send them to statsds")
)

type broker struct {
	id         int32
	addr       string
	conn       *sarama.Broker
	partitions []*partition
}

type consumer struct {
	group      string
	id         string
	partitions []*partition
}

func countConsumersPerTopicPerGroup(consumers []*consumer) map[string]map[string]int {
	result := make(map[string]map[string]int)

	byTopic := make(map[string][]*consumer)
	for _, c := range consumers {
		topics := make(map[string]struct{})
		for _, p := range c.partitions {
			topics[p.topic] = struct{}{}
		}
		for t := range topics {
			byTopic[t] = append(byTopic[t], c)
		}
	}

	for topic, consumers := range byTopic {
		result[topic] = make(map[string]int)
		for _, c := range consumers {
			result[topic][c.group] += 1
		}
	}
	return result
}

type partition struct {
	topic  string
	id     int
	offset int64
}

func getBrokerData(z *zk.Conn) ([]*broker, error) {
	brokers := make(map[int32]*broker)

	topics, _, err := z.Children("/brokers/topics")
	if err != nil {
		return nil, err
	}

	// retrieve partition ownerhip state
	type partitionState struct {
		Leader int32
	}

	for _, t := range topics {
		partitionIDs, _, err := z.Children(fmt.Sprintf("/brokers/topics/%s/partitions", t))
		if err != nil {
			return nil, err
		}
		for _, p := range partitionIDs {
			raw, _, err := z.Get(fmt.Sprintf("/brokers/topics/%s/partitions/%s/state", t, p))
			if err != nil {
				return nil, err
			}

			state := &partitionState{}
			err = json.Unmarshal(raw, state)
			if err != nil {
				return nil, err
			}

			b, ok := brokers[state.Leader]
			if !ok {
				b = &broker{id: state.Leader}
				brokers[state.Leader] = b
			}

			partitionInt, err := strconv.Atoi(p)
			if err != nil {
				return nil, err
			}

			b.partitions = append(b.partitions, &partition{topic: t, id: partitionInt})
		}
	}

	// retrieve information on hostport of each broker and establish
	// connections to the brokers.
	type brokerRegistration struct {
		Host string
		Port int
	}
	for id, b := range brokers {
		raw, _, err := z.Get(fmt.Sprintf("/brokers/ids/%d", id))
		if err != nil {
			return nil, err
		}

		reg := &brokerRegistration{}
		if err := json.Unmarshal(raw, reg); err != nil {
			return nil, err
		}
		b.addr = net.JoinHostPort(reg.Host, strconv.Itoa(reg.Port))

		b.conn = sarama.NewBroker(b.addr)
		if err := b.conn.Open(nil); err != nil {
			return nil, err
		}
		defer b.conn.Close()
	}

	// Query brokers in parallel, since they are different servers
	var wg = &sync.WaitGroup{}
	errors := make(chan error)

	wg.Add(len(brokers))
	for _, b := range brokers {
		go func(b *broker, w *sync.WaitGroup) {
			defer w.Done()
			// Construct the request to send to kafka
			request := &sarama.OffsetRequest{}
			for _, p := range b.partitions {
				request.AddBlock(p.topic, int32(p.id), sarama.OffsetNewest, 1)
			}

			// send the request
			resp, err := b.conn.GetAvailableOffsets(request)
			if err != nil {
				errors <- err
				return
			}

			// parse the response
			for _, p := range b.partitions {
				block := resp.GetBlock(p.topic, int32(p.id))
				if block == nil {
					p.offset = 0
				} else {
					p.offset = block.Offsets[0]
				}
			}
		}(b, wg)
	}
	allDone := make(chan struct{})
	go func() {
		wg.Wait()
		close(allDone)
	}()

	select {
	case <-allDone:
	case err := <-errors:
		return nil, err
	}

	brokerSlice := make([]*broker, 0)
	for _, b := range brokers {
		brokerSlice = append(brokerSlice, b)
	}
	return brokerSlice, nil
}

func getConsumerData(z *zk.Conn) ([]*consumer, error) {
	consumers := make(map[string]*consumer)

	groups, _, err := z.Children("/consumers")
	if err != nil {
		return nil, err
	}

	for _, g := range groups {
		topics, _, err := z.Children(fmt.Sprintf("/consumers/%s/owners", g))
		if err != nil {
			return nil, err
		}
		for _, t := range topics {
			partitions, _, err := z.Children(fmt.Sprintf("/consumers/%s/owners/%s", g, t))
			if err != nil {
				return nil, err
			}
			for _, p := range partitions {
				raw, _, err := z.Get(fmt.Sprintf("/consumers/%s/owners/%s/%s", g, t, p))
				if err != nil {
					return nil, err
				}

				id := string(raw)

				consumerKey := fmt.Sprintf("[%s][%s]", g, id)
				c, ok := consumers[consumerKey]
				if !ok {
					c = &consumer{id: id, group: g}
					consumers[consumerKey] = c
				}

				partitionInt, err := strconv.Atoi(p)
				if err != nil {
					return nil, err
				}

				raw, _, err = z.Get(fmt.Sprintf("/consumers/%s/offsets/%s/%s", g, t, p))
				if err != nil {
					return nil, err
				}

				offset, err := strconv.ParseInt(string(raw), 10, 0)
				if err != nil {
					return nil, err
				}

				partition := &partition{topic: t, id: partitionInt, offset: offset}
				c.partitions = append(c.partitions, partition)
			}
		}
	}
	consumerSlice := make([]*consumer, 0)
	for _, c := range consumers {
		consumerSlice = append(consumerSlice, c)
	}
	return consumerSlice, nil
}

func writeStats(s statsd.Statter, brokers []*broker, consumers []*consumer) error {

	brokerOffsets := make(map[string]map[int]int64)
	for _, b := range brokers {
		for _, p := range b.partitions {
			metric := fmt.Sprintf("%s.%d.broker.offset", p.topic, p.id)
			if err := s.Gauge(metric, p.offset, 1.0); err != nil {
				return err
			}

			o, ok := brokerOffsets[p.topic]
			if !ok {
				o = make(map[int]int64)
				brokerOffsets[p.topic] = o
			}
			o[p.id] = p.offset
		}
	}

	for _, c := range consumers {
		for _, p := range c.partitions {
			metric := fmt.Sprintf("%s.%d.%s.offset", p.topic, p.id, c.group)
			if err := s.Gauge(metric, p.offset, 1.0); err != nil {
				return err
			}

			b, ok := brokerOffsets[p.topic]
			if ok {
				metric := fmt.Sprintf("%s.%d.%s.lag", p.topic, p.id, c.group)
				brokerOffset := b[p.id]
				lag := brokerOffset - p.offset
				if err := s.Gauge(metric, lag, 1.0); err != nil {
					return err
				}
			}
		}
	}

	counts := countConsumersPerTopicPerGroup(consumers)
	for t, groups := range counts {
		for group, count := range groups {
			metric := fmt.Sprintf("%s.active-consumers.%s", t, group)
			if err := s.Gauge(metric, int64(count), 1.0); err != nil {
				return err
			}
		}
	}

	return nil
}

func main() {
	flag.Parse()

	zkAddrs := strings.Split(*zkAddr, ",")

	log.Printf("attempting to connect to zookeeper at %v", zkAddrs)
	zkClient, _, err := zk.Connect(zkAddrs, *interval)
	if err != nil {
		log.Fatalf("Unable to connect to ZooKeeper; err=%q", err)
	}
	defer zkClient.Close()
	log.Printf("zookeeper connection established")

	var statsdClient statsd.Statter
	if *noop {
		log.Printf("using no-op logger instead of statsd")
		statsdClient = &statsdLogger{*statsPrefix}
	} else {
		log.Printf("creating statsd client, writing to %v", *statsdAddr)
		statsdClient, err = statsd.NewClient(*statsdAddr, *statsPrefix)
		if err != nil {
			log.Fatalf("Unable to connect to statsd; err=%q", err)
		}
		log.Printf("statsd client created")
	}
	defer statsdClient.Close()

	log.Printf("monitoring loop start")
	ticker := time.NewTicker(*interval)
	for {
		<-ticker.C
		log.Printf("getting broker data")
		brokers, err := getBrokerData(zkClient)
		if err != nil {
			log.Printf("[ERROR] failed retrieving broker data! err=%q", err)
			continue
		}

		log.Printf("getting consumer data")
		consumers, err := getConsumerData(zkClient)
		if err != nil {
			log.Printf("[ERROR] failed retrieving broker data! err=%q", err)
			continue
		}

		log.Printf("writing stats")
		err = writeStats(statsdClient, brokers, consumers)
	}
	log.Printf("graceful exit")
}
