from collections import defaultdict
from typing import List, Optional

import sqlalchemy
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.sql import cast, func, join, literal_column, select, text, union_all

from maps_adv.adv_store.lib.db import db
from maps_adv.adv_store.v2.lib.db import tables
from maps_adv.common.helpers.enums import CampaignTypeEnum

from . import PlatformEnum, PublicationEnvEnum


def _query_select_for_action(action_table, action: str, campaign_ids: List[int]):
    # select the actions of the campaign
    row_q = (
        select([action_table])
        .where(action_table.c.campaign_id.in_(campaign_ids))
        .alias(f"{action}_row_q")
    )

    # select a table like this (action = 'open_site'):
    # |   phone_call   |  campaign_id  |
    # | {'title': ...} |       1       |
    row_q_labeled = (
        select(
            [
                row_q.c.campaign_id,
                (
                    cast(func.row_to_json(literal_column(row_q.name)), JSONB)
                    - text("'campaign_id'::text")
                ).label(action),
            ]
        )
        .select_from(row_q)
        .alias(f"{action}_labeled_q")
    )

    # select a table like this (action = 'open_site'):
    # |        action       |  campaign_id  |
    # | {'phone_call': ...} |       1       |
    return select(
        [
            row_q_labeled.c.campaign_id,
            (
                cast(func.row_to_json(literal_column(row_q_labeled.name)), JSONB)
                - text("'campaign_id'::text")
            ).label("action"),
        ]
    ).select_from(row_q_labeled)


def _query_select_for_action_list(campaign_ids: List[int]):
    open_site = _query_select_for_action(tables.open_site, "open_site", campaign_ids)
    search = _query_select_for_action(tables.search, "search", campaign_ids)
    download_app = _query_select_for_action(
        tables.download_app, "download_app", campaign_ids
    )
    phone_call = _query_select_for_action(tables.phone_call, "phone_call", campaign_ids)

    # select a table with all actions united
    # |         action         |  campaign_id  |
    # |   {'open_site': ...}   |       1       |
    # | {'download_app': ...}  |       1       |
    # |  {'phone_call': ...}   |       2       |
    united_q = union_all(open_site, search, phone_call, download_app).alias(
        "actions_united_q"
    )

    # select a table where actions grouped by campaign_id
    # |                    actions                  |  campaign_id  |
    # | [{'open_site': ...}, {'download_app': ...}] |       1       |
    # | [{'phone_call': ...}]                       |       2       |
    return (
        select(
            [
                united_q.c.campaign_id,
                cast(func.jsonb_agg(united_q.c.action), JSONB).label("actions"),
            ]
        )
        .select_from(united_q)
        .group_by(united_q.c.campaign_id)
        .alias("actions_q")
    )


def _query_select_for_creative(creative_table, creative: str, campaign_ids: List[int]):
    # select the creatives of the campaign
    row_q = (
        select([creative_table])
        .where(creative_table.c.campaign_id.in_(campaign_ids))
        .alias(f"{creative}_row_q")
    )

    # select a table like this (creative = 'billboard'):
    # |    billboard    |  campaign_id  |
    # | {'images': ...} |       1       |
    row_q_labeled = (
        select(
            [
                row_q.c.campaign_id,
                (
                    cast(func.row_to_json(literal_column(row_q.name)), JSONB)
                    - text("'campaign_id'::text")
                ).label(creative),
            ]
        )
        .select_from(row_q)
        .alias(f"{creative}_labeled_q")
    )

    # select a table like this (creative = 'billboard'):
    # |      creative      |  campaign_id  |
    # | {'billboard': ...} |       1       |
    return select(
        [
            row_q_labeled.c.campaign_id,
            (
                cast(func.row_to_json(literal_column(row_q_labeled.name)), JSONB)
                - text("'campaign_id'::text")
            ).label("creative"),
        ]
    ).select_from(row_q_labeled)


