package socket

import (
	"bytes"
	"context"
	"crypto/rand"
	"encoding/binary"
	"encoding/hex"
	"errors"
	"fmt"
	"net"
	"os"
	"path/filepath"
	"syscall"
)

var _ Socket = (*CygwinSocket)(nil)

type CygwinSocket struct {
	path     string
	uuid     []byte
	closed   bool
	conns    chan net.Conn
	errs     chan error
	done     chan struct{}
	listener net.Listener
}

func NewCygwinSocket(path string) (Socket, error) {
	return &CygwinSocket{
		path:  path,
		conns: make(chan net.Conn),
		errs:  make(chan error, 1),
	}, nil
}

func (s *CygwinSocket) Addr() string {
	return s.path
}

func (s *CygwinSocket) Listen() error {
	_ = os.Remove(s.path)
	err := os.MkdirAll(filepath.Dir(s.path), 0o700)
	if err != nil {
		return fmt.Errorf("failed to create sygwin socket folder: %w", err)
	}

	l, err := net.Listen("tcp", "localhost:0")
	if err != nil {
		return err
	}

	// cygwin socket uuid
	port := l.Addr().(*net.TCPAddr).Port
	uuid, err := createCygwinSocket(s.path, port)
	if err != nil {
		_ = l.Close()
		return err
	}

	s.listener = l
	s.uuid = uuid
	s.done = make(chan struct{})

	go func() {
		defer close(s.done)

		for {
			conn, err := s.listener.Accept()
			if err != nil {
				s.errs <- err

				if errors.Is(err, net.ErrClosed) {
					return
				}
				continue
			}

			s.conns <- conn
		}

	}()

	return nil
}

func (s *CygwinSocket) Accept(ctx context.Context) (Conn, error) {
	if s.closed {
		return nil, net.ErrClosed
	}

	select {
	case err := <-s.errs:
		return nil, err
	case conn := <-s.conns:
		if err := cygwinHandshake(conn, s.uuid); err != nil {
			_ = conn.Close()
			return nil, err
		}

		return &connWrapper{
			ReadWriteCloser: conn,
			peer:            dummyCreds(),
		}, nil
	case <-ctx.Done():
		return nil, ctx.Err()
	}
}

func (s *CygwinSocket) Close() error {
	if s.closed {
		return nil
	}

	s.closed = true
	_ = os.Remove(s.path)
	err := s.listener.Close()

	close(s.conns)
	for c := range s.conns {
		_ = c.Close()
	}

	if s.done != nil {
		<-s.done
	}

	close(s.errs)
	return err
}

func createCygwinSocket(filename string, port int) ([]byte, error) {
	file, err := os.OpenFile(filename, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0666)
	if err != nil {
		return nil, fmt.Errorf("failed to create socket file: %w", err)
	}

	var uuid [16]byte
	_, err = rand.Read(uuid[:])
	if err != nil {
		return nil, fmt.Errorf("failed to generate uuid: %w", err)
	}

	_, _ = file.WriteString(fmt.Sprintf("!<socket >%d s %s", port, uuidToString(uuid)))
	_ = file.Close()
	if err := setFileAttributes(filename, syscall.FILE_ATTRIBUTE_SYSTEM|syscall.FILE_ATTRIBUTE_READONLY); err != nil {
		return nil, fmt.Errorf("failed to set socket file attrs: %w", err)
	}

	return uuid[:], nil
}

func cygwinHandshake(conn net.Conn, uuid []byte) error {
	var cuuid [16]byte
	_, err := conn.Read(cuuid[:])
	if err != nil {
		return err
	}
	if !bytes.Equal(uuid[:], cuuid[:]) {
		return fmt.Errorf("invalid uuid")
	}

	_, err = conn.Write(uuid[:])
	if err != nil {
		return err
	}

	pidsUids := make([]byte, 12)
	_, err = conn.Read(pidsUids[:])
	if err != nil {
		return err
	}

	pid := os.Getpid()
	gid := pid // for cygwin's AF_UNIX -> AF_INET, pid = gid
	binary.LittleEndian.PutUint32(pidsUids, uint32(pid))
	binary.LittleEndian.PutUint32(pidsUids[8:], uint32(gid))
	if _, err = conn.Write(pidsUids); err != nil {
		return err
	}
	return nil
}

func uuidToString(uuid [16]byte) string {
	var buf [35]byte
	dst := buf[:]
	for i := 0; i < 4; i++ {
		b := uuid[i*4 : i*4+4]
		hex.Encode(dst[i*9:i*9+8], []byte{b[3], b[2], b[1], b[0]})
		if i != 3 {
			dst[9*i+8] = '-'
		}
	}
	return string(buf[:])
}

func setFileAttributes(path string, attr uint32) error {
	cpath, cpathErr := syscall.UTF16PtrFromString(path)
	if cpathErr != nil {
		return cpathErr
	}
	return syscall.SetFileAttributes(cpath, attr)
}
