//go:build linux || freebsd
// +build linux freebsd

package pcsc

import (
	"bytes"
	"context"
	"encoding/binary"
	"fmt"
	"io"
	"net"
	"os"
	"sync"
	"time"
)

const pcscSocketPathEnv = "PCSCD_SOCK_PATH"

const (
	connTimeout = 100 * time.Millisecond

	rcSuccess = 0

	// https://github.com/LudovicRousseau/PCSC/blob/1.9.0/src/winscard_msg.h#L76
	commandEstablishContext = 0x01
	commandReleaseContext   = 0x02
	commandConnect          = 0x04
	commandReconnect        = 0x05
	commandDisconnect       = 0x06
	commandBeginTransaction = 0x07
	commandEndTransaction   = 0x08
	commandTransmit         = 0x09
	commandVersion          = 0x11
	commandGetReadersState  = 0x12

	// Protocol information
	// https://github.com/LudovicRousseau/PCSC/blob/1.9.0/src/winscard_msg.h#L46-L49
	protocolVersionMajor = int32(4)
	protocolVersionMinor = int32(4)

	// https://github.com/LudovicRousseau/PCSC/blob/1.9.0/src/PCSC/pcsclite.h.in#L239-L246
	protocolUndefined = 0x0000
	protocolT0        = 0x0001
	protocolT1        = 0x0002
	protocolRaw       = 0x0004
	protocolT15       = 0x0008
	protocolAny       = protocolT0 | protocolT1

	// Internal reader states
	// https://github.com/LudovicRousseau/PCSC/blob/1.9.0/src/PCSC/pcsclite.h.in#L257-L263
	readerStateUnknown    = 0x0001
	readerStateAbsent     = 0x0002
	readerStatePresent    = 0x0004
	readerStateSwallowed  = 0x0008
	readerStatePowered    = 0x0010
	readerStateNegotiable = 0x0020
	readerStateSpecific   = 0x0040

	// https://github.com/LudovicRousseau/PCSC/blob/1.9.0/src/PCSC/pcsclite.h.in#L236
	scopeSystem = 0x0002

	// https://github.com/LudovicRousseau/PCSC/blob/1.9.0/src/PCSC/pcsclite.h.in#L248
	shareModeExclusive = 0x0001
	shareModeShared    = 0x0002
	shareModeDirect    = 0x0003

	// https://github.com/LudovicRousseau/PCSC/blob/1.9.0/src/PCSC/pcsclite.h.in#L286
	maxReaderNameSize = 128
	// https://github.com/LudovicRousseau/PCSC/blob/1.9.0/src/PCSC/pcsclite.h.in#L298
	maxBufferSizeExtended = 4 + 3 + (1 << 16) + 3 + 2
	// https://github.com/LudovicRousseau/PCSC/blob/1.9.0/src/PCSC/pcsclite.h.in#L59
	maxAttributeSize = 33
	// https://github.com/LudovicRousseau/PCSC/blob/1.9.0/src/PCSC/pcsclite.h.in#L284
	maxReaders = 16
)

type readerState struct {
	Name         [maxReaderNameSize]byte
	EventCounter uint32
	State        uint32
	Sharing      int32
	Attr         [maxAttributeSize]byte
	AttrSize     uint32
	Protocol     uint32
	_            [3]uint8
}

func (r readerState) name() string {
	if r.Name[0] == 0x00 {
		return ""
	}

	i := len(r.Name)
	for ; i > 0; i-- {
		if r.Name[i-1] != 0x00 {
			break
		}
	}
	return string(r.Name[:i])
}

type scConn struct {
	net.Conn
	mu sync.Mutex
}

func connect() (*scConn, error) {
	udsPath := pcscSocketPath
	if p := os.Getenv(pcscSocketPathEnv); p != "" {
		udsPath = p
	}

	ctx, cancel := context.WithTimeout(context.Background(), connTimeout)
	defer cancel()

	var d net.Dialer
	conn, err := d.DialContext(ctx, "unix", udsPath)
	if err != nil {
		return nil, ErrNoService
	}

	return &scConn{Conn: conn}, nil
}

