import argparse
import functools
import hmac
import json
import sys
import xml.etree.ElementTree as ET
from base64 import encodebytes
from email.utils import formatdate
from hashlib import sha1 as sha
from urllib.parse import urlparse, unquote

import requests
import time
from requests.auth import AuthBase


class S3Auth(AuthBase):
    """Attaches AWS Authentication to the given Request object."""

    # List of Query String Arguments of Interest
    special_params = [
        'acl', 'location', 'logging', 'partNumber', 'policy', 'requestPayment',
        'torrent', 'versioning', 'versionId', 'versions', 'website', 'uploads',
        'uploadId', 'response-content-type', 'response-content-language',
        'response-expires', 'response-cache-control', 'delete', 'lifecycle',
        'response-content-disposition', 'response-content-encoding', 'tagging',
        'notification', 'cors'
    ]

    def __init__(self, access_key, secret_key, service_url):
        self.access_key = str(access_key)
        self.secret_key = str(secret_key)
        self.service_base_url = service_url

    def __call__(self, r):
        # Create date header if it is not created yet.
        if 'date' not in r.headers and 'x-amz-date' not in r.headers:
            r.headers['date'] = formatdate(
                timeval=None,
                localtime=False,
                usegmt=True)
        signature = self.get_signature(r)
        r.headers['Authorization'] = 'AWS %s:%s' % (self.access_key, signature.decode())
        return r

    def get_signature(self, r):
        canonical_string = self.get_canonical_string(
            r.url, r.headers, r.method)
        key = str(self.secret_key).encode('utf-8')
        msg = canonical_string.encode('utf-8')
        h = hmac.new(key, msg, digestmod=sha)
        return encodebytes(h.digest()).strip()

    def get_canonical_string(self, url, headers, method):
        parsedurl = urlparse(url)
        objectkey = parsedurl.path[1:]
        query_args = sorted(parsedurl.query.split('&'))

        bucket = parsedurl.netloc[:-len(self.service_base_url)]
        if len(bucket) > 1:
            # remove last dot
            bucket = bucket[:-1]

        interesting_headers = {
            'content-md5': '',
            'content-type': '',
            'date': ''}
        for key in headers:
            lk = key.lower()
            try:
                lk = lk.decode('utf-8')
            except:
                pass
            if headers[key] and (lk in interesting_headers.keys()
                                 or lk.startswith('x-amz-')):
                interesting_headers[lk] = headers[key].strip()

        # If x-amz-date is used it supersedes the date header.
        if 'x-amz-date' in interesting_headers:
            interesting_headers['date'] = ''

        buf = '%s\n' % method
        for key in sorted(interesting_headers.keys()):
            val = interesting_headers[key]
            if key.startswith('x-amz-'):
                buf += '%s:%s\n' % (key, val)
            else:
                buf += '%s\n' % val

        # append the bucket if it exists
        if bucket != '':
            buf += '/%s' % bucket

        # add the objectkey. even if it doesn't exist, add the slash
        buf += '/%s' % objectkey

        params_found = False

        # handle special query string arguments
        for q in query_args:
            k = q.split('=')[0]
            if k in self.special_params:
                buf += '&' if params_found else '?'
                params_found = True

                try:
                    k, v = q.split('=', 1)

                except ValueError:
                    buf += q

                else:
                    # Riak CS multipart upload ids look like this, `TFDSheOgTxC2Tsh1qVK73A==`,
                    # is should be escaped to be included as part of a query string.
                    #
                    # A requests mp upload part request may look like
                    # resp = requests.put(
                    #     'https://url_here',
                    #     params={
                    #         'partNumber': 1,
                    #         'uploadId': 'TFDSheOgTxC2Tsh1qVK73A=='
                    #     },
                    #     data='some data',
                    #     auth=S3Auth('access_key', 'secret_key')
                    # )
                    #
                    # Requests automatically escapes the values in the `params` dict, so now
                    # our uploadId is `TFDSheOgTxC2Tsh1qVK73A%3D%3D`,
                    # if we sign the request with the encoded value the signature will
                    # not be valid, we'll get 403 Access Denied.
                    # So we unquote, this is no-op if the value isn't encoded.
                    buf += '{key}={value}'.format(key=k, value=unquote(v))

        return buf


class S3Client(object):
    PROD_URL = 's3.mds.yandex.net'
    TEST_URL = 's3.mdst.yandex.net'

    def __init__(self, access_key, secret_key, base_url=TEST_URL):
        self.base_url = base_url
        self.auth = S3Auth(access_key, secret_key, base_url)

    @staticmethod
    def base_request(method, url, **kwargs):
        retries = 8
        backoff = 2
        timeout = 1
        while True:
            r = requests.request(method, url, **kwargs)
            if 500 <= r.status_code <= 599 and retries > 0:
                retries -= 1
                time.sleep(timeout)
                timeout *= backoff
                continue
            else:
                break

        if r.status_code not in (200, 204):
            r.raise_for_status()
        return r

    def request(self, method, bucket=None, name=None, data=None):
        url = "https://" + '/'.join(filter(None, [self.base_url, bucket, name]))
        kwargs = {}
        if self.base_url == self.TEST_URL:
            kwargs["verify"] = False
        return self.base_request(method, url, data=data, auth=self.auth, **kwargs)

    def _fetch_list(self, element, path):
        if len(path) == 0:
            if len(element):
                return [{child.tag: child.text for child in element}]
            else:
                return [element.text]
        else:
            ret = []
            for child in element:
                if child.tag == path[0]:
                    ret.extend(self._fetch_list(child, path[1:]))
            return ret

    def dump_list(self, document, path, formatter):
        path = path.split("/")
        element = ET.fromstring(document.replace(b' xmlns="http://s3.amazonaws.com/doc/2006-03-01/"', b''))

        for line in formatter(self._fetch_list(element, path)):
            print(line.encode("utf-8"))

    def return_list(self, document, path, formatter):
        path = path.split("/")
        element = ET.fromstring(document.replace(b' xmlns="http://s3.amazonaws.com/doc/2006-03-01/"', b''))
        return [line.encode("utf-8").decode() for line in formatter(self._fetch_list(element, path))]


