package handler

import (
	"a.yandex-team.ru/library/go/core/log"
	pb "a.yandex-team.ru/travel/avia/price_prediction/api/v1"
	"a.yandex-team.ru/travel/avia/price_prediction/internal/checkprice"
	"context"
	"fmt"
	"github.com/opentracing/opentracing-go"
	"google.golang.org/grpc"
	"google.golang.org/grpc/codes"
	"google.golang.org/grpc/status"
	"strings"
)

type GRPCCheckPriceHandler struct {
	logger       log.Logger
	priceChecker checkprice.PriceChecker
}

func NewGRPCCheckPriceHandler(l log.Logger, checker checkprice.PriceChecker) *GRPCCheckPriceHandler {
	return &GRPCCheckPriceHandler{logger: l, priceChecker: checker}
}

func (h *GRPCCheckPriceHandler) CheckPrices(ctx context.Context, request *pb.TCheckPricesReq) (
	response *pb.TCheckPricesRsp,
	err error,
) {
	span, _ := opentracing.StartSpanFromContext(ctx, "internal.handler.GRPCCheckPriceHandler:CheckPrices")
	defer span.Finish()

	response = &pb.TCheckPricesRsp{PriceCategories: make(map[uint32]pb.TPricePrediction_ECategory)}
	failedPreconditions := make(map[uint32]error)
	for key, r := range request.CheckPricesReq {
		response.PriceCategories[key] = pb.TPricePrediction_CATEGORY_UNKNOWN

		req, err := checkprice.NewCheckPriceRequest(r)
		if err != nil {
			failedPreconditions[key] = err
			continue
		}

		category := h.priceChecker(ctx, *req)
		response.PriceCategories[key] = mapPriceCategory(category)
	}

	if len(failedPreconditions) != 0 {
		return response, status.Error(codes.FailedPrecondition, makeErrorMsg(failedPreconditions))
	} else {
		return response, nil
	}
}

func (h *GRPCCheckPriceHandler) GetServiceRegisterer() func(*grpc.Server) {
	return func(server *grpc.Server) {
		pb.RegisterCheckPriceServiceServer(server, h)
	}
}

func mapPriceCategory(category checkprice.PriceCategory) pb.TPricePrediction_ECategory {
	switch category {
	case checkprice.PriceCategoryGood:
		return pb.TPricePrediction_CATEGORY_GOOD
	case checkprice.PriceCategoryBad:
		return pb.TPricePrediction_CATEGORY_BAD
	default:
		return pb.TPricePrediction_CATEGORY_UNKNOWN
	}
}

func makeErrorMsg(errors map[uint32]error) string {
	parts := make([]string, 0, len(errors))
	for key, err := range errors {
		parts = append(parts, fmt.Sprintf("%v: %v", key, err))
	}
	return strings.Join(parts, "; ")
}
