package ipdc

import (
	"bytes"
	"compress/gzip"
	"context"
	"fmt"
	"io"
	"log"
	"net"
	"os"
	"sort"
	"strings"
	"time"

	"a.yandex-team.ru/solomon/libs/go/cache"
	"a.yandex-team.ru/solomon/libs/go/uhttp"
)

// ==========================================================================================

type NetRecord struct {
	IP6Min   net.IP       `json:"m"`
	IP6Max   net.IP       `json:"M"`
	Pref6Len int          `json:"p"`
	Children []*NetRecord `json:"c,omitempty"`
	Location string       `json:"l"`
}

var dcNames = map[string]string{
	"Сасово":     "sas",
	"Владимир":   "vla",
	"Мянтсяля":   "man",
	"Ивантеевка": "iva",
	"Мытищи":     "myt",
}

// ==========================================================================================

func strToMask(s string) (int, bool) {
	i, m := 0, 0
	for ; i < len(s) && i < 3 && '0' <= s[i] && s[i] <= '9'; i++ {
		m = m*10 + int(s[i]-'0')
	}
	if i == 0 {
		return 0, false
	}
	return m, true
}

func minIP6(ip net.IP, prefLen int) net.IP {
	res := make(net.IP, net.IPv6len)
	for i := 0; i < net.IPv6len && prefLen > 0; i++ {
		if prefLen > 8 {
			res[i] = ip[i]
		} else {
			prefLen = 8 - prefLen
			res[i] = (ip[i] >> prefLen) << prefLen
		}
		prefLen -= 8
	}
	return res
}

func maxIP6(ip net.IP, prefLen int) net.IP {
	res := make(net.IP, net.IPv6len)
	for i := 0; i < net.IPv6len; i++ {
		if prefLen > 8 {
			res[i] = ip[i]
		} else if prefLen <= 0 {
			res[i] = 0xff
		} else {
			res[i] = ip[i] | (byte(0xff) >> prefLen)
		}
		prefLen -= 8
	}
	return res
}

func lessIP6(a, b net.IP) bool {
	for i := 0; i < net.IPv6len; i++ {
		if a[i] != b[i] {
			return a[i] < b[i]
		}
	}
	return false
}

func printNet6(base []*NetRecord, offt string) {
	for _, c := range base {
		fmt.Printf("%s%v/%d %v children=%d %s\n", offt, c.IP6Min, c.Pref6Len, c.IP6Max, len(c.Children), c.Location)
		printNet6(c.Children, offt+"    ")
	}
}

func findNarrowSubnetIP6(base []*NetRecord, test net.IP) (*NetRecord, int) {
	rn := len(base)
	if rn == 0 {
		return nil, -1
	}
	// smallest index idx in [0, rn) at which
	// base[idx].IP6Max >= test
	idx := sort.Search(rn, func(i int) bool {
		return !lessIP6(base[i].IP6Max, test)
	})
	// bad if not found or base[idx].IP6Min > test
	if idx == rn || lessIP6(test, base[idx].IP6Min) {
		return nil, -1
	}
	if mn, _ := findNarrowSubnetIP6(base[idx].Children, test); mn != nil {
		return mn, idx
	}
	return base[idx], idx
}

func findNarrowSubnetNet6(base []*NetRecord, test *NetRecord) *NetRecord {
	rn := len(base)
	if rn == 0 {
		return nil
	}
	// smallest index idx in [0, rn) at which
	// base[idx].IP6Max >= test.IP6Max
	idx := sort.Search(rn, func(i int) bool {
		return !lessIP6(base[i].IP6Max, test.IP6Max)
	})
	// bad if not found or base[idx].IP6Min > test.IP6Min
	if idx == rn || lessIP6(test.IP6Min, base[idx].IP6Min) {
		return nil
	}
	if mn := findNarrowSubnetNet6(base[idx].Children, test); mn != nil {
		return mn
	}
	return base[idx]
}

