# -*- coding: utf-8 -*-

import base64
from collections import namedtuple
from datetime import datetime
import json

from passport.backend.core.crypto.signing import (
    sign,
    simple_is_correct_signature,
    simple_sign,
)
from passport.backend.core.models.support_code import SupportCode
from passport.backend.core.serializers.ydb.base import BaseYdbSerializer
from passport.backend.core.serializers.ydb.exceptions import DataCorruptedYdbSerializerError
from passport.backend.utils.string import (
    smart_bytes,
    smart_text,
)
from passport.backend.utils.time import datetime_to_integer_unixtime


def hash_support_code_value(code_value, secret):
    code_hash = sign(smart_bytes(code_value), secret.secret, secret.algorithm)
    code_hash = b''.join([code_hash, smart_bytes(secret.id)])
    return base64.standard_b64encode(code_hash)


class SupportCodeSerializer(BaseYdbSerializer):
    def __init__(self, signing_registry, cur_secret, old_secret):
        self.signing_registry = signing_registry
        self.cur_secret = cur_secret
        self.old_secret = old_secret

    @classmethod
    def from_config(cls, config):
        return SupportCodeSerializer(config.signing_registry, config.cur_secret, config.old_secret)

    def to_ydb_rows(self, model):
        sc_rows = list()
        for secret in [self.old_secret, self.cur_secret]:
            sc_rows.append(SupportCodeRow.from_support_code(model, secret))
        return [r.to_ydb_row() for r in sc_rows]

    def from_ydb_rows(self, ydb_rows):
        rows = list()
        try:
            rows.append(next(ydb_rows))
        except StopIteration:
            return
        try:
            rows.append(next(ydb_rows))
        except StopIteration:
            raise DataCorruptedYdbSerializerError('Data corrupted: not enough rows')

        sc_rows = list(map(SupportCodeRow.from_ydb_row, rows))

        for sc_row in sc_rows:
            if not sc_row.is_correct_signature(self.signing_registry):
                raise DataCorruptedYdbSerializerError('Data corrupted: invalid signature')

        if not (
            sc_rows[0].uid == sc_rows[1].uid and
            sc_rows[0].expires_at == sc_rows[1].expires_at
        ):
            raise DataCorruptedYdbSerializerError('Data corrupted: rows are different')

        return SupportCode(
            uid=sc_rows[0].uid,
            expires_at=sc_rows[0].expires_at,
            value=None,
        )


class SupportCodeRow(object):
    def __init__(
        self,
        uid=None,
        expires_at=None,
        code_hash=None,
        signature=None,
    ):
        self.uid = uid
        self.expires_at = expires_at
        self.code_hash = code_hash
        self.signature = signature

    def is_correct_signature(self, signing_registry):
        signature = base64.standard_b64decode(self.signature)
        return simple_is_correct_signature(
            signature,
            self.to_bytes(),
            signing_registry,
        )

    def to_bytes(self):
        expires_at = datetime_to_integer_unixtime(self.expires_at)
        return b''.join([smart_bytes(self.uid), smart_bytes(expires_at), smart_bytes(self.code_hash)])

    def sign(self, secret):
        signature = simple_sign(self.to_bytes(), version=secret)
        self.signature = smart_text(base64.standard_b64encode(signature))

    @classmethod
    def from_support_code(cls, support_code, secret):
        code_hash = hash_support_code_value(support_code.value, secret)
        support_code = SupportCodeRow(
            uid=support_code.uid,
            expires_at=support_code.expires_at,
            code_hash=code_hash,
            signature=None,
        )
        support_code.sign(secret)
        return support_code

    def to_ydb_row(self):
        value = dict(uid=self.uid, signature=self.signature)
        value = json.dumps(value)
        expires_at = datetime_to_integer_unixtime(self.expires_at)
        return dict(
            code_hash=self.code_hash,
            value=value,
            expires_at=expires_at,
        )

    @classmethod
    def from_ydb_row(cls, ydb_row):
        expires_at = datetime.fromtimestamp(ydb_row['expires_at'])
        value = json.loads(ydb_row['value'])
        return SupportCodeRow(
            uid=value['uid'],
            expires_at=expires_at,
            code_hash=ydb_row['code_hash'],
            signature=value['signature'],
        )


SupportCodeSerializerConfiguration = namedtuple(
    'SupportCodeSerializerConfiguration',
    'signing_registry cur_secret old_secret',
)
