package pricecalendar

import (
	"context"
	"fmt"
	"sync"
	"time"

	"a.yandex-team.ru/library/go/core/log"
	"a.yandex-team.ru/library/go/units"
	"a.yandex-team.ru/travel/library/go/httputil"
	commonModels "a.yandex-team.ru/travel/trains/library/go/httputil/clients/common/models"

	pcpb "a.yandex-team.ru/travel/trains/search_api/api/price_calendar"
	"a.yandex-team.ru/travel/trains/search_api/internal/direction/segments"
	"a.yandex-team.ru/travel/trains/search_api/internal/pkg/date"
	"a.yandex-team.ru/travel/trains/search_api/internal/pkg/dict/registry"
	"a.yandex-team.ru/travel/trains/search_api/internal/pkg/express"
	"a.yandex-team.ru/travel/trains/search_api/internal/pkg/feeservice"
	"a.yandex-team.ru/travel/trains/search_api/internal/pkg/helpers"
	"a.yandex-team.ru/travel/trains/search_api/internal/pkg/points"
	"a.yandex-team.ru/travel/trains/search_api/internal/pkg/railway"
	"a.yandex-team.ru/travel/trains/search_api/internal/pkg/schedule"
	tariffcache "a.yandex-team.ru/travel/trains/search_api/internal/pkg/tariffs/cache"
	"a.yandex-team.ru/travel/trains/search_api/internal/searcher/models"
)

type Service interface {
	PriceCalendar(
		ctx context.Context, pointFrom, pointTo string, userArgs models.UserArgs,
	) (*pcpb.PriceCalendarResponse, error)
	PriceCalendarRange(
		ctx context.Context, pointFrom, pointTo, dateFrom, dateTo string, userArgs models.UserArgs,
	) (*pcpb.PriceCalendarResponse, error)
}

const MaxSaleDepthDays = 120

var emptyPriceReasonImportance = map[pcpb.EmptyPriceReason]int32{
	pcpb.EmptyPriceReason_EMPTY_PRICE_REASON_INVALID:          100,
	pcpb.EmptyPriceReason_EMPTY_PRICE_REASON_SOLD_OUT:         60,
	pcpb.EmptyPriceReason_EMPTY_PRICE_REASON_OTHER:            30,
	pcpb.EmptyPriceReason_EMPTY_PRICE_REASON_NO_DIRECT_TRAINS: 0,
}

type TariffCacheBasedService struct {
	pointParser        *points.Parser
	tariffCache        *tariffcache.TariffCache
	repoRegistry       *registry.RepositoryRegistry
	expressRepository  *express.Repository
	segmentsMapper     *segments.Mapper
	scheduleRepository *schedule.Repository
	feeService         *feeservice.FeeService
	logger             log.Logger
}

func NewService(
	pointParser *points.Parser,
	tariffCache *tariffcache.TariffCache,
	repoRegistry *registry.RepositoryRegistry,
	expressRepository *express.Repository,
	scheduleRepository *schedule.Repository,
	feeService *feeservice.FeeService,
	logger log.Logger,
) *TariffCacheBasedService {
	return &TariffCacheBasedService{
		pointParser:        pointParser,
		tariffCache:        tariffCache,
		repoRegistry:       repoRegistry,
		expressRepository:  expressRepository,
		segmentsMapper:     segments.NewMapper(logger, repoRegistry),
		scheduleRepository: scheduleRepository,
		feeService:         feeService,
		logger:             logger,
	}
}

func (s *TariffCacheBasedService) PriceCalendar(
	ctx context.Context, pointFrom, pointTo string, userArgs models.UserArgs,
) (*pcpb.PriceCalendarResponse, error) {
	return s.PriceCalendarRange(ctx, pointFrom, pointTo, "", "", userArgs)
}

