package main

import (
	"bytes"
	"crypto/md5"
	"encoding/binary"
	"encoding/json"
	"fmt"
	"net"
	"os"
	"runtime"
	"strconv"
	"strings"
	"syscall"

	"github.com/vishvananda/netlink"

	"github.com/containernetworking/cni/pkg/skel"
	"github.com/containernetworking/cni/pkg/types"
	current "github.com/containernetworking/cni/pkg/types/100"
	"github.com/containernetworking/cni/pkg/version"

	"github.com/containernetworking/plugins/pkg/ip"
	"github.com/containernetworking/plugins/pkg/ns"
	bv "github.com/containernetworking/plugins/pkg/utils/buildversion"
)

// =============================================================================================================

var defaultMtnIface = "vlan688"
var ip6Mask128 = net.CIDRMask(128, 128)
var ip6Mask000 = net.CIDRMask(0, 128)

type NetConf struct {
	types.NetConf
	MTU       int              `json:"mtu"`        // mtu
	PrjIdStr  string           `json:"project_id"` // hex number
	ProjectId []byte           `json:"-"`
	Dad       bool             `json:"dad"`       // duplicate address detection
	MtnIface  string           `json:"mtn_iface"` // defaultMtnIface by default
	MtnIPv6   net.IP           `json:"-"`
	ContIface string           `json:"-"`
	ContIPv6  net.IP           `json:"-"`
	ContMask  net.IPMask       `json:"-"`
	HostIface string           `json:"-"`
	MAC       net.HardwareAddr `json:"-"`
}

func init() {
	// this ensures that main runs only on main thread (thread group leader).
	// since namespace ops (unshare, setns) are done for a single thread, we
	// must ensure that the goroutine does not jump from OS thread to thread
	runtime.LockOSThread()
}

// =============================================================================================================

func ipv6FromIPv6AndMAC(ipx net.IP, mac net.HardwareAddr) net.IP {
	ip := make(net.IP, net.IPv6len)
	copy(ip, ipx)
	if ip[11] == 0xff && ip[12] == 0xfe {
		ip[8] = mac[0] ^ 0x02
		ip[9] = mac[1]
		ip[10] = mac[2]
	} else {
		ip[12] = mac[2]
	}
	ip[13] = mac[3]
	ip[14] = mac[4]
	ip[15] = mac[5]
	return ip
}

func stringToMAC(str string) net.HardwareAddr {
	md5Sum := md5.Sum([]byte(str))
	mac := make([]byte, 6)
	for i := 0; i < 12; i += 2 {
		mac[i/2] = (md5Sum[i] & 0xf0) + (md5Sum[i+1] >> 4)
	}
	// https://en.wikipedia.org/wiki/MAC_address#Unicast_vs._multicast_(I/G_bit)
	// https://en.wikipedia.org/wiki/MAC_address#Universal_vs._local_(U/L_bit)
	mac[0] = (mac[0] | 0x02) & 0xfe
	return mac
}

func setProjectIdToIPv6(ipx net.IP, projectId []byte) net.IP {
	ip := make(net.IP, net.IPv6len)
	copy(ip, ipx)
	for i := 0; i < 4; i++ {
		ip[i+8] = projectId[i]
	}
	return ip
}

func isZeroProjectId(id []byte) bool {
	return bytes.Equal(id, []byte{0, 0, 0, 0})
}

func getProjectIdFromIPv6(ip net.IP) []byte {
	pi := make([]byte, 4)
	copy(pi, ip[8:12])
	return pi
}

func getProjectIdFromStr(id string) []byte {
	id = strings.TrimPrefix(id, "0x")
	if len(id) > 8 {
		id = id[len(id)-8:]
	}
	pi := make([]byte, 4)
	if u, err := strconv.ParseUint(id, 16, 32); err == nil {
		binary.BigEndian.PutUint32(pi, uint32(u))
		return pi
	}
	return pi
}

