import collections
import hashlib
import logging
import random
import time
from urllib import urlencode
from concurrent import futures
from contextlib import contextmanager

import json
import requests
import yt.wrapper as yt
from library.python import retry


def generate_reqid():
    random.seed()
    return '{}-{}'.format(int(time.time() * 1000), random.randint(0, 2 ** 63 - 1))


def rotate(cont, shift):
    shift %= len(cont)
    return cont[shift:] + cont[:shift]


@retry.retry(retry.RetryConf(logger=logging.getLogger()).waiting(1, backoff=2).upto_retries(5))
def get(*args, **kwargs):
    response = requests.get(*args, **kwargs)
    if not 400 <= response.status_code <= 499:
        response.raise_for_status()
    return response


@contextmanager
def graceful_shutdown(pool):
    try:
        yield pool
    finally:
        pool.close()
        pool.join()


class Chunk(collections.namedtuple('Chunk', 'messages start end')):
    def __str__(self):
        return 'Chunk [#{}: #{}'.format(self.start, self.end)


class Producer:
    def __init__(self, producer, seq_no):
        self.producer = producer
        self.seq_no = seq_no

    def write(self, chunk):
        responses = []
        for message in chunk.messages:
            self.seq_no += 1
            responses.append(self.producer.write(self.seq_no, message))
        self.wait_ack(responses, chunk)

    @staticmethod
    def wait_ack(responses, chunk):
        count = 0
        total = len(responses)
        not_done = responses
        while not_done:
            done, not_done = futures.wait(not_done, timeout=1, return_when=futures.ALL_COMPLETED)
            count += len(done)
            logging.info('{} {}/{} writes finished'.format(chunk, count, total))
            for response in done:
                result = response.result(timeout=0)
                if not result.HasField('ack'):
                    raise RuntimeError('logbroker write error: {}'.format(result))


class ProducerPositiveSeqNo(Exception):
    def __init__(self, msg='Resending the same YT table with the same start position is prohibited'):
        super(ProducerPositiveSeqNo, self).__init__(msg)


def set_up_producer(lb_token, topic, source_id, worker_id):
    logging.info('Setting up producer for topic {} source_id {}'.format(topic, source_id.decode()))

    import kikimr.public.sdk.python.persqueue.grpc_pq_streaming_api as pqlib
    import kikimr.public.sdk.python.persqueue.auth as auth
    from kikimr.public.sdk.python.persqueue.errors import SessionFailureResult

    logging.getLogger('kikimr.public.sdk.python.persqueue').setLevel(logging.DEBUG)

    try:
        endpoints = tuple('{}.logbroker.yandex.net'.format(dc) for dc in ('sas', 'vla', 'man', 'myt', 'iva'))
        for endpoint in rotate(endpoints, shift=worker_id):
            # noinspection PyBroadException
            try:
                lb = pqlib.PQStreamingAPI(endpoint, 2135)
                result = lb.start().result(timeout=10)
                if not result:
                    raise RuntimeError('{} connect start result: {}'.format(endpoint, result))

                credentials_provider = auth.OAuthTokenCredentialsProvider(lb_token.encode())
                configurator = pqlib.ProducerConfigurator(topic, source_id)

                producer = lb.create_retrying_producer(configurator, credentials_provider=credentials_provider)
                result = producer.start().result(timeout=10)

                if not isinstance(result, SessionFailureResult):
                    if result.HasField('init'):
                        logging.info('Producer start result: {}'.format(result))
                        seq_no = result.init.max_seq_no
                        if seq_no > 0:
                            raise ProducerPositiveSeqNo()
                        return Producer(producer, seq_no)
                    else:
                        raise RuntimeError('Unexpected producer start result: {}'.format(result))
                else:
                    raise RuntimeError('Producer start error: {}'.format(result))

            except ProducerPositiveSeqNo as e:
                logging.exception('Resending the same YT table with the same start position')
                raise e
            except Exception:
                logging.exception('Cannot connect to {}'.format(endpoint))

        raise RuntimeError('Cannot connect to logbroker')
    finally:
        logging.getLogger('kikimr.public.sdk.python.persqueue').setLevel(logging.INFO)


class Worker:
    id = 0

    def __init__(self, lb_token, lb_topic, source_id):
        self.lb_token = lb_token
        self.lb_topic = lb_topic
        self.parent_source_id = source_id
        self._producer = None

    @property
    def source_id(self):
        return '{}.{}'.format(self.parent_source_id.decode(), self.id).encode()

    @property
    def producer(self):
        if not self._producer:
            self._producer = set_up_producer(self.lb_token, self.lb_topic, self.source_id, self.id)
        return self._producer

    def __call__(self, chunk):
        # noinspection PyBroadException
        try:
            self.producer.write(chunk)
            return True
        except Exception:
            logging.exception('{} processing failed'.format(chunk))
            return False


