package shooter

import (
	"io/ioutil"
	"net/http"
	"strconv"
	"strings"

	"a.yandex-team.ru/library/go/core/xerrors"
	"a.yandex-team.ru/passport/infra/daemons/shooting_gallery/shooter/internal/ammo"
	"a.yandex-team.ru/passport/infra/daemons/shooting_gallery/shooter/internal/locker"
	"a.yandex-team.ru/passport/infra/daemons/shooting_gallery/shooter/internal/responder"
	"a.yandex-team.ru/passport/infra/daemons/shooting_gallery/shooter/internal/shooting"
	"a.yandex-team.ru/passport/infra/daemons/shooting_gallery/shooter/internal/sshchecker"
	"a.yandex-team.ru/passport/infra/daemons/shooting_gallery/shooter/internal/tvmcore"
)

type Config struct {
	Ammo      ammo.Config       `json:"ammo"`
	Shooting  shooting.Config   `json:"shooting"`
	Tvm       tvmcore.Config    `json:"tvm"`
	SSH       sshchecker.Config `json:"ssh"`
	Lock      locker.Config     `json:"lock"`
	AccessLog string            `json:"access_log"`
}

type Shooter struct {
	cfg       Config
	state     *State
	tvm       *tvmcore.State
	ssh       *sshchecker.State
	responder *responder.Resp
	lock      *locker.State
}

func NewShooter(cfg Config) (*Shooter, error) {
	lock, err := locker.NewLocker(cfg.Lock)
	if err != nil {
		return nil, err
	}

	state, err := NewState(cfg, lock)
	if err != nil {
		return nil, err
	}

	resp, err := responder.NewResponder(cfg.AccessLog)
	if err != nil {
		return nil, err
	}

	t, err := tvmcore.NewTvmCore(cfg.Tvm, resp)
	if err != nil {
		return nil, err
	}

	ssh, err := sshchecker.NewSSH(cfg.SSH, t, resp)
	if err != nil {
		return nil, err
	}

	res := &Shooter{
		cfg:       cfg,
		state:     state,
		tvm:       t,
		ssh:       ssh,
		responder: resp,
		lock:      lock,
	}

	return res, nil
}

func (s *Shooter) Stop() {
	s.ssh.Stop()
	s.tvm.Stop()
	s.state.Stop()
	s.lock.Stop()
}

func (s *Shooter) middlewareMethods(allowedMethods []string, next http.HandlerFunc) http.HandlerFunc {
	return func(w http.ResponseWriter, r *http.Request) {
		for _, m := range allowedMethods {
			if r.Method == m {
				next.ServeHTTP(w, r)
				return
			}
		}

		s.responder.Return400(w, r, xerrors.Errorf("invalid HTTP method. Allowed: %s. Got: %s", allowedMethods, r.Method))
	}
}

func (s *Shooter) middlewareLock(next http.HandlerFunc) http.HandlerFunc {
	return func(w http.ResponseWriter, r *http.Request) {
		login := responder.GetStringFromCtx(r, responder.AuthID)
		if err := s.lock.CheckAllowed(login); err != nil {
			s.responder.Return400(w, r, err)
			return
		}

		next.ServeHTTP(w, r)
	}
}

