package etcd3

import (
	"path"
	"strings"

	"a.yandex-team.ru/infra/nanny2/pkg/storage"

	"context"
	"fmt"
	"strconv"
	"time"

	"go.etcd.io/etcd/clientv3"
	"go.etcd.io/etcd/mvcc/mvccpb"
)

type store struct {
	client     *clientv3.Client
	creator    storage.ObjectCreator
	pathPrefix string
	watcher    *watcher
	codec      storage.Codec
}

func formatVersion(rev int64) string {
	return strconv.FormatUint(uint64(rev), 10)
}

func parseVersion(version string) int64 {
	rev, err := strconv.ParseInt(version, 10, 64)
	if err != nil {
		panic(fmt.Sprintf("Failed to parse %s", version))
	}
	return rev
}

func NewStore(c *clientv3.Client, creator storage.ObjectCreator, prefix string, codec storage.Codec) storage.Interface {
	return newStore(c, creator, prefix, codec)
}

func newStore(c *clientv3.Client, creator storage.ObjectCreator, prefix string, codec storage.Codec) *store {
	return &store{
		client:     c,
		creator:    creator,
		pathPrefix: prefix,
		watcher:    newWatcher(c, creator, codec),
		codec:      codec,
	}
}

// Get implements storage.Interface.Get.
func (s *store) Get(ctx context.Context, key string, out storage.Storable) error {
	key = keyWithPrefix(s.pathPrefix, key)
	getResp, err := s.client.KV.Get(ctx, key)
	if err != nil {
		return err
	}
	if len(getResp.Kvs) == 0 {
		return storage.NewKeyNotFoundError(key, 0)
	}
	kv := getResp.Kvs[0]
	return decode(s.creator, s.codec, kv.Value, out, kv.ModRevision)
}

// Create implements storage.Interface.Create.
func (s *store) Create(ctx context.Context, key string, obj storage.Storable, out storage.Storable) error {
	data, err := s.codec.Encode(obj)
	if err != nil {
		return err
	}
	key = keyWithPrefix(s.pathPrefix, key)

	txnResp, err := s.client.KV.Txn(ctx).If(
		notFound(key),
	).Then(
		clientv3.OpPut(key, string(data)),
	).Else(
		clientv3.OpGet(key),
	).Commit()
	if err != nil {
		return err
	}
	if !txnResp.Succeeded {
		getResp := txnResp.Responses[0].GetResponseRange()
		kv := getResp.Kvs[0]
		if out != nil {
			err = decode(s.creator, s.codec, kv.Value, out, kv.ModRevision)
			if err != nil {
				return err
			}
		}
		return storage.NewKeyExistsError(key, txnResp.Header.Revision)
	}
	// Set version in object
	s.creator.SetVersion(obj, formatVersion(txnResp.Header.Revision))
	return nil
}

func (s *store) UpdateIfMatch(ctx context.Context, key string, obj storage.Storable, version string, out storage.Storable) error {
	data, err := s.codec.Encode(obj)
	if err != nil {
		return err
	}
	key = keyWithPrefix(s.pathPrefix, key)
	rev := parseVersion(version)
	// Try to update object if versions match, otherwise - get current version
	txnResp, err := s.client.KV.Txn(ctx).If(
		clientv3.Compare(clientv3.ModRevision(key), "=", rev),
	).Then(
		clientv3.OpPut(key, string(data)),
	).Else(
		clientv3.OpGet(key),
	).Commit()
	if err != nil {
		return err
	}
	if !txnResp.Succeeded {
		getResp := (*clientv3.GetResponse)(txnResp.Responses[0].GetResponseRange())
		if len(getResp.Kvs) == 0 {
			return storage.NewKeyNotFoundError(key, rev)
		}
		// Get data and load it into out
		kv := getResp.Kvs[0]
		if err := decode(s.creator, s.codec, kv.Value, out, kv.ModRevision); err != nil {
			return err
		}
		return storage.NewResourceVersionConflictsError(key, rev)
	}
	return nil
}

func (s *store) GuaranteedUpdate(ctx context.Context, key string, tryUpdate storage.UpdateFunc) error {
	key = keyWithPrefix(s.pathPrefix, key)
	// Load object fast
	getResp, err := s.client.KV.Get(ctx, key, clientv3.WithSerializable())
	if err != nil {
		return err
	}
	if len(getResp.Kvs) == 0 {
		return storage.NewKeyNotFoundError(key, 0)
	}
	kv := getResp.Kvs[0]
	for {
		m := s.creator.New()
		if err := decode(s.creator, s.codec, kv.Value, m, kv.ModRevision); err != nil {
			return err
		}
		// Call user update function
		if err := tryUpdate(m); err != nil {
			return err
		}
		// Try to update
		data, err := s.codec.Encode(m)
		if err != nil {
			return err
		}
		txnResp, err := s.client.KV.Txn(ctx).If(
			clientv3.Compare(clientv3.ModRevision(key), "=", kv.ModRevision),
		).Then(
			clientv3.OpPut(key, string(data)),
		).Else(
			clientv3.OpGet(key),
		).Commit()
		if err != nil {
			return err
		}
		// If update failed because of conflict - extract current value and try again
		if !txnResp.Succeeded {
			getResp := (*clientv3.GetResponse)(txnResp.Responses[0].GetResponseRange())
			if len(getResp.Kvs) == 0 {
				return storage.NewKeyNotFoundError(key, kv.ModRevision)
			}
			kv = getResp.Kvs[0]
		} else {
			break
		}
	}
	return nil
}

