package handler

import (
	"encoding/json"
	"fmt"
	"time"

	"a.yandex-team.ru/library/go/core/log"
	"a.yandex-team.ru/travel/library/go/metrics"
	"github.com/opentracing/opentracing-go"
	"golang.org/x/net/context"
	"google.golang.org/grpc/codes"
	"google.golang.org/grpc/status"

	"a.yandex-team.ru/travel/rasp/train_bandit_api/internal/api/model"
	"a.yandex-team.ru/travel/rasp/train_bandit_api/internal/bandit"
	"a.yandex-team.ru/travel/rasp/train_bandit_api/pkg/utils"
	pb "a.yandex-team.ru/travel/rasp/train_bandit_api/proto"
)

type GRPCHandler struct {
	ChargerGetter bandit.ChargerGetter
	Logger        log.Logger
	Pricer        model.Pricer
	Metrics       *metrics.AppMetrics
}

func (h *GRPCHandler) GetCharge(ctx context.Context, request *pb.TGetChargeRequest) (
	response *pb.TGetChargeResponse, err error) {
	span, ctx := opentracing.StartSpanFromContext(ctx, "api.handler.GetCharge")
	defer span.Finish()
	return h.getChargeInner(ctx, request)
}

func (h *GRPCHandler) getChargeInner(ctx context.Context, request *pb.TGetChargeRequest) (
	response *pb.TGetChargeResponse, err error) {
	now := time.Now()
	ts := now.UnixNano()
	reqStr, _ := json.Marshal(request)
	h.Logger.Debugf("Called getChargeInner(%v) with request: %s", ts, reqStr)

	var charges []*pb.TCharge
	var chargesWithFees []*pb.TCharge
	logInfoMap := map[uint32]*pb.TAdditionalLogInfo{}

	charger, banditType, banditVersion, err := h.ChargerGetter.GetCharger(request.BanditType)
	if err != nil {
		h.Logger.Error("Cannot get charger", log.Error(err))
		return nil, status.Error(codes.InvalidArgument, "unknown bandit type "+request.BanditType)
	}

	partnerFee, partnerRefundFee, err := h.Pricer.GetPartnerFees(now)
	if err != nil {
		h.Logger.Error("Cannot get partner fee", log.Error(err))
		return nil, status.Error(codes.Internal, "get fee error")
	}

	for _, contextWithPrice := range request.ContextsWithPrices {
		if contextWithPrice.AdditionalLogInfo != nil {
			logInfoMap[contextWithPrice.InternalId] = contextWithPrice.AdditionalLogInfo
		}
		ticketFees := make(map[uint32]*pb.TTicketFee)
		for i, ticketPrice := range contextWithPrice.TicketPrices {
			ticketFees[i] = &pb.TTicketFee{TicketPrice: ticketPrice}
		}

		charge := &pb.TCharge{
			Context:    contextWithPrice.Context,
			Permille:   0,
			InternalId: contextWithPrice.InternalId,
			TicketFees: ticketFees,
		}
		charges = append(charges, charge)
	}

	for _, charge := range charges {
		bc, _ := json.Marshal(charge.Context)
		bcString := string(bc)
		var minTariffPermille *uint32
		charge.Permille, minTariffPermille, err = charger.GetCharge(charge.Context)
		if minTariffPermille != nil {
			charge.OptionalMinTariffPermille = &pb.TCharge_MinTariffPermille{MinTariffPermille: *minTariffPermille}
		}
		if err != nil {
			h.Logger.Error("Cannot get permille from bandit", log.Error(err))
			return nil, status.Error(codes.Internal, "bandit error")
		}
		h.Metrics.GetOrCreateCounter("get-charge",
			map[string]string{"bandit-type": banditType, "permille": fmt.Sprint(charge.Permille)},
			"permille-count").Inc()
		h.Logger.Debugf("GetCharge call. banditType: %s, permille: %v, context: %s", banditType, charge.Permille, bcString)
		charge.BanditVersion = banditVersion
		charge.BanditType = banditType

		chargeWithFee, err := h.Pricer.CalculateFees(charge, true, now)
		if err != nil {
			h.Logger.Error("Cannot calculate fees", log.Error(err))
			return nil, status.Error(codes.Internal, "calculate price error")
		}
		chargesWithFees = append(chargesWithFees, chargeWithFee)
	}

	resp := &pb.TGetChargeResponse{
		PartnerFee:        partnerFee,
		PartnerRefundFee:  partnerRefundFee,
		ChargesByContexts: chargesWithFees,
	}
	respStr, _ := json.Marshal(resp)
	h.Logger.Debugf("Finished getChargeInner(%v) with response: %s", ts, respStr)
	return resp, nil
}

