from .resources import NirvanaRegularProcessState, NirvanaOperationsBinary
import sandbox.common.types.resource as ctr
import sandbox.common.types.task as ctt
import sandbox.common.types.notification as ctn
from sandbox.common import rest
import sandbox.sdk2 as sdk2
import sandbox.sandboxsdk.process
from sandbox.sandboxsdk.errors import SandboxSubprocessError
from sandbox.projects.common.juggler import jclient
from sandbox.projects.common.vcs import arc

from datetime import datetime, timedelta
import logging
import json
import sandbox.sdk2.helpers.process as process
import requests
import re
from textwrap import dedent
import dateutil.tz


TZ = dateutil.tz.gettz("Europe/Moscow")


def get_url_to_nirvana_flow(flow_id):
    return 'https://nirvana.yandex-team.ru/flow/{}'.format(flow_id)


def get_nirvana_flow_url_pattern():
    return r'nirvana.yandex-team.ru/flow/([a-zA-Z0-9-]+)'


def html_href(url, text=None):
    if text is None:
        text = url
    return '<a href="{url}">{text}</a>'.format(url=url, text=text)


def flows_info_to_html(flows_info):
    return [
        html_href(
            get_url_to_nirvana_flow(flow['workflow_id']),
            datetime.fromtimestamp(flow['timestamp'], tz=TZ).isoformat(),
        ) if flow.get('workflow_id') else (
            'UNEXISTED WORKFLOW for ' + datetime.fromtimestamp(flow['timestamp'], tz=TZ).isoformat()
        )
        for flow in flows_info
    ]


def _get_new_graphs_timestamps(last_time, period, from_time, to_time):
    cur = last_time + period if last_time is not None else from_time
    rv = []
    logging.info('generate from %s to %s with period %s', cur, to_time, period)
    if period == 0:
        raise ValueError('Bad period of tasks generating', period)
    while cur <= to_time:
        rv.append(cur)
        cur += period
    return rv


def empty_process_state():
    return {
        "last_time": None,
        "graphs_to_watch": []
    }


class FlowsState():

    def __init__(self):
        self.completed_flows = []
        self.failed_flows = []
        self.abandoned_flows = []
        self.too_much_failed_flows = []
        self.restarted_flows = []
        self.new_flows = []
        self.queueing_flows = []
        self.failed_while_launching = []
        self.failed_while_launching_but_got_workflow_id = []
        self.running_flows = []

    def need_to_notify_about_fails(self):
        return bool(self.abandoned_flows or self.too_much_failed_flows or self.failed_while_launching or self.failed_while_launching_but_got_workflow_id)

    def to_dict(self):
        return self.__dict__.copy()


