import time
import contextlib

import porto


PortoBaseError = (porto.exceptions.EError, porto.exceptions.SocketError, porto.exceptions.SocketTimeout)
PortoBroken = (porto.exceptions.SocketError, porto.exceptions.SocketTimeout)
ContainerDoesNotExist = porto.exceptions.ContainerDoesNotExist
ContainerAlreadyExists = porto.exceptions.ContainerAlreadyExists
LimitProblems = (porto.exceptions.InvalidValue, porto.exceptions.ResourceNotAvailable)


def null_on_raise(func, *args, **kwargs):
    try:
        return func(*args, **kwargs)
    except:
        return None


def repeat_on_dead_porto(func, *args, **kwargs):
    while True:
        try:
            return func(*args, **kwargs)
        except PortoBroken:
            time.sleep(1)
            continue
        except Exception as ex:
            raise ex


def get_connection(timeout=120):
    conn = porto.Connection(timeout=timeout)
    conn.connect()
    return conn


@contextlib.contextmanager
def connect():
    connection = None
    try:
        connection = porto.Connection()
        yield connection
    finally:
        if connection:
            connection.disconnect()


def cores_to_percent(cpu):
    cpu = float(cpu[:-1]) * 100
    return cpu


def percent_to_cores(number):
    cpu = number / 100.0
    return bytes(cpu) + b'c'


def get_cpu_limit(cont_name, connection=None):
    connection = connection or get_connection()
    return cores_to_percent(connection.GetProperty(bytes(cont_name), b'cpu_limit'))


def get_mem_limit(cont_name, connection=None):
    connection = connection or get_connection()
    return long(connection.GetProperty(bytes(cont_name), b'memory_limit'))


def set_cpu_limit(cont_name, percents, connection=None):
    connection = connection or get_connection()
    return connection.SetProperty(bytes(cont_name), b'cpu_limit', percent_to_cores(percents))


def set_mem_limit(cont_name, value, connection=None):
    connection = connection or get_connection()
    if isinstance(value, (float, int)):
        value = bytes(value)
    return connection.SetProperty(bytes(cont_name), b'memory_limit', value)


def get_cpu_guarantee(cont_name, connection=None):
    connection = connection or get_connection()
    return cores_to_percent(connection.GetProperty(bytes(cont_name), b'cpu_guarantee'))


def get_cpu_policy(cont_name, connection=None):
    connection = connection or get_connection()
    return connection.GetProperty(bytes(cont_name), b'cpu_policy')


def get_mem_guarantee(cont_name, connection=None):
    connection = connection or get_connection()
    return long(connection.GetProperty(bytes(cont_name), b'memory_guarantee'))


def set_cpu_guarantee(cont_name, percents, connection=None):
    connection = connection or get_connection()
    return connection.SetProperty(bytes(cont_name), b'cpu_guarantee', percent_to_cores(percents))


def set_mem_guarantee(cont_name, value, connection=None):
    connection = connection or get_connection()
    if isinstance(value, (float, int)):
        value = bytes(value)
    return connection.SetProperty(bytes(cont_name), b'memory_guarantee', value)


def find_container(cont_name, connection=None):
    connection = connection or get_connection()
    return connection.Find(bytes(cont_name))


def get_state(cont_name, connection=None):
    connection = connection or get_connection()
    return connection.GetProperty(bytes(cont_name), b'state')


def get_mem_usage(cont_name, connection=None):
    connection = connection or get_connection()
    return long(connection.GetProperty(bytes(cont_name), b'memory_usage'))


def get_cpu_usage(cont_name, connection=None):
    connection = connection or get_connection()
    return long(connection.GetProperty(bytes(cont_name), b'cpu_usage'))


def _get_total(cont_name, property_name, connection, str_in_key=''):
    res = connection.GetProperty(bytes(cont_name), bytes(property_name)).split(';')
    total = 0
    for kv in res:
        kv = kv.strip().split(':')
        if len(kv) == 2 and str_in_key in kv[0]:
            total += int(kv[1].strip())
    return total


def get_total_net_limit(cont_name, connection=None):
    connection = connection or get_connection()
    return _get_total(cont_name, b'net_limit', connection)


def get_total_net_tx_bytes(cont_name, connection=None):
    connection = connection or get_connection()
    return _get_total(cont_name, b'net_tx_bytes', connection, 'group')


def get_total_net_rx_bytes(cont_name, connection=None):
    connection = connection or get_connection()
    return _get_total(cont_name, b'net_rx_bytes', connection, 'group')


def get_total_io_write(cont_name, connection=None):
    connection = connection or get_connection()
    return _get_total(cont_name, b'io_write', connection)


def get_total_io_read(cont_name, connection=None):
    connection = connection or get_connection()
    return _get_total(cont_name, b'io_read', connection)


def get_virt_mode(cont_name, connection=None):
    connection = connection or get_connection()
    return connection.GetProperty(bytes(cont_name), b'virt_mode')


def get_root(cont_name, connection=None):
    connection = connection or get_connection()
    return connection.GetProperty(bytes(cont_name), b'root')


def get_hostname(cont_name, connection=None):
    connection = connection or get_connection()
    return connection.GetProperty(bytes(cont_name), b'hostname')


def get_env(cont_name, connection=None):
    connection = connection or get_connection()
    return connection.GetProperty(bytes(cont_name), b'env')


def get_command(cont_name, connection=None):
    connection = connection or get_connection()
    return connection.GetProperty(bytes(cont_name), b'command')


def get_time(cont_name, connection=None):
    connection = connection or get_connection()
    return int(connection.GetProperty(bytes(cont_name), b'time'))


def list_containers(mask=None, connection=None):
    connection = connection or get_connection()
    return connection.List(mask)


def destroy_container(cont_name, connection=None):
    connection = connection or get_connection(10)
    repeat_on_dead_porto(connection.Destroy, bytes(cont_name))


def ensure_container_path(path, connection=None):
    connection = connection or get_connection()

    p = path.split('/')[:-1]
    for i in xrange(len(p)):
        container = '/'.join(p[:i+1])
        try:
            repeat_on_dead_porto(connection.Create, bytes(container))
        except ContainerAlreadyExists:
            pass


def unlink_volume(path, connection=None):
    connection = connection or get_connection()
    return repeat_on_dead_porto(connection.DestroyVolume, bytes(path))


def ensure_destroyed(porto_name):
    try:
        destroy_container(porto_name)
    except ContainerDoesNotExist:
        return
