package factories

import (
	"time"

	"a.yandex-team.ru/travel/proto/dicts/rasp"
	dictfactories "a.yandex-team.ru/travel/trains/search_api/internal/pkg/dict/factories"
	"a.yandex-team.ru/travel/trains/search_api/internal/pkg/dict/registry"
)

const (
	defaultStartTime = 0 * time.Minute
	defaultStayTime  = 60 * time.Minute
	timeBetweenStops = 120 * time.Minute
)

type scheduleStop struct {
	station  *rasp.TStation
	stayTime time.Duration
}

type ScheduleFactory struct {
	thread       *rasp.TThread
	startTime    time.Duration
	stops        []scheduleStop
	repoRegistry *registry.RepositoryRegistry
}

func NewScheduleFactory(repoRegistry *registry.RepositoryRegistry) *ScheduleFactory {
	return &ScheduleFactory{
		repoRegistry: repoRegistry,
		startTime:    defaultStartTime,
	}
}

func (f *ScheduleFactory) Create() *rasp.TThread {
	if f.thread == nil {
		f.thread = dictfactories.NewThreadFactory(f.repoRegistry).Create()
	}
	if len(f.stops) == 0 {
		f.fillDefaultStops()
	} else if len(f.stops) < 2 {
		panic("too few stops in schedule")
	}

	f.createStops()
	return f.thread
}

func (f *ScheduleFactory) AddStop(station *rasp.TStation, stayTime time.Duration) *ScheduleFactory {
	f.stops = append(f.stops, scheduleStop{
		station:  station,
		stayTime: stayTime,
	})
	return f
}

func (f *ScheduleFactory) WithThread(thread *rasp.TThread) *ScheduleFactory {
	f.thread = thread
	return f
}

func (f *ScheduleFactory) fillDefaultStops() {
	var stops []scheduleStop
	for i := 0; i < 2; i++ {
		stops = append(stops, scheduleStop{
			station:  dictfactories.NewStationFactory(f.repoRegistry).Create(),
			stayTime: defaultStayTime,
		})
	}
	f.stops = stops
}

func (f *ScheduleFactory) createStops() {
	dictfactories.NewThreadStationFactory(f.repoRegistry).
		WithThread(f.thread).
		WithStation(f.stops[0].station).
		WithDeparture(f.startTime).
		Create()

	arrivalTime := f.startTime + timeBetweenStops
	for i := 1; i < len(f.stops)-1; i++ {
		departureTime := arrivalTime + f.stops[i].stayTime

		dictfactories.NewThreadStationFactory(f.repoRegistry).
			WithThread(f.thread).
			WithStation(f.stops[i].station).
			WithArrival(arrivalTime).
			WithDeparture(departureTime).
			Create()
		arrivalTime = departureTime + timeBetweenStops
	}

	dictfactories.NewThreadStationFactory(f.repoRegistry).
		WithThread(f.thread).
		WithStation(f.stops[len(f.stops)-1].station).
		WithArrival(arrivalTime).
		Create()
}
