package metrics

import (
	"context"
	"math"

	structpb "github.com/golang/protobuf/ptypes/struct"
	"github.com/montanaflynn/stats"
	"google.golang.org/grpc"
	"google.golang.org/grpc/codes"

	"code.justin.tv/dta/rockpaperscissors/internal/junit"
	pb "code.justin.tv/dta/rockpaperscissors/proto"
)

func init() {
	info := &pb.MetricInfo{
		MetricId:        "tests_passing_percent",
		MetricName:      "% Tests Passing",
		Description:     "Median percent of tests passing in Jenkins jobs.",
		ValidForProject: true,
		KeyMetric:       true,
	}
	Registry().Register(info, NewTestsPassingPercentCalculator)
}

// TestsPassingPercentCalculator to calculate median tests passing in Jenkins jobs.
type TestsPassingPercentCalculator struct {
	ProjectMetadataServer pb.ProjectMetadataServiceServer
	EventServer           pb.EventServiceServer
}

// NewTestsPassingPercentCalculator factory for TestsPassingPercentCalculator structs.
func NewTestsPassingPercentCalculator(projectMetadataServer pb.ProjectMetadataServiceServer, eventServer pb.EventServiceServer) (Calculator, error) {
	return &TestsPassingPercentCalculator{
		ProjectMetadataServer: projectMetadataServer,
		EventServer:           eventServer,
	}, nil
}

func (p *TestsPassingPercentCalculator) getDataPointForEvent(event *pb.Event) (float64, error) {
	testSuites, err := junit.Unmarshal(event.GetBody())
	if err != nil {
		return 0.0, errf(codes.Internal,
			"Error parsing test results JUnit XML: %v", err)
	}

	if testSuites.Tests > 0 {
		passingPercent := 1.0 - float64(testSuites.Fails+testSuites.Errors)/float64(testSuites.Tests)
		return passingPercent, nil
	}

	tests := float64(0)
	fails := float64(0)
	for _, testSuite := range testSuites.TestSuites {
		tests += float64(testSuite.Tests)
		fails += float64(testSuite.Failures)
	}
	return 1.0 - fails/tests, nil
}

func (p *TestsPassingPercentCalculator) getDataPointsForJob(ctx context.Context, jenkinsJob string, timerange *pb.TimeRange) ([]float64, error) {
	var dataPoints []float64

	queryResp, err := p.EventServer.QueryEvents(ctx, &pb.QueryEventsRequest{
		Timerange: timerange,
		Type:      "JenkinsTestResults",
		Filters: []*pb.QueryEventsRequest_AttributeFilter{
			&pb.QueryEventsRequest_AttributeFilter{
				Key:   "jenkins_job_name",
				Value: jenkinsJob,
			},
		},
	})
	if err != nil {
		if grpc.Code(err) == codes.NotFound {
			return nil, nil
		}
		return nil, err
	}
	for _, event := range queryResp.GetEvents() {
		dataPoint, err := p.getDataPointForEvent(event)
		if err != nil {
			// TODO: it might be better to just log the error and skip it
			return nil, err
		}
		if !(math.IsNaN(dataPoint) || math.IsInf(dataPoint, 0)) {
			dataPoints = append(dataPoints, dataPoint)
		}
	}

	return dataPoints, nil
}

func (p *TestsPassingPercentCalculator) calculateEntry(ctx context.Context, req *pb.GetMetricRequest, entry *pb.GetMetricResponse_TimeSeriesEntry, jenkinsJobs []string) error {
	var dataPoints []float64

	for _, jenkinsJob := range jenkinsJobs {
		jobDataPoints, err := p.getDataPointsForJob(
			ctx, jenkinsJob, entry.GetTimerange())
		if err != nil {
			return err
		}
		if jobDataPoints != nil {
			dataPoints = append(dataPoints, jobDataPoints...)
		}
	}

	if len(dataPoints) == 0 {
		return nil
	}

	median, err := stats.Median(dataPoints)
	if err != nil {
		return errf(codes.Internal, "Error calculating median value: %v", err)
	}

	entry.Value = &structpb.Value{
		Kind: &structpb.Value_NumberValue{
			NumberValue: median,
		},
	}

	return nil
}

// Calculate the time series and fill out the response.
func (p *TestsPassingPercentCalculator) Calculate(ctx context.Context, req *pb.GetMetricRequest, resp *pb.GetMetricResponse) error {
	timeSeries, err := makeTimeSeries(req.Timerange, req.BucketSize, req.IanaTimeZone)
	if err != nil {
		return errf(codes.InvalidArgument,
			"Found inappropriate time range: %v", err)
	}
	resp.TimeSeries = timeSeries
	resp.TimeSeriesUnits = "%"

	jenkinsJobs, err := getJenkinsJobsForProject(ctx, p.ProjectMetadataServer, req.ProjectId)
	if err != nil {
		return err
	}
	if len(jenkinsJobs) == 0 {
		return nil
	}

	sem := make(chan error, len(timeSeries))
	for _, bucket := range timeSeries {
		go func(ctx context.Context, req *pb.GetMetricRequest, entry *pb.GetMetricResponse_TimeSeriesEntry, jenkinsJobs []string) {
			sem <- p.calculateEntry(ctx, req, entry, jenkinsJobs)
		}(ctx, req, bucket, jenkinsJobs)
	}
	for i := 0; i < len(timeSeries); i++ {
		err := <-sem
		if err != nil {
			// TODO: cancel the context to stop any still-processing work?
			return err
		}
	}

	return nil
}
