#!/skynet/python/bin/python

#
# skynet apt method implementation inspired by apt-boto-s3 which may be found at
# https://github.com/lucidsoftware/apt-boto-s3
#

import re
import os
import os.path
import shutil
import stat
import requests
import sys
from api.copier import Copier
import signal
import hashlib
import threading
from collections import namedtuple
from tempfile import mkdtemp
from urlparse import ParseResult, urlparse

DEBUG = False

SKY_ARCHIVES_DIR = '/var/cache/apt/archives/sky'


class Interrupt():
    def __init__(self):
        self.lock = threading.Lock()
        self.interrupted = False

    def __nonzero__(self):
        return self.interrupted

    def interrupt(self):
        with self.lock:
            if not self.interrupted:
                self.interrupted = True
                return True
        return False


class MessageHeader(namedtuple("MessageHeader_", ["status_code", "status_info"])):
    def __str__(self):
        return "{} {}".format(self.status_code, self.status_info)

    @staticmethod
    def parse(line):
        status_code, status_info = line.split(' ', 1)
        mh = MessageHeader(int(status_code), status_info)
        return mh


class MessageHeaders:
    CAPABILITIES = MessageHeader(100, 'Capabilities')
    STATUS = MessageHeader(102, 'Status')
    URI_FAILURE = MessageHeader(400, 'URI Failure')
    GENERAL_FAILURE = MessageHeader(401, 'General Failure')
    URI_START = MessageHeader(200, 'URI Start')
    URI_DONE = MessageHeader(201, 'URI Done')
    URI_ACQUIRE = MessageHeader(600, 'URI Acquire')
    CONFIGURATION = MessageHeader(601, 'Configuration')


class Message(namedtuple("Message_", ["header", "fields"])):
    @staticmethod
    def parse_lines(lines):
        message = Message(MessageHeader.parse(lines[0]), tuple(re.split(": *", line, 1) for line in lines[1:]))
        return message

    def get_field(self, field_name):
        return next(self.get_fields(field_name), None)

    def get_fields(self, field_name):
        return (value for name, value in self.fields if name.lower() == field_name.lower())

    def __str__(self):
        lines = [str(self.header)]
        lines.extend("{}: {}".format(name, value) for name, value in self.fields)
        lines.append("\n")
        return "\n".join(lines)


Pipes = namedtuple("Pipes", ["input", "output"])


class AptMethod(namedtuple("AptMethod_", ["pipes"])):
    def send(self, message):
        self.pipes.output.write(str(message))
        self.pipes.output.flush()

    def _send_error(self, message):
        self.send(Message(MessageHeaders.GENERAL_FAILURE, (("Message", message),)))

    def send_capabilities(self):
        raise NotImplementedError

    def handle_message(self, message):
        raise NotImplementedError

    def run(self):
        try:
            self.send_capabilities()

            tasks = []
            interrupt = Interrupt()

            lines = []
            while not interrupt.interrupted:
                line = self.pipes.input.readline()
                if not line:
                    for task in tasks:
                        task.join()
                    break
                line = line.rstrip("\n")
                if line:
                    lines.append(line)
                elif lines:
                    message = Message.parse_lines(lines)
                    lines = []
                    def handle_message(message):
                        try:
                            self.handle_message(message)
                        except Exception as ex:
                            if interrupt.interrupt():
                                self._send_error(ex)
                    task = threading.Thread(target=handle_message, args=(message,))
                    tasks.append(task)
                    task.start()

        except Exception as ex:
            raise


