package skottyca

import (
	"bufio"
	"bytes"
	"context"
	"crypto/md5"
	"crypto/tls"
	"encoding/hex"
	"errors"
	"fmt"
	"io"
	"strconv"
	"sync"
	"time"

	"github.com/go-resty/resty/v2"
	"github.com/klauspost/compress/zstd"
	"github.com/stripe/krl"
	"golang.org/x/crypto/ssh"

	"a.yandex-team.ru/library/go/certifi"
	"a.yandex-team.ru/library/go/core/log"
	"a.yandex-team.ru/library/go/core/log/nop"
)

type SkottyCA struct {
	httpc      *resty.Client
	syncPeriod time.Duration
	krl        *krl.KRL
	caKeys     []ssh.PublicKey
	mu         sync.RWMutex
	caKinds    []string
	zstd       *zstd.Decoder
	l          log.Logger
	ctx        context.Context
	cancelCtx  context.CancelFunc
	closed     chan struct{}
}

func NewCA(opts ...Option) (*SkottyCA, error) {
	certPool, err := certifi.NewCertPool()
	if err != nil {
		return nil, fmt.Errorf("can't create ca pool: %w", err)
	}

	zstdDec, err := zstd.NewReader(nil, zstd.WithDecoderConcurrency(1))
	if err != nil {
		return nil, fmt.Errorf("can't create zstd decoder: %w", err)
	}

	httpc := resty.New().
		SetTLSClientConfig(&tls.Config{RootCAs: certPool}).
		SetJSONEscapeHTML(false).
		SetBaseURL("https://cauth.yandex.net:4443").
		SetRetryCount(3).
		SetDoNotParseResponse(true).
		SetRetryWaitTime(1 * time.Second).
		SetRetryMaxWaitTime(10 * time.Second)

	ctx, cancelCtx := context.WithCancel(context.Background())

	ca := &SkottyCA{
		httpc:      httpc,
		zstd:       zstdDec,
		syncPeriod: DefaultSyncPeriod,
		ctx:        ctx,
		cancelCtx:  cancelCtx,
		closed:     make(chan struct{}),
		l:          &nop.Logger{},
		caKinds:    []string{CAKindSecure},
	}

	for _, opt := range opts {
		opt(ca)
	}

	if err := ca.Sync(); err != nil {
		return nil, fmt.Errorf("sync fail: %w", err)
	}

	go ca.loop()
	return ca, nil
}

func (c *SkottyCA) loop() {
	defer close(c.closed)

	t := time.NewTicker(c.syncPeriod)
	defer t.Stop()

	for {
		select {
		case <-c.ctx.Done():
			return
		case <-t.C:
		}

		if err := c.Sync(); err != nil {
			c.l.Error("sync fail", log.Error(err))
		}
	}
}

func (c *SkottyCA) IsRevoked(cert *ssh.Certificate) bool {
	c.mu.RLock()
	defer c.mu.RUnlock()

	if c.krl == nil {
		return false
	}
	return c.krl.IsRevoked(cert)
}

func (c *SkottyCA) IsUserAuthority(auth ssh.PublicKey) bool {
	c.mu.RLock()
	defer c.mu.RUnlock()

	if c.caKeys == nil {
		return false
	}

	target := auth.Marshal()
	for _, ca := range c.caKeys {
		if bytes.Equal(target, ca.Marshal()) {
			return true
		}
	}

	return false
}

func (c *SkottyCA) Sync() error {
	if err := c.syncCAKeys(); err != nil {
		return fmt.Errorf("unable to sync CA keys: %w", err)
	}

	if err := c.syncKRL(); err != nil {
		return fmt.Errorf("unable to sync KRL: %w", err)
	}

	return nil
}

func (c *SkottyCA) fetchCAKeys(ca string) ([]ssh.PublicKey, error) {
	rsp, err := c.httpc.R().
		SetPathParam("ca", ca).
		Get("/keylists/main/{ca}")
	if err != nil {
		return nil, err
	}

	body := rsp.RawBody()
	defer func() {
		_, _ = io.CopyN(io.Discard, body, 128<<10)
		_ = body.Close()
	}()

	if !rsp.IsSuccess() {
		return nil, fmt.Errorf("non-200 status code: %d", rsp.StatusCode())
	}

	caKeys, err := parsePublicKeys(body)
	if err != nil {
		return nil, fmt.Errorf("unable to parse ca keys: %w", err)
	}

	return caKeys, nil
}

func (c *SkottyCA) syncCAKeys() error {
	var caKeys []ssh.PublicKey
	for _, ca := range c.caKinds {
		keys, err := c.fetchCAKeys(ca)
		if err != nil {
			return fmt.Errorf("unable to sync ca %q keys: %w", ca, err)
		}

		caKeys = append(caKeys, keys...)
	}

	c.mu.Lock()
	defer c.mu.Unlock()
	c.caKeys = caKeys
	return nil
}

func (c *SkottyCA) syncKRL() error {
	rsp, err := c.httpc.R().Get("/keylists/main/all.zst")
	if err != nil {
		return err
	}

	body := rsp.RawBody()
	defer func() {
		_, _ = io.CopyN(io.Discard, body, 128<<10)
		_ = body.Close()
	}()

	if !rsp.IsSuccess() {
		return fmt.Errorf("non-200 status code: %d", rsp.StatusCode())
	}

	if rsp.Header().Get("Content-Type") != "application/krl+zst" {
		return fmt.Errorf("unexpected content-type header: %s", rsp.Header().Get("Content-Type"))
	}

	expectedHash := rsp.Header().Get("Etag")
	if expectedHash == "" {
		return errors.New("no ETag header returned")
	}

	if expectedHash[0] == '"' {
		expectedHash, err = strconv.Unquote(expectedHash)
		if err != nil {
			return fmt.Errorf("invalid ETag header: %s: %w", rsp.Header().Get("Etag"), err)
		}
	}

	krlBytes, err := io.ReadAll(body)
	if err != nil {
		return fmt.Errorf("read: %w", err)
	}

	md5Hash := md5.Sum(krlBytes)
	actualHash := hex.EncodeToString(md5Hash[:])
	if expectedHash != actualHash {
		return fmt.Errorf("hash mismatch: %s (expected) != %s (actual)", expectedHash, actualHash)
	}

	krlBytes, err = c.zstd.DecodeAll(krlBytes, nil)
	if err != nil {
		return fmt.Errorf("decode: %w", err)
	}

	parsedKRL, err := krl.ParseKRL(krlBytes)
	if err != nil {
		return fmt.Errorf("parse: %w", err)
	}

	c.mu.Lock()
	defer c.mu.Unlock()
	c.krl = parsedKRL
	return nil
}

func (c *SkottyCA) Shutdown(ctx context.Context) {
	c.cancelCtx()

	select {
	case <-c.closed:
	case <-ctx.Done():
	}
}

func parsePublicKeys(raw io.ReadCloser) ([]ssh.PublicKey, error) {
	var keys []ssh.PublicKey
	scanner := bufio.NewScanner(raw)
	for scanner.Scan() {
		ca, _, _, _, err := ssh.ParseAuthorizedKey(scanner.Bytes())
		if err != nil {
			return nil, fmt.Errorf("failed to parse authorized key %q: %w", scanner.Text(), err)
		}
		keys = append(keys, ca)
	}
	return keys, scanner.Err()
}
