package main

import (
	"errors"
	"flag"
	"fmt"
	"io/ioutil"
	"os"
	"path"
	"strconv"

	"ting/util"
)

type Config struct {
	Port    int
	DbUrl   string
	TLS     TLSConfig
	JWT     JWTConfig
	Logging LoggingConfig
}

type TLSConfig struct {
	CertFile string
	KeyFile  string
}

var tlsArgErr error = errors.New(`must provide both "tls_cert_file" and "tls_key_file" (or neither)`)

func (c TLSConfig) Validate() error {
	if util.Xor(c.CertFile == "", c.KeyFile == "") {
		return tlsArgErr
	} else if c.CertFile == "" {
		return nil
	} else if _, err := os.Stat(c.CertFile); err != nil {
		return fmt.Errorf("%s: %s", c.CertFile, err)
	} else if _, err = os.Stat(c.KeyFile); err != nil {
		return fmt.Errorf("%s: %s", c.KeyFile, err)
	}
	return nil
}

func (c TLSConfig) Enabled() bool {
	return c.CertFile != "" && c.KeyFile != ""
}

type JWTConfig struct {
	// "Disabled" instead of "Enabled" so that zero value with `false` does the safer thing.
	Disabled bool

	// Prod should set `ExtensionID` so that we can fetch the shared secret from Twitch API;
	// `Secret` and `SecretFile` are for dev... or if we don't get around to properly fetching based on `ExtensionID`. >_>
	// If JWT is enabled, exactly 1 of these must be set (i.e. not the empty string).
	ExtensionID string
	Secret      string
	SecretFile  string

	// The actual cached secret obtained by one of the above 3 methods.
	secret string
}

var jwtArgErr error = errors.New(`expected exactly ONE "jwt_*" CLI arg or env var`)

func (c Config) Validate() error {
	if c.Port < 0 || c.Port > 65535 {
		return fmt.Errorf("invalid port: %d", c.Port)
	} else if err := c.TLS.Validate(); err != nil {
		return err
	} else if err = c.JWT.Validate(); err != nil {
		return err
	}
	return nil
}

func (c JWTConfig) Validate() error {
	if c.Disabled {
		return nil
	}
	provided := []bool{
		c.ExtensionID != "",
		c.Secret != "",
		c.SecretFile != "",
	}
	found := false
	for _, p := range provided {
		if p {
			if found {
				return jwtArgErr
			} else {
				found = true
			}
		}
	}
	if !found {
		return jwtArgErr
	}
	return nil
}

func (c *JWTConfig) GetSecret() (string, error) {
	if c.secret != "" {
		return c.secret, nil
	} else if c.Disabled {
		return "", errors.New("JWT middleware is disabled")
	} else if c.ExtensionID != "" {
		return "", errors.New("fetching shared secret by Extension ID is not yet supported")
	} else if c.Secret != "" {
		c.secret = c.Secret
		return c.secret, nil
	} else if c.SecretFile == "" {
		return "", jwtArgErr
	}

	f, err := os.Open(c.SecretFile)
	if err != nil {
		return "", err
	}
	defer f.Close()
	if buf, err := ioutil.ReadAll(f); err != nil {
		return "", err
	} else if len(buf) == 0 {
		return "", errors.New("shared secret file is empty")
	} else {
		return string(buf), nil
	}
}

type LoggingConfig struct {
	Format string
	Level  string
}

// These defaults are for dev & test environments.
func DefaultConfig() Config {
	return Config{
		Port:  3000,
		DbUrl: "postgres://ting@localhost:5432/ting_test?sslmode=disable",
		TLS: TLSConfig{
			CertFile: "",
			KeyFile:  "",
		},
		JWT: JWTConfig{
			Disabled:    false,
			ExtensionID: "",
			Secret:      "",
			SecretFile:  "",
			secret:      "",
		},
		Logging: LoggingConfig{
			Format: "auto",
			Level:  "debug",
		},
	}
}

func (cfg Config) WithEnv() (Config, error) {
	// $PORT
	if portStr, found := os.LookupEnv("PORT"); found {
		if port, err := strconv.Atoi(portStr); err != nil {
			return cfg, fmt.Errorf("error parsing $PORT as int: %q", portStr)
		} else {
			cfg.Port = port
		}
	}

	// $DATABASE_URL
	if db, found := os.LookupEnv("DATABASE_URL"); found {
		cfg.DbUrl = db
	}

	// $JWT_DISABLED
	if disabledStr, found := os.LookupEnv("JWT_DISABLED"); found {
		if disabled, err := strconv.ParseBool(disabledStr); err != nil {
			return cfg, fmt.Errorf("error parsing $JWT_DISABLED as bool: %q", disabledStr)
		} else {
			cfg.JWT.Disabled = disabled
		}
	}

	// $JWT_EXTENSION_ID
	if id, found := os.LookupEnv("JWT_EXTENSION_ID"); found {
		cfg.JWT.ExtensionID = id
	}

	// $JWT_SECRET
	if secret, found := os.LookupEnv("JWT_SECRET"); found {
		cfg.JWT.Secret = secret
	}

	// $JWT_SECRET
	if secretFile, found := os.LookupEnv("JWT_SECRET_FILE"); found {
		cfg.JWT.SecretFile = secretFile
	}

	// $TLS_CERT_FILE
	if certFile, found := os.LookupEnv("TLS_CERT_FILE"); found {
		cfg.TLS.CertFile = certFile
	}

	// $TLS_KEY_FILE
	if keyFile, found := os.LookupEnv("TLS_KEY_FILE"); found {
		cfg.TLS.KeyFile = keyFile
	}

	// $LOGGING_FORMAT
	if format, found := os.LookupEnv("LOGGING_FORMAT"); found {
		cfg.Logging.Format = format
	}

	// $LOGGING_LEVEL
	if level, found := os.LookupEnv("LOGGING_LEVEL"); found {
		cfg.Logging.Level = level
	}

	return cfg, nil
}

