from abc import ABC, abstractmethod
from enum import Enum
from sandbox.projects.yabs.AnomalyAuditReport.lib.config_reader import Database, Totals, Threshold


class ColumnType(Enum):
    Bits = 'Bits'
    WideBits = 'WideBits'
    Array = 'Array'
    Group = 'Group'
    Dictionary = 'Dictionary'
    Other = 'Other'


class QueryBuilderBase(ABC):

    def __init__(self, current_date, previous_date, report_config: dict) -> None:
        super().__init__()

        self._table_name = report_config['tableName']
        self._select_level = report_config['selectLevel']

        self._where_clause = report_config['whereClause']

        self._select_good_events = report_config['selectGoodEvents']
        self._where_clause_good_events = report_config['whereClauseGoodEvents']

        self._current_date = current_date
        self._previous_date = previous_date

        self._thresholds = report_config['thresholds']

    def get_query(self, column_name: str, column_type: str, totals: Totals, db_column_name) -> str:

        tmp_column_name = column_name.strip().lower()

        threshold = self._thresholds.get('default')
        if tmp_column_name in self._thresholds:
            threshold = self._thresholds.get(tmp_column_name)
        elif db_column_name in self._thresholds:
            threshold = self._thresholds.get(db_column_name)

        return self._get_query_template().format(currentDate=self._current_date,
                                                 previousDate=self._previous_date,

                                                 selectFields=self._get_select_clause(column_name, column_type),
                                                 tableName=self._table_name,
                                                 subSelectFields=self._get_sub_select_clause(column_name, column_type),

                                                 whereClauseGoodEvents=self._where_clause_good_events,

                                                 whereClause=self._get_where_clause(),
                                                 groupByFields=self._get_group_by_clause(column_name),

                                                 having=self._get_having_clause(totals.current,
                                                                                totals.previous,
                                                                                threshold)
                                                 )

    def get_total_query(self) -> str:
        return self._get_total_query_template().format(currentDate=self._current_date,
                                                       previousDate=self._previous_date,
                                                       tableName=self._table_name,
                                                       whereClauseGoodEvents=self._where_clause_good_events,
                                                       whereClause=self._get_where_clause()
                                                       )

    @abstractmethod
    def get_table_columns_query(self):
        ...

    @abstractmethod
    def _get_query_template(self) -> str:
        ...

    @abstractmethod
    def _get_total_query_template(self) -> str:
        ...

    @abstractmethod
    def _get_select_clause(self, column_name: str, column_type: ColumnType) -> str:
        ...

    def _get_sub_select_clause(self, column_name: str, column_type: ColumnType) -> str:
        pass

    @abstractmethod
    def _get_where_clause(self) -> str:
        ...

    def _get_group_by_clause(self, column_name: str) -> str:
        return column_name

    def _get_having_clause(self, current_total_count: int, previous_total_count: int, threshold: Threshold):
        return """
        (       current_result > 0 AND previous_result = 0
            AND current_result > {current_total_count} * {LOWER_SIGNIFICANT_BOUND}
        )
        OR
        (       current_result = 0 and previous_result > 0
            AND previous_result > {previous_total_count} * {LOWER_SIGNIFICANT_BOUND}
        )
        OR
        (
            (
                    current_result > {current_total_count} * {LOWER_SIGNIFICANT_BOUND} AND previous_result > 0
                OR  previous_result > {previous_total_count} * {LOWER_SIGNIFICANT_BOUND} AND current_result > 0
            )
            AND
            (
                    current_result / {current_total_count} > previous_result / {previous_total_count} * {SMALL_EXPAND_VALUE}
                OR  current_result / {current_total_count} * {SMALL_EXPAND_VALUE} < previous_result / {previous_total_count}
            )
        )
        OR
        (
            (
                (   current_result < {current_total_count} * {LOWER_SIGNIFICANT_BOUND}
                OR
                    previous_result < {previous_total_count} * {LOWER_SIGNIFICANT_BOUND}
                )
                AND current_result > 0 AND previous_result > 0
            )
            AND
            (
                    current_result / {current_total_count} * {EXPAND_VALUE} < previous_result / {previous_total_count}
                OR  current_result / {current_total_count} > previous_result / {previous_total_count} * {EXPAND_VALUE}
            )
        )
        """.format(current_total_count=current_total_count,
                   previous_total_count=previous_total_count,
                   LOWER_SIGNIFICANT_BOUND=threshold.LOWER_SIGNIFICANT_BOUND,
                   SMALL_EXPAND_VALUE=threshold.SMALL_EXPAND_VALUE,
                   EXPAND_VALUE=threshold.EXPAND_VALUE)


