from __future__ import print_function

import Queue
import argparse
from BaseHTTPServer import BaseHTTPRequestHandler, HTTPServer
import errno
from importlib import import_module
import json
import logging
import logging.config
import threading
from operator import itemgetter
import os
from os import path
from random import randint
import signal
from SocketServer import ThreadingMixIn
from concurrent.futures import ThreadPoolExecutor
from concurrent.futures import _base
from concurrent.futures.thread import _WorkItem
import sys
import time
import traceback
import urlparse
import requests
import urllib
import re
import collections
from requests_toolbelt.adapters import host_header_ssl

import io
import gzip

import ipaddr
import netaddr

from config import Config, LoggingConfig, overrides_from_args
from iptables import IPTables, IPTablesError
from util import enumerate_lines
from util import get_project_id
from util import Timer
from library.python import svn_version

import subprocess32
from mds.valve.proto import valve_pb2

from . import metrics
from . import jugglerutil

import prctl

log = logging.getLogger(__name__.split(".")[-1])
log.addHandler(logging.NullHandler())


class ForceApply(object):
    def __init__(self, orly_flag_file, value=False):
        self._flag = orly_flag_file
        self._value = value

    def set(self):
        self._value = True

    def unset(self):
        self._value = False

    def is_set(self):
        return self._value or disabling_flag_exists(self._flag)


class AgentError(Exception):
    def __init__(self, message, text=None):
        self.message = message
        self.text = text

    def __str__(self):
        return self.message


class AgentShutdown(SystemExit):
    pass


class ApiData(object):
    def __init__(self):
        self.unistat = []
        self.last_url = ""
        self.dropcount = {}
        self.ready_ips = set()

    @property
    def status(self):
        global status
        return status.status


class Status(object):
    OK = 0
    WARN = 1
    CRIT = 2
    DISABLED = 2

    def __init__(self):
        self.status = {
            "status": "OK",
            "desc": "",
            "last_update": int(time.time() - get_uptime())
        }
        self.msgs = set()

    def ok(self, msg=""):
        self._set_status("OK", msg)

    def warn(self, msg):
        self._set_status("WARN", msg)

    def crit(self, msg):
        self._set_status("CRIT", msg)

    def disabled(self):
        self._set_status("DISABLED", "HBF disabled by {}".format(Config()["main"]["disabling_flag"]))

    def _set_status(self, status, msg, timestamp=None):
        if timestamp is None:
            timestamp = int(time.time())

        if status in ("OK", "DISABLED"):
            self.status["status"] = status
            self.status["desc"] = msg
            self.msgs = set()
        else:
            if getattr(self, status) > getattr(self, self.status["status"]):
                self.status["status"] = status  # :*

            if self.status["desc"]:
                if msg not in self.msgs:
                    self.status["desc"] = self.status["desc"] + "; " + msg
            else:
                self.status["desc"] = msg
            self.msgs.add(msg)

        self.status["last_update"] = timestamp


class SignalHandler(object):
    def __init__(self):
        self.reload = False
        self.dump = False
        signal.signal(signal.SIGTERM, self._stop)
        signal.signal(signal.SIGINT, self._stop)
        signal.signal(signal.SIGUSR1, self._dump)
        signal.signal(signal.SIGUSR2, self._reload)

    def _stop(self, sig_num, frame):
        log.info("Stop triggered by signal {}.".format(sig_num))
        signal.signal(signal.SIGINT, signal.SIG_IGN)
        signal.signal(signal.SIGTERM, signal.SIG_IGN)
        raise AgentShutdown

    def _reload(self, sig_num, frame):
        log.info("Force reload triggered by signal {}.".format(sig_num))
        self.reload = True

    def _dump(self, sig_num, frame):
        log.info("State dump triggered by signal {}.".format(sig_num))
        self.dump = True


class ServersideConfig(object):
    __slots__ = [
        'v4_tables',
        'v6_tables',
        'full_url',
        'last_modified',
        'download_time',
        'raw_data',
        'is_training_mode',
    ]


def get_uptime():
    with open('/proc/uptime', 'r') as f:
        return float(f.readline().split()[0])


def close_logs(sig_num, frame):
    log.info("Caught signal {}, closing log files.".format(sig_num))
    for h in logging.getLogger().handlers:
        log.info("Closing file '{}'.".format(h.stream.name))
        if isinstance(h, logging.FileHandler):
            try:
                h.close()
            except Exception:
                log.exception("Exception while closing logs:")


class PoolMixIn(ThreadingMixIn):
    def process_request(self, request, client_address):
        try:
            self.pool.submit(self.process_request_thread, request, client_address)
        except:
            self.handle_error(request, client_address)
            self.shutdown_request(request)


class ThreadPoolExecutorLimitedQueue(ThreadPoolExecutor):
    def __init__(self, max_workers, queue_limit):
        super(ThreadPoolExecutorLimitedQueue, self).__init__(max_workers)
        self._work_queue = Queue.Queue(queue_limit)

    def submit(self, fn, *args, **kwargs):
        with self._shutdown_lock:
            if self._shutdown:
                raise RuntimeError('cannot schedule new futures after shutdown')

            f = _base.Future()
            w = _WorkItem(f, fn, args, kwargs)

            self._work_queue.put(w, block=False)
            self._adjust_thread_count()
            return f


class PoolHTTPServer(PoolMixIn, HTTPServer):
    def __init__(self, server_address, api_handler, max_workers, queue_length):
        HTTPServer.__init__(self, server_address, api_handler)
        self.pool = ThreadPoolExecutorLimitedQueue(max_workers, queue_length)


