package etcd3

import (
	"strings"
	"sync"

	//"a.yandex-team.ru/infra/nanny2/pkg/api"
	"context"
	"net/http"
	"strconv"

	"go.etcd.io/etcd/clientv3"
	"go.etcd.io/etcd/etcdserver/api/v3rpc/rpctypes"
	"google.golang.org/grpc"
	"google.golang.org/grpc/codes"

	"a.yandex-team.ru/infra/nanny2/pkg/log"
	"a.yandex-team.ru/infra/nanny2/pkg/storage"
	pb "a.yandex-team.ru/yp/go/proto/hq"
)

const (
	// We have set a buffer in order to reduce times of context switches.
	incomingBufSize = 500
	outgoingBufSize = 100
)

type watcher struct {
	client  *clientv3.Client
	creator storage.ObjectCreator
	codec   storage.Codec
}

// watchChan implements watch.Interface.
type watchChan struct {
	watcher           *watcher
	codec             storage.Codec
	key               string
	initialRev        int64
	recursive         bool
	ctx               context.Context
	cancel            context.CancelFunc
	incomingEventChan chan *event
	resultChan        chan *storage.Event
	errChan           chan error
}

func newWatcher(client *clientv3.Client, creator storage.ObjectCreator, codec storage.Codec) *watcher {
	return &watcher{
		client:  client,
		creator: creator,
		codec:   codec,
	}
}

// Watch watches on a key and returns a watch.Interface that transfers relevant notifications.
// If rev is zero, it will return the existing object(s) and then start watching from
// the maximum revision+1 from returned objects.
// If rev is non-zero, it will watch events happened after given revision.
// If recursive is false, it watches on given key.
// If recursive is true, it watches any children and directories under the key, excluding the root key itself.
func (w *watcher) Watch(ctx context.Context, key string, rev int64, recursive bool) (storage.WatchInterface, error) {
	if recursive && !strings.HasSuffix(key, "/") {
		key += "/"
	}
	wc := w.createWatchChan(ctx, key, rev, recursive)
	go wc.run()
	return wc, nil
}

func (w *watcher) createWatchChan(ctx context.Context, key string, rev int64, recursive bool) *watchChan {
	wc := &watchChan{
		watcher:           w,
		key:               key,
		initialRev:        rev,
		recursive:         recursive,
		incomingEventChan: make(chan *event, incomingBufSize),
		resultChan:        make(chan *storage.Event, outgoingBufSize),
		errChan:           make(chan error, 1),
	}
	wc.ctx, wc.cancel = context.WithCancel(ctx)
	return wc
}

func (wc *watchChan) run() {
	watchClosedCh := make(chan struct{})
	go wc.startWatching(watchClosedCh)

	var resultChanWG sync.WaitGroup
	resultChanWG.Add(1)
	go wc.processEvent(&resultChanWG)

	select {
	case err := <-wc.errChan:
		if err == context.Canceled {
			break
		}
		errResult := parseError(err)
		if errResult != nil {
			// error result is guaranteed to be received by user before closing ResultChan.
			select {
			case wc.resultChan <- errResult:
			case <-wc.ctx.Done(): // user has given up all results
			}
		}
	case <-watchClosedCh:
	case <-wc.ctx.Done():
	}
	// We use wc.ctx to reap all goroutines. Under whatever condition, we should stop them all.
	// It's fine to double cancel.
	wc.cancel()
	// We need to wait until resultChan wouldn't be sent to anymore
	resultChanWG.Wait()
	close(wc.resultChan)
}

func (wc *watchChan) Stop() {
	wc.cancel()
}

func (wc *watchChan) ResultChan() <-chan *storage.Event {
	return wc.resultChan
}

// sync tries to retrieve existing data and send them to process.
// The revision to watch will be set to the revision in response.
func (wc *watchChan) sync() error {
	var opts []clientv3.OpOption
	if wc.recursive {
		opts = append(opts, clientv3.WithPrefix())
	}
	getResp, err := wc.watcher.client.Get(wc.ctx, wc.key, opts...)
	if err != nil {
		return err
	}
	wc.initialRev = getResp.Header.Revision
	for _, kv := range getResp.Kvs {
		wc.sendEvent(parseKV(kv))
	}
	return nil
}

