Skip to content

Commit dd9e377

Browse files
committed
feat: add multi-stream and operator fusion conflict detection
- Add operator_fusion_enabled flag to ggml_backend_cann_context - Implement conflict detection in constructor: * ACL graph mode disables multi-stream (higher performance) * Multi-stream mode disables operator fusion (low benefit) - Remove multi-stream fusion code (fusion disabled in multi-stream) - Keep fusion functionality in single-stream mode - Remove redundant multi_stream_enabled check in graph_compute - Fix unused variable warning (sync_all_to_stream)
1 parent 4951a4f commit dd9e377

2 files changed

Lines changed: 38 additions & 55 deletions

File tree

ggml/src/ggml-cann/common.h

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -571,6 +571,9 @@ struct ggml_backend_cann_context {
571571
aclrtEvent stream_events[GGML_CANN_NUM_COMPUTE_STREAMS] = { nullptr }; /**< Events for stream synchronization. */
572572
std::vector<const ggml_tensor *> unsynced_nodes; /**< Nodes that have been executed but not synced. */
573573

574+
// Operator fusion support
575+
bool operator_fusion_enabled = false; /**< Whether operator fusion is enabled. */
576+
574577
/**
575578
* @brief Constructor for initializing the context with a given device.
576579
* @param device Device ID.
@@ -584,6 +587,38 @@ struct ggml_backend_cann_context {
584587
GGML_LOG_INFO("%s: device %d execution mode is %s (%s)\n", __func__, device, acl_graph_mode ? "GRAPH" : "EAGER",
585588
acl_graph_mode ? "acl graph enabled" : "acl graph disabled");
586589
#endif
590+
591+
// Read environment variables for multi-stream and operator fusion
592+
bool env_multi_stream = parse_bool(get_env_as_lowercase("GGML_CANN_MULTI_STREAM").value_or(""));
593+
bool env_operator_fusion = parse_bool(get_env_as_lowercase("GGML_CANN_OPERATOR_FUSION").value_or(""));
594+
595+
// Handle conflicts and set final values
596+
#ifdef USE_ACL_GRAPH
597+
if (acl_graph_mode && env_multi_stream) {
598+
// ACL graph has higher performance, disable multi-stream
599+
multi_stream_enabled = false;
600+
operator_fusion_enabled = env_operator_fusion;
601+
GGML_LOG_INFO("%s: device %d multi-stream disabled (ACL graph mode has higher performance)\n",
602+
__func__, device);
603+
} else
604+
#endif
605+
if (env_multi_stream) {
606+
// Multi-stream enabled, disable operator fusion (fusion has low benefit with multi-stream)
607+
multi_stream_enabled = true;
608+
operator_fusion_enabled = false;
609+
if (env_operator_fusion) {
610+
GGML_LOG_INFO("%s: device %d operator fusion disabled (low benefit with multi-stream enabled)\n",
611+
__func__, device);
612+
}
613+
GGML_LOG_INFO("%s: device %d multi-stream execution enabled\n", __func__, device);
614+
} else {
615+
// Default single-stream mode
616+
multi_stream_enabled = false;
617+
operator_fusion_enabled = env_operator_fusion;
618+
if (env_operator_fusion) {
619+
GGML_LOG_INFO("%s: device %d operator fusion enabled\n", __func__, device);
620+
}
621+
}
587622
}
588623

589624
/**

ggml/src/ggml-cann/ggml-cann.cpp

Lines changed: 3 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -2226,13 +2226,9 @@ static void evaluate_and_capture_cann_graph(ggml_backend_cann_context * cann_ctx
22262226
#endif // USE_ACL_GRAPH
22272227
// Only perform the graph execution if CANN graphs are not enabled, or we are capturing the graph.
22282228
// With the use of CANN graphs, the execution will be performed by the graph launch.
2229-
static bool opt_fusion = parse_bool(get_env_as_lowercase("GGML_CANN_OPERATOR_FUSION").value_or(""));
2230-
2231-
// Check if multi-stream execution is enabled
2232-
static bool multi_stream_enabled = parse_bool(get_env_as_lowercase("GGML_CANN_MULTI_STREAM").value_or(""));
22332229

22342230
if (!use_cann_graph || cann_graph_capture_required) {
2235-
if (multi_stream_enabled) {
2231+
if (cann_ctx->multi_stream_enabled) {
22362232
// Multi-stream execution mode using memory-based dependency tracking
22372233
// Note: multi_stream_enabled implies !use_cann_graph (set in graph_compute)
22382234
// Track data pointers that have pending writes on each stream
@@ -2247,23 +2243,6 @@ static void evaluate_and_capture_cann_graph(ggml_backend_cann_context * cann_ctx
22472243
}
22482244
}
22492245

2250-
// Helper lambda to synchronize all active streams to the target stream
2251-
auto sync_all_to_stream = [&](int target_stream) {
2252-
if (active_streams.empty()) return;
2253-
2254-
// Record events on all active streams
2255-
for (int s : active_streams) {
2256-
ACL_CHECK(aclrtRecordEvent(cann_ctx->stream_events[s], cann_ctx->stream(s)));
2257-
}
2258-
// Wait for all events on the target stream
2259-
for (int s : active_streams) {
2260-
ACL_CHECK(aclrtStreamWaitEvent(cann_ctx->stream(target_stream), cann_ctx->stream_events[s]));
2261-
}
2262-
// Clear tracking
2263-
pending_writes.clear();
2264-
active_streams.clear();
2265-
};
2266-
22672246
// Helper lambda to wait for a specific stream on the target stream
22682247
auto wait_for_stream = [&](int src_stream, int target_stream) {
22692248
if (src_stream == target_stream) return;
@@ -2273,25 +2252,6 @@ static void evaluate_and_capture_cann_graph(ggml_backend_cann_context * cann_ctx
22732252

22742253
for (int i = 0; i < cgraph->n_nodes; i++) {
22752254
ggml_tensor * node = cgraph->nodes[i];
2276-
if (opt_fusion) {
2277-
if (ggml_cann_can_fuse(cgraph, i, { GGML_OP_ADD, GGML_OP_RMS_NORM })) {
2278-
// Fusion ops need synchronization - execute on stream 0
2279-
sync_all_to_stream(0);
2280-
2281-
// Execute fused op on stream 0
2282-
ggml_cann_op_add_rms_norm_fused(*cann_ctx, node, cgraph->nodes[i + 1]);
2283-
2284-
// Track the output
2285-
void * out_ptr = get_data_ptr(cgraph->nodes[i + 1]);
2286-
if (out_ptr) {
2287-
pending_writes[out_ptr] = 0;
2288-
active_streams.insert(0);
2289-
}
2290-
i++;
2291-
current_stream = 1; // Next node goes to stream 1
2292-
continue;
2293-
}
2294-
}
22952255

22962256
if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE ||
22972257
node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
@@ -2373,7 +2333,7 @@ static void evaluate_and_capture_cann_graph(ggml_backend_cann_context * cann_ctx
23732333
// Single-stream execution mode (original behavior)
23742334
for (int i = 0; i < cgraph->n_nodes; i++) {
23752335
ggml_tensor * node = cgraph->nodes[i];
2376-
if (opt_fusion) {
2336+
if (cann_ctx->operator_fusion_enabled) {
23772337
if (ggml_cann_can_fuse(cgraph, i, { GGML_OP_ADD, GGML_OP_RMS_NORM })) {
23782338
ggml_cann_op_add_rms_norm_fused(*cann_ctx, node, cgraph->nodes[i + 1]);
23792339
i++;
@@ -2438,14 +2398,6 @@ static enum ggml_status ggml_backend_cann_graph_compute(ggml_backend_t backend,
24382398
#ifdef USE_ACL_GRAPH
24392399
bool use_cann_graph = true;
24402400

2441-
// Check if multi-stream execution is enabled (must check before using use_cann_graph)
2442-
static bool multi_stream_enabled = parse_bool(get_env_as_lowercase("GGML_CANN_MULTI_STREAM").value_or(""));
2443-
2444-
// Multi-stream mode is incompatible with ACL graph capture/execution
2445-
if (multi_stream_enabled) {
2446-
use_cann_graph = false;
2447-
}
2448-
24492401
if (use_cann_graph) {
24502402
static bool prefill_use_graph = parse_bool(get_env_as_lowercase("GGML_CANN_PREFILL_USE_GRAPH").value_or(""));
24512403
if (!prefill_use_graph) {
@@ -2839,11 +2791,7 @@ static void ggml_backend_cann_event_wait(ggml_backend_t backend, ggml_backend_ev
28392791
*/
28402792
static void ggml_backend_cann_graph_optimize(ggml_backend_t backend, struct ggml_cgraph * graph) {
28412793
// Check if graph optimization is disabled via environment variable
2842-
static bool disable_graph_optimize = [] {
2843-
const char * env = getenv("GGML_CANN_DISABLE_GRAPH_OPTIMIZE");
2844-
return env != nullptr;
2845-
}();
2846-
2794+
static bool disable_graph_optimize = parse_bool(get_env_as_lowercase("GGML_CANN_DISABLE_GRAPH_OPTIMIZE").value_or(""));
28472795
if (disable_graph_optimize) {
28482796
return;
28492797
}

0 commit comments

Comments
 (0)