class Yt2Lb:
    def __init__(self, lb_token, yt_token, yt_proxy, yt_table, lb_topic, user, start=0, count=None, concurrency=1):
        self.lb_token = lb_token
        self.yt_token = yt_token
        self.yt_proxy = yt_proxy
        self.yt_table = yt_table
        self.lb_topic = lb_topic
        self.user = user.encode()
        self.start = start
        self.end = None if count is None else self.start + count
        if not isinstance(concurrency, int) or not 1 <= concurrency <= 16:
            raise RuntimeError('Invalid concurrency {}'.format(concurrency))
        self.concurrency = concurrency

        self.reqid = generate_reqid()

    @property
    def source_id(self):
        yt_realpath = yt.get(self.yt_table + '/@path')
        return hashlib.md5('{}{}{}'.format(self.yt_proxy, yt_realpath, self.start).encode()).hexdigest()[:8].encode()

    @property
    def is_prod_topic(self):
        return self.lb_topic == '/sup/pushes-batch'

    def _iter_chunks(self, recs):
        messages = []
        size = 0
        i = 0
        start = self.start

        for i, rec in enumerate(recs, start=self.start):
            if 'push' not in rec:
                logging.warning("Row #{} skipped: required column 'push' missed".format(i))
                continue

            push = rec['push']
            if not push:
                logging.warning("Row #{} skipped: empty 'push' column".format(i))
                continue

            batch = {
                'push': push.encode(),
                'reqid': self.reqid,
                'user': self.user
            }

            message = json.dumps(batch).encode()
            messages.append(message)
            size += len(message)

            if size > 8 * 1024 * 1024:
                yield Chunk(messages, start, i + 1)
                start = i + 1
                messages = []
                size = 0

            processed = i - self.start + 1
            if processed % 1000 == 0:
                logging.info('Processed {} rows'.format(processed))

        if messages:
            yield Chunk(messages, start, i + 1)

    def get_row_count(self):
        row_count = yt.row_count(self.yt_table)
        if self.end is None:
            end = row_count
        else:
            end = min(self.end, row_count)
        return max(0, end - self.start)

    def log_accept(self):
        if not self.is_prod_topic:
            return
        row_count = self.get_row_count()
        params = {
            'state': 'accept',
            'user': self.user,
            'reqid': self.reqid,
            'path': self.yt_table,
            'row_count': row_count
        }
        url = 'http://sup.yandex.net/pushes/batch/log?{}'.format(urlencode(params))
        response = requests.get(url)
        response.raise_for_status()

    def validate(self):
        if not self.is_prod_topic:
            return
        head = yt.TablePath(self.yt_table, start_index=self.start, end_index=self.start + 10)
        for i, rec in enumerate(yt.read_table(head), self.start):
            push = rec['push']
            response = get(
                'http://sup.yandex.net/pushes/validate',
                data=push.encode('utf-8'),
                headers={'Content-Type': 'application/json;charset=utf-8'}
            )
            try:
                response.raise_for_status()
            except Exception as e:
                logging.error('Invalid push at row #{}'.format(i))
                logging.error('Push request: {}'.format(push))
                logging.error('Response: {}'.format(response.text))
                raise e

    def run(self):
        logging.basicConfig(
            format='%(asctime)s %(processName)-17s %(levelname)-8s %(message)s',
            level=logging.INFO,
            datefmt='%Y-%m-%d %H:%M:%S'
        )

        logging.info('Request id {}'.format(self.reqid))

        self.setup_yt(self.yt_token, self.yt_proxy)

        with yt.Transaction():
            self.validate()

            table_path = yt.TablePath(self.yt_table, start_index=self.start, end_index=self.end)
            recs = yt.read_table(table_path, unordered=True, enable_read_parallel=True)
            chunks = self._iter_chunks(recs)

            worker = Worker(self.lb_token, self.lb_topic, self.source_id)
            for i, chunk in enumerate(chunks):
                success = worker(chunk)
                if not success:
                    raise RuntimeError('Writing failed. Inspect logs for more information')
                if i == 0:
                    # Log request acceptance after 1st successful write to logbroker
                    self.log_accept()

        return self.reqid

    @staticmethod
    def setup_yt(yt_token, yt_proxy):
        yt.update_config({'token': yt_token, 'proxy': {'url': yt_proxy}})
        yt.config['read_parallel']['max_thread_count'] = 16
        yt.config['read_parallel']['data_size_per_thread'] = 8 * 1024 * 1024