def _query_select_for_creative_list(campaign_ids: List[int]):
    billboard = _query_select_for_creative(tables.billboard, "billboard", campaign_ids)
    logo_and_text = _query_select_for_creative(
        tables.logo_and_text, "logo_and_text", campaign_ids
    )
    banner = _query_select_for_creative(tables.banner, "banner", campaign_ids)
    pin = _query_select_for_creative(tables.pin, "pin", campaign_ids)
    text = _query_select_for_creative(tables.text, "text", campaign_ids)
    icon = _query_select_for_creative(tables.icon, "icon", campaign_ids)
    pin_search = _query_select_for_creative(
        tables.pin_search, "pin_search", campaign_ids
    )
    via_point = _query_select_for_creative(tables.via_point, "via_point", campaign_ids)

    # select a table with all creatives united
    # |        creative        |  campaign_id  |
    # |   {'billboard': ...}   |       1       |
    # |    {'banner': ...}     |       1       |
    # |      {'text': ...}     |       2       |
    united_q = union_all(
        billboard, logo_and_text, banner, pin, text, icon, pin_search, via_point
    ).alias("creatives_united_q")

    # select a table where creatives grouped by campaign_id
    # |               creatives                |  campaign_id  |
    # | [{'billboard': ...}, {'banner': ...}]  |       1       |
    # | [{'text': ...}]                        |       2       |
    return (
        select(
            [
                united_q.c.campaign_id,
                cast(func.jsonb_agg(united_q.c.creative), JSONB).label("creatives"),
            ]
        )
        .select_from(united_q)
        .group_by(united_q.c.campaign_id)
        .alias("creatives_q")
    )


def _query_select_for_organizations(campaign_ids):
    # select the orgainzation entry of the campaign
    label = "organizations"
    row_q = (
        select([tables.organizations])
        .where(tables.organizations.c.campaign_id.in_(campaign_ids))
        .alias(f"{label}_row_q")
    )

    # |    organizations    |  campaign_id  |
    # | {'permalinks': ...} |       1       |
    row_q_labeled = (
        select(
            [
                row_q.c.campaign_id,
                (
                    cast(func.row_to_json(literal_column(row_q.name)), JSONB)
                    - text("'campaign_id'::text")
                ).label(f"{label}"),
            ]
        )
        .select_from(row_q)
        .alias(f"{label}_labeled_q")
    )

    # |        placing         |  campaign_id  |
    # | {'organizations': ...} |       1       |
    return select(
        [
            row_q_labeled.c.campaign_id,
            (
                cast(func.row_to_json(literal_column(row_q_labeled.name)), JSONB)
                - text("'campaign_id'::text")
            ).label("placing"),
        ]
    ).select_from(row_q_labeled)


def _query_select_for_areas(campaign_ids):
    # select the area entry of the campaign
    label = "area"
    row_q = (
        select([tables.area])
        .where(tables.area.c.campaign_id.in_(campaign_ids))
        .alias(f"{label}_row_q")
    )

    # |        area        |  campaign_id  |
    # |  {'version': ...}  |       1       |
    row_q_labeled = (
        select(
            [
                row_q.c.campaign_id,
                (
                    cast(func.row_to_json(literal_column(row_q.name)), JSONB)
                    - text("'campaign_id'::text")
                ).label(f"{label}"),
            ]
        )
        .select_from(row_q)
        .alias(f"{label}_labeled_q")
    )

    # |     placing    |  campaign_id  |
    # | {'area': ...}  |       1       |
    return select(
        [
            row_q_labeled.c.campaign_id,
            (
                cast(func.row_to_json(literal_column(row_q_labeled.name)), JSONB)
                - text("'campaign_id'::text")
            ).label("placing"),
        ]
    ).select_from(row_q_labeled)


def _query_select_for_placing(campaign_ids):
    organizations_query = _query_select_for_organizations(campaign_ids)
    areas_query = _query_select_for_areas(campaign_ids)

    placings_q = union_all(organizations_query, areas_query).alias("unioned_placing_q")

    placings_ranked_q = (
        select(
            [
                placings_q.c.placing,
                placings_q.c.campaign_id,
                func.row_number()
                .over(partition_by=placings_q.c.campaign_id)
                .label("position"),
            ]
        )
        .select_from(placings_q)
        .alias("first_placings_q")
    )

    return (
        select([placings_ranked_q.c.placing, placings_ranked_q.c.campaign_id])
        .select_from(placings_ranked_q)
        .where(placings_ranked_q.c.position == 1)
        .alias("placings_q")
    )


