import copy
import json
import logging
import operator
import os
import requests
import subprocess
import time
from collections import Counter

from sandbox import sdk2
from sandbox.projects.websearch.begemot import resources as br
from sandbox.projects.websearch.begemot.common import Begemots
from sandbox.projects.WizardRuntimeBuild.ya_make import YaMake
from sandbox.projects.common import solomon
from sandbox.projects.common.arcadia import sdk
from sandbox.common.errors import TaskFailure
from sandbox.common.utils import get_task_link
from sandbox.sandboxsdk import environments
from sandbox.sandboxsdk.channel import channel
from sandbox.sandboxsdk.svn import Arcadia, Svn


class BegemotInfo:
    rule_infos = {}
    rewrite_abc_by_shard = {}
    rewrite_service_template = {}
    rewrite_hamster_name = {}
    rewrite_prj = {}
    rewrite_resource_name = {}
    rewrite_shard_name = {}

    def __init__(self, rule_infos):
        self.rule_infos = rule_infos

        self.rewrite_abc_by_shard['Advq'] = 'advq'

        for suffix in ['Wizard', 'Merger', 'LingBoost', 'Service']:
            shard_name = 'Market{}'.format(suffix)
            self.rewrite_abc_by_shard[shard_name] = 'search-wizard/market'
            self.rewrite_service_template[shard_name] = 'bg_market_{}_yp'.format(suffix.lower())
            self.rewrite_hamster_name[shard_name] = None
            if suffix != 'Wizard':
                self.rewrite_shard_name[shard_name] = suffix
            if suffix == 'Service':
                self.rewrite_prj[shard_name] = 'market'
                self.rewrite_resource_name[shard_name] = Begemots['ServiceWizard'].fast_build_config_resource_name
                self.rewrite_shard_name[shard_name] = 'ServiceWizard'
            else:
                self.rewrite_prj[shard_name] = 'market-{}'.format(suffix.lower())
                self.rewrite_resource_name[shard_name] = Begemots[suffix].fast_build_config_resource_name

    def get_rule_infos(self):
        return self.rule_infos

    def get_special_shards(self):
        return self.rewrite_abc_by_shard.keys()

    def get_abc(self, rule, shard):
        return self.rewrite_abc_by_shard.get(str(shard), self.rule_infos.get(rule, {}).get('Abc', '??'))

    def get_prod_template(self, shard):
        if str(shard) in self.rewrite_service_template:
            return self.rewrite_service_template[str(shard)]
        return Begemots[shard].nanny_name

    def get_hamster(self, shard):
        if str(shard) in self.rewrite_hamster_name:
            return self.rewrite_hamster_name[str(shard)]
        return Begemots[shard].beta_parent

    def get_prj(self, shard):
        if str(shard) in self.rewrite_prj:
            return self.rewrite_prj[str(shard)]
        return Begemots[shard].prj

    def get_resource_name(self, shard):
        if str(shard) in self.rewrite_resource_name:
            return self.rewrite_resource_name[str(shard)]
        return Begemots[shard].fast_build_config_resource_name

    def get_shard_name(self, shard):
        if str(shard) in self.rewrite_shard_name:
            return self.rewrite_shard_name[str(shard)]
        return shard

    def get_betas_count(self, shard):
        if shard not in Begemots.keys():
            return 0
        result = 0
        if Begemots[shard].beta:
            result += 6
        if Begemots[shard].perf_beta:
            result += 1
        return result


