# -*- coding: utf-8 -*-

import json
import time

from passport.backend.vault.api.db import get_db
from passport.backend.vault.api.errors import MaximumFieldLengthExceededError
from passport.backend.vault.api.models.base import (
    BaseModel,
    MagicBigInteger,
    MagicBLOB,
    MagicInteger,
    State,
    Timestamp,
    UUIDType,
)
from passport.backend.vault.api.models.user_info import CreatorMixin
from passport.backend.vault.api.value_manager import (
    get_value_manager,
    ValueType,
)
from sqlalchemy import (
    and_,
    ForeignKeyConstraint,
    Index,
    or_,
)

from .secret import (
    Secret,
    SecretUUIDType,
)


db = get_db()

max_total_key_names_length = 32768
max_value_length = 65536


class SecretVersionUUIDType(UUIDType):
    prefix = 'ver'


class SecretVersion(BaseModel, CreatorMixin):
    __tablename__ = 'secret_versions'
    __repr_attrs__ = ['secret_uuid']

    default_serialization_columns = [
        'version', 'created_at', 'created_by', 'comment', 'value', 'expired_at',
    ]
    default_serialization_pycolumns = [
        'secret_name', 'creator_login', 'keys', 'expired',
    ]
    max_serialization_depth = 3

    version = db.Column(SecretVersionUUIDType, primary_key=True, default=lambda: SecretVersionUUIDType.create_ulid())
    secret_uuid = db.Column(SecretUUIDType, nullable=False)
    secret = db.relationship(
        'Secret',
        lazy='joined',
        uselist=False,
    )

    state = db.Column(db.Integer, nullable=False, default=State.normal.value, server_default='0')

    expired_at = db.Column(Timestamp, nullable=True)

    @staticmethod
    def check_expired(version):
        expired_at = None
        if hasattr(version, 'expired_at'):
            expired_at = version.expired_at
        elif isinstance(version, dict):
            expired_at = version.get('expired_at')

        if expired_at is None:
            return False
        return expired_at < time.time()

    @property
    def hidden(self):
        return (self.state == State.hidden.value)

    @property
    def expired(self):
        return SecretVersion.check_expired(self) if self.expired_at is not None else None

    def set_ttl(self, ttl=None):
        if ttl is not None:
            if ttl > 0:
                self.expired_at = time.time() + ttl
            else:
                self.expired_at = None

    @property
    def transitive_state(self):
        return max(self.state, self.secret.state)

    def state_name(self):
        return State(self.transitive_state).name

    parent_version_uuid = db.Column(SecretVersionUUIDType, nullable=True)
    parent_version = db.relationship(
        'SecretVersion',
        primaryjoin='SecretVersion.parent_version_uuid == foreign(SecretVersion.version)',
        lazy='select',
        uselist=False,
        viewonly=True,
    )

    def parent_diff_keys(self):
        result = None
        if self.parent_version:
            def unpack_value(value, keys=None):
                keys = keys or []
                return {
                    v['key']: v['value'] for v in value if v['key'] in keys
                }

            new_keys = set(self.keys())
            old_keys = set(self.parent_version.keys())

            changed = None
            shared_keys = new_keys & old_keys
            if shared_keys:
                old_value = unpack_value(self.parent_version.value, shared_keys)
                new_value = unpack_value(self.value, shared_keys)
                changed = sorted([k for k in old_value if old_value[k] != new_value[k]])

            result = dict(
                added=sorted(list(new_keys - old_keys)),
                removed=sorted(list(old_keys - new_keys)),
                changed=sorted(changed or []),
            )
        return result

    cipher_key_id = db.Column(db.String(32), nullable=False)
    _value = db.Column('value', MagicBLOB, nullable=False)
    value_type = db.Column(
        MagicInteger,
        nullable=False,
        default=ValueType.COMPRESSED_AES.value,
        server_default=str(ValueType.COMPRESSED_AES.value),
    )

    def secret_name(self):
        return self.secret.name

    @property
    def value(self):
        result = self.__dict__.get('_cached_value')
        if not result:
            result = self._cached_value = json.loads(
                get_value_manager(self.config).decode(
                    index=self.cipher_key_id,
                    algorithm=self.value_type,
                    value=self._value,
                ),
            )
        return result

    @value.setter
    def value(self, value):
        value_str = json.dumps(value or {})

        if len(value_str) > max_value_length:
            raise MaximumFieldLengthExceededError(
                'values',
                len(value_str),
                max_value_length,
            )

        self._cached_value = value or {}
        self.cipher_key_id, self.value_type, self._value = get_value_manager(self.config).encode(
            value=value_str,
        )

    creator_user_info = CreatorMixin.creator_relationship('SecretVersion')

    def keys(self):
        return self._keys.split(',') if self._keys else []

    comment = db.Column(db.String(1023), nullable=True)

    updated_at = db.Column(Timestamp, nullable=True)
    updated_by = db.Column(MagicBigInteger, nullable=True)

    def touch(self, updated_by, updated_at=None):
        self.updated_at = updated_at if updated_at else time.time()
        self.updated_by = updated_by

    def set_keys_from_value(self, value):
        keys = ','.join(list(sorted([el['key'] for el in value])))

        if len(keys) > max_total_key_names_length:
            raise MaximumFieldLengthExceededError(
                'keys',
                len(keys),
                max_total_key_names_length,
            )

        self._keys = keys

    _keys = db.Column('keys', db.Text(max_total_key_names_length), nullable=True)

    __table_args__ = (
        ForeignKeyConstraint(['secret_uuid'], [Secret.uuid], name='secret_versions_ibfk_1'),
        Index('idx_secret_versions_secret_uuid', 'secret_uuid'),
        Index('idx_secret_versions_created_at', 'created_at'),
        Index('idx_secret_versions_created_by', 'created_by'),
    )

    @staticmethod
    def create_secret_version(created_by, secret, value, comment=None, parent_version_uuid=None, ttl=None):
        current_time = time.time()
        secret_version = SecretVersion(
            version=SecretVersionUUIDType.create_ulid(),
            secret_uuid=secret.uuid,
            created_at=current_time,
            created_by=created_by,
            comment=comment,
            parent_version_uuid=parent_version_uuid,
        )
        secret_version.value = value
        secret_version.set_keys_from_value(value)
        secret_version.set_ttl(ttl)
        secret.touch(updated_by=created_by, updated_at=current_time)
        return secret_version

    @staticmethod
    def unexpired_version_filter():
        return or_(
            SecretVersion.expired_at.is_(None),
            SecretVersion.expired_at >= time.time(),
        )

    @staticmethod
    def only_visible_versions_filter():
        return and_(
            SecretVersion.unexpired_version_filter(),
            SecretVersion.state == State.normal.value,
        )
