from typing import List

from staff.headcounts.headcounts_summary.query_builder.aliases import Aliases
from staff.headcounts.headcounts_summary.query_builder.assignment_part_query_builder import AssignmentPartQueryBuilder
from staff.headcounts.headcounts_summary.query_builder.hierarchy_part_query_builder import HierarchyPartQueryBuilder
from staff.headcounts.headcounts_summary.query_builder.query_params import QueryParams, RelatedEntity


class SummaryQueryBuilder:
    assignment_alias = 'assignment'

    def __init__(
        self,
        query_params: QueryParams,
        aliases: Aliases,
        assignement_query_builder: AssignmentPartQueryBuilder,
    ) -> None:
        self.query_params = query_params
        self._aliases = aliases
        self._assignement_query_builder = assignement_query_builder
        self._department_permission_filter_builder = HierarchyPartQueryBuilder(
            self.query_params.department_permission_filters
        )
        self._value_stream_permission_filter_builder = HierarchyPartQueryBuilder(
            self.query_params.value_stream_permission_filters
        )
        self._geography_permission_filter_builder = HierarchyPartQueryBuilder(
            self.query_params.geography_permission_filters
        )
        self._department_user_filter_builder = HierarchyPartQueryBuilder(self.query_params.department_filters)
        self._value_stream_user_filter_builder = HierarchyPartQueryBuilder(self.query_params.value_stream_filters)
        self._geography_user_filter_builder = HierarchyPartQueryBuilder(self.query_params.geography_filters)

    def _select_grouping_fields(self) -> List[str]:
        return [
            f'{self.assignment_alias}.{grouping.pk_field_name} as {grouping.pk_field_name}'
            for grouping in self.query_params.groupings
        ]

    def _format_cte(self, alias: str, expression: str):
        return f'{alias} AS ({expression})'

    def _build_ctes(self) -> str:
        ctes = [
            self._format_cte(self.assignment_alias, self._assignement_query_builder.build()),
        ]

        if self.query_params.department_permission_filters:
            ctes.append(self._format_cte(
                alias=self._aliases.department_permissions_join,
                expression=self._department_permission_filter_builder.build(),
            ))

        if self.query_params.value_stream_permission_filters:
            ctes.append(self._format_cte(
                alias=self._aliases.value_stream_permissions_join,
                expression=self._value_stream_permission_filter_builder.build(),
            ))

        if self.query_params.geography_permission_filters:
            ctes.append(self._format_cte(
                alias=self._aliases.geography_permissions_join,
                expression=self._geography_permission_filter_builder.build(),
            ))

        if self.query_params.department_filters:
            ctes.append(self._format_cte(
                alias=self._aliases.department_user_filter_join,
                expression=self._department_user_filter_builder.build(),
            ))

        if self.query_params.value_stream_filters:
            ctes.append(self._format_cte(
                alias=self._aliases.value_stream_user_filter_join,
                expression=self._value_stream_user_filter_builder.build(),
            ))

        if self.query_params.geography_filters:
            ctes.append(self._format_cte(
                alias=self._aliases.geography_user_filter_join,
                expression=self._geography_user_filter_builder.build(),
            ))

        return ', '.join(ctes)

    def _aggregate_fields(self) -> List[str]:
        return [
            f'SUM(COALESCE({self.assignment_alias}.headcount, 0)) as headcount',
            f'SUM(COALESCE({self.assignment_alias}.working, 0)) as working',
            f'SUM(COALESCE({self.assignment_alias}.offer, 0)) as offer',
            (
                f'SUM(COALESCE({self.assignment_alias}.headcount, 0)) '
                f' - SUM(COALESCE({self.assignment_alias}.working, 0)) '
                f' - SUM(COALESCE({self.assignment_alias}.offer, 0)) as balance'
            ),
            f'SUM(COALESCE({self.assignment_alias}.vacancies, 0)) as vacancies',
            f'SUM(COALESCE({self.assignment_alias}.vacancies_plan_new, 0)) as vacancies_plan_new',
            f'SUM(COALESCE({self.assignment_alias}.vacancies_plan_replacement, 0)) as vacancies_plan_replacement',
            f'SUM(COALESCE({self.assignment_alias}.credit, 0)) as credit',
            f'SUM(COALESCE({self.assignment_alias}.working_crossing, 0)) as working_crossing',
            f'SUM(COALESCE({self.assignment_alias}.offer_crossing, 0)) as offer_crossing',
            f'SUM(COALESCE({self.assignment_alias}.vacancies_plan_crossing, 0)) as vacancies_plan_crossing',
            f'SUM(COALESCE({self.assignment_alias}.vacancies_crossing, 0)) as vacancies_crossing',
            (
                f'SUM(COALESCE({self.assignment_alias}.working_crossing, 0))'
                f' + SUM(COALESCE({self.assignment_alias}.offer_crossing, 0))'
                f' + SUM(COALESCE({self.assignment_alias}.vacancies_plan_crossing, 0))'
                f' + SUM(COALESCE({self.assignment_alias}.vacancies_crossing, 0)) as crossing'
            ),
        ]

    def _permissions_join(self, related_entity: RelatedEntity, alias: str) -> str:
        return f'LEFT {self._filters_join(related_entity, alias)}'

    def _filters_join(self, related_entity: RelatedEntity, alias: str) -> str:
        return f'JOIN {alias} ON ({self.assignment_alias}.{related_entity.pk_field_name} = {alias}.id)'

    def _joins(self):
        result = []

        if self.query_params.department_filters:
            result.append(self._filters_join(RelatedEntity.department, self._aliases.department_user_filter_join))

        if self.query_params.value_stream_filters:
            result.append(self._filters_join(RelatedEntity.value_stream, self._aliases.value_stream_user_filter_join))

        if self.query_params.geography_filters:
            result.append(self._filters_join(RelatedEntity.geography, self._aliases.geography_user_filter_join))

        if self.query_params.department_permission_filters:
            result.append(self._permissions_join(RelatedEntity.department, self._aliases.department_permissions_join))

        if self.query_params.value_stream_permission_filters:
            result.append(
                self._permissions_join(RelatedEntity.value_stream, self._aliases.value_stream_permissions_join)
            )

        if self.query_params.geography_permission_filters:
            result.append(
                self._permissions_join(RelatedEntity.geography, self._aliases.geography_permissions_join)
            )

        return ' '.join(result)

    def _where(self) -> str:
        permissions_conditions = []
        if self.query_params.department_permission_filters:
            permissions_conditions.append(f'{self._aliases.department_permissions_join}.id IS NOT NULL')

        if self.query_params.value_stream_permission_filters:
            permissions_conditions.append(f'{self._aliases.value_stream_permissions_join}.id IS NOT NULL')

        if self.query_params.geography_permission_filters:
            permissions_conditions.append(f'{self._aliases.geography_permissions_join}.id IS NOT NULL')

        if not permissions_conditions:
            return ''

        conditions = ' OR '.join(permissions_conditions)
        return f'WHERE {conditions}'

    def _group_by(self):
        if not self.query_params.groupings:
            return ''

        groupings = [
            f'{self.assignment_alias}.{related_entity.pk_field_name}'
            for related_entity in self.query_params.groupings
        ]
        formatted_groupings = ', '.join(groupings)
        return f'GROUP BY ({formatted_groupings})'

    def build(self) -> str:
        ctes = self._build_ctes()
        fields = ', '.join(self._select_grouping_fields() + self._aggregate_fields())
        joins = self._joins()
        where = self._where()

        result = f'WITH {ctes} SELECT {fields} FROM {self.assignment_alias} {joins} {where} {self._group_by()}'
        return result

    def build_query_for_rows_without_grouping(self) -> str:
        ctes = self._build_ctes()
        fields = ', '.join(self._select_grouping_fields() + [f'{self.assignment_alias}.*'])
        joins = self._joins()
        where = self._where()

        result = f'WITH {ctes} SELECT {fields} FROM {self.assignment_alias} {joins} {where}'
        return result
