// More info about tc:
// http://linux-tc-notes.sourceforge.net/tc/doc/cls_u32.txt
// https://man7.org/linux/man-pages/man8/tc-u32.8.html
package netlimit

import (
	"fmt"

	"a.yandex-team.ru/infra/rsm/dnsmanager/pkg/serverinfo"
	"a.yandex-team.ru/infra/rsm/sysconf/internal/plugin"

	"github.com/vishvananda/netlink"
	"github.com/vishvananda/netlink/nl"
)

const (
	MTUMax uint32 = 65536
	Mbit   uint32 = 1000000
)

var (
	// All Fastbone networks
	netsFB = "2a02:6b8:f000::/36"
	// Fastbone networks allocated for MAN.
	// See more https://racktables.yandex-team.ru/index.php?andor=and&cft%5B%5D=141&cft%5B%5D=1056&cfe=&page=ipv6space&tab=default
	// WE HAVE TO KNOW AND ALWAYS REMEMBER ABOUT THIS HARDCODE
	netsFBMan               = []string{"2a02:6b8:fc01::/48", "2a02:6b8:fc0d::/48", "2a02:6b8:fc11::/48", "2a02:6b8:fc12::/48", "2a02:6b8:fc13::/48", "2a02:6b8:fc1b::/48"}
	rateDefault             = 225 * Mbit
	prateDefault     uint32 = 0
	burstPercDefault uint32 = 10
	rateYABS                = 200 * Mbit
)

func init() {
	p, err := New()
	if err != nil {
		panic(err)
	}
	plugin.Register(p)
}

type RuleSet struct {
	l   netlink.Link
	q   netlink.Qdisc
	fls []netlink.Filter
}

func NewRuleSet(l netlink.Link) *RuleSet {
	return &RuleSet{
		l: l,
		q: &netlink.Ingress{
			QdiscAttrs: netlink.QdiscAttrs{
				LinkIndex: l.Attrs().Index,
				Parent:    netlink.HANDLE_INGRESS,
				Handle:    netlink.MakeHandle(0xffff, 0),
			},
		},
	}
}

func (rs *RuleSet) FilterAdd(f *U32) {
	rs.fls = append(rs.fls, f)
}

func (rs *RuleSet) String() string {
	return rs.l.Attrs().Name
}

type pl []*RuleSet

func New() (pl, error) {
	rate := rateDefault

	si, err := serverinfo.NewDefault().Read()
	if err != nil {
		return nil, nil
	}
	if si.Location != "man" {
		return nil, nil
	}
	switch si.WallePrj {
	case "rtc-yabs", "rtc-yabs-mtn":
		rate = rateYABS
	}
	// calc N percent of the rate and convert to bytes
	burst := (rate / 100 * burstPercDefault) / 8

	ls, err := netlink.LinkList()
	if err != nil {
		return nil, err
	}
	p := pl{}
	for _, l := range ls {
		switch l.Attrs().Name {
		case "vlan700", "vlan761":
			rs := NewRuleSet(l)
			p = append(p, rs)
			// It's analogue of next command: sudo tc filter add dev $IFACE parent ffff: prio 10 u32 match ip6 src 2a02:6b8:fc01::/48 action pass
			// At first we adding rules which skip network limit for internal traffic (inside of network location).
			// Priority of this rules have to lower than rules which apply the limit.
			prio := uint16(10)
			for _, net := range netsFBMan {
				f, err := NewU32Filter(l, rs.q.Attrs().Handle, prio, net, netlink.TC_ACT_OK)
				if err != nil {
					return nil, err
				}
				rs.FilterAdd(f)
			}
			prio = uint16(99)
			// It's analogue of next command: sudo tc filter add dev $IFACE parent ffff: prio 99 u32 match ip6 src 2a02:6b8:f000::/36 police rate 100Mbit burst 1250000 mtu 65536 drop
			f, err := NewU32Filter(l, rs.q.Attrs().Handle, prio, netsFB, netlink.TC_ACT_UNSPEC)
			if err != nil {
				return nil, err
			}
			if err := SetU32Police(f, rate, prateDefault, burst, MTUMax, netlink.TC_ACT_SHOT); err != nil {
				return nil, err
			}
			rs.FilterAdd(f)
		}
	}
	return p, nil
}

func (p pl) Name() string {
	return "netlimit"
}

func (p pl) Doc() string {
	return "https://st.yandex-team.ru/RTCNETWORK-654"
}

