package job

import (
	"a.yandex-team.ru/infra/maxwell/go/pkg/walle"
	"testing"

	"github.com/golang/protobuf/ptypes"
	"github.com/stretchr/testify/assert"

	"a.yandex-team.ru/infra/maxwell/go/internal/storages"
	"a.yandex-team.ru/infra/maxwell/go/internal/workingset"
	pb "a.yandex-team.ru/infra/maxwell/go/proto"
	"a.yandex-team.ru/library/go/core/log"
	"a.yandex-team.ru/library/go/core/log/zap"
)

func TestController_RemoveFinished(t *testing.T) {
	tests := []struct {
		name            string
		wantErr         bool
		hosts           []*pb.Host
		tasks           []*pb.Task
		lenAfterRemoval int
	}{{
		name:    "removes single record",
		wantErr: false,
		hosts: []*pb.Host{{
			Health: map[string]string{},
			Location: &pb.Host_Location{
				Country: "RU",
				City:    "SAS",
				Rack:    "1",
			},
			Restrictions:     make([]string, 0),
			FirmwareProblems: make([]string, 0),
			Project:          "rtc",
			Hostname:         "sas1-0001",
			Status:           "ready",
		}},
		tasks: []*pb.Task{{
			Meta: &pb.Task_Meta{
				CreationTime: ptypes.TimestampNow(),
				Mtime:        ptypes.TimestampNow(),
			},
			State:    pb.Task_FINISHED,
			Hostname: "sas1-0001",
			Group:    "1",
		}},
		lenAfterRemoval: 0,
	}}
	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			hostsStorage := storages.NewHostsInMemory(tt.hosts...)
			jobStorage := &storages.JobInMemory{
				TasksInMemory:     storages.NewTasksInMemory(tt.tasks...),
				ProcessedInMemory: storages.NewProcessedInMemory(),
			}
			j := &pb.Job{Spec: &pb.Job_Spec{PostCondition: &pb.PostCondition{}}}
			ws := workingset.New("test", jobStorage, hostsStorage, j.Spec)
			if err := ws.Restore(); err != nil {
				t.Errorf("failed to restore working set: %s", err)
			}
			assert.Len(t, ws.Records(), len(tt.tasks))
			c := &Controller{
				j:          j,
				dryRun:     true,
				ws:         ws,
				jobStorage: jobStorage,
			}
			l, _ := zap.New(zap.CLIConfig(log.DebugLevel))
			if err := c.RemoveFinished(l); (err != nil) != tt.wantErr {
				t.Errorf("RemoveFinished() error = %v, wantErr %v", err, tt.wantErr)
			}
			assert.Len(t, ws.Records(), tt.lenAfterRemoval)
		})
	}
}

func TestController_enforceTasks(t *testing.T) {
	now := ptypes.TimestampNow()
	tests := []struct {
		name      string
		hosts     []*pb.Host
		tasks     []*pb.Task
		after     []*pb.Task
		makeWalle func(*walle.Fake)
	}{{
		name: "enforce",
		hosts: []*pb.Host{{
			Health: map[string]string{},
			Location: &pb.Host_Location{
				Country: "RU",
				City:    "SAS",
				Rack:    "1",
			},
			Restrictions:     make([]string, 0),
			FirmwareProblems: make([]string, 0),
			Project:          "rtc",
			Hostname:         "sas1-0001",
			Status:           "rebooting",
		}},
		tasks: []*pb.Task{{
			Meta: &pb.Task_Meta{
				Action:       "reboot",
				CreationTime: now,
				Mtime:        now,
			},
			State:    pb.Task_RUNNING,
			Enforce:  pb.Task_NEED_ENFORCE,
			Hostname: "sas1-0001",
			Group:    "1",
		}},
		after: []*pb.Task{{
			Meta: &pb.Task_Meta{
				CreationTime: now,
				Mtime:        now,
				Action:       "reboot",
			},
			State:    pb.Task_RUNNING,
			Enforce:  pb.Task_ENFORCED,
			Hostname: "sas1-0001",
			Group:    "1",
		}},
		makeWalle: func(fake *walle.Fake) {
			fake.On("RebootHost", &walle.RebootHostRequest{
				HostID: "sas1-0001",
				Params: walle.RebootHostParams{
					IgnoreMaintenance: false,
				},
				Body: walle.RebootHostBody{
					IgnoreCms:       true,
					Reason:          "Maxwell: reboot ''",
					WithAutoHealing: true,
				},
			}).Return(make([]byte, 0), nil).Once()
		},
	}, {
		name: "no enforce",
		hosts: []*pb.Host{{
			Health: map[string]string{},
			Location: &pb.Host_Location{
				Country: "RU",
				City:    "SAS",
				Rack:    "1",
			},
			Restrictions:     make([]string, 0),
			FirmwareProblems: make([]string, 0),
			Project:          "rtc",
			Hostname:         "sas1-0001",
			Status:           "rebooting",
		}},
		tasks: []*pb.Task{{
			Meta: &pb.Task_Meta{
				Action:       "reboot",
				CreationTime: now,
				Mtime:        now,
			},
			State:    pb.Task_RUNNING,
			Enforce:  pb.Task_GRACEFUL,
			Hostname: "sas1-0001",
			Group:    "1",
		}},
		after: []*pb.Task{{
			Meta: &pb.Task_Meta{
				CreationTime: now,
				Mtime:        now,
				Action:       "reboot",
			},
			State:    pb.Task_RUNNING,
			Enforce:  pb.Task_GRACEFUL,
			Hostname: "sas1-0001",
			Group:    "1",
		}},
		makeWalle: func(fake *walle.Fake) {},
	}}
	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			hostsStorage := storages.NewHostsInMemory(tt.hosts...)
			jobStorage := &storages.JobInMemory{
				TasksInMemory:     storages.NewTasksInMemory(tt.tasks...),
				ProcessedInMemory: storages.NewProcessedInMemory(),
			}
			j := &pb.Job{
				Spec: &pb.Job_Spec{
					PostCondition: &pb.PostCondition{},
				}}
			ws := workingset.New("test", jobStorage, hostsStorage, j.Spec)
			if err := ws.Restore(); err != nil {
				t.Errorf("failed to restore working set: %s", err)
			}
			assert.Len(t, ws.Records(), len(tt.tasks))
			c := &Controller{
				j:          j,
				dryRun:     false,
				ws:         ws,
				jobStorage: jobStorage,
			}
			w := &walle.Fake{}
			tt.makeWalle(w)
			l, _ := zap.New(zap.CLIConfig(log.DebugLevel))
			if err := c.enforceTasks(w, l); err != nil {
				t.Error(err)
			}
			assert.Equal(t, tt.after, c.ws.Records().Tasks())
			w.AssertExpectations(t)
		})
	}
}

