import logging
import json
import time
import requests
from collections import namedtuple, defaultdict
from datetime import datetime, timedelta

from enum import Enum
import pytz

from sandbox import sdk2
from sandbox.sandboxsdk.environments import PipEnvironment
from sandbox.common.rest import Client
from sandbox.common.utils import get_task_link
from sandbox.common.types.task import ReleaseStatus, Status as TaskStatus
from sandbox.common.proxy import NoAuth
from sandbox.projects.common.nanny.client import NannyClient
from sandbox.projects.yabs.qa.solomon.mixin import SolomonTaskMixin, SolomonTaskMixinParameters
from sandbox.projects.release_machine.helpers.startrek_helper import STHelper
from sandbox.projects.release_machine.core import const as rm_const
from sandbox.projects.release_machine.components.all import get_component
from sandbox.projects.yabs.audit.runtime.monitoring.constants import (
    COMPONENT_NAME, RESOURCE_TYPE, AUDIT_REQUIREMENTS, AUDITABLE_BINARIES, BUILD_TASK_TYPES
)
from sandbox.projects.yabs.audit.runtime.monitoring.requirements import check_requirements
from sandbox.projects.yabs.audit.runtime.monitoring.startrek import get_release_issue
from sandbox.projects.yabs.release.duty.schedule import get_current_engine_responsibles
from sandbox.projects.yabs.YabsServerNewRuntime.utils.check_requirements import check_new_runtime_requirements

from .report import create_st_table_report, create_sb_table_report, create_approved_files_report, create_fails_report
from .issue import create_issue_report, create_or_update_issue
from .resolving import (
    get_services,
    get_enabled_hosts_by_service,
    BALANCERS_BY_PRJ,
    EXCLUDE_SERVICES,
)


logger = logging.getLogger(__name__)


RawMonitoringResult = namedtuple("RawMonitoringResult", ("service", "host", "engine_port", "file", "status"))
HostPort = namedtuple("HostPort", ("host", "engine_port"))
PORT = 8090


class MonitoringStatus(Enum):
    OK = "ok"
    FAIL = "fail"
    UNKNOWN = "unknown"
    DAEMON_TIMEOUT = "daemon_timeout"

    TEMPORARY_ERROR = "temporary_error"


COLOR_MAP = {
    MonitoringStatus.OK.value: "green",
    MonitoringStatus.FAIL.value: "red",
    MonitoringStatus.DAEMON_TIMEOUT.value: "blue",
    MonitoringStatus.UNKNOWN.value: "grey",
}


class HostTypes(Enum):
    hosts = "hosts"
    nanny_services = "nanny_services"
    balancer_backends = "balancer_backends"


class ReleaseAuditResult(sdk2.Resource):
    """Contains JSON with release audit fails
    """
    auto_backup = True
    ttl = 550  # 1.5 year

    date = sdk2.parameters.String("Date in %Y-%m-%d format")
    component_name = sdk2.parameters.String("Component name")


class ReleaseAuditReport(sdk2.Resource):
    """Contains HTML report with release audit fails
    """
    auto_backup = True
    ttl = 550  # 1.5 year

    date = sdk2.parameters.String("Date in %Y-%m-%d format")
    component_name = sdk2.parameters.String("Component name")


