package main

import (
	"bufio"
	"bytes"
	"context"
	"fmt"
	"io"
	"log"
	"net"
	"os"
	"os/exec"
	"strconv"
	"strings"
	"time"

	"a.yandex-team.ru/infra/azure/lb-conf-porto-helper/internal/config"
	"a.yandex-team.ru/infra/azure/lb-conf-porto-helper/internal/controller"
	"a.yandex-team.ru/infra/azure/lb-conf-porto-helper/pkg/isshelper"
	"a.yandex-team.ru/infra/azure/lb-conf-porto-helper/pkg/nshelper"
	"a.yandex-team.ru/infra/azure/lb-conf-porto-helper/pkg/portohelper"
	"a.yandex-team.ru/library/go/core/log/zap"
	"github.com/Azure/azure-sdk-for-go/profiles/latest/network/mgmt/network"
	"github.com/Azure/go-autorest/autorest"
	"github.com/Azure/go-autorest/autorest/adal"
)

const (
	containerIfaceName = "veth"
	ipCmd              = "/usr/sbin/ip"
	defaultConfigPath  = "/etc/porto-azure/lb-helper.yaml"
)

var logger *zap.Logger

func maskCGroupForPorto() error {
	f, err := os.OpenFile("/sys/fs/cgroup/unified/cgroup.procs", os.O_RDWR|os.O_APPEND, os.ModeAppend)
	if err != nil {
		return err
	}
	defer f.Close()
	_, err = f.Write([]byte(fmt.Sprintf("%d", os.Getpid())))
	return err
}

func runCmd(l *zap.Logger, name string, arg ...string) error {
	l = l.WithName(name).(*zap.Logger)
	l.Debugf("Running command %s with args %v", name, arg)
	devnull, err := os.Open(os.DevNull)
	if err != nil {
		return fmt.Errorf("failed to open /dev/null for '%s' stdin: %w", name, err)
	}
	cmd := exec.Command(name, arg...)
	epipe, err := cmd.StderrPipe()
	if err != nil {
		return fmt.Errorf("failed to open stderr pipe for '%s': %w", name, err)
	}
	opipe, err := cmd.StdoutPipe()
	if err != nil {
		return fmt.Errorf("failed to open stdout pipe for '%s': %w", name, err)
	}
	cmd.Stdin = devnull
	err = cmd.Start()
	if err != nil {
		return err
	}
	stdout, err := io.ReadAll(opipe)
	if err != nil {
		l.Errorf("Failed to read stdout: %v", err)
	}
	stderr, err := io.ReadAll(epipe)
	if err != nil {
		l.Errorf("Failed to read stderr: %v", err)
	}
	err = cmd.Wait()
	if err != nil {
		l.Errorf("Stdout: %s", stdout)
		l.Errorf("Stderr: %s", stderr)
	}
	return err
}

func getLLAddr(iface string) (string, error) {
	cmd := exec.Command(ipCmd, "-6", "addr", "show", "dev", iface, "scope", "link")
	out, err := cmd.CombinedOutput()
	if err != nil {
		return "", fmt.Errorf("failed to get %s ll address %w, output:\n====\n%s\n====", iface, err, out)
	}
	scanner := bufio.NewScanner(bytes.NewReader(out))
	for scanner.Scan() {
		line := strings.TrimSpace(scanner.Text())
		if line[:5] == "inet6" {
			parts := strings.Fields(line)
			return parts[1][:len(parts[1])-3], nil
		}
	}
	if scanner.Err() != nil {
		return "", err
	}
	return "", fmt.Errorf("failed to extract link-local address from ip output:\n====\n%s\n====", out)
}

type containerAddrInfo struct {
	ipv4Addr string
	llAddr   string
}

func containerNSFunc(ipv4Addr string, hostLLAddr string) (containerAddrInfo, error) {
	containerLLAddr, err := getLLAddr(containerIfaceName)
	if err != nil {
		logger.Fatalf("failed to get container ll addr: %v", err)
	}
	logger.Debugf("found container ll addr %s on %s", containerLLAddr, containerIfaceName)
	if err := runCmd(logger, ipCmd, "-4", "route", "add", "default", "via", "inet6", hostLLAddr, "dev", containerIfaceName); err != nil {
		logger.Fatalf("failed to add default route inside container ns: %v", err)
	}
	container4Addr := fmt.Sprintf("%s/32", ipv4Addr)
	if err := runCmd(logger, ipCmd, "-4", "addr", "add", container4Addr, "dev", containerIfaceName); err != nil {
		logger.Fatalf("failed to add address %s to %s inside container ns: %v", container4Addr, containerIfaceName, err)
	}
	return containerAddrInfo{llAddr: containerLLAddr, ipv4Addr: container4Addr}, nil
}

