package websocket

import (
	"crypto/tls"
	"fmt"
	"net"
	"sync/atomic"
	"time"

	"code.justin.tv/devhub/gdaas-ingest/libs/logging"
	"code.justin.tv/devhub/gdaas-ingest/libs/session"
	"code.justin.tv/extensions/shutdown"
	"github.com/gobwas/ws"
)

// Service is the top level interface for a websocket service
type Service interface {
	Start()
	Shutdown() error
}

type errHandler interface {
	closeForError(err error, msg string) []byte
	log(...interface{})
	timeout() *time.Duration
}

// NewService creates a new service instance at the indicated port; the optional
// tls config will cause the service to use a secured connection.
func NewService(port int, factory session.BindingFactory, settings *Settings) Service {
	internal := fixup(settings)
	internal.Lifecycle = shutdown.NewManager()
	service := &service{port, factory, internal, 0}
	if settings.Lifecycle != nil {
		settings.Lifecycle.RegisterHook(service, service.Shutdown)
	}
	return service
}

type service struct {
	port        int
	factory     session.BindingFactory
	settings    *Settings
	hasShutdown int32
}

func (s *service) Start() {
	s.settings.Lifecycle.RunUntilComplete(s.run)
}

func (s *service) HasShutdown() bool {
	return atomic.LoadInt32(&s.hasShutdown) != 0
}

func (s *service) Shutdown() error {
	atomic.AddInt32(&s.hasShutdown, 1)
	errs := s.settings.Lifecycle.Shutdown()
	if len(errs) == 0 {
		return nil
	}
	for _, err := range errs {
		s.settings.Logger(logging.Error, "Shutdown error:", err)
	}
	return fmt.Errorf("Shutdown caused %d errors; see report for details", len(errs))
}

func (s *service) run() {
	defer s.settings.Logger(logging.Debug, "Leaving accept loop")
	for {
		ln := s.listen()
		for {
			err := s.accept(ln)
			if err == nil {
				continue
			}
			if isClosedError(err) {
				return
			}
			s.settings.Logger(logging.Debug, "Unable to accept connection:", err)
			break
		}
	}
}

func (s *service) listen() *net.TCPListener {
	ln, err := net.ListenTCP("tcp", &net.TCPAddr{Port: s.port})
	if err != nil {
		panic(fmt.Sprintf("Unable to create socket: %v", err))
	}
	s.settings.Lifecycle.RegisterHook(ln, ln.Close)
	s.settings.Logger(logging.Info, "Listening on port", s.port, "...")
	return ln
}

func (s *service) accept(ln *net.TCPListener) error {
	conn, err := ln.AcceptTCP()
	if err != nil {
		return err
	}

	s.settings.Logger(logging.Debug, "Accepted connection", conn.RemoteAddr())
	s.settings.Lifecycle.RegisterHook(conn, closeRead(conn, s.settings.Logger))

	wrapped := net.Conn(conn)
	if s.settings.Certs != nil {
		wrapped = tls.Server(conn, s.settings.Certs)
	}

	link := &link{conn: wrapped, mask: func(f ws.Frame) ws.Frame { return f }, logger: s.settings.Logger, onClose: func(error) {
		s.settings.gracefulClose(conn)
	}}
	s.settings.Lifecycle.RunUntilComplete(s.handle(link))
	return nil
}

// https://groups.google.com/forum/#!topic/golang-nuts/I7a_3B8_9Gw
func (s *service) handle(link *link) func() {
	s.settings.Logger(logging.Trace, "in handle")
	return func() {
		defer link.onClose(nil)
		defer s.settings.Lifecycle.ExecuteHook(link.conn)
		_, err := ws.Upgrade(link.conn)
		if err != nil {
			s.settings.Logger(logging.Debug, "Unable to upgrade connection", err)
			return
		}

		binding := s.factory(link)
		defer func() { binding.OnClosed(err) }()
		// Run another goroutine on the same binding objcet to constantly checking for idle connection
		if idleFrameChecker, ok := binding.(session.IdleFrameChecker); ok {
			s.settings.Logger(logging.Trace, "Prepare to run idle checker")
			s.settings.Lifecycle.RunUntilComplete(checkIfIdle(idleFrameChecker.CheckIfIdle, s.settings))
		}
		err = execute(binding, link, ws.StateServerSide, s.settings, s.HasShutdown)
	}
}

func checkIfIdle(idleChecker func(), s *Settings) func() {
	s.Logger(logging.Trace, "Run idle checker")
	return idleChecker
}

func closeRead(c *net.TCPConn, logger logging.Function) func() error {
	return func() error {
		err := c.CloseRead()
		if err == nil {
			logger(logging.Debug, "Closed read:", c.RemoteAddr())
		}
		return filterClosed(err)
	}
}

func closeForError(shutdown bool, err error, msg string, logger logging.Function) ws.StatusCode {
	if isClosedError(err) || err.Error() == "EOF" {
		if !shutdown {
			return ws.StatusNormalClosure
		}
		return ws.StatusGoingAway
	}
	logger(logging.Debug, msg, err)
	return ws.StatusProtocolError
}

func filterClosed(err error) error {
	if isClosedError(err) {
		return nil
	}
	return err
}

func isClosedError(err error) bool {
	cast, ok := err.(*net.OpError)
	return ok && cast.Err.Error() == "use of closed network connection"
}