class APIHandler(BaseHTTPRequestHandler):
    paths = [
        '/status',
        '/ip_address_ready?ip=<ip_address>',
        '/unistat',
        '/last-url',
        '/dropcount4',
        '/dropcount6',
    ]

    def do_GET(self):
        parsed_path = urlparse.urlparse(self.path)
        if parsed_path.path.startswith('/status'):
            self.send_response(200)
            self.send_header("Content-type", "application/json")
            self.end_headers()
            self.wfile.write(json.dumps(self.data.status, indent=4))
        elif parsed_path.path.startswith('/ip_address_ready'):
            query = urlparse.parse_qs(parsed_path.query)
            if "ip" in query and len(query["ip"]) == 1:
                try:
                    ip = ipaddr.IPAddress(query["ip"][0])
                except ValueError:
                    status_code = 400
                else:
                    if ip in self.data.ready_ips:
                        status_code = 200
                    elif self.data.status["status"] == "DISABLED":
                        status_code = 503
                    else:
                        status_code = 404
            else:
                status_code = 400
            self.send_response(status_code)
        elif parsed_path.path.startswith('/unistat'):
            self.send_response(200)
            self.send_header("Content-type", "application/json")
            self.end_headers()
            self.wfile.write(json.dumps(self.data.unistat, indent=4))
        elif parsed_path.path.startswith('/last-url'):
            self.send_response(200)
            self.send_header("Content-type", "application/json")
            self.end_headers()
            self.wfile.write(json.dumps(self.data.last_url, indent=4))
        elif parsed_path.path.startswith('/dropcount4'):
            self.send_response(200)
            self.send_header("Content-type", "application/json")
            self.end_headers()
            try:
                self.wfile.write(json.dumps(self.data.dropcount['v4'], indent=4))
            except KeyError:
                self.wfile.write(json.dumps([]))
        elif parsed_path.path.startswith('/dropcount6'):
            self.send_response(200)
            self.send_header("Content-type", "application/json")
            self.end_headers()
            try:
                self.wfile.write(json.dumps(self.data.dropcount['v6'], indent=4))
            except KeyError:
                self.wfile.write(json.dumps([]))
        else:
            self.send_response(200)
            self.send_header("Content-type", "text/plain")
            self.end_headers()
            self.wfile.write("Yandex HBF Agent\n\n")
            self.wfile.write("List of paths:\n")
            self.wfile.write("\n".join("  - " + x for x in
                                       self.paths) + "\n")

    def log_message(self, format, *args):
        log.info(self.client_address[0] + " " + format % args)


def run_api(data):
    log.info("HTTP API: Starting.")
    config = Config()
    server_address = (config["http_api"]["host"], config["http_api"]["port"])
    max_workers = config["http_api"]["max_workers"]
    queue_length = config["http_api"]["queue_length"]
    log.info("max_workers: {}, queue_length: {}".format(max_workers, queue_length))
    APIHandler.data = data
    try:
        api = PoolHTTPServer(server_address, APIHandler, max_workers, queue_length)
        api.serve_forever()
    except Exception:
        log.critical("HTTP API: Exiting on exception:\n" +
                     traceback.format_exc())
        os.kill(os.getpid(), signal.SIGTERM)
    else:
        log.info("HTTP API: Normal exit.")
    finally:
        api.server_close()


def main():
    global args
    args = parse_args()

    if args.default_config:
        errors = Config.dump_default(args.configspec)
        if not errors:
            sys.exit(0)
        else:
            sys.exit(1)

    config = Config(args.config, args.configspec, overrides=overrides_from_args(args.override))

    if args.current_config:
        print(config)
        sys.exit(0)

    logging_conf_path = path.join(path.dirname(args.config),
                                  "logging.conf")
    logging_configspec_path = path.join(path.dirname(args.configspec),
                                        "logging.configspec")
    logging_config = LoggingConfig(logging_conf_path, logging_configspec_path,
                                   unrepr=True, check_extra=False)
    if config["main"]["log_level"]:
        logging_config.obj["root"]["level"] = config["main"]["log_level"]
    logging.config.dictConfig(logging_config.obj)

    if args.disable_hbf:
        ec = disable_hbf()
        sys.exit(ec)

    try:
        log.info("Starting.")

        global status
        status = Status()

        sync_manager_socket = config["main"]["sync_manager_socket"]
        try:
            os.remove(sync_manager_socket)
        except Exception:
            pass

        signal.signal(signal.SIGHUP, close_logs)

        data = ApiData()
        api_thread = threading.Thread(name='api-server', target=run_api, args=(data,))
        api_thread.daemon = True
        api_thread.start()

        signal_handler = SignalHandler()
        # Random delay to spread request burst on start
        sleep_period = randint(0, config["main"]["start_period_random"])
        log.info("Sleeping for {} s.".format(sleep_period))
        time.sleep(sleep_period)

        Agent(signal_handler).loop()
    except AgentShutdown:
        log.info("Normal exit.")
    except:
        log.critical("Exiting on exception:\n" + traceback.format_exc())
        sys.exit(1)


def parse_args():
    parser = argparse.ArgumentParser(description="Yandex HBF agent")
    parser.add_argument(
        "--config",
        default="/etc/yandex-hbf-agent/yandex-hbf-agent.conf",
        help="path to config, default: %(default)s"
    )
    parser.add_argument(
        "--configspec",
        default="/usr/share/yandex-hbf-agent/yandex-hbf-agent.configspec",
        help="path to configspec, default: %(default)s"
    )
    parser.add_argument(
        "--current-config", action="store_true",
        help="dump current configuration (considering defaults) and exit"
    )
    parser.add_argument(
        "--default-config", action="store_true",
        help="dump default configuration file and exit"
    )
    parser.add_argument(
        "--disable-hbf", action="store_true",
        help="disable jumps from builtin chains to HBF"
    )
    parser.add_argument(
        '-V', '--override', action='append',
        help='override config element, e.g. http_api.port=8083 (main can be omitted)'
    )
    return parser.parse_args()


def disable_hbf():
    current = path.join(path.dirname(args.config), "rules.d")
    default = path.join(path.dirname(args.configspec), "rules.d")
    ec = 0
    config = Config()
    use_yandex_iptables = config.use_yandex_iptables()
    for ip_version in ("v4", "v6"):
        log.info("Disabling jump to IP{} HBF.".format(ip_version))
        hbf_file_name = "50-hbf." + ip_version
        hbf_file_path = None
        for p in list_overlay_dirs([current, default], ip_version):
            if path.basename(p) == hbf_file_name:
                hbf_file_path = p
                break
        if hbf_file_path is None:
            log.critical("File '{}' not found in configuration.".format(
                hbf_file_name
            ))
            ec = 1
            continue
        try:
            hbf_jumps = IPTables(ip_version, open(hbf_file_path).read(), use_yandex_iptables=use_yandex_iptables)
            hbf_jumps.apply_delete()
        except Exception:
            log.critical("Exception:\n" + traceback.format_exc())
            ec = 1
        else:
            log.info("Disabling jump to IP{} HBF: OK".format(ip_version))
    return ec


def disabling_flag_exists(flag_path):
    return path.exists(flag_path)


