-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathWebSocketTransport.java
More file actions
159 lines (145 loc) · 5.51 KB
/
WebSocketTransport.java
File metadata and controls
159 lines (145 loc) · 5.51 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
package dev.arcp.client;
import com.fasterxml.jackson.databind.ObjectMapper;
import dev.arcp.core.transport.Transport;
import dev.arcp.core.wire.ArcpMapper;
import dev.arcp.core.wire.Envelope;
import java.io.IOException;
import java.io.UncheckedIOException;
import java.net.URI;
import java.net.http.HttpClient;
import java.net.http.WebSocket;
import java.time.Duration;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;
import java.util.concurrent.CompletionStage;
import java.util.concurrent.Executors;
import java.util.concurrent.Flow;
import java.util.concurrent.SubmissionPublisher;
import java.util.concurrent.TimeUnit;
import org.jspecify.annotations.Nullable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* Client-side ARCP {@link Transport} backed by the JDK {@link java.net.http.HttpClient.WebSocket}.
* JSON envelopes ride as text frames; multi-part text frames are reassembled per {@code last}
* delivery.
*/
public final class WebSocketTransport implements Transport {
private static final Logger log = LoggerFactory.getLogger(WebSocketTransport.class);
private final WebSocket socket;
private final ObjectMapper mapper;
private final SubmissionPublisher<Envelope> inbound;
private final StringBuilder partial = new StringBuilder();
private WebSocketTransport(WebSocket socket, ObjectMapper mapper) {
this.socket = socket;
this.mapper = mapper;
this.inbound = new SubmissionPublisher<>(Executors.newVirtualThreadPerTaskExecutor(), 1024);
}
/** Open a WebSocket connection to {@code uri} and return a connected transport. */
public static WebSocketTransport connect(URI uri) throws InterruptedException {
return connect(uri, Map.of(), ArcpMapper.shared(), Duration.ofSeconds(10));
}
public static WebSocketTransport connect(
URI uri, Map<String, String> headers, ObjectMapper mapper, Duration timeout)
throws InterruptedException {
HttpClient httpClient = HttpClient.newHttpClient();
WebSocket.Builder builder = httpClient.newWebSocketBuilder();
for (var entry : headers.entrySet()) {
builder.header(entry.getKey(), entry.getValue());
}
var futureSocket = new java.util.concurrent.atomic.AtomicReference<WebSocketTransport>();
CompletableFuture<WebSocket> stage =
builder.buildAsync(
uri,
new WebSocket.Listener() {
@Override
public void onOpen(WebSocket webSocket) {
webSocket.request(1);
}
@Override
public @Nullable CompletionStage<?> onText(
WebSocket webSocket, CharSequence data, boolean last) {
WebSocketTransport t = futureSocket.get();
if (t != null) {
t.handleText(data, last);
}
webSocket.request(1);
return null;
}
@Override
public @Nullable CompletionStage<?> onClose(
WebSocket webSocket, int statusCode, String reason) {
WebSocketTransport t = futureSocket.get();
if (t != null) {
t.inbound.close();
}
return null;
}
@Override
public void onError(WebSocket webSocket, Throwable error) {
WebSocketTransport t = futureSocket.get();
if (t != null) {
t.inbound.closeExceptionally(error);
}
}
});
WebSocket ws;
try {
ws = stage.get(timeout.toMillis(), TimeUnit.MILLISECONDS);
} catch (java.util.concurrent.ExecutionException e) {
Throwable cause =
e.getCause() instanceof CompletionException ce ? ce.getCause() : e.getCause();
throw new IllegalStateException("WebSocket connect failed: " + cause, cause);
} catch (java.util.concurrent.TimeoutException e) {
throw new IllegalStateException("WebSocket connect timed out", e);
}
WebSocketTransport transport = new WebSocketTransport(ws, mapper);
futureSocket.set(transport);
return transport;
}
private void handleText(CharSequence data, boolean last) {
partial.append(data);
if (!last) {
return;
}
String frame = partial.toString();
partial.setLength(0);
try {
Envelope env = mapper.readValue(frame, Envelope.class);
inbound.submit(env);
} catch (IOException e) {
log.warn("malformed envelope frame: {}", e.getMessage());
}
}
@Override
public void send(Envelope envelope) {
try {
String json = mapper.writeValueAsString(envelope);
socket.sendText(json, true).get(5, TimeUnit.SECONDS);
} catch (IOException e) {
throw new UncheckedIOException(e);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
throw new IllegalStateException("interrupted while sending", e);
} catch (java.util.concurrent.ExecutionException | java.util.concurrent.TimeoutException e) {
throw new IllegalStateException("send failed", e);
}
}
@Override
public Flow.Publisher<Envelope> incoming() {
return inbound;
}
@Override
public void close() {
try {
socket.sendClose(WebSocket.NORMAL_CLOSURE, "bye").get(2, TimeUnit.SECONDS);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
} catch (java.util.concurrent.ExecutionException
| java.util.concurrent.TimeoutException ignored) {
// best-effort close
}
inbound.close();
}
}