package schedule

import (
	"a.yandex-team.ru/library/go/core/log"
	"a.yandex-team.ru/travel/proto/dicts/rasp"
	"a.yandex-team.ru/travel/trains/search_api/internal/pkg/date/runmask"
	"a.yandex-team.ru/travel/trains/search_api/internal/pkg/dict/registry"
)

type Cache struct {
	logger   log.Structured
	registry *registry.RepositoryRegistry

	threads          map[int]*rasp.TThread
	stationStops     pointToStops
	settlementStops  pointToStops
	threadsArrival   map[int]*rasp.TThreadStation
	threadsDeparture map[int]*rasp.TThreadStation
}

func NewCache(logger log.Structured, registry *registry.RepositoryRegistry) *Cache {
	return &Cache{
		logger:   logger,
		registry: registry,

		threads:          make(map[int]*rasp.TThread),
		stationStops:     make(pointToStops),
		settlementStops:  make(pointToStops),
		threadsArrival:   make(map[int]*rasp.TThreadStation),
		threadsDeparture: make(map[int]*rasp.TThreadStation),
	}
}

func (c *Cache) Build() *Cache {
	stationSettlementIDs := c.getStationSettlementIDs()
	for _, thread := range c.getEnabledThreads() {
		threadID := int(thread.Id)
		c.threads[threadID] = thread

		for _, threadStation := range c.getThreadStations(thread.Id) {
			if !threadStation.HasArrival {
				c.threadsDeparture[threadID] = threadStation
			}

			if !threadStation.HasDeparture {
				c.threadsArrival[threadID] = threadStation
			}

			if threadStation.HasDeparture == threadStation.HasArrival &&
				threadStation.DepartureTz == threadStation.ArrivalTz {
				c.logger.Debug("skip caching threadStation because DepartureTz == ArrivalTz")
				continue
			}

			stationID := int(threadStation.StationId)
			c.stationStops.addThreadStation(stationID, threadID, threadStation)
			for settlementID := range stationSettlementIDs[stationID] {
				c.settlementStops.addThreadStation(settlementID, threadID, threadStation)
			}
		}
	}
	return c
}

func (c *Cache) getStationSettlementIDs() map[int]map[int]struct{} {
	stationSettlementIDs := make(map[int]map[int]struct{})

	addPair := func(stationID int, settlementID int) {
		settlementIDs, found := stationSettlementIDs[stationID]
		if !found {
			stationSettlementIDs[stationID] = make(map[int]struct{})
			settlementIDs = stationSettlementIDs[stationID]
		}
		settlementIDs[settlementID] = struct{}{}
	}

	for _, station := range c.registry.GetStationRepo().All() {
		if station.SettlementId == 0 {
			continue
		}

		addPair(int(station.Id), int(station.SettlementId))
	}

	for _, link := range c.registry.GetStationToSettlementRepo().All() {
		addPair(int(link.StationId), int(link.SettlementId))
	}
	return stationSettlementIDs
}

func (c *Cache) getEnabledThreads() []*rasp.TThread {
	routeRepo := c.registry.GetRouteRepo()

	var result []*rasp.TThread
	for _, thread := range c.registry.GetThreadRepo().All() {
		if thread.TransportType != rasp.TTransport_TYPE_TRAIN {
			continue
		}
		if thread.Type == rasp.TThread_TYPE_CANCEL {
			continue
		}
		if runmask.IsEmptyYearDays(thread.YearDays) {
			continue
		}
		if thread.IsHidden {
			continue
		}
		if route, found := routeRepo.Get(thread.RouteId); found && route.IsHidden {
			continue
		}

		result = append(result, thread)
	}
	return result
}

func (c *Cache) getThreadStations(threadID int32) []*rasp.TThreadStation {
	var result []*rasp.TThreadStation
	for _, threadStation := range c.registry.GetThreadStationRepo().GetByThreadStationID(threadID) {
		if threadStation.IsTechnicalStop {
			continue
		}

		station, _ := c.registry.GetStationRepo().Get(threadStation.StationId)
		if station == nil || station.MajorityId >= int32(rasp.TStation_MAJORITY_NOT_IN_SEARCH) {
			continue
		}
		result = append(result, threadStation)
	}
	return result
}
