package main

import (
	"context"
	"fmt"
	"net"
	"net/http/httptest"
	"strings"
	"testing"

	ilog "a.yandex-team.ru/infra/infractl/internal/log"
	"a.yandex-team.ru/library/go/core/log"
)

type blackboxMock struct {
	Login string
	UID   uint64
	Error error
}

func (b *blackboxMock) GetLoginFromOauth(ctx context.Context, oauth string, userip string) (string, uint64, error) {
	return b.Login, b.UID, b.Error
}

var invalidRemote = "12345"
var errUnknownUser = fmt.Errorf("unknown user")
var _, _, errInvalidRemote = net.SplitHostPort(invalidRemote)

func Test_getLogin(t *testing.T) {
	svcLog := ilog.ConfigureLogger(log.DebugLevel, true)

	tests := []struct {
		name          string
		authorization string
		remote        string
		response      string
		statusCode    int
		login         string
		uid           uint64
		err           error
	}{
		{
			name:       "empty request",
			statusCode: 401,
			response:   ErrorNoAuthorization.Error(),
		},
		{
			name:          "invalid auth format",
			authorization: "foobar",
			response:      ErrorInvalidMethod.Error(),
			statusCode:    401,
		},
		{
			name:          "invalid user info",
			authorization: "OAuth foobar",
			response:      ErrorUserInfoFailed(errUnknownUser).Error(),
			statusCode:    401,
			err:           errUnknownUser,
		},
		{
			name:          "invalid remote",
			authorization: "OAuth foobar",
			remote:        invalidRemote,
			response:      errInvalidRemote.Error(),
			statusCode:    400,
		},
		{
			name:          "valid request",
			authorization: "OAuth AAAA",
			response:      `{"login":"foo","uid":12345}`,
			statusCode:    200,
			login:         "foo",
			uid:           12345,
		},
	}
	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			request := httptest.NewRequest("GET", "/get-login", nil)
			request.Header.Set("Accept", "application/json")
			if len(tt.authorization) > 0 {
				request.Header.Set("Authorization", tt.authorization)
			}
			if len(tt.remote) > 0 {
				request.RemoteAddr = tt.remote
			}
			responseRecorder := httptest.NewRecorder()

			h := handler{
				log:            svcLog,
				blackboxClient: &blackboxMock{tt.login, tt.uid, tt.err},
			}
			h.ServeHTTP(responseRecorder, request)

			if responseRecorder.Code != tt.statusCode {
				t.Errorf("Want status: %v, got status: %v", tt.statusCode, responseRecorder.Code)
			}

			body := strings.TrimSpace(responseRecorder.Body.String())
			if len(tt.response) > 0 && body != tt.response {
				t.Errorf("Want body: %q, got body: %q", tt.response, body)
			}
		})
	}
}
