# -*- coding: utf-8 -*-
u"""
Скрипт для экспорта пользовательских коллекций в json
"""
import sys
import pymongo
import hashlib
import logging
import ujson
import zlib
import datetime
import time
import multiprocessing
import socket
import collections
import yt.wrapper as yt

from argparse import ArgumentParser


# количество процессов, которые обрабатывают данные
PROCESS_WORKER_NUM = 3
# как часто главный процесс опрашивает подпроцессы на `is_alive` (секунды)
MAIN_PROC_ALIVE_CHECK_DELTA = 0.1
# как часто главный процесс выдает статистику по процессам и размеру очереди (секунды)
MAIN_PROC_STAT_DELTA = 30
# как часто процесс генерации данных выводит статистику (секунды)
DOC_SOURCE_STAT_DELTA = 30
# максимальный размер очереди межпроцессного взаимодействия.
# https://docs.python.org/2/library/multiprocessing.html#multiprocessing.Queue
QUEUE_MAX_SIZE = 500000
# Количество документов отправлямых на YT в рамках одной транзакции
YT_BATCH_SIZE = 100000
# Количество попыток заливки данных в YT
YT_WRITE_ATTEMPS = 10
# Размер sleep-а между попытками: attemp * YT_FAIL_DELAY
YT_FAIL_DELAY = 10
# логирование
log_format = '%(asctime)s [%(process)d] [%(processName)s] [%(module)s] [%(levelname)s] %(message)s'
logging.basicConfig(stream=sys.stderr, format=log_format, level=logging.INFO)
logging.getLogger('yt.packages.requests.packages.urllib3.connectionpool').setLevel(logging.ERROR)
# YT config
yt.config['proxy']['url'] = settings.mrstat['yt_proxy']


if PROCESS_WORKER_NUM < 1:
    raise ValueError()


class Speedometer(object):
    msg_format = "progress: %(percent)5.2f%%, "\
        "instant_speed: %(instant_speed)0.2f(items/s), "\
        "arvg_speed: %(overall_speed)0.2f(items/s), items: %(count)i"

    def __init__(self, iterable, metric_hook=None, total_items=None, period=5):
        self._iter = iter(iterable)
        self._total_items = total_items
        self._metric_hook = metric_hook
        if metric_hook is None:
            self._metric_hook = self.default_metric_hook
        self._hook_period = period
        self._count = 0
        self._start_time = None
        self._last_hook_time = None
        self._last_hook_count = 0

    @classmethod
    def default_metric_hook(cls, metrics):
        print cls.msg_format % metrics

    def __iter__(self):
        return self

    def next(self):
        if self._count == 0:
            self._start_time = time.time()
            self._last_hook_time = self._start_time

        self.call_hook()
        try:
            self._count += 1
            return next(self._iter)
        except StopIteration:
            self._count -= 1
            self.call_hook(force=True)
            raise

    def get_metrics(self):
        metrics = {
            'percent': -1.0,
            'count': self._count,
            'instant_speed': 0.0,
            'overall_speed': 0.0,
        }
        if self._total_items:
            metrics['percent'] = self._count / float(self._total_items) * 100

        cur_time = time.time()
        if self._last_hook_time:
            metrics['instant_speed'] = (self._count - self._last_hook_count) / (cur_time - self._last_hook_time)
        if self._start_time:
            metrics['overall_speed'] = self._count / (cur_time - self._start_time)
        return metrics

    def call_hook(self, force=False):
        if force or time.time() - self._last_hook_time > self._hook_period:
            metrics = self.get_metrics()
            self._metric_hook(metrics)
            self._last_hook_time = time.time()
            self._last_hook_count = self._count


class DataSource(object):
    @property
    def name(self):
        return self._name

    def total_items(self):
        raise NotImplementedError()

    def data_stream(self):
        raise NotImplementedError()


class MongoDBSource(DataSource):
    def __init__(self, host, db_name, collection_name):
        self.client = pymongo.MongoClient(host, read_preference=pymongo.ReadPreference.SECONDARY_PREFERRED)
        self.collection_name = collection_name
        self.collection = self.client[db_name][collection_name]
        self._name = '%s.%s.%s' % (host, db_name, collection_name)
        try:
            self.shard_name = self.client['admin'].command('replSetGetStatus')['set']
        except Exception:
            if 'localhost' in host or '127.0.0.1' in host:
                self.shard_name = socket.gethostname()
            self.shard_name = host

    def total_items(self):
        return self.collection.count()

    def data_stream(self):
        cursor = self.collection.find(cursor_type=pymongo.CursorType.EXHAUST, batch_size=100000000)
        for doc in cursor:
            doc[u'_shard'] = self.shard_name
            doc[u'_collection'] = self.collection_name
            yield doc


