package signer

import (
	"bytes"
	"context"
	"crypto/tls"
	"fmt"
	"io"
	"net/http"
	"strconv"
	"time"

	"github.com/go-resty/resty/v2"

	"a.yandex-team.ru/library/go/certifi"
	"a.yandex-team.ru/library/go/core/log"
	"a.yandex-team.ru/library/go/core/log/nop"
	"a.yandex-team.ru/library/go/httputil/headers"
	"a.yandex-team.ru/security/libs/go/hashreader"
)

const (
	ProductionEndpoint = "https://signer.yandex-team.ru"
	TestingEndpoint    = "https://signer-test.sec.yandex-team.ru"
	DefaultEndpoint    = ProductionEndpoint
	DefaultRetries     = 5
)

type Signer interface {
	Sign(ctx context.Context, src io.Reader, dst io.Writer, opts ...SignOption) error
}

var _ Signer = (*YaSigner)(nil)

type YaSigner struct {
	httpc *resty.Client
	app   string
	log   log.Logger
}

func NewYaSigner(opts ...Option) (*YaSigner, error) {
	hc := &http.Client{
		Transport: &http.Transport{
			Proxy:               http.ProxyFromEnvironment,
			TLSHandshakeTimeout: 2 * time.Second,
		},
	}

	httpc := resty.NewWithClient(hc).
		SetRetryCount(DefaultRetries).
		SetBaseURL(DefaultEndpoint).
		SetHeader(headers.UserAgentKey, "a.yandex-team.ru/security/libs/go/signer")

	certPool, err := certifi.NewCertPool()
	if err == nil {
		httpc.SetTLSClientConfig(&tls.Config{RootCAs: certPool})
	}

	s := YaSigner{
		httpc: httpc,
		log:   &nop.Logger{},
	}

	for _, opt := range opts {
		switch v := opt.(type) {
		case optionLog:
			s.log = v.l
		case optionRetries:
			s.httpc.SetRetryCount(v.retries)
		case optionEndpoint:
			s.httpc.SetBaseURL(v.endpoint)
		case optionApp:
			s.app = v.app
		case optionAuthToken:
			s.httpc.SetHeader(headers.AuthorizationKey, "OAuth "+v.token)
		default:
			return nil, fmt.Errorf("unknown option: %T", v)
		}
	}
	return &s, nil
}

func (s *YaSigner) Sign(ctx context.Context, src io.Reader, dst io.Writer, opts ...SignOption) error {
	signID, err := s.newSign(ctx, opts...)
	if err != nil {
		return fmt.Errorf("start sign: %w", err)
	}
	s.log.Info("sign started", log.Int("id", signID))

	err = s.upload(ctx, signID, src)
	if err != nil {
		return fmt.Errorf("upload [%d]: %w", signID, err)
	}
	s.log.Info("uploaded", log.Int("id", signID))

	si, err := s.waitSign(ctx, signID)
	if err != nil {
		return fmt.Errorf("sign wait [%d]: %w", signID, err)
	}

	err = s.download(ctx, si, dst)
	if err != nil {
		return fmt.Errorf("download [%d]: %w", signID, err)
	}
	s.log.Info("complete", log.Int("id", signID))
	return nil
}

func (s *YaSigner) newSign(ctx context.Context, opts ...SignOption) (int, error) {
	reqBody := struct {
		Application string `json:"application"`
		Comment     string `json:"comment"`
	}{
		Application: s.app,
	}

	for _, opt := range opts {
		switch v := opt.(type) {
		case signOptionApp:
			reqBody.Application = v.app
		case signOptionComment:
			reqBody.Comment = v.comment
		}
	}

	var rspBody struct {
		ID int `json:"id"`
	}
	rsp, err := s.httpc.R().
		SetContext(ctx).
		ForceContentType("application/json").
		SetBody(reqBody).
		SetResult(&rspBody).
		Post("/api/v2/sign/")

	if err != nil {
		return 0, err
	}

	if !rsp.IsSuccess() {
		return 0, fmt.Errorf("non-200 (%s): %s", rsp.Status(), rspBodyStr(rsp))
	}

	if rspBody.ID <= 0 {
		return 0, fmt.Errorf("no ID in response: %s", rspBodyStr(rsp))
	}

	return rspBody.ID, nil
}