class IterLimiter(object):
    def __init__(self):
        config = Config()
        self.token_burst_limit = config["main"]["token_burst_limit"]
        self.token_period = config["main"]["token_period"]
        self.token_count = self.token_burst_limit
        self.last_token_time = time.time()

    def _populate(self):
        now = time.time()
        if now < self.last_token_time:
            return

        periods_ceil = int((now - self.last_token_time) / self.token_period)
        self.token_count = min(self.token_count + periods_ceil, self.token_burst_limit)
        self.last_token_time += periods_ceil * self.token_period
        log.debug(
            'TokenLimiter _populate: token_count: %d, ts: %d, now: %d, ceil: %d' % (
                self.token_count,
                self.last_token_time,
                now,
                periods_ceil
            )
        )

    def consume(self):
        self._populate()
        if self.token_count > 0:
            self.token_count -= 1
            return True

        return False


class Agent(object):

    def __init__(self, signal_handler):
        self.config = Config()
        self.use_yandex_iptables = self.config.use_yandex_iptables()
        self.rules = {"v4": {"current": None},
                      "v6": {"current": None}}
        self.current_system_rules = {"v4": None, "v6": None}
        self.guest_ips = set()
        self.host_ips = set()

        self.last_server_response = None
        self.tmp_last_server_response = None
        self.static_targets = set()
        self.last_remote_update = 0

        self.server_request_duration = 0.0
        self.server_download_duration = 0.0

        self.agent_error_count = 0
        self.server_error_count = 0
        self.modules_error_count = 0
        self.iptables_error_count = 0

        self.remote_rule_check_deadline = None
        self.update_period = self.config["main"]["update_period"]
        self.update_period_seed = self.config["main"]["update_period_random"]
        self.update_guests_period = self.config["main"]["update_guest_ips_period"]

        self.hbf_disable_flag = self.config["main"]["disabling_flag"]

        self.limiter = IterLimiter()

        self.signal_handler = signal_handler
        self.ugrams = {'v4': {}, 'v6': {}}

        self.geo = jugglerutil.get_geo()

        self.use_fastpath_accept = self.config['main']['fastpath_accept']
        self.is_training = False
        self.should_update_last_resp = False

        self.force_apply = ForceApply(self.config['orly']['disabling_flag'])

    def set_remote_update_deadline(self, period=None):
        if not period:
            period = self.update_period + randint(0, self.update_period_seed)
        self.remote_rule_check_deadline = time.time() + period
        log.info('Set rule check deadline: %d s after current moment' % period)

    def iteration_needed(self, moment):
        if not self.remote_rule_check_deadline:
            self.set_remote_update_deadline()
            return False

        return moment > self.remote_rule_check_deadline

    def handle_dump(self):
        if self.signal_handler.dump:
            dump(self.rules)
            self.signal_handler.dump = False

    def loop(self):
        while True:
            self.handle_hbf_disabling_flag()

            log.info("New iteration.")
            t = Timer()

            self.handle_reload()
            self.get_host_ips()
            self.get_guest_ips()
            self.get_static_targets()
            self.load_local_rules()
            self.load_remote_rules()
            self.generate_rules()
            if self.use_fastpath_accept:
                self.fastpath_rules()
            self.concat_rules()
            self.apply_rules()

            self.update_unistat(t.interval)
            self.update_dropcount()

            self.wait()

    def handle_hbf_disabling_flag(self):
        if disabling_flag_exists(self.hbf_disable_flag):
            log.info("Disabled by '{}'.".format(
                self.hbf_disable_flag
            ))
            disable_hbf()
            while disabling_flag_exists(self.hbf_disable_flag):
                t = Timer()
                status.disabled()
                self.handle_dump()
                if self.signal_handler.reload:
                    break
                self.update_unistat(t.interval, True)
                time.sleep(1)
            self.guest_ips.clear()  # Update 'guest_ips'.
            status.ok()  # Drop disabled status.

    def is_valid_rules(self, raw_data):
        crit_msg = ""
        validator_disable_flag = self.config['validator']['disabling_flag']
        if not self.config['validator']['enabled']:
            log.debug('Validator disabled in config')
            return True, crit_msg

        if disabling_flag_exists(validator_disable_flag):
            log.warning('Validator disabled via {}'.format(validator_disable_flag))
            return True, crit_msg

        if self.is_training:
            log.warning('Validator disabled due to training mode')
            return True, crit_msg

        sampler_path = self.config['validator']['sampler_path']
        dump_path = self.config['validator']['dump_path']

        val_req = valve_pb2.ProcessSampleRequest()
        val_req.HBFRules = raw_data

        run = [sampler_path, '-nopush', '-dump-path', dump_path]
        if not (
            prctl.cap_inheritable.sys_admin and
            prctl.cap_inheritable.sys_ptrace and
            prctl.cap_inheritable.dac_override and
            prctl.cap_inheritable.sys_module
        ):
            run.insert(0, 'sudo')
        log.debug("Executing: '{}'.".format(" ".join(run)))
        ret = subprocess32.Popen(run, stdout=subprocess32.PIPE, stderr=subprocess32.PIPE)
        try:
            out, err = ret.communicate(timeout=30)
        except subprocess32.TimeoutExpired:
            ret.kill()
            out, err = ret.communicate()

        log.debug("Return code: '{}'.".format(ret.returncode))
        log.debug("Out: '{}'.".format(out))
        log.debug("Err: '{}'.".format(err))

        if ret.returncode != 0:
            log.warning('sampler exit code: {}'.format(ret.returncode))
            log.warning('sampler stderr: {}'.format(err))
            return False, "Problems with tcp-sampler run"

        try:
            with open(dump_path, 'rb') as f:
                val_req.sample.ParseFromString(f.read())
        except Exception as e:
            crit_msg = "Exception on parse netSample from file"
            log.warning("{}: {}".format(crit_msg, e))
            return False, crit_msg

        data = io.BytesIO()
        with gzip.GzipFile(fileobj=data, mode='wb') as f:
            f.write(val_req.SerializeToString())
        ziped_data = data.getvalue()
        data.close()

        req = requests.Request(
            'PUT',
            self.config['validator']['url'],
            headers={
                'Content-Type': 'application/protobuf',
                'Content-Encoding': 'gzip',
                'User-Agent': 'yandex-hbf-agent: {}'.format(svn_version.svn_revision()),
            },
            data=ziped_data
        )

        # A backoff factor to apply between attempts AFTER THE SECOND TRY
        # (most errors are resolved immediately by a second try without a delay).
        #
        # formula for sleeps between attempts:
        # backoff_factor * (2 ** (current_try - 1 ))
        retry_strat = requests.packages.urllib3.util.retry.Retry(
            total=self.config['validator']['retry_count'],
            status_forcelist=[429, 500, 502, 503, 504],
            backoff_factor=self.config['validator']['retry_backoff_factor'],
        )

        retry_adapter = requests.adapters.HTTPAdapter(max_retries=retry_strat)

        with requests.Session() as s:
            s.mount('http://', retry_adapter)
            try:
                prep_req = s.prepare_request(req)
                resp = s.send(prep_req, timeout=self.config['validator']['timeout'])
                log.debug("Response headers: {}".format(resp.headers))
                if resp.status_code != 200:
                    crit_msg = "Bad response status code: {}".format(resp.status_code)
                    log.warning(crit_msg)
                    return False, crit_msg
                resp_data = valve_pb2.ProcessSampleResponse()
                resp_data.ParseFromString(resp.content)
                len_of_drops = len(resp_data.dropsOrRejects)
                if len_of_drops != 0:
                    crit_msg = "There are {} drops or rejects in rules".format(len_of_drops)
                    log.warning(crit_msg)
                    log.debug(resp_data)
                    return False, crit_msg

            except Exception as e:
                crit_msg = "Got exception on validation request: {}".format(e)
                log.warning(crit_msg)
                return False, crit_msg

        return True, crit_msg

    def handle_reload(self):
        # Check if reload is requested.
        if self.signal_handler.reload:
            self.config.reload()
            self.guest_ips.clear()  # Update 'guest_ips'.
            self.last_server_response = None  # Forget 'Last-Modified'.
            self.signal_handler.reload = False

    def get_host_ips(self):
        log.info("Running hook 'host_ips' to collect host IP addresses.")
        try:
            self.host_ips = run_ip_hook(self.config["hooks"]["host_ips"])
        except Exception:
            self.modules_error_count += 1
            log.exception("Got exception while running hook 'host_ips':")

    def fetch_guest_ips(self, old):
        hook_guest_ips = self.config["hooks"]["guest_ips"]
        if hook_guest_ips:
            try:
                new_ips = run_ip_hook(hook_guest_ips)
                changed = (new_ips != old)
                if changed:
                    log.info(
                        "Guest IP addresses changed, added: %s, deleted: %s" % (
                            ips_to_string(new_ips - old),
                            ips_to_string(old - new_ips)
                        )
                    )
                return new_ips, changed
            except Exception:
                self.modules_error_count += 1
                log.exception("Got exception while running hook 'guest_ips':")
        return set(), False

    def get_guest_ips(self):
        log.info("Running hook 'guest_ips' to collect guest IP addresses.")
        current_guests, changed = self.fetch_guest_ips(self.guest_ips)
        if current_guests or changed:
            # Actually, there is no need to renew empty set when it is not changed.
            # Furthermore, exception while polling guests also produces such result
            # so just stay with old set.
            self.guest_ips = current_guests

    def get_static_targets(self):
        targets_path = path.join(
            path.dirname(self.config.config), "targets.list"
        )
        if os.path.exists(targets_path):
            log.info("Loading static targets from '{}'.".format(targets_path))
            try:
                self.static_targets = parse_targets_list(targets_path)
            except Exception as e:
                self.agent_error_count += 1
                log.exception("Exception:")
                status.crit("Exception: " + str(e))
        else:
            self.static_targets.clear()
        if self.config["main"]["add_fastbone_target"]:
            self.static_targets.add("_FASTBONE_")

    def load_local_rules(self):
        for ip_version in self.rules:
            self.rules[ip_version]["local"] = get_local_rules(
                ip_version, path.dirname(self.config.config),
                path.dirname(self.config.configspec)
            )

    def load_remote_rules(self):
        all_targets = map(str, self.host_ips.union(self.guest_ips, self.static_targets))
        if not all_targets:
            msg = "No targets found, cannot get remote rules."
            log.error(msg)
            status.crit(msg)
        else:
            t = Timer()
            try:

                # (
                #     hbf_v4_rules,
                #     hbf_v6_rules,
                #     self.last_server_response,
                #     download_time,
                #     raw_data,
                #     is_training,
                # ) = get_serverside_config(
                serverside = get_serverside_config(
                    all_targets,
                    self.last_server_response,
                )
                # seems like self.last_server_response should be updated only after applying rules
                # see https://st.yandex-team.ru/RTCNETWORK-538#5f996bea3316147cd4e3a22e
                self.tmp_last_server_response = (serverside.full_url,  serverside.last_modified)
                self.is_training = serverside.is_training_mode
                self.last_remote_update = time.time()
                self.server_download_duration = serverside.download_time
                self.server_request_duration = serverside.download_time
                APIHandler.data.last_url = self.tmp_last_server_response[0]

                valid, crit_msg = self.is_valid_rules(serverside.raw_data)
                log.info('Rules are valid: {}'.format(valid))
                mode = self.config['validator']['mode']

                if not valid and mode == 'hard':
                    jugglerutil.push_validation_crit(crit_msg)
                    self.should_update_last_resp = False
                else:
                    self.rules["v4"]["hbf"] = serverside.v4_tables
                    self.rules["v6"]["hbf"] = serverside.v6_tables
                    self.should_update_last_resp = True
                    if not valid:
                        m = 'Validation failed, soft mode'
                        log.info(m)
                        jugglerutil.push_validation_warn(m)
                    else:
                        jugglerutil.push_validation_ok()

            except AgentError as e:
                self.agent_error_count += 1
                log.exception("AgentError:")
                if e.text:
                    log.debug("Dumping remote rules:\n" +
                              enumerate_lines(e.text))
            except requests.exceptions.HTTPError as e:
                self.server_request_duration = t.interval
                if e.response.status_code == 304:
                    log.info('Server indicated no changes in rules for targets')
                    self.last_remote_update = time.time()
                elif e.response.status_code == 429:
                    log.info('Server asked us to retry later')
                    # Reduced timeout, at least one try to refresh guest targets
                    self.set_remote_update_deadline(self.update_guests_period)
                    self.server_error_count += 1
                    return
                else:
                    log.info('Server error: %s' % e.response.status_code)
                    self.server_error_count += 1
            except Exception as e:
                self.agent_error_count += 1
                log.exception("Exception:")
                status.crit(
                    "Exception while loading remote rules: {}".format(e)
                )

        self.set_remote_update_deadline()  # default period

    def generate_rules(self):
        hook_rules = self.config["hooks"]["rules"]
        for ip_version in self.rules:
            ipv = 4 if ip_version == "v4" else 6
            host_ips_ipv = {ip for ip in self.host_ips
                            if ip.version == ipv}
            if self.guest_ips:
                guest_ips_ipv = {ip for ip in self.guest_ips
                                 if ip.version == ipv}
            else:
                guest_ips_ipv = set()
            self.rules[ip_version]["generated"] = \
                run_rule_hook(hook_rules, ip_version,
                              host_ips_ipv, guest_ips_ipv)

    def fastpath_rules(self):
        m = "Training mode: {}".format(self.is_training)
        log.info(m)
        if self.is_training:
            for ip_version in self.rules:
                t = IPTables(ip_version, use_yandex_iptables=self.use_yandex_iptables)
                self.rules[ip_version]["fast_path"] = t
            return

        for ip_version in self.rules:
            t = IPTables(ip_version, use_yandex_iptables=self.use_yandex_iptables)
            t.append_rule('filter', 'INPUT', '-I INPUT 1 -m state --state RELATED,ESTABLISHED -j ACCEPT')
            t.append_rule('filter', 'INPUT', '-I INPUT 2 -p tcp -m tcp ! --tcp-flags FIN,SYN,RST,ACK SYN -j ACCEPT')
            t.append_rule('filter', 'OUTPUT', '-I OUTPUT 1 -m state --state RELATED,ESTABLISHED -j ACCEPT')
            t.append_rule('filter', 'OUTPUT', '-I OUTPUT 2 -p tcp -m tcp ! --tcp-flags FIN,SYN,RST,ACK SYN -j ACCEPT')
            t.append_rule('filter', 'FORWARD', '-I FORWARD 1 -m state --state RELATED,ESTABLISHED -j ACCEPT')
            t.append_rule('filter', 'FORWARD', '-I FORWARD 2 -p tcp -m tcp ! --tcp-flags FIN,SYN,RST,ACK SYN -j ACCEPT')
            self.rules[ip_version]["fast_path"] = t

    def concat_rules(self):
        for ip_version in self.rules:
            self.rules[ip_version]["new"] = IPTables(ip_version, use_yandex_iptables=self.use_yandex_iptables)
            for rule_type in self.rules[ip_version]:
                if rule_type in ("current", "new"):
                    continue
                self.rules[ip_version]["new"] += \
                    self.rules[ip_version][rule_type]

    def apply_rules(self):
        applied = {}
        for ipv in self.rules:
            applied[ipv] = True
            if self.rules[ipv]["current"] != self.rules[ipv]["new"]:
                result = apply(self.rules[ipv]["new"], force=self.force_apply.is_set())
                if result:
                    self.rules[ipv]["current"] = self.rules[ipv]["new"]
                    self.current_system_rules[ipv] = self.load_system_rules(ipv)
                    APIHandler.data.ready_ips = self.host_ips.union(self.guest_ips)
                    # Touch 'last_update'.
                    # status.ok()
                else:
                    self.iptables_error_count += 1
                    applied[ipv] = False

                    if self.rules[ipv]["current"] is not None:
                        log.info("Applying last valid ruleset.")
                        apply(self.rules[ipv]["current"], force=True)
                    else:
                        log.warning("No valid rulesets obtained yet.")
            elif (self.current_system_rules[ipv]
                  != self.load_system_rules(ipv)):
                log.warning("System rules changed by external process,"
                            " restoring HBF rules.")
                if not apply(self.rules[ipv]["current"], force=True):
                    self.iptables_error_count += 1
            else:
                # No changes in rules possible regardless of ip set, e.g. for SLB
                # Let's reassure ready ips and touch status
                APIHandler.data.ready_ips = self.host_ips.union(self.guest_ips)
                # status.ok()

        # check if both ipv4 and ipv6 rules are applied
        applied_res, msg = is_applied(applied)
        if applied_res:
            status.ok()
            if self.should_update_last_resp:
                log.info("Updating last_server_response: {}".format(self.tmp_last_server_response))
                self.last_server_response = self.tmp_last_server_response
        else:
            status.crit(msg)

        # unset flags for next iterations
        self.force_apply.unset()
        self.should_update_last_resp = False

    def load_system_rules(self, ipv):
        system_rules = IPTables(ipv, use_yandex_iptables=self.use_yandex_iptables)
        system_rules.load_current()
        current = self.rules[ipv]["current"]
        for table in current.protected_chains:
            for chain in current.protected_chains[table]:
                try:
                    del system_rules[table][chain]
                except KeyError:
                    pass
        return system_rules

    @staticmethod
    def get_pkts_bytes_for_drops(list_of_matches):
        if not list_of_matches:
            return 0, 0, False
        # for targets LOG, DROP, REJECT should be only one result
        return list_of_matches[0].group('pkts'), list_of_matches[0].group('bytes'), True

    @staticmethod
    def fastbone_drops(iptables):
        fb_chains = {
            "input_pre": "Y_END_IN_PRE",
            "output_pre": "Y_END_OUT_PRE",
        }
        values = {
            "input_pre": 0,
            "output_pre": 0,
        }
        target = "RETURN"
        fb_comment = "_FASTBONE_"
        for in_out, chain in fb_chains.items():
            dump = iptables.get_vxL_chain_dump(chain)
            matches = iptables.parse_chain(dump, target)
            for m in matches:
                log.debug("fastbone, comment: '{}'".format(m.group("comment")))
                if fb_comment in m.group("comment"):
                    log.debug("fastbone, pkts: {}".format(m.group("pkts")))
                    values[in_out] += int(m.group("pkts"))
        return values

    def calc_drops(self, family, iptables):
        chains = {
            "input": "Y_END_IN",
            "output": "Y_END_OUT",
            "output_inet": "Y_END_OUT_INET"
        }
        rules = {
            "input": ("LOG", "DROP"),
            "output": ("LOG", "REJECT"),
            "output_inet": ("LOG", "REJECT")
        }

        result = {}
        fb_drops = collections.defaultdict(int)
        # fastbone is v6 only
        if family == 'v6':
            fb_drops.update(self.fastbone_drops(iptables))

        for in_out, chain in chains.items():
            dump = iptables.get_vxL_chain_dump(chain)
            for target in rules[in_out]:
                matches = iptables.parse_chain(dump, target)
                log.debug('parse_chain result: {}'.format(matches))
                pkts, byts, found = self.get_pkts_bytes_for_drops(matches)
                if not found:
                    log.debug("family: {}; NOT FOUND {} in {}; setting 0".format(family, target, chain))
                else:
                    log.debug("family: {}; FOUND {} in {}, pkts: {}, bytes: {}".format(family, target, chain, pkts, byts))

                target = target.lower()
                if in_out == 'input':
                    if family == 'v6':
                        fb_dropped = fb_drops["input_pre"]
                        pkts = int(pkts) - fb_dropped
                        result["packets_{}_fastbone_dhhh".format(in_out)] = fb_dropped
                    result["packets_{}_summ".format(target)] = pkts
                    result["packets_{}_dhhh".format(target)] = float(pkts)

                    hgram_name = "packets_{}_ahhh".format(target)
                else:
                    if family == 'v6':
                        fb_dropped = fb_drops["output_pre"]
                        pkts = int(pkts) - fb_dropped
                        result["packets_{}_{}_fastbone_dhhh".format(in_out, target)] = fb_dropped
                    result["packets_{}_{}_summ".format(in_out, target)] = pkts
                    result["packets_{}_{}_dhhh".format(in_out, target)] = float(pkts)

                    hgram_name = "packets_{}_{}_ahhh".format(in_out, target)

                ugram = self.ugrams[family].get(hgram_name)
                if not ugram:
                    self.ugrams[family][hgram_name] = metrics.Ugram(int(pkts))
                else:
                    ugram.update(int(pkts))
        r = []
        for k, v in result.items():
            r.append([k, v])

        for k, v in self.ugrams[family].items():
            r.append([k, v.value])

        return r

    def update_dropcount(self):
        dropcount = {}
        for family in ['v4', 'v6']:
            dropcount[family] = []
            current = self.rules[family]['current']
            if current:
                r = self.calc_drops(family, current)
                for item in r:
                    dropcount[family].append(item)

                if family == 'v6':
                    mode, msg, tags = self.determine_drop_juggler_mode_msg(family)
                    if mode == 'OK':
                        jugglerutil.push_drops_ok()
                    else:
                        jugglerutil.push_drops_crit(msg, tags)
        APIHandler.data.dropcount = dropcount

    def determine_drop_juggler_mode_msg(self, family):
        d = {}
        fam_d = self.ugrams.get(family, {})
        for hgram_name, buckets in fam_d.items():
            if 'log' in hgram_name:
                continue

            name = hgram_name.replace('packets_', '').replace('_ahhh', '')
            d[name] = None
            for bucket, value in buckets.value:
                if value == 1:
                    if bucket != 0:
                        d[name] = bucket
                    break

        mode = 'OK'
        msg = ''
        if not d:
            return 'CRIT', 'ugrams are empty', ['empty_ugrams']

        if self.geo:
            tags = [self.geo]
        else:
            tags = []
        for k, v in d.items():
            if v is not None:
                mode = 'CRIT'
                if msg:
                    msg = ', '.join((msg, '{}:{}'.format(k, v)))
                else:
                    msg = '{}:{}'.format(k, v)
                tags.append('{}_{}'.format(k, v))
        return mode, msg, tags

    def update_unistat(self, loop_time, hbf_disabled=False):
        unistat = list()
        # add stats only if hbf enabled
        if not hbf_disabled:
            unistat.append(["loop_time_ahhh", loop_time])
            unistat.append(["host_ips_count_ahhh", len(self.host_ips) if self.host_ips else 0])
            unistat.append(["guest_ips_count_ahhh", len(self.guest_ips) if self.guest_ips else 0])

            for family in 'v4', 'v6':
                current = self.rules[family]['current']

                if current:
                    unistat.append(["{}_rule_count_ahhh".format(family), float(current.rule_count)])
                    unistat.append(["{}_chain_count_ahhh".format(family), current.chain_count])
                    unistat.append(["{}_test_time_ahhh".format(family), current.test_time])
                    unistat.append(["{}_apply_time_ahhh".format(family), current.apply_time])
                    unistat.append(["{}_apply_time_protected_ahhh".format(family), current.apply_time_protected])
                    unistat.append(["{}_gc_time_ahhh".format(family), current.gc_time])

            unistat.append(["server_request_duration_ahhh", self.server_request_duration])
            unistat.append(["server_download_duration_ahhh", self.server_download_duration])

            unistat.append(["error_modules_dmmm", self.modules_error_count])
            unistat.append(["error_iptables_dmmm", self.iptables_error_count])
            unistat.append(["error_server_dmmm", self.server_error_count])
            unistat.append(["error_agent_dmmm", self.agent_error_count])

            disabled_value = 0

        if hbf_disabled:
            disabled_value = 1
        # always append disabled counter
        unistat.append(["agent_disabled_ammm", disabled_value])

        APIHandler.data.unistat = unistat

    def wait(self):
        if self.config["hooks"]["guest_ips"]:
            log.info("Polling hook 'guest_ips' to collect guest IP addresses.")

        now = time.time()
        changed = False
        next_guests_check = now

        while not disabling_flag_exists(self.hbf_disable_flag):
            self.handle_dump()

            if self.signal_handler.reload:
                return

            if self.iteration_needed(now):
                if self.limiter.consume():
                    return

            if not changed and now >= next_guests_check:
                # In fact, we fetch everything in the main loop again, just trigger on changes
                new_ids, changed = self.fetch_guest_ips(self.guest_ips)
                next_guests_check = now + self.update_guests_period

            if changed and new_ids:
                if self.limiter.consume():
                    log.info('Performing out-of-band request, changes detected')
                    self.force_apply.set()
                    return
                else:
                    log.debug('Out-of-band request throttled')

            time.sleep(1)
            now = time.time()


