package keys

import (
	"bufio"
	"bytes"
	"context"
	"encoding/json"
	"fmt"
	"io"
	"io/ioutil"
	"net/url"
	"path"
	"sync"
	"time"

	"github.com/go-resty/resty/v2"
	"golang.org/x/crypto/ssh"

	"a.yandex-team.ru/passport/infra/daemons/tvmcert/internal/cache"
	"a.yandex-team.ru/passport/infra/daemons/tvmcert/internal/task"
	"a.yandex-team.ru/passport/infra/daemons/tvmcert/internal/utils"
	"a.yandex-team.ru/passport/shared/golibs/juggler"
	"a.yandex-team.ru/passport/shared/golibs/logger"
	shared_utils "a.yandex-team.ru/passport/shared/golibs/utils"
)

const (
	publicKeysFile = "public_keys.json"
)

type Fetcher interface {
	task.Runner
	CheckPublicKey(publicKey ssh.PublicKey) error
	GetJugglerStatus() *juggler.Status
}

type Holder struct {
	Keys      []ssh.PublicKey
	Timestamp time.Time
	rwLock    sync.RWMutex
}

func (holder *Holder) UpdateWithTimestamp(keys []ssh.PublicKey, ts time.Time) {
	holder.rwLock.Lock()
	defer holder.rwLock.Unlock()
	holder.Keys = keys
	holder.Timestamp = ts
}

func (holder *Holder) Update(keys []ssh.PublicKey) {
	holder.UpdateWithTimestamp(keys, time.Now())
}

func (holder *Holder) GetKeys() ([]ssh.PublicKey, time.Time) {
	holder.rwLock.RLock()
	defer holder.rwLock.RUnlock()
	return holder.Keys, holder.Timestamp
}

type PublicKeyCache struct {
	Keys      [][]byte  `json:"keys"`
	Timestamp time.Time `json:"timestamp"`
}

type SkottyFetcherConfig struct {
	SkottyURL             string                 `json:"skotty_url"`
	CacheDir              string                 `json:"cache_dir"`
	KeysUpdatePeriod      *shared_utils.Duration `json:"keys_update_period"`
	JugglerKeyAgeWarn     *shared_utils.Duration `json:"juggler_key_age_warn"`
	JugglerKeyAgeCritical *shared_utils.Duration `json:"juggler_key_age_critical"`
}

func ParsePublicKeys(raw io.ReadCloser) ([]ssh.PublicKey, error) {
	var keys []ssh.PublicKey
	scanner := bufio.NewScanner(raw)
	for scanner.Scan() {
		ca, _, _, _, err := ssh.ParseAuthorizedKey(scanner.Bytes())
		if err != nil {
			return nil, fmt.Errorf("failed to parse authorized key %q: %w", scanner.Text(), err)
		}
		keys = append(keys, ca)
	}
	return keys, scanner.Err()
}

type SkottyFetcher struct {
	keysHolder             Holder
	keysTTL                time.Duration
	cache                  *cache.Manager
	client                 *resty.Client
	jugglerWarnTimeout     time.Duration
	jugglerCriticalTimeout time.Duration
	lastError              utils.ThreadSafeError
}

func NewSkottyFetcher(config SkottyFetcherConfig) (*SkottyFetcher, error) {
	var err error
	if config.SkottyURL == "" {
		config.SkottyURL = "https://skotty.sec.yandex-team.ru"
	}

	_, err = url.Parse(config.SkottyURL)
	if err != nil {
		return nil, err
	}

	if config.KeysUpdatePeriod == nil {
		config.KeysUpdatePeriod = &shared_utils.Duration{Duration: 24 * time.Hour}
	}

	if config.JugglerKeyAgeWarn == nil {
		config.JugglerKeyAgeWarn = &shared_utils.Duration{Duration: 24 * time.Hour}
	}

	if config.JugglerKeyAgeCritical == nil {
		config.JugglerKeyAgeCritical = &shared_utils.Duration{Duration: 48 * time.Hour}
	}

	cacheManager, err := cache.NewManager(path.Join(config.CacheDir, publicKeysFile))
	if err != nil {
		return nil, err
	}

	fetcher := &SkottyFetcher{
		keysHolder:             Holder{},
		keysTTL:                config.KeysUpdatePeriod.Duration,
		cache:                  cacheManager,
		client:                 resty.New().SetBaseURL(config.SkottyURL),
		jugglerWarnTimeout:     config.JugglerKeyAgeWarn.Duration,
		jugglerCriticalTimeout: config.JugglerKeyAgeCritical.Duration,
	}

	return fetcher, fetcher.Init()
}

