# encoding: UTF-8

import typing as t

import sqlalchemy.orm as orm

from appcore.data.model import Page
from appcore.data.model import Pageable
from appcore.struct import maybe_list

_T = t.TypeVar('T')
_TType = t.Type[_T]
_SessionFactory = t.Callable[[], orm.Session]


class Repository(t.Generic[_T]):
    def __init__(self, model, session_factory):
        # type: (_TType, _SessionFactory) -> None
        self.__model = model  # type: _TType
        self.__session_factory = session_factory  # type: _SessionFactory

    @property
    def model(self):
        # type: () -> _TType

        return self.__model

    @property
    def session(self):
        # type: () -> orm.Session
        
        return self.__session_factory()

    @property
    def query(self):
        # type: () -> orm.Query

        return self.session.query(self.model)

    def _apply_filters(self, query, filters):
        if filters is not None:
            filters = maybe_list(filters)
            query = query.filter(*filters)

        return query

    def _apply_sort(self, query, sort):
        if sort is not None:
            sort = maybe_list(sort)
            query = query.order_by(*sort)

        return query

    def _apply_pageable(self, query, pageable):
        if pageable is not None:
            query = query.slice(pageable.offset, pageable.end_offset)

        return query

    def find_iter(self, filter=None, sort=None, yield_per=10000):
        # type: (...) -> t.Iterable[_T]

        query = self.query
        query = self._apply_filters(query, filter)
        query = self._apply_sort(query, sort)  # type: orm.Query
        query = query.yield_per(yield_per)

        return iter(query)

    def find_paged(self, pageable, filter=None, sort=None):
        # type: (Pageable, t.Any, t.Any) -> Page[_T]

        query = self.query
        query = count_query = self._apply_filters(query, filter)
        query = self._apply_sort(query, sort)
        query = self._apply_pageable(query, pageable)

        return Page(
            items=query.all(),
            offset=pageable.offset,
            size=pageable.size,
            total=count_query.count(),
        )

    def find_one(self, filter=None, strongly=True):
        # type: (...) -> _T

        query = self.query
        query = self._apply_filters(query, filter)

        if strongly:
            return query.one()
        else:
            return query.one_or_none()

    def find_first(self, filter=None, sort=None):
        # type: (...) -> _T

        query = self.query
        query = self._apply_filters(query, filter)
        query = self._apply_sort(query, sort)

        return query.first()

    def count(self, filter=None):
        # type: (...) -> int

        query = self.query
        query = self._apply_filters(query, filter)

        return query.count()

    def get(self, id):
        # type: (...) -> _T

        return self.query.get(id)

    def save(self, entity, flush=False):
        # type: (_T, bool) -> _T

        merged = self.session.merge(entity)

        if flush:
            self.session.flush()

        self.session.commit()
        return merged

    def delete(self, entity):
        # type: (_T) -> None

        self.session.delete(entity)
        self.session.commit()