func makeHostIfaceName(id string, occupiedIdSet map[string]struct{}) (string, error) {
	ifaceId := []byte("00000000")
	idLen := len(ifaceId)
	idx := 0
	for _, c := range id {
		if c >= 'A' && c <= 'Z' {
			c += ('a' - 'A')
		}
		if (c >= '0' && c <= '9') || (c >= 'a' && c <= 'z') {
			ifaceId[idx] = byte(c)
			idx += 1
			if idx == idLen {
				break
			}
		}
	}
	idx = idLen
	for {
		iface := fmt.Sprintf("L3-%s", ifaceId)
		if _, ok := occupiedIdSet[iface]; !ok {
			return iface, nil
		}
		if idx == idLen || ifaceId[idx] == 'z' {
			idx -= 1
			if idx < 0 {
				return "", fmt.Errorf("failed to select host iface name for %s", id)
			}
			ifaceId[idx] = '0'
		} else if ifaceId[idx] == '9' {
			ifaceId[idx] = 'a'
		} else {
			ifaceId[idx] += 1
		}
	}
}

func getPodNameOrId(args *skel.CmdArgs) string {
	for _, v := range strings.Split(args.Args, ";") {
		tv := strings.TrimPrefix(v, "K8S_POD_NAME=")
		if tv != v {
			return tv
		}
	}
	return args.ContainerID
}

func loadConf(args *skel.CmdArgs) (*NetConf, error) {
	conf := &NetConf{}
	err := json.Unmarshal(args.StdinData, conf)
	if err != nil {
		return nil, fmt.Errorf("failed to parse network config: %v", err)
	}

	// ContIface MAC MtnIface
	conf.ContIface = args.IfName
	conf.MAC = stringToMAC(args.ContainerID)
	if conf.MtnIface == "" {
		conf.MtnIface = defaultMtnIface
	}

	// ProjectId
	if conf.PrjIdStr != "" {
		conf.ProjectId = getProjectIdFromStr(conf.PrjIdStr)
	}
	return conf, nil
}

func loadConfLocal(conf *NetConf, id string) error {
	// ProjectId
	if conf.ProjectId == nil {
		routes, err := netlink.RouteList(nil, netlink.FAMILY_V6)
		if err != nil {
			return err
		}
		for _, r := range routes {
			if r.Dst == nil {
				conf.ProjectId = getProjectIdFromIPv6(r.Src)
				break
			}
		}
	}
	if isZeroProjectId(conf.ProjectId) {
		return fmt.Errorf("bad project id: %v", conf.ProjectId)
	}

	// MTU
	if conf.MTU == 0 {
		link, err := netlink.LinkByName(conf.MtnIface)
		if err != nil {
			return err
		}
		conf.MTU = link.Attrs().MTU
	}

	// HostIfaceHostIface
	links, err := netlink.LinkList()
	if err != nil {
		return fmt.Errorf("failed to list host links %v", err)
	}
	linkNamesSet := map[string]struct{}{}
	for _, link := range links {
		linkNamesSet[link.Attrs().Name] = struct{}{}
	}
	if conf.HostIface, err = makeHostIfaceName(id, linkNamesSet); err != nil {
		return err
	}

	// ContMask
	conf.ContMask = ip6Mask128

	// MtnIPv6 ContIPv6
	mtnIface, err := net.InterfaceByName(conf.MtnIface)
	if err != nil {
		return fmt.Errorf("failed to get mtn iface: %v", err)
	}
	mtnAddrs, err := mtnIface.Addrs()
	if err != nil {
		return fmt.Errorf("failed to get addresses on mtn iface: %v", err)
	}
	for _, mtnAddr := range mtnAddrs {
		mtnIp, _, err := net.ParseCIDR(mtnAddr.String())
		if err != nil {
			return fmt.Errorf("bad address %s on mtn iface: %v", mtnAddr.String(), err)
		}
		if mtnIp.IsGlobalUnicast() && isZeroProjectId(getProjectIdFromIPv6(mtnIp)) {
			// If we need all ip connections to container from host to originate from IP with
			// some project id, then we need to add this IP/96 with that project id to MTN iface
			// otherwise it will originate from default MTN IP without project id (== 0)
			conf.MtnIPv6 = mtnIp
			conf.ContIPv6 = ipv6FromIPv6AndMAC(setProjectIdToIPv6(mtnIp, conf.ProjectId), conf.MAC)
			return nil
		}
	}
	return fmt.Errorf("no global unicast ip without project id found on mtn iface")
}