class TestDataSource(DataSource):
    """Для отладки"""
    def __init__(self, items_num=100):
        self.items_num = items_num
        self._name = 'xrange(%s)' % items_num

    def total_items(self):
        return self.items_num

    def data_stream(self):
        return ({"i": i} for i in xrange(self.items_num))


class SlowTestDataSource(TestDataSource):
    """Для отладки"""
    def data_stream(self):
        for i in xrange(self.items_num):
            yield {"i": i}
            time.sleep(0.1)


class MpfsDocProcessor(object):
    """Методы для работы с MPFS-ми документами"""
    @staticmethod
    def path_obfuscation(path, hash_len=6):
        """Обфускация пути. Возвращает массив"""
        if not isinstance(path, str):
            path = path.encode('utf-8')

        hash_func = lambda x: hashlib.sha1(x).hexdigest()[:hash_len]
        parts = path.strip('/').split('/')
        result = []
        for part in parts:
            result.append(hash_func(part))
        return "/%s" % '/'.join(result)

    @classmethod
    def process_doc(cls, doc):
        """MPFS user data document -> json"""
        if 'zdata' in doc:
            doc[u'zdata'] = ujson.loads(zlib.decompress(doc['zdata']))
        if 'key' in doc:
            doc[u'key'] = cls.path_obfuscation(doc['key'])
        doc[u'iso_eventtime'] = datetime.datetime.now().strftime("%Y-%m-%d %T")
        return ujson.dumps(doc, ensure_ascii=False)


class DataProcessor(object):
    """Класс обработчик сырых данных"""
    def flush(self):
        raise NotImplementedError()

    def process(self, item):
        raise NotImplementedError()


class StdoutMpfsDocProcessor(MpfsDocProcessor, DataProcessor):
    def flush(self):
        pass

    def process(self, doc):
        json_doc = self.process_doc(doc)
        sys.stdout.write("%s\n" % json_doc)
        sys.stdout.flush()


class YtMpfsDocProcessor(MpfsDocProcessor, DataProcessor):
    CACHE_SIZE = YT_BATCH_SIZE
    # настройки YT-я
    YT_TABLE_TMPL = '<append=true>%s'
    YT_FORMAT = '<encode_utf8=false>json'

    def __init__(self, table_path):
        super(YtMpfsDocProcessor, self).__init__()
        self._cache = []
        self.dst_table = self.YT_TABLE_TMPL % table_path

    def flush(self):
        if not self._cache:
            logging.info('Empty cache. No data to YT.')
            return
        start_dt = datetime.datetime.now()

        for try_num in xrange(YT_WRITE_ATTEMPS):
            try:
                yt.write_table(self.dst_table, self._cache, format=self.YT_FORMAT)
            except Exception as e:
                sleep_time = try_num * YT_FAIL_DELAY
                logging.exception("Catch error at `yt.write_table`. Try again: %s attemp. sleeping: %s(sec)" % (try_num, sleep_time))
                time.sleep(sleep_time)
            else:
                break

        duration = datetime.datetime.now() - start_dt
        stat = {
            'lines': len(self._cache),
            'size': sys.getsizeof(self._cache),
            'duration': str(duration),
        }
        logging.info('Send data to YT. Stat: %s.' % stat)
        self._cache = []

    def process(self, doc):
        json_doc = self.process_doc(doc)
        self._cache.append(json_doc)
        if len(self._cache) > self.CACHE_SIZE:
            self.flush()


class DataSourceProcess(multiprocessing.Process):
    """
    Обертка над `DataSource`, позволяющая получать данные и добавлять их в очередь в отдельном процессе.
    """
    def __init__(self, queue, data_source):
        if not isinstance(data_source, DataSource):
            raise TypeError("`data_source` should by `DataSource` instance. Got: %s" % type(data_source))
        super(DataSourceProcess, self).__init__()
        self.queue = queue
        self.data_source = data_source

    @staticmethod
    def log_metrics(metrics):
        logging.info(Speedometer.msg_format % metrics)

    def run(self, *args, **kwargs):
        total_doc_num = self.data_source.total_items()
        if total_doc_num == 0:
            raise ValueError('No data in source: "%s"' % self.data_source.name)
        logging.info('Documents in "%s": %i.' % (self.data_source.name, total_doc_num))
        data_stream = Speedometer(self.data_source.data_stream(),
                                  metric_hook=self.log_metrics,
                                  total_items=total_doc_num,
                                  period=DOC_SOURCE_STAT_DELTA)
        for data in data_stream:
            self.queue.put(data)
        self.queue.close()

        metrics = data_stream.get_metrics()
        stat_msg = "Arvarege speed: %(overall_speed)0.2f(items/s), processed items: %(count)i" % metrics
        logging.info('Job done! Data source: "%s". %s' % (self.data_source.name, stat_msg))


