import logging
import json
from sandbox import sdk2
from sandbox.common.errors import TaskFailure, TaskError
from sandbox.common.types import task as task_type
from sandbox.projects.yabs.partner_share.tasks.base_startrek_push_status import TacmanBaseStartrekPushStatus
import sandbox.common.types.task as ctt
from sandbox.sandboxsdk.environments import PipEnvironment
from sandbox.projects.yabs.partner_share.lib.st_helper import get_sandbox_link
from sandbox.projects.yabs.partner_share.lib.config.config import (
    get_config,
    render_spec,
    MAX_PARTNER_SHARE_FIELD,
    APPROVAL_TYPE_FIELD,
    APPROVAL_STAGES_FIELD
)

ONESHOT_YT_CLUSTER = 'markov'
TEST_ONESHOT_YT_CLUSTER = 'hahn'

logger = logging.getLogger(__name__)


class NoBackupDir(Exception):
    pass


def default_check_tasks_callback(self, task, failure_message=''):
    if task.status in task_type.Status.Group.BREAK:
        raise TaskError(failure_message)

    if task.status not in task_type.Status.Group.SUCCEED:
        raise TaskFailure(failure_message)


def on_changes_finish(self, task):
    failure_message = (
        'Failed to generate changes table. Task #{task.id} finished with status {task.status}'
        .format(task=task)
    )
    return default_check_tasks_callback(self, task, failure_message)


