package persistent

import (
	"crypto/sha256"
	"encoding/hex"
	"fmt"
	"log"
	"net/url"
	"sort"
	"strings"

	"code.justin.tv/release/trace"
	"code.justin.tv/release/trace/analysis/tx"
	"code.justin.tv/release/trace/api"
	"code.justin.tv/release/trace/api/report_v1"
	"github.com/golang/protobuf/proto"
	"github.com/pkg/errors"
	"github.com/syndtr/goleveldb/leveldb"
	"github.com/syndtr/goleveldb/leveldb/opt"
	"github.com/syndtr/goleveldb/leveldb/util"
)

// We store a few kinds of records in the db:
//
// - Transactions
//   - /tx/...
// - Program reports
//   - /report/v1/...
//
// The keyspace /tx/{txid}/{hash} contains Transaction messages, where "txid"
// is the transaction's hex-formatted transaction ID and "hash" is the sha256
// hash of the encoded data. The messages are
// code.justin.tv.release.trace.api.Transaction messages, as defined in
// code.justin.tv/release/trace/api/transaction.proto.
//
// The keyspace /report/v1/{program_name}/{hash}, where "program_name" is the
// fully-qualified name of the server program represented in the report
// (escaped for safe inclusion in a URL path) and "hash" is the sha256 hash of
// the encoded data. The messages are
// code.justin.tv.release.trace.api.report.v1.ProgramReport messages, as
// defined in code.justin.tv/release/trace/api/report_v1/report.proto.

type DB struct {
	ldb *leveldb.DB
}

// OpenDB provides access to a leveldb database on the local filesystem.
func OpenDB(path string) (*DB, error) {
	db, err := leveldb.OpenFile(path, &opt.Options{})
	if err != nil {
		return nil, errors.Wrapf(err, "open leveldb path=%q", path)
	}

	return &DB{
		ldb: db,
	}, nil
}

func (db *DB) Close() error {
	return db.ldb.Close()
}

func (db *DB) CompactAll() error {
	return db.ldb.CompactRange(util.Range{})
}

func (db *DB) WriteTransaction(tx *api.Transaction) error {
	txid := idForTx(tx)

	buf, err := proto.Marshal(tx)
	if err != nil {
		return errors.Wrapf(err, "marshal transaction txid=%q", txid)
	}

	hash := sha256.Sum256(buf)
	key := fmt.Sprintf("/tx/%s/%s", txid.String(), hex.EncodeToString(hash[:]))

	// TODO(rhys): batch writes
	err = db.ldb.Put([]byte(key), buf, nil)
	if err != nil {
		return errors.Wrapf(err, "write transaction key=%q", key)
	}

	return nil
}

func (db *DB) ReadTransaction(txid *trace.ID) (*api.Transaction, error) {
	key := fmt.Sprintf("/tx/%s/", txid.String())
	it := db.ldb.NewIterator(util.BytesPrefix([]byte(key)), nil)
	defer it.Release()

	var tx api.Transaction
	for it.Next() {
		// TODO: merge transactions
		err := proto.Unmarshal(it.Value(), &tx)
		if err != nil {
			return nil, errors.Wrap(err, "transaction could not be unmarshaled")
		}
		return &tx, nil
	}

	return nil, errors.New("transaction not found")
}

func idForTx(t *api.Transaction) tx.TransactionID {
	txid := t.TransactionId
	switch len(txid) {
	case 0:
		return tx.TransactionID{0, 0}
	case 1:
		return tx.TransactionID{0, t.TransactionId[0]}
	default:
		return tx.TransactionID{t.TransactionId[0], t.TransactionId[1]}

	}
}

func (db *DB) WriteProgramReport(rep *report_v1.ProgramReport) error {
	buf, err := proto.Marshal(rep)
	if err != nil {
		return errors.Wrapf(err, "marshal report content server=%q", rep.GetProgram().GetName())
	}

	hash := sha256.Sum256(buf)
	key := fmt.Sprintf("/report/v1/%s/%s", url.QueryEscape(rep.GetProgram().GetName()), hex.EncodeToString(hash[:]))

	err = db.ldb.Put([]byte(key), buf, nil)
	if err != nil {
		return errors.Wrapf(err, "write report content key=%q", key)
	}

	return nil
}

func (db *DB) ReadProgramReport(serverName string) (*report_v1.ProgramReport, error) {
	key := fmt.Sprintf("/report/v1/%s/", url.QueryEscape(serverName))
	it := db.ldb.NewIterator(util.BytesPrefix([]byte(key)), nil)
	defer it.Release()

	var rep report_v1.ProgramReport
	for it.Next() {
		// TODO: merge reports
		err := proto.Unmarshal(it.Value(), &rep)
		if err != nil {
			return nil, errors.Wrap(err, "report could not be unmarshaled")
		}
		return &rep, nil
	}

	return nil, errors.New("report not found")
}

func (db *DB) ListServers() ([]string, error) {
	const prefix = "/report/v1/"

	it := db.ldb.NewIterator(util.BytesPrefix([]byte(prefix)), nil)
	defer it.Release()

	// We iterate in reverse, so we can seek to the correct end bound to skip
	// an entire server's report listings. We plan to have only one report for
	// each service name, but this code is prepared for that to change.

	var servers []string
	next := it.Last
	for next() {
		next = it.Prev

		key := string(it.Key())

		parts := strings.SplitN(strings.TrimPrefix(key, prefix), "/", 3)
		if len(parts) != 2 {
			log.Printf("malformed key=%s", key)
			continue
		}
		encodedName := parts[0]
		it.Seek([]byte(fmt.Sprintf("%s%s/", prefix, encodedName)))

		name, err := url.QueryUnescape(encodedName)
		if err != nil {
			log.Printf("malformed key=%s", key)
			continue
		}

		servers = append(servers, name)
	}

	sort.Strings(servers)

	return servers, nil
}
