import datetime
import logging

import grpc
import retry
from yt import yson

from crypta.graph.export.proto.graph_pb2 import TGraph
from crypta.lib.proto.identifiers import id_pb2
from crypta.lib.python.worker_utils import worker
from crypta.lib.python.yt import yt_helpers
from crypta.ltp.viewer.lib import ltp_logs
from crypta.ltp.viewer.lib.chyt.client import ChytClient
from crypta.ltp.viewer.lib.structs import status
from crypta.ltp.viewer.lib.structs.id import Id
from crypta.ltp.viewer.lib.structs.record import Record
from crypta.ltp.viewer.lib.structs.task import Task
from crypta.ltp.viewer.lib.ydb import client
from crypta.ltp.viewer.proto import command_pb2
from crypta.ltp.viewer.services.api.proto import api_pb2_grpc
from crypta.ltp.viewer.services.worker.lib import stats
from crypta.ltp.viewer.services.worker.lib.context import ContextView


logger = logging.getLogger(__name__)


class Worker(worker.Worker):
    def __init__(self, worker_config):
        super(Worker, self).__init__(worker_config)
        self.config, context = worker_config.context
        self.yt = yt_helpers.get_yt_client(self.config.Yt.Proxy, self.config.Yt.Pool)
        self.dynamic_yt = yt_helpers.get_yt_client(self.config.DynamicYt.Proxy, self.config.DynamicYt.Pool)
        self.ltp_viewer_api = api_pb2_grpc.LtpViewerStub(grpc.insecure_channel(self.config.LtpViewerApiEndpoint))
        self.ydb = client.Client(self.config.Ydb.Endpoint, self.config.Ydb.Database, self.config.Ydb.Token)
        self.chyt = ChytClient(self.yt, self.config.ChytAlias)
        self.context = ContextView(context)

        self.callbacks = {
            "PreloadHistoryCommand": self.preload_history,
            "PreloadHistoryChunkCommand": self.preload_history_chunk,
            "DropHistoryCommand": self.drop_history,
            "ExpireCommand": self.expire,
        }
        logger.info("Worker started")

    def execute(self, task, labels):
        cmd_type = task.WhichOneof("Command")
        labels[stats.CMD_TYPE] = cmd_type
        self.callbacks[cmd_type](getattr(task, cmd_type))

    @retry.retry(tries=5, delay=1, backoff=1.5)
    def preload_history(self, command):
        id_ = Id.from_proto(command.Id)
        logger.info("Preloading: %s", id_)

        ids, history_id = self.get_ids(id_)
        logger.info("Graph: %s %s", history_id, ids)

        tasks = []

        running_tasks = {
            (row.id, row.id_type, row.log, row.date)
            for row in self.ydb.get_progress(id_, command.FromDate, command.ToDate) or []
            if row.status != status.FAILED
        }

        logs_per_id = self.chyt.get_logs_from_index(ids, self.config.Paths.IndexPath)

        for id_, logs in logs_per_id.items():
            logger.info("Logs for %s: %s", id_, logs)
            tasks.extend(
                Task(id_.id_type, id_.id, log, date)
                for log, dates in logs.items()
                for date in dates
                if (id_.id, id_.id_type, log, date) not in running_tasks and
                (not command.FromDate or date >= command.FromDate) and
                (not command.ToDate or date <= command.ToDate)
            )

        logger.info("Scheduling tasks: %s", tasks)

        scheduled_tasks = self.ydb.add_graph(ids, str(history_id), tasks, self.config.TasksPerHistory)
        self.schedule_chunks(scheduled_tasks)

        logger.info("Preloading finished: %s", id_)

    def schedule_chunks(self, tasks):
        self.ltp_viewer_api.PreloadHistoryChunk(
            command_pb2.TPreloadHistoryChunkCommand(
                Id=id_pb2.TId(
                    Type=task.id_type,
                    Value=task.id,
                ),
                Log=task.log,
                Date=task.date,
            )
            for task in tasks
        )

    def preload_history_chunk(self, command):
        id_ = Id.from_proto(command.Id)
        log = ltp_logs.LOGS_DICT[command.Log]

        try:
            self.try_preload_history_chunk(id_, log, command.Date)
        except Exception:
            self.fail_preload_history_chunk(id_, log, command.Date)

    @retry.retry(tries=5, delay=1, backoff=1.5)
    def try_preload_history_chunk(self, id_, log, date):
        logger.info("Preloading chunk: %s %s %s", id_, log.name, date)

        entries = self.chyt.get_entries_from_logs(id_, log.path, date, log.yt_columns, log.chyt_columns)
        records = (
            Record(
                entry["ActionTimestamp"],
                log.format_description(entry, self.context),
                log.format_additional_description(entry, self.context),
            )
            for entry in entries
        )
        scheduled_tasks = self.ydb.insert_chunk(date, log.name, records, id_)
        self.schedule_chunks(scheduled_tasks)

        logger.info("Preloading chunk finished: %s %s %s", id_, log.name, date)

    @retry.retry(tries=5, delay=1, backoff=1.5)
    def fail_preload_history_chunk(self, id_, log, date):
        logger.exception("Failed to process chunk: %s %s %s", id_, log.name, date)
        self.ydb.fail_chunk(date, log.name, id_)

    @retry.retry(tries=5, delay=1, backoff=1.5)
    def drop_history(self, command):
        self.ydb.drop_history(command.HistoryId)

    @retry.retry(tries=5, delay=1, backoff=1.5)
    def expire(self, command):
        ttl = datetime.timedelta(seconds=command.TTLSeconds)
        self.ydb.expire_queries(ttl)
        history_ids = self.ydb.get_expired_history_ids(ttl)
        self.ltp_viewer_api.DropHistory(
            command_pb2.TDropHistoryCommand(
                HistoryId=history_id,
            )
            for history_id in history_ids
        )

    def get_ids(self, id_):
        default_history_id = "{}-{}".format(id_.id_type, id_.id)
        if id_.id_type == "crypta_id":
            crypta_id = int(id_.id)
        else:
            crypta_ids = list(self.dynamic_yt.lookup_rows(self.config.Paths.IdToCryptaIdTable, [{"IdType": id_.id_type, "IdValue": id_.id}], column_names=["CryptaId"]))
            if not crypta_ids:
                return [id_], default_history_id
            crypta_id = crypta_ids[0]["CryptaId"]

        raw_graph = list(self.dynamic_yt.lookup_rows(self.config.Paths.CryptaIdToGraphTable, [{"CryptaId": crypta_id}], column_names=["Graph"]))
        if not raw_graph:
            return [id_], default_history_id

        graph = TGraph()
        graph.ParseFromString(yson.get_bytes(raw_graph[0]["Graph"]))
        return [Id(node.Type, node.Id) for node in graph.Nodes] + [Id("crypta_id", str(crypta_id))], str(crypta_id)
