# -*- encoding: utf-8 -*-

import logging
import os
import pickle
import re
import socket
import sys
from datetime import datetime

logger = logging.getLogger(__name__)

try:
    import pandas
    import sklearn  # noqa: F401
    from catboost import CatBoostClassifier, Pool, cv  # noqa: F401
except ImportError as e:
    logger.error(f"Can't import ML-related libraries {str(e)}")

os.environ["REQUESTS_CA_BUNDLE"] = "/etc/ssl/certs/ca-certificates.crt"

import requests.packages.urllib3.util.connection as urllib3_cn

PUNCHER_ID = re.compile(r"https://puncher.yandex-team.ru/tasks\?id=(\w+)")


# force ipv6 only
def allowed_gai_family():
    family = socket.AF_INET
    if urllib3_cn.HAS_IPV6:
        family = socket.AF_INET6
    return family


urllib3_cn.allowed_gai_family = allowed_gai_family

# TODO: move to config
SIB_PEOPLE = "ezaitov,buglloc,tokza,a-abakumov,ybuchnev,aleksei-m,shikari"
SIB_PEOPLE += "energen,horus,irmin,buglloc,ezaitov,anton-k,gots,"
SIB_PEOPLE += "kaleda,aleksei-m,axlodin,limetime,melkikh,naumov-al,procenkoeg"
SIB_PEOPLE = SIB_PEOPLE.split(",")

# TODO: move to config
BAD_AUTHORS = ["smikler"] + SIB_PEOPLE

# TODO: move to config
CORE_APPROVERS = ["ezaitov", "horus", "a-abakumov", "anton-k"]


