from sandbox import common
from sandbox import sdk2

from sandbox.projects.dj.unity.resources import DjUnityPackage
from sandbox.projects.dj.unity.utils import nirvana
from sandbox.projects.release_machine.helpers.startrek_helper import STHelper
from sandbox.sandboxsdk import process

import sandbox.projects.release_machine.core.const as rm_const
import sandbox.projects.release_machine.components.all as rm_comp

import json
import logging
import os
import requests
import tarfile

DRAFT = "draft"
EXECUTING = "executing"
FINISHED = "finished"
BROKEN = "broken"

PRODUCTION = "production"
RELEASE = "release"

class ProcessState(object):
    def __init__(
            self,
            name,
            template_workflow_id,
            template_instance_id,
            acceptance_instance_id,
            removed_block_ids,
            global_params,
            result_stats,
            deps,
            state,
            env
    ):
        self.name = name
        self.template_workflow_id = template_workflow_id
        self.template_instance_id = template_instance_id
        self.acceptance_instance_id = acceptance_instance_id
        self.removed_block_ids = removed_block_ids
        self.global_params = global_params
        self.result_stats = result_stats
        self.deps = deps
        self.state = state
        self.env = env

    def to_json(self):
        return {
            "name": self.name,
            "template_workflow_id": self.template_workflow_id,
            "template_instance_id": self.template_instance_id,
            "acceptance_instance_id": self.acceptance_instance_id,
            "removed_block_ids": self.removed_block_ids,
            "global_params": self.global_params,
            "result_stats": self.result_stats,
            "deps": self.deps,
            "state": self.state,
            "env": self.env
        }

    @classmethod
    def from_json(cls, json):
        return cls(
            json["name"],
            json["template_workflow_id"],
            json["template_instance_id"],
            json["acceptance_instance_id"],
            json["removed_block_ids"],
            json["global_params"],
            json["result_stats"],
            json["deps"],
            json["state"],
            json["env"]
        )

    @classmethod
    def from_process_json(cls, json, env):
        return cls(
            json["name"],
            json["nirvana"]["workflow_id"],
            json["nirvana"]["instance_id"],
            None,
            json["acceptance"]["removed_block_ids"],
            json["acceptance"]["global_params"],
            json["acceptance"]["result_stats"],
            json["acceptance"].get("depends_on", []),
            DRAFT,
            env
        )


