package eds

import (
	"context"
	"fmt"
	"log"
	"strings"
	"time"

	"google.golang.org/grpc"
	"google.golang.org/grpc/codes"
	"google.golang.org/grpc/credentials/insecure"
	"google.golang.org/grpc/keepalive"
	"google.golang.org/grpc/status"

	envoyConfigCore "github.com/envoyproxy/go-control-plane/envoy/config/core/v3"
	envoyConfigEndpoint "github.com/envoyproxy/go-control-plane/envoy/config/endpoint/v3"
	envoyServiceDiscovery "github.com/envoyproxy/go-control-plane/envoy/service/discovery/v3"
	envoyServiceEndpoint "github.com/envoyproxy/go-control-plane/envoy/service/endpoint/v3"

	"a.yandex-team.ru/solomon/libs/go/cache"
	"a.yandex-team.ru/solomon/libs/go/workerpool"
)

// ==========================================================================================

type EdsClient struct {
	timeout time.Duration
	conns   map[string]*grpc.ClientConn
}

func NewEdsClient(timeout time.Duration) (*EdsClient, error) {
	e := &EdsClient{
		timeout: timeout,
		conns:   make(map[string]*grpc.ClientConn),
	}
	return e, nil
}

func (e *EdsClient) Get(pattern string, endpoint string) (map[string]string, error) {
	var ok bool
	var err error
	var conn *grpc.ClientConn

	if conn, ok = e.conns[endpoint]; !ok {
		conn, err = grpc.Dial(endpoint,
			grpc.WithTransportCredentials(insecure.NewCredentials()),
			grpc.WithKeepaliveParams(
				keepalive.ClientParameters{Timeout: e.timeout},
			),
		)
		if err != nil {
			return nil, err
		}
		e.conns[endpoint] = conn
	}
	client := envoyServiceEndpoint.NewEndpointDiscoveryServiceClient(conn)

	ctx, cancel := context.WithTimeout(context.Background(), e.timeout)
	defer cancel()

	resp, err := client.FetchEndpoints(ctx, &envoyServiceDiscovery.DiscoveryRequest{
		Node:          &envoyConfigCore.Node{UserAgentName: "solomon-discovery"},
		ResourceNames: []string{pattern},
	})
	if err != nil {
		if e, ok := status.FromError(err); ok && e.Code() == codes.NotFound {
			return nil, fmt.Errorf("not found")
		}
		return nil, err
	}

	result := map[string]string{}
	for _, res := range resp.Resources {
		var cluster envoyConfigEndpoint.ClusterLoadAssignment
		if err := res.UnmarshalTo(&cluster); err != nil {
			return nil, fmt.Errorf("cannot unpack resource: %v", err)
		}
		for _, endpoint := range cluster.Endpoints {
			for _, lbEndpoint := range endpoint.LbEndpoints {
				ep := lbEndpoint.GetEndpoint()
				if ep == nil {
					return nil, fmt.Errorf("no endpoints")
				}
				addr := ep.GetAddress().GetSocketAddress()
				if addr == nil {
					return nil, fmt.Errorf("no address for %s", ep.Hostname)
				}
				host := strings.TrimSuffix(ep.Hostname, ".")
				result[host] = addr.Address
			}
		}
	}
	return result, nil
}

func (e *EdsClient) Shutdown() {
	for _, conn := range e.conns {
		_ = conn.Close()
	}
}

// ==========================================================================================

type EdsRequest struct {
	Pattern  string
	Endpoint string
}

type EdsData struct {
	Hosts map[string]string
	Error error
}

type workerJob struct {
	Request       *EdsRequest
	Data          *EdsData
	MinLastUpdate *time.Time
}

// ==========================================================================================

// Eds cache. Able to serve stale items and prefetch.
//
// type EdsCache interface {
//     Get(reqs []*EdsRequest) map[EdsRequest]*EdsData
//     Purge()
//     Destroy()
//     Dump() ([]byte, error)
//     Restore([]byte) error
// }
//