func setupVeth(conf *NetConf, netns ns.NetNS, result *current.Result) error {
	var contAddr *net.IPNet
	var hostIface *current.Interface

	err := netns.Do(func(hostNS ns.NetNS) error {
		if !conf.Dad {
			for _, sysFile := range []string{
				"/proc/sys/net/ipv6/conf/all/accept_dad",
				"/proc/sys/net/ipv6/conf/default/accept_dad",
			} {
				if err := os.WriteFile(sysFile, []byte{'0'}, 0o644); err != nil {
					return err
				}
			}
		}
		hostVeth, contVeth, err := ip.SetupVethWithName(
			conf.ContIface,
			conf.HostIface,
			conf.MTU,
			conf.MAC.String(),
			hostNS,
		)
		if err != nil {
			return err
		}
		hostIface = &current.Interface{
			Name: hostVeth.Name,
			Mac:  hostVeth.HardwareAddr.String(),
		}
		contIface := &current.Interface{
			Name:    contVeth.Name,
			Mac:     contVeth.HardwareAddr.String(),
			Sandbox: netns.Path(),
		}
		result.Interfaces = []*current.Interface{hostIface, contIface}

		contAddr = &net.IPNet{
			IP:   conf.ContIPv6,
			Mask: conf.ContMask,
		}
		contIP := &current.IPConfig{
			Interface: current.Int(1),
			Address:   *contAddr,
		}
		result.IPs = []*current.IPConfig{contIP}
		result.Routes = []*types.Route{}

		contLink, err := netlink.LinkByName(conf.ContIface)
		if err != nil {
			return fmt.Errorf("failed to lookup %q: %v", conf.ContIface, err)
		}

		// up container interface
		if err := netlink.LinkSetUp(contLink); err != nil {
			return fmt.Errorf("failed to set %q UP: %v", conf.ContIface, err)
		}

		// add ip address to container
		addrFlags := 0
		if !conf.Dad {
			addrFlags |= syscall.IFA_F_NODAD
		}
		contLinkAddr := &netlink.Addr{
			IPNet: contAddr,
			Flags: addrFlags,
		}
		if err = netlink.AddrAdd(contLink, contLinkAddr); err != nil {
			return fmt.Errorf("failed to add %v to %q: %v", conf.ContIPv6, conf.ContIface, err)
		}

		// add host MTN ip as neighbour
		hostNeigh := netlink.Neigh{
			LinkIndex:    contVeth.Index,
			State:        netlink.NUD_PERMANENT,
			IP:           conf.MtnIPv6,
			HardwareAddr: hostVeth.HardwareAddr,
		}
		err = netlink.NeighAdd(&hostNeigh)
		if err != nil {
			return fmt.Errorf("cannot add ip from %q as permanent neighbour for container: %v", conf.MtnIface, err)
		}

		// add host MTN ip direct route
		hostRoute := netlink.Route{
			LinkIndex: contVeth.Index,
			Dst: &net.IPNet{
				IP:   conf.MtnIPv6,
				Mask: ip6Mask128,
			},
			Scope: netlink.SCOPE_UNIVERSE,
		}
		if err := netlink.RouteAdd(&hostRoute); err != nil {
			return fmt.Errorf("failed to add to container direct host route %v: %v", hostRoute, err)
		}

		// add default route
		hostDefaultRoute := netlink.Route{
			LinkIndex: contVeth.Index,
			Dst: &net.IPNet{
				IP:   net.IPv6zero,
				Mask: ip6Mask000,
			},
			Gw: conf.MtnIPv6,
			// MTU: conf.MTU - 90,	// we should consider smaller magic MTU
			Scope: netlink.SCOPE_UNIVERSE,
		}
		if err := netlink.RouteAdd(&hostDefaultRoute); err != nil {
			return fmt.Errorf("failed to set ip from %q as conatiner default route %v: %v", conf.MtnIface, hostDefaultRoute, err)
		}
		result.Routes = append(result.Routes, &types.Route{
			Dst: *hostDefaultRoute.Dst,
			GW:  hostDefaultRoute.Gw,
		})

		return nil
	})
	if err != nil {
		return err
	}

	// host iface moved namespaces and may have a new index
	hostLink, err := netlink.LinkByName(hostIface.Name)
	if err != nil {
		return fmt.Errorf("failed to lookup %q: %v", hostIface.Name, err)
	}

	// add route to container from host veth iface
	contRoute := netlink.Route{
		LinkIndex: hostLink.Attrs().Index,
		Dst:       contAddr,
		Scope:     netlink.SCOPE_UNIVERSE,
	}
	if err := netlink.RouteAdd(&contRoute); err != nil {
		return fmt.Errorf("failed to add to host direct container route %v: %v", contRoute, err)
	}

	// add container to veth iface neighbours
	contNeigh := netlink.Neigh{
		LinkIndex:    hostLink.Attrs().Index,
		State:        netlink.NUD_PERMANENT,
		IP:           contAddr.IP,
		HardwareAddr: conf.MAC,
	}
	err = netlink.NeighAdd(&contNeigh)
	if err != nil {
		return fmt.Errorf("cannot add container ip as permanent neighbour %v: %v", contNeigh, err)
	}

	// arp proxy container ip on MTN
	mtnLink, err := netlink.LinkByName(conf.MtnIface)
	if err != nil {
		return fmt.Errorf("failed to lookup %q: %v", conf.MtnIface, err)
	}
	err = netlink.NeighAdd(&netlink.Neigh{
		LinkIndex: mtnLink.Attrs().Index,
		Flags:     netlink.NTF_PROXY,
		IP:        contAddr.IP,
	})
	if err != nil {
		return fmt.Errorf("cannot add proxy %q for conatiner ip: %v", conf.MtnIface, err)
	}
	return nil
}

