package cachedresolver

import (
	"context"
	"errors"

	"a.yandex-team.ru/infra/yp_service_discovery/golang/resolver"
	"a.yandex-team.ru/library/go/core/log"
	"a.yandex-team.ru/library/go/core/log/nop"
	"a.yandex-team.ru/library/go/core/xerrors"
)

var _ resolver.Resolver = new(CachedResolver)

type CachedResolver struct {
	resolver resolver.Resolver
	cache    resolver.Cacher
	logger   log.Structured

	returnStale bool
}

// New returns new instance of CachedResolver.
// It takes advantage from responses caching to reduce network requests
// and reduce errors chance if stale cache usage enabled.
func New(r resolver.Resolver, c resolver.Cacher, opts ...ResolverOpt) (*CachedResolver, error) {
	cr := &CachedResolver{
		resolver: r,
		cache:    c,
		logger:   new(nop.Logger),
	}

	for _, opt := range opts {
		opt(cr)
	}

	if r == nil {
		return nil, errors.New("resolver must not be nil")
	}
	if c == nil {
		return nil, errors.New("cacher must not be nil")
	}

	return cr, nil
}

// Close releases all resources by closing underlying resolver and cache.
func (r CachedResolver) Close() error {
	if err := r.resolver.Close(); err != nil {
		return err
	}

	if err := r.cache.Close(); err != nil {
		return err
	}

	return nil
}

// ResolveEndpoints sends request to Service Discovery server to retrieve endpoints information.
func (r CachedResolver) ResolveEndpoints(ctx context.Context, cluster, endpointSet string) (*resolver.ResolveEndpointsResponse, error) {
	// search for response in cache
	key := "endpoint-set@" + endpointSet + "@" + cluster
	value, stale := r.cache.Get(key)
	cached, ok := value.(*resolver.ResolveEndpointsResponse)
	if value != nil && !ok {
		return nil, xerrors.Errorf("invalid cached value type for endpoint-set %q: %T", endpointSet, value)
	}

	// return fresh cache right away
	if cached != nil && !stale {
		r.logger.Debug("serving response from cache",
			log.String("key", key),
			log.Any("response", cached),
		)
		return cached, nil
	}

	resp, err := r.resolver.ResolveEndpoints(ctx, cluster, endpointSet)
	if err != nil {
		// return stale cache if possible in case of error
		if cached != nil && r.returnStale {
			r.logger.Warn("using stale cache due to resolve error",
				log.String("key", key),
				log.Error(err),
			)
			return cached, nil
		}
		return nil, err
	}

	r.logger.Debug("saving response to cache",
		log.String("key", key),
		log.Any("response", resp),
	)

	_ = r.cache.Set(key, resp)
	return resp, nil
}

func (r CachedResolver) ResolvePods(ctx context.Context, cluster, podSet string) (*resolver.ResolvePodsResponse, error) {
	// search for response in cache
	key := "pod-set@" + podSet + "@" + cluster
	value, stale := r.cache.Get(key)
	cached, ok := value.(*resolver.ResolvePodsResponse)
	if value != nil && !ok {
		return nil, xerrors.Errorf("invalid cached value type for pod-set %q: %T", podSet, value)
	}

	// return fresh cache right away
	if cached != nil && !stale {
		r.logger.Debug("serving response from cache",
			log.String("key", key),
			log.Any("response", cached),
		)
		return cached, nil
	}

	resp, err := r.resolver.ResolvePods(ctx, cluster, podSet)
	if err != nil {
		// return stale cache if possible in case of error
		if cached != nil && r.returnStale {
			r.logger.Warn("using stale cache due to resolve error",
				log.String("key", key),
				log.Error(err),
			)
			return cached, nil
		}
		return nil, err
	}

	r.logger.Debug("saving response to cache",
		log.String("key", key),
		log.Any("response", resp),
	)

	_ = r.cache.Set(key, resp)
	return resp, nil
}

func (r CachedResolver) ResolveNode(ctx context.Context, cluster, node string) (*resolver.ResolveNodeResponse, error) {
	// search for response in cache
	key := "node@" + node + "@" + cluster
	value, stale := r.cache.Get(key)
	cached, ok := value.(*resolver.ResolveNodeResponse)
	if value != nil && !ok {
		return nil, xerrors.Errorf("invalid cached value type for node %q: %T", node, value)
	}

	// return fresh cache right away
	if cached != nil && !stale {
		r.logger.Debug("serving response from cache",
			log.String("key", key),
			log.Any("response", cached),
		)
		return cached, nil
	}

	resp, err := r.resolver.ResolveNode(ctx, cluster, node)
	if err != nil {
		// return stale cache if possible in case of error
		if cached != nil && r.returnStale {
			r.logger.Warn("using stale cache due to resolve error",
				log.String("key", key),
				log.Error(err),
			)
			return cached, nil
		}
		return nil, err
	}

	r.logger.Debug("saving response to cache",
		log.String("key", key),
		log.Any("response", resp),
	)

	_ = r.cache.Set(key, resp)
	return resp, nil
}
