package keys

import (
	"encoding/base64"
	"encoding/hex"
	"encoding/json"
	"io/ioutil"
	"strconv"

	"golang.org/x/xerrors"
)

const (
	HexKey    = "hex"
	Base64Key = "base64"
)

type Config struct {
	Filename     string `json:"keys_file"`
	DefaultKey   string `json:"default_key"`
	KeysEncoding string `json:"keys_encoding"`
}

type KeyMap struct {
	keys         map[string][]byte
	defaultKeyID string
}

func InitKeyMapWithDefaultKey(cfg Config) (*KeyMap, error) {
	if cfg.DefaultKey == "" {
		return nil, xerrors.Errorf("Default key is not configured")
	}
	return InitKeyMap(cfg)
}

func InitKeyMap(cfg Config) (*KeyMap, error) {
	if cfg.Filename == "" {
		return nil, xerrors.Errorf("Keys file is not configured")
	}

	jsonFile, err := ioutil.ReadFile(cfg.Filename)
	if err != nil {
		return nil, xerrors.Errorf("Failed to read keys file '%s': %w", cfg.Filename, err)
	}

	var keys map[string]string
	err = json.Unmarshal(jsonFile, &keys)
	if err != nil {
		return nil, xerrors.Errorf("Failed to parse keys file '%s': %w", cfg.Filename, err)
	}

	var keyMap KeyMap

	var addKey func(id string, key string) error
	switch cfg.KeysEncoding {
	case HexKey:
		addKey = keyMap.AddHexKey
	case Base64Key:
		addKey = keyMap.AddBase64Key
	default:
		return nil, xerrors.Errorf("Unsupported keys encoding type: %s", cfg.KeysEncoding)
	}

	for id, body := range keys {
		if id == "" || body == "" {
			return nil, xerrors.Errorf("Empty key id or body: '%s':'%s'", id, body)
		}
		err = addKey(id, body)
		if err != nil {
			return nil, err
		}
	}

	if cfg.DefaultKey != "" {
		if err := keyMap.SetDefaultKeyID(cfg.DefaultKey); err != nil {
			return nil, err
		}
	}

	return &keyMap, nil
}

func CreateKeyMap() *KeyMap {
	return &KeyMap{}
}

func (m *KeyMap) GetDefaultKeyID() string {
	return m.defaultKeyID
}

func (m *KeyMap) SetDefaultKeyID(id string) error {
	if _, found := m.keys[id]; !found {
		return xerrors.Errorf("No key for default id '%s'", id)
	}
	m.defaultKeyID = id
	return nil
}

func (m *KeyMap) SetDefaultKeyIDNum(id uint64) error {
	return m.SetDefaultKeyID(strconv.FormatUint(id, 10))
}

func (m *KeyMap) GetDefaultKey() []byte {
	return m.GetKey(m.defaultKeyID)
}

func (m *KeyMap) GetKey(id string) []byte {
	key, found := m.keys[id]
	if found {
		return key
	}
	return nil
}

func (m *KeyMap) GetKeyNum(id uint64) []byte {
	return m.GetKey(strconv.FormatUint(id, 10))
}

func (m *KeyMap) AddKey(id string, key []byte) {
	if m.keys == nil {
		m.keys = make(map[string][]byte)
	}
	m.keys[id] = key
}

func (m *KeyMap) AddHexKey(id string, key string) error {
	bytes, err := hex.DecodeString(key)
	if err != nil || len(bytes) == 0 {
		return xerrors.Errorf("Bad hex key '%s'", key)
	}
	m.AddKey(id, bytes)
	return nil
}

func (m *KeyMap) AddBase64Key(id string, key string) error {
	bytes, err := base64.StdEncoding.DecodeString(key)
	if err != nil || len(bytes) == 0 {
		return xerrors.Errorf("Bad base64 key '%s'", key)
	}
	m.AddKey(id, bytes)
	return nil
}
