[bug] Switch to reload_weights API for loading weights in legacy inference codepath#1685
[bug] Switch to reload_weights API for loading weights in legacy inference codepath#1685SumanthRH wants to merge 1 commit into
reload_weights API for loading weights in legacy inference codepath#1685Conversation
Signed-off-by: SumanthRH <sumanthrh99@gmail.com>
There was a problem hiding this comment.
Code Review
This pull request updates the weight loading logic in vllm_worker.py by introducing the set_current_vllm_config context manager and switching to model_runner.reload_weights. Feedback suggests optimizing memory efficiency by passing the weight generator directly to the loading function instead of accumulating tensors in an intermediate list.
| for name, tensor in self._weight_receiver.receive_weights(request): | ||
| weight_list.append((name, tensor)) | ||
|
|
||
| self.model_runner.model.load_weights(weights=weight_list) | ||
| with torch.device(self.device), set_current_vllm_config(self.vllm_config): | ||
| self.model_runner.reload_weights(weights_iterator=iter(weight_list)) |
There was a problem hiding this comment.
Instead of collecting all weights into an intermediate list, you can pass the generator from receive_weights directly to reload_weights. This reduces memory overhead by avoiding storing all tensors in a list simultaneously, which is particularly important for large models. It also allows vLLM to pipeline the weight loading process as tensors are received.
Note that this change makes the subsequent weight_list cleanup loop (lines 98-99) redundant as the list will remain empty, but since those lines are context, they can be left as-is or removed in a separate cleanup.
| for name, tensor in self._weight_receiver.receive_weights(request): | |
| weight_list.append((name, tensor)) | |
| self.model_runner.model.load_weights(weights=weight_list) | |
| with torch.device(self.device), set_current_vllm_config(self.vllm_config): | |
| self.model_runner.reload_weights(weights_iterator=iter(weight_list)) | |
| with torch.device(self.device), set_current_vllm_config(self.vllm_config): | |
| self.model_runner.reload_weights( | |
| weights_iterator=self._weight_receiver.receive_weights(request)) |
What does this PR do?
Fixes #1680