import logging
import re
from collections import defaultdict
from datetime import datetime
from hashlib import sha256

import six
from mongoengine import StringField, ListField, LongField

from sepelib.mongo.util import register_model
from walle.clients import hbf, juggler
from walle.expert.constants import HW_WATCHER_CHECK_MAX_POSSIBLE_DELAY
from walle.models import Document, timestamp
from walle.trypo_radix import TRYPOCompatibleRadix

log = logging.getLogger(__name__)


@register_model
class HbfDrill(Document):
    id = StringField(primary_key=True, help_text="Drill object synthetic primary key")
    project_macro = StringField(required=True, help_text="Macro of project undergoing the drills")
    project_ips = ListField(StringField(), required=True, help_text="IPs of macro")
    location = StringField(required=True, help_text="Name of datacenter")
    start_ts = LongField(required=True, help_text="Timestamp of drill start")
    end_ts = LongField(required=True, help_text="Timestamp of drill end")
    exclude_ips = ListField(StringField(), help_text="IPs excluded from the drill")
    obj_hash = StringField(required=True, help_text="Hash of an object, used for updating and caching")

    api_fields = ("id", "project_macro", "project_ips", "location", "start_ts", "end_ts", "exclude_ips")

    def mk_id(self):
        return "{}|{}|{}|{}".format(self.location, self.project_macro, self.start_ts, self.end_ts)

    def __repr__(self):
        return "{}|{}|{}|{} (effective until {})".format(
            self.location,
            self.project_macro,
            _ts_repr(self.start_ts),
            _ts_repr(self.end_ts),
            _ts_repr(self.obsoletion_ts),
        )

    def set_hash(self):
        def to_str(val):
            if isinstance(val, list):
                return ",".join(str(o) for o in val)
            else:
                return str(val)

        as_str = "|".join(to_str(getattr(self, field)) for field in self.api_fields)
        self.obj_hash = sha256(six.ensure_binary(as_str, "utf-8")).hexdigest()

    @property
    def obsoletion_ts(self):
        # drill object is applicable some time after official end (until host health comes back after network is open)
        return self.end_ts + HW_WATCHER_CHECK_MAX_POSSIBLE_DELAY


def _process_hbf_drills_wrapper():
    with juggler.exception_monitor(
        "wall-e-hbf-drills-update", err_msg_tmpl="HBF drills update failed: {exc}", reraise=False
    ):
        _process_hbf_drills()


def _process_hbf_drills():
    incoming_drills = _get_incoming_drills()
    existing_drills = {d.id: d for d in HbfDrill.objects.only("id", "obj_hash", "start_ts", "end_ts")}
    incoming_drills_ids, existing_drills_ids = set(incoming_drills), set(existing_drills)

    # add new drills
    for drill_id_to_add in incoming_drills_ids - existing_drills_ids:
        log.info("Adding new HBF drill %s", drill_id_to_add)
        incoming_drills[drill_id_to_add].save()

    # update (or not) existing drills
    for drill_id_to_update in incoming_drills_ids.intersection(existing_drills_ids):
        incoming, existing = incoming_drills[drill_id_to_update], existing_drills[drill_id_to_update]
        if incoming.obj_hash != existing.obj_hash:
            log.info("Updating HBF drill %s because its hash has changed", drill_id_to_update)
            existing.delete()
            incoming.save()

    # delete drills that do not appear anymore in incoming drills
    # * if drill disappeared from handle before its start or after obsoletion_ts (complete end) -- just delete it
    # * if drill is in progress, before deleting the drill object we must wait for the hosts' health data to arrive
    # * * we do it by setting drill.end_time to current timestamp (there is no way to schedule deletion at some time,
    #     as this is a cron job), so after health data update period it will be deleted by the main rule.
    #     It may be not feasible if drill is already in process of getting hosts' health info after
    #     end_ts (between end_ts and obsoletion_ts, setting end_ts to cur_ts would just make us wait longer)
    for drill_id_to_delete in existing_drills_ids - incoming_drills_ids:
        drill = existing_drills[drill_id_to_delete]
        cur_ts = timestamp()
        if cur_ts < drill.start_ts or cur_ts > drill.obsoletion_ts:
            log.info("Deleting vanished HBF drill %s", drill_id_to_delete)
            drill.delete()
        elif cur_ts < drill.end_ts:
            drill = HbfDrill.objects(id=drill_id_to_delete).get()
            log.info("Setting end_ts of %s to %s", repr(drill), cur_ts)
            drill.end_ts = cur_ts
            drill.set_hash()
            drill.save()


