import argparse
import base64
import errno
import hashlib
import json
import logging
import logging.handlers
import math
import mmap
import msgpack
import os
import queue
import random
import socket
import struct
import subprocess as subproc
import sys
import textwrap
import threading
import time
import traceback
import yaml

from path import Path

from ...rpc import errors as rpc_errors
from ...rpc.client import RPCClient

import fallocate
import lxml.etree
import requests


MB = 1024 * 1024
GB = MB * 1024

MIN_BUFFERS = 2
MAX_BUFFERS = 16

__version__ = "0.1"


def parse_args():
    parser = argparse.ArgumentParser(
        epilog=textwrap.dedent('''
            examples:
              {program} upload --qdm-key <key> current.qcow2
              {program} download --rev <revision_key> .

            notes:
              If --logfile not specified or set to "-" will log to stdout.
              If multiple files specified for uploading, they will be ordered by name.

        '''.format(program=os.path.basename(sys.argv[0]))),
        formatter_class=argparse.RawDescriptionHelpFormatter
    )
    parser.add_argument('--qdm-host', default='qdm.yandex-team.ru')
    parser.add_argument('--qdm-port', type=int, default=2183)
    parser.add_argument('--qdm-host-direct', default=False, action='store_true')
    parser.add_argument('--logfile', default='-', help='log file to write')
    parser.add_argument('--progress-file', default='-', help='progress file')
    parser.add_argument('--version', action='version', version='%(prog)s {version}'.format(version=__version__))

    subparsers = parser.add_subparsers(title='subcommands')

    parser_upload = subparsers.add_parser('upload', help='upload files to mds (you should know qdm key for session)')
    parser_upload.add_argument('--qdm-key', required=True, help='qdm session id')
    parser_upload.add_argument('--cwd', '-C', default=None, help='cwd to search files')
    parser_upload.add_argument('--mds-use-proxy', default=False, action='store_true')
    parser_upload.add_argument('--spec')
    parser_upload.add_argument('--background', default=False, action='store_true')
    parser_upload.add_argument('path', nargs='*', help='filenames to upload')
    parser_upload.set_defaults(operation='upload')

    parser_download = subparsers.add_parser('download', help='download files from qdm revision key')
    parser_download.add_argument('--rev', required=True, help='qdm revision key')
    parser_download.add_argument('--vm-id')
    parser_download.add_argument('--cluster')
    parser_download.add_argument('--node-id')
    parser_download.add_argument('dest', help='destination path')
    parser_download.set_defaults(operation='download')

    return parser.parse_args()


class IOReader(object):
    def __init__(self, threads, block_size):
        self.total_threads = threads
        self.threads = []

        self.block_size = block_size

        self.read_by = 4 * 1024 * 1024
        self.log = logging.getLogger('ioreader')

    def _reader(self, jobs, results):
        while True:
            try:
                job = jobs.get()

                if job is None:
                    break

                full_fn, fn, buff, idx, global_idx = job

                with open(full_fn, mode='rb') as fp:
                    fp.seek(idx * self.block_size)
                    read_bytes = fp.readinto(buff)
                    if not read_bytes:
                        # Seek past EOF, just ignore. Probably file shrunked in size while we read it
                        self.log.warning(
                            'Attempt to read past EOF (%s) (idx %d, %d byte)',
                            fn, idx, idx * self.block_size
                        )
                    else:
                        results.put((fn, idx, read_bytes, buff, global_idx))
            except BaseException:
                traceback.print_exc()
                sys.stderr.write('Unable to read data, emergency exit')
                sys.stderr.flush()
                os._exit(1)

    def _watcher(self, results, count):
        try:
            [t.join() for t in self.threads]
            self.log.info('IOReader finished, sending stopflag %d times', count)
            for i in range(count):
                results.put(None)
        except BaseException:
            traceback.print_exc()
            sys.stderr.write('IOReader watcher thread died, emergency exit')
            sys.stderr.flush()
            os._exit(1)

    def run(self, io_queue, work_queue):
        for i in range(self.total_threads):
            thr = threading.Thread(target=self._reader, args=(io_queue, work_queue))
            thr.daemon = True
            self.threads.append(thr)

        [t.start() for t in self.threads]

    def run_watcher(self, results, count):
        self._watcher_thr = threading.Thread(target=self._watcher, args=(results, count))
        self._watcher_thr.daemon = True
        self._watcher_thr.start()


