# -*- coding: utf-8 -*-

import logging
import os
import shutil
import tarfile
import tempfile

from sandbox import sdk2

from sandbox.projects.common.arcadia import sdk as arc
from sandbox.projects.common.constants import constants as arcc
from sandbox.sdk2.helpers import subprocess as sp

from sandbox.common.types import client as ctc

from sandbox.projects.dj.process_wrapper import process_wrapper
from sandbox.projects.dj.resource_types import DjAnaconda


logger = logging.getLogger("DjBuildTorchPackages")


class DjBuildTorchPackages(sdk2.Task):
    """Build PyTorch packages from Arcadia"""

    class Requirements(sdk2.Task.Requirements):
        client_tags = ctc.Tag.LINUX_PRECISE | ctc.Tag.LINUX_TRUSTY | ctc.Tag.LINUX_XENIAL

    class Parameters(sdk2.Task.Parameters):
        checkout_arcadia_from_url = sdk2.parameters.ArcadiaUrl("Arcadia URL", required=True)
        packages_paths = sdk2.parameters.List('Packages paths', sdk2.parameters.String, required=True,
                                              default=["ads/pytorch/packages/ads_pytorch",
                                                       "ads/pytorch/packages/pytorch_embedding_model",
                                                       "dj/torch/djtorch"])
        packages_resource_type = sdk2.parameters.String('Packages resource type', required=True,
                                                        default="OTHER_RESOURCE")

    def python_command(self, anaconda_directory):
        anaconda = sdk2.Resource.find(type=DjAnaconda, attrs={"released": "stable"}).order(-sdk2.Resource.id).first()
        if not anaconda:
            raise Exception('Anaconda not found')
        command = ["tar", "-xvf", str(sdk2.ResourceData(anaconda).path), "-C", anaconda_directory]
        logger.info("Running command: %s", command)
        with process_wrapper(self, logger='tar_anaconda') as pl:
            sp.check_call(command, stdout=pl.stdout, stderr=pl.stderr)
        return os.path.join(anaconda_directory, "bin", "python")

    def python_include(self, python):
        command = [python, "-c", "import sysconfig; print(sysconfig.get_paths()['include'])"]
        logger.info("Running command: %s", command)
        with process_wrapper(self, logger='python_include') as pl:
            return sp.check_output(command, stderr=pl.stderr)

    def on_execute(self):
        temp_directory = tempfile.mkdtemp()
        try:
            anaconda_directory = os.path.join(temp_directory, "anaconda")
            build_directory = os.path.join(temp_directory, "build")
            result_directory = os.path.join(temp_directory, "result")

            os.makedirs(result_directory)
            os.makedirs(anaconda_directory)

            python = self.python_command(anaconda_directory)

            with arc.mount_arc_path(self.Parameters.checkout_arcadia_from_url) as arcadia_dir:
                os.environ["ARCADIA_ROOT"] = arcadia_dir
                os.environ["PYTORCH_EMBEDDING_PYTHON_INCLUDE_PATH"] = self.python_include(python)

                logger.info("ARCADIA_ROOT=%s", os.environ["ARCADIA_ROOT"])
                logger.info("PYTORCH_EMBEDDING_PYTHON_INCLUDE_PATH=%s", os.environ["PYTORCH_EMBEDDING_PYTHON_INCLUDE_PATH"])

                if not os.path.exists(os.path.join(arcadia_dir, ".arc")):
                    os.mkdir(os.path.join(arcadia_dir, ".arc"))

                for package_index, package_path in enumerate(self.Parameters.packages_paths):
                    os.makedirs(build_directory)
                    command = [python, os.path.join(arcadia_dir, package_path, "setup.py"),
                               "bdist_wheel", "--dist-dir", build_directory]
                    logger.info("Running command: %s", command)
                    with process_wrapper(self, logger='python_{}_wheel'.format(package_index)) as pl:
                        sp.check_call(command, cwd=os.path.join(arcadia_dir, package_path), stdout=pl.stdout, stderr=pl.stderr)
                    packages = []
                    for name in os.listdir(build_directory):
                        if os.path.isfile(os.path.join(build_directory, name)) and name.endswith('.whl'):
                            packages.append(name)
                            shutil.copy(os.path.join(build_directory, name), os.path.join(result_directory, name))
                    if not packages:
                        raise Exception('Package not found')
                    shutil.rmtree(build_directory)

            logger.info("Creating result tar archive")
            resource = sdk2.Resource[self.Parameters.packages_resource_type](self, "Tar archive with PyTorch packages", "packages.tar")
            resource_data = sdk2.ResourceData(resource)
            with tarfile.open(str(resource_data.path), 'w') as tar:
                for name in os.listdir(result_directory):
                    tar.add(os.path.join(result_directory, name), name)

            logger.info("Finished")
        finally:
            shutil.rmtree(temp_directory)
