package sshchecker

import (
	"context"
	"encoding/base64"
	"encoding/json"
	"errors"
	"fmt"
	"io/ioutil"
	"net/http"
	"os"
	"strconv"
	"strings"
	"sync"
	"time"

	"github.com/go-resty/resty/v2"
	"golang.org/x/crypto/ssh"

	"a.yandex-team.ru/library/go/core/xerrors"
	"a.yandex-team.ru/passport/infra/daemons/shooting_gallery/shooter/internal/responder"
	"a.yandex-team.ru/passport/infra/daemons/shooting_gallery/shooter/internal/tvmcore"
	"a.yandex-team.ru/passport/shared/golibs/logger"
)

type Config struct {
	StaffURL       string   `json:"staff_url"`
	StaffCacheFile string   `json:"staff_cache"`
	AllowedTSDiff  int64    `json:"allowed_ts_diff"`
	AllowedLogins  []string `json:"allowed_logins"`
}

type State struct {
	mutex sync.RWMutex
	stop  chan bool
	tvm   *tvmcore.State

	cfg  Config
	keys keyChain

	responder *responder.Resp
}

type keyChain map[string][]ssh.PublicKey

func NewSSH(cfg Config, t *tvmcore.State, resp *responder.Resp) (*State, error) {
	keys, err := readFromDisk(cfg.StaffCacheFile)
	if err == nil {
		logger.Log().Infof("Ssh: keys were read from disk")
	} else {
		logger.Log().Warnf("Ssh: failed to read from disk: %s", err)

		keys, err = fetchFromHTTP(cfg, t)
		if err != nil {
			logger.Log().Warnf("Ssh: failed to fetch from staff: %s", err)
			return nil, xerrors.Errorf("failed to get keys from http and disk")
		}
		logger.Log().Infof("Ssh: keys were fetched from staff")
	}

	res := &State{
		stop:      make(chan bool),
		tvm:       t,
		cfg:       cfg,
		keys:      keys,
		responder: resp,
	}

	go func() {
		heartbeat := time.NewTicker(1 * time.Hour)

		for {
			select {
			case <-res.stop:
				logger.Log().Info("Ssh: quitting goroutine")
				return

			case <-heartbeat.C:
				if err := res.routine(); err != nil {
					logger.Log().Warnf("Ssh: error: %s", err)
				}
			}
		}
	}()

	return res, nil
}

func (s *State) Stop() {
	close(s.stop)
}

func (s *State) Middleware(next http.HandlerFunc) http.HandlerFunc {
	return func(w http.ResponseWriter, r *http.Request) {
		r = r.WithContext(context.WithValue(r.Context(), responder.AuthType, "ssh"))

		header := r.Header.Get("Authorization")
		login, err := s.middleware(header)
		r = r.WithContext(context.WithValue(r.Context(), responder.AuthID, login))

		if err != nil {
			s.responder.Return401(w, r, map[string]string{
				"error":  err.Error(),
				"header": header,
				"type":   "Authorization",
			})
			return
		}

		next.ServeHTTP(w, r)
	}
}

func (s *State) middleware(header string) (string, error) {
	fields, err := getHeaderFields(header)
	if err != nil {
		return "", err
	}
	if err := checkHeaderTimestamp(fields[1], s.cfg.AllowedTSDiff); err != nil {
		return fields[0], err
	}
	sign, err := getHeaderSignature(fields[2])
	if err != nil {
		return fields[0], err
	}

	keys, ok := s.getKeys(fields[0])
	if !ok {
		return fields[0], xerrors.Errorf("login is not allowed")
	}

	for _, k := range keys {
		if err := k.Verify([]byte(fields[1]), sign); err == nil {
			return fields[0], nil
		}
	}

	return fields[0], xerrors.Errorf("bad ssh sign")
}

func (s *State) getKeys(field string) ([]ssh.PublicKey, bool) {
	s.mutex.RLock()
	defer s.mutex.RUnlock()

	res, ok := s.keys[field]
	return res, ok
}

func getHeaderFields(header string) ([]string, error) {
	if header == "" {
		return nil, errors.New("missing header")
	}

	fields := strings.Split(header, " ")
	if len(fields) != 3 {
		return nil, xerrors.Errorf("format: <login> <timestamp> <base64(sign)>")
	}

	return fields, nil
}

