# -*- coding: utf-8 -*-

from __future__ import unicode_literals

from passport.backend.core.db.utils import insert_with_on_duplicate_key_update
from passport.backend.social.common.db.execute import execute
from passport.backend.social.common.db.schemas import (
    refresh_token_table,
    token_table,
)
from passport.backend.social.common.serialize import (
    deserialize_scopes,
    serialize_datetime,
)

from .domain import RefreshToken


class RefreshTokenRecord(object):
    _FIELDS = {
        'refresh_token_id',
        'value',
        'expired',
        'token_id',
        # Скоупы не сохраняются в БД, но их можно вычитать из соответствующего токена
        'scopes',
    }

    def __init__(self, **kwargs):
        for field_name in self._FIELDS:
            setattr(self, field_name, kwargs[field_name])

    def to_model(self):
        constructor_dict = {k: getattr(self, k) for k in self._FIELDS}
        return RefreshToken(**constructor_dict)

    @classmethod
    def from_model(cls, model):
        constructor_dict = {n: getattr(model, n) for n in cls._FIELDS}
        record = RefreshTokenRecord(**constructor_dict)
        record._model = model
        return record

    def save(self, db):
        assert self.token_id is not None
        query = (
            insert_with_on_duplicate_key_update(
                refresh_token_table,
                ['expired', 'value'],
            )
            .values(
                token_id=self.token_id,
                value=self.value,
                expired=serialize_datetime(self.expired),
            )
        )
        query_result = execute(db, query)
        self.refresh_token_id = query_result.lastrowid
        if self._model is not None:
            self._model.refresh_token_id = self.token_id

    @classmethod
    def find_by_token_ids(cls, token_ids, db):
        if not token_ids:
            return []

        rtt = refresh_token_table
        tt = token_table
        query = (
            rtt.join(
                tt,
                tt.c.token_id == rtt.c.token_id,
            )
            .select(use_labels=True)
            .where(rtt.c.token_id.in_(token_ids))
        )

        refresh_token_records = []
        for row in execute(db, query).fetchall():
            refresh_token_records.append(
                RefreshTokenRecord(
                    refresh_token_id=row['refresh_token_refresh_token_id'],
                    value=row['refresh_token_value'],
                    expired=row['refresh_token_expired'],
                    token_id=row['refresh_token_token_id'],
                    scopes=deserialize_scopes(row['token_scope']),
                ),
            )
        return refresh_token_records
