diff --git a/internal/server/server.go b/internal/server/server.go index d2a29f5..9262a6c 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -1184,11 +1184,24 @@ func (s *Server) defaultForwardRequestWithBodyFunc(w http.ResponseWriter, ctx co return err } -// proxyWebSocketCopy copies messages from src to dst +// proxyWebSocketCopy copies messages from src to dst, forwarding close frames +// to the destination so both peers receive a proper WebSocket close handshake. func proxyWebSocketCopy(src, dst *websocket.Conn) error { for { msgType, msg, err := src.ReadMessage() if err != nil { + if closeErr, ok := err.(*websocket.CloseError); ok { + code := closeErr.Code + // RFC 6455: 1005, 1006, 1015 must not be sent on the wire. + switch code { + case websocket.CloseNoStatusReceived, + websocket.CloseAbnormalClosure, + websocket.CloseTLSHandshake: + code = websocket.CloseGoingAway + } + _ = dst.WriteMessage(websocket.CloseMessage, + websocket.FormatCloseMessage(code, closeErr.Text)) + } return err } if err := dst.WriteMessage(msgType, msg); err != nil { @@ -1294,8 +1307,6 @@ func (s *Server) defaultProxyWebSocket(w http.ResponseWriter, r *http.Request, b backendConn.Close() return err } - defer clientConn.Close() - // Proxy messages in both directions errc := make(chan error, 2) go func() { @@ -1306,13 +1317,21 @@ func (s *Server) defaultProxyWebSocket(w http.ResponseWriter, r *http.Request, b err := proxyWebSocketCopy(backendConn, clientConn) errc <- err }() - // Wait for one direction to fail/close + // Wait for one direction to fail/close, then immediately close both + // connections so the other goroutine unblocks and finishes cleanly. err = <-errc + clientConn.Close() + backendConn.Close() + <-errc // wait for the second goroutine to finish // Mark endpoint as unhealthy for WS if error is not a normal closure if err != nil { if isExpectedWSClose(err) { - log.Debug().Err(err).Str("endpoint", helpers.RedactAPIKey(backendURL)).Msg("WebSocket connection closed normally") + if closeErr, ok := err.(*websocket.CloseError); ok && closeErr.Code == websocket.CloseAbnormalClosure { + log.Debug().Err(err).Str("endpoint", helpers.RedactAPIKey(backendURL)).Msg("WebSocket connection closed abnormally (1006), not counting as failure") + } else { + log.Debug().Err(err).Str("endpoint", helpers.RedactAPIKey(backendURL)).Msg("WebSocket connection closed normally") + } return nil } if chain, endpointID, found := s.findChainAndEndpointByURL(backendURL); found {