import contextlib
import os
import re
import sys
import random
import requests
import threading
import traceback

try:
    from api.srvmngr import getRoot
except ImportError as ex:
    from srvmngr.utils import getRoot

from kernel.util.sys.user import userPrivileges as user_privileges  # noqa


# TODO: fixme in cygwin python
def _chown(path, uid, gid):
    os.chown(path.replace('/', '\\') if sys.platform == 'cygwin' else path, uid, gid)


class CauthKeyUpdater(object):
    CAUTH_API_URL_USERS = 'https://cauth.yandex.net:4443/userkeys/'
    CAUTH_API_URL_ADMINS = 'https://cauth.yandex.net:4443/adminkeys/'

    CAUTH_REQUEST_TIMEOUT = 10
    UPDATE_EVERY = (600, 1200)

    SOURCES_LINE = re.compile(r'^\s*sources\s*=\s*"(.*&?)"', re.MULTILINE)

    def __init__(self, log):
        self.log = log.getChild('cauth')

        self.thr = None
        self.stop_ev = None
        self.allowed_sources = None

        self.cache = {
            'userkeys': [None, None],
            'adminkeys': [None, None],
        }

    def start(self):
        self.stop_ev = threading.Event()

        self.thr = threading.Thread(target=self._updater)
        self.thr.daemon = True
        self.thr.start()

    def stop(self):
        if self.thr:
            self.stop_ev.set()
            self.thr.join()
        return

    @contextlib.contextmanager
    def __call__(self):
        try:
            self.start()
            yield
        finally:
            self.stop()

    def fetch_keys(self, uri, cache, user=None):
        join_key = '?'
        if cache[0]:
            uri = uri + '%smd5=%s' % (join_key, cache[0])
            join_key = '&'

        if self.allowed_sources is not None:
            uri += '%ssources=%s' % (join_key, self.allowed_sources)
            join_key = '&'

        r = requests.get(uri, timeout=self.CAUTH_REQUEST_TIMEOUT)
        if r.status_code != 200:
            self.log.error('Server response: %r - %s', r.status_code, r.reason)
            return {}

        if cache[0] is not None and r.text == 'OK':
            self.log.debug('  cached')
            return cache[1]

        keys = {}
        md5 = None

        search_user = user is None

        for line in r.text.split('\n'):
            if search_user:
                try:
                    user, key = line.split(' : ', 1)
                except ValueError:
                    continue
            else:
                if line.startswith('#'):
                    continue
                if not line:
                    continue
                key = line.strip()

            keys.setdefault(user, set()).add(key)

        if r.text.count('\n') >= 2:
            last_line = r.text.rsplit('\n', 2)[1]
            if 'MD5:' in last_line:
                md5_idx = last_line.index('MD5:') + 4
                md5 = last_line[md5_idx:md5_idx + 32]

        if md5:
            cache[:] = (md5, keys)

        return keys

    def fetch_user_keys(self):
        return self.fetch_keys(self.CAUTH_API_URL_USERS, self.cache['userkeys'])

    def fetch_admin_keys(self):
        return self.fetch_keys(self.CAUTH_API_URL_ADMINS, self.cache['adminkeys'], user='root')

    def update_allowed_sources(self):
        cfgfile = '/etc/cauth/cauth.conf'
        if os.path.exists(cfgfile) and os.path.isfile(cfgfile):
            cfg = open(cfgfile, 'r').read()
            match = self.SOURCES_LINE.search(cfg)
            if match is not None:
                self.allowed_sources = match.group(1)
            else:
                self.allowed_sources = None
        else:
            self.allowed_sources = None

    def store_keys(self, path, user, keys):
        if not os.path.exists(path):
            os.makedirs(path, mode=0755)
            _chown(path, 0, 0)
            self.log.info('Keys directory created: %s', path)

        npath = os.path.join(path, user)
        if not os.path.exists(npath):
            os.makedirs(npath, mode=0755)
            _chown(npath, 0, 0)
            self.log.info('Keys directory created: %s', npath)

        good_keys = set()

        for idx, key in enumerate(keys):
            key_file = os.path.join(
                npath,
                '%s.cauth.%02d.public' % (user, idx)
            )
            with open(key_file, 'ab+') as fp:
                data = fp.read()
                if data != key:
                    self.log.info('Update key %s', key_file)
                    fp.seek(0)
                    fp.truncate(0)
                    fp.write(key)
                good_keys.add(key_file)

            _chown(key_file, 0, 0)
            os.chmod(key_file, 0o644)

        return good_keys

    def cleanup_keys(self, path, exclude):
        if not os.path.exists(path):
            return

        for item in os.listdir(path):
            item_path = os.path.join(path, item)
            if not os.path.isdir(item_path):
                continue

            for item2 in os.listdir(item_path):
                item2_path = os.path.join(item_path, item2)
                if '.cauth.' in item2_path and item2_path not in exclude:
                    self.log.info('Removing key: %s', item2_path)
                    os.unlink(item2_path)

            if not os.listdir(item_path):
                self.log.info('Removing directory: %s', item_path)
                os.rmdir(item_path)

    def _updater(self):
        keys_path = os.path.realpath(os.path.join(getRoot(), 'etc', 'auth'))

        while not self.stop_ev.isSet():
            try:
                self.update_allowed_sources()
                key_pairs_users = self.fetch_user_keys().items()
                key_pairs_admins = self.fetch_admin_keys().items()
                good_keys = set()

                with user_privileges():
                    for user, keys in key_pairs_users + key_pairs_admins:
                        good_keys.update(self.store_keys(keys_path, user, keys))

                    self.cleanup_keys(keys_path, exclude=good_keys)
            except Exception as ex:
                self.log.warning('Unhandled error: %s', ex)
                self.log.warning(traceback.format_exc())
            finally:
                tout = random.randrange(*self.UPDATE_EVERY)
                self.stop_ev.wait(tout)
