package samples

import (
	"io/ioutil"
	"net"
	"path/filepath"
	"strings"

	"github.com/vishvananda/netlink"
	"go.uber.org/zap"
	"golang.org/x/sys/unix"

	"a.yandex-team.ru/infra/tcp-sampler/pkg/iputil"
	"a.yandex-team.ru/infra/tcp-sampler/pkg/nsenter"
	"a.yandex-team.ru/infra/tcp-sampler/pkg/sockdiag"
	"a.yandex-team.ru/infra/tcp-sampler/pkg/sysctl"

	pb "a.yandex-team.ru/mds/valve/proto"
)

const (
	procNetTCP6 = "/proc/thread-self/net/tcp6"
	ephPort     = 50000
)

type TCPSet map[sample]struct{}

type limTCPInfo struct {
	SegsOut uint32
	SegsIn  uint32
	State   pb.TCPState
}

type sample struct {
	Chain   pb.Chain
	SrcAddr string
	DstAddr string
	SrcPort uint32 // no unit16 in protobuf
	DstPort uint32
	Iface   string
	Tcpi    limTCPInfo
}

func getLocalListenPorts(sockDiag []sockdiag.InetDiagTCPInfoResp) map[uint32]struct{} {

	localListenPorts := make(map[uint32]struct{})
	for _, s := range sockDiag {
		if s.TCPInfo != nil {
			if s.TCPInfo.State == sockdiag.TCPListen {
				localPort := uint32(s.InetDiagMsg.ID.SourcePort)
				localListenPorts[localPort] = struct{}{}
			}
		}
	}
	zap.S().Debugf("LocalListenPorts: %v\n", localListenPorts)
	return localListenPorts
}

func isPortEph(p, ephMin, ephMax uint32) bool {
	isEph := false
	if p > ephMin && p < ephMax {
		isEph = true
	}
	return isEph
}

func tcpInfoChanged(s sample, oldTCP TCPSet) bool {
	if _, ok := oldTCP[s]; ok {
		zap.S().Debug("Found!")
		return false
	}
	return true
}

type PrefaceTCP struct {
	SockDiag []sockdiag.InetDiagTCPInfoResp
	EphMin   uint32
	EphMax   uint32
	Iface    string
	HostIPs  map[string]struct{}
	RootPid  int
	Neighs   map[string]string
}

