import logging
from datetime import datetime, timedelta
from typing import Any, Dict, Iterable, List, Tuple

import ldap
from django.conf import settings


logger = logging.getLogger('staff.ldap_helper')


class LDAPException(Exception):
    pass


class LDAPContext(object):
    ad = None

    def __init__(self, dc, user=settings.LDAP_USER, password=settings.LDAP_PASSWORD):
        self.dc = dc
        self.user = user
        self.password = password

    @property
    def uri(self):
        return 'ldap://%s' % self.dc

    def bind(self):
        self.ad = ldap.initialize(self.uri)

        self.ad.set_option(ldap.OPT_DEBUG_LEVEL, 255)
        self.ad.set_option(ldap.OPT_NETWORK_TIMEOUT, 1)
        self.ad.set_option(ldap.OPT_PROTOCOL_VERSION, ldap.VERSION3)
        self.ad.set_option(ldap.OPT_REFERRALS, 0)
        self.ad.set_option(ldap.OPT_TIMEOUT, 1)
        self.ad.set_option(ldap.OPT_X_TLS, ldap.OPT_X_TLS_DEMAND)
        self.ad.set_option(ldap.OPT_X_TLS_CACERTFILE, settings.CA_BUNDLE)
        self.ad.set_option(ldap.OPT_X_TLS_DEMAND, True)
        # self.ad.set_option(ldap.OPT_X_TLS_NEWCTX, 0)
        self.ad.set_option(ldap.OPT_X_TLS_REQUIRE_CERT, ldap.OPT_X_TLS_NEVER)

        self.ad.start_tls_s()

        self.ad.simple_bind_s(self.user, self.password)
        return self.ad

    def unbind(self):
        try:
            self.ad.unbind_s()
        except ldap.LDAPError:
            logger.warning(
                'Can\'t disconect from domain controller %s', self.dc
            )

    def __enter__(self):
        return self.bind()

    def __exit__(self, type, value, traceback):
        self.unbind()
        if value:
            raise


def _ldap_search(search_dn: Tuple[str, ...], query: str, fields: Iterable[str], scope: int = None) -> List[Tuple]:
    if not isinstance(search_dn, (list, tuple)):
        search_dn = [search_dn]

    if scope is None:
        scope = ldap.SCOPE_SUBTREE

    retries = 5
    for n in range(retries):
        try:
            with LDAPContext(settings.LDAP_HOST) as ad:
                for basedn in search_dn:
                    result = ad.search_s(
                        basedn, scope, str(query), list(map(str, fields)),
                    )
                    if result:
                        return result
        except (ldap.LDAPError, TypeError):
            logger.warning('Can\'t search in domain controller %s', settings.LDAP_HOST)
            continue

    raise LDAPException('Connection problem')


def _decode_bytes(maybe_bytes):
    if isinstance(maybe_bytes, bytes):
        try:
            return maybe_bytes.decode('utf-8')
        except UnicodeDecodeError:
            return maybe_bytes
    elif isinstance(maybe_bytes, (list, tuple)):
        return [_decode_bytes(i) for i in maybe_bytes]
    return maybe_bytes


def encode_str(maybe_str):
    if isinstance(maybe_str, str):
        try:
            return maybe_str.encode('utf-8')
        except UnicodeEncodeError:
            return maybe_str
    elif isinstance(maybe_str, (list, tuple)):
        return [encode_str(i) for i in maybe_str]
    return maybe_str


def ldap_search(
    search_dn: Tuple[str, ...],
    query: str,
    fields_map: Dict[str, callable],
    scope: int = None,
) -> Dict[str, Any]:
    result = _ldap_search(search_dn, query, fields_map.keys(), scope)
    result = dict(f(_decode_bytes(result[0][1].get(k))) for k, f in fields_map.items())
    return result


_MAX_VALUE = (datetime.max - timedelta(days=1)).timestamp()


def from_ad_time(ad_time):
    """
    This Active Directory attribute pwdLastSet
    uses a timestamp that is stored as a large integer
    that represents the number
    of 100 nanosecond intervals since 1 January 1601.
    When we’re familiar with working with Unix epoch time,
    it is not really handy.
    Epoch is an integer that represents the time since 1 January 1970.

    How to convert pwdLastSet to Unix epoch :
        * Divide by 10’000’000 pwdLastSet to convert in seconds
        * Substract 11’644’473’600 (this is the difference in second between
                                    the 1 January 1601 and 1970)

    """
    timestamp = int(ad_time) // 10000000 - 11644473600
    if timestamp > _MAX_VALUE:
        return None

    return datetime.fromtimestamp(timestamp)
