# -*- coding: utf-8 -*-
import time
import logging
import itertools
import weakref

from kazoo.client import KazooClient
from kazoo.retry import KazooRetry
from kazoo.exceptions import ZookeeperError

from logbroker_client.utils import importobj
from logbroker_client.handlers.base import BaseHandler
from logbroker_client.runner.workers.base import Worker
from logbroker_client.logbroker.client import (
    LogbrokerConsumer,
    BaseLogbrokerClientException,
)
from logbroker_client.handlers.exceptions import HandlerException
from logbroker_client.consumers.exceptions import ExitException


log = logging.getLogger(__name__)


class PartitionerException(Exception):
    pass


class DistributedHandlerDecorator(BaseHandler):
    def __init__(self, worker, handler):
        self.worker = worker
        self.handler = handler

    def process(self, header, data):
        if self.worker.INTERRUPTED:
            raise ExitException('Interrupted')
        if not self.worker.partitioner.acquired:
            raise PartitionerException()
        return self.handler.process(header, data)

    def flush(self, force=False):
        return self.handler.flush(force)


class DistributedLogbrokerWorker(Worker):
    ZK_TIMEOUT = 10
    ZK_START_TIMEOUT = 60
    TIME_FOR_ALLOCATION = 120

    def __init__(self, handler):
        handler_cls = importobj(handler.get('class'))
        handler_args = handler.get('args')
        self.handler = DistributedHandlerDecorator(
            worker=weakref.proxy(self),
            handler=handler_cls(**handler_args)
        )
        self.zk_client = None
        self.partitioner = None

        self.INTERRUPTED = False
        super(DistributedLogbrokerWorker, self).__init__()

    def handle_sigint(self, signum, frame):
        self.INTERRUPTED = True

    def read_task(self):
        return self.tasks_queue.get()

    @staticmethod
    def _partitioner(identifier, members, partitions):
        all_partitions = sorted(partitions)

        if len(members) <= len(all_partitions):
            workers = sorted(members)
        else:
            workers_dict = {}
            for member in members:
                member_hostname = member[:member.rindex('-')]  # member format is '<hostname>-<pid>'
                workers_dict.setdefault(member_hostname, []).append(member)

            for worker_hostname in workers_dict:
                workers_dict[worker_hostname].sort()

            members_list = [workers_dict[worker_hostname] for worker_hostname in sorted(workers_dict.keys())]
            workers_tuples = itertools.izip_longest(*members_list)
            workers = list(itertools.ifilter(lambda x: x is not None, itertools.chain(*workers_tuples)))

        i = workers.index(identifier)
        # Now return the partition list starting at our location and
        # skipping the other workers
        return all_partitions[i::len(workers)]

    def create_partitioner(self, partitions, client, key):
        path = '/partitions_queue/{client}/{key}'.format(
            client=client,
            key=key,
        )
        return self.zk_client.SetPartitioner(path=path,
                                             set=partitions,
                                             partition_func=DistributedLogbrokerWorker._partitioner)

    def get_partitions_count(self, partitions, client, lock_key):
        if not self.partitioner:
            self.partitioner = self.create_partitioner(partitions, client, lock_key)

        start_allocation_time = None
        while not self.INTERRUPTED:
            try:
                if self.partitioner.failed:
                    log.warning('Failed partitioner')
                    self.partitioner = self.create_partitioner(partitions, client, lock_key)
                    log.info('Recreated partitioner after fail')
                elif self.partitioner.release:
                    self.partitioner.release_set()
                    log.info('Released partitioner set')
                elif self.partitioner.acquired:
                    partitions_count = len(list(self.partitioner))
                    log.info('Acquired %s partitions', partitions_count)
                    return partitions_count
                elif self.partitioner.allocating:
                    log.info('Allocating partitions...')
                    # В случае мигания зукипера всё может сломаться,
                    # пока такой костыль
                    if not start_allocation_time:
                        start_allocation_time = time.time()
                    elif time.time() - start_allocation_time > self.TIME_FOR_ALLOCATION:
                        raise IOError('Allocation failed')
                    self.partitioner.wait_for_acquire()
            except ZookeeperError, e:
                log.warning(e)
            time.sleep(1)

    def loop(self):
        log.info('Run worker')
        try:
            task = self.read_task()
        except IOError:
            log.warning("Couldn't receive task from queue")
            return

        log.info('Got task %s', task)

        partitions = task['partitions']

        zk_connection_retry = KazooRetry(max_tries=5, delay=0.5, backoff=2)
        zk_command_retry = KazooRetry(max_tries=5, delay=0.5, backoff=2)
        self.zk_client = KazooClient(
            hosts=task['zk_hosts'],
            connection_retry=zk_connection_retry,
            command_retry=zk_command_retry,
            timeout=self.ZK_TIMEOUT,
        )
        self.zk_client.start(timeout=self.ZK_START_TIMEOUT)

        need_realocation = True
        while not self.INTERRUPTED:
            if need_realocation or (self.partitioner and not self.partitioner.acquired):
                partitions_count = self.get_partitions_count(partitions, task['client'], task['lock_key'])
            if not partitions_count:
                time.sleep(5)
                continue
            else:
                need_realocation = False

            client = LogbrokerConsumer(
                task['hosts'],
                task['client'],
                task['topics'],
                partitions_count,
                task['data_port'],
            )
            try:
                client.read_unpacked(self.handler)
            except PartitionerException:
                need_realocation = True
            except BaseLogbrokerClientException, e:
                log.warning('%s: %s', e.__class__.__name__, e)
                time.sleep(0.1)
            except HandlerException, e:
                log.warning('%s: %s', e.__class__.__name__, e)
                time.sleep(0.1)
            except ExitException:
                log.info('Worker terminated')
                break
