# -*- coding: utf-8 -*-
import time
import socket
import logging
import random
from functools import wraps
from hashlib import md5

import requests

from .hhttplib import (
    HTTPConnection,
    HTTPException,
    BadStatusLine,
    IncompleteRead,
    ResponseNotReady,
    LineTooLong,
)
from .tskv import parse_tskv
from .decompress import AutoDecompressor
from ..utils import (retriable_n, backoff_retriable_n)

log = logging.getLogger('logbroker')


class BaseLogbrokerClientException(Exception):
    pass


class KilledSession(BaseLogbrokerClientException):
    pass


class LockException(BaseLogbrokerClientException):
    pass


class SuggestException(BaseLogbrokerClientException):
    pass


class CommitException(BaseLogbrokerClientException):
    pass


class TimeoutException(BaseLogbrokerClientException):
    pass


class NetworkException(BaseLogbrokerClientException):
    pass


class OffsetsException(BaseLogbrokerClientException):
    pass


class Try(object):
    def __init__(self, fn, *args, **kwargs):
        self.fn = fn
        self.result = None
        self.exception = None
        try:
            self.result = fn(*args, **kwargs)
        except Exception, e:
            self.exception = e

    def getOrElse(self, value):
        if self.result is None:
            return value
        return self.result


class ExceptionsWrapper(object):
    def __init__(self, mapping):
        self.mapping = mapping

    def __call__(self, f):
        @wraps(f)
        def wrapper(*args, **kwargs):
            try:
                return f(*args, **kwargs)
            except Exception, e:
                mapper = self.mapping.get(type(e))
                if mapper:
                    raise mapper(e)
                raise

        return wrapper


LogbrokerNetworkErrorsWrapper = ExceptionsWrapper(
    {
        socket.error: lambda e: NetworkException(str(e)),
        socket.timeout: lambda e: TimeoutException(str(e)),
        HTTPException: lambda e: NetworkException(str(e)),
        IncompleteRead: lambda e: NetworkException(str(e)),
        ResponseNotReady: lambda e: NetworkException(str(e)),
        BadStatusLine: lambda e: NetworkException(str(e)),
        LineTooLong: lambda e: NetworkException(str(e)),
    },
)


