# coding: utf-8

import base64
import hashlib
import json
import os
import re
import time

import six

from .base import VaultCLICommand


MIN_TLS_TICKET_LENGTH = 12
DEFAULT_TLS_TICKET_LENGTH = 48
MIN_TTL = 1
DEFAULT_TTL = 28
MIN_TICKETS_COUNT = 3
DEFAULT_TICKETS_COUNT = 3
SECONDS_IN_HOUR = 60 * 60


class TLSTicketValueError(ValueError):
    pass


class TLSTicket(object):
    def __init__(self, value=None, hexdigest=None, length=DEFAULT_TLS_TICKET_LENGTH, suffix=None):
        self.suffix = suffix or ''
        self.length = length

        self.value = value
        self.hexdigest = hexdigest
        if self.value is None:
            self.value = self.generate_ticket()
            self.hexdigest = self.make_hash(self.value)
        self.validate()

    def validate(self):
        value = self.value
        if isinstance(value, six.text_type):
            value = value.encode('utf-8')
        bin_value = base64.b64decode(value)
        if len(bin_value) != self.length:
            raise TLSTicketValueError(u'An invalid ticket content length. ({} != {}, {})'.format(
                len(bin_value),
                self.length,
                self.value,
            ))

        if self.make_hash(self.value) != self.hexdigest:
            raise TLSTicketValueError(
                u'Ticket error: "{}" is an invalid sha256 hexdigest for "{}"'.format(self.hexdigest, self.value)
            )

    def generate_ticket(self):
        # On a UNIX-like system this will query /dev/urandom,
        # and on Windows it will use CryptGenRandom()
        return base64.b64encode(os.urandom(self.length)).decode('utf-8')

    def make_hash(self, value):
        if isinstance(value, six.text_type):
            value = value.encode('utf-8')
        return hashlib.sha256(value).hexdigest()

    def dump(self, num=0):
        key = u'{}{}'.format(
            str(num),
            '.{}'.format(self.suffix) if self.suffix else '',
        )
        return {
            key: self.value,
            u'{}.sha256'.format(key): self.hexdigest,
        }


class TLSSecretValueError(ValueError):
    pass


class TLSSecret(object):
    def __init__(self, count=DEFAULT_TICKETS_COUNT, length=DEFAULT_TLS_TICKET_LENGTH,
                 ttl=DEFAULT_TTL, suffix=None, tickets=None, ts=None):
        self.warnings = []
        self.count = count
        if self.count < MIN_TICKETS_COUNT:
            self.count = MIN_TICKETS_COUNT
            self._warn('Tickets count less then {mtc} ({count}), set to {mtc}'.format(
                count=count,
                mtc=MIN_TICKETS_COUNT,
            ))

        self.length = length
        if self.length < MIN_TLS_TICKET_LENGTH:
            self.length = MIN_TLS_TICKET_LENGTH
            self._warn('Ticket length less then {mtl} ({length}), set to {mtl}'.format(
                length=length,
                mtl=MIN_TLS_TICKET_LENGTH,
            ))

        self.ttl = ttl
        if self.ttl < MIN_TTL:
            self.ttl = MIN_TTL
            self._warn('TTL less then {mttl} ({ttl}), set to {mttl}'.format(
                ttl=ttl,
                mttl=MIN_TTL,
            ))

        self.suffix = suffix or ''
        if tickets:
            self.tickets = tickets
        else:
            self.tickets = [self.create_new_ticket() for _ in range(0, count)]
        self._ts = ts

    def _warn(self, line):
        self.warnings.append(line)

    @property
    def ts(self):
        if self._ts is None:
            self._ts = time.time()
        return self._ts

    @classmethod
    def from_version(cls, version):
        options = version['value'].get('options', u'{}')
        options = json.loads(options)

        suffix = options.get('suffix', '')
        tickets = cls._get_tickets_from_version(version, suffix=suffix)
        if not tickets:
            raise TLSSecretValueError(u'Can\'t find tickets in a version value (suffix: "{}")'.format(suffix))

        ts = version['value'].get('ts', '')
        try:
            ts = float(ts)
            if ts <= 0:
                raise ValueError()
        except:
            raise TLSSecretValueError(u'"{}" is an invalid timestamp (ts) value'.format(ts))

        tls_secret = TLSSecret(
            count=options.get('count', DEFAULT_TICKETS_COUNT),
            length=options.get('length', DEFAULT_TLS_TICKET_LENGTH),
            ttl=options.get('ttl', DEFAULT_TTL),
            suffix=suffix,
            ts=ts,
            tickets=tickets,
        )
        if tls_secret.count != len(tls_secret.tickets):
            raise TLSSecretValueError(
                u'The version tickets count mismatch tickets counts in options or less minimal tickets count. '
                u'({} != {}) '
                u'Minimal tickets count is {}'.format(
                    tls_secret.count, len(tls_secret.tickets), MIN_TICKETS_COUNT,
                )
            )

        return tls_secret

    @classmethod
    def _get_tickets_from_version(cls, version, suffix=None):
        value_re = re.compile(r'^(\d+)\.{}$'.format(re.escape(suffix)) if suffix else r'^(\d+)$')
        hexdigest_re = re.compile(r'^(\d+)\.{}\.sha256$'.format(re.escape(suffix)) if suffix else r'^(\d+)\.sha256$')

        tickets = dict()
        for k, v in version['value'].items():
            vm = value_re.match(k)
            if vm:
                tickets.setdefault(vm.group(1), {})['value'] = v
                continue
            hm = hexdigest_re.match(k)
            if hm:
                tickets.setdefault(hm.group(1), {})['hexdigest'] = v
                continue

        result = []
        for _, t in sorted(tickets.items(), key=lambda x: int(x[0])):
            tls_ticket = TLSTicket(
                value=t.get('value', ''),
                hexdigest=t.get('hexdigest', ''),
                suffix=suffix,
            )
            result.append(tls_ticket)
        return result

    def create_new_ticket(self):
        return TLSTicket(length=self.length, suffix=self.suffix)

    def dump(self):
        result = {
            'ts': self.ts,
            'options': json.dumps({
                'count': self.count,
                'length': self.length,
                'ttl': self.ttl,
                'suffix': self.suffix,
            }, sort_keys=True),
        }
        for num, ticket in enumerate(self.tickets):
            result.update(ticket.dump(num=num))
        return result

    def rotate(self, force=False):
        current_ts = time.time()
        if current_ts >= self.ts + (self.ttl * SECONDS_IN_HOUR) or force:
            self.tickets.insert(0, self.create_new_ticket())
            self.tickets.pop()
            self._ts = current_ts
            return True
        return False


