import logging

from bravado import exception
import grpc
import retry
import tvmauth
import yt.clickhouse as chyt
import yt.wrapper as yt

from crypta.lab.rule_estimator.proto import (
    rule_estimate_stats_pb2,
    update_pb2,
)
from crypta.lab.rule_estimator.services.api.proto import (
    api_pb2,
    api_pb2_grpc,
)
from crypta.lab.rule_estimator.services.worker.lib import (
    calc_rule,
    stats,
)
from crypta.lib.proto.identifiers import id_pb2
from crypta.lib.python import (
    templater,
    time_utils,
)
from crypta.lib.python.identifiers.identifiers import GenericID
from crypta.lib.python.yt import yt_helpers
from crypta.siberia.bin.common.describing.experiment.proto import describing_experiment_pb2
from crypta.siberia.bin.common.describing.mode.python import describing_mode
from crypta.siberia.bin.common.siberia_client import SiberiaClient
from crypta.lib.python.worker_utils import worker


logger = logging.getLogger(__name__)


def get_sample(yt_client, alias, yuid_tables, aggregate_tables, sample_size):
    if not yuid_tables and not aggregate_tables:
        return 0, []

    params = {"aggregate_tables": aggregate_tables, "yuid_tables": yuid_tables, "sample_size": sample_size}
    ids = [
        id_pb2.TId(Type=row["id_type"], Value=row["id"])
        for row in chyt.execute(
            templater.render_resource("/queries/get_sample.sql", vars=params),
            alias=alias,
            client=yt_client,
        )
        if GenericID(row["id_type"], row["id"]).is_valid()
    ]
    coverage = list(chyt.execute(
        templater.render_resource("/queries/get_coverage.sql", vars=params),
        alias=alias,
        client=yt_client,
    ))[0]["coverage"]
    return coverage, ids


def run_worker(*args, **kwargs):
    try:
        worker = Worker(*args, **kwargs)
        worker.run()
    except:
        logger.exception("worker failed")