func (s *TariffCacheBasedService) PriceCalendarRange(
	ctx context.Context, pointFrom, pointTo, dateFrom, dateTo string, userArgs models.UserArgs,
) (*pcpb.PriceCalendarResponse, error) {
	funcName := "TariffCacheBasedService.PriceCalendar"

	from, err := s.pointParser.ParseByPointKey(pointFrom)
	if err != nil {
		s.logger.Errorf("%s: not found pointFrom %s: %s", funcName, pointFrom, err.Error())
		return nil, &PointNotFoundError{paramName: "pointFrom", paramValue: pointFrom}
	}
	to, err := s.pointParser.ParseByPointKey(pointTo)
	if err != nil {
		s.logger.Errorf("%s: not found pointTo %s: %s", funcName, pointTo, err.Error())
		return nil, &PointNotFoundError{paramName: "pointTo", paramValue: pointTo}
	}
	experiments, err := httputil.ParseExperiments(userArgs.UaasExperiments)
	if err != nil {
		s.logger.Errorf("%s: error parse experiments='%s': %s", funcName, userArgs.UaasExperiments, err.Error())
	}

	waitGroup := sync.WaitGroup{}
	waitGroup.Add(2)
	var forwardPrices, backwardPrices *pcpb.DirectionPrices
	var forwardErr, backwardErr error
	go func() {
		forwardPrices, forwardErr = s.oneWayPrices(ctx, from, to, dateFrom, dateTo, experiments.TrainsBanditType, userArgs.Icookie)
		waitGroup.Done()
	}()
	go func() {
		backwardPrices, backwardErr = s.oneWayPrices(ctx, to, from, dateFrom, dateTo, experiments.TrainsBanditType, userArgs.Icookie)
		waitGroup.Done()
	}()
	waitGroup.Wait()

	if forwardErr != nil {
		return nil, fmt.Errorf("%s: error getting forward prices (%s-->%s): %w", funcName, from.Slug(), to.Slug(), err)
	}
	if backwardErr != nil {
		return nil, fmt.Errorf("%s: error getting backward prices (%s<--%s): %w", funcName, from.Slug(), to.Slug(), err)
	}
	return &pcpb.PriceCalendarResponse{
		Forward: forwardPrices, Backward: backwardPrices,
	}, nil
}

func (s *TariffCacheBasedService) oneWayPrices(
	ctx context.Context, from, to points.Point, dateFrom, dateTo, banditType, icookie string,
) (*pcpb.DirectionPrices, error) {
	funcName := "TariffCacheBasedService.oneWayPrices"

	protoZone, ok := s.repoRegistry.GetTimeZoneRepo().Get(from.TimeZoneID())
	if !ok {
		return nil, fmt.Errorf("%s: not found time zone for %v", funcName, from.Slug())
	}
	loc, err := time.LoadLocation(protoZone.Code)
	if err != nil {
		return nil, fmt.Errorf("%s: cat not load time zone %s, %s: %w", funcName, protoZone.Code, from.Slug(), err)
	}
	var timeFrom, timeTo time.Time
	if dateFrom == "" || dateTo == "" {
		now := time.Now().In(loc)
		timeFrom = date.DateFromTimeLocal(now)
		timeTo = timeFrom.Add(time.Hour * 24 * MaxSaleDepthDays)
	} else {
		timeFrom, err = date.DateFromStringInLocation(dateFrom, loc)
		if err != nil {
			return nil, fmt.Errorf("%s: can not parse date: %w", funcName, err)
		}
		timeTo, err = date.DateFromStringInLocation(dateTo, loc)
		if err != nil {
			return nil, fmt.Errorf("%s: can not parse date: %w", funcName, err)
		}
		timeTo = timeTo.Add(time.Hour * 24)
	}

	trainSegments, err := s.getTrainSegments(ctx, from, to, timeFrom, timeTo)
	if err != nil {
		return nil, fmt.Errorf("%s: error getting segments: %w", funcName, err)
	}

	feeCtx, cancel := context.WithTimeout(ctx, time.Second)
	defer cancel()
	err = s.feeService.ApplyTrainSegmentsFee(feeCtx, trainSegments, banditType, icookie, feeservice.SegmentShortContextProvider)
	if err != nil {
		s.logger.Errorf("%s: error apply fee: %s", funcName, err.Error())
	}

	maxDepartureDt := timeFrom.Add(time.Minute)
	for _, segment := range trainSegments {
		if maxDepartureDt.Before(segment.DepartureLocalDt) {
			maxDepartureDt = segment.DepartureLocalDt
		}
	}

	dayPriceList, dayPriceByDateMap := s.generateEmptyDayPriceList(timeFrom, maxDepartureDt)
	for _, segment := range trainSegments {
		departureDate := date.DateFromTimeLocal(segment.DepartureLocalDt)
		departureDateStr := departureDate.Format(date.DateISOFormat)
		if dayPrice, ok := dayPriceByDateMap[departureDateStr]; ok {
			updateDayPrice(dayPrice, segment)
		} else {
			s.logger.Errorf("%s: not found dayPrice for date=%s, from=%s, to=%s", funcName, departureDateStr, from.Slug(), to.Slug())
		}
	}

	result := &pcpb.DirectionPrices{
		Dates: dayPriceList,
	}
	return result, nil
}

