package firewall

import (
	"context"
	"fmt"
	"net"
	"strconv"

	"a.yandex-team.ru/intranet/auth-checker/internal/cache"
	"a.yandex-team.ru/intranet/auth-checker/internal/database"
	"a.yandex-team.ru/library/go/core/log"
	"a.yandex-team.ru/library/go/core/log/ctxlog"
)

type IPChecker struct {
	cache    cache.Cache
	dbClient database.Client
	logger   log.Logger
}

func NewIPChecker(logger log.Logger, cache cache.Cache, dbClient database.Client) *IPChecker {
	return &IPChecker{
		cache:    cache,
		dbClient: dbClient,
		logger:   logger,
	}
}

func (c *IPChecker) IsInternal(ctx context.Context, ip string) (bool, error) {
	ctxlog.Debugf(ctx, c.logger, "ipchecker: trying to check ip: %v", ip)
	if c.cache != nil {
		value, err := c.cache.Get(ip)
		if err == nil {
			ctxlog.Debugf(ctx, c.logger, "ipchecker: using cache value: %v", value)
			return value, nil
		}
	}

	netIP := net.ParseIP(ip)
	if netIP == nil {
		ctxlog.Debugf(ctx, c.logger, "ipchecker: unable to parse ip: %v", ip)
		return false, fmt.Errorf("unable to parse ip: %v", ip)
	}

	leftValue, rightValue := calculateLeftAndRightValue(netIP)
	ctxlog.Debugf(ctx, c.logger, "ipchecker: ip %v splits to leftValue: %v and rightValue: %v", ip, leftValue, rightValue)

	conn, err := c.dbClient.GetReadConnection(ctx)
	if err != nil {
		return false, fmt.Errorf("unable to get connection from pool: %w", err)
	}
	defer conn.Release()

	sql := `SELECT racktables_netrule.is_allowed
			FROM racktables_netrule WHERE
			racktables_netrule.left_begin <= $1 AND
			racktables_netrule.left_end >= $1 AND
			racktables_netrule.right_begin <= $2 AND
			racktables_netrule.right_end >= $2`

	rows, err := conn.Query(ctx, sql, strconv.FormatUint(leftValue, 10), strconv.FormatUint(rightValue, 10))
	if err != nil {
		return false, fmt.Errorf("unable to execute query: %w", err)
	}
	defer rows.Close()

	var isAllowed bool
	for rows.Next() {
		if err = rows.Scan(&isAllowed); err != nil {
			return false, fmt.Errorf("unable to parse row: %w", err)
		}
		if !isAllowed {
			break
		}
	}
	ctxlog.Debugf(ctx, c.logger, "ipchecker: using database value: %v", isAllowed)

	c.cache.Set(ip, isAllowed)
	return isAllowed, nil
}

func calculateLeftAndRightValue(ip net.IP) (leftValue, rightValue uint64) {
	ipv4 := ip.To4()
	if ipv4 != nil {
		for index, byte_ := range ipv4 {
			rightValue += uint64(byte_) * uint64(1<<(8*(3-index)))
		}
		return
	}
	for index, byte_ := range ip {
		if index <= 7 {
			leftValue += uint64(byte_) * uint64(1<<(8*(7-index)))
		} else {
			rightValue += uint64(byte_) * uint64(1<<(8*(15-index)))
		}
	}
	return
}