func (s *store) Update(ctx context.Context, key string, obj storage.Storable) error {
	data, err := s.codec.Encode(obj)
	if err != nil {
		return err
	}
	key = keyWithPrefix(s.pathPrefix, key)

	putResp, err := s.client.KV.Put(ctx, key, string(data))
	if err != nil {
		return err
	}
	// Update object version
	s.creator.SetVersion(obj, formatVersion(putResp.Header.Revision))
	return nil
}

// Delete implements storage.Interface.Delete.
func (s *store) Delete(ctx context.Context, key string, out storage.Storable) error {
	key = keyWithPrefix(s.pathPrefix, key)
	return s.unconditionalDelete(ctx, key, out)
}

func (s *store) unconditionalDelete(ctx context.Context, key string, out storage.Storable) error {
	// We need to do get and delete in single transaction in order to
	// know the value and revision before deleting it.
	txnResp, err := s.client.KV.Txn(ctx).If().Then(
		clientv3.OpGet(key),
		clientv3.OpDelete(key),
	).Commit()
	if err != nil {
		return err
	}
	getResp := txnResp.Responses[0].GetResponseRange()
	if len(getResp.Kvs) == 0 {
		return storage.NewKeyNotFoundError(key, 0)
	}
	kv := getResp.Kvs[0]
	if out != nil {
		return decode(s.creator, s.codec, kv.Value, out, kv.ModRevision)
	}
	return nil
}

// List implements storage.Interface.List.
func (s *store) List(ctx context.Context) ([]storage.Storable, error) {
	key := s.pathPrefix
	// We need to make sure the key ended with "/" so that we only get children "directories".
	// e.g. if we have key "/a", "/a/b", "/ab", getting keys with prefix "/a" will return all three,
	// while with prefix "/a/" will return only "/a/b" which is the correct answer.
	if !strings.HasSuffix(key, "/") {
		key += "/"
	}
	getResp, err := s.client.KV.Get(ctx, key, clientv3.WithPrefix())
	if err != nil {
		return nil, err
	}
	objList, err := decodeList(getResp.Kvs, s.creator, s.codec)
	if err != nil {
		return nil, err
	}
	rev := formatVersion(getResp.Header.Revision)
	// update version with cluster level revision
	for _, obj := range objList {
		s.creator.SetVersion(obj, rev)
	}
	return objList, nil
}

// Watch implements storage.Interface.Watch.
func (s *store) Watch(ctx context.Context, key string, resourceVersion string) (storage.WatchInterface, error) {
	return s.watch(ctx, key, resourceVersion, false)
}

// WatchList implements storage.Interface.WatchList.
func (s *store) WatchList(ctx context.Context, resourceVersion string) (storage.WatchInterface, error) {
	return s.watch(ctx, "", resourceVersion, true)
}

func (s *store) Status(ctx context.Context) error {
	e := s.client.Endpoints()
	statusCtx, cancel := context.WithTimeout(ctx, 5000*time.Millisecond)
	defer cancel()
	_, err := s.client.Maintenance.Status(statusCtx, e[0])
	return err
}

func (s *store) watch(ctx context.Context, key string, rv string, recursive bool) (storage.WatchInterface, error) {
	var rev uint64 = 0
	var err error = nil
	if rv != "" {
		rev, err = strconv.ParseUint(rv, 10, 64)
	}
	if err != nil {
		return nil, err
	}
	key = keyWithPrefix(s.pathPrefix, key)
	return s.watcher.Watch(ctx, key, int64(rev), recursive)
}

func keyWithPrefix(prefix, key string) string {
	if strings.HasPrefix(key, prefix) {
		return key
	}
	return path.Join(prefix, key)
}

// decode decodes value of bytes into object. It will also set the object resource version to rev.
// On success, objPtr would be set to the object.
func decode(creator storage.ObjectCreator, codec storage.Codec, value []byte, obj storage.Storable, rev int64) error {
	err := codec.Decode(value, obj)
	if err != nil {
		return err
	}
	creator.SetVersion(obj, formatVersion(rev))
	return nil
}

// decodeList decodes a list of values into a list of objects, with resource version set to corresponding rev.
func decodeList(kvs []*mvccpb.KeyValue, creator storage.ObjectCreator, codec storage.Codec) ([]storage.Storable, error) {
	objList := make([]storage.Storable, len(kvs))
	for i, kv := range kvs {
		obj := creator.New()
		err := codec.Decode(kv.Value, obj)
		if err != nil {
			return nil, err
		}
		creator.SetVersion(obj, formatVersion(kv.ModRevision))
		objList[i] = obj
	}
	return objList, nil
}

func notFound(key string) clientv3.Cmp {
	return clientv3.Compare(clientv3.ModRevision(key), "=", 0)
}
