# -*- coding: utf-8 -*-
import typing
import itertools
from gino import Gino
from sqlalchemy import func
from sqlalchemy.dialects.postgresql import insert
from sqlalchemy.engine.result import RowProxy
from crm.agency_cabinet.common.utils.chunk import chuncker
from .base import db, get_or_create_from_config

if db is None:
    db: Gino = get_or_create_from_config()


class BaseModel(db.Model):

    created_at = db.Column(db.DateTime(timezone=True), server_default=func.now(), nullable=False)
    updated_at = db.Column(db.DateTime(timezone=True), server_default=func.now(), nullable=False, onupdate=func.now())

    @classmethod
    async def bulk_insert(cls, items: typing.List[dict], step=5000, return_list=True) -> typing.Iterable[RowProxy]:
        # TODO: try to fix inconsistent tuples https://groups.google.com/g/sqlalchemy/c/RTE4eHFK09w
        res = []
        for chunk in chuncker(items, step):
            res = itertools.chain(res, await insert(cls.__table__).values(tuple(chunk)).on_conflict_do_nothing().returning(cls.__table__).gino.all(read_only=False, reuse=False) or [])
        return list(res) if return_list else res

    @classmethod
    async def _get(cls, filter_columns: typing.Dict[str, str]) -> typing.Optional['BaseModel']:
        query = cls.query
        for key, value in filter_columns.items():
            col = cls.__table__.columns.get(key)
            if col is None:
                raise KeyError('Unknown column %s for table %s', key, cls.__tablename__)
            query = query.where(col == value)
        model = await query.gino.first()
        return model

    @classmethod
    async def get_or_create(cls, filter_columns: typing.Dict[str, str], create_columns: typing.Dict[str, str]):
        model = await cls._get(filter_columns)
        if model is None:
            return await cls.create(**create_columns)
        return model

    @classmethod
    async def update_or_create(cls, filter_columns: typing.Dict[str, typing.Union[str, int]], update_columns: typing.Dict[str, typing.Union[str, int]]):
        model = await cls._get(filter_columns)
        if model is None:
            return await cls.create(**update_columns)
        else:
            await model.update(**update_columns).apply()
            return model


class YtSync(BaseModel):
    __tablename__ = 'tools_yt_sync'

    id = db.Column(db.BigInteger, primary_key=True)
    path = db.Column(db.Text)
    tag = db.Column(db.Text, nullable=True)
    last_used_update = db.Column(db.DateTime(timezone=True))
    enabled = db.Column(db.Boolean, nullable=True, server_default='t', default=True)

    __table_args__ = (
        db.Index('tools_yt_sync__path_tag__idx', 'path', 'tag', unique=True),
    )


NumericType = db.Numeric(precision=18, scale=6)
