package registry

import (
	"fmt"

	"github.com/golang/protobuf/proto"

	"a.yandex-team.ru/library/go/core/log"
	"a.yandex-team.ru/travel/library/go/resourcestorage"
	"a.yandex-team.ru/travel/trains/search_api/internal/pkg/dict/rasp"
)

type dictUpdateObserver interface {
	OnDictUpdate(dictType DictType)
}

type RepositoryRegistry struct {
	logger        log.Logger
	repositories  map[DictType]*repoWrapper
	repoObservers map[DictType][]dictUpdateObserver
}

func NewRepositoryRegistry(logger log.Logger) *RepositoryRegistry {
	registry := &RepositoryRegistry{
		logger:        logger,
		repositories:  make(map[DictType]*repoWrapper),
		repoObservers: make(map[DictType][]dictUpdateObserver),
	}
	registry.addRepo(Country, func() repo { return rasp.NewCountryRepository() })
	registry.addRepo(Currency, func() repo { return rasp.NewCurrencyRepository() })
	registry.addRepo(NamedTrain, func() repo { return rasp.NewNamedTrainRepository() })
	registry.addRepo(Region, func() repo { return rasp.NewRegionRepository() })
	registry.addRepo(Route, func() repo { return rasp.NewRouteRepository() })
	registry.addRepo(Settlement, func() repo { return rasp.NewSettlementRepository() })
	registry.addRepo(StationToSettlement, func() repo { return rasp.NewStationToSettlementRepository() })
	registry.addRepo(Station, func() repo { return rasp.NewStationRepository() })
	registry.addRepo(StationCode, func() repo { return rasp.NewStationCodeRepository() })
	registry.addRepo(Thread, func() repo { return rasp.NewThreadRepository() })
	registry.addRepo(ThreadStation, func() repo { return rasp.NewThreadStationRepository() })
	registry.addRepo(TimeZone, func() repo { return rasp.NewTimeZoneRepository(logger) })
	return registry
}

func (r *RepositoryRegistry) addRepo(code DictType, newFn func() repo) {
	r.repositories[code] = newRepoWrapper(code, newFn)
}

func (r *RepositoryRegistry) GetCountryRepo() *rasp.CountryRepository {
	return r.repositories[Country].loadRepo().(*rasp.CountryRepository)
}

func (r *RepositoryRegistry) GetCurrencyRepo() *rasp.CurrencyRepository {
	return r.repositories[Currency].loadRepo().(*rasp.CurrencyRepository)
}

func (r *RepositoryRegistry) GetNamedTrainRepo() *rasp.NamedTrainRepository {
	return r.repositories[NamedTrain].loadRepo().(*rasp.NamedTrainRepository)
}

func (r *RepositoryRegistry) GetRegionRepo() *rasp.RegionRepository {
	return r.repositories[Region].loadRepo().(*rasp.RegionRepository)
}

func (r *RepositoryRegistry) GetRouteRepo() *rasp.RouteRepository {
	return r.repositories[Route].loadRepo().(*rasp.RouteRepository)
}

func (r *RepositoryRegistry) GetSettlementRepo() *rasp.SettlementRepository {
	return r.repositories[Settlement].loadRepo().(*rasp.SettlementRepository)
}

func (r *RepositoryRegistry) GetStationToSettlementRepo() *rasp.StationToSettlementRepository {
	return r.repositories[StationToSettlement].loadRepo().(*rasp.StationToSettlementRepository)
}

func (r *RepositoryRegistry) GetStationRepo() *rasp.StationRepository {
	return r.repositories[Station].loadRepo().(*rasp.StationRepository)
}

func (r *RepositoryRegistry) GetStationCodeRepo() *rasp.StationCodeRepository {
	return r.repositories[StationCode].loadRepo().(*rasp.StationCodeRepository)
}

func (r *RepositoryRegistry) GetThreadRepo() *rasp.ThreadRepository {
	return r.repositories[Thread].loadRepo().(*rasp.ThreadRepository)
}

func (r *RepositoryRegistry) GetThreadStationRepo() *rasp.ThreadStationRepository {
	return r.repositories[ThreadStation].loadRepo().(*rasp.ThreadStationRepository)
}

func (r *RepositoryRegistry) GetTimeZoneRepo() *rasp.TimeZoneRepository {
	return r.repositories[TimeZone].loadRepo().(*rasp.TimeZoneRepository)
}

func (r *RepositoryRegistry) AddRepositoryObserver(code DictType, observer dictUpdateObserver) {
	r.repoObservers[code] = append(r.repoObservers[code], observer)
}

func (r *RepositoryRegistry) getUpdateFn(code DictType) DictLoadFn {
	_, found := r.repositories[code]
	if !found {
		panic(fmt.Sprintf("unexpected dict code: %s", code))
	}
	return func(loader *resourcestorage.Loader) error {
		return r.updateRepo(code, loader)
	}
}

func (r *RepositoryRegistry) getSample(code DictType) proto.Message {
	return r.repositories[code].loadRepo().GetSample()
}

func (r *RepositoryRegistry) updateRepo(code DictType, l *resourcestorage.Loader) error {
	wrapper := r.repositories[code]
	repo := wrapper.newRepo()

	count, err := l.Load(repo)
	if err != nil {
		return fmt.Errorf("RepositoryRegistry.update %s: %v", wrapper.code, err)
	}
	wrapper.storeRepo(repo)
	for _, observer := range r.repoObservers[code] {
		observer.OnDictUpdate(code)
	}
	r.logger.Structured().Info("repository updated", log.Int("message_count", count))
	return nil
}
