package timbersaw

import (
	"context"
	"errors"
	"fmt"
	"io/ioutil"
	"math/rand"
	"net/http"
	"sync"
	"sync/atomic"
	"time"

	trpc "code.justin.tv/amzn/StarfruitTimbersawTwirp"
	telemetry "code.justin.tv/amzn/TwitchTelemetry"
	"code.justin.tv/video/invoker"
	"github.com/golang/protobuf/proto"
	pb "github.com/golang/protobuf/ptypes"

	"github.com/aws/aws-sdk-go/aws"
	"github.com/aws/aws-sdk-go/aws/client"
	"github.com/aws/aws-sdk-go/aws/credentials/stscreds"
	"github.com/aws/aws-sdk-go/aws/session"
	"github.com/aws/aws-sdk-go/service/s3"
)

const (
	timbersawClientAddr = "https://vpce.us-west-2.prod.timbersaw.live-video.a2z.com"
	timbersawTimeout    = 1 * time.Minute

	// ARN of role to assume to read the s3 bucket with default values.
	// TODO: make multi-region aware
	s3RoleARN           = "arn:aws:iam::407654553611:role/defaults-reader"
	s3DefaultPathBucket = "timbersaw-us-west-2-prod-defaults"

	cacheRemoveThreshold = 2 * time.Minute
	cacheMaxNumOfEntries = 100_000 // temporarily setting some cache size limit
	coalescingLimit      = 100     // # channel requests to coalesce in a bulk request to timbersaw
	jitterPercentage     = 0.10    // 10%
	minRefreshInterval   = 15 * time.Second
)

type defaultCache struct {
	originToPaths map[string][]Path
	expiry        time.Time
	mu            sync.RWMutex
}

type Client struct {
	currentPop      string
	invoker         *invoker.Tasks
	timbersawClient trpc.StarfruitTimbersaw
	s3Client        *s3.S3
	sampleReporter  *telemetry.SampleReporter
	onError         func(error)
	onPathFetch     func(newPaths []Path)

	defaultPathCache defaultCache
	channelPathCache *Cache
	pendingFetch     chan CacheKey

	allowedRPS          atomic.Value
	firstDefaultPathTTL time.Duration
}

type Config struct {
	CurrentPop      string                    // the pop that this client's callers live in - eg. 'mia02' or 'lhr05'
	Session         *session.Session          // session that is passed in from the client
	SampleReporter  *telemetry.SampleReporter // Passed in from NydusPR to emit metrics. Does not send metrics if nil.
	OnError         func(error)               // what to do with errors when they occur. default behaviour is to do nothing and ignore them.
	OnPathFetch     func(newPaths []Path)     // here's where you could put a callback function for when new paths are fetched
	S3ClientTimeout time.Duration             // The time we're willing to wait to establish a connection + fetch from S3. Defaults to 5 seconds.
	S3MaxRetries    int                       // The number of times we want to retry our S3 calls. Defaults to 3.
}

func NewClient(conf Config) (*Client, error) {
	if conf.CurrentPop == "" {
		return nil, errors.New("must specify a pop that this client is operating out of")
	}

	// Set up all configuration for S3, which is used for default paths
	creds := stscreds.NewCredentials(conf.Session, s3RoleARN)

	// We create a custom retryer that does not wait before trying again
	customRetryer := client.DefaultRetryer{
		MaxRetryDelay: 0,
		NumMaxRetries: 3,
	}
	if conf.S3MaxRetries > 0 {
		customRetryer.NumMaxRetries = conf.S3MaxRetries
	}

	// We also introduce a custom HTTP client that has a configured timeout on it
	httpClient := &http.Client{Timeout: 5 * time.Second}
	if conf.S3ClientTimeout > 0 {
		httpClient.Timeout = conf.S3ClientTimeout
	}

	s3Client := s3.New(conf.Session, &aws.Config{Credentials: creds, HTTPClient: httpClient, Retryer: customRetryer})

	// Timbersaw used for content based paths
	timbersawClient := trpc.NewStarfruitTimbersawProtobufClient(
		timbersawClientAddr,
		&http.Client{Timeout: timbersawTimeout})

	client := &Client{
		currentPop:      conf.CurrentPop,
		invoker:         invoker.New(),
		timbersawClient: timbersawClient,
		s3Client:        s3Client,
		sampleReporter:  conf.SampleReporter,
		onError:         conf.OnError,
		onPathFetch:     conf.OnPathFetch,

		defaultPathCache: defaultCache{
			originToPaths: make(map[string][]Path),
		},
		channelPathCache: NewCache(cacheMaxNumOfEntries, cacheRemoveThreshold),
		pendingFetch:     make(chan CacheKey, 2*coalescingLimit),
	}

	return client, nil
}

