#!/usr/bin/python

import json
import time
import sys
import subprocess
import logging

from sandbox.common.errors import TemporaryError, TaskError
from sandbox.common import rest

from sandbox.projects.yabs.qa.performance.memory_utils.smaps import parse_smaps

THP_QUALITY_RES_ID = 162816911


def warmup(warmer_executable, size=80, threads=30):
    out = subprocess.check_output([warmer_executable, '-s', str(size), '-t', str(threads)])
    return get_smaps_stats(out.splitlines())


def get_smaps_stats(smaps_lines):
    smaps = parse_smaps(smaps_lines)
    anon_total = 0.0
    ahp_total = 0.0
    ahp_maps = 0
    for s in smaps:
        if not s:
            continue
        try:
            flags = s['VmFlags'].split()
        except KeyError:
            logging.warning("Skipped section with no VmFlags:\n%s", json.dumps(s))
            continue
        if 'hg' not in flags:
            continue
        ahp_maps += 1
        anon_total += int(s['Anonymous'])
        ahp_total += int(s['AnonHugePages'])
    return ahp_maps, anon_total, ahp_total


_REQUIRED_RAM_CHECK_FAILS_KEY = '__required_ram_check_failed'


def check_required_ram(task, warmer_res_id=162057058):  # FIXME unhardcode res id
    thp_warmer_executable = task.sync_resource(warmer_res_id)
    required_ram_GiB = int(0.93 * task.required_ram / 1024)
    try:
        subprocess.check_output([thp_warmer_executable, '-s', str(required_ram_GiB), '-t', '24'])
    except subprocess.CalledProcessError as err:
        fails = task.ctx.get(_REQUIRED_RAM_CHECK_FAILS_KEY, [])

        msg = "Required RAM check failed. Exitcode %s, process output:\n%s" % (err.returncode, err.output)
        logging.exception(msg)
        task.set_info(msg)

        fqdn = task.client_info['fqdn']
        fails.append(fqdn.split('.')[0])
        task.ctx[_REQUIRED_RAM_CHECK_FAILS_KEY] = fails
        return False

    logging.info("Required RAM check complete: successfully allocated %s GiB of memeory", required_ram_GiB)
    return True


def warm_task_up(task, ahp_required_gib=70, timeout=240, warmer_res_id=162057058):  # FIXME unhardcode res id
    thp_warmer_executable = task.sync_resource(warmer_res_id)
    deadline = time.time() + timeout
    orig_descr = task.descr

    alloc_size = 1.0
    while True:
        count, anon, ahp = warmup(thp_warmer_executable, int(alloc_size))
        anon = anon / 2**20
        ahp = ahp / 2**20
        ahp_fraction = 1.0 * ahp / anon if anon else 0
        logging.info(
            "Acquired %d HP mappings, AnonHugePages %.2f GiB, Anonymous %.2f GiB, AHP/Anon=%.4f",
            count,
            ahp,
            anon,
            ahp_fraction
        )
        task.descr = 'THP={:.2f} GiB {}'.format(ahp, orig_descr)
        if ahp >= ahp_required_gib or time.time() > deadline:
            break
        if ahp_fraction > 0.8:
            alloc_size *= 2
            alloc_size = min(alloc_size, ahp_required_gib)

    logging.info("Finally acquired %s GiB of AnonHugePages", ahp)
    ahp_percent = 100.0 * ahp / ahp_required_gib
    hostname, _, _ = task.client_info['fqdn'].partition('.')
    ahp_data = '{}|{}'.format(int(ahp_percent), int(time.time()))
    try:
        _set_resource_attr(THP_QUALITY_RES_ID, hostname, ahp_data)
    except AttributeSettingError:
        logging.exception("Failed to update attribute of score resource")

    return ahp_percent


class AttributeSettingError(TaskError):
    pass


def _set_resource_attr(res_id, name, value):
    rest_client = rest.Client()

    attr = {'name': name, 'value': str(value)}
    try:
        rest_client.resource[res_id].attribute.create(attr)
    except TemporaryError:
        raise
    except Exception as create_exc:
        try:
            rest_client.resource[res_id].attribute[name].update(attr)
        except TemporaryError:
            raise
        except Exception as update_exc:
            raise AttributeSettingError("Failed to create resource attr: %s and failed to update it too: %s" % (create_exc, update_exc))


def _main():
    logging.basicConfig(level=logging.DEBUG, format="%(levelname)s %(message)s")
    while True:
        count, anon, ahp = warmup(sys.argv[1])
        print("%.2f %d %.2f %.2f %.4f" % (time.time(), count, ahp / 2**20, anon / 2**20, 1.0 * ahp / anon if anon else 0))


if __name__ == '__main__':
    _main()