const usageFmt string = `usage: %s [-port PORT] [-db DB_URL]
Starts the extension backend service for ting.

Options:`

var usage string

func init() {
	usage = fmt.Sprintf(usageFmt, path.Base(os.Args[0]))
}

func ParseArgv(args []string) (cfg Config, err error) {
	if cfg, err = DefaultConfig().WithEnv(); err != nil {
		return
	}

	fs := flag.NewFlagSet("main", flag.ContinueOnError)
	fs.Usage = func() {
		fmt.Fprintln(os.Stderr, usage)
		fs.PrintDefaults()
	}

	// Every method of parsing flags besides `FlagSet.Visit()` can't tell you
	// whether each flag was explicitly given vs whether the default is being used. x_x

	fs.Int("port", cfg.Port, "port to serve; overrides $PORT")
	fs.String("db", cfg.DbUrl, "database URL; overrides $DATABASE_URL")

	fs.Bool("tls_disabled", false, "disable TLS (i.e. ignore $TLS_* vars)")
	fs.String("tls_cert_file", cfg.TLS.CertFile, `"tls_key_file" must also be provided`)
	fs.String("tls_key_file", cfg.TLS.KeyFile, `"tls_cert_file" must also be provided`)

	fs.Bool("jwt_disabled", cfg.JWT.Disabled, "disable everything to do with JWTs")
	fs.String("jwt_extension_id", cfg.JWT.ExtensionID, "Extension ID to use for JWTs")
	fs.String("jwt_secret", cfg.JWT.Secret, "Base64-encoded shared secret for validating/signing JWTs")
	fs.String("jwt_secret_file", cfg.JWT.SecretFile, "file containing base64-encoded JWT shared secret.")

	fs.String("logging_format", cfg.Logging.Format, "text (color), plain (no color), or json")
	fs.String("logging_level", cfg.Logging.Level, "minimum log level")

	if err := fs.Parse(args); err != nil {
		return cfg, err
	}

	jwtArgFound := false
	fs.Visit(func(f *flag.Flag) {
		if err != nil {
			return
		}
		switch f.Name {
		case "port":
			cfg.Port = f.Value.(flag.Getter).Get().(int)
		case "db":
			cfg.DbUrl = f.Value.String()

		case "tls_disabled":
			cfg.TLS.CertFile = ""
			cfg.TLS.KeyFile = ""
		case "tls_cert_file":
			cfg.TLS.CertFile = f.Value.String()
		case "tls_key_file":
			cfg.TLS.KeyFile = f.Value.String()

		// JWT args are all mutually exclusive, so to allow CLI args to supersede env vars,
		// the first JWT-related CLI arg found will clear the rest of the JWT config.

		case "jwt_disabled":
			if jwtArgFound {
				err = jwtArgErr
			} else {
				jwtArgFound = true
				cfg.JWT = JWTConfig{
					Disabled:    true,
					ExtensionID: "",
					Secret:      "",
					SecretFile:  "",
				}
			}
		case "jwt_extension_id":
			if jwtArgFound {
				err = jwtArgErr
			} else {
				jwtArgFound = true
				cfg.JWT = JWTConfig{
					Disabled:    false,
					ExtensionID: f.Value.String(),
					Secret:      "",
					SecretFile:  "",
				}
			}
		case "jwt_secret":
			if jwtArgFound {
				err = jwtArgErr
			} else {
				jwtArgFound = true
				cfg.JWT = JWTConfig{
					Disabled:    false,
					ExtensionID: "",
					Secret:      f.Value.String(),
					SecretFile:  "",
				}
			}
		case "jwt_secret_file":
			if jwtArgFound {
				err = jwtArgErr
			} else {
				jwtArgFound = true
				cfg.JWT = JWTConfig{
					Disabled:    false,
					ExtensionID: "",
					Secret:      "",
					SecretFile:  f.Value.String(),
				}
			}

		case "logging_format":
			cfg.Logging.Format = f.Value.String()
		case "logging_level":
			cfg.Logging.Level = f.Value.String()
		}
	})
	if err != nil {
		return
	}

	err = cfg.Validate()
	return
}