func isSubnet6(na, nb *NetRecord) bool {
	return !lessIP6(na.IP6Min, nb.IP6Min) && !lessIP6(nb.IP6Max, na.IP6Max)
}

// minimum begin value, maximum subnet
func lessNet6(na, nb *NetRecord) bool {
	a, b := na.IP6Min, nb.IP6Min
	for i := 0; i < net.IPv6len; i++ {
		if a[i] != b[i] {
			return a[i] < b[i]
		}
	}
	return na.Pref6Len < nb.Pref6Len
}

// ==========================================================================================

// IP to DC converter
//
// type IPDC interface {
//     GetDc(net.IP) (string, error)
//     GetDcMany([]net.IP) ([]string, error)
//     Print()
//     Purge()
//     Destroy()
//     Dump() ([]byte, error)
//     Restore([]byte) error
// }
//

type IPDC struct {
	LogPrefix    string
	VerboseLevel int
	netListFile  string
	netListData  []byte
	rtClient     *uhttp.Client
	cache        *cache.Cache
}

func NewIPDCWithDefaults(verboseLevel int) *IPDC {
	goodCacheTime := 4 * time.Hour
	badCacheTime := 60 * time.Second
	prefetchTime := 40 * time.Second
	cleanUpInterval := 10 * time.Second
	cacheSize := 10
	serveStale := true
	httpTimeout := 10 * time.Second

	return NewIPDC(goodCacheTime, badCacheTime, prefetchTime, cleanUpInterval, httpTimeout, cacheSize, serveStale, verboseLevel)
}

func NewIPDC(goodCacheTime, badCacheTime, prefetchTime, cleanUpInterval, httpTimeout time.Duration,
	cacheSize int,
	serveStale bool,
	verboseLevel int) *IPDC {

	d := &IPDC{
		LogPrefix:    "[ipdc] ",
		VerboseLevel: verboseLevel,
		rtClient:     uhttp.NewClient("https://ro.racktables.yandex.net/export", "ipdc-agent", nil, httpTimeout, true),
	}

	d.cache = cache.NewCache(
		"ipdc",
		func(req interface{}) (interface{}, error) {
			return d.cacheFunc(req.(struct{}))
		},
		goodCacheTime,
		badCacheTime,
		prefetchTime,
		cleanUpInterval,
		serveStale,
		verboseLevel,
		cacheSize,
	)
	return d
}

func (d *IPDC) log(lvl int, ts *time.Time, format string, v ...interface{}) {
	if d.VerboseLevel >= lvl {
		tsStr := ""
		if ts != nil {
			tsStr = ", " + time.Since(*ts).String()
		}
		log.Printf(d.LogPrefix+format+tsStr, v...)
	}
}

