import os
import py
import re
import json
import socket
import tempfile
import subprocess as sp

from sandbox import common


REVISION_PATTERNS = [
    re.compile(ur'Revision:\s(?P<revision>\d+)'),
    re.compile(ur'Last Changed Rev:\s(?P<revision>\d+)')
]


def print_task_type_hosts(task_type, clients_list):
    print '=' * 20, task_type, '=' * 20
    for c in clients_list:
        print c['hostname'], c['platform'], c['model']
    print '=' * 40


def make_pytest_opts(args, unparsed_args=None):
    pytest_args = {
        name: getattr(args, name, None)
        for name in dir(args)
        if not name.startswith("_") and name not in ("func", "clean")
    }
    pytest_args = [
        '--{}{}'.format(k.replace('_', '-'), '' if v is True else '=' + str(v))
        for k, v in pytest_args.iteritems() if v
    ]
    if args.verbose:
        pytest_args.append('-n0')
    pytest_args.append('--local-arcadia-user')
    if unparsed_args:
        pytest_args.extend(unparsed_args)
    return pytest_args


def is_port_free(port):
    """
    Checks whether port is free.

    :return: `True` if port is free, else `False`
    """
    try:
        #  successful connection to port means that it is busy
        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        sock.connect(('127.0.0.1', port))
        sock.close()
    except socket.error:
        return True
    return False


def get_path_revision(path, last_changed_rev=False):
    try:
        devnull = open(os.devnull, 'wb')
        svn_info = sp.check_output(['svn', 'info', path], stderr=devnull, env={'LANG': 'C'})
    except (sp.CalledProcessError, OSError):
        return ''
    result = re.search(REVISION_PATTERNS[last_changed_rev], svn_info)
    return result.group('revision') if result else ''


class PortRange(object):
    def __init__(self, range_start, range_end=None):
        self.start = range_start
        self.end = range_end if range_end is not None else range_start

    def __lt__(self, other):
        return self.start < other.start

    def __contains__(self, item):
        return self.start <= item <= self.end

    def __repr__(self):
        if self.start != self.end:
            return "{}-{}".format(self.start, self.end)
        else:
            return "{}".format(self.start)


class Ports(object):

    SKYNET_PORTS = [PortRange(2399), PortRange(6881, 6999), PortRange(10000, 10199)]
    VALID_RANGE = PortRange(0, 2 ** 16 - 1)

    def __init__(self):
        self.forbidden_ranges = []

        self.serviceapi_port = None
        self.client_port = None
        self.serviceq_port = None
        self.web_server_port = None
        self.taskbox_port = None
        self.fileserver_port = None

    @staticmethod
    def check_port(port, ranges):
        ranges = ranges if isinstance(ranges, list) else [ranges]
        for rng in ranges:
            if port in rng:
                return rng

    def _find_port_consecutively(self, start_port):
        self.forbidden_ranges.sort()
        current_port = start_port + 1
        while current_port and current_port in self.VALID_RANGE:
            r = self.check_port(current_port, self.forbidden_ranges)
            if not r:
                return current_port
            else:
                current_port = r.end + 1

    def _find_port(self, base_port):
        if not self.check_port(base_port, self.VALID_RANGE) or self.check_port(base_port, self.forbidden_ranges):
            base_port = self._find_port_consecutively(base_port)
        if base_port:
            self.forbidden_ranges.append(PortRange(base_port))
        return base_port

    def initialize(self, base_port, registered_ports):
        self.forbidden_ranges.extend(self.SKYNET_PORTS)
        self.forbidden_ranges.extend(registered_ports)

        if base_port not in self.VALID_RANGE:
            raise ValueError("Base port {} is not valid value in the range {}".format(base_port, self.VALID_RANGE))

        if self.check_port(base_port, self.SKYNET_PORTS):
            raise ValueError("Port {} busy for Skynet: {}. Try the other one.".format(base_port, self.SKYNET_PORTS))
        elif self.check_port(base_port, registered_ports):
            raise ValueError("Port {} used by another sandbox installation: {}.".format(base_port, registered_ports))

        self.serviceapi_port = base_port
        self.client_port = self._find_port(base_port + 1)
        self.serviceq_port = self._find_port(base_port + 2)
        self.web_server_port = self._find_port(base_port + 3)
        self.taskbox_port = self._find_port(base_port + 4)
        self.fileserver_port = self._find_port(base_port + 10)

        if not self.client_port or not self.fileserver_port:
            raise ValueError("Can't find free ports for client and fileserver. Try another base port.")


def register_sandbox_install(path, ports):
    if not common.os.User.can_root:
        print "Cannot register sandbox installation at {}, ports {} because user cannot root".format(path, ports)
        return

    metadata_dir = py.path.local("/var/lib/sandbox")
    if not metadata_dir.exists():
        sp.call(["sudo", "mkdir", str(metadata_dir)])

    metadata_path = metadata_dir.join("metadata")
    if metadata_path.exists():
        with metadata_path.open() as f:
            metadata = json.load(f)
    else:
        metadata = {}

    if ports:
        metadata[path] = {"ports": ports}
    else:
        metadata.pop(path, None)

    with tempfile.NamedTemporaryFile("w", delete=False) as tmp:
        json.dump(metadata, tmp)

    sp.call(["sudo", "cp", tmp.name, str(metadata_path)])
    sp.call(["sudo", "chmod", "0644", str(metadata_path)])


def get_registered_ports(own_path):
    metadata_path = py.path.local("/var/lib/sandbox/metadata")
    ports = []
    if metadata_path.exists():
        with metadata_path.open() as f:
            metadata = json.load(f)
        for key, value in metadata.items():
            if key != own_path and py.path.local(key).exists():
                ports.extend(PortRange(p) for p in value["ports"])
    return ports
