package channels_test

import (
	"context"
	"encoding/json"
	"fmt"
	"net/http"
	"net/http/httptest"
	"reflect"
	"strconv"
	"sync"
	"testing"

	"code.justin.tv/foundation/twitchclient"
	"code.justin.tv/web/users-service/client/channels"
	"code.justin.tv/web/users-service/models"

	. "github.com/smartystreets/goconvey/convey"
)

const (
	requestsMade  = "requests-made"
	requestErrors = "request-errors"
)

func TestGetAll(t *testing.T) {
	Convey("when fetching channels by id", t, func() {
		statMap := map[string]int{}
		channelMap := map[string]*models.ChannelProperties{}
		for i := 0; i < 250; i++ {
			channelMap[strconv.Itoa(i)] = &models.ChannelProperties{
				ID: uint64(i),
			}
		}
		ts := initTestServer(channelMap, "id", statMap)
		defer ts.Close()

		c, err := channels.NewClient(twitchclient.ClientConf{
			Host: ts.URL,
		})
		So(err, ShouldBeNil)

		ctx := context.Background()

		Convey("and the server is available", func() {
			Convey("and there are no batches", func() {
				_, err := c.GetAll(ctx, []string{}, nil)

				So(err, ShouldBeNil)
				So(accessStatMap(statMap, requestsMade), ShouldEqual, 0)
			})

			Convey("and there's a single batch", func() {
				channels, err := c.GetAll(ctx, []string{"4", "5"}, nil)

				So(err, ShouldBeNil)
				So(channels, ShouldNotBeNil)
				So(len(channels.Results), ShouldEqual, 2)
				channelMap := channels.ToMapByID()
				So(len(channelMap), ShouldEqual, len(channels.Results))
				for i := range channels.Results {
					channel := channels.Results[i]
					So(reflect.DeepEqual(*channelMap[strconv.Itoa(channel.ID)], channel), ShouldBeTrue)
				}
				So(accessStatMap(statMap, requestsMade), ShouldEqual, 1)
			})

			Convey("and there are three batches", func() {
				ids := []string{}
				for _, channel := range channelMap {
					ids = append(ids, strconv.Itoa(int(channel.ID)))
				}

				channels, err := c.GetAll(ctx, ids, nil)

				So(err, ShouldBeNil)
				So(channels, ShouldNotBeNil)
				So(len(channels.Results), ShouldEqual, len(ids))
				for _, id := range ids {
					var found bool
					for _, channel := range channels.Results {
						if strconv.Itoa(int(channel.ID)) == id {
							found = true
							break
						}
					}
					So(found, ShouldBeTrue)
				}
				So(accessStatMap(statMap, requestsMade), ShouldEqual, 3)
			})

			Convey("and there are multiple batches but some fail", func() {
				ids := []string{}
				for _, channel := range channelMap {
					ids = append(ids, strconv.Itoa(int(channel.ID)))
				}
				ids = append(ids, "")

				_, err := c.GetAll(ctx, ids, nil)

				So(err, ShouldNotBeNil)
				So(accessStatMap(statMap, requestErrors), ShouldEqual, 1)
			})
		})
	})
}

func TestGetAllByLogin(t *testing.T) {
	Convey("when fetching channels by login", t, func() {
		statMap := map[string]int{}
		channelMap := map[string]*models.ChannelProperties{}
		for i := 0; i < 250; i++ {
			channelMap[strconv.Itoa(i)] = &models.ChannelProperties{
				Name: strconv.Itoa(i),
			}
		}
		ts := initTestServer(channelMap, "name", statMap)
		defer ts.Close()

		c, err := channels.NewClient(twitchclient.ClientConf{
			Host: ts.URL,
		})
		So(err, ShouldBeNil)

		ctx := context.Background()

		Convey("and the server is available", func() {
			Convey("and there are no batches", func() {
				_, err := c.GetAllByLogin(ctx, []string{}, nil)

				So(err, ShouldBeNil)
				So(accessStatMap(statMap, requestsMade), ShouldEqual, 0)
			})

			Convey("and there's a single batch", func() {
				channels, err := c.GetAllByLogin(ctx, []string{"4", "5"}, nil)

				So(err, ShouldBeNil)
				So(channels, ShouldNotBeNil)
				So(len(channels.Results), ShouldEqual, 2)
				So(accessStatMap(statMap, requestsMade), ShouldEqual, 1)
			})

			Convey("and there are three batches", func() {
				logins := []string{}
				for _, channel := range channelMap {
					logins = append(logins, channel.Name)
				}

				channels, err := c.GetAllByLogin(ctx, logins, nil)

				So(err, ShouldBeNil)
				So(channels, ShouldNotBeNil)
				So(len(channels.Results), ShouldEqual, len(logins))
				for _, login := range logins {
					var found bool
					for _, channel := range channels.Results {
						if channel.Name == login {
							found = true
							break
						}
					}
					So(found, ShouldBeTrue)
				}
				So(accessStatMap(statMap, requestsMade), ShouldEqual, 3)
			})

			Convey("and there are multiple batches but some fail", func() {
				logins := []string{}
				for _, channel := range channelMap {
					logins = append(logins, channel.Name)
				}
				logins = append(logins, "")

				_, err := c.GetAllByLogin(ctx, logins, nil)

				So(err, ShouldNotBeNil)
				So(accessStatMap(statMap, requestErrors), ShouldEqual, 1)
			})
		})
	})
}

var mutex sync.Mutex

// necessary to prevent testing race condition failure
func accessStatMap(statMap map[string]int, key string) int {
	mutex.Lock()
	defer mutex.Unlock()

	return statMap[key]
}

func initTestServer(channelMap map[string]*models.ChannelProperties, key string, statMap map[string]int) *httptest.Server {
	return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		mutex.Lock()
		defer mutex.Unlock()

		statMap[requestsMade]++
		channels := []models.ChannelProperties{}
		identifiers := r.URL.Query()[key]
		if len(identifiers) > 100 {
			panic(fmt.Errorf("batching did not occur: expected less than or equal to %d elements, received %d", 100, len(identifiers)))
		}
		for _, identifier := range identifiers {
			if identifier == "" {
				statMap[requestErrors]++
				w.WriteHeader(http.StatusInternalServerError)
				return
			}
			channel, ok := channelMap[identifier]
			if ok {
				channels = append(channels, *channel)
			}
		}

		result := &models.ChannelPropertiesResult{Results: channels}

		err := json.NewEncoder(w).Encode(result)
		if err != nil {
			w.WriteHeader(http.StatusInternalServerError)
		}
	}))
}
