// Original U32 struct doesn't support Police field.
// Unfourtantelly I had to copy parts of the code as is.
// For better understanding I marked new code with next comment: 'New piece of code'.
package netlimit

import (
	"encoding/binary"
	"fmt"
	"net"
	"syscall"

	"github.com/vishvananda/netlink"
	"github.com/vishvananda/netlink/nl"
	"golang.org/x/sys/unix"
)

var (
	networkOrder = binary.BigEndian
)

type U32 struct {
	netlink.FilterAttrs
	ClassID    uint32
	Divisor    uint32
	Hash       uint32
	RedirIndex int
	Sel        *netlink.TcU32Sel
	// New piece of code
	// Extending original U32 struct
	Actions []netlink.Action
	Police  nl.TcPolice
	Rtab    [256]uint32
	Ptab    [256]uint32
	// end New
}

func (filter *U32) Attrs() *netlink.FilterAttrs {
	return &filter.FilterAttrs
}

func (filter *U32) Type() string {
	return "u32"
}

// New piece of code
func NewU32Filter(l netlink.Link, ph uint32, pri uint16, ip string, act netlink.TcAct) (*U32, error) {
	sel, err := parseTcU32Sel(ip)
	if err != nil {
		return nil, err
	}
	f := &U32{
		FilterAttrs: netlink.FilterAttrs{
			LinkIndex: l.Attrs().Index,
			Parent:    ph,
			Protocol:  unix.ETH_P_ALL,
			Priority:  pri,
		},
		Sel: sel,
	}
	if act != netlink.TC_ACT_UNSPEC {
		f.Actions = append(f.Actions, &netlink.GenericAction{
			ActionAttrs: netlink.ActionAttrs{
				Action: act,
			},
		})
	}
	return f, err
}

func SetU32Police(f *U32, rate, prate, burst, mtu uint32, act netlink.TcAct) error {
	linklayer := nl.LINKLAYER_ETHERNET
	f.Police.Rate.Rate = rate / 8
	f.Police.PeakRate.Rate = prate / 8
	f.Police.Action = int32(act)
	f.Police.Mtu = mtu

	if f.Police.Rate.Rate != 0 {
		if netlink.CalcRtable(&f.Police.Rate, f.Rtab[:], -1, mtu, linklayer) < 0 {
			return fmt.Errorf("failed to calculate rate table")
		}
		f.Police.Burst = uint32(netlink.Xmittime(uint64(f.Police.Rate.Rate), burst))
	}
	if f.Police.PeakRate.Rate != 0 {
		if netlink.CalcRtable(&f.Police.PeakRate, f.Ptab[:], -1, mtu, linklayer) < 0 {
			return fmt.Errorf("failed to calculate peak rate table")
		}
	}
	return nil
}

// end New

func FilterList(link netlink.Link, parent uint32) ([]netlink.Filter, error) {
	req := nl.NewNetlinkRequest(unix.RTM_GETTFILTER, unix.NLM_F_DUMP)
	msg := &nl.TcMsg{
		Family: nl.FAMILY_ALL,
		Parent: parent,
	}
	msg.Ifindex = int32(link.Attrs().Index)
	req.AddData(msg)

	msgs, err := req.Execute(unix.NETLINK_ROUTE, unix.RTM_NEWTFILTER)
	if err != nil {
		return nil, err
	}

	var res []netlink.Filter
	for _, m := range msgs {
		msg := nl.DeserializeTcMsg(m)

		attrs, err := nl.ParseRouteAttr(m[msg.Len():])
		if err != nil {
			return nil, err
		}

		base := netlink.FilterAttrs{
			LinkIndex: int(msg.Ifindex),
			Handle:    msg.Handle,
			Parent:    msg.Parent,
		}
		base.Priority, base.Protocol = netlink.MajorMinor(msg.Info)
		base.Protocol = nl.Swap16(base.Protocol)

		var filter netlink.Filter
		filterType := ""
		detailed := false
		for _, attr := range attrs {
			switch attr.Attr.Type {
			case nl.TCA_KIND:
				filterType = string(attr.Value[:len(attr.Value)-1])
				switch filterType {
				case "u32":
					filter = &U32{}
				default:
					filter = &netlink.GenericFilter{FilterType: filterType}
				}
			case nl.TCA_OPTIONS:
				data, err := nl.ParseRouteAttr(attr.Value)
				if err != nil {
					return nil, err
				}
				switch filterType {
				case "u32":
					detailed, err = parseU32Data(filter, data)
					if err != nil {
						return nil, err
					}
				default:
					detailed = true
				}
			}
		}
		// only return the detailed version of the filter
		if detailed {
			*filter.Attrs() = base
			res = append(res, filter)
		}
	}

	return res, nil
}

