# coding: utf-8

from typing import Any
from typing import Iterable
from typing import List
from typing import Set
from typing import Tuple
from typing import Type
from typing import TypeVar

from sandbox import sdk2
from sandbox.projects.yabs.ranking_group.YabsServerRankingGraphs.definitions import AggregationsDefinition
from sandbox.projects.yabs.ranking_group.YabsServerRankingGraphs.definitions import DeleteUnfinitedColumnsDefinition
from sandbox.projects.yabs.ranking_group.YabsServerRankingGraphs.definitions import GetExperimentsDefinition
from sandbox.projects.yabs.ranking_group.YabsServerRankingGraphs.filtrations import Filtrations
from sandbox.projects.yabs.ranking_group.YabsServerRankingGraphs.lib.inputs import HitLogInput
from sandbox.projects.yabs.ranking_group.YabsServerRankingGraphs.lib.inputs import InputFullInterface
from sandbox.projects.yabs.ranking_group.YabsServerRankingGraphs.lib.inputs import get_input_by_name
from sandbox.projects.yabs.ranking_group.YabsServerRankingGraphs.sources import PMatchFailsBitsByExpsSourceBase
from sandbox.projects.yabs.ranking_group.YabsServerRankingGraphs.sources import SourceBaseInterface
from sandbox.projects.yabs.ranking_group.YabsServerRankingGraphs.sources import StageStatsSourceBase
from sandbox.projects.yabs.ranking_group.YabsServerRankingGraphs.sources import StageTimingsSourceBase


class SourceFullInterface(SourceBaseInterface):
    def __init__(self, task: sdk2.Task):
        self.task: sdk2.Task = task
        self.main_log: InputFullInterface = get_input_by_name(self.MainLogName, self.task)

        self.inputs_names: Set[str] = set()
        self.inputs: List[InputFullInterface] = []

        inputs_to_add: Set[str] = set()

        for filtration in Filtrations:
            if filtration.is_enabled(self.task):
                for log in filtration.LogsNames:
                    if log != self.MainLogName:
                        inputs_to_add.add(log)

        while inputs_to_add:
            new_inputs_to_add: Set[str] = set()
            for input_name in inputs_to_add:
                self.inputs_names.add(input_name)
                input: InputFullInterface = get_input_by_name(input_name, self.task)
                self.inputs.append(input)
                for new_input in input.OtherInputs:
                    if new_input != self.MainLogName and new_input not in self.inputs_names:
                        new_inputs_to_add.add(new_input)

            inputs_to_add = new_inputs_to_add

    def get_definitions_names(self) -> Tuple[str, ...]:
        return tuple()

    def iter_selects(self) -> Iterable[str]:
        raise NotImplementedError()

    def get_main_log_join(self) -> str:
        return self.main_log.get_join(self.MainLogName)

    def get_postfix_before_where(sellf) -> str:
        return ''

    def get_postfix_after_where(self) -> str:
        return ''

    def iter_joins(self) -> Iterable[str]:
        for input in self.inputs:
            yield f'{input.get_join(self.MainLogName)}'

    def iter_filtrations(self) -> Iterable[str]:
        for filtration in Filtrations:
            if not filtration.is_enabled(self.task):
                continue
            yield from filtration.iter_filtrations(self.task)

        for input in self.inputs:
            yield from input.iter_filtrations(self.MainLogName)


class StageStatsSource(StageStatsSourceBase, SourceFullInterface):
    def get_definitions_names(self) -> Tuple[str, ...]:
        return (
            AggregationsDefinition.Name,
            DeleteUnfinitedColumnsDefinition.Name,
        )

    def iter_data_columns(self) -> Iterable[str]:
        from sandbox.projects.yabs.ranking_group.YabsServerRankingGraphs.lib.stage_stats import AllColumns
        for column in AllColumns:
            yield column.get_select(self.MainLogName)

    def iter_selects(self) -> Iterable[str]:
        yield 'StaticMap(AsStruct({}), ${}) AS DataColumns'.format(
            ', '.join(self.iter_data_columns()),
            DeleteUnfinitedColumnsDefinition.Name,
        )
        yield 'StageName'
        yield 'PlaceSelect'
        yield f'{self.MainLogName}.iso_eventtime AS iso_eventtime'


class StageTimingsSource(StageTimingsSourceBase, SourceFullInterface):
    def iter_selects(self) -> Iterable[str]:
        yield 'HitLogID'
        yield 'RequestID'
        yield 'StatGroup'
        yield 'Stage'
        yield 'HardHit'
        yield f'SUM(IF({self.MainLogName}.DurationMicroseconds < 4e9, {self.MainLogName}.DurationMicroseconds, 0)) as DurationMicroseconds'

    def get_postfix_after_where(self) -> str:
        return '''GROUP BY {log}.HitLogID AS HitLogID,
    {log}.RequestID AS RequestID,
    {log}.StatGroup AS StatGroup,
    {log}.Stage AS Stage,
    {log}.SamplingRate >= 7500 AS `HardHit`'''.format(log=self.MainLogName)


class PMatchFailsBitsByExpsSource(PMatchFailsBitsByExpsSourceBase, SourceFullInterface):
    def iter_selects(self) -> Iterable[str]:
        yield f'CAST({self.MainLogName}.hitlogid AS UINT64) AS HitLogID'
        yield 'PageLabel'
        yield 'ExpID'
        yield f'CAST({self.MainLogName}.pmatchfailbits AS UINT64) AS PMatchFailBits'

    def get_postfix_before_where(self) -> str:
        return '''FLATTEN BY (
    String::SplitToList({log}.`pagelabels`, ',') AS PageLabel,
    ${get_experiments_func_name}({log}.`activetestids`) AS ExpID)'''.format(
            log=HitLogInput.Name,
            get_experiments_func_name=GetExperimentsDefinition.Name,
        )

    def get_definitions_names(self) -> Tuple[str, ...]:
        return (
            GetExperimentsDefinition.Name,
        )


SourceFullInterfaceSubclassType = TypeVar('SourceFullInterfaceSubclassType', bound=SourceFullInterface)

Sources: Tuple[Any, ...] = (
    StageStatsSource,
    StageTimingsSource,
    PMatchFailsBitsByExpsSource,
)

SourcesDict = {
    source.Name: source
    for source in Sources
}


def get_source_by_name(name: str) -> Type[SourceFullInterfaceSubclassType]:
    return SourcesDict[name]
