package worker

import (
	"fmt"
	"io"
	"io/ioutil"
	"log"
	"os"
	"path/filepath"
	"strconv"
	"strings"
	"sync"
	"syscall"
	"time"

	"github.com/gofrs/uuid"

	"a.yandex-team.ru/drive/runner/config"
	"a.yandex-team.ru/security/libs/go/porto"
)

type cachedPortoLayer struct {
	name    string
	expires time.Time
	loaded  chan struct{}
	err     error
}

type portoProcessor struct {
	// conn contains porto connection.
	conn *porto.Connection
	// layers contains porto layers.
	layers map[string]*portoLayer
	// container contains porto container.
	containers map[string]*portoContainer
	// dir contains path to roots.
	dir string
	// files contains file store.
	files *FileStore
	// cachedLayers
	cachedLayers map[FileURL]*cachedPortoLayer
	// expiredLayers
	expiredLayers []string
	// mutex contains mutex.
	mutex sync.Mutex
}

type portoLayer struct {
	proc *portoProcessor
	name string
	// set of container usages.
	containers map[string]struct{}
}

type portoContainer struct {
	proc             *portoProcessor
	name             string
	root             string
	layers           []string
	stderrR, stdoutR io.ReadCloser
	stderrW, stdoutW io.WriteCloser
}

func (c *portoContainer) Start() error {
	return c.proc.conn.Start(c.name)
}

func (c *portoContainer) Stats() (ProcessStats, error) {
	cpuUsage, err := c.proc.conn.GetProperty(c.name, "cpu_usage")
	if err != nil {
		return ProcessStats{}, fmt.Errorf(
			"unable to fetch cpu_usage: %w", err,
		)
	}
	memoryUsage, err := c.proc.conn.GetProperty(c.name, "memory_usage")
	if err != nil {
		return ProcessStats{}, fmt.Errorf(
			"unable to fetch memory_usage: %w", err,
		)
	}
	stats := ProcessStats{}
	stats.CPUUsage, err = strconv.ParseInt(cpuUsage, 10, 64)
	if err != nil {
		return ProcessStats{}, fmt.Errorf(
			"unable to parse cpu_usage: %w", err,
		)
	}
	stats.MemoryUsage, err = strconv.ParseInt(memoryUsage, 10, 64)
	if err != nil {
		return ProcessStats{}, fmt.Errorf(
			"unable to parse memory_usage: %w", err,
		)
	}
	return stats, nil
}

func (c *portoContainer) Signal(signal syscall.Signal) error {
	return c.proc.conn.Kill(c.name, signal)
}

type ExitError struct {
	exitCode int
}

func (e ExitError) ExitCode() int {
	return e.exitCode
}

func (e ExitError) Error() string {
	return fmt.Sprintf("exit code: %d", e.exitCode)
}

func (c *portoContainer) Wait() error {
	_, err := c.proc.conn.Wait(c.name, -1)
	c.closeWriteStreams()
	if err != nil {
		return err
	}
	rawExitCode, err := c.proc.conn.GetProperty(c.name, "exit_code")
	if err != nil {
		return err
	}
	exitCode, err := strconv.Atoi(rawExitCode)
	if err != nil {
		return err
	}
	if exitCode != 0 {
		return &ExitError{exitCode: exitCode}
	}
	return nil
}

func (c *portoContainer) closeWriteStreams() {
	if c.stderrW != nil {
		_ = c.stderrW.Close()
		c.stderrW = nil
	}
	if c.stdoutW != nil {
		_ = c.stdoutW.Close()
		c.stdoutW = nil
	}
}

func (c *portoContainer) closeReadStreams() {
	if c.stderrR != nil {
		_ = c.stderrR.Close()
		c.stderrR = nil
	}
	if c.stdoutR != nil {
		_ = c.stdoutR.Close()
		c.stdoutR = nil
	}
}

func (c *portoContainer) Destroy() error {
	if err := c.proc.conn.Destroy(c.name); err != nil {
		log.Println("Error:", err)
	}
	c.closeWriteStreams()
	c.closeReadStreams()
	if err := os.RemoveAll(c.root); err != nil {
		log.Println("Error:", err)
	}
	c.proc.mutex.Lock()
	defer c.proc.mutex.Unlock()
	for _, name := range c.layers {
		if layer, ok := c.proc.layers[name]; ok {
			delete(layer.containers, c.name)
		}
	}
	delete(c.proc.containers, c.name)
	return nil
}

func (c *portoContainer) Stderr() io.Reader {
	return c.stderrR
}

func (c *portoContainer) Stdout() io.Reader {
	return c.stdoutR
}

type TaskOptions map[string]string

