package index

import (
	"bytes"
	"context"
	"crypto/sha256"
	"encoding/hex"
	"sync"

	txtool "code.justin.tv/release/trace/analysis/tx"
	"code.justin.tv/release/trace/api"
	"code.justin.tv/release/trace/internal/stream"
	"code.justin.tv/release/trace/persistent"
	"github.com/aws/aws-sdk-go/aws"
	"github.com/aws/aws-sdk-go/service/firehose"
	"github.com/aws/aws-sdk-go/service/firehose/firehoseiface"
	"github.com/aws/aws-sdk-go/service/kinesis"
	"github.com/aws/aws-sdk-go/service/kinesis/kinesisiface"
	"github.com/golang/protobuf/proto"
	"github.com/syndtr/goleveldb/leveldb"
	"github.com/syndtr/goleveldb/leveldb/storage"
	"github.com/syndtr/goleveldb/leveldb/util"
	"golang.org/x/sync/errgroup"
)

type statLogger struct {
	prefix   string
	logFunc  func(format string, v ...interface{})
	statFunc func(name string, val float64, units string)
}

func (l *statLogger) logError(name string, err error) {
	l.scalar(name, 1)
	if l != nil && l.logFunc != nil {
		l.logFunc("%s%s: %v", l.prefix, name, err)
	}
}

func (l *statLogger) scalar(name string, val float64) {
	l.units(name, val, "None")
}

func (l *statLogger) units(name string, val float64, units string) {
	if l != nil && l.statFunc != nil {
		l.statFunc(l.prefix+name, val, units)
	}
}

type streamWriter struct {
	kinesis    kinesisiface.KinesisAPI
	streamName string
	stats      *statLogger
}

func (w *streamWriter) ProcessStream(ctx context.Context, out stream.Enqueuer, in stream.Dequeuer) error {
	for v, ok := in.Dequeue(ctx); ok && ctx.Err() == nil; v, ok = in.Dequeue(ctx) {
		buf := v.([]byte)

		hash := sha256.Sum256(buf)
		req := &kinesis.PutRecordInput{
			StreamName: aws.String(w.streamName),
			Data:       buf,
			// TODO: retry rate-limit failures with a new key / different shard
			PartitionKey: aws.String(hex.EncodeToString(hash[:])),
		}
		_, err := w.kinesis.PutRecordWithContext(ctx, req)
		if err != nil {
			w.stats.logError("KinesisPutRecordError", err)
			continue
		}
		w.stats.scalar("KinesisPutRecordSuccess", 1)
	}
	return nil
}

type firehoseWriter struct {
	firehose   firehoseiface.FirehoseAPI
	streamName string
	stats      *statLogger
}

func (w *firehoseWriter) ProcessStream(ctx context.Context, out stream.Enqueuer, in stream.Dequeuer) error {
	for v, ok := in.Dequeue(ctx); ok && ctx.Err() == nil; v, ok = in.Dequeue(ctx) {
		buf := v.([]byte)

		prefixed := append(proto.EncodeVarint(uint64(len(buf))), buf...)

		_, err := w.firehose.PutRecordWithContext(ctx, &firehose.PutRecordInput{
			DeliveryStreamName: aws.String(w.streamName),
			Record:             &firehose.Record{Data: prefixed},
		})
		if err != nil {
			w.stats.logError("FirehosePutRecordError", err)
			continue
		}
		w.stats.scalar("FirehosePutRecordSuccess", 1)
	}
	return nil
}

type indexGenerator struct {
	stats *statLogger
}

func (gen *indexGenerator) ProcessStream(ctx context.Context, out stream.Enqueuer, in stream.Dequeuer) error {
	for v, ok := in.Dequeue(ctx); ok && ctx.Err() == nil; v, ok = in.Dequeue(ctx) {
		record := v.(*kinesisRecord)
		var txs api.TransactionSet
		err := proto.Unmarshal(record.data, &txs)
		if err != nil {
			gen.stats.logError("UnmarshalError", err)
			continue
		}

		gen.stats.scalar("EntryCount", float64(len(txs.GetTransaction())))

		buf, err := buildIndex(&txs, record.shardId, record.sequenceNumber)
		if err != nil {
			gen.stats.logError("BuildIndexError", err)
			continue
		}

		gen.stats.units("ZipSize", float64(len(buf)), "Bytes")

		out.Enqueue(ctx, buf)
	}
	return nil
}