func (fetcher *SkottyFetcher) Init() error {
	if err := fetcher.UpdateKeys(fetcher.FetchKeysFromAPI); err != nil {
		logger.Log().Warnf("cannot get keys from skotty: %s", err)
		if err = fetcher.UpdateKeys(fetcher.FetchKeysFromCache); err != nil {
			logger.Log().Errorf("cannot get keys from cache or skotty: %s", err)
			return err
		}
	}

	if err := fetcher.WriteCache(); err != nil {
		logger.Log().Errorf("cannot write cache: %s", err)
		return err
	}

	return nil
}

func (fetcher *SkottyFetcher) CheckPublicKey(publicKey ssh.PublicKey) error {
	userKey := publicKey.Marshal()
	keys, _ := fetcher.keysHolder.GetKeys()
	for _, caKey := range keys {
		if bytes.Equal(caKey.Marshal(), userKey) {
			return nil
		}
	}
	return fmt.Errorf("invalid certificate key: %s", publicKey)
}

func (fetcher *SkottyFetcher) GetJugglerStatus() *juggler.Status {
	status := juggler.NewStatusOk()
	_, ts := fetcher.keysHolder.GetKeys()

	err := fetcher.lastError.GetError()
	if elapsed := time.Since(ts); elapsed > fetcher.jugglerCriticalTimeout {
		status.Update(juggler.NewStatus(juggler.Critical, "public keys are %.2f seconds old; last error: %s", elapsed.Seconds(), err))
	} else if elapsed = time.Since(ts); elapsed > fetcher.jugglerWarnTimeout {
		status.Update(juggler.NewStatus(juggler.Warning, "public keys are %.2f seconds old; last error: %s", elapsed.Seconds(), err))
	}

	return status
}

func (fetcher *SkottyFetcher) Run() error {
	_, ts := fetcher.keysHolder.GetKeys()
	if time.Since(ts) < fetcher.keysTTL {
		return nil
	}

	logger.Log().Debugf("trying to update public keys...")
	err := fetcher.UpdateKeys(fetcher.FetchKeysFromAPI)
	if err != nil {
		return err
	}
	err = fetcher.WriteCache()
	if err != nil {
		return err
	}
	logger.Log().Debugf("successfully updated public keys")

	return nil
}

func (fetcher *SkottyFetcher) WriteCache() error {
	pkCache := PublicKeyCache{}
	keys, ts := fetcher.keysHolder.GetKeys()

	pkCache.Timestamp = ts
	for _, key := range keys {
		publicKey := key.Marshal()
		pkCache.Keys = append(pkCache.Keys, publicKey)
	}

	data, err := json.Marshal(pkCache)
	if err != nil {
		return err
	}

	err = fetcher.cache.TryWrite(data)
	if err != nil {
		return err
	}
	return nil
}

func (fetcher *SkottyFetcher) UpdateKeys(fetch func() ([]ssh.PublicKey, time.Time, error)) error {
	keys, ts, err := fetch()
	if err != nil {
		fetcher.lastError.SetError(err)
		return err
	}
	if keys == nil {
		err = fmt.Errorf("no keys found")
		fetcher.lastError.SetError(err)
		return err
	}
	fetcher.keysHolder.UpdateWithTimestamp(keys, ts)
	return nil
}

func (fetcher *SkottyFetcher) FetchKeysFromCache() ([]ssh.PublicKey, time.Time, error) {
	data, err := fetcher.cache.TryRead()
	if err != nil {
		return nil, time.Time{}, err
	}

	pkCache := PublicKeyCache{}
	err = json.Unmarshal(data, &pkCache)
	if err != nil {
		return nil, time.Time{}, err
	}

	var keys []ssh.PublicKey
	for _, key := range pkCache.Keys {
		publicKey, err := ssh.ParsePublicKey(key)
		if err != nil {
			return nil, time.Time{}, err
		}
		keys = append(keys, publicKey)
	}

	return keys, pkCache.Timestamp, nil
}

func (fetcher *SkottyFetcher) FetchKeysFromAPI() ([]ssh.PublicKey, time.Time, error) {
	ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
	defer cancel()
	response, err := fetcher.client.R().
		SetContext(ctx).
		SetDoNotParseResponse(true).
		Get("/api/v1/ca/pub/ssh")
	if err != nil {
		return nil, time.Time{}, err
	}

	defer func() {
		err = response.RawBody().Close()
		if err != nil {
			logger.Log().Errorf("error closing body: %s", err)
		}
	}()

	if !response.IsSuccess() {
		body, _ := ioutil.ReadAll(response.RawBody())
		return nil, time.Time{}, fmt.Errorf("failed to get keys from api: [%d] %s", response.StatusCode(), body)
	}

	keys, err := ParsePublicKeys(response.RawBody())
	if err != nil {
		return nil, time.Time{}, err
	}

	return keys, time.Now(), nil
}
