# -*- coding: utf-8 -*-
import datetime as dt
import logging
import re

import requests
import six

import dateutil.parser
from dateutil.tz import tzutc

import sandbox.projects.common.search.bugbanner2 as bb2
import sandbox.projects.common.testenv_client as testenv_helper
import sandbox.projects.common.time_utils as tu
import sandbox.projects.common.decorators as sb_decorators
import sandbox.projects.release_machine.components.all as rm_components
import sandbox.projects.release_machine.core.const as rm_const
import sandbox.projects.release_machine.core.task_env as task_env
import sandbox.projects.release_machine.helpers.events_helper as events_helper
import sandbox.projects.release_machine.rm_notify as rm_notify
import sandbox.projects.release_machine.security as rm_sec
import sandbox.sdk2 as sdk2
from sandbox.common import rest
from sandbox.common import utils as common_utils
from sandbox.projects.common import binary_task
from sandbox.projects.common import error_handlers as eh
from sandbox.projects.common import link_builder as lb
from sandbox.projects.release_machine import input_params2 as rm_params
from sandbox.projects.release_machine.client import RMClient
from sandbox.projects.release_machine.tasks import base_task as rm_bt


TASK_FIELDS = [
    "context.__GSID",
    "input_parameters.component_name",
    "input_parameters.tag_or_branch_mode",
    "context.rm_event_type",
    events_helper.RM_PROTO_EVENT_CTX_KEY,
    events_helper.RM_PROTO_EVENT_READY_CTX_KEY,
    "context.scope_number",
    "time.created",
    "time.updated",
    "status",
    "type",
]
BRANCH_OR_TAG_TASK_TYPE = "CREATE_BRANCH_OR_TAG"  # the task to use to get branch/tag info for posting aux db

SCOPE_NUMBER_KEY = 'scope_number'
LAST_CRAWLED_TIME = "last_crawled_time"

# A set of pairs of known component names of components that share the same TE DB
# (thus leading to component name error)
COMPONENT_NAME_ERROR_WHITELIST = {
    frozenset(('jupiter', 'jupiter_viewers')),  # jupiter-trunk
    frozenset(('userfeat', 'userfeat_viewers')),  # userdata_full_trunk_yt
    frozenset(('exp_formulas_base', 'exp_formulas_mmeta')),  # formulas
    frozenset(('zephyr', 'zcli')),
}


