from textwrap import dedent

from Config import ALL_ACTIONS, Reducer


class SQLGenerator:
    def __init__(self, config, start_table):
        self._config = config
        self._result_tables = []
        self._sql = self._generate_sql(start_table)

    @property
    def tables(self):
        return self._result_tables

    @property
    def sql(self):
        return self._sql

    # ===== global

    def _generate_sql(self, start_table):
        result = self._generate_initial_selection(start_table)
        result += self._generate_all_projections()
        result += self._generate_non_projected()
        return result

    def _generate_initial_selection(self, start_table):
        event_names = set([field for event in self._config.events for field in event.event_mapper.fields])

        return dedent("""
            USE hahn;

            $events = (
                SELECT
                    TableName()                 AS data_table_name,
                    EventName                   AS key,
                    CAST (EventValue as Json)   AS value
                FROM RANGE(`//logs/appmetrica-events-log/appmetrica-yandex-events/30min`, "{start_table}")
                WHERE AppID = "ru.yandex.music" AND EventName IN (
                    {event_sources}
                )
            );
        """.format(
            start_table=start_table,
            event_sources=",\n                    ".join(['"%s"' % x for x in event_names]),
        ))

    # ===== projections

    def _generate_all_projections(self):
        result = ""
        for sn in set([str(e.event_mapper) for e in self._config.events if e.extract_mapper != ['']]):
            result += self._generate_projection(sn)
            result += self._generate_final_projection_select(sn)
        return result

    def _generate_projection(self, section_name):
        events = [self._extract_json_field_sql(e) for e in self._config.events if e.event_mapper == [section_name]]

        return dedent("""
             $projection_{name} = (
                 SELECT
                     data_table_name,
                     {fields}
                 FROM $events
                 WHERE key = "{name}"
             );
         """.format(
            name=section_name,
            fields=",\n                     ".join(events),
        ))

    def _generate_final_projection_select(self, section_name):
        self._result_tables.append([])

        events = [e for e in self._config.events if e.event_mapper == [section_name]]
        sql_fields = [f for e in events for f in self._extract_reduced(e)]

        return dedent("""
            SELECT
                data_table_name,
                {fields}
            FROM $projection_{name}
            GROUP BY data_table_name;
        """.format(
            name=section_name,
            fields=",\n                ".join(sql_fields),
        ))

    def _extract_reduced(self, e):
        if e.reducer == Reducer.PERCENTILE:
            self._result_tables[-1] = \
                self._result_tables[-1] + ['%s_%s' % (e.target_name, int(p * 100)) for p in self._config.percentile]
            return [
                'PERCENTILE(%s, %s) as %s_%s' % (e.target_name, p, e.target_name, int(p * 100)) for p in
                self._config.percentile
            ]
        else:
            raise RuntimeError('Unhandled reducer: ' + e.reducer.name)

    @staticmethod
    def _extract_json_field_sql(event):
        result = []
        for token in event.extract_mapper.all_fields:
            if token in ALL_ACTIONS:
                result.append(token)
            else:
                result.append('JSON_VALUE(value, "$.%s" RETURNING Float)' % token)
        return ' '.join(result) + ' AS ' + event.target_name

    # ===== non-projections

    def _generate_non_projected(self):
        events = [e for e in self._config.events if e.extract_mapper == ['']]
        sql_fields = [self._extract_count_sql(e) for e in events]
        self._result_tables.append([e.target_name for e in events])
        return dedent("""
            SELECT
                data_table_name,
                {fields}
            FROM $events
            GROUP BY data_table_name;
        """.format(
            fields=",\n                ".join(sql_fields),
        ))

    @staticmethod
    def _extract_count_sql(event):
        if event.reducer != Reducer.COUNT:
            raise RuntimeError('Bad reducer: ' + event.reducer.name)

        result = []
        for token in event.event_mapper.all_fields:
            if token in ALL_ACTIONS:
                result.append(token)
            else:
                result.append('COUNT_IF(key == "' + token + '")')
        return ' '.join(result) + ' AS ' + event.target_name
