package websocket

import (
	"crypto/tls"
	"fmt"
	"io"
	"net"
	"net/http"
	"sync/atomic"
	"time"

	"code.justin.tv/devhub/lib-lifecycle/src/lifecycle"
	"code.justin.tv/devhub/e2ml/libs/logging"
	"code.justin.tv/devhub/e2ml/libs/session"
	"github.com/gobwas/ws"
)

var healthReturn = ws.RejectConnectionError(ws.RejectionStatus(http.StatusOK))

func passthrough(f ws.Frame) ws.Frame { return f }

type dummyBinding struct{}

var noop session.Binding = &dummyBinding{}

func (*dummyBinding) OnTextMessage(string)   {}
func (*dummyBinding) OnBinaryMessage([]byte) {}
func (*dummyBinding) OnClosed(error)         {}

// NewServiceFactory creates a callback that generates websocket services for a
// given port and settings using the logic from the BindingFactory
func NewServiceFactory(port int, settings *Settings) session.ServiceFactory {
	return func(factory session.BindingFactory) (session.Service, error) {
		upgrader := ws.Upgrader{}
		if settings.HealthCheck != nil {
			upgrader.OnRequest = func(uri []byte) error {
				if string(uri) == *settings.HealthCheck {
					return healthReturn
				}
				return nil
			}
		}
		linkSettings := newLinkSettings(passthrough, settings)
		service := &service{port, factory, settings.Certs, lifecycle.NewManager(), linkSettings, upgrader, 0, 0}
		if settings.Lifecycle != nil {
			settings.Lifecycle.RegisterHook(service, service.Shutdown)
		}
		return service, nil
	}
}

type service struct {
	port        int
	factory     session.BindingFactory
	certs       *tls.Config
	lifecycle   lifecycle.Manager
	settings    *linkSettings
	upgrader    ws.Upgrader
	isRunning   int32
	hasShutdown int32
}

func (s *service) Start() error {
	atomic.StoreInt32(&s.isRunning, 1)
	ln, err := s.listen()
	if err != nil {
		return err
	}
	s.lifecycle.RunUntilComplete(func() { s.run(ln) })
	return nil
}

func (s *service) Stop() {
	atomic.StoreInt32(&s.isRunning, 0)
	_ = s.lifecycle.ExecuteHook(s)
}

func (s *service) WaitForDrainingConnections(until time.Time) {
	s.lifecycle.WaitForCompletion(until)
}

func (s *service) IsRunning() bool {
	return atomic.LoadInt32(&s.isRunning) != 0
}

func (s *service) HasShutdown() bool {
	return atomic.LoadInt32(&s.hasShutdown) != 0
}

func (s *service) Shutdown() error {
	s.Stop()
	atomic.StoreInt32(&s.hasShutdown, 1)
	return s.lifecycle.ExecuteAll()
}

func (s *service) run(ln *net.TCPListener) {
	defer s.settings.logger(logging.Info, "Leaving accept loop on port", s.port)
	s.settings.logger(logging.Info, "Listening on port", s.port, "...")
	for {
		err := s.accept(ln)
		if err == nil {
			continue
		}
		if isClosedError(err) {
			return
		}
		s.settings.logger(logging.Debug, "Unable to accept connection:", err)
		if ln, err = s.listen(); err != nil {
			s.settings.logger(logging.Warning, "Unable to listen: ", err)
			time.Sleep(time.Second)
		}
	}
}

func (s *service) listen() (*net.TCPListener, error) {
	ln, err := net.ListenTCP("tcp", &net.TCPAddr{Port: s.port})
	if err != nil {
		return nil, err
	}
	s.lifecycle.RegisterHook(s, ln.Close)
	if s.HasShutdown() || !s.IsRunning() {
		s.Stop() // catch race condition - shutdown or stop request during socket setup
	}
	return ln, nil
}

func (s *service) accept(ln *net.TCPListener) error {
	conn, err := ln.AcceptTCP()
	if err != nil {
		return err
	}

	s.settings.logger(logging.Trace, "Accepted connection", conn.RemoteAddr())

	wrapped := net.Conn(conn)
	if s.certs != nil {
		wrapped = tls.Server(conn, s.certs)
	}

	s.lifecycle.RunUntilComplete(s.handle(newLink(s.settings, wrapped, s.lifecycle)))
	return nil
}

// https://groups.google.com/forum/#!topic/golang-nuts/I7a_3B8_9Gw
func (s *service) handle(link *link) func() {
	return func() {
		var err error
		bind := noop
		defer func() { bind.OnClosed(link.cause()) }()
		defer link.gracefulClose()
		defer func() { link.Close(err) }()
		_, err = s.upgrader.Upgrade(link.conn)
		if err != nil {
			if err != healthReturn && err != io.EOF {
				s.settings.logger(logging.Debug, fmt.Errorf("Unable to upgrade connection: %w", err))
			}
			return
		}
		bind = s.factory(link)
		err = execute(bind, link, ws.StateServerSide, s.HasShutdown)
	}
}
