package metrics

import (
	"context"
	"fmt"
	"sync"
	"time"

	"a.yandex-team.ru/library/go/core/metrics"
	"a.yandex-team.ru/library/go/core/metrics/solomon"
	"google.golang.org/grpc"
	"google.golang.org/grpc/codes"
	"google.golang.org/grpc/status"
)

type (
	Option func(*metricsInterceptor)

	metricsInterceptor struct {
		r metrics.Registry

		endpointKey     func(info *grpc.UnaryServerInfo) string
		durationBuckets metrics.DurationBuckets
		defaultEndpoint endpointMetrics
		endpoints       sync.Map
		codes           []codes.Code
		addCodeTag      bool
	}

	endpointMetrics struct {
		registerOnce sync.Once

		requestCount     metrics.Counter
		requestDuration  metrics.Timer
		panicsCount      metrics.Counter
		errorsCount      metrics.Counter
		okCount          metrics.Counter
		inflightRequests metrics.Gauge
		codesCount       map[codes.Code]metrics.Counter
	}
)

func (m *metricsInterceptor) register(endpoint *endpointMetrics, key string) {
	r := m.r
	endpoint.registerOnce.Do(func() {
		if key != "" {
			r = r.WithTags(map[string]string{"endpoint": key})
		}
		endpoint.requestCount = r.Counter("request_count")
		solomon.Rated(endpoint.requestCount)
		endpoint.requestDuration = r.DurationHistogram("request_duration", m.durationBuckets)
		solomon.Rated(endpoint.requestDuration)
		endpoint.panicsCount = r.Counter("panics_count")
		solomon.Rated(endpoint.panicsCount)
		endpoint.errorsCount = r.Counter("errors_count")
		solomon.Rated(endpoint.errorsCount)
		endpoint.okCount = r.Counter("ok_count")
		solomon.Rated(endpoint.okCount)
		endpoint.inflightRequests = r.Gauge("inflight_requests")

		endpoint.codesCount = map[codes.Code]metrics.Counter{}
		if m.addCodeTag {
			for _, code := range m.codes {
				endpoint.codesCount[code] = r.WithTags(map[string]string{"code": code.String()}).Counter("code_count")
				solomon.Rated(endpoint.codesCount[code])
			}
		} else {
			for _, code := range m.codes {
				endpoint.codesCount[code] = r.Counter(fmt.Sprintf("code_%d_count", code))
				solomon.Rated(endpoint.codesCount[code])
			}
		}
	})
}

func (e *endpointMetrics) finishRequest(startTime time.Time) {
	e.requestDuration.RecordDuration(time.Since(startTime))
	e.inflightRequests.Add(-1)

	if err := recover(); err != nil {
		e.panicsCount.Inc()
		panic(err)
	}
}

func NewGrpcMetricsInterceptor(registry metrics.Registry, opts ...Option) grpc.UnaryServerInterceptor {
	m := &metricsInterceptor{
		r:               registry,
		durationBuckets: metrics.NewDurationBuckets(defaultDurationBuckets()...),
		endpointKey:     fullMethodEndpoint,
	}
	for _, opt := range opts {
		opt(m)
	}
	return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
		key := m.endpointKey(info)

		var endpoint *endpointMetrics
		if key != "" {
			value, _ := m.endpoints.LoadOrStore(key, &endpointMetrics{})
			endpoint = value.(*endpointMetrics)
		} else {
			endpoint = &m.defaultEndpoint
		}
		m.register(endpoint, key)

		endpoint.requestCount.Inc()
		endpoint.inflightRequests.Add(1)
		startTime := time.Now()
		defer endpoint.finishRequest(startTime)

		resp, err := handler(ctx, req)
		if err != nil {
			endpoint.errorsCount.Inc()
		} else {
			endpoint.okCount.Inc()
		}

		var code = status.Code(err)
		v, ok := endpoint.codesCount[code]
		if ok {
			v.Inc()
		}

		return resp, err
	}
}
