diff --git a/java/src/main/java/com/microsoft/tunnels/websocket/WebSocketConnector.java b/java/src/main/java/com/microsoft/tunnels/websocket/WebSocketConnector.java index 85ddfc88..9f459912 100644 --- a/java/src/main/java/com/microsoft/tunnels/websocket/WebSocketConnector.java +++ b/java/src/main/java/com/microsoft/tunnels/websocket/WebSocketConnector.java @@ -19,6 +19,9 @@ import java.net.SocketAddress; import java.net.URI; +import javax.net.ssl.SSLEngine; +import javax.net.ssl.SSLParameters; + import org.apache.sshd.common.AttributeRepository; import org.apache.sshd.common.io.IoConnectFuture; import org.apache.sshd.common.io.IoHandler; @@ -69,9 +72,27 @@ protected void initChannel(SocketChannel ch) throws Exception { ChannelPipeline p = ch.pipeline(); if (factory.webSocketUri.getScheme().equals("wss")) { - SslContext sslContext = SslContextBuilder.forClient() - .trustManager(InsecureTrustManagerFactory.INSTANCE).build(); - p.addLast("ssl", new SslHandler(sslContext.newEngine(ch.alloc()))); + String host = factory.webSocketUri.getHost(); + boolean isLocalDev = "localhost".equals(host) + || "tunnels.local.api.visualstudio.com".equals(host); + + SslContextBuilder builder = SslContextBuilder.forClient(); + if (isLocalDev) { + builder.trustManager(InsecureTrustManagerFactory.INSTANCE); + } + SslContext sslContext = builder.build(); + + var relayPort = factory.webSocketUri.getPort(); + if (relayPort == -1) { + relayPort = 443; + } + SSLEngine engine = sslContext.newEngine(ch.alloc(), host, relayPort); + if (!isLocalDev) { + SSLParameters params = engine.getSSLParameters(); + params.setEndpointIdentificationAlgorithm("HTTPS"); + engine.setSSLParameters(params); + } + p.addLast("ssl", new SslHandler(engine)); } p.addLast(new HttpClientCodec()); p.addLast(new HttpObjectAggregator(8192));