package encryption

import (
	"bytes"
	"fmt"
	"io"
	"io/ioutil"
	"testing"

	"github.com/stretchr/testify/assert"
)

// MockOLE instead of encrypting/decrypting, adds or removes 3 "!" characters from
// the start and end of each passed in byte stream
type MockOLE struct{}

func (m *MockOLE) NewEncryptor(encCtx map[string]string, w io.Writer) io.Writer {
	return &mockOLEEncrypter{w}
}

func (m *MockOLE) NewDecryptor(r io.Reader) io.Reader {
	return &mockOLEDecrypter{
		r:    r,
		read: false,
	}
}

type mockOLEEncrypter struct {
	w io.Writer
}

func (e *mockOLEEncrypter) Write(b []byte) (int, error) {
	x, err := e.w.Write([]byte("!!!"))
	if err != nil {
		return -1, err
	}

	y, err := e.w.Write(b)
	if err != nil {
		return -1, err
	}

	z, err := e.w.Write([]byte("!!!"))
	if err != nil {
		return -1, err
	}

	return x + y + z, nil
}

type mockOLEDecrypter struct {
	r    io.Reader
	read bool
}

func (d *mockOLEDecrypter) Read(p []byte) (int, error) {
	if d.read {
		return 0, io.EOF
	}

	b, err := ioutil.ReadAll(d.r)
	if err != nil {
		return -1, err
	}
	n := len(b)
	d.read = true

	// trim first three and last 3 characters
	return bytes.NewReader(b[3 : n-3]).Read(p)
}

func TestClient(t *testing.T) {
	t.Run("End-to-end client encryption and decryption", func(t *testing.T) {
		client := &AuthorizedFieldClient{
			ole: &MockOLE{},
		}

		t.Run("String encrypt/decrypt", func(t *testing.T) {
			plaintext := "foobar"

			e, err := client.EncryptString(dummyAuthContext(), plaintext)
			assert.NoError(t, err)
			assert.NotNil(t, e)
			assert.Equal(t, fmt.Sprintf("!!!%s!!!", plaintext), string(e))

			d, err := client.DecryptString(e)
			assert.NoError(t, err)
			assert.Equal(t, plaintext, d)
		})

		t.Run("Decrypt nil OLE payload", func(t *testing.T) {
			s, err := client.DecryptString(nil)
			assert.NoError(t, err)
			assert.Empty(t, s)

			f32, err := client.DecryptFloat32(nil)
			assert.NoError(t, err)
			assert.Empty(t, f32)

			f64, err := client.DecryptFloat64(nil)
			assert.NoError(t, err)
			assert.Empty(t, f64)

			i32, err := client.DecryptInt32(nil)
			assert.NoError(t, err)
			assert.Empty(t, i32)

			i64, err := client.DecryptInt64(nil)
			assert.NoError(t, err)
			assert.Empty(t, i64)
		})

		t.Run("Float32 encrypt/decrypt", func(t *testing.T) {
			float32ToEnc := float32(12.7)

			e, err := client.EncryptFloat32(dummyAuthContext(), float32ToEnc)
			assert.NoError(t, err)
			assert.NotNil(t, e)

			d, err := client.DecryptFloat32(e)
			assert.NoError(t, err)
			assert.Equal(t, float32ToEnc, d)
		})

		t.Run("Float64 encrypt/decrypt", func(t *testing.T) {
			float64ToEnc := float64(-1.234456e+78)

			e, err := client.EncryptFloat64(dummyAuthContext(), float64ToEnc)
			assert.NoError(t, err)
			assert.NotNil(t, e)

			d, err := client.DecryptFloat64(e)
			assert.NoError(t, err)
			assert.Equal(t, float64ToEnc, d)
		})

		t.Run("Int32 encrypt/decrypt", func(t *testing.T) {
			int32ToEnc := int32(42)

			e, err := client.EncryptInt32(dummyAuthContext(), int32ToEnc)
			assert.NoError(t, err)
			assert.NotNil(t, e)

			d, err := client.DecryptInt32(e)
			assert.NoError(t, err)
			assert.Equal(t, int32ToEnc, d)
		})

		t.Run("Int64 encrypt/decrypt", func(t *testing.T) {
			int64ToEnc := int64(1233454666666666485)

			e, err := client.EncryptInt64(dummyAuthContext(), int64ToEnc)
			assert.NoError(t, err)
			assert.NotNil(t, e)

			d, err := client.DecryptInt64(e)
			assert.NoError(t, err)
			assert.Equal(t, int64ToEnc, d)
		})
	})
}

func dummyAuthContext() map[string]string {
	return map[string]string{
		"EventType":   "FoobarUpdate",
		"Environment": "development",
		"MessageName": "MyGreatAuthorizedField",
		"FieldName":   "MyGreatFieldName",
	}
}
