Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 26 additions & 29 deletions read.go
Original file line number Diff line number Diff line change
Expand Up @@ -217,51 +217,48 @@ func (c *Conn) readLoop(ctx context.Context) (header, error) {
}
}

// prepareRead sets the readTimeout context and returns a done function
// to be called after the read is done. It also returns an error if the
// connection is closed. The reference to the error is used to assign
// an error depending on if the connection closed or the context timed
// out during use. Typically, the referenced error is a named return
// variable of the function calling this method.
func (c *Conn) prepareRead(ctx context.Context, err *error) (func(), error) {
// prepareRead sets the read timeout and checks whether the connection is closed.
func (c *Conn) prepareRead(ctx context.Context) error {
select {
case <-c.closed:
return nil, net.ErrClosed
return net.ErrClosed
default:
}
c.setupReadTimeout(ctx)

done := func() {
c.clearReadTimeout()
select {
case <-c.closed:
if *err != nil {
*err = net.ErrClosed
}
default:
}
if *err != nil && ctx.Err() != nil {
*err = ctx.Err()
}
}

c.closeStateMu.Lock()
closeReceivedErr := c.closeReceivedErr
c.closeStateMu.Unlock()
if closeReceivedErr != nil {
defer done()
return nil, closeReceivedErr
c.clearReadTimeout()
return closeReceivedErr
}

return done, nil
return nil
}

// finishRead clears the read timeout and reports whether the connection or
// operation context ended while the read was in progress.
func (c *Conn) finishRead(ctx context.Context, err *error) {
c.clearReadTimeout()
select {
case <-c.closed:
if *err != nil {
*err = net.ErrClosed
}
default:
}
if *err != nil && ctx.Err() != nil {
*err = ctx.Err()
}
}

func (c *Conn) readFrameHeader(ctx context.Context) (_ header, err error) {
readDone, err := c.prepareRead(ctx, &err)
err = c.prepareRead(ctx)
if err != nil {
return header{}, err
}
defer readDone()
defer c.finishRead(ctx, &err)

h, err := readFrameHeader(c.br, c.readHeaderBuf[:])
if err != nil {
Expand All @@ -272,11 +269,11 @@ func (c *Conn) readFrameHeader(ctx context.Context) (_ header, err error) {
}

func (c *Conn) readFramePayload(ctx context.Context, p []byte) (_ int, err error) {
readDone, err := c.prepareRead(ctx, &err)
err = c.prepareRead(ctx)
if err != nil {
return 0, err
}
defer readDone()
defer c.finishRead(ctx, &err)

n, err := io.ReadFull(c.br, p)
if err != nil {
Expand Down
Loading