package sectools

import (
	"context"
	"crypto/tls"
	"errors"
	"fmt"
	"io"
	"net/http"
	"runtime"

	"github.com/blang/semver/v4"
	"github.com/go-resty/resty/v2"
	"github.com/klauspost/compress/zstd"

	"a.yandex-team.ru/library/go/certifi"
	"a.yandex-team.ru/library/go/httputil/headers"
	"a.yandex-team.ru/security/libs/go/hashreader"
)

const (
	slowUpstream   = "https://tools.sec.yandex-team.ru/api/v2/proxy"
	fastUpstream   = "https://sectools-releases.s3.mds.yandex.net"
	DefaultRetries = 1

	uaFormat = "SecTools/%s/%s"
)

var ErrDownloadSameVersion = errors.New("download same version is prohibited")

type Client struct {
	toolName       string
	currentVersion string
	os             string
	arch           string
	preferFaster   bool
	channel        Channel
	httpc          *resty.Client
}

func NewClient(toolName string, opts ...Option) *Client {
	httpc := resty.NewWithClient(&http.Client{Transport: newHTTPTransport()}).
		SetRetryCount(DefaultRetries).
		SetLogger(&restyLogger{}).
		SetHeader(headers.UserAgentKey, fmt.Sprintf(uaFormat, toolName, "0.0.0")).
		SetBaseURL(slowUpstream)

	certPool, err := certifi.NewCertPool()
	if err == nil {
		httpc.SetTLSClientConfig(&tls.Config{RootCAs: certPool})
	}

	out := &Client{
		toolName: toolName,
		channel:  ChannelStable,
		arch:     runtime.GOARCH,
		os:       runtime.GOOS,
		httpc:    httpc,
	}

	for _, opt := range opts {
		opt(out)
	}

	return out
}

func (c *Client) IsLatestVersion(ctx context.Context, version string) (bool, string, error) {
	latestVersion, err := c.LatestVersion(ctx)
	if err != nil {
		return false, "", err
	}

	isLatest, err := isLatestVersion(version, latestVersion)
	if err != nil {
		return false, "", err
	}

	return isLatest, latestVersion, nil
}

func (c *Client) LatestVersion(ctx context.Context) (string, error) {
	manifest, err := c.LatestManifest(ctx)
	if err != nil {
		return "", err
	}

	return manifest.Version, nil
}

func (c *Client) LatestManifest(ctx context.Context) (*Manifest, error) {
	return c.Manifest(ctx, string(c.channel))
}

func (c *Client) DownloadLatestVersion(ctx context.Context, out io.Writer) error {
	return c.DownloadVersion(ctx, string(c.channel), out)
}

func (c *Client) DownloadVersion(ctx context.Context, version string, out io.Writer) error {
	manifest, err := c.Manifest(ctx, version)
	if err != nil {
		return err
	}

	if manifest.Version == c.currentVersion {
		return ErrDownloadSameVersion
	}

	platform, ok := manifest.Binaries[c.os]
	if !ok {
		return fmt.Errorf("unsupported tool platform: %s", c.os)
	}

	downloadInfo, ok := platform[c.arch]
	if !ok {
		if c.arch == "arm64" && c.os == "darwin" {
			// try Rosetta
			downloadInfo, ok = platform["amd64"]
		}

		if !ok {
			return fmt.Errorf("unsupported tool platform %q arch: %s", c.os, c.arch)
		}
	}

	var downloadURL string
	switch {
	case c.preferFaster && downloadInfo.FastURL != "":
		downloadURL = downloadInfo.FastURL
	case downloadInfo.URL != "":
		downloadURL = downloadInfo.URL
	default:
		return errors.New("invalid manifest: no download url available")
	}

	//TODO(buglloc): try all urls

	rsp, err := c.httpc.R().
		SetContext(ctx).
		SetDoNotParseResponse(true).
		Get(downloadURL)

	if err != nil {
		return err
	}

	body := rsp.RawBody()
	defer func() {
		_, _ = io.CopyN(io.Discard, body, 128<<10)
		_ = body.Close()
	}()

	hashedR, err := hashreader.NewHashReader(body)
	if err != nil {
		return fmt.Errorf("failed to create hash reader: %w", err)
	}

	zstdR, err := zstd.NewReader(hashedR)
	if err != nil {
		return fmt.Errorf("failed to create zstd reader: %w", err)
	}
	defer zstdR.Close()

	if _, err := io.Copy(out, zstdR); err != nil {
		return err
	}

	if downloadInfo.Hash != "" && hashedR.Hash() != downloadInfo.Hash {
		return fmt.Errorf("hash mismatch for url %q: %s != %s", downloadURL, hashedR.Hash(), downloadInfo.Hash)
	}

	return nil
}

func (c *Client) Manifest(ctx context.Context, version string) (*Manifest, error) {
	var out Manifest
	rsp, err := c.httpc.R().
		SetContext(ctx).
		SetResult(&out).
		ForceContentType(headers.TypeApplicationJSON.String()).
		SetPathParams(map[string]string{
			"tool":    c.toolName,
			"version": version,
		}).
		SetDoNotParseResponse(false).
		Get("/{tool}/{version}/manifest.json")

	if err != nil {
		return nil, err
	}

	if rsp.StatusCode() != http.StatusOK {
		return nil, fmt.Errorf("non-200 status code response: %d", rsp.StatusCode())
	}

	if out.Version == "" {
		return nil, fmt.Errorf("malformed manifest: %v", out)
	}

	return &out, nil
}

func isLatestVersion(cur, latest string) (bool, error) {
	curVer, err := semver.Parse(cur)
	if err != nil {
		return false, fmt.Errorf("invalid current version %q: %w", cur, err)
	}

	latestVer, err := semver.Parse(latest)
	if err != nil {
		return false, fmt.Errorf("invalid latest version %q: %w", latest, err)
	}

	return curVer.Compare(latestVer) >= 0, nil
}
