package schedule

import (
	"context"
	"math"
	"sync"
	"sync/atomic"
	"time"

	"github.com/opentracing/opentracing-go"

	"a.yandex-team.ru/library/go/core/log"
	"a.yandex-team.ru/travel/proto/dicts/rasp"
	"a.yandex-team.ru/travel/trains/search_api/internal/pkg/dict/registry"
	"a.yandex-team.ru/travel/trains/search_api/internal/pkg/limitcondition"
	"a.yandex-team.ru/travel/trains/search_api/internal/pkg/points"
	"a.yandex-team.ru/travel/trains/search_api/internal/pkg/updater"
)

const maxMajorityID = math.MaxInt32

type Repository struct {
	logger       log.Structured
	registry     *registry.RepositoryRegistry
	segmentMaker SegmentProvider

	updateMutex sync.Mutex
	cacheValue  atomic.Value
	validCache  bool
}

func NewRepository(logger log.Logger, repoRegistry *registry.RepositoryRegistry) *Repository {
	repo := &Repository{
		logger:   logger.Structured(),
		registry: repoRegistry,
		segmentMaker: SegmentProvider{
			logger:   logger,
			registry: repoRegistry,
		},
	}

	repoRegistry.AddRepositoryObserver(registry.Station, repo)
	repoRegistry.AddRepositoryObserver(registry.Thread, repo)
	repoRegistry.AddRepositoryObserver(registry.ThreadStation, repo)
	return repo
}

func (r *Repository) OnDictUpdate(_ registry.DictType) {
	r.validCache = false
}

func (r *Repository) UpdateCache() error {
	r.updateMutex.Lock()
	defer r.updateMutex.Unlock()

	if r.validCache {
		return updater.AlreadyUpdated
	}

	cache := NewCache(r.logger, r.registry).Build()
	r.cacheValue.Store(cache)
	r.validCache = true
	return nil
}

func (r *Repository) FindSegments(ctx context.Context, departure, arrival points.Point, minDepartureDate time.Time) Segments {
	span, _ := opentracing.StartSpanFromContext(ctx, repositoryFindSegmentsFnCaller.String())
	defer span.Finish()

	condition := limitcondition.New(
		departure, arrival,
		limitcondition.WithTransportTypes(rasp.TTransport_TYPE_TRAIN),
	)
	if condition.IsEmptyTransportTypes() {
		r.logger.Debug("empty transport types list")
		return nil
	}

	segments := r.findSegments(r.getPointThreadStops(departure), r.getPointThreadStops(arrival), condition)
	segments = excludeThroughTrains(segments)

	var result Segments
	for _, segment := range segments {
		result = append(result,
			r.segmentMaker.SegmentsByThreadRuns(segment, minDepartureDate)...,
		)
	}
	return result
}

func (r *Repository) makeSegment(thread *rasp.TThread, departure *rasp.TThreadStation, arrival *rasp.TThreadStation) *rawSegment {
	s := &rawSegment{
		thread:    thread,
		departure: departure,
		arrival:   arrival,
	}

	stationRepo := r.registry.GetStationRepo()
	s.departureMajorityID = maxMajorityID
	if station, found := stationRepo.Get(s.departure.StationId); found {
		s.departureMajorityID = station.MajorityId
	}
	s.arrivalMajorityID = maxMajorityID
	if station, found := stationRepo.Get(s.arrival.StationId); found {
		s.arrivalMajorityID = station.MajorityId
	}
	return s
}

func (r *Repository) findSegments(departureThreadStations threadToStations, arrivalThreadStations threadToStations, condition *limitcondition.LimitCondition) rawSegments {
	if len(departureThreadStations) == 0 || len(arrivalThreadStations) == 0 {
		return nil
	}

	threadIDs := make(map[int]struct{})
	for id := range departureThreadStations {
		if _, found := arrivalThreadStations[id]; found {
			threadIDs[id] = struct{}{}
		}
	}

	cache := r.getCache()

	var segments rawSegments
	for threadID := range threadIDs {
		thread := cache.threads[threadID]
		threadSegments := r.findThreadSegments(thread, departureThreadStations[threadID], arrivalThreadStations[threadID], condition)
		segments = append(segments, threadSegments...)
	}
	return segments
}

func (r *Repository) findThreadSegments(thread *rasp.TThread, departureStations []*rasp.TThreadStation, arrivalStations []*rasp.TThreadStation, condition *limitcondition.LimitCondition) rawSegments {
	if len(departureStations) == 0 || len(arrivalStations) == 0 {
		return nil
	}

	var segments rawSegments
	for _, departure := range departureStations {
		for _, arrival := range arrivalStations {
			if departure.Id < arrival.Id {
				segments = append(segments, r.makeSegment(thread, departure, arrival))
			}
		}
	}

	segments = filterSegmentsByMajority(r.logger, segments, condition)
	return findBestSegments(segments)
}

func (r *Repository) getPointThreadStops(point points.Point) threadToStations {
	if point.Type() == points.StationType {
		return r.getCache().stationStops[int(point.ID())]
	}
	return r.getCache().settlementStops[int(point.ID())]
}

func (r *Repository) getCache() *Cache {
	return r.cacheValue.Load().(*Cache)
}