class DataWorker(object):
    def __init__(self, qdm_upload_job, threads, mds_use_proxy):
        self.qdm_upload_job = qdm_upload_job
        self.total_threads = threads
        self.threads = []
        self.mds_use_proxy = mds_use_proxy
        self.log = logging.getLogger('dataworker')

    def _worker(self, work_queue, result_queue):
        while True:
            job = work_queue.get()

            if job is None:
                self.threads.remove(threading.current_thread())
                break

            fn, block_idx, read_bytes, buff, global_idx = job

            digest = hashlib.blake2b(digest_size=32)
            digest_read_by = 4 * 1024 * 1024

            md5_csum = hashlib.md5()

            for idx in range(math.ceil(read_bytes / digest_read_by)):
                data_block = buff[idx * digest_read_by:min((idx + 1) * digest_read_by, read_bytes)]
                digest.update(data_block)
                md5_csum.update(data_block)

            if not self.qdm_upload_job.check_data_block(global_idx, read_bytes, 'b2', digest.hexdigest()):
                self.log.debug('%s (%d bytes): need to upload', digest.hexdigest(), read_bytes)

                # 70 tries with N*3 secs pause between is about 2hrs total
                total_tries = 70
                pause_mult_between_tries = 3

                for idx in range(total_tries):
                    # Compute random digest with original hash + random 8 byte integer and float timestamp
                    rnd = random.randint(0, 2**64 - 1)
                    fn_digest = hashlib.blake2b(
                        digest.digest() + struct.pack('Qd', rnd, time.time()),
                        digest_size=32
                    )
                    mds_fn = 'qdmblock:1:%s' % (fn_digest.hexdigest(), )
                    try:
                        mds_key = self.upload_block_to_mds(
                            mds_fn, buff, read_bytes, self.qdm_upload_job, md5_csum.digest()
                        )
                    except Exception as ex:
                        if idx == 9:
                            raise

                        wait = pause_mult_between_tries * idx
                        self.log.warning(
                            'Failed to upload (%s) will try %d more times, next in %ds',
                            ex, total_tries - 1 - idx, wait
                        )
                        if wait > 0:
                            time.sleep(wait)
                    else:
                        break

                self.log.info('%s (%d bytes): uploaded with key %s', digest.hexdigest(), read_bytes, mds_key)
                self.qdm_upload_job.data_block_store(
                    global_idx, read_bytes, 'b2', digest.hexdigest(), mds_key, ttl=365 * 86400
                )
            else:
                self.log.debug('%s (%d bytes): no need to upload', digest.hexdigest(), read_bytes)

            result_queue.put((fn, global_idx, read_bytes, digest.hexdigest(), buff))

    def _convert_mmap_to_uploadable_obj(self, buff, size):
        class _Uploadable(object):
            def __init__(self, buff, size):
                self.buff = buff
                self.len = size
                self.pos = 0

            def __iter__(self):
                readby = 1024 * 1024

                while True:
                    end = min(self.pos + readby, self.len)
                    data = self.buff[self.pos:end]
                    self.pos = end

                    yield data

                    if end == self.len:
                        break

        return _Uploadable(buff, size)

    def upload_block_to_mds(self, fn, buff, size, qdm_upload_job, md5_csum):
        if self.mds_use_proxy:
            mds_hostname = 'storage-int.mds.yandex.net'
        else:
            mds_hostname = requests.get(
                'http://storage-int.mds.yandex.net:1111/hostname',
                timeout=60
            ).text

        tvm_ticket = qdm_upload_job.get_tvm_mds_ticket()

        self.log.debug('%s: uploading', fn)

        # storage-int.mdst.yandex.net
        req = requests.post(
            'http://%s:1111/upload-qdm/%s' % (mds_hostname, fn),
            data=self._convert_mmap_to_uploadable_obj(buff, size),
            headers={
                'X-Ya-Service-Ticket': tvm_ticket,
                'Content-MD5': base64.encodebytes(md5_csum).decode('utf-8').rstrip()
            }
        )

        self.log.debug('%s: uploaded with status code %r', fn, req.status_code)

        if req.status_code == 200:
            tree = lxml.etree.fromstring(req.text.encode('utf-8'))
            return tree.get('key')

        else:
            raise Exception('Failed to post data, return code was %r' % (req.status_code, ))

    def run(self, work_queue):
        result_queue = queue.Queue(maxsize=1)

        for i in range(self.total_threads):
            thr = threading.Thread(target=self._worker, args=(work_queue, result_queue))
            thr.daemon = True
            self.threads.append(thr)

        [t.start() for t in self.threads]

        while True:
            if not all(t.is_alive() for t in self.threads):
                raise Exception('Some threads died!')

            try:
                yield result_queue.get(timeout=1)
            except queue.Empty:
                continue

            if not self.threads and result_queue.qsize() == 0:
                # If all workers threads was stopped and removed and
                # we have no data in result queue anymore -- this indicates
                # successful stop
                self.log.info('All worker threads finished')
                return


