from sandbox import sdk2
from sandbox.sandboxsdk.process import run_process  # maybe sdk2.helpers.ProcessLog?
from sandbox.projects.common import error_handlers as eh
from sandbox.common.types import task as ctt
from sandbox.projects.common import apihelpers

from sandbox.projects import resource_types as rst
from sandbox.projects.common.utils import set_resource_attributes

from sandbox import common
import json

import sandbox.projects.common.constants as consts

import logging

from os.path import join as pj
import os
import shutil

from sandbox.sandboxsdk import environments

import datetime

TASK_TO_LOGS = {
    "LtpProfilePrestableTask": ["LtpProfilePrestable", "LtpProfileBuildStatePrestable"],
    "LtpProfileTask": ["LtpProfileWithOldVectors", "LtpProfileBuildState"],
    "LtpAdvProfileTask": ["LtpAdvProfile"],
}

TASK_TO_ACTION_LOGS = {
    "LtpProfilePrestableTask": ["LtpSearchClicks", "AdsUserSessionPrestableClickedDocs", "AdsUserSessionPrestableQueries", "AdsUserSessionPrestableMiscEvents"],
    "LtpProfileTask": ["LtpSearchClicks", "AdsUserSessionClickedDocs", "AdsUserSessionQueries", "AdsUserSessionMiscEvents"],
    "LtpAdvProfileTask": ["LtpRsyaClicks", "LtpRsyaConversions"],
}

USUAL = "usual"
CUSTOM_YT_PREFIX_WITH_SRC_START_TABLES = "custom_yt_prefix_with_src_start_tables"
TAKE_PROGRESS_FROM_PREV_SUCCESSFULL_TASK = "take_progress_from_prev_successfull_task"

HOUR_TABLE_NAME_FORMAT = "%Y-%m-%dT%H:%M:%S"
DATE_FORMAT = "%Y-%m-%d"


def to_date_datetime(dt_str):
    return datetime.datetime.strptime(dt_str, DATE_FORMAT)


def to_date_datetime_str(dt_dt):
    return datetime.datetime.strftime(dt_dt, DATE_FORMAT)


def to_hour_datetime_str(dt_dt):
    return datetime.datetime.strftime(dt_dt, HOUR_TABLE_NAME_FORMAT)


def RunProcess(cmd, env, log_prefix=None, exception_if_nonzero_code=True):
    cmd_str = ' '.join([str(cmd_elem) for cmd_elem in cmd])
    process = run_process(cmd_str,
                          wait=True,
                          outs_to_pipe=False,
                          close_fds=True,
                          check=False,
                          shell=True,
                          log_prefix=log_prefix,
                          environment=env
                          )
    process.communicate()
    if process.returncode != 0:
        exc_msg = ''

        error_file_path = process.stderr_path
        if error_file_path is not None:
            error = None
            with open(error_file_path, 'r') as error_file:
                error = '\n'.join(error_file.readlines()[-30:])
            exc_msg += 'error: {}\n'.format(error)

        result_file_path = process.stdout_path
        if result_file_path is not None:
            result = None
            with open(result_file_path, 'r') as result_file:
                result = '\n'.join(result_file.readlines()[-30:])
            exc_msg += 'result: {}\n'.format(result)

        if exception_if_nonzero_code:
            raise Exception(exc_msg)
        else:
            return exc_msg

    return ""


