from __future__ import unicode_literals
from __future__ import print_function

import logging
from time import sleep
from sandbox.projects.yabs.partner_share.lib.config.config import (
    get_daemons,
    APPROVAL_TYPE_FIELD,
    APPROVAL_STAGES_FIELD,
    MAX_PARTNER_SHARE_FIELD,
)
from yt.wrapper import YtClient, ypath_join
from sandbox.projects.yabs.partner_share.lib.ok_helper import OkHelper
from sandbox.projects.yabs.partner_share.lib.st_helper import (
    StartrekHelper,
    get_queue_from_issue,
    get_sandbox_link,
)
from sandbox.projects.yabs.partner_share.lib.sandbox_helper import SandboxHelper

l = logging.getLogger(__name__)


def full_flow(config):
    state = 'FILTER'
    states = [state]
    while 'next' in config['stages'][state]['next_states']:
        state = config['stages'][state]['next_states']['next']
        states.append(state)
    return states


class Integratest:
    def __init__(self, config, start_stage, end_stage, ticket, st_token, ok_token, yql_token, yt_cluster, sandbox_token):
        self.config = config
        self.start_stage = start_stage
        self.end_stage = end_stage
        self.ticket = ticket
        self.current_ticket = ticket
        self.st_token = st_token
        self.ok_token = ok_token
        self.yql_token = yql_token
        self.yt_cluster = yt_cluster
        self.sandbox_token = sandbox_token

        self.queue = get_queue_from_issue(ticket)

        self.ticket_url = self.config['queues'][self.queue]['front_url'] + '/request/'

        self.st_helper = StartrekHelper(
            useragent="sandbox",
            startrek_api_url=self.config['constants']['STARTREK_API_ENDPOINT'],
            st_token=st_token,
            local_fields_prefix=config['queues'][self.queue]['local_fields_prefix'],
        )

        self.ok_helper = OkHelper(self.ok_token, self.config['approvals'])
        self.sandbox_helper = SandboxHelper(sandbox_token)

        self.filters=self.st_helper.get_attachment(
            self.ticket,
            self.config['constants']['FILTER_FILE']
        )

        self.author = self.st_helper.get_field(self.ticket, 'createdBy').id

        try:
            self.revert_ticket = self.get_field('tacman_revert_issue')
        except:
            self.revert_ticket = None

        self.yt = YtClient(proxy=self.yt_cluster, token=self.yql_token)

    def get_flow(self):
        state = self.start_stage
        states = [state]
        while 'next' in self.config['stages'][state]['next_states'] and state != self.end_stage:
            state = self.config['stages'][state]['next_states']['next']
            states.append(state)
        return states

    def set_sb_id(self, id):
        field = self.stage['task_id_field']
        l.warn('Setting %s = "%s"', field, id)
        self.st_helper.change_local_field(self.current_ticket, field, id)

    def get_sb_id(self):
        field = self.stage['task_id_field']
        return self.st_helper.get_local_field(self.current_ticket, field)

    def set_field(self, field, id):
        l.warn('Setting %s = "%s"', field, id)
        self.st_helper.change_local_field(self.current_ticket, field, id)

    def get_field(self, field):
        return self.st_helper.get_local_field(self.current_ticket, field)

    def wait_state_change(self, previous_state, wait_state, timeout_minutes):
        l.warn('Waiting for state change...')
        for _ in range(timeout_minutes * 60):
            new_state = self.get_field('tacman_status')
            if new_state and new_state != previous_state:
                if new_state != wait_state:
                    raise RuntimeError('Got state %s, expecting %s' % (new_state, wait_state))
                l.warn('Got state %s', new_state)
                return new_state
            sleep(1)
        raise RuntimeError('Timeout %s minutes waiting for state %s' % (timeout_minutes, new_state))

    def wait_sb_id(self, timeout_minutes):
        field = self.stage['task_id_field']
        for i in range(timeout_minutes * 60):
            sb_id = self.get_field(field)
            if sb_id:
                l.warn(
                    '%s SB task: %s',
                    self.stage_name,
                    get_sandbox_link(sb_id),
                )
                return sb_id
            if i == 2:
                l.warn('Waiting for sandbox task...')
            sleep(1)
        raise RuntimeError('Timeout % sec waiting for sb_id')

    def run_stage(self):
        self.set_sb_id('')
        self.set_field("tacman_status", self.stage_name)

        previous_state = self.stage_name
        if 'process' in self.stage['next_states']:
            previous_state = self.wait_state_change(
                previous_state=previous_state,
                wait_state=self.stage['next_states']['process'],
                timeout_minutes=2,
            )

        self.sb_id = self.wait_sb_id(timeout_minutes=2)
        self.wait_state_change(
            previous_state=previous_state,
            wait_state=self.stage['next_states']['success'],
            timeout_minutes=self.stage['timeout_minutes'],
        )

    def protect_stage(self):
        # Protect from executing dangerous oneshot in production

        if len(self.filters['filters']) == 1:
            for condition in self.filters['filters'][0]['conditions']:
                if condition['field'] == 'PageID' and condition['value'] == '414314':
                    return

        if self.stage_name == 'EXECUTE':
            raise RuntimeError('Only single filters with PageID=414314 are allowed to EXECUTE in test')

        l.warn('WARNING: Only single filters with PageID=414314 are allowed to EXECUTE in test')

    def remove_repeat_oneshot_protector(self):
        path = ypath_join(
            self.config['constants']['TACMAN_REQUESTS_DIR'],
            self.current_ticket,
            self.stage['stage_folder'],
            'changes_applied'
        )
        if self.yt.exists(path):
            l.warn('Found, removing oneshot protector: %s', path)
            self.yt.remove(path)

    def prepare_stage(self):
        self.set_field('tacman_executor', self.author)

        if self.stage_name == 'FILTER':
            path = ypath_join(
                self.config['constants']['TACMAN_REQUESTS_DIR'],
                self.current_ticket,
            )
            self.yt.remove(path, recursive=True)
            l.warn('Removing direct ticket dir: %s', path)

        if self.stage_name == 'REVERT':
            path = ypath_join(
                self.config['constants']['TACMAN_REQUESTS_DIR'],
                self.current_ticket,
            )
            self.yt.remove(path, recursive=True)
            l.warn('Removing revert ticket dir: %s', path)

        if self.stage_name == 'TEST':
            self.set_field(APPROVAL_TYPE_FIELD, '')
            self.set_field(APPROVAL_STAGES_FIELD, '')
            self.set_field(MAX_PARTNER_SHARE_FIELD, '')

        if self.stage_name == 'EXECUTE':
            self.remove_repeat_oneshot_protector()

        # Approve only with developers
        if self.stage_name == 'APPROVE':
            self.set_field(APPROVAL_TYPE_FIELD, 'DEV')

            # Remove followers from ticket to test that task will add them
            approve_config = self.config['approvals']['DEV']
            followers = self.st_helper.get_followers(self.current_ticket)
            new_followers = list(set(followers) - set(approve_config['followers']))
            self.st_helper.set_followers(
                self.current_ticket,
                new_followers,
            )
            l.warn('Followers removed: %s', new_followers)

    def validate_stage(self):
        if self.stage_name == 'TEST':
            approval_type = self.get_field(APPROVAL_TYPE_FIELD)
            if approval_type not in self.config['approvals']:
                raise RuntimeError('Unknown approval type %s' % approval_type)
            l.warn('Approval type is valid')

            if self.get_field(APPROVAL_STAGES_FIELD) == '':
                raise RuntimeError('Approval stages not set')
            l.warn('Approval stages are set')

            if self.get_field(MAX_PARTNER_SHARE_FIELD) == '':
                raise RuntimeError('Max partner share not set')
            l.warn('Max partner share is set')

        if self.stage_name == 'APPROVE':
            ok_status = self.ok_helper.get_approve(self.ticket, 1, self.author)['status']
            if ok_status not in ('closed', 'suspended', 'in_progress'):
                raise RuntimeError('Unexpected OK status: %s' % ok_status)
            l.warn('OK API status: %s', ok_status)

            # Test that followers were added
            approve_config = self.config['approvals']['DEV']
            followers = self.st_helper.get_followers(self.current_ticket)
            if not set(approve_config['followers']).issubset(set(followers)):
                self.st_helper.set_followers(
                    self.current_ticket,
                    list(set(followers) - set(approve_config['followers'])),
                )
            l.warn('Followers successfully added: %s', followers)

        if self.stage_name == 'APPROVE_SUSPEND':
            ok_status = self.ok_helper.get_approve(self.ticket, 1, self.author)['status']
            if ok_status not in ('closed', 'suspended'):
                raise RuntimeError('Unexpected OK status: %s' % ok_status)
            l.warn('OK API status: %s', ok_status)

        if self.stage_name == 'APPROVE_RESUME':
            ok_status = self.ok_helper.get_approve(self.ticket, 1, self.author)['status']
            if ok_status not in ('closed', 'in_progress'):
                raise RuntimeError('Unexpected OK status: %s' % ok_status)
            l.warn('OK API status: %s', ok_status)

    def validate_outputs(self):
        if 'operations' not in self.stage:
            return

        for operation_name in self.stage['operations']:
            operation = self.config['operations'][operation_name]
            if 'outputs' not in operation:
                continue

            for output_name in operation['outputs']:
                path = ypath_join(
                    self.config['constants']['TACMAN_REQUESTS_DIR'],
                    self.current_ticket,
                    self.stage['stage_folder'],
                    output_name
                )

                if not self.yt.exists(path):
                    raise RuntimeError(
                        'Output table {} of operation {} not found'.format(
                            path,
                            operation_name,
                        )
                    )
                l.warn('Output table exists on YT: %s', path)
        # l.warn('All output tables exist on YT')

    def check_initial_status(self):
        initial_status = self.get_field('tacman_status')
        l.warn('Current state is: %s', initial_status)
        for stage_name in self.config['stages']:
            stage = self.config['stages'][stage_name]

            if initial_status == stage_name and 'performer' in stage:
                raise RuntimeError('Will not start new test when ticket state is stage start: %s' % initial_status)

            if 'next_states' not in stage:
                continue

            if 'process' in stage['next_states']:
                if stage['next_states']['process'] == initial_status:
                    raise RuntimeError('Will not start new test when ticket state is process: %s' % initial_status)

    def check_ticket_tasks(self):
        for stage_name in self.config['stages']:
            stage = self.config['stages'][stage_name]

            if 'task_id_field' not in stage:
                continue

            self.current_ticket = None
            if stage_name in self.config['ribbons']['direct']:
                self.current_ticket = self.ticket
            elif stage_name in self.config['ribbons']['revert']:
                if not self.revert_ticket:
                    continue
                self.current_ticket = self.revert_ticket

            if not self.current_ticket:
                continue

            try:
                task_id = self.get_field(stage['task_id_field'])
            except:
                l.warn('Ticket {} task not found'.format(stage_name))
                continue

            if not task_id:
                l.warn('Ticket {} task not found'.format(stage_name))
                continue

            task = self.sandbox_helper.get_task(task_id)

            # Skip QUEUE and FILTER_SERVER tasks, because they can run after ticket completion
            if task['type'] in get_daemons(self.config):
                continue

            if task['status'] in ('EXECUTING', 'ASSIGNED', 'ENQUEUED', 'FINISHING', 'WAIT_TASK'):
                raise RuntimeError('Will not start new test when {} task is {}: {}'.format(
                    stage_name,
                    task['status'],
                    get_sandbox_link(task['id']),
                ))
            l.warn('Ticket {} task is {}: {}'.format(
                stage_name,
                task['status'],
                get_sandbox_link(task['id']),
            ))

        self.current_ticket = self.ticket

    def show_skipped_stages(self):
        flow = full_flow(self.config)
        for stage_name in self.config['stages']:
            stage = self.config['stages'][stage_name]
            if stage_name not in flow and stage.get('performer'):
                l.warn('This stage will never be tested. Check config.json if needed: %s', stage_name)

    def run_integratest(self):
        self.show_skipped_stages()
        l.warn('Will run flow in this ticket: %s%s', self.ticket_url, self.ticket)
        l.warn('Revert ticket: %s%s', self.ticket_url, self.revert_ticket)
        self.check_initial_status()
        self.check_ticket_tasks()

        ok_status = self.ok_helper.get_approve(self.ticket, 1, self.author)['status']
        l.warn('Current OK API status: %s', ok_status)

        stages = self.get_flow()
        l.warn('Will execute flow: %s', stages)

        for stage_name in stages:
            l.warn('== Starting stage: %s', stage_name)

            self.stage_name = stage_name
            self.stage = self.config['stages'][self.stage_name]
            if self.stage_name.startswith('REVERT'):
                if not self.revert_ticket:
                    raise RuntimeError('Need revert ticket to test this stage')
                self.current_ticket = self.revert_ticket
            else:
                self.current_ticket = self.ticket

            self.protect_stage()
            self.prepare_stage()
            self.run_stage()
            self.validate_outputs()
            self.validate_stage()

        l.warn('== Finished integration test successfuly. Please check reports in frontend')