class BegemotQuotaUsageByRule(sdk2.Task):
    BEGEMOT_DATA_PATH = 'search/begemot/data'
    period = 3600
    end_time = time.time() - period * 2
    start_time = end_time - period * 24 * 7 # 1 week
    begemot_info = None
    geos = ['sas', 'man', 'vla']

    class Parameters(sdk2.Parameters):
        begemot_binary = sdk2.parameters.Resource(
            'Begemot binary',
            resource_type=br.BEGEMOT_EXECUTABLE,
            required=False
        )

        with sdk2.parameters.CheckGroup('Begemot shards to check') as shards:
            for shard_name in sorted(Begemots.Services):
                shard = Begemots[shard_name]
                setattr(
                    shards.values, shard_name,
                    shards.Value(shard_name, checked=shard.test_with_merger),
                )

            for suffix in ['Merger', 'Service', 'LingBoost']:
                shard_name = 'Market{}'.format(suffix)
                setattr(
                    shards.values, shard_name,
                    shards.Value(shard_name, checked=False),
                )

        checkout_arcadia_from_url = sdk2.parameters.ArcadiaUrl("Svn url for arcadia", required=True)

        count_cpu = sdk2.parameters.Bool(
            'Count cpu usage',
            required=True,
            default=True,
        )

        count_disk = sdk2.parameters.Bool(
            'Count disk/RAM usage',
            required=True,
            default=False,
        )

        with count_disk.value[True]:
            nanny_token = sdk2.parameters.String(
                'Nanny token',
                required=True,
                default='Begemot Nanny token',
            )

        report_to_solomon = sdk2.parameters.Bool(
            'Report to Solomon',
            required=True,
            default=False,
        )

        with report_to_solomon.value[True]:
            solomon_token = sdk2.parameters.String(
                'Solomon token vault name for the owner',
                required=True,
                default='Begemot Solomon token'
            )

            solomon_cluster = sdk2.parameters.String(
                'Solomon cluster',
                required=True,
                default='quota_usage_by_abc'
            )

        with count_cpu.value[True]:
            create_emails = sdk2.parameters.Bool(
                'Create emails',
                required=True,
                default=False,
            )

            with create_emails.value[True]:
                abc_token = sdk2.parameters.String(
                    'Abc token vault name for the owner',
                    required=True,
                    default='Begemot Abc token'
                )

                subject = sdk2.parameters.String(
                    'Emails subject',
                    required=True
                )

                print_emails = sdk2.parameters.Bool(
                    'Print emails texts to task info',
                    default=False
                )

                send_emails = sdk2.parameters.Bool(
                    'Send emails',
                    default=False
                )

                testing_receiver = sdk2.parameters.String(
                    'Testing receiver login',
                    description='If not empty, all emails will be sent to this person'
                )

    class Requirements(sdk2.Requirements):
        environments = (
            environments.PipEnvironment('yasmapi'),
        )

    def _golovan_request(self, period, start_time, end_time, signals):
        from yasmapi import GolovanRequest

        results = {}
        start = 0
        end = 0

        # GolovanRequest cannot return more than 200 signals
        while end < len(signals):
            end = min(start + 200, len(signals))
            retries = 5
            while retries > 0:
                ok = False
                try:
                    golovan_result = GolovanRequest('ASEARCH', period, start_time, end_time, signals[start:end], load_delay=0.1, max_retry=15, retry_delay=5)
                    for ts, values in golovan_result:
                        if ts not in results:
                            results[ts] = values
                        else:
                            results[ts].update(values)
                    ok = True
                except:
                    retries -= 1
                    if retries == 0:
                        raise TaskFailure('Task failed to get data from golovan. Try later')
                    time.sleep(5)

                if ok:
                    break

            start = end

        return results

    def _get_median_time_signal_tags(self, shard, rule):
        signal = 'quant(begemot-WORKER-{}-TIME_dhhh, 50)'.format(rule)
        return 'itype=begemot;prj={}:{}'.format(self.begemot_info.get_prj(shard), signal)

    def _get_requests_count_signal_tags(self, shard, rule):
        signal = 'begemot-WORKER-{}-REQUESTS_dmmm'.format(rule)
        return 'itype=begemot;prj={}:{}'.format(self.begemot_info.get_prj(shard), signal)

    def _get_cpu_usage_signal(self, shard):
        signal = 'portoinst-cpu_usage_cores_tmmv'
        return 'itype=begemot;prj={}:{}'.format(self.begemot_info.get_prj(shard), signal)

    def get_cpu_usage_by_shard(self, shards):
        cpu_usage_by_shard = {}
        signals = {shard: self._get_cpu_usage_signal(shard) for shard in shards}
        golovan_result = self._golovan_request(self.period, self.start_time, self.end_time, signals.values())
        for shard in shards:
            values_for_shard = []
            for ts, values in golovan_result.items():
                values_for_shard.append(values[signals[shard]])
            values_for_shard = sorted(values_for_shard)
            cpu_usage_by_shard[shard] = values_for_shard[int(len(values_for_shard) * 0.9)]

        return cpu_usage_by_shard

    def get_rule_weights(self, shard, rules):
        time_tags = {rule: self._get_median_time_signal_tags(shard, rule) for rule in rules}
        requests_tags = {rule: self._get_requests_count_signal_tags(shard, rule) for rule in rules}
        golovan_result = self._golovan_request(self.period, self.start_time, self.end_time, time_tags.values() + requests_tags.values())

        rule_weights = {}
        for rule in rules:
            rule_times = []
            requests_count = 0

            for ts, values in golovan_result.items():
                rule_times.append(values[time_tags[rule]] or 0.0)
                requests_count += int(values.get(requests_tags[rule], 0) or 0)

            sorted_times = sorted(rule_times)
            left_idx, right_idx = int(len(rule_times) * 0.1), int(len(rule_times) * 0.9)
            rule_weights[rule] = requests_count * sum(sorted_times[left_idx:right_idx]) / (right_idx - left_idx) if right_idx - left_idx > 0 else 0

        return {rule: (value / sum(rule_weights.values())) for rule, value in rule_weights.items()}

    def get_abc_weights(self, rule_weights, shard):
        abc_weights = Counter()
        abc_stats = {}
        for rule, weight in rule_weights.items():
            abc = self.begemot_info.get_abc(rule, shard)
            abc_weights[abc] += weight
            if abc not in abc_stats:
                abc_stats[abc] = {rule: weight}
            else:
                abc_stats[abc][rule] = weight

        return dict(abc_weights), abc_stats

    def get_rule_time_weights(self, shard):
        rules_set = set()
        rule_infos = self.begemot_info.get_rule_infos()
        with sdk.mount_arc_path(self.Parameters.checkout_arcadia_from_url) as arcadia:
            shard_yamake_path = os.path.join(arcadia, self.BEGEMOT_DATA_PATH, self.begemot_info.get_shard_name(shard), 'ya.make')
            yamake = YaMake.YaMake(shard_yamake_path)
            for r in yamake.peerdir:
                rulename = os.path.basename(r)
                if rulename in rule_infos:
                    rules_set.add(rulename)
                    for dep, info in rule_infos[rulename].get('Dependencies', {}).items():
                        if info.get('Required', False):
                            rules_set.add(dep)
                elif shard in self.begemot_info.get_special_shards():
                    rules_set.add(rulename)

        rule_weights = self.get_rule_weights(shard, list(rules_set))
        abc_weights, abc_stats = self.get_abc_weights(rule_weights, shard)
        return abc_weights, abc_stats

    def get_cpu_by_abc(self, abc_weights, abc_stats, cpu_usage_by_shard):
        cpu_by_abc = Counter()
        stats_by_rule = {}
        stats_by_shard = {}

        for shard in cpu_usage_by_shard:
            for abc in abc_weights[shard]:
                cpu_by_abc[abc] += abc_weights[shard][abc] * cpu_usage_by_shard[shard]

        for abc in cpu_by_abc:
            stats_by_rule[abc] = Counter()
            for shard in abc_weights:
                for rule in abc_stats[shard].get(abc, {}):
                    if self.begemot_info.get_abc(rule, shard) == abc:
                        stats_by_rule[abc][rule] += abc_stats[shard][abc][rule] * cpu_usage_by_shard[shard]
            stats_by_rule[abc] = dict(stats_by_rule[abc])

        for abc in cpu_by_abc:
            stats_by_shard[abc] = {}
            for shard in abc_weights:
                if abc in abc_weights[shard]:
                    stats_by_shard[abc][shard] = abc_weights[shard][abc] * cpu_usage_by_shard[shard]

        return dict(cpu_by_abc), stats_by_rule, stats_by_shard

    def get_memory_usage_by_fresh(self):
        fresh_stats = {}
        released_fresh = sdk2.Resource['BEGEMOT_FAST_BUILD_FRESH_CONFIG'].find(state='READY', attrs={'released': 'stable'}).first()
        with open(str(sdk2.ResourceData(released_fresh).path), 'r') as f:
            for res in json.load(f)['resources']:
                fresh_stats[res['name']] = float(res['resource_size_kb']) / (1024 * 1024)
        return fresh_stats

    def get_instance_number_by_service(self, session, service, geo, retries=5):
        for g in BegemotQuotaUsageByRule.geos:
            if g in service and geo != g:
                return 0

        while retries > 0:
            try:
                response = session.request('POST', 'http://yp-lite-ui.nanny.yandex-team.ru/api/yplite/pod-sets/ListPods/', json={
                    'serviceId': service,
                    'cluster': geo.upper()
                }).json()
                return int(response['total'])
            except:
                time.sleep(5)
                retries -= 1

        raise TaskFailure('Failed to get instance number for service {} (geo={})'.format(service, geo))

    def count_instances_by_shard(self, shards):
        instance_count = {}

        nanny_token = sdk2.Vault.data(self.Parameters.nanny_token)
        session = requests.Session()
        session.headers['Authorization'] = 'OAuth {}'.format(nanny_token)
        session.headers['Content-Type'] = 'application/json'

        for shard in shards:
            instance_count[shard] = 0
            services = []
            for geo in BegemotQuotaUsageByRule.geos:
                prod = '_prod'
                if 'Spellchecker' in shard:
                    prod = '' if 'Exp' in shard or shard == 'Spellchecker' else '_production'
                services.append('{}{}_{}'.format(self.begemot_info.get_prod_template(shard), prod, geo))
            try:
                hamster = self.begemot_info.get_hamster(shard)
                if 'hamster' in hamster and 'SpellcheckerExp' not in shard:
                    services.append(hamster)
            except:
                pass

            instance_count[shard] += self.begemot_info.get_betas_count(shard)

            for service in services:
                for geo in BegemotQuotaUsageByRule.geos:
                    instance_count[shard] += self.get_instance_number_by_service(session, service, geo)

        return instance_count

    def get_memory_usage_by_shard(self, release_task, instance_count):
        usage_by_abc, usage_by_rule, usage_by_shard = {}, {}, {}
        fresh_stats = self.get_memory_usage_by_fresh()

        for shard in self.Parameters.shards:
            resource = sdk2.Resource.find(type=self.begemot_info.get_resource_name(shard), task_id=release_task).first()
            with open(str(sdk2.ResourceData(resource).path), 'r') as f:
                for res in json.load(f)['resources']:
                    rule, size_gb = res['name'], float(res['resource_size_kb']) / (1024 * 1024)
                    abc = self.begemot_info.get_abc(rule, shard)
                    size_gb *= instance_count[shard]

                    if abc not in usage_by_abc:
                        usage_by_abc[abc] = 0
                        usage_by_rule[abc] = {}
                        usage_by_shard[abc] = {}
                    if rule not in usage_by_rule[abc]:
                        usage_by_rule[abc][rule] = 0
                    if shard not in usage_by_shard[abc]:
                        usage_by_shard[abc][shard] = 0

                    usage_by_abc[abc] += size_gb
                    usage_by_rule[abc][rule] += size_gb
                    usage_by_shard[abc][shard] += size_gb

                    if rule in fresh_stats:
                        rule_fresh = '{} (fresh data)'.format(rule)
                        if rule_fresh not in usage_by_rule[abc]:
                            usage_by_rule[abc][rule_fresh] = 0
                        usage_by_abc[abc] += fresh_stats[rule]
                        usage_by_rule[abc][rule_fresh] += fresh_stats[rule]
                        usage_by_shard[abc][shard] += fresh_stats[rule]

        return usage_by_abc, usage_by_rule, usage_by_shard

    def calculate_disk_usage(self, ram_by_abc, ram_by_rule, ram_by_shard, multiplier=4):
        # Disk usage ~ ram_usage * number_of snapshots
        disk_by_abc, disk_by_rule, disk_by_shard = copy.deepcopy(ram_by_abc), copy.deepcopy(ram_by_rule), copy.deepcopy(ram_by_shard)
        for abc in disk_by_abc:
            disk_by_abc[abc] *= multiplier
            for rule in disk_by_rule[abc]:
                disk_by_rule[abc][rule] *= multiplier
            for shard in disk_by_shard[abc]:
                disk_by_shard[abc][shard] *= multiplier

        return disk_by_abc, disk_by_rule, disk_by_shard

    def push_to_solomon(self, cpu_by_abc, token, service):
        labels = {
            'project': 'begemot',
            'cluster': self.Parameters.solomon_cluster,
            'service': service,
        }
        sensors = [{
            'ts': time.time(),
            'labels': {
                'sensor': abc
            },
            'value': usage
        } for abc, usage in cpu_by_abc.items()]
        solomon.push_to_solomon_v2(token, labels, sensors, common_labels=())

    def print_stats(self, abc_weights, stats, label, suffix, multiplier=1):
        info = [label, '']
        for abc, weight in sorted(abc_weights.items(), key=operator.itemgetter(1), reverse=True):
            info.append('{}: {:.3f}{}'.format(abc, weight * multiplier, suffix))
            for item, w in sorted(stats[abc].items(), key=operator.itemgetter(1), reverse=True):
                info.append('  - {}: {:.3f}{}'.format(item, w * multiplier, suffix))
            info.append('')

        self.set_info('\n'.join(info))

    def create_emails(self, cpu_by_abc, stats_by_rule, cpu_service):
        arcadia_path = Arcadia.checkout(self.Parameters.checkout_arcadia_from_url, 'arcadia', depth=Svn.Depth.IMMEDIATES)
        email_template_path = os.path.join(arcadia_path, 'sandbox/projects/websearch/begemot/tasks/BegemotQuotaUsageByRule/email_template.txt')
        Arcadia.update(email_template_path, depth=Svn.Depth.IMMEDIATES, parents=True)
        with open(email_template_path, 'r') as f:
            email_template = f.read()

        emails = {}
        for abc, cpu_usage in cpu_by_abc.items():
            email_text = email_template
            email_text = email_text.replace('<used_cores>', '{:.3f}'.format(cpu_usage))
            email_text = email_text.replace('<cores_to_reserve>', '3*{:.3f} = {:.3f}'.format(cpu_usage, 3 * cpu_usage))
            email_text = email_text.replace('<task_url>', get_task_link(self.id))
            email_text = email_text.replace('<chart_url>', 'https://solomon.yandex-team.ru/?project=begemot&cluster={}&service={}&graph=auto&b=31d'.format(self.Parameters.solomon_cluster, cpu_service))
            email_text = email_text.replace('<rules_list>', '\n'.join(sorted(stats_by_rule[abc].keys())))
            emails[abc] = email_text

        return emails

    def get_receivers(self, abc, token):
        hardware_managers = []
        quota_managers = []
        admins = []
        heads = []

        response = requests.get(
            url='https://abc-back.yandex-team.ru/api/v4/services/members/?service__slug={}'.format(abc),
            headers={'Authorization': 'OAuth {}'.format(token)},
        )
        for result in json.loads(response.text)['results']:
            if result['role']['code'] == 'hardware_resources_manager':
                hardware_managers.append(result['person']['login'])
            if result['role']['code'] == 'quotas_manager':
                quota_managers.append(result['person']['login'])
            if result['role']['scope']['slug'] == 'administration':
                admins.append(result['person']['login'])
            if result['role']['code'] == 'product_head':
                heads.append(result['person']['login'])

        for l in [hardware_managers, quota_managers, admins]:
            if l:
                return l
        return heads

    def send_or_print_email(self, text, login):
        #  'testing_receiver' parameter rewrites all receivers' logins
        receiver = self.Parameters.testing_receiver if self.Parameters.testing_receiver else login
        if self.Parameters.print_emails:
            self.set_info('Email for {}:\n\nSubject: {}\n\n{}'.format(receiver, self.Parameters.subject, text), do_escape=False)
        if self.Parameters.testing_receiver or self.Parameters.send_emails:
            headers = ['Reply-to: req-wizard@yandex-team.ru']
            channel.sandbox.send_email([receiver], None, self.Parameters.subject, text.replace('\n', '<br>'), 'text/html', 'utf-8', extra_headers=headers)
            self.set_info('Sent email to {}'.format(receiver))

    def on_execute(self):
        begemot_binary = self.Parameters.begemot_binary
        released_binary = sdk2.Resource["BEGEMOT_EXECUTABLE"].find(state='READY', attrs={'released': 'stable'}).first()
        release_task = released_binary.task_id
        if not begemot_binary:
            begemot_binary = released_binary

        bgschema = json.loads(subprocess.check_output([
            str(sdk2.ResourceData(begemot_binary).path),
            '--print-bgschema'
        ]))

        self.begemot_info = BegemotInfo(bgschema['RuleInfos'])

        if self.Parameters.count_cpu:
            shards_to_check = []
            for shard in self.Parameters.shards:
                # Prod and exp spellcheckers use common prj, do not count them twice
                if 'Spellchecker' in shard and 'Exp' in shard and shard[:-3] in self.Parameters.shards:
                    continue
                shards_to_check.append(shard)

            cpu_usage_by_shard = self.get_cpu_usage_by_shard(shards_to_check)

            abc_weights, abc_stats = {}, {}
            for shard in shards_to_check:
                abc_weights[shard], abc_stats[shard] = self.get_rule_time_weights(shard)

            cpu_by_abc, stats_by_rule, stats_by_shard = self.get_cpu_by_abc(abc_weights, abc_stats, cpu_usage_by_shard)
            self.print_stats(cpu_by_abc, stats_by_rule, 'Cores used by rule', ' cores')
            self.print_stats(cpu_by_abc, stats_by_shard, 'Cores used by shard', ' cores')

            for shard in abc_weights:
                self.print_stats(abc_weights[shard], abc_stats[shard], 'Time proportions for rules in shard {}'.format(shard), '%', 100)

            if self.Parameters.report_to_solomon:
                self.push_to_solomon(cpu_by_abc, sdk2.Vault.data(self.Parameters.solomon_token), 'cpu_usage')

            if self.Parameters.create_emails:
                abc_token = sdk2.Vault.data(self.Parameters.abc_token)
                emails_by_abc = self.create_emails(cpu_by_abc, stats_by_rule, 'cpu_usage')
                emails = {}
                for abc, text in emails_by_abc.items():
                    receivers = self.get_receivers(abc, abc_token)
                    for receiver in receivers:
                        emails[receiver] = text.replace('<username>', receiver)

                for login, email in emails.items():
                    self.send_or_print_email(email, login)

        if self.Parameters.count_disk:
            instances_by_shard = self.count_instances_by_shard(self.Parameters.shards)
            ram_by_abc, ram_by_rule, ram_by_shard = self.get_memory_usage_by_shard(release_task, instances_by_shard)
            self.print_stats(ram_by_abc, ram_by_rule, 'RAM quota used by rule (disk usage is RAM usage * 4)', ' GB')
            self.print_stats(ram_by_abc, ram_by_shard, 'RAM quota used by shard', ' GB')

            disk_by_abc, disk_by_rule, disk_by_shard = self.calculate_disk_usage(ram_by_abc, ram_by_rule, ram_by_shard)
            self.print_stats(disk_by_abc, disk_by_rule, 'Disk quota used by rule', ' GB')
            self.print_stats(disk_by_abc, disk_by_shard, 'Disk quota used by shard', ' GB')

            if self.Parameters.report_to_solomon:
                self.push_to_solomon(disk_by_abc, sdk2.Vault.data(self.Parameters.solomon_token), 'disk_usage')
                self.push_to_solomon(ram_by_abc, sdk2.Vault.data(self.Parameters.solomon_token), 'ram_usage')
