from concurrent import futures
import signal

import grpc
from grpc_reflection.v1alpha import reflection
from library.python.protobuf.json import proto2json

from crypta.lib.python.grpc import keepalive
from crypta.ltp.viewer.proto import command_pb2
from crypta.ltp.viewer.lib.structs import status
from crypta.ltp.viewer.lib.structs.id import Id
from crypta.ltp.viewer.lib.structs.filter import Filter
from crypta.ltp.viewer.lib.structs.page import Page
from crypta.ltp.viewer.lib.ydb.client import Client
from crypta.ltp.viewer.services.api.proto import (
    api_pb2,
    api_pb2_grpc,
)
from crypta.lib.python.lb_pusher import logbroker


class LtpViewer(api_pb2_grpc.LtpViewerServicer):
    def __init__(self, pq_writer, ydb_client, logger):
        self.pq_writer = pq_writer
        self.ydb_client = ydb_client
        self.logger = logger

    def GetHistory(self, request, context):
        self.logger.info("GetHistory: %s", request)
        id_ = Id.from_proto(request.Id)
        self.ydb_client.save_query(request.Owner, id_, request.FromDate, request.ToDate)
        history_filter = Filter.from_proto(request.Filter, request.FromDate, request.ToDate)

        rows, total = self.ydb_client.get_history(id_, Page.from_proto(request.Page), history_filter)
        response = api_pb2.TGetHistoryResponse()
        response.Total = total
        for row in rows:
            item = response.Items.add()
            item.EventTime = row.timestamp
            item.Id = row.id
            item.IdType = row.id_type
            item.SourceType = row.log
            item.Description = row.description
            item.AdditionalDescription = row.additional_description

        cmd = command_pb2.TCommand()
        cmd.PreloadHistoryCommand.Owner = request.Owner
        cmd.PreloadHistoryCommand.FromDate = request.FromDate
        cmd.PreloadHistoryCommand.ToDate = request.ToDate
        cmd.PreloadHistoryCommand.Id.Type = request.Id.Type
        cmd.PreloadHistoryCommand.Id.Value = request.Id.Value
        self.write_command_to_lb(cmd)

        return response

    def PreloadHistory(self, request, context):
        self.logger.info("PreloadHistory: %s", request)
        self.ydb_client.save_query(request.Owner, Id.from_proto(request.Id), request.FromDate, request.ToDate)
        cmd = command_pb2.TCommand()
        cmd.PreloadHistoryCommand.CopyFrom(request)
        self.write_command_to_lb(cmd)

        return api_pb2.TPreloadHistoryResponse(Message="Ok")

    def PreloadHistoryChunk(self, request_iterator, context):
        cmd = command_pb2.TCommand()

        for request in request_iterator:
            self.logger.info("PreloadHistoryChunk: %s", request)
            cmd.PreloadHistoryChunkCommand.CopyFrom(request)
            self.write_command_to_lb(cmd)

        return api_pb2.TPreloadHistoryChunkResponse(Message="Ok")

    def DropHistory(self, request_iterator, context):
        cmd = command_pb2.TCommand()

        for request in request_iterator:
            self.logger.info("DropHistory: %s", request)
            cmd.DropHistoryCommand.CopyFrom(request)
            self.write_command_to_lb(cmd)

        return api_pb2.TDropHistoryResponse(Message="Ok")

    def Expire(self, request, context):
        cmd = command_pb2.TCommand()
        self.logger.info("Expire: %s", request)
        cmd.ExpireCommand.CopyFrom(request)
        self.write_command_to_lb(cmd)

        return api_pb2.TExpireResponse(Message="Ok")

    def Ping(self, request, context):
        return api_pb2.TPingResponse(Message="Ok")

    def GetUserQueries(self, request, context):
        queries = self.ydb_client.get_user_queries(request.Owner)
        response = api_pb2.TGetUserQueriesResponse()
        for query in queries:
            progress = self.ydb_client.get_progress(Id(query.id_type, query.id), query.from_date, query.from_date)

            proto_query = response.Queries.add()
            proto_query.Id.Value = query.id
            proto_query.Id.Type = query.id_type
            proto_query.Ready = progress is not None and all(item.status in {status.FAILED, status.COMPLETED} for item in progress)
            proto_query.FromDate = query.from_date
            proto_query.ToDate = query.to_date

        return response

    def GetProgress(self, request, context):
        self.logger.info("GetProgress: %s", request)
        progress = self.ydb_client.get_progress(Id.from_proto(request.Id), request.FromDate, request.ToDate)

        response = api_pb2.TGetProgressResponse()
        if progress is None:
            response.Scheduled = False
        else:
            response.Scheduled = True
            response.Stats.Total = len(progress)

            for chunk in progress:
                by_date = response.StatsByDate[chunk.date]
                by_date.Total += 1

                if chunk.status == status.FAILED:
                    response.Stats.Failed += 1
                    by_date.Failed += 1
                elif chunk.status == status.COMPLETED:
                    response.Stats.Completed += 1
                    by_date.Completed += 1

        return response

    def write_command_to_lb(self, cmd):
        payload = proto2json.proto2json(cmd)
        self.pq_writer.write(payload)
        self.logger.info("Command: %s", payload)


def serve(config, logger):
    pq_client = logbroker.PQClient(
        config.Logbroker.Url,
        config.Logbroker.Port,
        tvm_id=config.Tvm.SourceTvmId,
        tvm_secret=config.Tvm.Secret,
    )
    pq_writer = pq_client.get_writer(config.Topic)

    ydb_client = Client(
        config.Ydb.Endpoint,
        config.Ydb.Database,
        config.Ydb.Token,
    )

    server = grpc.server(
        futures.ThreadPoolExecutor(max_workers=config.Workers),
        options=keepalive.get_keepalive_options()
    )
    api_pb2_grpc.add_LtpViewerServicer_to_server(
        LtpViewer(pq_writer, ydb_client, logger),
        server,
    )
    reflection.enable_server_reflection((
        api_pb2.DESCRIPTOR.services_by_name["LtpViewer"].full_name,
        reflection.SERVICE_NAME,
    ), server)
    server.add_insecure_port("[::]:{}".format(config.Port))

    signal.signal(signal.SIGTERM, lambda *args: server.stop(10).wait(10))

    with pq_client, pq_writer:
        server.start()
        server.wait_for_termination()
