#!/usr/bin/env python
# -*- coding: utf-8 -*-


import StringIO as StringIO
import collections
import datetime
import fcntl
import grp
import hashlib
import logging
import os
import pwd
import re
import signal
import smtplib
import socket
import subprocess
import sys
import threading
import time
import traceback
from Queue import Queue
from contextlib import contextmanager
from email.mime.text import MIMEText

import kazoo.client as kzclient
import requests

from infra.dist.cacus.lib import constants
from infra.dist.cacus.lib.dbal import package_repository
from infra.dist.cacus.lib.dbal import ubuntu_upstream
from kazoo.exceptions import LockTimeout, KazooException
from kazoo.protocol.states import KazooState
from tornado import gen, locks

ctx_cache = {}
config = {}
log = logging.getLogger(__name__)


class GlobalLock(object):
    """Lock used to serialize the script"""

    path = "/var/run/cacus/gpg.lock"

    def __init__(self):
        # We don't use os.makedirs(..., exist_ok=True)
        # to ensure Python 2 compatibility
        dirname = os.path.dirname(self.path)
        if not os.path.exists(dirname):
            os.makedirs(dirname)
        self.lock = open(self.path, "w")

    def __enter__(self):
        fcntl.lockf(self.lock, fcntl.LOCK_EX)

    def __exit__(self, t, v, tb):
        fcntl.lockf(self.lock, fcntl.LOCK_UN)


def get_hashes(file):
    md5 = hashlib.md5()
    sha1 = hashlib.sha1()
    sha256 = hashlib.sha256()
    sha512 = hashlib.sha512()

    fpos = file.tell()
    file.seek(0)

    for chunk in iter(lambda: file.read(4096), b''):
        md5.update(chunk)
        sha1.update(chunk)
        sha256.update(chunk)
        sha512.update(chunk)

    file.seek(fpos)

    return {
        'md5': md5.digest(),
        'sha1': sha1.digest(),
        'sha256': sha256.digest(),
        'sha512': sha512.digest()
    }


class ComplexCacahe(object):
    def _cache_expired(self, cached_at):
        elapsed = datetime.datetime.now() - cached_at
        if elapsed > datetime.timedelta(seconds=self.caching_time):
            return True
        return False


class RepoEnvsGetter(ComplexCacahe):
    _instance = None
    __singleton_lock = threading.Lock()  # used for thread-safeness

    def __new__(cls, *args, **kwargs):
        with cls.__singleton_lock:
            if not cls._instance:
                cls._instance = super(RepoEnvsGetter, cls).__new__(
                    cls, *args, **kwargs)
            return cls._instance

    def __init__(self):
        self.repos = {}
        self.caching_time = 180  # seconds

    @staticmethod
    def _actually_get_envs(repo):
        return package_repository.list_envs(repo)

    def get_envs(self, repo):
        if repo in self.repos:
            if not self._cache_expired(self.repos[repo]['cached_at']):
                return self.repos[repo]['envs']
        else:
            self.repos[repo] = {}
        envs = self._actually_get_envs(repo)
        self.repos[repo]['cached_at'] = datetime.datetime.now()
        self.repos[repo]['envs'] = envs
        return envs


class AsyncRepoEnvsGetter(RepoEnvsGetter):
    def __init__(self, db):
        self.db = db
        self.repos = {}
        self.caching_time = 180  # seconds
        self.update_lock = locks.Lock()

    @gen.coroutine
    def get_envs(self, repo):
        if repo in self.repos and 'cached_at' in self.repos[repo]:
            if not self._cache_expired(self.repos[repo]['cached_at']):
                raise gen.Return(self.repos[repo]['envs'])

        with (yield self.update_lock.acquire()):
            if repo not in self.repos:
                self.repos[repo] = {}
            envs = yield self.db[repo].distinct('environment')
            self.repos[repo]['cached_at'] = datetime.datetime.now()
            self.repos[repo]['envs'] = envs
        raise gen.Return(envs)