def _query_select_for_schedule(campaign_ids):
    label = "week_schedule"
    row_q = (
        select([tables.week_schedule])
        .where(tables.week_schedule.c.campaign_id.in_(campaign_ids))
        .alias(f"{label}_row_q")
    )

    # |    week_schedule   |  campaign_id  |
    # |   {'start': ...}   |       1       |
    # |   {'start': ...}   |       1       |
    row_q_labeled = (
        select(
            [
                row_q.c.campaign_id,
                (
                    cast(func.row_to_json(literal_column(row_q.name)), JSONB)
                    - text("'campaign_id'::text")
                ).label(f"{label}"),
            ]
        )
        .select_from(row_q)
        .alias(f"{label}_labeled_q")
    )

    # |       week_schedule       |  campaign_id  |
    # |   [{'start': ...}, ...]   |       1       |
    return (
        select(
            [
                row_q_labeled.c.campaign_id,
                cast(func.jsonb_agg(row_q_labeled.c.week_schedule), JSONB).label(label),
            ]
        )
        .select_from(row_q_labeled)
        .group_by(row_q_labeled.c.campaign_id)
        .alias("week_schedule_q")
    )


def _query_select_billing():
    return (
        select(
            [
                tables.campaign.c.id.label("campaign_id"),
                tables.fix.c.time_interval.label("fix__time_interval"),
                tables.fix.c.cost.label("fix__cost"),
                tables.cpm.c.cost.label("cpm__cost"),
                tables.cpm.c.daily_budget.label("cpm__daily_budget"),
                tables.cpm.c.budget.label("cpm__budget"),
                tables.cpa.c.cost.label("cpa__cost"),
                tables.cpa.c.budget.label("cpa__budget"),
                tables.cpa.c.daily_budget.label("cpa__daily_budget"),
                tables.billing.c.fix_id.label("fix_id"),
                tables.billing.c.cpm_id.label("cpm_id"),
                tables.billing.c.cpa_id.label("cpa_id"),
            ]
        )
        .select_from(
            join(
                tables.campaign,
                tables.billing,
                tables.billing.c.id == tables.campaign.c.billing_id,
            )
            .join(tables.fix, tables.billing.c.fix_id == tables.fix.c.id, full=True)
            .join(tables.cpm, tables.billing.c.cpm_id == tables.cpm.c.id, full=True)
            .join(tables.cpa, tables.billing.c.cpa_id == tables.cpa.c.id, full=True)
        )
        .alias("billing_q")
    )


def _query_select_for_current_status():
    ranked_statuses_q = select(
        [
            tables.status_history.c.campaign_id,
            tables.status_history.c.status,
            tables.status_history.c.metadata,
            func.rank()
            .over(
                partition_by=tables.status_history.c.campaign_id,
                order_by=tables.status_history.c.changed_datetime.desc(),
            )
            .label("status_rank"),
        ]
    ).alias("ranked_status_q")

    return (
        select(
            [
                ranked_statuses_q.c.campaign_id,
                ranked_statuses_q.c.status,
                ranked_statuses_q.c.metadata,
            ]
        )
        .select_from(ranked_statuses_q)
        .where(ranked_statuses_q.c.status_rank == 1)
        .alias("status_q")
    )


def _query_select_campaign_with_relations(campaign_ids):
    billing_q = _query_select_billing()
    creatives_q = _query_select_for_creative_list(campaign_ids)
    actions_q = _query_select_for_action_list(campaign_ids)
    placing_q = _query_select_for_placing(campaign_ids)
    schedules_q = _query_select_for_schedule(campaign_ids)
    status_q = _query_select_for_current_status()

    return (
        select(
            [
                billing_q,
                func.coalesce(creatives_q.c.creatives, "[]").label("creatives"),
                func.coalesce(actions_q.c.actions, "[]").label("actions"),
                func.coalesce(placing_q.c.placing, "[]").label("placing"),
                func.coalesce(schedules_q.c.week_schedule, "[]").label("week_schedule"),
                status_q,
                tables.campaign,
            ]
        )
        .select_from(
            join(
                tables.campaign,
                billing_q,
                tables.campaign.c.id == billing_q.c.campaign_id,
            )
            .join(
                # no status no campaign
                status_q,
                status_q.c.campaign_id == tables.campaign.c.id,
            )
            .join(
                creatives_q,
                creatives_q.c.campaign_id == tables.campaign.c.id,
                full=True,
            )
            .join(placing_q, placing_q.c.campaign_id == tables.campaign.c.id, full=True)
            .join(actions_q, actions_q.c.campaign_id == tables.campaign.c.id, full=True)
            .join(
                schedules_q,
                schedules_q.c.campaign_id == tables.campaign.c.id,
                full=True,
            )
        )
        .where(tables.campaign.c.id.in_(campaign_ids))
    )


