# -*- coding: utf-8 -*-

from __future__ import unicode_literals

import hashlib

from passport.backend.core.db.utils import insert_with_on_duplicate_key_update
from passport.backend.social.common.db.execute import execute
from passport.backend.social.common.db.schemas import (
    refresh_token_table,
    token_table,
)
from passport.backend.social.common.limits import get_qlimits
from passport.backend.social.common.serialize import (
    deserialize_scopes,
    serialize_datetime,
    serialize_scopes,
)
from passport.backend.social.common.social_config import social_config
import sqlalchemy.sql as sql

from .domain import Token


TOKEN_HASH_METHOD = 'md5'


def serialize_secret(secret):
    return secret or ''


def deserialize_secret(value):
    return value or None


class TokenRecord(object):
    _FIELDS = {
        'token_id',
        'uid',
        'profile_id',
        'application_id',
        'value',
        'secret',
        'scope',
        'expired',
        'created',
        'verified',
        'confirmed',
    }

    def __init__(self, **kwargs):
        self._value = None
        self.value_hash = None

        for field_name in self._FIELDS:
            setattr(self, field_name, kwargs[field_name])

        # Если value_hash задан явно, то используем его вместо вычисленного
        if 'value_hash' in kwargs:
            self.value_hash = kwargs['value_hash']

    @property
    def value(self):
        return self._value

    @value.setter
    def value(self, value):
        self._value = value
        self.value_hash = self.eval_value_hash(value)

    @staticmethod
    def eval_value_hash(value):
        if value is not None:
            return str(TOKEN_HASH_METHOD) + ':' + hashlib.md5(value).hexdigest()

    def to_model(self):
        constructor_dict = {k: getattr(self, k) for k in self._FIELDS}
        constructor_dict['scopes'] = constructor_dict.pop('scope')
        return Token(**constructor_dict)

    @classmethod
    def from_model(cls, model):
        record_name_to_model_name = {k: k for k in cls._FIELDS}
        record_name_to_model_name['scope'] = 'scopes'
        constructor_dict = {n: getattr(model, record_name_to_model_name[n]) for n in record_name_to_model_name}
        record = TokenRecord(**constructor_dict)
        record._model = model
        return record

    def save(self, db):
        assert self.uid is not None

        if self.token_id is None:
            query = (
                insert_with_on_duplicate_key_update(
                    token_table,
                    {
                        'value',
                        'value_hash',
                        'secret',
                        'scope',
                        'expired',
                        'verified',
                        'confirmed',
                        # Может остаться мусорный токен с другим profile_id, но
                        # такими же uid, app, token и на случай гонок нужно
                        # предусмотреть обновление profile_id, чтобы избежать
                        # отказа.
                        'profile_id',
                        # Хак нужный для работы get_token_newest, правильно было бы
                        # завести новое свойство токена updated. Но пока не жмёт
                        # можно и так
                        'created',
                    },
                )
                .values(
                    uid=self.uid,
                    application_id=self.application_id,
                    value=self.value,
                    value_hash=self.value_hash,
                    secret=serialize_secret(self.secret),
                    scope=serialize_scopes(self.scope),
                    expired=serialize_datetime(self.expired),
                    profile_id=self.profile_id,
                    created=serialize_datetime(self.created),
                    verified=serialize_datetime(self.verified),
                    confirmed=serialize_datetime(self.confirmed),
                )
            )
            query_result = execute(db, query)
            self.token_id = query_result.lastrowid
            if self._model is not None:
                self._model.token_id = self.token_id
        else:
            query = (
                token_table.update()
                .values(
                    value=self.value,
                    value_hash=self.value_hash,
                    secret=serialize_secret(self.secret),
                    scope=serialize_scopes(self.scope),
                    expired=serialize_datetime(self.expired),
                    profile_id=self.profile_id,
                    verified=serialize_datetime(self.verified),
                    confirmed=serialize_datetime(self.confirmed),
                    # Хак нужный для работы get_token_newest, правильно было бы
                    # завести новое свойство токена updated. Но пока не жмёт
                    # можно и так
                    created=serialize_datetime(self.created),
                )
                .where(token_table.c.token_id == self.token_id)
            )
            execute(db, query)

    @classmethod
    def find_by_value_for_account(cls, uid, application_id, value, db):
        if not social_config.find_token_by_value_hash:
            query = (
                token_table.select()
                .where(
                    sql.and_(
                        token_table.c.uid == uid,
                        token_table.c.application_id == application_id,
                        token_table.c.value == value,
                    ),
                )
            )
        else:
            query = (
                token_table.select()
                .where(
                    sql.and_(
                        token_table.c.uid == uid,
                        token_table.c.application_id == application_id,
                        token_table.c.value_hash == cls.eval_value_hash(value),
                    ),
                )
            )
        return cls._select_query_to_token_records(query, db)

    @classmethod
    def find_by_token_id(cls, token_id, db):
        query = token_table.select().where(token_table.c.token_id == token_id)
        return cls._select_query_to_token_records(query, db)

    @classmethod
    def find_all_for_profile(cls, profile_id, application_ids, db):
        criteria = token_table.c.profile_id == profile_id
        if application_ids:
            criteria = sql.and_(criteria, token_table.c.application_id.in_(application_ids))
        query = (
            token_table
            .select()
            .where(criteria)
            .order_by(sql.desc(token_table.c.token_id))
            .limit(get_qlimits()['tokens'])
        )
        return cls._select_query_to_token_records(query, db)

    @classmethod
    def find_all_for_account(cls, uid, application_ids, db):
        criteria = token_table.c.uid == uid
        if application_ids:
            criteria = sql.and_(criteria, token_table.c.application_id.in_(application_ids))
        query = (
            token_table
            .select().with_only_columns(token_table.c)
            .where(criteria)
            .order_by(sql.desc(token_table.c.token_id))
            .limit(get_qlimits()['tokens'])
        )
        return cls._select_query_to_token_records(query, db)

    @classmethod
    def find_all_for_profile_ids(cls, profile_ids, db):
        query = (
            token_table
            .select()
            .where(token_table.c.profile_id.in_(profile_ids))
            .order_by(sql.desc(token_table.c.token_id))
            .limit(get_qlimits()['tokens'])
        )
        return cls._select_query_to_token_records(query, db)

    @classmethod
    def delete_by_token_ids(cls, token_ids, db):
        query = refresh_token_table.delete().where(refresh_token_table.c.token_id.in_(token_ids))
        execute(db, query)

        query = token_table.delete().where(token_table.c.token_id.in_(token_ids))
        deletion_result = execute(db, query)
        return deletion_result.rowcount

    @classmethod
    def delete_all_for_account(cls, uid, application_ids, read_conn, write_conn):
        if application_ids == []:
            # Т.к. find_all_for_account найдёт токены для всех приложений, если
            # applications_id пустой список или None, в методе удаления нужно
            # защититься от случайного удаления токенов всех приложений.
            return 0
        token_records = cls.find_all_for_account(uid, application_ids, db=read_conn)
        token_ids = [t.token_id for t in token_records]
        return cls.delete_by_token_ids(token_ids, db=write_conn)

    @classmethod
    def delete_all_for_profile_ids(cls, profile_ids, read_conn, write_conn):
        token_records = cls.find_all_for_profile_ids(profile_ids, read_conn)
        token_ids = [t.token_id for t in token_records]
        return cls.delete_by_token_ids(token_ids, db=write_conn)

    @classmethod
    def _select_query_to_token_records(cls, query, db):
        token_records = []
        for row in execute(db, query).fetchall():
            token_records.append(
                TokenRecord(
                    token_id=row['token_id'],
                    uid=row['uid'],
                    profile_id=row['profile_id'],
                    application_id=row['application_id'],
                    value=row['value'],
                    secret=deserialize_secret(row['secret']),
                    scope=deserialize_scopes(row['scope']),
                    expired=row['expired'],
                    created=row['created'],
                    verified=row['verified'],
                    confirmed=row['confirmed'],
                ),
            )
        return token_records
