package dict

import (
	"fmt"
	"io"
	"path"
	"strconv"
	"sync"

	"github.com/golang/protobuf/proto"

	"a.yandex-team.ru/library/go/core/log/zap"
	"a.yandex-team.ru/travel/library/go/dicts/base"
	"a.yandex-team.ru/travel/library/go/dicts/repository"
	"a.yandex-team.ru/travel/proto/dicts/rasp"
)

type Config struct {
	ResourceDir           string `config:"rasp-resourcedir,required"`
	TimeZoneFN            string `config:"rasp-timezonefn,required"`
	SettlementFN          string `config:"rasp-settlementfn,required"`
	StationFN             string `config:"rasp-stationfn,required"`
	StationCodeFN         string `config:"rasp-stationcodefn,required"`
	ThreadFN              string `config:"rasp-threadfn,required"`
	StationExpressAliasFN string `config:"rasp-stationexpressaliasfn,required"`
	ThreadStationFN       string `config:"rasp-threadstationfn,required"`
	CountryFN             string `config:"rasp-countryfn,required"`
	CarrierFN             string `config:"rasp-carrierfn,required"`
}

var DefaultConfig = Config{
	ResourceDir:           "/app/rasp_data",
	TimeZoneFN:            "timezone_data/timezone.data",
	SettlementFN:          "settlement_data/settlement.data",
	StationFN:             "station_data/station.data",
	StationCodeFN:         "station_code_data/station_code.data",
	ThreadFN:              "thread_data/thread.data",
	StationExpressAliasFN: "station_express_alias_data/station_express_alias.data",
	ThreadStationFN:       "thread_station_data/thread_station.data",
	CountryFN:             "country_data/country.data",
	CarrierFN:             "carrier_data/carrier.data",
}

type DictRepo struct {
	cfg                      *Config
	logger                   *zap.Logger
	mutex                    sync.RWMutex
	timeZoneRepo             *repository.TimeZoneRepository
	settlementRepo           *repository.SettlementRepository
	stationRepo              *repository.StationRepository
	stationCodeRepo          *repository.StationCodeRepository
	threadRepo               *repository.ThreadRepository
	stationExpressAliasRepo  *repository.StationExpressAliasRepository
	threadStationRepo        *repository.ThreadStationRepository
	carrierRepo              *repository.CarrierRepository
	expressIDToID            *map[int32]int32
	geoIDIndex               *map[int32]int32
	trainNumberToThreadID    *map[string]int32
	threads                  []*rasp.TThread
	threadIDToThreadStations *map[int32][]*rasp.TThreadStation
	countryRepo              *repository.CountryRepository
	threadsByStationID       map[int32]map[int32]struct{}
}

type settlementWriter struct {
	geoIDIndex     *map[int32]int32
	settlementRepo *repository.SettlementRepository
}

func (sw *settlementWriter) Write(b []byte) (int, error) {
	settlement := &rasp.TSettlement{}
	if err := proto.Unmarshal(b, settlement); err != nil {
		return 0, fmt.Errorf("settlementWriter:Write: %w", err)
	}
	n, err := sw.settlementRepo.Write(b)
	if err != nil {
		return 0, err
	}
	(*sw.geoIDIndex)[settlement.GeoId] = settlement.Id
	return n, nil
}

type stationCodeWriter struct {
	expressIDtoID   *map[int32]int32
	stationCodeRepo *repository.StationCodeRepository
}

func (scw *stationCodeWriter) Write(b []byte) (int, error) {
	stationCode := &rasp.TStationCode{}
	if err := proto.Unmarshal(b, stationCode); err != nil {
		return 0, fmt.Errorf("stationCodeWriter:Write: %w", err)
	}
	n, err := scw.stationCodeRepo.Write(b)
	if err != nil {
		return 0, err
	}
	if stationCode.SystemId == rasp.ECodeSystem_CODE_SYSTEM_EXPRESS {
		expressID, _ := strconv.Atoi(stationCode.Code)
		(*scw.expressIDtoID)[int32(expressID)] = stationCode.StationId
	}
	return n, nil
}

type threadWriter struct {
	trainNumberToThreadID *map[string]int32
	threads               *[]*rasp.TThread
	threadRepo            *repository.ThreadRepository
}

func (tw *threadWriter) Write(b []byte) (int, error) {
	thread := &rasp.TThread{}
	if err := proto.Unmarshal(b, thread); err != nil {
		return 0, fmt.Errorf("threadWriter:Write: %w", err)
	}
	n, err := tw.threadRepo.Write(b)
	if err != nil {
		return 0, err
	}
	if thread.TransportType == rasp.TTransport_TYPE_TRAIN {
		(*tw.trainNumberToThreadID)[thread.Number] = thread.Id
		*tw.threads = append(*tw.threads, thread)
	}
	return n, nil
}

