# -*- coding: utf-8 -*-

import json
from collections import defaultdict

import bot_api
import gencfg_api
import walle_api_wrapper

from infra.yp.account_estimation import (
    Account,
    bytes_to_terabytes,
    ceil_with_precision,
    add_service_account_to_total_account
)

dcs_mapping = {
    'MANTSALA': 'man',
    'VLADIMIR': 'vla',
    'SAS': 'sas',
    'IVA': 'iva',
    'MYT': 'myt'
}

PREPROCESSED_BANNED_CPUS = [
    "epyc 7351",
    "epyc-7351",
    "xeon5530",
    "xeon e5530",
    "xeon5620",
    "xeon e5620",
    "xeon5440",
    "xeon e5440",
    "xeon5645",
    "xeon e5645",
    "xeon5670",
    "xeon e5670",
    "xeon5675",
    "xeon e5675",
    "opteron 6172",
    "opteron 6176",
    "opteron 6274",
]

BANNED_CPUS = (
    set(PREPROCESSED_BANNED_CPUS)
    | set(cpu.upper() for cpu in PREPROCESSED_BANNED_CPUS)
    | set(cpu.replace(' ', '') for cpu in PREPROCESSED_BANNED_CPUS)
    | set(cpu.replace(' ', '').upper() for cpu in PREPROCESSED_BANNED_CPUS)
)

BANNED_WALLE_TAGS = [
    'rtc',
    'runtime',
]

GPU_ABC_MODEL_MAPPING = {
    'TESLA M2090': 'gpu_tesla_m2090',
    'NVIDIA - TESLA M2070': 'gpu_tesla_m2070',
    'NVIDIA - GEFORCE GTX1080TI': 'gpu_geforce_1080ti',
    'NVIDIA - GTX1080': 'gpu_geforce_1080ti',
    'TESLA K40': 'gpu_tesla_k40',
    'TESLA M40': 'gpu_tesla_m40',
    'NVIDIA - TESLA P40': 'gpu_tesla_p40',
    'NVIDIA - TESLA V100': 'gpu_tesla_v100',
    'V100-SXM2-32GB': 'gpu_tesla_v100',
    'NVIDIA - A100': 'gpu_tesla_a100_40',
    'A100-SXM4-80GB': 'gpu_tesla_a100_80',
}

HOST_NOT_FOUND_ERROR = 'В bot нет информации о хосте {} (опознанном по {}).'
DC_NOT_FOUND_ERROR = 'При запросе в бот хоста {} получен неизвестный ДЦ: {}.'
FQDN_DOES_NOT_EXIST_ERROR = 'Хост {} не имеет FQDN, так как его ещё не забрали из предзаказа. Вам потребуется сначала забрать данный хост из предзаказа.'
CPU_BANNED_ERROR = 'Эти хосты нельзя сдать потому что их CPU находится в чёрном списке'
WALLE_TAGS_BANNED_ERROR = 'Эти хосты нельзя сдать, потому что они находятся в wall-e проектах принадлежащих RTC и размеченных тэгами RTC'
NO_SSD_ERROR = 'В dev сегмент не допускается ввод хостов без SSD/NVME.\nСписок хостов, которые мы не отправили на ввод по этой причине'
GPU_MODEL_NOT_FOUND = 'В базе нет информации для gpu model: "{}"'