class DataUploader(object):
    block_size = 512 * 1024 * 1024

    def __init__(self, qdm_upload_job, mds_use_proxy):
        self.qdm_upload_job = qdm_upload_job

        self.final_filemap = []

        self.mds_use_proxy = mds_use_proxy

        self.dirty_hashmap = {}
        self.dirty_revision = None

        self.log = logging.getLogger('datauploader')

        self.total_size = sum(info['source'].getsize() for info in self.qdm_upload_job.filemap)

    def read_data(self, work_queue):
        fn, start, length = work_queue.get()

    def _io_scheduler(self, buffer_queue, io_queue, io_threads):
        total_block_idx = 0

        try:
            for info in self.qdm_upload_job.filemap:
                for idx in range(math.ceil(info['source'].getsize() / self.block_size)):
                    buff = buffer_queue.get()
                    assert len(buff) == self.block_size

                    io_queue.put((info['source'], info['path'], buff, idx, total_block_idx))
                    total_block_idx += 1

            # Send stop indicators to all io threads
            self.log.info('IO scheduler finished -- sendint stop flag %d times', io_threads)

            for i in range(io_threads):
                io_queue.put(None)

        except BaseException:
            traceback.print_exc()
            sys.stderr.write('Unable to read data, emergency exit')
            sys.stderr.flush()
            os._exit(1)

    def upload(self, io_threads, upload_threads, buffers_count):
        duration = self.qdm_upload_job.get_estimated_duration()
        start_time = time.time()
        yield 'init'
        yield 'eta', start_time, duration
        # self.buff = [mmap.mmap(-1, self.block_size) for idx in range(io_threads)]
        #
        # max buffer usage == 8
        # 1 in io_thread queue
        # 1 in io_threads
        # 1 in work_thread queue
        # 4 in work_threads

        total_blocks = sum(
            math.ceil(
                info['source'].getsize() / self.block_size
            )
            for info in self.qdm_upload_job.filemap
        )

        yield 'total_bytes', self.total_size
        yield 'total_blocks', total_blocks

        # calculate needed buffers
        # io + work + 2 for queues (io_queue and work_queue)
        # 1 + 4 + 2 == 7 for generic use
        buffer_queue = queue.Queue(maxsize=buffers_count)
        for i in range(buffer_queue.maxsize):
            buffer_queue.put(mmap.mmap(-1, self.block_size))

        io_queue = queue.Queue(maxsize=1)
        work_queue = queue.Queue(maxsize=1)

        io_scheduler_thr = threading.Thread(
            target=self._io_scheduler, args=(
                buffer_queue, io_queue,
                io_threads  # used to determine how much stop indicators to send
            )
        )
        io_scheduler_thr.daemon = True
        io_scheduler_thr.start()

        io_reader = IOReader(io_threads, block_size=self.block_size)
        io_reader.run(io_queue, work_queue)
        io_reader.run_watcher(work_queue, upload_threads)

        data_worker = DataWorker(self.qdm_upload_job, upload_threads, self.mds_use_proxy)

        done_blocks = 0

        file_blocks = {}  # fn - block count

        self.qdm_upload_job.allow_reinit(True)

        for result in data_worker.run(work_queue):
            fn, idx, read_bytes, digest, buff = result

            # idx is a global idx here

            file_blocks.setdefault(fn, {'bcnt': 0, 'size': 0})
            file_blocks[fn]['bcnt'] += 1
            file_blocks[fn]['size'] += read_bytes

            done_blocks += 1

            self.log.debug('done block #%d (total %d blocks), fn %s', idx, done_blocks, fn)

            yield 'done_block', (done_blocks / total_blocks) * 100, done_blocks

            # if duration based on speed of previous uploads is greater than estimated eta -> change eta
            dur = (time.time() - start_time) * ((total_blocks - done_blocks)/done_blocks)
            if time.time() + dur > start_time + duration and total_blocks/done_blocks > 0.3:
                self.log.info('changing eta to {}'.format(dur))
                yield 'eta', time.time(), dur

            # Allow buffer to be reused
            buffer_queue.put(buff)

        self.log.info('all blocks done')
        yield 'all_done',

        io_scheduler_thr.join()

        filemap = {'qdm_filemap_version': 1, 'files': []}

        for info in self.qdm_upload_job.filemap:
            blocks_count = file_blocks[info['path']]['bcnt']
            file_size = file_blocks[info['path']]['size']

            filemap['files'].append({
                'name': info['path'],
                'blocks': blocks_count,
                'size': file_size,
                'meta': info['meta']
            })

        self.log.info('storing filemap (version=1):')

        for rec in filemap['files']:
            self.log.info(
                '  - %s: %d blocks, %d bytes total (meta: %r)',
                rec['name'], rec['blocks'], rec['size'], rec['meta']
            )

        revno = self.qdm_upload_job.finalize_job(filemap)
        yield 'finish', revno