class TLSBaseCommand(VaultCLICommand):
    def log(self, line):
        if not self.as_json:
            self.echo(line, err=True)


class TLSCreateCommand(TLSBaseCommand):
    """
    Create a tls-tickets secret
    """
    def __init__(self, args_parser, *args, **kwargs):
        super(TLSCreateCommand, self).__init__(args_parser, *args, **kwargs)
        self.usage = '%(prog)s name [-c <count>] [-l <length>] [-s <suffix>] [-t <TTL in hours>] [-C <comment>]'
        self.add_argument('name', help='Non unique secret name')
        self.add_base_argument('secret_comment')
        self.add_base_argument('tags')
        self.add_argument('--count', dest='count', type=int, default=DEFAULT_TICKETS_COUNT, help='A tls-tickets count (min: 3)')
        self.add_argument('--length', dest='length', type=int, default=DEFAULT_TLS_TICKET_LENGTH, help='A tls-ticket bytes length (min: 12)')
        self.add_argument('--key-suffix', dest='suffix', help='Suffix for keys with tickets')
        self.add_argument('--ttl', dest='ttl', type=int, default=DEFAULT_TTL,
                          help='After how many hours you can rotate the ticket (min: 1)', metavar='HOURS')

    def serialize_success_response(self, response):
        return self.serialize_secret_response(response)

    def process(self, cli_args, client, debug=False, *args, **kwargs):
        tls_secret = TLSSecret(count=cli_args.count, length=cli_args.length, ttl=cli_args.ttl, suffix=cli_args.suffix)
        for w in tls_secret.warnings:
            self.log(w)

        response = client.create_secret(
            cli_args.name,
            value=tls_secret.dump(),
            comment=cli_args.comment,
            tags=cli_args.tags,
        )
        self.print_response_and_exit(response)


class TLSRotateCommandError(Exception):
    pass


class TLSRotateCommand(TLSBaseCommand):
    """
    Rotate a tls-tickets secret
    """
    def __init__(self, args_parser, *args, **kwargs):
        super(TLSRotateCommand, self).__init__(args_parser, *args, **kwargs)
        self.usage = '%(prog)s [--force]'
        self.add_base_argument('secret_uuid')
        self.add_argument('--force', '-f', dest='force', action='store_true', help='Force rotate tickets')
        self.rotated = False

    def serialize_success_response(self, response):
        status = u'\nstatus: Rotation is not required. Use the -f option to force rotation'
        if self.rotated:
            status = u'\nstatus: Secret was rotated'
        return u'\n'.join([status, self.serialize_version_response(response)])

    def _validate_client_response(self, response):
        if not response.success:
            raise response.e

    def process(self, cli_args, client, debug=False, *args, **kwargs):
        self.log('Get a secret ({})'.format(cli_args.secret_uuid))
        secret_response = client.get_secret(cli_args.secret_uuid)
        self._validate_client_response(secret_response)
        secret_uuid = secret_response.result['uuid']
        if not any(map(lambda x: x['role_slug'].upper() == 'OWNER', secret_response.result['acl'])):
            raise TLSRotateCommandError(
                'You don\'t have an owner permission to the secret ({})'.format(
                    secret_uuid,
                )
            )

        self.log('Get the last secret version')
        version_response = client.get_version(secret_uuid)
        self._validate_client_response(version_response)
        version = version_response.result

        self.log('Validate the last version')
        tls_secret = TLSSecret.from_version(version)
        for w in tls_secret.warnings:
            self.log(w)

        if cli_args.force:
            self.log('Force rotate a version')
        else:
            self.log('Rotate a version if needed')

        self.rotated = tls_secret.rotate(force=cli_args.force)

        response = version_response
        if self.rotated:
            self.log('Store the new version')
            response = client.create_version(
                secret_uuid,
                value=tls_secret.dump(),
                old_version=version,
            )
            self._validate_client_response(response)

            self.log('Hide an old version')
            update_version_response = client.update_version(version['version'], state='hidden')
            self._validate_client_response(update_version_response)

        self.print_response_and_exit(response)