class TacmanApplyChanges(TacmanBaseStartrekPushStatus):

    class Requirements(sdk2.Requirements):
        cores = 1
        ram = 512
        environments = [
            PipEnvironment('yql', version='1.2.91', use_wheel=True),
            PipEnvironment('yandex-yt', version='0.9.17', use_wheel=True),
            PipEnvironment('startrek_client', version='2.5.0', use_wheel=True)
        ]

        class Caches(sdk2.Requirements.Caches):
            pass  # Do not use any shared caches (required for running on multislot agent)

    class Parameters(TacmanBaseStartrekPushStatus.Parameters):
        kill_timeout = 5 * 3600
        ignore_partner_share_above = sdk2.parameters.Integer(
            'Ignore records with PartnerShare above this value (0 to 1 000 000)',
            default=1000000
        )
        stage_name = sdk2.parameters.String(label='Stage to run: FILTER or TEST or...')
        direct_issue = sdk2.parameters.String(
            label='Startrek issue of a direct oneshot',
            default=''
        )
        user = sdk2.parameters.String(label='Oneshot author')
        filter_json_startrek_file = sdk2.parameters.String(
            "Filter json file name in startrek",
            default="filter.json",
            required=True
        )
        yt_cluster = sdk2.parameters.String(label='YT cluster for requests', default='hahn')
        chyt_cluster = sdk2.parameters.String(label='CHYT cluster', default='chyt.hahn/lelby_ckique')

    def show_yql_operation(self, title, yql_url):
        self.set_info('YQL operation: <a href={url}>{title}</a>'.format(
            url=yql_url,
            title=title
        ), do_escape=False)

    def show_yt_path(self, title, path):
        self.set_info('{title}: <a href=https://yt.yandex-team.ru/{cluster}/navigation?path={path}>{path}</a>'.format(
            title=title,
            cluster=self.Parameters.yt_cluster,
            path=path
        ), do_escape=False)

    def set_warn(self, title, path, warning):
        self.Context.warnings.append(
            '{warning}: {title}: {path}'.format(
                warning=warning,
                title=title,
                path=path,
            )
        )

    def check_table(self, info):
        try:
            rows = self.yt.get(info['path'] + '/@chunk_row_count')
            color = "lightgreen"
            if info.get('need_empty') and rows:
                self.set_warn(info['name'], info['path'], 'Table should be empty')
                color = "lightpink"
            if info.get('need_nonempty') and not rows:
                self.set_warn(info['name'], info['path'], 'Table should not be empty')
                color = "lightpink"
            self.set_info('{title}: <a href=https://yt.yandex-team.ru/{cluster}/navigation?path={path}>{path}</a> <span style="background-color:{color}"><b>({rows} rows)</b></span>'.format(
                title=info['name'],
                cluster=self.Parameters.yt_cluster,
                path=info['path'],
                rows=rows,
                color=color,
            ), do_escape=False)
        except:
            self.set_info('{title}: <a href=https://yt.yandex-team.ru/{cluster}/navigation?path={path}>{path}</a> <span style="background-color:lightpink"><b>(Table is missing)</b></span>'.format(
                title=info['name'],
                cluster=self.Parameters.yt_cluster,
                path=info['path'],
            ), do_escape=False)
            self.set_warn(info['name'], info['path'], 'Table is missing')

    def show_yt_path_rowcount(self, title, path):
        try:
            self.set_info('{title}: <a href=https://yt.yandex-team.ru/{cluster}/navigation?path={path}>{path}</a> ({rows} rows)'.format(
                title=title,
                cluster=self.Parameters.yt_cluster,
                path=path,
                rows=self.yt.get(path + '/@chunk_row_count')
            ), do_escape=False)
        except:
            self.show_yt_path(title, path)

    def show_started_task(self, title, task_id):
        self.set_info('Started <a href={url}>"{title}" task</a>'.format(
            title=title,
            url=get_sandbox_link(task_id),
        ), do_escape=False)

    def set_max_partner_share_in_ticket(self, spec):
        from sandbox.projects.yabs.partner_share.lib.fast_changes.fast_changes import MAX_PARTNER_SHARE_WITH_APPROVE_1

        rows = self.yt.read_table(spec['inputs']['changes_target'])
        max_partner_share = 0
        for row in rows:
            if row['PlannedPartnerShare'] > max_partner_share:
                max_partner_share = row['PlannedPartnerShare']

        logging.debug('Setting {} field of ticket {} to {}'.format(
            MAX_PARTNER_SHARE_FIELD,
            self.Parameters.issue,
            max_partner_share,
        ))
        self.st_helper.change_local_field(
            ticket=self.Parameters.issue,
            local_field_name=MAX_PARTNER_SHARE_FIELD,
            value=max_partner_share,
        )

        approval_type = 2 if max_partner_share > MAX_PARTNER_SHARE_WITH_APPROVE_1 else 1
        if 'TESTTACCHANGES' in self.Parameters.issue:
            approval_type = 'DEV'

        logging.debug('Setting {} field of ticket {} to {}'.format(
            APPROVAL_TYPE_FIELD,
            self.Parameters.issue,
            approval_type,
        ))
        self.st_helper.change_local_field(
            ticket=self.Parameters.issue,
            local_field_name=APPROVAL_TYPE_FIELD,
            value=approval_type,
        )
        self.st_helper.change_local_field(
            ticket=self.Parameters.issue,
            local_field_name=APPROVAL_STAGES_FIELD,
            value=json.dumps(self.config['approvals'][str(approval_type)]),
        )

    def wait_task(self, task_id):
        if task_id:
            raise sdk2.WaitTask(task_id, ctt.Status.Group.FINISH | ctt.Status.Group.BREAK, wait_all=True)

    def check_task(self, task_id):
        task = self.server.task[task_id].read()
        if task['status'] != 'SUCCESS':
            raise TaskFailure('Task {} failed: {}'.format(
                task_id,
                task['status']
            ))
        return task

    def direct_oneshot(self, spec):
        if 'TESTTACCHANGES' in self.Parameters.issue:
            self.protect_stage()
        with self.memoize_stage.direct_oneshot_task(commit_on_entrance=False):
            task = self.server.task({
                'type': 'EXECUTE_YT_ONESHOT',
                'description': 'Execute direct oneshot for issue "{}"'.format(self.Parameters.issue),
                'children': True,
                'owner': 'YABS-YT-ONESHOT-EXECUTOR',
                'custom_fields': [{'name': k, 'value': v} for k, v in dict(
                    oneshot_path=self.config['constants']['DIRECT_ONESHOT_PATH'],
                    oneshot_args=json.dumps({
                        'user': self.Parameters.user,
                        'ticket': self.Parameters.issue,
                        'oneshot': self.Parameters.issue,
                        'items': [],
                        'changes': spec['inputs']['changes']
                    }),
                    run_in_test_mode=False if self.Parameters.stage_name == 'EXECUTE' else True,
                    print_tables_only=False,
                    backup_dir_ttl=3600000,
                    oneshot_binary_ttl=90,
                    startrek_ticket=self.Parameters.issue,
                    yt_proxy=ONESHOT_YT_CLUSTER if self.Parameters.stage_name == 'EXECUTE' else TEST_ONESHOT_YT_CLUSTER,
                    reuse_oneshot_binary=True,
                    force_data_flush=False if self.Parameters.stage_name == 'EXECUTE' else True,
                    use_testing_startrek=self.Parameters.use_testing_startrek,
                    run_backup_tables=False if self.Parameters.stage_name == 'EXECUTE' else True,
                    allow_access_to_production_data=True if self.Parameters.stage_name == 'EXECUTE' else False,
                ).items()],
            })
            task_id = task['id']
            self.server.batch.tasks.start.update([task_id])
            self.Context.direct_task_id = task_id
            self.show_started_task('Direct oneshot', task_id)
            self.wait_task(self.Context.direct_task_id)

        task = self.check_task(self.Context.direct_task_id)
        self.yt_root = task['output_parameters'].get('backup_dir', '//')

        if len(self.yt_root) < 3 and self.Parameters.stage_name != 'EXECUTE':
            raise NoBackupDir('Direct oneshot task returned wrong backup dir {} in test mode'.format(self.yt_root))

        with self.memoize_stage.direct_oneshot_results(commit_on_entrance=False):
            self.yt.create(
                'map_node',
                spec['outputs']['oneshot_finished']['path'],
                recursive=True,
                force=True,
            )
            if len(self.yt_root) > 2:
                self.show_yt_path('Backup dir', self.yt_root)

    def revert_oneshot(self, spec):
        with self.memoize_stage.revert_oneshot_task(commit_on_entrance=False):
            task = self.server.task({
                'type': 'EXECUTE_YT_ONESHOT',
                'description': 'Execute revert oneshot for issue {} (direct issue {})'.format(
                    self.Parameters.issue,
                    self.direct_issue,
                ),
                'children': True,
                'owner': self.owner,
                'custom_fields': [{'name': k, 'value': v} for k, v in dict(
                    oneshot_path=self.config['constants']['REVERT_ONESHOT_PATH'],
                    oneshot_args=json.dumps({
                        'user': self.Parameters.user,
                        'ticket': self.direct_issue,
                        'oneshot': self.direct_issue
                    }),
                    run_in_test_mode=False if self.Parameters.stage_name == 'REVERT' else True,
                    print_tables_only=False,
                    backup_dir_ttl=3600000,
                    oneshot_binary_ttl=90,
                    startrek_ticket=self.Parameters.issue,
                    yt_proxy=ONESHOT_YT_CLUSTER if self.Parameters.stage_name == 'REVERT' else TEST_ONESHOT_YT_CLUSTER,
                    reuse_oneshot_binary=True,
                    force_data_flush=False if self.Parameters.stage_name == 'REVERT' else True,
                    use_testing_startrek=self.Parameters.use_testing_startrek,
                    run_backup_tables=False,
                    ypath_prefix=self.yt_root,
                    allow_access_to_production_data=True if self.Parameters.stage_name == 'REVERT' else False,
                ).items()],
            })
            task_id = task['id']
            self.server.batch.tasks.start.update([task_id])
            self.Context.revert_task_id = task_id
            self.show_started_task('Revert oneshot', task_id)
            self.wait_task(self.Context.revert_task_id)

        self.check_task(self.Context.revert_task_id)

        with self.memoize_stage.revert_oneshot_results(commit_on_entrance=False):
            self.yt.create(
                'map_node',
                spec['outputs']['revert_finished']['path'],
                recursive=True,
                force=True,
            )

    def protect_stage(self):
        # Protect from executing dangerous oneshot in testing
        self.set_info('WARNING: Only single filters with PageID=414314 are allowed to EXECUTE in test')
        if len(self.filters['filters']) == 1:
            for condition in self.filters['filters'][0]['conditions']:
                if condition['field'] == 'PageID' and condition['value'] == '414314':
                    return

        if self.Parameters.stage_name =='EXECUTE':

            raise RuntimeError('Only single filters with PageID=414314 are allowed to EXECUTE in test')


    def init(self):
        from yt.wrapper import YtClient

        tokens = self.Parameters.tokens.data()
        self.yql_token = tokens['yql_token']
        self.yt = YtClient(proxy=self.Parameters.yt_cluster, token=self.yql_token)

        self.config = get_config()

        self.stage = self.config['stages'][self.Parameters.stage_name]

        self.yt_root = '//'

        self.init_startrek()

        if self.Parameters.direct_issue:
            self.direct_issue = self.Parameters.direct_issue
        else:
            self.direct_issue = self.Parameters.issue

        self.filters=self.st_helper.get_attachment(
            self.direct_issue,
            self.Parameters.filter_json_startrek_file
        )

        with self.memoize_stage.init(commit_on_entrance=False):
            self.Context.warnings = []

    def memoize_operation(self, operation_name):
        with self.memoize_stage[operation_name](commit_on_entrance=False):
            self.run_operation(operation_name)

    def run_operation(self, operation_name):
        from sandbox.projects.yabs.partner_share.lib.create_tables.create_tables import create_output_tables
        from sandbox.projects.yabs.partner_share.lib.operations import run_imported_operation

        spec = render_spec(
            task=self,
            config=self.config,
            stage_name=self.Parameters.stage_name,
            operation_name=operation_name,

            yql_token=self.yql_token,
            yt_cluster=self.Parameters.yt_cluster,
            chyt_cluster=self.Parameters.chyt_cluster,

            issue=self.Parameters.issue,
            direct_issue=self.direct_issue,
            ignore_partner_share_above=self.Parameters.ignore_partner_share_above,
            filters=self.filters,

            yt_root=self.yt_root,
        )

        if spec.get('clickhouse'):
            create_output_tables(spec)

        if 'function' in spec:
            function_name = spec['function']
        else:
            function_name = operation_name

        logging.debug('Starting operation {} with function {}, spec: {}'.format(
            operation_name, function_name, spec
        ))

        if hasattr(self, function_name):
            getattr(self, function_name)(spec)
        else:
            run_imported_operation(function_name, spec)

        self.check_output_tables(spec)

    def check_output_tables(self, operation):
        from six import iteritems

        if 'outputs' not in operation:
            return

        for table, info in iteritems(operation['outputs']):
            if info.get('hidden'):
                continue

            if info.get('hidden_in_stages'):
                if self.Parameters.stage_name in info['hidden_in_stages']:
                    continue

            self.check_table(info)

    def run_stage(self):
        for operation_name in self.stage['operations']:
            if self.config['operations'][operation_name].get('memoize_stage'):
                self.memoize_operation(operation_name)
            else:
                self.run_operation(operation_name)

    def on_execute(self):
        self.init()

        self.run_stage()