type kinesisRecord struct {
	data           []byte
	shardId        string
	sequenceNumber string
}

type TxLocation struct {
	Kinesis  kinesisiface.KinesisAPI
	Firehose firehoseiface.FirehoseAPI
	LogFunc  func(format string, v ...interface{})
	StatFunc func(name string, val float64, units string)

	setupOnce sync.Once

	stats *statLogger

	streamName string
	workers    int

	stop         sync.Once
	txSetRecords stream.Stream
}

func (idx *TxLocation) setup() {
	const (
		defaultStreamName = "tx-index"
		defaultWorkers    = 1
	)

	idx.setupOnce.Do(func() {
		idx.stats = &statLogger{
			prefix:   "TransactionIndex/",
			logFunc:  idx.LogFunc,
			statFunc: idx.StatFunc,
		}

		idx.streamName = defaultStreamName
		idx.workers = defaultWorkers

		idx.txSetRecords = make(stream.UntypedStream, idx.workers)
	})
}

func (idx *TxLocation) Run(ctx context.Context) error {
	idx.setup()

	ctx, cancel := context.WithCancel(ctx)
	defer cancel()

	go func() {
		<-ctx.Done()
		idx.stop.Do(func() { idx.txSetRecords.Finish() })
	}()

	eg, ctx := errgroup.WithContext(ctx)

	var indexZips stream.FanoutEnqueue
	addIndexDest := func(sp stream.StreamProcessor) {
		in := make(stream.UntypedStream)
		indexZips = append(indexZips, in)
		fanoutGroup, ctx := errgroup.WithContext(ctx)
		for i := 0; i < idx.workers; i++ {
			fanoutGroup.Go(func() error { return sp.ProcessStream(ctx, nil, in) })
		}
		eg.Go(func() error {
			defer in.Finish()
			return fanoutGroup.Wait()
		})
	}

	if idx.Kinesis != nil {
		addIndexDest(&streamWriter{stats: idx.stats, streamName: idx.streamName, kinesis: idx.Kinesis})
	}
	if idx.Firehose != nil {
		addIndexDest(&firehoseWriter{stats: idx.stats, streamName: idx.streamName, firehose: idx.Firehose})
	}

	gen := &indexGenerator{stats: idx.stats}
	for i := 0; i < idx.workers; i++ {
		eg.Go(func() error { return gen.ProcessStream(ctx, indexZips, idx.txSetRecords) })
	}

	return eg.Wait()
}

func (idx *TxLocation) OnPutRecord(data []byte, shardId string, sequenceNumber string) {
	idx.setup()

	deadCtx, cancel := context.WithCancel(context.Background())
	cancel()

	val := &kinesisRecord{data: data, shardId: shardId, sequenceNumber: sequenceNumber}

	idx.stats.scalar("PutRecordInput", 1)
	ok := idx.txSetRecords.Enqueue(deadCtx, val)
	if !ok {
		idx.stats.scalar("PutRecordBackpressure", 1)
	}
}

func buildIndex(txs *api.TransactionSet, shardId, seqNum string) ([]byte, error) {
	batch := new(leveldb.Batch)
	for _, tx := range txs.GetTransaction() {
		key := (&kinesisLocation{
			txid:   txtool.IDForTx(tx).String(),
			shard:  shardId,
			seqNum: seqNum,
		}).leveldbKey()
		value := ""
		batch.Put([]byte(key), []byte(value))
	}

	stor := storage.NewMemStorage()
	defer stor.Close()

	db, err := leveldb.Open(stor, nil)
	if err != nil {
		return nil, err
	}
	defer db.Close()

	err = db.Write(batch, nil)
	if err != nil {
		return nil, err
	}

	err = db.CompactRange(util.Range{Start: nil, Limit: nil})
	if err != nil {
		return nil, err
	}

	err = db.Close()
	if err != nil {
		return nil, err
	}

	var buf bytes.Buffer
	_, err = persistent.NewZipWriter(stor).WriteTo(&buf)
	if err != nil {
		return nil, err
	}

	return buf.Bytes(), nil
}