func cmdAdd(args *skel.CmdArgs) error {
	conf, err := loadConf(args)
	if err != nil {
		return err
	}
	if err := loadConfLocal(conf, getPodNameOrId(args)); err != nil {
		return err
	}

	result := &current.Result{
		CNIVersion: conf.CNIVersion,
		DNS:        conf.DNS,
	}

	netns, err := ns.GetNS(args.Netns)
	if err != nil {
		return types.PrintResult(result, conf.CNIVersion)
	}
	defer netns.Close()

	if err = setupVeth(conf, netns, result); err != nil {
		return err
	}
	// Enable forwarding
	if err := ip.EnableForward(result.IPs); err != nil {
		return fmt.Errorf("could not enable IP forwarding: %v", err)
	}

	return types.PrintResult(result, conf.CNIVersion)
}

// =============================================================================================================

func cmdDel(args *skel.CmdArgs) error {
	conf, err := loadConf(args)
	if err != nil {
		return err
	}
	if args.Netns == "" {
		return nil
	}

	hostLinkIndex := -1
	err = ns.WithNetNSPath(args.Netns, func(_ ns.NetNS) error {
		link, err := netlink.LinkByName(args.IfName)
		if err == nil {
			hostLinkIndex = link.Attrs().ParentIndex
		}
		return err
	})
	if err == nil {
		mtnLink, err := netlink.LinkByName(conf.MtnIface)
		if err != nil {
			return fmt.Errorf("failed to lookup %q: %v", conf.MtnIface, err)
		}

		neighs, err := netlink.NeighList(hostLinkIndex, netlink.FAMILY_V6)
		if err != nil {
			return fmt.Errorf("failed to get iface idx %v neighbours: %v", hostLinkIndex, err)
		}
		for _, neigh := range neighs {
			if neigh.State&netlink.NUD_PERMANENT != 0 {
				_ = netlink.NeighDel(&netlink.Neigh{
					LinkIndex: mtnLink.Attrs().Index,
					Flags:     netlink.NTF_PROXY,
					IP:        neigh.IP,
				})
			}
		}
	}

	// There is a netns so try to clean up. Delete can be called multiple times
	// so don't return an error if the device is already removed.
	err = ns.WithNetNSPath(args.Netns, func(_ ns.NetNS) error {
		if err := ip.DelLinkByName(args.IfName); err != nil {
			if err != ip.ErrLinkNotFound {
				return err
			}
		}
		return nil
	})
	return err
}

