package main

import (
	"context"
	"crypto/sha1"
	"encoding/base64"
	"encoding/json"
	"fmt"
	"io/ioutil"
	"log"
	"net/http"
	"net/http/httputil"
	_ "net/http/pprof"
	"net/url"
	"os"
	"os/signal"
	"strconv"
	"strings"
	"syscall"
	"time"

	"code.justin.tv/growth/secretconf"
	s2sLog "code.justin.tv/sse/malachai/pkg/log"
	s2sCaller "code.justin.tv/sse/malachai/pkg/s2s/caller"
	asiimov "code.justin.tv/systems/guardian/middleware"
	gorillaContext "github.com/gorilla/context"
	"github.com/gorilla/sessions"
	"github.com/pkg/errors"
	uuid "github.com/satori/go.uuid"
	"golang.org/x/oauth2"
	validator "gopkg.in/validator.v2"
)

const (
	dashboard_addr                = 8000
	dashboard_debug_addr          = 6000
	dashboard_staging_url         = "https://eventbus-staging.xarth.tv"
	dashboard_prod_url            = "https://eventbus.xarth.tv"
	dashboard_sandstorm_role_base = "arn:aws:iam::734326455073:role/sandstorm/production/templated/role/eventbus-dashboard-"
	controlplane_addr             = 8888
)

const (
	TwitchLDAPUserHeader   = "Twitch-Ldap-User"
	TwitchLDAPGroupsHeader = "Twitch-Ldap-Groups"
)

type ServerConfig struct {
	ProxyURL  string // proxy for all backend services to hit
	Addr      string
	DebugAddr string

	S2SName          string
	S2SEnabled       bool
	GuardianEnabled  bool
	GuardianHostname string
	SandstormRoleARN string
	RedirectURL      string // url to redirect to for auth (format: http://url/auth)
	LoginTTL         int64
	NonceTTL         int64
}

var DevServerConfig = ServerConfig{
	ProxyURL:  fmt.Sprintf("http://localhost:%d/", controlplane_addr),
	Addr:      fmt.Sprintf(":%d", dashboard_addr),
	DebugAddr: fmt.Sprintf(":%d", dashboard_debug_addr),

	S2SName:          "eventbus-dashboard-staging", // use 'staging' if we have s2s turned on at all
	S2SEnabled:       false,
	GuardianEnabled:  false,
	GuardianHostname: "guardian.internal.justin.tv",
	SandstormRoleARN: dashboard_sandstorm_role_base + "development",
	RedirectURL:      fmt.Sprintf("http://localhost:%d/auth", dashboard_addr),
	LoginTTL:         60 * 60 * 24,
	NonceTTL:         60 * 5,
}

var StageServerConfig = ServerConfig{
	ProxyURL:  "https://controlplane.staging.eventbus.twitch.a2z.com",
	Addr:      fmt.Sprintf(":%d", dashboard_addr),
	DebugAddr: fmt.Sprintf(":%d", dashboard_debug_addr),

	S2SName:          "eventbus-dashboard-staging",
	S2SEnabled:       true,
	GuardianEnabled:  true,
	GuardianHostname: "twitch-eventbus-dev.us-west-2.prod.guardian.services.twitch.a2z.com",
	SandstormRoleARN: dashboard_sandstorm_role_base + "staging",
	RedirectURL:      dashboard_staging_url + "/auth",
	LoginTTL:         60 * 60 * 24,
	NonceTTL:         60 * 5,
}

var ProdServerConfig = ServerConfig{
	ProxyURL:  "https://controlplane.prod.eventbus.twitch.a2z.com",
	Addr:      fmt.Sprintf(":%d", dashboard_addr),
	DebugAddr: fmt.Sprintf(":%d", dashboard_debug_addr),

	S2SName:          "eventbus-dashboard-production",
	S2SEnabled:       true,
	GuardianEnabled:  true,
	GuardianHostname: "twitch-eventbus-aws.us-west-2.prod.guardian.services.twitch.a2z.com",
	SandstormRoleARN: dashboard_sandstorm_role_base + "production",
	RedirectURL:      dashboard_prod_url + "/auth",
	LoginTTL:         60 * 60 * 24,
	NonceTTL:         60 * 5,
}