def run_ip_hook(hook_value):
    ips = set()
    if hook_value:
        mod_names = [x.strip() for x in hook_value.split(",")]
        for name in mod_names:
            ips |= run_ip_mod(name)
    return ips


def run_ip_mod(name):
    log.debug("Running module '{}'.".format(name))
    ips = set()
    mod = import_module("hbfagent.mod." + name)
    for x in mod.run():
        if isinstance(x, (str, unicode)):
            try:
                ip = ipaddr.IPAddress(x)
            except ValueError:
                msg = "Module '{}' returned invalid IP address '{}'."
                log.error(msg.format(name, x))
            else:
                if not (ip.is_link_local or ip.is_loopback or
                        ip.is_private or ip.is_reserved or
                        ip.is_unspecified):
                    ips.add(ip)
        else:
            ips.add(x)
    log.debug("Found: {}.".format(ips_to_string(ips)))
    return ips


def parse_targets_list(targets_path):
    targets = set()
    try:
        with open(targets_path) as f:
            for line in f:
                line = line.strip()
                if not line or line.startswith("#"):
                    continue
                targets.add(line)
    except IOError as e:
        if e.errno != errno.ENOENT:
            raise
    return targets


def get_local_rules(ip_version, current, default):
    """Find all rule files, load and concatenate into single IPTables object.
    """

    current = path.join(current, "rules.d")
    default = path.join(default, "rules.d")
    config = Config()
    use_yandex_iptables = config.use_yandex_iptables()
    tables = IPTables(ip_version, use_yandex_iptables=use_yandex_iptables)
    for file_name in list_overlay_dirs([current, default], ip_version):
        log.info("Reading rules from '{}'.".format(file_name))
        try:
            with open(file_name, "r") as f:
                table = IPTables(ip_version, dump=f.read(), use_yandex_iptables=use_yandex_iptables)
        except Exception as e:
            log.exception("Exception:")
            status.crit(
                "Exception while reading '{}': {}".format(file_name, e)
            )
        else:
            t, c, r = table.count()
            log.debug("Loaded {} tables, {} chains, {} rules.".format(t, c, r))
            tables += table
    return tables