// =============================================================================================================

func checkVeth(conf *NetConf, netns ns.NetNS) error {
	hostLinkIndex := -1
	err := netns.Do(func(hostNS ns.NetNS) error {
		if !conf.Dad {
			for _, sysFile := range []string{
				"/proc/sys/net/ipv6/conf/all/accept_dad",
				"/proc/sys/net/ipv6/conf/default/accept_dad",
			} {
				if dad, err := os.ReadFile(sysFile); err != nil {
					return err
				} else if len(dad) < 1 || dad[0] != '0' {
					return fmt.Errorf("DAD is enabled in %s, but not enabled in config", sysFile)
				}
			}
		}

		contLink, err := netlink.LinkByName(conf.ContIface)
		if err != nil {
			return fmt.Errorf("failed to lookup %q: %v", conf.ContIface, err)
		}
		hostLinkIndex = contLink.Attrs().ParentIndex

		// check container interface
		if contLink.Attrs().MTU != conf.MTU {
			return fmt.Errorf("conatiner interface %q has different MTU (%v!=%v)", conf.ContIface, contLink.Attrs().MTU, conf.MTU)
		}
		if !bytes.Equal(contLink.Attrs().HardwareAddr, conf.MAC) {
			return fmt.Errorf("conatiner interface %q has different MAC address (%v!=%v)", conf.ContIface, contLink.Attrs().HardwareAddr, conf.MAC)
		}
		if contLink.Attrs().OperState.String() != "up" {
			return fmt.Errorf("conatiner interface %q is not UP", conf.ContIface)
		}

		// check ip address of the container
		addrList, err := netlink.AddrList(contLink, netlink.FAMILY_V6)
		if err != nil {
			return fmt.Errorf("cannot obtain list of IP addresses for container interface %q", conf.ContIface)
		}
		found := false
		for _, addr := range addrList {
			if addr.IP.Equal(conf.ContIPv6) {
				sizea, _ := addr.Mask.Size()
				sizec, _ := conf.ContMask.Size()
				if sizea == sizec {
					found = true
					break
				}
			}
		}
		if !found {
			return fmt.Errorf("cannot find addr %v on interface %v", conf.ContIPv6, conf.ContIface)
		}

		// check host MTN ip as neighbour
		neighs, err := netlink.NeighList(contLink.Attrs().Index, netlink.FAMILY_V6)
		if err != nil {
			return fmt.Errorf("failed to get container iface %q neighbours: %v", conf.ContIface, err)
		}
		found = false
		for _, neigh := range neighs {
			if neigh.State&netlink.NUD_PERMANENT != 0 && neigh.IP.Equal(conf.MtnIPv6) {
				found = true
				break
			}
		}
		if !found {
			return fmt.Errorf("cannot find addr %v among container interface %v neighbours", conf.MtnIPv6, conf.ContIface)
		}

		// check host MTN ip direct route
		filter := &netlink.Route{
			LinkIndex: contLink.Attrs().Index,
			Dst: &net.IPNet{
				IP:   conf.MtnIPv6,
				Mask: ip6Mask128,
			},
		}
		filterMask := netlink.RT_FILTER_OIF | netlink.RT_FILTER_DST
		if routes, err := netlink.RouteListFiltered(netlink.FAMILY_V6, filter, filterMask); err != nil {
			return fmt.Errorf("table lookup error for direct route %v: %v", filter, err)
		} else if routes == nil {
			return fmt.Errorf("direct route %v not found in routing table", filter)
		}

		// check default route
		filter = &netlink.Route{
			LinkIndex: contLink.Attrs().Index,
			Dst:       nil,
			Gw:        conf.MtnIPv6,
		}
		filterMask = netlink.RT_FILTER_OIF | netlink.RT_FILTER_DST | netlink.RT_FILTER_GW
		if routes, err := netlink.RouteListFiltered(netlink.FAMILY_V6, filter, filterMask); err != nil {
			return fmt.Errorf("table lookup error for default route %v: %v", filter, err)
		} else if routes == nil {
			return fmt.Errorf("default route %v not found in routing table", filter)
		}
		return nil
	})
	if err != nil {
		return err
	}

	contAddr := &net.IPNet{
		IP:   conf.ContIPv6,
		Mask: conf.ContMask,
	}

	// check route to container from host veth iface
	filter := &netlink.Route{
		LinkIndex: hostLinkIndex,
		Dst:       contAddr,
	}
	filterMask := netlink.RT_FILTER_OIF | netlink.RT_FILTER_DST
	if routes, err := netlink.RouteListFiltered(netlink.FAMILY_V6, filter, filterMask); err != nil {
		return fmt.Errorf("table lookup error for expected route %v: %v", filter, err)
	} else if routes == nil {
		return fmt.Errorf("expected route %v not found in routing table", filter)
	}

	// check container to veth iface neighbours
	neighs, err := netlink.NeighList(hostLinkIndex, netlink.FAMILY_V6)
	if err != nil {
		return fmt.Errorf("failed to get host veth iface (idx=%d) neighbours: %v", hostLinkIndex, err)
	}
	found := false
	for _, neigh := range neighs {
		if neigh.State&netlink.NUD_PERMANENT != 0 && neigh.IP.Equal(conf.ContIPv6) && bytes.Equal(neigh.HardwareAddr, conf.MAC) {
			found = true
			break
		}
	}
	if !found {
		return fmt.Errorf("cannot find addr %v among host interface (idx=%d) neighbours", conf.ContIPv6, hostLinkIndex)
	}
	return nil
}

func cmdCheck(args *skel.CmdArgs) error {
	conf, err := loadConf(args)
	if err != nil {
		return err
	}
	if err := loadConfLocal(conf, getPodNameOrId(args)); err != nil {
		return err
	}

	netns, err := ns.GetNS(args.Netns)
	if err != nil {
		return fmt.Errorf("failed to open netns %q: %v", args.Netns, err)
	}
	defer netns.Close()

	if err = checkVeth(conf, netns); err != nil {
		return err
	}
	// Check forwarding
	fwd, err := os.ReadFile("/proc/sys/net/ipv6/conf/all/forwarding")
	if err != nil {
		return fmt.Errorf("could not check IP forwarding: %v", err)
	} else if len(fwd) < 1 || fwd[0] != '1' {
		return fmt.Errorf("IP forwarding is not enabled")
	}
	return nil
}

// =============================================================================================================

func main() {
	skel.PluginMain(cmdAdd, cmdCheck, cmdDel, version.All, bv.BuildString("mtn"))
}