class LogbrokerReader(object):
    SUGGEST_PATH = "/pull/suggest?topic=%s&client=%s&count=%s"
    SESSION_PATH = "/pull/session?topic=%s&client=%s&timeout=30000"
    STREAM_PATH = "/pull/read?topic=%s&client=%s&limit=1000&format=raw"
    COMMIT_PATH = '/pull/commit?client=%s'
    DATA_PORT = 8998
    SUGGEST_TIMEOUT = 10
    COMMIT_TIMEOUT = 5

    def __init__(self, host, client, topics, partitions_count, data_port=None, suggest_timeout=None, timeout=30,
                 suggest_balance=None):
        self.host = '%s:%s' % (host, data_port if data_port else self.DATA_PORT)
        self.client = client
        self.topics = topics
        self.partitions_count = partitions_count
        self.timeout = timeout
        self.suggest_timeout = suggest_timeout or self.SUGGEST_TIMEOUT
        self.suggest_balance = suggest_balance
        self._read_connection = None
        self._session = None

    @retriable_n(retry_count=2, time_sleep=1, exceptions=(TimeoutException, NetworkException))
    @backoff_retriable_n(delay_max=30, delay_min=0.5, exceptions=(SuggestException))
    @LogbrokerNetworkErrorsWrapper
    def suggest(self):
        log.info('Call /suggest/ with host: %s', self.host)
        log.info('Suggest params topic=%s count=%s client=%s balance=%s',
                 self.topics, self.partitions_count, self.client, self.suggest_balance)
        suggest_path = self.SUGGEST_PATH % (','.join(self.topics), self.client, self.partitions_count)
        if self.suggest_balance:
            suggest_path += '&balance=%s' % self.suggest_balance
        log.info('Suggest path: %s', suggest_path)
        if self.suggest_balance:
            return self._suggest_balanced_request(suggest_path)
        else:
            return self._suggest_request(suggest_path)

    def _suggest_request(self, suggest_path):
        suggest_connection = HTTPConnection(self.host, timeout=self.suggest_timeout)
        suggest_connection.request('GET', suggest_path)
        try:
            response = suggest_connection.getresponse()
            data = response.read()
            if response.status != 200:
                raise SuggestException('Suggest response code=%s data=%s' % (response.status, data))
            partitions = [line.strip().split() for line in data.split('\n') if
                          line.strip()]
            return partitions
        finally:
            suggest_connection.close()

    def _suggest_balanced_request(self, suggest_path):
        suggest_connection = HTTPConnection(self.host, timeout=self.suggest_timeout)
        suggest_connection.request('GET', suggest_path)
        try:
            response = suggest_connection.getresponse()
            if response.status != 200:
                raise SuggestException('Suggest response code=%s data=%s' % (response.status, response.read()))

            for chunk in response.read_chunked(None):
                if not chunk:
                    raise SuggestException('Empty suggest chunk')
                chunk = chunk.strip()

                if chunk == 'ping':
                    log.info('Suggest returned ping, waiting')
                    continue
                parts = chunk.split()
                if parts[0] == 'ans':
                    return [parts[1:]]
                else:
                    raise SuggestException('Suggest unexpected chunk %s' % chunk)
        finally:
            suggest_connection.close()

    @LogbrokerNetworkErrorsWrapper
    def __enter__(self):
        partitions = self.suggest()
        if len(partitions) != self.partitions_count:
            raise SuggestException(
                'Suggest return %s partitions but required %s' % (len(partitions), self.partitions_count),
            )

        hosts = [x[0] for x in partitions]
        partitions = [x[1] for x in partitions]
        # Это может быть плохо, если несколько партиций, но
        # мы стараемся иметь воркеров по числу партиций
        read_host = random.choice(hosts)

        # Пытаемся создать сессию
        log.info('Got partitions: %s use host: %s', partitions, read_host)
        self._read_connection = HTTPConnection(read_host, timeout=self.timeout)
        session_path = self.SESSION_PATH % (','.join(partitions), self.client)
        self._read_connection.request('GET', session_path)
        response = self._read_connection.getresponse()
        log.info('Got headers from /session/')
        data = response.read()
        if response.status != 200:
            raise LockException('/session/ returned bad status: %s (%s)' % (response.status, data))
        session = response.getheader('Session')
        if not session:
            raise LockException('Bad session value: %s' % session)

        self._session = session
        log.info('Partitions were locked with session: %s', session)

        return self.chunk_reader(partitions)

    def __exit__(self, exc_type, exc_value, traceback):
        log.info('Close connection: %s %s', exc_type, exc_value)
        self._read_connection.close()
        self._session = None

    def chunk_reader(self, partitions):
        self.partitions = partitions
        read_path = self.STREAM_PATH % (','.join(partitions), self.client)
        log.info('Reader path: %s', read_path)
        while 1:
            self._read_connection.request('GET', read_path, headers={'Session': self._session})
            response = self._read_connection.getresponse()
            if response.status != 200:
                log.error('response:%s, code:%s, session:%s', response.read(), response.status, self._session)
                raise LockException('Cant lock while /read/', response.read())
            log.info('Reader session from response (%s): %s', partitions, response.getheader('Session'))
            log.info('Reader created with session (%s): %s', partitions, self._session)

            gen = response.read_chunked(None)
            while 1:
                try:
                    start_get_chunk = time.time()
                    chunk = gen.next()
                    log.info('Reader get-chunk-time (%s): %s size=%s', partitions, time.time() - start_get_chunk, len(chunk))
                    if not chunk:
                        log.info('Empty chunk')
                        break
                    elif 'java.lang.Exception:' in chunk:
                        raise BaseLogbrokerClientException(chunk)
                    header, body = chunk.split('\n', 1)
                    try:
                        header = parse_tskv(header)
                        log.info(
                            'Chunk info: server=%s file=%s topic=%s partition=%s md5=%s size=%s',
                            header.get('server', ''),
                            header.get('path', ''),
                            header.get('topic', ''),
                            header.get('partition'),
                            md5(body).hexdigest(),
                            len(chunk),
                        )
                    except ValueError:
                        log.error('Bad header\n`%s`\n`%s`', header, chunk)
                        raise BaseLogbrokerClientException('Bad read response')
                    yield header, body
                except StopIteration:
                    break
            if hasattr(response, 'trapiler') and response.trapiler:
                trailer_content = ', '.join(response.trailer)
                log.info('Trailer headers: %s', trailer_content)
                if 'X-Logbroker-Error' in trailer_content:
                    raise BaseLogbrokerClientException('Logbroker error in trailer headers: %s' % trailer_content)
            yield None, None

    @retriable_n(retry_count=1, time_sleep=0.2, exceptions=(TimeoutException, NetworkException, CommitException))
    @LogbrokerNetworkErrorsWrapper
    def commit(self, offsets):
        if not offsets:
            return
        path = self.COMMIT_PATH % self.client
        body = '\n'.join(
            ['%s:%s' % (topic, offset) for topic, offset in offsets.items()])

        log.debug('Start commit: %s', self._session)
        self._read_connection.request('POST', path, body=body, headers={'Session': self._session})
        response = self._read_connection.getresponse()

        data = response.read().strip()
        if response.status != 200:
            raise CommitException('Commit response code=%s data=%s' % (response.status, data))
        if data == 'session was killed':
            log.info('Commit for killed session')
            raise KilledSession('Commit for killed session')
        elif data != 'ok':
            log.warning('Commit error: %s', repr(data))
            raise CommitException(data)
        log.info('Commited (%s): %s', self._session, offsets)


EMPTY_READS_COUNT_BEFORE_COMMIT = 100


