# -*- coding: utf-8 -*-

import logging
import os
import re
import six.moves.queue
import subprocess
import sys
import threading

import yt.wrapper as yt  # for @yt.aggregator decorator


logger = logging.getLogger(__name__)


# thread-safe rows iteration
class _safe_iterator(object):
    def __init__(self, data):
        self.lock = threading.Lock()
        self.it = iter(data)

    def __iter__(self):
        return self

    def next(self):
        with self.lock:
            return self.it.next()


@yt.aggregator
class SubprocessMapper(object):
    """Python mapper class for yt.wrapper.
    Forks custom command to process rows.
    Based on sashateh's SubprocessFieldedStreamMapper.
    """

    NONE_VALUE = '<NONE>'

    def __init__(self, command, **kwargs):
        """Arguments:
            command     - command for Popen (string or tuple of args)
        kwargs:
            process_count - number of sub-processes to run
            dst_fields  - list of dicts defining field names/types for output tables, e.g.: [{"id":int,"name":str},...]
            dst_options - options for dst tables:
                            'unknown_as_string':  output unknown fields as string
            init_pause  - timeout for first output
            run_pause   - timeout for output
            ...         - other kwargs for Popen
        """
        self._none_value = SubprocessMapper.NONE_VALUE

        self.command = command  # may be redefined

        dst_options = kwargs.pop('dst_options', None)
        dst_fields = kwargs.pop('dst_fields', None)

        if dst_fields is None and dst_options is None:
            raise Exception("Define dst_fields or dst_options!")
        if dst_fields is None:
            dst_fields = [None for _ in range(len(dst_options))]
        if dst_options is None:
            dst_options = [dict() for _ in range(len(dst_fields))]

        for ti, opts in enumerate(dst_options):
            if dst_fields[ti] is None:
                if opts.get('unknown_as_string'):
                    dst_fields[ti] = {}  # нет обязательных полей
                else:
                    raise Exception("Can't derive dst_fields for table {}".format(ti))

        self.dst_options = dst_options
        self.dst_fields = dst_fields  # may be used in derived class

        self._process_count = kwargs.pop('process_count', 1)
        self._init_pause = kwargs.pop('init_pause', 2 * 3600)
        self._run_pause = kwargs.pop('run_pause', 2 * 600)
        self._kwargs = kwargs

    def __call__(self, rows):
        self.init_command()
        self._init_processes()
        self._safe_rows_iter = _safe_iterator(rows)
        for output in self.communicate():
            table_index = int(output.get('@table_index', '0'))
            table_fields = self.dst_fields[table_index]
            table_opts = self.dst_options[table_index]

            output_row = {}
            if table_opts.get('unknown_as_string'):
                for fname, fval in output.items():
                    if fval is None:
                        output_row[fname] = None
                    elif fname in table_fields:
                        output_row[fname] = table_fields[fname](fval)
                    else:
                        output_row[fname] = fval  # string
            else:
                for fname, ftype in table_fields.items():
                    val = output[fname]
                    output_row[fname] = ftype(val) if val is not None else None

            output_row['@table_index'] = table_index
            yield output_row

    def init_command(self):
        pass  # define in derived class

    def _init_processes(self):
        env = os.environ.copy()
        if hasattr(self, 'command_env'):
            env.update(self.command_env)

        sub_kwargs = self._kwargs.copy()
        if 'bufsize' not in sub_kwargs:
            sub_kwargs['bufsize'] = -1
        processes = {}
        for i in range(self._process_count):
            processes[i] = subprocess.Popen(self.command, stdout=subprocess.PIPE, stdin=subprocess.PIPE, env=env, **sub_kwargs)
        self.processes = processes

    # http://stackoverflow.com/questions/375427/non-blocking-read-on-a-subprocess-pipe-in-python/4896288#4896288
    def communicate(self):
        main_queue = six.moves.queue.Queue(maxsize=50000)

        # Write into external subprocesses
        def writer(sub_id):
            try:
                p = self.processes[sub_id]
                for row in self._safe_rows_iter:
                    if p.poll() is not None:
                        raise Exception("process %d finished too early" % sub_id)
                    line = self._encode(row)
                    p.stdin.write(line + "\n")
                p.stdin.close()
            except Exception as e:
                logger.exception("Some error during writing: %s", e)
                main_queue.put({"exc": e, "exc_info": sys.exc_info()}, block=True)

        # Read from external subprocesses
        def reader(sub_id):
            try:
                p = self.processes[sub_id]
                for line in iter(p.stdout):
                    row = self._decode(line.rstrip("\n"))
                    main_queue.put({"output": row}, block=True)
                p.stdout.close()
                p.wait()
                if p.returncode != 0:
                    raise Exception("Subprocess exited with non-zero code!")
            except Exception as e:
                logger.exception("Some error during writing: %s", e)
                main_queue.put({"exc": e, "exc_info": sys.exc_info()}, block=True)

        def waiter(threads):
            for t in threads:
                t.join()
            main_queue.put({"all_done": 1}, block=True)

        writer_threads = []
        for sub_id in self.processes.keys():
            writer_thread = threading.Thread(target=writer, args=[sub_id])
            writer_thread.daemon = True
            writer_thread.start()
            writer_threads.append(writer_thread)

        reader_threads = []
        for sub_id in self.processes.keys():
            reader_thread = threading.Thread(target=reader, args=[sub_id])
            reader_thread.daemon = True
            reader_thread.start()
            reader_threads.append(reader_thread)

        waiter_thread = threading.Thread(target=waiter, args=[writer_threads + reader_threads])
        waiter_thread.daemon = True
        waiter_thread.start()

        # Now read events from Queue
        timeout = self._init_pause
        while True:
            try:
                item = main_queue.get(block=True, timeout=timeout)
            except six.moves.queue.Empty:
                raise Exception("No new data in %d seconds" % timeout)
            except Exception as e:
                self._stop_and_raise(e, sys.exc_info()[2])

            if 'output' in item:
                yield item['output']
                main_queue.task_done()
                timeout = self._run_pause
            elif 'all_done' in item:
                break
            elif 'exc' in item:
                self._stop_and_raise(item['exc'], item['exc_info'][2])

    def _stop_and_raise(self, exc, tb):
        for p in self.processes.values():
            try:
                p.terminate()
            except Exception as e:
                logger.exception("Exception: %s", e)
            try:
                p.kill()
            except:
                pass
        logger.exception("External error")
        raise exc

    def _encode(self, row):
        return "\t".join([k + "=" + _escape_tsv_value(v, self._none_value) for k, v in row.items()])

    def _decode(self, line):
        none_string = str(self._none_value)
        row = {}
        for kv in line.split("\t"):
            k, v = kv.split("=", 1)
            row[k] = _unescape_tsv_value(v, none_string)
        return row


# encode/decode tsv
TSV_ESCAPE_RE = re.compile(r'(\n|\t|\\)')
TSV_UNESCAPE_RE = re.compile(r'\\(.)')
TSV_ESCAPE_DICT = {"\n": "\\n", "\t": "\\t", "\\": "\\\\"}
TSV_UNESCAPE_DICT = {"n": "\n", "t": "\t", "\\": "\\"}


def _escape_tsv_match(match):
    return TSV_ESCAPE_DICT.get(match.group(1))


def _escape_tsv_value(val, none_value):
    if val is None:
        return str(none_value)
    val = str(val)
    if TSV_ESCAPE_RE.search(val):
        return TSV_ESCAPE_RE.sub(_escape_tsv_match, val)
    else:
        return val


def _unescape_tsv_match(match):
    char = match.group(1)
    return TSV_UNESCAPE_DICT.get(char, char)


def _unescape_tsv_value(val, none_string):
    if val == none_string:
        return None
    return TSV_UNESCAPE_RE.sub(_unescape_tsv_match, val)
