# -*- coding: utf-8 -*-
from functools import wraps
import logging
import random
import socket
import time

from passport.backend.logbroker_client.core.logbroker.decompress import (
    AutoDecompressor,
    BadArchive,
)
from passport.backend.logbroker_client.core.logbroker.tskv import parse_tskv
from passport.backend.logbroker_client.core.utils import retriable_n
import requests
import six


if six.PY3:
    from http.client import (
        BadStatusLine,
        HTTPConnection as _HTTPConnection,
        HTTPException,
        IncompleteRead,
        LineTooLong,
        ResponseNotReady,
    )


    class HTTPConnection(_HTTPConnection):
        def getresponse(self, *args, **kwargs):
            return super(HTTPConnection, self).getresponse()


else:
    from passport.backend.logbroker_client.core.logbroker.hhttplib import (
        BadStatusLine,
        HTTPConnection,
        HTTPException,
        IncompleteRead,
        LineTooLong,
        ResponseNotReady,
    )


def _resp_str(value):
    if six.PY3:
        return value.decode()
    else:
        return value


def _resp_dict(value):
    if six.PY3:
        return {
            _resp_str(k): _resp_str(v)
            for k, v in value.items()
        }
    else:
        return value


def _read_chunked(response, amt):
    if six.PY3:
        while True:
            chunk = response.read(amt)
            if not chunk:
                break
            yield chunk
    else:
        for chunk in response.read_chunked(amt):
            yield chunk


log = logging.getLogger('logbroker')


class BaseLogbrokerClientException(Exception):
    pass


class KilledSession(BaseLogbrokerClientException):
    pass


class LockException(BaseLogbrokerClientException):
    pass


class SuggestException(BaseLogbrokerClientException):
    pass


class CommitException(BaseLogbrokerClientException):
    pass


class TimeoutException(BaseLogbrokerClientException):
    pass


class NetworkException(BaseLogbrokerClientException):
    pass


class OffsetsException(BaseLogbrokerClientException):
    pass


class NoHostsException(BaseLogbrokerClientException):
    pass


class BadPullListException(BaseLogbrokerClientException):
    pass


class Try(object):
    def __init__(self, fn, *args, **kwargs):
        self.fn = fn
        self.result = None
        self.exception = None
        try:
            self.result = fn(*args, **kwargs)
        except Exception as e:
            self.exception = e

    def getOrElse(self, value):
        if self.result is None:
            return value
        return self.result


class ExceptionsWrapper(object):
    def __init__(self, mapping):
        self.mapping = mapping

    def __call__(self, f):
        @wraps(f)
        def wrapper(*args, **kwargs):
            try:
                return f(*args, **kwargs)
            except Exception as e:
                mapper = self.mapping.get(type(e))
                if mapper:
                    raise mapper(e)
                raise

        return wrapper


LogbrokerNetworkErrorsWrapper = ExceptionsWrapper(
    {
        socket.error: lambda e: NetworkException(str(e)),
        socket.timeout: lambda e: TimeoutException(str(e)),
        HTTPException: lambda e: NetworkException(str(e)),
        IncompleteRead: lambda e: NetworkException(str(e)),
        ResponseNotReady: lambda e: NetworkException(str(e)),
        BadStatusLine: lambda e: NetworkException(str(e)),
        LineTooLong: lambda e: NetworkException(str(e)),
    },
)


