import datetime
from sqlalchemy import literal, and_, case, text

from crm.agency_cabinet.agencies.server.src.db.models import AgencyAnalytics
from crm.agency_cabinet.agencies.server.src.db import db
from crm.agency_cabinet.common.consts import AverageCheckBucket


def make_base_avg_check_aggregation(month_start: datetime.date, month_end: datetime.date):
    return db.select(
        [
            AgencyAnalytics.client_id.label('client_id'),
            AgencyAnalytics.agency_id.label('agency_id'),
            db.func.trunc(db.func.sum(AgencyAnalytics.all_money) * literal(1.0) / db.func.count(
                db.func.distinct(AgencyAnalytics.month)) / literal(1000.0)).label('avg_check')
        ]
    ).select_from(AgencyAnalytics).where(
        and_(
            AgencyAnalytics.month >= month_start,
            AgencyAnalytics.month < month_end
        )
    ).group_by(AgencyAnalytics.agency_id, AgencyAnalytics.client_id)


def make_current_vs_other_median_check_aggregation(agency_id: int, month_start: datetime.date, month_end: datetime.date):
    base_agg = make_base_avg_check_aggregation(month_start, month_end).cte('base_agg')
    case_stm = case(
        [
            (
                base_agg.c.agency_id == agency_id,
                literal('current')
            ),
        ],
        else_=literal('other')
    )

    return db.select(
        [
            case_stm.label('t'),
            db.func.percentile_cont(0.5).within_group(base_agg.c.avg_check).label('median_check'),
        ]
    ).select_from(base_agg).group_by('t')


def make_avg_check_agency_distribution_aggregation(
    agency_id: int,
    month_start: datetime.date,
    month_end: datetime.date,
    count_other: bool = False
):
    base_agg = make_base_avg_check_aggregation(month_start, month_end).cte('base_agg')
    case_stm = case(
        [
            (
                base_agg.c.avg_check < AverageCheckBucket.level_1.border,
                literal(AverageCheckBucket.level_1.value)
            ),
            (
                base_agg.c.avg_check < AverageCheckBucket.level_2.border,
                literal(AverageCheckBucket.level_2.value)
            ),
        ],
        else_=literal(AverageCheckBucket.level_3.value)
    )

    return db.select(
        [
            case_stm.label('check_ctg'),
            db.func.count(base_agg.c.client_id).label('count')
        ]
    ).select_from(base_agg).where(
        base_agg.c.agency_id != agency_id if count_other else base_agg.c.agency_id == agency_id
    ).group_by('check_ctg')


def make_median_check_distribution_aggregation(month_start: datetime.date, month_end: datetime.date, agency_id):
    base_agg = make_base_avg_check_aggregation(month_start, month_end).cte('base_agg')
    case_stm = case(
        [
            (
                base_agg.c.agency_id == agency_id,
                literal('current')
            ),
        ],
        else_=literal('other')
    )
    return db.select(
        [
            base_agg.c.agency_id,
            case_stm.label('t'),
            db.func.percentile_cont(0.5).within_group(base_agg.c.avg_check).label('median_check'),
            db.func.count(base_agg.c.client_id).label('cnt_clients')
        ]
    ).select_from(base_agg).group_by(base_agg.c.agency_id, 't')


def make_clients_info_aggregation(agency_id, month_start, month_end, get_only_current=False, sort_by_new=False):
    case_stm = case(
        [
            (
                AgencyAnalytics.agency_id == agency_id,
                literal('current')
            ),
        ],
        else_=literal('other')
    )
    q = db.select(
        [
            AgencyAnalytics.agency_id.label('agency_id'),
            case_stm.label('t'),
            db.func.count(db.func.distinct(AgencyAnalytics.client_id)).label('cnt_clients'),
            # это какой-то грязный хак для правильного рендеринга выражения поддерживаемого в postgresql
            # коробочного или красивого решения я не нашел
            db.func.count(db.func.distinct(AgencyAnalytics.client_id)).op("FILTER")(text("(WHERE epoch = 0)")).label(
                'cnt_new_clients')
        ]
    ).select_from(AgencyAnalytics).where(
        and_(
            AgencyAnalytics.month >= month_start,
            AgencyAnalytics.month < month_end
        )
    )
    if get_only_current:
        q = q.where(AgencyAnalytics.agency_id == agency_id)
    return q.group_by(AgencyAnalytics.agency_id, 't').order_by('cnt_new_clients' if sort_by_new else 'cnt_clients', 't')


def make_clients_info_aggregation_with_prev_period(
    agency_id: int,
    month_start: datetime.date,
    month_end: datetime.date,
    get_only_current: bool = False,
    sort_by_new: bool = False
):
    period_length = month_end - month_start
    prev_month_start = month_start - period_length

    case_stm = case(
        [
            (
                AgencyAnalytics.agency_id == agency_id,
                literal('current')
            ),
        ],
        else_=literal('other')
    )
    q = db.select(
        [
            AgencyAnalytics.agency_id.label('agency_id'),
            case_stm.label('t'),
            # это какой-то грязный хак для правильного рендеринга выражения поддерживаемого в postgresql
            # коробочного или красивого решения я не нашел
            db.func.count(
                db.func.distinct(AgencyAnalytics.client_id)
            ).op("FILTER")(
                text("(WHERE month >= '{}' AND month < '{}')".format(month_start.isoformat(), month_end.isoformat()))
            ).label('cnt_clients'),
            db.func.count(
                db.func.distinct(AgencyAnalytics.client_id)
            ).op("FILTER")(
                text("(WHERE epoch = 0 AND month >= '{}' AND month < '{}')".format(month_start.isoformat(), month_end.isoformat()))
            ).label('cnt_new_clients'),

            db.func.count(
                db.func.distinct(AgencyAnalytics.client_id)
            ).op("FILTER")(
                text("(WHERE month >= '{}' AND month < '{}')".format(prev_month_start.isoformat(), month_start.isoformat()))
            ).label('cnt_clients_prev'),
            db.func.count(
                db.func.distinct(AgencyAnalytics.client_id)
            ).op("FILTER")(
                text("(WHERE epoch = 0 AND month >= '{}' AND month < '{}')".format(
                    prev_month_start.isoformat(),
                    month_start.isoformat())
                )
            ).label('cnt_new_clients_prev')
        ]
    ).select_from(AgencyAnalytics).where(
        and_(
            AgencyAnalytics.month >= prev_month_start,
            AgencyAnalytics.month < month_end
        )
    )
    if get_only_current:
        q = q.where(AgencyAnalytics.agency_id == agency_id)
    return q.group_by(AgencyAnalytics.agency_id, 't').order_by('cnt_new_clients' if sort_by_new else 'cnt_clients', 't')