class RepoEnvArchesGetter(ComplexCacahe):
    _instance = None
    __singleton_lock = threading.Lock()  # used for thread-sefiness

    def __new__(cls, *args, **kwargs):
        with cls.__singleton_lock:
            if not cls._instance:
                cls._instance = super(RepoEnvArchesGetter, cls).__new__(
                    cls, *args, **kwargs)
            return cls._instance

    def __init__(self):
        self.repos = {}
        self.caching_time = 180  # seconds

    def _actually_get_arches(self, repo, env):
        return package_repository.list_architectures(repo, env)

    def get_arches(self, repo, env, skip_source=False):
        if repo in self.repos:
            if env in self.repos[repo]:
                if not self._cache_expired(self.repos[repo][env]['cached_at']):
                    return self.repos[repo][env]['arches']
            else:
                self.repos[repo][env] = {}
        else:
            self.repos[repo] = {}
            self.repos[repo][env] = {}
        arches = self._actually_get_arches(repo, env)
        if skip_source:
            arches = [x for x in arches if x != 'source']
        self.repos[repo][env]['cached_at'] = datetime.datetime.now()
        self.repos[repo][env]['arches'] = arches
        return arches


# mixni class
class SimpleCache(object):
    def _cache_expired(self):
        elapsed = datetime.datetime.now() - self.cached_at
        if elapsed > datetime.timedelta(seconds=self.caching_time):
            return True
        return False


class UpstreamDistsGetter(SimpleCache):
    _instance = None
    __singleton_lock = threading.Lock()  # used for thread-sefiness

    def __new__(cls, *args, **kwargs):
        with cls.__singleton_lock:
            if not cls._instance:
                cls._instance = super(UpstreamDistsGetter, cls).__new__(
                    cls, *args, **kwargs)
            return cls._instance

    def __init__(self):
        self.dists = []
        self.caching_time = 3600  # seconds
        self.cached_at = None

    def _actually_get_dists(self, store=ubuntu_upstream.default_store):
        return store.list_dists()

    def get_dists(self):
        if self.cached_at and not self._cache_expired():
            return self.dists
        with self.__class__.__singleton_lock:
            self.dists = self._actually_get_dists()
            self.cached_at = datetime.datetime.now()
        return self.dists


class AsyncRepoListGetter(SimpleCache):
    def __init__(self, db):
        self.db = db
        self.repos = []
        self.caching_time = 180  # seconds
        self.cached_at = None
        self.update_lock = locks.Lock()

    @gen.coroutine
    def get_repos(self):
        if self.cached_at and not self._cache_expired():
            raise gen.Return(self.repos)
        with (yield self.update_lock.acquire()):
            self.repos = yield self.db.list_collection_names()
            self.cached_at = datetime.datetime.now()
        raise gen.Return(self.repos)


class UploadHelper(object):
    def __init__(self, old_package, new_files):
        self.new_files = new_files
        self.old_pkg_names = [
            os.path.basename(x['storage_key']) for x in old_package.debs]
        self.new_files_names = map(os.path.basename, new_files)
        self.old_src_names = [
            os.path.basename(
                x['storage_key']) for x in old_package.sources
        ] if old_package.sources else []

    def is_extra_upload(self):
        old_archless_pkg_names = map(
            self.remove_arch_from_pkg_name, self.old_pkg_names)
        new_pkg_names = filter(
            lambda x: x.endswith(constants.deb_extensions), self.new_files_names)
        new_archless_pkg_names = map(
            self.remove_arch_from_pkg_name, new_pkg_names)
        if sorted(old_archless_pkg_names) == sorted(new_archless_pkg_names):
            return True
        return False

    def generate_files_diff(self):
        old_files_names = set(self.old_pkg_names) | set(self.old_src_names)
        extra_files_names = set(self.new_files_names) - old_files_names
        extra_files = filter(
            lambda x: x.endswith(tuple(extra_files_names)),
            self.new_files
        )
        return extra_files

    def overridden_files(self):
        return (set(self.old_pkg_names) | set(self.old_src_names)) & set(self.new_files_names)

    def remove_arch_from_pkg_name(self, pkg_name):
        return pkg_name.split("_", 1)[0]


