package decryptor

import (
	"context"
	"encoding/hex"
	"encoding/json"
	"errors"
	"io/ioutil"
	"os"
	"sync"
	"time"

	"a.yandex-team.ru/passport/infra/daemons/yasmsd/internal/logs"
	"a.yandex-team.ru/passport/shared/golibs/utils"
)

type KeyringConfig struct {
	KeysFile         string         `json:"keys_file"`
	KeysReloadPeriod utils.Duration `json:"keys_reload_period"`
}

/*
Работа с ключами для расшифровки сообщений
*/

type Keyring struct {
	sync.RWMutex
	config       *KeyringConfig // Конфигурация
	logs         *logs.Logs     // Логи
	keys         map[int][]byte // Ключи
	lastModified time.Time      // Время последней модификации файла с ключами
}

type KeyData struct {
	ID   int    `json:"id"`
	Body string `json:"body"`
}

func NewKeyring(config *KeyringConfig, loggers *logs.Logs) (*Keyring, error) {
	if config.KeysFile == "" {
		return nil, errors.New("empty keys file path")
	}

	if config.KeysReloadPeriod.Duration <= 0 {
		return nil, errors.New("invalid keys reload period")
	}

	keyring := &Keyring{
		config: config,
		logs:   loggers,
	}

	err := keyring.update()
	if err != nil {
		return nil, err
	}

	keyring.logs.General.WriteDebug(logs.ComponentKeyring, "initialized, keys file %s, reload period %d",
		keyring.config.KeysFile, keyring.config.KeysReloadPeriod)

	return keyring, err
}

func (keyring *Keyring) getKey(id int) []byte {
	keyring.RLock()
	defer keyring.RUnlock()

	key, found := keyring.keys[id]
	if found {
		return key
	}

	return nil
}

func (keyring *Keyring) update() error {
	stat, err := os.Stat(keyring.config.KeysFile)
	if err != nil {
		return errors.New("failed to stat keys file " + keyring.config.KeysFile)
	}

	if keyring.keys != nil && keyring.lastModified == stat.ModTime() {
		return nil
	}

	jsonFile, err := ioutil.ReadFile(keyring.config.KeysFile)
	if err != nil {
		return err
	}

	var loadedKeys []KeyData
	err = json.Unmarshal(jsonFile, &loadedKeys)
	if err != nil {
		return errors.New("Failed to parse keys file " + keyring.config.KeysFile)
	}

	if len(loadedKeys) == 0 {
		return errors.New("Keys file is empty: " + keyring.config.KeysFile)
	}

	minKeyID := loadedKeys[0].ID
	maxKeyID := loadedKeys[0].ID

	var newKeys = make(map[int][]byte)
	for _, k := range loadedKeys {
		bytes, err := hex.DecodeString(k.Body)
		if err != nil || len(bytes) == 0 {
			return errors.New("Bad hex key: " + k.Body)
		}
		newKeys[k.ID] = bytes
		if k.ID < minKeyID {
			minKeyID = k.ID
		}
		if k.ID > maxKeyID {
			maxKeyID = k.ID
		}
	}

	keyring.Lock()
	defer keyring.Unlock()

	keyring.keys = newKeys

	keyring.lastModified = stat.ModTime()
	keyring.logs.General.WriteDebug(logs.ComponentKeyring, "loaded %d new keys, with ids [%d,%d]",
		len(newKeys), minKeyID, maxKeyID)

	return nil
}

func (keyring *Keyring) Monitor(ctx context.Context, wg *sync.WaitGroup) {
	keyring.logs.General.WriteDebug(logs.ComponentKeyring, "monitor started: %s", keyring.config.KeysFile)

	ticker := time.NewTicker(keyring.config.KeysReloadPeriod.Duration)

LOOP:
	for {
		select {
		case <-ticker.C:
			err := keyring.update()
			if err != nil {
				keyring.logs.General.WriteWarning(logs.ComponentKeyring, "Failed to update keyring: %w", err)
			}
			continue
		case <-ctx.Done():
			break LOOP
		}
	}

	ticker.Stop()

	keyring.logs.General.WriteDebug(logs.ComponentKeyring, "monitor stopped: %s", keyring.config.KeysFile)

	wg.Done()
}