type ProcessStats struct {
	MemoryUsage int64
	CPUUsage    int64
}

type Process interface {
	Start() error
	Stats() (ProcessStats, error)
	Wait() error
	Signal(syscall.Signal) error
	Destroy() error
	Stderr() io.Reader
	Stdout() io.Reader
}

func (m *portoProcessor) newContainer() *portoContainer {
	m.mutex.Lock()
	defer m.mutex.Unlock()
	for i := 0; i < 1000; i++ {
		name, err := uuid.NewV4()
		if err != nil {
			continue
		}
		if _, ok := m.containers[name.String()]; ok {
			continue
		}
		container := &portoContainer{
			proc: m,
			name: fmt.Sprintf("self/dajr-%s", name),
		}
		m.containers[container.name] = container
		return container
	}
	panic("unable to create new container")
}

func (m *portoProcessor) newLayerUnlocked() *portoLayer {
	for i := 0; i < 1000; i++ {
		name, err := uuid.NewV4()
		if err != nil {
			continue
		}
		if _, ok := m.layers[name.String()]; ok {
			continue
		}
		layer := &portoLayer{
			proc:       m,
			name:       fmt.Sprintf("dajr-%s", name),
			containers: map[string]struct{}{},
		}
		m.layers[layer.name] = layer
		return layer
	}
	panic("unable to create new layer")
}

func (m *portoProcessor) Create(opts TaskOptions) (Process, error) {
	container := m.newContainer()
	options, err := parseOptions(opts)
	if err != nil {
		return nil, err
	}
	if err := func() error {
		if err := m.conn.Create(container.name); err != nil {
			return err
		}
		if err := m.prepareVolume(container, options); err != nil {
			return err
		}
		if err := m.prepareFiles(container, options); err != nil {
			return err
		}
		if err := m.conn.SetProperty(container.name, "command", options.Command); err != nil {
			return err
		}
		if err := m.conn.SetProperty(container.name, "root", "//"+container.root); err != nil {
			return err
		}
		if err := m.conn.SetProperty(container.name, "cwd", options.WorkDir); err != nil {
			return err
		}
		if err := m.conn.SetProperty(container.name, "env", options.environ()); err != nil {
			return err
		}
		if err := m.conn.SetProperty(container.name, "user", "root"); err != nil {
			return err
		}
		if err := m.conn.SetProperty(container.name, "group", "root"); err != nil {
			return err
		}
		if err := m.conn.SetProperty(container.name, "enable_porto", "isolate"); err != nil {
			return err
		}
		if options.MemoryLimit > 0 {
			if err := m.conn.SetProperty(
				container.name, "memory_limit", fmt.Sprint(options.MemoryLimit),
			); err != nil {
				return err
			}
		}
		stderrR, stderrW, err := os.Pipe()
		if err != nil {
			return err
		}
		if err = m.conn.SetProperty(
			container.name, "stderr_path",
			fmt.Sprintf("/dev/fd/%d", stderrW.Fd()),
		); err != nil {
			_ = stderrR.Close()
			_ = stderrW.Close()
			return err
		}
		container.stderrR = stderrR
		container.stderrW = stderrW
		stdoutR, stdoutW, err := os.Pipe()
		if err != nil {
			return err
		}
		if err = m.conn.SetProperty(
			container.name, "stdout_path",
			fmt.Sprintf("/dev/fd/%d", stdoutW.Fd()),
		); err != nil {
			_ = stdoutR.Close()
			_ = stdoutW.Close()
			return err
		}
		container.stdoutR = stdoutR
		container.stdoutW = stdoutW
		return nil
	}(); err != nil {
		if err := container.Destroy(); err != nil {
			log.Println("Error:", err)
		}
		return nil, err
	}
	return container, nil
}

func (m *portoProcessor) Cleanup() {
	m.mutex.Lock()
	defer m.mutex.Unlock()
	newLen := 0
	for _, name := range m.expiredLayers {
		m.expiredLayers[newLen] = name
		layer, ok := m.layers[name]
		if !ok {
			log.Printf("Layer %q already removed", name)
			continue
		}
		if len(layer.containers) > 0 {
			newLen++
			continue
		}
		if err := m.conn.RemoveLayer(layer.name); err != nil {
			log.Printf("Unable to remove layer %q", layer.name)
			newLen++
			continue
		}
	}
	m.expiredLayers = m.expiredLayers[:newLen]
}