def _get_incoming_drills():
    incoming_drills = {}
    for drill in hbf.get_hbf_drills():
        obj = _make_drill_obj(drill)
        incoming_drills[obj.id] = obj
    return incoming_drills


def _make_drill_obj(raw):
    drill = HbfDrill(project_ips=raw["project_ips"], exclude_ips=raw["exclude_ips"])
    drill.project_macro = raw["project"]
    drill.location = re.search(hbf.RT_DATACENTER_RE, raw["location"]).group(1).lower()
    drill.start_ts = raw["begin"]
    drill.end_ts = drill.start_ts + raw["duration"]

    drill.id = drill.mk_id()
    drill.set_hash()

    return drill


def _ts_repr(ts):
    return datetime.fromtimestamp(ts).strftime("%Y-%m-%d %H:%M:%S")


class DrillHostMatcher:
    """
    Wraps HbfDrill data with the ability to check if host's ip addrs match the drill (checks both project IPs
    and exclusions)
    """

    def __init__(self, drill):
        """
        :type drill: HbfDrill
        """
        self.drill = drill

        self.project_ips_matcher = TRYPOCompatibleRadix()
        for ip in drill.project_ips:
            self.project_ips_matcher.add(ip)

        self.excluded_ips_matcher = TRYPOCompatibleRadix()
        for ip in drill.exclude_ips:
            self.excluded_ips_matcher.add(ip)

    def match_host(self, host):
        """
        :type host: Host
        :returns match_reason (str or None)
        """
        # we have to think of a host as involved in a drill even after drill's end
        # because otherwise host healing will start right after drill is finished (because there won't be actual
        # health data for the host)
        if not self.drill.start_ts <= timestamp() <= self.drill.obsoletion_ts:
            return None

        for ip in host.ips:
            match_node = self.project_ips_matcher.search_best(ip)
            if match_node:
                exclude_node = self.excluded_ips_matcher.search_best(ip)
                if exclude_node:
                    break
                reason = "host's IP addr {} belongs to network {} involved in HBF drill {!r}".format(
                    ip, TRYPOCompatibleRadix.node_repr(match_node), self.drill
                )
                return reason

        return None

    def __repr__(self):
        return "<{}: {}>".format(self.__class__, self.drill)


def get_hbf_drills():
    return list(HbfDrill.objects().order_by("id"))


class HbfDrillsCollection:
    """
    Wraps all existing HbfDrill objects in DrillHostMatcher, allows to check if host participates in any of current
    HBF drills
    """

    def __init__(self, drills):
        self.dc_to_drills = defaultdict(list)
        for drill in drills:
            matcher = DrillHostMatcher(drill)
            self.dc_to_drills[drill.location].append(matcher)

    def get_host_inclusion_reason(self, host):
        host_dc = host.location.short_datacenter_name
        if not host.ips or host_dc not in self.dc_to_drills:
            return None

        for matcher in self.dc_to_drills[host_dc]:
            inclusion_reason = matcher.match_host(host)
            if inclusion_reason:
                return inclusion_reason

        return None


class HbfDrillsCollectionCache:
    def __init__(self):
        self.drills_hashes = None
        self.collection = None

    def get(self):
        hashes_in_db = self._get_hashes_in_db()
        if self.drills_hashes != hashes_in_db:
            drills = get_hbf_drills()
            self.collection = HbfDrillsCollection(drills)
            self.drills_hashes = hashes_in_db
        return self.collection

    def _get_hashes_in_db(self):
        return set(HbfDrill.get_collection().distinct(HbfDrill.obj_hash.db_field))


drills_cache = HbfDrillsCollectionCache()