func ensureLBRoutingRules(l *zap.Logger, cfg *config.Config) error {
	cmd := exec.Command(ipCmd, "-4", "rule", "list", "from", cfg.LBSubnet)
	out, err := cmd.CombinedOutput()
	if err != nil {
		return fmt.Errorf("failed to list rules for %s: %w", cfg.LBSubnet, err)
	}
	line := strings.TrimSpace(string(out))
	if line != "" {
		parts := strings.Fields(line)
		if len(parts) != 5 {
			return fmt.Errorf("unexpected number of fields in ip rule output '%s', expecting exactly 5 fields, got %d", line, len(parts))
		}
		table, err := strconv.ParseInt(parts[4], 10, 32)
		if err != nil {
			return err
		}
		if table != int64(cfg.LBTableID) {
			return fmt.Errorf("unexpected table id in rule '%s', expected table %d", line, cfg.LBTableID)
		}
	} else {
		return runCmd(l, ipCmd, "-4", "rule", "add", "from", cfg.LBSubnet, "lookup", fmt.Sprintf("%d", cfg.LBTableID))
	}
	return nil
}

func ensureSNATRoutingRules(l *zap.Logger, cfg *config.Config) error {
	cmd := exec.Command(ipCmd, "-4", "rule", "list", "from", cfg.SNATSubnet)
	out, err := cmd.CombinedOutput()
	if err != nil {
		return fmt.Errorf("failed to list rules for %s: %w", cfg.LBSubnet, err)
	}
	line := strings.TrimSpace(string(out))
	if line != "" {
		parts := strings.Fields(line)
		if len(parts) != 5 {
			return fmt.Errorf("unexpected number of fields in ip rule output '%s', expecting exactly 5 fields, got %d", line, len(parts))
		}
		table, err := strconv.ParseInt(parts[4], 10, 32)
		if err != nil {
			return err
		}
		if table != int64(cfg.SNATTableID) {
			return fmt.Errorf("unexpected table id in rule '%s', expected table %d", line, cfg.SNATTableID)
		}
	} else {
		return runCmd(l, ipCmd, "-4", "rule", "add", "from", cfg.SNATSubnet, "lookup", fmt.Sprintf("%d", cfg.SNATTableID))
	}
	return nil
}

func ensureLBRoutes(l *zap.Logger, cfg *config.Config) error {
	if err := runCmd(l, ipCmd, "-4", "route", "replace", fmt.Sprintf("%s/32", cfg.LBGateway), "dev", cfg.LBIface); err != nil {
		return fmt.Errorf("failed to set up lb gateway route %s via %s: %w", cfg.LBGateway, cfg.LBIface, err)
	}
	if err := runCmd(l, ipCmd, "-4", "route", "replace", "default", "via", cfg.LBGateway, "dev", cfg.LBIface, "table", fmt.Sprintf("%d", cfg.LBTableID)); err != nil {
		return fmt.Errorf("failed to set up default route for balanced traffic via %s, dev %s, table %d: %w", cfg.LBGateway, cfg.LBIface, cfg.LBTableID, err)
	}
	return nil
}

func ensureSNATRoutes(l *zap.Logger, cfg *config.Config) error {
	if err := runCmd(l, ipCmd, "-4", "route", "replace", "default", "via", cfg.SNATGateway, "dev", cfg.SNATInterface, "table", fmt.Sprintf("%d", cfg.SNATTableID)); err != nil {
		return fmt.Errorf("failed to set up default route for snat traffic via %s, dev %s, table %d: %w", cfg.SNATGateway, cfg.SNATInterface, cfg.SNATTableID, err)
	}
	return nil
}