func TestController_cancelTasks(t *testing.T) {
	now := ptypes.TimestampNow()
	tests := []struct {
		name      string
		hosts     []*pb.Host
		tasks     []*pb.Task
		after     []*pb.Task
		makeWalle func(*walle.Fake)
	}{{
		name: "cancel",
		hosts: []*pb.Host{{
			Health: map[string]string{},
			Location: &pb.Host_Location{
				Country: "RU",
				City:    "SAS",
				Rack:    "1",
			},
			Restrictions:     make([]string, 0),
			FirmwareProblems: make([]string, 0),
			Project:          "rtc",
			Hostname:         "sas1-0001",
			Status:           "rebooting",
		}},
		tasks: []*pb.Task{{
			Meta: &pb.Task_Meta{
				Action:       "reboot",
				CreationTime: now,
				Mtime:        now,
			},
			State:    pb.Task_RUNNING,
			Enforce:  pb.Task_NEED_CANCEL,
			Hostname: "sas1-0001",
			Group:    "1",
		}},
		after: []*pb.Task{{
			Meta: &pb.Task_Meta{
				CreationTime: now,
				Mtime:        now,
				Action:       "reboot",
			},
			State:    pb.Task_RUNNING,
			Enforce:  pb.Task_CANCEL_WAITING,
			Hostname: "sas1-0001",
			Group:    "1",
		}},
		makeWalle: func(fake *walle.Fake) {
			fake.On("CancelTask", "sas1-0001").Return(nil).Once()
		},
	}, {
		name: "no cancel",
		hosts: []*pb.Host{{
			Health: map[string]string{},
			Location: &pb.Host_Location{
				Country: "RU",
				City:    "SAS",
				Rack:    "1",
			},
			Restrictions:     make([]string, 0),
			FirmwareProblems: make([]string, 0),
			Project:          "rtc",
			Hostname:         "sas1-0001",
			Status:           "rebooting",
		}},
		tasks: []*pb.Task{{
			Meta: &pb.Task_Meta{
				Action:       "reboot",
				CreationTime: now,
				Mtime:        now,
			},
			State:    pb.Task_RUNNING,
			Enforce:  pb.Task_GRACEFUL,
			Hostname: "sas1-0001",
			Group:    "1",
		}},
		after: []*pb.Task{{
			Meta: &pb.Task_Meta{
				CreationTime: now,
				Mtime:        now,
				Action:       "reboot",
			},
			State:    pb.Task_RUNNING,
			Enforce:  pb.Task_GRACEFUL,
			Hostname: "sas1-0001",
			Group:    "1",
		}},
		makeWalle: func(fake *walle.Fake) {},
	}}
	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			hostsStorage := storages.NewHostsInMemory(tt.hosts...)
			jobStorage := &storages.JobInMemory{
				TasksInMemory:     storages.NewTasksInMemory(tt.tasks...),
				ProcessedInMemory: storages.NewProcessedInMemory(),
			}
			j := &pb.Job{
				Spec: &pb.Job_Spec{
					PostCondition: &pb.PostCondition{},
				}}
			ws := workingset.New("test", jobStorage, hostsStorage, j.Spec)
			if err := ws.Restore(); err != nil {
				t.Errorf("failed to restore working set: %s", err)
			}
			assert.Len(t, ws.Records(), len(tt.tasks))
			c := &Controller{
				j:          j,
				dryRun:     false,
				ws:         ws,
				jobStorage: jobStorage,
			}
			w := &walle.Fake{}
			tt.makeWalle(w)
			l, _ := zap.New(zap.CLIConfig(log.DebugLevel))
			if err := c.cancelTasks(w, l); err != nil {
				t.Error(err)
			}
			assert.Equal(t, tt.after, c.ws.Records().Tasks())
			w.AssertExpectations(t)
		})
	}
}