type Secrets struct {
	GuardianClientID     string `validate:"nonzero" secret:"guardian_client_id"`
	GuardianClientSecret string `validate:"nonzero" secret:"guardian_client_secret"`
	CookieSecret         string `validate:"nonzero" secret:"cookie_secret"`
}

type Server struct {
	ProxyURL string
	Proxy    http.Handler

	GuardianEnabled bool
	OauthConfig     *oauth2.Config
	CookieStore     *sessions.CookieStore
	LoginTTL        int64
	NonceTTL        int64
	CheckTokenURL   string
}

type Vars struct {
	Env string `json:"env"`
}

const API_PREFIX string = "/api"
const VARS_PATH string = "/vars"
const USER_LOOKUP_PATH string = "/current-user"
const DEFAULT_PATTERN string = "/"
const AUTH_PATH = "/auth"
const HEALTH_PATH = "/health"

func main() {
	stop := make(chan os.Signal, 1)
	signal.Notify(stop, os.Interrupt, syscall.SIGTERM)

	env := os.Getenv("ENVIRONMENT")
	if env == "" {
		env = "development"
	}
	log.Printf("Running with ENVIRONMENT: %v", env)

	serverConfig := resolveServerConfig(env)
	server := setupServer(serverConfig, env)
	vars := &Vars{Env: env}

	mux := http.NewServeMux()
	mux.HandleFunc(API_PREFIX+VARS_PATH, handleVarsFunc(vars))
	mux.HandleFunc(API_PREFIX+USER_LOOKUP_PATH, server.handleUserLookupFnc())
	mux.HandleFunc(API_PREFIX+DEFAULT_PATTERN, server.handleProxyRequests)
	mux.HandleFunc(AUTH_PATH, server.handleAuth)
	mux.HandleFunc(HEALTH_PATH, server.handleHealthCheck)
	mux.HandleFunc(DEFAULT_PATTERN, server.handleServingHTML)

	h := http.Server{Addr: serverConfig.Addr, Handler: gorillaContext.ClearHandler(mux)}
	p := http.Server{Addr: serverConfig.DebugAddr}
	go func() {
		if err := h.ListenAndServe(); err != nil {
			log.Fatalf("Listen and serve error: %v", err)
		}
	}()
	go func() {
		if err := p.ListenAndServe(); err != nil {
			log.Fatalf("Pprof listen and serve error: %v", err)
		}
	}()

	<-stop
	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
	defer cancel()
	h.Shutdown(ctx)
	p.Shutdown(ctx)
}

func (s *Server) handleAuth(w http.ResponseWriter, r *http.Request) {
	if r.FormValue("error") == "access_denied" {
		s.handleForbidden(w, r)
		return
	} else if s.isAuthenticated(r) {
		http.Redirect(w, r, "/", http.StatusFound)
		return
	}

	requestNonce := r.FormValue("state")
	session, err := s.CookieStore.Get(r, "auth-session")
	if err != nil {
		log.Printf("failed to get cookies: %s\n", err)
	}

	validNonce := validateNonce(requestNonce, session, s.NonceTTL)
	if !validNonce {
		requestNonce, err = generateNonce(w, r, session)
		if err != nil {
			log.Printf("failed to generate nonce: %s\n", err)
			w.WriteHeader(http.StatusInternalServerError)
			return
		}

		http.Redirect(w, r, s.OauthConfig.AuthCodeURL(requestNonce, oauth2.AccessTypeOnline), http.StatusFound)
		return
	}

	if err := s.setupAuth(w, r, session); err != nil {
		log.Printf("failed to setup auth: %s\n", err)
		w.WriteHeader(http.StatusInternalServerError)
		return
	}

	http.Redirect(w, r, "/", http.StatusFound)
}

func (s *Server) setupAuth(w http.ResponseWriter, r *http.Request, session *sessions.Session) error {
	authCode := r.FormValue("code")
	token, err := s.OauthConfig.Exchange(context.Background(), authCode)
	if err != nil {
		return err
	}
	authHandler := asiimov.New(nil, s.OauthConfig)
	authHandler.CheckTokenURL = s.CheckTokenURL
	tc, err := authHandler.CheckToken(token)
	if err != nil {
		return err
	}

	if tc == nil || tc.User == nil {
		return errors.New("using token didn't yield a valid User")
	}

	encodedLDAPGroups, err := encodeLDAPGroups(tc.User.Groups)
	if err != nil {
		return errors.Wrap(err, "could not encode ldap groups for session")
	}

	session.Values["login-time"] = time.Now().Unix()
	session.Values["login-uid"] = tc.User.UID
	session.Values["login-email"] = tc.User.Email
	session.Values["login-groups"] = encodedLDAPGroups

	if err := session.Save(r, w); err != nil {
		return fmt.Errorf("failed to save session: %s", err.Error())
	}

	return nil
}

