diff --git a/docs/source/user_guide/streams.md b/docs/source/user_guide/streams.md index fff357b759..26ea154321 100644 --- a/docs/source/user_guide/streams.md +++ b/docs/source/user_guide/streams.md @@ -1,20 +1,22 @@ # Streams -Streams allow concurrent execution of GPU operations. By default, all Quadrants kernels launch on the default stream, which serializes everything. By creating explicit streams, you can run independent kernels concurrently and control synchronization with events. +Streams allow concurrent execution of GPU operations. By default, all Quadrants kernels launch on the default stream, which serializes everything. With streams, you can run multiple top-level for loops in parallel. ## Supported platforms -| Backend | Streams | Events | Notes | -|---------|---------|--------|-------| -| CUDA | Yes | Yes | Full concurrent execution | -| AMDGPU | Yes | Yes | Full concurrent execution (requires ROCm >= 5.4) | -| CPU | No-op | No-op | `qd_stream` is silently ignored, kernels run serially | -| Metal | No-op | No-op | `qd_stream` is silently ignored, kernels run serially | -| Vulkan | No-op | No-op | `qd_stream` is silently ignored, kernels run serially | +| Backend | Supported | +|---------|-----------| +| CUDA | Yes | +| AMDGPU | Yes | +| CPU | No-op | +| Metal | No-op | +| Vulkan | No-op | -On backends without native stream support, `create_stream()` and `create_event()` return objects with handle `0`. All stream/event operations become no-ops and kernels run serially. Code written with streams is portable across all backends in the sense that it will run without modifications, but serially. +On backends without native stream support, stream operations are no-ops and for loops run serially. Code using streams is portable across all backends — it will run without modifications, but serially. -## Creating and using streams +## Stream parallelism + +Inside a `@qd.kernel`, each `with qd.stream_parallel():` block runs on its own GPU stream. ```python import quadrants as qd @@ -24,17 +26,40 @@ qd.init(arch=qd.cuda) N = 1024 a = qd.field(qd.f32, shape=(N,)) b = qd.field(qd.f32, shape=(N,)) +c = qd.field(qd.f32, shape=(N,)) @qd.kernel -def fill_a(): - for i in range(N): - a[i] = 1.0 +def compute_ab(): + with qd.stream_parallel(): + for i in range(N): + a[i] = compute_a(i) + with qd.stream_parallel(): + for j in range(N): + b[j] = compute_b(j) @qd.kernel -def fill_b(): +def combine(): for i in range(N): - b[i] = 2.0 + c[i] = a[i] + b[i] + +compute_ab() # the two stream_parallel blocks run concurrently +combine() # runs after compute_ab() returns — a[] and b[] are ready +``` + +Consecutive `with qd.stream_parallel():` blocks run concurrently. Multiple for loops within a single block share a stream and run serially on it. All streams are synchronized before the kernel returns. + +### Restrictions +- All top-level statements in a kernel must be either all `stream_parallel` blocks or all regular statements. Mixing the two at the top level is a compile-time error. +- Nesting `stream_parallel` blocks is not supported. + +## Explicit streams + +For cases that require manual control — such as launching separate kernels on different streams or interoperating with PyTorch — you can create and manage streams directly. + +### Creating and using streams + +```python s1 = qd.create_stream() s2 = qd.create_stream() @@ -50,7 +75,7 @@ s2.destroy() Pass `qd_stream=` to any kernel call to launch it on that stream. Kernels on different streams may execute concurrently. Call `synchronize()` to block until all work on a stream completes. -## Events +### Events Events let you express dependencies between streams without full synchronization. @@ -84,7 +109,7 @@ s2.destroy() `e.record(stream)` captures the point in `stream`'s execution. `e.wait(qd_stream=stream)` makes `stream` wait until the recorded point is reached. If `qd_stream` is omitted, the default stream waits. -## Context managers +### Context managers Streams and events support `with` blocks for automatic cleanup: @@ -95,11 +120,11 @@ with qd.create_stream() as s: # s.destroy() called automatically ``` -## PyTorch interop (CUDA) +### PyTorch interop (CUDA) When mixing Quadrants kernels with PyTorch operations on CUDA, both frameworks must use the same stream to avoid race conditions. Without explicit stream management, Quadrants and PyTorch may launch work on different streams with no ordering guarantees, leading to intermittent data corruption. -### Running Quadrants kernels on PyTorch's stream +#### Running Quadrants kernels on PyTorch's stream ```python import torch @@ -115,7 +140,7 @@ apply_actions_kernel(qd_stream=stream) Wrap PyTorch's raw `CUstream` pointer in a Quadrants `Stream` object. Do **not** call `destroy()` on this wrapper — PyTorch owns the underlying stream. -### Running PyTorch operations on a Quadrants stream +#### Running PyTorch operations on a Quadrants stream ```python qd_stream = qd.create_stream() @@ -136,4 +161,4 @@ qd_stream.destroy() - **Not compatible with graphs.** Do not pass `qd_stream` to a kernel decorated with `graph=True`. - **Not compatible with autodiff.** Do not pass `qd_stream` to a kernel that uses reverse-mode or forward-mode differentiation, or inside a `qd.ad.Tape` context. - **`qd.sync()` only waits on the default stream.** It does not drain explicit streams. Call `stream.synchronize()` on each stream you need to wait for. -- **No automatic synchronization.** You are responsible for inserting events or `synchronize()` calls when one stream's output is another stream's input. +- **No automatic synchronization with explicit streams.** When using explicit streams, you are responsible for inserting events or `synchronize()` calls when one stream's output is another stream's input. `stream_parallel` handles this automatically. diff --git a/python/quadrants/lang/ast/ast_transformer.py b/python/quadrants/lang/ast/ast_transformer.py index 8694a4e94e..263a4a11a3 100644 --- a/python/quadrants/lang/ast/ast_transformer.py +++ b/python/quadrants/lang/ast/ast_transformer.py @@ -119,7 +119,11 @@ def build_AnnAssign(ctx: ASTTransformerFuncContext, node: ast.AnnAssign): @staticmethod def build_assign_annotated( - ctx: ASTTransformerFuncContext, target: ast.Name, value, is_static_assign: bool, annotation: Type + ctx: ASTTransformerFuncContext, + target: ast.Name, + value, + is_static_assign: bool, + annotation: Type, ): """Build an annotated assignment like this: target: annotation = value. @@ -165,7 +169,10 @@ def build_Assign(ctx: ASTTransformerFuncContext, node: ast.Assign) -> None: @staticmethod def build_assign_unpack( - ctx: ASTTransformerFuncContext, node_target: list | ast.Tuple, values, is_static_assign: bool + ctx: ASTTransformerFuncContext, + node_target: list | ast.Tuple, + values, + is_static_assign: bool, ): """Build the unpack assignments like this: (target1, target2) = (value1, value2). The function should be called only if the node target is a tuple. @@ -591,7 +598,8 @@ def build_Return(ctx: ASTTransformerFuncContext, node: ast.Return) -> None: else: raise QuadrantsSyntaxError("The return type is not supported now!") ctx.ast_builder.create_kernel_exprgroup_return( - expr.make_expr_group(return_exprs), _qd_core.DebugInfo(ctx.get_pos_info(node)) + expr.make_expr_group(return_exprs), + _qd_core.DebugInfo(ctx.get_pos_info(node)), ) else: ctx.return_data = node.value.ptr @@ -1520,6 +1528,24 @@ def build_Continue(ctx: ASTTransformerFuncContext, node: ast.Continue) -> None: ctx.ast_builder.insert_continue_stmt(_qd_core.DebugInfo(ctx.get_pos_info(node))) return None + @staticmethod + def build_With(ctx: ASTTransformerFuncContext, node: ast.With) -> None: + if len(node.items) != 1: + raise QuadrantsSyntaxError("'with' in Quadrants kernels only supports a single context manager") + item = node.items[0] + if item.optional_vars is not None: + raise QuadrantsSyntaxError("'with ... as ...' is not supported in Quadrants kernels") + if not isinstance(item.context_expr, ast.Call): + raise QuadrantsSyntaxError("'with' in Quadrants kernels requires a call expression") + if not FunctionDefTransformer._is_stream_parallel_with(node, ctx.global_vars): + raise QuadrantsSyntaxError("'with' in Quadrants kernels only supports qd.stream_parallel()") + if not ctx.is_kernel: + raise QuadrantsSyntaxError("qd.stream_parallel() can only be used inside @qd.kernel, not @qd.func") + ctx.ast_builder.begin_stream_parallel() + build_stmts(ctx, node.body) + ctx.ast_builder.end_stream_parallel() + return None + @staticmethod def build_Pass(ctx: ASTTransformerFuncContext, node: ast.Pass) -> None: return None diff --git a/python/quadrants/lang/ast/ast_transformers/function_def_transformer.py b/python/quadrants/lang/ast/ast_transformers/function_def_transformer.py index 4c9bd5115b..142694091f 100644 --- a/python/quadrants/lang/ast/ast_transformers/function_def_transformer.py +++ b/python/quadrants/lang/ast/ast_transformers/function_def_transformer.py @@ -26,11 +26,13 @@ from quadrants.lang.ast.ast_transformer_utils import ( ASTTransformerFuncContext, ) +from quadrants.lang.ast.symbol_resolver import ASTResolver from quadrants.lang.buffer_view import BufferView from quadrants.lang.exception import ( QuadrantsSyntaxError, ) from quadrants.lang.matrix import MatrixType +from quadrants.lang.stream import stream_parallel from quadrants.lang.struct import StructType from quadrants.lang.util import to_quadrants_type from quadrants.types import annotations, buffer_view_type, ndarray_type, primitive_types @@ -54,7 +56,9 @@ def _decl_and_create_variable( assert this_arg_features is not None marker = this_arg_features[0] if marker == _TENSOR_T_NDARRAY_MARKER: - raw_element_type, ndim, needs_grad, boundary, layout = this_arg_features[1:] + raw_element_type, ndim, needs_grad, boundary, layout = ( + this_arg_features[1:] + ) return False, ( kernel_arguments.decl_ndarray_arg, ( @@ -73,7 +77,9 @@ def _decl_and_create_variable( assert ctx.global_vars is not None return True, ctx.global_vars.get(name) raise AssertionError(f"unknown qd.Tensor marker: {marker!r}") - if annotation == annotations.template or isinstance(annotation, annotations.template): + if annotation == annotations.template or isinstance( + annotation, annotations.template + ): if name in ctx.template_vars: return True, ctx.template_vars[name] assert ctx.global_vars is not None @@ -96,8 +102,12 @@ def _decl_and_create_variable( needs_grad, BoundaryMode(boundary), ) - offset = kernel_arguments.decl_scalar_arg(primitive_types.i32, full_name + "_offset") - size = kernel_arguments.decl_scalar_arg(primitive_types.i32, full_name + "_size") + offset = kernel_arguments.decl_scalar_arg( + primitive_types.i32, full_name + "_offset" + ) + size = kernel_arguments.decl_scalar_arg( + primitive_types.i32, full_name + "_size" + ) return True, BufferView(arr, offset, size) if isinstance(annotation, ndarray_type.NdarrayType): assert this_arg_features is not None @@ -137,7 +147,10 @@ def _transform_kernel_arg( ctx.create_variable(argument_name, argument_type) for field_idx, field in enumerate(dataclasses.fields(argument_type)): flat_name = create_flat_name(argument_name, field.name) - if pruning.enforcing and flat_name not in pruning.used_vars_by_func_id[func_id]: + if ( + pruning.enforcing + and flat_name not in pruning.used_vars_by_func_id[func_id] + ): continue # if a field is a dataclass, then feed back into process_kernel_arg recursively if dataclasses.is_dataclass(field.type): @@ -175,7 +188,9 @@ def _transform_kernel_arg( ctx.create_variable(argument_name, obj) @staticmethod - def _transform_as_kernel(ctx: ASTTransformerFuncContext, node: ast.FunctionDef, args: ast.arguments) -> None: + def _transform_as_kernel( + ctx: ASTTransformerFuncContext, node: ast.FunctionDef, args: ast.arguments + ) -> None: assert ctx.func is not None assert ctx.arg_features is not None if node.returns is not None: @@ -224,7 +239,9 @@ def _walk_obj(obj, arg_idx, path): child = child._unwrap() if isinstance(child, _ndarray.Ndarray): _register_ndarray(child, arg_idx, (*path, field.name)) - elif dataclasses.is_dataclass(child) and not isinstance(child, type): + elif dataclasses.is_dataclass(child) and not isinstance( + child, type + ): _walk_obj(child, arg_idx, (*path, field.name)) else: for attr_name, attr_val in vars(obj).items(): @@ -248,7 +265,9 @@ def _register_ndarray(nd, arg_idx, attr_chain): element_type, ndim, name, needs_grad ) arr = any_array.AnyArray( - _qd_core.make_external_tensor_expr(element_type, ndim, arg_id_vec, needs_grad, BoundaryMode.UNSAFE), + _qd_core.make_external_tensor_expr( + element_type, ndim, arg_id_vec, needs_grad, BoundaryMode.UNSAFE + ), _qd_layout=layout, ) cache[key] = arr @@ -257,7 +276,9 @@ def _register_ndarray(nd, arg_idx, attr_chain): assert ctx.py_args is not None for i, arg_meta in enumerate(ctx.func.arg_metas): anno = arg_meta.annotation - is_template = anno is annotations.template or isinstance(anno, annotations.template) + is_template = anno is annotations.template or isinstance( + anno, annotations.template + ) is_tensor_anno = anno is _TensorClass if not (is_template or is_tensor_anno): continue @@ -295,15 +316,21 @@ def _transform_func_arg( # directly — ndarray and field impls are both valid pass-by-reference arguments. if argument_type is _TensorClass: data = FunctionDefTransformer._unwrap_tensor(data) - _cache = getattr(getattr(ctx, "global_context", None), "ndarray_to_any_array", None) + _cache = getattr( + getattr(ctx, "global_context", None), "ndarray_to_any_array", None + ) promoted = _cache.get(id(data)) if _cache else None - ctx.create_variable(argument_name, promoted if promoted is not None else data) + ctx.create_variable( + argument_name, promoted if promoted is not None else data + ) return None if dataclasses.is_dataclass(argument_type): for field in dataclasses.fields(argument_type): flat_name = create_flat_name(argument_name, field.name) - data_child = FunctionDefTransformer._unwrap_tensor(getattr(data, field.name)) + data_child = FunctionDefTransformer._unwrap_tensor( + getattr(data, field.name) + ) if isinstance( data_child, ( @@ -315,11 +342,19 @@ def _transform_func_arg( ): # qd.Tensor struct fields skip check_matched (the Tensor class has no such method — it is # polymorphic). - if field.type is not _TensorClass and hasattr(field.type, "check_matched"): + if field.type is not _TensorClass and hasattr( + field.type, "check_matched" + ): field.type.check_matched(data_child.get_type(), field.name) - _cache = getattr(getattr(ctx, "global_context", None), "ndarray_to_any_array", None) + _cache = getattr( + getattr(ctx, "global_context", None), + "ndarray_to_any_array", + None, + ) promoted = _cache.get(id(data_child)) if _cache else None - ctx.create_variable(flat_name, promoted if promoted is not None else data_child) + ctx.create_variable( + flat_name, promoted if promoted is not None else data_child + ) elif dataclasses.is_dataclass(data_child): FunctionDefTransformer._transform_func_arg( ctx, @@ -336,9 +371,17 @@ def _transform_func_arg( # Ndarray arguments are passed by reference. if isinstance(argument_type, (ndarray_type.NdarrayType)): if not isinstance( - data, (_ndarray.ScalarNdarray, matrix.VectorNdarray, matrix.MatrixNdarray, any_array.AnyArray) + data, + ( + _ndarray.ScalarNdarray, + matrix.VectorNdarray, + matrix.MatrixNdarray, + any_array.AnyArray, + ), ): - raise QuadrantsSyntaxError(f"Argument {argument_name} of type {argument_type} is not recognized.") + raise QuadrantsSyntaxError( + f"Argument {argument_name} of type {argument_type} is not recognized." + ) argument_type.check_matched(data.get_type(), argument_name) ctx.create_variable(argument_name, data) return None @@ -348,7 +391,9 @@ def _transform_func_arg( # not here — data.arr is an Expr node during func compilation, not a real Ndarray. if isinstance(argument_type, buffer_view_type.BufferViewType): if not isinstance(data, BufferView): - raise QuadrantsSyntaxError(f"Argument {argument_name} expects a BufferView, got {type(data).__name__}") + raise QuadrantsSyntaxError( + f"Argument {argument_name} expects a BufferView, got {type(data).__name__}" + ) ctx.create_variable(argument_name, data) return None @@ -387,7 +432,9 @@ def _transform_func_arg( return None if id(argument_type) in primitive_types.type_ids: - ctx.create_variable(argument_name, impl.expr_init_func(qd_ops.cast(data, argument_type))) + ctx.create_variable( + argument_name, impl.expr_init_func(qd_ops.cast(data, argument_type)) + ) return None # Create a copy for non-template arguments, # so that they are passed by value. @@ -396,7 +443,9 @@ def _transform_func_arg( return None @staticmethod - def _transform_as_func(ctx: ASTTransformerFuncContext, node: ast.FunctionDef, args: ast.arguments) -> None: + def _transform_as_func( + ctx: ASTTransformerFuncContext, node: ast.FunctionDef, args: ast.arguments + ) -> None: # pylint: disable=import-outside-toplevel from quadrants.lang.kernel_impl import Func @@ -404,7 +453,9 @@ def _transform_as_func(ctx: ASTTransformerFuncContext, node: ast.FunctionDef, ar assert ctx.py_args is not None for py_arg_i, py_arg in enumerate(ctx.py_args): argument = ctx.func.arg_metas_expanded[py_arg_i] - FunctionDefTransformer._transform_func_arg(ctx, argument.name, argument.annotation, py_arg) + FunctionDefTransformer._transform_func_arg( + ctx, argument.name, argument.annotation, py_arg + ) # deal with dataclasses for v in ctx.func.orig_arguments: @@ -443,7 +494,83 @@ def build_FunctionDef( else: FunctionDefTransformer._transform_as_func(ctx, node, args) + if ctx.is_kernel: + FunctionDefTransformer._validate_stream_parallel_exclusivity( + node.body, ctx.global_vars + ) + with ctx.variable_scope_guard(): build_stmts(ctx, node.body) return None + + @staticmethod + def _is_stream_parallel_with(stmt: ast.stmt, global_vars: dict[str, Any]) -> bool: + if not isinstance(stmt, ast.With): + return False + if len(stmt.items) != 1: + return False + item = stmt.items[0] + if not isinstance(item.context_expr, ast.Call): + return False + func_node = item.context_expr.func + if ASTResolver.resolve_to(func_node, stream_parallel, global_vars): + return True + resolved = ASTResolver.resolve_value(func_node, global_vars) + if resolved is not None: + return getattr(resolved, "__name__", None) == "stream_parallel" and getattr( + resolved, "__module__", "" + ).startswith("quadrants") + if isinstance(func_node, ast.Attribute) and func_node.attr == "stream_parallel": + return True + if isinstance(func_node, ast.Name) and func_node.id == "stream_parallel": + return True + return False + + @staticmethod + def _is_docstring(stmt: ast.stmt, index: int) -> bool: + return ( + index == 0 + and isinstance(stmt, ast.Expr) + and isinstance(stmt.value, (ast.Constant, ast.Str)) + ) + + @staticmethod + def _is_coverage_probe(stmt: ast.stmt) -> bool: + if not isinstance(stmt, ast.Assign) or len(stmt.targets) != 1: + return False + target = stmt.targets[0] + return ( + isinstance(target, ast.Subscript) + and isinstance(target.value, ast.Name) + and target.value.id.startswith("_qd_cov") + ) + + @staticmethod + def _validate_stream_parallel_exclusivity( + body: list[ast.stmt], global_vars: dict[str, Any] + ) -> None: + if not any( + FunctionDefTransformer._is_stream_parallel_with(s, global_vars) + for s in body + ): + return + for i, stmt in enumerate(body): + if FunctionDefTransformer._is_docstring(stmt, i): + continue + if FunctionDefTransformer._is_coverage_probe(stmt): + continue + if not FunctionDefTransformer._is_stream_parallel_with(stmt, global_vars): + stmt_desc = f"{type(stmt).__name__}" + if isinstance(stmt, ast.With) and stmt.items: + ctx_expr = stmt.items[0].context_expr + if isinstance(ctx_expr, ast.Call) and isinstance( + ctx_expr.func, ast.Attribute + ): + stmt_desc += f"(with {ast.dump(ctx_expr.func)})" + raise QuadrantsSyntaxError( + "When using qd.stream_parallel(), all top-level statements " + "in the kernel must be 'with qd.stream_parallel():' blocks. " + f"Move non-parallel code to a separate kernel. " + f"[stmt {i}: {stmt_desc}, body_len={len(body)}]" + ) diff --git a/python/quadrants/lang/ast/symbol_resolver.py b/python/quadrants/lang/ast/symbol_resolver.py index 81296fcefb..c2b4fcaffe 100644 --- a/python/quadrants/lang/ast/symbol_resolver.py +++ b/python/quadrants/lang/ast/symbol_resolver.py @@ -55,3 +55,35 @@ def resolve_to(node, wanted, scope): return False # The name ``scope`` here could be a bit confusing return scope is wanted + + @staticmethod + def resolve_value(node, scope): + """Resolve an AST Name/Attribute node to a Python object. + + Same traversal as resolve_to but returns the resolved object (or None) instead of comparing against a wanted + value. + """ + if isinstance(node, ast.Name): + return scope.get(node.id) if isinstance(scope, dict) else None + + if not isinstance(node, ast.Attribute): + return None + + v = node.value + chain = [node.attr] + while isinstance(v, ast.Attribute): + chain.append(v.attr) + v = v.value + if not isinstance(v, ast.Name): + return None + chain.append(v.id) + + for attr in reversed(chain): + try: + if isinstance(scope, dict): + scope = scope[attr] + else: + scope = getattr(scope, attr) + except (KeyError, AttributeError): + return None + return scope diff --git a/python/quadrants/lang/stream.py b/python/quadrants/lang/stream.py index e87816568c..3f734587b3 100644 --- a/python/quadrants/lang/stream.py +++ b/python/quadrants/lang/stream.py @@ -1,4 +1,5 @@ import weakref +from contextlib import contextmanager from quadrants.lang import impl @@ -174,4 +175,14 @@ def create_event() -> Event: return Event(handle, _get_prog_weakref()) -__all__ = ["Stream", "Event", "create_stream", "create_event"] +@contextmanager +def stream_parallel(): + """Run top-level for loops in this block on separate GPU streams. + + Used inside @qd.kernel. At Python runtime (outside kernels), this is a no-op. During kernel compilation, the AST + transformer calls into the C++ ASTBuilder to tag loops with a stream-parallel group ID. + """ + yield + + +__all__ = ["Stream", "Event", "create_stream", "create_event", "stream_parallel"] diff --git a/quadrants/analysis/gen_offline_cache_key.cpp b/quadrants/analysis/gen_offline_cache_key.cpp index 66f03aab20..24fa3ce435 100644 --- a/quadrants/analysis/gen_offline_cache_key.cpp +++ b/quadrants/analysis/gen_offline_cache_key.cpp @@ -377,6 +377,7 @@ class ASTSerializer : public IRVisitor, public ExpressionVisitor { emit(stmt->strictly_serialized); emit(stmt->mem_access_opt); emit(stmt->block_dim); + emit(stmt->stream_parallel_group_id); emit(stmt->body.get()); } diff --git a/quadrants/codegen/amdgpu/codegen_amdgpu.cpp b/quadrants/codegen/amdgpu/codegen_amdgpu.cpp index 7eb23a7a2e..df8a9aaeae 100644 --- a/quadrants/codegen/amdgpu/codegen_amdgpu.cpp +++ b/quadrants/codegen/amdgpu/codegen_amdgpu.cpp @@ -351,6 +351,7 @@ class TaskCodeGenAMDGPU : public TaskCodeGenLLVM { current_task->grid_dim = num_SMs * query_max_block_per_sm; } current_task->block_dim = stmt->block_dim; + current_task->stream_parallel_group_id = stmt->stream_parallel_group_id; QD_ASSERT(current_task->grid_dim != 0); QD_ASSERT(current_task->block_dim != 0); // Host-side adstack sizing, same scheme as codegen_cuda: tight `grid_dim * block_dim` for diff --git a/quadrants/codegen/cuda/codegen_cuda.cpp b/quadrants/codegen/cuda/codegen_cuda.cpp index 64d5b0f283..87fbb9abc9 100644 --- a/quadrants/codegen/cuda/codegen_cuda.cpp +++ b/quadrants/codegen/cuda/codegen_cuda.cpp @@ -638,6 +638,7 @@ class TaskCodeGenCUDA : public TaskCodeGenLLVM { } current_task->block_dim = stmt->block_dim; current_task->dynamic_shared_array_bytes = dynamic_shared_array_bytes; + current_task->stream_parallel_group_id = stmt->stream_parallel_group_id; QD_ASSERT(current_task->grid_dim != 0); QD_ASSERT(current_task->block_dim != 0); // Host-side adstack sizing. For non-range_for and for const-bound range_for the launcher uses diff --git a/quadrants/codegen/llvm/llvm_compiled_data.h b/quadrants/codegen/llvm/llvm_compiled_data.h index 1606dd31f8..8e1cd6466d 100644 --- a/quadrants/codegen/llvm/llvm_compiled_data.h +++ b/quadrants/codegen/llvm/llvm_compiled_data.h @@ -93,14 +93,20 @@ class OffloadedTask { int block_dim{0}; int grid_dim{0}; int dynamic_shared_array_bytes{0}; + int stream_parallel_group_id{0}; AdStackSizingInfo ad_stack{}; explicit OffloadedTask(const std::string &name = "", int block_dim = 0, int grid_dim = 0, - int dynamic_shared_array_bytes = 0) - : name(name), block_dim(block_dim), grid_dim(grid_dim), dynamic_shared_array_bytes(dynamic_shared_array_bytes) {}; - QD_IO_DEF(name, block_dim, grid_dim, dynamic_shared_array_bytes, ad_stack); + int dynamic_shared_array_bytes = 0, + int stream_parallel_group_id = 0) + : name(name), + block_dim(block_dim), + grid_dim(grid_dim), + dynamic_shared_array_bytes(dynamic_shared_array_bytes), + stream_parallel_group_id(stream_parallel_group_id) {}; + QD_IO_DEF(name, block_dim, grid_dim, dynamic_shared_array_bytes, stream_parallel_group_id, ad_stack); }; struct LLVMCompiledTask { diff --git a/quadrants/ir/frontend_ir.cpp b/quadrants/ir/frontend_ir.cpp index 4e118753ee..3c750d4ff9 100644 --- a/quadrants/ir/frontend_ir.cpp +++ b/quadrants/ir/frontend_ir.cpp @@ -110,6 +110,7 @@ FrontendForStmt::FrontendForStmt(const FrontendForStmt &o) strictly_serialized(o.strictly_serialized), mem_access_opt(o.mem_access_opt), block_dim(o.block_dim), + stream_parallel_group_id(o.stream_parallel_group_id), loop_name(o.loop_name) { } @@ -118,6 +119,7 @@ void FrontendForStmt::init_config(Arch arch, const ForLoopConfig &config) { strictly_serialized = config.strictly_serialized; mem_access_opt = config.mem_access_opt; block_dim = config.block_dim; + stream_parallel_group_id = config.stream_parallel_group_id; loop_name = config.loop_name; if (arch == Arch::cuda || arch == Arch::amdgpu) { num_cpu_threads = 1; @@ -1390,6 +1392,7 @@ void ASTBuilder::create_assert_stmt(const Expr &cond, } void ASTBuilder::begin_frontend_range_for(const Expr &i, const Expr &s, const Expr &e, const DebugInfo &dbg_info) { + for_loop_dec_.config.stream_parallel_group_id = current_stream_parallel_group_id_; auto stmt_unique = std::make_unique(i, s, e, arch_, for_loop_dec_.config, dbg_info); auto stmt = stmt_unique.get(); this->insert(std::move(stmt_unique)); @@ -1403,6 +1406,7 @@ void ASTBuilder::begin_frontend_struct_for_on_snode(const ExprGroup &loop_vars, QD_WARN_IF(for_loop_dec_.config.strictly_serialized, "ti.loop_config(serialize=True) does not have effect on the struct for. " "The execution order is not guaranteed."); + for_loop_dec_.config.stream_parallel_group_id = current_stream_parallel_group_id_; auto stmt_unique = std::make_unique(loop_vars, snode, arch_, for_loop_dec_.config, dbg_info); for_loop_dec_.reset(); auto stmt = stmt_unique.get(); @@ -1416,6 +1420,7 @@ void ASTBuilder::begin_frontend_struct_for_on_external_tensor(const ExprGroup &l QD_WARN_IF(for_loop_dec_.config.strictly_serialized, "ti.loop_config(serialize=True) does not have effect on the struct for. " "The execution order is not guaranteed."); + for_loop_dec_.config.stream_parallel_group_id = current_stream_parallel_group_id_; auto stmt_unique = std::make_unique(loop_vars, external_tensor, arch_, for_loop_dec_.config, dbg_info); for_loop_dec_.reset(); @@ -1431,6 +1436,7 @@ void ASTBuilder::begin_frontend_mesh_for(const Expr &i, QD_WARN_IF(for_loop_dec_.config.strictly_serialized, "ti.loop_config(serialize=True) does not have effect on the mesh for. " "The execution order is not guaranteed."); + for_loop_dec_.config.stream_parallel_group_id = current_stream_parallel_group_id_; auto stmt_unique = std::make_unique(ExprGroup(i), mesh_ptr, element_type, arch_, for_loop_dec_.config, dbg_info); for_loop_dec_.reset(); diff --git a/quadrants/ir/frontend_ir.h b/quadrants/ir/frontend_ir.h index 7d2c7bd9df..b4ad04a9b5 100644 --- a/quadrants/ir/frontend_ir.h +++ b/quadrants/ir/frontend_ir.h @@ -23,6 +23,7 @@ struct ForLoopConfig { MemoryAccessOptions mem_access_opt; int block_dim{0}; bool uniform{false}; + int stream_parallel_group_id{0}; std::string loop_name{""}; }; @@ -198,6 +199,7 @@ class FrontendForStmt : public Stmt { bool strictly_serialized; MemoryAccessOptions mem_access_opt; int block_dim; + int stream_parallel_group_id{0}; std::string loop_name; FrontendForStmt(const ExprGroup &loop_vars, @@ -887,6 +889,7 @@ class ASTBuilder { config.mem_access_opt.clear(); config.block_dim = 0; config.strictly_serialized = false; + config.stream_parallel_group_id = 0; config.loop_name.clear(); } }; @@ -897,6 +900,8 @@ class ASTBuilder { Arch arch_; ForLoopDecoratorRecorder for_loop_dec_; int id_counter_{0}; + int stream_parallel_group_counter_{0}; + int current_stream_parallel_group_id_{0}; public: ASTBuilder(Block *initial, Arch arch, bool is_kernel) : is_kernel_(is_kernel), arch_(arch) { @@ -1022,6 +1027,15 @@ class ASTBuilder { for_loop_dec_.reset(); } + void begin_stream_parallel() { + QD_ERROR_IF(current_stream_parallel_group_id_ != 0, "Nested stream_parallel blocks are not supported"); + current_stream_parallel_group_id_ = ++stream_parallel_group_counter_; + } + + void end_stream_parallel() { + current_stream_parallel_group_id_ = 0; + } + Identifier get_next_id(const std::string &name = "") { return Identifier(id_counter_++, name); } diff --git a/quadrants/ir/statements.cpp b/quadrants/ir/statements.cpp index 75f66a7475..9adebe8e87 100644 --- a/quadrants/ir/statements.cpp +++ b/quadrants/ir/statements.cpp @@ -222,6 +222,7 @@ std::unique_ptr RangeForStmt::clone() const { auto new_stmt = std::make_unique(begin, end, body->clone(), is_bit_vectorized, num_cpu_threads, block_dim, strictly_serialized); new_stmt->reversed = reversed; + new_stmt->stream_parallel_group_id = stream_parallel_group_id; new_stmt->loop_name = loop_name; return new_stmt; } @@ -243,6 +244,7 @@ StructForStmt::StructForStmt(SNode *snode, std::unique_ptr StructForStmt::clone() const { auto new_stmt = std::make_unique(snode, body->clone(), is_bit_vectorized, num_cpu_threads, block_dim); new_stmt->mem_access_opt = mem_access_opt; + new_stmt->stream_parallel_group_id = stream_parallel_group_id; new_stmt->loop_name = loop_name; return new_stmt; } @@ -402,6 +404,7 @@ std::unique_ptr OffloadedStmt::clone() const { new_stmt->tls_size = tls_size; new_stmt->bls_size = bls_size; new_stmt->mem_access_opt = mem_access_opt; + new_stmt->stream_parallel_group_id = stream_parallel_group_id; new_stmt->loop_name = loop_name; // Shared-pointer copy: the captured trip-count `SizeExpr` is read-only after `determine_ad_stack_size` // populates it in `compile_to_offloads`, and LLVM codegen clones each offload at `codegen.cpp:68` diff --git a/quadrants/ir/statements.h b/quadrants/ir/statements.h index 3768f52bf1..2426396dab 100644 --- a/quadrants/ir/statements.h +++ b/quadrants/ir/statements.h @@ -955,6 +955,7 @@ class RangeForStmt : public Stmt { int block_dim; bool strictly_serialized; std::string range_hint; + int stream_parallel_group_id{0}; std::string loop_name; RangeForStmt(Stmt *begin, @@ -977,7 +978,14 @@ class RangeForStmt : public Stmt { std::unique_ptr clone() const override; - QD_STMT_DEF_FIELDS(begin, end, reversed, is_bit_vectorized, num_cpu_threads, block_dim, strictly_serialized); + QD_STMT_DEF_FIELDS(begin, + end, + reversed, + is_bit_vectorized, + num_cpu_threads, + block_dim, + strictly_serialized, + stream_parallel_group_id); QD_DEFINE_ACCEPT }; @@ -996,6 +1004,7 @@ class StructForStmt : public Stmt { int num_cpu_threads; int block_dim; MemoryAccessOptions mem_access_opt; + int stream_parallel_group_id{0}; std::string loop_name; StructForStmt(SNode *snode, @@ -1010,7 +1019,13 @@ class StructForStmt : public Stmt { std::unique_ptr clone() const override; - QD_STMT_DEF_FIELDS(snode, index_offsets, is_bit_vectorized, num_cpu_threads, block_dim, mem_access_opt); + QD_STMT_DEF_FIELDS(snode, + index_offsets, + is_bit_vectorized, + num_cpu_threads, + block_dim, + mem_access_opt, + stream_parallel_group_id); QD_DEFINE_ACCEPT }; @@ -1352,6 +1367,7 @@ class OffloadedStmt : public Stmt { std::size_t tls_size{1}; // avoid allocating dynamic memory with 0 byte std::size_t bls_size{0}; MemoryAccessOptions mem_access_opt; + int stream_parallel_group_id{0}; // Pre-chunking loop trip-count `SizeExpr` captured by `determine_ad_stack_size`. Set on adstack-bearing // range-for tasks before `make_cpu_multithreaded_range_for` rewrites the loop into per-thread chunks, so the @@ -1399,7 +1415,8 @@ class OffloadedStmt : public Stmt { reversed, num_cpu_threads, index_offsets, - mem_access_opt); + mem_access_opt, + stream_parallel_group_id); QD_DEFINE_ACCEPT }; diff --git a/quadrants/python/export_lang.cpp b/quadrants/python/export_lang.cpp index b3dc79bef5..9a9eca1d3d 100644 --- a/quadrants/python/export_lang.cpp +++ b/quadrants/python/export_lang.cpp @@ -306,7 +306,9 @@ void export_lang(py::module &m) { .def("strictly_serialize", &ASTBuilder::strictly_serialize) .def("block_dim", &ASTBuilder::block_dim) .def("insert_snode_access_flag", &ASTBuilder::insert_snode_access_flag) - .def("reset_snode_access_flag", &ASTBuilder::reset_snode_access_flag); + .def("reset_snode_access_flag", &ASTBuilder::reset_snode_access_flag) + .def("begin_stream_parallel", &ASTBuilder::begin_stream_parallel) + .def("end_stream_parallel", &ASTBuilder::end_stream_parallel); auto device_capability_config = py::class_(m, "DeviceCapabilityConfig").def("get", &DeviceCapabilityConfig::get); diff --git a/quadrants/runtime/amdgpu/kernel_launcher.cpp b/quadrants/runtime/amdgpu/kernel_launcher.cpp index 30c75a49ee..3bd2fa7fec 100644 --- a/quadrants/runtime/amdgpu/kernel_launcher.cpp +++ b/quadrants/runtime/amdgpu/kernel_launcher.cpp @@ -1,3 +1,5 @@ +#include + #include "quadrants/runtime/amdgpu/kernel_launcher.h" #include "quadrants/rhi/amdgpu/amdgpu_context.h" #include "quadrants/program/launch_context_builder.h" @@ -55,20 +57,15 @@ void KernelLauncher::launch_offloaded_tasks(LaunchContextBuilder &ctx, void *context_pointer, int arg_size) { auto *executor = get_runtime_executor(); - // Two gates govern the per-launch adstack publish work, both opt-in by the kernel's IR shape. Forward-only kernels - // skip both gates and pay zero adstack overhead; reverse-mode kernels without a captured `bound_expr` skip the - // lazy-claim block, paying the per-task `publish_adstack_metadata` only. See the matching comment in - // `runtime/cuda/kernel_launcher.cpp` for the role of each gate. + // See the matching comment in `runtime/cuda/kernel_launcher.cpp` for the role of each gate. const bool any_lazy_task = std::any_of(offloaded_tasks.begin(), offloaded_tasks.end(), [](const OffloadedTask &t) { return t.ad_stack.bound_expr.has_value(); }); if (any_lazy_task) { - // Allocate / reset the per-kernel lazy-claim arrays once before the first task. See the matching CPU launcher - // block for rationale; on AMDGPU the same memcpy_host_to_device path through the cached field pointers publishes - // the cleared counter and UINT32_MAX-defaulted capacity arrays. executor->publish_adstack_lazy_claim_buffers(offloaded_tasks.size()); } - std::size_t task_index = 0; - for (const auto &task : offloaded_tasks) { + + // Per-task adstack setup + grid-dim capping. Shared by serial and stream-parallel paths. + auto prepare_task = [&](std::size_t task_index, const OffloadedTask &task) -> int { int effective_grid_dim = task.grid_dim; if (!task.ad_stack.allocas.empty()) { // Pass the device-side `RuntimeContext` pointer through to the adstack sizer kernel. Without this the sizer @@ -106,7 +103,6 @@ void KernelLauncher::launch_offloaded_tasks(LaunchContextBuilder &ctx, executor->ensure_per_task_float_heap_post_reducer(task_index, task.ad_stack, n_threads_amdgpu, &ctx); } } - ++task_index; // Match the heap-row count resolved above: adstack-bearing tasks dispatch at most `kAdStackMaxConcurrentThreads`. // The runtime grid-strided loop walks the full element list / range with `i += grid_dim()` so a smaller grid // completes the same workload sequentially per slot. @@ -119,9 +115,52 @@ void KernelLauncher::launch_offloaded_tasks(LaunchContextBuilder &ctx, effective_grid_dim = 1; } } - QD_TRACE("Launching kernel {}<<<{}, {}>>>", task.name, effective_grid_dim, task.block_dim); - amdgpu_module->launch(task.name, effective_grid_dim, task.block_dim, task.dynamic_shared_array_bytes, - {(void *)&context_pointer}, {arg_size}); + return effective_grid_dim; + }; + + auto *active_stream = AMDGPUContext::get_instance().get_stream(); + for (size_t i = 0; i < offloaded_tasks.size();) { + const auto &task = offloaded_tasks[i]; + if (task.stream_parallel_group_id == 0) { + int effective_grid_dim = prepare_task(i, task); + QD_TRACE("Launching kernel {}<<<{}, {}>>>", task.name, effective_grid_dim, task.block_dim); + amdgpu_module->launch(task.name, effective_grid_dim, task.block_dim, task.dynamic_shared_array_bytes, + {(void *)&context_pointer}, {arg_size}); + i++; + } else { + size_t group_start = i; + while (i < offloaded_tasks.size() && offloaded_tasks[i].stream_parallel_group_id != 0) { + i++; + } + + std::map stream_by_id; + for (size_t j = group_start; j < i; j++) { + int sid = offloaded_tasks[j].stream_parallel_group_id; + if (stream_by_id.find(sid) == stream_by_id.end()) { + void *s = nullptr; + AMDGPUDriver::get_instance().stream_create(&s, 0x1 /*HIP_STREAM_NON_BLOCKING*/); + stream_by_id[sid] = s; + } + } + + for (size_t j = group_start; j < i; j++) { + const auto &t = offloaded_tasks[j]; + int effective_grid_dim = prepare_task(j, t); + AMDGPUContext::get_instance().set_stream(stream_by_id[t.stream_parallel_group_id]); + QD_TRACE("Launching kernel {}<<<{}, {}>>>", t.name, effective_grid_dim, t.block_dim); + amdgpu_module->launch(t.name, effective_grid_dim, t.block_dim, t.dynamic_shared_array_bytes, + {(void *)&context_pointer}, {arg_size}); + } + + for (auto &[sid, s] : stream_by_id) { + AMDGPUDriver::get_instance().stream_synchronize(s); + } + for (auto &[sid, s] : stream_by_id) { + AMDGPUDriver::get_instance().stream_destroy(s); + } + + AMDGPUContext::get_instance().set_stream(active_stream); + } } } diff --git a/quadrants/runtime/cuda/kernel_launcher.cpp b/quadrants/runtime/cuda/kernel_launcher.cpp index a52acf2ebf..e632aed4da 100644 --- a/quadrants/runtime/cuda/kernel_launcher.cpp +++ b/quadrants/runtime/cuda/kernel_launcher.cpp @@ -1,3 +1,5 @@ +#include + #include "quadrants/runtime/cuda/kernel_launcher.h" #include "quadrants/runtime/cuda/cuda_utils.h" #include "quadrants/rhi/cuda/cuda_context.h" @@ -68,22 +70,14 @@ void KernelLauncher::launch_offloaded_tasks(LaunchContextBuilder &ctx, // Two gates govern the per-launch adstack publish work, both opt-in by the kernel's IR shape. Forward-only kernels // skip both gates and pay zero adstack overhead; reverse-mode kernels without a captured `bound_expr` skip the // lazy-claim block, paying the per-task `publish_adstack_metadata` only. - // - `any_adstack`: at least one task has an `AdStackAllocaStmt`. Gates the per-task `publish_adstack_metadata` - // call (sets per-thread stride for the codegen heap-base addressing). - // - `any_lazy_task`: at least one task has a captured `bound_expr` (the codegen routes such tasks through the - // lazy LCA-block atomic-rmw row claim, which reads `runtime->adstack_row_counters[task_id]` and - // `runtime->adstack_bound_row_capacities[task_id]`). Gates `publish_adstack_lazy_claim_buffers` and the - // per-task reducer dispatch + DtoH heap sizing. const bool any_lazy_task = std::any_of(offloaded_tasks.begin(), offloaded_tasks.end(), [](const OffloadedTask &t) { return t.ad_stack.bound_expr.has_value(); }); if (any_lazy_task) { - // Allocate / reset the per-kernel lazy-claim arrays once before the first task. See the matching CPU launcher - // block for rationale; on CUDA the same memcpy_host_to_device path through the cached field pointers publishes - // the cleared counter and UINT32_MAX-defaulted capacity arrays. executor->publish_adstack_lazy_claim_buffers(offloaded_tasks.size()); } - std::size_t task_index = 0; - for (const auto &task : offloaded_tasks) { + + // Per-task adstack setup + grid-dim capping. Shared by serial and stream-parallel paths. + auto prepare_task = [&](std::size_t task_index, const OffloadedTask &task) -> int { int effective_grid_dim = task.grid_dim; if (!task.ad_stack.allocas.empty()) { std::size_t n = resolve_num_threads(task.ad_stack, executor); @@ -94,12 +88,6 @@ void KernelLauncher::launch_offloaded_tasks(LaunchContextBuilder &ctx, // `publish_adstack_metadata`'s host-pointer fast path. executor->publish_adstack_metadata(task.ad_stack, n, &ctx, device_context_ptr); if (task.ad_stack.bound_expr.has_value()) { - // Device-side reducer for tasks with a captured ndarray-backed `bound_expr`: a single-thread CUDA kernel - // walks the gating ndarray, counts gate-passing threads, writes the count into - // `runtime->adstack_bound_row_capacities[task_index]`. The codegen-emitted clamp at the float LCA-block - // claim site reads it back. Tasks without a captured gate keep the UINT32_MAX default and the clamp stays - // inert. - // // Reducer length is the gating ndarray's full flat element count, not `n`: the lazy row-claim atomic-rmw // fires once per LCA execution, and `gpu_parallel_struct_for` / `gpu_parallel_range_for` grid-stride (`i += // grid_dim()`) so a single dispatched thread can hit the LCA many times across one launch when the logical @@ -110,9 +98,6 @@ void KernelLauncher::launch_offloaded_tasks(LaunchContextBuilder &ctx, if (task.ad_stack.bound_expr->field_source_kind == StaticAdStackBoundExpr::FieldSourceKind::NdArray && !task.ad_stack.bound_expr->ndarray_arg_id.empty() && task.ad_stack.bound_expr->ndarray_ndim > 0 && ctx.args_type != nullptr) { - // Length = product of shape entries via `args_type`. See `runtime/cpu/kernel_launcher.cpp` for the - // unit-stability rationale; `array_runtime_sizes` carries different units depending on the dispatch entry - // point and would undercount by `sizeof(elem)`x for `qd.ndarray` arguments. int64_t flat_len = 1; for (int axis = 0; axis < task.ad_stack.bound_expr->ndarray_ndim; ++axis) { std::vector indices = task.ad_stack.bound_expr->ndarray_arg_id; @@ -134,29 +119,67 @@ void KernelLauncher::launch_offloaded_tasks(LaunchContextBuilder &ctx, // dispatched-threads worst case on sparse-grid workloads. executor->ensure_per_task_float_heap_post_reducer(task_index, task.ad_stack, n, &ctx); } - } - ++task_index; - // For adstack-bearing tasks, dispatch at most `kAdStackMaxConcurrentThreads` (matching the heap row count resolved - // above). The runtime's grid-strided loop (`gpu_parallel_struct_for` / `gpu_parallel_range_for`, - // `quadrants/runtime/llvm/runtime_module/runtime.cpp`) walks the full element list / range with `i += grid_dim()`, - // so a smaller grid completes the same workload sequentially per slot. Tasks without an adstack keep the - // codegen-emitted `task.grid_dim` (saturating_grid_dim) for max throughput. - if (!task.ad_stack.allocas.empty() && task.block_dim > 0) { // Floor division (not ceiling): the heap-row count `n` resolved by `resolve_num_threads` floors at // `kAdStackMaxConcurrentThreads`, so dispatching `cap_blocks * block_dim` threads must not exceed that count. // Ceiling division would over-dispatch by `block_dim - 1` threads when `block_dim` does not divide - // `kAdStackMaxConcurrentThreads` evenly (e.g. `block_dim=192`: `ceil(65536/192)*192 = 65664`), and threads with - // `linear_thread_idx >= 65536` would index past the heap end. - const std::size_t cap_blocks = - std::max(1u, kAdStackMaxConcurrentThreads / static_cast(task.block_dim)); - effective_grid_dim = static_cast(std::min(static_cast(task.grid_dim), cap_blocks)); - if (effective_grid_dim < 1) { - effective_grid_dim = 1; + // `kAdStackMaxConcurrentThreads` evenly (e.g. `block_dim=192`: `ceil(65536/192)*192 = 65664`), and threads + // with `linear_thread_idx >= 65536` would index past the heap end. + if (task.block_dim > 0) { + const std::size_t cap_blocks = + std::max(1u, kAdStackMaxConcurrentThreads / static_cast(task.block_dim)); + effective_grid_dim = + static_cast(std::min(static_cast(task.grid_dim), cap_blocks)); + if (effective_grid_dim < 1) { + effective_grid_dim = 1; + } + } + } + return effective_grid_dim; + }; + + auto *active_stream = CUDAContext::get_instance().get_stream(); + for (size_t i = 0; i < offloaded_tasks.size();) { + const auto &task = offloaded_tasks[i]; + if (task.stream_parallel_group_id == 0) { + int effective_grid_dim = prepare_task(i, task); + QD_TRACE("Launching kernel {}<<<{}, {}>>>", task.name, effective_grid_dim, task.block_dim); + cuda_module->launch(task.name, effective_grid_dim, task.block_dim, task.dynamic_shared_array_bytes, + {&ctx.get_context()}, {}); + i++; + } else { + size_t group_start = i; + while (i < offloaded_tasks.size() && offloaded_tasks[i].stream_parallel_group_id != 0) { + i++; } + + std::map stream_by_id; + for (size_t j = group_start; j < i; j++) { + int sid = offloaded_tasks[j].stream_parallel_group_id; + if (stream_by_id.find(sid) == stream_by_id.end()) { + void *s = nullptr; + CUDADriver::get_instance().stream_create(&s, 0x1 /*CU_STREAM_NON_BLOCKING*/); + stream_by_id[sid] = s; + } + } + + for (size_t j = group_start; j < i; j++) { + const auto &t = offloaded_tasks[j]; + int effective_grid_dim = prepare_task(j, t); + CUDAContext::get_instance().set_stream(stream_by_id[t.stream_parallel_group_id]); + QD_TRACE("Launching kernel {}<<<{}, {}>>>", t.name, effective_grid_dim, t.block_dim); + cuda_module->launch(t.name, effective_grid_dim, t.block_dim, t.dynamic_shared_array_bytes, {&ctx.get_context()}, + {}); + } + + for (auto &[sid, s] : stream_by_id) { + CUDADriver::get_instance().stream_synchronize(s); + } + for (auto &[sid, s] : stream_by_id) { + CUDADriver::get_instance().stream_destroy(s); + } + + CUDAContext::get_instance().set_stream(active_stream); } - QD_TRACE("Launching kernel {}<<<{}, {}>>>", task.name, effective_grid_dim, task.block_dim); - cuda_module->launch(task.name, effective_grid_dim, task.block_dim, task.dynamic_shared_array_bytes, - {&ctx.get_context()}, {}); } } diff --git a/quadrants/transforms/lower_ast.cpp b/quadrants/transforms/lower_ast.cpp index 72b45cb976..6818ad2f90 100644 --- a/quadrants/transforms/lower_ast.cpp +++ b/quadrants/transforms/lower_ast.cpp @@ -222,6 +222,7 @@ class LowerAST : public IRVisitor { stmt->num_cpu_threads, stmt->block_dim); new_for->loop_name = stmt->loop_name; new_for->index_offsets = offsets; + new_for->stream_parallel_group_id = stmt->stream_parallel_group_id; VecStatement new_statements; for (int i = 0; i < (int)stmt->loop_var_ids.size(); i++) { Stmt *loop_index = new_statements.push_back(new_for.get(), snode->physical_index_position[i]); @@ -256,6 +257,7 @@ class LowerAST : public IRVisitor { stmt->num_cpu_threads, stmt->block_dim, stmt->strictly_serialized, /*range_hint=*/fmt::format("arg ({})", fmt::join(arg_id, ", ")), /*loop_name=*/stmt->loop_name); + new_for->stream_parallel_group_id = stmt->stream_parallel_group_id; VecStatement new_statements; Stmt *loop_index = new_statements.push_back(new_for.get(), 0); for (int i = (int)shape.size() - 1; i >= 0; i--) { @@ -289,6 +291,7 @@ class LowerAST : public IRVisitor { stmt->is_bit_vectorized, stmt->num_cpu_threads, stmt->block_dim, stmt->strictly_serialized, /*range_hint=*/"", /*loop_name=*/stmt->loop_name); + new_for->stream_parallel_group_id = stmt->stream_parallel_group_id; new_for->body->insert(std::make_unique(new_for.get(), 0), 0); new_for->body->local_var_to_stmt[stmt->loop_var_ids[0]] = new_for->body->statements[0].get(); fctx.push_back(std::move(new_for)); diff --git a/quadrants/transforms/offload.cpp b/quadrants/transforms/offload.cpp index b9027e95b3..c8a62a1409 100644 --- a/quadrants/transforms/offload.cpp +++ b/quadrants/transforms/offload.cpp @@ -126,6 +126,7 @@ class Offloader { offloaded->body->insert(std::move(s->body->statements[j])); } offloaded->range_hint = s->range_hint; + offloaded->stream_parallel_group_id = s->stream_parallel_group_id; offloaded->loop_name = s->loop_name; root_block->insert(std::move(offloaded)); } else if (auto st = stmt->cast()) { @@ -237,6 +238,7 @@ class Offloader { offloaded_struct_for->is_bit_vectorized = for_stmt->is_bit_vectorized; offloaded_struct_for->num_cpu_threads = std::min(for_stmt->num_cpu_threads, config.cpu_max_num_threads); offloaded_struct_for->mem_access_opt = mem_access_opt; + offloaded_struct_for->stream_parallel_group_id = for_stmt->stream_parallel_group_id; offloaded_struct_for->loop_name = for_stmt->loop_name; root_block->insert(std::move(offloaded_struct_for)); diff --git a/tests/python/test_api.py b/tests/python/test_api.py index 0427d2f320..5bc1045cfb 100644 --- a/tests/python/test_api.py +++ b/tests/python/test_api.py @@ -226,6 +226,7 @@ def _get_expected_matrix_apis(): "static_assert", "static_print", "stop_grad", + "stream_parallel", "svd", "sym_eig", "sync", diff --git a/tests/python/test_streams.py b/tests/python/test_streams.py index f5a4d8f6d0..f4408f31f5 100644 --- a/tests/python/test_streams.py +++ b/tests/python/test_streams.py @@ -181,23 +181,6 @@ def fill(): e.destroy() -@test_utils.test() -def test_stream_with_ndarray(): - N = 1024 - - @qd.kernel - def fill(arr: qd.types.ndarray(dtype=qd.f32, ndim=1)): - for i in range(N): - arr[i] = 99.0 - - arr = qd.ndarray(qd.f32, shape=(N,)) - s = qd.create_stream() - fill(arr, qd_stream=s) - s.synchronize() - assert np.allclose(arr.to_numpy(), 99.0) - s.destroy() - - @test_utils.test() def test_concurrent_streams_with_events(): """Two slow kernels on separate streams run concurrently (~1s on GPU), serial fallback on CPU/Metal.""" @@ -366,3 +349,164 @@ def fill(): with pytest.raises(RuntimeError, match="not compatible with graph=True"): fill(qd_stream=s) s.destroy() + + +@test_utils.test() +def test_stream_parallel_basic(): + """Each with qd.stream_parallel() block runs on its own stream (serial fallback on CPU/Metal).""" + N = 1024 + a = qd.field(qd.f32, shape=(N,)) + b = qd.field(qd.f32, shape=(N,)) + + @qd.kernel + def fill_parallel(): + with qd.stream_parallel(): + for i in range(N): + a[i] = 1.0 + with qd.stream_parallel(): + for j in range(N): + b[j] = 2.0 + + fill_parallel() + qd.sync() + assert np.allclose(a.to_numpy(), 1.0) + assert np.allclose(b.to_numpy(), 2.0) + + +@test_utils.test() +def test_stream_parallel_multiple_loops_per_stream(): + """Multiple for loops inside one stream_parallel block share a stream (serial fallback on CPU/Metal).""" + N = 1024 + a = qd.field(qd.f32, shape=(N,)) + b = qd.field(qd.f32, shape=(N,)) + c = qd.field(qd.f32, shape=(N,)) + + @qd.kernel + def parallel_phase(): + with qd.stream_parallel(): + for i in range(N): + a[i] = 1.0 + for i in range(N): + a[i] = a[i] + 1.0 + with qd.stream_parallel(): + for j in range(N): + b[j] = 10.0 + + @qd.kernel + def combine(): + for i in range(N): + c[i] = a[i] + b[i] + + parallel_phase() + combine() + qd.sync() + assert np.allclose(a.to_numpy(), 2.0) + assert np.allclose(b.to_numpy(), 10.0) + assert np.allclose(c.to_numpy(), 12.0) + + +@test_utils.test() +def test_stream_parallel_timing(): + """stream_parallel achieves speedup on GPU, serial fallback elsewhere.""" + SPIN_ITERS = 5_000_000 + + a = qd.field(qd.i32, shape=(2,)) + b = qd.field(qd.i32, shape=(2,)) + + @qd.kernel + def serial_spin(): + for _ in range(1): + x = a[0] + for _j in range(SPIN_ITERS): + x = (1664525 * x + 1013904223) % 2147483647 + a[0] = x + for _ in range(1): + x = a[1] + for _j in range(SPIN_ITERS): + x = (1664525 * x + 1013904223) % 2147483647 + a[1] = x + + @qd.kernel + def parallel_spin(): + with qd.stream_parallel(): + for _ in range(1): + x = b[0] + for _j in range(SPIN_ITERS): + x = (1664525 * x + 1013904223) % 2147483647 + b[0] = x + with qd.stream_parallel(): + for _ in range(1): + x = b[1] + for _j in range(SPIN_ITERS): + x = (1664525 * x + 1013904223) % 2147483647 + b[1] = x + + import time + + # Warm up + serial_spin() + parallel_spin() + qd.sync() + + qd.sync() + t0 = time.perf_counter() + serial_spin() + qd.sync() + serial_time = time.perf_counter() - t0 + + qd.sync() + t0 = time.perf_counter() + parallel_spin() + qd.sync() + stream_time = time.perf_counter() - t0 + + speedup = serial_time / stream_time + if qd.lang.impl.current_cfg().arch in (qd.cuda, qd.amdgpu): + assert speedup > 1.5, ( + f"Expected >1.5x speedup, got {speedup:.2f}x " f"(serial={serial_time:.3f}s, stream={stream_time:.3f}s)" + ) + else: + assert speedup > 0.75, ( + f"Expected >=0.75x (serial fallback), got {speedup:.2f}x " + f"(serial={serial_time:.3f}s, stream={stream_time:.3f}s)" + ) + + +@test_utils.test() +def test_stream_parallel_rejects_mixed_top_level(): + """Mixing stream_parallel and non-stream_parallel at top level is an error.""" + import pytest # noqa: I001 + + from quadrants.lang.exception import QuadrantsSyntaxError + + N = 64 + a = qd.field(qd.f32, shape=(N,)) + + with pytest.raises(QuadrantsSyntaxError, match="all top-level statements"): + + @qd.kernel + def bad_kernel(): + with qd.stream_parallel(): + for i in range(N): + a[i] = 1.0 + for i in range(N): + a[i] = 2.0 + + bad_kernel() + + +@test_utils.test() +def test_stream_with_ndarray(): + N = 1024 + + @qd.kernel + def fill(arr: qd.types.ndarray(dtype=qd.f32, ndim=1)): + for i in range(N): + arr[i] = 99.0 + + arr = qd.ndarray(qd.f32, shape=(N,)) + s = qd.create_stream() + fill(arr, qd_stream=s) + s.synchronize() + assert np.allclose(arr.to_numpy(), 99.0) + s.destroy()