func FilterAdd(filter netlink.Filter) error {
	native := nl.NativeEndian()
	req := nl.NewNetlinkRequest(unix.RTM_NEWTFILTER, unix.NLM_F_CREATE|unix.NLM_F_EXCL|unix.NLM_F_ACK)
	base := filter.Attrs()
	msg := &nl.TcMsg{
		Family:  nl.FAMILY_ALL,
		Ifindex: int32(base.LinkIndex),
		Handle:  base.Handle,
		Parent:  base.Parent,
		Info:    netlink.MakeHandle(base.Priority, nl.Swap16(base.Protocol)),
	}
	req.AddData(msg)
	req.AddData(nl.NewRtAttr(nl.TCA_KIND, nl.ZeroTerminated(filter.Type())))

	options := nl.NewRtAttr(nl.TCA_OPTIONS, nil)

	switch filter := filter.(type) {
	case *U32:
		sel := filter.Sel
		if sel == nil {
			// match all
			sel = &nl.TcU32Sel{
				Nkeys: 1,
				Flags: nl.TC_U32_TERMINAL,
			}
			sel.Keys = append(sel.Keys, nl.TcU32Key{})
		}

		if native != networkOrder {
			// Copy TcU32Sel.
			cSel := *sel
			keys := make([]nl.TcU32Key, cap(sel.Keys))
			copy(keys, sel.Keys)
			cSel.Keys = keys
			sel = &cSel

			// Handle the endianness of attributes
			sel.Offmask = native.Uint16(htons(sel.Offmask))
			sel.Hmask = native.Uint32(htonl(sel.Hmask))
			for i, key := range sel.Keys {
				sel.Keys[i].Mask = native.Uint32(htonl(key.Mask))
				sel.Keys[i].Val = native.Uint32(htonl(key.Val))
			}
		}
		sel.Nkeys = uint8(len(sel.Keys))
		options.AddRtAttr(nl.TCA_U32_SEL, sel.Serialize())
		if filter.ClassID != 0 {
			options.AddRtAttr(nl.TCA_U32_CLASSID, nl.Uint32Attr(filter.ClassID))
		}
		if filter.Divisor != 0 {
			if (filter.Divisor-1)&filter.Divisor != 0 {
				return fmt.Errorf("illegal divisor %d. Must be a power of 2", filter.Divisor)
			}
			options.AddRtAttr(nl.TCA_U32_DIVISOR, nl.Uint32Attr(filter.Divisor))
		}
		if filter.Hash != 0 {
			options.AddRtAttr(nl.TCA_U32_HASH, nl.Uint32Attr(filter.Hash))
		}
		// New piece of code
		if (filter.Police != nl.TcPolice{}) {
			police := options.AddRtAttr(nl.TCA_U32_POLICE, nil)
			police.AddRtAttr(nl.TCA_POLICE_TBF, filter.Police.Serialize())
			if (filter.Police.Rate != nl.TcRateSpec{}) {
				payload := netlink.SerializeRtab(filter.Rtab)
				police.AddRtAttr(nl.TCA_POLICE_RATE, payload)
			}
			if (filter.Police.PeakRate != nl.TcRateSpec{}) {
				payload := netlink.SerializeRtab(filter.Ptab)
				police.AddRtAttr(nl.TCA_POLICE_PEAKRATE, payload)
			}
		}
		// end New
		actionsAttr := options.AddRtAttr(nl.TCA_U32_ACT, nil)
		// backwards compatibility
		if filter.RedirIndex != 0 {
			filter.Actions = append([]netlink.Action{netlink.NewMirredAction(filter.RedirIndex)}, filter.Actions...)
		}
		if err := netlink.EncodeActions(actionsAttr, filter.Actions); err != nil {
			return err
		}
	}

	req.AddData(options)
	_, err := req.Execute(unix.NETLINK_ROUTE, 0)
	return err
}