class QuotaPerHost:
    def __init__(self, cpumodels, name, hostdata=None, abc_segment='', ignore_banned_lists=False):
        '''
        :param name: FQDN or inventory number of host
        '''
        self.name = name
        self.hostdata = hostdata
        self.inv = hostdata['inv'] if hostdata else ''
        self.fqdn = name if hostdata else ''
        self.location = hostdata['dc'] if hostdata else ''
        self.abc_segment = abc_segment
        self.error_messages = []
        self.cpu_model = hostdata['cpu_model'] if hostdata else ''
        self.gpu_model = hostdata['gpu_model'] if hostdata else None
        self.gpu_count = int(hostdata['gpu_count']) if hostdata else 0
        self.has_ssd = True if (hostdata and hostdata['ssd']) else False
        self.is_bot_model = False if hostdata else True
        self.hosttype = 'inv' if name.isdigit() else 'fqdn'
        self.account = Account.empty()
        self.bot_data = dict()
        self.jbog_bot_data = dict()
        self.cpumodels = cpumodels

        if not hostdata:
            response = bot_api.get_info_about_host(self.name, self.hosttype)
            if response['res'] == 1:
                self.bot_data = response['data']
                self.inv = self.bot_data['instance_number']
                self.detect_jbog_data()
                self.detect_cpu_model()
                self.detect_gpu_model_and_count()
                self.detect_ssd_presence()
                try:
                    self.location = dcs_mapping[self.bot_data['loc_segment2']]
                except KeyError:
                    self.append_error(DC_NOT_FOUND_ERROR.format(self.name, self.bot_data['loc_segment2']))
                try:
                    self.fqdn = self.bot_data['XXCSI_FQDN']
                except KeyError:
                    self.append_error(FQDN_DOES_NOT_EXIST_ERROR.format(self.name))
                if not ignore_banned_lists and self.cpu_model in BANNED_CPUS:
                    self.append_error((CPU_BANNED_ERROR, self.cpu_model, self.get_host_desc()))
            else:
                self.append_error(HOST_NOT_FOUND_ERROR.format(self.name, self.hosttype))

        if self.gpu_model:
            if self.gpu_model not in GPU_ABC_MODEL_MAPPING:
                self.append_error(GPU_MODEL_NOT_FOUND.format(self.gpu_model))
            else:
                self.gpu_model = GPU_ABC_MODEL_MAPPING[self.gpu_model]

        if self.abc_segment == 'dev' and not self.has_ssd:
            self.append_error((NO_SSD_ERROR, self.get_host_desc()))

        if not ignore_banned_lists and self.inv:
            host_walle_info = walle_api_wrapper.get_hosts(invs=[int(self.inv)], fields=['inv', 'tags'])
            if host_walle_info['result']:
                host_walle_tags = set(host_walle_info['result'][0].get('tags', []) or [])
                host_banned_tags = [tag for tag in BANNED_WALLE_TAGS if tag in host_walle_tags]
                if host_banned_tags:
                    self.append_error((WALLE_TAGS_BANNED_ERROR, ','.join(host_banned_tags), self.get_host_desc()))

    def detect_jbog_data(self):
        for connected in self.bot_data.get('Connected', []):
            if connected['item_segment3'] == 'NODE-GPU' and connected['item_segment4'] == 'SRV':
                jbog_inv = connected['instance_number']
                response = bot_api.get_info_about_host(jbog_inv, hosttype='inv')
                if response['res'] == 1:
                    self.jbog_bot_data = response['data']

    def get_host_desc(self):
        if self.inv:
            return self.inv
        elif self.fqdn:
            return self.fqdn
        else:
            return self.name

    def get_error(self):
        return self.error_messages

    def append_error(self, error):
        self.error_messages.append(error)

    def calculate_specs(self):
        if not self.get_error():
            if self.is_bot_model:
                self.calculate_cpu_quota()
                self.calculate_ram_quota()
                self.calculate_disk_quota()
                self.translate_units()
            else:
                self.calculate_cpu_quota()
                self.account.memory = float(self.hostdata['memory'])
                self.account.hdd = float(self.hostdata['disk'])
                self.account.ssd = float(self.hostdata['ssd'])

            if self.gpu_count > 0:
                self.account.gpu_models[self.gpu_model] = self.gpu_count

        return self.get_host_info()

    def _get_cpu_cores(self, use_power=True):
        return gencfg_api.get_cpu_cores(self.cpumodels, self.cpu_model, self.is_bot_model, use_power=use_power)

    def _get_cpu_model_variations(self, cpu_name):
        trimmed_cpu_name = cpu_name.replace(' ', '')
        return [
            cpu_name,
            cpu_name.upper(),
            cpu_name.lower(),
            trimmed_cpu_name,
            trimmed_cpu_name.upper(),
            trimmed_cpu_name.lower(),
        ]

    def detect_cpu_model(self):
        if self.is_bot_model:
            cpu_model_candidates = []
            for component in self.bot_data['Components']:
                if component['item_segment3'] == 'CPU' and component['item_segment4'] == 'SRV':
                    components_cpu = component['attribute12']
                    cpu_model_candidates.extend(self._get_cpu_model_variations(components_cpu))
                    alt_components_cpu = '{} {}'.format(component['attribute11'], component['attribute12'])
                    cpu_model_candidates.extend(self._get_cpu_model_variations(alt_components_cpu))
                    break
            if self.bot_data['item_segment3'] == 'SERVERS' and self.bot_data['item_segment4'] == 'SRV':
                segments_cpu = self.bot_data['attribute17']
                cpu_model_candidates.extend(self._get_cpu_model_variations(segments_cpu))
            final_segment_cpu = self.bot_data['item_segment2'].partition('/')[0]
            cpu_model_candidates.extend(self._get_cpu_model_variations(final_segment_cpu))
            for cpu_candidate in cpu_model_candidates:
                if cpu_candidate in self.cpumodels['botmodel_to_model']:
                    self.cpu_model = cpu_candidate
                    break
            if self.cpu_model not in self.cpumodels["botmodel_to_model"]:
                self.cpu_model = final_segment_cpu

    def detect_gpu_model_and_count(self):
        self.gpu_count = 0
        for component in self.bot_data['Components']:
            if component['item_segment2'] == 'GPU' and component['item_segment4'] == 'SRV':
                self.gpu_count += 1
                if component['attribute20'] == 'N/A':
                    self.gpu_model = component['attribute12']
                else:
                    self.gpu_model = component['attribute20']
        # если не нашли GPU в самом хосте - поищем в jbog
        if self.jbog_bot_data and not self.gpu_model and not self.gpu_count:
            for component in self.jbog_bot_data.get('Components'):
                if component['item_segment2'] == 'GPU' and component['item_segment4'] == 'SRV':
                    self.gpu_count += 1
                    if not component.get('attribute20') or component.get('attribute20') == 'N/A':
                        self.gpu_model = component['attribute12']
                    else:
                        self.gpu_model = component['attribute20']

    def detect_ssd_presence(self):
        self.has_ssd = False
        for component in self.bot_data['Components']:
            if (component['item_segment3'] == 'DISKDRIVES' and component['item_segment4'] == 'SRV' and
                    component['attribute16'] == 'SSD'):
                self.has_ssd = True
                break

    def calculate_cpu_quota(self):
        self.account.cpu = self._get_cpu_cores()

    def calculate_ram_quota(self):
        for component in self.bot_data['Components']:
            if component['item_segment3'] == 'RAM' and component['item_segment4'] == 'SRV':
                self.account.memory += int(component['attribute13'])

    def calculate_disk_quota(self):
        for component in self.bot_data['Components']:
            if component['item_segment3'] == 'DISKDRIVES' and component['item_segment4'] == 'SRV':
                # NOTE: bot counts this bot_data is powers of 10
                if component['attribute16'] == 'HDD':
                    self.account.hdd += int(component['attribute14'])
                elif component['attribute16'] == 'SSD':
                    self.account.ssd += int(component['attribute14'])

    def translate_units(self):
        # NOTE: translate ssd and hdd to powers of 2 and then to terabytes
        self.account.ssd = ceil_with_precision(bytes_to_terabytes(self.account.ssd * 1000 * 1000 * 1000), 2)
        self.account.hdd = ceil_with_precision(bytes_to_terabytes(self.account.hdd * 1000 * 1000 * 1000), 2)

    def apply_tax(self, discount_mode=None):
        # NOTE: before we apply tax, we have just specs of host, after apply tax we have actual quota for this host
        # NOTE: memory is in powers of 2, so no need to recalculate it
        if discount_mode is None:
            discount_mode = "full"

        cpu_original = self._get_cpu_cores(use_power=discount_mode != "raw")
        if discount_mode not in ["raw", "power_only", "gencfg"]:
            if cpu_original >= 200:
                cpu_original -= 8
                self.account.memory -= 32.0
            else:
                cpu_original -= 2
                self.account.memory -= 7.0

        self.account.cpu = cpu_original

        if discount_mode in ["full", "gencfg"]:
            self.account.cpu = ceil_with_precision(min(self.account.cpu, int((self.account.memory) / 4)), 1)
            self.account.memory = 4 * self.account.cpu
            MAX_DISK_SPACE = 2  # 2T
            discount = self.account.cpu / cpu_original
            self.account.hdd = min(self.account.hdd, MAX_DISK_SPACE) * discount
            self.account.ssd = min(self.account.ssd, MAX_DISK_SPACE) * discount

        self.account.io_ssd = self.account.cpu * 5 if self.account.ssd > 0 else 0
        self.account.io_hdd = self.account.cpu * 2 if self.account.hdd > 0 else 0
        self.account.net_bandwidth = self.account.cpu * 7

    def get_host_info(self):
        return {'inv': self.inv,
                'fqdn': self.fqdn,
                'location': self.location,
                'cpu_model': self.cpu_model if not self.get_error() else '',
                'account': self.account,
                'gpu_model': self.gpu_model,
                'gpu_count': self.gpu_count}