func (d *IPDC) cacheFunc(_ struct{}) ([]*NetRecord, error) {
	recordEstimate := 32768
	var err error
	var reader io.Reader

	startTime := time.Now()
	if d.netListData != nil {
		reader = bytes.NewBuffer(d.netListData)
	} else if d.netListFile != "" {
		file, _ := os.Open(d.netListFile)
		defer file.Close()
		reader = file
	} else {
		req, err := d.rtClient.NewGetRequest(context.Background(), "/networklist-perdc.txt.gz")
		if err != nil {
			return nil, err
		}
		data, err := d.rtClient.SendRequest(req)
		if err != nil {
			return nil, err
		}
		reader = bytes.NewBuffer(data)
	}

	zr, err := gzip.NewReader(reader)
	if err != nil {
		return nil, err
	}

	recs := make([]*NetRecord, 0, recordEstimate)

	var buf bytes.Buffer
	n, err := buf.ReadFrom(zr)
	d.log(1, nil, "unpacked %d bytes for networklist-perdc table", n)

	if zrr := zr.Close(); zrr != nil {
		return nil, zrr
	}
	if err != nil {
		return nil, err
	}

	// process data fields
	netProcessor := func(netString, dcString string) {
		idx := strings.IndexByte(netString, '/')
		if idx < 0 {
			d.log(1, nil, "bad record in networklist-perdc table, no net mask found, %s", netString)
			return
		}
		// use always 16 byte IP
		ip := net.ParseIP(netString[:idx]).To16()
		if ip == nil {
			d.log(1, nil, "bad record in networklist-perdc table, bad ip, %s", netString)
		}
		prefLen, ok := strToMask(netString[idx+1:])
		if !ok {
			d.log(1, nil, "bad record in networklist-perdc table, bad net mask, %s", netString)
			return
		}
		if ip.To4() != nil {
			prefLen += 96
		}

		var dc string
		for k, v := range dcNames {
			if strings.Contains(dcString, k) {
				dc = v
				break
			}
		}
		if dc == "" {
			return
		}
		nr := &NetRecord{
			IP6Min:   minIP6(ip, prefLen),
			IP6Max:   maxIP6(ip, prefLen),
			Pref6Len: prefLen,
			Location: dc,
		}
		recs = append(recs, nr)
	}
	// parse buffer
	b := buf.Bytes()
	idxMax := int(n)
	idxFin := idxMax - 1
	isSep := true
	fields := []int{}
	for idx := 0; idx < idxMax; idx++ {
		c := b[idx]
		eol := (c == '\n' || idx == idxFin)
		if eol || c == '\t' {
			if !isSep {
				isSep = true
				if idx == idxFin {
					fields = append(fields, idxMax)
				} else {
					fields = append(fields, idx)
				}
				if eol {
					// process if 4 fields
					if len(fields) == 8 {
						netProcessor(string(b[fields[0]:fields[1]]), string(b[fields[4]:fields[5]]))
					}
					fields = fields[:0]
				}
			}
		} else if isSep {
			isSep = false
			fields = append(fields, idx)
		}
	}
	if len(recs) == 0 {
		return nil, fmt.Errorf("got 0 net records from source")
	}
	sort.Slice(recs, func(i int, j int) bool {
		return lessNet6(recs[i], recs[j])
	})
	// pileup children
	recsSize := len(recs)
	for baseNet, idx := recs[0], 1; idx < recsSize; idx++ {
		testNet := recs[idx]
		if isSubnet6(testNet, baseNet) {
			target := findNarrowSubnetNet6(baseNet.Children, testNet)
			if target == nil {
				target = baseNet
			}
			if target.Children == nil {
				target.Children = make([]*NetRecord, 0)
			}
			target.Children = append(target.Children, testNet)
			recs[idx] = nil
		} else {
			baseNet = testNet
		}
	}
	// remove nil refs
	nx := 1
	for ; nx < recsSize && recs[nx] != nil; nx++ {
	}
	for idx := nx + 1; idx < recsSize; idx++ {
		if recs[idx] != nil {
			recs[nx], recs[idx] = recs[idx], nil
			nx++
		}
	}
	recs = recs[:nx]

	d.log(1, &startTime, "got %d nets, %d nonoverlapping nets", recsSize, len(recs))
	if d.VerboseLevel >= 3 && len(recs) > 0 {
		printNet6(recs, "")
	}

	return recs, nil
}

func (d *IPDC) GetDc(ip net.IP) (string, error) {
	var dc string
	reqTime := time.Now()

	rs, err := d.cache.Get(struct{}{})
	if err != nil {
		return dc, err
	}
	recs := rs.([]*NetRecord)
	if len(recs) == 0 {
		return dc, fmt.Errorf("cannot search, got zero length net records list")
	}

	ip16 := ip.To16()
	if xNet, _ := findNarrowSubnetIP6(recs, ip16); xNet != nil {
		d.log(2, &reqTime, "got %v/%d (%s) for %v", xNet.IP6Min, xNet.Pref6Len, xNet.Location, ip16)
		return xNet.Location, nil
	}
	d.log(2, &reqTime, "got nothing for %v within %d nets", ip16, len(recs))
	return dc, fmt.Errorf("ip addr is not within %d nets", len(recs))
}