def list_overlay_dirs(dirs, ip_version):
    seen = set()
    entries = []
    for d in dirs:
        if path.isdir(d):
            for f in os.listdir(d):
                full_path = path.join(d, f)
                if (path.isfile(full_path) and f.endswith(ip_version) and
                        f not in seen):
                    seen.add(f)
                    entries.append((d, f))
    entries.sort(key=itemgetter(1))
    return [path.join(e[0], e[1]) for e in entries]


def form_requests(method, url, try_ips, **kwargs):
    reqs = []
    if not try_ips:
        reqs.append(requests.Request(method, url, **kwargs))
        return reqs

    config = Config()
    server_url = config["main"]["server_url"]
    server_ips = [i.strip() for i in config["main"]["server_ips"].split(',') if i and not i.isspace()]

    for si in server_ips:
        log.debug("Form url for server_ip: {}".format(si))
        if netaddr.valid_ipv4(si):
            log.debug("Valid ipv4")
            url_f = url.replace(server_url, 'https://{}'.format(si))
        elif netaddr.valid_ipv6(si):
            log.debug("Valid ipv6")
            url_f = url.replace(server_url, 'https://[{}]'.format(si))
        else:
            log.debug("Config: not valid IP in server_ips field: {}".format(si))
            continue

        log.debug("Formed url: {}".format(url_f))
        req = requests.Request(method, url_f, **kwargs)
        reqs.append(req)

    return reqs