func (s *Shooter) CreateMux() (*http.ServeMux, error) {
	handlers := map[string]http.HandlerFunc{
		"/ping": s.middlewareMethods([]string{"GET"},
			s.HandlePing),
		"/cli/status": s.middlewareMethods([]string{"GET"},
			s.ssh.Middleware(
				s.HandleCliStatus)),
		"/cli/shooting/start": s.middlewareMethods([]string{"POST"},
			s.ssh.Middleware(
				s.middlewareLock(
					s.HandleCliShootingStart))),
		"/cli/shooting/stop": s.middlewareMethods([]string{"POST"},
			s.ssh.Middleware(
				s.middlewareLock(
					s.HandleCliShootingStop))),
		"/cli/ammo/list": s.middlewareMethods([]string{"GET"},
			s.ssh.Middleware(
				s.HandleCliAmmoList)),
		"/cli/ammo/create": s.middlewareMethods([]string{"POST"},
			s.ssh.Middleware(
				s.middlewareLock(
					s.HandleCliAmmoCreate))),
		"/cli/ammo/delete": s.middlewareMethods([]string{"POST"},
			s.ssh.Middleware(
				s.middlewareLock(
					s.HandleCliAmmoDelete))),
		"/cli/state/top": s.middlewareMethods([]string{"GET"},
			s.ssh.Middleware(
				s.HandleCliStateTop)),
		"/cli/state/perf": s.middlewareMethods([]string{"GET", "POST"},
			s.ssh.Middleware(
				s.middlewareLock(
					s.HandleCliStatePerf))),
		"/cli/state/perf_cancel": s.middlewareMethods([]string{"POST"},
			s.ssh.Middleware(
				s.middlewareLock(
					s.HandleCliStatePerfCancel))),
		"/cli/version": s.middlewareMethods([]string{"GET", "POST"},
			s.ssh.Middleware(
				s.middlewareLock(
					s.HandleCliVersion))),
		"/cli/lock": s.middlewareMethods([]string{"POST"},
			s.ssh.Middleware(
				s.HandleCliLock)),
		"/cli/unlock": s.middlewareMethods([]string{"POST"},
			s.ssh.Middleware(
				s.HandleCliUnlock)),
		"/prospector/task": s.middlewareMethods([]string{"GET"},
			s.HandleProspectorTask),
		"/stateviewer/top": s.middlewareMethods([]string{"POST"},
			s.tvm.Middleware(
				s.HandleStateviewerTop)),
		"/stateviewer/perf": s.middlewareMethods([]string{"GET", "POST"},
			s.tvm.Middleware(
				s.HandleStateviewerPerf)),
		"/stateviewer/version": s.middlewareMethods([]string{"GET", "POST"},
			s.tvm.Middleware(
				s.HandleStateviewerVersion)),
	}

	mux := http.NewServeMux()
	for k, v := range handlers {
		mux.Handle(k, v)
	}

	return mux, nil
}

func (s *Shooter) HandlePing(w http.ResponseWriter, r *http.Request) {
	res := map[string]string{
		"status": "ok",
	}
	s.responder.Return200(w, r, res)
}

func (s *Shooter) HandleCliStatus(w http.ResponseWriter, r *http.Request) {
	res, err := s.state.GetStatus()
	if err != nil {
		s.responder.Return500(w, r, err)
		return
	}

	s.responder.Return200(w, r, res)
}

func (s *Shooter) HandleCliShootingStart(w http.ResponseWriter, r *http.Request) {
	ammoID, ok := s.getRequiredCgiParamString(w, r, "ammo_id")
	if !ok {
		return
	}
	schema, ok := s.getRequiredCgiParamString(w, r, "schema")
	if !ok {
		return
	}
	rate, ok := s.getRequiredCgiParamInt(w, r, "rate")
	if !ok {
		return
	}
	instances, ok := s.getRequiredCgiParamInt(w, r, "instances")
	if !ok {
		return
	}
	duration, ok := s.getCgiParamInt(w, r, "duration")
	if !ok {
		return
	}
	workers, ok := s.getCgiParamInt(w, r, "workers")
	if !ok {
		return
	}
	connectionClose := r.URL.Query().Get("connection_close") == "yes"

	c := shooting.CmdCreatorImpl{}
	res, rErr, err := s.state.ShootingStart(shooting.Params{
		AmmoID:          ammoID,
		Schema:          schema,
		Rate:            rate,
		Instances:       instances,
		Duration:        duration,
		Workers:         workers,
		ConnectionClose: connectionClose,
	},
		&c)
	s.respond(w, r, res, rErr, err)
}

func (s *Shooter) HandleCliShootingStop(w http.ResponseWriter, r *http.Request) {
	res, rErr, err := s.state.ShootingStop()
	s.respond(w, r, res, rErr, err)
}