type threadStationWriter struct {
	threadIDToThreadStations *map[int32][]*rasp.TThreadStation
	threadsByStationID       map[int32]map[int32]struct{}
	threadStationRepo        *repository.ThreadStationRepository
}

func (ts *threadStationWriter) Write(b []byte) (int, error) {
	threadStation := &rasp.TThreadStation{}
	if err := proto.Unmarshal(b, threadStation); err != nil {
		return 0, fmt.Errorf("threadWriter:Write: %w", err)
	}
	n, err := ts.threadStationRepo.Write(b)
	if err != nil {
		return 0, err
	}
	(*ts.threadIDToThreadStations)[threadStation.ThreadId] = append((*ts.threadIDToThreadStations)[threadStation.ThreadId],
		threadStation)
	if ts.threadsByStationID[threadStation.StationId] == nil {
		ts.threadsByStationID[threadStation.StationId] = map[int32]struct{}{}
	}
	ts.threadsByStationID[threadStation.StationId][threadStation.ThreadId] = struct{}{}
	return n, nil
}

func NewRepo(cfg *Config, logger *zap.Logger) *DictRepo {
	geoIDIndex := make(map[int32]int32)
	expressIDToID := make(map[int32]int32)
	trainNumberToThreadID := make(map[string]int32)
	threadIDToThreadStations := make(map[int32][]*rasp.TThreadStation)
	return &DictRepo{
		cfg:                      cfg,
		logger:                   logger,
		timeZoneRepo:             repository.NewTimeZoneRepository(),
		settlementRepo:           repository.NewSettlementRepository(),
		stationRepo:              repository.NewStationRepository(),
		stationCodeRepo:          repository.NewStationCodeRepository(),
		threadRepo:               repository.NewThreadRepository(),
		stationExpressAliasRepo:  repository.NewStationExpressAliasRepository(),
		threadStationRepo:        repository.NewThreadStationRepository(),
		countryRepo:              repository.NewCountryRepository(),
		carrierRepo:              repository.NewCarrierRepository(),
		geoIDIndex:               &geoIDIndex,
		expressIDToID:            &expressIDToID,
		trainNumberToThreadID:    &trainNumberToThreadID,
		threadIDToThreadStations: &threadIDToThreadStations,
	}
}

func (r *DictRepo) Load() error {
	const logMessage = "DictRepo.Load"

	geoIDIndex := make(map[int32]int32)
	expressIDToID := make(map[int32]int32)
	trainNumberToThreadID := make(map[string]int32)
	threadIDToThreadStations := make(map[int32][]*rasp.TThreadStation)
	threadsByStationID := map[int32]map[int32]struct{}{}
	var threads []*rasp.TThread
	timeZoneRepo := repository.NewTimeZoneRepository()
	settlementRepo := repository.NewSettlementRepository()
	stationRepo := repository.NewStationRepository()
	stationCodeRepo := repository.NewStationCodeRepository()
	threadRepo := repository.NewThreadRepository()
	stationExpressAliasRepo := repository.NewStationExpressAliasRepository()
	threadStationRepo := repository.NewThreadStationRepository()
	countryRepo := repository.NewCountryRepository()
	carrierRepo := repository.NewCarrierRepository()

	if err := r.load(timeZoneRepo, r.cfg.TimeZoneFN); err != nil {
		return err
	}
	if err := r.load(&settlementWriter{geoIDIndex: &geoIDIndex, settlementRepo: settlementRepo}, r.cfg.SettlementFN); err != nil {
		return err
	}
	if err := r.load(stationRepo, r.cfg.StationFN); err != nil {
		return err
	}

	if err := r.load(&stationCodeWriter{expressIDtoID: &expressIDToID, stationCodeRepo: stationCodeRepo}, r.cfg.StationCodeFN); err != nil {
		return err
	}

	if err := r.load(&threadWriter{trainNumberToThreadID: &trainNumberToThreadID, threadRepo: threadRepo,
		threads: &threads}, r.cfg.ThreadFN); err != nil {
		return err
	}

	if err := r.load(stationExpressAliasRepo, r.cfg.StationExpressAliasFN); err != nil {
		return err
	}

	if err := r.load(&threadStationWriter{threadIDToThreadStations: &threadIDToThreadStations, threadsByStationID: threadsByStationID,
		threadStationRepo: threadStationRepo}, r.cfg.ThreadStationFN); err != nil {
		return err
	}

	if err := r.load(countryRepo, r.cfg.CountryFN); err != nil {
		return err
	}

	if err := r.load(carrierRepo, r.cfg.CarrierFN); err != nil {
		return err
	}

	r.mutex.Lock()
	r.geoIDIndex = &geoIDIndex
	r.expressIDToID = &expressIDToID
	r.trainNumberToThreadID = &trainNumberToThreadID
	r.timeZoneRepo = timeZoneRepo
	r.settlementRepo = settlementRepo
	r.stationRepo = stationRepo
	r.stationCodeRepo = stationCodeRepo
	r.threadRepo = threadRepo
	r.stationExpressAliasRepo = stationExpressAliasRepo
	r.threads = threads
	r.threadStationRepo = threadStationRepo
	r.threadIDToThreadStations = &threadIDToThreadStations
	r.countryRepo = countryRepo
	r.threadsByStationID = threadsByStationID
	r.carrierRepo = carrierRepo
	r.mutex.Unlock()

	r.logger.Infof("%s: all rasp dictionaries loaded", logMessage)
	return nil
}

