# -*- coding: utf-8 -*-

from base64 import (
    b64decode,
    b64encode,
)
import json
import logging

from passport.backend.core.crypto.signing import (
    simple_is_correct_signature,
    simple_sign,
)
from passport.backend.core.models.drive import DriveSession
from passport.backend.core.ydb.ydb import get_ydb_drive_session
import six


log = logging.getLogger(__name__)


def find_drive_session(drive_device_id):
    ydb = get_ydb_drive_session()
    drive_session_dict = ydb.first(dict(drive_device_id=drive_device_id))
    if drive_session_dict is None:
        return
    drive_session_dict.update(drive_device_id=drive_device_id)

    drive_session = DriveSession()
    drive_session.parse(drive_session_dict)

    signature = b64decode(drive_session_dict['signature'].encode())
    if not simple_is_correct_signature(signature, drive_session.to_bytes()):
        log.warning('Ignore invalid drive session (invalid signature): %s' % drive_session_dict)
        return

    return drive_session


def save_drive_session(drive_session):
    signature = simple_sign(drive_session.to_bytes())
    signature = b64encode(signature).decode()

    drive_session_dict = {
        key: value
        for key, value in six.iteritems(dict(drive_session))
        if value
    }
    drive_session_dict.update(
        signature=str(signature),
    )
    drive_session_dict.pop('drive_device_id', None)

    ydb = get_ydb_drive_session()
    ydb.set(
        dict(drive_device_id=drive_session.drive_device_id),
        json.dumps(drive_session_dict),
    )


def delete_drive_session(drive_device_id):
    ydb = get_ydb_drive_session()
    return ydb.delete(dict(drive_device_id=drive_device_id))
