import datetime
import logging

import yt.wrapper as yt

from crypta.lab.proto import constructor_pb2
from crypta.lib.python.yt import (
    schema_utils,
    yt_helpers,
)
from crypta.profile.runners.segments.lib.constructor_segments import build_constructor_segments
from crypta.profile.runners.segments.lib.constructor_segments.common import utils


logger = logging.getLogger(__name__)

WHITELIST = {
    constructor_pb2.RuleCondition.Source.Name(source_type) for source_type in (
        constructor_pb2.RuleCondition.Source.METRICA_SITES,
        constructor_pb2.RuleCondition.Source.BROWSER_SITES,
        constructor_pb2.RuleCondition.Source.YANDEX_REFERRER,
        constructor_pb2.RuleCondition.Source.SITES,
        constructor_pb2.RuleCondition.Source.PUBLIC_SITES,

        constructor_pb2.RuleCondition.Source.METRICA_TITLES,
        constructor_pb2.RuleCondition.Source.BROWSER_TITLES,
        constructor_pb2.RuleCondition.Source.SEARCH_REQUESTS,
        constructor_pb2.RuleCondition.Source.WORDS,
        constructor_pb2.RuleCondition.Source.PUBLIC_WORDS,

        constructor_pb2.RuleCondition.Source.SEARCH_RESULTS_HOSTS,

        constructor_pb2.RuleCondition.Source.APPS,
        constructor_pb2.RuleCondition.Source.PRECALCULATED_TABLES,
    )
}

AGGREGATED = "aggregated"
TIMEOUT = 3600 * 1000


def process_task(yt_client, rule_table, compute):
    rule_dir, table_name = yt.ypath_split(rule_table)
    yt_client.mkdir(rule_dir, recursive=True)

    with yt_client.Transaction() as tx:
        yt_client.lock(rule_dir, mode="shared", child_key=table_name, waitable=True, wait_for=TIMEOUT)
        if yt_client.exists(rule_table):
            return rule_table

        compute(rule_table, tx)
        yt_helpers.set_ttl(
            table=rule_table,
            ttl_timedelta=datetime.timedelta(days=1),
            yt_client=yt_client,
        )
    return rule_table


def calc_rule(yt_client, rule_condition, rule_dir, output_dir):
    output_table = yt.ypath_join(output_dir, str(rule_condition.revision))
    output_table_name = str(rule_condition.revision)
    if yt_client.exists(output_table):
        return output_table

    if rule_condition.source not in WHITELIST:
        logger.warn("Source %s is not supported", rule_condition.source)
        return

    config = build_constructor_segments.ConstructorSegmentsConfig(logger, yt_client, None)
    config.add_rule_conditions([rule_condition], rule_condition.ruleId)
    daily_tasks, aggregate_tasks = config.prepare_rules(update_rejected=False)

    rule_tables = []

    if daily_tasks:
        for task_cls, kwargs in daily_tasks.iteritems():
            input_table = sorted(yt_client.list(task_cls.index_dir, absolute=True))[-1]
            rule_table = yt.ypath_join(rule_dir, task_cls.__name__, str(rule_condition.revision))
            rule_tables.append(rule_table)

            if yt_client.exists(rule_table):
                continue

            task = task_cls(date="placeholder", **kwargs)
            task.compute_pre_tx(input_table, rule_table)
            process_task(task.yt, rule_table, lambda table, tx: task.compute(input_table, table, tx))

    elif aggregate_tasks:
        for task_cls, kwargs in aggregate_tasks.iteritems():
            rule_table = yt.ypath_join(rule_dir, task_cls.__name__, str(rule_condition.revision))
            rule_tables.append(rule_table)

            task = task_cls(date="placeholder", **kwargs)
            process_task(task.yt, rule_table, lambda table, tx: task.compute(table, tx))

    is_aggregated = len(aggregate_tasks) > 0
    with yt_client.Transaction():
        yt_client.lock(output_dir, mode="shared", child_key=output_table_name, waitable=True, wait_for=TIMEOUT)
        yt_client.create(
            "table",
            output_table,
            recursive=True,
            ignore_existing=True,
            attributes={"schema": utils.aggregated_schema if is_aggregated else schema_utils.yt_schema_from_dict(utils.daily_schema)},
        )
        if rule_tables:
            yt_client.run_merge(rule_tables, output_table)
        yt_client.set_attribute(output_table, AGGREGATED, is_aggregated)
        yt_helpers.set_ttl(
            table=output_table,
            ttl_timedelta=datetime.timedelta(days=1),
            yt_client=yt_client,
        )

    return output_table
