Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions comfy/multigpu.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from __future__ import annotations
import queue
import threading
import torch
import logging

Expand All @@ -11,6 +13,67 @@
import comfy.model_management


class MultiGPUThreadPool:
"""Persistent thread pool for multi-GPU work distribution.

Maintains one worker thread per extra GPU device. Each thread calls
torch.cuda.set_device() once at startup so that compiled kernel caches
(inductor/triton) stay warm across diffusion steps.
"""

def __init__(self, devices: list[torch.device]):
self._workers: list[threading.Thread] = []
self._work_queues: dict[torch.device, queue.Queue] = {}
self._result_queues: dict[torch.device, queue.Queue] = {}

for device in devices:
wq = queue.Queue()
rq = queue.Queue()
self._work_queues[device] = wq
self._result_queues[device] = rq
t = threading.Thread(target=self._worker_loop, args=(device, wq, rq), daemon=True)
t.start()
self._workers.append(t)

def _worker_loop(self, device: torch.device, work_q: queue.Queue, result_q: queue.Queue):
try:
torch.cuda.set_device(device)
except Exception as e:
logging.error(f"MultiGPUThreadPool: failed to set device {device}: {e}")
while True:
item = work_q.get()
if item is None:
return
result_q.put((None, e))
return
while True:
item = work_q.get()
if item is None:
break
fn, args, kwargs = item
try:
result = fn(*args, **kwargs)
result_q.put((result, None))
except Exception as e:
result_q.put((None, e))

def submit(self, device: torch.device, fn, *args, **kwargs):
self._work_queues[device].put((fn, args, kwargs))

def get_result(self, device: torch.device):
return self._result_queues[device].get()

@property
def devices(self) -> list[torch.device]:
return list(self._work_queues.keys())

def shutdown(self):
for wq in self._work_queues.values():
wq.put(None) # sentinel
for t in self._workers:
t.join(timeout=5.0)


class GPUOptions:
def __init__(self, device_index: int, relative_speed: float):
self.device_index = device_index
Expand Down
1 change: 1 addition & 0 deletions comfy/sampler_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import comfy.patcher_extension
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from comfy.model_base import BaseModel
from comfy.model_patcher import ModelPatcher
from comfy.controlnet import ControlBase

Expand Down
53 changes: 42 additions & 11 deletions comfy/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@
import comfy.patcher_extension
import comfy.hooks
import comfy.context_windows
import comfy.multigpu
import comfy.utils
import scipy.stats
import numpy
import threading


def add_area_dims(area, num_dims):
Expand Down Expand Up @@ -509,15 +509,38 @@ def _handle_batch(device: torch.device, batch_tuple: tuple[comfy.hooks.HookGroup
raise


def _handle_batch_pooled(device, batch_tuple):
worker_results = []
_handle_batch(device, batch_tuple, worker_results)
return worker_results

results: list[thread_result] = []
threads: list[threading.Thread] = []
thread_pool: comfy.multigpu.MultiGPUThreadPool = model_options.get("multigpu_thread_pool")
main_device = output_device
main_batch_tuple = None

# Submit extra GPU work to pool first, then run main device on this thread
pool_devices = []
for device, batch_tuple in device_batched_hooked_to_run.items():
new_thread = threading.Thread(target=_handle_batch, args=(device, batch_tuple, results))
threads.append(new_thread)
new_thread.start()
if device == main_device and thread_pool is not None:
main_batch_tuple = batch_tuple
elif thread_pool is not None:
thread_pool.submit(device, _handle_batch_pooled, device, batch_tuple)
pool_devices.append(device)
else:
# Fallback: no pool, run everything on main thread
_handle_batch(device, batch_tuple, results)

for thread in threads:
thread.join()
# Run main device batch on this thread (parallel with pool workers)
if main_batch_tuple is not None:
_handle_batch(main_device, main_batch_tuple, results)

# Collect results from pool workers
for device in pool_devices:
worker_results, error = thread_pool.get_result(device)
if error is not None:
raise error
results.extend(worker_results)

for output, mult, area, batch_chunks, cond_or_uncond, error in results:
if error is not None:
Expand Down Expand Up @@ -1187,17 +1210,25 @@ def outer_sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None,

multigpu_patchers = comfy.sampler_helpers.prepare_model_patcher_multigpu_clones(self.model_patcher, self.loaded_models, self.model_options)

noise = noise.to(device=device, dtype=torch.float32)
latent_image = latent_image.to(device=device, dtype=torch.float32)
sigmas = sigmas.to(device)
cast_to_load_options(self.model_options, device=device, dtype=self.model_patcher.model_dtype())
# Create persistent thread pool for extra GPU devices
if multigpu_patchers:
extra_devices = [p.load_device for p in multigpu_patchers]
self.model_options["multigpu_thread_pool"] = comfy.multigpu.MultiGPUThreadPool(extra_devices)

try:
noise = noise.to(device=device, dtype=torch.float32)
latent_image = latent_image.to(device=device, dtype=torch.float32)
sigmas = sigmas.to(device)
cast_to_load_options(self.model_options, device=device, dtype=self.model_patcher.model_dtype())

self.model_patcher.pre_run()
for multigpu_patcher in multigpu_patchers:
multigpu_patcher.pre_run()
output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed, latent_shapes=latent_shapes)
finally:
thread_pool = self.model_options.pop("multigpu_thread_pool", None)
if thread_pool is not None:
thread_pool.shutdown()
self.model_patcher.cleanup()
for multigpu_patcher in multigpu_patchers:
multigpu_patcher.cleanup()
Expand Down
Loading