package socket

import (
	"context"
	"errors"
	"fmt"
	"net"
	"os"
	"path/filepath"
)

var _ Socket = (*UnixSocket)(nil)

type UnixSocket struct {
	path     string
	closed   bool
	conns    chan net.Conn
	errs     chan error
	done     chan struct{}
	listener net.Listener
}

func NewUnixSocket(path string) (Socket, error) {
	return &UnixSocket{
		path:  path,
		conns: make(chan net.Conn),
		errs:  make(chan error, 1),
	}, nil
}

func (s *UnixSocket) Addr() string {
	return s.path
}

func (s *UnixSocket) Listen() error {
	_ = os.Remove(s.path)
	err := os.MkdirAll(filepath.Dir(s.path), 0o700)
	if err != nil {
		return fmt.Errorf("failed to create unix socket folder: %w", err)
	}

	l, err := net.Listen("unix", s.path)
	if err != nil {
		return err
	}

	s.listener = l
	s.done = make(chan struct{})

	go func() {
		defer close(s.done)

		for {
			conn, err := s.listener.Accept()
			if err != nil {
				s.errs <- err

				if errors.Is(err, net.ErrClosed) {
					return
				}
				continue
			}

			s.conns <- conn
		}

	}()
	return nil
}

func (s *UnixSocket) Accept(ctx context.Context) (Conn, error) {
	if s.closed {
		return nil, net.ErrClosed
	}

	select {
	case err := <-s.errs:
		return nil, err
	case conn := <-s.conns:
		creds, err := unixCreds(conn)
		if err != nil {
			_ = conn.Close()
			return nil, fmt.Errorf("unable to get unix creds: %w", err)
		}

		return &connWrapper{
			ReadWriteCloser: conn,
			peer:            creds,
		}, nil
	case <-ctx.Done():
		return nil, ctx.Err()
	}
}

func (s *UnixSocket) Close() error {
	if s.closed {
		return nil
	}

	s.closed = true
	_ = os.Remove(s.path)
	err := s.listener.Close()

	close(s.conns)
	for c := range s.conns {
		_ = c.Close()
	}

	if s.done != nil {
		<-s.done
	}
	close(s.errs)
	return err
}
