package ussh

import (
	"a.yandex-team.ru/solomon/libs/go/color"
	"a.yandex-team.ru/solomon/libs/go/hosts"
	"a.yandex-team.ru/solomon/libs/go/utils"
	"fmt"
	"log"
	"math/rand"
	"os"
	"path"
	"sync/atomic"
	"time"
)

type ClusterClient struct {
	addresses []hosts.Address
	client    Client
	logsDir   string
}

type CmdResult struct {
	hostname string
	output   []byte
	err      error
}

type ClientRunResult struct {
	FailedCount int
}

type Call func(string) error

func NewClusterClient(addresses []hosts.Address, useBastion bool, logsDir string) *ClusterClient {
	if logsDir != "" {
		if err := os.MkdirAll(logsDir, os.ModePerm); err != nil {
			log.Fatalf("cannot create logs dir %s: %v", logsDir, err)
		}
	}
	return &ClusterClient{addresses: addresses, client: NewClient(useBastion), logsDir: logsDir}
}

func runRemoteCmdWithRetries(client Client, address *hosts.Address, cmd string, logRetries bool, preCall Call, postCall Call) *CmdResult {
	var result *CmdResult
	maxRetries := 5
	for i := 0; i < maxRetries; i++ {
		preCallFailed := false
		if preCall != nil {
			err := preCall(address.Name)
			if err != nil {
				preCallFailed = true
				result = &CmdResult{hostname: address.Name, err: err}
			}
		}
		if !preCallFailed {
			result = runRemoteCmd(client, address, cmd)
			if result.err == nil {
				break
			} else {
				// backoff timeout
				time.Sleep(nextDelay(i, maxRetries))
				if logRetries {
					log.Printf("%s: %s", color.Red("Failed, retry"), result.err)
				}
			}
		} else {
			time.Sleep(nextDelay(i, maxRetries))
		}
	}
	if postCall != nil {
		err := postCall(address.Name)
		if err != nil {
			log.Printf("%s: FAILED to execute post action %s, do it manually", address.Name, err.Error())
			log.Printf("see more details in logs/%s.log", address.Name)
		}
	}
	return result
}

func runRemoteCmd(client Client, address *hosts.Address, cmd string) *CmdResult {
	conn, err := client.Connect(address)
	if err != nil {
		return &CmdResult{hostname: address.Name, err: err}
	}
	defer conn.Close()

	output, err := conn.Run(cmd)
	if err != nil {
		return &CmdResult{hostname: address.Name, err: err}
	}

	return &CmdResult{hostname: address.Name, output: output}
}

func (c *ClusterClient) logCmdResult(cmd string, result *CmdResult) {
	filename := path.Join(c.logsDir, result.hostname+".log")
	f, err := os.OpenFile(filename, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0644)
	if err != nil {
		log.Printf("cannot open log file %s: %v", filename, err)
		return
	}

	defer func(f *os.File) {
		_ = f.Close()
	}(f)

	if _, err = f.WriteString("Run: " + cmd + "\n"); err == nil {
		if result.output != nil {
			_, err = f.Write(result.output)
		} else if result.err != nil {
			_, err = f.WriteString("Failed: " + result.err.Error() + "\n")
		}
	}

	if err != nil {
		log.Printf("cannot write log to %s: %v", filename, err)
	}
}

func (c *ClusterClient) RunSequentially(cmd string, delay time.Duration) *ClientRunResult {
	return c.RunSequentiallyWithWrapCall(cmd, delay, nil, nil)
}

func (c *ClusterClient) RunSequentiallyWithWrapCall(cmd string, delay time.Duration, preCall Call, postCall Call) *ClientRunResult {
	count := len(c.addresses)
	failed := 0

	for i, addr := range c.addresses {
		prefix := fmt.Sprintf("[%d/%d] %s", i+1, count, color.Purple(addr.Name))

		log.Printf("%s: RUN %s", prefix, color.Yellow(cmd))
		result := runRemoteCmdWithRetries(c.client, &addr, cmd, true, preCall, postCall)
		c.logCmdResult(cmd, result)

		if result.err != nil {
			log.Printf("%s: FAILED %s", prefix, color.Red(result.err.Error()))
			log.Printf("see more details in logs/%s.log", addr.Name)
			failed++
		} else {
			log.Printf("%s: %s", prefix, color.Green("OK"))
		}

		if i != count-1 {
			// do not sleep after running command on the last address
			utils.Sleep(delay)
		}
	}
	return &ClientRunResult{FailedCount: failed}
}

func (c *ClusterClient) RunParallel(parallelism int, cmd string) *ClientRunResult {
	return c.RunParallelWithWrapCall(parallelism, cmd, nil, nil, time.Second*0)
}

func (c *ClusterClient) RunParallelWithWrapCall(parallelism int, cmd string, preCall Call, postCall Call, delay time.Duration) *ClientRunResult {
	if len(c.addresses) == 1 {
		// no need to run in parallel mode
		return c.RunSequentially(cmd, time.Duration(0))
	}

	addresses := make(chan hosts.Address, len(c.addresses))
	results := make(chan *CmdResult, len(c.addresses))
	var active int32

	for i := 0; i < parallelism; i++ {
		go func() {
			atomic.AddInt32(&active, 1)
			for address := range addresses {
				time.Sleep(time.Duration(rand.Int31n(1000)) * time.Millisecond)
				runWithDelay := func() *CmdResult {
					res := runRemoteCmdWithRetries(c.client, &address, cmd, false, preCall, postCall)
					utils.SleepWithoutPrint(delay)
					return res
				}
				results <- runWithDelay()
			}
			atomic.AddInt32(&active, -1)
		}()
	}

	for _, address := range c.addresses {
		addresses <- address
	}
	close(addresses)

	format := "\033[100D -- " +
		"%d Active" + " / " +
		color.Cyan("%d Pending") + " / " +
		color.Red("%d Failed") + " / " +
		color.Green("%d Done") + " --"

	failed, count := 0, len(c.addresses)

	for done := 0; done < count; {
		fmt.Fprintf(os.Stderr, format, atomic.LoadInt32(&active), count-done, failed, done)

		select {
		case result := <-results:
			c.logCmdResult(cmd, result)
			done++
			if result.err != nil {
				failed++
			}
		case <-time.After(time.Second):
		}
	}
	fmt.Fprintln(os.Stderr)

	return &ClientRunResult{FailedCount: failed}
}

func (c *ClusterClient) Close() {
	c.client.Close()
}