# As far as mongodb does not accept dot symbol in document keys we should
# replace all dots in filenames (that are used as keys) with smth else


def sanitize_filename(file):
    return file.replace(".", "___")


def desanitize_filename(file):
    return file.replace("___", ".")


def download_file(url, filename):
    log = logging.getLogger("infra.dist.cacus.lib.common.downloader")
    try:
        total_bytes = 0
        r = requests.get(url, stream=True)
        if r.status_code == 200:
            with open(filename, 'w') as f:
                for chunk in r.iter_content(64 * 1024):
                    total_bytes += len(chunk)
                    f.write(chunk)
            result = {'result': globals()['status'].OK, 'msg': 'OK'}
            msg = 'GET {} {} {} bytes {} sec'.format(
                url, r.status_code, total_bytes, r.elapsed.total_seconds())
            log.debug(msg)
        else:
            r.close()
            result = {
                'result': globals()['status'].NOT_FOUND,
                'msg': 'GET {}: 404'.format(url)
            }
        r.close()
    except (requests.ConnectionError, requests.HTTPError) as e:
        result = {'result': globals()['status'].ERROR, 'msg': str(e)}
    except requests.Timeout as e:
        result = {'result': globals()['status'].TIMEOUT, 'msg': str(e)}
    return result


class GPGSignError(Exception):
    pass


def gpg_sign(data, signer_email, popen_class=subprocess.Popen, lock=GlobalLock):
    with lock():
        process = popen_class(
            [
                'gpg',
                '--no-tty',
                '--armor',
                '--default-key',
                signer_email,
                '--detach-sign',
                '-o-'
            ],
            stdin=subprocess.PIPE,
            stdout=subprocess.PIPE,
            stderr=subprocess.PIPE
        )
        stdout, stderr = process.communicate(input=data)
        if 'BEGIN PGP' not in stdout or 'error' in stderr:
            raise GPGSignError("cannot sign:\n{}\nusing: {}\ngpg stdout:\n{}\ngpg stderr:\n{}\nretval: {}\n".format(
                data, signer_email, stdout, stderr, process.returncode))
        return stdout


def gpg_sign_in_place(data, signer_email, popen_class=subprocess.Popen, lock=GlobalLock):
    with lock():
        cmd = [
            'gpg',
            '--no-tty',
            '--armor',
            '--default-key',
            signer_email,
            '--clearsign',
            '-o-'
        ]
        process = popen_class(
            cmd,
            stdin=subprocess.PIPE,
            stdout=subprocess.PIPE,
            stderr=subprocess.PIPE
        )
        stdout, stderr = process.communicate(input=data)
        if 'BEGIN PGP' not in stdout or 'error' in stderr:
            raise GPGSignError("cannot sign:\n{}\nusing: {}\ngpg stdout:\n{}\ngpg stderr:\n{}\nretval: {}\n".format(
                data, signer_email, stdout, stderr, process.returncode))
        return stdout


class myStringIO(StringIO.StringIO, object):

    def __init__(self, *args, **kwargs):
        super(self.__class__, self).__init__(self, *args, **kwargs)

    def __enter__(self):
        self.seek(0)
        return self

    def __exit__(self, exc_type, exc_value, traceback):
        # close() on StringIO will free memory buffer,
        # so 'with' statement is destructive
        self.close()

    def getvalue(self):
        self.seek(0)
        result = super(self.__class__, self).getvalue()
        self.seek(0)
        return result

    def __str__(self):
        if not ("closed" in self.__dict__):
            return ""
        return self.getvalue()


class RepoLockTimeout(Exception):
    pass


class DeferedIndexerLockTimeout(Exception):
    pass


class MetadataIsBusy(Exception):
    pass


