package functions

import (
	"context"

	"code.justin.tv/safety/datastore/interfaces"

	"code.justin.tv/safety/datastore/models"
	"code.justin.tv/safety/datastore/models/report/audit"
	"code.justin.tv/safety/datastore/models/report/status"

	"github.com/pkg/errors"
	"golang.org/x/sync/errgroup"
)

// AutoResolveReports auto resolves all the reports that matches the given rule
func AutoResolveReports(ctx context.Context, tx interfaces.Transaction, autoResolve *models.AutoResolve) (int, error) {
	reportsProcessed := 0
	batchSize := uint64(1000)
	for {
		reports, _, err := tx.ReportPage(ctx, autoResolveToReportFilter(autoResolve), nil, batchSize, 0)
		if err != nil {
			return 0, errors.WithMessage(err, "fetching reports")
		}

		reportIDs := make([]int64, 0, len(reports))
		for _, report := range reports {
			reportIDs = append(reportIDs, report.ID)
		}

		holds, err := tx.ActiveReportHolds(ctx, reportIDs)
		if err != nil {
			return 0, errors.WithMessage(err, "fetching report holds")
		}

		reportIDHoldMap := make(map[int64]struct{}, len(holds))
		for _, hold := range holds {
			reportIDHoldMap[hold.ReportID] = struct{}{}
		}

		wg, ctx := errgroup.WithContext(ctx)
		for _, report := range reports {
			// Skip any report on hold
			if _, ok := reportIDHoldMap[report.ID]; ok {
				continue
			}

			oldReport := report
			wg.Go(func() error {
				newReport := *oldReport
				newReport.AssignedTo = autoResolve.CreatedBy
				newReport.Status = strPtr(status.Resolved)
				diff, err := models.AuditDiff(oldReport, &newReport)
				if err != nil {
					return errors.WithMessage(err, "computing diffs between reports")
				}

				action := audit.AutoResolve
				_, err = tx.CreateReportAudit(ctx, models.ReportAudit{
					Action:        &action,
					ActionBy:      autoResolve.CreatedBy,
					AutoResolveID: &autoResolve.ID,
					ReportID:      &newReport.ID,
					Diff:          &diff,
				})
				if err != nil {
					return errors.WithMessage(err, "creating report audits")
				}

				err = tx.UpdateReport(ctx, newReport)
				if err != nil {
					return errors.WithMessage(err, "updating report")
				}

				return nil
			})
		}

		// Wait for all the update finish
		err = wg.Wait()
		if err != nil {
			return 0, err
		}

		reportsProcessed += len(reports)

		if uint64(len(reports)) < batchSize {
			break
		}
	}

	return reportsProcessed, nil
}

// TODO replace negative id with something more sensible by SAFETY-1523
func autoResolveUserToUserID(u *models.AutoResolveUser) *int {
	if u == nil {
		return nil
	}
	if u.OperationType == models.AutoResolveUserIs {
		return &u.UserID
	}
	id := u.UserID * -1
	return &id
}

func autoResolveToReportFilter(autoResolve *models.AutoResolve) *models.ReportFilter {
	filter := &models.ReportFilter{
		Status: &models.Filter{
			Type:  models.FilterEqual,
			Value: status.New,
		},
	}

	if autoResolve.Content != nil {
		filter.Content = &models.Filter{
			Type:  models.FilterEqual,
			Value: autoResolve.Content,
		}
	}

	if autoResolve.Reason != nil {
		filter.Reason = &models.Filter{
			Type:  models.FilterEqual,
			Value: autoResolve.Reason,
		}
	}

	if autoResolve.FromUserID != nil {
		userID := *autoResolve.FromUserID
		// TODO replace negative id with something more sensible by SAFETY-1523
		if userID < 0 {
			filter.FromUserID = &models.Filter{
				Type:  models.FilterNotEqual,
				Value: userID * -1,
			}
		} else {
			filter.FromUserID = &models.Filter{
				Type:  models.FilterEqual,
				Value: userID,
			}
		}
	}

	if autoResolve.TargetUserID != nil {
		userID := *autoResolve.TargetUserID
		// TODO replace negative id with something more sensible by SAFETY-1523
		if userID < 0 {
			filter.TargetUserID = &models.Filter{
				Type:  models.FilterNotEqual,
				Value: userID * -1,
			}
		} else {
			filter.TargetUserID = &models.Filter{
				Type:  models.FilterEqual,
				Value: userID,
			}
		}
	}

	return filter
}

func strPtr(s string) *string {
	return &s
}
