Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 24 additions & 5 deletions internal/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -1184,11 +1184,24 @@ func (s *Server) defaultForwardRequestWithBodyFunc(w http.ResponseWriter, ctx co
return err
}

// proxyWebSocketCopy copies messages from src to dst
// proxyWebSocketCopy copies messages from src to dst, forwarding close frames
// to the destination so both peers receive a proper WebSocket close handshake.
func proxyWebSocketCopy(src, dst *websocket.Conn) error {
for {
msgType, msg, err := src.ReadMessage()
if err != nil {
if closeErr, ok := err.(*websocket.CloseError); ok {
code := closeErr.Code
// RFC 6455: 1005, 1006, 1015 must not be sent on the wire.
switch code {
case websocket.CloseNoStatusReceived,
websocket.CloseAbnormalClosure,
websocket.CloseTLSHandshake:
code = websocket.CloseGoingAway
}
_ = dst.WriteMessage(websocket.CloseMessage,
websocket.FormatCloseMessage(code, closeErr.Text))
}
return err
}
Comment thread
coderabbitai[bot] marked this conversation as resolved.
if err := dst.WriteMessage(msgType, msg); err != nil {
Expand Down Expand Up @@ -1294,8 +1307,6 @@ func (s *Server) defaultProxyWebSocket(w http.ResponseWriter, r *http.Request, b
backendConn.Close()
return err
}
defer clientConn.Close()

// Proxy messages in both directions
errc := make(chan error, 2)
go func() {
Expand All @@ -1306,13 +1317,21 @@ func (s *Server) defaultProxyWebSocket(w http.ResponseWriter, r *http.Request, b
err := proxyWebSocketCopy(backendConn, clientConn)
errc <- err
}()
// Wait for one direction to fail/close
// Wait for one direction to fail/close, then immediately close both
// connections so the other goroutine unblocks and finishes cleanly.
err = <-errc
clientConn.Close()
backendConn.Close()
<-errc // wait for the second goroutine to finish

// Mark endpoint as unhealthy for WS if error is not a normal closure
if err != nil {
if isExpectedWSClose(err) {
log.Debug().Err(err).Str("endpoint", helpers.RedactAPIKey(backendURL)).Msg("WebSocket connection closed normally")
if closeErr, ok := err.(*websocket.CloseError); ok && closeErr.Code == websocket.CloseAbnormalClosure {
log.Debug().Err(err).Str("endpoint", helpers.RedactAPIKey(backendURL)).Msg("WebSocket connection closed abnormally (1006), not counting as failure")
} else {
log.Debug().Err(err).Str("endpoint", helpers.RedactAPIKey(backendURL)).Msg("WebSocket connection closed normally")
}
return nil
}
if chain, endpointID, found := s.findChainAndEndpointByURL(backendURL); found {
Expand Down
Loading