class RepoLock(object):

    def __init__(self, repo, env, timeout=1800, nonblocking=False):
        self.repo = repo
        self.env = env
        self.timeout = timeout
        self.nonblocking = nonblocking
        self.log = logging.getLogger("infra.dist.cacus.lib.common.Locks")
        lock_factory = ZKLockFactory(logger=self.log)
        lock_id = 'repo/{}/{}'.format(repo, env)
        self.lock = lock_factory.get_lock(lock_id, raw=True)

    def __enter__(self):
        self.log.debug("Trying to lock repo {}/{}".format(self.repo, self.env))
        if self.nonblocking:
            acquired = self.lock.acquire(blocking=False)
            if acquired:
                return
            else:
                msg = "Lock for repo {}/{} is busy".format(self.repo, self.env)
                raise MetadataIsBusy(msg)
        try:
            self.lock.acquire(timeout=self.timeout)
            self.log.debug("{}/{} locked".format(self.repo, self.env))
        except LockTimeout:
            msg = "Timeout while trying to lock repo {0}/{1}".format(
                self.repo, self.env)
            raise RepoLockTimeout(msg)

    def __exit__(self, exc_type, exc_value, traceback):
        self.lock.release()
        self.log.debug("Repo {}/{} unlocked".format(self.repo, self.env))


class SrcLock(object):

    def __init__(self, repo, env):
        self.repo = repo
        self.env = env
        self.log = logging.getLogger("infra.dist.cacus.lib.common.Locks")
        lock_factory = ZKLockFactory(logger=self.log)
        lock_id = 'src/{}/{}'.format(repo, env)
        self.lock = lock_factory.get_lock(lock_id, raw=True)

    def __enter__(self):
        self.log.debug("Trying to lock src {}/{}".format(self.repo, self.env))
        acquired = self.lock.acquire(blocking=False)
        if not acquired:
            msg = "Lock for src {}/{} is busy".format(self.repo, self.env)
            raise MetadataIsBusy(msg)

    def __exit__(self, exc_type, exc_value, traceback):
        self.lock.release()
        self.log.debug("Src {}/{} unlocked".format(self.repo, self.env))


class DeferedIndexerLock(object):
    _lock_prefix = 'defered_indexer/'

    def __init__(self, repo):
        self.repo = repo
        self.log = logging.getLogger("infra.dist.cacus.lib.common.Locks")
        lock_factory = ZKLockFactory(logger=self.log)
        lock_id = '{}{}'.format(self._lock_prefix, repo)
        self.lock = lock_factory.get_lock(lock_id, raw=True)

    def __enter__(self):
        self.log.debug(
            "Defered indexer is waiting for {} lock".format(self.repo))
        try:
            self.lock.acquire(timeout=30)
            self.log.debug(
                "Defered indexer aquired {} lock".format(self.repo))
        except LockTimeout:
            self.log.debug(
                "Defered indexer timed out aquring {} lock."
                " There is another indexer already indexing this repo,"
                " or something wrong with ZK cluster".format(self.repo))
            raise DeferedIndexerLockTimeout

    def __exit__(self, exc_type, exc_value, traceback):
        self.lock.release()
        self.log.debug(
            "Defered indexer has been released {} lock".format(self.repo))

    def contenders(self):
        return self.lock.contenders()

    @staticmethod
    def extract_clean_lock_id(small_lock_id):
        return small_lock_id.replace(DeferedIndexerLock._lock_prefix, '')