class DataDownloader(object):
    def __init__(self, qdm_download_job, dest):
        self.qdm_download_job = qdm_download_job
        self.dest = Path(dest)
        self.log = logging.getLogger('datadownloader')
        self.io_sem = threading.Semaphore(1)

        self.threads = []
        self.total_size = 0

    def _download_piece(self, fn, hashtype, hash_, size, mds_key, mds_type, byte_offset):
        fn.parent.makedirs_p()

        open(fn, 'a+b').close()

        digest = hashlib.blake2b(digest_size=32)

        with open(fn, 'r+b') as fp:
            fp.seek(byte_offset)

            if mds_type == 'storage-int-mdst':
                mds_uri = 'storage-int.mdst'
                mds_ns = 'qyp'
            else:
                mds_uri = 'storage-int.mds'
                mds_ns = 'qdm'

            url = 'http://%s.yandex.net:80/get-%s/%s' % (
                mds_uri, mds_ns, mds_key
            )

            tvm_ticket = self.qdm_download_job.get_tvm_mds_ticket()

            with requests.get(
                url, headers={
                    'X-Ya-Service-Ticket': tvm_ticket
                }, timeout=60, stream=True
            ) as req:
                req.raise_for_status()

                for idx, chunk in enumerate(req.iter_content(chunk_size=1 * 1024 * 1024)):
                    digest.update(chunk)
                    with self.io_sem:
                        fp.write(chunk)
                        fp.flush()

                assert digest.hexdigest() == hash_, 'Download hash mismatch!'

    def _worker_thread(self, job_queue, result_queue):
        while True:
            job = job_queue.get()
            if job is None:
                self.threads.remove(threading.current_thread())
                return

            fn, block_idx, global_idx, hashtype, hash_, size, mds_key, mds_type, byte_offset = job
            max_tries = 10

            for i in range(1, max_tries + 1):
                try:
                    self._download_piece(
                        self.dest.joinpath(fn), hashtype, hash_, size, mds_key, mds_type, byte_offset
                    )
                except Exception as ex:
                    self.log.warning('Error during downloading: %s', ex)
                    if i == 10:
                        self.log.warning('Unable to download block, will not try anymore')
                        raise
                    else:
                        self.log.warning('Unable to download block, will retry %d more times' % (max_tries - i, ))
                        time.sleep(3)
                        continue
                else:
                    break

            self.log.info('%s: done chunk offset:%d (%d bytes)', fn, byte_offset, size)
            result_queue.put((fn, block_idx, global_idx))

    def download(self, filter_filenames, worker_threads):
        yield 'init',

        if not self.dest.exists():
            self.dest.makedirs()

        self.log.info('Requesting blockmap...')
        blockmap = self.qdm_download_job.get_blockmap()

        total_blocks = sum(len(blocks) for blocks in blockmap['files'].values())

        self.log.info(
            'Got blockmap with %d total blocks', total_blocks
        )
        self.log.info(
            '  with %d files', len(blockmap['files'])
        )
        self.total_size = sum(sum(b[2] for b in blocks) for blocks in blockmap['files'].values())
        self.log.info(
            '  with total %d bytes',
            self.total_size
        )

        job_queue = queue.Queue()
        result_queue = queue.Queue()

        self.qdm_download_job.allow_reinit(True)

        for i in range(worker_threads):
            thr = threading.Thread(target=self._worker_thread, args=(job_queue, result_queue))
            thr.daemon = True
            self.threads.append(thr)

        [t.start() for t in self.threads]

        if filter_filenames is not None:
            self.log.debug('  filtering %d files using filter %r:', len(blockmap['files']), filter_filenames)
            allowed_files = set()
            for fn in blockmap['files']:
                if fn in filter_filenames:
                    self.log.debug('    [x] %s', fn)
                    allowed_files.add(fn)
                else:
                    self.log.debug('    [ ] %s', fn)
        else:
            allowed_files = set(blockmap['files'].keys())

        if not allowed_files:
            raise Exception('Left 0 files to download after supplied filter')

        filtered_blocks = sum(len(blocks) for fn, blocks in blockmap['files'].items() if fn in allowed_files)

        if filtered_blocks != total_blocks:
            self.log.info(
                '  with %d blocks left after filenames filter', filtered_blocks
            )
            self.total_size = sum(sum(b[2] for b in bs) for fn, bs in blockmap['files'].items() if fn in allowed_files)

        self.log.info('Precreating all files with fallocate...')
        for fn, blocks in blockmap['files'].items():
            if fn not in allowed_files:
                self.log.info('  %-30s: ignoring...', fn)
                continue

            rfn = self.dest.joinpath(fn)
            rfn.parent.makedirs_p()
            with open(rfn, 'ab') as fp:  # using ab mode to force O_CREAT but do not truncate files
                size = sum(b[2] for b in blocks)
                self.log.info('  %-30s: calling fallocate to %d bytes...', fn, size)
                ts = time.time()
                try:
                    fallocate.fallocate(fp, 0, size)
                except OSError as ex:
                    if ex.errno == errno.ENOTSUP:
                        self.log.info('  %-30s: not supported in this system (%s: %s)', fn, type(ex).__name__, ex)
                    else:
                        raise
                else:
                    self.log.info('  %-30s: done, took %.2fs', fn, time.time() - ts)

        self.log.info('Scheduling blocks for download...')

        global_idx = 0
        for fn, blocks in blockmap['files'].items():
            if fn not in allowed_files:
                continue

            for block_idx, block in enumerate(blocks):
                job_queue.put((
                    fn, block_idx, global_idx, block[0], block[1], block[2], block[3], block[4], block[5]
                ))
                global_idx += 1

        for i in range(worker_threads):
            job_queue.put(None)
        done_blocks = 0

        while True:
            if not all(t.is_alive() for t in self.threads):
                raise Exception('Some threads died!')

            if not self.threads:
                self.log.info('All worker threads finished')

                self.log.info('Truncating files to their real size')

                for fn, blocks in blockmap['files'].items():
                    if fn not in allowed_files:
                        continue
                    rfn = self.dest.joinpath(fn)
                    with open(rfn, 'rb+') as fp:
                        size = sum(b[2] for b in blocks)
                        self.log.info('  %-30s: truncating to %d bytes...', fn, size)
                        ts = time.time()
                        fp.truncate(size)
                        self.log.info('  %-30s: done, took %.2fs', fn, time.time() - ts)

                self.log.info('Calling fsync for each downloaded file...')
                for fn in blockmap['files']:
                    if fn not in allowed_files:
                        continue
                    rfn = self.dest.joinpath(fn)
                    with open(rfn, 'rb+') as fp:
                        self.log.info('  %-30s: fsync...', fn)
                        ts = time.time()
                        os.fsync(fp.fileno())
                        self.log.info('  %-30s: done, took %.2fs', fn, time.time() - ts)

                self.qdm_download_job.finish()
                yield 'finished', True
                return

            try:
                fn, block_idx, global_idx = result_queue.get(timeout=1)
                done_blocks += 1
                progress = (done_blocks / filtered_blocks) * 100
                yield 'done_block', progress, done_blocks
            except queue.Empty:
                continue


