package client

import (
	"context"
	"encoding/csv"
	"errors"
	"io"
	"net/http"
	"time"

	netutil "a.yandex-team.ru/infra/walle/server/go/internal/lib/net"
	"a.yandex-team.ru/infra/walle/server/go/internal/lib/tool"
)

type Racktables interface {
	HasSwitchPort(switchName netutil.SwitchName, portName netutil.SwitchPort) (bool, error)
}

type RacktablesConfig struct {
	Host                 string               `yaml:"host"`
	Token                string               `yaml:"token"`
	ExperimentalSwitches []netutil.SwitchName `yaml:"experimental_switches"`
}

type racktables struct {
	client               *http.Client
	host                 string
	token                string
	cacheSwitchPorts     *tool.CacheStore[netutil.SwitchName, []netutil.SwitchPort]
	experimentalSwitches []netutil.SwitchName
}

func NewRacktables(cnf *RacktablesConfig) Racktables {
	r := &racktables{
		client:               &http.Client{},
		host:                 "https://" + cnf.Host,
		token:                cnf.Token,
		experimentalSwitches: cnf.ExperimentalSwitches,
	}
	r.cacheSwitchPorts = &tool.CacheStore[netutil.SwitchName, []netutil.SwitchPort]{}
	r.cacheSwitchPorts.Init(r.getSwitchPorts, 30*time.Second)
	return r
}

func (r *racktables) HasSwitchPort(switchName netutil.SwitchName, portName netutil.SwitchPort) (bool, error) {
	ports, err := r.cacheSwitchPorts.Get(switchName)
	if err != nil {
		return false, err
	}
	if len(ports) == 0 {
		for _, experimental := range r.experimentalSwitches {
			if switchName == experimental {
				return true, nil
			}
		}
	}
	for _, port := range ports {
		if port == portName {
			return true, nil
		}
	}
	return false, nil
}

func (r *racktables) getSwitchPorts() (map[netutil.SwitchName][]netutil.SwitchPort, error) {
	ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
	defer cancel()
	req, err := http.NewRequestWithContext(ctx, "GET", r.host+"/export/switchports.txt", nil)
	if err != nil {
		return nil, err
	}
	req.Header.Add("Authorization", "OAuth "+r.token)
	resp, err := r.client.Do(req)
	if err != nil {
		return nil, err
	}
	defer resp.Body.Close()
	reader := csv.NewReader(resp.Body)
	reader.Comma = '\t'
	reader.FieldsPerRecord = 3
	result := make(map[netutil.SwitchName][]netutil.SwitchPort)
	for {
		fields, err := reader.Read()
		if errors.Unwrap(err) == csv.ErrFieldCount {
			continue
		}
		if err == io.EOF {
			break
		}
		if err != nil {
			return nil, err
		}
		sw, port := netutil.SwitchName(fields[0]), netutil.SwitchPort(fields[1])
		result[sw] = append(result[sw], port)
	}
	return result, nil
}