func (s *YaSigner) upload(ctx context.Context, signID int, src io.Reader) error {
	body, err := hashreader.NewHashReader(src, hashreader.WithSha1Hash())
	if err != nil {
		return fmt.Errorf("create hash reader: %w", err)
	}

	var rspBody struct {
		Success bool   `json:"success"`
		SHA1    string `json:"sha1"`
	}
	rsp, err := s.httpc.R().
		SetContext(ctx).
		SetHeader(headers.ContentTypeKey, "application/octet-stream").
		ForceContentType("application/json").
		SetBody(body).
		SetResult(&rspBody).
		SetPathParam("id", strconv.Itoa(signID)).
		Put("/api/v2/sign/{id}")

	if err != nil {
		return err
	}

	if !rsp.IsSuccess() {
		return fmt.Errorf("non-200 (%s): %s", rsp.Status(), rspBodyStr(rsp))
	}

	if !rspBody.Success {
		return fmt.Errorf("non-success response: %s", rspBodyStr(rsp))
	}

	if body.Sum() != rspBody.SHA1 {
		return fmt.Errorf("checksum mismatch: %s (expected) != %s (actual)", body.Sum(), rspBody.SHA1)
	}

	return nil
}

func (s *YaSigner) download(ctx context.Context, si *signInfo, dst io.Writer) error {
	rsp, err := s.httpc.R().
		SetContext(ctx).
		SetDoNotParseResponse(true).
		Get(si.URL)

	if err != nil {
		return err
	}

	if !rsp.IsSuccess() {
		return fmt.Errorf("non-200 (%s): %s", rsp.Status(), rspBodyStr(rsp))
	}

	body := rsp.RawBody()
	defer func() {
		_, _ = io.CopyN(io.Discard, body, 128<<10)
		_ = body.Close()
	}()

	r, err := hashreader.NewHashReader(body, hashreader.WithSha1Hash())
	if err != nil {
		return fmt.Errorf("create hash reader: %w", err)
	}

	if _, err := io.Copy(dst, r); err != nil {
		return err
	}

	if r.Sum() != si.SHA1 {
		return fmt.Errorf("hash mismatch for url %q: %s (expected) != %s (actual)", si.URL, si.SHA1, r.Sum())
	}

	return nil
}

func (s *YaSigner) waitSign(ctx context.Context, signID int) (*signInfo, error) {
	for {
		s.log.Info("wait sign task", log.Int("id", signID))
		done, out, err := s.signInfo(ctx, signID)
		if err != nil {
			return nil, err
		}

		if done {
			return out, nil
		}

		time.Sleep(5 * time.Second)
	}
}

func (s *YaSigner) signInfo(ctx context.Context, signID int) (bool, *signInfo, error) {
	var rspBody struct {
		signInfo
		Status signStatus `json:"status"`
	}
	rsp, err := s.httpc.R().
		SetContext(ctx).
		ForceContentType("application/json").
		SetResult(&rspBody).
		SetPathParam("id", strconv.Itoa(signID)).
		Get("/api/v2/sign/{id}")

	if err != nil {
		return false, nil, err
	}

	if !rsp.IsSuccess() {
		return false, nil, fmt.Errorf("non-200 (%s): %s", rsp.Status(), rspBodyStr(rsp))
	}

	switch rspBody.Status {
	case signStatusInProgress:
		return false, nil, nil
	case signStatusComplete:
		if rspBody.signInfo.URL == "" || rspBody.signInfo.SHA1 == "" {
			return false, nil, fmt.Errorf("invalid signer response: %s", rspBodyStr(rsp))
		}
		return true, &rspBody.signInfo, nil
	case signStatusError:
		return false, nil, fmt.Errorf("sign fail: %s", rspBodyStr(rsp))
	default:
		return false, nil, fmt.Errorf("unsupported status: %s", rspBody.Status)
	}
}

func rspBodyStr(rsp *resty.Response) string {
	return string(bytes.TrimSpace(rsp.Body()))
}