class LogbrokerConsumer(object):
    def __init__(self, hosts, client, topics, partitions_count, data_port=None, suggest_timeout=None,
                 suggest_balance=None, decoder_class=AutoDecompressor,
                 stop_on_commit=False):
        self.hosts = hosts
        self.data_port = data_port
        self.client = client
        self.topics = topics
        self.partitions_count = partitions_count
        self.suggest_timeout = suggest_timeout
        self.suggest_balance = suggest_balance
        self.decoder_class = decoder_class or AutoDecompressor
        self.stop_on_commit = stop_on_commit

    @LogbrokerNetworkErrorsWrapper
    def read_unpacked(self, handler):
        working_host = random.choice(self.hosts)
        reader = LogbrokerReader(
            working_host,
            self.client,
            self.topics,
            self.partitions_count,
            self.data_port,
            suggest_timeout=self.suggest_timeout,
            suggest_balance=self.suggest_balance,
        )
        decoder = self.decoder_class()
        with reader as chunk_stream:
            offsets = {}
            flush_count = 0
            empty_successive_reads_count = 0
            got_non_empty_chunk = False
            for chunk in chunk_stream:
                # В цикле вызываем поочередно /pull/read и /pull/commit в рамках одной сессии
                header, rawdata = chunk
                if header is not None and rawdata is not None:
                    unpacked = decoder.decode(rawdata)
                    log.debug('Packed: %s Unpacked: %s', len(rawdata), len(unpacked))
                    is_flushed = handler.process(header, unpacked)

                    topic = header['topic']
                    partition = header['partition']
                    offset = header['offset']
                    offsets['%s:%s' % (topic, partition)] = offset
                    got_non_empty_chunk = True

                    if is_flushed:
                        flush_count += 1
                else:
                    # делаем flush и commit после успешного окончания чтения
                    handler.flush(force=True)
                    log.info('Flush count per read: %s', flush_count + 1)
                    reader.commit(offsets)
                    offsets = {}
                    flush_count = 0

                    if self.stop_on_commit:
                        return

                    # подсчет числа последовательных read-ов, вернувших только пустой чанк
                    if not got_non_empty_chunk:
                        empty_successive_reads_count += 1
                    else:
                        empty_successive_reads_count = 0
                        got_non_empty_chunk = False

                    if empty_successive_reads_count >= EMPTY_READS_COUNT_BEFORE_COMMIT:
                        # достигли порога - считаем оффсеты и закоммитим, чтобы они не протухли (STATADMIN-4304)
                        log.info('Too many empty reads, commiting offsets')
                        meta = LogbrokerMeta(working_host, self.client)
                        try:
                            info = meta.get_offsets_info(self.client, reader.partitions, self.data_port)
                        except Exception as e:
                            log.error('Failed to get offsets info', exc_info=e)
                        else:
                            offsets = {item['partition']: item['offset'] for item in info}
                            reader.commit(offsets)
                            offsets = {}
                            empty_successive_reads_count = 0


class LogbrokerMeta(object):
    def __init__(self, balancer_host, client):
        self.balancer_host = balancer_host
        self.client = client

    def hosts_info(self, dc, timeout=10):
        response = requests.get(
            'http://%s/hosts_info' % self.balancer_host,
            params={'dc': dc},
            timeout=timeout,
        )
        if response.status_code != 200:
            raise Exception('Bad hosts_info response in dc %s: %s (%s)' % (dc, response.status_code, response.content))
        data = response.content
        return [x.split('\t')[0] for x in data.split('\n') if x]

    def show_parts(self, dc=None, ident=None, timeout=10):
        response = requests.get(
            'http://%s/pull/list' % self.balancer_host,
            params={'ident': ident, 'dc': dc, 'show-parts': '1'},
            timeout=timeout,
        )
        if response.status_code != 200:
            raise Exception('Bad /pull/list response: %s (%s)' % (response.status_code, response.content))
        data = response.content
        return [x.strip() for x in data.split('\n') if x.strip()]

    def get_offsets_info(self, client, topics, data_port, dc=None):
        if dc:
            hosts = self.hosts_info(dc)
            host = '%s:%s' % (random.choice(hosts), data_port)
        else:
            host = '%s:%s' % (self.balancer_host, data_port)
        response = requests.get(
            'http://%s/pull/offsets' % host,
            params={'client': client, 'topic': ','.join(topics)},
        )
        data = response.content

        if response.status_code != 200:
            log.error(
                'Failed to read offsets (host %s): code %s, content "%s"',
                host,
                response.status_code,
                response.content,
            )
            raise OffsetsException()

        info = []
        for line in data.split('\n'):
            if not line.strip():
                continue
            partition, offset, start, size, lag, owner = line.strip().split()
            topic, _ = partition.split(':')
            info.append({
                'partition': partition,
                'offset': int(offset),
                'start': int(start),
                'size': int(size),
                'lag': int(lag),
                'owner': owner,
            })
        info.sort(key=lambda item: item['partition'])
        return info