class NirvanaRegularProcess(sdk2.Task):
    class Requirements(sdk2.Requirements):
        cores = 1  # vCores
        ram = 1024  # Mb
        disk_space = 6 * 1024

        class Caches(sdk2.Requirements.Caches):
            pass  # means that task do not use any shared caches

    """Clone and run nirvana graph, changing some options"""
    class Parameters(sdk2.Task.Parameters):

        with sdk2.parameters.Group('Graph launcher description') as param_block:
            nirvana_oauth_token = sdk2.parameters.String(
                """
                Nirvana OAuth token name (in sandbox vault). Supported formats:
                    "secret_name" or "OWNER@secret_name"
                """,
                required=True
            )

            graph_creator = sdk2.parameters.Resource(
                'Program creating graph',
                state=(ctr.State.READY, ),
                required=False,
            )

            process_runner_search_config = sdk2.parameters.String(
                "graph_creator_search_config",
                description=dedent('''\
                    Json with parameters of searching graph creator program
                    Example: {"type": "ML_ENGINE_LOG_CREATOR",
                              "state": "READY",
                              "attrs": {"released": "stable"}}
                '''),
                multiline=True
            )

            cmd_template = sdk2.parameters.String(
                '''
                Commandline template. Available variables
                {binary}
                {timestamp}
                {config}
                {process_id}
                {arcadia}
                ''',
                required=True,
                default='{binary} --time {timestamp}',
                multiline=True
            )

            program_config = sdk2.parameters.String(
                '''Program config, will be written to file and available in commandline template''',
                multiline=True
            )

            nirvana_cmd_binary = sdk2.parameters.LastReleasedResource(
                'nirvana_cmd binary',
                resource_type=NirvanaOperationsBinary,
                state=(ctr.State.READY, ),
                required=True,
            )

            arcadia_revision = sdk2.parameters.Integer(
                "Arcadia revision for {arcadia}, use HEAD by default",
                required=False,
                default=None
            )

        with sdk2.parameters.Group('Process parameters') as param_block:
            process_state = sdk2.parameters.Resource(
                'Optional process state',
                resource_type=NirvanaRegularProcessState,
                required=False,
                state=(ctr.State.READY, ),
            )

            process_id = sdk2.parameters.String(
                "Regular process id, will be used to retrieve state",
                required=True
            )

            max_count_of_running_flows = sdk2.parameters.Integer(
                'Max count of running flows',
                default=None,
                required=False
            )

            period = sdk2.parameters.Integer(
                "Period to run graphs, in seconds",
                required=True
            )

            start_time = sdk2.parameters.Integer(
                "First time of graph, if no state found",
                required=False
            )

            end_time = sdk2.parameters.Integer(
                "Last time of graph to be generated, default is now",
                required=False
            )

            use_state = sdk2.parameters.Bool(
                'Use state file',
                required=True,
                default=True,
            )

            is_one_shot = sdk2.parameters.Bool(
                'Kill scheduler after success or after max_attemp_time',
                required=False,
                default=False
            )

            max_attempt_time = sdk2.parameters.Integer(
                "Time in seconds to restart failed tasks",
                required=False
            )

            max_attempts = sdk2.parameters.Integer(
                "Max attemps count",
                required=False,
                default=0
            )

            attempts_to_notify = sdk2.parameters.Integer(
                "Attemps count to send notifications about failed flows",
                required=False,
                default=0
            )

            juggler_host = sdk2.parameters.String(
                "Name of jugler host to notify about problems",
                required=False,
                default=None
            )

    def smart_get_vault_data(self, token_or_ownertoken):
        if '@' in token_or_ownertoken:
            owner, token = token_or_ownertoken.split('@', 1)
            return sdk2.Vault.data(owner, token)
        else:
            return sdk2.Vault.data(token_or_ownertoken)

    def get_execution_state(self, nirvana_operations_cmd, flow_id, path_to_token, log_path):
        with (log_path / 'check_state.{}.out'.format(flow_id)).open('w+') as stdout:
            with process.ProcessLog(self, logger=logging.getLogger('check_state.' + flow_id), set_action=False) as pl:
                process.subprocess.check_call(
                    '{binary} --token {token_path} get_execution_state {flow_id}'.format(binary=nirvana_operations_cmd, token_path=path_to_token, flow_id=flow_id),
                    stdout=stdout,
                    stderr=pl.stdout,
                    shell=True,
                )
            stdout.seek(0)
            return json.load(stdout)

    def on_save(self):
        logging.info('on create called')
        if self.Parameters.process_runner_search_config:
            params = json.loads(self.Parameters.process_runner_search_config)

            logging.info('params: %s', params)
            self.Parameters.graph_creator = sdk2.Resource.find(**params).order(-sdk2.Resource.id).first()
        else:
            logging.info('no config found: %s', self.Parameters.__getstate__())

    def do_job(self, process_state, process_runner_bin, nirvana_operations_bin, token_file, process_runner_config_file, process_id, now=None, log_path=None):
        process_state = process_state.copy()
        if now is None:
            now = datetime.now(TZ)

        if log_path is None:
            log_path = self.log_path()
        now_timestamp = int((now - datetime(1970, 1, 1, tzinfo=dateutil.tz.gettz("UTC"))).total_seconds())
        new_graphs_timestamps = _get_new_graphs_timestamps(
            process_state["last_time"],
            self.Parameters.period,
            self.Parameters.start_time or now_timestamp,
            self.Parameters.end_time or now_timestamp
        )
        logging.info('timestamps: %s', str(new_graphs_timestamps))

        old_graphs_states = [
            (
                flow_info,
                self.get_execution_state(nirvana_operations_bin, flow_info['workflow_id'], token_file, log_path),
            )
            for flow_info in process_state["graphs_to_watch"] if flow_info.get('workflow_id')
        ]

        flows_state = FlowsState()
        flows_state.running_flows = [flow_info for flow_info, state in old_graphs_states if state['status'] != 'completed']

        flows_to_restart = []
        # graphs that were not successfully launched and did not get workflow_id
        flows_to_restart += [dict(flow_info.items() + [('attempt', flow_info.get('attempt', 0) + 1)])
                             for flow_info in process_state["graphs_to_watch"]
                             if not flow_info.get('workflow_id')]
        for flow_info, state in old_graphs_states:
            logging.info('graph %s state is %s, result %s', flow_info['workflow_id'], state['status'], state['result'])
            if state['status'] == 'completed':
                if state['result'] in {'cancel', 'failure'}:
                    flows_state.failed_flows.append(flow_info)
                    if (self.Parameters.max_attempt_time and
                            datetime.fromtimestamp(flow_info.get('first_launch_time', flow_info['timestamp']), tz=TZ) < now - timedelta(seconds=self.Parameters.max_attempt_time)):
                        logging.error('stop to watch graph %s for timestamp %s, max_attempt_time=%s reached',
                                      flow_info['workflow_id'], datetime.fromtimestamp(flow_info['timestamp'], tz=TZ), self.Parameters.max_attempt_time)
                        flows_state.abandoned_flows.append(flow_info)
                        continue

                    if self.Parameters.max_attempts and flow_info.get('attempt', 0) + 1 >= self.Parameters.max_attempts:
                        logging.error('stop to watch graph %s for timestamp %s, max_attemps=%s reached',
                                      flow_info['workflow_id'], datetime.fromtimestamp(flow_info['timestamp'], tz=TZ), self.Parameters.max_attempts)
                        flows_state.abandoned_flows.append(flow_info)
                        continue

                    if self.Parameters.attempts_to_notify and flow_info.get('attempt', 0) + 1 >= self.Parameters.attempts_to_notify:
                        logging.error('too much fail of graph %s for timestamp %s',
                                      flow_info['workflow_id'], datetime.fromtimestamp(flow_info['timestamp'], tz=TZ))
                        flows_state.too_much_failed_flows.append(flow_info)

                    flows_to_restart.append(dict(flow_info.items() + [('attempt', flow_info.get('attempt', 0) + 1)]))
                else:
                    flows_state.completed_flows.append(flow_info)

        def _get_empty_flowinfo(timestamp):
            return {'timestamp': timestamp, 'first_launch_time': now_timestamp, 'attempt': 0}

        flows_to_be_launched = sorted(flows_to_restart + [_get_empty_flowinfo(t) for t in new_graphs_timestamps], key=lambda flow_info: flow_info['timestamp'])
        flows_state.queueing_flows = []
        if self.Parameters.max_count_of_running_flows is not None:
            if self.Parameters.max_count_of_running_flows <= len(flows_state.running_flows):
                logging.info('already running %s flows, nothing will be launched')
                flows_state.queueing_flows = flows_to_be_launched
                flows_to_be_launched = []
            else:
                count_to_start = self.Parameters.max_count_of_running_flows - len(flows_state.running_flows)
                flows_state.queueing_flows = flows_to_be_launched[count_to_start:]
                flows_to_be_launched = flows_to_be_launched[:count_to_start]
        else:
            logging.info('not restricted count of running flows')

        new_graphs, flows_state.failed_while_launching, flows_state.failed_while_launching_but_got_workflow_id = self.start_workflows(
            flows_to_be_launched,
            self.Parameters.cmd_template,
            dict(config=process_runner_config_file, token=token_file, binary=process_runner_bin, process_id=process_id)
        )
        process_state["last_time"] = max([process_state["last_time"]] + [f["timestamp"] for f in flows_to_restart] + new_graphs_timestamps)
        process_state["graphs_to_watch"] = flows_state.running_flows + new_graphs + flows_state.failed_while_launching + flows_state.queueing_flows
        flows_state.restarted_flows = new_graphs[:len(flows_to_restart)]
        flows_state.new_flows = new_graphs[len(flows_to_restart):]

        return process_state, flows_state

    def on_execute(self):

        binary = str(sdk2.ResourceData(self.Parameters.graph_creator).path)
        nirvana_operations = str(sdk2.ResourceData(self.Parameters.nirvana_cmd_binary).path)
        process_id = self.Parameters.process_id
        state_file = "run_nirvana_regular.state.json"

        token = self.smart_get_vault_data(self.Parameters.nirvana_oauth_token)
        token_file = 'nirvana_token'
        with open(token_file, 'w') as out:
            out.write(token)
        config_file = 'process_runner_config'
        if self.Parameters.program_config:
            with open(config_file, 'w') as out:
                out.write(self.Parameters.program_config)

        state_resource = self.Parameters.process_state
        if not state_resource:
            resource = NirvanaRegularProcessState.find(
                attrs={"process_id": process_id},
                state=(ctr.State.READY, ),
            ).first()
            if resource:
                state_resource = resource

        if state_resource and self.Parameters.use_state:
            with sdk2.ResourceData(state_resource).path.open() as infile:
                process_state = json.load(infile)
        else:
            process_state = empty_process_state()

        process_state, graphs_lists = self.do_job(process_state, binary, nirvana_operations, token_file, config_file, process_id)

        if graphs_lists.failed_while_launching:
            self.set_flows_info('failed dates', graphs_lists.failed_while_launching, color="red")
        if graphs_lists.failed_while_launching_but_got_workflow_id:
            self.set_flows_info('failed while launching but got workflow_id dates', graphs_lists.failed_while_launching_but_got_workflow_id, color="red")
        if graphs_lists.abandoned_flows:
            self.set_flows_info('abandoned flows', graphs_lists.abandoned_flows, color="red")
        if graphs_lists.too_much_failed_flows:
            self.set_flows_info('too much failed flows', graphs_lists.too_much_failed_flows, color="yellow")
        self.set_flows_info('failed flows', graphs_lists.failed_flows)
        self.set_flows_info('completed flows', graphs_lists.completed_flows)
        self.set_flows_info('running flows', graphs_lists.running_flows)
        self.set_flows_info('restarted flows', graphs_lists.restarted_flows)
        self.set_flows_info('new flows', graphs_lists.new_flows)
        self.set_flows_info('queueing flows', graphs_lists.queueing_flows)

        new_state = sdk2.ResourceData(
            NirvanaRegularProcessState(
                self,
                "Nirvana regular process state of process %s" % process_id,
                path=state_file,
                process_id=process_id
            )
        )
        with new_state.path.open('wb') as out:
            json.dump(process_state, out)

        if self.Parameters.is_one_shot and not process_state["graphs_to_watch"]:
            self.stop_self_scheduler()

        if graphs_lists.need_to_notify_about_fails():
            self.notify_about_fails(graphs_lists)
        else:
            self.notify_juggler(graphs_lists, status='OK')

    def get_task_url(self):
        return '{}/task/{}/view'.format(sandbox.common.utils.server_url(), self.id)

    def stop_self_scheduler(self):
        # if all graphs are filtered by self.Parameters.max_attempt_time
        # in one-shot mode it is time to kill scheduler
        # FIXED: forbid ads/ml_engine/regular_task/schedule
        # in --one-shot mode without --max-attempt-time parameter
        # But it seems to be that it is better to forbid it here -
        # because not only ./schedule can use these params

        sandbox_client = rest.Client()
        # here also some weird code - scheduler id can not be accessed
        # by rest.Client so lets get task_url from rest.Client
        # and go to this url by requests.get() to get scheduler id
        task_info = sandbox_client.task[self.id].read()
        response = requests.get(task_info['url'])
        if response.status_code != 200:
            logging.warn("Request to %s returned with status code %s" % (task_info['url'], response.status_code))
            return
        j = json.loads(response.text)
        scheduler_id = j.get('scheduler', {}).get('id')
        logging.info('Here we are going to stop scheduler with id %s' % scheduler_id)
        try:
            logging.info('Entering to code that will try to delete scheduler')
            # weird way to check that this scheduler exists -
            # try to throw exception while accessing scheduler info
            scheduler_info = sandbox_client.scheduler[scheduler_id].read()
            logging.info(scheduler_info)
            logging.info('Deleting scheduler %s' % scheduler_id)
            del sandbox_client.scheduler[scheduler_id]
        except sandbox_client.HTTPError:
            logging.warn('Seems like you are trying to delete scheduler that does not exist')

    def notify_about_fails(self, graphs_lists):
        self.notify_email(graphs_lists)
        self.notify_juggler(graphs_lists, status='CRIT')

    def get_short_message(self, graphs_lists):
        return (
            self.get_task_url() +
            " : " +
            (
                '; '.join(
                    '{name}: {count}'.format(
                        name=name,
                        count=len(flows)
                    )
                    for name, flows in
                    graphs_lists.to_dict().iteritems()
                    if len(flows)
                ) or 'all done'
            )
        )

    def notify_juggler(self, graphs_lists, status):
        if self.Parameters.juggler_host:
            try:
                jclient.send_events_to_juggler(
                    self.Parameters.juggler_host,
                    jclient.to_juggler_service_name(self.Parameters.process_id),
                    status,
                    self.get_short_message(graphs_lists)
                )
            except Exception:
                self.set_info('Failed to send event to juggler')
                logging.error('Failed to send event to juggler', exc_info=True)
        else:
            logging.info("No juggler host specified in task, so not sending event")

    def notify_email(self, graphs_lists):
        if graphs_lists.abandoned_flows:
            subject = 'Abandoned flows in process {}'.format(self.Parameters.process_id)
        elif graphs_lists.too_much_failed_flows:
            subject = 'Too much times failed flows in process {}'.format(self.Parameters.process_id)
        else:
            subject = 'Problems with {}'.format(self.Parameters.process_id)
        body = ["Task " + html_href(self.get_task_url(), str(self.id))]
        if graphs_lists.abandoned_flows:
            body += ['Abandoned flows:'] + flows_info_to_html(graphs_lists.abandoned_flows)
        if graphs_lists.too_much_failed_flows:
            body += ['Too much times failed flows:'] + flows_info_to_html(graphs_lists.too_much_failed_flows)
        if graphs_lists.failed_while_launching:
            body += ['Failed while launching flows:'] + flows_info_to_html(graphs_lists.failed_while_launching)
        if graphs_lists.failed_while_launching_but_got_workflow_id:
            body += ['Failed while launching but got workflowid flows:'] + flows_info_to_html(graphs_lists.failed_while_launching_but_got_workflow_id)
        for notification_rule in self.Parameters.notifications:
            if set(notification_rule.statuses).intersection(set([ctt.Status.FAILURE, ctt.Status.EXCEPTION])):
                self.server.notification(
                    subject=subject,
                    body='<br>\n'.join(body),
                    recipients=notification_rule.recipients,
                    transport=notification_rule.transport,
                    type=ctn.Type.HTML,
                    charset=ctn.Charset.UTF,
                    task_id=self.id,
                    view=ctn.View.DEFAULT
                )

    def set_flows_info(self, title, flows, color=None):
        if color:
            title = '<font color="{color}">{title}</font>'.format(color=color, title=title)
        html_strings = flows_info_to_html(flows)
        self.set_info(
            '{}:\n{}'.format(title, '\n'.join(html_strings)),
            do_escape=False
        )

    def get_cmd(self, cmd_template, timestamp, format_ctx):
        try:
            return cmd_template.format(timestamp=timestamp, **format_ctx)
        except KeyError as e:
            if e.args[0] == 'arcadia':
                if getattr(self, 'arcadia_path', None) is None:
                    revision = (
                        "r{}".format(self.Parameters.arcadia_revision)
                        if self.Parameters.arcadia_revision else
                        None
                    )
                    self.arcadia_path = arc.Arc().mount_path(None, revision, fetch_all=False)
                return cmd_template.format(timestamp=timestamp, arcadia=self.arcadia_path._work_path, **format_ctx)
            raise

    def start_workflows(self, flow_infos, cmd_template, format_ctx):
        procs = []
        for flow_info in flow_infos:
            procs.append(
                (
                    sandbox.sandboxsdk.process.run_process(
                        self.get_cmd(cmd_template, flow_info['timestamp'], format_ctx),
                        shell=True,
                        outputs_to_one_file=False,
                        wait=False,
                        log_prefix=datetime.fromtimestamp(flow_info['timestamp'], tz=TZ).strftime('run_regular.%Y-%m-%dT%H:%M:%S')
                    ),
                    flow_info
                )
            )

        rv = []
        failed_while_launching = []
        failed_while_launching_but_got_workflow_id = []  # this one is subset of failed_while_launching with flows that has workflow_id in stderr. Used for more info in stderr

        for proc, flow_info in procs:
            proc.wait()
            with open(proc.stdout_path) as fin, open(proc.stderr_path) as ferr:
                started_flow_info = flow_info.copy()
                try:
                    sandbox.sandboxsdk.process.check_process_return_code(proc)
                except SandboxSubprocessError as err:
                    logging.error('Subprocess for date %s raised exception %s' % (datetime.fromtimestamp(flow_info['timestamp'], tz=TZ).isoformat(), err))
                    logging.error('Look for more info in nonempty stderr')
                    failed_while_launching.append(started_flow_info)

                    flows_in_stderr = re.findall(get_nirvana_flow_url_pattern(), ferr.read())
                    if flows_in_stderr:
                        failed_flow_info = started_flow_info.copy()
                        failed_flow_info.update({'workflow_id': flows_in_stderr[0]})
                        failed_while_launching_but_got_workflow_id.append(failed_flow_info)

                    continue

                try:
                    flow_info_patch = json.load(fin)
                    if not isinstance(flow_info_patch, dict) and not flow_info_patch.get('workflow_id'):
                        # nirvana_regular_process can watch only one graph, not easy to restart if one fails
                        raise ValueError('output of graph createor doesnt have workflow_id', flow_info_patch)
                except:
                    failed_while_launching.append(started_flow_info)
                    continue

                started_flow_info.update(flow_info_patch)
                rv.append(started_flow_info)

        if getattr(self, 'arcadia_path', None) is not None:
            if self.arcadia_path.mounted:
                self.arcadia_path.unmount()
                self.arcadia_path = None
        return rv, failed_while_launching, failed_while_launching_but_got_workflow_id