func (s *Server) isAuthenticated(r *http.Request) bool {
	if !s.GuardianEnabled {
		return true
	}
	session, err := s.CookieStore.Get(r, "auth-session")
	if err != nil {
		log.Printf("failed to get cookies: %s", err)
		return false
	}

	if tsIface, present := session.Values["login-time"]; present {
		ts, typeOk := tsIface.(int64)
		return typeOk && ts+s.LoginTTL > time.Now().Unix()
	}
	return false
}

func (s *Server) handleProxyRequests(w http.ResponseWriter, r *http.Request) {
	// API requests must be authenticated
	if !s.isAuthenticated(r) {
		http.Redirect(w, r, AUTH_PATH, http.StatusFound)
	}
	err := s.embedLDAPHeaders(r)
	if err != nil {
		internalServerError(w, r, err)
		return
	}

	log.Printf("handling proxy request %v", r.URL.Path)
	s.Proxy.ServeHTTP(w, r)
}

// Embed LDAP information into request headers
func (s *Server) embedLDAPHeaders(r *http.Request) error {
	var user string
	var groups string
	var err error
	if s.GuardianEnabled {
		user, groups, err = s.ldapInfoGuardian(r)
	} else {
		user, groups, err = s.ldapInfoFile(r)
	}
	if err != nil {
		return errors.Wrap(err, "could not get ldap info")
	}
	r.Header.Set(TwitchLDAPUserHeader, user)
	r.Header.Set(TwitchLDAPGroupsHeader, groups)
	return nil
}

func (s *Server) ldapInfoGuardian(r *http.Request) (string, string, error) {
	session, err := s.CookieStore.Get(r, "auth-session")
	if err != nil {
		return "", "", errors.Wrap(err, "could not fetch auth cookie")
	}
	var ldapUser string
	var ldapGroups string
	var ok bool
	if ldapUser, ok = session.Values["login-uid"].(string); !ok {
		return "", "", errors.New("could not get ldap user from session")
	}
	if ldapGroups, ok = session.Values["login-groups"].(string); !ok {
		return "", "", errors.New("could not get ldap groups from session")
	}
	return ldapUser, ldapGroups, nil
}

func (s *Server) ldapInfoFile(r *http.Request) (string, string, error) {
	b, err := ioutil.ReadFile(".ldap.json")
	if err != nil {
		return "", "", err
	}
	info := &struct {
		User   string   `json:"user"`
		Groups []string `json:"groups"`
	}{}
	err = json.Unmarshal(b, info)
	if err != nil {
		return "", "", err
	}
	groupsBase64, err := encodeLDAPGroups(info.Groups)
	if err != nil {
		return "", "", errors.Wrap(err, "could not encode ldap groups from file")
	}
	return info.User, groupsBase64, nil
}

const nonceDelimiter = ":"

func generateNonce(w http.ResponseWriter, r *http.Request, session *sessions.Session) (string, error) {
	nonce := uuid.NewV4().String() + nonceDelimiter + strconv.FormatInt(time.Now().Unix(), 10)

	// Warning: Storing nonces as secure cookies is nonstandard and exposes us to a replay attack.
	// In an actual high-security application, these should be stored in Redis on the server side.
	session.Values["nonce"] = nonce
	if err := session.Save(r, w); err != nil {
		return "", fmt.Errorf("failed to save session: %s", err.Error())
	}

	return hashString(nonce), nil
}

