package websocket

import (
	"crypto/tls"
	"fmt"
	"io"
	"net"
	"net/http"
	"sync/atomic"
	"time"

	"code.justin.tv/devhub/lib-lifecycle/src/lifecycle"
	"code.justin.tv/karlpatr/message-prototype/libs/logging"
	"code.justin.tv/karlpatr/message-prototype/libs/session"
	"github.com/gobwas/ws"
)

var healthReturn = ws.RejectConnectionError(ws.RejectionStatus(http.StatusOK))

type errHandler interface {
	closeForError(err error, msg string) []byte
	log(...interface{})
	timeout() *time.Duration
}

// NewServiceFactory creates a callback that generates websocket services for a
// given port and settings using the logic from the BindingFactory
func NewServiceFactory(port int, settings *Settings) session.ServiceFactory {
	return func(factory session.BindingFactory) (session.Service, error) {
		internal := fixup(settings)
		internal.Lifecycle = lifecycle.NewManager()
		upgrader := ws.Upgrader{}
		if internal.HealthCheck != nil {
			upgrader.OnRequest = func(uri []byte) error {
				if string(uri) == *internal.HealthCheck {
					return healthReturn
				}
				return nil
			}
		}
		service := &service{port, factory, internal, upgrader, newLinkSettings(
			func(f ws.Frame) ws.Frame { return f },
			defaultDuration,
			settings.Logger,
		), 0, 0}
		if settings.Lifecycle != nil {
			settings.Lifecycle.RegisterHook(service, service.Shutdown)
		}
		return service, nil
	}
}

type service struct {
	port         int
	factory      session.BindingFactory
	settings     *Settings
	upgrader     ws.Upgrader
	linkSettings *linkSettings
	isRunning    int32
	hasShutdown  int32
}

func (s *service) Start() {
	atomic.StoreInt32(&s.isRunning, 1)
	s.settings.Lifecycle.RunUntilComplete(s.run)
}

func (s *service) Stop() {
	atomic.StoreInt32(&s.isRunning, 0)
	_ = s.settings.Lifecycle.ExecuteHook(s)
}

func (s *service) WaitForDrainingConnections(until time.Time) {
	s.settings.Lifecycle.WaitForCompletion(until)
}

func (s *service) IsRunning() bool {
	return atomic.LoadInt32(&s.isRunning) != 0
}

func (s *service) HasShutdown() bool {
	return atomic.LoadInt32(&s.hasShutdown) != 0
}

func (s *service) Shutdown() error {
	s.Stop()
	atomic.StoreInt32(&s.hasShutdown, 1)
	return s.settings.Lifecycle.ExecuteAll()
}

func (s *service) run() {
	defer s.settings.Logger(logging.Info, "Leaving accept loop on port", s.port)
	s.settings.Logger(logging.Info, "Listening on port", s.port, "...")
	for {
		ln := s.listen()
		for {
			err := s.accept(ln)
			if err == nil {
				continue
			}
			if isClosedError(err) {
				return
			}
			s.settings.Logger(logging.Debug, "Unable to accept connection:", err)
			break
		}
	}
}

func (s *service) listen() *net.TCPListener {
	ln, err := net.ListenTCP("tcp", &net.TCPAddr{Port: s.port})
	if err != nil {
		panic(fmt.Sprintf("Unable to create socket: %v", err))
	}
	s.settings.Lifecycle.RegisterHook(s, ln.Close)
	if s.HasShutdown() || !s.IsRunning() {
		s.Stop() // catch race condition - shutdown or stop request during socket setup
	}
	return ln
}

func (s *service) accept(ln *net.TCPListener) error {
	conn, err := ln.AcceptTCP()
	if err != nil {
		return err
	}

	s.settings.Logger(logging.Trace, "Accepted connection", conn.RemoteAddr())
	s.settings.Lifecycle.RegisterHook(conn, func() error { return s.settings.gracefulClose(conn) })

	onClosed := func(error) { _ = s.settings.Lifecycle.ExecuteHook(conn) }
	wrapped := net.Conn(conn)
	if s.settings.Certs != nil {
		wrapped = tls.Server(conn, s.settings.Certs)
	}

	s.settings.Lifecycle.RunUntilComplete(s.handle(&link{s.linkSettings, onClosed, wrapped}))
	return nil
}

// https://groups.google.com/forum/#!topic/golang-nuts/I7a_3B8_9Gw
func (s *service) handle(link *link) func() {
	return func() {
		var err error
		defer func() { link.onClose(err) }()
		_, err = s.upgrader.Upgrade(link.conn)
		if err != nil {
			if err != healthReturn && err != io.EOF {
				s.settings.Logger(logging.Debug, "Unable to upgrade connection", err)
			}
			return
		}

		binding := s.factory(link)
		defer func() { binding.OnClosed(err) }()

		err = execute(binding, link, ws.StateServerSide, s.settings, s.HasShutdown)
	}
}

func closeForError(shutdown bool, err error, msg string, logger logging.Function) ws.StatusCode {
	if isClosedError(err) || err.Error() == "EOF" {
		if !shutdown {
			return ws.StatusNormalClosure
		}
		return ws.StatusGoingAway
	}
	logger(logging.Debug, msg, err)
	return ws.StatusProtocolError
}

func filterClosed(err error) error {
	if isClosedError(err) {
		return nil
	}
	return err
}

func isClosedError(err error) bool {
	cast, ok := err.(*net.OpError)
	return ok && cast.Err.Error() == "use of closed network connection"
}
