from sandbox import sdk2
from sandbox.common.types import task
import sandbox.projects.common.constants as consts
import sandbox.projects.kwyt.resources as res
from sandbox.sandboxsdk.errors import SandboxTaskFailureError

from os.path import join as pj

import logging
import os
import shutil
import subprocess


_KWYTQL_BIN_PATH = 'robot/kwyt/tools/kwytql'


class KwytqlTask(sdk2.Task):
    """
    Start kwytql task.
    """

    class Requirements(sdk2.Task.Requirements):
        pass

    class Parameters(sdk2.Task.Parameters):
        arcadia_url = sdk2.parameters.ArcadiaUrl(
            "Svn url for arcadia (you can add '@<commit_number>')",
            required=True,
            default_value='arcadia:/arc/trunk/arcadia'
        )

        kwytqlBinResource = sdk2.parameters.Resource(
            "Resource id of kwytql (if empty - will be built)",
            resource_type=res.KwytqlBin
        )
        yqlToken = sdk2.parameters.YavSecret("YQL token", default="", multiline=False)
        yqlTokenKey = sdk2.parameters.String("YQL token key", default="yql.token", multiline=False)
        pool = sdk2.parameters.String("Pool", default="", multiline=False)
        query = sdk2.parameters.String("Query", default="", multiline=True)
        preQuery = sdk2.parameters.String("Pre-query", default="", multiline=True)
        finalQuery = sdk2.parameters.String("Final query", default="", multiline=True)
        shardsList = sdk2.parameters.String("Shards list", default="", multiline=False)
        kwytPath = sdk2.parameters.String("Path to kwyt", default="//home/kwyt/pages/", multiline=False)
        twoOperations = sdk2.parameters.Bool("Use two simultaneous operations")
        cluster = sdk2.parameters.String("Cluster", default="arnold", multiline=False)

    def buildKwytql(self, kwytqlBinPath):
        subtask = sdk2.Task["YA_MAKE_2"](
            self,
            description="Child of {}".format(self.id),
            owner=self.owner,
            checkout_arcadia_from_url=self.Parameters.arcadia_url,
            targets=kwytqlBinPath,
            arts=pj(kwytqlBinPath, 'kwytql'),
            result_single_file=True,
            result_rt='KWYTQL_BIN',
            result_rd='kwytql binary',
            use_aapi_fuse=True,
            aapi_fallback=True,
            build_system=consts.SEMI_DISTBUILD_BUILD_SYSTEM,
            checkout_mode='auto'
        )
        subtask.enqueue()

        return subtask.id

    class Context(sdk2.Context):
        errors = []
        messages = []
        subtasks = {}

    def checkSubtaskStatus(self, subtask):
        if subtask.status in (task.Status.FAILURE, task.Status.EXCEPTION):
            self.Context.errors.append('Subtask {} failed'.format(subtask.id))
        else:
            self.Context.messages.append('Subtask {} have finished successfully'.format(subtask.id))

    def on_execute(self):
        if not self.Parameters.kwytqlBinResource:
            with self.memoize_stage.build_kwytql:
                subtask = self.buildKwytql(_KWYTQL_BIN_PATH)
                self.Context.subtasks['KWYTQL_BIN'] = subtask
                raise sdk2.WaitTask([subtask], task.Status.Group.FINISH | task.Status.Group.BREAK, wait_all=True)

        subtasks = list(self.find())
        if subtasks:
            for subtask in subtasks:
                self.checkSubtaskStatus(subtask)

        if self.Context.errors:
            raise SandboxTaskFailureError('One of child tasks failed unexpectedly')

        if self.Parameters.kwytqlBinResource:
            kwytqlBinRes = self.Parameters.kwytqlBinResource
        else:
            kwytqlBinRes = sdk2.Resource.find(
                type='KWYTQL_BIN',
                task_id=self.Context.subtasks['KWYTQL_BIN']
            ).first()
            logging.info("kwytql resource: {}".format(kwytqlBinRes.id))
        kwytqlBin = str(sdk2.ResourceData(kwytqlBinRes).path)

        os.environ['YQL_TOKEN'] = self.Parameters.yqlToken.data()[self.Parameters.yqlTokenKey]
        queriesDir = pj(os.getcwd(), 'queries')
        os.mkdir(queriesDir)
        queryFile = open(pj(queriesDir, 'query'), 'w')
        queryFile.write(self.Parameters.query)
        queryFile.close()
        cmd = [
            kwytqlBin,
            '--query', pj(queriesDir, 'query')]

        if self.Parameters.preQuery:
            preQueryFile = open(pj(queriesDir, 'preQuery'), 'w')
            preQueryFile.write(self.Parameters.preQuery)
            preQueryFile.close()
            cmd.extend(['--pre-query', pj(queriesDir, 'preQuery')])
        if self.Parameters.finalQuery:
            finalQueryFile = open(pj(queriesDir, 'finalQuery'), 'w')
            finalQueryFile.write(self.Parameters.finalQuery)
            finalQueryFile.close()
            cmd.extend(['--final-query', pj(queriesDir, 'finalQuery')])
        if self.Parameters.shardsList:
            cmd.extend(['--shards-list', self.Parameters.shardsList])
        if self.Parameters.kwytPath:
            cmd.extend(['--kwyt-path', self.Parameters.kwytPath])
        if self.Parameters.twoOperations:
            cmd.extend(['--use-two-operations'])
        if self.Parameters.cluster:
            cmd.extend(['--cluster', self.Parameters.cluster])

        logging.info("Run cmd: {}".format(' '.join(cmd)))
        subprocess.check_call(cmd)

        shutil.rmtree(queriesDir)
