From 0457d0b7994066c4629b38471ae14d39b8a7b8d4 Mon Sep 17 00:00:00 2001 From: Josh Romero Date: Tue, 2 Dec 2025 12:30:32 -0800 Subject: [PATCH 01/13] Adding support for CUDA graph capture. Signed-off-by: Josh Romero --- docs/api/config.rst | 36 +++- src/csrc/include/internal/cuda_graphs.h | 246 ++++++++++++++++++++++++ src/csrc/include/internal/model_pack.h | 8 + src/csrc/include/internal/model_state.h | 3 + src/csrc/setup.cpp | 9 +- src/csrc/torchfort.cpp | 8 + src/csrc/training.cpp | 198 +++++++++++++++++-- 7 files changed, 485 insertions(+), 23 deletions(-) create mode 100644 src/csrc/include/internal/cuda_graphs.h diff --git a/docs/api/config.rst b/docs/api/config.rst index 93edf7d..8a06f6a 100644 --- a/docs/api/config.rst +++ b/docs/api/config.rst @@ -31,18 +31,34 @@ The block in the configuration file defining general properties takes the follow The following table lists the available options: -+-----------------------+-----------+------------------------------------------------------------------------------------------------+ -| Option | Data Type | Description | -+=======================+===========+================================================================================================+ -| ``report_frequency`` | integer | frequency of reported TorchFort training/validation output lines to terminal (default = ``0``) | -+-----------------------+-----------+------------------------------------------------------------------------------------------------+ -| ``enable_wandb_hook`` | boolean | flag to control whether wandb hook is active (default = ``false``) | -+-----------------------+-----------+------------------------------------------------------------------------------------------------+ -| ``verbose`` | boolean | flag to control verbose output from TorchFort (default = ``false``) | -+-----------------------+-----------+------------------------------------------------------------------------------------------------+ ++------------------------+-----------+-------------------------------------------------------------------------------------------------+ +| Option | Data Type | Description | ++========================+===========+=================================================================================================+ +| ``report_frequency`` | integer | frequency of reported TorchFort training/validation output lines to terminal (default = ``0``) | ++------------------------+-----------+-------------------------------------------------------------------------------------------------+ +| ``enable_wandb_hook`` | boolean | flag to control whether wandb hook is active (default = ``false``) | ++------------------------+-----------+-------------------------------------------------------------------------------------------------+ +| ``verbose`` | boolean | flag to control verbose output from TorchFort (default = ``false``) | ++------------------------+-----------+-------------------------------------------------------------------------------------------------+ +| ``enable_cuda_graphs`` | boolean | flag to enable CUDA graph capture for training and inference (default = ``false``). See below. | ++------------------------+-----------+-------------------------------------------------------------------------------------------------+ For more information about the wandb hook, see :ref:`wandb_support-ref`. +CUDA Graphs +^^^^^^^^^^^ +When ``enable_cuda_graphs`` is set to ``true``, TorchFort will capture CUDA graphs for the forward pass (inference) +and the forward + loss + backward pass (training). CUDA graphs can significantly reduce kernel launch overhead +and improve performance for models with many small operations. + +**Requirements and limitations:** + +- Input tensors must be on GPU and must have consistent data pointers, shapes, and dtypes across all training/inference calls with the captured model. + If inputs change after graph capture, an error will be thrown. +- For training, CUDA graph capture is automatically disabled when gradient accumulation (``grad_accumulation_steps > 1``) is active. +- The optimizer step and learning rate scheduler updates are not captured in the graph. +- A warmup period of 3 iterations is performed before graph capture to ensure stable execution. + .. _optimizer_properties-ref: Optimizer Properties @@ -613,4 +629,4 @@ Refer to the :ref:`lr_schedule_properties-ref` for available scheduler types and General Remarks ~~~~~~~~~~~~~~~ -Example YAML files for training the different algorithms are available in the `tests/rl/configs <<../../tests/rl/configs/>>`_ directory. \ No newline at end of file +Example YAML files for training the different algorithms are available in the `tests/rl/configs <<../../tests/rl/configs/>>`_ directory. diff --git a/src/csrc/include/internal/cuda_graphs.h b/src/csrc/include/internal/cuda_graphs.h new file mode 100644 index 0000000..9bb9ccd --- /dev/null +++ b/src/csrc/include/internal/cuda_graphs.h @@ -0,0 +1,246 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#ifdef ENABLE_GPU + +#include + +#include +#include + +#include +#include + +#include "internal/defines.h" +#include "internal/exceptions.h" + +namespace torchfort { + +// RAII wrapper for cudaGraph_t +class CudaGraph { +public: + CudaGraph() : graph_(nullptr) {} + ~CudaGraph() { + if (graph_) { + cudaGraphDestroy(graph_); + } + } + + // Non-copyable + CudaGraph(const CudaGraph&) = delete; + CudaGraph& operator=(const CudaGraph&) = delete; + + // Movable + CudaGraph(CudaGraph&& other) noexcept : graph_(other.graph_) { + other.graph_ = nullptr; + } + CudaGraph& operator=(CudaGraph&& other) noexcept { + if (this != &other) { + if (graph_) cudaGraphDestroy(graph_); + graph_ = other.graph_; + other.graph_ = nullptr; + } + return *this; + } + + cudaGraph_t& get() { return graph_; } + cudaGraph_t get() const { return graph_; } + bool valid() const { return graph_ != nullptr; } + +private: + cudaGraph_t graph_; +}; + +// RAII wrapper for cudaGraphExec_t +class CudaGraphExec { +public: + CudaGraphExec() : exec_(nullptr) {} + ~CudaGraphExec() { + if (exec_) { + cudaGraphExecDestroy(exec_); + } + } + + // Non-copyable + CudaGraphExec(const CudaGraphExec&) = delete; + CudaGraphExec& operator=(const CudaGraphExec&) = delete; + + // Movable + CudaGraphExec(CudaGraphExec&& other) noexcept : exec_(other.exec_) { + other.exec_ = nullptr; + } + CudaGraphExec& operator=(CudaGraphExec&& other) noexcept { + if (this != &other) { + if (exec_) cudaGraphExecDestroy(exec_); + exec_ = other.exec_; + other.exec_ = nullptr; + } + return *this; + } + + cudaGraphExec_t& get() { return exec_; } + cudaGraphExec_t get() const { return exec_; } + bool valid() const { return exec_ != nullptr; } + + // Launch the graph on a stream + void launch(cudaStream_t stream) { + if (exec_) { + CHECK_CUDA(cudaGraphLaunch(exec_, stream)); + } + } + +private: + cudaGraphExec_t exec_; +}; + +// Input signature for validating consistent inputs +struct InputSignature { + std::vector ptrs; + std::vector> shapes; + std::vector dtypes; + + bool operator==(const InputSignature& other) const { + return ptrs == other.ptrs && shapes == other.shapes && dtypes == other.dtypes; + } + + bool operator!=(const InputSignature& other) const { + return !(*this == other); + } + + bool empty() const { return ptrs.empty(); } +}; + +// Helper to create input signature from tensor list +inline InputSignature make_input_signature(const std::vector& tensors) { + InputSignature sig; + sig.ptrs.reserve(tensors.size()); + sig.shapes.reserve(tensors.size()); + sig.dtypes.reserve(tensors.size()); + for (const auto& t : tensors) { + sig.ptrs.push_back(t.data_ptr()); + sig.shapes.push_back(t.sizes().vec()); + sig.dtypes.push_back(t.scalar_type()); + } + return sig; +} + +// Helper to create input signature from multiple tensor lists (for training) +inline InputSignature make_input_signature(const std::vector& inputs, + const std::vector& labels, + const std::vector& extra_args) { + InputSignature sig; + size_t total = inputs.size() + labels.size() + extra_args.size(); + sig.ptrs.reserve(total); + sig.shapes.reserve(total); + sig.dtypes.reserve(total); + + for (const auto& t : inputs) { + sig.ptrs.push_back(t.data_ptr()); + sig.shapes.push_back(t.sizes().vec()); + sig.dtypes.push_back(t.scalar_type()); + } + for (const auto& t : labels) { + sig.ptrs.push_back(t.data_ptr()); + sig.shapes.push_back(t.sizes().vec()); + sig.dtypes.push_back(t.scalar_type()); + } + for (const auto& t : extra_args) { + sig.ptrs.push_back(t.data_ptr()); + sig.shapes.push_back(t.sizes().vec()); + sig.dtypes.push_back(t.scalar_type()); + } + return sig; +} + +// Validate that current inputs match the captured signature +inline void validate_input_signature(const InputSignature& expected, + const InputSignature& actual, + const char* context) { + if (expected != actual) { + std::stringstream ss; + ss << "CUDA graph input mismatch in " << context << ". " + << "When cuda_graphs is enabled, input tensors must have consistent " + << "data pointers, shapes, and dtypes across all calls. " + << "If you need to change inputs, disable cuda_graphs."; + THROW_INVALID_USAGE(ss.str()); + } +} + +// Graph state for inference +struct InferenceGraphState { + int warmup_count = 0; + bool captured = false; + + InputSignature input_signature; + CudaGraph graph; + CudaGraphExec graph_exec; + std::vector static_outputs; +}; + +// Graph state for training (single graph for forward + loss + backward) +struct TrainingGraphState { + int warmup_count = 0; + bool captured = false; + + InputSignature input_signature; + CudaGraph graph; + CudaGraphExec graph_exec; + torch::Tensor static_loss; +}; + +// Graph state for a model, including the capture stream +class ModelGraphState { +public: + InferenceGraphState inference; + TrainingGraphState training; + + ModelGraphState(int device_index = 0) + : capture_stream_(nullptr), device_index_(device_index) { + // Create a non-blocking stream for graph capture + CHECK_CUDA(cudaSetDevice(device_index_)); + CHECK_CUDA(cudaStreamCreateWithFlags(&capture_stream_, cudaStreamNonBlocking)); + } + + ~ModelGraphState() { + if (capture_stream_) { + cudaStreamDestroy(capture_stream_); + } + } + + // Non-copyable + ModelGraphState(const ModelGraphState&) = delete; + ModelGraphState& operator=(const ModelGraphState&) = delete; + + cudaStream_t capture_stream() const { return capture_stream_; } + int device_index() const { return device_index_; } + + // Get c10 stream wrapper for the capture stream (for PyTorch integration) + c10::cuda::CUDAStream get_capture_cuda_stream() const { + return c10::cuda::getStreamFromExternal(capture_stream_, device_index_); + } + +private: + cudaStream_t capture_stream_; + int device_index_; +}; + +} // namespace torchfort + +#endif // ENABLE_GPU + diff --git a/src/csrc/include/internal/model_pack.h b/src/csrc/include/internal/model_pack.h index 351e96d..f09b331 100644 --- a/src/csrc/include/internal/model_pack.h +++ b/src/csrc/include/internal/model_pack.h @@ -25,6 +25,9 @@ #include "internal/distributed.h" #include "internal/model_state.h" #include "internal/model_wrapper.h" +#ifdef ENABLE_GPU +#include "internal/cuda_graphs.h" +#endif namespace torchfort { @@ -38,6 +41,11 @@ struct ModelPack { std::shared_ptr state; int grad_accumulation_steps = 1; float max_grad_norm = 0.0; + +#ifdef ENABLE_GPU + // CUDA graph state (initialized if enable_cuda_graphs is true) + std::shared_ptr graph_state; +#endif }; void save_model_pack(const ModelPack& model_pack, const std::string& fname, bool save_optimizer = true); diff --git a/src/csrc/include/internal/model_state.h b/src/csrc/include/internal/model_state.h index c4c2111..ff63aba 100644 --- a/src/csrc/include/internal/model_state.h +++ b/src/csrc/include/internal/model_state.h @@ -36,6 +36,9 @@ struct ModelState { bool verbose; std::filesystem::path report_file; + // CUDA graph settings + bool enable_cuda_graphs = false; + void save(const std::string& fname); void load(const std::string& fname); }; diff --git a/src/csrc/setup.cpp b/src/csrc/setup.cpp index 8b083c0..80a4406 100644 --- a/src/csrc/setup.cpp +++ b/src/csrc/setup.cpp @@ -262,7 +262,8 @@ std::shared_ptr get_state(const char* name, const YAML::Node& state_ if (state_node["general"]) { auto params = get_params(state_node["general"]); - std::set supported_params{"report_frequency", "enable_wandb_hook", "verbose"}; + std::set supported_params{"report_frequency", "enable_wandb_hook", "verbose", + "enable_cuda_graphs"}; check_params(supported_params, params.keys()); try { @@ -311,6 +312,12 @@ std::shared_ptr get_state(const char* name, const YAML::Node& state_ } catch (std::out_of_range) { state->verbose = false; } + + try { + state->enable_cuda_graphs = params.get_param("enable_cuda_graphs")[0]; + } catch (std::out_of_range) { + state->enable_cuda_graphs = false; + } } return state; diff --git a/src/csrc/torchfort.cpp b/src/csrc/torchfort.cpp index b973ff9..cff1e77 100644 --- a/src/csrc/torchfort.cpp +++ b/src/csrc/torchfort.cpp @@ -125,6 +125,14 @@ torchfort_result_t torchfort_create_model(const char* name, const char* config_f // Setting up general options models[name].state = get_state(name, config); + +#ifdef ENABLE_GPU + // Initialize graph state if CUDA graphs are enabled + if (models[name].state->enable_cuda_graphs && models[name].model->device().is_cuda()) { + models[name].graph_state = std::make_shared( + models[name].model->device().index()); + } +#endif } catch (const BaseException& e) { std::cerr << e.what(); return e.getResult(); diff --git a/src/csrc/training.cpp b/src/csrc/training.cpp index 77c7095..49bf817 100644 --- a/src/csrc/training.cpp +++ b/src/csrc/training.cpp @@ -28,14 +28,38 @@ #include "internal/defines.h" #include "internal/logging.h" +#include "internal/model_pack.h" #include "internal/nvtx.h" #include "internal/tensor_list.h" #include "internal/utils.h" +#ifdef ENABLE_GPU +#include "internal/cuda_graphs.h" +#endif namespace torchfort { // Declaration of external global variables extern std::unordered_map models; +#ifdef ENABLE_GPU +// Number of warmup iterations before CUDA graph capture +constexpr int kCudaGraphWarmupIters = 3; + +// Helper to instantiate a CUDA graph from a captured graph +void instantiate_graph(CudaGraph& graph, CudaGraphExec& exec) { + cudaGraphNode_t error_node; + char log_buffer[1024]; + cudaError_t result = cudaGraphInstantiate(&exec.get(), graph.get(), &error_node, log_buffer, sizeof(log_buffer)); + if (result != cudaSuccess) { + std::stringstream ss; + ss << "CUDA graph instantiation failed: " << cudaGetErrorString(result); + if (strlen(log_buffer) > 0) { + ss << " Log: " << log_buffer; + } + THROW_INTERNAL_ERROR(ss.str()); + } +} +#endif + void inference_multiarg(const char* name, torchfort_tensor_list_t inputs_in, torchfort_tensor_list_t outputs_in, cudaStream_t ext_stream = 0) { torchfort::nvtx::rangePush("torchfort_inference"); @@ -55,7 +79,7 @@ void inference_multiarg(const char* name, torchfort_tensor_list_t inputs_in, tor auto model = models[name].model; -#if ENABLE_GPU +#ifdef ENABLE_GPU c10::cuda::OptionalCUDAStreamGuard stream_guard; c10::cuda::OptionalCUDAGuard cuda_guard; set_device_and_stream(stream_guard, cuda_guard, model->device(), ext_stream); @@ -64,9 +88,79 @@ void inference_multiarg(const char* name, torchfort_tensor_list_t inputs_in, tor inputs->to(model->device()); model->eval(); - auto results = model->forward(inputs->tensors); - for (int i = 0; i < results.size(); ++i) { + std::vector results; + +#ifdef ENABLE_GPU + // CUDA graph handling + bool capturing = false; + InferenceGraphState* graph_state = nullptr; + cudaStream_t capture_stream = nullptr; + + if (models[name].state->enable_cuda_graphs && model->device().is_cuda() && models[name].graph_state) { + + graph_state = &models[name].graph_state->inference; + capture_stream = models[name].graph_state->capture_stream(); + + // Create input signature for validation + InputSignature current_sig = make_input_signature(inputs->tensors); + + if (graph_state->captured) { + + validate_input_signature(graph_state->input_signature, current_sig, "inference"); + + } else if (graph_state->warmup_count == kCudaGraphWarmupIters) { + + // Store input signature used during capture + graph_state->input_signature = current_sig; + + // Synchronize user stream before capture + CHECK_CUDA(cudaStreamSynchronize(user_stream)); + + // Switch PyTorch to use our capture stream + auto capture_c10_stream = models[name].graph_state->get_capture_cuda_stream(); + guard.reset_stream(capture_c10_stream); + + // Begin capture + CHECK_CUDA(cudaStreamBeginCapture(capture_stream, cudaStreamCaptureModeGlobal)); + capturing = true; + } + } +#endif + + // Forward pass +#ifdef ENABLE_GPU + if (!graph_state || !graph_state->captured) { +#endif + results = model->forward(inputs->tensors); +#ifdef ENABLE_GPU + if (graph_state) graph_state->warmup_count++; + } + + if (graph_state) { + + if (capturing) { + // End capture and instantiate the graph + CHECK_CUDA(cudaStreamEndCapture(capture_stream, &graph_state->graph.get())); + instantiate_graph(graph_state->graph, graph_state->graph_exec); + graph_state->static_outputs = results; + graph_state->captured = true; + + // Switch back to user stream for replay and subsequent operations + auto user_c10_stream = c10::cuda::getStreamFromExternal(user_stream, model->device().index()); + guard.reset_stream(user_c10_stream); + } + + // Replay graph + if (graph_state->captured) { + graph_state->graph_exec.launch(user_stream); + results = graph_state->static_outputs; + } + } +#endif + + // Copy results to output tensors + for (size_t i = 0; i < results.size(); ++i) { outputs->tensors[i].copy_(results[i].reshape(outputs->tensors[i].sizes())); } models[name].state->step_inference++; @@ -116,21 +210,101 @@ void train_multiarg(const char* name, torchfort_tensor_list_t inputs_in, torchfo auto opt = models[name].optimizer; auto state = models[name].state; - // fwd pass - auto results = model->forward(inputs->tensors); - auto loss = models[name].loss->forward(results, labels->tensors, - (extra_loss_args) ? extra_loss_args->tensors : std::vector()); + torch::Tensor loss; - // extract loss - *loss_val = loss.item(); +#ifdef ENABLE_GPU + // CUDA graph handling + bool capturing = false; + TrainingGraphState* graph_state = nullptr; + cudaStream_t capture_stream = nullptr; + + if (state->enable_cuda_graphs && model->device().is_cuda() && models[name].graph_state && + models[name].grad_accumulation_steps == 1) { + + // Note: CUDA graph capture for training is disabled if gradient accumulation is active + + graph_state = &models[name].graph_state->training; + capture_stream = models[name].graph_state->capture_stream(); + + // Create input signature for validation + std::vector extra_args_vec = extra_loss_args ? extra_loss_args->tensors : std::vector(); + InputSignature current_sig = make_input_signature(inputs->tensors, labels->tensors, extra_args_vec); + + if (graph_state->captured) { + + validate_input_signature(graph_state->input_signature, current_sig, "training"); + + } else if (graph_state->warmup_count == kCudaGraphWarmupIters) { + + // Store input signature used during capture + graph_state->input_signature = current_sig; + capturing = true; + } + } +#endif - // bwd pass if (state->step_train_current % models[name].grad_accumulation_steps == 0) { - opt->zero_grad(); +#ifdef ENABLE_GPU + // Only explicitly call zero_grad for non-replay steps + if (!graph_state || !graph_state->captured) { +#endif + opt->zero_grad(/*set_to_none=*/true); +#ifdef ENABLE_GPU + } +#endif } - loss.backward(); +#ifdef ENABLE_GPU + if (capturing) { + // Synchronize user stream before capture + CHECK_CUDA(cudaStreamSynchronize(user_stream)); + + // Switch PyTorch to use our capture stream + auto capture_c10_stream = models[name].graph_state->get_capture_cuda_stream(); + guard.reset_stream(capture_c10_stream); + + // Begin capture on our non-blocking stream + CHECK_CUDA(cudaStreamBeginCapture(capture_stream, cudaStreamCaptureModeGlobal)); + } +#endif + + // Forward + loss + backward +#ifdef ENABLE_GPU + if (!graph_state || !graph_state->captured) { +#endif + auto fwd_results = model->forward(inputs->tensors); + loss = models[name].loss->forward(fwd_results, labels->tensors, + (extra_loss_args) ? extra_loss_args->tensors : std::vector()); + loss.backward(); +#ifdef ENABLE_GPU + if (graph_state) graph_state->warmup_count++; + } + + if (graph_state) { + if (capturing) { + // End graph capture and instantiate + CHECK_CUDA(cudaStreamEndCapture(capture_stream, &graph_state->graph.get())); + instantiate_graph(graph_state->graph, graph_state->graph_exec); + graph_state->static_loss = loss; + graph_state->captured = true; + + // Switch back to user stream for replay and subsequent operations + auto user_c10_stream = c10::cuda::getStreamFromExternal(user_stream, model->device().index()); + guard.reset_stream(user_c10_stream); + } + + // Replay graph + if (graph_state->captured) { + graph_state->graph_exec.launch(user_stream); + loss = graph_state->static_loss; + } + } +#endif + + // Extract loss value + *loss_val = loss.item(); + // Optimizer step and related operations if ((state->step_train_current + 1) % models[name].grad_accumulation_steps == 0) { // allreduce (average) gradients (if running distributed) if (models[name].comm) { From 6ac8cdc1480e841eb4f5f9a8bd70e115c3c69795 Mon Sep 17 00:00:00 2001 From: Josh Romero Date: Wed, 3 Dec 2025 12:56:46 -0800 Subject: [PATCH 02/13] Simplifying and cleaning up implementation. Signed-off-by: Josh Romero --- src/csrc/include/internal/cuda_graphs.h | 377 ++++++++++++++++++------ src/csrc/training.cpp | 136 ++------- 2 files changed, 301 insertions(+), 212 deletions(-) diff --git a/src/csrc/include/internal/cuda_graphs.h b/src/csrc/include/internal/cuda_graphs.h index 9bb9ccd..2eb8838 100644 --- a/src/csrc/include/internal/cuda_graphs.h +++ b/src/csrc/include/internal/cuda_graphs.h @@ -21,9 +21,11 @@ #include +#include #include #include +#include #include #include @@ -32,6 +34,16 @@ namespace torchfort { +// Number of warmup iterations before CUDA graph capture +constexpr int kCudaGraphWarmupIters = 3; + +// Action to take for current iteration +enum class GraphAction { + WARMUP, // Run eager execution, increment warmup count + CAPTURE, // Run eager execution with graph capture + REPLAY // Skip eager execution, replay captured graph +}; + // RAII wrapper for cudaGraph_t class CudaGraph { public: @@ -98,117 +110,293 @@ class CudaGraphExec { cudaGraphExec_t get() const { return exec_; } bool valid() const { return exec_ != nullptr; } - // Launch the graph on a stream - void launch(cudaStream_t stream) { - if (exec_) { - CHECK_CUDA(cudaGraphLaunch(exec_, stream)); - } - } - private: cudaGraphExec_t exec_; }; -// Input signature for validating consistent inputs -struct InputSignature { - std::vector ptrs; - std::vector> shapes; - std::vector dtypes; +// Graph state for inference +class InferenceGraphState { +public: + InferenceGraphState(const char* context = "inference") : context_(context) {} + + // Determine action for this iteration - validates inputs if captured, stores signature if ready to capture + // Returns the action to take. Call begin_capture() after this if action == CAPTURE. + GraphAction prepare(const std::vector& inputs) { + InputSignature current_sig = make_input_signature(inputs); - bool operator==(const InputSignature& other) const { - return ptrs == other.ptrs && shapes == other.shapes && dtypes == other.dtypes; + if (captured_) { + validate_inputs(current_sig); + return GraphAction::REPLAY; + } + + if (warmup_count_ == kCudaGraphWarmupIters) { + input_signature_ = std::move(current_sig); + return GraphAction::CAPTURE; + } + + return GraphAction::WARMUP; } - bool operator!=(const InputSignature& other) const { - return !(*this == other); + // Begin graph capture - call this after prepare() returns CAPTURE and after any pre-capture work + void begin_capture(cudaStream_t capture_stream, + cudaStream_t user_stream, + c10::cuda::OptionalCUDAStreamGuard& guard, + int device_index) { + CHECK_CUDA(cudaStreamSynchronize(user_stream)); + auto capture_c10_stream = c10::cuda::getStreamFromExternal(capture_stream, device_index); + guard.reset_stream(capture_c10_stream); + CHECK_CUDA(cudaStreamBeginCapture(capture_stream, cudaStreamCaptureModeGlobal)); } - bool empty() const { return ptrs.empty(); } -}; + // Finalize after forward pass - handles capture end or warmup increment + void finalize(GraphAction action, + cudaStream_t capture_stream, + cudaStream_t user_stream, + c10::cuda::OptionalCUDAStreamGuard& guard, + int device_index, + const std::vector& outputs) { + if (action == GraphAction::CAPTURE) { + end_capture(capture_stream, user_stream, guard, device_index); + static_outputs_ = outputs; + } else if (action == GraphAction::WARMUP) { + warmup_count_++; + } + } -// Helper to create input signature from tensor list -inline InputSignature make_input_signature(const std::vector& tensors) { - InputSignature sig; - sig.ptrs.reserve(tensors.size()); - sig.shapes.reserve(tensors.size()); - sig.dtypes.reserve(tensors.size()); - for (const auto& t : tensors) { - sig.ptrs.push_back(t.data_ptr()); - sig.shapes.push_back(t.sizes().vec()); - sig.dtypes.push_back(t.scalar_type()); - } - return sig; -} - -// Helper to create input signature from multiple tensor lists (for training) -inline InputSignature make_input_signature(const std::vector& inputs, - const std::vector& labels, - const std::vector& extra_args) { - InputSignature sig; - size_t total = inputs.size() + labels.size() + extra_args.size(); - sig.ptrs.reserve(total); - sig.shapes.reserve(total); - sig.dtypes.reserve(total); - - for (const auto& t : inputs) { - sig.ptrs.push_back(t.data_ptr()); - sig.shapes.push_back(t.sizes().vec()); - sig.dtypes.push_back(t.scalar_type()); - } - for (const auto& t : labels) { - sig.ptrs.push_back(t.data_ptr()); - sig.shapes.push_back(t.sizes().vec()); - sig.dtypes.push_back(t.scalar_type()); - } - for (const auto& t : extra_args) { - sig.ptrs.push_back(t.data_ptr()); - sig.shapes.push_back(t.sizes().vec()); - sig.dtypes.push_back(t.scalar_type()); - } - return sig; -} - -// Validate that current inputs match the captured signature -inline void validate_input_signature(const InputSignature& expected, - const InputSignature& actual, - const char* context) { - if (expected != actual) { - std::stringstream ss; - ss << "CUDA graph input mismatch in " << context << ". " - << "When cuda_graphs is enabled, input tensors must have consistent " - << "data pointers, shapes, and dtypes across all calls. " - << "If you need to change inputs, disable cuda_graphs."; - THROW_INVALID_USAGE(ss.str()); - } -} + // Launch captured graph on the given stream + void launch(cudaStream_t stream) { + CHECK_CUDA(cudaGraphLaunch(graph_exec_.get(), stream)); + } -// Graph state for inference -struct InferenceGraphState { - int warmup_count = 0; - bool captured = false; - - InputSignature input_signature; - CudaGraph graph; - CudaGraphExec graph_exec; - std::vector static_outputs; + // Get static outputs (valid after CAPTURE or REPLAY) + const std::vector& get_outputs() const { return static_outputs_; } + + bool is_captured() const { return captured_; } + +private: + // Input signature for validating consistent inputs + struct InputSignature { + std::vector ptrs; + std::vector> shapes; + std::vector dtypes; + + bool operator!=(const InputSignature& other) const { + return ptrs != other.ptrs || shapes != other.shapes || dtypes != other.dtypes; + } + }; + + static InputSignature make_input_signature(const std::vector& tensors) { + InputSignature sig; + sig.ptrs.reserve(tensors.size()); + sig.shapes.reserve(tensors.size()); + sig.dtypes.reserve(tensors.size()); + for (const auto& t : tensors) { + sig.ptrs.push_back(t.data_ptr()); + sig.shapes.push_back(t.sizes().vec()); + sig.dtypes.push_back(t.scalar_type()); + } + return sig; + } + + void validate_inputs(const InputSignature& current_sig) const { + if (input_signature_ != current_sig) { + std::stringstream ss; + ss << "CUDA graph input mismatch in " << context_ << ". " + << "When enable_cuda_graphs is set, input tensors must have consistent " + << "data pointers, shapes, and dtypes across all calls. " + << "If you need to change inputs, disable enable_cuda_graphs."; + THROW_INVALID_USAGE(ss.str()); + } + } + + void end_capture(cudaStream_t capture_stream, + cudaStream_t user_stream, + c10::cuda::OptionalCUDAStreamGuard& guard, + int device_index) { + CHECK_CUDA(cudaStreamEndCapture(capture_stream, &graph_.get())); + instantiate_graph(); + captured_ = true; + auto user_c10_stream = c10::cuda::getStreamFromExternal(user_stream, device_index); + guard.reset_stream(user_c10_stream); + } + + void instantiate_graph() { + cudaGraphNode_t error_node; + char log_buffer[1024]; + cudaError_t result = cudaGraphInstantiate(&graph_exec_.get(), graph_.get(), + &error_node, log_buffer, sizeof(log_buffer)); + if (result != cudaSuccess) { + std::stringstream ss; + ss << "CUDA graph instantiation failed in " << context_ << ": " + << cudaGetErrorString(result); + if (std::strlen(log_buffer) > 0) { + ss << " Log: " << log_buffer; + } + THROW_INTERNAL_ERROR(ss.str()); + } + } + + const char* context_; + int warmup_count_ = 0; + bool captured_ = false; + InputSignature input_signature_; + CudaGraph graph_; + CudaGraphExec graph_exec_; + std::vector static_outputs_; }; // Graph state for training (single graph for forward + loss + backward) -struct TrainingGraphState { - int warmup_count = 0; - bool captured = false; - - InputSignature input_signature; - CudaGraph graph; - CudaGraphExec graph_exec; - torch::Tensor static_loss; +class TrainingGraphState { +public: + TrainingGraphState(const char* context = "training") : context_(context) {} + + // Determine action for this iteration - validates inputs if captured, stores signature if ready to capture + // Returns the action to take. Call begin_capture() after this if action == CAPTURE. + GraphAction prepare(const std::vector& inputs, + const std::vector& labels, + const std::vector& extra_args) { + InputSignature current_sig = make_input_signature(inputs, labels, extra_args); + + if (captured_) { + validate_inputs(current_sig); + return GraphAction::REPLAY; + } + + if (warmup_count_ == kCudaGraphWarmupIters) { + input_signature_ = std::move(current_sig); + return GraphAction::CAPTURE; + } + + return GraphAction::WARMUP; + } + + // Begin graph capture - call this after prepare() returns CAPTURE and after any pre-capture work + void begin_capture(cudaStream_t capture_stream, + cudaStream_t user_stream, + c10::cuda::OptionalCUDAStreamGuard& guard, + int device_index) { + CHECK_CUDA(cudaStreamSynchronize(user_stream)); + auto capture_c10_stream = c10::cuda::getStreamFromExternal(capture_stream, device_index); + guard.reset_stream(capture_c10_stream); + CHECK_CUDA(cudaStreamBeginCapture(capture_stream, cudaStreamCaptureModeGlobal)); + } + + // Finalize after forward+loss+backward pass - handles capture end or warmup increment + void finalize(GraphAction action, + cudaStream_t capture_stream, + cudaStream_t user_stream, + c10::cuda::OptionalCUDAStreamGuard& guard, + int device_index, + const torch::Tensor& loss) { + if (action == GraphAction::CAPTURE) { + end_capture(capture_stream, user_stream, guard, device_index); + static_loss_ = loss; + } else if (action == GraphAction::WARMUP) { + warmup_count_++; + } + } + + // Launch captured graph on the given stream + void launch(cudaStream_t stream) { + CHECK_CUDA(cudaGraphLaunch(graph_exec_.get(), stream)); + } + + // Get static loss (valid after CAPTURE or REPLAY) + const torch::Tensor& get_loss() const { return static_loss_; } + + bool is_captured() const { return captured_; } + +private: + // Input signature for validating consistent inputs + struct InputSignature { + std::vector ptrs; + std::vector> shapes; + std::vector dtypes; + + bool operator!=(const InputSignature& other) const { + return ptrs != other.ptrs || shapes != other.shapes || dtypes != other.dtypes; + } + }; + + static InputSignature make_input_signature(const std::vector& inputs, + const std::vector& labels, + const std::vector& extra_args) { + InputSignature sig; + size_t total = inputs.size() + labels.size() + extra_args.size(); + sig.ptrs.reserve(total); + sig.shapes.reserve(total); + sig.dtypes.reserve(total); + + for (const auto& t : inputs) { + sig.ptrs.push_back(t.data_ptr()); + sig.shapes.push_back(t.sizes().vec()); + sig.dtypes.push_back(t.scalar_type()); + } + for (const auto& t : labels) { + sig.ptrs.push_back(t.data_ptr()); + sig.shapes.push_back(t.sizes().vec()); + sig.dtypes.push_back(t.scalar_type()); + } + for (const auto& t : extra_args) { + sig.ptrs.push_back(t.data_ptr()); + sig.shapes.push_back(t.sizes().vec()); + sig.dtypes.push_back(t.scalar_type()); + } + return sig; + } + + void validate_inputs(const InputSignature& current_sig) const { + if (input_signature_ != current_sig) { + std::stringstream ss; + ss << "CUDA graph input mismatch in " << context_ << ". " + << "When enable_cuda_graphs is set, input tensors must have consistent " + << "data pointers, shapes, and dtypes across all calls. " + << "If you need to change inputs, disable enable_cuda_graphs."; + THROW_INVALID_USAGE(ss.str()); + } + } + + void end_capture(cudaStream_t capture_stream, + cudaStream_t user_stream, + c10::cuda::OptionalCUDAStreamGuard& guard, + int device_index) { + CHECK_CUDA(cudaStreamEndCapture(capture_stream, &graph_.get())); + instantiate_graph(); + captured_ = true; + auto user_c10_stream = c10::cuda::getStreamFromExternal(user_stream, device_index); + guard.reset_stream(user_c10_stream); + } + + void instantiate_graph() { + cudaGraphNode_t error_node; + char log_buffer[1024]; + cudaError_t result = cudaGraphInstantiate(&graph_exec_.get(), graph_.get(), + &error_node, log_buffer, sizeof(log_buffer)); + if (result != cudaSuccess) { + std::stringstream ss; + ss << "CUDA graph instantiation failed in " << context_ << ": " + << cudaGetErrorString(result); + if (std::strlen(log_buffer) > 0) { + ss << " Log: " << log_buffer; + } + THROW_INTERNAL_ERROR(ss.str()); + } + } + + const char* context_; + int warmup_count_ = 0; + bool captured_ = false; + InputSignature input_signature_; + CudaGraph graph_; + CudaGraphExec graph_exec_; + torch::Tensor static_loss_; }; // Graph state for a model, including the capture stream class ModelGraphState { public: - InferenceGraphState inference; - TrainingGraphState training; + InferenceGraphState inference{"inference"}; + TrainingGraphState training{"training"}; ModelGraphState(int device_index = 0) : capture_stream_(nullptr), device_index_(device_index) { @@ -230,11 +418,6 @@ class ModelGraphState { cudaStream_t capture_stream() const { return capture_stream_; } int device_index() const { return device_index_; } - // Get c10 stream wrapper for the capture stream (for PyTorch integration) - c10::cuda::CUDAStream get_capture_cuda_stream() const { - return c10::cuda::getStreamFromExternal(capture_stream_, device_index_); - } - private: cudaStream_t capture_stream_; int device_index_; diff --git a/src/csrc/training.cpp b/src/csrc/training.cpp index 49bf817..7e909e5 100644 --- a/src/csrc/training.cpp +++ b/src/csrc/training.cpp @@ -40,26 +40,6 @@ namespace torchfort { // Declaration of external global variables extern std::unordered_map models; -#ifdef ENABLE_GPU -// Number of warmup iterations before CUDA graph capture -constexpr int kCudaGraphWarmupIters = 3; - -// Helper to instantiate a CUDA graph from a captured graph -void instantiate_graph(CudaGraph& graph, CudaGraphExec& exec) { - cudaGraphNode_t error_node; - char log_buffer[1024]; - cudaError_t result = cudaGraphInstantiate(&exec.get(), graph.get(), &error_node, log_buffer, sizeof(log_buffer)); - if (result != cudaSuccess) { - std::stringstream ss; - ss << "CUDA graph instantiation failed: " << cudaGetErrorString(result); - if (strlen(log_buffer) > 0) { - ss << " Log: " << log_buffer; - } - THROW_INTERNAL_ERROR(ss.str()); - } -} -#endif - void inference_multiarg(const char* name, torchfort_tensor_list_t inputs_in, torchfort_tensor_list_t outputs_in, cudaStream_t ext_stream = 0) { torchfort::nvtx::rangePush("torchfort_inference"); @@ -92,69 +72,33 @@ void inference_multiarg(const char* name, torchfort_tensor_list_t inputs_in, tor std::vector results; #ifdef ENABLE_GPU - // CUDA graph handling - bool capturing = false; + GraphAction action = GraphAction::WARMUP; InferenceGraphState* graph_state = nullptr; - cudaStream_t capture_stream = nullptr; if (models[name].state->enable_cuda_graphs && model->device().is_cuda() && models[name].graph_state) { - graph_state = &models[name].graph_state->inference; - capture_stream = models[name].graph_state->capture_stream(); - - // Create input signature for validation - InputSignature current_sig = make_input_signature(inputs->tensors); - - if (graph_state->captured) { - - validate_input_signature(graph_state->input_signature, current_sig, "inference"); - - } else if (graph_state->warmup_count == kCudaGraphWarmupIters) { - - // Store input signature used during capture - graph_state->input_signature = current_sig; - - // Synchronize user stream before capture - CHECK_CUDA(cudaStreamSynchronize(user_stream)); - - // Switch PyTorch to use our capture stream - auto capture_c10_stream = models[name].graph_state->get_capture_cuda_stream(); - guard.reset_stream(capture_c10_stream); + action = graph_state->prepare(inputs->tensors); - // Begin capture - CHECK_CUDA(cudaStreamBeginCapture(capture_stream, cudaStreamCaptureModeGlobal)); - capturing = true; + if (action == GraphAction::CAPTURE) { + graph_state->begin_capture(models[name].graph_state->capture_stream(), ext_stream, guard, model->device().index()); } } #endif // Forward pass #ifdef ENABLE_GPU - if (!graph_state || !graph_state->captured) { + if (action != GraphAction::REPLAY) { #endif results = model->forward(inputs->tensors); #ifdef ENABLE_GPU - if (graph_state) graph_state->warmup_count++; } if (graph_state) { + graph_state->finalize(action, models[name].graph_state->capture_stream(), ext_stream, guard, model->device().index(), results); - if (capturing) { - // End capture and instantiate the graph - CHECK_CUDA(cudaStreamEndCapture(capture_stream, &graph_state->graph.get())); - instantiate_graph(graph_state->graph, graph_state->graph_exec); - graph_state->static_outputs = results; - graph_state->captured = true; - - // Switch back to user stream for replay and subsequent operations - auto user_c10_stream = c10::cuda::getStreamFromExternal(user_stream, model->device().index()); - guard.reset_stream(user_c10_stream); - } - - // Replay graph - if (graph_state->captured) { - graph_state->graph_exec.launch(user_stream); - results = graph_state->static_outputs; + if (action == GraphAction::CAPTURE || action == GraphAction::REPLAY) { + graph_state->launch(ext_stream); + results = graph_state->get_outputs(); } } #endif @@ -213,40 +157,22 @@ void train_multiarg(const char* name, torchfort_tensor_list_t inputs_in, torchfo torch::Tensor loss; #ifdef ENABLE_GPU - // CUDA graph handling - bool capturing = false; + GraphAction action = GraphAction::WARMUP; TrainingGraphState* graph_state = nullptr; - cudaStream_t capture_stream = nullptr; + // Note: CUDA graph capture for training is disabled if gradient accumulation is active if (state->enable_cuda_graphs && model->device().is_cuda() && models[name].graph_state && models[name].grad_accumulation_steps == 1) { - - // Note: CUDA graph capture for training is disabled if gradient accumulation is active - graph_state = &models[name].graph_state->training; - capture_stream = models[name].graph_state->capture_stream(); - - // Create input signature for validation std::vector extra_args_vec = extra_loss_args ? extra_loss_args->tensors : std::vector(); - InputSignature current_sig = make_input_signature(inputs->tensors, labels->tensors, extra_args_vec); - - if (graph_state->captured) { - - validate_input_signature(graph_state->input_signature, current_sig, "training"); - - } else if (graph_state->warmup_count == kCudaGraphWarmupIters) { - - // Store input signature used during capture - graph_state->input_signature = current_sig; - capturing = true; - } + action = graph_state->prepare(inputs->tensors, labels->tensors, extra_args_vec); } #endif if (state->step_train_current % models[name].grad_accumulation_steps == 0) { #ifdef ENABLE_GPU - // Only explicitly call zero_grad for non-replay steps - if (!graph_state || !graph_state->captured) { + // zero_grad is only needed for non-replay steps + if (action != GraphAction::REPLAY) { #endif opt->zero_grad(/*set_to_none=*/true); #ifdef ENABLE_GPU @@ -255,48 +181,28 @@ void train_multiarg(const char* name, torchfort_tensor_list_t inputs_in, torchfo } #ifdef ENABLE_GPU - if (capturing) { - // Synchronize user stream before capture - CHECK_CUDA(cudaStreamSynchronize(user_stream)); - - // Switch PyTorch to use our capture stream - auto capture_c10_stream = models[name].graph_state->get_capture_cuda_stream(); - guard.reset_stream(capture_c10_stream); - - // Begin capture on our non-blocking stream - CHECK_CUDA(cudaStreamBeginCapture(capture_stream, cudaStreamCaptureModeGlobal)); + if (action == GraphAction::CAPTURE) { + graph_state->begin_capture(models[name].graph_state->capture_stream(), ext_stream, guard, model->device().index()); } #endif // Forward + loss + backward #ifdef ENABLE_GPU - if (!graph_state || !graph_state->captured) { + if (action != GraphAction::REPLAY) { #endif auto fwd_results = model->forward(inputs->tensors); loss = models[name].loss->forward(fwd_results, labels->tensors, (extra_loss_args) ? extra_loss_args->tensors : std::vector()); loss.backward(); #ifdef ENABLE_GPU - if (graph_state) graph_state->warmup_count++; } if (graph_state) { - if (capturing) { - // End graph capture and instantiate - CHECK_CUDA(cudaStreamEndCapture(capture_stream, &graph_state->graph.get())); - instantiate_graph(graph_state->graph, graph_state->graph_exec); - graph_state->static_loss = loss; - graph_state->captured = true; - - // Switch back to user stream for replay and subsequent operations - auto user_c10_stream = c10::cuda::getStreamFromExternal(user_stream, model->device().index()); - guard.reset_stream(user_c10_stream); - } + graph_state->finalize(action, models[name].graph_state->capture_stream(), ext_stream, guard, model->device().index(), loss); - // Replay graph - if (graph_state->captured) { - graph_state->graph_exec.launch(user_stream); - loss = graph_state->static_loss; + if (action == GraphAction::CAPTURE || action == GraphAction::REPLAY) { + graph_state->launch(ext_stream); + loss = graph_state->get_loss(); } } #endif From cd8899be9c652679406bd4ba9d5fa4996ae1833e Mon Sep 17 00:00:00 2001 From: Josh Romero Date: Wed, 3 Dec 2025 15:34:54 -0800 Subject: [PATCH 03/13] Adding graph support for grad accumulation. Cleaning up some ifdefs. Signed-off-by: Josh Romero --- src/csrc/include/internal/cuda_graphs.h | 17 +++++++----- src/csrc/training.cpp | 35 ++++++++++++------------- 2 files changed, 27 insertions(+), 25 deletions(-) diff --git a/src/csrc/include/internal/cuda_graphs.h b/src/csrc/include/internal/cuda_graphs.h index 2eb8838..611d374 100644 --- a/src/csrc/include/internal/cuda_graphs.h +++ b/src/csrc/include/internal/cuda_graphs.h @@ -18,15 +18,15 @@ #pragma once #ifdef ENABLE_GPU - #include +#include +#include +#endif #include #include #include -#include -#include #include #include "internal/defines.h" @@ -34,9 +34,6 @@ namespace torchfort { -// Number of warmup iterations before CUDA graph capture -constexpr int kCudaGraphWarmupIters = 3; - // Action to take for current iteration enum class GraphAction { WARMUP, // Run eager execution, increment warmup count @@ -44,6 +41,11 @@ enum class GraphAction { REPLAY // Skip eager execution, replay captured graph }; +#ifdef ENABLE_GPU + +// Number of warmup iterations before CUDA graph capture +constexpr int kCudaGraphWarmupIters = 3; + // RAII wrapper for cudaGraph_t class CudaGraph { public: @@ -423,7 +425,8 @@ class ModelGraphState { int device_index_; }; +#endif + } // namespace torchfort -#endif // ENABLE_GPU diff --git a/src/csrc/training.cpp b/src/csrc/training.cpp index 7e909e5..58dec3b 100644 --- a/src/csrc/training.cpp +++ b/src/csrc/training.cpp @@ -32,9 +32,7 @@ #include "internal/nvtx.h" #include "internal/tensor_list.h" #include "internal/utils.h" -#ifdef ENABLE_GPU #include "internal/cuda_graphs.h" -#endif namespace torchfort { // Declaration of external global variables @@ -71,8 +69,9 @@ void inference_multiarg(const char* name, torchfort_tensor_list_t inputs_in, tor std::vector results; -#ifdef ENABLE_GPU GraphAction action = GraphAction::WARMUP; + +#ifdef ENABLE_GPU InferenceGraphState* graph_state = nullptr; if (models[name].state->enable_cuda_graphs && model->device().is_cuda() && models[name].graph_state) { @@ -86,13 +85,11 @@ void inference_multiarg(const char* name, torchfort_tensor_list_t inputs_in, tor #endif // Forward pass -#ifdef ENABLE_GPU if (action != GraphAction::REPLAY) { -#endif results = model->forward(inputs->tensors); -#ifdef ENABLE_GPU } +#ifdef ENABLE_GPU if (graph_state) { graph_state->finalize(action, models[name].graph_state->capture_stream(), ext_stream, guard, model->device().index(), results); @@ -156,13 +153,12 @@ void train_multiarg(const char* name, torchfort_tensor_list_t inputs_in, torchfo torch::Tensor loss; -#ifdef ENABLE_GPU GraphAction action = GraphAction::WARMUP; + +#ifdef ENABLE_GPU TrainingGraphState* graph_state = nullptr; - // Note: CUDA graph capture for training is disabled if gradient accumulation is active - if (state->enable_cuda_graphs && model->device().is_cuda() && models[name].graph_state && - models[name].grad_accumulation_steps == 1) { + if (state->enable_cuda_graphs && model->device().is_cuda() && models[name].graph_state) { graph_state = &models[name].graph_state->training; std::vector extra_args_vec = extra_loss_args ? extra_loss_args->tensors : std::vector(); action = graph_state->prepare(inputs->tensors, labels->tensors, extra_args_vec); @@ -170,14 +166,19 @@ void train_multiarg(const char* name, torchfort_tensor_list_t inputs_in, torchfo #endif if (state->step_train_current % models[name].grad_accumulation_steps == 0) { + // Only run zero_grad on non-replay steps or if gradient accumulation is active + if (action != GraphAction::REPLAY || models[name].grad_accumulation_steps > 1) { + if (models[name].grad_accumulation_steps > 1) { #ifdef ENABLE_GPU - // zero_grad is only needed for non-replay steps - if (action != GraphAction::REPLAY) { + // With graphs and grad accumulation active, gradients must be persistent (set_to_none = false) + opt->zero_grad(/*set_to_none=*/(graph_state == nullptr)); +#else + opt->zero_grad(/*set_to_none=*/true); #endif - opt->zero_grad(/*set_to_none=*/true); -#ifdef ENABLE_GPU + } else { + opt->zero_grad(/*set_to_none=*/true); + } } -#endif } #ifdef ENABLE_GPU @@ -187,16 +188,14 @@ void train_multiarg(const char* name, torchfort_tensor_list_t inputs_in, torchfo #endif // Forward + loss + backward -#ifdef ENABLE_GPU if (action != GraphAction::REPLAY) { -#endif auto fwd_results = model->forward(inputs->tensors); loss = models[name].loss->forward(fwd_results, labels->tensors, (extra_loss_args) ? extra_loss_args->tensors : std::vector()); loss.backward(); -#ifdef ENABLE_GPU } +#ifdef ENABLE_GPU if (graph_state) { graph_state->finalize(action, models[name].graph_state->capture_stream(), ext_stream, guard, model->device().index(), loss); From 35f34d34b6a08c96faafe414d6789facacd9c178 Mon Sep 17 00:00:00 2001 From: Josh Romero Date: Wed, 3 Dec 2025 15:36:08 -0800 Subject: [PATCH 04/13] Update docs. Signed-off-by: Josh Romero --- docs/api/config.rst | 1 - 1 file changed, 1 deletion(-) diff --git a/docs/api/config.rst b/docs/api/config.rst index 8a06f6a..72539d1 100644 --- a/docs/api/config.rst +++ b/docs/api/config.rst @@ -55,7 +55,6 @@ and improve performance for models with many small operations. - Input tensors must be on GPU and must have consistent data pointers, shapes, and dtypes across all training/inference calls with the captured model. If inputs change after graph capture, an error will be thrown. -- For training, CUDA graph capture is automatically disabled when gradient accumulation (``grad_accumulation_steps > 1``) is active. - The optimizer step and learning rate scheduler updates are not captured in the graph. - A warmup period of 3 iterations is performed before graph capture to ensure stable execution. From 2aa33a0bfbb13d0968d003c046ac9d2d99c3c88d Mon Sep 17 00:00:00 2001 From: Josh Romero Date: Wed, 3 Dec 2025 15:51:48 -0800 Subject: [PATCH 05/13] Formatting. Signed-off-by: Josh Romero --- src/csrc/include/internal/cuda_graphs.h | 83 +++++++++---------------- src/csrc/setup.cpp | 3 +- src/csrc/torchfort.cpp | 3 +- src/csrc/training.cpp | 14 +++-- 4 files changed, 39 insertions(+), 64 deletions(-) diff --git a/src/csrc/include/internal/cuda_graphs.h b/src/csrc/include/internal/cuda_graphs.h index 611d374..307b395 100644 --- a/src/csrc/include/internal/cuda_graphs.h +++ b/src/csrc/include/internal/cuda_graphs.h @@ -18,9 +18,9 @@ #pragma once #ifdef ENABLE_GPU -#include #include #include +#include #endif #include @@ -36,9 +36,9 @@ namespace torchfort { // Action to take for current iteration enum class GraphAction { - WARMUP, // Run eager execution, increment warmup count - CAPTURE, // Run eager execution with graph capture - REPLAY // Skip eager execution, replay captured graph + WARMUP, // Run eager execution, increment warmup count + CAPTURE, // Run eager execution with graph capture + REPLAY // Skip eager execution, replay captured graph }; #ifdef ENABLE_GPU @@ -61,12 +61,11 @@ class CudaGraph { CudaGraph& operator=(const CudaGraph&) = delete; // Movable - CudaGraph(CudaGraph&& other) noexcept : graph_(other.graph_) { - other.graph_ = nullptr; - } + CudaGraph(CudaGraph&& other) noexcept : graph_(other.graph_) { other.graph_ = nullptr; } CudaGraph& operator=(CudaGraph&& other) noexcept { if (this != &other) { - if (graph_) cudaGraphDestroy(graph_); + if (graph_) + cudaGraphDestroy(graph_); graph_ = other.graph_; other.graph_ = nullptr; } @@ -96,12 +95,11 @@ class CudaGraphExec { CudaGraphExec& operator=(const CudaGraphExec&) = delete; // Movable - CudaGraphExec(CudaGraphExec&& other) noexcept : exec_(other.exec_) { - other.exec_ = nullptr; - } + CudaGraphExec(CudaGraphExec&& other) noexcept : exec_(other.exec_) { other.exec_ = nullptr; } CudaGraphExec& operator=(CudaGraphExec&& other) noexcept { if (this != &other) { - if (exec_) cudaGraphExecDestroy(exec_); + if (exec_) + cudaGraphExecDestroy(exec_); exec_ = other.exec_; other.exec_ = nullptr; } @@ -140,9 +138,7 @@ class InferenceGraphState { } // Begin graph capture - call this after prepare() returns CAPTURE and after any pre-capture work - void begin_capture(cudaStream_t capture_stream, - cudaStream_t user_stream, - c10::cuda::OptionalCUDAStreamGuard& guard, + void begin_capture(cudaStream_t capture_stream, cudaStream_t user_stream, c10::cuda::OptionalCUDAStreamGuard& guard, int device_index) { CHECK_CUDA(cudaStreamSynchronize(user_stream)); auto capture_c10_stream = c10::cuda::getStreamFromExternal(capture_stream, device_index); @@ -151,11 +147,8 @@ class InferenceGraphState { } // Finalize after forward pass - handles capture end or warmup increment - void finalize(GraphAction action, - cudaStream_t capture_stream, - cudaStream_t user_stream, - c10::cuda::OptionalCUDAStreamGuard& guard, - int device_index, + void finalize(GraphAction action, cudaStream_t capture_stream, cudaStream_t user_stream, + c10::cuda::OptionalCUDAStreamGuard& guard, int device_index, const std::vector& outputs) { if (action == GraphAction::CAPTURE) { end_capture(capture_stream, user_stream, guard, device_index); @@ -166,9 +159,7 @@ class InferenceGraphState { } // Launch captured graph on the given stream - void launch(cudaStream_t stream) { - CHECK_CUDA(cudaGraphLaunch(graph_exec_.get(), stream)); - } + void launch(cudaStream_t stream) { CHECK_CUDA(cudaGraphLaunch(graph_exec_.get(), stream)); } // Get static outputs (valid after CAPTURE or REPLAY) const std::vector& get_outputs() const { return static_outputs_; } @@ -211,9 +202,7 @@ class InferenceGraphState { } } - void end_capture(cudaStream_t capture_stream, - cudaStream_t user_stream, - c10::cuda::OptionalCUDAStreamGuard& guard, + void end_capture(cudaStream_t capture_stream, cudaStream_t user_stream, c10::cuda::OptionalCUDAStreamGuard& guard, int device_index) { CHECK_CUDA(cudaStreamEndCapture(capture_stream, &graph_.get())); instantiate_graph(); @@ -225,12 +214,11 @@ class InferenceGraphState { void instantiate_graph() { cudaGraphNode_t error_node; char log_buffer[1024]; - cudaError_t result = cudaGraphInstantiate(&graph_exec_.get(), graph_.get(), - &error_node, log_buffer, sizeof(log_buffer)); + cudaError_t result = + cudaGraphInstantiate(&graph_exec_.get(), graph_.get(), &error_node, log_buffer, sizeof(log_buffer)); if (result != cudaSuccess) { std::stringstream ss; - ss << "CUDA graph instantiation failed in " << context_ << ": " - << cudaGetErrorString(result); + ss << "CUDA graph instantiation failed in " << context_ << ": " << cudaGetErrorString(result); if (std::strlen(log_buffer) > 0) { ss << " Log: " << log_buffer; } @@ -254,8 +242,7 @@ class TrainingGraphState { // Determine action for this iteration - validates inputs if captured, stores signature if ready to capture // Returns the action to take. Call begin_capture() after this if action == CAPTURE. - GraphAction prepare(const std::vector& inputs, - const std::vector& labels, + GraphAction prepare(const std::vector& inputs, const std::vector& labels, const std::vector& extra_args) { InputSignature current_sig = make_input_signature(inputs, labels, extra_args); @@ -273,9 +260,7 @@ class TrainingGraphState { } // Begin graph capture - call this after prepare() returns CAPTURE and after any pre-capture work - void begin_capture(cudaStream_t capture_stream, - cudaStream_t user_stream, - c10::cuda::OptionalCUDAStreamGuard& guard, + void begin_capture(cudaStream_t capture_stream, cudaStream_t user_stream, c10::cuda::OptionalCUDAStreamGuard& guard, int device_index) { CHECK_CUDA(cudaStreamSynchronize(user_stream)); auto capture_c10_stream = c10::cuda::getStreamFromExternal(capture_stream, device_index); @@ -284,12 +269,8 @@ class TrainingGraphState { } // Finalize after forward+loss+backward pass - handles capture end or warmup increment - void finalize(GraphAction action, - cudaStream_t capture_stream, - cudaStream_t user_stream, - c10::cuda::OptionalCUDAStreamGuard& guard, - int device_index, - const torch::Tensor& loss) { + void finalize(GraphAction action, cudaStream_t capture_stream, cudaStream_t user_stream, + c10::cuda::OptionalCUDAStreamGuard& guard, int device_index, const torch::Tensor& loss) { if (action == GraphAction::CAPTURE) { end_capture(capture_stream, user_stream, guard, device_index); static_loss_ = loss; @@ -299,9 +280,7 @@ class TrainingGraphState { } // Launch captured graph on the given stream - void launch(cudaStream_t stream) { - CHECK_CUDA(cudaGraphLaunch(graph_exec_.get(), stream)); - } + void launch(cudaStream_t stream) { CHECK_CUDA(cudaGraphLaunch(graph_exec_.get(), stream)); } // Get static loss (valid after CAPTURE or REPLAY) const torch::Tensor& get_loss() const { return static_loss_; } @@ -358,9 +337,7 @@ class TrainingGraphState { } } - void end_capture(cudaStream_t capture_stream, - cudaStream_t user_stream, - c10::cuda::OptionalCUDAStreamGuard& guard, + void end_capture(cudaStream_t capture_stream, cudaStream_t user_stream, c10::cuda::OptionalCUDAStreamGuard& guard, int device_index) { CHECK_CUDA(cudaStreamEndCapture(capture_stream, &graph_.get())); instantiate_graph(); @@ -372,12 +349,11 @@ class TrainingGraphState { void instantiate_graph() { cudaGraphNode_t error_node; char log_buffer[1024]; - cudaError_t result = cudaGraphInstantiate(&graph_exec_.get(), graph_.get(), - &error_node, log_buffer, sizeof(log_buffer)); + cudaError_t result = + cudaGraphInstantiate(&graph_exec_.get(), graph_.get(), &error_node, log_buffer, sizeof(log_buffer)); if (result != cudaSuccess) { std::stringstream ss; - ss << "CUDA graph instantiation failed in " << context_ << ": " - << cudaGetErrorString(result); + ss << "CUDA graph instantiation failed in " << context_ << ": " << cudaGetErrorString(result); if (std::strlen(log_buffer) > 0) { ss << " Log: " << log_buffer; } @@ -400,8 +376,7 @@ class ModelGraphState { InferenceGraphState inference{"inference"}; TrainingGraphState training{"training"}; - ModelGraphState(int device_index = 0) - : capture_stream_(nullptr), device_index_(device_index) { + ModelGraphState(int device_index = 0) : capture_stream_(nullptr), device_index_(device_index) { // Create a non-blocking stream for graph capture CHECK_CUDA(cudaSetDevice(device_index_)); CHECK_CUDA(cudaStreamCreateWithFlags(&capture_stream_, cudaStreamNonBlocking)); @@ -428,5 +403,3 @@ class ModelGraphState { #endif } // namespace torchfort - - diff --git a/src/csrc/setup.cpp b/src/csrc/setup.cpp index 80a4406..fe1f5ff 100644 --- a/src/csrc/setup.cpp +++ b/src/csrc/setup.cpp @@ -262,8 +262,7 @@ std::shared_ptr get_state(const char* name, const YAML::Node& state_ if (state_node["general"]) { auto params = get_params(state_node["general"]); - std::set supported_params{"report_frequency", "enable_wandb_hook", "verbose", - "enable_cuda_graphs"}; + std::set supported_params{"report_frequency", "enable_wandb_hook", "verbose", "enable_cuda_graphs"}; check_params(supported_params, params.keys()); try { diff --git a/src/csrc/torchfort.cpp b/src/csrc/torchfort.cpp index cff1e77..0d6acba 100644 --- a/src/csrc/torchfort.cpp +++ b/src/csrc/torchfort.cpp @@ -129,8 +129,7 @@ torchfort_result_t torchfort_create_model(const char* name, const char* config_f #ifdef ENABLE_GPU // Initialize graph state if CUDA graphs are enabled if (models[name].state->enable_cuda_graphs && models[name].model->device().is_cuda()) { - models[name].graph_state = std::make_shared( - models[name].model->device().index()); + models[name].graph_state = std::make_shared(models[name].model->device().index()); } #endif } catch (const BaseException& e) { diff --git a/src/csrc/training.cpp b/src/csrc/training.cpp index 58dec3b..746477e 100644 --- a/src/csrc/training.cpp +++ b/src/csrc/training.cpp @@ -26,13 +26,13 @@ #endif #include +#include "internal/cuda_graphs.h" #include "internal/defines.h" #include "internal/logging.h" #include "internal/model_pack.h" #include "internal/nvtx.h" #include "internal/tensor_list.h" #include "internal/utils.h" -#include "internal/cuda_graphs.h" namespace torchfort { // Declaration of external global variables @@ -79,7 +79,8 @@ void inference_multiarg(const char* name, torchfort_tensor_list_t inputs_in, tor action = graph_state->prepare(inputs->tensors); if (action == GraphAction::CAPTURE) { - graph_state->begin_capture(models[name].graph_state->capture_stream(), ext_stream, guard, model->device().index()); + graph_state->begin_capture(models[name].graph_state->capture_stream(), ext_stream, guard, + model->device().index()); } } #endif @@ -91,7 +92,8 @@ void inference_multiarg(const char* name, torchfort_tensor_list_t inputs_in, tor #ifdef ENABLE_GPU if (graph_state) { - graph_state->finalize(action, models[name].graph_state->capture_stream(), ext_stream, guard, model->device().index(), results); + graph_state->finalize(action, models[name].graph_state->capture_stream(), ext_stream, guard, + model->device().index(), results); if (action == GraphAction::CAPTURE || action == GraphAction::REPLAY) { graph_state->launch(ext_stream); @@ -160,7 +162,8 @@ void train_multiarg(const char* name, torchfort_tensor_list_t inputs_in, torchfo if (state->enable_cuda_graphs && model->device().is_cuda() && models[name].graph_state) { graph_state = &models[name].graph_state->training; - std::vector extra_args_vec = extra_loss_args ? extra_loss_args->tensors : std::vector(); + std::vector extra_args_vec = + extra_loss_args ? extra_loss_args->tensors : std::vector(); action = graph_state->prepare(inputs->tensors, labels->tensors, extra_args_vec); } #endif @@ -197,7 +200,8 @@ void train_multiarg(const char* name, torchfort_tensor_list_t inputs_in, torchfo #ifdef ENABLE_GPU if (graph_state) { - graph_state->finalize(action, models[name].graph_state->capture_stream(), ext_stream, guard, model->device().index(), loss); + graph_state->finalize(action, models[name].graph_state->capture_stream(), ext_stream, guard, + model->device().index(), loss); if (action == GraphAction::CAPTURE || action == GraphAction::REPLAY) { graph_state->launch(ext_stream); From 5d9bf6632a536cc1abe97ce71662f22977cd6abb Mon Sep 17 00:00:00 2001 From: Josh Romero Date: Thu, 4 Dec 2025 11:00:13 -0800 Subject: [PATCH 06/13] Move loss D2H copy and allreduce after optimizer step call. Signed-off-by: Josh Romero --- src/csrc/training.cpp | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/src/csrc/training.cpp b/src/csrc/training.cpp index 746477e..560ed9a 100644 --- a/src/csrc/training.cpp +++ b/src/csrc/training.cpp @@ -210,9 +210,6 @@ void train_multiarg(const char* name, torchfort_tensor_list_t inputs_in, torchfo } #endif - // Extract loss value - *loss_val = loss.item(); - // Optimizer step and related operations if ((state->step_train_current + 1) % models[name].grad_accumulation_steps == 0) { // allreduce (average) gradients (if running distributed) @@ -223,9 +220,6 @@ void train_multiarg(const char* name, torchfort_tensor_list_t inputs_in, torchfo grads.push_back(p.grad()); } models[name].comm->allreduce(grads, true); - - // average returned loss value - models[name].comm->allreduce(*loss_val, true); } if (models[name].max_grad_norm > 0.0) { @@ -238,6 +232,13 @@ void train_multiarg(const char* name, torchfort_tensor_list_t inputs_in, torchfo } } + // Extract loss value + *loss_val = loss.item(); + if (models[name].comm) { + // average returned loss value (if running distributed) + models[name].comm->allreduce(*loss_val, true); + } + state->step_train++; state->step_train_current++; if (state->report_frequency > 0 && state->step_train % state->report_frequency == 0) { From 8cc1cf7664e1d76eb435ebe5742fe5808f99e324 Mon Sep 17 00:00:00 2001 From: Josh Romero Date: Wed, 4 Feb 2026 14:33:35 -0800 Subject: [PATCH 07/13] Move capture stream handling into graph state instances. Signed-off-by: Josh Romero --- src/csrc/include/internal/cuda_graphs.h | 91 ++++++++++++------------- src/csrc/torchfort.cpp | 2 +- src/csrc/training.cpp | 11 ++- 3 files changed, 50 insertions(+), 54 deletions(-) diff --git a/src/csrc/include/internal/cuda_graphs.h b/src/csrc/include/internal/cuda_graphs.h index 307b395..5f3a897 100644 --- a/src/csrc/include/internal/cuda_graphs.h +++ b/src/csrc/include/internal/cuda_graphs.h @@ -138,20 +138,26 @@ class InferenceGraphState { } // Begin graph capture - call this after prepare() returns CAPTURE and after any pre-capture work - void begin_capture(cudaStream_t capture_stream, cudaStream_t user_stream, c10::cuda::OptionalCUDAStreamGuard& guard, - int device_index) { + void begin_capture(cudaStream_t user_stream, c10::cuda::OptionalCUDAStreamGuard& stream_guard, + torch::Device device) { + + c10::cuda::CUDAGuard cuda_guard(device); + + // Create a non-blocking stream for graph capture + CHECK_CUDA(cudaStreamCreateWithFlags(&capture_stream_, cudaStreamNonBlocking)); + CHECK_CUDA(cudaStreamSynchronize(user_stream)); - auto capture_c10_stream = c10::cuda::getStreamFromExternal(capture_stream, device_index); - guard.reset_stream(capture_c10_stream); - CHECK_CUDA(cudaStreamBeginCapture(capture_stream, cudaStreamCaptureModeGlobal)); + auto capture_c10_stream = c10::cuda::getStreamFromExternal(capture_stream_, device.index()); + stream_guard.reset_stream(capture_c10_stream); + CHECK_CUDA(cudaStreamBeginCapture(capture_stream_, cudaStreamCaptureModeGlobal)); } // Finalize after forward pass - handles capture end or warmup increment - void finalize(GraphAction action, cudaStream_t capture_stream, cudaStream_t user_stream, - c10::cuda::OptionalCUDAStreamGuard& guard, int device_index, + void finalize(GraphAction action, cudaStream_t user_stream, + c10::cuda::OptionalCUDAStreamGuard& stream_guard, torch::Device device, const std::vector& outputs) { if (action == GraphAction::CAPTURE) { - end_capture(capture_stream, user_stream, guard, device_index); + end_capture(user_stream, stream_guard, device); static_outputs_ = outputs; } else if (action == GraphAction::WARMUP) { warmup_count_++; @@ -202,13 +208,16 @@ class InferenceGraphState { } } - void end_capture(cudaStream_t capture_stream, cudaStream_t user_stream, c10::cuda::OptionalCUDAStreamGuard& guard, - int device_index) { - CHECK_CUDA(cudaStreamEndCapture(capture_stream, &graph_.get())); + void end_capture(cudaStream_t user_stream, c10::cuda::OptionalCUDAStreamGuard& stream_guard, + torch::Device device) { + c10::cuda::CUDAGuard cuda_guard(device); + CHECK_CUDA(cudaStreamEndCapture(capture_stream_, &graph_.get())); instantiate_graph(); captured_ = true; - auto user_c10_stream = c10::cuda::getStreamFromExternal(user_stream, device_index); - guard.reset_stream(user_c10_stream); + auto user_c10_stream = c10::cuda::getStreamFromExternal(user_stream, device.index()); + stream_guard.reset_stream(user_c10_stream); + + CHECK_CUDA(cudaStreamDestroy(capture_stream_)); } void instantiate_graph() { @@ -233,6 +242,7 @@ class InferenceGraphState { CudaGraph graph_; CudaGraphExec graph_exec_; std::vector static_outputs_; + cudaStream_t capture_stream_ = nullptr; }; // Graph state for training (single graph for forward + loss + backward) @@ -260,19 +270,24 @@ class TrainingGraphState { } // Begin graph capture - call this after prepare() returns CAPTURE and after any pre-capture work - void begin_capture(cudaStream_t capture_stream, cudaStream_t user_stream, c10::cuda::OptionalCUDAStreamGuard& guard, - int device_index) { + void begin_capture(cudaStream_t user_stream, c10::cuda::OptionalCUDAStreamGuard& stream_guard, + torch::Device device) { + c10::cuda::CUDAGuard cuda_guard(device); + + // Create a non-blocking stream for graph capture + CHECK_CUDA(cudaStreamCreateWithFlags(&capture_stream_, cudaStreamNonBlocking)); + CHECK_CUDA(cudaStreamSynchronize(user_stream)); - auto capture_c10_stream = c10::cuda::getStreamFromExternal(capture_stream, device_index); - guard.reset_stream(capture_c10_stream); - CHECK_CUDA(cudaStreamBeginCapture(capture_stream, cudaStreamCaptureModeGlobal)); + auto capture_c10_stream = c10::cuda::getStreamFromExternal(capture_stream_, device.index()); + stream_guard.reset_stream(capture_c10_stream); + CHECK_CUDA(cudaStreamBeginCapture(capture_stream_, cudaStreamCaptureModeGlobal)); } // Finalize after forward+loss+backward pass - handles capture end or warmup increment - void finalize(GraphAction action, cudaStream_t capture_stream, cudaStream_t user_stream, - c10::cuda::OptionalCUDAStreamGuard& guard, int device_index, const torch::Tensor& loss) { + void finalize(GraphAction action, cudaStream_t user_stream, + c10::cuda::OptionalCUDAStreamGuard& stream_guard, torch::Device device, const torch::Tensor& loss) { if (action == GraphAction::CAPTURE) { - end_capture(capture_stream, user_stream, guard, device_index); + end_capture(user_stream, stream_guard, device); static_loss_ = loss; } else if (action == GraphAction::WARMUP) { warmup_count_++; @@ -337,13 +352,16 @@ class TrainingGraphState { } } - void end_capture(cudaStream_t capture_stream, cudaStream_t user_stream, c10::cuda::OptionalCUDAStreamGuard& guard, - int device_index) { - CHECK_CUDA(cudaStreamEndCapture(capture_stream, &graph_.get())); + void end_capture(cudaStream_t user_stream, c10::cuda::OptionalCUDAStreamGuard& stream_guard, + torch::Device device) { + c10::cuda::CUDAGuard cuda_guard(device); + CHECK_CUDA(cudaStreamEndCapture(capture_stream_, &graph_.get())); instantiate_graph(); captured_ = true; - auto user_c10_stream = c10::cuda::getStreamFromExternal(user_stream, device_index); - guard.reset_stream(user_c10_stream); + auto user_c10_stream = c10::cuda::getStreamFromExternal(user_stream, device.index()); + stream_guard.reset_stream(user_c10_stream); + + CHECK_CUDA(cudaStreamDestroy(capture_stream_)); } void instantiate_graph() { @@ -368,6 +386,7 @@ class TrainingGraphState { CudaGraph graph_; CudaGraphExec graph_exec_; torch::Tensor static_loss_; + cudaStream_t capture_stream_ = nullptr; }; // Graph state for a model, including the capture stream @@ -376,28 +395,8 @@ class ModelGraphState { InferenceGraphState inference{"inference"}; TrainingGraphState training{"training"}; - ModelGraphState(int device_index = 0) : capture_stream_(nullptr), device_index_(device_index) { - // Create a non-blocking stream for graph capture - CHECK_CUDA(cudaSetDevice(device_index_)); - CHECK_CUDA(cudaStreamCreateWithFlags(&capture_stream_, cudaStreamNonBlocking)); - } - - ~ModelGraphState() { - if (capture_stream_) { - cudaStreamDestroy(capture_stream_); - } - } - // Non-copyable - ModelGraphState(const ModelGraphState&) = delete; ModelGraphState& operator=(const ModelGraphState&) = delete; - - cudaStream_t capture_stream() const { return capture_stream_; } - int device_index() const { return device_index_; } - -private: - cudaStream_t capture_stream_; - int device_index_; }; #endif diff --git a/src/csrc/torchfort.cpp b/src/csrc/torchfort.cpp index 0d6acba..324738a 100644 --- a/src/csrc/torchfort.cpp +++ b/src/csrc/torchfort.cpp @@ -129,7 +129,7 @@ torchfort_result_t torchfort_create_model(const char* name, const char* config_f #ifdef ENABLE_GPU // Initialize graph state if CUDA graphs are enabled if (models[name].state->enable_cuda_graphs && models[name].model->device().is_cuda()) { - models[name].graph_state = std::make_shared(models[name].model->device().index()); + models[name].graph_state = std::make_shared(); } #endif } catch (const BaseException& e) { diff --git a/src/csrc/training.cpp b/src/csrc/training.cpp index 560ed9a..b4b4b12 100644 --- a/src/csrc/training.cpp +++ b/src/csrc/training.cpp @@ -79,8 +79,7 @@ void inference_multiarg(const char* name, torchfort_tensor_list_t inputs_in, tor action = graph_state->prepare(inputs->tensors); if (action == GraphAction::CAPTURE) { - graph_state->begin_capture(models[name].graph_state->capture_stream(), ext_stream, guard, - model->device().index()); + graph_state->begin_capture(ext_stream, stream_guard, model->device()); } } #endif @@ -92,8 +91,7 @@ void inference_multiarg(const char* name, torchfort_tensor_list_t inputs_in, tor #ifdef ENABLE_GPU if (graph_state) { - graph_state->finalize(action, models[name].graph_state->capture_stream(), ext_stream, guard, - model->device().index(), results); + graph_state->finalize(action, ext_stream, stream_guard, model->device(), results); if (action == GraphAction::CAPTURE || action == GraphAction::REPLAY) { graph_state->launch(ext_stream); @@ -186,7 +184,7 @@ void train_multiarg(const char* name, torchfort_tensor_list_t inputs_in, torchfo #ifdef ENABLE_GPU if (action == GraphAction::CAPTURE) { - graph_state->begin_capture(models[name].graph_state->capture_stream(), ext_stream, guard, model->device().index()); + graph_state->begin_capture(ext_stream, stream_guard, model->device()); } #endif @@ -200,8 +198,7 @@ void train_multiarg(const char* name, torchfort_tensor_list_t inputs_in, torchfo #ifdef ENABLE_GPU if (graph_state) { - graph_state->finalize(action, models[name].graph_state->capture_stream(), ext_stream, guard, - model->device().index(), loss); + graph_state->finalize(action, ext_stream, stream_guard, model->device(), loss); if (action == GraphAction::CAPTURE || action == GraphAction::REPLAY) { graph_state->launch(ext_stream); From b7dbad6d2b09f3ff2a49a7cf8c281352fba7880e Mon Sep 17 00:00:00 2001 From: Josh Romero Date: Wed, 4 Feb 2026 14:40:04 -0800 Subject: [PATCH 08/13] Formatting fixes. Signed-off-by: Josh Romero --- src/csrc/include/internal/cuda_graphs.h | 21 ++++++++------------- 1 file changed, 8 insertions(+), 13 deletions(-) diff --git a/src/csrc/include/internal/cuda_graphs.h b/src/csrc/include/internal/cuda_graphs.h index 5f3a897..a8a428a 100644 --- a/src/csrc/include/internal/cuda_graphs.h +++ b/src/csrc/include/internal/cuda_graphs.h @@ -138,8 +138,7 @@ class InferenceGraphState { } // Begin graph capture - call this after prepare() returns CAPTURE and after any pre-capture work - void begin_capture(cudaStream_t user_stream, c10::cuda::OptionalCUDAStreamGuard& stream_guard, - torch::Device device) { + void begin_capture(cudaStream_t user_stream, c10::cuda::OptionalCUDAStreamGuard& stream_guard, torch::Device device) { c10::cuda::CUDAGuard cuda_guard(device); @@ -153,9 +152,8 @@ class InferenceGraphState { } // Finalize after forward pass - handles capture end or warmup increment - void finalize(GraphAction action, cudaStream_t user_stream, - c10::cuda::OptionalCUDAStreamGuard& stream_guard, torch::Device device, - const std::vector& outputs) { + void finalize(GraphAction action, cudaStream_t user_stream, c10::cuda::OptionalCUDAStreamGuard& stream_guard, + torch::Device device, const std::vector& outputs) { if (action == GraphAction::CAPTURE) { end_capture(user_stream, stream_guard, device); static_outputs_ = outputs; @@ -208,8 +206,7 @@ class InferenceGraphState { } } - void end_capture(cudaStream_t user_stream, c10::cuda::OptionalCUDAStreamGuard& stream_guard, - torch::Device device) { + void end_capture(cudaStream_t user_stream, c10::cuda::OptionalCUDAStreamGuard& stream_guard, torch::Device device) { c10::cuda::CUDAGuard cuda_guard(device); CHECK_CUDA(cudaStreamEndCapture(capture_stream_, &graph_.get())); instantiate_graph(); @@ -270,8 +267,7 @@ class TrainingGraphState { } // Begin graph capture - call this after prepare() returns CAPTURE and after any pre-capture work - void begin_capture(cudaStream_t user_stream, c10::cuda::OptionalCUDAStreamGuard& stream_guard, - torch::Device device) { + void begin_capture(cudaStream_t user_stream, c10::cuda::OptionalCUDAStreamGuard& stream_guard, torch::Device device) { c10::cuda::CUDAGuard cuda_guard(device); // Create a non-blocking stream for graph capture @@ -284,8 +280,8 @@ class TrainingGraphState { } // Finalize after forward+loss+backward pass - handles capture end or warmup increment - void finalize(GraphAction action, cudaStream_t user_stream, - c10::cuda::OptionalCUDAStreamGuard& stream_guard, torch::Device device, const torch::Tensor& loss) { + void finalize(GraphAction action, cudaStream_t user_stream, c10::cuda::OptionalCUDAStreamGuard& stream_guard, + torch::Device device, const torch::Tensor& loss) { if (action == GraphAction::CAPTURE) { end_capture(user_stream, stream_guard, device); static_loss_ = loss; @@ -352,8 +348,7 @@ class TrainingGraphState { } } - void end_capture(cudaStream_t user_stream, c10::cuda::OptionalCUDAStreamGuard& stream_guard, - torch::Device device) { + void end_capture(cudaStream_t user_stream, c10::cuda::OptionalCUDAStreamGuard& stream_guard, torch::Device device) { c10::cuda::CUDAGuard cuda_guard(device); CHECK_CUDA(cudaStreamEndCapture(capture_stream_, &graph_.get())); instantiate_graph(); From 20e4c586b095e4528814b1b8bc6f9436f7bdadc3 Mon Sep 17 00:00:00 2001 From: Josh Romero Date: Thu, 5 Feb 2026 08:38:48 -0800 Subject: [PATCH 09/13] Add some graphs tests. Signed-off-by: Josh Romero --- tests/supervised/CMakeLists.txt | 3 + tests/supervised/test_training.cpp | 93 ++++++++++++++++++++++++++---- 2 files changed, 85 insertions(+), 11 deletions(-) diff --git a/tests/supervised/CMakeLists.txt b/tests/supervised/CMakeLists.txt index ceb7902..abfffd2 100644 --- a/tests/supervised/CMakeLists.txt +++ b/tests/supervised/CMakeLists.txt @@ -71,6 +71,9 @@ install(FILES ${CMAKE_CURRENT_SOURCE_DIR}/configs/mlp2_gradacc.yaml DESTINATION install(FILES ${CMAKE_CURRENT_SOURCE_DIR}/configs/missing_opt.yaml DESTINATION ${CMAKE_INSTALL_PREFIX}/bin/tests/supervised/configs) install(FILES ${CMAKE_CURRENT_SOURCE_DIR}/configs/missing_loss.yaml DESTINATION ${CMAKE_INSTALL_PREFIX}/bin/tests/supervised/configs) install(FILES ${CMAKE_CURRENT_SOURCE_DIR}/configs/torchscript.yaml DESTINATION ${CMAKE_INSTALL_PREFIX}/bin/tests/supervised/configs) +install(FILES ${CMAKE_CURRENT_SOURCE_DIR}/configs/torchscript_graphs.yaml DESTINATION ${CMAKE_INSTALL_PREFIX}/bin/tests/supervised/configs) install(FILES ${CMAKE_CURRENT_SOURCE_DIR}/configs/torchscript_multiarg.yaml DESTINATION ${CMAKE_INSTALL_PREFIX}/bin/tests/supervised/configs) install(FILES ${CMAKE_CURRENT_SOURCE_DIR}/configs/torchscript_multiarg_extra.yaml DESTINATION ${CMAKE_INSTALL_PREFIX}/bin/tests/supervised/configs) +install(FILES ${CMAKE_CURRENT_SOURCE_DIR}/configs/torchscript_multiarg_graphs.yaml DESTINATION ${CMAKE_INSTALL_PREFIX}/bin/tests/supervised/configs) +install(FILES ${CMAKE_CURRENT_SOURCE_DIR}/configs/torchscript_multiarg_extra_graphs.yaml DESTINATION ${CMAKE_INSTALL_PREFIX}/bin/tests/supervised/configs) install(FILES ${CMAKE_CURRENT_SOURCE_DIR}/scripts/setup_tests.py DESTINATION ${CMAKE_INSTALL_PREFIX}/bin/tests/supervised/scripts) diff --git a/tests/supervised/test_training.cpp b/tests/supervised/test_training.cpp index 0986752..ac2eabd 100644 --- a/tests/supervised/test_training.cpp +++ b/tests/supervised/test_training.cpp @@ -33,7 +33,7 @@ void training_test(const std::string& model_config, int dev_model, int dev_input, std::vector shape, bool should_fail_create, bool should_fail_train, bool should_fail_inference, bool check_result, - int dev_stream = -1) { + int n_train_steps = 1, int n_inference_steps = 1, int dev_stream = -1) { std::string model_name = generate_random_name(10); @@ -93,8 +93,20 @@ void training_test(const std::string& model_config, int dev_model, int dev_input #endif try { - CHECK_TORCHFORT(torchfort_train(model_name.c_str(), input_ptr, shape.size(), shape.data(), label_ptr, shape.size(), - shape.data(), &loss_val, TORCHFORT_FLOAT, stream)); + for (int i = 0; i < n_train_steps; ++i) { + auto tmp_input = generate_random(shape); + auto tmp_label = generate_random(shape); + std::copy(tmp_input.begin(), tmp_input.end(), input.begin()); + std::copy(tmp_label.begin(), tmp_label.end(), label.begin()); +#ifdef ENABLE_GPU + if (dev_input != TORCHFORT_DEVICE_CPU) { + copy_from_host_vector(input_ptr, input); + copy_from_host_vector(label_ptr, label); + } +#endif + CHECK_TORCHFORT(torchfort_train(model_name.c_str(), input_ptr, shape.size(), shape.data(), label_ptr, shape.size(), + shape.data(), &loss_val, TORCHFORT_FLOAT, stream)); + } if (should_fail_train) { FAIL() << "This test should fail train call, but did not."; } @@ -123,8 +135,17 @@ void training_test(const std::string& model_config, int dev_model, int dev_input #endif try { - CHECK_TORCHFORT(torchfort_inference(model_name.c_str(), input_ptr, shape.size(), shape.data(), output_ptr, - shape.size(), shape.data(), TORCHFORT_FLOAT, stream)); + for (int i = 0; i < n_inference_steps; ++i) { + auto tmp_input = generate_random(shape); + std::copy(tmp_input.begin(), tmp_input.end(), input.begin()); +#ifdef ENABLE_GPU + if (dev_input != TORCHFORT_DEVICE_CPU) { + copy_from_host_vector(input_ptr, input); + } +#endif + CHECK_TORCHFORT(torchfort_inference(model_name.c_str(), input_ptr, shape.size(), shape.data(), output_ptr, + shape.size(), shape.data(), TORCHFORT_FLOAT, stream)); + } if (should_fail_inference) { FAIL() << "This test should fail inference call, but did not."; } @@ -136,7 +157,7 @@ void training_test(const std::string& model_config, int dev_model, int dev_input } } catch (const c10::Error& e) { std::cout << e.what() << std::endl; - if (should_fail_train) { + if (should_fail_inference) { // pass } else { FAIL(); @@ -175,7 +196,7 @@ void training_test(const std::string& model_config, int dev_model, int dev_input void training_test_multiarg(const std::string& model_config, int dev_model, int dev_input, bool use_extra_args, bool should_fail_create, bool should_fail_train, bool should_fail_inference, - bool check_result) { + bool check_result, int n_train_steps = 1, int n_inference_steps = 1) { #ifdef ENABLE_GPU if (dev_model == 1 || dev_input == 1) { int ngpu; @@ -250,8 +271,29 @@ void training_test_multiarg(const std::string& model_config, int dev_model, int } try { - CHECK_TORCHFORT(torchfort_train_multiarg(model_name.c_str(), inputs_tl, labels_tl, &loss_val, - (use_extra_args) ? extra_args_tl : nullptr, 0)); + for (int i = 0; i < n_train_steps; ++i) { + for (int i = 0; i < 2; ++i) { + auto tmp_input = generate_random(shape); + std::copy(tmp_input.begin(), tmp_input.end(), inputs[i].begin()); + auto tmp_label = generate_random(shape); + std::copy(tmp_label.begin(), tmp_label.end(), labels[i].begin()); + if (use_extra_args) { + auto tmp_extra_args = generate_random(shape); + std::copy(tmp_extra_args.begin(), tmp_extra_args.end(), extra_args[i].begin()); + } +#ifdef ENABLE_GPU + if (dev_input != TORCHFORT_DEVICE_CPU) { + copy_from_host_vector(input_ptrs[i], inputs[i]); + copy_from_host_vector(label_ptrs[i], labels[i]); + if (use_extra_args) { + copy_from_host_vector(extra_args_ptrs[i], extra_args[i]); + } + } +#endif + } + CHECK_TORCHFORT(torchfort_train_multiarg(model_name.c_str(), inputs_tl, labels_tl, &loss_val, + (use_extra_args) ? extra_args_tl : nullptr, 0)); + } if (should_fail_train) { FAIL() << "This test should fail train call, but did not."; } @@ -267,7 +309,18 @@ void training_test_multiarg(const std::string& model_config, int dev_model, int FAIL() << "GPU device switched by torchfort_train_multiarg."; try { - CHECK_TORCHFORT(torchfort_inference_multiarg(model_name.c_str(), inputs_tl, outputs_tl, 0)); + for (int i = 0; i < n_inference_steps; ++i) { + for (int i = 0; i < 2; ++i) { + auto tmp_input = generate_random(shape); + std::copy(tmp_input.begin(), tmp_input.end(), inputs[i].begin()); +#ifdef ENABLE_GPU + if (dev_input != TORCHFORT_DEVICE_CPU) { + copy_from_host_vector(input_ptrs[i], inputs[i]); + } +#endif + } + CHECK_TORCHFORT(torchfort_inference_multiarg(model_name.c_str(), inputs_tl, outputs_tl, 0)); + } if (should_fail_inference) { FAIL() << "This test should fail inference call, but did not."; } @@ -494,6 +547,12 @@ TEST(TorchFort, TrainTestTorchScriptGPUCPU) { TEST(TorchFort, TrainTestTorchScriptGPUGPU) { training_test("configs/torchscript.yaml", 0, 0, {10, 2, 10}, false, false, false, true); } +TEST(TorchFort, TrainTestTorchScriptCPUGPUGraphs) { + training_test("configs/torchscript_graphs.yaml", TORCHFORT_DEVICE_CPU, 0, {10, 2, 10}, false, false, false, true, 5, 5); +} +TEST(TorchFort, TrainTestTorchScriptGPUGPUGraphs) { + training_test("configs/torchscript_graphs.yaml", 0, 0, {10, 2, 10}, false, false, false, true, 5, 5); +} TEST(TorchFort, TrainTestTorchScriptMultiArgCPUGPU) { training_test_multiarg("configs/torchscript_multiarg.yaml", TORCHFORT_DEVICE_CPU, 0, false, false, false, false, true); @@ -505,6 +564,12 @@ TEST(TorchFort, TrainTestTorchScriptMultiArgGPUCPU) { TEST(TorchFort, TrainTestTorchScriptMultiArgGPUGPU) { training_test_multiarg("configs/torchscript_multiarg.yaml", 0, 0, false, false, false, false, true); } +TEST(TorchFort, TrainTestTorchScriptMultiArgCPUGPUGraphs) { + training_test_multiarg("configs/torchscript_multiarg_graphs.yaml", TORCHFORT_DEVICE_CPU, 0, false, false, false, false, true, 5, 5); +} +TEST(TorchFort, TrainTestTorchScriptMultiArgGPUGPUGraphs) { + training_test_multiarg("configs/torchscript_multiarg_graphs.yaml", 0, 0, false, false, false, false, true, 5, 5); +} TEST(TorchFort, TrainTestTorchScriptMultiArgCPUGPU1) { training_test_multiarg("configs/torchscript_multiarg.yaml", TORCHFORT_DEVICE_CPU, 1, false, false, false, false, true); @@ -530,6 +595,12 @@ TEST(TorchFort, TrainTestTorchScriptMultiArgExtraGPUCPU) { TEST(TorchFort, TrainTestTorchScriptMultiArgExtraGPUGPU) { training_test_multiarg("configs/torchscript_multiarg_extra.yaml", 0, 0, true, false, false, false, true); } +TEST(TorchFort, TrainTestTorchScriptMultiArgExtraCPUGPUGraphs) { + training_test_multiarg("configs/torchscript_multiarg_extra_graphs.yaml", TORCHFORT_DEVICE_CPU, 0, true, false, false, false, true, 5, 5); +} +TEST(TorchFort, TrainTestTorchScriptMultiArgExtraGPUGPUGraphs) { + training_test_multiarg("configs/torchscript_multiarg_extra_graphs.yaml", 0, 0, true, false, false, false, true, 5, 5); +} #endif // Testing expected error cases @@ -558,7 +629,7 @@ TEST(TorchFort, TrainTestMLPCPUCPU1DDimError) { #ifdef ENABLE_GPU TEST(TorchFort, TrainTestMLPGPUGPUStreamWrongDeviceError) { - training_test("configs/mlp2.yaml", 0, 0, {10, 10}, false, true, true, false, 1); + training_test("configs/mlp2.yaml", 0, 0, {10, 10}, false, true, true, false, 1, 1, 1); } #endif From 24c9eefa1dff63ffee380716eaa4ded19923beb9 Mon Sep 17 00:00:00 2001 From: Josh Romero Date: Thu, 5 Feb 2026 16:04:11 -0800 Subject: [PATCH 10/13] Add more graphs tests. Disable CPU inputs in graph capture mode. Signed-off-by: Josh Romero --- src/csrc/include/internal/cuda_graphs.h | 12 + src/csrc/training.cpp | 26 +- .../configs/torchscript_graphs.yaml | 13 + .../torchscript_multiarg_extra_graphs.yaml | 15 ++ .../configs/torchscript_multiarg_graphs.yaml | 15 ++ tests/supervised/test_training.cpp | 237 +++++++++++++++++- 6 files changed, 305 insertions(+), 13 deletions(-) create mode 100644 tests/supervised/configs/torchscript_graphs.yaml create mode 100644 tests/supervised/configs/torchscript_multiarg_extra_graphs.yaml create mode 100644 tests/supervised/configs/torchscript_multiarg_graphs.yaml diff --git a/src/csrc/include/internal/cuda_graphs.h b/src/csrc/include/internal/cuda_graphs.h index a8a428a..951738f 100644 --- a/src/csrc/include/internal/cuda_graphs.h +++ b/src/csrc/include/internal/cuda_graphs.h @@ -191,6 +191,9 @@ class InferenceGraphState { sig.ptrs.push_back(t.data_ptr()); sig.shapes.push_back(t.sizes().vec()); sig.dtypes.push_back(t.scalar_type()); + if (!t.device().is_cuda()) { + THROW_INVALID_USAGE("Model inputs must be on GPU when enable_cuda_graphs is true."); + } } return sig; } @@ -323,16 +326,25 @@ class TrainingGraphState { sig.ptrs.push_back(t.data_ptr()); sig.shapes.push_back(t.sizes().vec()); sig.dtypes.push_back(t.scalar_type()); + if (!t.device().is_cuda()) { + THROW_INVALID_USAGE("Model inputs must be on GPU when enable_cuda_graphs is true."); + } } for (const auto& t : labels) { sig.ptrs.push_back(t.data_ptr()); sig.shapes.push_back(t.sizes().vec()); sig.dtypes.push_back(t.scalar_type()); + if (!t.device().is_cuda()) { + THROW_INVALID_USAGE("Model labels must be on GPU when enable_cuda_graphs is true."); + } } for (const auto& t : extra_args) { sig.ptrs.push_back(t.data_ptr()); sig.shapes.push_back(t.sizes().vec()); sig.dtypes.push_back(t.scalar_type()); + if (!t.device().is_cuda()) { + THROW_INVALID_USAGE("Model extra args must be on GPU when enable_cuda_graphs is true."); + } } return sig; } diff --git a/src/csrc/training.cpp b/src/csrc/training.cpp index b4b4b12..5fef883 100644 --- a/src/csrc/training.cpp +++ b/src/csrc/training.cpp @@ -63,8 +63,6 @@ void inference_multiarg(const char* name, torchfort_tensor_list_t inputs_in, tor set_device_and_stream(stream_guard, cuda_guard, model->device(), ext_stream); #endif - inputs->to(model->device()); - model->eval(); std::vector results; @@ -78,9 +76,16 @@ void inference_multiarg(const char* name, torchfort_tensor_list_t inputs_in, tor graph_state = &models[name].graph_state->inference; action = graph_state->prepare(inputs->tensors); + // Move inputs to model device after prepare() to catch CPU inputs + inputs->to(model->device()); + if (action == GraphAction::CAPTURE) { graph_state->begin_capture(ext_stream, stream_guard, model->device()); } + } else { +#endif + inputs->to(model->device()); +#ifdef ENABLE_GPU } #endif @@ -142,11 +147,6 @@ void train_multiarg(const char* name, torchfort_tensor_list_t inputs_in, torchfo set_device_and_stream(stream_guard, cuda_guard, model->device(), ext_stream); #endif - inputs->to(model->device()); - labels->to(model->device()); - if (extra_loss_args) - extra_loss_args->to(model->device()); - model->train(); auto opt = models[name].optimizer; auto state = models[name].state; @@ -163,6 +163,18 @@ void train_multiarg(const char* name, torchfort_tensor_list_t inputs_in, torchfo std::vector extra_args_vec = extra_loss_args ? extra_loss_args->tensors : std::vector(); action = graph_state->prepare(inputs->tensors, labels->tensors, extra_args_vec); + // Move inputs to model device after prepare() to catch CPU inputs + inputs->to(model->device()); + labels->to(model->device()); + if (extra_loss_args) + extra_loss_args->to(model->device()); + } else { +#endif + inputs->to(model->device()); + labels->to(model->device()); + if (extra_loss_args) + extra_loss_args->to(model->device()); +#ifdef ENABLE_GPU } #endif diff --git a/tests/supervised/configs/torchscript_graphs.yaml b/tests/supervised/configs/torchscript_graphs.yaml new file mode 100644 index 0000000..f4285e7 --- /dev/null +++ b/tests/supervised/configs/torchscript_graphs.yaml @@ -0,0 +1,13 @@ +general: + enable_cuda_graphs: 1 + +model: + type: torchscript + parameters: + filename: "model.pt" + +loss: + type: MSE + +optimizer: + type: adam diff --git a/tests/supervised/configs/torchscript_multiarg_extra_graphs.yaml b/tests/supervised/configs/torchscript_multiarg_extra_graphs.yaml new file mode 100644 index 0000000..f49a9fe --- /dev/null +++ b/tests/supervised/configs/torchscript_multiarg_extra_graphs.yaml @@ -0,0 +1,15 @@ +general: + enable_cuda_graphs: 1 + +model: + type: torchscript + parameters: + filename: "model_multiarg.pt" + +loss: + type: torchscript + parameters: + filename: "loss_multiarg_extra.pt" + +optimizer: + type: adam diff --git a/tests/supervised/configs/torchscript_multiarg_graphs.yaml b/tests/supervised/configs/torchscript_multiarg_graphs.yaml new file mode 100644 index 0000000..1c8f64b --- /dev/null +++ b/tests/supervised/configs/torchscript_multiarg_graphs.yaml @@ -0,0 +1,15 @@ +general: + enable_cuda_graphs: 1 + +model: + type: torchscript + parameters: + filename: "model_multiarg.pt" + +loss: + type: torchscript + parameters: + filename: "loss_multiarg.pt" + +optimizer: + type: adam diff --git a/tests/supervised/test_training.cpp b/tests/supervised/test_training.cpp index ac2eabd..a5449e7 100644 --- a/tests/supervised/test_training.cpp +++ b/tests/supervised/test_training.cpp @@ -486,6 +486,222 @@ void training_test_grad_accumulation(const std::string& model_config, int dev_mo free_data_ptr(output2_ptr, dev_input); } +#ifdef ENABLE_GPU +void training_test_graphs_errors(int dev_input) { + + if (dev_input != TORCHFORT_DEVICE_CPU) { + CHECK_CUDA(cudaSetDevice(dev_input)); + } + + std::string model_name = generate_random_name(10); + CHECK_TORCHFORT(torchfort_create_model(model_name.c_str(), "configs/torchscript_graphs.yaml", 0)); + + std::vector shape = {10, 2, 10}; + std::vector shape2 = {10, 3, 10}; + auto input = generate_random(shape); + auto label = generate_random(shape); + auto output = generate_random(shape); + float loss_val; + + float* input_ptr = get_data_ptr(input, dev_input); + float* label_ptr = get_data_ptr(label, dev_input); + float* output_ptr = get_data_ptr(output, dev_input); + + // Train and run inference for 4 iterations to trigger graph capture + for (int i = 0; i < 4; ++i) { + CHECK_TORCHFORT(torchfort_train(model_name.c_str(), input_ptr, shape.size(), shape.data(), label_ptr, shape.size(), + shape.data(), &loss_val, TORCHFORT_FLOAT, 0)); + CHECK_TORCHFORT(torchfort_inference(model_name.c_str(), input_ptr, shape.size(), shape.data(), output_ptr, + shape.size(), shape.data(), TORCHFORT_FLOAT, 0)); + } + + // Change input buffer + auto input2 = generate_random(shape); + float* input2_ptr = get_data_ptr(input2, dev_input); + try { + CHECK_TORCHFORT(torchfort_train(model_name.c_str(), input2_ptr, shape.size(), shape.data(), label_ptr, shape.size(), + shape.data(), &loss_val, TORCHFORT_FLOAT, 0)); + FAIL() << "This test should fail train call, but did not."; + } catch (const torchfort::BaseException& e) { + // pass + } + + // Change label buffer + auto label2 = generate_random(shape); + float* label2_ptr = get_data_ptr(label2, dev_input); + try { + CHECK_TORCHFORT(torchfort_train(model_name.c_str(), input_ptr, shape.size(), shape.data(), label2_ptr, shape.size(), + shape.data(), &loss_val, TORCHFORT_FLOAT, 0)); + FAIL() << "This test should fail train call, but did not."; + } catch (const torchfort::BaseException& e) { + // pass + } + + // Change input buffer for inference + try { + CHECK_TORCHFORT(torchfort_inference(model_name.c_str(), input2_ptr, shape.size(), shape.data(), output_ptr, + shape.size(), shape.data(), TORCHFORT_FLOAT, 0)); + FAIL() << "This test should fail inference call, but did not."; + } catch (const torchfort::BaseException& e) { + // pass + } + + free_data_ptr(input_ptr, dev_input); + free_data_ptr(label_ptr, dev_input); + free_data_ptr(output_ptr, dev_input); + free_data_ptr(input2_ptr, dev_input); + free_data_ptr(label2_ptr, dev_input); +} + +void training_test_multiarg_graphs_errors(int dev_input, bool use_extra_args) { + + if (dev_input != TORCHFORT_DEVICE_CPU) { + CHECK_CUDA(cudaSetDevice(dev_input)); + } + + std::string model_name = generate_random_name(10); + if (use_extra_args) { + CHECK_TORCHFORT(torchfort_create_model(model_name.c_str(), "configs/torchscript_multiarg_extra_graphs.yaml", 0)); + } else { + CHECK_TORCHFORT(torchfort_create_model(model_name.c_str(), "configs/torchscript_multiarg_graphs.yaml", 0)); + } + + std::vector shape = {10, 10}; + std::vector> inputs(2), labels(2), outputs(2); + std::vector> inputs2(2), labels2(2); + for (int i = 0; i < 2; ++i) { + inputs[i] = generate_random(shape); + labels[i] = generate_random(shape); + outputs[i] = generate_random(shape); + inputs2[i] = generate_random(shape); + labels2[i] = generate_random(shape); + } + + float loss_val; + + std::vector> extra_args; + std::vector> extra_args2; + if (use_extra_args) { + for (int i = 0; i < 2; ++i) { + extra_args.push_back(generate_random(shape)); + extra_args2.push_back(generate_random(shape)); + } + } + + + torchfort_tensor_list_t inputs_tl, labels_tl, outputs_tl; + torchfort_tensor_list_t inputs2_tl, labels2_tl, outputs2_tl; + CHECK_TORCHFORT(torchfort_tensor_list_create(&inputs_tl)); + CHECK_TORCHFORT(torchfort_tensor_list_create(&labels_tl)); + CHECK_TORCHFORT(torchfort_tensor_list_create(&outputs_tl)); + CHECK_TORCHFORT(torchfort_tensor_list_create(&inputs2_tl)); + CHECK_TORCHFORT(torchfort_tensor_list_create(&labels2_tl)); + + std::vector input_ptrs(2), label_ptrs(2), output_ptrs(2); + std::vector input2_ptrs(2), label2_ptrs(2); + + for (int i = 0; i < 2; ++i) { + input_ptrs[i] = get_data_ptr(inputs[i], dev_input); + label_ptrs[i] = get_data_ptr(labels[i], dev_input); + output_ptrs[i] = get_data_ptr(outputs[i], dev_input); + CHECK_TORCHFORT( + torchfort_tensor_list_add_tensor(inputs_tl, input_ptrs[i], shape.size(), shape.data(), TORCHFORT_FLOAT)); + CHECK_TORCHFORT( + torchfort_tensor_list_add_tensor(labels_tl, label_ptrs[i], shape.size(), shape.data(), TORCHFORT_FLOAT)); + CHECK_TORCHFORT( + torchfort_tensor_list_add_tensor(outputs_tl, output_ptrs[i], shape.size(), shape.data(), TORCHFORT_FLOAT)); + + input2_ptrs[i] = get_data_ptr(inputs2[i], dev_input); + label2_ptrs[i] = get_data_ptr(labels2[i], dev_input); + CHECK_TORCHFORT( + torchfort_tensor_list_add_tensor(inputs2_tl, input2_ptrs[i], shape.size(), shape.data(), TORCHFORT_FLOAT)); + CHECK_TORCHFORT( + torchfort_tensor_list_add_tensor(labels2_tl, label2_ptrs[i], shape.size(), shape.data(), TORCHFORT_FLOAT)); + } + + torchfort_tensor_list_t extra_args_tl; + torchfort_tensor_list_t extra_args2_tl; + std::vector extra_args_ptrs(2); + std::vector extra_args2_ptrs(2); + if (use_extra_args) { + torchfort_tensor_list_create(&extra_args_tl); + torchfort_tensor_list_create(&extra_args2_tl); + for (int i = 0; i < 2; ++i) { + extra_args_ptrs[i] = get_data_ptr(extra_args[i], dev_input); + CHECK_TORCHFORT(torchfort_tensor_list_add_tensor(extra_args_tl, extra_args_ptrs[i], shape.size(), shape.data(), + TORCHFORT_FLOAT)); + extra_args2_ptrs[i] = get_data_ptr(extra_args2[i], dev_input); + CHECK_TORCHFORT(torchfort_tensor_list_add_tensor(extra_args2_tl, extra_args2_ptrs[i], shape.size(), shape.data(), + TORCHFORT_FLOAT)); + } + } + + // Train and run inference for 4 iterations to trigger graph capture + for (int i = 0; i < 4; ++i) { + CHECK_TORCHFORT(torchfort_train_multiarg(model_name.c_str(), inputs_tl, labels_tl, &loss_val, + (use_extra_args) ? extra_args_tl : nullptr, 0)); + CHECK_TORCHFORT(torchfort_inference_multiarg(model_name.c_str(), inputs_tl, outputs_tl, 0)); + } + + // Change input buffer + try { + CHECK_TORCHFORT(torchfort_train_multiarg(model_name.c_str(), inputs2_tl, labels_tl, &loss_val, + (use_extra_args) ? extra_args_tl : nullptr, 0)); + FAIL() << "This test should fail train call, but did not."; + } catch (const torchfort::BaseException& e) { + // pass + } + + // Change label buffer + try { + CHECK_TORCHFORT(torchfort_train_multiarg(model_name.c_str(), inputs_tl, labels2_tl, &loss_val, + (use_extra_args) ? extra_args_tl : nullptr, 0)); + FAIL() << "This test should fail train call, but did not."; + } catch (const torchfort::BaseException& e) { + // pass + } + + // Change extra args buffer + if (use_extra_args) { + try { + CHECK_TORCHFORT(torchfort_train_multiarg(model_name.c_str(), inputs_tl, labels_tl, &loss_val, + (use_extra_args) ? extra_args2_tl : nullptr, 0)); + FAIL() << "This test should fail train call, but did not."; + } catch (const torchfort::BaseException& e) { + // pass + } + } + + // Change input buffer for inference + try { + CHECK_TORCHFORT(torchfort_inference_multiarg(model_name.c_str(), inputs2_tl, outputs_tl, 0)); + FAIL() << "This test should fail inference call, but did not."; + } catch (const torchfort::BaseException& e) { + // pass + } + + for (int i = 0; i < 2; ++i) { + free_data_ptr(input_ptrs[i], dev_input); + free_data_ptr(label_ptrs[i], dev_input); + free_data_ptr(input2_ptrs[i], dev_input); + free_data_ptr(label2_ptrs[i], dev_input); + if (use_extra_args) { + free_data_ptr(extra_args_ptrs[i], dev_input); + free_data_ptr(extra_args2_ptrs[i], dev_input); + } + } + CHECK_TORCHFORT(torchfort_tensor_list_destroy(inputs_tl)); + CHECK_TORCHFORT(torchfort_tensor_list_destroy(labels_tl)); + CHECK_TORCHFORT(torchfort_tensor_list_destroy(outputs_tl)); + CHECK_TORCHFORT(torchfort_tensor_list_destroy(inputs2_tl)); + CHECK_TORCHFORT(torchfort_tensor_list_destroy(labels2_tl)); + if (use_extra_args) { + CHECK_TORCHFORT(torchfort_tensor_list_destroy(extra_args_tl)); + CHECK_TORCHFORT(torchfort_tensor_list_destroy(extra_args2_tl)); + } +} +#endif + TEST(TorchFort, TrainTestMLPCPUCPU) { training_test("configs/mlp2.yaml", TORCHFORT_DEVICE_CPU, TORCHFORT_DEVICE_CPU, {10, 2, 5}, false, false, false, false); @@ -547,8 +763,8 @@ TEST(TorchFort, TrainTestTorchScriptGPUCPU) { TEST(TorchFort, TrainTestTorchScriptGPUGPU) { training_test("configs/torchscript.yaml", 0, 0, {10, 2, 10}, false, false, false, true); } -TEST(TorchFort, TrainTestTorchScriptCPUGPUGraphs) { - training_test("configs/torchscript_graphs.yaml", TORCHFORT_DEVICE_CPU, 0, {10, 2, 10}, false, false, false, true, 5, 5); +TEST(TorchFort, TrainTestTorchScriptGPUCPUGraphs) { + training_test("configs/torchscript_graphs.yaml", 0, TORCHFORT_DEVICE_CPU, {10, 2, 10}, false, true, true, false, 5, 5); } TEST(TorchFort, TrainTestTorchScriptGPUGPUGraphs) { training_test("configs/torchscript_graphs.yaml", 0, 0, {10, 2, 10}, false, false, false, true, 5, 5); @@ -564,8 +780,8 @@ TEST(TorchFort, TrainTestTorchScriptMultiArgGPUCPU) { TEST(TorchFort, TrainTestTorchScriptMultiArgGPUGPU) { training_test_multiarg("configs/torchscript_multiarg.yaml", 0, 0, false, false, false, false, true); } -TEST(TorchFort, TrainTestTorchScriptMultiArgCPUGPUGraphs) { - training_test_multiarg("configs/torchscript_multiarg_graphs.yaml", TORCHFORT_DEVICE_CPU, 0, false, false, false, false, true, 5, 5); +TEST(TorchFort, TrainTestTorchScriptMultiArgGPUCPUGraphs) { + training_test_multiarg("configs/torchscript_multiarg_graphs.yaml", 0, TORCHFORT_DEVICE_CPU, false, false, true, true, false, 5, 5); } TEST(TorchFort, TrainTestTorchScriptMultiArgGPUGPUGraphs) { training_test_multiarg("configs/torchscript_multiarg_graphs.yaml", 0, 0, false, false, false, false, true, 5, 5); @@ -595,8 +811,8 @@ TEST(TorchFort, TrainTestTorchScriptMultiArgExtraGPUCPU) { TEST(TorchFort, TrainTestTorchScriptMultiArgExtraGPUGPU) { training_test_multiarg("configs/torchscript_multiarg_extra.yaml", 0, 0, true, false, false, false, true); } -TEST(TorchFort, TrainTestTorchScriptMultiArgExtraCPUGPUGraphs) { - training_test_multiarg("configs/torchscript_multiarg_extra_graphs.yaml", TORCHFORT_DEVICE_CPU, 0, true, false, false, false, true, 5, 5); +TEST(TorchFort, TrainTestTorchScriptMultiArgExtraGPUCPUGraphs) { + training_test_multiarg("configs/torchscript_multiarg_extra_graphs.yaml", 0, TORCHFORT_DEVICE_CPU, true, false, true, true, false, 5, 5); } TEST(TorchFort, TrainTestTorchScriptMultiArgExtraGPUGPUGraphs) { training_test_multiarg("configs/torchscript_multiarg_extra_graphs.yaml", 0, 0, true, false, false, false, true, 5, 5); @@ -631,6 +847,15 @@ TEST(TorchFort, TrainTestMLPCPUCPU1DDimError) { TEST(TorchFort, TrainTestMLPGPUGPUStreamWrongDeviceError) { training_test("configs/mlp2.yaml", 0, 0, {10, 10}, false, true, true, false, 1, 1, 1); } +TEST(TorchFort, TrainTestTorchScriptGraphsErrors) { + training_test_graphs_errors(0); +} +TEST(TorchFort, TrainTestTorchScriptMultiArgGraphsErrors) { + training_test_multiarg_graphs_errors(0, false); +} +TEST(TorchFort, TrainTestTorchScriptMultiArgExtraGraphsErrors) { + training_test_multiarg_graphs_errors(0, true); +} #endif int main(int argc, char* argv[]) { From 915b9b45dbd1b48d500e07001523fd88d888df33 Mon Sep 17 00:00:00 2001 From: Josh Romero Date: Thu, 5 Feb 2026 16:07:28 -0800 Subject: [PATCH 11/13] Formatting. Signed-off-by: Josh Romero --- tests/supervised/test_training.cpp | 26 +++++++++++--------------- 1 file changed, 11 insertions(+), 15 deletions(-) diff --git a/tests/supervised/test_training.cpp b/tests/supervised/test_training.cpp index a5449e7..1d4e7a2 100644 --- a/tests/supervised/test_training.cpp +++ b/tests/supervised/test_training.cpp @@ -104,8 +104,8 @@ void training_test(const std::string& model_config, int dev_model, int dev_input copy_from_host_vector(label_ptr, label); } #endif - CHECK_TORCHFORT(torchfort_train(model_name.c_str(), input_ptr, shape.size(), shape.data(), label_ptr, shape.size(), - shape.data(), &loss_val, TORCHFORT_FLOAT, stream)); + CHECK_TORCHFORT(torchfort_train(model_name.c_str(), input_ptr, shape.size(), shape.data(), label_ptr, + shape.size(), shape.data(), &loss_val, TORCHFORT_FLOAT, stream)); } if (should_fail_train) { FAIL() << "This test should fail train call, but did not."; @@ -588,7 +588,6 @@ void training_test_multiarg_graphs_errors(int dev_input, bool use_extra_args) { } } - torchfort_tensor_list_t inputs_tl, labels_tl, outputs_tl; torchfort_tensor_list_t inputs2_tl, labels2_tl, outputs2_tl; CHECK_TORCHFORT(torchfort_tensor_list_create(&inputs_tl)); @@ -764,7 +763,8 @@ TEST(TorchFort, TrainTestTorchScriptGPUGPU) { training_test("configs/torchscript.yaml", 0, 0, {10, 2, 10}, false, false, false, true); } TEST(TorchFort, TrainTestTorchScriptGPUCPUGraphs) { - training_test("configs/torchscript_graphs.yaml", 0, TORCHFORT_DEVICE_CPU, {10, 2, 10}, false, true, true, false, 5, 5); + training_test("configs/torchscript_graphs.yaml", 0, TORCHFORT_DEVICE_CPU, {10, 2, 10}, false, true, true, false, 5, + 5); } TEST(TorchFort, TrainTestTorchScriptGPUGPUGraphs) { training_test("configs/torchscript_graphs.yaml", 0, 0, {10, 2, 10}, false, false, false, true, 5, 5); @@ -781,7 +781,8 @@ TEST(TorchFort, TrainTestTorchScriptMultiArgGPUGPU) { training_test_multiarg("configs/torchscript_multiarg.yaml", 0, 0, false, false, false, false, true); } TEST(TorchFort, TrainTestTorchScriptMultiArgGPUCPUGraphs) { - training_test_multiarg("configs/torchscript_multiarg_graphs.yaml", 0, TORCHFORT_DEVICE_CPU, false, false, true, true, false, 5, 5); + training_test_multiarg("configs/torchscript_multiarg_graphs.yaml", 0, TORCHFORT_DEVICE_CPU, false, false, true, true, + false, 5, 5); } TEST(TorchFort, TrainTestTorchScriptMultiArgGPUGPUGraphs) { training_test_multiarg("configs/torchscript_multiarg_graphs.yaml", 0, 0, false, false, false, false, true, 5, 5); @@ -812,7 +813,8 @@ TEST(TorchFort, TrainTestTorchScriptMultiArgExtraGPUGPU) { training_test_multiarg("configs/torchscript_multiarg_extra.yaml", 0, 0, true, false, false, false, true); } TEST(TorchFort, TrainTestTorchScriptMultiArgExtraGPUCPUGraphs) { - training_test_multiarg("configs/torchscript_multiarg_extra_graphs.yaml", 0, TORCHFORT_DEVICE_CPU, true, false, true, true, false, 5, 5); + training_test_multiarg("configs/torchscript_multiarg_extra_graphs.yaml", 0, TORCHFORT_DEVICE_CPU, true, false, true, + true, false, 5, 5); } TEST(TorchFort, TrainTestTorchScriptMultiArgExtraGPUGPUGraphs) { training_test_multiarg("configs/torchscript_multiarg_extra_graphs.yaml", 0, 0, true, false, false, false, true, 5, 5); @@ -847,15 +849,9 @@ TEST(TorchFort, TrainTestMLPCPUCPU1DDimError) { TEST(TorchFort, TrainTestMLPGPUGPUStreamWrongDeviceError) { training_test("configs/mlp2.yaml", 0, 0, {10, 10}, false, true, true, false, 1, 1, 1); } -TEST(TorchFort, TrainTestTorchScriptGraphsErrors) { - training_test_graphs_errors(0); -} -TEST(TorchFort, TrainTestTorchScriptMultiArgGraphsErrors) { - training_test_multiarg_graphs_errors(0, false); -} -TEST(TorchFort, TrainTestTorchScriptMultiArgExtraGraphsErrors) { - training_test_multiarg_graphs_errors(0, true); -} +TEST(TorchFort, TrainTestTorchScriptGraphsErrors) { training_test_graphs_errors(0); } +TEST(TorchFort, TrainTestTorchScriptMultiArgGraphsErrors) { training_test_multiarg_graphs_errors(0, false); } +TEST(TorchFort, TrainTestTorchScriptMultiArgExtraGraphsErrors) { training_test_multiarg_graphs_errors(0, true); } #endif int main(int argc, char* argv[]) { From 3648abf68ba3a1bc5c296eff4ddf35a2535013bf Mon Sep 17 00:00:00 2001 From: Josh Romero Date: Thu, 5 Feb 2026 16:25:01 -0800 Subject: [PATCH 12/13] Add error checking to get_loss and get_outputs methods. Signed-off-by: Josh Romero --- src/csrc/include/internal/cuda_graphs.h | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/src/csrc/include/internal/cuda_graphs.h b/src/csrc/include/internal/cuda_graphs.h index 951738f..d6b462d 100644 --- a/src/csrc/include/internal/cuda_graphs.h +++ b/src/csrc/include/internal/cuda_graphs.h @@ -166,7 +166,12 @@ class InferenceGraphState { void launch(cudaStream_t stream) { CHECK_CUDA(cudaGraphLaunch(graph_exec_.get(), stream)); } // Get static outputs (valid after CAPTURE or REPLAY) - const std::vector& get_outputs() const { return static_outputs_; } + const std::vector& get_outputs() const { + if (!captured_) { + THROW_INTERNAL_ERROR("Attempting to get static outputs before graph has been captured."); + } + return static_outputs_; + } bool is_captured() const { return captured_; } @@ -297,7 +302,12 @@ class TrainingGraphState { void launch(cudaStream_t stream) { CHECK_CUDA(cudaGraphLaunch(graph_exec_.get(), stream)); } // Get static loss (valid after CAPTURE or REPLAY) - const torch::Tensor& get_loss() const { return static_loss_; } + const torch::Tensor& get_loss() const { + if (!captured_) { + THROW_INTERNAL_ERROR("Attempting to get static loss before graph has been captured."); + } + return static_loss_; + } bool is_captured() const { return captured_; } From 4a9b294eb1187d40555f6a9750a3d72d080075d4 Mon Sep 17 00:00:00 2001 From: Josh Romero Date: Fri, 6 Feb 2026 10:14:07 -0800 Subject: [PATCH 13/13] Handle empty tensors in cuda graphs input device checks. Signed-off-by: Josh Romero --- src/csrc/include/internal/cuda_graphs.h | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/csrc/include/internal/cuda_graphs.h b/src/csrc/include/internal/cuda_graphs.h index d6b462d..0b946a5 100644 --- a/src/csrc/include/internal/cuda_graphs.h +++ b/src/csrc/include/internal/cuda_graphs.h @@ -196,7 +196,9 @@ class InferenceGraphState { sig.ptrs.push_back(t.data_ptr()); sig.shapes.push_back(t.sizes().vec()); sig.dtypes.push_back(t.scalar_type()); - if (!t.device().is_cuda()) { + // A user can pass an "empty" GPU tensor as an argument, which will have a nullptr address and + // not be associated with device. Skip those inputs in the device check. + if (t.data_ptr() && !t.device().is_cuda()) { THROW_INVALID_USAGE("Model inputs must be on GPU when enable_cuda_graphs is true."); } } @@ -336,7 +338,9 @@ class TrainingGraphState { sig.ptrs.push_back(t.data_ptr()); sig.shapes.push_back(t.sizes().vec()); sig.dtypes.push_back(t.scalar_type()); - if (!t.device().is_cuda()) { + // A user can pass an "empty" GPU tensor as an argument, which will have a nullptr address and + // not be associated with device. Skip those inputs in the device check. + if (t.data_ptr() && !t.device().is_cuda()) { THROW_INVALID_USAGE("Model inputs must be on GPU when enable_cuda_graphs is true."); } } @@ -344,7 +348,7 @@ class TrainingGraphState { sig.ptrs.push_back(t.data_ptr()); sig.shapes.push_back(t.sizes().vec()); sig.dtypes.push_back(t.scalar_type()); - if (!t.device().is_cuda()) { + if (t.data_ptr() && !t.device().is_cuda()) { THROW_INVALID_USAGE("Model labels must be on GPU when enable_cuda_graphs is true."); } } @@ -352,7 +356,7 @@ class TrainingGraphState { sig.ptrs.push_back(t.data_ptr()); sig.shapes.push_back(t.sizes().vec()); sig.dtypes.push_back(t.scalar_type()); - if (!t.device().is_cuda()) { + if (t.data_ptr() && !t.device().is_cuda()) { THROW_INVALID_USAGE("Model extra args must be on GPU when enable_cuda_graphs is true."); } }