# coding: utf8
from __future__ import absolute_import, division, print_function, unicode_literals

import os
import logging
import time

import jaeger_client

logger = logging.getLogger(__name__)

SAMPLING_LONG_SPANS = os.environ.get('TRACING_SAMPLING_LONG_SPANS', False)
SAMPLING_TIME_THRESHOLD = os.environ.get('TRACING_SAMPLING_TIME_THRESHOLD', 0.15)
SAMPLING_LONG_SPANS_RATE_LIMIT = os.environ.get('TRACING_SAMPLING_LONG_SPANS_RATE_LIMIT', 1.0)

rate_limit_sampler = jaeger_client.RateLimitingSampler(
    max_traces_per_second=float(SAMPLING_LONG_SPANS_RATE_LIMIT)
)


class Span(jaeger_client.Span):
    @classmethod
    def cast(cls, span):
        span.__class__ = cls
        return span

    def set_tag(self, key, value):
        flags = self.context.flags
        self.context.flags = jaeger_client.span.SAMPLED_FLAG
        super(Span, self).set_tag(key, value)
        self.context.flags = flags
        return self

    def finish(self, finish_time=None):
        """Do the same work as the jaeger's span:finish()
           but if span was longer than threshold it have to be reported
        """
        if not self.is_sampled():
            if not SAMPLING_LONG_SPANS:
                return
            end_time = finish_time or time.time()
            if (end_time - self.start_time) < SAMPLING_TIME_THRESHOLD:
                if not rate_limit_sampler.is_sampled(
                    trace_id=self.trace_id,
                    operation=self.operation_name
                ):
                    return

        with self.update_lock:
            if self.finished:
                logger.warning('Span has already been finished; will not be reported again.')
                return
                pass
            self.finished = True
            self.end_time = finish_time or time.time()

        self.tracer.report_span(self)


class Tracer(jaeger_client.Tracer):
    def start_span(self, *args, **kwargs):
        return Span.cast(super(Tracer, self).start_span(*args, **kwargs))


class Config(jaeger_client.Config):
    def create_tracer(self, reporter, sampler, throttler=None):
        return Tracer(
            service_name=self.service_name,
            reporter=reporter,
            sampler=sampler,
            metrics_factory=self._metrics_factory,
            trace_id_header=self.trace_id_header,
            generate_128bit_trace_id=self.generate_128bit_trace_id,
            baggage_header_prefix=self.baggage_header_prefix,
            debug_id_header=self.debug_id_header,
            tags=self.tags,
            max_tag_value_length=self.max_tag_value_length,
            max_traceback_length=self.max_traceback_length,
            extra_codecs=self.propagation,
            throttler=throttler,
            scope_manager=self.scope_manager,
        )