class LogbrokerReader(object):
    SUGGEST_PATH = "/pull/suggest?topic=%s&client=%s&count=%s"
    SESSION_WHOLE_TIMEOUT = 1800000  # LBOPS-122
    SESSION_PATH = "/pull/session?topic=%s&client=%s&timeout=30000"
    STREAM_PATH = "/pull/read?topic=%s&client=%s&limit=1000&format=raw"
    COMMIT_PATH = '/pull/commit?client=%s'
    DATA_PORT = 8998
    SUGGEST_TIMEOUT = 5
    COMMIT_TIMEOUT = 5

    def __init__(self, host, client, topics, partitions_count, data_port=None, suggest_timeout=None, timeout=25,
                 suggest_balance=None):
        self.host = '%s:%s' % (host, data_port if data_port else self.DATA_PORT)
        self.client = client
        self.topics = topics
        self.partitions_count = 1 if len(self.topics) != 1 else partitions_count
        self.timeout = timeout
        self.suggest_timeout = suggest_timeout or self.SUGGEST_TIMEOUT
        self.suggest_balance = suggest_balance
        self._read_connection = None
        self._session = None

    @retriable_n(retry_count=2, time_sleep=0.2, exceptions=(TimeoutException, NetworkException, SuggestException))
    @LogbrokerNetworkErrorsWrapper
    def suggest(self):
        log.info('Call /suggest/ with host: %s', self.host)
        log.info('Suggest params topic=%s count=%s client=%s balance=%s',
                 self.topics, self.partitions_count, self.client, self.suggest_balance)
        suggest_path = self.SUGGEST_PATH % (','.join(self.topics), self.client, self.partitions_count)
        if self.suggest_balance:
            suggest_path += '&balance=%s' % self.suggest_balance
        log.info('Suggest path: %s', suggest_path)
        if self.suggest_balance:
            return self._suggest_balanced_request(suggest_path)
        else:
            return self._suggest_request(suggest_path)

    def _suggest_request(self, suggest_path):
        suggest_connection = HTTPConnection(self.host, timeout=self.suggest_timeout)
        suggest_connection.request('GET', suggest_path)
        try:
            response = suggest_connection.getresponse(True)
            data = response.read()
            if response.status != 200:
                raise SuggestException('Suggest response code=%s data=%s' % (response.status, data))
            partitions = [_resp_str(line.strip().split()) for line in data.split(b'\n') if
                          line.strip()]
            return partitions
        finally:
            suggest_connection.close()

    def _suggest_balanced_request(self, suggest_path):
        suggest_connection = HTTPConnection(self.host, timeout=self.suggest_timeout)
        suggest_connection.request('GET', suggest_path)
        try:
            # Avoid DDosing suggest
            time.sleep(random.randrange(3))
            response = suggest_connection.getresponse(True)
            if response.status != 200:
                raise SuggestException('Suggest response code=%s data=%s' % (response.status, response.read()))

            for chunk in _read_chunked(response, None):
                if not chunk:
                    raise SuggestException('Empty suggest chunk')
                chunk = chunk.strip()

                if chunk == 'ping':
                    log.info('Suggest returned ping, waiting')
                    continue
                parts = [_resp_str(x) for x in chunk.split()]
                if parts[0] == 'ans':
                    return [parts[1:]]
                else:
                    raise SuggestException('Suggest unexpected chunk %s' % chunk)
        finally:
            suggest_connection.close()

    @LogbrokerNetworkErrorsWrapper
    def __enter__(self):
        partitions = self.suggest()
        if len(partitions) != self.partitions_count:
            raise SuggestException(
                'Suggest return %s partitions but required %s' % (len(partitions), self.partitions_count),
            )

        hosts = [x[0] for x in partitions]
        partitions = [x[1] for x in partitions]
        # Это может быть плохо, если несколько партиций, но
        # мы стараемся иметь воркеров по числу партиций
        read_host = random.choice(hosts)

        # Пытаемся создать сессию
        log.info('Got partitions: %s use host: %s', partitions, read_host)
        self._read_connection = HTTPConnection(read_host, timeout=self.timeout)
        session_path = self.SESSION_PATH % (','.join(partitions), self.client)
        try:
            self._read_connection.request('GET', session_path, headers={'Timeout': self.SESSION_WHOLE_TIMEOUT})
            response = self._read_connection.getresponse(True)
            data = response.read()
            if response.status != 200:
                raise LockException('/session/ returned bad status: %s (%s)' % (response.status, data))
            session = response.getheader('Session')
            if not session:
                raise LockException('Bad session value: %s' % session)
        except Exception:
            log.info('Close connection (session not established)')
            self._read_connection.close()
            raise

        self._session = session

        return self.chunk_reader(partitions)

    def __exit__(self, exc_type, exc_value, traceback):
        log.info('Close connection: %s %s', exc_type, exc_value)
        self._read_connection.close()
        self._session = None

    def chunk_reader(self, partitions):
        self.partitions = partitions
        read_path = self.STREAM_PATH % (','.join(partitions), self.client)
        log.info('Reader path: %s', read_path)
        while 1:
            self._read_connection.request('GET', read_path, headers={'Session': self._session})
            response = self._read_connection.getresponse(True)
            if response.status != 200:
                log.error('response:%s, code:%s, session:%s', response.read(), response.status, self._session)
                raise LockException('Cant lock while /read/', response.read())
            # log.info('Partitions: %s, session: %s (%s)', partitions, response.getheader('Session'), self._session)

            gen = _read_chunked(response, None)
            while 1:
                try:
                    chunk = six.next(gen)
                    if not chunk:
                        break
                    elif b'java.lang.Exception:' in chunk:
                        raise BaseLogbrokerClientException(chunk)
                    header, body = chunk.split(b'\n', 1)
                    try:
                        header = _resp_dict(parse_tskv(header))
                    except ValueError as err:
                        log.error('Bad header: %s\n`%s`\n`%s`', err, header, chunk)
                        raise BaseLogbrokerClientException('Bad read response')
                    except UnicodeDecodeError as err:
                        log.error('Bad header encoding: %s\n`%s`\n`%s`', err, header, chunk)
                        raise BaseLogbrokerClientException('Bad read response')
                    yield header, body
                except StopIteration:
                    break
            if six.PY2:
                if response.trailer:
                    trailer_content = b', '.join(response.trailer)
                    log.info('Trailer headers: %s', trailer_content)
                    if b'X-Logbroker-Error' in trailer_content:
                        raise BaseLogbrokerClientException('Logbroker error in trailer headers: %s' % trailer_content)
            yield None, None

    @retriable_n(retry_count=1, time_sleep=0.2, exceptions=(TimeoutException, NetworkException, CommitException))
    @LogbrokerNetworkErrorsWrapper
    def commit(self, offsets):
        if not offsets:
            return
        path = self.COMMIT_PATH % self.client
        body = '\n'.join(
            ['%s:%s' % (topic, offset) for topic, offset in offsets.items()])

        self._read_connection.request('POST', path, body=body, headers={'Session': self._session})
        response = self._read_connection.getresponse(True)

        data = response.read().strip()
        if response.status != 200:
            raise CommitException('Commit response code=%s data=%s' % (response.status, data))
        if data == b'session was killed':
            log.info('Commit for killed session')
            raise KilledSession('Commit for killed session')
        elif data != b'ok':
            log.warning('Commit error: %s', repr(data))
            raise CommitException(data)


