diff --git a/httpserver/httpserver.go b/httpserver/httpserver.go index 1341a87acb..d53be4d304 100644 --- a/httpserver/httpserver.go +++ b/httpserver/httpserver.go @@ -29,12 +29,22 @@ func Register(r *route.Router, reloadCh chan<- chan error) { r.Get("/metrics", promhttp.Handler().ServeHTTP) r.Post("/-/reload", func(w http.ResponseWriter, req *http.Request) { - errc := make(chan error) - defer close(errc) + errc := make(chan error, 1) - reloadCh <- errc - if err := <-errc; err != nil { - http.Error(w, fmt.Sprintf("failed to reload config: %s", err), http.StatusInternalServerError) + select { + case reloadCh <- errc: + case <-req.Context().Done(): + http.Error(w, req.Context().Err().Error(), http.StatusUnprocessableEntity) + return + } + + select { + case err := <-errc: + if err != nil { + http.Error(w, fmt.Sprintf("failed to reload config: %s", err), http.StatusInternalServerError) + } + case <-req.Context().Done(): + http.Error(w, req.Context().Err().Error(), http.StatusUnprocessableEntity) } }) diff --git a/httpserver/httpserver_test.go b/httpserver/httpserver_test.go index 24cfb7dbf9..73c6e4d531 100644 --- a/httpserver/httpserver_test.go +++ b/httpserver/httpserver_test.go @@ -14,14 +14,136 @@ package httpserver import ( + "context" + "errors" "net/http" "net/http/httptest" "testing" + "time" "github.com/prometheus/common/route" "github.com/stretchr/testify/require" ) +func TestReloadSuccess(t *testing.T) { + reloadCh := make(chan chan error) + router := route.New() + Register(router, reloadCh) + + done := make(chan struct{}) + go func() { + defer close(done) + req := httptest.NewRequest("POST", "/-/reload", nil) + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + require.Equal(t, http.StatusOK, w.Code) + }() + + errc := <-reloadCh + errc <- nil + + select { + case <-done: + case <-time.After(5 * time.Second): + t.Fatal("handler did not return") + } +} + +func TestReloadError(t *testing.T) { + reloadCh := make(chan chan error) + router := route.New() + Register(router, reloadCh) + + done := make(chan struct{}) + go func() { + defer close(done) + req := httptest.NewRequest("POST", "/-/reload", nil) + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + require.Equal(t, http.StatusInternalServerError, w.Code) + require.Contains(t, w.Body.String(), "bad config") + }() + + errc := <-reloadCh + errc <- errors.New("bad config") + + select { + case <-done: + case <-time.After(5 * time.Second): + t.Fatal("handler did not return") + } +} + +func TestReloadClientDisconnectBeforeEnqueue(t *testing.T) { + // reloadCh is never consumed, so the handler blocks on enqueue. + // Cancelling the context should unblock it. + reloadCh := make(chan chan error) + router := route.New() + Register(router, reloadCh) + + ctx, cancel := context.WithCancel(context.Background()) + done := make(chan struct{}) + go func() { + defer close(done) + req := httptest.NewRequest("POST", "/-/reload", nil).WithContext(ctx) + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + require.Equal(t, http.StatusUnprocessableEntity, w.Code) + }() + + // Give the handler time to block on reloadCh send. + time.Sleep(50 * time.Millisecond) + cancel() + + select { + case <-done: + case <-time.After(5 * time.Second): + t.Fatal("handler did not unblock after context cancellation") + } +} + +func TestReloadClientDisconnectDuringReload(t *testing.T) { + // The handler enqueues successfully but the client disconnects + // before the reload result arrives. The buffered channel ensures + // the main goroutine (sender) does not block. + reloadCh := make(chan chan error) + router := route.New() + Register(router, reloadCh) + + ctx, cancel := context.WithCancel(context.Background()) + done := make(chan struct{}) + go func() { + defer close(done) + req := httptest.NewRequest("POST", "/-/reload", nil).WithContext(ctx) + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + require.Equal(t, http.StatusUnprocessableEntity, w.Code) + }() + + errc := <-reloadCh + cancel() + + select { + case <-done: + case <-time.After(5 * time.Second): + t.Fatal("handler did not unblock after context cancellation") + } + + // Simulate the main goroutine sending the result after the handler + // has already returned. This must not block thanks to the buffered channel. + sendDone := make(chan struct{}) + go func() { + defer close(sendDone) + errc <- nil + }() + + select { + case <-sendDone: + case <-time.After(5 * time.Second): + t.Fatal("sender blocked on errc — channel must be buffered") + } +} + func TestDebugHandlersWithRoutePrefix(t *testing.T) { reloadCh := make(chan chan error)