package storage

import (
	"errors"
	"fmt"
	"github.com/golang/protobuf/proto"
	"github.com/golang/snappy"
	"sync"
)

const (
	// Our current storable objects all have first byte value
	// not less than 8, so we can use first byte value equal to 1
	// to show that the data is in custom format
	encodingSnappy = iota + 1
)
const (
	storageCodecSnappyReadOnly  = "r-snappy"
	storageCodecSnappyReadWrite = "rw-snappy"
)

const (
	bufferCapacity = 4096
)

type Codec interface {
	Encode(obj Storable) ([]byte, error)
	Decode(data []byte, obj Storable) error
}

func NewCodec(storageCodec string) (Codec, error) {
	switch storageCodec {
	case storageCodecSnappyReadOnly:
		return &snappyCodec{
			readOnly: true,
			pool:     newBufferPool(bufferCapacity),
		}, nil
	case storageCodecSnappyReadWrite:
		return &snappyCodec{
			readOnly: false,
			pool:     newBufferPool(bufferCapacity),
		}, nil
	default:
		return nil, fmt.Errorf("unknown etcd codec: %s", storageCodec)
	}
}

type bufferPool struct {
	p sync.Pool
}

func newBufferPool(size int) bufferPool {
	return bufferPool{
		p: sync.Pool{
			New: func() interface{} {
				return make([]byte, 0, size)
			},
		},
	}
}

// Get returns byte slice with len=0 and cap=size
func (bp *bufferPool) Get() []byte {
	return bp.p.Get().([]byte)
}

// Put returns byte slice to pool, truncating it's size to 0
func (bp *bufferPool) Put(b []byte) {
	bp.p.Put(b[0:0])
}

type snappyCodec struct {
	readOnly bool
	pool     bufferPool
}

func (c *snappyCodec) Encode(obj Storable) ([]byte, error) {
	if c.readOnly {
		return proto.Marshal(obj)
	}
	buf := c.pool.Get()
	pbuf := proto.NewBuffer(buf)
	defer c.pool.Put(buf)
	if err := pbuf.Marshal(obj); err != nil {
		return nil, err
	}
	l := snappy.MaxEncodedLen(len(pbuf.Bytes()))
	// Use first byte for data format
	r := make([]byte, l+1)
	r[0] = encodingSnappy
	d := snappy.Encode(r[1:], pbuf.Bytes())
	return r[:len(d)+1], nil
}

func (c *snappyCodec) Decode(data []byte, obj Storable) error {
	if len(data) == 0 {
		return errors.New("cannot decode empty data slice")
	}
	if data[0] == encodingSnappy {
		enc := data[1:]
		buf := c.pool.Get()
		buf = buf[:cap(buf)]
		defer c.pool.Put(buf)
		dec, err := snappy.Decode(buf, enc)
		if err != nil {
			return err
		}
		return proto.Unmarshal(dec, obj)
	}
	return proto.Unmarshal(data, obj)
}
