from abc import ABC, abstractmethod
from enum import Enum
from typing import Generic, List, Optional, Type, TypeVar

from sqlalchemy import func
from sqlalchemy.orm import Session

import travel.avia.subscriptions.app.model.db as db_models
from travel.avia.subscriptions.app.model.db import Base, DBInstanceMixin

TBase = TypeVar('TBase', bound=Base)


class UpsertAction(Enum):
    UPDATE = 1
    INSERT = 2


class Storage(ABC, Generic[TBase]):
    @abstractmethod
    def get(self, **kwargs) -> Optional[TBase]:
        pass

    @abstractmethod
    def find(self, **kwargs) -> List[TBase]:
        pass

    @abstractmethod
    def create(self, **kwargs) -> TBase:
        pass

    @abstractmethod
    def get_or_create(self, **kwargs) -> TBase:
        pass

    @abstractmethod
    def upsert(self, where: dict, values: dict = None) -> (UpsertAction, TBase):
        pass


def DatabaseStorage(storable_class: Type[DBInstanceMixin]):
    class DBS(Storage):
        def __init__(self, sess: Session):
            self.session: Session = sess

        def get(self, **kwargs) -> Optional[storable_class]:
            return self.session.query(storable_class).filter_by(**kwargs).first()

        def find(self, **kwargs) -> List[storable_class]:
            return self.session.query(storable_class).filter_by(**kwargs).all()

        def create(self, **kwargs) -> storable_class:
            value = storable_class(**kwargs)
            self.session.add(value)
            self.session.flush((value,))
            return value

        def get_or_create(self, **kwargs) -> storable_class:
            return self.get(**kwargs) or self.create(**kwargs)

        def upsert(self, where: dict, values: dict = None) -> (UpsertAction, storable_class):
            values = values or {}
            where_object = self.get(**where)
            if where_object is None:
                return UpsertAction.INSERT, self.create(**{**where, **values})

            self.session.add(where_object)
            for k, v in values.items():
                setattr(where_object, k, v)
            return UpsertAction.UPDATE, where_object

    class EmailDBS(DBS):
        def get(self, **kwargs) -> Optional[storable_class]:
            q = self.session.query(storable_class)
            if 'email' in kwargs:
                email = kwargs['email']
                del kwargs['email']
                q = q.filter(func.lower(db_models.Email.email) == func.lower(email))
            if len(kwargs) > 0:
                q = q.filter_by(**kwargs)
            return q.first()

    if issubclass(storable_class, db_models.Email):
        return EmailDBS
    return DBS