class AcceptanceState(object):
    def create_nirvana_client(self):
        logging.info('Creating nirvana client')
        nirvana_token = sdk2.Vault.data(self.nirvana_token_vault)
        return nirvana.NirvanaClient(nirvana_token)

    def __init__(
        self,
        resource_ids,
        nirvana_token_vault,
        global_params,
        component_name,
        release_number,
        base_dir,
        yt_server,
        yt_token_vault,
        processes=[]
    ):
        self.package_resource_ids = resource_ids
        self.nirvana_token_vault = nirvana_token_vault
        self.nirvana_client = self.create_nirvana_client()
        self.global_params = global_params
        self.component_name = component_name
        self.release_number = release_number
        self.base_dir = base_dir
        self.yt_server = yt_server
        self.yt_token_vault = yt_token_vault
        self.processes = processes
        self.st_helper = STHelper(sdk2.Vault.data(rm_const.COMMON_TOKEN_OWNER, rm_const.COMMON_TOKEN_NAME))

    def to_json(self):
        json = {
            "package_resources": self.package_resource_ids,
            "nirvana_token_vault": self.nirvana_token_vault,
            "component_name": self.component_name,
            "release_number": self.release_number,
            "global_params": self.global_params,
            "base_dir": self.base_dir,
            "yt_server": self.yt_server,
            "yt_token_vault": self.yt_token_vault
        }

        json["processes"] = []

        for process_states in self.processes:
            states = {RELEASE: process_states[RELEASE].to_json()}
            if PRODUCTION in process_states:
                states[PRODUCTION] = process_states[PRODUCTION].to_json()
            json["processes"].append(states)

        return json

    @classmethod
    def from_json(cls, json):
        processes = []
        for process_states in json["processes"]:
            states = {RELEASE: ProcessState.from_json(process_states[RELEASE])}
            if PRODUCTION in process_states:
                states[PRODUCTION] = ProcessState.from_json(process_states[PRODUCTION])
            processes.append(states)

        return cls(
            json["package_resources"],
            json["nirvana_token_vault"],
            json["global_params"],
            json["component_name"],
            json["release_number"],
            json["base_dir"],
            json["yt_server"],
            json["yt_token_vault"],
            processes
        )

    def add(self, release_process_json, production_process_json):
        states = {RELEASE: ProcessState.from_process_json(release_process_json, RELEASE)}
        if production_process_json is not None:
            states[PRODUCTION] = ProcessState.from_process_json(production_process_json, PRODUCTION)
        self.processes.append(states)

    def substitute_params(self, value, env):
        for k, v in self.global_params[env].iteritems():
            if isinstance(value, basestring):
                value = value.replace(k, v)
        return value

    def prepare_graph_global_params(self, state):
        params = []
        for param in state.global_params:
            params.append(
                {
                    "parameter": param["key"],
                    "value": self.substitute_params(param["value"], state.env)
                }
            )
        return params

    @staticmethod
    def create_blocks_description(block_ids):
        blocks = []
        for b in block_ids:
            blocks.append({
                "code": b
            })
        return blocks

    def start_process(self, state):
        logging.info('Creating new workflow instance')
        state.acceptance_instance_id = self.nirvana_client.clone_workflow_instance(
            state.template_workflow_id,
            state.template_instance_id
        )
        self.nirvana_client.add_comment_to_workflow_instance(
            state.acceptance_instance_id,
            'Acceptance graph: {}'.format(state.env)
        )
        logging.info('New instance: {}/{}'.format(state.template_workflow_id, state.acceptance_instance_id))
        new_global_params = self.prepare_graph_global_params(state)
        logging.info('Setting new params: {}'.format(new_global_params))
        self.nirvana_client.set_workflow_parameters(
            state.template_workflow_id,
            state.acceptance_instance_id,
            new_global_params
        )
        blocks_to_remove = self.create_blocks_description(state.removed_block_ids)
        logging.info('Removing blocks {}'.format(blocks_to_remove))
        self.nirvana_client.remove_blocks(
            state.template_workflow_id,
            state.acceptance_instance_id,
            blocks_to_remove
        )
        self.nirvana_client.approve_workflow_instance(
            state.template_workflow_id,
            state.acceptance_instance_id
        )
        logging.info('Starting instance')
        self.nirvana_client.start_workflow_instance(
            state.template_workflow_id,
            state.acceptance_instance_id
        )
        state.state = EXECUTING

    def start_processes(self):
        logging.info('Trying to start processes')
        for process_states in self.processes:
            for state in process_states.values():
                if state.state == DRAFT and len(state.deps) == 0:
                    self.start_process(state)

    def check_process_finished(self, state):
        instance_state = self.nirvana_client.get_workflow_instance_state(
            state.template_workflow_id,
            state.acceptance_instance_id
        )
        if instance_state['status'] == 'completed':
            if instance_state['result'] != 'success':
                state.state = BROKEN
                raise common.errors.TaskFailure(
                    'Instance {}/{} failed. Status: {}'.format(
                        state.template_workflow_id,
                        state.acceptance_instance_id,
                        instance_state['result']
                    )
                )
            state.state = FINISHED
            logging.info(
                'Instance {}/{} finished successfully.'.format(
                    state.template_workflow_id,
                    state.acceptance_instance_id
                )
            )
            return True
        return False

    def download_result(self, state):
        stats_list = state.result_stats
        if not isinstance(stats_list, list):
            stats_list = [state.result_stats]

        result = {}
        for stats in stats_list:
            if "name" not in stats:
                return None
            block_results = self.nirvana_client.get_blocks_result(
                state.template_workflow_id,
                state.acceptance_instance_id,
                self.create_blocks_description([stats["block_id"]])
            )
            for output in block_results[0]["results"]:
                if output["endpoint"] == stats["output_id"]:
                    storage_url = output["directStoragePath"]
                    r = requests.get(storage_url, verify=False)
                    r.raise_for_status()
                    result[stats["name"]] = (json.loads(r.text), stats["type"])

        return result

    def make_comment_in_st_ticket(self, message):
        if self.component_name is None or len(self.component_name) == 0:
            return

        c_info = rm_comp.COMPONENTS[self.component_name]()
        self.st_helper.comment(self.release_number, message, c_info)

    @staticmethod
    def create_process_stats_message(stats_name, env, process_name, cluster, path):
        return '((http://yt.yandex-team.ru/{}/navigation?path={} Stats)) {} for {} process {}\n'.format(
            cluster,
            path,
            stats_name,
            env,
            process_name
        )

    def notify_new_process(self, release_state):
        result_stats_files = self.download_result(release_state)
        if result_stats_files is None:
            message = 'No stats for new process {}'.format(release_state.name)
            logging.info(message)
            self.make_comment_in_st_ticket(message)
        else:
            for name, stats_file in result_stats_files.iteritems():
                message = self.create_process_stats_message(
                    name,
                    'new',
                    release_state.name,
                    stats_file[0]["cluster"],
                    stats_file[0]["path"]
                )
                logging.info(message)
                self.make_comment_in_st_ticket(message)

    def get_diff_tool(self):
        package_resource_data = sdk2.ResourceData(sdk2.Resource[self.package_resource_ids[RELEASE]])
        with tarfile.open(str(package_resource_data.path), "r:gz") as tar:
            tar.extractall(path=".")
        return "./tools/yt/yt_rule_applier"

    def run_diff_tool(self, diff_tool_path, release_stats_file_path, production_stats_file, mode):
        diff_mode = ''
        if mode == 'actions':
            diff_mode = 'action-stat-diff'
        elif mode == 'profiles':
            diff_mode = 'profile-stat-diff'

        out_file = "./stats_diff"

        cmd = [
            diff_tool_path,
            diff_mode,
            '--prev', production_stats_file,
            '--new', release_stats_file_path,
            '-d', str(1),
            '-o', out_file
        ]

        process.run_process(cmd, work_dir=".", log_prefix='diff_tool')

        try:
            diff = {}
            with open(out_file, 'r') as f:
                diff = json.load(f)
            with open(out_file, 'w') as f:
                json.dump(diff, f, indent=2)
        except:
            print("Can not load json")

        return out_file

    def build_diff_file(self, release_stats_file, production_stats_file, mode, process_name, stats_name):
        import yt.wrapper as yt
        yt.config.set_proxy(self.yt_server)
        yt.config['token'] = sdk2.Vault.data(self.yt_token_vault)

        release_stats_file_path = "./release.stats"
        production_stats_file_path = "./production.stats"
        with open(release_stats_file_path, 'wb') as f:
            f.write(yt.read_file(release_stats_file).read())
        with open(production_stats_file_path, 'wb') as f:
            f.write(yt.read_file(production_stats_file).read())

        diff_tool_path = self.get_diff_tool()

        out_file = self.run_diff_tool(diff_tool_path, release_stats_file_path, production_stats_file_path, mode)
        yt_file = "{}/{}_{}.diff".format(self.base_dir, process_name, stats_name)
        with open(out_file, "rb") as f:
            yt.write_file(yt_file, f)
        with open(out_file, "r") as f:
            content = f.read()
        return yt_file, content

    def notify_processes_diff(self, release_state, production_state):
        release_stats_file = self.download_result(release_state)
        production_stats_file = self.download_result(production_state)

        message = ""

        if release_stats_file is None:
            message += 'No stats for release process {}\n'.format(release_state.name)
        else:
            for name, stats_file in release_stats_file.iteritems():
                if len(message) > 0:
                    message += '\n'
                message += self.create_process_stats_message(
                    name,
                    'release',
                    release_state.name,
                    stats_file[0]["cluster"],
                    stats_file[0]["path"]
                )
                if production_stats_file is not None and name in production_stats_file:
                    message += self.create_process_stats_message(
                        name,
                        'production',
                        production_state.name,
                        production_stats_file[name][0]["cluster"],
                        production_stats_file[name][0]["path"]
                    )
                    diff_file, diff = self.build_diff_file(
                        stats_file[0]["path"],
                        production_stats_file[name][0]["path"],
                        stats_file[1],
                        release_state.name,
                        name
                    )
                    message += '((http://yt.yandex-team.ru/{}/navigation?path={} {})) {} for process {}.\n'.format(
                        stats_file[0]["cluster"],
                        diff_file,
                        'Diff' if len(diff) > 0 else 'No diff',
                        name,
                        release_state.name
                    )
                    if len(diff) > 0 and len(diff) < 1024:
                        message += "<{Show diff\n" + diff + "}>\n"

        if production_stats_file is None:
            message += 'No stats for production process {}\n'.format(production_state.name)
        else:
            for name, stats_file in production_stats_file.iteritems():
                if release_stats_file is None or name not in release_stats_file:
                    message += self.create_process_stats_message(
                        name,
                        'production',
                        production_state.name,
                        stats_file[0]["cluster"],
                        stats_file[0]["path"]
                    )

        logging.info(message)
        self.make_comment_in_st_ticket(message)

    def update_dependencies(self, state):
        name = state.name
        env = state.env

        for process_states in self.processes:
            if env in process_states:
                if name in process_states[env].deps:
                    logging.info('Remove dependency {} from process {} {}'.format(name, process_states[env].name, env))
                    process_states[env].deps.remove(name)

    def on_process_finished(self, process_states, env):
        self.update_dependencies(process_states[env])
        self.start_processes()

        if process_states[RELEASE].state != FINISHED:
            return

        if not PRODUCTION in process_states:
            self.notify_new_process(process_states[RELEASE])
            return

        if process_states[PRODUCTION].state == FINISHED:
            self.notify_processes_diff(process_states[RELEASE], process_states[PRODUCTION])

    def has_running_process(self):
        logging.info('Checking for running processes...')
        for process_states in self.processes:
            for state in process_states.values():
                if state.state == EXECUTING:
                    return True
        return False

    def check_processes(self):
        logging.info('Checking for completed processes...')
        for process_states in self.processes:
            for state in process_states.values():
                if state.state == EXECUTING:
                    if self.check_process_finished(state):
                        self.on_process_finished(process_states, state.env)
        return self.has_running_process()
