package server

import (
	"context"
	"fmt"
	"io"
	"log"
	"net"
	"net/http"
	"os"
	"strconv"
	"time"

	"github.com/go-chi/chi/v5"
	"github.com/golang/protobuf/proto"
	"google.golang.org/grpc"
	"google.golang.org/grpc/grpclog"

	"a.yandex-team.ru/infra/hmserver/pkg/httpserver/httprpc"
	"a.yandex-team.ru/infra/hmserver/pkg/reporter/types"
	pb "a.yandex-team.ru/infra/hmserver/proto"
	"a.yandex-team.ru/infra/hostctl/pkg/pbutil"
	hostpb "a.yandex-team.ru/infra/hostctl/proto"
	"a.yandex-team.ru/library/go/httputil/headers"
	"a.yandex-team.ru/library/go/httputil/resource"
)

const (
	saveTimeout = 10 * time.Second
	readTimeout = 5 * time.Minute
	// cursorTimeout should be at least one hour to render results for large clusters
	cursorTimeout = 1 * time.Hour
)

type server struct {
	m        *Manager
	l        *log.Logger
	s        *grpc.Server
	grpcAddr string
}

var dir = resource.Dir("/")

func (s *server) HandleReport(ctx context.Context, req *pb.HandleReportRequest) (*pb.HandleReportResponse, error) {
	ctx, cancel := context.WithTimeout(ctx, saveTimeout)
	defer cancel()
	err := s.m.HandleReport(ctx, req.Units, req.HostInfo, req.LastStatusChange)
	return &pb.HandleReportResponse{}, err
}

func (s *server) GetReports(ctx context.Context, req *pb.GetReportsRequest) (*pb.GetReportsResponse, error) {
	ctx, cancel := context.WithTimeout(ctx, readTimeout)
	defer cancel()
	if req.Limit > 1000 {
		return nil, fmt.Errorf("limit is too large, should be less then 1000")
	}
	if req.Limit < 0 {
		return nil, fmt.Errorf("limit should not be negative")
	}
	if req.Offset < 0 {
		return nil, fmt.Errorf("offset should not be negative")
	}
	reports, err := s.m.GetHostReports(ctx, types.Node(req.Node), types.Unit(req.Unit), types.Stage(req.Stage), types.Version(req.Version), req.Ready, req.Pending, req.Limit, req.Offset)
	if err != nil {
		return nil, err
	}
	return &pb.GetReportsResponse{Reports: reports}, nil
}

func (s *server) GetUnits(ctx context.Context, req *pb.GetUnitsRequest) (*pb.GetUnitsResponse, error) {
	ctx, cancel := context.WithTimeout(ctx, readTimeout)
	defer cancel()
	units, err := s.m.GetUnits(ctx)
	if err != nil {
		return nil, err
	}
	unitsPb := make(map[string]*pb.Units)
	for name, u := range units {
		unitsPb[name] = &pb.Units{Units: u}
	}
	return &pb.GetUnitsResponse{Units: unitsPb}, nil
}

func (s *server) GetHosts(ctx context.Context, req *pb.GetHostsRequest) (*pb.GetHostsResponse, error) {
	ctx, cancel := context.WithTimeout(ctx, readTimeout)
	defer cancel()
	hosts, err := s.m.GetHosts(ctx, req.HostInfo)
	if err != nil {
		return nil, err
	}
	return &pb.GetHostsResponse{Hosts: hosts}, nil
}

func (s *server) GetHostsCount(ctx context.Context, req *pb.GetHostsCountRequest) (*pb.GetHostsCountResponse, error) {
	ctx, cancel := context.WithTimeout(ctx, readTimeout)
	defer cancel()
	total, err := s.m.hostsStorage.HostsCount(ctx)
	if err != nil {
		return nil, err
	}
	return &pb.GetHostsCountResponse{Total: int32(total)}, nil
}

func (s *server) GetKernelVersions(ctx context.Context, req *pb.GetKernelVersionsRequest) (*pb.GetKernelVersionsResponse, error) {
	ctx, cancel := context.WithTimeout(ctx, readTimeout)
	defer cancel()
	versions, err := s.m.hostsStorage.KernelVersions(ctx)
	if err != nil {
		return nil, err
	}
	return &pb.GetKernelVersionsResponse{Versions: versions}, nil
}