def longest(key, items):
    if items:
        return max(len(i[key]) for i in items)
    else:
        return 0


def format_bucket_item(items):
    length = longest("Name", items) + 3
    for i in items:
        yield u"{name:{length}} | {ctime}".format(
            name=i["Name"],
            length=length,
            ctime=i["CreationDate"]
        )


def format_file_item(items):
    for i in items:
        yield i["Key"]


def accept_args(*argnames):
    argnames = set(argnames)

    def _decorator(func):
        @functools.wraps(func)
        def _decorated(args):
            kwargs = {argname: argval for argname, argval in vars(args).items() if argname in argnames}
            return func(**kwargs)

        return _decorated

    return _decorator


@accept_args("client", "bucket", "name", "filename")
def upload(client, bucket, name, filename=None):
    data = sys.stdin if filename is None else open(filename)

    with data:
        return client.request("put", bucket, name, data)


@accept_args("client", "bucket", "name", "filename")
def fetch(client, bucket, name, filename=None):
    out = sys.stdout if filename is None else open(filename, "w")

    with out:
        out.write(client.request("get", bucket, name).content)


@accept_args("client", "bucket", "name")
def remove(client, bucket, name):
    client.request("delete", bucket, name)


@accept_args("client")
def list_buckets(client):
    data = client.request("get").content
    client.dump_list(data, path="Buckets/Bucket", formatter=format_bucket_item)


@accept_args("client", "bucket")
def list_names(client, bucket):
    data = client.request("get", bucket).content
    client.dump_list(data, path="Contents", formatter=format_file_item)


@accept_args("client", "bucket")
def new_bucket(client, bucket):
    client.request("put", bucket)


class _HelpTokenUrl(argparse.Action):
    type = str

    def __call__(self, parser, namespace, values, option_string=None):
        self.print_token_url(system=values, testing=namespace.testing)
        parser.exit()

    def print_token_url(self, system, testing=False):
        if system == "abc":
            oauth_client_id = "9e9702c0b7f54152ac339989d9039ccd"
        elif system == "s3-mds":
            oauth_client_id = 'b43e4f9172184d8d95f59bd91f697d7a' if testing else '6797456f343042aabba07f49b478c49b'
        else:
            raise argparse.ArgumentError(self, "Unknown system {}".format(system))

        url = "https://oauth.yandex-team.ru/authorize?response_type=token&client_id={0}".format(oauth_client_id)
        print("Get your token for {} here:".format(system))
        print(url)


@accept_args("client", "token", "slug")
def get_service_id(client, token, slug):
    url = "https://abc-api.yandex-team.ru/v2/services"
    params = {
        "_one": 1,
        "_fields": "name,slug,id",
        "slug": slug
    }
    headers = {"Authorization": "OAuth " + token}

    r = client.base_request("get", url, params=params, headers=headers)
    content_type = r.headers["Content-Type"].split(";")[0]
    if content_type != "application/json":
        print("Some error occurred, probably token is invalid", file=sys.stderr)
        print(content_type)
        sys.exit(2)
    print(json.dumps(r.json(), indent=2, ensure_ascii=False).encode("utf-8"))


@accept_args("client", "token", "service_id", "role", "testing")
def get_access_key(client, token, service_id, role, testing=False):
    kwargs = {}
    if testing:
        base_url = 's3-idm.mdst.yandex.net'
        kwargs["verify"] = False
    else:
        base_url = 's3-idm.mds.yandex.net'
    url = "https://{0}/credentials/create-access-key".format(base_url)
    data = {
        "service_id": service_id,
        "role": role,
    }
    headers = {"Authorization": "OAuth " + token}

    r = client.base_request("post", url, data=data, headers=headers, **kwargs)
    if r.status_code == 200 and r.headers["Content-Type"].startswith("application/json"):
        result = r.json()
        print(json.dumps(result, indent=2, ensure_ascii=False))
        print("You can now run this script like this:")
        print(u"{prog} -k '{key_id}' -s '{secret_key}' [command]".format(
            prog=sys.argv[0],
            key_id=result["AccessKeyId"],
            secret_key=result["AccessSecretKey"]
        ))
        print("Don't forget to store this key in a safe place!")
    else:
        print("Failed to fetch access key, answer was:", file=sys.stderr)
        print(r.headers, file=sys.stderr)
        print(r.content, file=sys.stderr)
        sys.exit(2)