class ClickHouseQueryBuilder(QueryBuilderBase):
    def get_table_columns_query(self):
        return f'DESCRIBE TABLE {self._table_name};'

    def _get_select_clause(self, column_name: str, column_type: ColumnType) -> str:

        if column_type == ColumnType.Bits:
            return f'toInt16(arrayJoin(arrayMap(x->log2(x),bitmaskToArray({column_name})))) as {column_name}'
        elif column_type == ColumnType.WideBits:
            return """
            toInt16
            (
                arrayJoin
                (
                    arrayConcat
                    (
                        arrayMap(x->log2(x),bitmaskToArray(reinterpretAsUInt64(substring({column_name},1,8)))),
                        arrayMap(x->64+log2(x),bitmaskToArray(reinterpretAsUInt64(substring({column_name},9,8)))),
                        arrayMap(x->128+log2(x),bitmaskToArray(reinterpretAsUInt64(substring({column_name},17,8))))
                    )
                )
            )
            as {column_name}""".format(column_name=column_name)

        elif column_type == ColumnType.Array:
            return f"if( notEmpty({column_name}), arrayJoin(splitByChar(',', {column_name})), '') as {column_name}"
        else:
            return column_name

    def _get_total_query_template(self) -> str:
        if self._select_good_events:
            return """
        SELECT
            countIf( EventDate = '{currentDate}') as current_result,
            countIf( EventDate = '{previousDate}') as previous_result,
            countIf( EventDate = '{currentDate}' AND {whereClauseGoodEvents}) as current_good_result,
            countIf( EventDate = '{previousDate}' AND {whereClauseGoodEvents}) as previous_good_result
        FROM
            {tableName}
        WHERE
                (EventDate = '{currentDate}' or EventDate = '{previousDate}')
            {whereClause}
        ;
        """
        else:
            return """
        SELECT
            countIf( EventDate = '{currentDate}') as current_result,
            countIf( EventDate = '{previousDate}') as previous_result,
            0 as current_good_result,
            0 as previous_good_result
        FROM
            {tableName}
        WHERE
                (EventDate = '{currentDate}' or EventDate = '{previousDate}')
            {whereClause}
        ;
        """

    def _get_query_template(self) -> str:
        if self._select_good_events:
            return """
        SELECT
            {selectFields},
            countIf( EventDate = '{currentDate}') as current_result,
            countIf( EventDate = '{previousDate}') as previous_result,
            countIf( EventDate = '{currentDate}' AND {whereClauseGoodEvents}) as current_good_result,
            countIf( EventDate = '{previousDate}' AND {whereClauseGoodEvents}) as previous_good_result
        FROM
            {tableName}
        WHERE
                (EventDate = '{currentDate}' or EventDate = '{previousDate}')
            {whereClause}
        GROUP BY
            {groupByFields}
        HAVING
            {having}

        """
        else:
            return """
        SELECT
            {selectFields},
            countIf( EventDate = '{currentDate}') as current_result,
            countIf( EventDate = '{previousDate}') as previous_result,
            0 as current_good_result,
            0 as previous_good_result
        FROM
            {tableName}
        WHERE
                (EventDate = '{currentDate}' or EventDate = '{previousDate}')
            {whereClause}
        GROUP BY
            {groupByFields}
        HAVING
            {having}

        """

    def _get_where_clause(self) -> str:
        return f'AND {self._where_clause}' if self._where_clause else ''