class QdmJob(object):
    def __init__(self, qdm_host, qdm_port):
        self._qdm_host = qdm_host
        self._qdm_port = qdm_port

        self._cli = None

        self._lock = threading.RLock()

        self._log = logging.getLogger('qdmjob')
        self._job = None

        self._retryable = False
        self._allow_reinit = False

    def _job_call(self, meths):
        deadline = time.time() + 3600

        attempt = 0

        while True:
            attempt += 1
            try:
                result = None
                for meth, args, kwargs in meths:
                    result = getattr(self._job, meth)(*args, **kwargs)
                return result

            except rpc_errors.CallFail as ex:
                self._log.critical('RPC call (#%d) failed: %s: %s', attempt, type(ex).__name__, ex)
                raise

            except (OSError, rpc_errors.RPCError) as ex:
                # BADFileDesc, etc. Reconnect should fix that
                self._log.critical('RPC call (#%d) failed: %s: %s', attempt, type(ex).__name__, ex)
                if self._allow_reinit:
                    if time.time() > deadline:
                        raise
                    self._log.info('We allowed to reinit connection, try that in 30s...')
                    time.sleep(30)  # wait some time before reinit, this will allow server to update
                    try:
                        if not self.connect():
                            raise rpc_errors.RPCError('Failed to connect QDM server')
                        self.init()
                    except rpc_errors.CallFail:
                        self._log.critical(
                            '  unable to reinit: RPC call (#%d) failed: %s: %s',
                            attempt, type(ex).__name__, ex
                        )
                        raise
                    except (OSError, rpc_errors.RPCError) as ex:
                        self._log.warning('  unable to reinit: %s: %s', type(ex).__name__, ex)
                        time.sleep(1)
                    else:
                        self._log.debug('  connection was reinitialized')
                    continue
                else:
                    self._retryable = True
                    raise

    def _get_blocked(self, meth, response_hint, *args):
        with self._lock:
            response = self._job_call((
                ('send', ((meth, args), ), {}),
                ('next', (), {})
            ))

            assert response[0] == response_hint, 'Invalid response for %r, expected %r, got %r' % (
                meth, response_hint, response
            )
            return response[1]

    def _notify(self, meth, *args):
        with self._lock:
            self._job_call((
                ('send', ((meth, ) + args, ), {}),
            ))

    def connect(self):
        if self._cli:
            self._cli.stop()

        self._cli = RPCClient(self._qdm_host, self._qdm_port)
        deadline = time.time() + 600

        with self._lock:
            while time.time() < deadline:
                next_try_min = time.time() + 30

                try:
                    self._log.info('Trying to connect QDM server...')
                    self._cli.connect()
                    self._log.debug('  connected')
                except (socket.error, rpc_errors.HandshakeError) as ex:
                    self._log.critical('  unable to connect QDM server: %s', ex)
                    if time.time() < deadline:
                        self._log.info('  will retry more for next %ds', deadline - time.time())
                        wait = next_try_min - time.time()
                        if wait > 0:
                            self._log.debug('    waiting %fs before retry...', wait)
                            time.sleep(wait)
                        continue
                    else:
                        return False
                else:
                    return True

        return False

    def finish(self):
        assert self._job_call((
            ('send', (('finish', ), ), {}),
            ('wait', (), {'timeout': 60}),
        )), 'QDM job finished with bad return code!'

    def is_retryable(self):
        return self._retryable

    def allow_reinit(self, flag):
        self._allow_reinit = flag


