diff --git a/read.go b/read.go index 64822511..871b1105 100644 --- a/read.go +++ b/read.go @@ -217,51 +217,48 @@ func (c *Conn) readLoop(ctx context.Context) (header, error) { } } -// prepareRead sets the readTimeout context and returns a done function -// to be called after the read is done. It also returns an error if the -// connection is closed. The reference to the error is used to assign -// an error depending on if the connection closed or the context timed -// out during use. Typically, the referenced error is a named return -// variable of the function calling this method. -func (c *Conn) prepareRead(ctx context.Context, err *error) (func(), error) { +// prepareRead sets the read timeout and checks whether the connection is closed. +func (c *Conn) prepareRead(ctx context.Context) error { select { case <-c.closed: - return nil, net.ErrClosed + return net.ErrClosed default: } c.setupReadTimeout(ctx) - done := func() { - c.clearReadTimeout() - select { - case <-c.closed: - if *err != nil { - *err = net.ErrClosed - } - default: - } - if *err != nil && ctx.Err() != nil { - *err = ctx.Err() - } - } - c.closeStateMu.Lock() closeReceivedErr := c.closeReceivedErr c.closeStateMu.Unlock() if closeReceivedErr != nil { - defer done() - return nil, closeReceivedErr + c.clearReadTimeout() + return closeReceivedErr } - return done, nil + return nil +} + +// finishRead clears the read timeout and reports whether the connection or +// operation context ended while the read was in progress. +func (c *Conn) finishRead(ctx context.Context, err *error) { + c.clearReadTimeout() + select { + case <-c.closed: + if *err != nil { + *err = net.ErrClosed + } + default: + } + if *err != nil && ctx.Err() != nil { + *err = ctx.Err() + } } func (c *Conn) readFrameHeader(ctx context.Context) (_ header, err error) { - readDone, err := c.prepareRead(ctx, &err) + err = c.prepareRead(ctx) if err != nil { return header{}, err } - defer readDone() + defer c.finishRead(ctx, &err) h, err := readFrameHeader(c.br, c.readHeaderBuf[:]) if err != nil { @@ -272,11 +269,11 @@ func (c *Conn) readFrameHeader(ctx context.Context) (_ header, err error) { } func (c *Conn) readFramePayload(ctx context.Context, p []byte) (_ int, err error) { - readDone, err := c.prepareRead(ctx, &err) + err = c.prepareRead(ctx) if err != nil { return 0, err } - defer readDone() + defer c.finishRead(ctx, &err) n, err := io.ReadFull(c.br, p) if err != nil {