Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion skyrl/backends/skyrl_train/inference_servers/vllm_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ def load_weights(self, request: bytes) -> None:
"""
import pickle

from vllm.config import set_current_vllm_config

# Unpickle request to restore the original object type
assert isinstance(request, bytes), f"Expected bytes, got {type(request).__name__}"
request = pickle.loads(request)
Expand All @@ -90,7 +92,8 @@ def load_weights(self, request: bytes) -> None:
for name, tensor in self._weight_receiver.receive_weights(request):
weight_list.append((name, tensor))

self.model_runner.model.load_weights(weights=weight_list)
with torch.device(self.device), set_current_vllm_config(self.vllm_config):
self.model_runner.reload_weights(weights_iterator=iter(weight_list))
Comment on lines 92 to +96
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Instead of collecting all weights into an intermediate list, you can pass the generator from receive_weights directly to reload_weights. This reduces memory overhead by avoiding storing all tensors in a list simultaneously, which is particularly important for large models. It also allows vLLM to pipeline the weight loading process as tensors are received.

Note that this change makes the subsequent weight_list cleanup loop (lines 98-99) redundant as the list will remain empty, but since those lines are context, they can be left as-is or removed in a separate cleanup.

Suggested change
for name, tensor in self._weight_receiver.receive_weights(request):
weight_list.append((name, tensor))
self.model_runner.model.load_weights(weights=weight_list)
with torch.device(self.device), set_current_vllm_config(self.vllm_config):
self.model_runner.reload_weights(weights_iterator=iter(weight_list))
with torch.device(self.device), set_current_vllm_config(self.vllm_config):
self.model_runner.reload_weights(
weights_iterator=self._weight_receiver.receive_weights(request))


for weight in weight_list:
del weight
Expand Down
Loading