func checkHeaderTimestamp(field string, allowedDiff int64) error {
	timestamp, err := strconv.Atoi(field)
	if err != nil {
		return xerrors.Errorf("failed to parse timestamp")
	}

	timeDiff := time.Now().Unix() - int64(timestamp)
	if timeDiff < 0 {
		timeDiff *= -1
	}
	if timeDiff > allowedDiff {
		return xerrors.Errorf("timestamp is bad")
	}

	return nil
}

func getHeaderSignature(field string) (*ssh.Signature, error) {
	signBytes, err := base64.StdEncoding.DecodeString(field)
	if err != nil {
		return nil, xerrors.Errorf("base64 in sign is invalid: %s", err)
	}

	var sign ssh.Signature
	if err := ssh.Unmarshal(signBytes, &sign); err != nil {
		return nil, xerrors.Errorf("failed to parse ssh sign: %s", err)
	}

	return &sign, nil
}

func (s *State) routine() error {
	keys, err := fetchFromHTTP(s.cfg, s.tvm)
	if err != nil {
		return err
	}

	s.mutex.Lock()
	defer s.mutex.Unlock()

	s.keys = keys

	return nil
}

func fetchFromHTTP(cfg Config, t *tvmcore.State) (keyChain, error) {
	ticket, err := t.GetTicket(tvmcore.TvmAliasStaff)
	if err != nil {
		return nil, xerrors.Errorf("failed to get ticket: %w", err)
	}

	client := resty.New().
		SetHostURL(cfg.StaffURL).
		SetTimeout(5 * time.Second).
		SetRedirectPolicy(resty.NoRedirectPolicy())

	path := fmt.Sprintf(
		"/v3/persons?login=%s&_fields=keys.key,login&official.is_dismissed=false",
		strings.Join(cfg.AllowedLogins, ","),
	)

	resp, err := client.R().
		SetHeader("X-Ya-Service-Ticket", ticket).
		Get(path)
	if err != nil {
		return nil, xerrors.Errorf("failed to perform http request: %w", err)
	}
	if resp.StatusCode() != 200 {
		return nil, xerrors.Errorf("failed to perform http request: %d.\n%s", resp.StatusCode(), string(resp.Body()))
	}

	keys, err := parseResponse(resp.Body())
	if err != nil {
		return nil, xerrors.Errorf("failed to parse response: %d.\n%s", resp.StatusCode(), string(resp.Body()))
	}

	if err := ioutil.WriteFile(cfg.StaffCacheFile, resp.Body(), os.FileMode(0666)); err != nil {
		return nil, xerrors.Errorf("failed to write keys to disk: %s", err)
	}

	return keys, nil
}

func readFromDisk(staffCacheFile string) (keyChain, error) {
	data, err := ioutil.ReadFile(staffCacheFile)
	if err != nil {
		return nil, xerrors.Errorf("failed to read file: %s. %w", staffCacheFile, err)
	}

	keys, err := parseResponse(data)
	if err != nil {
		return nil, xerrors.Errorf("failed to parse file: %w.\n%s", err, string(data))
	}

	return keys, nil
}

func parseResponse(resp []byte) (keyChain, error) {
	var staff staffResponse
	if err := json.Unmarshal(resp, &staff); err != nil {
		return nil, err
	}

	res := make(map[string][]ssh.PublicKey)

	for idx := range staff.Result {
		pkeys := make([]ssh.PublicKey, 0)
		for _, pk := range staff.Result[idx].Keys {
			out, _, _, _, err := ssh.ParseAuthorizedKey([]byte(pk.Key))
			if err != nil {
				logger.Log().Warnf("failed to parse ssh key for %s: %s", staff.Result[idx].Login, err)
				continue
			}
			pkeys = append(pkeys, out)
		}

		if len(pkeys) > 0 {
			res[staff.Result[idx].Login] = pkeys
		}
	}

	if len(res) == 0 {
		return nil, xerrors.Errorf("no one ssh key in memory")
	}

	return res, nil
}

type staffResponse struct {
	Result []staffResponseResult `json:"result"`
}

type staffResponseResult struct {
	Login string                   `json:"login"`
	Keys  []staffResponseResultKey `json:"keys"`
}

type staffResponseResultKey struct {
	Key string `json:"key"`
}
