package app

import (
	"fmt"
	"io/ioutil"
	"os"

	"a.yandex-team.ru/library/go/core/log/zap"
	tpb "a.yandex-team.ru/travel/proto"
	"github.com/golang/protobuf/proto"
	"gopkg.in/yaml.v2"

	"a.yandex-team.ru/travel/buses/backend/internal/common/dict"
	pb "a.yandex-team.ru/travel/buses/backend/proto"
	wpb "a.yandex-team.ru/travel/buses/backend/proto/worker"
)

type (
	CommonData struct {
		ServiceID int32 `yaml:"service-id"`
	}
	RatesData struct {
		YandexFee float32 `yaml:"yandex-fee"`
		Revenue   float32 `yaml:"revenue"`
	}
	ProductIdsData struct {
		Ticket     string `yaml:"ticket"`
		PartnerFee string `yaml:"partner-fee"`
		YandexFee  string `yaml:"yandex-fee"`
	}
	PartnerData struct {
		Rates      RatesData
		ProductIds ProductIdsData `yaml:"product-ids"`
	}
	BillingData struct {
		Common   CommonData
		Partners map[string]PartnerData
	}
)

func LoadBillingData(configPath string) (*BillingData, error) {
	d := &BillingData{}
	f, err := os.Open(configPath)
	if err != nil {
		return d, fmt.Errorf("LoadBillingData error: %w", err)
	}
	buf, err := ioutil.ReadAll(f)
	if err != nil {
		return d, fmt.Errorf("LoadBillingData error: %w", err)
	}
	err = yaml.Unmarshal(buf, d)
	if err != nil {
		return d, fmt.Errorf("LoadBillingData error: %w", err)
	}

	return d, nil
}

func (billingData *BillingData) calcYandexFee(rates *RatesData, price *tpb.TPrice) *tpb.TPrice {
	yandexFee := int64(float32(price.Amount) * rates.YandexFee * 0.01)
	return &tpb.TPrice{Amount: yandexFee, Currency: price.Currency, Precision: price.Precision}
}

func (billingData *BillingData) getRideRates(ride *pb.TRide) (*RatesData, error) {
	const funcName = "getRideRates"
	supplier, err := dict.GetSupplier(ride.Supplier.GetID())
	if err != nil {
		return nil, fmt.Errorf("%s: %w", funcName, err)
	}
	partnerData, ok := billingData.Partners[supplier.Name]
	if !ok {
		return nil, fmt.Errorf("%s: no billing data for supplier %s", funcName, supplier.Name)
	}
	return &partnerData.Rates, nil
}

func (billingData *BillingData) addYandexFeeToRide(rates *RatesData, ride *pb.TRide) *pb.TRide {
	rideCopy := proto.Clone(ride).(*pb.TRide)
	rideCopy.YandexFee = billingData.calcYandexFee(rates, rideCopy.Price)
	rideCopy.Price.Amount += rideCopy.YandexFee.Amount
	return rideCopy
}

func (billingData *BillingData) addYandexFeeToBookParams(rates *RatesData, bookParams *pb.TBookParams) *pb.TBookParams {
	bookParamsCopy := proto.Clone(bookParams).(*pb.TBookParams)
	for _, ticketType := range bookParamsCopy.TicketTypes {
		ticketType.YandexFee = billingData.calcYandexFee(rates, ticketType.Price)
		ticketType.Price.Amount += ticketType.YandexFee.Amount
	}
	return bookParamsCopy
}

func (billingData *BillingData) WithYandexFee(ride *pb.TRide, bookParams *pb.TBookParams) (*pb.TRide, *pb.TBookParams, error) {
	rates, err := billingData.getRideRates(ride)
	if err != nil {
		return nil, nil, err
	}
	return billingData.addYandexFeeToRide(rates, ride), billingData.addYandexFeeToBookParams(rates, bookParams), nil
}

func (billingData *BillingData) RidesWithYandexFee(rides []*pb.TRide, logger *zap.Logger) []*pb.TRide {
	result := make([]*pb.TRide, len(rides))
	for i, ride := range rides {
		if rideWithFee, err := billingData.RideWithYandexFee(ride); err == nil {
			result[i] = rideWithFee
		} else {
			logger.Errorf("cannot add yandex fee: %s", err.Error())
			result[i] = ride
		}
	}
	return result
}

func (billingData *BillingData) RideWithYandexFee(ride *pb.TRide) (*pb.TRide, error) {
	rates, err := billingData.getRideRates(ride)
	if err != nil {
		return nil, fmt.Errorf("RideWithYandexFee: %w", err)
	}
	return billingData.addYandexFeeToRide(rates, ride), nil
}

func (billingData *BillingData) GetRevenue(supplierName string) float32 {
	partnerBilling, ok := billingData.Partners[supplierName]
	if ok {
		return partnerBilling.Rates.Revenue
	}
	return 0.0
}

func (billingData *BillingData) WithRevenue(order *wpb.TOrder, supplierID uint32) (*wpb.TOrder, error) {
	supplier, err := dict.GetSupplier(supplierID)
	if err != nil {
		return nil, fmt.Errorf("WithRevenue: %w", err)
	}
	supplierName := supplier.Name
	revenueRate := float64(billingData.GetRevenue(supplierName)) * 0.01
	order = proto.Clone(order).(*wpb.TOrder)
	for _, ticket := range order.Tickets {
		if ticket.Revenue == nil {
			ticket.Revenue = &tpb.TPrice{
				Amount:    int64(revenueRate * float64(ticket.Price.GetAmount())),
				Currency:  ticket.Price.GetCurrency(),
				Precision: ticket.Price.GetPrecision(),
			}
		}
	}
	return order, nil
}
