# coding: utf-8
from __future__ import print_function

import os
import logging
import logging.handlers
import signal
import errno
import contextlib
import collections
import socket
import itertools
import datetime
import pwd
import grp
import random
import sys
import traceback
import urlparse

import concurrent.futures
import ipaddress
import tornado.ioloop
import tornado.web
import tornado.httpserver
import tornado.netutil
import tornado.locks

import _netmon

from . import backend_maintainer
from . import interfaces
from . import controllers
from . import ticker
from . import encoding
from . import settings
from . import utils
from . import _unistat_callback

from .logger import Logger


class Service(object):
    """Something that can be fused into application."""

    @contextlib.contextmanager
    def create_context(self, app):
        yield

    def cancel(self):
        pass


class AppMixin(object):

    _app = None

    @contextlib.contextmanager
    def create_context(self, app):
        self._app = app
        try:
            yield
        finally:
            self._app = None


class StatHandler(tornado.web.RequestHandler):

    def get(self):
        self.set_header("content-type", "application/json")
        self.write(_netmon.get_metrics())
        self.finish()


class ShutdownHandler(tornado.web.RequestHandler):

    _service = None

    def initialize(self, service):
        self._service = service

    def _has_permission(self):
        return self.request.remote_ip in ("127.0.0.1", "::1")

    def get(self):
        if self._has_permission():
            self.write("OK")
            self.finish()
        else:
            raise tornado.web.HTTPError(401, "Access denied")

    def on_finish(self):
        if self._has_permission():
            self._service.shutdown()


class StatService(Service):

    def __init__(self, port):
        self._port = port
        self._tornado_app = tornado.web.Application([
            (r"/stats", StatHandler),
            (r"/stats/", StatHandler),
            (r"/shutdown", ShutdownHandler, dict(service=self)),
            (r"/shutdown/", ShutdownHandler, dict(service=self)),
        ])
        self._http_server = tornado.httpserver.HTTPServer(self._tornado_app)
        self._app = None

    @contextlib.contextmanager
    def create_context(self, application):
        logging.info("Listening on port %d", self._port)
        self._http_server.listen(self._port)
        self._app = application
        try:
            yield
        finally:
            self._app = None
            self._http_server.stop()

    @tornado.gen.coroutine
    def cancel(self):
        yield self._http_server.close_all_connections()

    def shutdown(self):
        if self._app is not None:
            # call shutdown on next iteration because connection should be closed
            tornado.ioloop.IOLoop.current().add_callback(self._app.shutdown)


class Resolver(tornado.netutil.Resolver):

    _controller = None
    _cache = None
    _netmon_urls = None
    _netmon_url_cache_ttl_secs = 300
    _netmon_url_cache_jitter_secs = 60

    def initialize(self, controller, cache):
        self._controller = controller
        self._cache = cache
        self._netmon_urls = tuple(
            urlparse.urlparse(url).hostname
            for url in (settings.current().netmon_url,) + settings.current().noc_sla_urls if url
        )

    def close(self):
        self._controller = None

    @tornado.gen.coroutine
    def resolve(self, host, port, family=socket.AF_UNSPEC, raise_on_nxdomain=False):
        if family == socket.AF_UNSPEC and not settings.current().dns_resolve_ip4:
            family = socket.AF_INET6

        result = self._cache.get(host, port, family)
        if result is not None:
            raise tornado.gen.Return(result)

        result = yield controllers.defer(self._controller.resolve, host, port, family)
        if result is None or (not result and raise_on_nxdomain):
            raise socket.gaierror()
        else:
            addrs = [_netmon.Address(_family, _addr[0], _addr[1]) for _family, _addr in result]
            if host in self._netmon_urls:
                ttl = self._netmon_url_cache_ttl_secs
                ttl += random.randint(0, self._netmon_url_cache_jitter_secs)
                self._cache.set_with_ttl(host, port, family, addrs, ttl)
            else:
                self._cache.set(host, port, family, addrs)
            raise tornado.gen.Return(result)


class FallbackResolver(tornado.netutil.BlockingResolver):

    _async_resolver = None

    def resolve(self, host, port, family=socket.AF_UNSPEC):
        # properly handle case when host is already ip address
        try:
            ipaddress.ip_address(encoding.safe_unicode(host))
        except ValueError:
            if self._async_resolver is not None:
                return self._async_resolver.resolve(host, port, family, raise_on_nxdomain=True)
        return super(FallbackResolver, self).resolve(host, port, family)