func (r *DictRepo) load(repo io.Writer, fn string) error {
	const logMessage = "DictRepo.load"
	bytesIterator, err := base.BuildIteratorFromFile(path.Join(r.cfg.ResourceDir, fn))
	if err != nil {
		return fmt.Errorf("%s: %w", logMessage, err)
	}
	err = bytesIterator.Populate(repo)
	if err != nil {
		return fmt.Errorf("%s: %w", logMessage, err)
	}
	r.logger.Infof("%s: rasp dictionary %s loaded", logMessage, fn)
	return nil
}

func (r *DictRepo) GetSettlement(id int32) (*rasp.TSettlement, bool) {
	r.mutex.RLock()
	defer r.mutex.RUnlock()
	return r.settlementRepo.Get(int(id))
}

func (r *DictRepo) GetStation(id int32) (*rasp.TStation, bool) {
	r.mutex.RLock()
	defer r.mutex.RUnlock()
	return r.stationRepo.Get(int(id))
}

func (r *DictRepo) GetTimeZone(id int32) (*rasp.TTimeZone, bool) {
	r.mutex.RLock()
	defer r.mutex.RUnlock()
	return r.timeZoneRepo.Get(int(id))
}

func (r *DictRepo) GetStationByExpressCode(id int32) (*rasp.TStation, bool) {
	r.mutex.RLock()
	defer r.mutex.RUnlock()
	id, ok := (*r.expressIDToID)[id]
	if !ok {
		return nil, false
	}
	return r.GetStation(id)
}

func (r *DictRepo) GetThread(id int32) (*rasp.TThread, bool) {
	r.mutex.RLock()
	defer r.mutex.RUnlock()
	return r.threadRepo.Get(int(id))
}

func (r *DictRepo) GetThreadByTrainNumber(number string) (*rasp.TThread, bool) {
	r.mutex.RLock()
	defer r.mutex.RUnlock()
	id, ok := (*r.trainNumberToThreadID)[number]
	if !ok {
		return nil, false
	}
	return r.threadRepo.Get(int(id))
}

func (r *DictRepo) GetStationExpressAlias(alias string) (*rasp.TStationExpressAlias, bool) {
	r.mutex.RLock()
	defer r.mutex.RUnlock()
	return r.stationExpressAliasRepo.Get(alias)
}

func (r *DictRepo) GetThreads() []*rasp.TThread {
	return r.threads
}

func (r *DictRepo) GetThreadStation(id int32) (*rasp.TThreadStation, bool) {
	r.mutex.RLock()
	defer r.mutex.RUnlock()
	return r.threadStationRepo.Get(int(id))
}

func (r *DictRepo) GetThreadStationsByThreadID(id int32) ([]*rasp.TThreadStation, bool) {
	r.mutex.RLock()
	defer r.mutex.RUnlock()
	return (*r.threadIDToThreadStations)[id], true
}

func (r *DictRepo) GetCountryByCountryID(id int32) (*rasp.TCountry, bool) {
	r.mutex.RLock()
	defer r.mutex.RUnlock()
	return r.countryRepo.Get(id)
}

func (r *DictRepo) GetThreadsNumbersByStationIDs(fromID int32, toID int32) []int32 {
	r.mutex.RLock()
	defer r.mutex.RUnlock()

	var threadsNumbers []int32
	threadsWithStationToID, ok := r.threadsByStationID[toID]
	if !ok {
		return nil
	}
	for threadID := range r.threadsByStationID[fromID] {
		if _, ok := threadsWithStationToID[threadID]; ok {
			threadsNumbers = append(threadsNumbers, threadID)
		}
	}

	return threadsNumbers
}

func (r *DictRepo) GetCarrierByCarrierID(id int32) (*rasp.TCarrier, bool) {
	r.mutex.RLock()
	defer r.mutex.RUnlock()
	return r.carrierRepo.Get(int(id))
}