func (s *TariffCacheBasedService) getTrainSegments(
	ctx context.Context,
	from points.Point, to points.Point,
	departureFrom time.Time, departureTo time.Time,
) (segments.TrainSegments, error) {
	funcName := "TariffCacheBasedService.getTrainSegments"
	railWayLocation := railway.GetLocationByPoint(from, s.repoRegistry)
	scheduleSegments := s.scheduleRepository.FindSegments(ctx, from, to, departureFrom)
	fromExpressID := int32(s.expressRepository.FindExpressID(from))
	toExpressID := int32(s.expressRepository.FindExpressID(to))
	tariffs, err := s.tariffCache.Select(ctx,
		[]int32{fromExpressID},
		[]int32{toExpressID},
		departureFrom,
		departureTo,
	)
	if err != nil {
		return nil, fmt.Errorf("%s: can not load tariffs: %w", funcName, err)
	}
	trainSegments := s.segmentsMapper.MapTrainSegments(ctx, railWayLocation, scheduleSegments, tariffs)
	return trainSegments, nil
}

func (s *TariffCacheBasedService) generateEmptyDayPriceList(timeFrom time.Time, maxDepartureDt time.Time) ([]*pcpb.DayPrice, map[string]*pcpb.DayPrice) {
	var dayPriceList []*pcpb.DayPrice
	pricesMap := make(map[string]*pcpb.DayPrice)
	currentDt := timeFrom
	for currentDt.Before(maxDepartureDt) {
		departureDateStr := currentDt.Format(date.DateISOFormat)
		dayPrice := &pcpb.DayPrice{
			Date:             departureDateStr,
			Price:            nil,
			EmptyPriceReason: pcpb.EmptyPriceReason_EMPTY_PRICE_REASON_NO_DIRECT_TRAINS,
		}
		dayPriceList = append(dayPriceList, dayPrice)
		pricesMap[departureDateStr] = dayPrice
		currentDt = currentDt.Add(units.Day)
	}
	return dayPriceList, pricesMap
}

func getEmptyPriceReason(segment *segments.TrainSegment) pcpb.EmptyPriceReason {
	emptyPriceReason := pcpb.EmptyPriceReason_EMPTY_PRICE_REASON_OTHER
	if segment.BrokenClasses != nil {
		for _, classErrors := range [][]uint32{
			segment.BrokenClasses.Common, segment.BrokenClasses.Sitting,
			segment.BrokenClasses.Platzkarte, segment.BrokenClasses.Compartment,
			segment.BrokenClasses.Soft, segment.BrokenClasses.Suite,
		} {
			for _, classError := range classErrors {
				if classError == commonModels.TariffErrorSoldOut {
					return pcpb.EmptyPriceReason_EMPTY_PRICE_REASON_SOLD_OUT
				}
			}
		}
	}
	return emptyPriceReason
}

func updateDayPrice(dayPrice *pcpb.DayPrice, segment *segments.TrainSegment) {
	segmentPrice := segment.GetMinPrice()
	segmentEmptyPriceReason := pcpb.EmptyPriceReason_EMPTY_PRICE_REASON_INVALID

	if segmentPrice == nil {
		segmentEmptyPriceReason = getEmptyPriceReason(segment)
	} else {
		if dayPrice.Price == nil {
			dayPrice.Price = segmentPrice
		} else {
			dayPrice.Price = helpers.GetMinPrice(dayPrice.Price, segmentPrice)
		}
	}
	if emptyPriceReasonImportance[dayPrice.EmptyPriceReason] < emptyPriceReasonImportance[segmentEmptyPriceReason] {
		dayPrice.EmptyPriceReason = segmentEmptyPriceReason
	}
}