class QdmUploadJob(QdmJob):
    def __init__(self, qdm_host, qdm_port, qdm_key):
        super(QdmUploadJob, self).__init__(qdm_host, qdm_port)

        self._key = qdm_key
        self._rev_key = None

        self._tvm_mds_ticket = None
        self._get_tvm_ticket_lock = threading.Lock()

        self._vmspec = None
        self._filemap = []

        self.log = logging.getLogger('uploadjob')

    def load_spec(self, spec_raw):
        spec = json.loads(spec_raw)

        assert 'qdm_spec_version' in spec, 'qdm_spec_version key not found in spec'
        assert spec['qdm_spec_version'] == 1, 'only qdm_spec_version=1 is supported'

        self._load_spec_v1(spec)

    def _load_spec_v1(self, spec):
        assert 'vmspec' in spec
        assert 'filemap' in spec

        vmspec = spec['vmspec']
        filemap = spec['filemap']

        assert isinstance(vmspec, (dict, type(None)))
        assert isinstance(filemap, (list, tuple))

        try:
            msgpack.dumps(vmspec)
        except:
            raise Exception('vmspec should be msgpack-dumpable, but it is not')

        self._vmspec = vmspec

        assert len(filemap) >= 1

        for fileinfo in filemap:
            assert isinstance(fileinfo, dict)

            assert 'path' in fileinfo
            assert 'source' in fileinfo

            path = fileinfo['path']
            source = fileinfo['source']
            meta = fileinfo.get('meta', {})

            assert isinstance(path, str)
            assert isinstance(source, str)

            if meta is not None:
                assert isinstance(meta, dict)
            elif meta is None:
                meta = {}

            self._filemap.append({
                'path': path,
                'source': Path(source),
                'meta': meta
            })

        assert len(self._filemap) >= 1

        self.log.info('Loaded upload spec v1')
        self.log.info('  filemap:')

        for fileinfo in self._filemap:
            self.log.info('    - %s (src: %s)', fileinfo['path'], fileinfo['source'])
            if fileinfo['meta'] is not None:
                self.log.info('      (meta %r)', fileinfo['meta'])

        if self._vmspec is None:
            self.log.info('  vmspec: None')
        else:
            self.log.info('  vmspec:')
            for line in yaml.dump(self._vmspec, default_flow_style=False).split('\n'):
                self.log.info('    %s', line)

    def generate_spec_v0(self, cwd, paths):
        assert isinstance(cwd, str)
        assert isinstance(paths, (list, tuple))

        filemap = []

        cwd = Path(cwd)

        for fn in sorted(paths):  # for auto-generated spec we always sort file list
            full_fn = cwd.joinpath(fn)
            filemap.append({
                'path': fn,
                'source': str(full_fn),
                'meta': {}
            })

        return json.dumps({
            'qdm_spec_version': 1,
            'vmspec': None,
            'filemap': filemap
        })

    def init(self):
        self._job = self._cli.call('upload_session', self._key, self._rev_key)

        try:
            self._rev_key = self._get_blocked('init', 'init', self.vmspec)
        except rpc_errors.CallFail as ex:
            self._log.warning('We cant continue this session anymore, will retry whole job')
            if 'Unable to find storage revision to continue' in str(ex):
                self._retryable = True
            raise

    def check_data_block(self, idx, size, hashtype, hash_):
        # Run server-side check for specified block/hash/size to determine
        # should we (False) or not (True) to upload this to MDS storage
        return self._get_blocked('check_data_block', 'data_block_checked', idx, size, hashtype, hash_)

    def data_block_store(self, idx, size, hashtype, hash_, mds_key, ttl):
        return self._get_blocked(
            'data_block_stored', 'data_block_stored',
            idx, size, hashtype, hash_, mds_key, ttl
        )

    def finalize_job(self, filemap):
        result = self._get_blocked('finalize_revision', 'revision_done', filemap)
        revision = result
        self.finish()
        return revision

    def set_progress(self, progress):
        return self._notify('set_progress', progress)

    def get_tvm_mds_ticket(self, force=False):
        with self._get_tvm_ticket_lock:
            if not self._tvm_mds_ticket or time.time() - self._tvm_mds_ticket[1] >= 3600:
                self._tvm_mds_ticket = (self._get_blocked('get_tvm_ticket', 'tvm_ticket'), time.time())

            return self._tvm_mds_ticket[0]

    def get_volume_io_by_storage_class(self, volume_name):
        if self._vmspec is None:
            return 0
        q_spec = self._vmspec.get('spec', {}).get('qemu')
        for v in q_spec.get('volumes', []):
            if v.get('name') == volume_name:
                io_lim = q_spec.get('io_limits_per_storage', {})
                disk_type = v.get('storage_class', '')
                return int(io_lim.get(disk_type, 0))
        return 0

    def get_estimated_duration(self, factor=1.5):
        res = 0
        for f in self._filemap:
            meta = f.get('meta', {})
            volume_name = meta.get('volume_name', '')
            bandwidth = self.get_volume_io_by_storage_class(volume_name)
            if not bandwidth:
                self.log.warning('Not able to get bandwidth for filespec: {}'.format(f))
            else:
                res += f['source'].getsize()/bandwidth
        return res * factor

    @property
    def filemap(self):
        return self._filemap

    @property
    def vmspec(self):
        return self._vmspec