class HostsQuota:
    def __init__(self, hosts, hostdata=None, abc_segment='', ignore_banned_lists=False):
        '''
        Сначала мы считаем индвивидуальные характеристики каждого хоста, чтобы показать пользователю, что он нам принёс.
        Затем мы считаем квоту и группируем по дц
        '''
        self.hosts = hosts
        self.host_specs = []
        self.errors = []
        self.by_dc = dict()
        cpumodels = gencfg_api.get_cpumodels()
        processed = set()
        for host in self.hosts:
            if hostdata:
                individual_host_specs = QuotaPerHost(cpumodels, host, hostdata[host], abc_segment=abc_segment,
                                                     ignore_banned_lists=ignore_banned_lists)
            else:
                individual_host_specs = QuotaPerHost(cpumodels, host, abc_segment=abc_segment,
                                                     ignore_banned_lists=ignore_banned_lists)
            if individual_host_specs.inv not in processed or individual_host_specs.inv == '':
                # to prevent duplicates in the same request
                processed.add(individual_host_specs.inv)
                individual_host_specs.calculate_specs()
                if individual_host_specs.get_error():
                    self.errors.extend(individual_host_specs.get_error())
                else:
                    self.host_specs.append(individual_host_specs)
        self.errors = flatten_messages(self.errors)

    def get_individual_specs(self):
        return [host_spec.get_host_info() for host_spec in self.host_specs]

    def get_total_quota(self, discount_mode=None):
        if not self.by_dc:
            # NOTE: first we calculate specs of individual hosts, show it to the user, then we apply tax for every host and then we calculate quota
            for i in range(len(self.host_specs)):
                self.host_specs[i].apply_tax(discount_mode=discount_mode)
            for i in range(len(self.host_specs)):
                add_service_account_to_total_account(
                    {
                        self.host_specs[i].get_host_info()['location']:
                            self.host_specs[i].get_host_info()['account']
                    }, self.by_dc)
            self.by_dc['total'] = Account.empty()
            for host_spec in self.host_specs:
                self.by_dc['total'].add(host_spec.get_host_info()['account'])

        for dc in self.by_dc:
            self.by_dc[dc].ssd = ceil_with_precision(self.by_dc[dc].ssd, 3)
            self.by_dc[dc].hdd = ceil_with_precision(self.by_dc[dc].hdd, 3)

        return self.by_dc

    def get_errors(self):
        return self.errors


