package main

import (
	"fmt"
	"log"
	"net"
	"strings"
	"sync"
	"sync/atomic"
	"time"

	"golang.org/x/net/dns/dnsmessage"
	"golang.org/x/sys/unix"

	"a.yandex-team.ru/solomon/tools/discovery/internal/metrics"
)

var (
	udpMaxSize = 512
	tcpMaxSize = 2048
)

const hexDigit = "0123456789abcdef"

// ==========================================================================================

type UDPServer6 struct {
	InFlight int32
	Served   int64
	s        *net.UDPConn
	wg       sync.WaitGroup
}

func (u *UDPServer6) Start(addr string) error {
	var err error

	udpAddr, err := net.ResolveUDPAddr("udp6", addr)
	if err != nil {
		return err
	}
	if u.s, err = net.ListenUDP("udp6", udpAddr); err != nil {
		return err
	}
	raw, err := u.s.SyscallConn()
	if err != nil {
		_ = u.s.Close()
		return err
	}
	if err = raw.Control(func(fd uintptr) {
		if err = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_REUSEADDR, 1); err != nil {
			return
		}
		if err = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_REUSEPORT, 1); err != nil {
			return
		}
		if err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_RECVPKTINFO, 1); err != nil {
			return
		}
	}); err != nil {
		_ = u.s.Close()
		return err
	}
	return nil
}

func (u *UDPServer6) Run(servFunc func([]byte) ([]byte, error), logFunc func(lvl int, format string, v ...interface{})) {
	// 2*sizeof(struct in6_pktinfo) == src + dst
	oobSize := 40

	u.wg.Add(1)
	for {
		buf := make([]byte, udpMaxSize)
		oob := make([]byte, oobSize)

		n, _, _, srcAddr, err := u.s.ReadMsgUDP(buf, oob)
		if err != nil {
			logFunc(0, "cannot read udp: %v", err)
			break
		}
		go func() {
			atomic.AddInt32(&u.InFlight, 1)
			defer func() {
				atomic.AddInt32(&u.InFlight, -1)
			}()

			msg, err := servFunc(buf[:n])
			atomic.AddInt64(&u.Served, 1)
			if err != nil {
				logFunc(0, "request error, %v", err)
			}
			if msg == nil {
				return
			}
			// https://www.ietf.org/rfc/rfc3542.txt Paragraph 6, Note, Page 28
			//
			// Some UDP servers want to respond to client
			// requests by sending their reply out the same interface on which the
			// request was received and with the source IPv6 address of the reply
			// equal to the destination IPv6 address of the request. To do this the
			// application can enable just the IPV6_RECVPKTINFO socket option and
			// then use the received control information from recvmsg() as the
			// outgoing control information for sendmsg(). The application need not
			// examine or modify the in6_pktinfo structure at all.
			if n, _, err = u.s.WriteMsgUDP(msg, oob, srcAddr); err != nil {
				logFunc(0, "failed to write udp, %v", err)
			} else if n != len(msg) {
				logFunc(0, "failed to write udp, sent %d out of %d bytes", n, len(msg))
			}
		}()
	}
	u.wg.Done()
}

func (u *UDPServer6) Stop() error {
	err := u.s.Close()
	u.wg.Wait()
	return err
}

// ==========================================================================================

type TCPServer6 struct {
	Timeout  time.Duration
	InFlight int32
	Served   int64
	s        *net.TCPListener
	wg       sync.WaitGroup
}

func (t *TCPServer6) Start(addr string, timeout time.Duration) error {
	var err error

	t.Timeout = timeout
	tcpAddr, err := net.ResolveTCPAddr("tcp6", addr)
	if err != nil {
		return err
	}
	if t.s, err = net.ListenTCP("tcp6", tcpAddr); err != nil {
		return err
	}
	raw, err := t.s.SyscallConn()
	if err != nil {
		_ = t.s.Close()
		return err
	}
	if err = raw.Control(func(fd uintptr) {
		if err = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_REUSEADDR, 1); err != nil {
			return
		}
		if err = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_REUSEPORT, 1); err != nil {
			return
		}
		if err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_RECVPKTINFO, 1); err != nil {
			return
		}
	}); err != nil {
		_ = t.s.Close()
		return err
	}
	return nil
}