// Bulk fetch timbersaw client paths and add to cache
// returns cache keys that were not able to be fetched and should be fetched again
// note: mangles channelKeys argument, so don't pass in something you care about
func (c *Client) bulkFetchContentPaths(ctx context.Context, channelKeys []CacheKey) {
	if len(channelKeys) < 1 { // If there are no requests queued up, we can exit early
		return
	}

	startTime := time.Now()
	defer func() {
		if c.sampleReporter != nil {
			c.sampleReporter.ReportDurationSample("TimbersawBulkFetchContentPathsDuration", time.Since(startTime))
		}
	}()

	reqList := make([]*trpc.ContentPathToOriginRequest, 0, len(channelKeys))
	for _, key := range channelKeys {
		reqList = append(reqList, &trpc.ContentPathToOriginRequest{
			Channel:       key.ChannelARN,
			ChannelOrigin: key.Origin,
			CurrentPop:    c.currentPop,
		})
	}

	response, err := c.timbersawClient.GetContentPathToOriginBatch(ctx, &trpc.ContentPathToOriginBatchRequest{
		Requests: reqList,
	})
	if err != nil {
		if c.onError != nil {
			c.onError(err)
		}
		return
	}

	now := time.Now() // for setting TTLs later
	for _, resp := range response.Responses {
		channelKey := CacheKey{
			ChannelARN: resp.Channel,
			Origin:     resp.ChannelOrigin,
		}

		paths := resp.GetPaths()
		pathList := make([]Path, 0, len(paths))
		for _, path := range paths {
			pathList = append(pathList, path.GetPopPath())
		}

		ttl, err := pb.Duration(resp.GetTtl())
		if err != nil {
			if c.onError != nil {
				c.onError(err)
			}
			continue // skip adding this path
		}

		c.channelPathCache.Set(channelKey, pathList, now.Add(ttl))

		if c.onPathFetch != nil {
			c.onPathFetch(pathList)
		}
	}
}

// runPathFetcher emits a GetContentPath request to timbersaw at a fixed interval as set by Timbersaw
// It prioritizes new channels and fills the remainder of the request with the oldest cached channels
// A metric is emitted tracking how old the cached channels being requested are to help tune the request interval
// should be called from a goroutine
func (c *Client) runPathFetcher(ctx context.Context) error {
	r := rand.New(rand.NewSource(time.Now().UnixNano()))
	t := time.NewTimer(c.pathFetchWait(r))

	for {
		select {
		case <-ctx.Done():
			t.Stop()
			return ctx.Err()
		case <-t.C:
		}

		// Start the new timer immediately
		t = time.NewTimer(c.pathFetchWait(r))
		channelKeys := make([]CacheKey, 0, coalescingLimit)

		// Load up to coalescingLimit of pending fetches
	LoadPendingFetches:
		for i := 0; i < coalescingLimit; i++ {
			select {
			case key := <-c.pendingFetch:
				channelKeys = append(channelKeys, key)
			default:
				break LoadPendingFetches
			}
		}

		// Fill the rest with pending refreshes
		refreshKeys, refreshAge := c.channelPathCache.NextRefreshKeys(coalescingLimit - len(channelKeys))
		channelKeys = append(channelKeys, refreshKeys...)

		if c.sampleReporter != nil {
			if len(refreshKeys) > 0 {
				// If we're able to refresh some items, track the max time since the last refresh
				// We can use this metric to get a sense for if we need to raise or lower the request/second config
				c.sampleReporter.ReportDurationSample("TimbersawRefreshAge", refreshAge)
			} else if len(channelKeys) > 0 {
				// Here we're fetching content paths, but we're full of new items and can't refresh any existing ones
				// Therefore the refresh was "skipped" - we're fetching but not refreshing
				// This case is expected occassionally, but if it is happening frequently then our cache will grow stale
				// It should be used alongside TimbersawRefreshAge to determine the proper request/second config
				c.sampleReporter.Report("TimbersawRefreshSkipped", 1.0, telemetry.UnitCount)
			} else {
				// In this case we have no items in the cache to refresh and no items to fetch
				// This is expected right after a reboot and never expected any other time
				// If we see this metric consistently reported then something is wrong
				c.sampleReporter.Report("TimbersawRefreshEmpty", 1.0, telemetry.UnitCount)
			}
		}

		// Perform the fetch
		c.bulkFetchContentPaths(ctx, channelKeys)
	}
}

func (c *Client) getFromS3(ctx context.Context, path string, resp proto.Message) error {
	// Response is keyed by the current PoP
	s3Resp, err := c.s3Client.GetObjectWithContext(ctx, &s3.GetObjectInput{
		Bucket: aws.String(s3DefaultPathBucket),
		Key:    aws.String(path),
	})
	if err != nil {
		return err
	}

	// Get body data as a slice of bytes
	defer func() {
		err := s3Resp.Body.Close()
		if err != nil {
			c.onError(err)
		}
	}()
	body, err := ioutil.ReadAll(s3Resp.Body)
	if err != nil {
		return err
	}

	// Unmarshal body to proto struct
	return proto.Unmarshal(body, resp)
}