func (c *scConn) writeRequest(command uint32, req interface{}) error {
	var data []byte
	if req != nil {
		b := &bytes.Buffer{}
		if err := binary.Write(b, nativeByteOrder, req); err != nil {
			return fmt.Errorf("failed to marshal message body: %w", err)
		}

		size := uint32(b.Len())

		data = make([]byte, b.Len()+4+4)
		nativeByteOrder.PutUint32(data[0:4], size)
		nativeByteOrder.PutUint32(data[4:8], command)
		if _, err := io.ReadFull(b, data[8:]); err != nil {
			return err
		}

	} else {
		data = make([]byte, 4+4)
		nativeByteOrder.PutUint32(data[0:4], 0)
		nativeByteOrder.PutUint32(data[4:8], command)
	}

	if _, err := c.Conn.Write(data); err != nil {
		return fmt.Errorf("failed to write request bytes: %w", err)
	}

	return nil
}

func (c *scConn) readResp(resp interface{}) error {
	if err := binary.Read(c.Conn, nativeByteOrder, resp); err != nil {
		return fmt.Errorf("failed to read response: %w", err)
	}

	return nil
}

func (c *scConn) sendMessage(command uint32, req, resp interface{}) error {
	c.mu.Lock()
	defer c.mu.Unlock()

	if err := c.writeRequest(command, req); err != nil {
		return err
	}

	return c.readResp(resp)
}

var _ Client = (*scClient)(nil)

type scClient struct {
	conn  *scConn
	ctxID uint32
}

func NewClient() (Client, error) {
	conn, err := connect()
	if err != nil {
		return nil, err
	}

	out := &scClient{
		conn: conn,
	}

	out.ctxID, err = out.establishContext()
	if err != nil {
		_ = conn.Close()
		return nil, fmt.Errorf("failed to establish pscs context: %w", err)
	}

	return out, nil
}

func (c *scClient) CheckCompatibility() error {
	major, _, err := c.Version()
	if err != nil {
		return fmt.Errorf("failed to get pcsc version: %w", err)
	}

	if major != protocolVersionMajor {
		return fmt.Errorf("major version mismatch: %d (pcsd) != %d (client)", major, protocolVersionMajor)
	}

	return nil
}

func (c *scClient) Version() (major, minor int32, err error) {
	msg := struct {
		Major int32
		Minor int32
		RV    uint32
	}{
		Major: protocolVersionMajor,
		Minor: protocolVersionMinor,
		RV:    rcSuccess,
	}

	if err := c.conn.sendMessage(commandVersion, msg, &msg); err != nil {
		return 0, 0, err
	}

	if msg.RV != rcSuccess {
		return 0, 0, pcscErrorFromCode(int64(msg.RV))
	}

	return msg.Major, msg.Minor, nil
}

func (c *scClient) ListReaders() ([]string, error) {
	var resp [maxReaders]readerState

	if err := c.conn.sendMessage(commandGetReadersState, nil, &resp); err != nil {
		return nil, fmt.Errorf("can't send message: %v", err)
	}

	var names []string
	readerPresent := uint32(readerStatePowered | readerStatePresent)
	for _, r := range resp {
		if r.State&readerPresent != readerPresent {
			continue
		}

		name := r.name()
		if name == "" {
			continue
		}

		names = append(names, name)
	}

	return names, nil
}

func (c *scClient) Connect(reader string) (Handle, error) {
	if len(reader) >= maxReaderNameSize {
		return nil, fmt.Errorf("reader name is too long: %s", reader)
	}

	msg := struct {
		Context            uint32
		ReaderName         [maxReaderNameSize]byte
		ShareMode          uint32
		PreferredProtocols uint32
		Card               int32
		ActiveProtocol     uint32
		RV                 uint32
	}{
		Context:            c.ctxID,
		ShareMode:          shareModeShared,
		PreferredProtocols: protocolT1,
	}

	copy(msg.ReaderName[:], reader)

	if err := c.conn.sendMessage(commandConnect, msg, &msg); err != nil {
		return nil, err
	}

	if msg.RV != rcSuccess {
		return nil, pcscErrorFromCode(int64(msg.RV))
	}

	return &scHandle{
		conn:     c.conn,
		card:     msg.Card,
		protocol: msg.ActiveProtocol,
	}, nil
}

func (c *scClient) establishContext() (uint32, error) {
	msg := struct {
		Scope   uint32
		Context uint32
		RV      uint32
	}{
		Scope: scopeSystem,
	}

	if err := c.conn.sendMessage(commandEstablishContext, msg, &msg); err != nil {
		return 0, err
	}

	if msg.RV != rcSuccess {
		return 0, pcscErrorFromCode(int64(msg.RV))
	}

	return msg.Context, nil
}