func (t *TCPServer6) Run(servFunc func([]byte) ([]byte, error), logFunc func(lvl int, format string, v ...interface{})) {
	t.wg.Add(1)
	for {
		conn, err := t.s.AcceptTCP()
		if err != nil {
			logFunc(0, "cannot accept tcp: %v", err)
			break
		}
		go func() {
			atomic.AddInt32(&t.InFlight, 1)
			defer func() {
				if err := conn.Close(); err != nil {
					logFunc(1, err.Error())
				}
				atomic.AddInt32(&t.InFlight, -1)
			}()

			logFunc(2, "tcp from %v", conn.RemoteAddr())
			if err = conn.SetDeadline(time.Now().Add(t.Timeout)); err != nil {
				logFunc(0, "failed to set tcp deadline, %v", err)
			}

			buf := make([]byte, tcpMaxSize)
			n, err := conn.Read(buf)
			if err != nil {
				// mask balancer pings
				lvl := 0
				if n <= 0 {
					lvl = 2
				}
				logFunc(lvl, "failed to read (bytes=%d): %v", n, err)
				return
			}
			msg, err := servFunc(buf[:n])
			atomic.AddInt64(&t.Served, 1)
			if err != nil {
				logFunc(0, "request error, %v", err)
			}
			if msg == nil {
				return
			}
			if n, err = conn.Write(msg); err != nil {
				logFunc(0, "failed to write tcp, %v", err)
			} else if n != len(msg) {
				logFunc(0, "failed to write tcp, sent %d out of %d bytes", n, len(msg))
			}
		}()
	}
	t.wg.Done()
}

func (t *TCPServer6) Stop() error {
	err := t.s.Close()
	t.wg.Wait()
	return err
}

// ==========================================================================================

// from https://cs.opensource.google/go/go/+/master:src/internal/itoa/itoa.go;drc=master;l=18
func uitoa(val byte) string {
	if val == 0 { // avoid string allocation
		return "0"
	}
	var buf [20]byte // big enough for 64bit value base 10
	i := len(buf) - 1
	for val >= 10 {
		q := val / 10
		buf[i] = byte('0' + val - q*10)
		i--
		val = q
	}
	// val < 10
	buf[i] = byte('0' + val)
	return string(buf[i:])
}

// from https://cs.opensource.google/go/go/+/master:src/net/dnsclient.go;drc=master;l=32
func arpa(ip net.IP) string {
	if ip.To4() != nil {
		return uitoa(ip[15]) + "." + uitoa(ip[14]) + "." + uitoa(ip[13]) + "." + uitoa(ip[12]) + ".in-addr.arpa."
	}
	// Must be IPv6
	buf := make([]byte, 0, len(ip)*4+len("ip6.arpa."))
	// Add it, in reverse, to the buffer
	for i := len(ip) - 1; i >= 0; i-- {
		v := ip[i]
		buf = append(buf, hexDigit[v&0xF],
			'.',
			hexDigit[v>>4],
			'.',
		)
	}
	// Append "ip6.arpa." and return (buf already has the final .)
	buf = append(buf, "ip6.arpa."...)
	return string(buf)
}

// ==========================================================================================

type DNSServer struct {
	LogPrefix             string
	Addr                  string
	VerboseLevel          int
	cronWg                sync.WaitGroup
	dataStore             atomic.Value // map[dnsmessage.Name]*dnsmessage.Resource
	metrics               atomic.Value // *Metrics
	stopChan              chan struct{}
	metricsUpdateInterval time.Duration
	timeout               time.Duration
	mutex                 sync.Mutex
	tcpServer             *TCPServer6
	udpServer             *UDPServer6
}