class YtQueryBuilderBase(QueryBuilderBase):
    def _get_select_clause(self, column_name: str, column_type: ColumnType) -> str:
        return column_name

    def _get_sub_select_clause(self, column_name: str, column_type: ColumnType) -> str:
        # Difference between CH and CHYT is Nullable columns. Due to this reason assumeNotNull operator is used

        if column_type == ColumnType.Bits:
            return f'toInt16(arrayJoin(arrayMap(x->log2(x),bitmaskToArray(assumeNotNull({column_name}))))) as {column_name}'
        elif column_type == ColumnType.WideBits:
            return """
            toInt16
            (
                arrayJoin
                (
                    arrayConcat
                    (
                        arrayMap(x->log2(x),bitmaskToArray(reinterpretAsUInt64(substring(assumeNotNull({column_name}), 1, 8)))),
                        arrayMap(x->64+log2(x),bitmaskToArray(reinterpretAsUInt64(substring(assumeNotNull({column_name}), 9, 8)))),
                        arrayMap(x->128+log2(x),bitmaskToArray(reinterpretAsUInt64(substring(assumeNotNull({column_name}), 17, 8))))
                    )
                )
            )
            as {column_name}""".format(column_name=column_name)

        elif column_type == ColumnType.Array:
            return f"if(notEmpty({column_name}), arrayJoin(splitByChar(',', assumeNotNull({column_name}))), '') as {column_name}"
        else:
            return column_name

    def _get_where_clause(self) -> str:
        return """
            WHERE
                {}
        """.format(self._where_clause) if self._where_clause else ''

    def _get_total_query_template(self) -> str:
        if self._select_good_events:
            return """
        SELECT
            sum(if( EventDate = '{currentDate}', result, 0 )) as current_result,
            sum(if( EventDate = '{previousDate}', result, 0 )) as previous_result,
            sum(if( EventDate = '{currentDate}', result_good, 0 )) as current_good_result,
            sum(if( EventDate = '{currentDate}', result_good, 0 ))as previous_good_result
        FROM
        (
            SELECT
                '{currentDate}' as EventDate,
                count(*) as result,
                countIf({whereClauseGoodEvents}) as result_good
            FROM
                `{tableName}/{currentDate}`
            {whereClause}

            UNION ALL

            SELECT
                '{previousDate}' as EventDate,
                count(*) as result,
                countIf({whereClauseGoodEvents}) as result_good
            FROM
                `{tableName}/{previousDate}`
            {whereClause}
        )
        ;
        """
        else:
            return """
        SELECT
            sum(if( EventDate = '{currentDate}', result, 0 )) as current_result,
            sum(if( EventDate = '{previousDate}', result, 0 )) as previous_result,
            0 as current_good_result,
            0 as previous_good_result
        FROM
        (
            SELECT
                '{currentDate}' as EventDate,
                count(*) as result
            FROM
                `{tableName}/{currentDate}`
            {whereClause}

            UNION ALL

            SELECT
                '{previousDate}' as EventDate,
                count(*) as result
            FROM
                `{tableName}/{previousDate}`
            {whereClause}
        )
        ;
        """


class ClickHouseOverYtQueryBuilder(YtQueryBuilderBase):
    def get_table_columns_query(self):
        return f'DESCRIBE TABLE `{self._table_name}/{self._current_date}`;'

    def _get_query_template(self) -> str:
        if self._select_good_events:
            return """
        SELECT
            {selectFields},
            sum(if( EventDate = '{currentDate}', result, 0 )) as current_result,
            sum(if( EventDate = '{previousDate}', result, 0 )) as previous_result,
            sum(if( EventDate = '{currentDate}', result_good, 0 )) as current_good_result,
            sum(if( EventDate = '{currentDate}', result_good, 0 ))as previous_good_result
        FROM
        (
            SELECT
                {subSelectFields},
                '{currentDate}' as EventDate,
                count(*) as result,
                countIf({whereClauseGoodEvents}) as result_good
            FROM
                `{tableName}/{currentDate}`
            {whereClause}
            GROUP BY
                {groupByFields}

            UNION ALL

            SELECT
                {subSelectFields},
                '{previousDate}' as EventDate,
                count(*) as result,
                countIf({whereClauseGoodEvents}) as result_good
            FROM
                `{tableName}/{previousDate}`
            {whereClause}
            GROUP BY
                {groupByFields}
        )
        GROUP BY
            {groupByFields}
        HAVING
            {having}
        """
        else:
            return """
        SELECT
            {selectFields},
            sum(if( EventDate = '{currentDate}', result, 0 )) as current_result,
            sum(if( EventDate = '{previousDate}', result, 0 )) as previous_result,
            0 as current_good_result,
            0 as previous_good_result
        FROM
        (
            SELECT
                {subSelectFields},
                '{currentDate}' as EventDate,
                count(*) as result
            FROM
                `{tableName}/{currentDate}`
            {whereClause}
            GROUP BY
                {groupByFields}

            UNION ALL

            SELECT
                {subSelectFields},
                '{previousDate}' as EventDate,
                count(*) as result
            FROM
                `{tableName}/{previousDate}`
            {whereClause}
            GROUP BY
                {groupByFields}
        )
        GROUP BY
            {groupByFields}
        HAVING
            {having}
        """