class ResolverService(Service):

    def __init__(self):
        self._resolver = None
        self._cache = _netmon.DnsCache()
        self._loop = ticker.LoopingCall("dns_cleaner_loop", self._cache.cleanup, 60)

    @tornado.gen.coroutine
    def try_resolve(self, host, ignore_errors=True):
        try:
            results = yield self._resolver.resolve(host, 0)
            addresses = [(family, value[0]) for family, value in results]
        except socket.gaierror:
            if ignore_errors:
                addresses = []
            else:
                raise tornado.gen.Return(None)
        raise tornado.gen.Return(addresses)

    @contextlib.contextmanager
    def create_context(self, application):
        controller = _netmon.ResolverController(logging.getLogger("resolver"))
        with controllers.context(controller) as controller:
            self._resolver = Resolver(controller=controller, cache=self._cache)
            FallbackResolver._async_resolver = self._resolver
            try:
                yield
            finally:
                FallbackResolver._async_resolver = None
                self._resolver = None

    @tornado.gen.coroutine
    def cancel(self):
        yield self._loop.cancel()


class BackendMaintainerService(AppMixin, Service):

    def __init__(self):
        if settings.current().noc_sla_urls:
            sd_enabled = bool(
                settings.current().yp_sd_request_interval and settings.current().yp_sd_url
            )
            self._noc_sla_backend_maintainer = backend_maintainer.NocSlaBackendMaintainer(sd_enabled)
            self._loop = ticker.LoopingCall(
                "backend_update_loop",
                lambda: self._noc_sla_backend_maintainer.update_backends(self._app[ResolverService]),
                max(60, settings.current().yp_sd_request_interval),
                start_delay=random.randrange(settings.current().yp_sd_request_interval),
                round_by_interval=True
            )
        else:
            self._noc_sla_backend_maintainer = None
            self._loop = None

    def get_all_backends(self):
        return (settings.current().netmon_url,) + self.get_noc_sla_backends()

    def get_prod_backends(self):
        return (settings.current().netmon_url,)

    def get_noc_sla_backends(self):
        return (self._noc_sla_backend_maintainer.get_backends()
                if self._noc_sla_backend_maintainer is not None
                else ())

    def get_noc_sla_master(self):
        return (self._noc_sla_backend_maintainer.get_master()
                if self._noc_sla_backend_maintainer is not None
                else None)

    @tornado.gen.coroutine
    def cancel(self):
        if self._loop:
            yield self._loop.cancel()