type EdsCache struct {
	LogPrefix    string
	VerboseLevel int
	workerpool   *workerpool.WorkerPool
	cache        *cache.Cache
	client       *EdsClient
}

func NewEdsCache(goodCacheTime, badCacheTime, prefetchTime, cleanUpInterval, requestTimeout time.Duration,
	cacheSize, workers int,
	serveStale bool,
	verboseLevel int) (*EdsCache, error) {

	var err error

	c := &EdsCache{
		LogPrefix:    "[eds] ",
		VerboseLevel: verboseLevel,
	}
	c.client, err = NewEdsClient(requestTimeout)
	if err != nil {
		return nil, err
	}

	c.cache = cache.NewCache(
		"eds",
		func(req interface{}) (interface{}, error) {
			return c.cacheEds(req.(EdsRequest))
		},
		goodCacheTime,
		badCacheTime,
		prefetchTime,
		cleanUpInterval,
		serveStale,
		verboseLevel,
		cacheSize,
	)
	c.workerpool = workerpool.NewWorkerPool(
		"eds",
		workers,
		func(req interface{}) {
			c.poolWorker(req.(*workerJob))
		},
		verboseLevel > 1,
	)
	return c, nil
}

func (e *EdsCache) log(lvl int, ts *time.Time, format string, v ...interface{}) {
	if e.VerboseLevel >= lvl {
		tsStr := ""
		if ts != nil {
			tsStr = ", " + time.Since(*ts).String()
		}
		log.Printf(e.LogPrefix+format+tsStr, v...)
	}
}

func (e *EdsCache) cacheEds(req EdsRequest) (map[string]string, error) {
	reqTime := time.Now()
	result, err := e.client.Get(req.Pattern, req.Endpoint)
	if err != nil {
		return nil, err
	}
	e.log(2, &reqTime, "got %d host records for %v", len(result), req)
	return result, nil
}

func (e *EdsCache) poolWorker(job *workerJob) {
	var resp interface{}
	var cacheError error

	if job.MinLastUpdate != nil {
		resp, cacheError = e.cache.GetForceFresh(*job.Request, job.MinLastUpdate)
	} else {
		resp, cacheError = e.cache.Get(*job.Request)
	}
	job.Data.Hosts = resp.(map[string]string)
	job.Data.Error = cacheError
}

func (e *EdsCache) Get(reqs []*EdsRequest, minLastUpdate *time.Time) map[EdsRequest]*EdsData {
	reqTime := time.Now()
	jobs := make([]interface{}, len(reqs))
	result := make(map[EdsRequest]*EdsData, len(reqs))

	for i, req := range reqs {
		d := &EdsData{}
		jobs[i] = &workerJob{
			Request:       req,
			Data:          d,
			MinLastUpdate: minLastUpdate,
		}
		result[*req] = d
	}
	e.workerpool.Do(jobs)
	e.log(2, &reqTime, "unrolled %d records", len(reqs))
	return result
}

func (e *EdsCache) Purge() {
	e.log(1, nil, "purging")
	e.cache.Purge()
}

func (e *EdsCache) Destroy() {
	e.log(1, nil, "destroying")
	e.workerpool.Stop(true)
	e.cache.Destroy()
	e.client.Shutdown()
}

func (e *EdsCache) Dump(onlyFresh bool) ([]byte, error) {
	e.log(1, nil, "dumping %d records (only fresh = %v)", e.cache.Len(), onlyFresh)
	return e.cache.Dump(onlyFresh)
}

func (e *EdsCache) Restore(inData []byte) error {
	e.log(1, nil, "restoring")
	bumpEOL := false
	err := e.cache.Restore(inData, bumpEOL, EdsRequest{}, map[string]string{})
	if err != nil {
		return err
	}
	e.log(1, nil, "restored %d records", e.cache.Len())
	return nil
}
