diff --git a/conn.go b/conn.go index 09234871..bfc78aaa 100644 --- a/conn.go +++ b/conn.go @@ -168,24 +168,34 @@ func (c *Conn) close() error { return err } -func (c *Conn) setupWriteTimeout(ctx context.Context) { +func (c *Conn) setupWriteTimeout(ctx context.Context) bool { + if ctx.Done() == nil { + return false + } + stop := context.AfterFunc(ctx, func() { c.clearWriteTimeout() c.close() }) swapTimeoutStop(&c.writeTimeoutStop, &stop) + return true } func (c *Conn) clearWriteTimeout() { swapTimeoutStop(&c.writeTimeoutStop, nil) } -func (c *Conn) setupReadTimeout(ctx context.Context) { +func (c *Conn) setupReadTimeout(ctx context.Context) bool { + if ctx.Done() == nil { + return false + } + stop := context.AfterFunc(ctx, func() { c.clearReadTimeout() c.close() }) swapTimeoutStop(&c.readTimeoutStop, &stop) + return true } func (c *Conn) clearReadTimeout() { diff --git a/conn_test.go b/conn_test.go index c3ccc886..01af8cd4 100644 --- a/conn_test.go +++ b/conn_test.go @@ -592,10 +592,11 @@ func BenchmarkConn(b *testing.B) { msg := []byte(strings.Repeat("1234", 128)) readBuf := make([]byte, len(msg)) writes := make(chan struct{}) - defer close(writes) werrs := make(chan error) + writerDone := make(chan struct{}) go func() { + defer close(writerDone) for range writes { select { case werrs <- c1.Write(bb.ctx, websocket.MessageText, msg): @@ -650,6 +651,13 @@ func BenchmarkConn(b *testing.B) { } b.StopTimer() + close(writes) + select { + case <-writerDone: + case <-bb.ctx.Done(): + b.Fatal(bb.ctx.Err()) + } + b.ReportMetric(float64(*bytesWritten/b.N), "written/op") b.ReportMetric(float64(*bytesRead/b.N), "read/op") diff --git a/read.go b/read.go index 871b1105..c19f41de 100644 --- a/read.go +++ b/read.go @@ -218,29 +218,33 @@ func (c *Conn) readLoop(ctx context.Context) (header, error) { } // prepareRead sets the read timeout and checks whether the connection is closed. -func (c *Conn) prepareRead(ctx context.Context) error { +func (c *Conn) prepareRead(ctx context.Context) (bool, error) { select { case <-c.closed: - return net.ErrClosed + return false, net.ErrClosed default: } - c.setupReadTimeout(ctx) + timeoutSet := c.setupReadTimeout(ctx) c.closeStateMu.Lock() closeReceivedErr := c.closeReceivedErr c.closeStateMu.Unlock() if closeReceivedErr != nil { - c.clearReadTimeout() - return closeReceivedErr + if timeoutSet { + c.clearReadTimeout() + } + return false, closeReceivedErr } - return nil + return timeoutSet, nil } // finishRead clears the read timeout and reports whether the connection or // operation context ended while the read was in progress. -func (c *Conn) finishRead(ctx context.Context, err *error) { - c.clearReadTimeout() +func (c *Conn) finishRead(ctx context.Context, err *error, timeoutSet bool) { + if timeoutSet { + c.clearReadTimeout() + } select { case <-c.closed: if *err != nil { @@ -254,11 +258,11 @@ func (c *Conn) finishRead(ctx context.Context, err *error) { } func (c *Conn) readFrameHeader(ctx context.Context) (_ header, err error) { - err = c.prepareRead(ctx) + timeoutSet, err := c.prepareRead(ctx) if err != nil { return header{}, err } - defer c.finishRead(ctx, &err) + defer c.finishRead(ctx, &err, timeoutSet) h, err := readFrameHeader(c.br, c.readHeaderBuf[:]) if err != nil { @@ -269,11 +273,11 @@ func (c *Conn) readFrameHeader(ctx context.Context) (_ header, err error) { } func (c *Conn) readFramePayload(ctx context.Context, p []byte) (_ int, err error) { - err = c.prepareRead(ctx) + timeoutSet, err := c.prepareRead(ctx) if err != nil { return 0, err } - defer c.finishRead(ctx, &err) + defer c.finishRead(ctx, &err, timeoutSet) n, err := io.ReadFull(c.br, p) if err != nil { diff --git a/write.go b/write.go index 2e12153b..81bab4fa 100644 --- a/write.go +++ b/write.go @@ -318,8 +318,9 @@ func (c *Conn) writeFrame(ctx context.Context, fin bool, flate bool, opcode opco return 0, net.ErrClosed default: } - c.setupWriteTimeout(ctx) - defer c.clearWriteTimeout() + if c.setupWriteTimeout(ctx) { + defer c.clearWriteTimeout() + } c.writeHeader.fin = fin c.writeHeader.opcode = opcode