class YabsServerRuntimeMonitoring(SolomonTaskMixin, sdk2.Task):
    """Checks that running binaries were properly released
    """

    class Requirements(sdk2.Requirements):
        cores = 1  # exactly 1 core
        ram = 4096  # 4GiB or less

        environments = [
            PipEnvironment('startrek_client', use_wheel=True),  # PYPI-101
            PipEnvironment('retrying'),
        ]

        class Caches(sdk2.Requirements.Caches):
            pass

    class Parameters(SolomonTaskMixinParameters, sdk2.Parameters):
        expires = timedelta(hours=1)

        with sdk2.parameters.Group("Host settings") as host_settings:
            new_runtime_groups = sdk2.parameters.List("Groups allowed to have specific builds (BSNEWRUNTIME)")
            host_type = sdk2.parameters.String("Host types", choices=[(host_type.name, host_type.value) for host_type in HostTypes], default=HostTypes.hosts.value)
            with host_type.value[HostTypes.hosts.value]:
                hosts = sdk2.parameters.List("Hosts to check")

            with host_type.value[HostTypes.nanny_services.value]:
                use_nanny_dashboard = sdk2.parameters.Bool("Use services from dashboard", default=True)
                with use_nanny_dashboard.value[True]:
                    nanny_dashboard = sdk2.parameters.String("Use services from dashboard", default="bsfront_production")
                    nanny_token = sdk2.parameters.String("Nanny token vault name")
                with use_nanny_dashboard.value[False]:
                    include_services = sdk2.parameters.List("Include hosts from following nanny services")

            with host_type.value[HostTypes.balancer_backends.value]:
                balancer_nanny_dashboard = sdk2.parameters.String("Use services from dashboard", default="bsfront_production")
                balancer_nanny_token = sdk2.parameters.String("Nanny token vault name")

        with sdk2.parameters.Group("Issues settings") as issues_settings:
            st_queue = sdk2.parameters.String("Startrek queue", default_value="BSAUDIT")
            st_token = sdk2.parameters.String("Startrek token vault name", default_value="audit-release-startrek-token")
            staff_token = sdk2.parameters.String("Staff token vault name", default_value="audit-release-staff-token")

        with sdk2.parameters.Group("Debug options") as debug_options:
            debug_mode = sdk2.parameters.Bool("Debug mode", default=False)
            with debug_mode.value[True]:
                use_production_sandbox = sdk2.parameters.Bool("Use production sandbox", default=True)

        max_attempts = sdk2.parameters.Integer('Max number of attempts for each host', default=3)
        history_limit = sdk2.parameters.Integer('Amount of previous results to take into consideration when deciding daemon_timeout severity', default=2)
        timeout_limit = sdk2.parameters.Integer('Allow daemon_timeout following amount of times', default=1)

        solomon_project = SolomonTaskMixinParameters.solomon_project(default="yabs_debug")
        solomon_service = SolomonTaskMixinParameters.solomon_service(default="audit_release")
        solomon_cluster = SolomonTaskMixinParameters.solomon_cluster(default="frontend")

    def iter_hosts_chunks(self):
        host_type = HostTypes(self.Parameters.host_type)
        if host_type == HostTypes.hosts:
            yield None, [HostPort(host, None) for host in self.Parameters.hosts]

        elif host_type == HostTypes.nanny_services:
            if self.Parameters.use_nanny_dashboard:
                nanny_token = sdk2.Vault.data(self.Parameters.nanny_token)
                nanny_client = NannyClient(rm_const.Urls.NANNY_BASE_URL, nanny_token)

                services = set(nanny_client.get_dashboard_services(self.Parameters.nanny_dashboard))
            else:
                services = set(self.Parameters.include_services)

            services -= set(EXCLUDE_SERVICES)
            from api.hq import HQResolver
            resolver = HQResolver(use_service_resolve_policy=True)
            for service in services:
                try:
                    service_hosts = resolver.get_mtns(service)
                    yield service, [HostPort(host, None) for host in service_hosts]
                except Exception as exc:
                    logger.exception("Unable to resolve hosts for %s: %s", service, exc)
                    yield service, []

        elif host_type == HostTypes.balancer_backends:
            nanny_token = sdk2.Vault.data(self.Parameters.balancer_nanny_token)
            nanny_client = NannyClient(rm_const.Urls.NANNY_BASE_URL, nanny_token)
            for prj, balancer_hostnames in BALANCERS_BY_PRJ.items():
                services = get_services(
                    nanny_client,
                    self.Parameters.balancer_nanny_dashboard,
                    filter_by_labels=[("prj", prj), ("ctype", "prod")],
                    exclude_services=EXCLUDE_SERVICES,
                )
                services.extend(get_services(
                    nanny_client,
                    self.Parameters.balancer_nanny_dashboard,
                    filter_by_labels=[("prj", prj), ("ctype", "experiment")],
                    exclude_services=EXCLUDE_SERVICES,
                ))
                enabled_hosts_by_service = get_enabled_hosts_by_service(balancer_hostnames, services)
                for service in services:
                    yield service, [HostPort(host, engine_port) for host, engine_port in enabled_hosts_by_service[service]]

        else:
            raise ValueError("Unknown host_type specified in parameters: {}".format(self.Parameters.host_type))

    @staticmethod
    def get_approved_files(task_types, resource_type, sandbox_client, startrek_helper, component_info, overall_limit=20):
        from sandbox.projects.yabs.release.version.sandbox_helpers import SandboxHelper
        version_sandbox_helper = SandboxHelper(sandbox_client=sandbox_client)
        approved_files = defaultdict(dict)
        limit = 10
        offset = 0
        while offset < overall_limit:
            released_tasks = sandbox_client.task.read(type=task_types, release=ReleaseStatus.STABLE, limit=limit, offset=offset)
            for task in released_tasks['items']:
                issue = get_release_issue(task['id'], sandbox_client, startrek_helper, component_info)
                server_version = version_sandbox_helper.get_basic_version_from_task(task['id'])
                is_legit, errors = check_requirements(AUDIT_REQUIREMENTS, issue, task['id'], sandbox_client)
                if is_legit:
                    file_md5_hashes = task["output_parameters"]["md5"]
                    for filename in AUDITABLE_BINARIES:
                        file_md5 = file_md5_hashes.get(filename)
                        if file_md5 is not None:
                            approved_files[filename][(task['type'], server_version, '')] = file_md5
                else:
                    logger.info("Release built by %s #%d is incorrect due to:\n%s", task['type'], task['id'], "\n".join(errors))

            offset += limit

        return approved_files

    @staticmethod
    def get_new_runtime_approved_files(sandbox_client, startrek_client, task_type='YABS_NEW_RUNTIME_BUILD', overall_limit=20):
        from sandbox.projects.yabs.release.version.version import get_version_from_arcadia_url
        approved_files = defaultdict(dict)
        limit = 10
        offset = 0
        while offset < overall_limit:
            tasks = sandbox_client.task.read(type=task_type, status=TaskStatus.SUCCESS, limit=limit, offset=offset, hidden=True)
            logger.info('get_new_runtime_approved_files: tasks found {}'.format(len(tasks['items'])))
            for task_info in tasks['items']:
                issue = task_info['output_parameters']['issue'].split('/')[-1]
                logger.info('issue {} parsed from {} ({})'.format(issue, task_info['output_parameters']['issue'], task_info['id']))
                server_version = get_version_from_arcadia_url(task_info['input_parameters']['arcadia_url'])
                patch = task_info['input_parameters']['patch']
                is_legit, errors = check_new_runtime_requirements(startrek_client, issue)
                logger.info('{} is {}'.format(issue, 'legit' if is_legit else 'not_legit'))
                if is_legit:
                    build_tasks = sandbox_client.task.read(type=BUILD_TASK_TYPES, status=TaskStatus.SUCCESS, parent=task_info['id'], limit=limit)
                    for build_task in build_tasks['items']:
                        file_md5_hashes = build_task["output_parameters"]["md5"]
                        logger.info('Build task {} has output_parameters->md5: {}'.format(build_task['id'], file_md5_hashes))
                        for filename in AUDITABLE_BINARIES:
                            file_md5 = file_md5_hashes.get(filename)
                            if file_md5 is not None:
                                approved_files[filename][(build_task['type'], server_version, patch)] = file_md5
                else:
                    error_message = '\n'.join(['* {}'.format(e) for e in errors])
                    formatted_message = 'Not all requirements for NEW_RUNTIME are met:\n\n{}'.format(error_message)
                    logger.info(formatted_message)

            offset += limit

        return approved_files

    def generate_group_specific_approved_files(self, sandbox_client, startrek_client):
        new_runtime_files = self.get_new_runtime_approved_files(sandbox_client, startrek_client)
        return {group_name: new_runtime_files for group_name in self.Parameters.new_runtime_groups}

    def send_statistics(self, stats):
        file_stats = {}
        for filename in AUDITABLE_BINARIES:
            for status in MonitoringStatus:
                file_stats[(filename, status)] = 0

        for (_, filename, status), count in stats.items():
            file_stats[(filename, status)] += count

        sensors = []
        for (filename, status), count in file_stats.items():
            sensors.append({
                "labels": {
                    "status": status.value,
                    "filename": filename,
                },
                "value": count,
                "ts": int(time.time()),
            })
        self.solomon_push_client.add(sensors)

    @staticmethod
    def iter_raw_monitoring_results(hosts_chunks_iterator, approved_files, group_specific_approved_files, temporary_is_fail=True, timeout=120):
        from sandbox.projects.yabs.audit.runtime.monitoring.runner import iter_binary_hashes
        for service, hosts_chunk in hosts_chunks_iterator:
            logger.info('Checking service {} with hosts_chunk.len={}'.format(service, len(hosts_chunk)))
            if not hosts_chunk:
                for filename in AUDITABLE_BINARIES:
                    yield RawMonitoringResult(service, None, None, filename, MonitoringStatus.UNKNOWN)
                continue

            for host, engine_port, result in iter_binary_hashes(hosts_chunk, PORT, timeout=timeout):
                if not result or not result.engine_is_running:
                    for filename in AUDITABLE_BINARIES:
                        yield RawMonitoringResult(service, host, engine_port, filename, MonitoringStatus.UNKNOWN)
                    continue

                if result.md5 is None:
                    for filename in AUDITABLE_BINARIES:
                        if not temporary_is_fail:
                            status = MonitoringStatus.TEMPORARY_ERROR
                        elif isinstance(result.exception, requests.ReadTimeout):
                            status = MonitoringStatus.DAEMON_TIMEOUT
                        else:
                            status = MonitoringStatus.FAIL
                        yield RawMonitoringResult(service, host, engine_port, filename, status)
                    continue

                for filename, file_hash in result.md5.items():
                    if filename not in AUDITABLE_BINARIES:
                        continue
                    if file_hash in approved_files[filename].values():
                        status = MonitoringStatus.OK
                    elif service in group_specific_approved_files and file_hash in group_specific_approved_files[service][filename].values():
                        logger.info("Make group specific release check for {}: filename={}".format(service, filename))
                        status = MonitoringStatus.OK
                    else:
                        status = MonitoringStatus.FAIL
                        logger.error("File %s at %s has md5 %s that was not properly approved", filename, host, file_hash)
                    yield RawMonitoringResult(service, host, engine_port, filename, status)

    def get_assignee(self):
        if not self.Parameters.debug_mode:
            engine_responsibles = get_current_engine_responsibles()
            if len(engine_responsibles) > 0:
                return engine_responsibles[0]
            logger.error("Not found current Engine@OnDuty, falling back to task's maintainer")
        return "igorock"

    def create_report_resource(self, fails, auditable_files, date_tag, component_name):
        report = create_fails_report(fails, auditable_files, get_task_link(self.id), str(self.type))
        report_resource = ReleaseAuditReport(self, "AuditRelease report", "report.html", date=date_tag, component_name=component_name)
        report_resource_data = sdk2.ResourceData(report_resource)
        report_resource_data.path.write_bytes(report.encode("utf-8"))
        return report_resource

    def create_result_resource(self, fails, date_tag, component_name, status=MonitoringStatus.FAIL.value):
        result_resource = ReleaseAuditResult(self, "AuditRelease result", "result_{}.json".format(status), date=date_tag, component_name=component_name, status=status)
        result_resource_data = sdk2.ResourceData(result_resource)
        result_resource_data.path.write_bytes(json.dumps(fails, indent=2))
        return result_resource

    @property
    def startrek_token(self):
        if not hasattr(self, "__startrek_token"):
            self.__startrek_token = sdk2.Vault.data(self.Parameters.st_token)
        return self.__startrek_token

    @property
    def staff_token(self):
        if not hasattr(self, "__staff_token"):
            self.__staff_token = sdk2.Vault.data(self.Parameters.staff_token)
        return self.__staff_token.strip()

    def get_head(self, whom):
        headers = {'Authorization': 'OAuth {}'.format(self.staff_token)}
        response = requests.get(
            'https://staff-api.yandex-team.ru/v3/persons',
            headers=headers,
            params={
                'login': whom,
                '_fields': 'chief.login'
            }
        ).json()
        return response['result'][0]['chief']['login']

    def get_bs_head(self):
        headers = {'Authorization': 'OAuth {}'.format(self.staff_token)}
        response = requests.get(
            'https://staff-api.yandex-team.ru/v3/groups',
            headers=headers,
            params={
                'id': '83',
                '_fields': 'department.heads.person.login'
            }
        ).json()
        return response['result'][0]['department']['heads'][0]['person']['login']

    def do_report(self, approved_files, group_specific_approved_files, fails, timeouts, stats, startrek_client):
        self.send_statistics(stats)

        approved_files_report = create_approved_files_report(approved_files)
        self.set_info(approved_files_report, do_escape=False)

        for group, approved in group_specific_approved_files.items():
            approved_files_report = create_approved_files_report(approved)
            approved_files_report = approved_files_report.replace('Approved releases', 'Approved releases for {}'.format(group))
            self.set_info(approved_files_report, do_escape=False)

        stats_report_sb = create_sb_table_report(stats, AUDITABLE_BINARIES, MonitoringStatus, COLOR_MAP)
        self.set_info(stats_report_sb, do_escape=False)

        today = datetime.now(tz=pytz.timezone("Europe/Moscow")).date()
        date_tag = today.strftime("%Y-%m-%d")

        if len(timeouts) > 0:
            self.create_result_resource(timeouts, date_tag, COMPONENT_NAME, status=MonitoringStatus.DAEMON_TIMEOUT.value)

        if len(fails) > 0:
            report_resource = self.create_report_resource(fails, AUDITABLE_BINARIES, date_tag, COMPONENT_NAME)
            self.create_result_resource(fails, date_tag, COMPONENT_NAME, status=MonitoringStatus.FAIL.value)
            self.set_info('<a href="{resource_link}" target="_blank">Full report</a>'.format(resource_link=report_resource.http_proxy), do_escape=False)

            stats_report_st = create_st_table_report(stats, AUDITABLE_BINARIES, MonitoringStatus, COLOR_MAP)
            st_report_text = create_issue_report(report_resource.http_proxy, stats_report_st, get_task_link(self.id), str(self.type))
            issue = create_or_update_issue(
                startrek_client,
                st_report_text,
                fails,
                date_tag=date_tag,
                component_name=COMPONENT_NAME,
                queue=self.Parameters.st_queue,
                assignee=self.get_assignee(),
                followers=[self.get_head(self.get_assignee()), self.get_bs_head()]
            )
            self.set_info('Issue: <a href="https://st.yandex-team.ru/{issue.key}" target="_blank">{issue.key}</a>'.format(issue=issue), do_escape=False)

    def get_host_daemon_timeout_history(self, history_limit):
        host_daemon_timeout_history = defaultdict(int)
        if self.scheduler:
            prev_tasks = list(sdk2.Task.find(scheduler=self.scheduler, status=TaskStatus.SUCCESS, order="-id").limit(history_limit))
            logger.debug('Found %s prev_tasks: %s', len(prev_tasks), list(prev_tasks))
            history_resources = []
            for prev_task in prev_tasks:
                resource = ReleaseAuditResult.find(task_id=prev_task.id, state='READY', status=MonitoringStatus.DAEMON_TIMEOUT.value, limit=1).first()
                if resource:
                    history_resources.append(resource)
            logger.debug('Found %s history resources: %s', MonitoringStatus.DAEMON_TIMEOUT.value, history_resources)

            for history_resource in history_resources:
                with open(str(sdk2.ResourceData(history_resource).path), 'r') as history_file:
                    d = json.load(history_file)
                for service_name, host_files in d.items():
                    for host, file_statuses in host_files.items():
                        if MonitoringStatus.DAEMON_TIMEOUT.name in file_statuses.values():
                            logger.debug('Found %s at %s in service %s', MonitoringStatus.DAEMON_TIMEOUT.name, host, service_name)
                            host_daemon_timeout_history[host] += 1

        logger.debug('Got host daemon_timeout scores: %s', dict(host_daemon_timeout_history))
        return dict(host_daemon_timeout_history)

    def on_execute(self):
        sandbox_client = Client(base_url=Client.DEFAULT_BASE_URL, auth=NoAuth()) if self.Parameters.debug_mode and self.Parameters.use_production_sandbox else self.server
        startrek_helper = STHelper(token=self.startrek_token)
        from startrek_client import Startrek
        startrek_client = Startrek(useragent=self.__class__.__name__, token=self.startrek_token)

        component_info = get_component(COMPONENT_NAME)
        approved_files = self.get_approved_files(BUILD_TASK_TYPES, RESOURCE_TYPE, sandbox_client, startrek_helper, component_info)
        logger.info("Approved releases are:\n%s", approved_files)
        group_specific_approved_files = self.generate_group_specific_approved_files(sandbox_client, startrek_client)
        logger.info("Approved group specific releases are:\n%s", group_specific_approved_files)
        fails = {}
        timeouts = {}
        temporary = defaultdict(set)
        stats = defaultdict(int)

        host_daemon_timeout_history = self.get_host_daemon_timeout_history(self.Parameters.history_limit)
        for result in self.iter_raw_monitoring_results(self.iter_hosts_chunks(), approved_files, group_specific_approved_files, temporary_is_fail=self.Parameters.max_attempts <= 1):
            if result.status == MonitoringStatus.TEMPORARY_ERROR:
                temporary[result.service].add(HostPort(result.host, result.engine_port))
                continue
            elif result.status == MonitoringStatus.FAIL:
                fails.setdefault(result.service, {}).setdefault(result.host, {})[result.file] = result.status.name
            elif result.status == MonitoringStatus.DAEMON_TIMEOUT:
                timeouts.setdefault(result.service, {}).setdefault(result.host, {})[result.file] = result.status.name
                if host_daemon_timeout_history.get(result.host, 0) >= self.Parameters.timeout_limit:
                    fails.setdefault(result.service, {}).setdefault(result.host, {})[result.file] = result.status.name
            stats[(result.service, result.file, result.status)] += 1

        attempt = 1
        while temporary:
            logger.warning("Got %d services with temporary unavailable md5 daemon at %d attempt", len(temporary), attempt)
            time.sleep(180)  # If it is a temporary failure we need to give Nanny a little bit of time to restart runtime daemon or shut down yabs-server
            retried_temporary = defaultdict(set)
            # TODO: igorock@ make iter_raw_monitoring_results in one cycle
            temporary_is_fail = attempt >= self.Parameters.max_attempts
            timeout = attempt * 120
            for result in self.iter_raw_monitoring_results(temporary.items(), approved_files, group_specific_approved_files, temporary_is_fail=temporary_is_fail, timeout=timeout):
                if result.status == MonitoringStatus.TEMPORARY_ERROR:
                    retried_temporary[result.service].add(HostPort(result.host, result.engine_port))
                    continue
                elif result.status == MonitoringStatus.FAIL:
                    fails.setdefault(result.service, {}).setdefault(result.host, {})[result.file] = result.status.name
                elif result.status == MonitoringStatus.DAEMON_TIMEOUT:
                    timeouts.setdefault(result.service, {}).setdefault(result.host, {})[result.file] = result.status.name
                    if host_daemon_timeout_history.get(result.host, 0) >= self.Parameters.timeout_limit:
                        fails.setdefault(result.service, {}).setdefault(result.host, {})[result.file] = result.status.name
                stats[(result.service, result.file, result.status)] += 1
            attempt += 1
            temporary = retried_temporary

        logger.debug("Got results:\nFails:\n%s\nTimeouts:\n%s\nStats\n", fails, timeouts, stats)
        logger.info('Head: {}'.format(self.get_head(self.get_assignee())))
        logger.info('Head bs: {}'.format(self.get_bs_head()))
        self.do_report(approved_files, group_specific_approved_files, fails, timeouts, stats, startrek_client)
