Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -168,24 +168,34 @@ func (c *Conn) close() error {
return err
}

func (c *Conn) setupWriteTimeout(ctx context.Context) {
func (c *Conn) setupWriteTimeout(ctx context.Context) bool {
if ctx.Done() == nil {
return false
}

stop := context.AfterFunc(ctx, func() {
c.clearWriteTimeout()
c.close()
})
swapTimeoutStop(&c.writeTimeoutStop, &stop)
return true
}

func (c *Conn) clearWriteTimeout() {
swapTimeoutStop(&c.writeTimeoutStop, nil)
}

func (c *Conn) setupReadTimeout(ctx context.Context) {
func (c *Conn) setupReadTimeout(ctx context.Context) bool {
if ctx.Done() == nil {
return false
}

stop := context.AfterFunc(ctx, func() {
c.clearReadTimeout()
c.close()
})
swapTimeoutStop(&c.readTimeoutStop, &stop)
return true
}

func (c *Conn) clearReadTimeout() {
Expand Down
10 changes: 9 additions & 1 deletion conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -592,10 +592,11 @@ func BenchmarkConn(b *testing.B) {
msg := []byte(strings.Repeat("1234", 128))
readBuf := make([]byte, len(msg))
writes := make(chan struct{})
defer close(writes)
werrs := make(chan error)
writerDone := make(chan struct{})

go func() {
defer close(writerDone)
for range writes {
select {
case werrs <- c1.Write(bb.ctx, websocket.MessageText, msg):
Expand Down Expand Up @@ -650,6 +651,13 @@ func BenchmarkConn(b *testing.B) {
}
b.StopTimer()

close(writes)
select {
case <-writerDone:
case <-bb.ctx.Done():
b.Fatal(bb.ctx.Err())
}

b.ReportMetric(float64(*bytesWritten/b.N), "written/op")
b.ReportMetric(float64(*bytesRead/b.N), "read/op")

Expand Down
28 changes: 16 additions & 12 deletions read.go
Original file line number Diff line number Diff line change
Expand Up @@ -218,29 +218,33 @@ func (c *Conn) readLoop(ctx context.Context) (header, error) {
}

// prepareRead sets the read timeout and checks whether the connection is closed.
func (c *Conn) prepareRead(ctx context.Context) error {
func (c *Conn) prepareRead(ctx context.Context) (bool, error) {
select {
case <-c.closed:
return net.ErrClosed
return false, net.ErrClosed
default:
}
c.setupReadTimeout(ctx)
timeoutSet := c.setupReadTimeout(ctx)

c.closeStateMu.Lock()
closeReceivedErr := c.closeReceivedErr
c.closeStateMu.Unlock()
if closeReceivedErr != nil {
c.clearReadTimeout()
return closeReceivedErr
if timeoutSet {
c.clearReadTimeout()
}
return false, closeReceivedErr
}

return nil
return timeoutSet, nil
}

// finishRead clears the read timeout and reports whether the connection or
// operation context ended while the read was in progress.
func (c *Conn) finishRead(ctx context.Context, err *error) {
c.clearReadTimeout()
func (c *Conn) finishRead(ctx context.Context, err *error, timeoutSet bool) {
if timeoutSet {
c.clearReadTimeout()
}
select {
case <-c.closed:
if *err != nil {
Expand All @@ -254,11 +258,11 @@ func (c *Conn) finishRead(ctx context.Context, err *error) {
}

func (c *Conn) readFrameHeader(ctx context.Context) (_ header, err error) {
err = c.prepareRead(ctx)
timeoutSet, err := c.prepareRead(ctx)
if err != nil {
return header{}, err
}
defer c.finishRead(ctx, &err)
defer c.finishRead(ctx, &err, timeoutSet)

h, err := readFrameHeader(c.br, c.readHeaderBuf[:])
if err != nil {
Expand All @@ -269,11 +273,11 @@ func (c *Conn) readFrameHeader(ctx context.Context) (_ header, err error) {
}

func (c *Conn) readFramePayload(ctx context.Context, p []byte) (_ int, err error) {
err = c.prepareRead(ctx)
timeoutSet, err := c.prepareRead(ctx)
if err != nil {
return 0, err
}
defer c.finishRead(ctx, &err)
defer c.finishRead(ctx, &err, timeoutSet)

n, err := io.ReadFull(c.br, p)
if err != nil {
Expand Down
5 changes: 3 additions & 2 deletions write.go
Original file line number Diff line number Diff line change
Expand Up @@ -318,8 +318,9 @@ func (c *Conn) writeFrame(ctx context.Context, fin bool, flate bool, opcode opco
return 0, net.ErrClosed
default:
}
c.setupWriteTimeout(ctx)
defer c.clearWriteTimeout()
if c.setupWriteTimeout(ctx) {
defer c.clearWriteTimeout()
}

c.writeHeader.fin = fin
c.writeHeader.opcode = opcode
Expand Down
Loading