func NewDNSServer(dnsAddr string, verboseLevel int) (*DNSServer, error) {
	metricsUpdateInterval := 5 * time.Second
	timeout := 2 * time.Second

	d := &DNSServer{
		LogPrefix:             "[dns] ",
		Addr:                  dnsAddr,
		VerboseLevel:          verboseLevel,
		stopChan:              make(chan struct{}),
		metricsUpdateInterval: metricsUpdateInterval,
		timeout:               timeout,
		tcpServer:             new(TCPServer6),
		udpServer:             new(UDPServer6),
	}
	d.dataStore.Store(map[dnsmessage.Name]*dnsmessage.Resource{})

	if err := d.prepare(); err != nil {
		return nil, err
	}

	d.cronWg.Add(1)
	go d.cron()

	d.log(0, nil, "started")
	return d, nil
}

func (d *DNSServer) log(lvl int, ts *time.Time, format string, v ...interface{}) {
	if d.VerboseLevel >= lvl {
		tsStr := ""
		if ts != nil {
			tsStr = ", " + time.Since(*ts).String()
		}
		log.Printf(d.LogPrefix+format+tsStr, v...)
	}
}

// ==========================================================================================

func (d *DNSServer) cron() {
	defer d.cronWg.Done()

	tickerMetrics := time.NewTicker(d.metricsUpdateInterval)
	defer tickerMetrics.Stop()

	for {
		select {
		case <-tickerMetrics.C:
			initTime := time.Now()
			d.updateMetrics()
			d.log(1, &initTime, "metrics updated")
		case <-d.stopChan:
			d.log(1, nil, "exiting cron task")
			return
		}
	}
}

func (d *DNSServer) prepare() error {
	if err := d.udpServer.Start(d.Addr); err != nil {
		return err
	}
	if err := d.tcpServer.Start(d.Addr, d.timeout); err != nil {
		if err := d.udpServer.Stop(); err != nil {
			d.log(0, nil, "failure on udp server stop, %v", err)
		}
		return err
	}

	go d.udpServer.Run(func(data []byte) ([]byte, error) {
		return d.serveData(data, true)
	}, func(lvl int, format string, v ...interface{}) {
		d.log(lvl, nil, format, v...)
	})
	go d.tcpServer.Run(func(data []byte) ([]byte, error) {
		return d.serveData(data, false)
	}, func(lvl int, format string, v ...interface{}) {
		d.log(lvl, nil, format, v...)
	})

	d.updateMetrics()
	d.log(1, nil, "metrics updated")

	return nil
}

// ==========================================================================================

func (d *DNSServer) Shutdown() error {
	d.log(1, nil, "begin shutdown")
	defer d.log(1, nil, "stopped")

	d.mutex.Lock()
	defer d.mutex.Unlock()

	close(d.stopChan)
	errTCP := d.tcpServer.Stop()
	errUDP := d.udpServer.Stop()
	d.cronWg.Wait()

	if errTCP != nil || errUDP != nil {
		return fmt.Errorf("tcp=%v, udp=%v", errTCP, errUDP)
	}
	return nil
}

// ==========================================================================================

func (d *DNSServer) updateMetrics() {
	m := metrics.NewMetrics()

	inFlight := atomic.LoadInt32(&d.tcpServer.InFlight) + atomic.LoadInt32(&d.udpServer.InFlight)
	served := atomic.LoadInt64(&d.tcpServer.Served) + atomic.LoadInt64(&d.udpServer.Served)

	m.AddIGauge(nil, "dns.inflight", uint64(inFlight))
	m.AddRate(nil, "dns.served", served)

	d.metrics.Store(m)
}

func (d *DNSServer) GetMetrics() *metrics.Metrics {
	return d.metrics.Load().(*metrics.Metrics)
}

// ==========================================================================================

func (d *DNSServer) SetDataClean(data map[string]net.IP, ttl uint32) {
	d.log(0, nil, "setting new data, size=%d TTL=%d", len(data), ttl)
	if err := d.SetData(data, ttl); err != nil {
		d.log(0, nil, "failed to set data: %v", err)
	}
}