class PuncherClassifier(object):
    # TODO: move to config
    BOOTCAMP_BOSSES = ["zzhanka"]
    NOT_FEATURES = ["status", "id"]

    def __init__(
        self, root_path, puncher, tvm, startrek, staff, abc, waffles, bad_ports
    ):
        self.puncher = puncher
        self.startrek = startrek
        self.staff = staff
        self.abc = abc
        self.tvm = tvm
        self.waffles = waffles
        self.bad_ports = bad_ports
        self.catboost_model = None
        if "pandas" not in sys.modules:
            logger.error("Pandas not found")
            return
        model_path = os.path.join(root_path, "ml", "catboost_model.pickle")
        try:
            with open(model_path, "rb") as fd:
                self.catboost_model = pickle.load(fd)
        except Exception as e:
            logger.error("Classifier model loading error", str(e))
            pass

    def get_rule_features_by_ticket(self, issue_key, rule_id=""):
        issue = self.startrek.issues[issue_key]
        if "noml" in issue.tags:
            return {}
        # temporary
        if issue.assignee and issue.assignee.login not in CORE_APPROVERS:
            return {}
        if not rule_id:
            rule_ids = PUNCHER_ID.findall(issue.description)
            if not rule_ids:
                return {}
            rule_id = rule_ids[0]
        if not rule_id:
            return {}
        rule = self.puncher.get_rule(rule_id)
        return self.get_rule_features(rule)

    def get_src_features(self, src):
        features = dict()
        features["has_zombies"] = 0
        features["has_personal"] = 0
        features["has_service"] = 0
        features["has_servicerole"] = 0
        features["has_department"] = 0
        features["has_allstaff"] = 0
        features["has_wikigroup"] = 0
        features["size"] = len(src)
        features["has_external"] = 0
        features["actual_people"] = 0
        features["personal_people"] = 0
        features["has_unusual_bosses"] = 0
        for item in src:
            # we skip rules from SIB for now b/c they are bad for ML
            if item.get("machine_name", "") == "@dpt_yandex_mnt_security@":
                return {"need_skip": True}
            # we skip rules from any for now b/c they are bad for ML
            if item.get("type", "") == "any":
                return {"need_skip": True}
            if item.get("type") == "robot":
                features["has_zombies"] = 1
            if item.get("type") == "user":
                features["has_personal"] = 1
                features["actual_people"] += 1
                features["personal_people"] += 1
                if (
                    not features["has_unusual_bosses"]
                    and self.staff.get_person_chief_list(item.get("machine_name")[1:-1])
                    in self.BOOTCAMP_BOSSES
                ):
                    features["has_unusual_bosses"] = 1
            if item.get("title", {}).get("en", "") == "@allstaff@":
                features["has_allstaff"] = 1
                features["actual_people"] = 10000
            if item.get("type") == "servicerole":
                features["has_servicerole"] = 1
                features["actual_people"] += len(
                    self.staff.get_group_members(
                        group_url=item["machine_name"][5:-1],
                        group_type="servicerole",
                        with_subgroups=False,
                        limit=1000,
                    )
                )
            if item.get("type") == "service":
                features["has_service"] = 1
                features["actual_people"] += len(
                    self.abc.get_people_by_id(service_id=item["url"][36:])
                )
            if item.get("type") == "department":
                features["has_department"] = 1
                features["actual_people"] += len(
                    self.staff.get_group_members(
                        group_type="department",
                        group_url=item["machine_name"][5:-1],
                        limit=1000,
                    )
                )
            if item.get("type") == "wiki":
                features["has_wikigroup"] = 1
                features["actual_people"] += len(
                    self.staff.get_group_members(
                        group_url=item["machine_name"][6:-1],
                        group_type="wiki",
                        with_subgroups=False,
                        limit=1000,
                    )
                )
            if item.get("external"):
                features["has_external"] = 1
        return features

    def get_approver_features(self, issue_key):
        features = dict()
        issue = self.startrek.issues[issue_key]
        if not issue.assignee:
            return features
        return features

    def get_dst_features(self, dst):
        features = dict()
        features["has_vhost"] = 0
        features["size"] = len(dst)
        features["has_hostmacro"] = 0
        features["has_macro"] = 0
        features["vhost_count"] = 0
        for item in dst:
            # we skip rules to internet for now b/c they are bad for ML
            if item.get("type", "") == "inet":
                return {"need_skip": True}
            if item.get("type") == u"hostmacro":
                features["has_host_macro"] = 1
            if item.get("type") == u"macro":
                features["has_macro"] = 1
            if item.get("type") == u"virtualservice":
                features["has_vhost"] = 1
                host = item.get("machine_name")
                ip = self.waffles.do_resolve(host)
                if not ip:
                    continue
                features["vhost_count"] += len(self.waffles.get_virtual_hosts(ip))
        return features

    def get_author_features(self, author, issue=None):
        username = author.get("login")
        features = dict()
        features["external"] = 0
        features["is_homeworker"] = 0
        features["office_id"] = 0
        # SIB generates strange rules which are bad for ML
        if username in BAD_AUTHORS or self.staff.is_person_from_security_group(
            username
        ):
            return {"need_skip": True}
        if author.get("external"):
            features["external"] = 1
        if self.staff.is_homeworker(username):
            features["is_homeworker"] = 1
        office_id = self.staff.get_person_info(username, field="location.office.id")
        if office_id:
            features["office_id"] = office_id
        features["bosses_count"] = len(self.staff.get_person_chief_list(username))
        join_at = self.staff.get_person_info(username, field="official.join_at")
        join_datetime = datetime.strptime(join_at, "%Y-%m-%d")
        days_since_join = (datetime.now() - join_datetime).days
        features["exp_category"] = 3
        if days_since_join < 90:
            features["exp_category"] = 0
        elif 365 <= days_since_join > 90:
            features["exp_category"] = 1
        elif 730 > days_since_join > 365:
            features["exp_category"] = 2
        return features

    def get_resp_features(self, resp):
        features = dict()
        features["has_external"] = 0
        # SIB generates strange rules which are bad for ML
        for item in resp:
            if item.get("login") in SIB_PEOPLE:
                return {"need_skip": True}
            if item.get("external"):
                features["has_external"] = 1
        return features

    def get_rule_features(self, rule):
        features = dict()
        features["id"] = rule.get("id")
        features["rule_type"] = 1 if rule.get("system") == "puncher" else 0
        # skip static rules for now
        if rule.get("system") == "static":
            return {}
        features["status"] = 1 if rule.get("status") == "approved" else 0
        features["proto"] = 1 if rule.get("protocol") == "tcp" else 0
        ports = []
        for port in rule.get("ports"):
            ports += self.puncher.port_parse(port)
        features["ports_count"] = len(ports)
        features["has_unspecified_ports"] = (
            1 if any([p for p in ports if p not in [80, 443, 22]]) else 0
        )
        features["has_https"] = 1 if 443 in ports else 0
        features["has_http"] = 1 if 80 in ports else 0
        features["has_ssh"] = 1 if 22 in ports else 0
        features["has_bad_ports"] = (
            1 if any([p in ports for p in self.bad_ports]) else 0
        )
        features["temporary"] = 1 if rule.get("until") else 0
        features["comment_len"] = len(rule.get("comment"))
        features["comment_has_ticket"] = (
            1 if re.findall(r"([A-Z]+\-[0-9]+)", rule.get("comment")) else 0
        )

        approver_features = self.get_approver_features(rule.get("task"))
        if approver_features.get("need_skip"):
            return {}
        for k in approver_features.keys():
            features["approver_{}".format(k)] = approver_features[k]

        author_features = self.get_author_features(rule.get("author"))
        if author_features.get("need_skip"):
            return {}
        for k in author_features.keys():
            features["author_{}".format(k)] = author_features[k]

        src_features = self.get_src_features(rule.get("sources"))
        if src_features.get("need_skip"):
            return {}
        for k in src_features.keys():
            features["src_{}".format(k)] = src_features[k]

        resp_features = self.get_resp_features(rule.get("responsibles"))
        if resp_features.get("need_skip"):
            return {}
        for k in resp_features.keys():
            features["resp_{}".format(k)] = resp_features[k]

        dst_features = self.get_dst_features(rule.get("destinations"))
        if dst_features.get("need_skip"):
            return {}
        for k in dst_features.keys():
            features["dst_{}".format(k)] = dst_features[k]
        return features

    def prepare_rule_data(self, rule_features):
        feature_keys = [k for k in rule_features.keys() if k not in self.NOT_FEATURES]
        rule_data = pandas.DataFrame.from_dict([rule_features])
        return rule_data[feature_keys].fillna(0).values

    def predict(self, rule_features):
        if not self.catboost_model:
            return None
        if not rule_features:
            return None
        rule_data = self.prepare_rule_data(rule_features)
        try:
            res = self.catboost_model.predict(rule_data)
        except Exception as e:
            logger.error("Rule classification error", str(e))
            return None
        else:
            return res

    def predict_proba(self, rule_features):
        if not self.catboost_model:
            return None
        if not rule_features:
            return None
        rule_data = self.prepare_rule_data(rule_features)
        try:
            proba = self.catboost_model.predict_proba(rule_data)
        except Exception as e:
            logger.error("Rule classification error", str(e))
            return None
        else:
            return proba
