# -*- coding: utf-8 -*-

from __future__ import unicode_literals

from collections import namedtuple

from passport.backend.core.db.query import join_queries
from passport.backend.core.serializers.base import serialize
from passport.backend.social.common.db.execute import execute
from passport.backend.social.common.db.utils import get_master_engine
from passport.backend.social.common.serialize import TYPE_NAME_TO_DATABASE_SERIALIZER


class ModelState(namedtuple('ModelState', ['snapshot', 'model'])):
    @classmethod
    def created(cls, model):
        return cls(snapshot=None, model=model)

    @property
    def is_created(self):
        return self.snapshot is None and self.model

    @classmethod
    def deleted(cls, model):
        return cls(snapshot=model.snapshot(), model=None)

    @property
    def is_deleted(self):
        return self.snapshot and self.model is None

    @classmethod
    def changed(cls, model):
        return cls(snapshot=model.snapshot(), model=model)

    @property
    def is_changed(self):
        return self.snapshot and self.model


class _BaseSession(object):
    def __init__(self):
        self._model_id_to_state = dict()

    def add(self, model):
        # Добавить новый объект в БД
        model_id = id(model)
        state = self._model_id_to_state.get(model_id)
        if not state:
            state = ModelState.created(model)
        else:
            assert False
        self._model_id_to_state[model_id] = state

    def add_committed(self, model):
        # Считать что данный объект уже находится в БД и сериализовать только
        # его изменения.
        model_id = id(model)
        state = self._model_id_to_state.get(model_id)
        if not state:
            state = ModelState.changed(model)
        else:
            assert False
        self._model_id_to_state[model_id] = state

    def remove(self, model):
        # Удалить сохранённый объект из БД
        model_id = id(model)
        state = self._model_id_to_state.get(model_id)
        if not state:
            assert False
        elif state.is_changed:
            state = ModelState.deleted(model)
        else:
            assert False
        self._model_id_to_state[model_id] = state

    def commit(self):
        for model_id in self._model_id_to_state.keys():
            state = self._model_id_to_state[model_id]
            if state.is_created:
                self._serialize(None, state.model)
                state = ModelState.changed(state.model)
                self._model_id_to_state[model_id] = state
            elif state.is_changed:
                is_changed = self._serialize(state.snapshot, state.model)
                if is_changed:
                    state = ModelState.changed(state.model)
                    self._model_id_to_state[model_id] = state
            elif state.is_deleted:
                self._serialize(state.snapshot, None)
                del self._model_id_to_state[model_id]
            else:
                assert False

    def _serialize(self):
        raise NotImplementedError()  # pragma: no cover


class Session(_BaseSession):
    def __init__(self, write_conn=None):
        super(Session, self).__init__()
        if write_conn is None:
            self._write_conn = get_master_engine()
        else:
            self._write_conn = write_conn

    def _serialize(self, old, new):
        return self._serialize_to_database(old, new)

    def _serialize_to_database(self, old, new):
        is_changed = False
        bits = serialize(old, new, None, TYPE_NAME_TO_DATABASE_SERIALIZER)
        for query, callback in join_queries(bits):
            result = execute(self._write_conn, query.to_query())
            is_changed = True
            if callback:
                callback(result)
        return is_changed
