package point

import (
	"fmt"

	geobaselib "a.yandex-team.ru/library/go/yandex/geobase"
	"a.yandex-team.ru/travel/komod/trips/internal/models"
	"a.yandex-team.ru/travel/komod/trips/internal/references"
	"a.yandex-team.ru/travel/library/go/geobase"
)

var pointTypeToRegionTypeMap = map[models.PointKind]geobaselib.RegionType{
	models.CountryPointKind: geobaselib.RegionTypeCountry,
	models.RegionPointKind:  geobaselib.RegionTypeRegion,
}

var (
	ErrUnexpectedPointKind = fmt.Errorf("unexpected point kind")
)

type Resolver struct {
	geoBase      geobase.Geobase
	reference    references.References
	pointFactory *Factory
}

func NewResolver(
	geoBase geobase.Geobase,
	reference references.References,
	pointFactory *Factory,
) *Resolver {
	return &Resolver{
		geoBase:      geoBase,
		reference:    reference,
		pointFactory: pointFactory,
	}
}

func (r Resolver) GetCountry(point models.Point) (models.Point, error) {
	if point.GetGeoID() != 0 {
		return r.resolveByGeoID(point.GetGeoID(), models.CountryPointKind)
	}

	var countryID int
	switch point.GetKind() {
	case models.StationPointKind:
		if station, ok := r.reference.Stations().Get(point.GetID()); ok {
			countryID = int(station.CountryId)
		}
	case models.SettlementPointKind:
		if settlement, ok := r.reference.Settlements().Get(point.GetID()); ok {
			countryID = int(settlement.CountryId)
		}
	default:
		return nil, ErrUnexpectedPointKind
	}

	return r.pointFactory.MakeByCountryID(countryID)
}

func (r Resolver) GetRegion(point models.Point) (models.Point, error) {
	if point.GetGeoID() != 0 {
		return r.resolveByGeoID(point.GetGeoID(), models.RegionPointKind)
	}

	var regionID int
	switch point.GetKind() {
	case models.StationPointKind:
		if station, ok := r.reference.Stations().Get(point.GetID()); ok {
			regionID = int(station.RegionId)
		}
	case models.SettlementPointKind:
		if settlement, ok := r.reference.Settlements().Get(point.GetID()); ok {
			regionID = int(settlement.RegionId)
		}
	default:
		return nil, ErrUnexpectedPointKind
	}
	return r.pointFactory.MakeByRegionID(regionID)
}

func (r Resolver) resolveByGeoID(geoID int, targetKind models.PointKind) (models.Point, error) {
	geoTargetKind, ok := pointTypeToRegionTypeMap[targetKind]
	if !ok {
		return nil, fmt.Errorf("undefined pointType to regionType relation")
	}

	geoRegion, ok := r.tryCast(geoID, geoTargetKind)
	if !ok {
		return nil, fmt.Errorf("unable to cast point to region")
	}
	return r.pointFactory.MakeByGeoID(int(geoRegion.ID))
}

func (r Resolver) tryCast(geoID int, targetType geobaselib.RegionType) (*geobaselib.Region, bool) {
	region, err := r.geoBase.RoundToRegion(geoID, targetType)
	if err != nil {
		return nil, false
	}
	return region, true
}