class IfaceService(AppMixin, Service):

    DiscoveredInterface = collections.namedtuple("DiscoveredInterface", (
        "address",
        "mask",
        "family",
        "fqdn",
        "mac",
        "vlan",
        "backbone6",
        "fastbone6"
    ))

    def __init__(self, allow_virtual=False, networks=None, allow_mtn_vlan=False):
        self._addresses = None
        self._discovered_interfaces = None
        self._cached_lookups = {}
        self._loop = ticker.LoopingCall("sync_addresses_loop", self._sync_addresses, 60)
        self._thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=1)
        self._allow_virtual = allow_virtual
        self._networks = networks
        self._allow_mtn_vlan = allow_mtn_vlan

    @tornado.gen.coroutine
    def _resolve_fqdn_by_address(self, address):
        if address in self._cached_lookups:
            raise tornado.gen.Return(self._cached_lookups[address])
        try:
            result = yield self._thread_pool.submit(socket.gethostbyaddr, address)
        except socket.error:
            raise tornado.gen.Return(None)
        else:
            fqdn = self._cached_lookups[address] = (result[0] if result[0] != address else None)
            raise tornado.gen.Return(fqdn)

    @tornado.gen.coroutine
    def get_addresses(self, wait=True):
        if self._addresses is None and wait:
            yield self._loop.wait()
        raise tornado.gen.Return(self._addresses or ())

    @tornado.gen.coroutine
    def get_interfaces(self, wait=True):
        if self._discovered_interfaces is None and wait:
            yield self._loop.wait()
        raise tornado.gen.Return(self._discovered_interfaces or ())

    @tornado.gen.coroutine
    def _filter_addresses(self, addresses):
        raise tornado.gen.Return(addresses)

    @tornado.gen.coroutine
    def _sync_addresses(self):
        discovered_interfaces = interfaces.get_host_interfaces(self._allow_virtual,
                                                               self._networks,
                                                               self._allow_mtn_vlan)
        logging.debug("Discovered interfaces are: %r", discovered_interfaces)

        all_addresses = {(x.family, x.address) for x in discovered_interfaces}
        new_addresses = yield self._filter_addresses(all_addresses)
        new_addresses = sorted(new_addresses)
        if self._addresses is None or new_addresses != self._addresses:
            self._addresses = new_addresses
            logging.info("Local addresses are: %r", sorted(address for _, address in self._addresses))

        if EchoService in self._app:
            try:
                yield self._app[EchoService].on_address_sync(all_addresses)
            except Exception as exc:
                logging.warning("Can't bind echo service: %s", exc)

        if UdpService in self._app:
            try:
                yield self._app[UdpService].on_address_sync(self._addresses)
            except Exception as exc:
                logging.warning("Can't bind udp service: %s", exc)

        if IcmpService in self._app:
            try:
                yield self._app[IcmpService].on_address_sync(self._addresses)
            except Exception as exc:
                logging.warning("Can't bind icmp service: %s", exc)

        if TcpService in self._app:
            try:
                yield self._app[TcpService].on_address_sync(self._addresses)
            except Exception as exc:
                logging.warning("Can't bind tcp service: %s", exc)

        if LinkService in self._app:
            try:
                yield self._app[LinkService].on_address_sync(discovered_interfaces)
            except Exception as exc:
                logging.warning("Can't bind link service: %s", exc)

        fqdn_map = yield {
            iface.address: self._resolve_fqdn_by_address(iface.address)
            for iface in discovered_interfaces
            if not interfaces.is_mtn_vlan(iface.vlan)
            # don't expect mtn vlan interface to have a fqdn
        }

        self._discovered_interfaces = []
        for iface in discovered_interfaces:
            fqdn = None
            if fqdn_map.get(iface.address) is not None:
                fqdn = fqdn_map[iface.address]
            elif interfaces.is_mtn_vlan(iface.vlan):
                # make up a fake fqdn
                fqdn = 'vlan{}@{}'.format(iface.vlan, settings.current().hostname)
            elif iface.family == socket.AF_INET and settings.current().ignore_ipv4_dns_fails:
                fqdn = settings.current().hostname

            if fqdn is not None:
                self._discovered_interfaces.append(
                    self.DiscoveredInterface(fqdn=fqdn, **iface._asdict())
                )
        # FIXME: we return all discovered interfaces here, disregarding _filter_addresses.
        # So until topology is updated with our telemetry info, addresses we pass
        # to controllers may not correspond to the interfaces we use for outgoing probes.
        # Also, an agent may update topology 1h later than the server, i.e. 1h after
        # probe schedule starts using agent's new addresses.

    @tornado.gen.coroutine
    def cancel(self):
        yield self._loop.cancel()
        self._thread_pool.shutdown(wait=False)


class EchoService(Service):

    def __init__(self, listen_ports):
        self.listen_ports = listen_ports
        self.controller = None

    @contextlib.contextmanager
    def create_context(self, application):
        controller = _netmon.EchoController(logging.getLogger("echo"))
        with controllers.context(controller) as controller:
            self.controller = controller
            try:
                yield
            finally:
                self.controller = None

    @tornado.gen.coroutine
    def on_address_sync(self, addresses):
        yield controllers.defer(self.controller.sync_addresses, [
            _netmon.Address(family, address, port)
            for family, address in addresses
            for port in self.listen_ports
        ])


class ControllerService(Service):

    def __init__(self):
        self._controller = None
        self._configured = tornado.locks.Event()

    def _create_controller(self):
        raise NotImplementedError()

    @contextlib.contextmanager
    def create_context(self, application):
        with controllers.context(self._create_controller()) as controller:
            self._controller = controller
            try:
                yield
            finally:
                self._controller = None

    @tornado.gen.coroutine
    def on_address_sync(self, addresses):
        try:
            yield controllers.defer(self._controller.sync_addresses, [
                _netmon.Address(family, address, 0)
                for family, address in addresses
            ])
        finally:
            self._configured.set()

    @tornado.gen.coroutine
    def schedule_checks(self, configs):
        yield self._configured.wait()
        reports = yield controllers.defer(self._controller.schedule_checks, configs)
        raise tornado.gen.Return(reports)


class UdpService(ControllerService):

    def _create_controller(self):
        return _netmon.UdpPollerController(
            logging.getLogger("udp")
        )


class IcmpService(ControllerService):

    def _create_controller(self):
        return _netmon.IcmpPollerController(
            logging.getLogger("icmp")
        )


