# -*- coding: utf-8 -*-

import os
import six
import time
import signal
import traceback
import logging
import threading
import multiprocessing
import math

from sandbox.sandboxsdk import parameters as sp

from sandbox.common import errors as common_errors
from sandbox.common import system
from sandbox.projects.common import file_utils as fu
from sandbox.projects.common import utils


class ThreadPoolFailureError(common_errors.TaskFailure):
    pass


class TimeoutException(common_errors.TaskError):
    pass


_mp_title = 'Multiprocessing params'


class UseMultiprocessing(sp.SandboxBoolParameter):
    name = 'use_multiprocessing'
    description = 'Use multiprocessing'
    group = _mp_title
    default_value = True


class ProcessCount(sp.SandboxIntegerParameter):
    name = 'process_count'
    description = 'Process count'
    group = _mp_title
    default_value = 0


class ScaleProcessCount(sp.SandboxIntegerParameter):
    name = 'scale_process_count'
    description = 'Scale process count (in percents, 0 = use all CPUs)'
    group = _mp_title
    default_value = 0


class SaveMultiprocessingErrorsInFiles(sp.SandboxBoolParameter):
    name = 'save_multiprocessing_errors_in_files'
    description = 'Save errors in files'
    group = _mp_title
    default_value = False


PARAMS = (
    UseMultiprocessing,
    ProcessCount,
    ScaleProcessCount,
    SaveMultiprocessingErrorsInFiles,
)


def _get_chunks(seq, count):
    """
        [1,2,3,4,5,6,7,8], count=3 -> generator([1,2,3], [4,5,6], [7,8]) of generators
    """
    if not seq:
        return []

    count = min(len(seq), count)
    size = int(math.ceil(float(len(seq)) / count)) if count else len(seq)
    logging.info(
        "Partitioning: List len = %s, # of chunks = %s by %s elements",
        len(seq), count, size
    )
    return (
        (seq[i] for i in six.moves.xrange(pos, pos + size) if i < len(seq))
        for pos in six.moves.xrange(0, len(seq), size)
    )


def get_cpu_chunks(seq, num_of_proc=None):
    proc_count = num_of_proc or system.get_sysparams()['ncpu']
    return _get_chunks(seq, proc_count)


def _wrapper_func(func, chunk, params):
    results = None
    error = None

    try:
        results = func(chunk, params)

    except Exception as e:
        msg = traceback.format_exc()
        logging.error("threadPool - exception in worker: %s", msg)

        error = (e, msg, time.time())

        if params.SaveErrorsInDir:
            file_name = os.path.join(params.SaveErrorsInDir, "{}.txt".format(os.getpid()))
            fu.append_file(file_name, "{}\n{}".format(msg, "-" * 100))

    return results, error


class ThreadWorker(threading.Thread):
    """
        Deprecated but will not be deleted
        may be useful for debug purposes
    """

    def __init__(self, func, chunk, params):
        threading.Thread.__init__(self)

        self.Func = func
        self.Chunk = chunk
        self.Params = params

        self.Results = []
        self.Error = None

    def run(self):
        logging.info("threadPool - thread %s started", self)
        self.Results, self.Error = _wrapper_func(self.Func, self.Chunk, self.Params)
        logging.info("threadPool - thread %s ended", self)

    def IsAlive(self):
        return self.isAlive()

    def GetResults(self):
        return self.Results

    def GetError(self):
        return self.Error


def _proc_wrapper_func(func, chunk, params, remote_dict):
    for sig in (signal.SIGTERM, signal.SIGINT):
        signal.signal(sig, signal.SIG_DFL)
    results, error = _wrapper_func(func, chunk, params)

    try:
        remote_dict["results"] = results
        remote_dict["error"] = error
    except Exception as e:
        logging.error("Can't fill worker results:\n%s", e)


class ProcessWorker(multiprocessing.Process):
    def __init__(self, func, chunk, params):
        self.ServerDict = multiprocessing.Manager().dict()

        multiprocessing.Process.__init__(self, target=_proc_wrapper_func, args=(func, chunk, params, self.ServerDict))

    def __repr__(self):
        return '<ProcessWorker({}, {})>'.format(self.pid, "alive" if self.is_alive() else "dead")

    def IsAlive(self):
        return self.is_alive()

    def GetResults(self):
        return self.ServerDict["results"]

    def GetError(self):
        if "error" not in self.ServerDict:
            return (
                common_errors.TaskFailure("ProcessWorker: No stack trace"),
                "ProcessWorker: No stack trace",
                time.time()
            )
        return self.ServerDict["error"]


def join_workers(workers, timeout):
    """
        workers - list of threading.Thread or multiprocessing.Process
    """
    logging.info("threadPool - joining workers: %s", workers)

    start_time = time.time()

    while True:
        for worker in workers:
            worker.join(0.5)

        workers = [t for t in workers if t.IsAlive()]

        logging.info("threadPool - alive workers: %s", workers)

        if len(workers) == 0:
            break

        if timeout is not None:
            if (time.time() - start_time) > timeout:
                raise TimeoutException("timed out. alive workers: {}".format(workers))

        time.sleep(10)

    logging.info("threadPool - joined all workers")


class FakeObj(object):
    pass


def process_data(
    func,
    list_to_process,
    params,
    timeout=None,
    default_exception_type=common_errors.TaskError,
    use_processes=True,
    process_count=None,
    ctx=None,
):
    if params is None:
        params = FakeObj()
    params.SaveErrorsInDir = False

    if ctx:
        use_processes = utils.get_or_default(ctx, UseMultiprocessing)
        process_count = utils.get_or_default(ctx, ProcessCount)
        scale_num_of_proc = utils.get_or_default(ctx, ScaleProcessCount)
        if scale_num_of_proc:
            if not process_count:
                process_count = system.get_sysparams()['ncpu']
            process_count = process_count * scale_num_of_proc / 100
            if not process_count:
                process_count = 1

        if utils.get_or_default(ctx, SaveMultiprocessingErrorsInFiles):
            params.SaveErrorsInDir = utils.create_misc_resource_and_dir(
                ctx, 'multiprocessing_errors_resource_id', 'multiprocessing errors', 'multiprocessing_errors')

    logging.info("threadPool - processing started")

    workers = []

    for chunk in get_cpu_chunks(list_to_process, process_count):
        if use_processes:
            worker = ProcessWorker(func, chunk, params)
        else:
            worker = ThreadWorker(func, chunk, params)

        worker.start()
        workers.append(worker)

    logging.info("threadPool - workers started")
    join_workers(workers, timeout)
    logging.info("threadPool - workers ended")
    results = worker_results(workers, use_processes, default_exception_type)
    logging.info("threadPool - processing ended")

    return results


def worker_results(workers, use_processes, default_exception_type):
    results = []
    for worker in workers:
        error = worker.GetError()

        if error:
            exception, msg, timestamp = error
            e = ThreadPoolFailureError if isinstance(exception, common_errors.TaskFailure) else default_exception_type
            raise e("An error occurred in worker ({}):\n{}".format(worker, msg))

        chunk_results = worker.GetResults()

        if type(chunk_results) == list:
            results_descr = 'list of {} items'.format(len(chunk_results))
            results.extend(chunk_results)
        else:
            results_descr = str(type(chunk_results))
            results.append(chunk_results)

        if use_processes:
            worker_descr = '{} (PID: {})'.format(worker, worker.pid)
        else:
            worker_descr = str(worker)
        logging.info('results from worker %s: %s', worker_descr, results_descr)

    return results