EMPTY_READS_COUNT_BEFORE_COMMIT = 100


class LogbrokerConsumer(object):
    def __init__(self, hosts, client, topics, partitions_count, data_port=None, suggest_timeout=None,
                 suggest_balance=None):
        self.hosts = hosts
        self.data_port = data_port
        self.client = client
        self.topics = topics
        self.partitions_count = partitions_count
        self.suggest_timeout = suggest_timeout
        self.suggest_balance = suggest_balance

    @LogbrokerNetworkErrorsWrapper
    def read_unpacked(self, handler):
        working_host = random.choice(self.hosts)
        reader = LogbrokerReader(
            working_host,
            self.client,
            self.topics,
            self.partitions_count,
            self.data_port,
            suggest_timeout=self.suggest_timeout,
            suggest_balance=self.suggest_balance,
        )
        decompressor = AutoDecompressor()
        with reader as chunk_stream:
            offsets = {}
            flush_count = 0
            empty_successive_reads_count = 0
            got_non_empty_chunk = False
            for chunk in chunk_stream:
                # В цикле вызываем поочередно /pull/read и /pull/commit в рамках одной сессии
                header, rawdata = chunk
                if header is not None and rawdata is not None:
                    try:
                        unpacked = decompressor.decompress(rawdata)
                    except BadArchive:
                        log.warning('Bad archive: %s', header)
                        continue

                    log.debug('Packed: %s Unpacked: %s', len(rawdata), len(unpacked))
                    is_flushed = handler.process(header, _resp_str(unpacked))

                    topic = header['topic']
                    partition = header['partition']
                    offset = header['offset']
                    offsets['%s:%s' % (topic, partition)] = offset
                    got_non_empty_chunk = True

                    if is_flushed:
                        flush_count += 1
                else:
                    # делаем flush и commit после успешного окончания чтения
                    handler.flush(force=True)
                    log.info('Flush count per read: %s', flush_count + 1)
                    reader.commit(offsets)
                    offsets = {}
                    flush_count = 0

                    # подсчет числа последовательных read-ов, вернувших только пустой чанк
                    if not got_non_empty_chunk:
                        empty_successive_reads_count += 1
                    else:
                        empty_successive_reads_count = 0
                        got_non_empty_chunk = False

                    if empty_successive_reads_count >= EMPTY_READS_COUNT_BEFORE_COMMIT:
                        # достигли порога - считаем оффсеты и закоммитим, чтобы они не протухли (STATADMIN-4304)
                        log.info('Too many empty reads, commiting offsets')
                        meta = LogbrokerMeta(working_host, self.client)
                        try:
                            info = meta.get_offsets_info(self.client, reader.partitions, self.data_port)
                        except Exception as e:
                            log.error('Failed to get offsets info', exc_info=e)
                        else:
                            offsets = {item['partition']: item['offset'] for item in info}
                            reader.commit(offsets)
                            offsets = {}
                            empty_successive_reads_count = 0