// New piece of code
// Helper function to translate CIDR entry to netlink.TcU32Sel
func parseTcU32Sel(s string) (*netlink.TcU32Sel, error) {
	addr, net, err := net.ParseCIDR(s)
	if err != nil {
		return nil, err
	}
	sel := &netlink.TcU32Sel{
		Flags: netlink.TC_U32_TERMINAL,
	}
	for i := 0; i < len(addr); i += 4 {
		mask := binary.BigEndian.Uint32(net.Mask[i : i+4])
		val := binary.BigEndian.Uint32(addr[i : i+4])
		if mask == 0 && val == 0 {
			continue
		}
		sel.Keys = append(sel.Keys, netlink.TcU32Key{
			Mask:    mask,
			Val:     val,
			Off:     int32(i + 8),
			OffMask: 0,
		})
	}
	sel.Nkeys = uint8(len(sel.Keys))
	return sel, nil
}

// end New

func parseU32Data(filter netlink.Filter, data []syscall.NetlinkRouteAttr) (bool, error) {
	native := nl.NativeEndian()
	u32 := filter.(*U32)
	detailed := false
	for _, datum := range data {
		switch datum.Attr.Type {
		case nl.TCA_U32_SEL:
			detailed = true
			sel := nl.DeserializeTcU32Sel(datum.Value)
			u32.Sel = sel
			if native != networkOrder {
				u32.Sel.Offmask = native.Uint16(htons(sel.Offmask))
				u32.Sel.Hmask = native.Uint32(htonl(sel.Hmask))
				for i, key := range u32.Sel.Keys {
					u32.Sel.Keys[i].Mask = native.Uint32(htonl(key.Mask))
					u32.Sel.Keys[i].Val = native.Uint32(htonl(key.Val))
				}
			}
		case nl.TCA_U32_ACT:
			tables, err := nl.ParseRouteAttr(datum.Value)
			if err != nil {
				return detailed, err
			}
			u32.Actions, err = parseActions(tables)
			if err != nil {
				return detailed, err
			}
		case nl.TCA_U32_CLASSID:
			u32.ClassID = native.Uint32(datum.Value)
		case nl.TCA_U32_DIVISOR:
			u32.Divisor = native.Uint32(datum.Value)
		case nl.TCA_U32_HASH:
			u32.Hash = native.Uint32(datum.Value)
		// New piece of code
		case nl.TCA_U32_POLICE:
			adata, _ := nl.ParseRouteAttr(datum.Value)
			for _, aattr := range adata {
				switch aattr.Attr.Type {
				case nl.TCA_POLICE_TBF:
					u32.Police = *nl.DeserializeTcPolice(aattr.Value)
				case nl.TCA_POLICE_RATE:
					u32.Rtab = netlink.DeserializeRtab(aattr.Value)
				case nl.TCA_POLICE_PEAKRATE:
					u32.Ptab = netlink.DeserializeRtab(aattr.Value)
				}
			}
		}
		// end New
	}
	return detailed, nil
}

func parseActions(tables []syscall.NetlinkRouteAttr) ([]netlink.Action, error) {
	var actions []netlink.Action
	for _, table := range tables {
		var action netlink.Action
		var actionType string
		aattrs, err := nl.ParseRouteAttr(table.Value)
		if err != nil {
			return nil, err
		}
	nextattr:
		for _, aattr := range aattrs {
			switch aattr.Attr.Type {
			case nl.TCA_KIND:
				actionType = string(aattr.Value[:len(aattr.Value)-1])
				switch actionType {
				case "gact":
					action = &netlink.GenericAction{}
				default:
					break nextattr
				}
			case nl.TCA_OPTIONS:
				adata, err := nl.ParseRouteAttr(aattr.Value)
				if err != nil {
					return nil, err
				}
				for _, adatum := range adata {
					switch actionType {
					case "gact":
						switch adatum.Attr.Type {
						case nl.TCA_GACT_PARMS:
							gen := *nl.DeserializeTcGen(adatum.Value)
							toAttrs(&gen, action.Attrs())
						}
					}
				}
			}
		}
		actions = append(actions, action)
	}
	return actions, nil
}

func toAttrs(tcgen *nl.TcGen, attrs *netlink.ActionAttrs) {
	attrs.Index = int(tcgen.Index)
	attrs.Capab = int(tcgen.Capab)
	attrs.Action = netlink.TcAct(tcgen.Action)
	attrs.Refcnt = int(tcgen.Refcnt)
	attrs.Bindcnt = int(tcgen.Bindcnt)
}

func htonl(val uint32) []byte {
	bytes := make([]byte, 4)
	binary.BigEndian.PutUint32(bytes, val)
	return bytes
}

func htons(val uint16) []byte {
	bytes := make([]byte, 2)
	binary.BigEndian.PutUint16(bytes, val)
	return bytes
}
