from datetime import datetime
from sqlalchemy.dialects.postgresql import JSONB

from crm.agency_cabinet.common.db.models import BaseModel, db


class CalculatorData(BaseModel):
    __tablename__ = 'calculator_data'

    id = db.Column(db.BigInteger, primary_key=True)
    contract_id = db.Column(db.BigInteger, nullable=False)
    service = db.Column(db.String, nullable=False)
    data = db.Column(JSONB)
    version = db.Column(db.String, nullable=True)

    __table_args__ = (
        db.ForeignKeyConstraint(['contract_id'], ['contract.id']),
        db.UniqueConstraint('contract_id', 'service', 'version', name='calculator_data__contract_id_service_version__uc')
    )


class CalculatorDataUpdater:
    def __init__(self, data: CalculatorData):
        self.data = data

    async def update_data(self, new_data: dict, prefer_actual=True, merge_data=True):
        merged_data = self.data.data.get('months', [])

        old_data = {month['period_from']: (position, month) for position, month in enumerate(self.data.data.get('months', []))}
        new_data = {month['period_from']: month for month in new_data['months']}

        new_periods = set(new_data.keys()) - set(old_data.keys())
        updated_periods = set(new_data.keys()) - new_periods

        for period_from in updated_periods:
            position, month = old_data[period_from]
            merged_data[position] = self._update_month_data(month, new_data[period_from],
                                                            prefer_actual=prefer_actual, merge_data=merge_data)

        for period_from in new_periods:
            merged_data.append(new_data[period_from])

        merged_data = sorted(merged_data, key=lambda x: x['period_from'])

        await CalculatorData.update.values(
            data={'months': merged_data},
            updated_at=datetime.now()
        ).where(CalculatorData.id == self.data.id).gino.status(read_only=False, reuse=False)

    async def add_or_override_month_data(self, new_month_data: dict):
        months = self.data.data.get('months', [])
        period_from = new_month_data['period_from']
        position = next((i for i, month in enumerate(months) if month['period_from'] == period_from), None)

        if position:
            months[position] = new_month_data
        else:
            months.append(new_month_data)
            months = sorted(months, key=lambda x: x['period_from'])

        await CalculatorData.update.values(
            data={'months': months},
            updated_at=datetime.now()
        ).where(CalculatorData.id == self.data.id).gino.status(read_only=False, reuse=False)

    @staticmethod
    def _update_month_data(old_month_data, new_month_data, prefer_actual=True, merge_data=True):
        if prefer_actual is True and new_month_data['predict'] is False and old_month_data['predict'] is True:
            return new_month_data
        elif prefer_actual is True and new_month_data['predict'] is True and old_month_data['predict'] is False:
            return old_month_data

        if merge_data is False:
            return new_month_data

        merged_month_data = old_month_data
        merged_month_data['indexes'] = CalculatorDataUpdater._update_indexes(old_month_data['indexes'], new_month_data['indexes'])

        merged_month_data['grades'] = CalculatorDataUpdater._update_grades(old_month_data['grades'], new_month_data['grades'])
        return merged_month_data

    @staticmethod
    def _update_indexes(old_month_data, new_month_data):
        old_indexes_dict = {index['index_id']: index for index in old_month_data or []}
        new_indexes_dict = {index['index_id']: index for index in new_month_data or []}
        old_indexes_dict.update(new_indexes_dict)
        return sorted(old_indexes_dict.values(), key=lambda x: x['index_id'])

    @staticmethod
    def _update_grades(old_grades, new_grades):
        old_grades_dict = {grade['grade_id']: grade for grade in old_grades or []}
        new_grades_dict = {grade['grade_id']: grade for grade in new_grades or []}
        old_grades_dict.update(new_grades_dict)
        return sorted(old_grades_dict.values(), key=lambda x: x['grade_id'])