class DataProcessorProcess(multiprocessing.Process):
    def __init__(self, queue, no_more_data_event, data_processor):
        if not isinstance(data_processor, DataProcessor):
            raise TypeError("`data_processor` should by `DataProcessor` instance. Got: %s" % type(data_processor))
        super(DataProcessorProcess, self).__init__()
        self.queue = queue
        self.no_more_data_event = no_more_data_event
        self.data_processor = data_processor

    def run(self):
        stat = collections.Counter()
        while 1:
            if self.no_more_data_event.is_set() and self.queue.empty():
                break
            try:
                doc = self.queue.get(True, timeout=1.0)
            except Exception as e:
                stat['queue_get_timeouts'] += 1
            else:
                self.data_processor.process(doc)
                stat['doc_processed'] += 1
            stat['iterations'] += 1
        self.data_processor.flush()
        logging.info('Job done! Stat: %s' % dict(stat))


def log_main_stat(child_proccesses, queue, no_more_data_event, prefix='In progress'):
    stat = {
        'is_alive': {p.name: p.is_alive() for p in child_proccesses},
        'qsize': queue.qsize(),
        'no_more_data': no_more_data_event.is_set(),
    }
    logging.info('%s. Stat: %s' % (prefix, stat))


def check_or_create_table(dst):
    if not yt.exists(dst):
        yt.create('table', dst)


def main(host, collection_name, dst):
    # создаем переменные синхронизации
    no_more_data_event = multiprocessing.Event()
    queue = multiprocessing.Queue(QUEUE_MAX_SIZE)

    # проверяем наличие таблицы на YT
    # вынесено в отдельный процесс, т.к. при форке tcp-соединения наследуются,
    # что может приводить к ошибкам.
    check_table_proc = multiprocessing.Process(target=check_or_create_table, args=(dst,))
    check_table_proc.start()
    check_table_proc.join()

    # запускаем процессы обрабатывающие данные
    processors = []
    for i in range(PROCESS_WORKER_NUM):
        #data_processor = StdoutMpfsDocProcessor()
        data_processor = YtMpfsDocProcessor(dst)
        p = DataProcessorProcess(queue, no_more_data_event, data_processor)
        processors.append(p)
        p.start()

    # создаем процесс источник данных
    #data_source = SlowTestDataSource(items_num=100)
    #data_source = TestDataSource(items_num=10000)
    data_source = MongoDBSource(host, collection_name, collection_name)
    data_source_proc = DataSourceProcess(queue, data_source)
    data_source_proc.start()

    child_processes = [data_source_proc] + processors
    stat_count = 0
    stat_every = int(MAIN_PROC_STAT_DELTA / MAIN_PROC_ALIVE_CHECK_DELTA)
    # мониторим состояние детей-процессов
    while any([p.is_alive() for p in child_processes]):
        if not data_source_proc.is_alive() and not no_more_data_event.is_set():
            # посылаем событие, что данные добавляться в очередь больше не будут
            no_more_data_event.set()

        # вывод статистики главного процесса
        if stat_count > stat_every:
            log_main_stat(child_processes, queue, no_more_data_event)
            stat_count = 0

        time.sleep(MAIN_PROC_ALIVE_CHECK_DELTA)
        stat_count += 1
    log_main_stat(child_processes, queue, no_more_data_event, prefix='Job done')


if __name__ == '__main__':
    user_collections = ['user_data', 'trash', 'attach_data', 'misc_data', 'narod_data', 'hidden_data']

    parser = ArgumentParser(description="Export users data to json")
    parser.add_argument('--host', required=True, dest='host', help='Host name with port.')
    parser.add_argument('--collection', required=True, choices=user_collections, dest='collection', help='Collection name.')
    parser.add_argument('--dst-table', required=True, dest='dst', help='YT table full path. Ex: "//home/mpfs-stat/storage/mpfs_db_dump/2016-06-06"')
    args = parser.parse_args()
    try:
        main(args.host, args.collection, args.dst)
    except (NotImplementedError, ValueError) as e:
        logging.error(e.message)
