diff --git a/httpbin/handlers_test.go b/httpbin/handlers_test.go index dbbfeec..484d9e4 100644 --- a/httpbin/handlers_test.go +++ b/httpbin/handlers_test.go @@ -24,12 +24,10 @@ import ( "strconv" "strings" "testing" - "testing/synctest" "time" "github.com/mccutchen/go-httpbin/v2/internal/testing/assert" "github.com/mccutchen/go-httpbin/v2/internal/testing/must" - "github.com/mccutchen/go-httpbin/v2/internal/testing/netpipetestserver" ) // appTestInfo carries the setup necessary for each unit test below, forming @@ -80,20 +78,6 @@ func setupTestApp(t *testing.T, opts ...OptionFunc) *appTestInfo { } } -func setupSynctestApp(t *testing.T, opts ...OptionFunc) *appTestInfo { - app := createApp(opts...) - srv, client := netpipetestserver.New(t, app) - client.Timeout = 5 * time.Second - client.CheckRedirect = func(_ *http.Request, _ []*http.Request) error { - return http.ErrUseLastResponse - } - return &appTestInfo{ - App: app, - Srv: srv, - Client: client, - } -} - // createApp creates an [HTTPBin] instance with default configuration, which // can be overridden by the given opts. func createApp(opts ...OptionFunc) *HTTPBin { @@ -2074,6 +2058,7 @@ func TestTrailers(t *testing.T) { func TestDelay(t *testing.T) { t.Parallel() + app := setupTestApp(t) okTests := []struct { url string @@ -2081,69 +2066,59 @@ func TestDelay(t *testing.T) { }{ // go-style durations are supported {"/delay/0ms", 0}, - {"/delay/500ms", 500 * time.Millisecond}, + {"/delay/100ms", 100 * time.Millisecond}, // as are floating point seconds {"/delay/0", 0}, - {"/delay/0.5", 500 * time.Millisecond}, - {"/delay/1", time.Second}, + {"/delay/0.1", 100 * time.Millisecond}, } for _, test := range okTests { t.Run("ok"+test.url, func(t *testing.T) { t.Parallel() - synctest.Test(t, func(t *testing.T) { - app := setupSynctestApp(t) - start := time.Now() - req := newTestRequest(t, "GET", app.URL(test.url), nil) - resp := mustDoRequest(t, app, req) - elapsed := time.Since(start) + start := time.Now() + req := newTestRequest(t, "GET", app.URL(test.url), nil) + resp := mustDoRequest(t, app, req) + elapsed := time.Since(start) - _ = mustParseResponse[bodyResponse](t, resp) + _ = mustParseResponse[bodyResponse](t, resp) - if elapsed < test.expectedDelay { - t.Fatalf("expected delay of %s, got %s", test.expectedDelay, elapsed) - } + if elapsed < test.expectedDelay { + t.Fatalf("expected delay of %s, got %s", test.expectedDelay, elapsed) + } - timings := decodeServerTimings(resp.Header.Get("Server-Timing")) - assert.DeepEqual(t, timings, map[string]serverTiming{ - "initial_delay": {"initial_delay", test.expectedDelay, "initial delay"}, - }, "incorrect Server-Timing header value") - }) + timings := decodeServerTimings(resp.Header.Get("Server-Timing")) + assert.DeepEqual(t, timings, map[string]serverTiming{ + "initial_delay": {"initial_delay", test.expectedDelay, "initial delay"}, + }, "incorrect Server-Timing header value") }) } t.Run("handle cancelation", func(t *testing.T) { t.Parallel() - synctest.Test(t, func(t *testing.T) { - app := setupSynctestApp(t) - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) - defer cancel() - req := newTestRequest(t, "GET", app.URL("/delay/1"), nil).WithContext(ctx) - _, err := app.Client.Do(req) - if !os.IsTimeout(err) { - t.Errorf("expected timeout error, got %v", err) - } - }) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) + defer cancel() + + req := newTestRequest(t, "GET", app.URL("/delay/1"), nil).WithContext(ctx) + _, err := app.Client.Do(req) + if !os.IsTimeout(err) { + t.Errorf("expected timeout error, got %v", err) + } }) t.Run("cancelation causes 499", func(t *testing.T) { t.Parallel() - synctest.Test(t, func(t *testing.T) { - app := setupSynctestApp(t) - - ctx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond) - defer cancel() + ctx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond) + defer cancel() - // use httptest.NewRecorder rather than a live httptest.NewServer - // because only the former will let us inspect the status code. - w := httptest.NewRecorder() - req, _ := http.NewRequestWithContext(ctx, "GET", "/delay/1s", nil) - app.App.ServeHTTP(w, req) - assert.Equal(t, w.Code, 499, "incorrect status code") - }) + // use httptest.NewRecorder rather than a live httptest.NewServer + // because only the former will let us inspect the status code. + w := httptest.NewRecorder() + req, _ := http.NewRequestWithContext(ctx, "GET", "/delay/1s", nil) + app.App.ServeHTTP(w, req) + assert.Equal(t, w.Code, 499, "incorrect status code") }) badTests := []struct { @@ -2164,7 +2139,6 @@ func TestDelay(t *testing.T) { for _, test := range badTests { t.Run("bad"+test.url, func(t *testing.T) { t.Parallel() - app := setupTestApp(t) req := newTestRequest(t, "GET", app.URL(test.url), nil) resp := mustDoRequest(t, app, req) assert.StatusCode(t, resp, test.code) @@ -2180,6 +2154,7 @@ func TestDrip(t *testing.T) { opts = []OptionFunc{ WithMaxBodySize(int64(maxBodySize)), } + app = setupTestApp(t, opts...) ) okTests := []struct { @@ -2198,10 +2173,9 @@ func TestDrip(t *testing.T) { {url.Values{"delay": {"0h"}}, 0, 10, http.StatusOK}, // or floating point seconds - {url.Values{"duration": {"0.25"}}, 250 * time.Millisecond, 10, http.StatusOK}, + {url.Values{"duration": {"0.1"}}, 100 * time.Millisecond, 10, http.StatusOK}, {url.Values{"duration": {"0"}}, 0, 10, http.StatusOK}, - {url.Values{"duration": {"1"}}, 1 * time.Second, 10, http.StatusOK}, - {url.Values{"delay": {"0.25"}}, 250 * time.Millisecond, 10, http.StatusOK}, + {url.Values{"delay": {"0.1"}}, 100 * time.Millisecond, 10, http.StatusOK}, {url.Values{"delay": {"0"}}, 0, 10, http.StatusOK}, {url.Values{"numbytes": {"1"}}, 0, 1, http.StatusOK}, @@ -2212,35 +2186,32 @@ func TestDrip(t *testing.T) { {url.Values{"code": {"599"}}, 0, 10, 599}, {url.Values{"code": {"567"}}, 0, 10, 567}, - {url.Values{"duration": {"250ms"}, "delay": {"250ms"}}, 500 * time.Millisecond, 10, http.StatusOK}, - {url.Values{"duration": {"250ms"}, "delay": {"0.25s"}}, 500 * time.Millisecond, 10, http.StatusOK}, + {url.Values{"duration": {"100ms"}, "delay": {"100ms"}}, 200 * time.Millisecond, 10, http.StatusOK}, + {url.Values{"duration": {"100ms"}, "delay": {"0.1"}}, 200 * time.Millisecond, 10, http.StatusOK}, } for _, test := range okTests { t.Run(fmt.Sprintf("ok/%s", test.params.Encode()), func(t *testing.T) { t.Parallel() - synctest.Test(t, func(t *testing.T) { - app := setupSynctestApp(t) - start := time.Now() - req := newTestRequest(t, "GET", app.URL("/drip", test.params), nil) - resp := mustDoRequest(t, app, req) - assert.BodySize(t, resp, test.numbytes) // must read body before measuring elapsed time - elapsed := time.Since(start) - - assert.StatusCode(t, resp, test.code) - assert.ContentType(t, resp, textContentType) - assert.Header(t, resp, "Content-Length", strconv.Itoa(test.numbytes)) - if elapsed < test.duration { - t.Fatalf("expected minimum duration of %s, request took %s", test.duration, elapsed) - } + start := time.Now() + req := newTestRequest(t, "GET", app.URL("/drip", test.params), nil) + resp := mustDoRequest(t, app, req) + assert.BodySize(t, resp, test.numbytes) // must read body before measuring elapsed time + elapsed := time.Since(start) - // Note: while the /drip endpoint seems like an ideal use case for - // using chunked transfer encoding to stream data to the client, it - // is actually intended to simulate a slow connection between - // server and client, so it is important to ensure that it writes a - // "regular," un-chunked response. - assert.DeepEqual(t, resp.TransferEncoding, nil, "unexpected Transfer-Encoding header") - }) + assert.StatusCode(t, resp, test.code) + assert.ContentType(t, resp, textContentType) + assert.Header(t, resp, "Content-Length", strconv.Itoa(test.numbytes)) + if elapsed < test.duration { + t.Fatalf("expected minimum duration of %s, request took %s", test.duration, elapsed) + } + + // Note: while the /drip endpoint seems like an ideal use case for + // using chunked transfer encoding to stream data to the client, it + // is actually intended to simulate a slow connection between + // server and client, so it is important to ensure that it writes a + // "regular," un-chunked response. + assert.DeepEqual(t, resp.TransferEncoding, nil, "unexpected Transfer-Encoding header") }) } @@ -2255,7 +2226,6 @@ func TestDrip(t *testing.T) { // indication we need. t.Parallel() - app := setupTestApp(t) req := newTestRequest(t, "GET", app.URL("/drip?code=100"), nil) reqBytes, err := httputil.DumpRequestOut(req, false) assert.NilError(t, err) @@ -2275,103 +2245,85 @@ func TestDrip(t *testing.T) { t.Run("writes are actually incremmental", func(t *testing.T) { t.Parallel() - synctest.Test(t, func(t *testing.T) { - app := setupSynctestApp(t) - - var ( - duration = 1 * time.Second - numBytes = 3 - endpoint = fmt.Sprintf("/drip?duration=%s&numbytes=%d", duration, numBytes) - // Match server logic for calculating the delay between writes - wantPauseBetweenWrites = computePausePerWrite(duration, int64(numBytes)) - ) - req := newTestRequest(t, "GET", app.URL(endpoint), nil) - resp := mustDoRequest(t, app, req) - - // Here we read from the response one byte at a time, and ensure that - // at least the expected delay occurs for each read. - // - // The request above includes an initial delay equal to the expected - // wait between writes so that even the first iteration of this loop - // expects to wait the same amount of time for a read. - buf := make([]byte, 1024) - gotBody := make([]byte, 0, numBytes) - for i := 0; ; i++ { - start := time.Now() - n, err := resp.Body.Read(buf) - gotPause := time.Since(start) - - // We expect to read exactly one byte on each iteration. On the - // last iteration, we expct to hit EOF after reading the final - // byte, because the server does not pause after the last write. - assert.Equal(t, n, 1, "incorrect number of bytes read") - assert.DeepEqual(t, buf[:n], []byte{'*'}, "unexpected bytes read") - gotBody = append(gotBody, buf[:n]...) - - if err == io.EOF { - break - } + var ( + duration = 500 * time.Millisecond + numBytes = 3 + endpoint = fmt.Sprintf("/drip?duration=%s&numbytes=%d", duration, numBytes) + ) - assert.NilError(t, err) + // start timer before sending the request to ensure the client + // duration measurement is at least as long as the server's duration, + // to avoid flakiness + start := time.Now() + req := newTestRequest(t, "GET", app.URL(endpoint), nil) + resp := mustDoRequest(t, app, req) - // only ensure that we pause for the expected time between writes - // after the first byte. - if i > 0 { - assert.Equal(t, gotPause, wantPauseBetweenWrites, "incorrect pause between writes") - } + // read incremental writes in a loop. should read one byte at a time, + // despite the larger read buffer. + buf := make([]byte, 1024) + gotBody := make([]byte, 0, numBytes) + numReads := 0 + for { + n, err := resp.Body.Read(buf) + assert.Equal(t, n, 1, "incorrect number of bytes read") + assert.DeepEqual(t, buf[:n], []byte{'*'}, "unexpected bytes read") + gotBody = append(gotBody, buf[:n]...) + numReads++ + if err == io.EOF { + break } + assert.NilError(t, err) + } - wantBody := bytes.Repeat([]byte{'*'}, numBytes) - assert.DeepEqual(t, gotBody, wantBody, "incorrect body") - }) + // writes were incrmemental if a) we did one read per byte and b) + // reading the whole response took (at least) the expected duration + assert.Equal(t, numReads, numBytes, "incorrect read count") + assert.MinDuration(t, time.Since(start), duration) + + wantBody := bytes.Repeat([]byte{'*'}, numBytes) + assert.DeepEqual(t, gotBody, wantBody, "incorrect body") }) t.Run("handle cancelation during initial delay", func(t *testing.T) { t.Parallel() - synctest.Test(t, func(t *testing.T) { - app := setupSynctestApp(t) - // For this test, we expect the client to time out and cancel the - // request after 10ms. The handler should still be in its intitial - // delay period, so this will result in a request error since no status - // code will be written before the cancelation. - ctx, cancel := context.WithTimeout(context.Background(), 25*time.Millisecond) - defer cancel() + // For this test, we expect the client to time out and cancel the + // request after 10ms. The handler should still be in its intitial + // delay period, so this will result in a request error since no status + // code will be written before the cancelation. + ctx, cancel := context.WithTimeout(context.Background(), 25*time.Millisecond) + defer cancel() - req := newTestRequest(t, "GET", app.URL("/drip?duration=500ms&delay=500ms"), nil).WithContext(ctx) - if _, err := app.Client.Do(req); !os.IsTimeout(err) { - t.Fatalf("expected timeout error, got %s", err) - } - }) + req := newTestRequest(t, "GET", app.URL("/drip?duration=500ms&delay=500ms"), nil).WithContext(ctx) + if _, err := app.Client.Do(req); !os.IsTimeout(err) { + t.Fatalf("expected timeout error, got %s", err) + } }) t.Run("handle cancelation during drip", func(t *testing.T) { t.Parallel() - synctest.Test(t, func(t *testing.T) { - app := setupSynctestApp(t) - ctx, cancel := context.WithTimeout(context.Background(), 250*time.Millisecond) - defer cancel() + ctx, cancel := context.WithTimeout(context.Background(), 250*time.Millisecond) + defer cancel() - req := newTestRequest(t, "GET", app.URL("/drip?duration=900ms&delay=100ms"), nil).WithContext(ctx) - resp := mustDoRequest(t, app, req) + req := newTestRequest(t, "GET", app.URL("/drip?duration=900ms&delay=100ms"), nil).WithContext(ctx) + resp := mustDoRequest(t, app, req) - // In this test, the server should have started an OK response before - // our client timeout cancels the request, so we should get an OK here. - assert.StatusCode(t, resp, http.StatusOK) + // In this test, the server should have started an OK response before + // our client timeout cancels the request, so we should get an OK here. + assert.StatusCode(t, resp, http.StatusOK) - // But, we should time out while trying to read the whole response - // body. - body, err := io.ReadAll(resp.Body) - if !os.IsTimeout(err) { - t.Fatalf("expected timeout reading body, got %s", err) - } + // But, we should time out while trying to read the whole response + // body. + body, err := io.ReadAll(resp.Body) + if !os.IsTimeout(err) { + t.Fatalf("expected timeout reading body, got %s", err) + } - // And even though the request timed out, we should get a partial - // response. - assert.DeepEqual(t, body, []byte("**"), "incorrect partial body") - }) + // And even though the request timed out, we should get a partial + // response. + assert.DeepEqual(t, body, []byte("**"), "incorrect partial body") }) badTests := []struct { @@ -2407,7 +2359,6 @@ func TestDrip(t *testing.T) { for _, test := range badTests { t.Run(fmt.Sprintf("bad/%s", test.params.Encode()), func(t *testing.T) { t.Parallel() - app := setupTestApp(t, opts...) req := newTestRequest(t, "GET", app.URL("/drip", test.params), nil) resp := mustDoRequest(t, app, req) assert.StatusCode(t, resp, test.code) @@ -2416,7 +2367,6 @@ func TestDrip(t *testing.T) { t.Run("ensure HEAD request works with streaming responses", func(t *testing.T) { t.Parallel() - app := setupTestApp(t, opts...) req := newTestRequest(t, "HEAD", app.URL("/drip?duration=900ms&delay=100ms"), nil) resp := mustDoRequest(t, app, req) assert.StatusCode(t, resp, http.StatusOK) @@ -2425,36 +2375,33 @@ func TestDrip(t *testing.T) { t.Run("Server-Timings header", func(t *testing.T) { t.Parallel() - synctest.Test(t, func(t *testing.T) { - app := setupSynctestApp(t) - var ( - duration = 100 * time.Millisecond - delay = 50 * time.Millisecond - numBytes = 10 - ) + var ( + duration = 100 * time.Millisecond + delay = 50 * time.Millisecond + numBytes = 10 + ) - url := fmt.Sprintf("/drip?duration=%s&delay=%s&numbytes=%d", duration, delay, numBytes) - req := newTestRequest(t, "GET", app.URL(url), nil) - resp := mustDoRequest(t, app, req) + url := fmt.Sprintf("/drip?duration=%s&delay=%s&numbytes=%d", duration, delay, numBytes) + req := newTestRequest(t, "GET", app.URL(url), nil) + resp := mustDoRequest(t, app, req) - assert.StatusCode(t, resp, http.StatusOK) + assert.StatusCode(t, resp, http.StatusOK) - timings := decodeServerTimings(resp.Header.Get("Server-Timing")) + timings := decodeServerTimings(resp.Header.Get("Server-Timing")) - // compute expected pause between writes to match server logic and - // handle lossy floating point truncation in the serialized header - // value - computedPause := duration / time.Duration(numBytes-1) - wantPause, _ := time.ParseDuration(fmt.Sprintf("%.2fms", computedPause.Seconds()*1e3)) + // compute expected pause between writes to match server logic and + // handle lossy floating point truncation in the serialized header + // value + computedPause := duration / time.Duration(numBytes-1) + wantPause, _ := time.ParseDuration(fmt.Sprintf("%.2fms", computedPause.Seconds()*1e3)) - assert.DeepEqual(t, timings, map[string]serverTiming{ - "total_duration": {"total_duration", delay + duration, "total request duration"}, - "initial_delay": {"initial_delay", delay, "initial delay"}, - "pause_per_write": {"pause_per_write", wantPause, "computed pause between writes"}, - "write_duration": {"write_duration", duration, "duration of writes after initial delay"}, - }, "incorrect Server-Timing header value") - }) + assert.DeepEqual(t, timings, map[string]serverTiming{ + "total_duration": {"total_duration", delay + duration, "total request duration"}, + "initial_delay": {"initial_delay", delay, "initial delay"}, + "pause_per_write": {"pause_per_write", wantPause, "computed pause between writes"}, + "write_duration": {"write_duration", duration, "duration of writes after initial delay"}, + }, "incorrect Server-Timing header value") }) } @@ -2544,25 +2491,23 @@ func TestRange(t *testing.T) { t.Run("ok_range_with_duration", func(t *testing.T) { t.Parallel() - synctest.Test(t, func(t *testing.T) { - app := setupSynctestApp(t) - url := "/range/100?duration=100ms" - req := newTestRequest(t, "GET", app.URL(url), nil) - req.Header.Add("Range", "bytes=10-24") + app := setupTestApp(t) + url := "/range/100?duration=100ms" + req := newTestRequest(t, "GET", app.URL(url), nil) + req.Header.Add("Range", "bytes=10-24") - start := time.Now() - resp := mustDoRequest(t, app, req) - elapsed := time.Since(start) + start := time.Now() + resp := mustDoRequest(t, app, req) + elapsed := time.Since(start) - assert.StatusCode(t, resp, http.StatusPartialContent) - assert.Header(t, resp, "ETag", "range100") - assert.Header(t, resp, "Accept-Ranges", "bytes") - assert.Header(t, resp, "Content-Length", "15") - assert.Header(t, resp, "Content-Range", "bytes 10-24/100") - assert.Header(t, resp, "Content-Type", textContentType) - assert.BodyEquals(t, resp, "klmnopqrstuvwxy") - assert.Equal(t, elapsed, 15*time.Millisecond, "incorrect duration") - }) + assert.StatusCode(t, resp, http.StatusPartialContent) + assert.Header(t, resp, "ETag", "range100") + assert.Header(t, resp, "Accept-Ranges", "bytes") + assert.Header(t, resp, "Content-Length", "15") + assert.Header(t, resp, "Content-Range", "bytes 10-24/100") + assert.Header(t, resp, "Content-Type", textContentType) + assert.BodyEquals(t, resp, "klmnopqrstuvwxy") + assert.MinDuration(t, elapsed, 15*time.Millisecond) }) t.Run("ok_multiple_ranges", func(t *testing.T) { @@ -3286,18 +3231,18 @@ func TestJSONL(t *testing.T) { url string expectedLines int }{ - {"/jsonl", 10}, // default count - {"/jsonl?count=1", 1}, // minimum - {"/jsonl?count=5", 5}, // custom count - {"/jsonl?count=0", 1}, // clamped to min - {"/jsonl?count=-5", 1}, // clamped to min - {"/jsonl?count=3&duration=1s", 3}, // with duration - {"/jsonl?count=1&duration=1s", 1}, // single line with duration - {"/jsonl?count=3&delay=0s", 3}, // with zero delay - {"/jsonl?count=2&duration=1s&delay=0s", 2}, // with both - {"/jsonl?count=3&duration=1s&jitter=0", 3}, // jitter=0 (no effect) - {"/jsonl?count=3&duration=1s&jitter=0.5", 3}, // jitter=0.5 - {"/jsonl?count=3&duration=1s&jitter=1", 3}, // jitter=1 (max) + {"/jsonl", 10}, // default count + {"/jsonl?count=1", 1}, // minimum + {"/jsonl?count=5", 5}, // custom count + {"/jsonl?count=0", 1}, // clamped to min + {"/jsonl?count=-5", 1}, // clamped to min + {"/jsonl?count=3&duration=100ms", 3}, // with duration + {"/jsonl?count=1&duration=100ms", 1}, // single line with duration + {"/jsonl?count=3&delay=0s", 3}, // with zero delay + {"/jsonl?count=2&duration=100ms&delay=0s", 2}, // with both + {"/jsonl?count=3&duration=100ms&jitter=0", 3}, // jitter=0 (no effect) + {"/jsonl?count=3&duration=100ms&jitter=0.5", 3}, // jitter=0.5 + {"/jsonl?count=3&duration=100ms&jitter=1", 3}, // jitter=1 (max) } for _, test := range okTests { t.Run("ok"+test.url, func(t *testing.T) { @@ -3371,42 +3316,40 @@ func TestJSONL(t *testing.T) { t.Run("writes are actually incremental", func(t *testing.T) { t.Parallel() + // request params var ( duration = 100 * time.Millisecond count = 3 endpoint = fmt.Sprintf("/jsonl?duration=%s&count=%d", duration, count) - - wantPauseBetweenWrites = duration / time.Duration(count-1) ) - synctest.Test(t, func(t *testing.T) { - app := setupSynctestApp(t) - req := newTestRequest(t, "GET", app.URL(endpoint), nil) - resp := mustDoRequest(t, app, req) - scanner := bufio.NewScanner(resp.Body) - lineCount := 0 - - for i := 0; ; i++ { - start := time.Now() - if !scanner.Scan() { - break - } - gotPause := time.Since(start) - - var sr streamResponse - err := json.Unmarshal(scanner.Bytes(), &sr) - assert.NilError(t, err) - assert.Equal(t, sr.ID, i, "unexpected JSONL line ID") - - if i > 0 { - assert.Equal(t, gotPause, wantPauseBetweenWrites, "incorrect pause between writes") - } + // start timer before sending the request to ensure the client + // duration measurement is at least as long as the server's duration, + // to avoid flakiness + start := time.Now() + req := newTestRequest(t, "GET", app.URL(endpoint), nil) + resp := mustDoRequest(t, app, req) - lineCount++ + // read incremental writes in a loop. should read one line at a time, + // despite the larger read buffer. + // + scanner := bufio.NewScanner(resp.Body) + numReads := 0 + var sr streamResponse + for i := 0; ; i++ { + if !scanner.Scan() { + assert.NilError(t, scanner.Err()) // Err returns nil on io.EOF + break } - assert.NilError(t, scanner.Err()) - assert.Equal(t, lineCount, count, "unexpected number of lines") - }) + assert.NilError(t, json.Unmarshal(scanner.Bytes(), &sr)) + assert.Equal(t, sr.ID, i, "unexpected JSONL line ID") + numReads++ + } + + // writes were incrmemental if a) we did one read per line and b) + // reading the whole response took (at least) the expected duration + assert.Equal(t, numReads, count, "unexpected number of lines") + assert.MinDuration(t, time.Since(start), duration) }) t.Run("handle cancelation during initial delay", func(t *testing.T) { @@ -3424,32 +3367,30 @@ func TestJSONL(t *testing.T) { t.Run("handle cancelation during stream", func(t *testing.T) { t.Parallel() - synctest.Test(t, func(t *testing.T) { - app := setupSynctestApp(t) + app := setupTestApp(t) - ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) - defer cancel() + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() - req := newTestRequest(t, "GET", app.URL("/jsonl?duration=900ms&delay=0&count=2"), nil).WithContext(ctx) - resp := mustDoRequest(t, app, req) + req := newTestRequest(t, "GET", app.URL("/jsonl?duration=900ms&delay=0&count=2"), nil).WithContext(ctx) + resp := mustDoRequest(t, app, req) - assert.StatusCode(t, resp, http.StatusOK) + assert.StatusCode(t, resp, http.StatusOK) - // Should time out while trying to read the whole body - body, err := io.ReadAll(resp.Body) - if !os.IsTimeout(err) { - t.Fatalf("expected timeout reading body, got %s", err) - } + // Should time out while trying to read the whole body + body, err := io.ReadAll(resp.Body) + if !os.IsTimeout(err) { + t.Fatalf("expected timeout reading body, got %s", err) + } - // Partial read should include the first line - var sr streamResponse - scanner := bufio.NewScanner(bytes.NewReader(body)) - if !scanner.Scan() { - t.Fatal("expected at least one JSONL line in partial body") - } - assert.NilError(t, json.Unmarshal(scanner.Bytes(), &sr)) - assert.Equal(t, sr.ID, 0, "unexpected JSONL line ID") - }) + // Partial read should include the first line + var sr streamResponse + scanner := bufio.NewScanner(bytes.NewReader(body)) + if !scanner.Scan() { + t.Fatal("expected at least one JSONL line in partial body") + } + assert.NilError(t, json.Unmarshal(scanner.Bytes(), &sr)) + assert.Equal(t, sr.ID, 0, "unexpected JSONL line ID") }) t.Run("ensure HEAD request works with streaming responses", func(t *testing.T) { @@ -3660,40 +3601,37 @@ func TestSSE(t *testing.T) { {url.Values{"delay": {"0h"}}, 0, 10}, // or floating point seconds - {url.Values{"duration": {"0.25"}}, 250 * time.Millisecond, 10}, - {url.Values{"duration": {"1"}}, 1 * time.Second, 10}, - {url.Values{"delay": {"0.25"}}, 250 * time.Millisecond, 10}, + {url.Values{"duration": {"0.1"}}, 100 * time.Millisecond, 10}, + {url.Values{"delay": {"0.1"}}, 100 * time.Millisecond, 10}, {url.Values{"delay": {"0"}}, 0, 10}, {url.Values{"count": {"1"}}, 0, 1}, {url.Values{"count": {"011"}}, 0, 11}, {url.Values{"count": {fmt.Sprintf("%d", app.App.maxSSECount)}}, 0, int(app.App.maxSSECount)}, - {url.Values{"duration": {"250ms"}, "delay": {"250ms"}}, 500 * time.Millisecond, 10}, - {url.Values{"duration": {"250ms"}, "delay": {"0.25s"}}, 500 * time.Millisecond, 10}, + {url.Values{"duration": {"100ms"}, "delay": {"100ms"}}, 200 * time.Millisecond, 10}, + {url.Values{"duration": {"100ms"}, "delay": {"0.1"}}, 200 * time.Millisecond, 10}, - {url.Values{"duration": {"250ms"}, "jitter": {"0"}}, 250 * time.Millisecond, 10}, - {url.Values{"duration": {"250ms"}, "jitter": {"0.5"}}, 0, 10}, - {url.Values{"duration": {"250ms"}, "jitter": {"1"}}, 0, 10}, + {url.Values{"duration": {"100ms"}, "jitter": {"0"}}, 100 * time.Millisecond, 10}, + {url.Values{"duration": {"100ms"}, "jitter": {"0.5"}}, 0, 10}, + {url.Values{"duration": {"100ms"}, "jitter": {"1"}}, 0, 10}, } for _, test := range okTests { t.Run(fmt.Sprintf("ok/%s", test.params.Encode()), func(t *testing.T) { t.Parallel() - synctest.Test(t, func(t *testing.T) { - app := setupSynctestApp(t) - req := newTestRequest(t, "GET", app.URL("/sse", test.params), nil) - start := time.Now() - resp := mustDoRequest(t, app, req) - assert.StatusCode(t, resp, http.StatusOK) - events := parseServerSentEventStream(t, resp) + app := setupTestApp(t) + req := newTestRequest(t, "GET", app.URL("/sse", test.params), nil) + start := time.Now() + resp := mustDoRequest(t, app, req) + assert.StatusCode(t, resp, http.StatusOK) + events := parseServerSentEventStream(t, resp) - if elapsed := time.Since(start); elapsed < test.duration { - t.Fatalf("expected minimum duration of %s, request took %s", test.duration, elapsed) - } - assert.ContentType(t, resp, sseContentType) - assert.DeepEqual(t, resp.TransferEncoding, []string{"chunked"}, "unexpected Transfer-Encoding header") - assert.Equal(t, len(events), test.count, "unexpected number of events") - }) + if elapsed := time.Since(start); elapsed < test.duration { + t.Fatalf("expected minimum duration of %s, request took %s", test.duration, elapsed) + } + assert.ContentType(t, resp, sseContentType) + assert.DeepEqual(t, resp.TransferEncoding, []string{"chunked"}, "unexpected Transfer-Encoding header") + assert.Equal(t, len(events), test.count, "unexpected number of events") }) } @@ -3744,49 +3682,34 @@ func TestSSE(t *testing.T) { duration = 100 * time.Millisecond count = 3 endpoint = fmt.Sprintf("/sse?duration=%s&count=%d", duration, count) - - // Match server logic for calculating the delay between writes - wantPauseBetweenWrites = duration / time.Duration(count-1) ) - synctest.Test(t, func(t *testing.T) { - app := setupSynctestApp(t) - req := newTestRequest(t, "GET", app.URL(endpoint), nil) - resp := mustDoRequest(t, app, req) - buf := bufio.NewReader(resp.Body) - eventCount := 0 - - // Here we read from the response one byte at a time, and ensure that - // at least the expected delay occurs for each read. - // - // The request above includes an initial delay equal to the expected - // wait between writes so that even the first iteration of this loop - // expects to wait the same amount of time for a read. - for i := 0; ; i++ { - start := time.Now() - event, err := parseServerSentEvent(t, buf) - if err == io.EOF { - break - } - assert.NilError(t, err) - gotPause := time.Since(start) - - // We expect to read exactly one byte on each iteration. On the - // last iteration, we expct to hit EOF after reading the final - // byte, because the server does not pause after the last write. - assert.Equal(t, event.ID, i, "unexpected SSE event ID") - - // only ensure that we pause for the expected time between writes - // after the first byte. - if i > 0 { - assert.Equal(t, gotPause, wantPauseBetweenWrites, "incorrect pause betwen writes") - } + // start timer before sending the request to ensure the client + // duration measurement is at least as long as the server's duration, + // to avoid flakiness + start := time.Now() + req := newTestRequest(t, "GET", app.URL(endpoint), nil) + resp := mustDoRequest(t, app, req) - eventCount++ + // read incremental writes in a loop. should read one byte at a time, + // despite the larger read buffer. + buf := bufio.NewReader(resp.Body) + eventCount := 0 + for i := 0; ; i++ { + event, err := parseServerSentEvent(t, buf) + if err == io.EOF { + break } + assert.NilError(t, err) + assert.Equal(t, event.ID, i, "unexpected SSE event ID") + eventCount++ + } - assert.Equal(t, eventCount, count, "unexpected number of events") - }) + // writes were incrmemental if a) we read the correct number of events + // and b) reading the whole response took (at least) the expected + // duration + assert.Equal(t, eventCount, count, "unexpected number of events") + assert.MinDuration(t, time.Since(start), duration) }) t.Run("handle cancelation during initial delay", func(t *testing.T) { @@ -3808,31 +3731,29 @@ func TestSSE(t *testing.T) { t.Run("handle cancelation during stream", func(t *testing.T) { t.Parallel() - synctest.Test(t, func(t *testing.T) { - app := setupSynctestApp(t) + app := setupTestApp(t) - ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) - defer cancel() + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() - req := newTestRequest(t, "GET", app.URL("/sse?duration=900ms&delay=0&count=2"), nil).WithContext(ctx) - resp := mustDoRequest(t, app, req) + req := newTestRequest(t, "GET", app.URL("/sse?duration=900ms&delay=0&count=2"), nil).WithContext(ctx) + resp := mustDoRequest(t, app, req) - // In this test, the server should have started an OK response before - // our client timeout cancels the request, so we should get an OK here. - assert.StatusCode(t, resp, http.StatusOK) + // In this test, the server should have started an OK response before + // our client timeout cancels the request, so we should get an OK here. + assert.StatusCode(t, resp, http.StatusOK) - // But, we should time out while trying to read the whole response - // body. - body, err := io.ReadAll(resp.Body) - if !os.IsTimeout(err) { - t.Fatalf("expected timeout reading body, got %s", err) - } + // But, we should time out while trying to read the whole response + // body. + body, err := io.ReadAll(resp.Body) + if !os.IsTimeout(err) { + t.Fatalf("expected timeout reading body, got %s", err) + } - // partial read should include the first whole event - event, err := parseServerSentEvent(t, bufio.NewReader(bytes.NewReader(body))) - assert.NilError(t, err) - assert.Equal(t, event.ID, 0, "unexpected SSE event ID") - }) + // partial read should include the first whole event + event, err := parseServerSentEvent(t, bufio.NewReader(bytes.NewReader(body))) + assert.NilError(t, err) + assert.Equal(t, event.ID, 0, "unexpected SSE event ID") }) t.Run("ensure HEAD request works with streaming responses", func(t *testing.T) { @@ -3847,8 +3768,8 @@ func TestSSE(t *testing.T) { t.Parallel() var ( - duration = 250 * time.Millisecond - delay = 100 * time.Millisecond + duration = 100 * time.Millisecond + delay = 50 * time.Millisecond count = 11 // keep numbers round by ensuring (count-1) evenly divides duration (see computePausePerWrite) params = url.Values{ "duration": {duration.String()}, @@ -3856,40 +3777,38 @@ func TestSSE(t *testing.T) { "count": {strconv.Itoa(count)}, } ) - synctest.Test(t, func(t *testing.T) { - app := setupSynctestApp(t) - req := newTestRequest(t, "GET", app.URL("/sse", params), nil) - resp := mustDoRequest(t, app, req) + app := setupTestApp(t) + req := newTestRequest(t, "GET", app.URL("/sse", params), nil) + resp := mustDoRequest(t, app, req) - // need to fully consume body for Server-Timing trailers to arrive - must.ReadAll(t, resp.Body) + // need to fully consume body for Server-Timing trailers to arrive + must.ReadAll(t, resp.Body) - rawTimings := resp.Trailer.Get("Server-Timing") - t.Logf("raw Server-Timing header value: %q", rawTimings) + rawTimings := resp.Trailer.Get("Server-Timing") + t.Logf("raw Server-Timing header value: %q", rawTimings) - timings := decodeServerTimings(rawTimings) + timings := decodeServerTimings(rawTimings) - // Ensure total server time makes sense based on duration and delay - total := timings["total_duration"] - assert.Equal(t, total.dur, duration+delay, "incorrect total_duration") + // Ensure total server time makes sense based on duration and delay + total := timings["total_duration"] + assert.MinDuration(t, total.dur, duration+delay) - // Ensure computed pause time makes sense based on duration, delay, and - // numbytes (should be exact, but we're re-parsing a truncated float in - // the header value) - pause := timings["pause_per_write"] - assert.Equal(t, pause.dur, computePausePerWrite(duration, int64(count)), "incorrect pause_per_write") + // Ensure computed pause time makes sense based on duration, delay, and + // numbytes (should be exact, but we're re-parsing a truncated float in + // the header value) + pause := timings["pause_per_write"] + assert.MinDuration(t, pause.dur, computePausePerWrite(duration, int64(count))) - // remaining timings should exactly match request parameters, no need - // to adjust for per-run variations - wantTimings := map[string]serverTiming{ - "write_duration": {"write_duration", duration, "duration of writes after initial delay"}, - "initial_delay": {"initial_delay", delay, "initial delay"}, - } - for k, want := range wantTimings { - got := timings[k] - assert.DeepEqual(t, got, want, "incorrect timing for key %q", k) - } - }) + // remaining timings should exactly match request parameters, no need + // to adjust for per-run variations + wantTimings := map[string]serverTiming{ + "write_duration": {"write_duration", duration, "duration of writes after initial delay"}, + "initial_delay": {"initial_delay", delay, "initial delay"}, + } + for k, want := range wantTimings { + got := timings[k] + assert.DeepEqual(t, got, want, "incorrect timing for key %q", k) + } }) } @@ -3941,25 +3860,21 @@ func TestWebSocketEcho(t *testing.T) { t.Run("handshake ok", func(t *testing.T) { t.Parallel() - synctest.Test(t, func(t *testing.T) { - app := setupSynctestApp(t) - req := newTestRequest(t, http.MethodGet, app.URL("/websocket/echo"), nil) - for k, v := range handshakeHeaders { - req.Header.Set(k, v) - } - resp := mustDoRequest(t, app, req) - assert.StatusCode(t, resp, http.StatusSwitchingProtocols) - }) + app := setupTestApp(t) + req := newTestRequest(t, http.MethodGet, app.URL("/websocket/echo"), nil) + for k, v := range handshakeHeaders { + req.Header.Set(k, v) + } + resp := mustDoRequest(t, app, req) + assert.StatusCode(t, resp, http.StatusSwitchingProtocols) }) t.Run("handshake failed", func(t *testing.T) { t.Parallel() - synctest.Test(t, func(t *testing.T) { - app := setupSynctestApp(t) - req := newTestRequest(t, http.MethodGet, app.URL("/websocket/echo"), nil) - resp := mustDoRequest(t, app, req) - assert.StatusCode(t, resp, http.StatusBadRequest) - }) + app := setupTestApp(t) + req := newTestRequest(t, http.MethodGet, app.URL("/websocket/echo"), nil) + resp := mustDoRequest(t, app, req) + assert.StatusCode(t, resp, http.StatusBadRequest) }) maxBodySize := 1024 @@ -3987,15 +3902,13 @@ func TestWebSocketEcho(t *testing.T) { for _, tc := range paramTests { t.Run(tc.query, func(t *testing.T) { t.Parallel() - synctest.Test(t, func(t *testing.T) { - app := setupSynctestApp(t, WithMaxBodySize(int64(maxBodySize))) - req := newTestRequest(t, http.MethodGet, app.URL("/websocket/echo?"+tc.query), nil) - for k, v := range handshakeHeaders { - req.Header.Set(k, v) - } - resp := mustDoRequest(t, app, req) - assert.StatusCode(t, resp, tc.wantStatus) - }) + app := setupTestApp(t, WithMaxBodySize(int64(maxBodySize))) + req := newTestRequest(t, http.MethodGet, app.URL("/websocket/echo?"+tc.query), nil) + for k, v := range handshakeHeaders { + req.Header.Set(k, v) + } + resp := mustDoRequest(t, app, req) + assert.StatusCode(t, resp, tc.wantStatus) }) } } @@ -4011,10 +3924,7 @@ func mustDoRequest(t *testing.T, app *appTestInfo, req *http.Request) *http.Resp t.Helper() resp, err := app.Client.Do(req) assert.NilError(t, err) - t.Cleanup(func() { - _, _ = io.Copy(io.Discard, resp.Body) - resp.Body.Close() - }) + t.Cleanup(func() { resp.Body.Close() }) return resp } diff --git a/httpbin/websocket/websocket_test.go b/httpbin/websocket/websocket_test.go index 19a3ea6..be814f6 100644 --- a/httpbin/websocket/websocket_test.go +++ b/httpbin/websocket/websocket_test.go @@ -11,12 +11,10 @@ import ( "strings" "sync" "testing" - "testing/synctest" "time" "github.com/mccutchen/go-httpbin/v2/httpbin/websocket" "github.com/mccutchen/go-httpbin/v2/internal/testing/assert" - "github.com/mccutchen/go-httpbin/v2/internal/testing/netpipetestserver" ) func TestHandshake(t *testing.T) { @@ -230,72 +228,75 @@ func TestConnectionLimits(t *testing.T) { t.Run("maximum request duration is enforced", func(t *testing.T) { t.Parallel() - maxDuration := 500 * time.Millisecond - - synctest.Test(t, func(t *testing.T) { - _, client := netpipetestserver.New(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - ws := websocket.New(w, r, websocket.Limits{ - MaxDuration: maxDuration, - // TODO: test these limits as well - MaxFragmentSize: 128, - MaxMessageSize: 256, - }) - if err := ws.Handshake(); err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return - } - ws.Serve(websocket.EchoHandler) - })) + maxDuration := 200 * time.Millisecond + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ws := websocket.New(w, r, websocket.Limits{ + MaxDuration: maxDuration, + // TODO: test these limits as well + MaxFragmentSize: 128, + MaxMessageSize: 256, + }) + if err := ws.Handshake(); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + ws.Serve(websocket.EchoHandler) + })) + defer srv.Close() + + conn, err := net.Dial("tcp", srv.Listener.Addr().String()) + assert.NilError(t, err) + defer conn.Close() + + reqParts := []string{ + "GET /websocket/echo HTTP/1.1", + "Host: test", + "Connection: upgrade", + "Upgrade: websocket", + "Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==", + "Sec-WebSocket-Version: 13", + } + reqBytes := []byte(strings.Join(reqParts, "\r\n") + "\r\n\r\n") + t.Logf("raw request:\n%q", reqBytes) + + // start timer before sending the request to ensure the client + // duration measurement is at least as long as the server's duration, + // to avoid flakiness + start := time.Now() + + // first, we write the request line and headers, which should cause the + // server to respond with a 101 Switching Protocols response. + { + n, err := conn.Write(reqBytes) + assert.NilError(t, err) + assert.Equal(t, n, len(reqBytes), "incorrect number of bytes written") - conn, err := netpipetestserver.Dial(t, client) + resp, err := http.ReadResponse(bufio.NewReader(conn), nil) assert.NilError(t, err) - defer conn.Close() - - reqParts := []string{ - "GET /websocket/echo HTTP/1.1", - "Host: test", - "Connection: upgrade", - "Upgrade: websocket", - "Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==", - "Sec-WebSocket-Version: 13", - } - reqBytes := []byte(strings.Join(reqParts, "\r\n") + "\r\n\r\n") - t.Logf("raw request:\n%q", reqBytes) - - // first, we write the request line and headers, which should cause the - // server to respond with a 101 Switching Protocols response. - { - n, err := conn.Write(reqBytes) - assert.NilError(t, err) - assert.Equal(t, n, len(reqBytes), "incorrect number of bytes written") - - resp, err := http.ReadResponse(bufio.NewReader(conn), nil) - assert.NilError(t, err) - assert.StatusCode(t, resp, http.StatusSwitchingProtocols) + assert.StatusCode(t, resp, http.StatusSwitchingProtocols) + } + + // next, we try to read from the connection, expecting the connection + // to be closed after roughly maxDuration seconds + { + resp, err := io.ReadAll(conn) + elapsed := time.Since(start) + // we sometimes get a non-nil error and some garbage in the + // (partial?) resp read from the server, like + // + // \x88\x18\x03\xf3read pipe: i/o timeout + // + // So for now we make sure the test took the expected amount + // of time and only validate the error if we got one. + if err != nil { + assert.Error(t, err, io.EOF) } - - // next, we try to read from the connection, expecting the connection - // to be closed after roughly maxDuration seconds - { - start := time.Now() - resp, err := io.ReadAll(conn) - elapsed := time.Since(start) - // we sometimes get a non-nil error and some garbage in the - // (partial?) resp read from the server, like - // - // \x88\x18\x03\xf3read pipe: i/o timeout - // - // So for now we make sure the test took the expected amount - // of time and only validate the error if we got one. - if err != nil { - assert.Error(t, err, io.EOF) - } - if len(resp) > 0 { - t.Logf("unexpected response data: %q", resp) - } - assert.Equal(t, elapsed, maxDuration, "incorrect elapsed time") + if len(resp) > 0 { + t.Logf("unexpected response data: %q", resp) } - }) + assert.MinDuration(t, elapsed, maxDuration) + } }) t.Run("client closing connection", func(t *testing.T) { @@ -312,78 +313,80 @@ func TestConnectionLimits(t *testing.T) { wg sync.WaitGroup ) - synctest.Test(t, func(t *testing.T) { - wg.Add(1) - - _, client := netpipetestserver.New(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - defer wg.Done() - start := time.Now() - ws := websocket.New(w, r, websocket.Limits{ - MaxDuration: serverTimeout, - MaxFragmentSize: 128, - MaxMessageSize: 256, - }) - if err := ws.Handshake(); err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return - } - ws.Serve(websocket.EchoHandler) - elapsedServerTime = time.Since(start) - })) - - conn, err := netpipetestserver.Dial(t, client) - assert.NilError(t, err) - defer conn.Close() - - // should cause the client end of the connection to close well before - // the max request time configured above - conn.SetDeadline(time.Now().Add(clientTimeout)) - - reqParts := []string{ - "GET /websocket/echo HTTP/1.1", - "Host: test", - "Connection: upgrade", - "Upgrade: websocket", - "Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==", - "Sec-WebSocket-Version: 13", - } - reqBytes := []byte(strings.Join(reqParts, "\r\n") + "\r\n\r\n") - t.Logf("raw request:\n%q", reqBytes) - - // first, we write the request line and headers, which should cause the - // server to respond with a 101 Switching Protocols response. - { - n, err := conn.Write(reqBytes) - assert.NilError(t, err) - assert.Equal(t, n, len(reqBytes), "incorrect number of bytes written") - - resp, err := http.ReadResponse(bufio.NewReader(conn), nil) - assert.NilError(t, err) - assert.StatusCode(t, resp, http.StatusSwitchingProtocols) + wg.Add(1) + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + defer wg.Done() + start := time.Now() + ws := websocket.New(w, r, websocket.Limits{ + MaxDuration: serverTimeout, + MaxFragmentSize: 128, + MaxMessageSize: 256, + }) + if err := ws.Handshake(); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return } + ws.Serve(websocket.EchoHandler) + elapsedServerTime = time.Since(start) + })) + defer srv.Close() + + conn, err := net.Dial("tcp", srv.Listener.Addr().String()) + assert.NilError(t, err) + defer conn.Close() + + reqParts := []string{ + "GET /websocket/echo HTTP/1.1", + "Host: test", + "Connection: upgrade", + "Upgrade: websocket", + "Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==", + "Sec-WebSocket-Version: 13", + } + reqBytes := []byte(strings.Join(reqParts, "\r\n") + "\r\n\r\n") + t.Logf("raw request:\n%q", reqBytes) + + // start client timer before setting conn deadline or writing the + // request to ensure the client duration measurement is at least as + // long as the server's duration to avoid flakiness + start := time.Now() + + // deadline should cause the client end of the connection to close + // well before the max request time configured above + conn.SetDeadline(time.Now().Add(clientTimeout)) + + // first, we write the request line and headers, which should cause the + // server to respond with a 101 Switching Protocols response. + { + n, err := conn.Write(reqBytes) + assert.NilError(t, err) + assert.Equal(t, n, len(reqBytes), "incorrect number of bytes written") - // next, we try to read from the connection, expecting the connection - // to be closed after roughly clientTimeout seconds. - // - // the server should detect the closed connection and abort the - // handler, also after roughly clientTimeout seconds. - { - start := time.Now() - _, err := conn.Read(make([]byte, 1)) - elapsedClientTime = time.Since(start) - - // close client connection, which should interrupt the server's - // blocking read call on the connection - conn.Close() - - assert.Equal(t, os.IsTimeout(err), true, "expected timeout error") - assert.Equal(t, elapsedClientTime, clientTimeout, "incorrect elapsed client time") - - // wait for the server to finish - wg.Wait() - assert.Equal(t, elapsedServerTime, clientTimeout, "incorrect elapsed server time") - } - }) + resp, err := http.ReadResponse(bufio.NewReader(conn), nil) + assert.NilError(t, err) + assert.StatusCode(t, resp, http.StatusSwitchingProtocols) + } + + // next, we try to read from the connection, expecting the connection + // to be closed after roughly clientTimeout seconds. + // + // the server should detect the closed connection and abort the + // handler, also after roughly clientTimeout seconds. + { + _, err := conn.Read(make([]byte, 1)) + elapsedClientTime = time.Since(start) + + // close client connection, which should interrupt the server's + // blocking read call on the connection + conn.Close() + + assert.Equal(t, os.IsTimeout(err), true, "expected timeout error") + assert.MinDuration(t, elapsedClientTime, clientTimeout) + + // wait for the server to finish + wg.Wait() + assert.MinDuration(t, elapsedServerTime, clientTimeout) + } }) } diff --git a/internal/testing/assert/assert.go b/internal/testing/assert/assert.go index 03bd8b7..a7bd451 100644 --- a/internal/testing/assert/assert.go +++ b/internal/testing/assert/assert.go @@ -7,6 +7,7 @@ import ( "reflect" "strings" "testing" + "time" "github.com/mccutchen/go-httpbin/v2/internal/testing/must" ) @@ -56,6 +57,14 @@ func Error(t *testing.T, got, expected error) { } } +// MinDuration asserts that got >= min. +func MinDuration(t *testing.T, got time.Duration, min time.Duration) { + t.Helper() + if got < min { + t.Errorf("expected duration %s >= %s", got, min) + } +} + // StatusCode asserts that a response has a specific status code. func StatusCode(t *testing.T, resp *http.Response, code int) { t.Helper() diff --git a/internal/testing/netpipetestserver/netpipetestserver.go b/internal/testing/netpipetestserver/netpipetestserver.go deleted file mode 100644 index 71512eb..0000000 --- a/internal/testing/netpipetestserver/netpipetestserver.go +++ /dev/null @@ -1,110 +0,0 @@ -// Package netpipetestserver provides [httptest.Server] and [http.Client] -// pairs configured to work within a synctest bubble by swapping the network -// for a pair of in-memory [net.Pipe] connections. -package netpipetestserver - -import ( - "context" - "net" - "net/http" - "net/http/httptest" - "testing" -) - -// New creates a new httptest.Server and http.Client pair suitable for use -// within a synctest bubble, which can't span real network connections. The -// server and client communicate over a pair of in-memory [net.Pipe] -// connections. -func New(t *testing.T, handler http.Handler) (*httptest.Server, *http.Client) { - t.Helper() - - ln := newNetPipeListener() - srv := httptest.NewUnstartedServer(handler) - srv.Listener = ln - srv.Start() - t.Cleanup(srv.Close) - - client := srv.Client() - client.Transport.(*http.Transport).DialContext = ln.DialContext - - return srv, client -} - -// netPipeListener is used to enable a server to accept connections from -// clients via [net.Pipe]. Client transports must use the server listener's -// DialContext method to make connections. -type netPipeListener struct { - connCh chan net.Conn - done chan struct{} - addr netPipeAddr -} - -func newNetPipeListener() *netPipeListener { - return &netPipeListener{ - connCh: make(chan net.Conn), - done: make(chan struct{}), - } -} - -// Accept accepts connections via [net.Pipe] from clients using the listener's -// own [DialContext] method. -func (ln *netPipeListener) Accept() (net.Conn, error) { - select { - case conn := <-ln.connCh: - return conn, nil - case <-ln.done: - return nil, net.ErrClosed - } -} - -// Close closes the listener. -func (ln *netPipeListener) Close() error { - select { - case <-ln.done: - default: - close(ln.done) - } - return nil -} - -// Dial allows tests using netPipeListener to directly access the underlying -// client connection via the [http.Client]'s transport. The client must be -// one created by [New]. -func Dial(t testing.TB, client *http.Client) (net.Conn, error) { - t.Helper() - addr := netPipeAddr{} - return client.Transport.(*http.Transport).DialContext(t.Context(), addr.Network(), addr.String()) -} - -// Addr returns a dummy [net.Addr] implementation. To actually connect to a -// listener, the listener's own [DialContext] method must be used (e.g. in -// the client's [http.Transport]). -func (ln *netPipeListener) Addr() net.Addr { - return ln.addr -} - -// DialContext creates both client and server conns via [net.Pipe] and -// returns the client conn. The server conn is enqueued for the listener to -// pick up in its [Accept] method. -func (ln *netPipeListener) DialContext(ctx context.Context, _, _ string) (net.Conn, error) { - clientConn, serverConn := net.Pipe() - select { - case ln.connCh <- serverConn: - return clientConn, nil - case <-ln.done: - clientConn.Close() - serverConn.Close() - return nil, net.ErrClosed - case <-ctx.Done(): - clientConn.Close() - serverConn.Close() - return nil, ctx.Err() - } -} - -type netPipeAddr struct{} - -func (netPipeAddr) Network() string { return "tcp" } -func (netPipeAddr) String() string { return "netpipetestserver:0" } - -var _ net.Addr = netPipeAddr{}