package server

import (
	"context"
	"fmt"
	"time"

	"code.justin.tv/live/plucky/api/parser"

	"code.justin.tv/live/plucky/api/db"

	"code.justin.tv/creator-collab/log"

	"code.justin.tv/live/plucky/api/rpc"
	"github.com/aws/aws-sdk-go/aws"
	"github.com/aws/aws-sdk-go/service/cloudwatchlogs"
	"google.golang.org/protobuf/types/known/timestamppb"
)

func (s *Server) updateDatabase(ctx context.Context, service *Service, queryRange db.TimeRange, skipPartialQueryRefetch bool) error {
	serviceID := service.ID

	queryRanges, mergedRanges, err := s.getQueryTimeRange(ctx, serviceID, queryRange, skipPartialQueryRefetch)
	if err != nil {
		return err
	}

	logGroups := make([]*string, len(service.LogGroups))
	for i, logGroup := range service.LogGroups {
		logGroups[i] = aws.String(logGroup)
	}

	occurrences := make([]*rpc.Occurrence, 0)

	for _, r := range queryRanges {
		o, err := s.fetchOccurrences(ctx, serviceID, r, logGroups)
		if err != nil {
			return err
		}

		occurrences = append(occurrences, o...)
	}

	err = s.database.InsertOccurrences(ctx, serviceID, occurrences, mergedRanges)
	if err != nil {
		return err
	}

	return nil
}

func (s *Server) fetchOccurrences(ctx context.Context, serviceID string, queryRange db.TimeRange, logGroups []*string) ([]*rpc.Occurrence, error) {
	s.logger.Debug("starting fetchOccurrences", log.Fields{
		"query_range": queryRange.String(),
		"service":     serviceID,
	})
	logStartTime := time.Now()
	defer func() {
		s.logger.Debug("fetchOccurrences", log.Fields{
			"duration": time.Since(logStartTime),
		})
	}()

	query := "fields @timestamp, @message, @log, @logStream \n" +
		"| filter ispresent(fingerprint) \n" +
		"| sort @timestamp desc \n" +
		"| display @timestamp, @message, @log, @logStream \n" +
		"| limit 10000 \n"
	if serviceID == ServiceIDShelfie {
		query = "fields @timestamp, @message, @log, @logStream \n" +
			"| filter ispresent(error) \n" +
			"| sort @timestamp desc \n" +
			"| display @timestamp, @message, @log, @logStream \n" +
			"| limit 10000 \n"
	}

	getQueryResultsOutput, err := s.cwlogsClient.runQuery(ctx, &cloudwatchlogs.StartQueryInput{
		StartTime:     aws.Int64(queryRange.Start.Unix()),
		EndTime:       aws.Int64(queryRange.End.Unix()),
		LogGroupNames: logGroups,
		QueryString:   aws.String(query),
	})
	if err != nil {
		return nil, err
	}

	occurrences := make([]*rpc.Occurrence, 0, len(getQueryResultsOutput.Results))
	for _, fieldGroup := range getQueryResultsOutput.Results {
		fm := resultFieldsToMap(fieldGroup)

		jsonMessage, err := fm.getStringField("@message")
		if err != nil {
			s.logger.Error(err)
			continue
		}

		timestamp, err := fm.getStringField("@timestamp")
		if err != nil {
			s.logger.Error(err)
			continue
		}

		logGroup, err := fm.getStringField("@log")
		if err != nil {
			s.logger.Error(err)
			continue
		}

		logStream, err := fm.getStringField("@logStream")
		if err != nil {
			s.logger.Error(err)
			continue
		}

		queryResult := &parser.QueryResult{
			LogGroup:  logGroup,
			LogStream: logStream,
			Message:   jsonMessage,
			Timestamp: timestamp,
		}
		logEvent, err := parser.ParseCloudWatchLogsInsightsQueryResult(queryResult)
		if err != nil {
			s.logger.Error(err)
			continue
		}

		occurrences = append(occurrences, logEventToOccurrence(serviceID, logEvent))
	}

	s.logger.Debug("finished fetchOccurrences", log.Fields{
		"query_range":     queryRange.String(),
		"service":         serviceID,
		"num_occurrences": len(occurrences),
	})

	return occurrences, nil
}