class QdmDownloadJob(QdmJob):
    def __init__(self, qdm_host, qdm_port, rev, vm_id, cluster, node_id):
        super(QdmDownloadJob, self).__init__(qdm_host, qdm_port)
        self._rev = rev

        self._tvm_mds_ticket = None
        self._get_tvm_ticket_lock = threading.Lock()

        self._session_key = None
        self._vm_id = vm_id
        self._cluster = cluster
        self._node_id = node_id

    def init(self):
        self._job = self._cli.call(
            'download_session', self._rev, self._session_key,
            self._vm_id, self._cluster, self._node_id
        )
        self._session_key = self._get_blocked('init', 'init')
        return True

    def set_progress(self, progress):
        return self._notify('set_progress', progress)

    def get_blockmap(self):
        return self._get_blocked('get_blockmap', 'blockmap')

    def get_tvm_mds_ticket(self, force=False):
        with self._get_tvm_ticket_lock:
            if not self._tvm_mds_ticket or time.time() - self._tvm_mds_ticket[1] >= 3600:
                self._tvm_mds_ticket = (self._get_blocked('get_tvm_ticket', 'tvm_ticket'), time.time())

            return self._tvm_mds_ticket[0]


def get_porto_memory_limit():
    try:
        return int(subproc.check_output(['portoctl', 'get', 'self', 'memory_limit_total']))
    except:
        return 0


def get_qdm_server_lsb_hostname(host):
    response = requests.get('http://%s/api/v1/hostname' % (host, ), timeout=60)
    response.raise_for_status()
    return response.text


def write_progress_data(path: str, operation_type: str, progress: int, total_bytes: int, done_bytes: int,
                        start_time=None, duration=None) -> None:
    if path == '-':
        return
    with open(path, 'w') as f:
        data = {
            'operation': operation_type, 'progress': progress, 'time': time.time(),
            'total_bytes': total_bytes, 'done_bytes': done_bytes, 'start_time': start_time, 'duration': duration
        }
        json.dump(data, f)