def do_reqs(reqs, try_ips, timeout):
    config = Config()
    server_url = config['main']['server_url']

    # oversimplified way to get hostname for 'Host' header
    host_header = re.match(r'^https?://([a-zA-Z0-9.-]+)/?', server_url).group(1)

    known_good_status_codes = (200, 304, 429)
    with requests.Session() as s:
        if try_ips:
            s.mount('https://', host_header_ssl.HostHeaderSSLAdapter())
        for i, req in enumerate(reqs):
            last_try = False
            if i == len(reqs)-1:
                last_try = True

            req_prepared = s.prepare_request(req)
            req_prepared.headers['Host'] = host_header
            if try_ips:
                log.info("Trying {}".format(req_prepared.url))
            try:
                response = s.send(req_prepared, timeout=timeout)
                if response.status_code in known_good_status_codes:
                    return response
                raise requests.exceptions.HTTPError(response=response)
            except Exception as e:
                log.warning("Got exception: {}".format(e))
                if last_try:
                    log.debug("This is the last try with try_ips == {}".format(try_ips))
                    raise


def perform_requests(method, url, try_ips, **kwargs):
    timeout = kwargs.pop('timeout', 15)
    reqs = form_requests(method, url, try_ips, **kwargs)
    if not reqs:
        raise Exception("No requests are present!")
    response = do_reqs(reqs, try_ips, timeout)
    if response.status_code != 200:
        raise requests.exceptions.HTTPError(response=response)
    return response