func logEventToOccurrence(serviceID string, logEvent *parser.LogEvent) *rpc.Occurrence {
	stackTraces := make([]*rpc.StackTrace, 0, len(logEvent.StackTrace))
	for _, s := range logEvent.StackTrace {
		stackTraces = append(stackTraces, &rpc.StackTrace{
			FullFilePath:  s.FullFilePath,
			Line:          s.Line,
			Method:        s.Method,
			ShortFilePath: s.ShortFilePath,
			BrazilPkgName: s.BrazilPackageName,
			Relevant:      s.Relevant,
			LineUrl:       s.LineURL,
			PkgUrl:        s.PackageURL,
		})
	}

	ts := timestamppb.New(logEvent.Timestamp)

	return &rpc.Occurrence{
		OccurrenceId: logEvent.LogEventID,
		Service:      serviceID,
		Message:      logEvent.Message,
		RawJson:      logEvent.Raw,
		StackTrace:   stackTraces,
		RequestId:    logEvent.RequestID,
		Operation:    logEvent.Operation,
		Timestamp:    ts,
		Fingerprint:  logEvent.Fingerprint,
		Level:        logEvent.Level,
	}
}

func (s *Server) getQueryTimeRange(ctx context.Context, serviceID string, queryRange db.TimeRange, skipPartialDataRefetch bool) (
	[]db.TimeRange, *db.CachedTimeRanges, error) {

	cachedRanges, err := s.database.GetCachedTimeRanges(ctx, serviceID)
	if err != nil {
		return nil, nil, err
	}

	s.logger.Debug("cachedRanges.Real:")
	for _, qr := range cachedRanges.Real {
		s.logger.Debug(qr.String())
	}
	s.logger.Debug("cachedRanges.Safe:")
	for _, qr := range cachedRanges.Safe {
		s.logger.Debug(qr.String())
	}

	queryRanges := []db.TimeRange{queryRange}

	for _, cachedRange := range cachedRanges.Real {
		currentQueryRanges := queryRanges
		queryRanges = make([]db.TimeRange, 0, len(currentQueryRanges))

		for _, queryRange := range currentQueryRanges {
			newQueryRanges := queryRange.RemoveOverlaps(cachedRange)
			queryRanges = append(queryRanges, newQueryRanges...)
		}
	}

	s.logger.Debug("filtered queryRanges:", log.Fields{
		"skip_partial_data_refetch": skipPartialDataRefetch,
	})
	for _, qr := range queryRanges {
		s.logger.Debug(qr.String())
	}

	if len(queryRanges) == 0 && skipPartialDataRefetch {
		s.logger.Debug("getQueryTimeRange - no query")
		return queryRanges, cachedRanges, nil
	}

	queryRanges = []db.TimeRange{queryRange}
	for _, cachedRange := range cachedRanges.Safe {
		currentQueryRanges := queryRanges
		queryRanges = make([]db.TimeRange, 0, len(currentQueryRanges))

		for _, queryRange := range currentQueryRanges {
			newQueryRanges := queryRange.RemoveOverlaps(cachedRange)
			queryRanges = append(queryRanges, newQueryRanges...)
		}
	}

	mergedRealRanges := []db.TimeRange{queryRange}
	for _, cachedRange := range cachedRanges.Real {
		currentMergedRanges := mergedRealRanges
		mergedRealRanges = make([]db.TimeRange, 0, len(currentMergedRanges))

		for _, queryRange := range currentMergedRanges {
			newMergedRanges := queryRange.MergeIfOverlaps(cachedRange)
			mergedRealRanges = append(mergedRealRanges, newMergedRanges...)
		}
	}

	safeEnd := s.clock.NowUTC().Add(-2 * time.Minute)
	s.logger.Debug(fmt.Sprintf("getQueryTimeRange - safeEnd - %s", safeEnd.String()))
	safeQueryRange := queryRange
	if trimmed := queryRange.TrimEnd(safeEnd); trimmed != nil {
		safeQueryRange = *trimmed
	}
	s.logger.Debug(fmt.Sprintf("getQueryTimeRange - safeQueryRange - %s", safeQueryRange.String()))

	mergedSafeRanges := []db.TimeRange{safeQueryRange}
	for _, cachedRange := range cachedRanges.Safe {
		currentMergedRanges := mergedSafeRanges
		mergedSafeRanges = make([]db.TimeRange, 0, len(currentMergedRanges))

		for _, queryRange := range currentMergedRanges {
			newMergedRanges := queryRange.MergeIfOverlaps(cachedRange)
			mergedSafeRanges = append(mergedSafeRanges, newMergedRanges...)
		}
	}

	for _, qr := range queryRanges {
		s.logger.Debug(fmt.Sprintf("getQueryTimeRange - queries - %s", qr.String()))
	}
	for _, qr := range mergedRealRanges {
		s.logger.Debug(fmt.Sprintf("getQueryTimeRange - mergedRealRanges - %s", qr.String()))
	}
	for _, qr := range mergedSafeRanges {
		s.logger.Debug(fmt.Sprintf("getQueryTimeRange - mergedSafeRanges - %s", qr.String()))
	}
	return queryRanges, &db.CachedTimeRanges{
		Real: mergedRealRanges,
		Safe: mergedSafeRanges,
	}, nil
}
