package krl

import (
	"bytes"
	"context"
	"crypto/md5"
	"crypto/rand"
	"encoding/base64"
	"encoding/hex"
	"encoding/json"
	"fmt"
	"net/http"
	"net/url"
	"path"
	"strings"
	"time"

	"github.com/go-resty/resty/v2"
	"github.com/klauspost/compress/zstd"
	"github.com/stripe/krl"
	"golang.org/x/crypto/ssh"

	"a.yandex-team.ru/passport/infra/daemons/tvmcert/internal/cache"
	"a.yandex-team.ru/passport/infra/daemons/tvmcert/internal/task"
	"a.yandex-team.ru/passport/infra/daemons/tvmcert/internal/utils"
	"a.yandex-team.ru/passport/shared/golibs/juggler"
	"a.yandex-team.ru/passport/shared/golibs/logger"
	shared_utils "a.yandex-team.ru/passport/shared/golibs/utils"
)

type Fetcher interface {
	task.Runner
	CheckCertificate(certificate *ssh.Certificate) error
	GetJugglerStatus() *juggler.Status
}

const (
	CacheFile = "krl.json"
)

type SkottyFetcherConfig struct {
	SkottyURL          string                 `json:"skotty_url"`
	CacheDir           string                 `json:"cache_dir"`
	UpdatePeriod       *shared_utils.Duration `json:"update_period"`
	JugglerAgeWarn     *shared_utils.Duration `json:"juggler_age_warn"`
	JugglerAgeCritical *shared_utils.Duration `json:"juggler_age_critical"`
}

func DecodeZSTD(raw []byte) ([]byte, error) {
	decoder, err := zstd.NewReader(nil, zstd.WithDecoderConcurrency(1))
	if err != nil {
		return nil, err
	}
	defer decoder.Close()

	out, err := decoder.DecodeAll(raw, nil)
	if err != nil {
		return nil, err
	}
	return out, nil
}

func ParseKRL(raw []byte) (*krl.KRL, error) {
	decoded, err := DecodeZSTD(raw)
	if err != nil {
		return nil, err
	}
	return krl.ParseKRL(decoded)
}

type SkottyFetcher struct {
	holder                 Holder
	ttl                    time.Duration
	cache                  *cache.Manager
	client                 *resty.Client
	jugglerWarnTimeout     time.Duration
	jugglerCriticalTimeout time.Duration
	lastError              utils.ThreadSafeError
}

func NewSkottyFetcher(config SkottyFetcherConfig) (*SkottyFetcher, error) {
	var err error
	if config.SkottyURL == "" {
		config.SkottyURL = "https://skotty.sec.yandex-team.ru"
	}

	_, err = url.Parse(config.SkottyURL)
	if err != nil {
		return nil, err
	}

	if config.UpdatePeriod == nil {
		config.UpdatePeriod = &shared_utils.Duration{Duration: 24 * time.Hour}
	}

	if config.JugglerAgeWarn == nil {
		config.JugglerAgeWarn = &shared_utils.Duration{Duration: 24 * time.Hour}
	}

	if config.JugglerAgeCritical == nil {
		config.JugglerAgeCritical = &shared_utils.Duration{Duration: 48 * time.Hour}
	}

	cacheManager, err := cache.NewManager(path.Join(config.CacheDir, CacheFile))
	if err != nil {
		return nil, err
	}

	fetcher := &SkottyFetcher{
		holder:                 Holder{},
		ttl:                    config.UpdatePeriod.Duration,
		cache:                  cacheManager,
		client:                 resty.New().SetBaseURL(config.SkottyURL),
		jugglerWarnTimeout:     config.JugglerAgeWarn.Duration,
		jugglerCriticalTimeout: config.JugglerAgeCritical.Duration,
	}

	return fetcher, fetcher.Init()
}

func (fetcher *SkottyFetcher) Init() error {
	if err := fetcher.Update(fetcher.FetchFromAPI); err != nil {
		logger.Log().Warnf("cannot get krl from skotty: %s", err)
		if err = fetcher.Update(fetcher.FetchFromCache); err != nil {
			logger.Log().Errorf("cannot get krl from cache or skotty: %s", err)
			return err
		}
	}

	if err := fetcher.WriteCache(); err != nil {
		logger.Log().Errorf("cannot write cache: %s", err)
		return err
	}

	return nil
}

func (fetcher *SkottyFetcher) CheckCertificate(certificate *ssh.Certificate) error {
	k, _, _ := fetcher.holder.GetKRL()
	if k == nil {
		return fmt.Errorf("krl not initialized")
	}

	if k.IsRevoked(certificate) {
		return fmt.Errorf("certificate (%s) has been revoked", certificate.KeyId)
	}

	return nil
}