func (s *Shooter) HandleCliAmmoList(w http.ResponseWriter, r *http.Request) {
	res, err := s.state.AmmoList()
	if err != nil {
		s.responder.Return500(w, r, err)
		return
	}

	s.responder.Return200(w, r, res)
}

func (s *Shooter) HandleCliAmmoCreate(w http.ResponseWriter, r *http.Request) {
	hosts, ok := s.getRequiredCgiParamString(w, r, "hosts")
	if !ok {
		return
	}
	dur, ok := s.getRequiredCgiParamInt(w, r, "duration")
	if !ok {
		return
	}

	res, rErr, err := s.state.AmmoCreate(strings.Split(hosts, ","), dur)
	s.respond(w, r, res, rErr, err)
}

func (s *Shooter) HandleCliAmmoDelete(w http.ResponseWriter, r *http.Request) {
	id, ok := s.getRequiredCgiParamString(w, r, "id")
	if !ok {
		return
	}

	res, rErr, err := s.state.AmmoDelete(id)
	s.respond(w, r, res, rErr, err)
}

func (s *Shooter) HandleCliStateTop(w http.ResponseWriter, r *http.Request) {
	res, err := s.state.GetStateTop()
	s.respond(w, r, res, err, nil)
}

func (s *Shooter) HandleCliStatePerf(w http.ResponseWriter, r *http.Request) {
	if r.Method == "GET" {
		res, rErr, err := s.state.GetStatePerf()
		s.respond(w, r, res, rErr, err)
	} else {
		frequency, ok := s.getCgiParamInt(w, r, "frequency")
		if !ok {
			return
		}
		sleep, ok := s.getCgiParamInt(w, r, "sleep")
		if !ok {
			return
		}

		res, rErr, err := s.state.PostStatePerfCmd(frequency, sleep)
		s.respond(w, r, res, rErr, err)
	}
}

func (s *Shooter) HandleProspectorTask(w http.ResponseWriter, r *http.Request) {
	host, ok := s.getRequiredCgiParamString(w, r, "host")
	if !ok {
		return
	}

	res, err := s.state.ProspectorGetTask(host)
	s.respond(w, r, res, nil, err)
}

func (s *Shooter) HandleStateviewerTop(w http.ResponseWriter, r *http.Request) {
	timestamp, ok := s.getRequiredCgiParamInt(w, r, "timestamp")
	if !ok {
		return
	}

	output, err := ioutil.ReadAll(r.Body)
	if err != nil {
		s.responder.Return500(w, r, xerrors.Errorf("failed to read request body in HandleStateviewerTop: %s", err))
		return
	}

	res, err := s.state.PostStateviewerTop(string(output), int64(timestamp))
	s.respond(w, r, res, nil, err)
}

func (s *Shooter) HandleStateviewerPerf(w http.ResponseWriter, r *http.Request) {
	if r.Method == "GET" {
		res, rErr, err := s.state.GetStatePerfCmd()
		s.respond(w, r, res, rErr, err)
	} else {
		timestamp, ok := s.getRequiredCgiParamInt(w, r, "timestamp")
		if !ok {
			return
		}

		output, err := ioutil.ReadAll(r.Body)
		if err != nil {
			s.responder.Return500(w, r, xerrors.Errorf("failed to read request body in HandleStateviewerTop: %s", err))
			return
		}

		res, rErr, err := s.state.PostStatePerf(string(output), int64(timestamp))
		s.respond(w, r, res, rErr, err)
	}
}

func (s *Shooter) HandleStateviewerVersion(w http.ResponseWriter, r *http.Request) {
	if r.Method == "GET" {
		res, rErr, err := s.state.GetInstallVersionCmd()
		s.respond(w, r, res, rErr, err)
	} else {
		output, err := ioutil.ReadAll(r.Body)
		if err != nil {
			s.responder.Return500(w, r, xerrors.Errorf("failed to read request body in HandleStateviewerVersion: %s", err))
			return
		}

		res, rErr, err := s.state.PostVersion(output)
		s.respond(w, r, res, rErr, err)
	}
}