func lbMode(conf *config.Config) {
	pContext := portohelper.ContextFromEnv()
	iss := isshelper.NewDefaultClient()
	ctx, cancel := context.WithTimeout(context.Background(), time.Second)
	defer cancel()
	podsInfo, err := iss.GetPodsInfo(ctx)
	if err != nil {
		logger.Fatalf("failed to get pods info: %v", err)
	}
	parts := strings.Split(pContext.ContainerName, "/")
	var lbName, rgName, awacsNamespace string
	for _, pod := range podsInfo.Pods {
		if parts[0] == pod.ContainerName {
			logger.Debugf("found container: %s", pod.ContainerName)
			for _, attr := range pod.DynamicAttributes.Labels.Attributes {
				if attr.Key == "azure_lb_name" {
					buf, err := attr.GetValue()
					if err != nil {
						logger.Fatalf("failed to get azure_lb_name pod label value: %v", err)
					}
					lbName = strings.TrimSpace(string(buf[2:]))
					logger.Debugf("found container %s, lb name: %s", parts[0], lbName)
				} else if attr.Key == "azure_lb_rg" {
					buf, err := attr.GetValue()
					if err != nil {
						logger.Fatalf("failed to get azure_lb_rg pod label value: %v", err)
					}
					rgName = strings.TrimSpace(string(buf[2:]))
					logger.Debugf("found container %s, rg name: %s", parts[0], rgName)
				} else if attr.Key == "awacs_namespace_id" {
					buf, err := attr.GetValue()
					if err != nil {
						logger.Fatalf("failed to get awacs_namespace_id pod label value: %v", err)
					}
					awacsNamespace = strings.TrimSpace(string(buf[2:]))
					logger.Debugf("found container %s, awacs_namespace_id: %s", parts[0], awacsNamespace)
				}
			}
		}
	}
	if lbName == "" && awacsNamespace == "" {
		logger.Info("no lbs found")
		os.Exit(0)
	}
	if lbName == "" && awacsNamespace != "" {
		lbName = awacsNamespace
	}
	state, err := controller.LoadState(conf)
	if err != nil {
		logger.Fatalf("failed to load state: %v", err)
	}
	aState, err := state.ActualizeState()
	if err != nil {
		logger.Fatalf("failed to actualize state: %v", err)
	}
	newState, err := aState.UpdateState(lbName, parts[0], rgName)
	if err != nil {
		logger.Fatalf("failed to update state: %v", err)
	}

	tokenProvider, err := adal.NewServicePrincipalTokenFromManagedIdentity("https://management.azure.com/", &adal.ManagedIdentityOptions{})
	if err != nil {
		logger.Fatalf("failed to obtain tokenProvider: %v", err)
	}
	client := network.NewLoadBalancerBackendAddressPoolsClient(conf.SubscriptionID)
	client.Authorizer = autorest.NewBearerAuthorizer(tokenProvider)

	appliedState, err := newState.ApplyMutations(context.Background(), logger, client)
	if err != nil {
		logger.Fatalf("failed to apply new state to azure: %v", err)
	}
	if err := appliedState.PersistState(); err != nil {
		// TODO: revert azure ops in case of failure?
		// maybe use azure deployments???
		logger.Fatalf("failed to persist updated state: %v", err)
	}
	hostLLAddr, err := getLLAddr(pContext.L3Interface)
	if err != nil {
		logger.Fatalf("failed to get host ll addr: %v", err)
	}
	logger.Debugf("found host ll addr %s on %s", hostLLAddr, pContext.L3Interface)
	var containerAddrInfo containerAddrInfo
	if err := nshelper.RunInNetNSAtPath(pContext.NetNsFDPath, func() error {
		info, err := containerNSFunc(newState.GetNewSlot().Address, hostLLAddr)
		if err != nil {
			return err
		}
		containerAddrInfo = info
		return nil
	}); err != nil {
		logger.Fatalf("failed to set up container NS: %v", err)
	}
	if err := runCmd(logger, ipCmd, "-4", "route", "add", containerAddrInfo.ipv4Addr, "via", "inet6", containerAddrInfo.llAddr, "dev", pContext.L3Interface); err != nil {
		logger.Fatalf("failed to add container %s route %s: %v", parts[0], newState.GetNewSlot().Address, err)
	}
	logger.Infof("l3 lb configured for %s on %s", pContext.ContainerName, newState.GetNewSlot().Address)
	if err := ensureLBRoutes(logger, conf); err != nil {
		logger.Fatalf("failed to ensure LB routes: %v", err)
	}
	if err := ensureLBRoutingRules(logger, conf); err != nil {
		logger.Fatalf("failed to ensure LB routing rules: %v", err)
	}
}

