package references

import (
	"fmt"
	"strings"

	"github.com/golang/protobuf/proto"

	"a.yandex-team.ru/travel/library/go/dicts/repository"
	"a.yandex-team.ru/travel/proto/dicts/rasp"
)

type StationCodesRepository struct {
	baseRepository     *repository.StationCodeRepository
	stationIDByIATA    map[string]int32
	stationIDBySirena  map[string]int32
	stationIDByICAO    map[string]int32
	stationIDByExpress map[string]int32
}

func NewStationCodesRepository() *StationCodesRepository {
	return &StationCodesRepository{
		baseRepository:     repository.NewStationCodeRepository(),
		stationIDByIATA:    make(map[string]int32),
		stationIDBySirena:  make(map[string]int32),
		stationIDByICAO:    make(map[string]int32),
		stationIDByExpress: make(map[string]int32),
	}
}

func (sr *StationCodesRepository) Write(b []byte) (int, error) {
	stationCodeProto := &rasp.TStationCode{}
	if err := proto.Unmarshal(b, stationCodeProto); err != nil {
		return 0, fmt.Errorf("StationRepository:Write: %w", err)
	}
	sr.baseRepository.Add(stationCodeProto)
	stationCode := strings.ToUpper(stationCodeProto.Code)
	if stationCodeProto.SystemId == rasp.ECodeSystem_CODE_SYSTEM_IATA {
		sr.stationIDByIATA[stationCode] = stationCodeProto.StationId
	} else if stationCodeProto.SystemId == rasp.ECodeSystem_CODE_SYSTEM_SIRENA {
		sr.stationIDBySirena[stationCode] = stationCodeProto.StationId
	} else if stationCodeProto.SystemId == rasp.ECodeSystem_CODE_SYSTEM_ICAO {
		sr.stationIDByICAO[stationCode] = stationCodeProto.StationId
	} else if stationCodeProto.SystemId == rasp.ECodeSystem_CODE_SYSTEM_EXPRESS {
		sr.stationIDByExpress[stationCode] = stationCodeProto.StationId
	}
	return len(b), nil
}

func (sr *StationCodesRepository) GetStationIDByCode(stationCode string) (int32, bool) {
	if stationCode == "" {
		return 0, false
	}
	stationCode = strings.ToUpper(stationCode)
	if stationID, ok := sr.stationIDByIATA[stationCode]; ok {
		return stationID, ok
	}
	if stationID, ok := sr.stationIDBySirena[stationCode]; ok {
		return stationID, ok
	}
	if stationID, ok := sr.stationIDByICAO[stationCode]; ok {
		return stationID, ok
	}
	return 0, false
}

func (sr *StationCodesRepository) GetStationIDByExpressCode(stationCode string) (int32, bool) {
	if stationCode == "" {
		return 0, false
	}
	if stationID, ok := sr.stationIDByExpress[stationCode]; ok {
		return stationID, ok
	}
	return 0, false
}

// Slow, only used for the sanity check tool
func (sr *StationCodesRepository) SeekStationExpressCodeByID(stationID int32) (string, bool) {
	for k, v := range sr.stationIDByExpress {
		if v == stationID {
			return k, true
		}
	}
	return "", false
}

func (sr *StationCodesRepository) UpdateFromSource(iterator RepositoryUpdater) error {
	newState := NewStationCodesRepository()
	if err := iterator.Populate(newState); err != nil {
		return err
	}
	*sr = *newState
	return nil
}
