package sshutil

import (
	"errors"
	"fmt"
	"time"
)

var errLeadingInt = errors.New("bad [0-9]*") // never printed

// leadingInt consumes the leading [0-9]* from s.
func leadingInt(s string) (x int64, rem string, err error) {
	i := 0
	for ; i < len(s); i++ {
		c := s[i]
		if c < '0' || c > '9' {
			break
		}
		if x > (1<<63-1)/10 {
			// overflow
			return 0, "", errLeadingInt
		}
		x = x*10 + int64(c) - '0'
		if x < 0 {
			// overflow
			return 0, "", errLeadingInt
		}
	}
	return x, s[i:], nil
}

var unitMap = map[string]int64{
	"":  int64(time.Second),
	"s": int64(time.Second),
	"S": int64(time.Second),
	"m": int64(time.Minute),
	"M": int64(time.Minute),
	"h": int64(time.Hour),
	"H": int64(time.Hour),
	"d": int64(time.Hour * 24),
	"D": int64(time.Hour * 24),
	"w": int64(time.Hour * 24 * 7),
	"W": int64(time.Hour * 24 * 7),
}

// ParseDuration parses a duration string
// according to ssh_config(5) format: https://man7.org/linux/man-pages/man5/sshd_config.5.html#TIME_FORMATS
func ParseDuration(s string) (time.Duration, error) {
	orig := s
	var d int64

	if s == "0" || s == "" {
		return 0, nil
	}

	for s != "" {
		var v int64
		var err error

		// The next character must be [0-9]
		if !('0' <= s[0] && s[0] <= '9') {
			return 0, fmt.Errorf("invalid duration: %s", orig)
		}
		// Consume [0-9]*
		pl := len(s)
		v, s, err = leadingInt(s)
		if err != nil {
			return 0, fmt.Errorf("invalid duration: %s", orig)
		}

		// whether we consumed anything
		if pl == len(s) {
			// no digits
			return 0, fmt.Errorf("invalid duration: %s", orig)
		}

		// Consume unit.
		i := 0
		for ; i < len(s); i++ {
			c := s[i]
			if '0' <= c && c <= '9' {
				break
			}
		}

		u := s[:i]
		if u == "" && d != 0 {
			return 0, fmt.Errorf("expected unit in duration: %s", orig)
		}

		s = s[i:]
		unit, ok := unitMap[u]
		if !ok {
			return 0, fmt.Errorf("unknown unit %q in duration: %s", u, orig)
		}
		if v > (1<<63-1)/unit {
			return 0, fmt.Errorf("overflow in duration: %s", orig)
		}

		d += v * unit
		if d < 0 {
			return 0, fmt.Errorf("overflow in duration: %s", orig)
		}
	}

	return time.Duration(d), nil
}