def get_serverside_config(targets, previous_response=None):
    log.info("Loading HBF rules.")
    config = Config()

    serverside = ServersideConfig()

    previous_url, previous_last_modified = None, None
    if previous_response is not None:
        previous_url, previous_last_modified = previous_response

    options = "&".join([o.strip() for o in config["main"]["server_options"].split(',')])
    quoted_targets = ",".join([urllib.quote(target, ":") for target in targets])
    targets = ",".join(targets)

    server_url = config["main"]["server_url"]
    post_targets = {}
    serverside.full_url = "%s/get/%s?%s" % (server_url, targets, options)  # this used in valve
    url = "%s/get/%s?%s" % (server_url, quoted_targets, options)  # this used to do proper GET to hbf-server
    if len(url) > config["main"]["get_targets_uri_limit"]:
        url = "%s/get/?%s" % (server_url, options)
        post_targets['targets'] = targets  # this used in POST to hbf-server

    method = 'GET'
    if post_targets:
        method = 'POST'

    headers = {
        'User-Agent': 'yandex-hbf-agent: {}'.format(svn_version.svn_revision()),
        'Accept-Encoding': 'gzip',
    }

    if serverside.full_url == previous_url and previous_last_modified:
        log.debug("Will use 'If-Modified-Since: {}'".format(previous_last_modified))
        headers["If-Modified-Since"] = previous_last_modified
    else:
        log.info("URL: <{}>".format(url))
        if post_targets:
            log.info("POST targets: %s" % post_targets)

    t = Timer()

    req_kwargs = {
        'headers': headers,
        'timeout': config['main']['server_timeout'],
        'data': post_targets,
    }
    try:
        try_ips = False
        response = perform_requests(method, url, try_ips, **req_kwargs)
    except requests.exceptions.ConnectionError:
        log.info("Trying to get rules with server_ips")
        try_ips = True
        response = perform_requests(method, url, try_ips, **req_kwargs)

    serverside.download_time = t.interval

    h = response.headers
    log.debug("Response headers: {}".format(h))
    serverside.raw_data = response.content

    serverside.last_modified = h.get("Last-Modified", None)
    real_server = h.get("Real-server", None)
    if real_server:
        real_server = real_server.split(',')[-1]  # https://a.yandex-team.ru/review/1064540/details#comment-1576311

    x_request_id = h.get("X-Request-Id", None)
    x_training = h.get("X-HBF-Training", None)
    serverside.is_training_mode = False
    if x_training and x_training == 'true':
        serverside.is_training_mode = True

    msg = "Status code: {}, response length: {} bytes, Real-server: {}, X-Request-Id: {}"
    log.info(msg.format(response.status_code, len(serverside.raw_data), real_server, x_request_id))

    check_output = "output" in [o.strip() for o in config["main"]["server_options"].split(',')]
    v4_tables, v6_tables = parse_remote_rules(serverside.raw_data, check_output)
    serverside.v4_tables = v4_tables
    serverside.v6_tables = v6_tables
    return serverside
    # return v4_tables, v6_tables, (full_url, last_modified), download_time, data, training