// ParseTCP6 gets info from sockDiag and uses it to form samples
func (pt PrefaceTCP) Parse() []*pb.TCPSample {

	tcpSamples := make(map[sample]struct{})
	localListenPorts := getLocalListenPorts(pt.SockDiag)

	for _, sd := range pt.SockDiag {
		s := sample{}

		state := sd.InetDiagMsg.State
		if state == sockdiag.TCPSynSent {
			continue
		}

		dstAddr := sd.InetDiagMsg.ID.Destination
		// Filter out connections to localhost
		if dstAddr.IsLoopback() || dstAddr.IsUnspecified() {
			continue
		} else {
			s.DstAddr = dstAddr.String()
		}

		s.SrcAddr = sd.InetDiagMsg.ID.Source.String()

		s.SrcPort = uint32(sd.InetDiagMsg.ID.SourcePort)
		s.DstPort = uint32(sd.InetDiagMsg.ID.DestinationPort)

		isDstEph := isPortEph(s.DstPort, pt.EphMin, pt.EphMax)

		if state != sockdiag.TCPListen {
			if isPortEph(s.SrcPort, pt.EphMin, pt.EphMax) {
				s.SrcPort = ephPort
			}
		}

		s.Tcpi = limTCPInfo{State: pb.TCPState(state)}
		if (state == sockdiag.TCPEstablished || state == sockdiag.TCPCloseWait) && (sd.TCPInfo != nil) {
			s.Tcpi.SegsOut = sd.TCPInfo.SegsOut
			s.Tcpi.SegsIn = sd.TCPInfo.SegsIn
		}

		// Container
		if pt.RootPid != 1 {
			if _, ok := localListenPorts[s.SrcPort]; ok {
				if _, ok := pt.HostIPs[s.DstAddr]; ok {
					s.Chain = pb.Chain_OUTPUT
				} else {
					s.Chain = pb.Chain_FORWARD
				}
			} else if s.SrcPort == ephPort && !isDstEph {
				if _, ok := pt.HostIPs[s.DstAddr]; ok {
					s.Chain = pb.Chain_INPUT
				} else {
					s.Chain = pb.Chain_FORWARD
				}
			} else {
				s.Chain = pb.Chain_UNKNOWN
			}
		} else { // dom0
			if _, ok := localListenPorts[s.SrcPort]; ok {
				s.Chain = pb.Chain_INPUT
			} else if s.SrcPort == ephPort && !isDstEph {
				s.Chain = pb.Chain_OUTPUT
			} else {
				s.Chain = pb.Chain_UNKNOWN
			}
		}

		if pt.RootPid != 1 {
			s.Iface = pt.Iface
		} else {
			if net.IP.Equal(sd.InetDiagMsg.ID.Source, sd.InetDiagMsg.ID.Destination) {
				s.Iface = pt.Iface
			} else {
				zap.S().Debug("Equal prefixes, search neigh")
				if len(pt.Neighs) == 0 {
					zap.S().Debug("No neighbours, set unknown")
					s.Iface = "unknown"
				} else {
					if nIface, ok := pt.Neighs[s.DstAddr]; ok {
						s.Iface = nIface
					} else {
						s.Iface = pt.Iface
					}
				}
			}
		}

		tcpSamples[s] = struct{}{}
	}

	resultTCPSamples := make([]*pb.TCPSample, 0, len(tcpSamples))
	for k := range tcpSamples {
		tcpi := &pb.TCPInfo{}
		tcpi.SegsOut = k.Tcpi.SegsOut
		tcpi.SegsIn = k.Tcpi.SegsIn
		tcpi.State = k.Tcpi.State
		smpl := &pb.TCPSample{}
		smpl.Iface = k.Iface
		smpl.Chain = k.Chain
		smpl.SrcAddr = net.ParseIP(k.SrcAddr)
		smpl.DstAddr = net.ParseIP(k.DstAddr)
		smpl.SrcPort = k.SrcPort
		smpl.DstPort = k.DstPort
		smpl.Tcpi = tcpi
		resultTCPSamples = append(resultTCPSamples, smpl)
	}
	return resultTCPSamples
}

func FormInterfaceSamples() ([]*pb.Interface, error) {
	var r []*pb.Interface
	ifaces, err := net.Interfaces()
	if err != nil {
		return nil, err
	}
	for _, i := range ifaces {
		ifaceSample := &pb.Interface{}
		ifaceSample.Name = i.Name

		addresses, err := i.Addrs()
		if err != nil {
			zap.S().Debug(err)
			continue
		}
		var addressesForSample [][]byte
		for _, a := range addresses {
			addressesForSample = append(addressesForSample, net.ParseIP(strings.Split(a.String(), "/")[0]))
		}
		ifaceSample.Ips = addressesForSample
		r = append(r, ifaceSample)
	}
	return r, nil
}

func getPeerIndex() int {
	var pIndex int
	ll, err := netlink.LinkList()
	if err != nil {
		return -1
	}
	for _, l := range ll {
		if veth, ok := l.(*netlink.Veth); ok {
			pIndex, err = netlink.VethPeerIndex(veth)
			if err != nil {
				return -1
			}
			break
		}
	}
	return pIndex
}

func getActiveIface() string {
	matches, err := filepath.Glob("/sys/class/net/eth*/operstate")
	if err != nil {
		zap.S().Debug(err)
	}
	zap.S().Debugf("Matches: %#v\n", matches)
	for _, m := range matches {
		d, err := ioutil.ReadFile(m)
		if err != nil {
			zap.S().Debug(err)
		}
		s := strings.Trim(string(d), "\n")
		if s == "up" {
			return strings.Split(m, "/")[4]
		}
	}
	return "unknown"
}

