package updater

import (
	"encoding/json"
	"testing"

	"code.justin.tv/web/jax/common/config"
	"code.justin.tv/web/jax/common/stats"
	"code.justin.tv/web/jax/db"

	"github.com/stretchr/testify/mock"
)

type mockPartnershipsBackend struct {
	mock.Mock
}

func (m *mockPartnershipsBackend) IsPartner(id int) (bool, error) {
	args := m.Called(id)
	return args.Get(0).(bool), args.Error(1)
}

func TestPartnerships(t *testing.T) {
	subject, backend := setupPartnershipsTest(t)
	input := []db.ChannelResult{
		db.ChannelResult{
			Channel:    "wickd",
			Properties: map[string]interface{}{"rails": map[string]interface{}{"channel_id": json.Number("1")}},
		},
		db.ChannelResult{
			Channel:    "telepresence",
			Properties: map[string]interface{}{"rails": map[string]interface{}{"channel_id": json.Number("2")}},
		},
	}
	backend.On("IsPartner", 1).Return(true, nil)
	backend.On("IsPartner", 2).Return(false, nil)
	out, err := subject.Fetch(input)
	if err != nil {
		t.Errorf("Fetch failed: %v", err)
	}
	if len(out) != 2 {
		t.Errorf("output should be length 2")
	}
	checkChannel(t, out, "wickd", true)
	checkChannel(t, out, "telepresence", false)
}

func setupPartnershipsTest(t *testing.T) (*partnershipsUpdater, *mockPartnershipsBackend) {
	subject := &partnershipsUpdater{}
	conf := &config.Config{
		Environment:    "development",
		StatsHostPort:  "graphite.internal.justin.tv:8125",
		MoneypennyHost: "https://moneypenny.dev.us-west2.internal.justin.tv",
	}
	stats := stats.InitStatsd(conf)
	subject.Init(conf, stats)
	backend := &mockPartnershipsBackend{}
	subject.Backend = backend
	return subject, backend
}

func checkChannel(t *testing.T, output map[string]map[string]interface{}, channel string, expectation bool) {
	if props, ok := output[channel]; ok {
		if fields, ok := props["partnerships"].(map[string]interface{}); ok {
			if partnerProgram, ok := fields["partner_program"].(bool); ok {
				if partnerProgram == expectation {
					return
				}
				t.Errorf(`partner_program is incorrect for %s`, channel)
			}
			t.Errorf(`partner_program get failed for %s`, channel)
		}
		t.Errorf(`partnerships get failed for %s`, channel)
	}
	t.Errorf(`channel get failed for %s`, channel)
}
