import logging
import os
import random
from threading import Lock
from typing import Callable
import socket

import kikimr.public.sdk.python.persqueue.auth as auth
import kikimr.public.sdk.python.persqueue.grpc_pq_streaming_api as pqlib
from kikimr.public.sdk.python.persqueue.errors import SessionFailureResult

from retrying import retry
import ujson

logger = logging.getLogger(__name__)

try:
    TimeoutError
except NameError:
    TimeoutError = socket.timeout


class LogbrokerWriter:
    STREAMING_API_DEFAULT_PORT = 2135

    def __init__(self, token, topic, source_id, endpoint='logbroker.yandex.net', write_timeout=10):
        # type: (str, str, str, str) -> None
        if not token:
            raise ValueError('Logbroker token is not defined')
        if not topic:
            raise ValueError('Logbroker topic is not defined')

        self.topic = topic
        self.token = token
        self.endpoint = endpoint

        self.api = None
        self.producer = None
        self.source_id = source_id
        self.write_timeout = write_timeout

        self._lock_ = Lock()

    @retry(
        retry_on_exception=lambda exc: isinstance(exc, RuntimeError) or isinstance(exc, TimeoutError),
        wrap_exception=True,
        stop_max_attempt_number=10,
        wait_fixed=10000,
    )
    def _try_enter_(self):
        # See https://a.yandex-team.ru/arc/trunk/arcadia/kikimr/public/sdk/python/persqueue/examples/producer/__main__.py
        # for example
        self.api = pqlib.PQStreamingAPI(self.endpoint, self.STREAMING_API_DEFAULT_PORT)

        logger.info('Starting PqLib')
        api_start_future = self.api.start()

        result = api_start_future.result(timeout=10)
        logger.info('Api started with result: %s', result)

        credentials_provider = auth.OAuthTokenCredentialsProvider(self.token)

        logger.info('Configuring producer for topic %s with source id %s', self.topic, self.source_id)
        configurator = pqlib.ProducerConfigurator(self.topic, self.source_id)

        self.producer = self.api.create_retrying_producer(configurator, credentials_provider=credentials_provider)

        logger.info('Starting producer')
        start_future = self.producer.start()
        start_result = start_future.result(timeout=10)

        self.max_seq_no = None

        if not isinstance(start_result, SessionFailureResult):
            if start_result.HasField('init'):
                self.max_seq_no = start_result.init.max_seq_no
                logger.info(
                    'Producer start result was: %s. Sequence number: %s',
                    start_result, self.max_seq_no)

            else:
                raise RuntimeError('Unexpected producer start result from server: {}.'.format(start_result))
        else:
            raise RuntimeError('Error occurred on start of producer: {}.'.format(start_result))
        logger.info('Producer started')
        return self

    def __enter__(self):
        return self._try_enter_()

    def write(self, data):
        with self._lock_:
            self.max_seq_no += 1
            response = self.producer.write(self.max_seq_no, data)
            write_result = response.result(timeout=self.write_timeout)
            if not write_result.HasField('ack'):
                raise RuntimeError('Message write failed with error {}'.format(write_result))

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.producer.stop()
        logger.info('Producer stopped')
        self.api.stop()
        logger.info('Api stopped')


class LogbrokerProducer:
    def __init__(self, token, topic, source_id_generator, producer_id, endpoint='logbroker.yandex.net'):
        # type: (str, str, Callable[[str], str], str, str) -> None
        if not token:
            raise ValueError('Logbroker token is not defined')
        if not topic:
            raise ValueError('Logbroker topic is not defined')

        self.token = token
        self.topic = topic
        self.endpoint = endpoint
        self.source_id = source_id_generator(producer_id)

    def write(self, data):
        # type: (dict) -> None
        json_string = ujson.dumps(data)
        with LogbrokerWriter(self.token, self.topic, self.source_id, self.endpoint) as writer:
            return writer.write(json_string)


def get_qloud_source_id(producer_id):
    return '.'.join([
        os.environ.get('QLOUD_PROJECT', 'avia'),
        os.environ.get('QLOUD_APPLICATION', 'flight-status-fetcher'),
        os.environ.get('QLOUD_ENVIRONMENT', 'dev'),
        os.environ.get('QLOUD_COMPONENT', 'dev'),
        os.environ.get('QLOUD_INSTANCE', 'dev-1'),
        producer_id,
        'pid-{}'.format(os.getpid()),
        str(random.randint(0, 9))
    ]).encode('utf-8')


def get_deploy_source_id(producer_id):
    return '.'.join([
        os.getenv('DEPLOY_STAGE_ID', 'avia-flight-status-fetcher-dev'),
        os.environ.get('DEPLOY_POD_PERSISTENT_FQDN', 'jllad2k6utvtqbzj.sas.yp-c.yandex.net'),
        producer_id,
        'pid-{}'.format(os.getpid()),
        str(random.randint(0, 9))
    ]).encode('utf-8')
