diff --git a/source/neuropod/backends/torchscript/torch_backend.cc b/source/neuropod/backends/torchscript/torch_backend.cc index 70c0fe6f..d06a70b9 100644 --- a/source/neuropod/backends/torchscript/torch_backend.cc +++ b/source/neuropod/backends/torchscript/torch_backend.cc @@ -7,6 +7,10 @@ #include "neuropod/backends/torchscript/type_utils.hh" #include "neuropod/internal/tensor_types.hh" +#ifndef __APPLE__ +#include +#endif + #include #include @@ -291,6 +295,16 @@ std::unique_ptr TorchNeuropodBackend::infer_internal(const Neu { torch::NoGradGuard guard; +#ifndef __APPLE__ + // Make sure we're running on the correct device + std::unique_ptr device_guard; + const auto model_device = get_torch_device(DeviceType::GPU); + if (model_device.is_cuda()) + { + device_guard = stdx::make_unique(model_device); + } +#endif + // Get inference schema const auto &method = model_->get_method("forward"); const auto &schema = SCHEMA(method);