diff --git a/src/diffusers/cache_manager.py b/src/diffusers/cache_manager.py new file mode 100644 index 000000000000..3991fe7f0348 --- /dev/null +++ b/src/diffusers/cache_manager.py @@ -0,0 +1,90 @@ +from typing import Any, Callable, Dict, List, Optional, Tuple, Union +import torch +from loguru import logger +from diffusers.runtime_state import get_runtime_state + +class CacheEntry: + def __init__( + self, + cache_type: "str", + num_cache_tensors: int = 1, + tensors: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None, + ): + self.cache_type: str = cache_type + if tensors is None: + self.tensors: List[torch.Tensor] = [None,] * num_cache_tensors + elif isinstance(tensors, torch.Tensor): + self.tensors = [tensors, ] + elif isinstance(tensors, List): + self.tensors = [tensors, ] + +class CacheManager: + def __init__(self): + self.cache: Dict[Tuple[str, Any], CacheEntry] = {} + + def register_cache_entry(self, layer, layer_type: str, cache_type: str = "naive_cache"): + self.cache[layer_type, layer] = CacheEntry(cache_type) + + def cache_update( + self, + new_kv: Union[torch.Tensor, List[torch.Tensor]], + layer, + slice_dim: int = 1, + layer_type: str = "attn", + ): + return_list = False + if isinstance(new_kv, List): + return_list = True + new_kv = torch.cat(new_kv, dim=-1) + if get_runtime_state().num_pipeline_patch == 1 or not get_runtime_state().patch_mode: + kv_cache = new_kv + self.cache[layer_type, layer].tensors[0] = kv_cache + else: + start_token_idx = get_runtime_state().pp_patches_token_start_idx_local[ + get_runtime_state().pipeline_patch_idx + ] + end_token_idx = get_runtime_state().pp_patches_token_start_idx_local[ + get_runtime_state().pipeline_patch_idx + 1 + ] + kv_cache = self.cache[layer_type, layer].tensors[0] + kv_cache = self._update_kv_in_dim( + kv_cache=kv_cache, + new_kv=new_kv, + dim=slice_dim, + start_idx=start_token_idx, + end_idx=end_token_idx, + ) + self.cache[layer_type, layer].tensors[0] = kv_cache + if return_list: + return torch.chunk(kv_cache, 2, dim=-1) + else: + return kv_cache + + def _update_kv_in_dim( + self, + kv_cache: torch.Tensor, + new_kv: torch.Tensor, + dim: int, + start_idx: int, + end_idx: int, + ): + if dim < 0: + dim += kv_cache.dim() + + if dim == 0: + kv_cache[start_idx:end_idx, ...] = new_kv + elif dim == 1: + kv_cache[:, start_idx:end_idx:, ...] = new_kv + elif dim == 2: + kv_cache[:, :, start_idx:end_idx, ...] = new_kv + elif dim == 3: + kv_cache[:, :, :, start_idx:end_idx, ...] = new_kv + return kv_cache + +_CACHE_MANAGER = CacheManager() + +def get_cache_manager(): + global _CACHE_MANAGER + if _CACHE_MANAGER is None: + _CACHE_MANAGER = CacheManager() + return _CACHE_MANAGER \ No newline at end of file diff --git a/src/diffusers/conv.py b/src/diffusers/conv.py new file mode 100644 index 000000000000..c9734ad2f8c6 --- /dev/null +++ b/src/diffusers/conv.py @@ -0,0 +1,91 @@ +import torch +from torch import nn +from torch.nn import functional as F +from diffusers.runtime_state import get_runtime_state +from diffusers.parallel_state import get_pipeline_parallel_world_size, get_sequence_parallel_world_size +from loguru import logger + +class CustomConv2d(nn.Module): + def __init__( + self, conv2d: nn.Conv2d + ): + super().__init__() + self.module = conv2d + self.module_type = type(self.module) + self.activation_cache = None + + def naive_forward(self, x: torch.Tensor) -> torch.Tensor: + output = self.module(x) + return output + + def sliced_forward(self, x: torch.Tensor) -> torch.Tensor: + b, c, h, w = x.shape + stride = self.module.stride[0] + padding = self.module.padding[0] + + idx = get_runtime_state().pipeline_patch_idx + pp_patches_start_idx_local = get_runtime_state().pp_patches_start_idx_local + h_begin = pp_patches_start_idx_local[idx] - padding + h_end = pp_patches_start_idx_local[idx + 1] + padding + final_padding = [padding, padding, 0, 0] + if h_begin < 0: + h_begin = 0 + final_padding[2] = padding + if h_end > h: + h_end = h + final_padding[3] = padding + sliced_input = x[:, :, h_begin:h_end, :] + padded_input = F.pad(sliced_input, final_padding, mode="constant") + result = F.conv2d( + padded_input, + self.module.weight, + self.module.bias, + stride=stride, + padding="valid", + ) + return result + + def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: + if ( + ( + get_pipeline_parallel_world_size() == 1 + and get_sequence_parallel_world_size() == 1 + ) + or self.module.kernel_size == (1, 1) + or self.module.kernel_size == 1 + ): + output = self.naive_forward(x) + else: + if ( + not get_runtime_state().patch_mode + or get_runtime_state().num_pipeline_patch == 1 + ): + self.activation_cache = x + output = self.naive_forward(self.activation_cache) + else: + if self.activation_cache is None: + self.activation_cache = torch.zeros( + [ + x.shape[0], + x.shape[1], + get_runtime_state().pp_patches_start_idx_local[-1], + x.shape[3], + ], + dtype=x.dtype, + device=x.device, + ) + + self.activation_cache[ + :, + :, + get_runtime_state() + .pp_patches_start_idx_local[ + get_runtime_state().pipeline_patch_idx + ] : get_runtime_state() + .pp_patches_start_idx_local[ + get_runtime_state().pipeline_patch_idx + 1 + ], + :, + ] = x + output = self.sliced_forward(self.activation_cache) + return output \ No newline at end of file diff --git a/src/diffusers/embedding.py b/src/diffusers/embedding.py new file mode 100644 index 000000000000..723681e16b13 --- /dev/null +++ b/src/diffusers/embedding.py @@ -0,0 +1,48 @@ +import torch +from diffusers.models.embeddings import PatchEmbed, get_2d_sincos_pos_embed +from diffusers.runtime_state import get_runtime_state +from torch import nn +from loguru import logger + +class CustomPatchEmbed(nn.Module): # xinze: the difference is that, we do the positional embedding as if the patch is the full picture. After embedding process, we crop the result according to the patch index and use it as the final embedding. + def __init__( + self, patch_embedding: PatchEmbed, + ): + super().__init__() + self.module = patch_embedding + self.module_type = type(self.module) # self.module.pos_embed is injected in the from_pretrained step. + self.pos_embed = None + self.activation_cache = None + + + def forward(self, latent): + + height = ( + get_runtime_state().config.height + // get_runtime_state().vae_scale_factor + ) + width = latent.shape[-1] + + latent = self.module.proj(latent) + if self.module.flatten: + latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC + + + # Interpolate positional embeddings if needed. + # (For PixArt-Alpha: https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L162C151-L162C160) + + if getattr(self.module, "pos_embed_max_size", None): + pos_embed = self.module.cropped_pos_embed(height, width) + + + if get_runtime_state().patch_mode: + start, end = get_runtime_state().pp_patches_token_start_end_idx_global[ + get_runtime_state().pipeline_patch_idx + ] + pos_embed = pos_embed[ + :, + start:end, + :, + ] + + return (latent + pos_embed).to(latent.dtype) \ No newline at end of file diff --git a/src/diffusers/group_coordinator.py b/src/diffusers/group_coordinator.py new file mode 100644 index 000000000000..111e8e36a904 --- /dev/null +++ b/src/diffusers/group_coordinator.py @@ -0,0 +1,515 @@ +import torch +import torch.distributed +from torch.distributed import Backend, ProcessGroup +from typing import Any, Dict, List, Optional, Tuple, Union + +class GroupCoordinator: + """ + PyTorch ProcessGroup wrapper for a group of processes. + PyTorch ProcessGroup is bound to one specific communication backend, + e.g. NCCL, Gloo, MPI, etc. + GroupCoordinator takes charge of all the communication operations among + the processes in the group. It can route the communication to + a specific implementation (e.g. switch allreduce implementation + based on the tensor size and cuda graph mode). + """ + + # available attributes: + rank: int # global rank + ranks: List[int] # global ranks in the group + world_size: int # size of the group + # difference between `local_rank` and `rank_in_group`: + # if we have a group of size 4 across two nodes: + # Process | Node | Rank | Local Rank | Rank in Group + # 0 | 0 | 0 | 0 | 0 + # 1 | 0 | 1 | 1 | 1 + # 2 | 1 | 2 | 2 | 0 + # 3 | 1 | 3 | 3 | 1 + # local_rank: int # local rank used to assign devices + rank_in_group: int # rank inside the group + cpu_group: ProcessGroup # group for CPU communication + device_group: ProcessGroup # group for device communication + + def __init__( + self, + group_ranks: List[List[int]], + local_rank: int, + torch_distributed_backend: Union[str, Backend], + ): + + self.rank = torch.distributed.get_rank() + self.local_rank = local_rank + self.device_group = None + self.cpu_group = None + + for ranks in group_ranks: + device_group = torch.distributed.new_group( + ranks, backend=torch_distributed_backend + ) + # a group with `gloo` backend, to allow direct coordination between + # processes through the CPU. + cpu_group = torch.distributed.new_group(ranks, backend="gloo") + if self.rank in ranks: + self.ranks = ranks + self.world_size = len(ranks) + self.rank_in_group = ranks.index(self.rank) + self.device_group = device_group + self.cpu_group = cpu_group + + assert self.cpu_group is not None + assert self.device_group is not None + + if torch.cuda.is_available(): + self.device = torch.device(f"cuda:{local_rank}") + else: + self.device = torch.device("cpu") + + def all_gather(# xinze: this is just adapted torch.all_gather_into_tensor and then reshape + self, input_: torch.Tensor, dim: int = 0, separate_tensors: bool = False + ) -> Union[torch.Tensor, List[torch.Tensor]]: + world_size = self.world_size + if world_size == 1: + return input_ + if dim <0: + dim += input_.dim() + input_size = input_.size() + output_tensor = torch.empty( + (world_size,) + input_size, dtype=input_.dtype, device=input_.device #xinze: (world_size, ) is a tuple. + ) + # All-gather + torch.distributed.all_gather_into_tensor( + output_tensor, input_, group=self.device_group + ) + if dim != 0: + output_tensor = output_tensor.movedim(0, dim) + + if separate_tensors: + tensor_list = [ + output_tensor.view(-1) + .narrow(0, input_.numel() * i, input_.numel()) + .view_as(input_) + for i in range(world_size) + ] + return tensor_list + else: + input_size = list(input_.size()) + input_size[dim] = input_size[dim] * world_size + output_tensor = output_tensor.reshape(input_size) + return output_tensor + + @property + def prev_rank(self): + rank_in_group = self.rank_in_group + world_size = self.world_size + return self.ranks[(rank_in_group - 1) % world_size] + + @property + def first_rank(self): + """Return the global rank of the first process in the group""" + return self.ranks[0] + + @property + def last_rank(self): + """Return the global rank of the last process in the group""" + return self.ranks[-1] + + @property + def is_first_rank(self): + """Return whether the caller is the first process in the group""" + return self.rank == self.first_rank + + @property + def is_last_rank(self): + """Return whether the caller is the last process in the group""" + return self.rank == self.last_rank + + @property + def next_rank(self): + """Return the global rank of the process that follows the caller""" + rank_in_group = self.rank_in_group + world_size = self.world_size + return self.ranks[(rank_in_group + 1) % world_size] + + @property + def group_next_rank(self): + """Return the group rank of the process that follows the caller""" + rank_in_group = self.rank_in_group + world_size = self.world_size + return (rank_in_group + 1) % world_size + + @property + def group_prev_rank(self): + """Return the group rank of the process that precedes the caller""" + rank_in_group = self.rank_in_group + world_size = self.world_size + return (rank_in_group - 1) % world_size + + @property + def skip_rank(self): + """Return the global rank of the process that skip connects with the caller""" + rank_in_group = self.rank_in_group + world_size = self.world_size + return self.ranks[(world_size - rank_in_group - 1) % world_size] + + @property + def group_skip_rank(self): + """Return the group rank of the process that skip connects with the caller""" + rank_in_group = self.rank_in_group + world_size = self.world_size + return (world_size - rank_in_group - 1) % world_size + +class SequenceParallelGroupCoordinator(GroupCoordinator): + def __init__( + self, + group_ranks: List[List[int]], + local_rank: int, + torch_distributed_backend: Union[str, Backend], + **kwargs, + ): + super().__init__( + group_ranks=group_ranks, + local_rank=local_rank, + torch_distributed_backend=torch_distributed_backend, + ) + + ulysses_group = kwargs.get("ulysses_group", None) + ring_group = kwargs.get("ring_group", None) + if ulysses_group is None: + raise RuntimeError( + f"Please pass argument 'ulysses_group' when calling init func of SequenceParallelGroupCoordinator" + ) + if ring_group is None: + raise RuntimeError( + f"Please pass argument 'ring_group' when calling init func of SequenceParallelGroupCoordinator" + ) + self.ulysses_group = ulysses_group + self.ring_group = ring_group + + self.ulysses_world_size = torch.distributed.get_world_size( + self.ulysses_group + ) + self.ulysses_rank = torch.distributed.get_rank(self.ulysses_group) + self.ring_world_size = torch.distributed.get_world_size(self.ring_group) + self.ring_rank = torch.distributed.get_rank(self.ring_group) + +class PipelineParallelGroupCoordinator(GroupCoordinator): + def __init__( + self, + group_ranks: List[List[int]], + local_rank: int, + torch_distributed_backend: Union[str, Backend], + **kwargs, + ): + self.rank = torch.distributed.get_rank() + self.local_rank = local_rank + self.device_group = None + self.cpu_group = None + self.cpu_groups = [] + self.device_groups = [] + if len(group_ranks[0]) > 2 or len(group_ranks[0]) == 1: + for ranks in group_ranks: + device_group = torch.distributed.new_group( + ranks, backend=torch_distributed_backend + ) + # a group with `gloo` backend, to allow direct coordination between + # processes through the CPU. + cpu_group = torch.distributed.new_group(ranks, backend="gloo") + if self.rank in ranks: + self.ranks = ranks + self.world_size = len(ranks) + self.rank_in_group = ranks.index(self.rank) + self.device_group = device_group + self.cpu_group = cpu_group + # when pipeline parallelism is 2, we need to create two groups to avoid + # communication stall. + # *_group_0_1 represents the group for communication from device 0 to + # device 1. + # *_group_1_0 represents the group for communication from device 1 to + # device 0. + elif len(group_ranks[0]) == 2: + for ranks in group_ranks: + device_group_0_1 = torch.distributed.new_group( + ranks, backend=torch_distributed_backend + ) + device_group_1_0 = torch.distributed.new_group( + ranks, backend=torch_distributed_backend + ) + # a group with `gloo` backend, to allow direct coordination between + # processes through the CPU. + cpu_group_0_1 = torch.distributed.new_group(ranks, backend="gloo") + cpu_group_1_0 = torch.distributed.new_group(ranks, backend="gloo") + if self.rank in ranks: + self.ranks = ranks + self.world_size = len(ranks) + self.rank_in_group = ranks.index(self.rank) + self.device_groups = [device_group_0_1, device_group_1_0] + self.cpu_groups = [cpu_group_0_1, cpu_group_1_0] + self.device_group = device_group_0_1 + self.cpu_group = cpu_group_0_1 + + assert self.cpu_group is not None + assert self.device_group is not None + + if torch.cuda.is_available(): + self.device = torch.device(f"cuda:{local_rank}") + else: + self.device = torch.device("cpu") + + self.recv_buffer_set: bool = False + self.recv_tasks_queue: List[Tuple[str, int]] = [] + self.receiving_tasks: List[Tuple[torch.distributed.Work, str, int]] = [] + self.dtype: Optional[torch.dtype] = None + self.num_pipefusion_patches: Optional[int] = None + + self.recv_shape: Dict[str, Dict[int, torch.Size]] = {} + self.send_shape: Dict[str, Dict[int, torch.Size]] = {} + self.recv_buffer: Dict[str, Dict[int, torch.Size]] = {} + + self.skip_tensor_recv_buffer_set: bool = False + self.recv_skip_tasks_queue: List[Union[int, Tuple[str, int]]] = [] + self.receiving_skip_tasks: List[Tuple[torch.distributed.Work, str, int]] = [] + self.skip_tensor_recv_buffer: Optional[ + Union[List[torch.Tensor], torch.Tensor] + ] = None + self.skip_device_group = None + for ranks in group_ranks: + skip_device_group = torch.distributed.new_group( + ranks, backend=torch_distributed_backend + ) + if self.rank in ranks: + self.skip_device_group = skip_device_group + assert self.skip_device_group is not None + + def reset_buffer(self): + self.recv_tasks_queue = [] + self.receiving_tasks = [] + self.recv_shape = {} + self.send_shape = {} + self.recv_buffer = {} + + self.recv_skip_tasks_queue = [] + self.receiving_skip_tasks = [] + self.skip_tensor_recv_buffer = {} + + def recv_next(self): + if len(self.recv_tasks_queue) == 0: + raise ValueError("No more tasks to receive") + elif len(self.recv_tasks_queue) > 0: + name, idx = self.recv_tasks_queue.pop(0) + self._check_shape_and_buffer(recv_prev=True, name=name, segment_idx=idx) + self.receiving_tasks.append( + (self._pipeline_irecv(self.recv_buffer[name][idx]), name, idx) + ) + + def _pipeline_irecv(self, tensor: torch.tensor): + return torch.distributed.irecv( + tensor, + src=self.prev_rank, + group=( + self.device_groups[(self.rank_in_group + 1) % 2] + if self.world_size == 2 + else self.device_group + ), + ) + + def get_pipeline_recv_data( + self, idx: int = -1, name: str = "latent" + ) -> torch.Tensor: + receiving_task = self.receiving_tasks.pop(0) + receiving_task[0].wait() + return self.recv_buffer[name][idx] + + def _check_shape_and_buffer( + self, + tensor_send_to_next=None, + recv_prev=False, + name: Optional[str] = None, + segment_idx: int = 0, + ): + pass + +def generate_masked_orthogonal_rank_groups( + world_size: int, parallel_size: List[int], mask: List[bool] +) -> List[List[int]]: + """Generate orthogonal parallel groups based on the parallel size and mask. + + Arguments: + world_size (int): world size + + parallel_size (List[int]): + The parallel size of each orthogonal parallel type. For example, if + tensor_parallel_size = 2, pipeline_model_parallel_group = 3, data_parallel_size = 4, + and the parallel mapping order is tp-pp-dp, then the parallel_size = [2, 3, 4]. + + mask (List[bool]): + The mask controls which parallel methods the generated groups represent. If mask[i] is + True, it means the generated group contains the i-th parallelism method. For example, + if parallel_size = [tp_size, pp_size, dp_size], and mask = [True, False , True], then + the generated group is the `tp-dp` group, if the mask = [False, True, False], then the + generated group is the `pp` group. + + Algorithm: + For orthogonal parallelism, such as tp/dp/pp/cp, the global_rank and + local_rank satisfy the following equation: + global_rank = tp_rank + dp_rank * tp_size + pp_rank * tp_size * dp_size (1) + tp_rank \in [0, tp_size) + dp_rank \in [0, dp_size) + pp_rank \in [0, pp_size) + + If we want to get the `dp_group` (tp_size * pp_size groups of dp_size ranks each. + For example, if the gpu size is 8 and order is 'tp-pp-dp', size is '2-2-2', and the + dp_group here is [[0, 4], [1, 5], [2, 6], [3, 7]].) + The tp_rank and pp_rank will be combined to form the `dp_group_index`. + dp_group_index = tp_rank + pp_rank * tp_size (2) + + So, Given that tp_rank and pp_rank satisfy equation (2), and dp_rank in + range(0, dp_size), the ranks in dp_group[dp_group_index] satisfies the + equation (1). + + This function solve this math problem. + + For example, if the parallel_size = [tp_size, dp_size, pp_size] = [2, 3, 4], + and the mask = [False, True, False]. Then, + dp_group_index(0) = tp_rank(0) + pp_rank(0) * 2 + dp_group_index(1) = tp_rank(1) + pp_rank(0) * 2 + ... + dp_group_index(7) = tp_rank(1) + pp_rank(3) * 2 + + dp_group[0] = 0 + range(0, 3) * 2 + 0 = [0, 2, 4] + dp_group[1] = 1 + range(0, 3) * 2 + 0 = [1, 3, 5] + ... + dp_group[7] = 1 + range(0, 3) * 2 + 3 * 2 * 3 = [19, 21, 23] + """ + + def prefix_product(a: List[int], init=1) -> List[int]: + r = [init] + for v in a: + init = init * v + r.append(init) + return r + + def inner_product(a: List[int], b: List[int]) -> int: + return sum([x * y for x, y in zip(a, b)]) + + def decompose(index, shape, stride=None): + """ + This function solve the math problem below: + There is an equation: + index = sum(idx[i] * stride[i]) + And given the value of index, stride. + Return the idx. + This function will used to get the pp/dp/pp_rank + from group_index and rank_in_group. + """ + if stride is None: + stride = prefix_product(shape) + idx = [(index // d) % s for s, d in zip(shape, stride)] + # stride is a prefix_product result. And the value of stride[-1] + # is not used. + assert ( + sum([x * y for x, y in zip(idx, stride[:-1])]) == index + ), "idx {} with shape {} mismatch the return idx {}".format(index, shape, idx) + return idx + + masked_shape = [s for s, m in zip(parallel_size, mask) if m] + unmasked_shape = [s for s, m in zip(parallel_size, mask) if not m] + + global_stride = prefix_product(parallel_size) + masked_stride = [d for d, m in zip(global_stride, mask) if m] + unmasked_stride = [d for d, m in zip(global_stride, mask) if not m] + + group_size = prefix_product(masked_shape)[-1] + num_of_group = world_size // group_size + + ranks = [] + for group_index in range(num_of_group): + # get indices from unmaksed for group_index. + decomposed_group_idx = decompose(group_index, unmasked_shape) + rank = [] + for rank_in_group in range(group_size): + # get indices from masked for rank_in_group. + decomposed_rank_idx = decompose(rank_in_group, masked_shape) + rank.append( + inner_product(decomposed_rank_idx, masked_stride) + + inner_product(decomposed_group_idx, unmasked_stride) + ) + ranks.append(rank) + return ranks + + +class RankGenerator(object): + def __init__( + self, + # tp: int, + sp: int, + pp: int, + cfg: int, + # dp: int, + order: str, + rank_offset: int = 0, + ) -> None: + # self.tp = tp + self.sp = sp + self.pp = pp + self.cfg = cfg + # self.dp = dp + self.rank_offset = rank_offset + self.world_size = sp * pp * cfg + + self.name_to_size = { + # "tp": self.tp, + "sp": self.sp, + "pp": self.pp, + "cfg": self.cfg, + # "dp": self.dp, + } + order = order.lower() + + for name in self.name_to_size.keys(): + if name not in order and self.name_to_size[name] != 1: + raise RuntimeError( + f"The size of ({name}) is ({self.name_to_size[name]}), but you haven't specified the order ({self.order})." + ) + elif name not in order: + order = order + "-" + name + + self.order = order + self.ordered_size = [] + + for token in order.split("-"): + self.ordered_size.append(self.name_to_size[token]) + + def get_mask(self, order: str, token: str): + ordered_token = order.split("-") + token = token.split("-") + mask = [False] * len(ordered_token) + for t in token: + mask[ordered_token.index(t)] = True + return mask + + def get_ranks(self, token): + """Get rank group by input token. + + Arguments: + token (str): + Specify the ranks type that want to get. If we want + to obtain multiple parallel types, we can use a hyphen + '-' to separate them. For example, if we want to obtain + the TP_DP group, the token should be 'tp-dp'. + + independent_ep (bool: True): + This flag controls whether we treat EP and DP independently. + EP shares ranks with DP, if we want to get ranks related to + EP, we should set the flag. For example, get_ranks('dp', True) + will get DP modulo EP group, and get_ranks('dp', False) will + get full DP group. + """ + mask = self.get_mask(self.order, token) + ranks = generate_masked_orthogonal_rank_groups( + self.world_size, self.ordered_size, mask + ) + if self.rank_offset > 0: + for rank_group in ranks: + for i in range(len(rank_group)): + rank_group[i] += self.rank_offset + return ranks diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 02ed1f965abf..602a9ea4b2c5 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -151,6 +151,10 @@ def __init__( qk_norm=qk_norm, eps=1e-6, ) + # add kv cache here! + from diffusers.cache_manager import get_cache_manager + + get_cache_manager().register_cache_entry(layer=self.attn, layer_type="attn") if use_dual_attention: self.attn2 = Attention( diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index faacc431c386..2f9deee80ea3 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -1224,45 +1224,86 @@ def __call__( key = attn.to_k(hidden_states) value = attn.to_v(hidden_states) + encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) + encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) + encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) + inner_dim = key.shape[-1] head_dim = inner_dim // attn.heads + from diffusers.cache_manager import get_cache_manager + key, value = get_cache_manager().cache_update( + new_kv=[key, value], + layer=attn, + slice_dim=1, + layer_type="attn", + ) + + query = torch.cat([query, encoder_hidden_states_query_proj], dim=1) + key = torch.cat([key, encoder_hidden_states_key_proj], dim=1) + value = torch.cat([value, encoder_hidden_states_value_proj], dim=1) + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - if attn.norm_q is not None: - query = attn.norm_q(query) - if attn.norm_k is not None: - key = attn.norm_k(key) + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.module.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention( + query, + key, + value, + attn_mask=attention_mask, + dropout_p=0.0, + is_causal=False, + ) - # `context` projections. - if encoder_hidden_states is not None: - encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) - encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) - encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) + hidden_states = hidden_states.transpose(1, 2).reshape( + batch_size, -1, attn.heads * head_dim + ) - encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - if attn.norm_added_q is not None: - encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) - if attn.norm_added_k is not None: - encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) - query = torch.cat([query, encoder_hidden_states_query_proj], dim=2) - key = torch.cat([key, encoder_hidden_states_key_proj], dim=2) - value = torch.cat([value, encoder_hidden_states_value_proj], dim=2) + # query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + # key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + # value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # if attn.norm_q is not None: + # query = attn.norm_q(query) + # if attn.norm_k is not None: + # key = attn.norm_k(key) + + # # `context` projections. + # if encoder_hidden_states is not None: + # encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) + # encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) + # encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) + # inner_dim = key.shape[-1] + # head_dim = inner_dim // attn.heads + + # encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( + # batch_size, -1, attn.heads, head_dim + # ).transpose(1, 2) + # encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( + # batch_size, -1, attn.heads, head_dim + # ).transpose(1, 2) + # encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( + # batch_size, -1, attn.heads, head_dim + # ).transpose(1, 2) + + # if attn.norm_added_q is not None: + # encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) + # if attn.norm_added_k is not None: + # encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) + + # query = torch.cat([query, encoder_hidden_states_query_proj], dim=2) + # key = torch.cat([key, encoder_hidden_states_key_proj], dim=2) + # value = torch.cat([value, encoder_hidden_states_value_proj], dim=2) + + # hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) + # hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + - hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) - hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) if encoder_hidden_states is not None: @@ -1278,11 +1319,22 @@ def __call__( hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) - - if encoder_hidden_states is not None: - return hidden_states, encoder_hidden_states - else: - return hidden_states + # input_ndim = hidden_states.ndim + # context_input_ndim = encoder_hidden_states.ndim + # if input_ndim == 4: + # hidden_states = hidden_states.transpose(-1, -2).reshape( + # batch_size, channel, height, width + # ) + # if context_input_ndim == 4: + # encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape( + # batch_size, channel, height, width + # ) + + # if encoder_hidden_states is not None: + # return hidden_states, encoder_hidden_states + # else: + # return hidden_states + return hidden_states, encoder_hidden_states class PAGJointAttnProcessor2_0: diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 8f8f1073da74..64be8249f298 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -283,7 +283,6 @@ def __init__( grid_size = pos_embed_max_size else: grid_size = int(num_patches**0.5) - if pos_embed_type is None: self.pos_embed = None elif pos_embed_type == "sincos": diff --git a/src/diffusers/models/transformers/transformer_sd3.py b/src/diffusers/models/transformers/transformer_sd3.py index 79452bb85176..9010045ead00 100644 --- a/src/diffusers/models/transformers/transformer_sd3.py +++ b/src/diffusers/models/transformers/transformer_sd3.py @@ -32,6 +32,8 @@ from ...utils.torch_utils import maybe_allow_in_graph from ..embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed from ..modeling_outputs import Transformer2DModelOutput +from diffusers.parallel_state import is_pipeline_first_stage, is_pipeline_last_stage, get_pipeline_parallel_world_size +from diffusers.runtime_state import get_runtime_state logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -184,6 +186,9 @@ def __init__( self.gradient_checkpointing = False + # xinze: here is a cache!!! + self.encoder_hidden_states_cache = [None for _ in range(len(self.transformer_blocks))] + # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: """ @@ -383,12 +388,23 @@ def forward( logger.warning( "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." ) - - height, width = hidden_states.shape[-2:] - - hidden_states = self.pos_embed(hidden_states) # takes care of adding positional embeddings too. + patch_size = self.config.patch_size + # ------------------MODIFIED HERE---------------- + # height, width = hidden_states.shape[-2:] + height, width = get_runtime_state().config.height, get_runtime_state().config.width + height = height // (patch_size * 8) # 8 is vae scale factor + width = width // (patch_size * 8) + if get_runtime_state().patch_mode: + height = height // get_pipeline_parallel_world_size() + + # hidden_states = self.pos_embed(hidden_states) # takes care of adding positional embeddings too. + # temb = self.time_text_embed(timestep, pooled_projections) + # encoder_hidden_states = self.context_embedder(encoder_hidden_states) temb = self.time_text_embed(timestep, pooled_projections) - encoder_hidden_states = self.context_embedder(encoder_hidden_states) + if is_pipeline_first_stage(): # only pp rank 0 needs patchify + hidden_states = self.pos_embed(hidden_states) + encoder_hidden_states = self.context_embedder(encoder_hidden_states) + for index_block, block in enumerate(self.transformer_blocks): # Skip specified layers @@ -414,34 +430,71 @@ def custom_forward(*inputs): **ckpt_kwargs, ) elif not is_skip: - encoder_hidden_states, hidden_states = block( - hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb - ) + # encoder_hidden_states, hidden_states = block( + # hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb + # ) + if get_runtime_state().patch_mode and get_runtime_state().pipeline_patch_idx == 0: + self.encoder_hidden_states_cache[index_block] = encoder_hidden_states + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb + ) + elif get_runtime_state().patch_mode: + _, hidden_states = block( + hidden_states=hidden_states, encoder_hidden_states=self.encoder_hidden_states_cache[index_block], temb=temb + ) + else: + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb + ) # controlnet residual if block_controlnet_hidden_states is not None and block.context_pre_only is False: interval_control = len(self.transformer_blocks) / len(block_controlnet_hidden_states) hidden_states = hidden_states + block_controlnet_hidden_states[int(index_block / interval_control)] - hidden_states = self.norm_out(hidden_states, temb) - hidden_states = self.proj_out(hidden_states) - - # unpatchify - patch_size = self.config.patch_size - height = height // patch_size - width = width // patch_size - - hidden_states = hidden_states.reshape( - shape=(hidden_states.shape[0], height, width, patch_size, patch_size, self.out_channels) - ) - hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) - output = hidden_states.reshape( - shape=(hidden_states.shape[0], self.out_channels, height * patch_size, width * patch_size) - ) - - if USE_PEFT_BACKEND: - # remove `lora_scale` from each PEFT layer - unscale_lora_layers(self, lora_scale) + # hidden_states = self.norm_out(hidden_states, temb) + # hidden_states = self.proj_out(hidden_states) + + # # unpatchify + # patch_size = self.config.patch_size + # height = height // patch_size + # width = width // patch_size + + # hidden_states = hidden_states.reshape( + # shape=(hidden_states.shape[0], height, width, patch_size, patch_size, self.out_channels) + # ) + # hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) + # output = hidden_states.reshape( + # shape=(hidden_states.shape[0], self.out_channels, height * patch_size, width * patch_size) + # ) + + # if USE_PEFT_BACKEND: + # # remove `lora_scale` from each PEFT layer + # unscale_lora_layers(self, lora_scale) + if is_pipeline_last_stage(): + + hidden_states = self.norm_out(hidden_states, temb) + hidden_states = self.proj_out(hidden_states) + + # unpatchify + patch_size = self.config.patch_size + # height = height // patch_size + # width = width // patch_size + + hidden_states = hidden_states.reshape( + shape=(hidden_states.shape[0], height, width, patch_size, patch_size, self.out_channels) + ) + hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) + output = (hidden_states.reshape( + shape=(hidden_states.shape[0], self.out_channels, height * patch_size, width * patch_size) + ), None,) + + if USE_PEFT_BACKEND: + # remove `lora_scale` from each PEFT layer + unscale_lora_layers(self, lora_scale) + else: + # if not last stage, then we need intermediate encoder hidden states. + output = hidden_states, encoder_hidden_states if not return_dict: return (output,) diff --git a/src/diffusers/nccl_parallel_context.py b/src/diffusers/nccl_parallel_context.py new file mode 100644 index 000000000000..f7b91c791b6c --- /dev/null +++ b/src/diffusers/nccl_parallel_context.py @@ -0,0 +1,503 @@ +import os +from typing import Optional, Union, Tuple, Sequence +from datetime import timedelta +from loguru import logger +from collections.abc import Coroutine + +import asyncio +import torch + +import crossing_accelerator_cuda_extension +from parallel_context import ( + ParallelContext, + AsyncModeContext, + AsyncMode, +) + +# types +Shape = Sequence[int] +Stream = torch.cuda.Stream +ProcessGroup = torch.distributed.ProcessGroup +Communicator = "torch.cuda.nccl.Communicator" + +if int(os.getenv("CROSSING_ENABLE_MSCCLPP_ALLRECUDE", "0")) == 1: + try: + from crossing.accelerators.mscclpp_allreduce import MscclppAllReduce + + MSCCLPP_ALLRECUDE_ENABLED = True + except: + MSCCLPP_ALLRECUDE_ENABLED = False +else: + MSCCLPP_ALLRECUDE_ENABLED = False + +from contextlib import closing +import socket + + +def _find_free_port(): + with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s: + s.bind(("", 0)) + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + return s.getsockname()[1] + + +MSCCLPP_MAX_SIZE_BYTES = 128 * 1024 + + +class OverlapAsyncModeContext(AsyncModeContext): + def __init__(self, parallel_ctx: "NcclParallelContext"): + self.parallel_ctx = parallel_ctx + + def __enter__(self): + self._coroutine_switch_condition = asyncio.Condition() + self._comm_stream = self.parallel_ctx._tensor_parallel_comm_stream + if self._comm_stream is None: + self._comm_stream = torch.cuda.Stream() + self.parallel_ctx._tensor_parallel_comm_stream = self._comm_stream + self._old_async_mode = self.parallel_ctx._async_mode + self.parallel_ctx._async_mode = self + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.parallel_ctx._async_mode = self._old_async_mode + if exc_type is not None: + logger.error( + f"An exception occurred in OverlapAsyncModeContext context scope: {exc_value}" + ) + return False + + async def gather(self, task1, task2) -> Tuple: + if isinstance(task1, Coroutine): + task1 = asyncio.create_task(task1) + if isinstance(task2, Coroutine): + task2 = asyncio.create_task(task2) + + assert isinstance( + task1, asyncio.Task + ), f"task1 should be a asyncio.Task, but got {task1}" + assert isinstance( + task2, asyncio.Task + ), f"task2 should be a asyncio.Task, but got {task2}" + + result1 = None + result2 = None + task1_completed = False + task2_completed = False + + while not task1_completed or not task2_completed: + done, pending = await asyncio.wait( + [task1, task2], return_when=asyncio.FIRST_COMPLETED + ) + for task in done: + if task is task1: + # notify task2 + async with self._coroutine_switch_condition: + self._coroutine_switch_condition.notify() + result1 = task.result() + task1_completed = True + else: + assert task1_completed + # task_2 done + result2 = task.result() + task2_completed = True + + return result1, result2 + + async def async_allreduce(self, tensor): + comp_stream = torch.cuda.current_stream() + # sync from comp to comm + cuda_event_comp = torch.cuda.Event() + cuda_event_comp.record(stream=comp_stream) + self._comm_stream.wait_event(cuda_event_comp) + + # allreduce on comm stream + with torch.cuda.stream(self._comm_stream): + self.parallel_ctx.allreduce_in_tensor_parallel_group(tensor) + + # sync from comm to comp + cuda_event_comm = torch.cuda.Event() + cuda_event_comm.record(stream=self._comm_stream) + # switch to another coroutine + async with self._coroutine_switch_condition: + self._coroutine_switch_condition.notify() + await self._coroutine_switch_condition.wait() + comp_stream.wait_event(cuda_event_comm) + + +def _init_torch_distributed(timeout=timedelta(hours=24)): + assert torch.cuda.is_available() + assert torch.distributed.is_available() + + env_rank = int(os.getenv("RANK", "0")) + env_world_size = int(os.getenv("WORLD_SIZE", "1")) + if env_world_size == 1: + return + + if torch.distributed.is_initialized(): + assert env_world_size == torch.distributed.get_world_size() + assert env_rank == torch.distributed.get_rank() + return + + backend = "nccl" + options = torch.distributed.ProcessGroupNCCL.Options() + options.is_high_priority_stream = True + + torch.distributed.init_process_group( + backend=backend, + world_size=env_world_size, + rank=env_rank, + timeout=timeout, + pg_options=options, + ) + + +def _check_send_recv_tensors(tensors: Union[torch.Tensor, Sequence[torch.Tensor]]): + if isinstance(tensors, torch.Tensor): + tensors = [tensors] + else: + if not isinstance(tensors, (list, tuple)): + raise ValueError( + f"send_tensors must be a list of tensors, but got type {type(tensors)}" + ) + + for i, tensor in enumerate(tensors): + if not isinstance(tensor, torch.Tensor): + raise ValueError( + "send_tensors must be a list of tensors, " + f"but got {i} th element type {type(tensor)}" + ) + + return tensors + + +class NcclParallelContext(ParallelContext): + def __init__( + self, + *, + ranks: Sequence[int] = None, + tensor_parallel_size: int = 1, + pipeline_parallel_size: int = 1, + device_index: int = 0, + ): + _init_torch_distributed() + + self._ranks = ranks + self._tensor_parallel_size = tensor_parallel_size + self._pipeline_parallel_size = pipeline_parallel_size + self._device_index = device_index + self._torch_device = torch.device(f"cuda:{self._device_index}") + self._parallel_size = tensor_parallel_size * pipeline_parallel_size + + if not isinstance(ranks, (list, tuple)) or not all( + isinstance(rank, int) for rank in ranks + ): + raise ValueError(f"`ranks` should be a list of int") + if len(ranks) != self._parallel_size: + raise ValueError(f"`ranks` should has the same length as parallel_ctx.size") + self._ranks = ranks + + if self._parallel_size > 1: + self._main_group = torch.distributed.new_group(self._ranks, backend="nccl") + else: + self._main_group = None + + # for grouping + self._process_mesh = torch.tensor(self._ranks) + + self._rank = ( + torch.distributed.get_rank() if torch.distributed.is_initialized() else 0 + ) + + # TODO: remove when tests don't rely on set_device in parallel_ctx + if self._rank in self._ranks: + torch.cuda.set_device(self._device_index) + + # tensor parallel property + self._tensor_parallel_group: Optional[ProcessGroup] = None + self._tensor_parallel_group_root: int = 0 + self._tensor_parallel_main_comm: Optional[Communicator] = None + self._tensor_parallel_comm_stream: Optional[Stream] = None + # async mode + self._async_mode: Optional[AsyncModeContext] = None + + # pipeline parallel property + self._pipeline_parallel_group: Optional[ProcessGroup] = None + self._pipeline_stage_id: int = 0 + self._pipeline_parallel_group_prev_rank = 0 + self._pipeline_parallel_group_next_rank = 0 + + with torch.cuda.device(self._device_index): + self._init_pipeline_parallel_group() + self._init_tensor_parallel_group() + + if self.tensor_parallel_size > 1 and MSCCLPP_ALLRECUDE_ENABLED: + if self.tensor_parallel_rank == 0: + port = _find_free_port() + port_tensor = torch.tensor( + [port], dtype=torch.int32, device=self._torch_device + ) + else: + port_tensor = torch.empty( + (1,), dtype=torch.int32, device=self._torch_device + ) + self.broadcast_in_tensor_parallel_group(port_tensor) + port = port_tensor.item() + self._mscclpp_allreduce = MscclppAllReduce( + rank=self.tensor_parallel_rank, + parallel_size=self.tensor_parallel_size, + max_size_bytes=MSCCLPP_MAX_SIZE_BYTES, + port=port, + device=self._torch_device, + ) + else: + self._mscclpp_allreduce = None + + def _init_pipeline_parallel_group(self): + if self.size > 1: + self._pipeline_parallel_group = self._new_group( + self._pipeline_parallel_size + ) + if self._pipeline_parallel_group is not None: + # set pipeline property + group_ranks = torch.distributed.get_process_group_ranks( + self._pipeline_parallel_group + ) + self._pipeline_stage_id = group_ranks.index(self._rank) + self._pipeline_parallel_group_prev_rank = group_ranks[ + (self._pipeline_stage_id - 1) % self._pipeline_parallel_size + ] + self._pipeline_parallel_group_next_rank = group_ranks[ + (self._pipeline_stage_id + 1) % self._pipeline_parallel_size + ] + + def _init_tensor_parallel_group(self): + if self.size > 1: + self._tensor_parallel_group = self._new_group(self._tensor_parallel_size) + if self._tensor_parallel_group is not None: + group_ranks = torch.distributed.get_process_group_ranks( + self._tensor_parallel_group + ) + self._tensor_parallel_group_root = group_ranks[0] + self._tensor_parallel_main_comm = self._create_communicator( + self._tensor_parallel_group + ) + + def _new_group( + self, parallel_size: int + ) -> Optional[torch.distributed.ProcessGroup]: + remain_num_groups = self._process_mesh.size(-1) + if remain_num_groups % parallel_size != 0: + raise ValueError( + f"The process mesh {self._process_mesh} cannot be divided further by parallel_size {parallel_size}" + ) + + num_groups = remain_num_groups // parallel_size + self._process_mesh = self._process_mesh.view( + *self._process_mesh.shape[:-1], parallel_size, num_groups + ) + + ret_group = None + global_rank = torch.distributed.get_rank() + # (past_num_groups, parallel_size, new_num_groups) + process_mesh = self._process_mesh.view(-1, parallel_size, num_groups) + for group_i in range(process_mesh.size(0)): + for group_j in range(process_mesh.size(-1)): + global_ranks = process_mesh[group_i, :, group_j].tolist() + group = torch.distributed.new_group(global_ranks, backend="nccl") + if global_rank in global_ranks: + ret_group = group + + return ret_group + + def _create_communicator(self, group: torch.distributed.ProcessGroup): + backend = torch.distributed.get_backend(group) + assert backend == "nccl" + group_ranks = torch.distributed.get_process_group_ranks(group) + if self._rank in group_ranks: + group_rank = group_ranks.index(self._rank) + unique_id = torch.cuda.nccl.unique_id() + unique_id_tensor = torch.ByteTensor(list(unique_id)).cuda() + torch.distributed.broadcast( + unique_id_tensor, src=group_ranks[0], group=group + ) + unique_id = unique_id_tensor.cpu().numpy().tobytes() + comm = crossing_accelerator_cuda_extension._nccl_comm_init_rank( + nranks=len(group_ranks), commId=unique_id, rank=group_rank + ) + return comm + return None + + def device_index(self) -> int: + return self._device_index + + def torch_device(self) -> torch.device: + return self._torch_device + + def allreduce(self, tensor: torch.Tensor, red_op: torch.distributed.ReduceOp): + if self.size > 1: + torch.distributed.all_reduce(tensor, op=red_op, group=self._main_group) + + @property + def tensor_parallel_group(self): + return self._tensor_parallel_group + + @property + def size(self) -> int: + return self._parallel_size + + @property + def tensor_parallel_size(self): + return self._tensor_parallel_size + + @property + def tensor_parallel_rank(self) -> int: + if self._tensor_parallel_group is None: + return 0 + return self._tensor_parallel_group.rank() + + def broadcast_in_tensor_parallel_group(self, tensor: torch.Tensor): + if self.tensor_parallel_size > 1: + torch.distributed.broadcast( + tensor, + src=self._tensor_parallel_group_root, + group=self._tensor_parallel_group, + ) + + def allreduce_in_tensor_parallel_group(self, tensor: torch.Tensor): + if self.tensor_parallel_size > 1: + tensor_size = tensor.numel() * tensor.element_size() + if ( + self._mscclpp_allreduce + and tensor_size <= MSCCLPP_MAX_SIZE_BYTES + and tensor.dtype == torch.float16 + and tensor.is_contiguous() + ): + self._mscclpp_allreduce(tensor) + else: + crossing_accelerator_cuda_extension._nccl_all_reduce( + input=tensor, + output=tensor, + op=0, # ncclRedOp_t::ncclSum + comm=self._tensor_parallel_main_comm, + ) + + def allgather_in_tensor_parallel_group( + self, input: torch.Tensor, output: torch.Tensor + ): + if self.tensor_parallel_size > 1: + crossing_accelerator_cuda_extension._nccl_all_gather( + input=input, + output=output, + comm=self._tensor_parallel_main_comm, + ) + + def async_mode(self, async_mode_type: AsyncMode) -> AsyncModeContext: + if async_mode_type == AsyncMode.OVERLAPPING: + return OverlapAsyncModeContext(self) + + raise NotImplementedError + + async def async_allreduce_in_tensor_parallel_group(self, tensor: torch.Tensor): + if self.tensor_parallel_size == 1: + return + + if self._async_mode is None: + return self.allreduce_in_tensor_parallel_group(tensor) + + return await self._async_mode.async_allreduce(tensor) + + @property + def pipeline_parallel_group(self): + return self._pipeline_parallel_group + + @property + def pipeline_parallel_size(self) -> int: + return self._pipeline_parallel_size + + @property + def pipeline_stage_id(self) -> int: + return self._pipeline_stage_id + + def recv_from_prev_stage( + self, recv_tensors: Union[torch.Tensor, Sequence[torch.Tensor]] + ): + """Receive tensor from previous stage in pipeline (forward receive).""" + assert ( + self.pipeline_parallel_group is not None + ), "pipeline_parallel_group is not initialized, pipeline_parallel_size need to be set greater than 1" + + recv_tensors = _check_send_recv_tensors(recv_tensors) + p2p_ops = [] + for recv_tensor in recv_tensors: + recv_op = torch.distributed.P2POp( + torch.distributed.irecv, + recv_tensor, + self._pipeline_parallel_group_prev_rank, + group=self.pipeline_parallel_group, + ) + p2p_ops.append(recv_op) + + if len(p2p_ops) > 0: + reqs = torch.distributed.batch_isend_irecv(p2p_ops) + for req in reqs: + req.wait() + + def send_to_next_stage( + self, send_tensors: Union[torch.Tensor, Sequence[torch.Tensor]] + ): + """Send tensor to next stage in pipeline (forward send).""" + assert ( + self.pipeline_parallel_group is not None + ), "pipeline_parallel_group is not initialized, pipeline_parallel_size need to be set greater than 1" + + send_tensors = _check_send_recv_tensors(send_tensors) + p2p_ops = [] + for tensor in send_tensors: + send_op = torch.distributed.P2POp( + torch.distributed.isend, + tensor, + self._pipeline_parallel_group_next_rank, + group=self.pipeline_parallel_group, + ) + p2p_ops.append(send_op) + + if len(p2p_ops) > 0: + reqs = torch.distributed.batch_isend_irecv(p2p_ops) + for req in reqs: + req.wait() + + def send_and_recv_between_neighborhoods( + self, + recv_tensors: Union[torch.Tensor, Sequence[torch.Tensor]], + send_tensors: Union[torch.Tensor, Sequence[torch.Tensor]], + ): + assert ( + self.pipeline_parallel_group is not None + ), "pipeline_parallel_group is not initialized, pipeline_parallel_size need to be set greater than 1" + + recv_tensors = _check_send_recv_tensors(recv_tensors) + send_tensors = _check_send_recv_tensors(send_tensors) + p2p_ops = [] + + for recv_tensor in recv_tensors: + recv_op = torch.distributed.P2POp( + torch.distributed.irecv, + recv_tensor, + self._pipeline_parallel_group_prev_rank, + group=self.pipeline_parallel_group, + ) + p2p_ops.append(recv_op) + + for send_tensor in send_tensors: + send_op = torch.distributed.P2POp( + torch.distributed.isend, + send_tensor, + self._pipeline_parallel_group_next_rank, + group=self.pipeline_parallel_group, + ) + p2p_ops.append(send_op) + + if len(p2p_ops) > 0: + reqs = torch.distributed.batch_isend_irecv(p2p_ops) + for req in reqs: + req.wait() \ No newline at end of file diff --git a/src/diffusers/parallel_context.py b/src/diffusers/parallel_context.py new file mode 100644 index 000000000000..22a959cee4b7 --- /dev/null +++ b/src/diffusers/parallel_context.py @@ -0,0 +1,124 @@ +import torch + +from abc import ABC, abstractmethod +from typing import Optional, Union, Sequence, Tuple, ContextManager +from contextlib import nullcontext +from enum import Enum + + + +class AsyncMode(Enum): + OVERLAPPING = 1 + + +class AsyncModeContext(ABC): + @abstractmethod + def __enter__(self): + pass + + @abstractmethod + def __exit__(self, exc_type, exc_value, traceback): + pass + + @abstractmethod + def gather(self, *tasks) -> Tuple: + pass + + +class ParallelContext(ABC): + @abstractmethod + def backend(self) -> str: + pass + + @abstractmethod + def device_index(self) -> int: + pass + + @abstractmethod + def torch_device(self) -> torch.device: + pass + + @abstractmethod + def send_recv_comm_device(self) -> Union[str, torch.device]: + pass + + @property + @abstractmethod + def size(self) -> int: + pass + + @abstractmethod + def allreduce(self, tensor: torch.Tensor, red_op: torch.distributed.ReduceOp): + pass + + @property + @abstractmethod + def tensor_parallel_size(self) -> int: + pass + + @property + @abstractmethod + def tensor_parallel_rank(self) -> int: + pass + + @abstractmethod + def broadcast_in_tensor_parallel_group(self, tensor: torch.Tensor): + pass + + @abstractmethod + def allreduce_in_tensor_parallel_group(self, tensor: torch.Tensor): + pass + + @abstractmethod + def allgather_in_tensor_parallel_group( + self, input: torch.Tensor, output: torch.Tensor + ): + pass + + @abstractmethod + def async_mode(self, async_mode: AsyncMode) -> AsyncModeContext: + pass + + @abstractmethod + async def async_allreduce_in_tensor_parallel_group(self, tensor: torch.Tensor): + pass + + @property + @abstractmethod + def pipeline_parallel_size(self) -> int: + pass + + @property + @abstractmethod + def pipeline_stage_id(self) -> int: + pass + + def in_pipeline_first_stage(self) -> bool: + return self.pipeline_stage_id == 0 + + def in_pipeline_last_stage(self) -> bool: + return self.pipeline_stage_id == self.pipeline_parallel_size - 1 + + @abstractmethod + def recv_from_prev_stage( + self, recv_tensors: Union[torch.Tensor, Sequence[torch.Tensor]] + ): + pass + + @abstractmethod + def send_to_next_stage( + self, send_tensors: Union[torch.Tensor, Sequence[torch.Tensor]] + ): + pass + + @abstractmethod + def send_and_recv_between_neighborhoods( + self, + recv_tensors: Union[torch.Tensor, Sequence[torch.Tensor]], + send_tensors: Union[torch.Tensor, Sequence[torch.Tensor]], + ): + pass + + def tensor_parallel_reduce_context(self) -> ContextManager[None]: + return nullcontext() + diff --git a/src/diffusers/parallel_state.py b/src/diffusers/parallel_state.py new file mode 100644 index 000000000000..591bad5bfc71 --- /dev/null +++ b/src/diffusers/parallel_state.py @@ -0,0 +1,209 @@ +from typing import List, Optional +import torch +import torch.distributed +from diffusers.group_coordinator import ( + GroupCoordinator, + PipelineParallelGroupCoordinator, + SequenceParallelGroupCoordinator, + RankGenerator, + generate_masked_orthogonal_rank_groups +) +import os +from yunchang import set_seq_parallel_pg +from yunchang.globals import PROCESS_GROUP + +_WORLD: Optional[GroupCoordinator] = None +_SP: Optional[SequenceParallelGroupCoordinator] = None +_PP: Optional[PipelineParallelGroupCoordinator] = None +_CFG: Optional[GroupCoordinator] = None + +def get_world_group() -> GroupCoordinator: + return _WORLD + +# SP +def get_sp_group() -> SequenceParallelGroupCoordinator: + return _SP + +def get_sequence_parallel_world_size(): + """Return world size for the sequence parallel group.""" + return get_sp_group().world_size + + +def get_sequence_parallel_rank(): + """Return my rank for the sequence parallel group.""" + return get_sp_group().rank_in_group + + +def get_ulysses_parallel_world_size(): + return get_sp_group().ulysses_world_size + + +def get_ulysses_parallel_rank(): + return get_sp_group().ulysses_rank + + +def get_ring_parallel_world_size(): + return get_sp_group().ring_world_size + +def get_ring_parallel_rank(): + return get_sp_group().ring_rank + +# CFG +def get_cfg_group() -> GroupCoordinator: + return _CFG + +def get_classifier_free_guidance_world_size(): + """Return world size for the classifier_free_guidance parallel group.""" + return get_cfg_group().world_size + + +def get_classifier_free_guidance_rank(): + """Return my rank for the classifier_free_guidance parallel group.""" + return get_cfg_group().rank_in_group + + + +# PP +def get_pp_group() -> GroupCoordinator: + return _PP + + +def get_pipeline_parallel_world_size(): + """Return world size for the pipeline model parallel group.""" + return get_pp_group().world_size + + +def get_pipeline_parallel_rank(): + """Return my rank for the pipeline model parallel group.""" + return get_pp_group().rank_in_group + + +def is_pipeline_first_stage(): + """Return True if in the first pipeline model parallel stage, False otherwise.""" + return get_pipeline_parallel_rank() == 0 + + +def is_pipeline_last_stage(): + """Return True if in the last pipeline model parallel stage, False otherwise.""" + return get_pipeline_parallel_rank() == (get_pipeline_parallel_world_size() - 1) + + +def init_world_group( + ranks: List[int], local_rank: int, backend: str +) -> GroupCoordinator: + return GroupCoordinator( + group_ranks=[ranks], + local_rank=local_rank, + torch_distributed_backend=backend, + ) + +def init_model_parallel_group( + group_ranks: List[List[int]], + local_rank: int, + backend: "nccl", + parallel_mode: str, + **kwargs, +) -> GroupCoordinator: + assert parallel_mode in [ + "data", + "pipeline", + "tensor", + "sequence", + "classifier_free_guidance", + ], f"parallel_mode {parallel_mode} is not supported" + if parallel_mode == "pipeline": + return PipelineParallelGroupCoordinator( + group_ranks=group_ranks, + local_rank=local_rank, + torch_distributed_backend=backend, + ) + elif parallel_mode == "sequence": + return SequenceParallelGroupCoordinator( + group_ranks=group_ranks, + local_rank=local_rank, + torch_distributed_backend=backend, + **kwargs, + ) + else: + return GroupCoordinator( + group_ranks=group_ranks, + local_rank=local_rank, + torch_distributed_backend=backend, + ) + + +def initialize_model_parallel( + classifier_free_guidance_degree: int = 1, + # sequence_parallel_degree: int = 1, + ulysses_degree: int = 1, + ring_degree: int = 1, + # tensor_parallel_degree: int = 1, + pipeline_parallel_degree: int = 1, +) -> None: + sequence_parallel_degree = ulysses_degree * ring_degree + rank_generator: RankGenerator = RankGenerator( + # tensor_parallel_degree, + sequence_parallel_degree, + pipeline_parallel_degree, + classifier_free_guidance_degree, + # data_parallel_degree, + "sp-pp-cfg", + ) + global _CFG + global _PP + global _SP + backend = "nccl" + _CFG = init_model_parallel_group( + group_ranks=rank_generator.get_ranks("cfg"), + local_rank=get_world_group().local_rank, + backend=backend, + parallel_mode="classifier_free_guidance", + ) + grouprankcfg = rank_generator.get_ranks("cfg") + grouprankpp = rank_generator.get_ranks("pp") + groupranksp = rank_generator.get_ranks("sp") + print(f"cfg group ranks is {grouprankcfg}, pp group ranks is {grouprankpp}, sp group rank is {groupranksp}") + + _PP = init_model_parallel_group( + group_ranks=rank_generator.get_ranks("pp"), + local_rank=get_world_group().local_rank, + backend=backend, + parallel_mode="pipeline", + ) + + set_seq_parallel_pg( + sp_ulysses_degree=ulysses_degree, + sp_ring_degree=ring_degree, + rank=get_world_group().rank_in_group, + world_size=get_world_group().world_size, + ) + _SP = init_model_parallel_group( + group_ranks=rank_generator.get_ranks("sp"), + local_rank=get_world_group().local_rank, + backend=backend, + parallel_mode="sequence", + ulysses_group=PROCESS_GROUP.ULYSSES_PG, + ring_group=PROCESS_GROUP.RING_PG, + ) + print(f"ulysses group is {PROCESS_GROUP.ULYSSES_PG}, ring group is {PROCESS_GROUP.RING_PG}") + print(f"ulysses world size is {_SP.ulysses_world_size}, ring group is {_SP.ring_world_size}") + + +def init_distributed_environment( + world_size: int = -1, + rank: int = -1, + distributed_init_method: str = "env://", + local_rank: int = -1, + backend: str = "nccl", +): + local_rank= int(os.environ.get("LOCAL_RANK", "0")) + torch.distributed.init_process_group( + backend="nccl", + init_method="env://", + world_size=-1, + rank=-1, + ) + torch.cuda.set_device(local_rank) + global _WORLD + ranks = list(range(torch.distributed.get_world_size())) + _WORLD = init_world_group(ranks, local_rank, backend="nccl") \ No newline at end of file diff --git a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py index 513f86441c3a..da7e6e5f3cfd 100644 --- a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +++ b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py @@ -23,6 +23,9 @@ T5TokenizerFast, ) +from loguru import logger +import torch.distributed as dist + from ...image_processor import VaeImageProcessor from ...loaders import FromSingleFileMixin, SD3LoraLoaderMixin from ...models.autoencoders import AutoencoderKL @@ -40,6 +43,34 @@ from ..pipeline_utils import DiffusionPipeline from .pipeline_output import StableDiffusion3PipelineOutput +from diffusers.torch_parallel_context import TorchBasedParallelContext +from diffusers.postal_service import ( + Shipment, + PostalService, +) +from diffusers.parallel_state import ( + get_cfg_group, + get_classifier_free_guidance_rank, + get_classifier_free_guidance_world_size, + get_pipeline_parallel_rank, + get_pipeline_parallel_world_size, + get_pp_group, + get_ring_parallel_rank, + get_ring_parallel_world_size, + get_sequence_parallel_rank, + get_sequence_parallel_world_size, + get_sp_group, + get_ulysses_parallel_rank, + get_ulysses_parallel_world_size, + get_world_group, + generate_masked_orthogonal_rank_groups, + is_pipeline_first_stage, + is_pipeline_last_stage, +) +from diffusers.runtime_state import( + get_runtime_state, +) + if is_torch_xla_available(): import torch_xla.core.xla_model as xm @@ -179,6 +210,21 @@ def __init__( tokenizer_3: T5TokenizerFast, ): super().__init__() + # ---------------ADD BELOW--------------- + + self.pp_degree = get_pp_group().world_size + self.local_rank = get_pp_group().local_rank + import os + self.parallel_ctx = TorchBasedParallelContext( + ranks=get_world_group().ranks, + pipeline_parallel_size=self.pp_degree, + tensor_parallel_size=get_cfg_group().world_size, + device_index=int(os.environ.get("LOCAL_RANK", "0")) + ) + self.post_office = PostalService(self.parallel_ctx) + self.shipment = Shipment(self.parallel_ctx.torch_device()) + + #----------------ADD ABOVE------------------ self.register_modules( vae=vae, @@ -668,6 +714,31 @@ def num_timesteps(self): @property def interrupt(self): return self._interrupt + + def _process_cfg_split_batch( + self, + concat_group_0_negative: torch.Tensor, + concat_group_0: torch.Tensor, + concat_group_1_negative: torch.Tensor, + concat_group_1: torch.Tensor, + ): + r""" + if not using cfg parallel, then the function return as previously commented version. + if use cfg parallel, this function assign rank 0 device with negative prompt (namely null prompty “”), + and assign rank1 device with positive prompt. + """ + if get_classifier_free_guidance_world_size() == 1: + concat_group_0 = torch.cat([concat_group_0_negative, concat_group_0], dim=0) + concat_group_1 = torch.cat([concat_group_1_negative, concat_group_1], dim=0) + elif get_classifier_free_guidance_rank() == 0: + concat_group_0 = concat_group_0_negative + concat_group_1 = concat_group_1_negative + elif get_classifier_free_guidance_rank() == 1: + concat_group_0 = concat_group_0 + concat_group_1 = concat_group_1 + else: + raise ValueError("Invalid classifier free guidance rank") + return concat_group_0, concat_group_1 @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) @@ -814,6 +885,8 @@ def __call__( height = height or self.default_sample_size * self.vae_scale_factor width = width or self.default_sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, @@ -875,18 +948,29 @@ def __call__( lora_scale=lora_scale, ) + # if self.do_classifier_free_guidance: + # if skip_guidance_layers is not None: + # original_prompt_embeds = prompt_embeds + # original_pooled_prompt_embeds = pooled_prompt_embeds + # prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + # pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0) + # --------------- MODIFIED HERE -------------- if self.do_classifier_free_guidance: - if skip_guidance_layers is not None: - original_prompt_embeds = prompt_embeds - original_pooled_prompt_embeds = pooled_prompt_embeds - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) - pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0) + ( + prompt_embeds, + pooled_prompt_embeds, + ) = self._process_cfg_split_batch( + negative_prompt_embeds, + prompt_embeds, + negative_pooled_prompt_embeds, + pooled_prompt_embeds, + ) + # 4. Prepare timesteps timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas=sigmas) num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) self._num_timesteps = len(timesteps) - # 5. Prepare latent variables num_channels_latents = self.transformer.config.in_channels latents = self.prepare_latents( @@ -899,56 +983,127 @@ def __call__( generator, latents, ) + # torch.cuda.synchronize() + # 6. Denoising loop + num_pipeline_warmup_steps = 1 with self.progress_bar(total=num_inference_steps) as progress_bar: - for i, t in enumerate(timesteps): - if self.interrupt: - continue - - # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents - # broadcast to batch dimension in a way that's compatible with ONNX/Core ML - timestep = t.expand(latent_model_input.shape[0]) - - noise_pred = self.transformer( - hidden_states=latent_model_input, - timestep=timestep, - encoder_hidden_states=prompt_embeds, - pooled_projections=pooled_prompt_embeds, - joint_attention_kwargs=self.joint_attention_kwargs, - return_dict=False, - )[0] - - # perform guidance - if self.do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) - should_skip_layers = ( - True - if i > num_inference_steps * skip_layer_guidance_start - and i < num_inference_steps * skip_layer_guidance_stop - else False - ) - if skip_guidance_layers is not None and should_skip_layers: - timestep = t.expand(latents.shape[0]) - latent_model_input = latents - noise_pred_skip_layers = self.transformer( - hidden_states=latent_model_input, - timestep=timestep, - encoder_hidden_states=original_prompt_embeds, - pooled_projections=original_pooled_prompt_embeds, - joint_attention_kwargs=self.joint_attention_kwargs, - return_dict=False, - skip_layers=skip_guidance_layers, - )[0] - noise_pred = ( - noise_pred + (noise_pred_text - noise_pred_skip_layers) * self._skip_layer_guidance_scale - ) + if ( + get_pipeline_parallel_world_size() > 1 and len(timesteps) > num_pipeline_warmup_steps + ): + latents = self._sync_pipeline( + latents=latents, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + timesteps=timesteps[:num_pipeline_warmup_steps], + num_warmup_steps=num_pipeline_warmup_steps, + progress_bar=progress_bar, + callback_on_step_end=callback_on_step_end, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + ) + # * pipefusion stage + latents = self._async_pipeline( + latents=latents, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + timesteps=timesteps[num_pipeline_warmup_steps:], + num_warmup_steps=num_warmup_steps, + progress_bar=progress_bar, + callback_on_step_end=callback_on_step_end, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + num_pipeline_warmup_steps=num_pipeline_warmup_steps, + ) + else: + latents = self._sync_pipeline( + latents=latents, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + timesteps=timesteps, + num_warmup_steps=num_warmup_steps, + progress_bar=progress_bar, + callback_on_step_end=callback_on_step_end, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + sync_only=True, + ) + + + # if output_type == "latent": + # image = latents + + # else: + # latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + + # image = self.vae.decode(latents, return_dict=False)[0] + # image = self.image_processor.postprocess(image, output_type=output_type) + # ------------------MODIFIED HERE--------------- + if not output_type == "latent": + if is_pipeline_last_stage(): + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + image = self.vae.decode(latents, return_dict=False)[0] + + if is_pipeline_last_stage(): + if output_type == "latent": + image = latents + else: + image = self.image_processor.postprocess(image, output_type=output_type) + + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return StableDiffusion3PipelineOutput(images=image) + else: + return None - # compute the previous noisy sample x_t -> x_t-1 + def _sync_pipeline( + self, + latents: torch.Tensor, + prompt_embeds: torch.Tensor, + pooled_prompt_embeds: torch.Tensor, + timesteps: List[int], + num_warmup_steps: int, + progress_bar, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + sync_only: bool = False, + ): + get_runtime_state().set_patch_mode(patch_mode=False) + for i, t in enumerate(timesteps): + if self.interrupt: + continue + if is_pipeline_last_stage(): + last_timestep_latents = latents + + if get_pipeline_parallel_world_size() > 1: + if not is_pipeline_first_stage(): + package = self.post_office.recv_shipment() + latents = package.content["hidden_states"] + encoder_hidden_states = package.content["encoder_hidden_states"] + elif not i == 0: + latents = self.post_office.recv_shipment().content["hidden_states"] + # for the first timestep of the first stage, need not recv anything. + + + latents, encoder_hidden_states = self._backbone_forward( + latents=latents, + encoder_hidden_states=( + prompt_embeds + if is_pipeline_first_stage() + else encoder_hidden_states + ), + pooled_prompt_embeds=pooled_prompt_embeds, + t=t, + ) + + + # compute the previous noisy sample x_t -> x_t-1 + if is_pipeline_last_stage(): latents_dtype = latents.dtype - latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + latents = self.scheduler.step(latents, t, last_timestep_latents, return_dict=False)[0] if latents.dtype != latents_dtype: if torch.backends.mps.is_available(): @@ -968,26 +1123,179 @@ def __call__( "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds ) - # call the callback, if provided - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): - progress_bar.update() + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if get_pipeline_parallel_world_size() > 1: + if is_pipeline_last_stage(): + if i == self._num_timesteps-1: + pass + else: + self.shipment.update({"hidden_states": latents}) + self.post_office.send_shipment(self.shipment) + else: + self.shipment.update({"hidden_states": latents}) + self.shipment.update({"encoder_hidden_states": encoder_hidden_states}) + self.post_office.send_shipment(self.shipment) + self.shipment.clear() - if XLA_AVAILABLE: - xm.mark_step() + return latents + + def _init_async_pipeline( + self, + latents: torch.Tensor, + ): + # split the latents into patches. + get_runtime_state().set_patch_mode(patch_mode=True) + if is_pipeline_first_stage(): + latents = self.post_office.recv_shipment().content["hidden_states"] + patch_latents = list(latents.split(get_runtime_state().pp_patches_height, dim=2)) + elif is_pipeline_last_stage(): + patch_latents = list(latents.split(get_runtime_state().pp_patches_height, dim=2)) + else: + patch_latents = [None for _ in range(get_runtime_state().num_pipeline_patch)] - if output_type == "latent": - image = latents + return patch_latents - else: - latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + def _async_pipeline( + self, + latents: torch.Tensor, + prompt_embeds: torch.Tensor, + pooled_prompt_embeds: torch.Tensor, + timesteps: List[int], + num_warmup_steps: int, + progress_bar, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + num_pipeline_warmup_steps=1, + ): + patch_latents = self._init_async_pipeline(latents=latents) + num_pipeline_patch = get_pp_group().world_size + last_patch_latents = [None for _ in range(num_pipeline_patch)] # xinze: we need it for scheduler to do stepping. + for i, t in enumerate(timesteps): + for patch_idx in range(num_pipeline_patch): + if is_pipeline_last_stage(): + last_patch_latents[patch_idx] = patch_latents[patch_idx] + + + if is_pipeline_first_stage() and i == 0: + pass + else: + package = self.post_office.recv_shipment() + if not is_pipeline_first_stage() and patch_idx == 0: + last_encoder_hidden_states = package.content["encoder_hidden_states"] + patch_latents[patch_idx] = package.content["hidden_states"] + + + patch_latents[patch_idx], next_encoder_hidden_states = ( + self._backbone_forward( + latents=patch_latents[patch_idx], + encoder_hidden_states=( + prompt_embeds + if is_pipeline_first_stage() + else last_encoder_hidden_states + ), + pooled_prompt_embeds=pooled_prompt_embeds, + t=t, + ) + ) - image = self.vae.decode(latents, return_dict=False)[0] - image = self.image_processor.postprocess(image, output_type=output_type) + if is_pipeline_last_stage(): + latents_dtype = patch_latents[patch_idx].dtype + patch_latents[patch_idx] = self.scheduler.step(patch_latents[patch_idx], \ + t, last_patch_latents[patch_idx], return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end( + self, i, t, callback_kwargs + ) - # Offload all models - self.maybe_free_model_hooks() + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop( + "prompt_embeds", prompt_embeds + ) + negative_prompt_embeds = callback_outputs.pop( + "negative_prompt_embeds", negative_prompt_embeds + ) + negative_pooled_prompt_embeds = callback_outputs.pop( + "negative_pooled_prompt_embeds", + negative_pooled_prompt_embeds, + ) + + if not is_pipeline_last_stage() and patch_idx == 0: + self.shipment.update({"encoder_hidden_states": next_encoder_hidden_states}) + if not is_pipeline_last_stage() or i != len(timesteps) - 1: + self.shipment.update({"hidden_states": patch_latents[patch_idx]}) + if is_pipeline_last_stage() and i == len(timesteps) - 1: + pass + else: + self.post_office.send_shipment(self.shipment) + self.shipment.clear() + + get_runtime_state().next_patch() + + if i == len(timesteps) - 1 or ( + (i + num_pipeline_warmup_steps + 1) > num_warmup_steps + and (i + num_pipeline_warmup_steps + 1) % self.scheduler.order == 0 + ): + progress_bar.update() + + + if XLA_AVAILABLE and is_pipeline_last_stage(): + xm.mark_step() + + latents = torch.cat(patch_latents, dim=2) + return latents - if not return_dict: - return (image,) + def _backbone_forward( + self, + latents: torch.Tensor, + encoder_hidden_states: torch.Tensor, + pooled_prompt_embeds: torch.Tensor, + t: Union[float, torch.Tensor], + ): + if is_pipeline_first_stage(): + latents = torch.cat( + [latents] * (2 // get_classifier_free_guidance_world_size()) + ) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]) + + noise_pred, encoder_hidden_states = self.transformer( + hidden_states=latents, + timestep=timestep, + encoder_hidden_states=encoder_hidden_states, + pooled_projections=pooled_prompt_embeds, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] + + # classifier free guidance + if is_pipeline_last_stage(): + # xinze: pp last stage means a time step is about to end + if get_classifier_free_guidance_world_size() == 1: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + elif get_classifier_free_guidance_world_size() == 2: + noise_pred_uncond, noise_pred_text = get_cfg_group().all_gather( + noise_pred, separate_tensors=True + ) + latents = noise_pred_uncond + self.guidance_scale * ( + noise_pred_text - noise_pred_uncond + ) + else: + latents = noise_pred - return StableDiffusion3PipelineOutput(images=image) + return latents, encoder_hidden_states \ No newline at end of file diff --git a/src/diffusers/postal_service.py b/src/diffusers/postal_service.py new file mode 100644 index 000000000000..eb8038f97874 --- /dev/null +++ b/src/diffusers/postal_service.py @@ -0,0 +1,524 @@ +import torch + +from typing import List, Union, Optional +from diffusers.parallel_context import ParallelContext +from loguru import logger + +support_type = [ + torch.float32, + torch.float64, + torch.float16, + torch.uint8, + torch.int8, + torch.int16, + torch.int32, + torch.int64, + torch.bool, + torch.complex64, + torch.complex128, + torch.bfloat16, + torch.cdouble, + torch.quint8, + torch.qint8, + torch.qint32, +] + +type2int = {} +for idx, dtype in enumerate(support_type): + type2int[dtype] = idx + +int64_type = torch.int64 +int64_byte = torch.tensor([], dtype=int64_type).element_size() +byte_type = torch.uint8 +min_copy_size = 4096 + + +def is_cpu(device_type): + return str(device_type) == "cpu" + + +def string2list(s: str): + return list(s.encode()) + + +def byte_list2string(t: List[int]): + return bytes(t).decode() + + +def patch2long_byte(p): + return (-p) % int64_byte + + +def jump2patch(p): + return p + patch2long_byte(p) + + +def patching(p, tensor_list: List[torch.Tensor], device="cpu"): + patch_size = patch2long_byte(p) + if patch_size > 0: + tensor_list.append(torch.empty(patch_size, dtype=byte_type, device=device)) + + +def sum_tensor_len(tensor_list: List[torch.Tensor]): + sum_len = 0 + for tensor in tensor_list: + sum_len += len(tensor) + return sum_len + + +class Shipment: + def __init__(self, device: torch.device, buffer_device: Optional[torch.device]=None): + # Record the device for accelerator + # `device` is the device for accelerating computing and `buffer_device` is device for buffer to be sent/received + self._device = device + self._buffer_device = buffer_device or device + # Buffer and shipment on cpu should be tensor + # Use [] to speed up + # We do not use None since len(None) would cause error + # The structure of buffer, which is related to pack and unpack. + # buffer = torch.tensor([the number of package n| package 0 | package 1 | ... | package n]) + # Each package has specific format: + # package = [len(meta) | len(cpu tensors) | len(gpu tensors) | meta | cpu tensors | gpu tensors] + # meta is the information of tensor, which has structure: + # meta = [len(name) | name | dtype | ndim | shape | device | len(tensor)] + self._buffer = [] + # Items packed in the buffer + self._packed_items = {} + # The store position on the buffer + self._pack_item2position = {} + # Not packed items + self._warehouse = {} + # Items packed but replaced by new items + self._replaced_items = {} + + # Local variable for pack + self.meta_int_list: List[int] = [] + # Local variable for look through items, or say, unpack + # The part of buffer sent from gpu to cpu + self.shipment_on_cpu = [] + # The point on gpu shipment marking the starting position of shipment on cpu + self.cpu_shipment_start = 0 + # Convert the shipment on cpu from bytes to int64 + self.shipment_in_list = [] + # The pointer in shipment in list, which is used to mark the reading position of meta information + # The data is int64, adding 1 to pointer_in_list is equal to move 8 bytes with a cpu or gpu pointer (with type bytes) + self.pointer_in_list = 0 + + # _bundling_single_int, _bundling, _pack are 3 functions related to pack + def _bundling_single_int(self, single_int: int): + self.meta_int_list.append(single_int) + + def _bundling(self, list_1d: Union[List[int], torch.Size]): + self._bundling_single_int(len(list_1d)) + self.meta_int_list.extend(list_1d) + + def _pack(self): + if len(self._replaced_items) > 0: + for key, value in self._replaced_items.items(): + data_start_end = self._pack_item2position[key] + self._buffer[data_start_end[0] : data_start_end[1]] = ( + value.view(-1).view(byte_type).to(self._buffer_device) + ) + + self._replaced_items.clear() + + # No new items + if len(self._warehouse) == 0: + return + + # Update total pacakge number + if len(self._buffer) > 0: + self._buffer[0:int64_byte].view(int64_type)[0] += 1 + + # Convert self._warehouse to byte tensors + self.meta_int_list = [] + should_on_cpu2data_tensor_list = {True: [], False: []} + self._bundling_single_int(len(self._warehouse)) + new_item2position = {} + storage_start_position = 0 + # Go through all the cpu tensors before dealing with any gpu tensors + for should_on_cpu in [True, False]: + for key, value in self._warehouse.items(): + if is_cpu(value.device.type) == should_on_cpu: + # name + self._bundling(string2list(key)) + # data type + self._bundling_single_int(type2int[value.dtype]) + # shape + self._bundling(value.size()) + # device + self._bundling_single_int(int(should_on_cpu)) + # the length of value viewed as byte tensor + value_view_as_1d_byte_data = ( + value.contiguous().view(-1).view(byte_type) + ) + value_view_as_1d_byte_data_size = len(value_view_as_1d_byte_data) + self._bundling_single_int(value_view_as_1d_byte_data_size) + # data + should_on_cpu2data_tensor_list[should_on_cpu].append( + value_view_as_1d_byte_data + ) + # storage pointers of start and end + storage_end_position = ( + storage_start_position + value_view_as_1d_byte_data_size + ) + new_item2position[key] = [ + storage_start_position, + storage_end_position, + ] + storage_start_position = jump2patch(storage_end_position) + # An empty tensor patch, to make the offset divisible by int64_byte + patching( + value_view_as_1d_byte_data_size, + should_on_cpu2data_tensor_list[should_on_cpu], + value.device, + ) + + # Calculate tensor length + total_package_len = [ + len(self.meta_int_list) * int64_byte, + sum_tensor_len(should_on_cpu2data_tensor_list[True]), + sum_tensor_len(should_on_cpu2data_tensor_list[False]), + ] + self.meta_int_list = total_package_len + self.meta_int_list + # Concatenate cpu tensors first + if len(self._buffer) == 0: + self.meta_int_list.insert(0, 1) + cpu_data_tensor_list = should_on_cpu2data_tensor_list[True] + if len(cpu_data_tensor_list) == 0: + concat_cpu_tensor = torch.tensor( + self.meta_int_list, dtype=int64_type, device=self._buffer_device + ).view(byte_type) + # Place the meta information before cpu and gpu tensors, move the offset correspondingly + offset = len(concat_cpu_tensor) + else: + cpu_data_tensor_list.insert( + 0, torch.tensor(self.meta_int_list, dtype=int64_type).view(byte_type) + ) + offset = len(cpu_data_tensor_list[0]) + concat_cpu_tensor = torch.cat(cpu_data_tensor_list).to(self._buffer_device) + # Concatenate gpu tensors + gpu_data_tensor_list = should_on_cpu2data_tensor_list[False] + if is_cpu(self._buffer_device) and len(gpu_data_tensor_list) > 0: + gpu_data_tensor_list = [torch.cat(gpu_data_tensor_list).to(self._buffer_device)] + gpu_data_tensor_list.insert(0, concat_cpu_tensor) + if len(self._buffer) > 0: + gpu_data_tensor_list.insert(0, self._buffer) + # Place the last package before this package, move the offset correspondingly + offset += len(self._buffer) + self._buffer = torch.cat(gpu_data_tensor_list) + self.meta_int_list.clear() + # Move the items from warehouse to the shipment + self._packed_items.update(self._warehouse) + self._warehouse.clear() + # Maintain the packed item storage pointers + for key, pointers in new_item2position.items(): + self._pack_item2position[key] = [pointers[0] + offset, pointers[1] + offset] + + # _unbundling_single_int, _unbunding, _deliver_to_host, _unpack are 4 functions related to unpack + def _unbundling_single_int(self) -> int: + single_int = self.shipment_in_list[self.pointer_in_list] + self.pointer_in_list += 1 + return single_int + + def _unbundling(self): + tensor_len = self._unbundling_single_int() + p_start = self.pointer_in_list + # Not using p_end would cause error: + self.pointer_in_list += tensor_len + item = self.shipment_in_list[p_start : self.pointer_in_list] + return item + + def _deliver_to_host(self, size=min_copy_size): + if size < min_copy_size: + size = min_copy_size + self.cpu_shipment_start += self.pointer_in_list * int64_byte + p_end = min(self.cpu_shipment_start + size, len(self._buffer)) + self.shipment_on_cpu = self._buffer[self.cpu_shipment_start : p_end].to("cpu") + self.shipment_in_list = self.shipment_on_cpu.view(int64_type).tolist() + self.pointer_in_list = 0 + + def _look_through_items(self, do_something): + self.pointer_in_list = 0 + self.cpu_shipment_start = 0 + self._deliver_to_host() + # Total pacakge number + total_package_num = self._unbundling_single_int() + for package_id in range(total_package_num): + # Get length of meta, cpu tensors, gpu tensors + total_meta_len = self._unbundling_single_int() + total_cpu_tensor_len = self._unbundling_single_int() + total_gpu_tensor_len = self._unbundling_single_int() + # At least deliver this whole package to host + if ( + len(self.shipment_on_cpu) - self.pointer_in_list * int64_byte + < total_meta_len + total_cpu_tensor_len + ): + self._deliver_to_host(total_meta_len + total_cpu_tensor_len) + # Assign pointers, which is related to the position of shipment copy + # The pointer on shipment_on_cpu, marking the reading position of cpu tensors + p_cpu = self.pointer_in_list * int64_byte + total_meta_len + # The pointer on shipment on gpu, marking the reading position of gpu tensors + p_gpu = self.cpu_shipment_start + p_cpu + total_cpu_tensor_len + # The end of this package + # total_gpu_tensor_len % 8 == 0, we do not need to use jump2patch + p_end = p_gpu + total_gpu_tensor_len + # The end position of meta information, which is also the start position of the cpu tensors + p_meta_end = p_cpu + # look into the shipment + total_item_num = self._unbundling_single_int() + for item_id in range(total_item_num): + # name + name_list = self._unbundling() + name = byte_list2string(name_list) + # data type + dtype_id = self._unbundling_single_int() + value_dtype = support_type[dtype_id] + # shape + # It supports scalar. + # For example: torch.tensor([42]).view([]) == torch.tensor(42) + shape = self._unbundling() + # device + on_cpu = self._unbundling_single_int() + # data length + data_len = self._unbundling_single_int() + # p_gpu_end = p_gpu + data_len + # data + if on_cpu: + # on cpu + p_cpu_end = p_cpu + data_len + data = ( + self.shipment_on_cpu[p_cpu:p_cpu_end] + .view(value_dtype) + .view(shape) + ) + data_memory_start = self.cpu_shipment_start + p_cpu + data_memory_end = self.cpu_shipment_start + p_cpu_end + # Move the cpu pointer to the next cpu tensor + p_cpu = jump2patch(p_cpu_end) + else: + # on cuda + p_gpu_end = p_gpu + data_len + data = self._buffer[p_gpu:p_gpu_end].view(value_dtype).view(shape).to(self._device) + data_memory_start = p_gpu + data_memory_end = p_gpu_end + # Move the gpu pointer to the next gpu tensor + p_gpu = jump2patch(p_gpu_end) + # do something + # Return False means keep going through the rest of the shipment + # Return True means stopping looking for the item + if do_something(name, data, data_memory_start, data_memory_end): + return + + assert ( + self.pointer_in_list * int64_byte == p_meta_end + ), "expect self.pointer_in_list({}) * int64_byte({}) == p_meta_end({})".format( + self.pointer_in_list, int64_byte, p_meta_end + ) + assert p_gpu == p_end, "p_gpu ({}) must be equal to p_end ({})".format( + p_gpu, p_end + ) + # Move the meta pointer to the start of the next package + self.pointer_in_list = (p_end - self.cpu_shipment_start) // int64_byte + # Have not finish the package + # and still need to take 3 integers from the cpu shipment + if package_id < total_package_num - 1 and self.pointer_in_list + 3 > len( + self.shipment_in_list + ): + self._deliver_to_host() + + def _unpack(self): + def take_item_out(name, data, p_gpu, p_gpu_end): + self._packed_items[name] = data + self._pack_item2position[name] = [p_gpu, p_gpu_end] + # Return False means keep going through the rest of the shipment + return False + + self._look_through_items(take_item_out) + + # Make sure the item is packed before using this function + def _remove_packed_item(self, item_name): + self._warehouse.update(self._packed_items) + self._warehouse.update(self._replaced_items) + del self._warehouse[item_name] + self._buffer = [] + self._packed_items.clear() + self._replaced_items.clear() + self._pack_item2position.clear() + + @staticmethod + def from_buffer(buffer: torch.ByteTensor, device: torch.device=None): + device = device or buffer.device + shipment = Shipment(device, buffer_device=buffer.device) + shipment._buffer = buffer + shipment._unpack() + return shipment + + def remove(self, item_name): + if item_name in self._warehouse: + del self._warehouse[item_name] + else: + assert ( + item_name in self._packed_items or item_name in self._replaced_items + ), f"{item_name} not in shipment" + self._remove_packed_item(item_name) + + def update(self, new_items: dict): + if len(new_items) == 0: + return + for key, value in new_items.items(): + assert isinstance( + value, torch.Tensor + ), "parameter of {} must be a Tensor, but got a {}".format(key, type(value)) + assert ( + is_cpu(value.device.type) or value.device == self._device + ), f"This shipment only accept tensors on cpu or {self._device} but gets {value.device}" + + not_found_item_with_same_meta = True + # Packed or replaced + for package in [self._packed_items, self._replaced_items]: + if key in package: + old_value = package[key] + if ( + old_value.dtype == value.dtype + and old_value.size() == value.size() + and old_value.device == value.device + ): + # Replace the item no matter value changed or not + # Different item with the same dtype, shape, device + if package == self._replaced_items: + # Found in the replaced items, directly replace + package[key] = value + else: + # Found in the packed items, move and replace it + del package[key] + self._replaced_items[key] = value + + not_found_item_with_same_meta = False + else: + # Different item with same name, remove the old item first, add the new item later + self._remove_packed_item(key) + + break + + if not_found_item_with_same_meta: + # Add the new item + self._warehouse[key] = value + + @property + def buffer(self): + self._pack() + return self._buffer + + @property + def content(self): + return dict(self._packed_items, **self._warehouse, **self._replaced_items) + + def clear(self): + self._warehouse.clear() + self._buffer = [] + self._packed_items.clear() + self._pack_item2position.clear() + + def is_empty(self) -> bool: + return ( + len(self._warehouse) == 0 + and len(self._packed_items) == 0 + and len(self._replaced_items) == 0 + ) + + +class PostalService: + def __init__(self, parallel_ctx: ParallelContext): + self.parallel_ctx = parallel_ctx + self.send_recv_comm_device = self.parallel_ctx.send_recv_comm_device() + + def send_shipment(self, shipment: Shipment): + # Pack items before shipping + + # logger.info(f"{self.send_recv_comm_device=} is sending shape {len(shipment.buffer)}") + self.parallel_ctx.send_to_next_stage( + torch.tensor( + [len(shipment.buffer)], device=self.send_recv_comm_device + ) + ) + # logger.info(f"{len(shipment.buffer)=} is sending in {self.send_recv_comm_device}") + self.parallel_ctx.send_to_next_stage(shipment.buffer) + # logger.info(f"{self.send_recv_comm_device} has sended") + + def recv_shipment(self) -> Shipment: + + # logger.info(f"{self.send_recv_comm_device=} is receiving shape") + shipment_volume = torch.empty( + [1], + dtype=torch.int64, + device=self.send_recv_comm_device, + requires_grad=False, + ) + self.parallel_ctx.recv_from_prev_stage(shipment_volume) + # logger.info(f"{shipment_volume=} is receiving in {self.send_recv_comm_device=}") + + buffer = torch.empty( + shipment_volume.tolist(), + dtype=byte_type, + device=self.send_recv_comm_device, + requires_grad=False, + ) + self.parallel_ctx.recv_from_prev_stage(buffer) + #logger.info(f"Receive: {buffer.shape=} is buffer") + # logger.info(f"{self.send_recv_comm_device} has received!") + return Shipment.from_buffer(buffer, self.parallel_ctx.torch_device()) + + def exchange_shipment(self, shipment: Shipment) -> Shipment: + # Send and recv the length of shipment simultaneously + recv_shipment_volume = torch.empty( + [1], + dtype=torch.int64, + device=self.send_recv_comm_device, + requires_grad=False, + ) + send_shipment_volume = torch.tensor( + [len(shipment.buffer)], + dtype=torch.int64, + device=self.send_recv_comm_device, + requires_grad=False, + ) + self.parallel_ctx.send_and_recv_between_neighborhoods( + recv_shipment_volume, send_shipment_volume + ) + # Send and recv shipment simultaneously + recv_buffer = torch.empty( + recv_shipment_volume.tolist(), + dtype=byte_type, + device=self.send_recv_comm_device, + requires_grad=False, + ) + send_buffer = shipment.buffer + self.parallel_ctx.send_and_recv_between_neighborhoods(recv_buffer, send_buffer) + return Shipment.from_buffer(recv_buffer, self.parallel_ctx.torch_device()) + + def broadcast_shipment(self, shipment: Shipment) -> Shipment: + if self.parallel_ctx.tensor_parallel_size == 1: + return shipment + + # broadcast shipment volume + shipment_volume = torch.tensor(len(shipment.buffer)).to( + self.parallel_ctx.torch_device() + ) + self.parallel_ctx.broadcast_in_tensor_parallel_group(shipment_volume) + # broadcast shipment buffer + if self.parallel_ctx.tensor_parallel_rank == 0: + shipment_buffer = shipment.buffer + else: + shipment_buffer = torch.empty( + shipment_volume, + device=self.parallel_ctx.torch_device(), + dtype=byte_type, + requires_grad=False, + ) + self.parallel_ctx.broadcast_in_tensor_parallel_group(shipment_buffer) + + if self.parallel_ctx.tensor_parallel_rank == 0: + return shipment + else: + return Shipment.from_buffer(shipment_buffer) diff --git a/src/diffusers/runtime_state.py b/src/diffusers/runtime_state.py new file mode 100644 index 000000000000..279b2747986e --- /dev/null +++ b/src/diffusers/runtime_state.py @@ -0,0 +1,207 @@ +import torch.distributed +import random +from typing import List, Optional, Tuple +from abc import ABCMeta +from argparse import Namespace + +import numpy as np +import torch +# from diffusers.pipelines.pipeline_utils import DiffusionPipeline + +from diffusers.parallel_state import ( + get_pp_group, + get_sp_group, + get_sequence_parallel_rank, + get_sequence_parallel_world_size, +) + +class RuntimeState(metaclass=ABCMeta): + config: Namespace + num_pipeline_patch: int + + def __init__(self, config): + self.config = config + self.num_pipeline_patch = config.pipefusion_parallel_degree + self.warmup_steps = 1 + +class DiTRuntimeState(RuntimeState): + patch_mode: bool + pipeline_patch_idx: int + vae_scale_factor: int + backbone_patch_size: int + pp_patches_height: Optional[List[int]] + pp_patches_start_idx_local: Optional[List[int]] + pp_patches_start_end_idx_global: Optional[List[List[int]]] + pp_patches_token_start_idx_local: Optional[List[int]] + pp_patches_token_start_end_idx_global: Optional[List[List[int]]] + pp_patches_token_num: Optional[List[int]] + max_condition_sequence_length: int + + # def __init__(self, config): + # super().__init__(config) + # self.patch_mode = False + # self.pipeline_patch_idx = 0 + # self._set_model_parameters( + # # vae_scale_factor=pipeline.vae_scale_factor, + # # backbone_patch_size=pipeline.transformer.config.patch_size, + # # backbone_in_channel=pipeline.transformer.config.in_channels, + # # backbone_inner_dim=pipeline.transformer.config.num_attention_heads + # # * pipeline.transformer.config.attention_head_dim, + + + # vae_scale_factor=8, + # backbone_patch_size=2, + # backbone_in_channel=16, + # backbone_inner_dim=24 * 64 + # ) + # self._calc_patches_metadata() + + def __init__(self, pipeline, config): + super().__init__(config) + self.patch_mode = False + self.pipeline_patch_idx = 0 + self._set_model_parameters( + vae_scale_factor=pipeline.vae_scale_factor, + backbone_patch_size=pipeline.transformer.config.patch_size, + backbone_in_channel=pipeline.transformer.config.in_channels, + backbone_inner_dim=pipeline.transformer.config.num_attention_heads + * pipeline.transformer.config.attention_head_dim, + + + # vae_scale_factor=8, + # backbone_patch_size=2, + # backbone_in_channel=16, + # backbone_inner_dim=24 * 64 + ) + self._calc_patches_metadata() + + def next_patch(self): + if self.patch_mode: + self.pipeline_patch_idx += 1 + if self.pipeline_patch_idx == self.num_pipeline_patch: + self.pipeline_patch_idx = 0 + else: + self.pipeline_patch_idx = 0 + + def set_patch_mode(self, patch_mode: bool): + self.patch_mode = patch_mode + self.pipeline_patch_idx = 0 + + def _set_model_parameters( + self, + vae_scale_factor:int, + backbone_patch_size:int, + backbone_inner_dim:int, + backbone_in_channel:int, + ): + self.vae_scale_factor = vae_scale_factor + self.backbone_patch_size = backbone_patch_size + self.backbone_inner_dim = backbone_inner_dim + self.backbone_in_channel = backbone_in_channel + + def _input_size_change( + self, + height: Optional[int] = None, + width: Optional[int] = None, + ): + self.config.height = height or self.config.height + self.config.width = width or self.config.width + self._calc_patches_metadata() + self._reset_recv_buffer() + + def _calc_patches_metadata(self): + num_sp_patches = get_sequence_parallel_world_size() + sp_patch_idx = get_sequence_parallel_rank() + patch_size = self.backbone_patch_size + vae_scale_factor = self.vae_scale_factor + latents_height = self.config.height // vae_scale_factor + latents_width = self.config.width // vae_scale_factor # xinze: for 1024 width and 8 vae factor, latents width is 128 + + pipeline_patches_height = ( + latents_height + self.num_pipeline_patch - 1 + ) // self.num_pipeline_patch + + num_pipeline_patch = ( + latents_height + pipeline_patches_height - 1 + ) // pipeline_patches_height + + pipeline_patches_height_list = [ + pipeline_patches_height for _ in range(num_pipeline_patch-1) + ] + + the_last_pp_patch_height = latents_height - pipeline_patches_height * ( + num_pipeline_patch - 1 + ) + pipeline_patches_height_list.append(the_last_pp_patch_height) + + flatten_patches_height = [ + pp_patch_height // num_sp_patches + for _ in range(num_sp_patches) + for pp_patch_height in pipeline_patches_height_list + ]# xinze: if using sp, then patch will be deeper patched. + flatten_patches_start_idx = [0] + [ + sum(flatten_patches_height[:i]) + for i in range(1, len(flatten_patches_height) + 1) + ]# xinze: for height 1024, pp=sp=2, it is [0, 32 ,64, 96, 128] + pp_sp_patches_height = [ + flatten_patches_height[ + pp_patch_idx * num_sp_patches : (pp_patch_idx + 1) * num_sp_patches + ] + for pp_patch_idx in range(num_pipeline_patch) + ]# xinze: for height 1024, pp=sp=2, it is [[32, 32], [32, 32]] + pp_sp_patches_start_idx = [ + flatten_patches_start_idx[ + pp_patch_idx * num_sp_patches : (pp_patch_idx + 1) * num_sp_patches + 1 + ] + for pp_patch_idx in range(num_pipeline_patch) + ]# xinze: for height 1024, pp=sp=2, it is [[0, 32, 64], [64, 96, 128]] + pp_patches_height = [ + sp_patches_height[sp_patch_idx] + for sp_patches_height in pp_sp_patches_height + ]# xinze: for height 1024, pp=sp=2, it is [32, 32] + pp_patches_start_idx_local = [0] + [ + sum(pp_patches_height[:i]) for i in range(1, len(pp_patches_height) + 1) + ]# xinze: for height 1024, pp=sp=2, it is [0, 32, 64] + pp_patches_start_end_idx_global = [ + sp_patches_start_idx[sp_patch_idx : sp_patch_idx + 2] + for sp_patches_start_idx in pp_sp_patches_start_idx + ]# xinze: for height 1024, pp=sp=2, it is [[0, 32], [64, 96]] for rank 0 and [[32, 64], [96, 128]] for rank 1 + + pp_patches_token_start_end_idx_global = [ + [ + (latents_width // patch_size) * (start_idx // patch_size), + (latents_width // patch_size) * (end_idx // patch_size), + ] + for start_idx, end_idx in pp_patches_start_end_idx_global + ]# xinze: it is [[0, 1024], [2048, 3072]] for rank 0 and [[1024, 2048], [3072, 4096]] for rank 1 + + pp_patches_token_num = [ + end - start for start, end in pp_patches_token_start_end_idx_global + ]# xinze: it is [1024, 1024] + pp_patches_token_start_idx_local = [ + sum(pp_patches_token_num[:i]) for i in range(len(pp_patches_token_num) + 1) + ]# xinze: it is [0, 1024, 2048] + self.num_pipeline_patch = num_pipeline_patch + self.pp_patches_height = pp_patches_height + self.pp_patches_start_idx_local = pp_patches_start_idx_local + self.pp_patches_start_end_idx_global = pp_patches_start_end_idx_global + self.pp_patches_token_start_idx_local = pp_patches_token_start_idx_local + self.pp_patches_token_start_end_idx_global = ( + pp_patches_token_start_end_idx_global + ) + + self.pp_patches_token_num = pp_patches_token_num + + + def _reset_recv_buffer(self): + get_pp_group().reset_buffer() + get_pp_group.set_config(self, dtype=torch.float16) + +def initialize_runtime_state(pipeline, engine_config): + global _RUNTIME + + _RUNTIME = DiTRuntimeState(pipeline=pipeline, config=engine_config) + +def get_runtime_state(): + assert _RUNTIME is not None, "Runtime state has not been initialized." + return _RUNTIME \ No newline at end of file diff --git a/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py b/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py index 91264e805a0f..c7f33deddeba 100644 --- a/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py +++ b/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py @@ -15,7 +15,7 @@ import math from dataclasses import dataclass from typing import List, Optional, Tuple, Union - +from diffusers.runtime_state import get_runtime_state import numpy as np import torch @@ -330,7 +330,14 @@ def step( prev_sample = prev_sample.to(model_output.dtype) # upon completion increase step index by one - self._step_index += 1 + # self._step_index += 1 + if ( + not get_runtime_state().patch_mode + or get_runtime_state().pipeline_patch_idx + == get_runtime_state().num_pipeline_patch - 1 + ): + self._step_index += 1 + if not return_dict: return (prev_sample,) diff --git a/src/diffusers/torch_parallel_context.py b/src/diffusers/torch_parallel_context.py new file mode 100644 index 000000000000..27dd95b9de88 --- /dev/null +++ b/src/diffusers/torch_parallel_context.py @@ -0,0 +1,354 @@ +from typing import Optional, Union, Sequence, ContextManager +from datetime import timedelta +from contextlib import nullcontext + +import torch +import os + +from diffusers.parallel_context import ( + AsyncMode, + AsyncModeContext, + ParallelContext, +) +from loguru import logger +# types +Shape = Sequence[int] +Stream = torch.cuda.Stream +ProcessGroup = torch.distributed.ProcessGroup + + +def _check_send_recv_tensors(tensors: Union[torch.Tensor, Sequence[torch.Tensor]]): + if isinstance(tensors, torch.Tensor): + tensors = [tensors] + else: + if not isinstance(tensors, (list, tuple)): + raise ValueError( + f"send_tensors must be a list of tensors, but got type {type(tensors)}" + ) + + for i, tensor in enumerate(tensors): + if not isinstance(tensor, torch.Tensor): + raise ValueError( + "send_tensors must be a list of tensors, " + f"but got {i} th element type {type(tensor)}" + ) + + return tensors + + +class TorchBasedParallelContext(ParallelContext): + def __init__( + self, + *, + ranks: Sequence[int] = None, + tensor_parallel_size: int = 1, + pipeline_parallel_size: int = 1, + device_index: int = 0, + backend: str = "nccl", + ): + self._init_torch_distributed() + self._backend = backend + + self._ranks = ranks + self._tensor_parallel_size = tensor_parallel_size + self._pipeline_parallel_size = pipeline_parallel_size + self._device_index = device_index + self._torch_device = torch.device(f"cuda:{self._device_index}") + self._parallel_size = tensor_parallel_size * pipeline_parallel_size + + if not isinstance(ranks, (list, tuple)) or not all( + isinstance(rank, int) for rank in ranks + ): + raise ValueError(f"`ranks` should be a list of int") + if len(ranks) != self._parallel_size: + raise ValueError(f"`ranks` should has the same length as parallel_ctx.size") + self._ranks = ranks + + if self._parallel_size > 1: + self._main_group = torch.distributed.new_group( + self._ranks, backend=self._backend + ) + else: + self._main_group = None + + # for grouping + self._process_mesh = torch.tensor(self._ranks) + + self._rank = ( + torch.distributed.get_rank() if torch.distributed.is_initialized() else 0 + ) + + # TODO: remove when tests don't rely on set_device in parallel_ctx + if self._rank in self._ranks: + torch.cuda.set_device(self._device_index) + + # tensor parallel property + self._tensor_parallel_group: Optional[ProcessGroup] = None + self._tensor_parallel_group_root: int = 0 + self._tensor_parallel_main_comm: Optional[Communicator] = None + self._tensor_parallel_comm_stream: Optional[Stream] = None + # async mode + self._async_mode: Optional[AsyncModeContext] = None + + # pipeline parallel property + self._pipeline_parallel_group: Optional[ProcessGroup] = None + self._pipeline_stage_id: int = 0 + self._pipeline_parallel_group_prev_rank = 0 + self._pipeline_parallel_group_next_rank = 0 + + with torch.cuda.device(self._device_index): + self._init_tensor_parallel_group() + self._init_pipeline_parallel_group() + + + def _init_torch_distributed(self, timeout=timedelta(hours=24)): + assert torch.cuda.is_available() + assert torch.distributed.is_available() + + env_rank = int(os.getenv("RANK", "0")) + env_world_size = int(os.getenv("WORLD_SIZE", "1")) + if env_world_size == 1: + return + + if torch.distributed.is_initialized(): + assert env_world_size == torch.distributed.get_world_size() + assert env_rank == torch.distributed.get_rank() + return + + backend = "nccl" + options = torch.distributed.ProcessGroupNCCL.Options() + options.is_high_priority_stream = True + + torch.distributed.init_process_group( + backend=backend, + world_size=env_world_size, + rank=env_rank, + timeout=timeout, + pg_options=options, + ) + + def backend(self) -> str: + return self._backend + + def send_recv_comm_device(self) -> Union[str, torch.device]: + # communication device for send/recv + return self._torch_device + + def _init_pipeline_parallel_group(self): + if self.size > 1: + self._pipeline_parallel_group = self._new_group( + self._pipeline_parallel_size + ) + if self._pipeline_parallel_group is not None: + # set pipeline property + group_ranks = torch.distributed.get_process_group_ranks( + self._pipeline_parallel_group + ) + self._pipeline_stage_id = group_ranks.index(self._rank) + self._pipeline_parallel_group_prev_rank = group_ranks[ + (self._pipeline_stage_id - 1 + self._pipeline_parallel_size) % self._pipeline_parallel_size + ] + self._pipeline_parallel_group_next_rank = group_ranks[ + (self._pipeline_stage_id + 1) % self._pipeline_parallel_size + ] + + def _init_tensor_parallel_group(self): + if self.size > 1: + self._tensor_parallel_group = self._new_group(self._tensor_parallel_size) + if self._tensor_parallel_group is not None: + group_ranks = torch.distributed.get_process_group_ranks( + self._tensor_parallel_group + ) + self._tensor_parallel_group_root = group_ranks[0] + + def _new_group( + self, parallel_size: int + ) -> Optional[torch.distributed.ProcessGroup]: + remain_num_groups = self._process_mesh.size(-1) + if remain_num_groups % parallel_size != 0: + raise ValueError( + f"The process mesh {self._process_mesh} cannot be divided further by parallel_size {parallel_size}" + ) + + num_groups = remain_num_groups // parallel_size + self._process_mesh = self._process_mesh.view( + *self._process_mesh.shape[:-1], parallel_size, num_groups + ) + + ret_group = None + global_rank = torch.distributed.get_rank() + # (past_num_groups, parallel_size, new_num_groups) + process_mesh = self._process_mesh.view(-1, parallel_size, num_groups) + for group_i in range(process_mesh.size(0)): + for group_j in range(process_mesh.size(-1)): + global_ranks = process_mesh[group_i, :, group_j].tolist() + group = torch.distributed.new_group(global_ranks, backend=self._backend) + if global_rank in global_ranks: + ret_group = group + + return ret_group + + def device_index(self) -> int: + return self._device_index + + def torch_device(self) -> torch.device: + return self._torch_device + + def allreduce(self, tensor: torch.Tensor, red_op: torch.distributed.ReduceOp): + if self.size > 1: + torch.distributed.all_reduce(tensor, op=red_op, group=self._main_group) + + @property + def tensor_parallel_group(self): + return self._tensor_parallel_group + + @property + def size(self) -> int: + return self._parallel_size + + @property + def tensor_parallel_size(self): + return self._tensor_parallel_size + + @property + def tensor_parallel_rank(self) -> int: + if self._tensor_parallel_group is None: + return 0 + return self._tensor_parallel_group.rank() + + def broadcast_in_tensor_parallel_group(self, tensor: torch.Tensor): + if self.tensor_parallel_size > 1: + torch.distributed.broadcast( + tensor, + src=self._tensor_parallel_group_root, + group=self._tensor_parallel_group, + ) + + def allreduce_in_tensor_parallel_group(self, tensor: torch.Tensor): + if self.tensor_parallel_size > 1: + torch.distributed.all_reduce( + tensor, + op=torch.distributed.ReduceOp.SUM, + group=self._tensor_parallel_group, + ) + + def allgather_in_tensor_parallel_group( + self, input: torch.Tensor, output: torch.Tensor + ): + if self.tensor_parallel_size > 1: + output1d = output.view(-1) + local_elems = output1d.size(0) // self._tensor_parallel_size + tensor_list = [ + output1d[i * local_elems : (i + 1) * local_elems] + for i in range(self._tensor_parallel_size) + ] + torch.distributed.all_gather( + tensor_list, + input.view(-1), + group=self._tensor_parallel_group, + ) + + @property + def pipeline_parallel_group(self): + return self._pipeline_parallel_group + + @property + def pipeline_parallel_size(self) -> int: + return self._pipeline_parallel_size + + @property + def pipeline_stage_id(self) -> int: + return self._pipeline_stage_id + + def recv_from_prev_stage( + self, recv_tensors: Union[torch.Tensor, Sequence[torch.Tensor]] + ): + """Receive tensor from previous stage in pipeline (forward receive).""" + assert ( + self.pipeline_parallel_group is not None + ), "pipeline_parallel_group is not initialized, pipeline_parallel_size need to be set greater than 1" + + recv_tensors = _check_send_recv_tensors(recv_tensors) + p2p_ops = [] + for recv_tensor in recv_tensors: + recv_op = torch.distributed.P2POp( + torch.distributed.irecv, + recv_tensor, + self._pipeline_parallel_group_prev_rank, + group=self.pipeline_parallel_group, + ) + p2p_ops.append(recv_op) + # logger.info(f"the current divice is {self._device_index=}, recving from {self._pipeline_parallel_group_prev_rank=}, in group {self.pipeline_parallel_group}") + if len(p2p_ops) > 0: + reqs = torch.distributed.batch_isend_irecv(p2p_ops) + for req in reqs: + req.wait() + + def send_to_next_stage( + self, send_tensors: Union[torch.Tensor, Sequence[torch.Tensor]] + ): + """Send tensor to next stage in pipeline (forward send).""" + assert ( + self.pipeline_parallel_group is not None + ), "pipeline_parallel_group is not initialized, pipeline_parallel_size need to be set greater than 1" + + send_tensors = _check_send_recv_tensors(send_tensors) + p2p_ops = [] + for tensor in send_tensors: + send_op = torch.distributed.P2POp( + torch.distributed.isend, + tensor, + self._pipeline_parallel_group_next_rank, + group=self.pipeline_parallel_group, + ) + p2p_ops.append(send_op) + # logger.info(f"the current divice is {self._device_index=}, sending to {self._pipeline_parallel_group_next_rank=}, in group {self.pipeline_parallel_group=}") + if len(p2p_ops) > 0: + reqs = torch.distributed.batch_isend_irecv(p2p_ops) + for req in reqs: + req.wait() + + def send_and_recv_between_neighborhoods( + self, + recv_tensors: Union[torch.Tensor, Sequence[torch.Tensor]], + send_tensors: Union[torch.Tensor, Sequence[torch.Tensor]], + ): + assert ( + self.pipeline_parallel_group is not None + ), "pipeline_parallel_group is not initialized, pipeline_parallel_size need to be set greater than 1" + + recv_tensors = _check_send_recv_tensors(recv_tensors) + send_tensors = _check_send_recv_tensors(send_tensors) + p2p_ops = [] + + for recv_tensor in recv_tensors: + recv_op = torch.distributed.P2POp( + torch.distributed.irecv, + recv_tensor, + self._pipeline_parallel_group_prev_rank, + group=self.pipeline_parallel_group, + ) + p2p_ops.append(recv_op) + + for send_tensor in send_tensors: + send_op = torch.distributed.P2POp( + torch.distributed.isend, + send_tensor, + self._pipeline_parallel_group_next_rank, + group=self.pipeline_parallel_group, + ) + p2p_ops.append(send_op) + + if len(p2p_ops) > 0: + reqs = torch.distributed.batch_isend_irecv(p2p_ops) + for req in reqs: + req.wait() + + def tensor_parallel_reduce_context(self) -> ContextManager[None]: + return nullcontext() + + def async_mode(self, async_mode_type: AsyncMode) -> AsyncModeContext: + raise NotImplementedError + + async def async_allreduce_in_tensor_parallel_group(self, tensor: torch.Tensor): + raise NotImplementedError diff --git a/src/examples/run.sh b/src/examples/run.sh new file mode 100644 index 000000000000..c9576cad0ee5 --- /dev/null +++ b/src/examples/run.sh @@ -0,0 +1,68 @@ +set -x + +export PYTHONPATH=$PWD:$PYTHONPATH +export CUDA_VISIBLE_DEVICES=1,4,6,7 +# export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 + +# # Select the model type +export MODEL_TYPE="Sd3" +# # Configuration for different model types +# # script, model_id, inference_step +declare -A MODEL_CONFIGS=( + ["Pixart-alpha"]="pixartalpha_example.py /cfs/dit/PixArt-XL-2-1024-MS 20" + ["Pixart-sigma"]="pixartsigma_example.py /cfs/dit/PixArt-Sigma-XL-2-2K-MS 20" + ["Sd3"]="sd3_example.py /maasjfs/hf_models/stable-diffusion-3.5-fp8/stable-diffusion-3-medium-diffusers 20" + ["Flux"]="flux_example.py /cfs/dit/FLUX.1-dev 28" + ["HunyuanDiT"]="hunyuandit_example.py /cfs/dit/HunyuanDiT-v1.2-Diffusers 50" +) + +if [[ -v MODEL_CONFIGS[$MODEL_TYPE] ]]; then + IFS=' ' read -r SCRIPT MODEL_ID INFERENCE_STEP <<< "${MODEL_CONFIGS[$MODEL_TYPE]}" + export SCRIPT MODEL_ID INFERENCE_STEP +else + echo "Invalid MODEL_TYPE: $MODEL_TYPE" + exit 1 +fi + +mkdir -p ./results + +# task args +TASK_ARGS="--height 1024 --width 1024" + + +# On 8 gpus, pp=2, ulysses=2, ring=1, cfg_parallel=2 (split batch) +N_GPUS=4 +PARALLEL_ARGS="--pipefusion_parallel_degree 2 --ulysses_degree 1 --ring_degree 1" + +CFG_ARGS="--use_cfg_parallel" + +# By default, num_pipeline_patch = pipefusion_degree, and you can tune this parameter to achieve optimal performance. +# xinze: patch number is not necessarily equal to pp degree. +# PIPEFUSION_ARGS="--num_pipeline_patch 8 " + +# For high-resolution images, we use the latent output type to avoid runing the vae module. Used for measuring speed. +# OUTPUT_ARGS="--output_type latent" + +# PARALLLEL_VAE="--use_parallel_vae" + +# Another compile option is `--use_onediff` which will use onediff's compiler. +# COMPILE_FLAG="--use_torch_compile" + +# $CFG_ARGS \" + +# export CUDA_VISIBLE_DEVICES=4,5,6,7 + +torchrun --nproc_per_node=$N_GPUS ./examples/$SCRIPT \ +--model $MODEL_ID \ +$PARALLEL_ARGS \ +$TASK_ARGS \ +$PIPEFUSION_ARGS \ +$OUTPUT_ARGS \ +--num_inference_steps $INFERENCE_STEP \ +--prompt "a female character looks like Asuka Langley with long, red, flowing hair that appears to be made of ethereal, swirling patterns resembling the Northern Lights or Aurora Borealis. The background is dominated by deep blues and purples, creating a mysterious and dramatic atmosphere. The character's face is serene, with pale skin and striking features. She wears a dark-colored outfit with subtle patterns. The overall style of the artwork is reminiscent of fantasy or supernatural genres." \ +$PARALLLEL_VAE \ +$CFG_ARGS + +# --warmup_steps 1 \ + +# $COMPILE_FLAG diff --git a/src/examples/sd3_example.py b/src/examples/sd3_example.py new file mode 100644 index 000000000000..ff34ca86db8d --- /dev/null +++ b/src/examples/sd3_example.py @@ -0,0 +1,271 @@ +import time +import os +import torch +import torch.distributed as dist +from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3 import StableDiffusion3Pipeline +import argparse +from diffusers.group_coordinator import ( + GroupCoordinator, + RankGenerator, + SequenceParallelGroupCoordinator, + PipelineParallelGroupCoordinator, +) +from diffusers.runtime_state import initialize_runtime_state, get_runtime_state +from typing import Any, Dict, List, Optional, Tuple, Union +from torch import nn +from diffusers.models.embeddings import PatchEmbed +from diffusers.conv import CustomConv2d +from diffusers.embedding import CustomPatchEmbed +from diffusers.parallel_state import ( + init_model_parallel_group, + init_world_group, + initialize_model_parallel, + init_distributed_environment, + get_world_group, + get_pipeline_parallel_rank, + get_pipeline_parallel_world_size, + is_pipeline_last_stage, +) +import torch.nn as nn + + + + +def convert_transformer( + transformer: nn.Module, + blocks_name: List[str] = ['transformer_blocks'], + ) -> nn.Module: + pp_rank = get_pipeline_parallel_rank() + pp_world_size = get_pipeline_parallel_world_size() + blocks_list = { + block_name: getattr(transformer, block_name) for block_name in blocks_name + } + num_blocks_list = [len(blocks) for blocks in blocks_list.values()] + blocks_idx = { + name: [sum(num_blocks_list[:i]), sum(num_blocks_list[: i + 1])] + for i, name in enumerate(blocks_name) + } + + num_blocks_per_stage = ( + sum(num_blocks_list) + pp_world_size - 1 + ) // pp_world_size + stage_block_start_idx = pp_rank * num_blocks_per_stage + stage_block_end_idx = min( + (pp_rank + 1) * num_blocks_per_stage, + sum(num_blocks_list), + ) + for name, [blocks_start, blocks_end] in zip( + blocks_idx.keys(), blocks_idx.values() + ): + if ( + blocks_end <= stage_block_start_idx + or stage_block_end_idx <= blocks_start + ): + setattr(transformer, name, nn.ModuleList([])) + elif stage_block_start_idx <= blocks_start: + if blocks_end <= stage_block_end_idx: + pass + else: + setattr( + transformer, + name, + blocks_list[name][: -(blocks_end - stage_block_end_idx)], + ) + elif blocks_start < stage_block_start_idx: + if blocks_end <= stage_block_end_idx: + setattr( + transformer, + name, + blocks_list[name][stage_block_start_idx - blocks_start :], + ) + else: # blocks_end > stage_layer_end_idx + setattr( + transformer, + name, + blocks_list[name][ + stage_block_start_idx + - blocks_start : stage_block_end_idx + - blocks_end + ], + ) + + return transformer + +def _change_layer(submodule): + if isinstance(submodule, nn.Conv2d): + return CustomConv2d(conv2d=submodule) + elif isinstance(submodule, PatchEmbed): + return CustomPatchEmbed(patch_embedding=submodule) + else: + return submodule + + +def change_conv_embed(model: nn.Module, submodule_classes_to_change=[nn.Conv2d, PatchEmbed]): + if model is None: + return None + for name, module in model.named_children(): + need_change = False + for class_to_change in submodule_classes_to_change: + if isinstance(module, class_to_change): + need_change = True + break + if need_change: + new_layer = _change_layer(module) + setattr(model, name, new_layer) + + for subname, submodule in module.named_children(): + need_change = False + for class_to_change in submodule_classes_to_change: + if isinstance(submodule, class_to_change): + need_change = True + break + if need_change: + new_layer = _change_layer(submodule) + setattr(module, subname, new_layer) + + + return model + + + + +def main(): + parser = argparse.ArgumentParser(description="parallel diffuser arguments") + model_group = parser.add_argument_group("Model Options") + model_group.add_argument( + "--model", + type=str, + default="/maasjfs/hf_models/stable-diffusion-3.5-fp8/stable-diffusion-3-medium-diffusers", + help="Name or path of the huggingface model to use." + ) + model_group.add_argument( + "--ulysses_degree", + type=int, + default=None, + help="Ulysses SP degree. Used in attention layer" + ) + model_group.add_argument( + "--ring_degree", + type=int, + default=None, + help="Ring SP degree. Used in attention layer" + ) + model_group.add_argument( + "--pipefusion_parallel_degree", + type=int, + default=1, + help="Pipefusion parallel degree." + ) + model_group.add_argument( + "--prompt", + type=str, + nargs="*", + default="", + help="prompt." + ) + model_group.add_argument( + "--use_parallel_vae", + action="store_true", + help="use distvae to parallel vae" + ) + model_group.add_argument( + "--num_inference_steps", + type=int, + default=20, + help="number of inference steps." + ) + model_group.add_argument( + "--use_cfg_parallel", + action="store_true", + help="Use split batch in classifier_free_guidance. cfg_degree will be 2 if set", + ) + model_group.add_argument( + "--height", + type=int, + default=1024, + ) + model_group.add_argument( + "--width", + type=int, + default=1024, + ) + args = parser.parse_args() + use_parallel_vae = args.use_parallel_vae + ulysses_degree = args.ulysses_degree + ring_degree = args.ring_degree + PP_degree = args.pipefusion_parallel_degree + if args.use_cfg_parallel: + cfg_degree = 2 + else: + cfg_degree = 1 + + init_distributed_environment() + + + + + initialize_model_parallel( + classifier_free_guidance_degree=cfg_degree, + ulysses_degree=ulysses_degree, + ring_degree=ring_degree, + pipeline_parallel_degree=PP_degree, + ) + + + + local_rank= int(os.environ.get("LOCAL_RANK", "0")) + pipe = StableDiffusion3Pipeline.from_pretrained( + pretrained_model_name_or_path=args.model, + torch_dtype=torch.float16, + use_parallel_vae=use_parallel_vae, + PP_degree=PP_degree, + ulysses_degree=ulysses_degree, + ring_degree=ring_degree, + ).to(f"cuda:{local_rank}") + initialize_runtime_state(pipeline=pipe, engine_config=args) + + + + if pipe.transformer is not None and PP_degree > 1: + pipe.transformer = convert_transformer(pipe.transformer) + pipe.transformer = change_conv_embed(model=pipe.transformer) + + args.num_inference_steps = 20 + output = pipe( + height=1024, + width=1024, + prompt=args.prompt, + num_inference_steps=args.num_inference_steps, + generator=torch.Generator(device="cuda").manual_seed(42), + ) + + start_time = time.time() + for i in range(3): + output = pipe( + height=1024, + width=1024, + prompt=args.prompt, + num_inference_steps=args.num_inference_steps, + generator=torch.Generator(device="cuda").manual_seed(42), + ) + end_time = time.time() + parallel_info = ( + f"ulysses{args.ulysses_degree}_ring{args.ring_degree}_" + f"pp{args.pipefusion_parallel_degree}" + ) + if is_pipeline_last_stage(): + image=output.images[0] + if not os.path.exists("results"): + os.mkdir("results") + image.save(f"./results/SD3_result_{parallel_info}_rank{local_rank}.png") + print(f"image saved to ./results/SD3_result_{parallel_info}_rank{local_rank}.png") + if get_world_group().rank == get_world_group().world_size - 1: + print(f"used time: {(end_time-start_time)/3:.2f} seconds") + dist.barrier() + dist.destroy_process_group() + + + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/src/results/SD3_result_ulysses1_ring1_pp1_rank0.png b/src/results/SD3_result_ulysses1_ring1_pp1_rank0.png new file mode 100644 index 000000000000..f7ce431b8b8a Binary files /dev/null and b/src/results/SD3_result_ulysses1_ring1_pp1_rank0.png differ diff --git a/src/results/SD3_result_ulysses1_ring1_pp1_rank1.png b/src/results/SD3_result_ulysses1_ring1_pp1_rank1.png new file mode 100644 index 000000000000..d1792ed49caa Binary files /dev/null and b/src/results/SD3_result_ulysses1_ring1_pp1_rank1.png differ diff --git a/src/results/SD3_result_ulysses1_ring1_pp2_rank1.png b/src/results/SD3_result_ulysses1_ring1_pp2_rank1.png new file mode 100644 index 000000000000..a7485ee10043 Binary files /dev/null and b/src/results/SD3_result_ulysses1_ring1_pp2_rank1.png differ diff --git a/src/results/SD3_result_ulysses1_ring1_pp2_rank3.png b/src/results/SD3_result_ulysses1_ring1_pp2_rank3.png new file mode 100644 index 000000000000..a7485ee10043 Binary files /dev/null and b/src/results/SD3_result_ulysses1_ring1_pp2_rank3.png differ diff --git a/src/results/SD3_result_ulysses1_ring1_pp4_rank3.png b/src/results/SD3_result_ulysses1_ring1_pp4_rank3.png new file mode 100644 index 000000000000..39772f68a327 Binary files /dev/null and b/src/results/SD3_result_ulysses1_ring1_pp4_rank3.png differ