class Worker(worker.Worker):
    def __init__(self, worker_config):
        super(Worker, self).__init__(worker_config)
        self.config, self.api = worker_config.context
        self.tvm_client = tvmauth.TvmClient(
            tvmauth.TvmApiClientSettings(
                self_tvm_id=self.config.Tvm.SourceTvmId,
                self_secret=self.config.Tvm.Secret,
                dsts={'siberia': self.config.Siberia.Tvm.DestinationTvmId},
            )
        ) if self.config.Tvm.Secret else None
        self.siberia = SiberiaClient(self.config.Siberia.Host, self.config.Siberia.Port)
        self.yt = yt_helpers.get_yt_client(self.config.Yt.Proxy, self.config.Yt.Pool)
        self.rule_estimator_api = api_pb2_grpc.RuleEstimatorStub(grpc.insecure_channel(self.config.RuleEstimatorApiEndpoint))

    def execute(self, task, labels):
        if isinstance(task, (int, long)):
            labels[stats.TASK_TYPE] = stats.RULE_CONDITION
            self.describe_rule_condition(task, labels)
        elif isinstance(task, update_pb2.Update):
            labels[stats.TASK_TYPE] = stats.RULE
            self.describe_rule(set(task.RuleConditionIds))
        else:
            raise Exception("Unknown task class: {}".format(task.__class__.__name__))

    def get_tvm_ticket(self):
        return self.tvm_client.get_service_ticket_for('siberia') if self.tvm_client else ""

    @retry.retry(tries=5, delay=1, backoff=1.5)
    def describe_rule_condition(self, rule_condition_id, labels):
        try:
            logger.info("Describing rule condition id: %s", rule_condition_id)
            logger.info("Getting rule condition")
            try:
                rule_condition = self.api.lab.getRuleConditionByRevision(revision=rule_condition_id).result()

                logger.info("Calculating rule")
                labels[stats.RULE_CONDITION_TYPE] = rule_condition.source

                if self.is_rule_condition_estimate_fresh(rule_condition_id):
                    logger.info("Rule condition %s is fresh", rule_condition_id)
                    return

                table = calc_rule.calc_rule(self.yt, rule_condition, self.config.RuleDir, self.config.OutputDir)
                if table is None:
                    logger.info("Rule condition %s is not supported", rule_condition_id)
                    return

                is_aggregated = self.yt.get_attribute(table, calc_rule.AGGREGATED)
                siberia_stats = self.get_stats(
                    yuid_tables=[table] if not is_aggregated else [],
                    aggregate_tables=[table] if is_aggregated else [],
                )

            except exception.HTTPNotFound:
                logger.exception("Rule condition is missing")
                siberia_stats = self.get_empty_stats()

            self.set_rule_condition_stats(rule_condition_id, siberia_stats)

            logger.info("Rule condition %s is done", rule_condition_id)
        except Exception:
            logger.exception("Failed to describe rule condtion: %s", rule_condition_id)
            raise

    def set_rule_condition_stats(self, rule_condition_id, siberia_stats):
        logger.info("Writing to rule estimator API")
        self.rule_estimator_api.SetRuleConditionStats(api_pb2.SetRuleConditionStatsRequest(
            RuleConditionId=rule_condition_id,
            Stats=siberia_stats,
        ))

    def is_rule_condition_estimate_fresh(self, rule_condition_id):
        try:
            siberia_stats = self.rule_estimator_api.GetRuleConditionStats(api_pb2.GetRuleConditionStatsRequest(
                RuleConditionId=rule_condition_id,
            ))
            logger.debug("got stats: %s", rule_condition_id)
            return siberia_stats.IsReady and (time_utils.get_current_time() - siberia_stats.Timestamp < self.config.FreshnessThresholdSec)
        except grpc.RpcError as e:
            if e.code() == grpc.StatusCode.NOT_FOUND:
                return False
            raise

    @retry.retry(tries=5, delay=1, backoff=1.5)
    def describe_rule(self, rule):
        try:
            logger.info("Describing rule: %s", rule)
            if self.is_rule_estimate_fresh(rule):
                logger.info("Rule %s is fresh", rule)
                return

            logger.info("Calculating rule")

            siberia_stats = self.get_stats(**self.get_rule_tables(rule))

            logger.info("Writing to rule estimator API")
            self.rule_estimator_api.SetRuleStats(api_pb2.SetRuleStatsRequest(
                RuleConditionIds=rule,
                Stats=siberia_stats,
            ))
            logger.info("Rule %s is done", rule)
        except Exception:
            logger.exception("Failed to describe rule: %s", rule)
            raise

    def is_rule_estimate_fresh(self, rule):
        try:
            siberia_stats = self.rule_estimator_api.GetRuleStats(api_pb2.GetRuleStatsRequest(
                RuleConditionIds=rule,
            ))
            logger.debug("got stats: %s", rule)
            return siberia_stats.IsReady and (time_utils.get_current_time() - siberia_stats.Timestamp < self.config.FreshnessThresholdSec)
        except grpc.RpcError as e:
            if e.code() == grpc.StatusCode.NOT_FOUND:
                return False
            raise

    def get_rule_tables(self, rule):
        yuid_tables = []
        aggregate_tables = []

        for rule_condition_id in rule:
            src_table = yt.ypath_join(self.config.OutputDir, str(rule_condition_id))
            if self.yt.exists(src_table):
                if self.yt.get_attribute(src_table, calc_rule.AGGREGATED):
                    aggregate_tables.append(src_table)
                else:
                    yuid_tables.append(src_table)

        return {
            "yuid_tables": yuid_tables,
            "aggregate_tables": aggregate_tables,
        }

    def get_stats(self, yuid_tables, aggregate_tables):
        logger.info("Getting coverage and ids")
        coverage, ids = get_sample(self.yt, self.config.ChytAlias, yuid_tables, aggregate_tables, self.config.SampleSize)

        if not ids:
            logger.warn("Empty ids")
            return self.get_empty_stats()

        ids = id_pb2.TIds(Ids=ids)

        logger.info("Sending to siberia")
        user_set_id = int(self.siberia.user_sets_describe_ids(
            ids,
            mode=describing_mode.FAST,
            tvm_ticket=self.get_tvm_ticket(),
            experiment=describing_experiment_pb2.TDescribingExperiment(CryptaIdUserDataVersion="by_crypta_id"),
        ).UserSetId)

        return rule_estimate_stats_pb2.RuleEstimateStats(
            Coverage=coverage,
            UserSetId=user_set_id,
            Timestamp=time_utils.get_current_time(),
        )

    @staticmethod
    def get_empty_stats():
        return rule_estimate_stats_pb2.RuleEstimateStats(
            Timestamp=time_utils.get_current_time(),
        )
