# coding: utf-8

from typing import Iterable
from typing import TypeVar
from typing import Tuple

from sandbox import sdk2
from sandbox.projects.yabs.ranking_group.YabsServerRankingGraphs.inputs import DictionariesBase
from sandbox.projects.yabs.ranking_group.YabsServerRankingGraphs.inputs import HitLogBase
from sandbox.projects.yabs.ranking_group.YabsServerRankingGraphs.inputs import LogsBase
from sandbox.projects.yabs.ranking_group.YabsServerRankingGraphs.inputs import PageDictionaryBase
from sandbox.projects.yabs.ranking_group.YabsServerRankingGraphs.inputs import SAMPLERS_DEFAULT_TEXT
from sandbox.projects.yabs.ranking_group.YabsServerRankingGraphs.inputs import StageStatsLogBase
from sandbox.projects.yabs.ranking_group.YabsServerRankingGraphs.inputs import StageTimingsLogBase
from sandbox.projects.yabs.yt_logfeller_sampler import YtLogFellerSamplers


class InputFullInterface(object):
    OtherInputs: Tuple[str] = tuple()

    def __init__(self, task: sdk2.Task):
        self.task = task

    def get_join(self, main_log: str) -> str:
        raise NotImplementedError()

    def iter_filtrations(self, main_log: str) -> Iterable[str]:
        yield from []


class LogFull(InputFullInterface, LogsBase):
    def get_prefix(self) -> str:
        return getattr(self.task.Parameters, self.get_prefix_field())

    def get_range(self) -> str:
        sampler_name = getattr(self.task.Parameters, self.get_sampler_name())
        if sampler_name == SAMPLERS_DEFAULT_TEXT:
            sampler = getattr(YtLogFellerSamplers, self.DefaultSamplerName)
        else:
            sampler = getattr(YtLogFellerSamplers, sampler_name)

        from_table = sampler.get_table_name(self.task.get_parsed_begin_time())
        to_table = sampler.get_table_name(self.task.get_parsed_end_time())

        return f'RANGE(`{sampler.get_folder_name(self.get_prefix())}`, "{from_table}", "{to_table}")'

    def get_on(self, main_log: str) -> str:
        if main_log == self.Name:
            return ''
        raise NotImplementedError()

    def get_join(self, main_log: str) -> str:
        return f'{self.get_range()} AS {self.Name} {self.get_on(main_log) if main_log != self.Name else ""}'


class DictionariesFull(InputFullInterface, DictionariesBase):
    def get_on(self):
        raise NotImplementedError()

    def get_prefix(self) -> str:
        return getattr(self.task.Parameters, self.get_table_path_field)

    def get_join(self, main_log: str) -> str:
        return f'`{self.get_prefix()}` AS {self.Name} {self.get_on() if main_log != self.Name else ""}'


class PageInput(DictionariesFull, PageDictionaryBase):
    OtherInputs = (HitLogBase.Name,)

    def get_on(self):
        return f'ON {self.Name}.PageID == CAST({HitLogBase.Name}.pageid AS INT64)'


class HitLogInput(LogFull, HitLogBase):
    def get_on(self, main_log: str) -> str:
        if main_log in (StageStatsLogBase.Name, StageTimingsLogBase.Name):
            return f'ON {main_log}.HitLogID = CAST({self.Name}.hitlogid AS UINT64) AND {main_log}.RequestID = {self.Name}.requestid'
        else:
            raise RuntimeError(f'HitLog can\'t been joined with {main_log}')

    def iter_filtrations(self, main_log: str) -> Iterable[str]:
        if main_log == StageStatsLogBase.Name:
            yield f'{self.Name}.options LIKE "%stage-stats-log-written%"'
        elif main_log == StageTimingsLogBase.Name:
            yield f'{self.Name}.options LIKE "%stage-timings-log-written%"'
        else:
            raise RuntimeError(f'HitLog can\'t been joined with {main_log}')


class StageStatsInput(LogFull, StageStatsLogBase):
    def get_on(self, main_log: str) -> str:
        raise RuntimeError(f'HitLog can\'t been joined with {main_log}')


class StageTimingsInput(LogFull, StageTimingsLogBase):
    def get_on(self, main_log: str) -> str:
        raise RuntimeError(f'HitLog can\'t been joined with {main_log}')


InputsFullInterfaceSubclassType = TypeVar('InputsFullInterfaceSubclassType', bound=InputFullInterface)

Inputs = (
    PageInput,
    StageStatsInput,
    StageTimingsInput,
    HitLogInput,
)

InputsDict = {
    input.Name: input
    for input in Inputs
}


def get_input_by_name(name: str, task: sdk2.Task) -> InputFullInterface:
    return InputsDict[name](task)
