package compressor

import (
	"bytes"
	"compress/gzip"
	"io"

	"github.com/andybalholm/brotli"
	"github.com/segmentio/kafka-go/zstd"

	"a.yandex-team.ru/library/go/core/xerrors"
)

type CompressionCodecType int

const (
	GZip CompressionCodecType = iota
	Brotli
	ZStd
)

func Compress(toEncode []byte, codec CompressionCodecType) ([]byte, error) {
	return compressImpl(toEncode, codec, nil)
}

func CompressLevel(toEncode []byte, codec CompressionCodecType, level int) ([]byte, error) {
	return compressImpl(toEncode, codec, &level)
}

func compressImpl(toEncode []byte, codec CompressionCodecType, level *int) ([]byte, error) {
	encoded := bytes.NewBuffer(make([]byte, 0, len(toEncode)/10))

	var w io.WriteCloser
	switch codec {
	case GZip:
		var err error
		if level == nil {
			w = gzip.NewWriter(encoded)
		} else {
			if w, err = gzip.NewWriterLevel(encoded, *level); err != nil {
				return nil, xerrors.Errorf("failed to create gzip-writer: %s", err)
			}
		}
	case Brotli:
		if level == nil {
			w = brotli.NewWriter(encoded)
		} else {
			w = brotli.NewWriterLevel(encoded, *level)
		}
	case ZStd:
		if level == nil {
			w = zstd.NewCompressionCodec().NewWriter(encoded)
		} else {
			w = zstd.NewCompressionCodecWith(*level).NewWriter(encoded)
		}
	default:
		return nil, xerrors.Errorf("unknown compression codec: %v", codec)
	}

	if _, err := w.Write(toEncode); err != nil {
		return nil, xerrors.Errorf("failed to compress blob: %s", err)
	}
	if err := w.Close(); err != nil {
		return nil, xerrors.Errorf("failed to compress blob: %s", err)
	}

	return encoded.Bytes(), nil
}

func Decompress(expectedLen uint64, encoded []byte, codec CompressionCodecType) ([]byte, error) {
	decoded := make([]byte, expectedLen)
	buf := bytes.NewBuffer(encoded)

	var r io.Reader
	switch codec {
	case GZip:
		var err error
		if r, err = gzip.NewReader(buf); err != nil {
			return nil, xerrors.Errorf("failed to create gzip-reader: %s", err)
		}
	case Brotli:
		r = brotli.NewReader(buf)
	case ZStd:
		r = zstd.NewCompressionCodec().NewReader(buf)
	default:
		return nil, xerrors.Errorf("unknown compression codec: %v", codec)
	}

	_, err := io.ReadFull(r, decoded)
	if err != nil {
		return nil, xerrors.Errorf("failed to decompress blob: %w", err)
	}

	return decoded, nil
}