def real_main():
    args = parse_args()

    if args.logfile != '-':
        handler = logging.handlers.TimedRotatingFileHandler(
            args.logfile,
            when='midnight',
            backupCount=7,
        )

        logging.basicConfig(
            level=logging.DEBUG,
            format='%(asctime)s %(levelname)-8s [%(name)-16s]  %(message)s',
            handlers=[handler]
        )
    else:
        logging.basicConfig(
            level=logging.DEBUG,
            format='%(asctime)s %(levelname)-8s [%(name)-16s]  %(message)s'
        )

    log = logging.getLogger('main')

    # Disable DEBUG requests logging
    logging.getLogger('requests').setLevel(logging.WARNING)
    logging.getLogger('urllib3').setLevel(logging.WARNING)

    log.info('Initialized QDM Client')
    log.info('Arguments:')
    for key, value in sorted(args.__dict__.items()):
        if key == 'qdm_key':
            value = '<hidden>'
        elif key == 'logfile':
            # Pointless to log logfile path :)
            continue

        log.info('  %s: %s' % (key, value))

    if not args.qdm_host_direct:
        log.info('Attempt to convert LSB hostname (%s) to direct...', args.qdm_host)
        args.qdm_host = get_qdm_server_lsb_hostname(args.qdm_host)
        log.info('  converted to %s', args.qdm_host)

    if args.operation == 'upload':
        qdm_upload_job = QdmUploadJob(args.qdm_host, args.qdm_port, args.qdm_key)

        if args.spec and args.path:
            log.warning('--spec was specified, thus all paths (%r) are ignored', args.path)

        if args.spec and args.cwd:
            log.warning('--spec was specified together with --cwd, thus --cwd was ignored')

        if not args.spec:
            if args.cwd is None:
                args.cwd = '.'
            spec = qdm_upload_job.generate_spec_v0(args.cwd, args.path)
        else:
            spec = args.spec

        qdm_upload_job.load_spec(spec)

        if not qdm_upload_job.connect():
            os._exit(1)

        try:
            qdm_upload_job.init()

            uploader = DataUploader(
                qdm_upload_job, args.mds_use_proxy
            )

            finished = False

            # Choose amount of buffers wisely
            log.info('Determining how much memory we can use...')

            if args.background:
                buffers_count = 1
                log.info('  will use 1 buffer (background working mode)')
            else:
                memory_limit = get_porto_memory_limit()

                if memory_limit == 0:
                    log.info('  detected no memory limit')
                    buffers_count = int(MAX_BUFFERS / 2)
                else:
                    log.info('  detected %sG memory limit', int(memory_limit / GB))
                    buffers_count = int(memory_limit / (512 * MB)) - 2  # each buff 512Mb, count 2 less just for sake
                    buffers_count = max(MIN_BUFFERS, buffers_count)     # cap min
                    buffers_count = min(MAX_BUFFERS, buffers_count)     # cap max

                log.info('  will use %d buffers', buffers_count)

            done_blocks_percent = 0
            start_time, duration = None, None
            for progress in uploader.upload(io_threads=2, upload_threads=12, buffers_count=buffers_count):
                if progress[0] == 'eta':
                    start_time, duration = progress[1], progress[2]
                if progress[0] == 'done_block':
                    if progress[1] < 100:
                        log.debug('[%2d%%] PROGRESS: done block %d', progress[1], progress[2])
                    else:
                        finished = True
                        log.debug('[FIN] PROGRESS: done block %d', progress[2])
                    done_blocks_percent = progress[1]
                else:
                    if not finished:
                        log.debug('[   ] PROGRESS: %r', progress)
                    else:
                        done_blocks_percent = 100
                        log.debug('[FIN] PROGRESS: %r', progress)

                if progress[0] == 'finish':
                    # Print final rev id with qdm: prefix, it will be saved in YP info as-is
                    done_blocks_percent = 100
                    sys.stdout.write('qdm:%s\n' % (progress[1], ))

                write_progress_data(
                    args.progress_file, args.operation, int(done_blocks_percent),
                    uploader.total_size, int((uploader.total_size * done_blocks_percent)/100),
                    start_time=start_time, duration=duration
                )
                qdm_upload_job.set_progress(progress)
        except Exception as ex:
            log.critical('Got error: %s: %s', type(ex).__name__, ex)
            for line in traceback.format_exc().split('\n'):
                if line:
                    log.critical(line)

            if qdm_upload_job.is_retryable():
                log.info('  above error indicated as retryable... restarting whole job')
                os.execv(sys.argv[0], sys.argv)

            raise KeyboardInterrupt  # we do not need to print traceback to stderr again
    else:
        if '/' in args.rev:
            args.rev, fn_filter = args.rev.split('/', 1)
            log.info('Will filter only one file: %s', fn_filter)
        else:
            fn_filter = None

        qdm_download_job = QdmDownloadJob(
            args.qdm_host, args.qdm_port,
            args.rev, args.vm_id, args.cluster, args.node_id
        )
        if not qdm_download_job.connect():
            os._exit(1)

        try:
            qdm_download_job.init()

            downloader = DataDownloader(qdm_download_job, args.dest)

            finished = False

            for progress in downloader.download(
                filter_filenames=[fn_filter] if fn_filter is not None else None,
                worker_threads=4
            ):
                done_blocks_percent = 0
                if progress[0] == 'done_block':
                    if progress[1] < 100:
                        log.debug('[%2d%%] PROGRESS: done block %d', progress[1], progress[2])
                    else:
                        finished = True
                        log.debug('[FIN] PROGRESS: done block %d', progress[2])
                    done_blocks_percent = progress[1]
                else:
                    if not finished:
                        log.debug('[   ] PROGRESS: %r', progress)
                    else:
                        done_blocks_percent = 100
                        log.debug('[FIN] PROGRESS: %r', progress)

                write_progress_data(
                    args.progress_file, args.operation, int(done_blocks_percent),
                    downloader.total_size, int((downloader.total_size * done_blocks_percent)/100),
                )
                qdm_download_job.set_progress(progress)
        except Exception as ex:
            log.critical('Got error: %s: %s', type(ex).__name__, ex)
            for line in traceback.format_exc().split('\n'):
                if line:
                    log.critical(line)

            if qdm_download_job.is_retryable():
                log.info('  above error indicated as retryable... restarting whole job')
                os.execv(sys.argv[0], sys.argv)

            raise KeyboardInterrupt  # we do not need to print traceback to stderr again


def main():
    try:
        return real_main()
    except KeyboardInterrupt:
        try:
            logging.getLogger('main').error('Interrupted')
        except:
            sys.stderr.write('Interrupted\n')
            sys.stderr.flush()
        os._exit(1)