func (c *scClient) releaseContext(ctxID uint32) error {
	msg := struct {
		Context uint32
		RV      uint32
	}{
		Context: ctxID,
	}

	if err := c.conn.sendMessage(commandReleaseContext, msg, &msg); err != nil {
		return err
	}

	if msg.RV != rcSuccess {
		return pcscErrorFromCode(int64(msg.RV))
	}

	return nil
}

func (c *scClient) Close() error {
	defer func() { _ = c.conn.Close() }()

	return c.releaseContext(c.ctxID)
}

var _ Handle = (*scHandle)(nil)

type scHandle struct {
	conn     *scConn
	card     int32
	protocol uint32
}

func (h *scHandle) Begin() (Tx, error) {
	msg := struct {
		Card int32
		RV   uint32
	}{
		Card: h.card,
	}

	if err := h.conn.sendMessage(commandBeginTransaction, msg, &msg); err != nil {
		return nil, err
	}

	if msg.RV != rcSuccess {
		return nil, pcscErrorFromCode(int64(msg.RV))
	}

	return &scTx{
		conn:     h.conn,
		card:     h.card,
		protocol: h.protocol,
	}, nil
}

func (h *scHandle) Close() error {
	msg := struct {
		Card        int32
		Disposition uint32
		RV          uint32
	}{
		Card: h.card,
		// SCARD_LEAVE_CARD		0x0000	/**< Do nothing on close */
		Disposition: 0x0000,
	}

	if err := h.conn.sendMessage(commandDisconnect, msg, &msg); err != nil {
		return err
	}

	if msg.RV != rcSuccess {
		return pcscErrorFromCode(int64(msg.RV))
	}

	return nil
}

var _ Tx = (*scTx)(nil)

type scTx struct {
	conn     *scConn
	card     int32
	protocol uint32
}

func (t *scTx) Close() error {
	msg := struct {
		Card        int32
		Disposition uint32
		RV          uint32
	}{
		Card: t.card,
		// SCARD_LEAVE_CARD		0x0000	/**< Do nothing on close */
		Disposition: 0x0000,
	}

	if err := t.conn.sendMessage(commandEndTransaction, msg, &msg); err != nil {
		return err
	}

	if msg.RV != rcSuccess {
		return pcscErrorFromCode(int64(msg.RV))
	}

	return nil
}

func (t *scTx) transmit(req []byte) (more bool, b []byte, err error) {
	t.conn.mu.Lock()
	defer t.conn.mu.Unlock()

	msg := struct {
		Card            int32
		SendPciProtocol uint32
		SendPciLength   uint32
		SendLength      uint32
		RecvPciProtocol uint32
		RecvPciLength   uint32
		RecvLength      uint32
		RV              uint32
	}{
		Card:            t.card,
		SendPciProtocol: t.protocol,
		SendPciLength:   8,
		SendLength:      uint32(len(req)),
		RecvPciProtocol: protocolT1,
		RecvPciLength:   8,
		RecvLength:      maxBufferSizeExtended,
	}

	if err := t.conn.writeRequest(commandTransmit, msg); err != nil {
		return false, nil, err
	}

	if _, err := t.conn.Write(req); err != nil {
		return false, nil, fmt.Errorf("failed to write apu request: %w", err)
	}

	if err := t.conn.readResp(&msg); err != nil {
		return false, nil, err
	}

	if msg.RV != rcSuccess {
		return false, nil, pcscErrorFromCode(int64(msg.RV))
	}

	respN := msg.RecvLength
	if respN < 2 {
		return false, nil, fmt.Errorf("scard response too short: %d", respN)
	}

	resp := make([]byte, respN)
	if _, err := io.ReadFull(t.conn, resp); err != nil {
		return false, nil, fmt.Errorf("failed to read scard response: %w", err)
	}

	sw1 := resp[respN-2]
	sw2 := resp[respN-1]

	if sw1 == 0x90 && sw2 == 0x00 {
		return false, resp[:respN-2], nil
	}

	if sw1 == 0x61 {
		return true, resp[:respN-2], nil
	}

	return false, nil, newError(sw1, sw2)
}
