package pricing

import (
	"fmt"
	"sort"
	"time"

	common "a.yandex-team.ru/travel/proto"

	"a.yandex-team.ru/travel/rasp/train_bandit_api/internal/model"
	"a.yandex-team.ru/travel/rasp/train_bandit_api/pkg/utils"
	pb "a.yandex-team.ru/travel/rasp/train_bandit_api/proto"
)

type PricingConfig struct {
	Currency          int32              `config:"pricing-currency" yaml:"currency"`
	Precision         int32              `config:"pricing-precision" yaml:"precision"`
	PartnerFeeHistory []PartnerFeeChange `yaml:"partner-fee-history"`
}

type PartnerFeeChange struct {
	Amount float64   `yaml:"amount"`
	From   time.Time `yaml:"from"`
}

var MskLocation, _ = time.LoadLocation("Europe/Moscow")
var DefaultPricingConfig = PricingConfig{
	Currency:  int32(common.ECurrency_C_RUB),
	Precision: 2,
	PartnerFeeHistory: []PartnerFeeChange{
		{Amount: 31.70, From: time.Time{}},
		{Amount: 34.70, From: time.Date(2020, 8, 1, 0, 0, 0, 0, MskLocation)},
		{Amount: 36.20, From: time.Date(2022, 1, 1, 0, 0, 0, 0, MskLocation)},
	},
}

type Pricing struct {
	config            *PricingConfig
	PartnerFeeHistory []PartnerFeeChange
}

func New(config *PricingConfig) (pricing *Pricing, err error) {
	if len(config.PartnerFeeHistory) == 0 {
		return nil, fmt.Errorf("partner fees not set")
	}
	partnerFeeHistory := make([]PartnerFeeChange, len(config.PartnerFeeHistory))
	copy(partnerFeeHistory, config.PartnerFeeHistory)
	sort.Slice(partnerFeeHistory, func(i, j int) bool {
		return partnerFeeHistory[i].From.After(partnerFeeHistory[j].From)
	})
	return &Pricing{
		config:            config,
		PartnerFeeHistory: partnerFeeHistory,
	}, nil
}

func (p *Pricing) toProtoPrice(value float64) *common.TPrice {
	return utils.ToProtoPrice(value, common.ECurrency(p.config.Currency), p.config.Precision)
}

func (p *Pricing) CalculateFees(charge *pb.TCharge, banditFeeApplied bool, now time.Time) (*pb.TCharge, error) {
	ratio := PermilleToRatio(charge.Permille)

	pf, _, err := p.getPartnerFees(now)
	if err != nil {
		return nil, err
	}

	for _, ticketFee := range charge.TicketFees {
		if ticketFee.TicketPrice.Amount.Amount == 0 {
			ticketFee.ServiceFee = p.toProtoPrice(0.)
			ticketFee.Fee = p.toProtoPrice(0.)
			ticketFee.IsBanditFeeApplied = banditFeeApplied
			continue
		}
		serviceAmount := utils.FromProtoPrice(ticketFee.TicketPrice.ServiceAmount)
		amount := utils.FromProtoPrice(ticketFee.TicketPrice.Amount)
		if amount <= serviceAmount || charge.Context.CarType != model.CarTypePlatzkarte {
			serviceAmount = 0.
		}
		tariffAmount := amount - serviceAmount
		serviceFee := serviceAmount * ratio
		fee := tariffAmount * ratio

		ticketFee.ServiceFee = p.toProtoPrice(serviceFee)
		ticketFee.Fee = p.toProtoPrice(fee)
		ticketFee.IsBanditFeeApplied = banditFeeApplied

		if charge.OptionalMinTariffPermille != nil {
			minimalFee := tariffAmount*PermilleToRatio(charge.GetMinTariffPermille()) + pf
			if minimalFee > fee {
				ticketFee.Fee = p.toProtoPrice(minimalFee)
				ticketFee.IsBanditFeeApplied = false
			}
		}
	}

	return charge, nil
}

func (p *Pricing) getPartnerFees(now time.Time) (partnerFee float64, partnerRefundFee float64, err error) {
	for _, fee := range p.PartnerFeeHistory {
		if now.After(fee.From) {
			return fee.Amount, fee.Amount, nil
		}
	}
	return 0, 0, fmt.Errorf("not found partner fee")
}

func (p *Pricing) GetPartnerFees(now time.Time) (partnerFee *common.TPrice, partnerRefundFee *common.TPrice, err error) {
	pf, prf, err := p.getPartnerFees(now)
	if err != nil {
		return nil, nil, err
	}
	return p.toProtoPrice(pf), p.toProtoPrice(prf), nil
}

func PermilleToRatio(permille uint32) float64 {
	return float64(permille) / 1000
}