class LogbrokerMeta(object):
    def __init__(self, balancer_host, client):
        self.balancer_host = balancer_host
        self.client = client

    @retriable_n(
        12,
        time_sleep=10.0,
        exceptions=(NoHostsException, requests.exceptions.Timeout, requests.exceptions.ConnectionError),
    )
    def hosts_info(self, dc, timeout=10):
        response = requests.get(
            'http://%s/hosts_info' % self.balancer_host,
            params={'dc': dc},
            timeout=timeout,
        )
        if response.status_code != 200:
            raise NoHostsException('Bad hosts_info response in dc %s: %s (%s)' % (dc, response.status_code, response.content))
        data = response.content
        return [_resp_str(x.split(b'\t')[0]) for x in data.split(b'\n') if x]

    @retriable_n(
        12,
        time_sleep=10.0,
        exceptions=(BadPullListException, requests.exceptions.Timeout, requests.exceptions.ConnectionError),
    )
    def show_parts(self, dc=None, ident=None, timeout=10):
        response = requests.get(
            'http://%s/pull/list' % self.balancer_host,
            params={'ident': ident, 'dc': dc, 'show-parts': '1'},
            timeout=timeout,
        )
        if response.status_code != 200:
            raise BadPullListException('Bad /pull/list response: %s (%s)' % (response.status_code, response.content))
        data = response.content
        return [_resp_str(x.strip()) for x in data.split(b'\n') if x.strip()]

    @retriable_n(
        12,
        time_sleep=10.0,
        exceptions=(OffsetsException, requests.exceptions.Timeout, requests.exceptions.ConnectionError),
    )
    def get_offsets_info(self, client, topics, data_port, dc=None):
        if dc:
            hosts = self.hosts_info(dc)
            host = '%s:%s' % (random.choice(hosts), data_port)
        else:
            host = '%s:%s' % (self.balancer_host, data_port)
        response = requests.get(
            'http://%s/pull/offsets' % host,
            params={'client': client, 'topic': ','.join(topics)},
        )
        data = response.content

        if response.status_code != 200:
            log.error(
                'Failed to read offsets (host %s): code %s, content "%s"',
                host,
                response.status_code,
                response.content,
            )
            raise OffsetsException()

        info = []
        for line in data.split(b'\n'):
            if not line.strip():
                continue
            partition, offset, start, size, lag, owner = line.strip().split()
            topic, _ = partition.split(b':')
            info.append({
                'partition': _resp_str(partition),
                'offset': int(offset),
                'start': int(start),
                'size': int(size),
                'lag': int(lag),
                'owner': owner,
            })
        info.sort(key=lambda item: item['partition'])
        return info
