import logging
import os
import subprocess
import time
from datetime import timedelta
from multiprocessing import cpu_count

from sandbox import sdk2
from sandbox.common.share import skynet_get
from sandbox.common.errors import TaskFailure
from sandbox.sandboxsdk.environments import PipEnvironment
from sandbox.projects.common.arcadia import sdk
from sandbox.projects.yabs.qa.resource_types import YABS_SERVER_B2B_BINARY_BASE


logger = logging.getLogger(__name__)
BASES_DIR = 'bases'


def get_unpacked_size(base_path, arcadia_root):
    cmd = [
        sdk._get_ya_tool(arcadia_root),
        'tool',
        'uc',
        '-d',
        '-f={}'.format(base_path),
        '-j=4',
    ]
    return len(subprocess.check_output(cmd))


def download_base(base_path, rbtorrent, local_path, arcadia, resource_id):
    start = time.time()
    try:
        skynet_get(rbtorrent, local_path, fallback_to_bb=True)
    except Exception as exc:
        logger.error('Cannot download %s due to exception: %s', base_path, exc, exc_info=True)
        raise
    finally:
        logger.info('Download %s task finished, time elapsed %s', base_path, timedelta(seconds=time.time() - start))
    try:
        unpacked_size = get_unpacked_size(os.path.join(local_path, base_path), arcadia)
    except Exception as exc:
        logger.error('Cannot get unpacked_size for %s due to exception: %s', base_path, exc, exc_info=True)
        raise
    try:
        resource = YABS_SERVER_B2B_BINARY_BASE[resource_id]
        sdk2.ResourceData(resource).ready()
    except Exception as exc:
        logger.error('Cannot publish resource for %s due to exception: %s', base_path, exc, exc_info=True)
        raise
    return unpacked_size


class YabsServerDownloadBases(sdk2.Task):
    """Download yabs_server binary bases by rbtorrent"""

    class Requirements(sdk2.Requirements):
        ram = 128 * 1024
        environments = (
            PipEnvironment('futures'),
        )

        class Caches(sdk2.Requirements.Caches):
            pass

    class Parameters(sdk2.Parameters):
        kill_timeout = 1 * 60 * 60

        base_bstr_info = sdk2.parameters.JSON('Bases config to download')

        with sdk2.parameters.Group('Misc') as misc:
            n_jobs = sdk2.parameters.Integer('Process pool size to use for download', default=8)

        with sdk2.parameters.Output:
            base_resources = sdk2.parameters.JSON('Dict with bases')
            failed_to_download_bases = sdk2.parameters.JSON('List of based that task failed to download')

    def publish_results(self):
        with self.memoize_stage.publish_results:
            self.Parameters.base_resources = self.Context.base_resources
            self.Parameters.failed_to_download_bases = list(set(self.Parameters.base_bstr_info.keys()) - set(self.Context.base_resources.keys()))

    def on_break(self, *args, **kwargs):
        self.publish_results()

    def on_failure(self, *args, **kwargs):
        self.publish_results()

    def on_success(self, *args, **kwargs):
        self.publish_results()

    def on_execute(self):
        base_resources = {}
        self.Context.base_resources = {}
        os.makedirs(BASES_DIR)
        from concurrent.futures import ThreadPoolExecutor, as_completed
        with ThreadPoolExecutor(max_workers=min(self.Parameters.n_jobs, int(cpu_count()), len(self.Parameters.base_bstr_info))) as process_pool, sdk.mount_arc_path("arcadia-arc:/#trunk") as arcadia:
            futures_to_base_name = {}
            for base_name, base_bstr_info in self.Parameters.base_bstr_info.items():
                base_path = os.path.join(BASES_DIR, base_bstr_info['name'])
                resource = YABS_SERVER_B2B_BINARY_BASE(
                    self,
                    description='{} production base'.format(base_name),
                    path=base_path,
                    generation_id=self.id,
                    tag=base_name,
                    mtime=base_bstr_info['mtime'],
                    torrent=base_bstr_info['torrent'],
                    source='production',
                    ttl=4,
                )
                futures_to_base_name[process_pool.submit(download_base, base_bstr_info['name'], base_bstr_info['torrent'], BASES_DIR, arcadia, resource.id)] = (base_name, resource.id)

            for f in as_completed(futures_to_base_name):
                base_name, resource_id = futures_to_base_name[f]
                exc = f.exception()
                if exc is not None:
                    logger.error('Base %s is broken: %s', base_name, exc, exc_info=True)
                    continue

                unpacked_size = f.result()
                self.server.resource[resource_id].attribute.create(name='unpacked_size', value=unpacked_size)

                base_resources[base_name] = resource_id
                self.Context.base_resources = base_resources

        failed_to_download_bases = list(set(self.Parameters.base_bstr_info.keys()) - set(self.Context.base_resources.keys()))
        if failed_to_download_bases:
            raise TaskFailure('Failed to download bases: {}'.format(failed_to_download_bases))