func validateNonce(requestNonce string, session *sessions.Session, ttl int64) bool {
	nonceIface, present := session.Values["nonce"]
	if !present {
		log.Println("missing nonce cookie")
		return false
	}

	nonce, typeOk := nonceIface.(string)
	if !typeOk {
		log.Printf("got wrong nonce cookie type for %v", nonceIface)
		return false
	}

	splitNonce := strings.Split(nonce, nonceDelimiter)
	if len(splitNonce) != 2 {
		log.Printf("bad nonce format for %s", nonce)
		return false
	}

	ts := splitNonce[1]
	tsNum, err := strconv.ParseInt(ts, 10, 64)
	if err != nil {
		log.Printf("failed to convert timestamp: %s", err.Error())
		return false
	}

	return tsNum+ttl > time.Now().Unix() && hashString(nonce) == requestNonce
}

func hashString(input string) string {
	hasher := sha1.New()
	hasher.Write([]byte(input))
	return base64.URLEncoding.EncodeToString(hasher.Sum(nil))
}

func setupServer(serverConfig ServerConfig, env string) *Server {
	var secrets Secrets
	var oauthConfig *oauth2.Config

	if serverConfig.GuardianEnabled {
		mustLoadSecrets(env, serverConfig.SandstormRoleARN, &secrets)
		oauthConfig = &oauth2.Config{
			ClientID:     secrets.GuardianClientID,
			ClientSecret: secrets.GuardianClientSecret,
			Endpoint: oauth2.Endpoint{
				AuthURL:  asiimov.AuthURL,
				TokenURL: fmt.Sprintf("https://%s/oauth2/token", serverConfig.GuardianHostname),
			},
			RedirectURL: serverConfig.RedirectURL,
		}
	}

	server := &Server{
		ProxyURL: serverConfig.ProxyURL,
		Proxy:    *createProxy(serverConfig.ProxyURL, serverConfig.S2SEnabled, serverConfig.S2SName),

		GuardianEnabled: serverConfig.GuardianEnabled,
		OauthConfig:     oauthConfig,
		CookieStore:     sessions.NewCookieStore([]byte(secrets.CookieSecret)),
		LoginTTL:        serverConfig.LoginTTL,
		NonceTTL:        serverConfig.NonceTTL,
		CheckTokenURL:   fmt.Sprintf("https://%s/oauth2/check_token", serverConfig.GuardianHostname),
	}

	return server
}

func resolveServerConfig(env string) ServerConfig {
	switch env {
	case "development":
		return DevServerConfig
	case "staging":
		return StageServerConfig
	case "prod":
		return ProdServerConfig
	default:
		panic("could not resolve App config for env " + env)
	}
}

func mustLoadSecrets(env string, sandstormRoleARN string, conf interface{}) {
	var sandstormEnv string
	switch env {
	case "development":
		sandstormEnv = "development"
	case "staging":
		sandstormEnv = "staging"
	case "prod":
		sandstormEnv = "production"
	default:
		panic("could not resolve sandstorm env given " + env)
	}

	manager := secretconf.NewManager(sandstormRoleARN)
	err := secretconf.Load(conf, manager, "eventbus", "dashboard", sandstormEnv)
	if err != nil {
		log.Fatalf("failed to initialize secrets: %s", err.Error())
	}

	err = validator.Validate(conf)
	if err != nil {
		log.Fatalf("failed to validate secrets: %s", err.Error())
	}
}

func createProxy(inputURL string, s2sEnabled bool, s2sName string) *http.Handler {
	url, err := url.Parse(inputURL)
	if err != nil {
		log.Fatalf("failed to parse url from request: %v", err)
	}
	proxy := httputil.NewSingleHostReverseProxy(url)
	if s2sEnabled {
		cfg := &s2sCaller.Config{
			DisableStatsClient: true,
		}
		t, err := s2sCaller.NewRoundTripper(s2sName, cfg, &s2sLog.NoOpLogger{}) // TODO: real logger
		if err != nil {
			log.Fatalf("failed to initialize s2s roundtripper: %s\n", err.Error())
		}
		proxy.Transport = newStripXRoundTripper(t)
	}
	server := http.StripPrefix(API_PREFIX, proxy)
	return &server
}

func handleVarsFunc(vars *Vars) func(w http.ResponseWriter, r *http.Request) {
	return func(w http.ResponseWriter, r *http.Request) {
		w.Header().Set("Content-Type", "application/json")

		b, err := json.Marshal(vars)
		if err != nil {
			log.Printf("failed to marshal vars: %v", err)
		}

		_, err = w.Write(b)
		if err != nil {
			log.Printf("vars write failed: %v", err)
		}
	}
}