class TcpService(ControllerService):

    def _create_controller(self):
        return _netmon.TcpPollerController(
            logging.getLogger("tcp")
        )

    def __init__(self, listen_ports):
        super(TcpService, self).__init__()
        self.listen_ports = listen_ports

    @tornado.gen.coroutine
    def on_address_sync(self, addresses):
        try:
            yield controllers.defer(self._controller.sync_addresses, [
                _netmon.Address(family, address, port)
                for family, address in addresses
                for port in self.listen_ports
            ])
        finally:
            self._configured.set()


class LinkService(ControllerService):

    def _create_controller(self):
        return _netmon.LinkPollerController(
            logging.getLogger("link")
        )

    def __init__(self):
        super(LinkService, self).__init__()

    def _get_macs(self, addresses, gw, match_vlan):
        dst_ip = addresses[0].address if addresses else None
        src_mac = addresses[0].mac if addresses else None
        if not dst_ip or not src_mac:
            logging.warning("Can't get dst ip or src ll addr for link service")
            return 0, '', '', ''

        vlan = addresses[0].vlan if addresses and match_vlan else None
        oif = interfaces.get_ifindex(src_mac, vlan)
        dst_mac = interfaces.get_neigh6_lladdr(gw, oif)
        if not dst_mac:
            logging.warning("Can't get dst ll addr for link service")
            return 0, '', '', ''
        return oif, src_mac, dst_mac, dst_ip

    @tornado.gen.coroutine
    def on_address_sync(self, addresses):
        try:
            link_addresses = filter(lambda a: not interfaces.is_mtn_vlan(a.vlan), addresses)

            gw = interfaces.get_default_gw6()
            if not gw:
                logging.error("Can't get default gw addr for link service")
                raise tornado.gen.Return(None)

            bb_oif, bb_src_mac, bb_dst_mac, bb_dst_ip = self._get_macs(
                [x for x in link_addresses if x.backbone6],
                gw, match_vlan=False
            )
            if not bb_src_mac or not bb_dst_mac or not bb_dst_ip:
                logging.error('Failed to configure link service for backbone')
                raise tornado.gen.Return(None)

            fb_oif, fb_src_mac, fb_dst_mac, fb_dst_ip = self._get_macs(
                [x for x in link_addresses if x.fastbone6],
                gw, match_vlan=True
            )
            if not fb_src_mac or not fb_dst_mac or not fb_dst_ip:
                logging.warning('Failed to configure link service for fastbone')

            bb_src_ip = settings.current().link_poller_src_ip or bb_dst_ip
            fb_src_ip = settings.current().link_poller_fb_src_ip or fb_dst_ip
            port = settings.current().link_poller_port
            yield controllers.defer(self._controller.sync_addresses, port, port,
                                    bb_oif, bb_src_mac, bb_dst_mac, bb_src_ip, bb_dst_ip,
                                    fb_oif, fb_src_mac, fb_dst_mac, fb_src_ip, fb_dst_ip)
        finally:
            self._configured.set()

    @tornado.gen.coroutine
    def report_status(self):
        yield self._configured.wait()
        reports = yield controllers.defer(self._controller.report_status)
        raise tornado.gen.Return(reports)