func (h *GRPCHandler) GetChargeStringContext(ctx context.Context, request *pb.TGetChargeStringCtxRequest) (
	response *pb.TGetChargeStringCtxResponse, err error) {
	span, ctx := opentracing.StartSpanFromContext(ctx, "api.handler.GetChargeStringContext")
	defer span.Finish()
	rq := &pb.TGetChargeRequest{
		BanditType:         request.BanditType,
		ContextsWithPrices: make([]*pb.TTicketPrices, len(request.ContextsWithPrices)),
	}
	for i, contextWithPrice := range request.ContextsWithPrices {
		bc, err := utils.BanditContextFromStr(contextWithPrice.Context)
		if err != nil {
			h.Logger.Error("Can not parse bandit context", log.Error(err))
			return nil, status.Error(codes.InvalidArgument, "invalid context")
		}
		rq.ContextsWithPrices[i] = &pb.TTicketPrices{
			InternalId:        contextWithPrice.InternalId,
			TicketPrices:      contextWithPrice.TicketPrices,
			AdditionalLogInfo: contextWithPrice.AdditionalLogInfo,
			Context:           bc,
		}
	}

	rsp, err := h.getChargeInner(ctx, rq)
	if err != nil {
		return nil, err
	}

	response = &pb.TGetChargeStringCtxResponse{
		PartnerFee:        rsp.PartnerFee,
		PartnerRefundFee:  rsp.PartnerRefundFee,
		ChargesByContexts: make([]*pb.TChargeStringCtx, len(rsp.ChargesByContexts)),
	}
	for i, charge := range rsp.ChargesByContexts {
		bc, err := utils.MarshalToStr(charge.Context)
		if err != nil {
			h.Logger.Error("Can not get string context", log.Error(err))
			return nil, status.Error(codes.Internal, "string context error")
		}
		response.ChargesByContexts[i] = &pb.TChargeStringCtx{
			BanditType:    charge.BanditType,
			BanditVersion: charge.BanditVersion,
			Permille:      charge.Permille,
			InternalId:    charge.InternalId,
			TicketFees:    charge.TicketFees,
			Context:       bc,
		}
	}
	return response, nil
}

func (h *GRPCHandler) GetChargeByToken(ctx context.Context, request *pb.TGetChargeByTokenRequest) (
	response *pb.TGetChargeByTokenResponse, err error) {
	span, ctx := opentracing.StartSpanFromContext(ctx, "api.handler.GetChargeByToken")
	defer span.Finish()
	token, err := utils.BanditTokenFromStr(request.FeeCalculationToken)
	if err != nil {
		h.Logger.Error("Can not parse bandit token", log.Error(err))
		return nil, status.Error(codes.InvalidArgument, "invalid context")
	}
	rq := &pb.TGetChargeRequest{
		BanditType:         token.ActualBanditType,
		ContextsWithPrices: make([]*pb.TTicketPrices, 1),
	}
	rq.ContextsWithPrices[0] = &pb.TTicketPrices{
		InternalId:   0,
		TicketPrices: request.TicketPrices,
		Context:      token.Context,
	}

	rsp, err := h.getChargeInner(ctx, rq)
	if err != nil {
		return nil, err
	}

	response = &pb.TGetChargeByTokenResponse{
		PartnerFee:       rsp.PartnerFee,
		PartnerRefundFee: rsp.PartnerRefundFee,
	}
	charge := rsp.ChargesByContexts[0]
	bc, err := utils.MarshalToStr(charge.Context)
	if err != nil {
		h.Logger.Error("Can not get string context", log.Error(err))
		return nil, status.Error(codes.Internal, "string context error")
	}
	response.ChargesByContext = &pb.TChargeStringCtx{
		BanditType:    charge.BanditType,
		BanditVersion: charge.BanditVersion,
		Permille:      charge.Permille,
		InternalId:    charge.InternalId,
		TicketFees:    charge.TicketFees,
		Context:       bc,
	}
	return response, nil
}