func FormNetSampleNSMap(notInherited map[string]int, neighs map[string]string) map[string]*pb.NetNSSample {
	m := make(map[string]*pb.NetNSSample)

	hostIPs, err := iputil.GetHostIPs()
	if err != nil {
		zap.S().Debugf("Could not get host IPs: %v", err)
		// panic() ??
	}

	for name, rootPid := range notInherited {
		var (
			interfaceSamples []*pb.Interface
			ephMin           uint32
			ephMax           uint32
			peerIndex        int
			sockDiag         []sockdiag.InetDiagTCPInfoResp
		)
		if err := nsenter.WithSetNetNS(rootPid, func() (err error) {
			if interfaceSamples, err = FormInterfaceSamples(); err != nil {
				return err
			}
			if sockDiag, err = sockdiag.SocketDiagTCPInfo(unix.AF_INET6); err != nil {
				return err
			}
			ephMin, ephMax = sysctl.GetLocalPortRange()
			peerIndex = getPeerIndex()
			return
		}); err != nil {
			zap.S().Debugf("unable to to job with net namespace of PID %d: %s", rootPid, err)
			continue
		}

		var iface string
		if rootPid == 1 {
			iface = getActiveIface()
		} else if peerIndex < 0 {
			iface = "unknown"
		} else {
			i, err := netlink.LinkByIndex(peerIndex)
			if err != nil {
				iface = "unknown"
			} else {
				iface = i.Attrs().Name
			}
		}

		preTCP := PrefaceTCP{
			SockDiag: sockDiag,
			EphMin:   ephMin,
			EphMax:   ephMax,
			Iface:    iface,
			HostIPs:  hostIPs,
			RootPid:  rootPid,
			Neighs:   neighs,
		}

		parsedTCP := preTCP.Parse()

		m[name] = &pb.NetNSSample{
			Tcp:   parsedTCP,
			Iface: interfaceSamples,
		}
	}
	return m
}

func Filter(ns map[string]*pb.NetNSSample, oldFullSample *pb.NetSample) map[string]*pb.NetNSSample {
	filtered := make(map[string]*pb.NetNSSample)

	for newName, newNetNSSample := range ns {
		filteredTCP := []*pb.TCPSample{}
		zap.S().Debugf("Filtering %s", newName)
		if oldNetNSSample, ok := oldFullSample.Ns[newName]; ok {
			setOfOldTCPSamples := formSetOfSamples(oldNetNSSample.Tcp)
			for _, newTCPSample := range newNetNSSample.Tcp {
				if setOfOldTCPSamples.HasProto(newTCPSample) {
					zap.S().Debugf("Found: %s", newTCPSample)
					continue
				}
				filteredTCP = append(filteredTCP, newTCPSample)
			}
		} else {
			filteredTCP = newNetNSSample.Tcp
		}
		filtered[newName] = &pb.NetNSSample{Iface: newNetNSSample.Iface, Tcp: filteredTCP}
	}
	return filtered
}

func formSetOfSamples(samples []*pb.TCPSample) TCPSet {
	m := make(TCPSet)
	for _, sample := range samples {
		m[convertToSampleStruct(sample)] = struct{}{}
	}
	return m
}

func convertToSampleStruct(s *pb.TCPSample) sample {
	r := sample{
		SrcAddr: net.IP(s.SrcAddr).String(),
		SrcPort: s.SrcPort,
		DstAddr: net.IP(s.DstAddr).String(),
		DstPort: s.DstPort,
	}

	if s.Tcpi != nil {
		r.Tcpi = limTCPInfo{
			SegsIn:  s.Tcpi.SegsIn,
			SegsOut: s.Tcpi.SegsOut,
			State:   s.Tcpi.State,
		}
	}
	return r
}

func (t TCPSet) HasProto(s *pb.TCPSample) bool {
	sampleStruct := convertToSampleStruct(s)
	_, found := t[sampleStruct]
	return found
}