class AptMethodSky(AptMethod):
    deb_extensions = ('.deb', '.udeb', '.ddeb')

    def __init__(self, *args, **kwargs):
        AptMethod.__init__(self, *args, **kwargs)
        self.copier = Copier()

    def send_capabilities(self):
        self.send(Message(MessageHeaders.CAPABILITIES, (
            ('Send-Config', 'true'),
            ('Pipeline', 'true'),
            ('Single-Instance', 'yes')
        )))

    def handle_message(self, message):
        if message.header.status_code == MessageHeaders.URI_ACQUIRE.status_code:
            uri = message.get_field("URI")
            clean_uri = None
            if uri.startswith("sky://"):
                if uri.endswith(self.deb_extensions):
                    uri_obj = urlparse(uri)
                    first_slash_pos = uri_obj.path[1:].find('/') + 1
                    clean_path = uri_obj.path[first_slash_pos:]
                    clean_uri_obj = ParseResult(
                        uri_obj.scheme,
                        uri_obj.netloc,
                        clean_path,
                        uri_obj.params,
                        uri_obj.query,
                        uri_obj.fragment)
                    clean_uri = clean_uri_obj.geturl()
                uri_body = clean_uri[3:] if clean_uri else uri[3:]
                params = "?rbtorrent=1" if uri.endswith(self.deb_extensions) else ""
                http_uri = "http{}{}".format(uri_body, params)
                self.send(Message(MessageHeaders.STATUS, (
                    ('Message', 'Will fetch:'),
                    ('URI', uri)
                )))
            filename = message.get_field("Filename")
            try:
                response = requests.get(http_uri)
                if response.status_code != 200:
                    if response.status_code != 204:
                        self.send(Message(MessageHeaders.URI_FAILURE, (
                            ('URI', uri),
                            ('Message', "%d  %s" % (response.status_code, response.reason)),
                            ('FailReason', "HttpError%d" % response.status_code)
                        )))
                    else:
                        self.send(Message(MessageHeaders.STATUS, (
                            ('Message', 'Acquiring rbtorrent resource id'),
                            ('URI', uri)
                        )))
                        resid = response.headers["X-Resource-Id"]
                        self.send(Message(MessageHeaders.STATUS, (
                            ('Message', 'Acquiring file by sky copier'),
                            ('URI', uri)
                        )))
                        h = self.copier.handle(resid)
                        if not os.path.isdir(SKY_ARCHIVES_DIR):
                            os.mkdir(SKY_ARCHIVES_DIR, 0o775)
                        g = h.get(SKY_ARCHIVES_DIR)
                        r = g.wait()
                        acquired_file = os.path.join(r.files()[0]['path'])
                        os.symlink(acquired_file, filename)
                        self.send(Message(MessageHeaders.STATUS, (
                            ('Message', 'File acquired'),
                            ('URI', uri)
                        )))
                        if DEBUG:
                            with open('/log.txt', 'a') as f:
                                f.write('acquired file: {}\n'.format(acquired_file))
                                f.write('requested filename: {}\n'.format(filename))

                        md5 = hashlib.md5()
                        sha1 = hashlib.sha1()
                        sha256 = hashlib.sha256()
                        sha512 = hashlib.sha512()
                        total = 0
                        with open(filename) as f:
                            while True:
                                data = f.read(4096)
                                if not data:
                                    break
                                total += len(data)
                                md5.update(data)
                                sha1.update(data)
                                sha256.update(data)
                                sha512.update(data)

                        message = (('Filename', filename), ('URI', uri), ('MD5-Hash', md5.hexdigest()),
                                  ('MD5Sum-Hash', md5.hexdigest()), ('SHA1-Hash', sha1.hexdigest()),
                                  ('SHA256-Hash', sha256.hexdigest()), ('SHA512-Hash', sha512.hexdigest()),
                                  ('Size', total))
                        if DEBUG:
                            with open('/log.txt', 'a') as f:
                                f.write('apt message: {}\n'.format(str(message)))
                        self.send(Message(MessageHeaders.URI_DONE, message))
                else:

                    md5 = hashlib.md5()
                    sha1 = hashlib.sha1()
                    sha256 = hashlib.sha256()
                    sha512 = hashlib.sha512()

                    md5.update(response.content)
                    sha1.update(response.content)
                    sha256.update(response.content)
                    sha512.update(response.content)

                    with open(filename, "w") as output_file:
                        output_file.write(response.content)
                    self.send(Message(MessageHeaders.URI_DONE, (
                        ('Filename', filename),
                        ('URI', uri),
                        ('MD5-Hash', md5.hexdigest()),
                        ('MD5Sum-Hash', md5.hexdigest()),
                        ('SHA1-Hash', sha1.hexdigest()),
                        ('SHA256-Hash', sha256.hexdigest()),
                        ('SHA512-Hash', sha512.hexdigest()),
                        ('Size', len(response.content))
                    )))
            except Exception as e:
                self.send(Message(MessageHeaders.URI_FAILURE, (
                    ('Message', e.message),
                    ('URI', uri)
                )))


if __name__ == '__main__':
    def handler(sign, frame):
        sys.exit(0)
    signal.signal(signal.SIGINT, handler)

    pipes = Pipes(sys.stdin, sys.stdout)
    AptMethodSky(pipes).run()
