#!/usr/bin/env python
# -*- coding: utf-8 -*-

# S3Auth is BSD-licensed library from Paul Tax
# https://github.com/tax/python-requests-aws
# I've dropped py3 support, though.

from __future__ import print_function

import argparse
import functools
import hmac
import json
import sys
import xml.etree.ElementTree as ET
from base64 import encodestring
from email.utils import formatdate
from hashlib import sha1 as sha
from urlparse import urlparse, unquote

import requests
import time
from requests.auth import AuthBase


class S3Auth(AuthBase):
    """Attaches AWS Authentication to the given Request object."""

    # List of Query String Arguments of Interest
    special_params = [
        'acl', 'location', 'logging', 'partNumber', 'policy', 'requestPayment',
        'torrent', 'versioning', 'versionId', 'versions', 'website', 'uploads',
        'uploadId', 'response-content-type', 'response-content-language',
        'response-expires', 'response-cache-control', 'delete', 'lifecycle',
        'response-content-disposition', 'response-content-encoding', 'tagging',
        'notification', 'cors'
    ]

    def __init__(self, access_key, secret_key, service_url):
        self.access_key = str(access_key)
        self.secret_key = str(secret_key)
        self.service_base_url = service_url

    def __call__(self, r):
        # Create date header if it is not created yet.
        if 'date' not in r.headers and 'x-amz-date' not in r.headers:
            r.headers['date'] = formatdate(
                timeval=None,
                localtime=False,
                usegmt=True)
        signature = self.get_signature(r)
        r.headers['Authorization'] = 'AWS %s:%s' % (self.access_key, signature)
        return r

    def get_signature(self, r):
        canonical_string = self.get_canonical_string(
            r.url, r.headers, r.method)
        key = self.secret_key
        msg = canonical_string
        h = hmac.new(key, msg, digestmod=sha)
        return encodestring(h.digest()).strip()

    def get_canonical_string(self, url, headers, method):
        parsedurl = urlparse(url)
        objectkey = parsedurl.path[1:]
        query_args = sorted(parsedurl.query.split('&'))

        bucket = parsedurl.netloc[:-len(self.service_base_url)]
        if len(bucket) > 1:
            # remove last dot
            bucket = bucket[:-1]

        interesting_headers = {
            'content-md5': '',
            'content-type': '',
            'date': ''}
        for key in headers:
            lk = key.lower()
            try:
                lk = lk.decode('utf-8')
            except:
                pass
            if headers[key] and (lk in interesting_headers.keys()
                                 or lk.startswith('x-amz-')):
                interesting_headers[lk] = headers[key].strip()

        # If x-amz-date is used it supersedes the date header.
        if 'x-amz-date' in interesting_headers:
            interesting_headers['date'] = ''

        buf = '%s\n' % method
        for key in sorted(interesting_headers.keys()):
            val = interesting_headers[key]
            if key.startswith('x-amz-'):
                buf += '%s:%s\n' % (key, val)
            else:
                buf += '%s\n' % val

        # append the bucket if it exists
        if bucket != '':
            buf += '/%s' % bucket

        # add the objectkey. even if it doesn't exist, add the slash
        buf += '/%s' % objectkey

        params_found = False

        # handle special query string arguments
        for q in query_args:
            k = q.split('=')[0]
            if k in self.special_params:
                buf += '&' if params_found else '?'
                params_found = True

                try:
                    k, v = q.split('=', 1)

                except ValueError:
                    buf += q

                else:
                    # Riak CS multipart upload ids look like this, `TFDSheOgTxC2Tsh1qVK73A==`,
                    # is should be escaped to be included as part of a query string.
                    #
                    # A requests mp upload part request may look like
                    # resp = requests.put(
                    #     'https://url_here',
                    #     params={
                    #         'partNumber': 1,
                    #         'uploadId': 'TFDSheOgTxC2Tsh1qVK73A=='
                    #     },
                    #     data='some data',
                    #     auth=S3Auth('access_key', 'secret_key')
                    # )
                    #
                    # Requests automatically escapes the values in the `params` dict, so now
                    # our uploadId is `TFDSheOgTxC2Tsh1qVK73A%3D%3D`,
                    # if we sign the request with the encoded value the signature will
                    # not be valid, we'll get 403 Access Denied.
                    # So we unquote, this is no-op if the value isn't encoded.
                    buf += '{key}={value}'.format(key=k, value=unquote(v))

        return buf