class ZKLockFactory(object):
    _cls_instance = None
    __singleton_lock = threading.Lock()  # used for thread-sefiness

    def __new__(cls, logger=None):
        with cls.__singleton_lock:
            if not cls._cls_instance:
                cls._cls_instance = super(cls, cls).__new__(cls)
                hosts = ''
                port = str(config['distributed_locks']['zk_port'])
                hosts = ','.join(
                    ['{}:{}'.format(host, port)
                     for host in config['distributed_locks']['zk_nodes']]
                )

                if config['distributed_locks']['supress_zk_log']:
                    zk = kzclient.KazooClient(hosts=hosts)
                else:
                    zk = kzclient.KazooClient(hosts=hosts, logger=logger)
                zk.start()
                cls._cls_instance._connection = zk
            else:
                zk = cls._cls_instance._connection
                if logger:
                    if zk.logger != logger:
                        msg = 'Logger for ZooKeeper has been already set!' \
                              ' New logger will be ignored.'
                        log.debug(msg)
            return cls._cls_instance

    def __init__(self, logger=None):
        self._instances = {}

    def ensure_zk(self):
        attempts = 0
        while attempts < 5:
            if self._connection.state != KazooState.CONNECTED:
                try:
                    self._connection.restart()
                except KazooException as error:
                    log.critical('cannot ensure zk: %s', error)
                    _, _, tb = sys.exc_info()
                    trace = ''.join(traceback.format_exception(type(error), error, tb))
                    log.critical(trace)
                time.sleep(2 ** attempts)
            else:
                break
        if attempts >= 5:
            log.critical('cannot restart KazooClient instance. giving up.')
            exit(-1)

    @contextmanager
    def lock_cm(self, lock, timeout=20):
        self.ensure_zk()
        lock.acquire(timeout=timeout)
        try:
            yield
        finally:
            lock.release()

    def get_lock(self, lock_id, timeout=20, raw=False):
        self.ensure_zk()
        if lock_id not in self._instances:
            lock = self._connection.Lock(
                'cacus/{}/{}'.format(
                    config['distributed_locks']['zk_prefix'],
                    lock_id
                ),
                self.format_full_lock_id(lock_id)
            )
            self._instances[lock_id] = lock

        if raw:
            return self._instances[lock_id]

        return self.lock_cm(
            self._instances[lock_id],
            timeout=timeout
        )

    @staticmethod
    def format_full_lock_id(small_lock_id):
        return '{}_cacus_{}_lock_id:{}_{}'.format(
            config['distributed_locks']['zk_prefix'],
            small_lock_id,
            socket.gethostname(),
            os.getpid()
        )

    @staticmethod
    def extract_small_lock_id(full_lock_id):
        pattern = config['distributed_locks']['zk_prefix'] + '_cacus_(.*)_lock_id'
        match = re.match(pattern, full_lock_id)
        return match.groups()[0]


class TimeoutError(Exception):
    pass


class timeout:
    def __init__(self, seconds=1, error_message='Timeout'):
        self.seconds = seconds
        self.error_message = error_message

    def handle_timeout(self, signum, frame):
        raise TimeoutError(self.error_message)

    def __enter__(self):
        signal.signal(signal.SIGALRM, self.handle_timeout)
        signal.alarm(self.seconds)

    def __exit__(self, type, value, traceback):
        signal.alarm(0)


class OrderedSet(collections.MutableSet):

    def __init__(self, iterable=None):
        self.end = end = []
        end += [None, end, end]  # sentinel node for doubly linked list
        self.map = {}  # key --> [key, prev, next]
        if iterable is not None:
            self |= iterable

    def __len__(self):
        return len(self.map)

    def __contains__(self, key):
        return key in self.map

    def add(self, key):
        if key not in self.map:
            end = self.end
            curr = end[1]
            curr[2] = end[1] = self.map[key] = [key, curr, end]

    def discard(self, key):
        if key in self.map:
            key, prev, next = self.map.pop(key)
            prev[2] = next
            next[1] = prev

    def __iter__(self):
        end = self.end
        curr = end[2]
        while curr is not end:
            yield curr[0]
            curr = curr[2]

    def __reversed__(self):
        end = self.end
        curr = end[1]
        while curr is not end:
            yield curr[0]
            curr = curr[1]

    def pop(self, last=True):
        if not self:
            raise KeyError('set is empty')
        key = self.end[1][0] if last else self.end[2][0]
        self.discard(key)
        return key

    def __repr__(self):
        if not self:
            return '%s()' % (self.__class__.__name__,)
        return '%s(%r)' % (self.__class__.__name__, list(self))

    def __eq__(self, other):
        if isinstance(other, OrderedSet):
            return len(self) == len(other) and list(self) == list(other)
        return set(self) == set(other)


class OrderedSetQueue(Queue):
    def _init(self, maxsize):
        self.queue = OrderedSet()

    def _put(self, item):
        self.queue.add(item)

    def _get(self):
        return self.queue.pop()