func snatMode(conf *config.Config) {
	ipv4Base, snatSubnet, err := net.ParseCIDR(conf.SNATSubnet)
	if err != nil {
		logger.Fatalf("failed to parse snat_subnet %s: %v", conf.SNATSubnet, err)
	}
	pContext := portohelper.ContextFromEnv()
	iss := isshelper.NewDefaultClient()
	ctx, cancel := context.WithTimeout(context.Background(), time.Second)
	defer cancel()
	podsInfo, err := iss.GetPodsInfo(ctx)
	if err != nil {
		logger.Fatalf("failed to get pods info: %v", err)
	}
	parts := strings.Split(pContext.ContainerName, "/")
	var podInfo *isshelper.PodInfo
	for _, pod := range podsInfo.Pods {
		if parts[0] == pod.ContainerName {
			logger.Debugf("found container: %s", pod.ContainerName)
			for _, attr := range pod.DynamicAttributes.Labels.Attributes {
				if attr.Key == "ipv4_snat_enable" {
					buf, err := attr.GetValue()
					if err != nil {
						logger.Fatalf("failed to get ipv4_snat_enable pod label value: %v", err)
					}
					if strings.TrimSpace(string(buf[2:])) == "true" {
						podInfo = pod
						logger.Debugf("found container for snat %s", parts[0])
						break
					}
				}
			}
		}
		if podInfo != nil {
			break
		}
	}
	if podInfo == nil {
		logger.Infof("no ipv4_snat_enable=true label found on container %s", parts[0])
		os.Exit(0)
	}
	var ip6Addr string
	for _, alloc := range podInfo.IP6AddressAllocations {
		if alloc.VlanID == "backbone" {
			ip6Addr = alloc.Address
			break
		}
	}
	if ip6Addr == "" {
		logger.Fatalf("failed to get ipv6 address of pod %s", podInfo.ID)
	}
	parts = strings.Split(ip6Addr, ":")
	ipNonce := parts[len(parts)-2]
	ipOctets, err := strconv.ParseUint(ipNonce, 16, 16)
	if err != nil {
		logger.Fatalf("failed to parse ipv6 nonce part %s: %v", ipNonce, err)
	}
	ipv4Addr := ipv4Base.To4()
	if ipv4Addr == nil {
		logger.Fatalf("got non ipv4 snat_subnet %s from config", conf.SNATSubnet)
	}
	ipv4Addr[2] = byte((ipOctets & 0xff00) >> 8)
	ipv4Addr[3] = byte(ipOctets & 0xff)
	if !snatSubnet.Contains(ipv4Addr) {
		logger.Fatalf("snat_subnet %s does not contain computed ipv4 address %s", conf.SNATSubnet, ipv4Addr)
	}

	hostLLAddr, err := getLLAddr(pContext.L3Interface)
	if err != nil {
		logger.Fatalf("failed to get host ll addr: %v", err)
	}
	logger.Debugf("found host ll addr %s on %s", hostLLAddr, pContext.L3Interface)
	var containerAddrInfo containerAddrInfo
	if err := nshelper.RunInNetNSAtPath(pContext.NetNsFDPath, func() error {
		info, err := containerNSFunc(ipv4Addr.String(), hostLLAddr)
		if err != nil {
			return err
		}
		containerAddrInfo = info
		return nil
	}); err != nil {
		logger.Fatalf("failed to set up container NS: %v", err)
	}
	if err := runCmd(logger, ipCmd, "-4", "route", "add", containerAddrInfo.ipv4Addr, "via", "inet6", containerAddrInfo.llAddr, "dev", pContext.L3Interface); err != nil {
		logger.Fatalf("failed to add container %s route %s: %v", parts[0], containerAddrInfo.ipv4Addr, err)
	}
	logger.Infof("snat routing configured for %s on %s", pContext.ContainerName, ipv4Addr.String())
	if err := ensureSNATRoutes(logger, conf); err != nil {
		logger.Fatalf("failed to ensure snat routes: %v", err)
	}
	if err := ensureSNATRoutingRules(logger, conf); err != nil {
		logger.Fatalf("failed to ensure snat routing rules: %v", err)
	}
}

func main() {
	conf, err := config.FromFile(defaultConfigPath)
	if err != nil {
		log.Fatalf("failed to load default config from %s: %v", defaultConfigPath, err)
	}
	l, err := zap.New(conf.Log.ZapConfig())
	if err != nil {
		log.Fatalf("failed to initialize logger: %v", err)
	}
	logger = l
	err = maskCGroupForPorto()
	if err != nil {
		logger.Fatalf("failed to mask cgroup for porto: %v", err)
	}
	switch conf.Mode {
	case "snat":
		snatMode(conf)
	case "lb":
		fallthrough
	default:
		lbMode(conf)
	}
}