def calculate_and_format_quota_from_hostsdata(hosts, hostsdata, abc_info, abc_segment, discount_mode=None,
                                              use_format="text", ignore_banned_lists=False):
    quotas = HostsQuota(hosts, hostsdata, abc_segment=abc_segment, ignore_banned_lists=ignore_banned_lists)
    individuals = quotas.get_individual_specs()
    return _calculate_and_format_quota(quotas, individuals, abc_info, abc_segment,
                                       hosts_stat=False,
                                       console_flag=True,
                                       discount_mode=discount_mode,
                                       use_format=use_format,
                                       )


def calculate_and_format_quota_from_hosts_file(hosts, abc_info, abc_segment, console_flag=False, discount_mode=None,
                                               use_format="text", ignore_banned_lists=False):
    quotas = HostsQuota(hosts, abc_segment=abc_segment, ignore_banned_lists=ignore_banned_lists)
    individuals = quotas.get_individual_specs()
    return _calculate_and_format_quota(quotas, individuals, abc_info, abc_segment,
                                       hosts_stat=True,
                                       console_flag=console_flag,
                                       discount_mode=discount_mode,
                                       use_format=use_format,
                                       )


def flatten_messages(messages):
    """
    flattens messages::

       turns:
          [
             'Problem 1'
             ('Problem 2', '444'),
             ('Problem 3', '123'),
             ('Problem 2', '555'),
             ('Problem 3', '456'),
             'Problem 4',
          ]
       to:
          [
             'Problem 1'
             'Problem 2: 444, 555',
             'Problem 3: 123, 456',
             'Problem 4',
          ]
    :param messages: list of messages
    :return: list of flattened messages
    """
    out = []
    out_dict = defaultdict(list)
    for row in messages:
        if isinstance(row, tuple):
            out_dict[row[0]].append(row[1:] if len(row) > 2 else row[1])
        else:
            out.append(row)

    for group, rows in out_dict.items():
        if isinstance(rows, (str, unicode)):
            out.append('{}:\n{}'.format(group, rows))
        else:
            its_final_level = all(isinstance(f, (str, unicode)) for f in rows)
            level_separator = ', ' if its_final_level else '\n'
            group_separator = ': ' if its_final_level else ':\n'
            out.append('{}{}{}'.format(group, group_separator, level_separator.join(flatten_messages(rows))))
    return out


