Skip to content

Commit b0ccf71

Browse files
committed
[FLINK-27286] Add communication infra to support training high dimension models
1 parent ba327b0 commit b0ccf71

28 files changed

Lines changed: 3922 additions & 0 deletions

flink-ml-lib/pom.xml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,11 @@ under the License.
138138
<scope>test</scope>
139139
<type>test-jar</type>
140140
</dependency>
141+
<dependency>
142+
<groupId>it.unimi.dsi</groupId>
143+
<artifactId>fastutil</artifactId>
144+
<version>8.5.12</version>
145+
</dependency>
141146
</dependencies>
142147

143148
<build>
Lines changed: 239 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,239 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS,
14+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
* See the License for the specific language governing permissions and
16+
* limitations under the License.
17+
*/
18+
19+
package org.apache.flink.ml.common.ps;
20+
21+
import org.apache.flink.api.common.typeutils.TypeSerializer;
22+
import org.apache.flink.api.java.tuple.Tuple2;
23+
import org.apache.flink.core.memory.DataInputViewStreamWrapper;
24+
import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
25+
import org.apache.flink.ml.util.Bits;
26+
import org.apache.flink.util.Preconditions;
27+
28+
import java.io.ByteArrayInputStream;
29+
import java.io.ByteArrayOutputStream;
30+
import java.io.IOException;
31+
import java.lang.reflect.Array;
32+
import java.util.ArrayList;
33+
import java.util.Comparator;
34+
import java.util.Iterator;
35+
import java.util.List;
36+
37+
/**
38+
* {@link Message} is responsible for encoding all messages exchanged between {@link
39+
* org.apache.flink.ml.common.ps.WorkerOperator} and {@link
40+
* org.apache.flink.ml.common.ps.ServerOperator}. The message format follows this structure:
41+
*
42+
* <p>`workerId serverId stageId messageType keyLength keys valuesLength values`
43+
*
44+
* <p>where the message fields include the worker ID, server ID, stage ID, message type, length of
45+
* the keys, keys themselves, length of the values, and the values.
46+
*/
47+
public class Message {
48+
private static final int WORKER_ID_OFFSET = 0;
49+
private static final int SERVER_ID_OFFSET = Integer.BYTES;
50+
private static final int STAGE_ID_OFFSET = Integer.BYTES + SERVER_ID_OFFSET;
51+
private static final int KVS_OFFSET = Integer.BYTES + STAGE_ID_OFFSET;
52+
53+
/** The storage of message in bytes. */
54+
public final byte[] bytes;
55+
56+
/** Constructs a message instance from the bytes. */
57+
public Message(byte[] bytes) {
58+
this.bytes = bytes;
59+
}
60+
61+
/** Constructs a message instance from long keys and double values. */
62+
public Message(int workerId, int serverId, int stageId, long[] keys, double[] values) {
63+
int sizeInBytes = KVS_OFFSET + Bits.getLongDoubleArraySizeInBytes(Tuple2.of(keys, values));
64+
bytes = new byte[sizeInBytes];
65+
Bits.putInt(bytes, WORKER_ID_OFFSET, workerId);
66+
Bits.putInt(bytes, SERVER_ID_OFFSET, serverId);
67+
Bits.putInt(bytes, STAGE_ID_OFFSET, stageId);
68+
Bits.putLongDoubleArray(Tuple2.of(keys, values), bytes, KVS_OFFSET);
69+
}
70+
71+
/** Constructs a message instance from long keys and generics values. */
72+
public <V> Message(
73+
int workerId,
74+
int serverId,
75+
int stageId,
76+
long[] keys,
77+
V[] values,
78+
TypeSerializer<V> serializer)
79+
throws IOException {
80+
ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
81+
DataOutputViewStreamWrapper dataOutputViewStreamWrapper =
82+
new DataOutputViewStreamWrapper(byteArrayOutputStream);
83+
dataOutputViewStreamWrapper.writeInt(workerId);
84+
dataOutputViewStreamWrapper.writeInt(serverId);
85+
dataOutputViewStreamWrapper.writeInt(stageId);
86+
87+
dataOutputViewStreamWrapper.writeInt(keys.length);
88+
for (long key : keys) {
89+
dataOutputViewStreamWrapper.writeLong(key);
90+
}
91+
dataOutputViewStreamWrapper.writeInt(values.length);
92+
for (V value : values) {
93+
serializer.serialize(value, dataOutputViewStreamWrapper);
94+
}
95+
bytes = byteArrayOutputStream.toByteArray();
96+
}
97+
98+
/** Retrieves the keys. */
99+
public long[] getKeys() {
100+
return Bits.getLongArray(bytes, KVS_OFFSET);
101+
}
102+
103+
/** Retrieves the values using the given serializer. */
104+
public <V> V[] getValues(TypeSerializer<V> serializer) throws IOException {
105+
int numIndices = Bits.getInt(bytes, KVS_OFFSET);
106+
int offset = KVS_OFFSET + Integer.BYTES + numIndices * Long.BYTES;
107+
int numValues = Bits.getInt(bytes, offset);
108+
offset += Integer.BYTES;
109+
110+
// Since the generics got erased, we use reflections to create the array.
111+
V[] result = (V[]) Array.newInstance(serializer.createInstance().getClass(), numValues);
112+
ByteArrayInputStream byteArrayInputStream =
113+
new ByteArrayInputStream(bytes, offset, bytes.length - offset);
114+
DataInputViewStreamWrapper dataInputViewStreamWrapper =
115+
new DataInputViewStreamWrapper(byteArrayInputStream);
116+
for (int i = 0; i < numValues; i++) {
117+
result[i] = serializer.deserialize(dataInputViewStreamWrapper);
118+
}
119+
return result;
120+
}
121+
122+
/**
123+
* Retrieves the values in double array format.
124+
*
125+
* <p>Note that getting double array in this function using {@link Bits#getDoubleArray(byte[],
126+
* int)} is faster than {@link Message#getValues} by up to 2.3X.
127+
*/
128+
public double[] getValuesInDoubleArray() {
129+
int offset = KVS_OFFSET + Bits.getInt(bytes, KVS_OFFSET) * Long.BYTES + Integer.BYTES;
130+
return Bits.getDoubleArray(bytes, offset);
131+
}
132+
133+
/** Retrieves the worker id. */
134+
public int getWorkerId() {
135+
return Bits.getInt(bytes, WORKER_ID_OFFSET);
136+
}
137+
138+
/** Sets the worker id. */
139+
public void setWorkerId(int workerId) {
140+
Bits.putInt(bytes, WORKER_ID_OFFSET, workerId);
141+
}
142+
143+
/** Retrieves the server id. */
144+
public int getServerId() {
145+
return Bits.getInt(bytes, SERVER_ID_OFFSET);
146+
}
147+
148+
/** Sets the server id. */
149+
public void setServerId(int serverId) {
150+
Bits.putInt(bytes, SERVER_ID_OFFSET, serverId);
151+
}
152+
153+
public int getStageId() {
154+
return Bits.getInt(bytes, STAGE_ID_OFFSET);
155+
}
156+
157+
/**
158+
* Assembles the received messages from servers according to the server id. Note that these
159+
* messages should come from the same request.
160+
*/
161+
public static Message assembleMessages(Iterator<byte[]> messageIterator) {
162+
List<Message> messages = new ArrayList<>();
163+
while (messageIterator.hasNext()) {
164+
messages.add(new Message(messageIterator.next()));
165+
}
166+
messages.sort(Comparator.comparingInt(Message::getServerId));
167+
168+
int numMessages = messages.size();
169+
int numKeys = 0, numValues = 0;
170+
int numAssembledBytes = 0;
171+
int workerId = -1;
172+
int stageId = -1;
173+
for (Message message : messages) {
174+
Preconditions.checkState(workerId == -1 || workerId == message.getWorkerId());
175+
Preconditions.checkState(stageId == -1 || stageId == message.getStageId());
176+
workerId = message.getWorkerId();
177+
stageId = message.getStageId();
178+
numKeys += message.getNumKeys();
179+
numValues += message.getNumValues();
180+
numAssembledBytes += message.bytes.length;
181+
}
182+
numAssembledBytes -= (numMessages - 1) * (KVS_OFFSET + Integer.BYTES * 2);
183+
byte[] assembledBytes = new byte[numAssembledBytes];
184+
Bits.putInt(assembledBytes, WORKER_ID_OFFSET, workerId);
185+
Bits.putInt(assembledBytes, STAGE_ID_OFFSET, stageId);
186+
int keysOffset = KVS_OFFSET;
187+
Bits.putInt(assembledBytes, keysOffset, numKeys);
188+
keysOffset += Integer.BYTES;
189+
int valuesOffset = keysOffset + numKeys * Long.BYTES;
190+
Bits.putInt(assembledBytes, valuesOffset, numValues);
191+
valuesOffset += Integer.BYTES;
192+
193+
for (Message message : messages) {
194+
Tuple2<Integer, Integer> keyOoffsetAndLength = message.getKeysOffsetAndLength();
195+
System.arraycopy(
196+
message.bytes,
197+
keyOoffsetAndLength.f0,
198+
assembledBytes,
199+
keysOffset,
200+
keyOoffsetAndLength.f1);
201+
keysOffset += keyOoffsetAndLength.f1;
202+
Tuple2<Integer, Integer> valuesOffsetAndLength = message.getValuesOffSetAndLength();
203+
System.arraycopy(
204+
message.bytes,
205+
valuesOffsetAndLength.f0,
206+
assembledBytes,
207+
valuesOffset,
208+
valuesOffsetAndLength.f1);
209+
valuesOffset += valuesOffsetAndLength.f1;
210+
}
211+
212+
Message message = new Message(assembledBytes);
213+
message.setServerId(-1);
214+
return message;
215+
}
216+
217+
private Tuple2<Integer, Integer> getKeysOffsetAndLength() {
218+
int start = KVS_OFFSET + Integer.BYTES;
219+
int numBytes = Bits.getInt(bytes, KVS_OFFSET) * Long.BYTES;
220+
return Tuple2.of(start, numBytes);
221+
}
222+
223+
private Tuple2<Integer, Integer> getValuesOffSetAndLength() {
224+
int start =
225+
Bits.getInt(bytes, KVS_OFFSET) * Long.BYTES
226+
+ KVS_OFFSET
227+
+ Integer.BYTES
228+
+ Integer.BYTES;
229+
return Tuple2.of(start, bytes.length - start);
230+
}
231+
232+
private int getNumKeys() {
233+
return Bits.getInt(bytes, KVS_OFFSET);
234+
}
235+
236+
private int getNumValues() {
237+
return Bits.getInt(bytes, KVS_OFFSET + Integer.BYTES + Long.BYTES * getNumKeys());
238+
}
239+
}

0 commit comments

Comments
 (0)