class YtQueryBuilder(YtQueryBuilderBase):
    def get_table_columns_query(self):
        return "//logs/bs-chevent-log/1d/2022-02-01" + "/@schema"
        return f'{self._table_name}/@schema'

    def _get_query_template(self) -> str:
        if self._select_good_events:
            return """
        SELECT *
        FROM
        (
            SELECT
                {selectFields},
                sum(if( EventDate = '{currentDate}', result, 0 )) as current_result,
                sum(if( EventDate = '{previousDate}', result, 0 )) as previous_result,
                sum(if( EventDate = '{currentDate}', result_good, 0 )) as current_good_result,
                sum(if( EventDate = '{currentDate}', result_good, 0 ))as previous_good_result
            FROM
            (
                SELECT
                    {subSelectFields},
                    '{currentDate}' as EventDate,
                    count(*) as result,
                    countIf({whereClauseGoodEvents}) as result_good
                FROM
                    `{tableName}/{currentDate}`
                {whereClause}
                GROUP BY
                    {groupByFields}

                UNION ALL

                SELECT
                    {subSelectFields},
                    '{previousDate}' as EventDate,
                    count(*) as result,
                    countIf({whereClauseGoodEvents}) as result_good
                FROM
                    `{tableName}/{previousDate}`
                {whereClause}
                GROUP BY
                    {groupByFields}
            )
            GROUP BY
                {groupByFields}
        )
        WHERE
            {having}
        """
        else:
            return """
        SELECT *
        FROM
        (
            SELECT
                {selectFields},
                sum(if( EventDate = '{currentDate}', result, 0 )) as current_result,
                sum(if( EventDate = '{previousDate}', result, 0 )) as previous_result,
                0 as current_good_result,
                0 as previous_good_result
            FROM
            (
                SELECT
                    {subSelectFields},
                    '{currentDate}' as EventDate,
                    count(*) as result
                FROM
                    `{tableName}/{currentDate}`
                {whereClause}
                GROUP BY
                    {groupByFields}

                UNION ALL

                SELECT
                    {subSelectFields},
                    '{previousDate}' as EventDate,
                    count(*) as result
                FROM
                    `{tableName}/{previousDate}`
                {whereClause}
                GROUP BY
                    {groupByFields}
            )
            GROUP BY
                {groupByFields}
        )
        WHERE
            {having}
        """


class QueryBuilderFactory:

    @staticmethod
    def get_query_builder(report_config: dict, current_date, previous_date) -> QueryBuilderBase:
        if report_config['database'] == Database.ClickHouse:
            return ClickHouseQueryBuilder(current_date, previous_date, report_config)
        elif report_config['database'] == Database.ClickHouseOverYT:
            return ClickHouseOverYtQueryBuilder(current_date, previous_date, report_config)
        elif report_config['database'] == Database.YT:
            return YtQueryBuilder(current_date, previous_date, report_config)
        else:
            raise ValueError(
                f"Anomaly Audit Report {report_config['reportName']} with DB {report_config['database']} is not supported")