def _calculate_and_format_quota(quotas, individuals, abc_info, abc_segment, hosts_stat=True, console_flag=False,
                                discount_mode=None, use_format="text"):
    # create description for ticket
    description = []
    if not console_flag:
        description.append('**ABC сервис:** {}'.format(abc_info))
        description.append('**Сегмент YP:** {}'.format(abc_segment))
        description.append('**Инвентарные номера / FQDN**')
        description.append('%%\n{}\n%%'.format('\n'.join([desc['inv'] + '/' + desc['fqdn'] for desc in individuals])))
        description.append('**Расчет квоты**')
        description.append('%%')

    if hosts_stat:  # quota by host
        description.append('LOC|INV|FQDN|CPUMODEL|RAM|SSD|HDD|IOSSD|IOHDD|NET_BANDWIDTH|GPU_MODEL|GPU_COUNT')
        for desc in individuals:
            row = '\t'.join(map(str,
                                [
                                    desc['location'].upper(),
                                    desc['inv'],
                                    desc['fqdn'],
                                    desc['cpu_model'],
                                    desc['account'].memory,
                                    desc['account'].ssd,
                                    desc['account'].hdd,
                                    desc['account'].io_ssd,
                                    desc['account'].io_hdd,
                                    desc['account'].net_bandwidth,
                                    desc['gpu_model'],
                                    desc['gpu_count'],
                                ])
                            )
            description.append(row)
        description.append('')

    # get_total_quota change individuals it applys taxes
    by_dc = quotas.get_total_quota(discount_mode=discount_mode)
    if use_format == "json":
        hosts_data = []
        for host_info in individuals:
            hosts_data.append({
                'inv': host_info["inv"],
                'fqdn': host_info["fqdn"],
                'dc': host_info["location"],
                'cpu_model': host_info["cpu_model"],
                'quota': {
                    'cpu': host_info["account"].cpu,
                    'memory': host_info["account"].memory,
                    'ssd': host_info["account"].ssd,
                    'hdd': host_info["account"].hdd,
                    'io_ssd': host_info["account"].io_ssd,
                    'io_hdd': host_info["account"].io_hdd,
                    'net_bandwidth': host_info["account"].net_bandwidth,
                    'gpu_model': host_info["gpu_model"],
                    'gpu_count': host_info["gpu_count"]
                }
            })

        quota_data = dict([(dc, quota.__dict__) for dc, quota in by_dc.iteritems()])

        res = {
            "hosts": hosts_data,
            "quota": quota_data
        }

        return json.dumps(res, indent=4, sort_keys=True), quotas

    # append info by dc
    for dc, quota in by_dc.iteritems():
        if dc == 'total':
            continue
        description.append('{}: {} cores, {}G memory, {}T ssd, {}T hdd, {}MB/s io ssd, {}MB/s io hdd, {}MB/s net bandwidth, GPU models: {}'.format(
            dc.upper(),
            quota.cpu,
            quota.memory,
            quota.ssd,
            quota.hdd,
            quota.io_ssd,
            quota.io_hdd,
            quota.net_bandwidth,
            dict(quota.gpu_models),
        ))
    # add total
    description.append('')
    description.append('{}:\t{} cores, {}G memory, {}T ssd, {}T hdd, {}MB/s io ssd, {}MB/s io hdd, {}MB/s net bandwidth, GPU models: {}'.format(
        'TOTAL',
        by_dc['total'].cpu,
        by_dc['total'].memory,
        by_dc['total'].ssd,
        by_dc['total'].hdd,
        by_dc['total'].io_ssd,
        by_dc['total'].io_hdd,
        by_dc['total'].net_bandwidth,
        dict(by_dc['total'].gpu_models),
    ))
    if not console_flag:
        description.append('%%')

    if not console_flag:
        description.append('<{inv')
        description.append('\n'.join([desc['inv'] for desc in individuals]))
        description.append('}>')

        description.append('<{form')
        description.append('\n'.join(quotas.hosts))
        description.append('}>')

    return '\n'.join([line.decode('utf-8') for line in description]), quotas