// startWatching does:
// - get current objects if initialRev=0; set initialRev to current rev
// - watch on given key and send events to process.
func (wc *watchChan) startWatching(watchClosedCh chan struct{}) {
	if wc.initialRev == 0 {
		if err := wc.sync(); err != nil {
			log.Errorf("failed to sync with latest state: %s", err.Error())
			wc.sendError(err)
			return
		}
	}
	opts := []clientv3.OpOption{clientv3.WithRev(wc.initialRev + 1)}
	if wc.recursive {
		opts = append(opts, clientv3.WithPrefix())
	}
	wch := wc.watcher.client.Watch(wc.ctx, wc.key, opts...)
	for wres := range wch {
		if wres.Err() != nil {
			err := wres.Err()
			log.Errorf("watch chan error: %s", err.Error())
			// If there is an error on server (e.g. compaction), the channel will return it before closed.
			wc.sendError(wres.Err())
			return
		}
		for _, e := range wres.Events {
			wc.sendEvent(parseEvent(e))
		}
	}
	// When we come to this point, it's only possible that client side ends the watch.
	// e.g. cancel the context, close the client.
	// If this watch chan is broken and context isn't cancelled, other goroutines will still hang.
	// We should notify the main thread that this goroutine has exited.
	close(watchClosedCh)
}

// processEvent processes events from etcd watcher and sends results to resultChan.
func (wc *watchChan) processEvent(wg *sync.WaitGroup) {
	defer wg.Done()

	for {
		select {
		case e := <-wc.incomingEventChan:
			res := wc.transform(e)
			if res == nil {
				continue
			}
			// If user couldn't receive results fast enough, we also block incoming events from watcher.
			// Because storing events in local will cause more memory usage.
			// The worst case would be closing the fast watcher.
			select {
			case wc.resultChan <- res:
			case <-wc.ctx.Done():
				return
			}
		case <-wc.ctx.Done():
			return
		}
	}
}

// transform transforms an event into a result for user if not filtered.
func (wc *watchChan) transform(e *event) (res *storage.Event) {
	curObj, oldObj, err := prepareObjs(wc.ctx, e, wc.watcher.client, wc.watcher.creator, wc.watcher.codec)
	if err != nil {
		wc.sendError(err)
		return nil
	}

	switch {
	case e.isDeleted:
		res = &storage.Event{
			Type:     storage.Deleted,
			Object:   oldObj,
			Revision: strconv.FormatInt(e.rev, 10),
		}
	case e.isCreated:
		res = &storage.Event{
			Type:   storage.Added,
			Object: curObj,
		}
	default:
		res = &storage.Event{
			Type:   storage.Modified,
			Object: curObj,
		}

	}
	return res
}

func parseError(err error) *storage.Event {
	var status *pb.Status
	switch {
	case err == rpctypes.ErrCompacted:
		status = &pb.Status{
			Status:  "InternalError",
			Message: err.Error(),
			Code:    http.StatusGone,
			Reason:  "Expired",
		}
	default:
		status = &pb.Status{
			Status:  "InternalError",
			Message: err.Error(),
			Code:    http.StatusInternalServerError,
			Reason:  "InternalError",
		}
	}

	return &storage.Event{
		Type:   storage.Error,
		Status: status,
	}
}

func (wc *watchChan) sendError(err error) {
	// Context.canceled is an expected behavior.
	// We should just stop all goroutines in watchChan without returning error.
	// TODO: etcd client should return context.Canceled instead of grpc specific error.
	//nolint:SA1019
	if grpc.Code(err) == codes.Canceled || err == context.Canceled {
		return
	}
	select {
	case wc.errChan <- err:
	case <-wc.ctx.Done():
	}
}

func (wc *watchChan) sendEvent(e *event) {
	if len(wc.incomingEventChan) == incomingBufSize {
		log.Infof("Fast watcher, slow processing. Number of buffered events: %d."+
			"Probably caused by slow decoding, user not receiving fast, or other processing logic",
			incomingBufSize)
	}
	select {
	case wc.incomingEventChan <- e:
	case <-wc.ctx.Done():
	}
}

func prepareObjs(ctx context.Context, e *event, client *clientv3.Client, creator storage.ObjectCreator, codec storage.Codec) (curObj storage.Storable, oldObj storage.Storable, err error) {
	if !e.isDeleted {
		curObj, err = decodeObj(creator, codec, e.value, e.rev)
		if err != nil {
			return nil, nil, err
		}
		return curObj, nil, nil
	}
	if e.isDeleted {
		getResp, err := client.Get(ctx, e.key, clientv3.WithRev(e.rev-1))
		if err != nil {
			return nil, nil, err
		}
		// Note that this sends the *old* object with the etcd revision for the time at
		// which it gets deleted.
		// We assume old object is returned only in Deleted event. Users need
		// to have larger than previous rev to tell the ordering.
		oldObj, err = decodeObj(creator, codec, getResp.Kvs[0].Value, e.rev)
		if err != nil {
			return nil, nil, err
		}
	}
	return curObj, oldObj, nil
}

func decodeObj(creator storage.ObjectCreator, codec storage.Codec, data []byte, rev int64) (storage.Storable, error) {
	obj := creator.New()
	err := codec.Decode(data, obj)
	if err != nil {
		return nil, err
	}
	// Ensure resource version is set on the object we load from etcd
	creator.SetVersion(obj, strconv.FormatUint(uint64(rev), 10))
	return obj, nil
}