class Application(object):

    def __init__(self):
        self._prepared = False
        self._running = False
        self._services = collections.OrderedDict()

        self._ioloop = tornado.ioloop.IOLoop.current()

        self.register(ResolverService())

    @classmethod
    def _setup_watchdog(cls):
        ioloop = tornado.ioloop.IOLoop.current()
        # kill process if loop is blocked for 30 seconds
        ioloop.set_blocking_signal_threshold(30, cls._on_blocked_loop)

    def _set_signal_handler(self, signum, callback):
        signal.signal(signum, lambda sig, frame: self._ioloop.add_callback_from_signal(callback))

    @property
    def loop(self):
        return self._ioloop

    @staticmethod
    def _set_process_group():
        try:
            os.setpgid(0, 0)
        except OSError as exc:
            if exc.errno != errno.EPERM:
                raise

    @staticmethod
    def _on_blocked_loop(sig, frame):
        Logger.direct_logger.warning(
            'IOLoop blocked for 30 seconds in\n%s',
            ''.join(traceback.format_stack(frame)))
        os.killpg(os.getpgid(os.getpid()), signal.SIGKILL)

    @tornado.gen.coroutine
    def _try_to_cancel_service(self, service):
        try:
            yield tornado.gen.with_timeout(
                datetime.timedelta(seconds=10),
                tornado.gen.maybe_future(service.cancel())
            )
        except Exception as exc:
            logging.exception("Can't cancel service %r: %s", service, exc)
            raise tornado.gen.Return(False)
        raise tornado.gen.Return(True)

    @tornado.gen.coroutine
    def _cancel(self):
        for service in reversed(self._services.values()):
            cancelled = yield self._try_to_cancel_service(service)
            if not cancelled:
                # because we can't stop agent in normal order, kill it
                os._exit(1)

    @tornado.gen.coroutine
    def stop_loop(self):
        Logger.direct_logger.info("Agent shut down, stopping IOLoop")
        self._ioloop.stop()

    @tornado.gen.coroutine
    def shutdown(self):
        if not self._running:
            logging.debug("Agent already shutting down!")
        else:
            logging.info("Agent is shutting down!")
            self._running = False

            from .rpc import RpcClient
            if RpcClient in self._services and BackendMaintainerService in self._services:
                # before calling terminated_host we need to stop requests
                # to scheduled_probes, send_reports and enqueued_tasks

                from .agent import ScheduledProbesService
                from .sender import SenderService
                from .tasks import TaskDispatcher
                for service_type in (ScheduledProbesService, SenderService, TaskDispatcher):
                    if service_type in self._services:
                        yield self._try_to_cancel_service(self._services[service_type])

                for url in self[BackendMaintainerService].get_all_backends():
                    if url:
                        try:
                            yield self._services[RpcClient].terminated_host(url)
                        except Exception as exc:
                            logging.error('Failed to mark self as terminated host at %s: %r', url, exc)

            self._ioloop.add_callback(lambda: self._cancel(callback=lambda _: self.stop_loop()))

    def reload(self):
        # simply exit, skycore will restart us
        self.shutdown()

    @contextlib.contextmanager
    def _context(self):
        Logger.direct_logger.info("Agent is starting!")

        # Failing to stop unistat pusher before python interpreter exits causes a segfault.
        # Start it here to guard against unexpected errors aborting the program.
        if settings.current().unistat_pusher:
            _netmon.start_unistat_pusher()
            _unistat_callback.push_signal(_unistat_callback.AgentStarts, 1.0)

        self._running = True
        try:
            with utils.ExitStack() as stack:
                for service in self._services.values():
                    stack.enter_context(service.create_context(self))
                yield
        finally:
            self._running = False

            if settings.current().unistat_pusher:
                _netmon.stop_unistat_pusher()

        logging.info("Agent exited!")

    def register(self, service):
        for cls in itertools.chain(getattr(service, "provides", ()), (type(service),)):
            if cls in self._services:
                raise RuntimeError("Service {!r} already registered!".format(cls))
            self._services[cls] = service

    def __getitem__(self, cls):
        return self._services[cls]

    def __contains__(self, cls):
        return cls in self._services

    def _safe_chown(self, path, uid, gid):
        try:
            os.chown(path, uid, gid)
        except OSError as e:
            logging.warning("%s", e)
            pass

    def _drop_privileges(self):
        if os.geteuid() == 0:
            uid = pwd.getpwnam(settings.current().user).pw_uid
            gid = grp.getgrnam(settings.current().group).gr_gid
            self._safe_chown(str(settings.current().var_dir), uid, gid)
            self._safe_chown(str(settings.current().pid_path), uid, gid)
            if settings.current().log_path is not None:
                self._safe_chown(str(settings.current().log_path), uid, gid)
            _netmon.drop_privileges(settings.current().user, uid, gid)

    def setup_signals(self):
        self._set_signal_handler(signal.SIGINT, self.shutdown)
        self._set_signal_handler(signal.SIGTERM, self.shutdown)
        self._set_signal_handler(signal.SIGUSR1, self.reload)

        if not sys.stdout.isatty():
            self._set_process_group()

    def prepare(self):
        if not self._prepared:
            self._prepared = True
            self._setup_watchdog()
            Logger.initialize()
            _netmon.init_global_log(logging.getLogger("global"))

            # Timestamping needs elevated privileges
            # for manipulating device ts implementation.
            # Either we want to enforce it disabled or enable it
            # Otherwise, drop privileges after creating log file
            if not settings.current().link_poller:
                self._drop_privileges()

    def start_loop(self):
        with self._context():
            self._ioloop.start()

    def start(self):
        self.prepare()
        self.setup_signals()
        self.start_loop()

    def run_sync(self, func, timeout=None):
        with self._context():
            return self._ioloop.run_sync(func, timeout=timeout)
