package ranker

import (
	"sync"
	"time"

	"code.justin.tv/feeds/clients/feeddataflow"
	"code.justin.tv/feeds/feeds-common/entity"
	"code.justin.tv/feeds/service-common"
	"golang.org/x/net/context"
	"golang.org/x/sync/errgroup"
)

// SerialTraitLoader loads traits as the requests come in, via lots of goroutines
type SerialTraitLoader struct {
	ActorRelationshipTraitLoader ActorRelationshipTraitLoader
	ActorTraitLoader             ActorTraitLoader
	EntityTraitLoader            EntityTraitLoader
	FeedTraitLoader              FeedTraitLoader
	Stats                        *service_common.StatSender
}

func (t *SerialTraitLoader) setupFeedLoading(egCtx context.Context, eg *errgroup.Group, tp *StoryBatchTraits, storyActors []entity.Entity, feedIDs []string, metadata *feeddataflow.Metadata) {
	if len(storyActors) > len(feedIDs) {
		t.setupFeedLoadingByFeedID(egCtx, eg, tp, storyActors, feedIDs, metadata)
	} else {
		t.setupFeedLoadingByActorID(egCtx, eg, tp, storyActors, feedIDs, metadata)
	}
}

func (t *SerialTraitLoader) setupFeedLoadingByActorID(egCtxOriginal context.Context, origineg *errgroup.Group, tp *StoryBatchTraits, storyActors []entity.Entity, feedIDs []string, metadata *feeddataflow.Metadata) {
	origineg.Go(func() error {
		eg, egCtx := errgroup.WithContext(egCtxOriginal)
		var mu sync.Mutex
		traitsPerFeed := make(map[string]*FeedTraits, len(feedIDs))

		// Need feed traits first to load by actor
		for _, feedID := range feedIDs {
			feedID := feedID
			eg.Go(func() error {
				ftraits, err := t.FeedTraitLoader.ForFeed(egCtx, feedID)
				if err != nil {
					return err
				}
				tp.setFeedTraits(feedID, ftraits)
				mu.Lock()
				defer mu.Unlock()
				traitsPerFeed[feedID] = ftraits
				return nil
			})
		}
		err := eg.Wait()
		if err != nil {
			return err
		}
		eg, egCtx = errgroup.WithContext(egCtxOriginal)

		feedOwners := make([]entity.Entity, len(feedIDs))
		for idx, feedID := range feedIDs {
			traitsForFeed := traitsPerFeed[feedID]
			feedOwners[idx] = traitsForFeed.Owner
		}

		for _, storyActor := range storyActors {
			storyActor := storyActor
			eg.Go(func() error {
				allTraits, err := t.ActorRelationshipTraitLoader.ForMultipleFromActors(egCtx, feedOwners, storyActor, metadata)
				if err != nil {
					return err
				}
				for idx, feedOwner := range feedOwners {
					tp.addActorRelationshipTraits(feedOwner, storyActor, &allTraits[idx])
				}
				return nil
			})
		}
		return eg.Wait()
	})
}

func (t *SerialTraitLoader) setupFeedLoadingByFeedID(egCtx context.Context, eg *errgroup.Group, tp *StoryBatchTraits, storyActors []entity.Entity, feedIDs []string, metadata *feeddataflow.Metadata) {
	for _, feedID := range feedIDs {
		feedID := feedID
		eg.Go(func() error {
			ftraits, err := t.FeedTraitLoader.ForFeed(egCtx, feedID)
			if err != nil {
				return err
			}
			tp.setFeedTraits(feedID, ftraits)
			if ftraits.Owner.ID() != "" {
				allTraits, err := t.ActorRelationshipTraitLoader.ForMultipleToActors(egCtx, ftraits.Owner, storyActors, metadata)
				if err != nil {
					return err
				}
				for idx, actor := range storyActors {
					tp.addActorRelationshipTraits(ftraits.Owner, actor, &allTraits[idx])
				}
			}
			return nil
		})
	}
}

// LoadTraits loads all story batch traits for a single story batch
func (t *SerialTraitLoader) LoadTraits(ctx context.Context, s *StoryBatch) (_ *StoryBatchTraits, retErr error) {
	defer func(startTime time.Time) {
		if retErr != nil {
			t.Stats.IncC("load_traits_err", 1, .1)
		}
		t.Stats.TimingDurationC("load_traits", time.Since(startTime), .1)
	}(time.Now())
	tp := StoryBatchTraits{}

	eg, egCtx := errgroup.WithContext(ctx)
	storyActors := s.uniqueActors()
	for _, actor := range storyActors {
		actor := actor
		eg.Go(func() error {
			userTraits, err := t.ActorTraitLoader.ForActor(egCtx, actor)
			if err != nil {
				return err
			}
			tp.setActor(actor, userTraits)
			return nil
		})
	}
	for _, story := range s.Stories {
		story := story
		eg.Go(func() error {
			et, err := t.EntityTraitLoader.ForEntity(egCtx, story.Entity, s.Metadata)
			if err != nil {
				return err
			}
			tp.setEntity(story.Entity, et)
			return nil
		})
	}

	t.setupFeedLoading(egCtx, eg, &tp, storyActors, s.FeedIDs, s.Metadata)

	if err := eg.Wait(); err != nil {
		return nil, err
	}
	return &tp, nil
}