def parse_remote_rules(data, check_output=False):

    def unexpected_error(what):
        msg = "Unexpected '{}' at line {}.".format(what, n)
        raise AgentError(msg, data)

    v4 = ""
    v6 = ""
    current_section = None
    v4_end = False
    v6_end = False
    config = Config()
    use_yandex_iptables = config.use_yandex_iptables()
    for n, line in enumerate(data.splitlines(), 1):
        if line.startswith("#"):
            if "ERROR:" in line:
                msg = "Error in server answer at line {}".format(n)
                log.error(msg + ":\n" + line)
                status.crit(msg)
            elif line == "#BEGIN IPTABLES":
                if current_section is not None or v4_end:
                    unexpected_error("#BEGIN IPTABLES")
                current_section = "v4"
            elif line == "#BEGIN IP6TABLES":
                if current_section is not None or v6_end:
                    unexpected_error("#BEGIN IP6TABLES")
                current_section = "v6"
            elif line == '#END IPTABLES':
                if current_section != "v4" or v4_end:
                    unexpected_error("#END IPTABLES")
                current_section = None
                v4_end = True
            elif line == '#END IP6TABLES' or v6_end:
                if current_section != "v6":
                    unexpected_error("#END IP6TABLES")
                current_section = None
                v6_end = True
        elif current_section == "v4":
            v4 += line + "\n"
        elif current_section == "v6":
            v6 += line + "\n"
        elif line.strip() == "":
            continue
        else:
            msg = "Unexpected data in remote rules at line {}".format(n)
            log.warning(msg + ".")
            status.warn(msg)

    if not (v4_end and v6_end):
        raise AgentError("Received incomplete data from HBF server.", data)

    v4_tables = IPTables("v4", dump=v4, use_yandex_iptables=use_yandex_iptables)
    v6_tables = IPTables("v6", dump=v6, use_yandex_iptables=use_yandex_iptables)

    for tables in (v4_tables, v6_tables):
        t, c, r = tables.count()
        log.info("Remote IP{} rules contain {} tables, {} chains,"
                 " {} rules.".format(tables.ip_version, t, c, r))

    for tables in (v4_tables, v6_tables):
        if "Y_FW" not in tables["filter"]:
            msg = "Remote IP{} tables has no Y_FW rules.".format(
                tables.ip_version
            )
            raise AgentError(msg, str(tables))
        if check_output:
            if "Y_FW_OUT" not in tables["filter"]:
                msg = "Remote IP{} tables has no Y_FW_OUT rules.".format(
                    tables.ip_version
                )
                raise AgentError(msg, str(tables))

    return v4_tables, v6_tables


def run_rule_hook(hook_value, ip_version, host_ips, guest_ips):
    config = Config()
    use_yandex_iptables = config.use_yandex_iptables()
    rules = IPTables(ip_version, use_yandex_iptables=use_yandex_iptables)
    if hook_value:
        mod_names = [x.strip() for x in hook_value.split(",")]
        for mod_name in mod_names:
            msg = "Running module '{}' to generate IP{} rules."
            log.info(msg.format(mod_name, ip_version))
            try:
                mod = import_module("hbfagent.mod." + mod_name)
                mod_rules = mod.run(ip_version, host_ips, guest_ips)
            except Exception:
                msg = "Exception while running module '{}':".format(mod_name)
                log.exception(msg)
            else:
                rules += mod_rules
                t, c, r = mod_rules.count()
                msg = "Generated {} tables, {} chains, {} rules."
                log.debug(msg.format(t, c, r))
    return rules


def apply(tables, force=False):
    result = True
    try:
        tables.apply(force=force)
    except IPTablesError as e:
        msg = "Unable to apply IP{} rules: {}".format(tables.ip_version, e)
        log.exception(msg)
        status.crit(msg)
        result = False
    return result


def project_id_in_set(address, ips):
    project_id = get_project_id(address)
    if project_id is None:
        return False
    for ip in ips:
        if get_project_id(ip) == project_id:
            return True
    return False


def ips_to_string(ips):
    return quoted_strings(ips) if ips else "none"


def quoted_strings(xs):
    return ", ".join(["'" + str(x) + "'" for x in xs])


def dump(rules):
    config = Config()
    root_logger = logging.getLogger()
    prev_level = root_logger.getEffectiveLevel()
    root_logger.setLevel(logging.DEBUG)
    try:
        log.debug("Dumping configuration:\n" + str(config))
        for ip_version in rules:
            for rules_type in rules[ip_version]:
                log.debug("Dumping IP{} '{}' rules:".format(
                    ip_version, rules_type
                ))
                tables = rules[ip_version][rules_type]
                if tables:
                    tables.dump()
                else:
                    log.debug(tables)
    finally:
        root_logger.setLevel(prev_level)


def is_applied(applied):
    applied_res = True
    err_template = 'could not apply {} rules'
    msg = ""
    for k, v in applied.items():
        applied_res &= v
        if v is False:
            if msg:
                msg = ';'.join((msg, err_template.format(k)))
            else:
                msg = err_template.format(k)
    return applied_res, msg
