diff --git a/cext/tile_kernel.cpp b/cext/tile_kernel.cpp index c9578fe..500d599 100644 --- a/cext/tile_kernel.cpp +++ b/cext/tile_kernel.cpp @@ -463,12 +463,14 @@ static Result classify_arg(PyObject* arg) { return PythonArgKind::TorchTensorDlpack; } - if (PyObject_HasAttr(arg, g___dlpack___pyunicode)) - return PythonArgKind::DlpackArray; - + // Prefer __cuda_array_interface__ so objects like CuPy arrays can continue to use the + // supported direct pointer path even if they also expose __dlpack__. if (PyObject_HasAttr(arg, g___cuda_array_interface___pyunicode)) return PythonArgKind::CudaArray; + if (PyObject_HasAttr(arg, g___dlpack___pyunicode)) + return PythonArgKind::DlpackArray; + return raise(PyExc_TypeError, "Unsupported argument type %s", Py_TYPE(arg)->tp_name); } @@ -846,7 +848,7 @@ static PyPtr parse_array_constraint(ConstantCursor& cursor) { static Result arrayrepr_cuda_array_iface(PyObject* pyobj, unsigned index_bitwidth, - Arena& arena) { + Arena& arena, LaunchHelper&) { PyPtr dict = steal(PyObject_GetAttr(pyobj, g___cuda_array_interface___pyunicode)); if (!PyDict_Check(dict.get())) { PyErr_SetString(PyExc_TypeError, @@ -924,7 +926,7 @@ static Result arrayrepr_cuda_array_iface(PyObject* pyobj, unsigned in } static Result arrayrepr_dlpack_common(PyObject* dlpack_capsule, unsigned index_bitwidth, - Arena& arena) { + Arena& arena, LaunchHelper&) { void* ptr = PyCapsule_GetPointer(dlpack_capsule, "dltensor"); if (!ptr) return ErrorRaised; DLManagedTensor* tensor = static_cast(ptr); @@ -972,15 +974,10 @@ static Result arrayrepr_dlpack_common(PyObject* dlpack_capsule, unsig PyCapsule_SetName(dlpack_capsule, "used_dltensor"); - // We assume that __dlpack__ returns a view of the tensor, - // so we release the capsule immediately. This should be OK for using with PyTorch - // since it always returns a view. - // - // This is technically an incorrect implementation. To do it correctly, we would - // need to implement a mechanism similar to the one found in Torch's CUDACachingAllocator: - // instead of calling the deleter immediately, we would push a cudaEvent to the stream - // after we launch the kernel, and only call the deleter once the event is ready. - tensor->deleter(tensor); + // Today this path is only used for torch._C._to_dlpack(), which returns a view. + // Generic __dlpack__ objects are rejected until we support their lifetime contract. + if (tensor->deleter) + tensor->deleter(tensor); return ret; } @@ -996,7 +993,7 @@ static Result dtype_from_torch_dtype(PyObject* torch_dtype) { } static Result arrayrepr_torch_tensor_pymethod(PyObject* tensor, unsigned index_bitwidth, - Arena& arena) { + Arena& arena, LaunchHelper&) { PyPtr data_ptr = steal(PyObject_CallMethod(tensor, "data_ptr", nullptr)); if (!data_ptr) return ErrorRaised; @@ -1076,50 +1073,27 @@ static Result arrayrepr_torch_tensor_pymethod(PyObject* tensor, unsig } static Result arrayrepr_torch_tensor_dlpack(PyObject* pyobj, unsigned index_bitwidth, - Arena& arena) { + Arena& arena, LaunchHelper& helper) { PyPtr dlpack_capsule = steal(PyObject_CallFunctionObjArgs( g_torch_to_dlpack_func, pyobj, nullptr)); if (!dlpack_capsule) { SavedException exc = save_raised_exception(); LOG_PYTHON_ERROR("debug", exc, "Fail to convert to dlpack, use fallback path"); - return arrayrepr_torch_tensor_pymethod(pyobj, index_bitwidth, arena); + return arrayrepr_torch_tensor_pymethod(pyobj, index_bitwidth, arena, helper); } - return arrayrepr_dlpack_common(dlpack_capsule.get(), index_bitwidth, arena); + return arrayrepr_dlpack_common(dlpack_capsule.get(), index_bitwidth, arena, helper); } -static Result arrayrepr_dlpack(PyObject* pyobj, unsigned index_bitwidth, - Arena& arena) { - PyPtr dlpack_method = steal(PyObject_GetAttr(pyobj, g___dlpack___pyunicode)); - if (!dlpack_method) return ErrorRaised; - - PyPtr empty_args = steal(PyTuple_New(0)); - if (!empty_args) return ErrorRaised; - PyPtr kwargs = steal(PyDict_New()); - if (!kwargs) return ErrorRaised; - - // stream -1 signals "producer must not perform any synchronization" - PyPtr stream_value = steal(PyLong_FromLong(-1)); - if (!stream_value) return ErrorRaised; - PyDict_SetItemString(kwargs.get(), "stream", stream_value.get()); - - PyPtr dlpack_capsule = steal(PyObject_Call( - dlpack_method.get(), empty_args.get(), kwargs.get())); - if (!dlpack_capsule) return ErrorRaised; - - return arrayrepr_dlpack_common(dlpack_capsule.get(), index_bitwidth, arena); -} - - -typedef Result (*ArrayReprFunc)(PyObject*, unsigned, Arena&); +typedef Result (*ArrayReprFunc)(PyObject*, unsigned, Arena&, LaunchHelper&); template static Status extract_array(const DriverApi* driver, PyObject* pyobj, unsigned index_bitwidth, LaunchHelper& helper) { - Result ar = F(pyobj, index_bitwidth, helper.arena); + Result ar = F(pyobj, index_bitwidth, helper.arena, helper); if (!ar.is_ok()) return ErrorRaised; size_t num_words = 1 + 2 * ar->arrty.ndim; @@ -1249,14 +1223,17 @@ static PyPtr parse_pyfloat_constraint(ConstantCursor& cursor, bool is_constant) } static Result get_array_repr(PythonArgKind kind, PyObject* pyobj, - unsigned index_bitwidth, Arena& arena) { + unsigned index_bitwidth, bool stream_is_capturing, + LaunchHelper& helper) { switch (kind) { case PythonArgKind::TorchTensorDlpack: - return arrayrepr_torch_tensor_dlpack(pyobj, index_bitwidth, arena); + if (stream_is_capturing) + return arrayrepr_torch_tensor_pymethod(pyobj, index_bitwidth, helper.arena, helper); + return arrayrepr_torch_tensor_dlpack(pyobj, index_bitwidth, helper.arena, helper); case PythonArgKind::DlpackArray: - return arrayrepr_dlpack(pyobj, index_bitwidth, arena); + return raise(PyExc_RuntimeError, "__dlpack__ array arguments aren't supported yet"); case PythonArgKind::CudaArray: - return arrayrepr_cuda_array_iface(pyobj, index_bitwidth, arena); + return arrayrepr_cuda_array_iface(pyobj, index_bitwidth, helper.arena, helper); default: return raise(PyExc_AssertionError, "Unexpected argument kind for array: %d", static_cast(kind)); @@ -1264,7 +1241,7 @@ static Result get_array_repr(PythonArgKind kind, PyObject* pyobj, } static Status extract_py_list(const DriverApi* driver, PyObject* pyobj, unsigned index_bitwidth, - LaunchHelper& helper) { + bool stream_is_capturing, LaunchHelper& helper) { size_t len = PyList_GET_SIZE(pyobj); if (len > INT32_MAX) return raise(PyExc_TypeError, "List is too long"); @@ -1288,7 +1265,7 @@ static Status extract_py_list(const DriverApi* driver, PyObject* pyobj, unsigned PyTypeObject* first_item_type = first_item->ob_type; Result first_repr_res = get_array_repr(first_arg_kind, first_item, index_bitwidth, - helper.arena); + stream_is_capturing, helper); if (!first_repr_res.is_ok()) return ErrorRaised; Word* item_pointers = helper.arena.alloc(len); @@ -1322,7 +1299,8 @@ static Status extract_py_list(const DriverApi* driver, PyObject* pyobj, unsigned kind = *res; } - Result repr_res = get_array_repr(kind, item, index_bitwidth, helper.arena); + Result repr_res = get_array_repr(kind, item, index_bitwidth, + stream_is_capturing, helper); if (!repr_res.is_ok()) return ErrorRaised; item_pointers[i].arena_ptr = repr_res->repr; @@ -1372,6 +1350,7 @@ static Status extract_cuda_args(const DriverApi* driver, const Vec& constant_arg_flags, const Vec& int64_index_flags, const Vec& int64_param_flags, + bool stream_is_capturing, LaunchHelper& helper) { CHECK(num_pyargs == arg_kinds.size()); helper.arena.clear(); @@ -1388,14 +1367,18 @@ static Status extract_cuda_args(const DriverApi* driver, switch (arg_kinds[i]) { case PythonArgKind::TorchTensorDlpack: - if (!extract_array( - driver, pyobj, index_bitwidth, helper)) + if (stream_is_capturing) { + if (!extract_array( + driver, pyobj, index_bitwidth, helper)) { + return ErrorRaised; + } + } else if (!extract_array( + driver, pyobj, index_bitwidth, helper)) { return ErrorRaised; + } break; case PythonArgKind::DlpackArray: - if (!extract_array(driver, pyobj, index_bitwidth, helper)) - return ErrorRaised; - break; + return raise(PyExc_RuntimeError, "__dlpack__ array arguments aren't supported yet"); case PythonArgKind::CudaArray: if (!extract_array(driver, pyobj, index_bitwidth, helper)) return ErrorRaised; @@ -1410,7 +1393,10 @@ static Status extract_cuda_args(const DriverApi* driver, if (!extract_py_bool(pyobj, is_constant, helper)) return ErrorRaised; break; case PythonArgKind::PyList: - if (!extract_py_list(driver, pyobj, index_bitwidth, helper)) return ErrorRaised; + if (!extract_py_list(driver, pyobj, index_bitwidth, + stream_is_capturing, helper)) { + return ErrorRaised; + } break; } } @@ -2040,6 +2026,16 @@ struct PreparedLaunch { unsigned dynamic_smem_bytes; }; +static bool needs_stream_capture_status(const Vec& arg_kinds) { + for (PythonArgKind kind : arg_kinds) { + if (kind == PythonArgKind::TorchTensorDlpack + || kind == PythonArgKind::PyList) { + return true; + } + } + return false; +} + static Result get_stream_context(const DriverApi* driver, CUstream stream) { CUcontext ctx = nullptr; CUresult res = driver->cuStreamGetCtx(stream, &ctx); @@ -2053,6 +2049,16 @@ static Result get_stream_context(const DriverApi* driver, CUstream st return ctx; } +static Result get_stream_capture_status(const DriverApi* driver, CUstream stream) { + CUstreamCaptureStatus status = CU_STREAM_CAPTURE_STATUS_NONE; + CUresult res = driver->cuStreamIsCapturing(stream, &status); + if (res != CUDA_SUCCESS) { + return raise(PyExc_RuntimeError, "Failed to check stream capturing status: %s", + get_cuda_error(driver, res)); + } + return status != CU_STREAM_CAPTURE_STATUS_NONE; +} + static Result prepare_launch( const DriverApi* driver, PyObject* dispatcher_pyobj, @@ -2107,9 +2113,16 @@ static Result prepare_launch( PythonArgProfile{family_item->value, std::move(arg_kinds)}); } + bool stream_is_capturing = false; + if (needs_stream_capture_status(profile_item->value.arg_kinds)) { + Result capture_status = get_stream_capture_status(driver, launch_stream); + if (!capture_status.is_ok()) return ErrorRaised; + stream_is_capturing = *capture_status; + } + if (!extract_cuda_args(driver, pyargs, num_pyargs, profile_item->value.arg_kinds, dispatcher.constant_arg_flags, dispatcher.int64_index_flags, - dispatcher.int64_param_flags, *helper)) { + dispatcher.int64_param_flags, stream_is_capturing, *helper)) { return ErrorRaised; } @@ -2143,12 +2156,7 @@ static Result prepare_launch( // Handle list arguments if (!helper->list_args.empty()) { if (!tx) { - CUstreamCaptureStatus status; - CUresult res = driver->cuStreamIsCapturing(launch_stream, &status); - if (res != CUDA_SUCCESS) - return raise(PyExc_RuntimeError, "Failed to check stream capturing status: %s", - get_cuda_error(driver, res)); - if (status != CU_STREAM_CAPTURE_STATUS_NONE) + if (stream_is_capturing) return raise(PyExc_RuntimeError, "List argument in CUDAGraph isn't supported yet"); Result pool_res = get_stream_buffer_pool(driver, @@ -2971,4 +2979,3 @@ Status tile_kernel_init(PyObject* m) { return OK; } - diff --git a/test/test_cudagraph.py b/test/test_cudagraph.py index b39830d..079bfb0 100644 --- a/test/test_cudagraph.py +++ b/test/test_cudagraph.py @@ -7,6 +7,17 @@ import pytest +class DlpackProxy: + def __init__(self, tensor): + self.tensor = tensor + + def __dlpack__(self, stream=None): + return self.tensor.__dlpack__(stream=stream) + + def __dlpack_device__(self): + return self.tensor.__dlpack_device__() + + @ct.kernel def add_one(x): xi = ct.load(x, 0, ()) @@ -27,6 +38,20 @@ def test_simple(): assert x.item() == 10 +def test_proxy_dlpack_unsupported(): + x = torch.zeros(1, device='cuda') + with pytest.raises(RuntimeError, match=r"__dlpack__ array arguments aren't supported yet"): + ct.launch(torch.cuda.current_stream(), (1,), add_one, (DlpackProxy(x),)) + + +def test_proxy_dlpack_cudagraph_unsupported(): + x = torch.zeros(1, device='cuda') + graph = torch.cuda.CUDAGraph() + with pytest.raises(RuntimeError, match=r"__dlpack__ array arguments aren't supported yet"): + with torch.cuda.graph(graph): + ct.launch(torch.cuda.current_stream(), (1,), add_one, (DlpackProxy(x),)) + + @ct.kernel def matmul_accumulate(x, y, z): acc = ct.load(z, (0, 0), (16, 16))