import os
import sys
import time
import logging
from contextlib import contextmanager

import six
import ujson

from api.copier import Copier

from library.auth.verify import VerifyManager
from infra.skylib import http_tools
from infra.skylib.sysutils.time import sleep
from infra.skylib.intervals import IntervalController


if sys.version_info.major < 3:
    def b(s):
        if s is None:
            return None
        return s.encode('utf-8') if isinstance(s, unicode) else str(s)  # noqa
else:
    def b(s):
        if s is None:
            return None
        return s if isinstance(s, bytes) else bytes(str(s), 'utf-8')


class KeysStorage(object):
    class KeysNotReady(Exception):
        pass

    def __init__(self, storage_dir, log=None, privileges_lock=None):
        self.log = log or logging.getLogger('infra.skylib.keys_storage')
        self.storage_dir = storage_dir
        self.privileges_lock = privileges_lock
        self.rbtorrent_id = None
        self.mtime = None
        self.keys = {}
        self.keys_total = 0
        self.fail_interval = IntervalController(
            initial=60.,
            multiplier=1.2,
            variance=.1,
            maximum=600.,
        )
        self.success_interval = IntervalController(
            initial=300.,
            multiplier=1.,
            variance=.2,
            maximum=300.,
        )

        self.load_keys_cache()

    @property
    def keys_expired(self):
        # we store keys for at most two days
        return self.mtime is None or time.time() - 60 * 60 * 24 * 2 > self.mtime

    @property
    def rbtorrent_file(self):
        return os.path.join(self.storage_dir, 'rbotorrent.link')

    @property
    def data_dir(self):
        return os.path.join(self.storage_dir, 'data')

    @property
    def data_file(self):
        # filename is hardcoded in YP builder
        return os.path.join(self.data_dir, 'ssh_keys.json')

    @contextmanager
    def lock(self):
        if self.privileges_lock is not None:
            with self.privileges_lock:
                yield
        else:
            yield

    def load_keys_cache(self):
        try:
            if os.path.exists(self.rbtorrent_file):
                with self.lock(), open(self.rbtorrent_file) as f:
                    stat = os.fstat(f.fileno())
                    self.mtime = stat.st_mtime
                    self.rbtorrent_id = f.read().strip()

                self.load_keys_data()
        except Exception as e:
            self.log.warning("Cache is broken, load failed: %s", e)
            self.mtime = None
            self.rbtorrent_id = None

    def load_keys_data(self):
        if self.keys_expired:
            self.keys = {}
            return

        if os.path.exists(self.data_file):
            with self.lock(), open(self.data_file) as f:
                keys_data = ujson.load(f)

            keys, keys_total = self.load_keys(keys_data)
            assert isinstance(keys, dict)
            self.keys = keys
            self.keys_total = keys_total

    def load_keys(self, data):
        keys_by_fp = {}
        keys = {}

        for user_id, records in six.iteritems(data):
            user_id = b(user_id)
            for record in records:
                try:
                    key = VerifyManager.loadsKeys(b(record)).next()
                except Exception as e:
                    self.log.warning("cannot parse key for user %r %s: %r", user_id, e, record)
                else:
                    if 'cert-authority' in key.options:
                        # CAs are processed in one common way, so we do not allow user to add
                        # custom CA.
                        continue
                    key = keys_by_fp.setdefault(key.fingerprint(), key)
                    key.userNames.add(user_id)
                    keys.setdefault(user_id, []).append(key)

        return keys, len(keys_by_fp)

    def fetch_rbtorrent_url(self):
        data = http_tools.fetch_json("http://localhost:25536/pods/keys", "Keys", "keys rbtorrent", log=self.log)
        if not isinstance(data, dict):
            self.log.warning("ISS response is malformed")
            return
        url = data.get('rbtorrent')
        if not url or not isinstance(url, six.string_types):
            self.log.warning('ISS response has no valid rbtorrent')
            return

        return url

    def download_rbtorrent(self, rbtorrent):
        data_dir = self.data_dir
        if not os.path.exists(data_dir):
            os.makedirs(data_dir, mode=0o755)
        elif not os.path.isdir(data_dir):
            os.unlink(data_dir)
            os.makedirs(data_dir, mode=0o755)

        client = Copier()
        client.get(rbtorrent, dest=self.data_dir)

    def update_loop(self):
        while True:
            try:
                url = self.fetch_rbtorrent_url()
                if not url:
                    sleep_time = self.fail_interval.schedule_next()
                    self.log.debug("no keys url available, will sleep %.2fs before next attempt", sleep_time)
                    sleep(sleep_time)
                    continue

                mtime = time.time()
                with self.lock():
                    self.download_rbtorrent(url)
                self.mtime = mtime
                self.load_keys_data()

                if url:
                    with self.lock(), open(self.rbtorrent_file, 'w') as f:
                        f.write(url)
            except Exception as e:
                sleep_time = self.fail_interval.schedule_next()
                self.log.exception("keys update failed, will sleep %.2fs before next attempt: %s", sleep_time, e)
                sleep(sleep_time)
            else:
                self.fail_interval.reset()
                sleep_time = self.success_interval.schedule_next()
                self.log.info("%s keys updated from url %r, will sleep %.2fs before next attempt", self.keys_total, url, sleep_time)
                sleep(sleep_time)

    def get_keys(self, username):
        if self.keys_expired:
            raise self.KeysNotReady()

        return list(self.keys.get(username, []))