func (fetcher *SkottyFetcher) GetJugglerStatus() *juggler.Status {
	status := juggler.NewStatusOk()
	_, ts, _ := fetcher.holder.GetKRL()

	err := fetcher.lastError.GetError()
	if elapsed := time.Since(ts); elapsed > fetcher.jugglerCriticalTimeout {
		status.Update(juggler.NewStatus(juggler.Critical, "krl is %.2f seconds old; last error: %s", elapsed.Seconds(), err))
	} else if elapsed = time.Since(ts); elapsed > fetcher.jugglerWarnTimeout {
		status.Update(juggler.NewStatus(juggler.Warning, "krl are %.2f seconds old; last error: %s", elapsed.Seconds(), err))
	}

	return status
}

func (fetcher *SkottyFetcher) Run() error {
	_, ts, _ := fetcher.holder.GetKRL()
	if time.Since(ts) < fetcher.ttl {
		return nil
	}

	logger.Log().Debugf("trying to update krl...")
	err := fetcher.Update(fetcher.FetchFromAPI)
	if err != nil {
		return err
	}
	err = fetcher.WriteCache()
	if err != nil {
		return err
	}
	logger.Log().Debugf("successfully updated krl")

	return nil
}

func (fetcher *SkottyFetcher) WriteCache() error {
	c := Cache{}
	k, ts, etag := fetcher.holder.GetKRL()
	if k == nil {
		return fmt.Errorf("krl not initialized")
	}

	var err error
	c.Timestamp = ts
	c.Etag = etag
	c.KRL, err = k.Marshal(rand.Reader)
	if err != nil {
		return err
	}

	data, err := json.Marshal(c)
	if err != nil {
		return err
	}

	err = fetcher.cache.TryWrite(data)
	if err != nil {
		return err
	}
	return nil
}

func (fetcher *SkottyFetcher) Update(fetch FetchKRLFunc) error {
	k, ts, etag, err := fetch()
	if err != nil {
		fetcher.lastError.SetError(err)
		return err
	}
	if k == nil {
		// empty new krl and non-empty etag means "not modified" and we need to update ts only.
		// empty new krl and empty etag means broken krl
		if etag == "" {
			err = fmt.Errorf("no valid krl")
			fetcher.lastError.SetError(err)
			return err
		}
		currentKRL, _, _ := fetcher.holder.GetKRL()
		if currentKRL == nil {
			err = fmt.Errorf("broken krl in cache")
			fetcher.lastError.SetError(err)
			return err
		}
		logger.Log().Debugf("Not modified. Update timestamp only. Etag=%s", etag)
		fetcher.holder.UpdateTimestampOnly(ts)
		return nil
	}
	fetcher.holder.UpdateWithTimestamp(k, ts, etag)
	return nil
}

func (fetcher *SkottyFetcher) FetchFromCache() (*krl.KRL, time.Time, string, error) {
	data, err := fetcher.cache.TryRead()
	if err != nil {
		return nil, time.Time{}, "", err
	}

	c := Cache{}
	err = json.Unmarshal(data, &c)
	if err != nil {
		return nil, time.Time{}, "", err
	}

	k, err := krl.ParseKRL(c.KRL)
	if err != nil {
		return nil, time.Time{}, "", err
	}

	return k, c.Timestamp, c.Etag, nil
}

func (fetcher *SkottyFetcher) FetchFromAPI() (*krl.KRL, time.Time, string, error) {
	ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
	defer cancel()
	httpRequest := fetcher.client.R().SetContext(ctx)
	currentEtag := fetcher.holder.GetEtag()
	if currentEtag != "" {
		httpRequest.SetHeader("If-None-Match", fetcher.holder.Etag)
	}
	response, err := httpRequest.Get("/api/v1/ca/krl/all.zst")
	if err != nil {
		return nil, time.Time{}, "", err
	}

	if response.StatusCode() == http.StatusNotModified {
		if currentEtag == "" {
			return nil, time.Time{}, "", fmt.Errorf("got 'not modified' from api with empty Etag")
		}
		// do not need to update keys, only timestamp. Etag should be not empty
		return nil, time.Now(), currentEtag, nil
	}

	if !response.IsSuccess() {
		return nil, time.Time{}, "", fmt.Errorf("failed to get krl from api: [%d] %s",
			response.StatusCode(), base64.StdEncoding.EncodeToString(response.Body()))
	}

	body := response.Body()
	newEtag := strings.Trim(response.Header().Get("Etag"), "\"")
	checkSum, err := hex.DecodeString(newEtag)
	if err != nil {
		return nil, time.Time{}, "", fmt.Errorf("cannot parse checksum: %s (%s)", response.Header().Get("Etag"), err)
	}
	hashSum := md5.Sum(body)

	if !bytes.Equal(checkSum, hashSum[:]) {
		return nil, time.Time{}, "", fmt.Errorf("checksum does not match: %s != %s",
			hex.EncodeToString(checkSum), hex.EncodeToString(hashSum[:]))
	}

	k, err := ParseKRL(response.Body())
	if err != nil {
		return nil, time.Time{}, "", err
	}

	return k, time.Now(), newEtag, nil
}