func (s *server) GetAllHosts(req *pb.GetAllHostsRequest, srv pb.Reporter_GetAllHostsServer) error {
	ctx := srv.Context()
	ctx, cancel := context.WithTimeout(ctx, cursorTimeout)
	defer cancel()
	cursor, err := s.m.GetHostsCursor(ctx)
	if err != nil {
		return err
	}
	batchSize := 500
	for {
		select {
		case <-ctx.Done():
			return ctx.Err()
		default:
		}
		records, err := cursor.Next(ctx, batchSize)
		if err == io.EOF {
			// return will close stream from server side
			return nil
		}
		if err != nil {
			s.l.Println(err)
			return err
		}
		if len(records) == 0 {
			return nil
		}
		resp := pb.GetAllHostsResponse{
			Infos: records,
		}
		if err := srv.Send(&resp); err != nil {
			return err
		}
	}
}

func (s *server) GetHost(ctx context.Context, req *pb.GetHostRequest) (*pb.GetHostResponse, error) {
	ctx, cancel := context.WithTimeout(ctx, readTimeout)
	defer cancel()
	host, err := s.m.GetHost(ctx, types.Node(req.Node))
	if err != nil {
		return nil, err
	}
	return &pb.GetHostResponse{HostInfo: host}, nil
}

func (s *server) GetUnit(ctx context.Context, req *pb.GetUnitRequest) (*pb.GetUnitResponse, error) {
	ctx, cancel := context.WithTimeout(ctx, readTimeout)
	defer cancel()
	versions, err := s.m.GetUnitVersions(ctx, types.Unit(req.UnitName))
	if err != nil {
		return nil, err
	}
	pending, err := s.m.GetUnitPending(ctx, types.Unit(req.UnitName))
	if err != nil {
		return nil, err
	}
	stages, err := s.m.GetUnitStages(ctx, types.Unit(req.UnitName))
	if err != nil {
		return nil, err
	}
	ready, err := s.m.GetUnitReady(ctx, types.Unit(req.UnitName))
	if err != nil {
		return nil, err
	}
	resp := &pb.GetUnitResponse{
		Ready:    make([]*pb.StatusCount, 0),
		Pending:  make([]*pb.StatusCount, 0),
		Versions: make(map[string]int64),
		Stages:   make(map[string]int64),
	}
	for v, c := range versions {
		resp.Versions[string(v)] = int64(c)
	}
	for v, c := range stages {
		resp.Stages[string(v)] = int64(c)
	}
	for r, c := range pending {
		resp.Pending = append(resp.Pending, &pb.StatusCount{
			Status: r,
			Count:  int64(c),
		})
	}
	for r, c := range ready {
		resp.Ready = append(resp.Ready, &pb.StatusCount{
			Status: r,
			Count:  int64(c),
		})
	}
	return resp, nil
}

func (s *server) GetHeartbeats(ctx context.Context, request *pb.GetHeartbeatsRequest) (*pb.GetHeartbeatsResponse, error) {
	resp := &pb.GetHeartbeatsResponse{Ok: &hostpb.Condition{}}
	heartbeats, err := s.m.GetHeartbeats(ctx, request.Hosts)
	if err != nil {
		pbutil.FalseCond(resp.Ok, err.Error())
		return resp, err
	}
	resp.Heartbeats = heartbeats
	pbutil.TrueCond(resp.Ok, "Ok")
	return resp, nil
}

func NewServer(l *log.Logger, m *Manager, grpcAddr string) *server {
	s := grpc.NewServer()
	glog := grpclog.NewLoggerV2(os.Stdout, os.Stdout, os.Stdout)
	grpclog.SetLoggerV2(glog)
	serv := &server{m, l, s, grpcAddr}
	pb.RegisterReporterServer(s, serv)
	return serv
}

func (s *server) RunGrpc(ctx context.Context) error {
	lis, err := net.Listen("tcp", s.grpcAddr)
	if err != nil {
		return fmt.Errorf("failed to listen: %w", err)
	}
	s.l.Println("grpc server listening")
	go func(ctx context.Context) {
		<-ctx.Done()
		s.s.GracefulStop()
	}(ctx)
	if err := s.s.Serve(lis); err != nil {
		return fmt.Errorf("failed to serve: %v", err)
	}
	return nil
}