class S3Client(object):
    PROD_URL = 's3.mds.yandex.net'
    TEST_URL = 's3.mdst.yandex.net'

    def __init__(self, access_key, secret_key, base_url=TEST_URL):
        self.base_url = base_url
        self.auth = S3Auth(access_key, secret_key, base_url)

    @staticmethod
    def base_request(method, url, **kwargs):
        retries = 8
        backoff = 2
        timeout = 1

        while True:
            r = requests.request(method, url, **kwargs)
            if 500 <= r.status_code <= 599 and retries > 0:
                retries -= 1
                time.sleep(timeout)
                timeout *= backoff
                continue
            else:
                break

        if r.status_code not in (200, 204):
            r.raise_for_status()

        return r

    def request(self, method, bucket=None, name=None, data=None):
        url = "https://" + '/'.join(filter(None, [self.base_url, bucket, name]))
        kwargs = {}
        if self.base_url == self.TEST_URL:
            kwargs["verify"] = False

        return self.base_request(method, url, data=data, auth=self.auth, **kwargs)

    def _fetch_list(self, element, path):
        if len(path) == 0:
            if len(element):
                return [{child.tag: child.text for child in element}]
            else:
                return [element.text]
        else:
            ret = []
            for child in element:
                if child.tag == path[0]:
                    ret.extend(self._fetch_list(child, path[1:]))
            return ret

    def dump_list(self, document, path, formatter):
        path = path.split("/")
        element = ET.fromstring(document.replace(' xmlns="http://s3.amazonaws.com/doc/2006-03-01/"', ''))

        for line in formatter(self._fetch_list(element, path)):
            print(line.encode("utf-8"))


def longest(key, items):
    if items:
        return max(len(i[key]) for i in items)
    else:
        return 0


def format_bucket_item(items):
    length = longest("Name", items) + 3
    for i in items:
        yield u"{name:{length}} | {ctime}".format(
            name=i["Name"],
            length=length,
            ctime=i["CreationDate"]
        )


def format_file_item(items):
    length = longest("Key", items) + 3

    for i in items:
        size = float(i["Size"])
        unit = "b"
        for new_unit in ["kB", "mB", "gB"]:
            if size / 1024 > 1:
                size /= 1024
                unit = new_unit

        yield u"{key:{length}} | {size:>8.3} {unit:2} | {mtime}".format(
            key=i["Key"],
            length=length,
            size=size,
            unit=unit,
            mtime=i["LastModified"]
        )


def accept_args(*argnames):
    argnames = set(argnames)

    def _decorator(func):
        @functools.wraps(func)
        def _decorated(args):
            kwargs = {argname: argval for argname, argval in vars(args).items() if argname in argnames}
            return func(**kwargs)

        return _decorated

    return _decorator


@accept_args("client", "bucket", "name", "filename")
def upload(client, bucket, name, filename=None):
    data = sys.stdin if filename is None else open(filename)

    with data:
        return client.request("put", bucket, name, data)


@accept_args("client", "bucket", "name", "filename")
def fetch(client, bucket, name, filename=None):
    out = sys.stdout if filename is None else open(filename, "w")

    with out:
        out.write(client.request("get", bucket, name).content)


@accept_args("client", "bucket", "name")
def remove(client, bucket, name):
    client.request("delete", bucket, name)


@accept_args("client")
def list_buckets(client):
    data = client.request("get").content
    client.dump_list(data, path="Buckets/Bucket", formatter=format_bucket_item)


@accept_args("client", "bucket")
def list_names(client, bucket):
    data = client.request("get", bucket).content
    client.dump_list(data, path="Contents", formatter=format_file_item)


@accept_args("client", "bucket")
def new_bucket(client, bucket):
    client.request("put", bucket)


class _HelpTokenUrl(argparse.Action):
    type = str

    def __call__(self, parser, namespace, values, option_string=None):
        self.print_token_url(system=values, testing=namespace.testing)
        parser.exit()

    def print_token_url(self, system, testing=False):
        if system == "abc":
            oauth_client_id = "9e9702c0b7f54152ac339989d9039ccd"
        elif system == "s3-mds":
            oauth_client_id = 'b43e4f9172184d8d95f59bd91f697d7a' if testing else '6797456f343042aabba07f49b478c49b'
        else:
            raise argparse.ArgumentError(self, "Unknown system {}".format(system))

        url = "https://oauth.yandex-team.ru/authorize?response_type=token&client_id={0}".format(oauth_client_id)
        print("Get your token for {} here:".format(system))
        print(url)


@accept_args("client", "token", "slug")
def get_service_id(client, token, slug):
    url = "https://abc-api.yandex-team.ru/v2/services"
    params = {
        "_one": 1,
        "_fields": "name,slug,id",
        "slug": slug
    }
    headers = {"Authorization": "OAuth " + token}

    r = client.base_request("get", url, params=params, headers=headers)
    content_type = r.headers["Content-Type"].split(";")[0]
    if content_type != "application/json":
        print("Some error occurred, probably token is invalid", file=sys.stderr)
        print(content_type)
        sys.exit(2)
    print(json.dumps(r.json(), indent=2, ensure_ascii=False).encode("utf-8"))


