From 4b61bc94d6e273e3117a8f942cffd503ad0de82c Mon Sep 17 00:00:00 2001 From: Aaron3S Date: Thu, 18 Jun 2026 17:54:07 +0800 Subject: [PATCH] feat: check origin --- .../chen/web/config/WebSocketConfig.java | 62 ++++++++++++++++++- 1 file changed, 60 insertions(+), 2 deletions(-) diff --git a/backend/web/src/main/java/org/jumpserver/chen/web/config/WebSocketConfig.java b/backend/web/src/main/java/org/jumpserver/chen/web/config/WebSocketConfig.java index 4eba094..73d5417 100644 --- a/backend/web/src/main/java/org/jumpserver/chen/web/config/WebSocketConfig.java +++ b/backend/web/src/main/java/org/jumpserver/chen/web/config/WebSocketConfig.java @@ -5,6 +5,7 @@ import org.jumpserver.chen.framework.ws.SessionWebSocketHandler; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; +import org.springframework.http.HttpStatus; import org.springframework.http.server.ServerHttpRequest; import org.springframework.http.server.ServerHttpResponse; import org.springframework.web.socket.WebSocketHandler; @@ -14,8 +15,9 @@ import org.springframework.web.socket.server.HandshakeInterceptor; import org.springframework.web.socket.server.standard.ServletServerContainerFactoryBean; -import java.util.Map; -import java.util.Objects; +import java.net.URI; +import java.util.*; +import java.util.stream.Collectors; @Configuration @EnableWebSocket @@ -43,10 +45,66 @@ public void registerWebSocketHandlers(WebSocketHandlerRegistry registry) { .setAllowedOrigins("*"); } + private static final Set TRUSTED_DOMAINS = + Arrays.stream( + Optional.ofNullable(System.getenv("DOMAINS")) + .orElse("") + .split(",") + ) + .map(String::trim) + .filter(s -> !s.isEmpty()) + .collect(Collectors.toSet()); + + + private static boolean checkOrigin(String origin) { + if (origin == null || origin.isBlank()) { + return true; + } + + if (TRUSTED_DOMAINS.contains("*")) { + return true; + } + + try { + URI uri = URI.create(origin); + + String host = uri.getHost(); + int port = uri.getPort(); + + // 本机访问直接放行 + if ("localhost".equalsIgnoreCase(host) + || "127.0.0.1".equals(host) + || "::1".equals(host) + || "0:0:0:0:0:0:0:1".equals(host)) { + return true; + } + + + String hostPort = + port > 0 + ? host + ":" + port + : host; + + return TRUSTED_DOMAINS.contains(hostPort) + || TRUSTED_DOMAINS.contains(host); + + } catch (Exception e) { + return false; + } + } + public static class ServletWebSocketHandshakeInterceptor implements HandshakeInterceptor { @Override public boolean beforeHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler, Map attributes) throws Exception { + + String origin = request.getHeaders().getOrigin(); + + if (!checkOrigin(origin)) { + response.setStatusCode(HttpStatus.FORBIDDEN); + return false; + } + var token = request.getHeaders().get("Sec-WebSocket-Protocol").get(0); attributes.put("token", token); response.getHeaders().put("Sec-WebSocket-Protocol", Objects.requireNonNull(request.getHeaders().get("Sec-WebSocket-Protocol")));