package pypi

import (
	"encoding/json"
	"io"
	"io/ioutil"
	"os"

	"a.yandex-team.ru/library/go/core/xerrors"
	"a.yandex-team.ru/security/libs/go/archer"
	"a.yandex-team.ru/security/libs/go/pypi"
	"a.yandex-team.ru/security/libs/go/simplelog"
	"a.yandex-team.ru/security/libs/go/yahttp"
	"a.yandex-team.ru/security/yadi/libs/pypi/pkgparser"
)

const (
	MaxPkgSize     = 600 * 1024 * 1024
	UnknownLicense = "UNKNOWN"
)

type (
	Package struct {
		pypi.Package
		versions map[string]pypi.Release
	}

	PackageVersion struct {
		pkgInfo    *pkgparser.PkgInfo
		pkgRelease pypi.Release
	}

	PackageType = pypi.PackageType

	PkgRequire struct {
		Name     string   `json:"name"`
		Extras   []string `json:"extras"`
		Versions string   `json:"versions,omitempty"`
	}

	PkgRequirements struct {
		Requires []PkgRequire            `json:"requires,omitempty"`
		Extras   map[string][]PkgRequire `json:"extras,omitempty"`
	}
)

var (
	pkgPriority = map[pypi.PackageType]uint8{
		// easy to parse & possibly lower size
		pypi.PackageTypeWheel: 3,
		// easy to parse
		pypi.PackageTypeSrcEgg: 2,
		// ugliest & with lot of legacy sub-formats :(
		pypi.PackageTypeSrc: 1,
	}
)

func (p *Package) Resolve() error {
	err := p.Package.Resolve()
	if err != nil {
		return err
	}

	releases := p.Package.Releases()
	p.versions = make(map[string]pypi.Release, len(releases))
	for version, releaseTypes := range releases {
		for _, release := range releaseTypes {
			if exist, ok := p.versions[version]; ok {
				if pkgPriority[exist.Type] >= pkgPriority[release.Type] {
					continue
				}

				if _, ok := archer.Walkers[release.Ext]; !ok {
					simplelog.Error("failed to find archer", "url", release.DownloadURL, "ext", release.Ext)
					continue
				}

				p.versions[version] = release
			} else {
				if _, ok := archer.Walkers[release.Ext]; !ok {
					simplelog.Error("failed to find archer", "url", release.DownloadURL, "ext", release.Ext)
					continue
				}

				p.versions[version] = release
			}
		}
	}
	return nil
}

func (p *Package) Name() string {
	return p.Package.Name()
}

func (p *Package) NormName() string {
	return p.Package.NormName()
}

func (p *Package) Versions() []string {
	result := make([]string, len(p.versions))
	i := 0
	for version := range p.versions {
		result[i] = version
		i++
	}
	return result
}

func (p *Package) Version(version string) (*PackageVersion, error) {
	pkgVer, ok := p.versions[version]
	if !ok {
		return nil, xerrors.Errorf("version %s not found", version)
	}

	pkgPath, err := download(pkgVer)
	if err != nil {
		simplelog.Error("failed to download version",
			"pkg_name", p.Name(),
			"url", pkgVer.DownloadURL,
			"err", err,
		)
		return nil, err
	}
	defer func() {
		if err := os.Remove(pkgPath); err != nil {
			simplelog.Error("failed to cleanup pkg", "path", pkgPath)
		}
	}()

	var parsedPkg *pkgparser.PkgInfo
	switch pkgVer.Type {
	case pypi.PackageTypeWheel:
		parsedPkg, err = pkgparser.ParseWheelPackage(pkgPath, archer.Walkers[pkgVer.Ext])
	case pypi.PackageTypeSrc:
		parsedPkg, err = pkgparser.ParseSrcPackage(pkgPath, archer.Walkers[pkgVer.Ext])
	case pypi.PackageTypeSrcEgg:
		parsedPkg, err = pkgparser.ParseEggPackage(pkgPath, archer.Walkers[pkgVer.Ext])
	default:
		return nil, xerrors.Errorf("unknown pkg type: %d", pkgVer.Type)
	}

	if err != nil {
		simplelog.Error("failed to parse version",
			"pkg_name", p.Package.Name(),
			"url", pkgVer.DownloadURL,
			"err", err,
		)
		return nil, err
	}

	return &PackageVersion{
		pkgInfo:    parsedPkg,
		pkgRelease: pkgVer,
	}, nil
}

func (r *PackageVersion) Version() string {
	return r.pkgInfo.Version
}

func (r *PackageVersion) License() string {
	if r.pkgInfo.License == UnknownLicense {
		return ""
	}

	return r.pkgInfo.License
}

func (r *PackageVersion) DownloadURL() string {
	return r.pkgRelease.DownloadURL
}

func (r *PackageVersion) Requirements() (*PkgRequirements, error) {
	requirements := &PkgRequirements{
		Requires: make([]PkgRequire, len(r.pkgInfo.Requires)),
		Extras:   make(map[string][]PkgRequire, len(r.pkgInfo.Extras)),
	}

	for i, req := range r.pkgInfo.Requires {
		requirements.Requires[i] = PkgRequire{
			Name:     req.Name,
			Extras:   req.Extras,
			Versions: req.Versions,
		}
	}

	for name, reqs := range r.pkgInfo.Extras {
		for _, req := range reqs {
			requirements.Extras[name] = append(requirements.Extras[name], PkgRequire{
				Name:     req.Name,
				Extras:   req.Extras,
				Versions: req.Versions,
			})
		}
	}

	return requirements, nil
}

func (r *PackageVersion) RequirementsJSON() ([]byte, error) {
	requirements, err := r.Requirements()
	if err != nil {
		return nil, err
	}

	return json.Marshal(requirements)
}

func download(pkgVer pypi.Release) (pkgPath string, resultErr error) {
	if pkgVer.Size > MaxPkgSize {
		resultErr = xerrors.Errorf("failed to download pkg: too big %d > %d", pkgVer.Size, MaxPkgSize)
		return
	}

	tempFile, err := ioutil.TempFile("", "yadi-index")
	if err != nil {
		resultErr = xerrors.Errorf("failed to create tmp file for pkg: %w", err)
		return
	}
	defer func() { _ = tempFile.Close() }()

	response, err := doGet(pkgVer.DownloadURL)
	if err != nil {
		resultErr = xerrors.Errorf("failed to download pkg: %w", err)
		return
	}
	defer yahttp.GracefulClose(response.Body)

	if response.StatusCode != 200 {
		resultErr = xerrors.Errorf("failed to download pkg: invalid reponse status %d", response.StatusCode)
		return
	}

	if response.ContentLength > MaxPkgSize {
		resultErr = xerrors.Errorf("failed to download pkg: too big %d > %d", response.ContentLength, MaxPkgSize)
		return
	}

	_, err = io.Copy(tempFile, response.Body)
	if err != nil {
		// clean up unfinished download
		_ = os.Remove(pkgPath)
		resultErr = xerrors.Errorf("failed to save pkg: %w", err)
		return
	}

	pkgPath = tempFile.Name()
	return
}
