# -*- coding: utf-8 -*-
import os
import json

import vh
from datacloud.ml_utils.vh_wrapper.helpers.cubes import run_grid_search, process_gs_results
from datacloud.key_manager.key_helpers import get_key


class GraphBuilder(object):
    DEFAULT_QUOTA = 'datacloud'
    DEFAULT_SANDBOX_TOKEN_SECRET = None
    DEFAULT_SANDBOX_OWNER = None
    DEFAULT_A_REVISION = 5666024
    DEFAULT_YT_TOKEN_SECRET = 'robot_xprod_yt_token'
    DEFAULT_ST_USERAGENT = 'robot-xprod'
    DEFAULT_ST_TOKEN = 'robot_xprod_st_token'
    DEFAULT_SECRETS_FILE = 'pipeline_secrets'
    NIRVANA_TOKEN_NAME = 'NIRVANA_TOKEN'

    def __init__(self, oauth_token=None, quota=DEFAULT_QUOTA, project=None, workflow_guid=None,
                 label=None, sandbox_oauth_token_secret=DEFAULT_SANDBOX_TOKEN_SECRET,
                 sandbox_owner=DEFAULT_SANDBOX_OWNER, arcadia_revision=DEFAULT_A_REVISION,
                 yt_token_secret=DEFAULT_YT_TOKEN_SECRET, st_user_agent=DEFAULT_ST_USERAGENT,
                 st_token=DEFAULT_ST_TOKEN, secrets_file=DEFAULT_SECRETS_FILE):
        self.oauth_token = oauth_token
        self.quota = quota
        self.project = project
        self.workflow_guid = workflow_guid
        self.label = label
        self.sandbox_oauth_token_secret = sandbox_oauth_token_secret
        self.sandbox_owner = sandbox_owner
        self.arcadia_revision = arcadia_revision
        self.yt_token_secret = yt_token_secret
        self.st_user_agent = st_user_agent
        self.st_token = st_token
        self.secrets_file = secrets_file

        if self.oauth_token is None:
            self.oauth_token = get_key(
                self.secrets_file,
                self.NIRVANA_TOKEN_NAME,
                default=os.getenv(self.NIRVANA_TOKEN_NAME)
            )

    def _make_run_params(self, **kwargs):
        params = dict(
            oauth_token=self.oauth_token, quota=self.quota, project=self.project,
            workflow_guid=self.workflow_guid, label=self.label,
            sandbox_oauth_token_secret=self.sandbox_oauth_token_secret,
            sandbox_owner=self.sandbox_owner, arcadia_revision=self.arcadia_revision,
            yt_token_secret=self.yt_token_secret, start=True, backend=vh.NirvanaBackend())

        params.update(kwargs)
        return params

    def _run(self, **kwargs):
        return vh.run(**self._make_run_params(**kwargs))

    def _run_async(self, **kwargs):
        return vh.run_async(**self._make_run_params(**kwargs))


class InputPipelineGraphBuilder(GraphBuilder):
    def __init__(self, workflow_guid, oauth_token=None):
        super(InputPipelineGraphBuilder, self).__init__(
            workflow_guid=workflow_guid,
            oauth_token=oauth_token
        )

    def run_input_pipeline_graph(self, table_path, target_names, yt_folder, ticket_name, ST_Logger_cls,
                                 features_tag='', comment_id='', write_st_message=True):
        with vh.Graph():
            gs_ops = []
            for target_name in target_names:
                params = {
                    'target_name': target_name,
                    'stream_to_yt': True,
                    'yt_folder': yt_folder,
                    'table_path': table_path,
                    'ticket_name': ticket_name
                }
                gs_ops.append(
                    run_grid_search(params=json.dumps(params), ticket_name=ticket_name,
                                    target_name=target_name, features_tag=features_tag)
                )

            if write_st_message:
                process_gs_results.partial(ST_Logger_cls=ST_Logger_cls)(
                    gs_results=gs_ops, useragent=self.st_user_agent,
                    st_token=self.st_token, ticket_name=ticket_name,
                    comment_id=comment_id, features_tag=features_tag
                )

            self._run()