@rm_notify.notify2()
class ReleaseMachineCrawler(binary_task.LastBinaryTaskRelease, bb2.BugBannerTask, rm_bt.ComponentErrorReporter):
    """ Task for crawling events """

    _task_link_html = None

    class Requirements(task_env.TinyRequirements):
        disk_space = 1024  # 1 Gb

    class Parameters(rm_params.BaseReleaseMachineParameters):
        _lbrp = binary_task.binary_release_parameters(stable=True)
        kill_timeout = 3 * 60 * 60  # 3 hour
        debug = sdk2.parameters.Bool("Debug", default=False)
        component_name_filter = sdk2.parameters.String("Component name filter")

    class Context(sdk2.Task.Context):
        independent_task_ids = []

    @property
    def task_link_html(self):
        if not self._task_link_html:
            self._task_link_html = lb.task_link(self.id)
        return self._task_link_html

    @property
    def is_binary_run(self):
        return self.Parameters.binary_executor_release_type != 'none'

    @sb_decorators.memoized_property
    def rm_model_client(self):
        from release_machine.release_machine.services.release_engine.services.Model import ModelClient
        return ModelClient.from_address(rm_const.Urls.RM_HOST)

    def on_execute(self):
        binary_task.LastBinaryTaskRelease.on_execute(self)
        self.add_bugbanner(bb2.Banners.ReleaseMachine)

        oauth_token = rm_sec.get_rm_token(self)
        te_helper = testenv_helper.TEClient(oauth_token=oauth_token)
        rm_client = RMClient()
        self.db_re = re.compile(r'database=([\d\w]+)')
        self.rest_client = rest.Client()
        not_ok = []

        name_filter_re = re.compile(self.Parameters.component_name_filter or ".*", re.IGNORECASE)
        all_component_names = rm_components.get_component_names()

        logging.debug("All components: %s", all_component_names)

        for c_name in filter(name_filter_re.match, all_component_names):
            try:
                logging.info("Start processing %s", c_name)

                self.update_current_time(rm_client, c_name)

                c_info = rm_components.get_component(c_name)
                if c_info.release_cycle_type == rm_const.ReleaseCycleType.CI:
                    logging.info("Skip %s. It is CI-based!", c_name)
                    continue
                with common_utils.Timer("Crawl Timer") as timer:
                    result_ok = self.process_testenv(c_info, te_helper, timer)

                if result_ok:
                    self.build_and_send_last_crawled_time(c_name)
                else:
                    not_ok.append(c_name)

                if float(timer) > 15.0:
                    self.set_info('[slower.15s] Component `{}` handle time: {}'.format(c_info.name, timer))
            except Exception as e:
                eh.log_exception("Component {} processing FAILED".format(c_name), e)
                not_ok.append(c_name)

        if not_ok:
            rm_notify.send_tm_message(
                self,
                "WARNING! NOT OK crawling result for components: [{component}]\nSB: {task_link}".format(
                    component=", ".join(c_name for c_name in not_ok),
                    task_link=self.task_link_html,
                ),
                [rm_const.RM_USERS["rm_maintainers"].telegram],
            )

        prepare_components_message = "PrepareRmComponentEnvironment tasks:\n"
        for c_name, task_id in self.Context.independent_task_ids:
            prepare_components_message += "Component name: {c_name}, task: {task}\n".format(
                c_name=c_name,
                task=lb.task_link(task_id, plain=True),
            )
        if self.Context.independent_task_ids:
            self.set_info(prepare_components_message)

    def update_current_time(self, rm_client, c_name):
        self.current_time_utc = dt.datetime.utcnow()
        self.current_time_utc_tz = self.current_time_utc.replace(tzinfo=tzutc())
        self.current_time_utc_short = self.current_time_utc.replace(microsecond=0)
        self.last_crawled_time = self._get_last_crawled_time(rm_client, c_name)

    def build_and_send_last_crawled_time(self, component_name):

        if self.Parameters.debug:
            return

        from release_machine.common_proto import events_pb2 as rm_proto_events

        value = self.current_time_utc_tz.isoformat()

        events_helper.post_proto_events([rm_proto_events.EventData(
            general_data=rm_proto_events.EventGeneralData(
                hash=six.text_type(hash((
                    component_name,
                    LAST_CRAWLED_TIME,
                    value,
                    self.status,
                ))),
                component_name=component_name,
                referrer=u"sandbox_task:{}".format(self.id)
            ),
            task_data=rm_proto_events.EventSandboxTaskData(
                task_id=self.id,
                status=self.status,
                created_at=self.created.isoformat(),
                updated_at=self.updated.isoformat(),
            ),
            update_state_data=rm_proto_events.UpdateStateData(
                key=LAST_CRAWLED_TIME,
                value=value,
            ),
        )])

    def _get_last_crawled_time(self, rm_client, c_name):
        last_crawled_time = rm_client.get_state(c_name, LAST_CRAWLED_TIME)
        if not last_crawled_time:
            logging.warn("No last crawled time detected from state. Use 2 hours gap")
            last_crawled_time = (self.current_time_utc_tz - dt.timedelta(hours=2)).isoformat()
        else:
            last_crawled_time = last_crawled_time[0]["value"]
        logging.info("Last crawled time for %s: %s", c_name, last_crawled_time)
        return last_crawled_time

    def process_testenv(self, c_info, te_helper, timer):

        with timer["testenv"] as testenv_timer:

            logging.info("Processing testenv for component: %s", c_info.name)
            trunk_result_ok = self.handle_trunk_dbs(te_helper, c_info, testenv_timer)

            if c_info.is_branched:
                branch_result_ok = self.handle_branch_dbs(c_info, testenv_timer)
            else:
                branch_result_ok = True

            result_ok = trunk_result_ok and branch_result_ok

        logging.info("Result ok: %s", result_ok)
        logging.info("Testenv time: %s", testenv_timer)

        return result_ok

    def handle_branch_dbs(self, c_info, testenv_timer):
        result_ok = True
        with testenv_timer["branch_dbs"] as testenv_branch_timer:
            last_branch_num = c_info.last_branch_num
            for i in range(3):  # last 3 branch numbers
                curr_branch_num = last_branch_num - i
                if result_ok and curr_branch_num > 0:
                    curr_branch_te_db = c_info.testenv_cfg__db_template.format(testenv_db_num=curr_branch_num)
                    result_ok = self.handle_branch_db(curr_branch_te_db, c_info, testenv_branch_timer)
        logging.info("Testenv branch db time: %s", testenv_branch_timer)
        return result_ok

    def handle_trunk_dbs(self, te_helper, c_info, testenv_timer):
        with testenv_timer["trunk_dbs"] as testenv_trunk_timer:
            result_ok = self.handle_trunk_db(te_helper, c_info, testenv_trunk_timer)
            if result_ok:
                self.handle_aux_trunk_dbs(te_helper, c_info, testenv_trunk_timer)
        logging.info("Testenv trunk time: %s", testenv_trunk_timer)
        return result_ok

    def handle_aux_trunk_dbs(self, te_helper, c_info, timer):
        with timer["handling_aux_trunk_dbs"]:
            try:
                aux_runs_dbs = te_helper.get_aux_runs(c_info.testenv_cfg__trunk_db)
            except Exception as exc:
                # TODO: process bad databases and inform users
                eh.log_exception(
                    "Failed to process '{}' database, cannot get aux runs".format(c_info.testenv_cfg__trunk_db),
                    exc, task=self
                )
                return
            logging.debug("Collect aux databases: %s", aux_runs_dbs)
            aux_runs_dbs = self.filter_rm_aux_dbs(aux_runs_dbs)
            logging.debug("Filter aux databases: %s", aux_runs_dbs)
            skipped_dbs = []
            for db in aux_runs_dbs:
                db_name = self.db_re.findall(db["auxiliary_check"])
                if not db_name:
                    skipped_dbs.append(db["auxiliary_check"])
                    continue
                db_name = db_name[0]
                try:
                    self.handle_aux_db(te_helper, db_name, c_info)
                except Exception as e:
                    eh.log_exception("Failed to process aux database: {}".format(db_name), e, task=self)
            logging.info("Skip not existed or not loaded databases: %s", skipped_dbs)

    def handle_aux_db(self, te_helper, db, c_info):
        logging.info("Handle aux db %s", db)
        last_results = te_helper.get_last_results(db)
        last_tasks_ids = {row_data[1]["task_id"] for row_data in last_results if "task_id" in row_data[1]}
        self.process_trunk_tasks(last_tasks_ids, c_info, db, aux=True)

    @staticmethod
    def build_aux_run_key(task_info=None):
        logging.debug("Building a key to store aux db name in state")
        template = "{prefix}/$aux_run"
        prefix = ''
        if not task_info:
            logging.debug("No task info provided. Using a blank prefix")
        elif task_info['input_parameters.tag_or_branch_mode'] == rm_const.BranchingMode.TRUNK_TO_BRANCH:
            prefix = '/br{branch_num}'.format(branch_num=task_info['context.scope_number'])
        elif task_info['input_parameters.tag_or_branch_mode'] == rm_const.BranchingMode.TRUNK_TO_TAG:
            prefix = '/tr{tag_num}'.format(tag_num=task_info['context.scope_number'])
        key = template.format(prefix=prefix)
        logging.debug("Result key: %s", key)
        return key

    def add_aux_to_state(self, c_info, task_info_list, db_name):
        logging.debug("Adding aux db to state")
        branch_or_tag_task_info = next(
            (task_info for task_info in task_info_list if task_info.get('type', '') == BRANCH_OR_TAG_TASK_TYPE),
            None
        )
        logging.debug("Got %s task info: %s", BRANCH_OR_TAG_TASK_TYPE, branch_or_tag_task_info)

        from release_machine.common_proto import events_pb2 as rm_proto_events

        key = self.build_aux_run_key(task_info=branch_or_tag_task_info)
        component_name = c_info.name

        events_helper.post_proto_events([rm_proto_events.EventData(
            general_data=rm_proto_events.EventGeneralData(
                hash=six.text_type(hash((
                    component_name,
                    key,
                    db_name,
                    "SUCCESS",
                ))),
                component_name=component_name,
                referrer=u"old_event_crawler:{}".format(self.id)
            ),
            task_data=rm_proto_events.EventSandboxTaskData(
                task_id=self.id,
                status="SUCCESS",
                created_at=self.created.isoformat(),
                updated_at=self.updated.isoformat(),
            ),
            update_state_data=rm_proto_events.UpdateStateData(
                key=key,
                value=db_name,
            ),
        )])

    def handle_trunk_db(self, te_helper, c_info, timer):

        from release_machine.release_machine.proto.structures import message_pb2

        with timer["handling_trunk_db"]:
            db = c_info.testenv_cfg__trunk_db

            logging.info("Handle %s", db)

            try:

                sys_info_response = te_helper.get_sys_info(db)

                if not sys_info_response.json().get("is_started", True):

                    self.rm_model_client.set_component_status(
                        message_pb2.SetComponentStatusRequest(
                            component_name=c_info.name,
                            status="DISABLED",
                        ),
                    )

            except:

                logging.exception("Failed to process %s db status", db)

            try:

                launched_tasks = te_helper.get_all_sandbox_task_ids(
                    db,
                    timestamp_since=int(tu.datetime_to_timestamp(self.current_time_utc_short)),
                )

            except requests.ReadTimeout as rt:
                eh.log_exception("Unable to process database {}".format(db), rt)
                return False

            return self.process_trunk_tasks({item["task_id"] for item in launched_tasks}, c_info, db)

    def handle_branch_db(self, db, c_info, timer):
        with timer["handling_branch_db"]:
            logging.info("Handle %s", db)

            try:
                tasks_info = self.get_last_updated_sb_task(db)
            except requests.ReadTimeout as rt:
                eh.log_exception("Unable to process database {}".format(db), rt)
                return False

            rm_event_data_generator = events_helper.updated_rm_event_data_by_task_info_generator(c_info, tasks_info)
            result_ok = self.send_events(c_info, rm_event_data_generator)
        return result_ok

    def get_last_updated_sb_task(self, db, limit=2000):
        from_time = dateutil.parser.parse(self.last_crawled_time).replace(microsecond=0, tzinfo=None)
        tasks = self.rest_client.task[{
            "limit": limit,
            "tags": "TESTENV-DATABASE-{}".format(db.upper()),
            "updated": "{}Z..{}Z".format(from_time.isoformat(), self.current_time_utc_short.isoformat()),
            "fields": TASK_FIELDS,
            "hidden": True,
        }]["items"]
        logging.info("Got %s tasks from sandbox for db: %s", len(tasks), db)
        logging.debug("Got task infos: %s", tasks)
        return tasks

    def get_task_infos(self, task_ids, c_info):
        # Sometimes there can be None element in task_ids list
        task_ids = [t_id for t_id in task_ids if t_id is not None]
        if not task_ids:
            logging.info("No tasks to process")
            return
        logging.info("Process %s tasks", len(task_ids))
        task_info_list = self.get_all_useful_info(task_ids, c_info)
        logging.debug("All info for task_ids: {task_ids}\n{tasks_info}".format(
            task_ids=task_ids, tasks_info=task_info_list
        ))
        return task_info_list

    def process_trunk_tasks(self, task_ids, c_info, db, aux=False):
        task_info_list = self.get_task_infos(task_ids, c_info)
        if not task_info_list:
            return True
        if aux:
            rm_event_data_generator = events_helper.updated_rm_event_data_by_task_info_generator(c_info, task_info_list)
        else:
            rm_event_data_generator = events_helper.rm_event_data_by_task_info_generator(c_info, task_info_list)
        result_ok = self.send_events(c_info, rm_event_data_generator)
        if aux:
            self.add_aux_to_state(c_info, task_info_list, db)
        return result_ok

    def _get_proto_event(self, task_info):
        """
        Builds a new-style protobuf event

        Note: RMINCIDENTS-555 - one cannot simply turn old crawler off
        """
        if not self.is_binary_run:
            return

        rm_proto_event_str = task_info.get(events_helper.RM_PROTO_EVENT_CTX_KEY)
        event_ready = task_info.get(events_helper.RM_PROTO_EVENT_READY_CTX_KEY)
        inferred_rm_proto = task_info.get(events_helper.INFERRED_RM_PROTO_EVENT_KEY)

        if event_ready is False:
            return

        try:
            if rm_proto_event_str:
                event_message = events_helper.parse_rm_proto_event_from_b64_encoded_str(rm_proto_event_str)
            elif inferred_rm_proto:
                event_message = inferred_rm_proto
            else:
                return
            event_message.task_data.task_id = task_info['id']
            event_message.task_data.created_at = task_info.get('started_at', task_info['time.created'])
            event_message.task_data.updated_at = task_info.get('finished_at', task_info['time.updated'])
            event_message.task_data.status = events_helper.convert_sandbox_task_status_into_rm_event_enum_status(
                task_info['status'],
            )
            event_message.general_data.referrer = "old_event_crawler:{}".format(self.id)
            return event_message
        except Exception as e:
            eh.log_exception("Failed to create proto event from {}".format(task_info), e, task=self)

    def send_events(self, c_info, rm_event_data_generator):
        proto_events = []
        result_ok = True
        for task_info, event_type in rm_event_data_generator:
            event = self._get_proto_event(task_info)
            if event:
                proto_events.append(event)
        if not self.Parameters.debug:
            result_ok = events_helper.post_proto_events(proto_events, prod=True, dev=True)
        return result_ok

    @staticmethod
    def filter_rm_aux_dbs(dbs):
        return [db for db in dbs if db["test_names"].startswith('_') and not db["compare_mode"]]

    def get_all_useful_info(self, task_ids, c_info):
        if not task_ids:
            return []
        # Process task_ids by chunks
        chunk_size = 100
        all_info = []
        for chunk in common_utils.chunker(task_ids, chunk_size):
            task_info_list = self.rest_client.task[{
                "limit": len(chunk),
                "id": chunk,
                "fields": TASK_FIELDS,
                "hidden": True,
            }]["items"]

            component_name_match_errors = []

            for task_info in task_info_list:

                component_name = task_info.get('input_parameters.component_name') or c_info.name
                task_id = task_info["id"]

                if component_name == c_info.name:  # RMDEV-1622
                    all_info.append(task_info)
                    continue  # everything's OK, continue with next task info

                if {component_name, c_info.name} in COMPONENT_NAME_ERROR_WHITELIST:
                    logging.info(
                        '#%s: %s != %s -- skipping task info (no report dew to whitelist setting)',
                        task_id,
                        component_name,
                        c_info.name,
                    )
                    continue

                component_name_match_errors.append(
                    "- #{task_id}: {input_parameter_component_name} != {c_info_name}".format(
                        task_id=lb.task_link(task_id),
                        input_parameter_component_name=component_name,
                        c_info_name=c_info.name,
                    )
                )

                logging.debug("Skipping task info %s", task_info)

            if component_name_match_errors:

                msg = (
                    "Input parameter 'component_name' of the following tasks does not match current component's name:\n"
                    "{}".format(
                        "\n".join(component_name_match_errors)
                    )
                )
                self._report_error(c_info.name, msg)

                rm_notify.send_tm_message(
                    self,
                    "ERROR! {task_type}: {msg}\nSB: {task_link}".format(
                        task_type=str(self.type),
                        msg=msg,
                        task_link=self.task_link_html,
                    ),
                    [rm_const.RM_USERS["rm_maintainers"].telegram],
                )

        return all_info