func (d *IPDC) GetDcMany(ips []net.IP) ([]string, error) {
	reqTime := time.Now()
	ipsLen := len(ips)
	result := make([]string, ipsLen)

	if ipsLen == 0 {
		return nil, fmt.Errorf("zero length request")
	}

	rs, err := d.cache.Get(struct{}{})
	if err != nil {
		return nil, err
	}
	recs := rs.([]*NetRecord)
	if len(recs) == 0 {
		return nil, fmt.Errorf("cannot search, got zero length net records list")
	}

	idxs := make([]int, ipsLen)
	for i := 0; i < ipsLen; i++ {
		ips[i] = ips[i].To16()
		idxs[i] = i
	}
	sort.Slice(idxs, func(i int, j int) bool {
		return lessIP6(ips[idxs[i]], ips[idxs[j]])
	})
	netBinCut(0, len(recs), 0, ipsLen-1, recs, ips, idxs, result)

	d.log(2, &reqTime, "got reply for %d ip addrs", ipsLen)
	return result, nil
}

func netBinCut(minRec, maxRec, minIdx, maxIdx int, recs []*NetRecord, ips []net.IP, idxs []int, result []string) {
	if xNet, idx := findNarrowSubnetIP6(recs[minRec:maxRec], ips[idxs[minIdx]]); xNet != nil {
		minRec += idx
		result[idxs[minIdx]] = xNet.Location
	}
	if minIdx == maxIdx {
		return
	}
	if xNet, idx := findNarrowSubnetIP6(recs[minRec:maxRec], ips[idxs[maxIdx]]); xNet != nil {
		maxRec = minRec + idx + 1
		result[idxs[maxIdx]] = xNet.Location
	}
	if maxIdx-minIdx == 1 {
		return
	}
	midIdx := (minIdx + maxIdx) / 2
	midRecLow, midRecHigh := maxRec, minRec
	if xNet, idx := findNarrowSubnetIP6(recs[minRec:maxRec], ips[idxs[midIdx]]); xNet != nil {
		midRecHigh = minRec + idx
		midRecLow = midRecHigh + 1
		result[idxs[midIdx]] = xNet.Location
	}
	if midIdx-minIdx > 1 {
		netBinCut(minRec, midRecLow, minIdx+1, midIdx-1, recs, ips, idxs, result)

	}
	if maxIdx-midIdx > 1 {
		netBinCut(midRecHigh, maxRec, midIdx+1, maxIdx-1, recs, ips, idxs, result)
	}
}

func (d *IPDC) Print() error {
	rs, err := d.cache.Get(struct{}{})
	if err != nil {
		return err
	}
	recs := rs.([]*NetRecord)
	if len(recs) == 0 {
		return fmt.Errorf("got zero length net records list")
	}
	printNet6(recs, "")
	return nil
}

func (d *IPDC) Purge() {
	d.log(1, nil, "purging")
	d.cache.Purge()
}

func (d *IPDC) Destroy() {
	d.log(1, nil, "destroying")
	d.cache.Destroy()
}

func (d *IPDC) Dump(onlyFresh bool) ([]byte, error) {
	d.log(1, nil, "dumping %d records (only fresh = %v)", d.cache.Len(), onlyFresh)
	return d.cache.Dump(onlyFresh)
}

func (d *IPDC) Restore(inData []byte) error {
	d.log(1, nil, "restoring")
	bumpEOL := false
	err := d.cache.Restore(inData, bumpEOL, struct{}{}, []*NetRecord{})
	if err != nil {
		return err
	}
	d.log(1, nil, "restored %d records", d.cache.Len())
	return nil
}
