package module

import (
	"errors"
	"fmt"
	"path"
	"strings"

	"golang.org/x/mod/module"
)

var (
	ErrInvalidModulePath = errors.New("invalid module path")
	ErrDisallowedOrigin  = errors.New("disallowed origin")

	privateOrigins = []string{
		"*.yandex-team.ru",
	}

	pathValidators = []struct {
		domain     string
		minSlashes int
	}{
		{
			domain:     "github.com",
			minSlashes: 2,
		},
		{
			domain:     "golang.org",
			minSlashes: 2,
		},
		{
			domain:     "golang.yandex",
			minSlashes: 1,
		},
		{
			domain:     "google.golang.org",
			minSlashes: 1,
		},
		{
			domain:     "gopkg.in",
			minSlashes: 1,
		},
		{
			domain:     "go.uber.org",
			minSlashes: 1,
		},
		{
			domain:     "rsc.io",
			minSlashes: 1,
		},
	}

	moduleValidators = map[What][]validator{
		WhatList: {
			validateModulePath,
		},
		WhatLatest: {
			validateModulePath,
		},
		WhatInfo: {
			validateModulePath,
			validateModuleVersion,
		},
		WhatMod: {
			validateModulePath,
			validateModuleVersion,
		},
		WhatZip: {
			validateModulePath,
			validateModuleVersion,
		},
	}
)

type validator func(m *Module) error

func ValidateModule(m *Module) error {
	validators, ok := moduleValidators[m.What]
	if !ok {
		return nil
	}

	for _, v := range validators {
		if err := v(m); err != nil {
			return err
		}
	}
	return nil
}

func validateModulePath(m *Module) error {
	modulePath, _, ok := module.SplitPathVersion(m.Path)
	if !ok {
		return ErrInvalidModulePath
	}

	remoteDomain := m.Path
	if idx := strings.Index(m.Path, "/"); idx > 0 {
		remoteDomain = m.Path[:idx]
	}

	for _, v := range privateOrigins {
		if ok, _ := path.Match(v, remoteDomain); ok {
			return ErrDisallowedOrigin
		}
	}

	for _, v := range pathValidators {
		if v.domain != remoteDomain {
			continue
		}

		slashes := strings.Count(modulePath, "/")
		if slashes < v.minSlashes {
			return ErrInvalidModulePath
		}

		return nil
	}
	return nil
}

func validateModuleVersion(m *Module) error {
	if m.Version == "latest" {
		return errors.New("version latest is disallowed")
	}

	if m.Version != module.CanonicalVersion(m.Version) {
		if m.What == WhatInfo {
			// that's fine, info request may request remote commit
			return nil
		}

		return fmt.Errorf("version '%s' is not in canonical form", m.Version)
	}

	_, pathMajor, _ := module.SplitPathVersion(m.Path)
	if err := module.CheckPathMajor(m.Version, pathMajor); err != nil {
		return err
	}

	return nil
}