class RecalcLtp(sdk2.Task):
    class Requirements(sdk2.Requirements):
        environments = [
            environments.PipEnvironment('yandex-yt', version='0.10.8'),
        ]

    class Parameters(sdk2.Parameters):
        with sdk2.parameters.String("Mode", required=True) as mode:
            mode.values[USUAL] = USUAL
            mode.values[TAKE_PROGRESS_FROM_PREV_SUCCESSFULL_TASK] = TAKE_PROGRESS_FROM_PREV_SUCCESSFULL_TASK

        with mode.value[TAKE_PROGRESS_FROM_PREV_SUCCESSFULL_TASK]:
            prev_task_from_which_to_continue = sdk2.parameters.Task("prev_task_from_which_to_continue. Id of successfull task from which we can take progress and continue", required=True)

        with mode.value[USUAL]:
            custom_prefix_with_tables_from_which_to_start = sdk2.parameters.String(
                "For neshtat case: yt prefix with src tables from which we should start. For example if "
                " previous task is not entirely successfull but had progress which you don't want to lose", default="")

            with sdk2.parameters.String(
                "logos task name", required=True
            ) as TaskName:
                for task_name in TASK_TO_LOGS.keys():
                    TaskName.values[task_name] = task_name

            logos_release_id_prefix = sdk2.parameters.String("logos_release_id_prefix. Datetime-suffix will be added. Reactor path will be "
                                                             " /logos/graphs/ads/dev/<login>/<logos_release_id_prefix><suf>)", required=True)

        login = sdk2.parameters.String("login. Logos and Yt user login: will be used for forming logos-token-secret-name in nirvana-vault <login>_logos_token. And also will be used for forming "
                                       " working directory on Yt //home/ads/logos/dev/<login>/<logos_release_id>", required=True)

        yt_pool = sdk2.parameters.String("Yt pool. <login> has to have access to it. If not specified then default pool will be used", default="")
        nirvana_quota = sdk2.parameters.String("Nirvana quota. <login> has to have access to it. If not specified then default one will be used", default="")
        arcadia_url = sdk2.parameters.ArcadiaUrl('Arcadia url for binaries (must be with @<rev>). For example: arcadia:/arc/trunk/arcadia@<rev>', required=True)
        arcadia_patch = sdk2.parameters.String("Arcadia patch", default="")

        keep_reactor_graph = sdk2.parameters.Bool("Keep reactor logos dev graph", default=False)

        with sdk2.parameters.Group(
                "Range of dates to recalc [first, last] with step, but first and last will be included and (last-first) div step must be 0. Step must be 1 or 15. "
                "If 15, then first and last has to be in the 15-grid {2021-02-16, 2021-03-03, etc.} - grid can be seen in logos-prod-path."
                " Src start table with be with date first-step.") as range_token_block:
            first_datetime_str = sdk2.parameters.String("First for profile in format YYYY-MM-DD. ", required=True)
            last_datetime_str = sdk2.parameters.String("Last for profile in format YYYY-MM-DD. ", required=True)
            step = sdk2.parameters.Integer("Step", default=15)

        with sdk2.parameters.Group("Logos token - has to be the token of <login>") as logos_token_block:
            logos_token_secret_owner = sdk2.parameters.String("Owner of sb-vault-secret with logos token. If empty - task's owner will be used", default="")
            logos_token_secret_name = sdk2.parameters.String("Name of sb-vault-secret with logos token", required=True)

        with sdk2.parameters.Group("Yt token - any token able to write to //home/ads/logos/dev/<login>") as yt_token_block:
            yt_token_secret_owner = sdk2.parameters.String("Owner of sb-vault-secret with yt token. If empty - task's owner will be used", default="")
            yt_token_secret_name = sdk2.parameters.String("Name of sb-vault-secret with yt token", required=True)

    class Context(sdk2.Context):
        cluster = "hahn"
        initialized = False
        prepared = False
        build_logos_subtask_id = None
        logos_bin_res_id = None
        logos_bin_dir = "logos/projects/ads/graph/bin"
        max_history_days = 180
        logos_launch_happened = False

    def GetEnv(self, need_logos_user_config=False):
        env = dict(os.environ)
        env['YT_TOKEN_PATH'] = self.yt_token_path

        secret_owner = self.Parameters.logos_token_secret_owner
        secret_name = self.Parameters.logos_token_secret_name
        env['LOGOS_TOKEN'] = sdk2.Vault.data(secret_owner, secret_name)

        if need_logos_user_config:
            env['LOGOS_USER_CUSTOM_CONFIG'] = "step={},use_profile_with_vectors=False".format(self.Parameters.step)

        YTSpec = {
            "job_io": {
                "table_writer": {
                    "max_row_weight": 128 * 1024 * 1024
                }
            }
        }
        env['YT_SPEC'] = json.dumps(YTSpec)

        return env

    def CreateBuildSubtask(self):
        subtask_type = sdk2.Task["YA_MAKE"]

        params = {
                'arch': 'linux',
                'checkout_arcadia_from_url': self.Parameters.arcadia_url,
                'arcadia_patch': self.Parameters.arcadia_patch,
                'targets': self.Context.logos_bin_dir,
                'arts': pj(self.Context.logos_bin_dir, "logos_tool"),
                'result_rt': "ARCADIA_PROJECT",
                'build_system': 'semi_distbuild',
                'build_type': 'release',
                'use_aapi_fuse': True,
                'aapi_fallback': True,
                'check_return_code': True,
                'result_single_file': True,
                'result_ttl': '30',
                'build_output_ttl': 1,
                'build_output_html_ttl': 1,
                'allure_report_ttl': 1,
                consts.STRIP_BINARIES: True,
        }

        subtask = subtask_type(self, description='Building logos_tool binary', **params)
        sdk2.Task.server.task[subtask.id].update({'requirements': {'disk_space': 80737418240, 'ram': 4096}})
        subtask.enqueue()
        return subtask

    def ValidateFirstAndLastAndStep(self):
        assert self.Context.step in [1, 15]
        first_dt = to_date_datetime(self.Context.first_ltp_profile_datetime)
        last_dt = to_date_datetime(self.Context.last_ltp_profile_datetime)
        assert first_dt <= last_dt
        assert last_dt < datetime.datetime.now() - datetime.timedelta(days=1)
        assert (datetime.datetime.now() - first_dt).days <= self.Context.max_history_days + 10  # 10 is just in case
        if self.Context.step == 15:
            date_defining_15_grid = "2021-07-01"
            dt_dt_from_15_grid = to_date_datetime(date_defining_15_grid)
            assert (last_dt - dt_dt_from_15_grid).days % 15 == 0
            assert (first_dt - dt_dt_from_15_grid).days % 15 == 0

    def Initialize(self):
        import yt.wrapper as yt

        self.Context.login = self.Parameters.login
        self.Context.first_ltp_profile_datetime = self.Parameters.first_datetime_str
        self.Context.last_ltp_profile_datetime = self.Parameters.last_datetime_str
        self.Context.step = self.Parameters.step
        self.Context.src_start_date = to_date_datetime_str(to_date_datetime(self.Context.first_ltp_profile_datetime) - datetime.timedelta(days=self.Context.step))
        self.ValidateFirstAndLastAndStep()

        self.Context.arcadia_url = self.Parameters.arcadia_url
        self.Context.arcadia_patch = self.Parameters.arcadia_patch if self.Parameters.arcadia_patch else None
        self.Context.revision = self.Parameters.arcadia_url.rsplit('@', 1)[-1]

        if self.Parameters.mode == TAKE_PROGRESS_FROM_PREV_SUCCESSFULL_TASK:
            prev_task = self.Parameters.prev_task_from_which_to_continue

            copy_ctx_params = ["TaskName", "logos_release_id_prefix"]
            for ctx_param in copy_ctx_params:
                setattr(self.Context, ctx_param, getattr(prev_task.Context, ctx_param))

            assert to_date_datetime(prev_task.Context.last_ltp_profile_datetime) >= to_date_datetime(self.Context.src_start_date)

            should_be_equal_for_reuse_of_graph = ["step", "arcadia_url", "arcadia_patch", "login"]
            self.Context.take_existing_graph = True
            for ctx_param in should_be_equal_for_reuse_of_graph:
                if getattr(self.Context, ctx_param) != getattr(prev_task.Context, ctx_param):
                    self.Context.take_existing_graph = False
                    break

            for log in TASK_TO_LOGS[self.Context.TaskName]:
                src_table_path = os.path.join(prev_task.Context.yt_working_dir, "home/bs/logs", log, "1d", self.Context.src_start_date)
                if not yt.exists(src_table_path):
                    raise Exception("Not exist {}".format(src_table_path))
        else:
            self.Context.take_existing_graph = False
            self.Context.TaskName = self.Parameters.TaskName
            self.Context.logos_release_id_prefix = self.Parameters.logos_release_id_prefix

        if self.Context.take_existing_graph:
            prev_task = self.Parameters.prev_task_from_which_to_continue
            copy_ctx_params = ["now_from_logos_release_id", "logos_release_id", "reactor_task_path", "yt_working_dir"]
            for ctx_param in copy_ctx_params:
                setattr(self.Context, ctx_param, getattr(prev_task.Context, ctx_param))
        else:
            self.Context.now_from_logos_release_id = to_hour_datetime_str(datetime.datetime.now())
            self.Context.logos_release_id = "{}_{}_r{}_step{}".format(self.Context.logos_release_id_prefix, self.Context.now_from_logos_release_id, self.Context.revision, self.Context.step)
            self.Context.reactor_task_path = "/logos/graphs/ads/dev/{}/{}/tasks/{}/{}".format(self.Parameters.login, self.Context.logos_release_id, self.Context.cluster, self.Context.TaskName)
            self.Context.yt_working_dir = "//home/ads/logos/dev/{}/{}".format(self.Parameters.login, self.Context.logos_release_id)

        self.Context.last_date_ltp_profile_tables = []
        for log in TASK_TO_LOGS[self.Context.TaskName]:
            self.Context.last_date_ltp_profile_tables.append(os.path.join(self.Context.yt_working_dir, "home/bs/logs", log, "1d", self.Context.last_ltp_profile_datetime))

    def Prepare(self):
        import yt.wrapper as yt

        if self.Parameters.mode == TAKE_PROGRESS_FROM_PREV_SUCCESSFULL_TASK:
            prev_task = self.Parameters.prev_task_from_which_to_continue

            if not self.Context.take_existing_graph:
                yt.move(prev_task.Context.yt_working_dir, self.Context.yt_working_dir, recursive=True)
        else:
            if self.Parameters.custom_prefix_with_tables_from_which_to_start:
                logs_dir = pj(self.Parameters.custom_prefix_with_tables_from_which_to_start, "home/bs/logs")
                prod_src = False
            else:
                logs_dir = "//home/bs/logs"
                prod_src = True

            for log in TASK_TO_LOGS[self.Context.TaskName]:
                src_table_dir = pj(logs_dir, log, "1d")
                src_table_path = pj(src_table_dir, self.Context.src_start_date)
                dst_table_path = os.path.join(self.Context.yt_working_dir, "home/bs/logs", log, "1d", self.Context.src_start_date)
                if yt.exists(src_table_path):
                    yt.link(src_table_path, dst_table_path, recursive=True)
                else:
                    if prod_src:
                        some_prod_table = pj(src_table_dir, yt.list(src_table_dir)[0])
                        # TODO check run_merge
                        yt.run_merge(yt.TablePath(some_prod_table, start_index=0, end_index=0), dst_table_path, mode=sorted)
                    else:
                        raise Exception("Table is absent: {}".format(src_table_path))

    def DoSyncBinary(self, res_id, bin_name):
        resource = sdk2.Resource.find(id=res_id).limit(1).first()
        resource_data = sdk2.ResourceData(resource)
        resource_path = resource_data.path
        p = sdk2.path.Path(resource_path)
        new_resource_path = pj(os.path.dirname(str(p)), bin_name)
        shutil.move(str(p), new_resource_path)
        return new_resource_path

    def ConfigureYtTokenAndCluster(self):
        secret_owner = self.Parameters.yt_token_secret_owner
        secret_name = self.Parameters.yt_token_secret_name
        yt_token = sdk2.Vault.data(secret_owner, secret_name)

        import yt.wrapper as yt
        yt.config["token"] = yt_token
        yt.config["proxy"]["url"] = self.Context.cluster
        self.yt_token_path = pj(str(self.path()), 'yt_token_file')
        os.system('echo {} > {}'.format(yt_token, self.yt_token_path))

    def GetBinResourceAttrDict(self):
        return {
            "target": self.Context.logos_bin_dir,
            "build_url": self.Context.arcadia_url,

            # "None" is fine. We can't use None and can't omit key "arcadia_patch" - in this cases we might have problems with correct fetching from cache
            "arcadia_patch": str(self.Context.arcadia_patch),
        }

    def TryFindCachedBin(self):
        resource = apihelpers.get_last_resource_with_attrs(rst.ARCADIA_PROJECT, self.GetBinResourceAttrDict(), all_attrs=True)
        if resource:
            return resource.id
        else:
            return None

    def GetProdInputs(self, first_date_for_logs, last_date_for_logs):
        opt = ""
        for log in TASK_TO_ACTION_LOGS[self.Context.TaskName]:
            opt += " {log}={first_date_for_logs}..{last_date_for_logs} ".format(log=log, first_date_for_logs=first_date_for_logs, last_date_for_logs=last_date_for_logs)
        return opt

    def on_execute(self):
        logging.info('RecalcLtpTask: Start')

        self.ConfigureYtTokenAndCluster()

        if not self.Context.initialized:
            self.Initialize()
            self.Context.initialized = True  # has to be set after and not before Initialize() in case Initialize raises exception

        if self.Context.logos_bin_res_id is None and self.Context.build_logos_subtask_id is None:
            self.Context.logos_bin_res_id = self.TryFindCachedBin()
        if self.Context.logos_bin_res_id is None:
            if self.Context.build_logos_subtask_id is None:
                subtask = self.CreateBuildSubtask()
                self.Context.build_logos_subtask_id = subtask.id

        if not self.Context.prepared:
            self.Prepare()
            self.Context.prepared = True  # has to be set after and not before Prepare() in case Prepare raises exception

        if self.Context.logos_bin_res_id is None:
            subtask = sdk2.Task[self.Context.build_logos_subtask_id]
            if subtask.status not in ctt.Status.Group.FINISH + ctt.Status.Group.BREAK:
                raise sdk2.WaitTask([subtask], ctt.Status.Group.FINISH | ctt.Status.Group.BREAK, wait_all=True)
            if subtask.status in (ctt.Status.FAILURE, ctt.Status.EXCEPTION):
                self.Context.build_logos_subtask_id = None
                raise common.errors.TaskError("logos build subtask failed - its status is {}. You may restart me (the priemka-task) and I will create new build subtask".format(str(subtask.status)))
            res_id = subtask.Context.ap_packs["project"]
            set_resource_attributes(int(res_id), self.GetBinResourceAttrDict())
            self.Context.logos_bin_res_id = res_id

        bin_local_path = self.DoSyncBinary(self.Context.logos_bin_res_id, "logos_tool")

        last_date = to_hour_datetime_str(to_date_datetime(self.Context.last_ltp_profile_datetime))
        # TODO get rid of wait-localy? or not
        cmd = [
            bin_local_path, " autorun-tasks  --subgraph {task_name} --tasks {task_name} --release-id {release_id}  --nirvana-secret-name {login}_logos_token --user {login}"
            " --first-date {first_date} --last-date {last_date}  {yt_pool_option} --no-diff --step {step} --lookup-inputs self {nirvana_quota_option} --prod-inputs "
            " {prod_inputs} --wait-localy".format(
                task_name=self.Context.TaskName,
                release_id=self.Context.logos_release_id,
                login=self.Parameters.login,
                first_date=to_hour_datetime_str(to_date_datetime(self.Context.first_ltp_profile_datetime)),
                last_date=last_date,
                yt_pool_option="--yt-pool {}".format(self.Parameters.yt_pool) if self.Parameters.yt_pool else "",
                step=self.Context.step,
                nirvana_quota_option="--nirvana-quota {}".format(self.Parameters.nirvana_quota) if self.Parameters.nirvana_quota else "",
                prod_inputs=self.GetProdInputs(
                    first_date_for_logs=to_hour_datetime_str(to_date_datetime(self.Context.src_start_date) + datetime.timedelta(days=1)),
                    last_date_for_logs=last_date,
                ),
            )]

        # TODO Maybe we should not wait for it to finish and instead we should regularly check whether the final table exists. I mean we should pass wait=False and poll. But the task
        # has to be executing anyway!!!
        errorMsg = ""
        if not self.Context.logos_launch_happened:
            self.Context.logos_launch_happened = True
            self.Context.save()  # dump context
            errorMsg = RunProcess(cmd, self.GetEnv(True), log_prefix="run_logos", exception_if_nonzero_code=False)

        import yt.wrapper as yt
        all_exist = True
        for path_to_check in self.Context.last_date_ltp_profile_tables:
            if not yt.exists(path_to_check):
                all_exist = False
                errorMsg += "\n\n\nFinal table {} doesn't exist! Maybe it's still being built. "
                " You might check in reactor - {}. \n\n\n That's why your logos graph in reactor hasn't been deleted, do it yourself when it's not needed anymore".format(
                    path_to_check, self.Context.reactor_task_path)
        if all_exist:
            if not self.Parameters.keep_reactor_graph:
                cmd = [bin_local_path,
                       " delete-release --release-id {release_id} --nirvana-secret-name {login}_logos_token".format(release_id=self.Context.logos_release_id, login=self.Parameters.login)]
                deleteErrorMsg = RunProcess(cmd, self.GetEnv(), log_prefix="delete_logos_graph", exception_if_nonzero_code=False)
                if deleteErrorMsg:
                    errorMsg += "FINAL TABLE EXISTS AND ONLY DELETION OF REACTOR-LOGOS-GRAPH FAILED: {}".format(deleteErrorMsg)

        if errorMsg:
            eh.check_failed(errorMsg)