def _query_select_short_format_campaigns() -> sqlalchemy.sql.Select:
    billing_q = _query_select_billing()
    status_q = _query_select_for_current_status()

    return select(
        [
            billing_q,
            status_q,
            tables.campaign.c.name,
            tables.campaign.c.comment,
            tables.campaign.c.timezone,
            tables.campaign.c.start_datetime,
            tables.campaign.c.end_datetime,
            tables.campaign.c.campaign_type,
            tables.campaign.c.platforms,
        ]
    ).select_from(
        join(
            tables.campaign, billing_q, tables.campaign.c.id == billing_q.c.campaign_id
        ).join(status_q, status_q.c.campaign_id == tables.campaign.c.id)
    )


def _clean_billing(campaigns: list):
    fields = [
        "fix__time_interval",
        "fix__cost",
        "cpm__daily_budget",
        "cpm__budget",
        "cpm__cost",
        "cpa__daily_budget",
        "cpa__budget",
        "cpa__cost",
        "fix_id",
        "cpm_id",
        "cpa_id",
    ]
    for campaign in campaigns:
        billing = defaultdict(dict)
        if campaign["fix_id"]:
            billing["fix"] = {
                "cost": campaign["fix__cost"],
                "time_interval": campaign["fix__time_interval"],
            }
        elif campaign["cpm_id"]:
            billing["cpm"] = {
                "cost": campaign["cpm__cost"],
                "budget": campaign["cpm__budget"],
                "daily_budget": campaign["cpm__daily_budget"],
            }
        elif campaign["cpa_id"]:
            billing["cpa"] = {
                "cost": campaign["cpa__cost"],
                "budget": campaign["cpa__budget"],
                "daily_budget": campaign["cpa__daily_budget"],
            }
        else:
            raise ValueError("Billing is neither cpm nor cpa nor fix.")

        for field in fields:
            del campaign[field]

        campaign["billing"] = dict(billing)


def _clean_fields(campaigns):
    for campaign in campaigns:
        campaign["id"] = campaign.pop("campaign_id")
        campaign["platforms"] = list(PlatformEnum[p] for p in campaign["platforms"])


def _clean_campaign(campaigns: list):
    for campaign in campaigns:
        campaign["publication_envs"] = list(
            PublicationEnvEnum[e] for e in campaign["publication_envs"]
        )
        campaign["platforms"] = list(PlatformEnum[e] for e in campaign["platforms"])
        del campaign["changed_datetime"]
        del campaign["billing_id"]
        del campaign["campaign_id"]


def _clean_status(campaigns: list):
    for campaign in campaigns:
        campaign.pop("metadata")


def _patch_campaign_category_search(campaigns: list):
    for campaign in campaigns:
        if (
            campaign["campaign_type"] == CampaignTypeEnum.CATEGORY_SEARCH
            and not campaign["placing"]
        ):
            campaign["placing"] = {"organizations": {"permalinks": []}}


async def get_campaign(campaign_id: int, use_rw_conn=False) -> Optional[dict]:
    campaigns = await get_campaigns([campaign_id], use_rw_conn=use_rw_conn)
    if len(campaigns) < 1:
        return None

    campaign = campaigns[0]
    del campaign["id"]

    return campaign


async def get_campaigns(campaign_ids, use_rw_conn=False):
    # TODO: @pconstant use query builder for retrieving data after other operations
    query = _query_select_campaign_with_relations(campaign_ids)
    con = db.rw if use_rw_conn else db.ro
    campaigns = [dict(c) for c in await con.fetch_all(query)]
    _clean_billing(campaigns)
    _clean_campaign(campaigns)
    _clean_status(campaigns)
    _patch_campaign_category_search(campaigns)
    return campaigns