func (s *Shooter) HandleCliStatePerfCancel(w http.ResponseWriter, r *http.Request) {
	res, rErr, err := s.state.CancelStatePerf()
	s.respond(w, r, res, rErr, err)
}

func (s *Shooter) HandleCliVersion(w http.ResponseWriter, r *http.Request) {
	if r.Method == "GET" {
		res, rErr, err := s.state.GetVersion()
		s.respond(w, r, res, rErr, err)
	} else {
		pack, ok := s.getRequiredCgiParamString(w, r, "package")
		if !ok {
			return
		}
		version, ok := s.getRequiredCgiParamString(w, r, "version")
		if !ok {
			return
		}

		res, rErr, err := s.state.PostInstallVersionCmd(pack, version)
		s.respond(w, r, res, rErr, err)
	}
}

func (s *Shooter) HandleCliLock(w http.ResponseWriter, r *http.Request) {
	duration, ok := s.getCgiParamInt(w, r, "duration")
	if !ok {
		return
	}

	login := responder.GetStringFromCtx(r, responder.AuthID)

	if duration == 0 {
		duration = s.cfg.Lock.MaxDuration
	}
	if duration > s.cfg.Lock.MaxDuration {
		err := xerrors.Errorf("duration is too long: %d. Max: %d", duration, s.cfg.Lock.MaxDuration)
		s.respond(w, r, nil, err, nil)
		return
	}

	if err := s.lock.Lock(duration, login); err != nil {
		err = xerrors.Errorf("failed to lock shooting gallery: %s", err)
		s.respond(w, r, nil, err, nil)
		return
	}

	res := map[string]string{
		"status": "ok",
	}
	s.respond(w, r, res, nil, nil)
}

func (s *Shooter) HandleCliUnlock(w http.ResponseWriter, r *http.Request) {
	login := responder.GetStringFromCtx(r, responder.AuthID)

	if err := s.lock.Unlock(login); err != nil {
		err = xerrors.Errorf("failed to unlock shooting gallery: %s", err)
		s.respond(w, r, nil, err, nil)
		return
	}

	res := map[string]string{
		"status":  "ok",
		"details": "will be unlocked soon",
	}
	s.respond(w, r, res, nil, nil)
}

func (s *Shooter) respond(w http.ResponseWriter, r *http.Request, printable interface{}, rErr, err error) {
	if err != nil {
		s.responder.Return500(w, r, err)
		return
	}
	if rErr != nil {
		s.responder.Return400(w, r, rErr)
		return
	}

	s.responder.Return200(w, r, printable)
}

func (s *Shooter) getRequiredCgiParamString(w http.ResponseWriter, r *http.Request, key string) (string, bool) {
	val := r.URL.Query().Get(key)
	if val == "" {
		s.responder.Return400(w, r, xerrors.Errorf("param is required: %s", key))
		return "", false
	}
	return val, true
}

func (s *Shooter) getRequiredCgiParamInt(w http.ResponseWriter, r *http.Request, key string) (uint32, bool) {
	val, ok := s.getRequiredCgiParamString(w, r, key)
	valInt, err := strconv.ParseUint(val, 10, 32)
	if ok && err != nil {
		ok = false
		s.responder.Return400(w, r, xerrors.Errorf("param must be uint: %s", key))
	}
	return uint32(valInt), ok
}

func (s *Shooter) getCgiParamInt(w http.ResponseWriter, r *http.Request, key string) (uint32, bool) {
	val := r.URL.Query().Get(key)
	if val == "" {
		return 0, true
	}
	valInt, err := strconv.ParseUint(val, 10, 32)
	if err != nil {
		s.responder.Return400(w, r, xerrors.Errorf("param must be uint: %s", key))
		return 0, false
	}
	return uint32(valInt), true
}