func (m *portoProcessor) prepareVolume(
	container *portoContainer, options parsedOptions,
) error {
	for _, layer := range options.Layers {
		layer, err := m.prepareLayer(container, layer)
		if err != nil {
			return err
		}
		<-layer.loaded
		if layer.err != nil {
			return layer.err
		}
	}
	root, err := ioutil.TempDir(m.dir, "container")
	if err != nil {
		return err
	}
	container.root = root
	if _, err := m.conn.CreateVolume(root, map[string]string{
		"layers": strings.Join(container.layers, ";"),
	}); err != nil {
		return err
	}
	defer func() {
		if err := m.conn.UnlinkVolume(root, ""); err != nil {
			log.Println("Error:", err)
		}
	}()
	return m.conn.LinkVolume(root, container.name)
}

const optionSchema = "option"

func (m *portoProcessor) prepareFiles(
	container *portoContainer, options parsedOptions,
) error {
	for _, url := range options.Files {
		if err := func() error {
			filePath := url.Path
			if !filepath.IsAbs(filePath) {
				filePath = filepath.Join(options.WorkDir, filePath)
			}
			if strings.Contains(filePath, "..") {
				return fmt.Errorf("illegal file path: %q", filePath)
			}
			filePath = filepath.Join(container.root, filePath)
			if url.FileURL.Provider == optionSchema {
				value, ok := options.Options[url.FileURL.Path]
				if !ok {
					return fmt.Errorf(
						"option %q does not exists", url.FileURL.Path,
					)
				}
				return ioutil.WriteFile(filePath, []byte(value), url.Mode)
			}
			file, err := m.files.ReadFile(url.FileURL)
			if err != nil {
				return err
			}
			defer func() {
				_ = file.Close()
			}()
			temp, err := os.OpenFile(
				filePath, os.O_RDWR|os.O_CREATE|os.O_TRUNC, url.Mode,
			)
			if err != nil {
				return err
			}
			defer func() {
				_ = temp.Close()
			}()
			_, err = io.Copy(temp, file)
			return err
		}(); err != nil {
			return err
		}
	}
	return nil
}

func (m *portoProcessor) prepareLayer(
	container *portoContainer, url FileURL,
) (*cachedPortoLayer, error) {
	m.mutex.Lock()
	defer m.mutex.Unlock()
	if cache, ok := m.cachedLayers[url]; ok {
		select {
		case <-cache.loaded:
			if cache.err == nil {
				if cache.expires.After(time.Now()) {
					attachLayer(container, m.layers[cache.name])
					return cache, nil
				}
				m.expiredLayers = append(m.expiredLayers, cache.name)
			}
			delete(m.cachedLayers, url)
		default:
			attachLayer(container, m.layers[cache.name])
			return cache, nil
		}
	}
	layer := m.newLayerUnlocked()
	attachLayer(container, layer)
	cache := &cachedPortoLayer{
		name:   layer.name,
		loaded: make(chan struct{}),
	}
	m.cachedLayers[url] = cache
	go m.loadCachedLayer(cache, url)
	return cache, nil
}

func (m *portoProcessor) loadCachedLayer(
	cache *cachedPortoLayer, url FileURL,
) {
	defer close(cache.loaded)
	cache.err = func() error {
		file, err := m.files.LoadFile(url)
		if err != nil {
			return err
		}
		defer file.Close()
		cache.expires = file.ExpireTime()
		return m.conn.ImportLayer(porto.ImportLayerOpts{
			Layer:   cache.name,
			Tarball: file.Name(),
		})
	}()
}

func cleanupPorto(api *porto.API) {
	conn := api.Connection()
	if names, err := conn.ListByMask("dajr-*"); err == nil {
		for _, name := range names {
			_ = conn.Destroy(name)
		}
	}
	if layers, err := conn.ListLayers2("", "dajr-*"); err == nil {
		for _, layer := range layers {
			_ = conn.RemoveLayer(layer.Name)
		}
	}
}

func newPortoProcessor(
	cfg *config.Config, files *FileStore,
) (*portoProcessor, error) {
	maxConnections := cfg.Worker.PortoMaxConnections
	if maxConnections <= 0 {
		maxConnections = 10
	}
	api, err := porto.NewAPI(&porto.APIOpts{MaxConnections: maxConnections})
	if err != nil {
		return nil, err
	}
	cleanupPorto(api)
	dir := filepath.Join(cfg.SystemDir, "roots")
	_ = os.RemoveAll(dir)
	if err := os.MkdirAll(dir, os.ModePerm); err != nil {
		return nil, err
	}
	return &portoProcessor{
		conn:         api.Connection(),
		layers:       map[string]*portoLayer{},
		containers:   map[string]*portoContainer{},
		cachedLayers: map[FileURL]*cachedPortoLayer{},
		dir:          dir,
		files:        files,
	}, nil
}

func attachLayer(container *portoContainer, layer *portoLayer) {
	layer.containers[container.name] = struct{}{}
	container.layers = append(container.layers, layer.name)
}