class ReadWriteLock:
    """ A lock object that allows many simultaneous "read locks", but
    only one "write lock." """

    def __init__(self):
        self._read_ready = threading.Condition(threading.Lock())
        self._readers = 0

    def acquire_read(self):
        """ Acquire a read lock. Blocks only if a thread has
        acquired the write lock. """
        self._read_ready.acquire()
        try:
            self._readers += 1
        finally:
            self._read_ready.release()

    def release_read(self):
        """ Release a read lock. """
        self._read_ready.acquire()
        try:
            self._readers -= 1
            if not self._readers:
                self._read_ready.notifyAll()
        finally:
            self._read_ready.release()

    def acquire_write(self):
        """ Acquire a write lock. Blocks until there are no
        acquired read or write locks. """
        self._read_ready.acquire()
        while self._readers > 0:
            self._read_ready.wait()

    def release_write(self):
        """ Release a write lock. """
        self._read_ready.release()


def send_mail(sender, to, subject, body):
    if '<' in to and '>' in to:
        try:
            email_re = re.compile('.*<(?P<email>.*)>.*')
            match = email_re.match(to)
            dst_email = match.groupdict()['email']
        except Exception as error:
            log.warning('cannot parse notification destination: {}'.format(to))
            log.warning(error)
    else:
        dst_email = to
    if not dst_email.endswith('@yandex-team.ru'):
        log.warning('Sending mail to external domains restricted. Refusing to send notification to: {}'.format(to))
        return

    message = MIMEText(body)
    message['Subject'] = subject
    message['From'] = sender
    message['To'] = to
    try:
        s = smtplib.SMTP(config['notifications']['smtp_host'])
        s.sendmail(sender, [to], message.as_string())
        s.quit()
    except smtplib.SMTPException as e:
        msg = 'Error ocured, during sending e-mail, error: {}'.format(e)
        log.warning(msg)
    log.info('Sent notification to: {}'.format(to))


def sigusr1_handler(signum, frame):
    loggers = [logging.getLogger()]  # get the root logger
    loggers.extend(logging.getLogger(name) for name in logging.root.manager.loggerDict)
    log.info('SIGUSR1 received, reopening logs...')
    for logger in loggers:
        for handler in logger.handlers:
            close_log_file(handler)
    log.info('SIGUSR1 received, logs reopened')


def close_log_file(handler):
    if isinstance(handler, logging.FileHandler):
        handler.close()
        handler.stream = None


def sigterm_handler(signum, frame):
    log.info('SIGTERM received. Cleaning up...')
    sys.stderr.close()
    sys.stdout.close()
    os.killpg(os.getpgid(0))
    exit(0)


def sigint_handler(signum, frame):
    log.info('SIGINT received. Cleaning up...')
    exit(13)


def sigchld_handler(signum, frame):
    try:
        os.waitpid(-1, os.WNOHANG)
    except OSError as error:
        if error.errno != 10:
            log.error('error in sigchld_handler: %s', error)


def setup_handlers():
    log.info('setting up signal handlers')
    signal.signal(signal.SIGUSR1, sigusr1_handler)
    signal.siginterrupt(signal.SIGUSR1, config['daemon_params']['siginterrupt_flag'])
    signal.signal(signal.SIGTERM, sigterm_handler)
    signal.siginterrupt(signal.SIGTERM, config['daemon_params']['siginterrupt_flag'])
    signal.signal(signal.SIGCHLD, sigchld_handler)


def daemonize(user, group, func, func_args=()):
    log.info('Trying to daemonize action: {}'.format(action_name))
    pid = 0
    if pid == 0:  # daemonizing
        if os.getuid() != 0:
            log.warn('launched as non-root user, user and group settings omitted')
        else:
            log.info('dropping privileges')
            os.setuid(pwd.getpwnam(user).pw_uid)
            os.setgid(grp.getgrnam(group).gr_gid)
        setup_handlers()
        log.info('done setup_handlers')
        func(*func_args)
        exit(0)
    else:
        log.info('daemon master for {} launched as pid: {}'.format(action_name, pid))
        exit(0)


def get_daemon_params(config, daemon=None):
    return config['daemon_params']['daemonize'], config['daemon_params']['user'], config['daemon_params']['group']


action_name = None