func (p pl) IsApplicable() (st plugin.State) {
	st.Status = plugin.StatusSkip

	if len(p) > 0 {
		st.Status = plugin.StatusEnable
	}
	return
}

func (p pl) Check() (sts plugin.States) {
	for _, rs := range p {
		sts.Add(checkLimitFeature(rs))
	}
	return
}

func (p pl) Enable(force bool) (sts plugin.States) {
	var st plugin.State
	for _, rs := range p {
		if st = checkLimitFeature(rs); st.Status == plugin.StatusDiff {
			st = enableLimitFeature(rs)
		}
		sts.Add(st)
	}
	return
}

func (p pl) Disable(force bool) (sts plugin.States) {
	for _, rs := range p {
		sts.Add(plugin.State{Name: rs.String(), Status: plugin.StatusSkip, Err: plugin.ErrNotImpl})
	}
	return
}

func enableLimitFeature(rs *RuleSet) (st plugin.State) {
	st = plugin.State{Name: rs.String(), Status: plugin.StatusFail, Err: nil}

	_ = netlink.QdiscDel(rs.q)
	st.Err = netlink.QdiscAdd(rs.q)
	if st.Err != nil {
		return
	}

	for _, f := range rs.fls {
		st.Err = FilterAdd(f)
		if st.Err != nil {
			return
		}
	}

	st.Status = plugin.StatusOk
	return
}

func checkLimitFeature(rs *RuleSet) (st plugin.State) {
	st = plugin.State{Name: rs.String(), Status: plugin.StatusFail, Err: nil}

	fls, err := FilterList(rs.l, rs.q.Attrs().Handle)
	if err != nil {
		st.Err = err
		st.Status = plugin.StatusFail
		return
	}

	st.Err = filtersEqual(rs.fls, fls)
	if st.Err != nil {
		st.Status = plugin.StatusDiff
		return
	}
	st.Status = plugin.StatusOk
	return
}

func filtersEqual(efls, fls []netlink.Filter) error {
	// check filters list
	if len(efls) != len(fls) {
		return fmt.Errorf("number of filters is different \n\tmust: %d \n\tgot: %d", len(efls), len(fls))
	}
	for fi, f := range fls {
		fhStr := netlink.HandleStr(f.Attrs().Handle)
		ef := efls[fi]

		if ef.Type() != f.Type() {
			return fmt.Errorf("type of filters is different: %s and %s", ef.Type(), f.Type())
		}
		switch f.Type() {
		case "u32":
			f, ok := f.(*U32)
			if !ok {
				return fmt.Errorf("interface conversion \n\tmust:u32 \n\tgot: %T", f)
			}
			ef, ok := ef.(*U32)
			if !ok {
				return fmt.Errorf("interface conversion \n\tmust:u32 \n\tgot: %T", ef)
			}
			if !actionEqual(f.Actions, ef.Actions) {
				return fmt.Errorf("filter %s actions are different \n\tmust:%+v \n\tgot: %+v", fhStr, ef.Actions, f.Actions)
			}
			if !policeEqual(&f.Police, &ef.Police) {
				return fmt.Errorf("filter %s police are different \n\tmust: %+v \n\tgot: %+v", fhStr, ef.Police, f.Police)
			}
			if !selEqual(f.Sel, ef.Sel) {
				return fmt.Errorf("filter %s sel are different \n\tmust: %#v \n\tgot: %#v", fhStr, ef.Sel, f.Sel)
			}
		default:
			return fmt.Errorf("only u32 type applicable")
		}
	}
	return nil
}

func actionEqual(a, ea []netlink.Action) bool {
	if len(a) != len(ea) {
		return false
	}
	for ai, ca := range a {
		if ca.Attrs().Action != ea[ai].Attrs().Action {
			return false
		}
	}
	return true
}

func policeEqual(p, ep *nl.TcPolice) bool {
	return p.Action == ep.Action &&
		p.Rate.Rate == ep.Rate.Rate &&
		p.Burst == ep.Burst &&
		p.Mtu == ep.Mtu
}

func selEqual(s, es *netlink.TcU32Sel) bool {
	if s == nil && es == nil {
		return true
	}
	if (s != nil && es == nil) || (s == nil && es != nil) {
		return false
	}
	if len(s.Keys) != len(es.Keys) {
		return false
	}
	for si, sel := range s.Keys {
		if sel.Mask != es.Keys[si].Mask || sel.Val != es.Keys[si].Val {
			return false
		}
	}
	return true
}
