package sirena

import (
	"bytes"
	"context"
	"encoding/binary"
	"io"
	"net"
	"time"

	"go.uber.org/atomic"

	"a.yandex-team.ru/library/go/core/log"
	"a.yandex-team.ru/library/go/core/xerrors"
)

const HeaderLength = 100

var messageID atomic.Int64
var timeout = 10 * time.Second

type request struct {
	Host     string
	Port     string
	clientID uint16
	Message  []byte
}

type ConnectionConfig struct {
	Host     string
	Port     string
	ClientID uint16
}

func ServeSirenaRequest(logger log.Logger, ctx_ context.Context, sirenaConfig ConnectionConfig, requestMsg []byte) ([]byte, error) {
	timestamp := time.Now().Unix()
	messageID := messageID.Inc()
	logger.Infof("Request to Sirena # %v-%v", timestamp, messageID)
	logger.Infof("Request body: %v", string(requestMsg))
	logger.Infof("Request host: %v:%v", sirenaConfig.Host, sirenaConfig.Port)
	fullMsg := makeFullMessage(logger, requestMsg, timestamp, messageID, sirenaConfig.ClientID)
	ctx, cancel := context.WithTimeout(ctx_, timeout)
	defer cancel()
	sock, err := getSocket(ctx, sirenaConfig.Host, sirenaConfig.Port)
	if err != nil {
		return nil, xerrors.Errorf("tryGet(): %w", err)
	}
	defer func(sock net.Conn) {
		err := sock.Close()
		if err != nil {
			logger.Error("Error closing Sirena connection", log.Error(err))
		}
	}(sock)

	wroteToSirena, err := sock.Write(fullMsg)
	if err != nil {
		return nil, xerrors.Errorf("tryGet(): sock.Write(fullMsg): %w", err)
	}
	logger.Infof("Sent request to Sirena of %d bytes", wroteToSirena)

	var headerBytes = make([]byte, HeaderLength)
	err = sock.SetDeadline(time.Now().Add(30 * time.Second))
	if err != nil {
		return nil, xerrors.Errorf("tryGet(): set read deadline: %w", err)
	}
	headerBytesRead, err := io.ReadFull(sock, headerBytes)
	if err != nil {
		return nil, xerrors.Errorf("tryGet(): ReadFull(): %w", err)
	}
	logger.Infof("Got response header of %d bytes", headerBytesRead)

	var headerStruct SirenaHeader
	err = binary.Read(bytes.NewBuffer(headerBytes), binary.BigEndian, &headerStruct)
	if err != nil {
		return nil, xerrors.Errorf("tryGet(): %w", err)
	}
	logger.Infof("Expected response size %d bytes", headerStruct.MessageLength)

	var data = make([]byte, headerStruct.MessageLength)
	factReadData, err := io.ReadFull(sock, data)
	if err != nil {
		return nil, xerrors.Errorf("tryGet(): %w", err)
	}
	logger.Infof("Actual response size %d bytes", factReadData)
	logger.Info("Sirena response", log.ByteString("data", data))

	return data, nil
}

func getSocket(ctx context.Context, host, port string) (net.Conn, error) {
	var d net.Dialer
	return d.DialContext(ctx, "tcp", net.JoinHostPort(host, port))
}

type SirenaHeader struct {
	MessageLength uint32
	Timestamp     uint32
	MessageID     uint32
	STUB1         [32]byte
	ClientID      uint16
	Flag1         byte
	Flag2         byte
	KeyID         uint32 // unused, should be 0
	STUB2         [48]byte
}

func makeFullMessage(logger log.Logger, msg []byte, timestamp int64, messageID int64, clientID uint16) []byte {
	var b = new(bytes.Buffer)
	err := binary.Write(b, binary.BigEndian, SirenaHeader{
		MessageLength: uint32(len(msg)),
		Timestamp:     uint32(timestamp),
		MessageID:     uint32(messageID),
		STUB1:         [32]byte{},
		ClientID:      clientID,
		Flag1:         0,
		Flag2:         0,
		KeyID:         0,
		STUB2:         [48]byte{},
	})
	if err != nil {
		logger.Error("Cannot encode header to binary format", log.Error(err))
		return nil
	}
	header := b.Bytes()
	if len(header) != HeaderLength {
		logger.Errorf("Invalid header size: %d expected, but generated %d", HeaderLength, len(header))
	}

	return append(header, msg...)
}