func (s *Server) handleUserLookupFnc() func(w http.ResponseWriter, r *http.Request) {
	return func(w http.ResponseWriter, r *http.Request) {
		ldapUser, _, err := s.ldapInfoGuardian(r)
		if err != nil {
			http.Error(w, errors.Wrap(err, "could not get ldap info").Error(), 404)
		}

		b, err := json.Marshal(ldapUser)
		if err != nil {
			internalServerError(w, r, errors.Wrap(err, "could not marshal ldapUser"))
			return
		}

		_, err = w.Write(b)
		if err != nil {
			internalServerError(w, r, errors.Wrap(err, "could not write current user response"))
			return
		}
	}
}

func (s *Server) handleHealthCheck(w http.ResponseWriter, r *http.Request) {
	w.WriteHeader(http.StatusOK)
	w.Write([]byte("OK"))
}

func (s *Server) handleForbidden(w http.ResponseWriter, r *http.Request) {
	w.WriteHeader(http.StatusForbidden)
	var b strings.Builder
	b.WriteString("<h1>Access Denied</h1>")
	b.WriteString("<p>Unfortunately, we have to deny you access aboard the EventBus. Perhaps you forgot your bus pass? Being a member of the 'infra' group in LDAP is required before riding!</p>")
	b.WriteString(`<p>For more information, check out the <a href="https://git.xarth.tv/pages/eventbus/docs/troubleshooting/">troubleshooting documentation</a>.`)
	w.Write([]byte(b.String()))
}

func (s *Server) handleBundle(w http.ResponseWriter, r *http.Request) {
	w.Header().Add("Cache-Control", "no-cache, no-store, must-revalidate")
	w.Header().Add("Pragma", "no-cache")
	w.Header().Add("Expires", "0")
	http.FileServer(http.Dir("./dist")).ServeHTTP(w, r)
}

func (s *Server) handleServingHTML(w http.ResponseWriter, r *http.Request) {
	if !s.isAuthenticated(r) {
		http.Redirect(w, r, AUTH_PATH, http.StatusFound)
	}

	dist := http.FileServer(http.Dir("./dist"))
	log.Printf("handle request for path %v", r.URL.Path)
	if isAllowedPath(r.URL.Path) {
		b, err := ioutil.ReadFile("./dist/index.html")
		if err != nil {
			log.Printf("failed to read index.html: %v", err)
			w.WriteHeader(http.StatusInternalServerError)
			return
		}

		w.Header().Add("Content-Type", "text/html")
		w.Header().Add("Cache-Control", "no-cache, no-store, must-revalidate")
		w.Header().Add("Pragma", "no-cache")
		w.Header().Add("Expires", "0")

		_, err = w.Write(b)
		if err != nil {
			log.Printf("failed to write contents: %v", err)
		}
		return
	}

	dist.ServeHTTP(w, r)
}

func isAllowedPath(url string) bool {
	return url == "/" ||
		strings.HasPrefix(url, "/services") ||
		strings.HasPrefix(url, "/events") ||
		strings.HasPrefix(url, "/docs")
}

type stripXRoundTripper struct {
	next http.RoundTripper
}

func newStripXRoundTripper(next http.RoundTripper) http.RoundTripper {
	return &stripXRoundTripper{
		next: next,
	}
}

func (s *stripXRoundTripper) RoundTrip(r *http.Request) (*http.Response, error) {
	r.Header.Del("X-Forwarded-For")
	r.Header.Del("X-Forwarded-Proto")
	r.Header.Del("X-Forwarded-Port")
	r.Header.Del("X-Amzn-Trace-Id")
	return s.next.RoundTrip(r)
}

func internalServerError(w http.ResponseWriter, r *http.Request, err error) {
	// this is a sad path but unexpected path, just kill the request here
	errStr := fmt.Sprintf("internal server error: %s", err.Error())
	log.Printf(errStr)
	http.Error(w, errStr, 500)
	return
}

// converts a list of strings into a base64 encoded JSON representation
func encodeLDAPGroups(groups []string) (string, error) {
	b, err := json.Marshal(groups)
	if err != nil {
		return "", err
	}
	return base64.StdEncoding.EncodeToString(b), nil
}
