package feeservice

import (
	"context"
	"fmt"
	"strings"

	"a.yandex-team.ru/library/go/core/log"
	"google.golang.org/protobuf/types/known/timestamppb"

	banditpb "a.yandex-team.ru/travel/rasp/train_bandit_api/proto"

	"a.yandex-team.ru/travel/trains/search_api/api/tariffs"
	"a.yandex-team.ru/travel/trains/search_api/internal/direction/segments"
	"a.yandex-team.ru/travel/trains/search_api/internal/pkg/helpers"
)

type FeeService struct {
	banditClient banditpb.BanditApiServiceV1Client
	logger       log.Logger
}

type SegmentContextProvider func(segment *segments.TrainSegment, place *tariffs.TrainPlace, icookie string) *banditpb.TContext

func NewService(
	banditClient banditpb.BanditApiServiceV1Client,
	logger log.Logger,
) *FeeService {
	return &FeeService{
		banditClient: banditClient,
		logger:       logger,
	}
}

func (s *FeeService) ApplyTrainSegmentsFee(
	ctx context.Context,
	trainSegments segments.TrainSegments,
	banditType string,
	icookie string,
	contextProvider SegmentContextProvider,
) error {
	const funcName string = "FeeService.ApplyTrainSegmentsFee"
	contextsWithPrices, placeByTicketPriceIndex := s.buildContextsWithPrices(trainSegments, icookie, contextProvider)

	chargeRsp, err := s.banditClient.GetCharge(ctx, &banditpb.TGetChargeRequest{
		BanditType:         banditType,
		ContextsWithPrices: contextsWithPrices,
	})
	if err != nil {
		return fmt.Errorf("%s: cannot GetCharge from bandit: %w", funcName, err)
	}

	err = s.updatePlacesFromBandit(placeByTicketPriceIndex, chargeRsp)
	if err != nil {
		return fmt.Errorf("%s: cannot update places from bandit: %w", funcName, err)
	}
	return nil
}

/*
Из trainSegments создает список banditpb.TTicketPrices для бандита
и map[uint32][]*tariffs.TrainPlace для обновления наценок в сегментах по данным из бандита
*/
func (s *FeeService) buildContextsWithPrices(
	trainSegments segments.TrainSegments, icookie string, contextProvider SegmentContextProvider,
) ([]*banditpb.TTicketPrices, map[uint32][]*tariffs.TrainPlace) {
	var contextsWithPrices []*banditpb.TTicketPrices
	placeByTicketPriceIndex := make(map[uint32][]*tariffs.TrainPlace)
	ticketPricesIndexByContext := make(map[string]int32)
	ticketPriceIndexByStringKey := make(map[string]uint32)
	var ticketPricesIndexCounter uint32 = 0
	var ticketPriceIndexCounter uint32 = 0
	for _, segment := range trainSegments {
		for _, place := range segment.Places {
			banditContext := contextProvider(segment, place, icookie)
			banditContextStr := banditContext.String()
			ticketPricesIndex, ok := ticketPricesIndexByContext[banditContextStr]
			var ticketPrices *banditpb.TTicketPrices
			if !ok {
				ticketPricesIndex = int32(ticketPricesIndexCounter)
				ticketPrices = &banditpb.TTicketPrices{
					InternalId:   ticketPricesIndexCounter,
					Context:      banditContext,
					TicketPrices: make(map[uint32]*banditpb.TTicketPrice),
				}
				contextsWithPrices = append(contextsWithPrices, ticketPrices)
				ticketPricesIndexByContext[banditContextStr] = ticketPricesIndex
				ticketPricesIndexCounter++
			} else {
				ticketPrices = contextsWithPrices[ticketPricesIndex]
			}
			ticketPrice := &banditpb.TTicketPrice{
				Amount:        place.PriceDetails.TicketPrice,
				ServiceAmount: place.PriceDetails.ServicePrice,
			}
			ticketPriceStringKey := fmt.Sprintf("%v:%s", ticketPricesIndex, ticketPrice.String())
			ticketPriceIndex, ok := ticketPriceIndexByStringKey[ticketPriceStringKey]
			if !ok {
				ticketPriceIndex = ticketPriceIndexCounter
				ticketPrices.TicketPrices[ticketPriceIndex] = ticketPrice
				ticketPriceIndexByStringKey[ticketPriceStringKey] = ticketPriceIndex
				ticketPriceIndexCounter++
			}
			placeByTicketPriceIndex[ticketPriceIndex] = append(placeByTicketPriceIndex[ticketPriceIndex], place)
		}
	}
	return contextsWithPrices, placeByTicketPriceIndex
}

func (s *FeeService) updatePlacesFromBandit(targetPlaces map[uint32][]*tariffs.TrainPlace, sourceBanditRsp *banditpb.TGetChargeResponse) error {
	const funcName string = "FeeService.updatePlacesFromBandit"
	for _, charge := range sourceBanditRsp.ChargesByContexts {
		for ticketFeeIndex, ticketFee := range charge.TicketFees {
			for _, place := range targetPlaces[ticketFeeIndex] {
				fee, err := helpers.Sum(ticketFee.Fee, ticketFee.ServiceFee)
				if err != nil {
					return fmt.Errorf("%s: fee sum error: %w", funcName, err)
				}
				place.PriceDetails.Fee = fee
				totalPrice, err := helpers.Sum(ticketFee.TicketPrice.Amount, fee)
				if err != nil {
					return fmt.Errorf("%s: price sum error: %w", funcName, err)
				}
				place.Price = totalPrice
			}
		}
	}
	return nil
}

func SegmentFullContextProvider(segment *segments.TrainSegment, place *tariffs.TrainPlace, icookie string) *banditpb.TContext {
	bc := SegmentShortContextProvider(segment, place, icookie)
	bc.ICookie = icookie
	bc.Departure = &timestamppb.Timestamp{Seconds: segment.DepartureLocalDt.Unix()}
	bc.Arrival = &timestamppb.Timestamp{Seconds: segment.ArrivalLocalDt.Unix()}
	return bc
}

func SegmentShortContextProvider(segment *segments.TrainSegment, place *tariffs.TrainPlace, icookie string) *banditpb.TContext {
	trainType := ""
	if segment.TrainBrand != nil {
		trainType = strings.ToUpper(segment.TrainBrand.TitleDefault)
	}
	return &banditpb.TContext{
		PointFrom: fmt.Sprintf("s%v", segment.DepartureStation.StationId),
		PointTo:   fmt.Sprintf("s%v", segment.ArrivalStation.StationId),
		TrainType: trainType,
		CarType:   place.CoachType,
	}
}