func (d *DNSServer) SetData(data map[string]net.IP, ttl uint32) error {
	nm := make(map[dnsmessage.Name]*dnsmessage.Resource, len(data)*2)

	for name, ip := range data {
		name = strings.TrimRight(name, ".") + "."

		nName, err := dnsmessage.NewName(name)
		if err != nil {
			return err
		}
		ipName, err := dnsmessage.NewName(arpa(ip))
		if err != nil {
			return err
		}

		aaaa := new(dnsmessage.AAAAResource)
		copy(aaaa.AAAA[:], ip)
		nm[nName] = &dnsmessage.Resource{
			Header: dnsmessage.ResourceHeader{
				Name:  nName,
				Class: dnsmessage.ClassINET,
				TTL:   ttl,
			},
			Body: aaaa,
		}

		nm[ipName] = &dnsmessage.Resource{
			Header: dnsmessage.ResourceHeader{
				Name:  ipName,
				Class: dnsmessage.ClassINET,
				TTL:   ttl,
			},
			Body: &dnsmessage.PTRResource{
				PTR: nName,
			},
		}
	}
	d.dataStore.Store(nm)
	return nil
}

// ==========================================================================================

func (d *DNSServer) serveData(data []byte, isUDP bool) ([]byte, error) {
	defer d.log(2, nil, "request processed")

	var err, errPack error

	if !isUDP {
		l := (int(data[0]) << 8) + int(data[1]) + 2
		if l != len(data) {
			return nil, fmt.Errorf("bad tcp message size bytes, want=%d got=%d max=%d", l, len(data), tcpMaxSize)
		}
		data = data[2:]
	}

	req := dnsmessage.Message{}
	if err = req.Unpack(data); err != nil {
		return nil, err
	}

	// Removing Questions section from the response will result in -500 ns per request
	resp := dnsmessage.Message{}
	resp.Questions = req.Questions
	resp.Header.ID = req.Header.ID
	resp.Header.Response = true
	resp.Header.RCode = dnsmessage.RCodeSuccess

	if req.Header.Response || req.Header.OpCode != 0 {
		resp.Header.RCode = dnsmessage.RCodeFormatError
		err = fmt.Errorf("bad request response=%v opcode=%s", req.Header.Response, req.Header.OpCode.GoString())
	} else {
		for _, q := range req.Questions {
			if q.Class != dnsmessage.ClassINET {
				resp.Header.RCode = dnsmessage.RCodeNotImplemented
				err = fmt.Errorf("bad class=%s in request for %s", q.Class.String(), q.Name.String())
				break
			}
			if q.Type == dnsmessage.TypeAAAA || q.Type == dnsmessage.TypePTR {
				r, ok := d.dataStore.Load().(map[dnsmessage.Name]*dnsmessage.Resource)[q.Name]
				if !ok {
					// NameError == NXDOMAIN, really should be Refused
					resp.Header.RCode = dnsmessage.RCodeRefused
					break
				}
				resp.Answers = append(resp.Answers, *r)
			} else {
				resp.Header.RCode = dnsmessage.RCodeRefused
				err = fmt.Errorf("bad type=%s in request for %s", q.Type.String(), q.Name.String())
				break
			}
		}
	}
	buf := make([]byte, 2, udpMaxSize+2)
	if buf, errPack = resp.AppendPack(buf); errPack != nil {
		return nil, fmt.Errorf("failed to pack self created message, %s, %v", resp.GoString(), errPack)
	}
	size := len(buf) - 2
	if isUDP {
		buf = buf[2:]
		if size > udpMaxSize {
			buf = buf[:udpMaxSize]
			buf[2] |= 1 << 1 // Truncated
		}
	} else {
		if size > tcpMaxSize {
			size = tcpMaxSize
			buf = buf[:tcpMaxSize+2]
			buf[4] |= 1 << 1 // Truncated
		}
		buf[0], buf[1] = byte(size>>8), byte(size)
	}
	return buf, err
}