@accept_args("client", "token", "service_id", "role", "testing")
def get_access_key(client, token, service_id, role, testing=False):
    kwargs = {}
    if testing:
        base_url = 's3-idm.mdst.yandex.net'
        kwargs["verify"] = False
    else:
        base_url = 's3-idm.mds.yandex.net'
    url = "https://{0}/credentials/create-access-key".format(base_url)
    data = {
        "service_id": service_id,
        "role": role,
    }
    headers = {"Authorization": "OAuth " + token}

    r = client.base_request("post", url, data=data, headers=headers, **kwargs)
    if r.status_code == 200 and r.headers["Content-Type"].startswith("application/json"):
        result = r.json()
        print(json.dumps(result, indent=2, ensure_ascii=False))
        print("You can now run this script like this:")
        print(u"{prog} -k '{key_id}' -s '{secret_key}' [command]".format(
            prog=sys.argv[0],
            key_id=result["AccessKeyId"],
            secret_key=result["AccessSecretKey"]
        ))
        print("Don't forget to store this key in a safe place!")
    else:
        print("Failed to fetch access key, answer was:", file=sys.stderr)
        print(r.headers, file=sys.stderr)
        print(r.content, file=sys.stderr)
        sys.exit(2)


def _parse_args():
    parser = argparse.ArgumentParser(description="Wall-E uploader for S3-MDS.")
    parser.add_argument("-c", "--config", help="configuration file path")

    parser.add_argument("-T", "--testing", action="store_true", help="Use testing S3-MDS service")
    parser.add_argument("-u", "--url", help="url for s3-mds service")
    parser.add_argument("-k", "--access-key", help="access key for s3-mds service")
    parser.add_argument("-s", "--secret-key", help="secret key for s3-mds service")

    action = parser.add_subparsers()

    upload_action = action.add_parser("upload", help="Upload file to a specified bucket")
    upload_action.set_defaults(handle=upload)
    upload_action.add_argument("bucket", help="bucket to upload file to")
    upload_action.add_argument("name", help="key/name to use for data on s3")
    upload_action.add_argument("filename", nargs='?', help="file to upload")

    fetch_action = action.add_parser("fetch", help="Fetch specified file from a bucket")
    fetch_action.set_defaults(handle=fetch)
    fetch_action.add_argument("bucket", help="bucket to fetch file from")
    fetch_action.add_argument("name", help="key/name to fetch from s3")
    fetch_action.add_argument("filename", nargs='?', help="save data to filename")

    remove_action = action.add_parser("remove", help="Remove specified file from a bucket")
    remove_action.set_defaults(handle=remove)
    remove_action.add_argument("bucket", help="bucket to remove file from")
    remove_action.add_argument("name", help="key/name to remove from s3")

    list_buckets_action = action.add_parser("list-names", help="List files in a bucket")
    list_buckets_action.set_defaults(handle=list_names)
    list_buckets_action.add_argument("bucket", help="bucket name")

    list_buckets_action = action.add_parser("list-buckets", help="List buckets")
    list_buckets_action.set_defaults(handle=list_buckets)

    new_bucket_action = action.add_parser("new-bucket", help="Create new bucket")
    new_bucket_action.set_defaults(handle=new_bucket)
    new_bucket_action.add_argument("bucket", help="bucket name")

    get_key_action = action.add_parser("get-key", help="Get access key to access your service in S3-MDS")
    get_key_action.set_defaults(handle=get_access_key)
    get_key_action.add_argument("token", help="oauth token for S3-MDS service")
    get_key_action.add_argument("service_id", help="Service ID in ABC")
    get_key_action.add_argument("role", choices=("reader", "admin", "owner"), help="User role in S3 service")

    get_service_id_action = action.add_parser(
        "get-service-id",
        help="Fetch your service's id from ABC by slug",
        description=(
            "Fetch abc service id from ABC for you. Get your oauth token for staff/abc at "
            "https://oauth.yandex-team.ru/authorize?response_type=token&client_id=9e9702c0b7f54152ac339989d9039ccd. "
            "Slug if your abc service is usually the last component of the url, e.g. for "
            "https://abc.yandex-team.ru/services/Wall-E/ the slug is 'Wall-E'"
        )
    )
    get_service_id_action.set_defaults(handle=get_service_id)
    get_service_id_action.add_argument("token", help="oauth token for abc/staff service (try get-token-url --help)")
    get_service_id_action.add_argument("slug", help="slug if your abc service")

    get_token_url_action = action.add_parser("get-token-url",
                                             help="Print url to obtain the oauth token for the system",
                                             )
    get_token_url_action.add_argument("system", choices=("s3-mds", "abc"), action=_HelpTokenUrl,
                                      help="Print url to obtain the oauth token for the system")

    return parser.parse_args()


if __name__ == '__main__':
    args = _parse_args()

    # use prod s3 service if not specified.
    if args.testing:
        args.url = S3Client.TEST_URL
    elif not args.url:
        args.url = S3Client.PROD_URL

    args.client = S3Client(args.access_key, args.secret_key, args.url or S3Client.TEST_URL)
    args.handle(args)