/* Single page application handler for UI, all requests are served with index.html
   and JS machinery handles routes itself.
*/
func spaHandler(w http.ResponseWriter, _ *http.Request) {
	f, err := dir.Open("public/index.html")
	if err != nil {
		er(err, w, 500)
		return
	}
	_, err = io.Copy(w, f)
	if err != nil {
		er(err, w, 500)
		return
	}
	w.Header().Set(headers.ContentTypeKey, headers.TypeTextHTML.String())
}

func er(err error, w http.ResponseWriter, statusCode int) {
	w.WriteHeader(statusCode)
	resp := []byte(err.Error())
	w.Header().Set(headers.ContentTypeKey, headers.TypeTextPlain.String())
	w.Header().Set(headers.ContentLength, strconv.Itoa(len(resp)))
	_, _ = w.Write(resp)
}

func (s *server) RegisterUI(handle func(pattern string, h http.Handler), bind func(method, pattern string, h http.HandlerFunc)) {
	handle("/public/*", http.FileServer(dir))
	bind("GET", "/*", spaHandler)
}

func (s *server) Register(mux *chi.Mux) {
	httprpc.New("POST", "/api/reports").
		WithLogger(s.l).
		CorsAllowAll().
		WithJSONPbReader(&pb.GetReportsRequest{}).
		WithHandler(func(ctx context.Context, req proto.Message) (proto.Message, error) {
			return s.GetReports(ctx, req.(*pb.GetReportsRequest))
		}).
		WithJSONPbWriter().
		Mount(mux)
	httprpc.New("GET", "/api/units").
		WithLogger(s.l).
		CorsAllowAll().
		WithRequestReader(func(request *http.Request) (proto.Message, error) { return &pb.GetUnitsRequest{}, nil }).
		WithHandler(func(ctx context.Context, req proto.Message) (proto.Message, error) {
			return s.GetUnits(ctx, req.(*pb.GetUnitsRequest))
		}).
		WithJSONPbWriter().
		Mount(mux)
	httprpc.New("GET", "/api/units/{name}").
		WithLogger(s.l).
		CorsAllowAll().
		WithRequestReader(func(r *http.Request) (proto.Message, error) {
			return &pb.GetUnitRequest{UnitName: chi.URLParam(r, "name")}, nil
		}).
		WithHandler(func(ctx context.Context, req proto.Message) (proto.Message, error) {
			return s.GetUnit(ctx, req.(*pb.GetUnitRequest))
		}).
		WithJSONPbWriter().
		Mount(mux)
	httprpc.New("POST", "/api/hosts/").
		WithLogger(s.l).
		CorsAllowAll().
		WithJSONPbReader(&pb.GetHostsRequest{}).
		WithHandler(func(ctx context.Context, req proto.Message) (proto.Message, error) {
			return s.GetHosts(ctx, req.(*pb.GetHostsRequest))
		}).
		WithJSONPbWriter().
		Mount(mux)
	httprpc.New("GET", "/api/kernels").
		WithLogger(s.l).
		CorsAllowAll().
		WithRequestReader(func(request *http.Request) (proto.Message, error) { return &pb.GetKernelVersionsRequest{}, nil }).
		WithHandler(func(ctx context.Context, req proto.Message) (proto.Message, error) {
			return s.GetKernelVersions(ctx, req.(*pb.GetKernelVersionsRequest))
		}).
		WithJSONPbWriter().
		Mount(mux)
	httprpc.New("GET", "/api/hosts/{id}").
		WithLogger(s.l).
		CorsAllowAll().
		WithRequestReader(func(r *http.Request) (proto.Message, error) {
			return &pb.GetHostRequest{Node: chi.URLParam(r, "id")}, nil
		}).
		WithHandler(func(ctx context.Context, req proto.Message) (proto.Message, error) {
			return s.GetHost(ctx, req.(*pb.GetHostRequest))
		}).
		WithJSONPbWriter().
		Mount(mux)
	httprpc.New("POST", "/api/heartbeats/").
		WithLogger(s.l).
		CorsAllowAll().
		WithJSONPbReader(&pb.GetHeartbeatsRequest{}).
		WithHandler(func(ctx context.Context, req proto.Message) (proto.Message, error) {
			return s.GetHeartbeats(ctx, req.(*pb.GetHeartbeatsRequest))
		}).
		WithJSONPbWriter().
		Mount(mux)
}
