from __future__ import unicode_literals
import os
import datetime

from infra.dproxy.proto import dproxy_pb2
import ydb


class TablesGeneratorBase(object):
    def __init__(
        self,
        log,
        scheme_client,
        min_date,
        max_date,
        dt,
        table_path,
        query_builder,
        order,
    ):
        self.log = log
        self.scheme_client = scheme_client
        self.min_date = min_date
        self.max_date = max_date
        self.dt = dt
        self.table_path = table_path
        self.query_builder = query_builder
        self.order = order

        self.cache = None
        self.fill_cache()

    def __iter__(self):
        return self

    def __next__(self):
        if not self.cache:
            raise StopIteration()
        table_name, start, end = self.cache.pop()
        return (self.query_builder
                .copy()
                .with_table(table_name)
                .with_timestamp((start, end), self.order)
                .build()
                )

    next = __next__

    def get_table_types(self):
        return []

    def fill_cache(self):
        try:
            tables = [
                child.name
                for child in self.scheme_client.list_directory(self.table_path).children
                if child.type in self.get_table_types()
            ]
        except ydb.issues.SchemeError:
            # table_path doesn't exist
            self.cache = []
            self.log.debug("loaded 0 tables, as path doesn't exist")
            return

        asc = self.order == dproxy_pb2.ASC
        if asc:
            min_date = self.dt
            max_date = self.max_date
        else:
            min_date = self.min_date
            max_date = self.dt

        self.cache = self.generate_cache(tables, min_date, max_date, asc)

    def generate_cache(
        self,
        tables,
        min_date,
        max_date,
        asc,
    ):
        return None


class DateTablesGenerator(TablesGeneratorBase):
    TABLE_FMT = "%Y-%m-%d"

    @staticmethod
    def make_table_name_from_dt(dt):
        return dt.strftime(DateTablesGenerator.TABLE_FMT)

    @staticmethod
    def try_parse_table_name(name):
        try:
            return datetime.datetime.strptime(name, DateTablesGenerator.TABLE_FMT)
        except ValueError:
            return None

    def get_table_types(self):
        return [ydb.SchemeEntryType.TABLE]

    def generate_cache(
        self,
        tables,
        min_date,
        max_date,
        asc,
    ):
        # well, we have to rely on useful date format here
        table_dates = list(filter(None, (DateTablesGenerator.try_parse_table_name(t) for t in tables)))
        tables = sorted(
            (t for t in table_dates if min_date.date() <= t.date() <= max_date.date()),
            reverse=not asc,  # if ascending order, earliest should be at the end
        )

        table_names_for_log = [DateTablesGenerator.make_table_name_from_dt(table) for table in tables[:3]]
        self.log.debug(
            "loaded %d tables: %s%s",
            len(tables),
            table_names_for_log,
            '...' if len(tables) > 3 else '',
        )

        cache = []
        for table in tables:
            min_table_dt = max(table, min_date)
            max_table_dt = min(table + datetime.timedelta(days=1) - datetime.timedelta(microseconds=1), max_date)

            table_name = DateTablesGenerator.make_table_name_from_dt(table)

            self.log.debug("for date %s making interval %s .. %s", table, min_table_dt, max_table_dt)
            if asc:
                start = min_table_dt
                for interval in self.query_builder.IN_TABLE_INTERVALS:
                    if start > max_table_dt:
                        break
                    cache.append((table_name, start, min(start + interval, max_table_dt)))
                    start += interval
            else:
                end = max_table_dt
                for interval in self.query_builder.IN_TABLE_INTERVALS:
                    if end < min_table_dt:
                        break
                    cache.append((table_name, max(end - interval, min_table_dt), end))
                    end -= interval

        return list(reversed(cache))


class OlapTablesGenerator(TablesGeneratorBase):
    TABLE_NAME = 'log'

    def get_table_types(self):
        return [ydb.SchemeEntryType.COLUMN_TABLE]

    def generate_cache(
        self,
        tables,
        min_date,
        max_date,
        asc,
    ):
        return [(OlapTablesGenerator.TABLE_NAME, min_date, max_date)]


def make_table_path_prefix(db, fragments, table_prefix=None):
    if table_prefix:
        return os.path.join(db, table_prefix, *fragments)

    return os.path.join(db, *fragments)


def escape_dots(key):
    return key.replace(".", "\\.")


def keys_of_dict(
    d,
    prefix='',
    escaped_prefix='',
    depth=0,
    query_prefix='',
):
    if depth > 10:  # protection from excessive recursion
        return

    for key, val in d.iteritems():
        key_full = prefix + key
        escaped_key_full = escaped_prefix + escape_dots(key)
        if key_full.startswith(query_prefix):
            yield escaped_key_full

        if isinstance(val, dict):
            new_prefix = key_full + '.'
            new_escaped_prefix = escaped_key_full + '.'

            # if we can filter out the whole subbranch, do it
            subquery_prefix = query_prefix[:len(new_prefix)]
            if subquery_prefix.startswith(new_prefix[:len(subquery_prefix)]):
                for k in keys_of_dict(
                    val,
                    prefix=new_prefix,
                    escaped_prefix=new_escaped_prefix,
                    depth=depth + 1,
                ):
                    yield k