// fetchDefaultPaths fetches the current default paths from S3 and updates defaultPatchCache.
// Returns a ttl for the fetched default paths.
func (c *Client) fetchDefaultPaths(ctx context.Context) (time.Duration, error) {
	startTime := time.Now()
	defer func() {
		if c.sampleReporter != nil {
			c.sampleReporter.ReportDurationSample("TimbersawFetchDefaultPathsDuration", time.Since(startTime))
		}
	}()
	// Fetch from S3 and add to default Cache
	resp := &trpc.DefaultPathToOriginsResponse{}
	err := c.getFromS3(ctx, fmt.Sprintf("path/%s", c.currentPop), resp)
	if err != nil {
		return time.Duration(0), err
	}

	c.defaultPathCache.mu.Lock()
	defer c.defaultPathCache.mu.Unlock()
	for origin, pathsResp := range resp.GetDefaultPathsByOrigin() {
		// Convert proto generated paths into a slice of Paths
		var pathList []Path
		for _, path := range pathsResp.GetPaths() {
			pathList = append(pathList, path.GetPopPath())
		}

		c.defaultPathCache.originToPaths[origin] = pathList
		if c.onPathFetch != nil {
			c.onPathFetch(pathList)
		}
	}

	// Set expiration for all the default paths
	ttl, err := pb.Duration(resp.GetTtl())
	if err != nil {
		return time.Duration(0), err
	}
	c.defaultPathCache.expiry = time.Now().Add(ttl)

	return ttl, nil
}

// runRefreshDefaultPaths starts a ticker to automatically fetch default paths at some interval
// should be called from a goroutine
func (c *Client) runRefreshDefaultPaths(ctx context.Context) error {
	timer := time.NewTimer(c.firstDefaultPathTTL)

	for {
		select {
		case <-ctx.Done():
			timer.Stop()
			return ctx.Err()
		case <-timer.C:
			ttl, err := c.fetchDefaultPaths(ctx)
			if err != nil && c.onError != nil {
				c.onError(err)
			}
			if ttl < minRefreshInterval {
				ttl = minRefreshInterval
			}

			timer = time.NewTimer(ttl)
		}
	}
}

// fetchDefaultConfig loads configuration from S3
func (c *Client) fetchDefaultConfig(ctx context.Context) error {
	startTime := time.Now()
	defer func() {
		if c.sampleReporter != nil {
			c.sampleReporter.ReportDurationSample("TimbersawFetchDefaultConfigDuration", time.Since(startTime))
		}
	}()
	resp := &trpc.DefaultConfigResponse{}
	err := c.getFromS3(ctx, "config", resp)
	if err != nil {
		return err
	}

	c.allowedRPS.Store(int(resp.RpsPerHost))

	return nil
}

// runRefreshDefaultConfig loads config from S3 every minute
func (c *Client) runRefreshDefaultConfig(ctx context.Context) error {
	ticker := time.NewTicker(1 * time.Minute)
	defer ticker.Stop()

	for {
		select {
		case <-ctx.Done():
			return ctx.Err()
		case <-ticker.C:
			err := c.fetchDefaultConfig(ctx)
			if err != nil && c.onError != nil {
				c.onError(err)
			}
		}
	}
}

func (c *Client) pathFetchWait(r *rand.Rand) time.Duration {
	rps := c.allowedRPS.Load().(int)
	if rps < 1 {
		rps = 1
	}

	averageWait := 1 * time.Second / time.Duration(rps)
	adjustmentRange := float64(averageWait) * jitterPercentage
	return averageWait - time.Duration(adjustmentRange) + time.Duration(2*adjustmentRange*r.Float64())
}

// Init double-checks that the client is set up correctly, by fetching the default config and default paths concurrently.
func (c *Client) Init(ctx context.Context) error {
	invk := invoker.New()
	invk.Add(c.fetchDefaultConfig)
	invk.Add(func(ctx context.Context) (err error) {
		c.firstDefaultPathTTL, err = c.fetchDefaultPaths(ctx)
		return err
	})
	err := invk.Run(ctx)
	if err != nil {
		return err
	}

	return nil
}

// Run starts the background goroutines that the Timbersaw Client relies on.
func (c *Client) Run(ctx context.Context) error {
	c.invoker.Add(c.channelPathCache.Run) // gc for content paths cache
	c.invoker.Add(c.runPathFetcher)
	c.invoker.Add(c.runRefreshDefaultPaths)
	c.invoker.Add(c.runRefreshDefaultConfig)

	return c.invoker.Run(ctx)
}

func (c *Client) GetPaths(channelARN string, origin string) ([]Path, error) {
	key := CacheKey{
		ChannelARN: channelARN,
		Origin:     origin,
	}
	// See if we have channel-specific paths ready to go
	if paths, ok := c.channelPathCache.Get(key); ok {
		return paths, nil
	}

	// If we don't, add it to the list to be fetched next cycle
	select {
	case c.pendingFetch <- key:
	default:
		if c.onError != nil {
			c.onError(fmt.Errorf("pendingFetches buffer full"))
		}
	}

	// Then get the default paths for that origin
	c.defaultPathCache.mu.RLock()
	defer c.defaultPathCache.mu.RUnlock()
	paths, ok := c.defaultPathCache.originToPaths[origin]
	if ok {
		return paths, nil
	}

	return nil, errors.New("no path exists for channel nor origin")
}
