package dns

import (
	"fmt"
	"sort"
	"strings"

	"github.com/kentik/patricia"
	"github.com/kentik/patricia/string_tree"

	"a.yandex-team.ru/library/go/core/log"
	"a.yandex-team.ru/library/go/core/xerrors"
	"a.yandex-team.ru/security/waffles/pkg/logger"
)

const (
	tagSeparator = ";"
)

type Tree struct {
	v4 *string_tree.TreeV4
	v6 *string_tree.TreeV6
}

var (
	ErrUnknownIPType = xerrors.NewSentinel("unknown IP address type")
)

func NewTree(virtualHosts map[string][]string) *Tree {
	ipv4Addresses := string_tree.NewTreeV4()
	ipv6Addresses := string_tree.NewTreeV6()
	for ip := range virtualHosts {
		ipv4, ipv6, err := patricia.ParseIPFromString(ip)
		if err != nil {
			logger.L.Error("failed to parse IP address", log.String("ip", ip), log.Error(err))
			continue
		}

		tag := strings.Join(uniqStringSlice(virtualHosts[ip]), tagSeparator)
		switch {
		case ipv4 != nil:
			_, _, err = ipv4Addresses.Set(*ipv4, tag)
		case ipv6 != nil:
			_, _, err = ipv6Addresses.Set(*ipv6, tag)
		}

		if err != nil {
			logger.L.Error("failed to add IP address", log.String("ip", ip), log.Error(err))
			continue
		}
	}

	return &Tree{
		v4: ipv4Addresses,
		v6: ipv6Addresses,
	}
}

func (t *Tree) Hosts(ip string) ([]string, error) {
	ipv4, ipv6, err := patricia.ParseIPFromString(ip)
	if err != nil {
		return nil, fmt.Errorf("failed to parse ip %q: %w", ip, err)
	}

	var rawHosts []string
	switch {
	case ipv4 != nil:
		rawHosts, err = t.v4.FindTags(*ipv4)
	case ipv6 != nil:
		rawHosts, err = t.v6.FindTags(*ipv6)
	default:
		err = ErrUnknownIPType.WithFrame()
	}

	if err != nil {
		return nil, xerrors.Errorf("failed to get hosts for ip %q: %w", ip, err)
	}

	if len(rawHosts) == 0 {
		return nil, nil
	}

	return strings.Split(rawHosts[0], tagSeparator), nil
}

func uniqStringSlice(in []string) []string {
	sort.Strings(in)
	j := 0
	for i := 1; i < len(in); i++ {
		if in[j] == in[i] {
			continue
		}
		j++
		in[j] = in[i]
	}
	return in[:j+1]
}
