# -*- coding: utf-8 -*-
import datetime
import logging
import re
import time

import requests
from sandbox import sdk2
from sandbox.common import errors
from sandbox.projects.common import binary_task
from sandbox.projects.inventori.common.resources import DUPLICATION_MODES
from sandbox.projects.inventori.common.resources import RunMode
from sandbox.projects.inventori.common import binary_task
from sandbox.projects.inventori.common import InventoryRunTaskTemplate
from sandbox.projects.inventori.common import resources
from sandbox.projects.common.arcadia import sdk as arcadiasdk
import os
from sandbox.projects.inventori.common.utils import report_status

#############################################################################

RETRIES_COUNT = 5
RETRIES_INTERVAL = 5

SOLOMON_PROJECT = 'inventori'
SOLOMON_SERVICE = 'scheduler'

SLEEP_PERIOD_SECONDS = 3 * 60  # 3 minutes sleep

#############################################################################

TaskParams = resources.InventoriBaseTaskParams


class ReleaseParameters(sdk2.Parameters):
    ext_params = binary_task.binary_release_parameters_list(stable=True)


class RunInventoriYqlTask(binary_task.LastBinaryTaskRelease, sdk2.Task):
    YQL_API_BASE_URL = 'https://yql.yandex.net/api/v2'
    YQL_API_PROXIED_BASE_URL = 'https://yql.yandex-team.ru/api/v2'
    YQL_WEBUI_BASE_URL = 'https://yql.yandex-team.ru'
    SUCCESS_YQL_STATUSES = [
        'COMPLETED',
    ]
    FAIL_YQL_STATUSES = [
        'ABORTED',
        'ERROR'
    ]
    FINAL_YQL_STATUSES = SUCCESS_YQL_STATUSES + FAIL_YQL_STATUSES

    class Requirements(sdk2.Task.Requirements):
        # Use multislot hosts
        # https://wiki.yandex-team.ru/sandbox/cookbook/#cores1multislot
        disk_space = 2 * 1024  # 2 GiB

        cores = 1
        ram = 512

        class Caches(sdk2.Requirements.Caches):
            pass

    class Parameters(InventoryRunTaskTemplate.get_run_params(TaskParams, base_class=ReleaseParameters)):
        kill_timeout = 5 * 60 * 60  # 5 hours by default
        do_not_restart = True

        task_name = sdk2.parameters.String('Task name',
                                           required=True,
                                           description='task_name for monitoring (selected task type name by default; '
                                                       'also will be added as scheduler/task tag)')

        with sdk2.parameters.RadioGroup("Source type", description='Manual for experiments, Arcadia for final version') as source_type:
            source_type.values['manual'] = source_type.Value('manual', default=True)
            source_type.values['arcadia'] = source_type.Value('arcadia')

        with source_type.value['manual']:
            query = sdk2.parameters.String(
                'YQL Query',
                multiline=True,
                required=True,
                description='Pls do NOT specify cluster here!',
            )

        with source_type.value['arcadia']:
            # Decided not use this opportunity
            # revision_path = sdk2.parameters.String(
            #     'Revision you will work with',
            #     default='#trunk',
            #     required=True,
            #     description='Commit/PR for scripts, ex. #users/heretic/remove-check; #0c72621f2381c1b4772ab72df86888b89bc9c0e2',
            # )


            query_arcadia_path = sdk2.parameters.String(
                'YQL query path',
                required=True,
                description='Paste path to your arc script. Usually inventori/scripts',
            )

        with source_type.value['manual']:
            custom_placeholders = sdk2.parameters.Dict(
                'Custom placeholders (keys like %KEY%)',
                description='There are helpful default placeholders like %ENVIRONMENT_TYPE% %YEAR% %MONTH% %DAY% etc'
            )

        # override default parameters because in YQL task we want use placeholders
        output_tables = sdk2.parameters.Dict(
            'Output table (table_name->path; name will be transformed to placeholders like %TABLE_NAME%)',
            required=True
        )

        with sdk2.parameters.Output:
            result_operation_id = sdk2.parameters.String('Operation id')
            result_additional_operation_id = sdk2.parameters.String('Additional operation id')

            master_cluster = sdk2.parameters.String('Master cluster')
            slave_cluster = sdk2.parameters.String('Slave cluster')

    def _save_operation(self, operation_id, additional_operation_id=None):
        self.Context.result_operation_id = self.Parameters.result_operation_id = operation_id
        if additional_operation_id:
            self.Parameters.result_additional_operation_id = additional_operation_id
            self.Context.result_additional_operation_id = additional_operation_id

        self.Context.save()

    def _get_operation_ids(self):
        if self.Context.result_additional_operation_id:
            return (self.Context.result_operation_id,
                    self.Context.result_additional_operation_id)
        return self.Context.result_operation_id,

    def _get_clusters_to_operation_id(self):
        result = {self.Context.master_cluster: self.Context.result_operation_id}
        if self.Context.result_additional_operation_id:
            result[self.Context.slave_cluster] = self.Context.result_additional_operation_id
        return result

    def _get_operation_ids_from_context(self):
        if self.Context.result_additional_operation_id:
            return (self.Context.result_operation_id,
                    self.Context.result_additional_operation_id)
        return self.Context.result_operation_id,

    def on_execute(self):
        from inventori.pylibs.utils.cluster_resolver import ClusterResolver
        from inventori.pylibs.utils.cluster_replication import ClusterReplication

        super().on_execute()

        logging.info('Send request stage')

        if self.Parameters.source_type == 'manual':
            query = self._form_query()
        else:
            query = self._get_query_from_arcanum(self.Parameters.query_arcadia_path, '#trunk') #self.Parameters.revision_path

        api_type = 'SQLv1'
        attributes = {
            'user_agent': 'YQL Sandbox ({} task)'.format(self.__class__.__name__)
        }
        request = {
            'action': 'RUN',
            'type': api_type,
            'attributes': attributes,
        }

        cluster_resolver = ClusterResolver(
            yt_token=str(self.Parameters.yt_token.data()[self.Parameters.yt_token.default_key]),
            infra_token=str(self.Parameters.oauth_token.data()[self.Parameters.oauth_token.default_key]))

        self.set_info('For ClusterResolver will use next tables: {using_tables}'.format(
            using_tables=self.Parameters.using_tables))

        master_slave_cluster_pair = cluster_resolver.get_master_slave_cluster_pair(
            self.Parameters.using_tables)
        available_clusters = [c for c in master_slave_cluster_pair if c is not None]

        if self.Parameters.run_mode in DUPLICATION_MODES:
            self.Parameters.master_cluster, self.Parameters.slave_cluster = master_slave_cluster_pair
            self.Context.master_cluster, self.Context.slave_cluster = master_slave_cluster_pair

            requests_to_run = []

            if self.Parameters.run_mode == RunMode.REPLICATION.value:
                self._raise_use_cluster_in_query(query)

                logging.info('Master cluster is {master_cluster}, slave is {slave_cluster}'.format(
                    master_cluster=self.Context.master_cluster,
                    slave_cluster=self.Context.slave_cluster if self.Context.slave_cluster else None))

                self.set_info(
                    'Master cluster is {master_cluster}, '
                    'Slave cluster is {slave_cluster}'.format(
                        master_cluster=self.Context.master_cluster,
                        slave_cluster=self.Context.slave_cluster
                    ),
                    do_escape=False,
                )

                cluster_request = request.copy()
                cluster_request['content'] = 'USE {};\n{}'.format(self.Context.master_cluster, query)
                requests_to_run = [(self.Context.master_cluster, cluster_request)]

            elif self.Parameters.run_mode == RunMode.ASYNC.value:
                requests_to_run = []
                for cluster in available_clusters:
                    requests_to_run.append((cluster, dict(request, content='USE {};\n{}'.format(cluster, query))))

        elif self.Parameters.run_mode == RunMode.SIMPLE.value:
            chosen_cluster = self.Parameters.yt_yql_cluster

            if chosen_cluster not in available_clusters:
                self.set_info(
                    "Cluster {cluster} is not available. I won't run anything".format(
                        cluster=chosen_cluster))
                return

            cluster_request = request.copy()
            cluster_request['content'] = 'USE {};\n{}'.format(chosen_cluster, query)

            self.Parameters.master_cluster = chosen_cluster
            self.Parameters.slave_cluster = None
            self.Context.master_cluster = chosen_cluster
            self.Context.slave_cluster = None

            requests_to_run = [(chosen_cluster, cluster_request)]

        operations = []

        for cluster_name, request in requests_to_run:
            response = self._access_api(
                method='POST',
                url='{yql_api_base_url}/operations'.format(yql_api_base_url=self.YQL_API_BASE_URL),
                json=request
            )

            operation_id = response['id']
            operations.append(operation_id)

            detail = '({})'.format(cluster_name)
            self._add_link('Operation{}'.format(detail), operation_id)

            self._publish_query(operation_id, detail)

        self._save_operation(*operations)

        self._report_status(0, self.Context.master_cluster)
        if self.Context.slave_cluster:
            self._report_status(0,  self.Context.slave_cluster)

        logging.info('Waiting until YQL tasks will finish')
        self.Context.save()
        self._wait_for_yql_tasks()

        # if self.Parameters.slave_cluster empty do nothing
        if self.Parameters.run_mode == RunMode.REPLICATION.value and self.Context.slave_cluster:
            cluster_replication = ClusterReplication(
                yt_token=str(self.Parameters.yt_token.data()[self.Parameters.yt_token.default_key]))
            tm_tasks = cluster_replication.tm_copy_tables(master_cluster=self.Context.master_cluster,
                                                          slave_cluster=self.Context.slave_cluster,
                                                          tables=list(self.Parameters.output_tables.values()))

            self.set_info('Start next tm_tasks: {tm_tasks}'.format(tm_tasks=tm_tasks))

            if self.Parameters.is_waiting_for_tm_tasks:
                summary_sleeping_time = self._waiting_tm_tasks(cluster_replication)
                self.set_info('Waited TM tasks for {time} seconds'.format(time=summary_sleeping_time))
                self.set_info('Their statuses is: {statuses}'.format(
                    statuses=' '.join(cluster_replication.get_final_task_statuses())))

        self.set_info('Finished on_execute successfully')

    def _get_session(self):
        token = str(self.Parameters.yql_token.data()[self.Parameters.yql_token.default_key])
        session = requests.Session()
        session.headers.update({
            'User-Agent': 'YQL Sandbox ({name} task)'.format(name=self.__class__.__name__),
            'Authorization': 'OAuth {token}'.format(token=token),
            'Content-Type': 'application/json',
        })
        return session

    def _access_api(self, method, url, raise_on_error=True, retries_count=RETRIES_COUNT,
                    retries_interval=RETRIES_INTERVAL, **kwargs):
        session = self._get_session()
        while retries_count:
            try:
                response = session.request(
                    method=method,
                    url=url,
                    **kwargs
                )
                response.raise_for_status()
                return response.json()
            except requests.HTTPError:
                retries_count -= 1
                time.sleep(retries_interval)
                continue
            except Exception as e:
                if not raise_on_error:
                    logging.exception('Failed to access api')
                    return None
                raise e

        if not raise_on_error:
            logging.warning('Failed to access api')
            return None
        raise errors.TaskFailure(
            'Access api failed: method -- {method}, url -- {url}'.format(
                method=method,
                url=url
            )
        )

    def _get_query_from_arcanum(self, path, revision):
        logging.info('Reading config from {}'.format(path))

        with arcadiasdk.mount_arc_path("arcadia-arc:/{}".format(revision)) as arcadia_root:
            arcadia_path = sdk2.Path(arcadia_root)
            with open(os.path.join(arcadia_path, path)) as file_handler:
                file_content = file_handler.read()

        return file_content

    def _form_query(self):
        query = self.Parameters.query
        logging.debug('query template is:\n%s', query)

        now = datetime.datetime.now()
        utc_now = datetime.datetime.utcnow()
        yesterday = now - datetime.timedelta(days=1)
        two_days_ago = now - datetime.timedelta(days=2)
        week_ago = now - datetime.timedelta(days=7)
        two_weeks_ago = now - datetime.timedelta(days=14)
        yql_query_placeholders = {
            '%ENVIRONMENT_TYPE%': self.Parameters.environment_type,
            '%YEAR%': now.year,
            '%MONTH%': '%02d' % now.month,
            '%DAY%': '%02d' % now.day,
            '%HOUR%': '%02d' % now.hour,
            '%MINUTE%': '%02d' % now.minute,
            '%SECOND%': '%02d' % now.second,
            '%UTC_YEAR%': utc_now.year,
            '%UTC_MONTH%': '%02d' % utc_now.month,
            '%UTC_DAY%': '%02d' % utc_now.day,
            '%UTC_HOUR%': '%02d' % utc_now.hour,
            '%UTC_MINUTE%': '%02d' % utc_now.minute,
            '%UTC_SECOND%': '%02d' % utc_now.second,
            '%YESTERDAY_YEAR%': yesterday.year,
            '%YESTERDAY_MONTH%': '%02d' % yesterday.month,
            '%YESTERDAY_DAY%': '%02d' % yesterday.day,
            '%TWO_DAYS_AGO_YEAR%': two_days_ago.year,
            '%TWO_DAYS_AGO_MONTH%': '%02d' % two_days_ago.month,
            '%TWO_DAYS_AGO_DAY%': '%02d' % two_days_ago.day,
            '%WEEK_AGO_YEAR%': week_ago.year,
            '%WEEK_AGO_MONTH%': '%02d' % week_ago.month,
            '%WEEK_AGO_DAY%': '%02d' % week_ago.day,
            '%TWO_WEEKS_AGO_YEAR%': two_weeks_ago.year,
            '%TWO_WEEKS_AGO_MONTH%': '%02d' % two_weeks_ago.month,
            '%TWO_WEEKS_AGO_DAY%': '%02d' % two_weeks_ago.day,
            '%TIMESTAMP%': int(time.mktime(now.timetuple())),
            '%USER%': self.owner,
            '%OWNER%': self.owner
        }

        custom_placeholders = self.Parameters.custom_placeholders
        if custom_placeholders:
            yql_query_placeholders.update(custom_placeholders)

        output_tables = self.Parameters.output_tables
        if output_tables:
            yql_query_placeholders.update({'%{}%'.format(k.upper()): v for k, v in output_tables.items()})

        for key, value in yql_query_placeholders.items():
            query = query.replace(key, str(value))

        logging.debug('result query is:\n%s', query)
        return query

    def _get_share_id(self, operation_id):
        share_id = self._access_api(
            method='GET',
            url='{yql_api_base_url}/operations/{operation_id}/share_id'.format(
                yql_api_base_url=self.YQL_API_BASE_URL,
                operation_id=operation_id
            ),
            raise_on_error=False,
        )
        return share_id

    def _publish_query(self, operation_id, detail):
        share_id = self._get_share_id(operation_id)
        if share_id is None:
            self.set_info(
                'Error getting operation public link'
            )
        else:
            self._add_link('Public link{}'.format(detail), share_id)

    def _add_link(self, title, operation_id):
        operation_url = '{yql_webui_base_url}/Operations/{operation_id}'.format(
            yql_webui_base_url=self.YQL_WEBUI_BASE_URL,
            operation_id=operation_id
        )
        self.set_info(
            '{title}: <a href="{operation_url}">{operation_url}</a>'.format(
                title=title,
                operation_url=operation_url
            ),
            do_escape=False
        )

    def _save_operation_status(self, operation_id, status):
        """
        :type operation_id: int
        :type status: int
        :param status: -1, 0 or 1
        """
        statuses = self.Context.yql_operation_statuses or {}
        statuses[str(operation_id)] = status
        self.Context.yql_operation_statuses = statuses
        self.Context.save()

    def _get_operation_status(self, operation_id):
        """
        :type operation_id: int
        :rtype: int
        :return: status -1, 0 or 1
        """
        return (self.Context.yql_operation_statuses or {}).get(str(operation_id), None)

    def _report_status(self, status, yt_cluster):
        """
        :param status: 0 - start, -1 - fail, 1 - successfully finished
        :type status: int
        :type yt_cluster: str (hahn, arnold)
        """
        report_status(self, self.Context.exec_timestamp, yt_cluster, status)

    def on_prepare(self):
        logging.info('On prepare')
        if not self.Context.exec_timestamp:
            self.Context.exec_timestamp = int(datetime.datetime.now().timestamp())
        super().on_prepare()

    def on_break(self, prev_status, status):
        logging.info('On break, prev_status=$s, status=%s', prev_status, status)
        self._kill_yql_tasks()
        super().on_break(prev_status, status)

    def on_failure(self, prev_status):
        logging.info('in method on_failure prev_status'.format(prev_status))
        self._kill_yql_tasks()
        super().on_failure(prev_status)

    @property
    def current_tags(self):
        tags = []
        for tag in self.Parameters.tags:
            tags.append(str(tag))
        return tags

    def append_tag(self, tag):
        tag_formatted = tag.upper()
        if tag_formatted not in self.current_tags:
            self.Parameters.tags.append(tag_formatted)

    def on_save(self):
        binary_task.LastBinaryTaskRelease.on_save(self)
        if self.Parameters.task_name:
            self.append_tag(str(self.Parameters.task_name))

    def _kill_yql_tasks(self):
        logging.info('Trying to kill running yql tasks..')
        operation_statuses = []

        # ensure check operation statuses
        self._check_for_yql_status(list(self._get_clusters_to_operation_id().items()))

        if self._get_operation_status(self.Context.result_operation_id) != 1:
            self._report_status(-1, self.Context.master_cluster)
        if (
            self.Context.slave_cluster
            and self._get_operation_status(self.Context.result_additional_operation_id) != 1
        ):
            self._report_status(-1, self.Context.slave_cluster)

        # now we will get operation_ids from Context because we use method in on_break
        for operation_id in self._get_operation_ids_from_context():
            if self._get_operation_status(operation_id) is not None:
                self.set_info('Already stopped: {id}'.format(id=operation_id))
                continue
            try:
                response = self._access_api(
                    method='POST',
                    url='{yql_api_base_url}/operations/{operation_id}'.format(
                        yql_api_base_url=self.YQL_API_BASE_URL,
                        operation_id=operation_id,
                    ),
                    json={
                        'action': 'ABORT',
                    }
                )

                if 'status' not in response:
                    self._logger.info('Operation %s not found', operation_id)

                operation_statuses.append(response['id'])
                self.set_info('Successfully stopped: {id}'.format(id=response['id']))
            except Exception as e:
                logging.exception('Fail while killing YQL: %s', e)
                self.set_info('Fail while killing YQL')

        return operation_statuses

    def _check_for_yql_status(self, cluster_operation_ids):
        result_statuses = {}
        for yt_cluster, operation_id in cluster_operation_ids:
            if self._get_operation_status(operation_id) is not None:
                continue
            url = '{yql_api_base_url}/operations/{operation_id}/meta'.format(
                yql_api_base_url=self.YQL_API_BASE_URL,
                operation_id=operation_id,
            )
            response = self._access_api(method='GET', url=url)
            status = response['status']
            result_statuses[yt_cluster] = status
            if status in self.FINAL_YQL_STATUSES:
                logging.info('%s YQL is in %s', yt_cluster.capitalize(), status)
                report_status = self._get_solomon_status_by_yql_final_status(status)

                self._save_operation_status(operation_id, report_status)
                self._report_status(report_status, yt_cluster)
                self.set_info('Solomon report {} status is {}'.format(yt_cluster.upper(), report_status))
        return result_statuses

    def _wait_for_yql_tasks(self):
        """Blocking wait for YQL queries to finish"""
        result_statuses = {}

        while True:
            logging.info('In loop of waiting YQL tasks')
            time.sleep(SLEEP_PERIOD_SECONDS)

            result_statuses.update(
                self._check_for_yql_status(list(self._get_clusters_to_operation_id().items()))
            )

            if any(status not in self.FINAL_YQL_STATUSES
                   for status in result_statuses.values()):
                continue
            elif any(status in self.FAIL_YQL_STATUSES
                     for status in result_statuses.values()):
                raise errors.TaskFailure('YQL query evaluation failed')
            return

    def _get_solomon_status_by_yql_final_status(self, yql_status):
        if yql_status in self.SUCCESS_YQL_STATUSES:
            return 1
        elif yql_status in self.FAIL_YQL_STATUSES:
            return -1

        raise errors.TaskFailure('UNKNOWN YQL status')


    def _waiting_tm_tasks(self, cluster_replication, seconds_to_sleep=7):
        self.set_info('Now I am waiting for TM tasks')
        summary_sleeping_time = 0

        while cluster_replication.is_copy_tasks_still_running():
            summary_sleeping_time += seconds_to_sleep
            time.sleep(seconds_to_sleep)

        return summary_sleeping_time

    def _raise_use_cluster_in_query(self, query):
        """Check if cluster was specify in query"""
        m = re.search(r'(^|[\s])use\s+(\w+)\s*;', query, re.IGNORECASE)
        if m:
            msg = 'Query already has USE statement: "{}"'.format(m.group().strip())
            self.set_info('WARN: {msg}'.format(msg=msg))
            raise ValueError(msg)
