package tracing

import (
	"net/http"

	"github.com/opentracing/opentracing-go"
)

type TracingMiddlewareBuilder struct {
	extractors []TagExtractor
	filters    []*PathFilter
	tracer     SpanCreater
}

type SpanCreater interface {
	StartSpan(string, ...opentracing.StartSpanOption) opentracing.Span
	Inject(sm opentracing.SpanContext, format interface{}, carrier interface{}) error
	Extract(format interface{}, carrier interface{}) (opentracing.SpanContext, error)
}

type PathFilter struct {
	path string
}

type TracingMiddleware struct {
	extractors []TagExtractor
	filters    []*PathFilter
	tracer     SpanCreater
}

type HeaderTagExtractor struct {
	tag    string
	header string
}

type GetParamTagExtractor struct {
	tag   string
	field string
}

type TagExtractor interface {
	extract(r *http.Request) string
	getTag() string
}

func (e *GetParamTagExtractor) extract(r *http.Request) string {
	q := r.URL.Query()
	field := q.Get(e.field)
	return field
}

func (e *HeaderTagExtractor) extract(r *http.Request) string {
	header := r.Header.Get(e.header)
	return header
}

func (e *HeaderTagExtractor) getTag() string {
	return e.tag
}

func (e *GetParamTagExtractor) getTag() string {
	return e.tag
}

func NewPathFilter(path string) *PathFilter {
	filter := &PathFilter{path: path}
	return filter
}

func NewHeaderTagExtractor(tag, header string) *HeaderTagExtractor {
	extractor := &HeaderTagExtractor{
		tag:    tag,
		header: header,
	}
	return extractor
}

func NewGetParamTagExtractor(tag, field string) *GetParamTagExtractor {
	extractor := &GetParamTagExtractor{
		tag:   tag,
		field: field,
	}
	return extractor
}

func NewTracingMiddlewareBuilder() *TracingMiddlewareBuilder {
	builder := TracingMiddlewareBuilder{
		tracer: opentracing.GlobalTracer(),
	}
	return &builder
}

func (builder *TracingMiddlewareBuilder) WithSpanCreater(tracer SpanCreater) *TracingMiddlewareBuilder {
	builder.tracer = tracer
	return builder
}

func (builder *TracingMiddlewareBuilder) WithFilter(filter *PathFilter) *TracingMiddlewareBuilder {
	builder.filters = append(builder.filters, filter)
	return builder
}

func (builder *TracingMiddlewareBuilder) WithExtractor(extractor TagExtractor) *TracingMiddlewareBuilder {
	builder.extractors = append(builder.extractors, extractor)
	return builder
}

func (builder *TracingMiddlewareBuilder) Build() *TracingMiddleware {
	tracingMiddleware := &TracingMiddleware{
		extractors: builder.extractors,
		filters:    builder.filters,
		tracer:     builder.tracer,
	}
	return tracingMiddleware
}

func (m *TracingMiddleware) filter(r *http.Request) bool {
	requestPath := r.URL.Path
	for _, filter := range m.filters {
		if filter.path == requestPath {
			return false
		}
	}
	return true
}

func (m *TracingMiddleware) extract(r *http.Request) map[string]string {
	tags := make(map[string]string)
	for _, extractor := range m.extractors {
		tags[extractor.getTag()] = extractor.extract(r)
	}
	return tags
}

func (m *TracingMiddleware) Handle(next http.Handler) http.Handler {
	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		if m.filter(r) {
			carrier := opentracing.HTTPHeadersCarrier(r.Header)
			context, err := m.tracer.Extract(opentracing.HTTPHeaders, carrier)
			var rootSpan opentracing.Span
			if err == nil {
				rootSpan = m.tracer.StartSpan(r.URL.Path, opentracing.ChildOf(context))
			} else {
				rootSpan = m.tracer.StartSpan(r.URL.Path)
			}
			for key, value := range m.extract(r) {
				rootSpan.SetTag(key, value)
			}
			defer rootSpan.Finish()
			r = r.WithContext(opentracing.ContextWithSpan(r.Context(), rootSpan))
		}
		next.ServeHTTP(w, r)
	})
}
