Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 71 additions & 64 deletions cext/tile_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -463,12 +463,14 @@ static Result<PythonArgKind> classify_arg(PyObject* arg) {
return PythonArgKind::TorchTensorDlpack;
}

if (PyObject_HasAttr(arg, g___dlpack___pyunicode))
return PythonArgKind::DlpackArray;

// Prefer __cuda_array_interface__ so objects like CuPy arrays can continue to use the
// supported direct pointer path even if they also expose __dlpack__.
if (PyObject_HasAttr(arg, g___cuda_array_interface___pyunicode))
return PythonArgKind::CudaArray;

if (PyObject_HasAttr(arg, g___dlpack___pyunicode))
return PythonArgKind::DlpackArray;

return raise(PyExc_TypeError, "Unsupported argument type %s", Py_TYPE(arg)->tp_name);
}

Expand Down Expand Up @@ -846,7 +848,7 @@ static PyPtr parse_array_constraint(ConstantCursor& cursor) {


static Result<ArrayRepr> arrayrepr_cuda_array_iface(PyObject* pyobj, unsigned index_bitwidth,
Arena<Word>& arena) {
Arena<Word>& arena, LaunchHelper&) {
PyPtr dict = steal(PyObject_GetAttr(pyobj, g___cuda_array_interface___pyunicode));
if (!PyDict_Check(dict.get())) {
PyErr_SetString(PyExc_TypeError,
Expand Down Expand Up @@ -924,7 +926,7 @@ static Result<ArrayRepr> arrayrepr_cuda_array_iface(PyObject* pyobj, unsigned in
}

static Result<ArrayRepr> arrayrepr_dlpack_common(PyObject* dlpack_capsule, unsigned index_bitwidth,
Arena<Word>& arena) {
Arena<Word>& arena, LaunchHelper&) {
void* ptr = PyCapsule_GetPointer(dlpack_capsule, "dltensor");
if (!ptr) return ErrorRaised;
DLManagedTensor* tensor = static_cast<DLManagedTensor*>(ptr);
Expand Down Expand Up @@ -972,15 +974,10 @@ static Result<ArrayRepr> arrayrepr_dlpack_common(PyObject* dlpack_capsule, unsig

PyCapsule_SetName(dlpack_capsule, "used_dltensor");

// We assume that __dlpack__ returns a view of the tensor,
// so we release the capsule immediately. This should be OK for using with PyTorch
// since it always returns a view.
//
// This is technically an incorrect implementation. To do it correctly, we would
// need to implement a mechanism similar to the one found in Torch's CUDACachingAllocator:
// instead of calling the deleter immediately, we would push a cudaEvent to the stream
// after we launch the kernel, and only call the deleter once the event is ready.
tensor->deleter(tensor);
// Today this path is only used for torch._C._to_dlpack(), which returns a view.
// Generic __dlpack__ objects are rejected until we support their lifetime contract.
if (tensor->deleter)
tensor->deleter(tensor);
return ret;
}

Expand All @@ -996,7 +993,7 @@ static Result<DLDataType> dtype_from_torch_dtype(PyObject* torch_dtype) {
}

static Result<ArrayRepr> arrayrepr_torch_tensor_pymethod(PyObject* tensor, unsigned index_bitwidth,
Arena<Word>& arena) {
Arena<Word>& arena, LaunchHelper&) {
PyPtr data_ptr = steal(PyObject_CallMethod(tensor, "data_ptr", nullptr));
if (!data_ptr) return ErrorRaised;

Expand Down Expand Up @@ -1076,50 +1073,27 @@ static Result<ArrayRepr> arrayrepr_torch_tensor_pymethod(PyObject* tensor, unsig
}

static Result<ArrayRepr> arrayrepr_torch_tensor_dlpack(PyObject* pyobj, unsigned index_bitwidth,
Arena<Word>& arena) {
Arena<Word>& arena, LaunchHelper& helper) {
PyPtr dlpack_capsule = steal(PyObject_CallFunctionObjArgs(
g_torch_to_dlpack_func, pyobj, nullptr));

if (!dlpack_capsule) {
SavedException exc = save_raised_exception();
LOG_PYTHON_ERROR("debug", exc, "Fail to convert to dlpack, use fallback path");
return arrayrepr_torch_tensor_pymethod(pyobj, index_bitwidth, arena);
return arrayrepr_torch_tensor_pymethod(pyobj, index_bitwidth, arena, helper);
}

return arrayrepr_dlpack_common(dlpack_capsule.get(), index_bitwidth, arena);
return arrayrepr_dlpack_common(dlpack_capsule.get(), index_bitwidth, arena, helper);
}

static Result<ArrayRepr> arrayrepr_dlpack(PyObject* pyobj, unsigned index_bitwidth,
Arena<Word>& arena) {
PyPtr dlpack_method = steal(PyObject_GetAttr(pyobj, g___dlpack___pyunicode));
if (!dlpack_method) return ErrorRaised;

PyPtr empty_args = steal(PyTuple_New(0));
if (!empty_args) return ErrorRaised;

PyPtr kwargs = steal(PyDict_New());
if (!kwargs) return ErrorRaised;

// stream -1 signals "producer must not perform any synchronization"
PyPtr stream_value = steal(PyLong_FromLong(-1));
if (!stream_value) return ErrorRaised;
PyDict_SetItemString(kwargs.get(), "stream", stream_value.get());

PyPtr dlpack_capsule = steal(PyObject_Call(
dlpack_method.get(), empty_args.get(), kwargs.get()));
if (!dlpack_capsule) return ErrorRaised;

return arrayrepr_dlpack_common(dlpack_capsule.get(), index_bitwidth, arena);
}


typedef Result<ArrayRepr> (*ArrayReprFunc)(PyObject*, unsigned, Arena<Word>&);
typedef Result<ArrayRepr> (*ArrayReprFunc)(PyObject*, unsigned, Arena<Word>&, LaunchHelper&);


template <ArrayReprFunc F>
static Status extract_array(const DriverApi* driver, PyObject* pyobj, unsigned index_bitwidth,
LaunchHelper& helper) {
Result<ArrayRepr> ar = F(pyobj, index_bitwidth, helper.arena);
Result<ArrayRepr> ar = F(pyobj, index_bitwidth, helper.arena, helper);
if (!ar.is_ok()) return ErrorRaised;

size_t num_words = 1 + 2 * ar->arrty.ndim;
Expand Down Expand Up @@ -1249,22 +1223,25 @@ static PyPtr parse_pyfloat_constraint(ConstantCursor& cursor, bool is_constant)
}

static Result<ArrayRepr> get_array_repr(PythonArgKind kind, PyObject* pyobj,
unsigned index_bitwidth, Arena<Word>& arena) {
unsigned index_bitwidth, bool stream_is_capturing,
LaunchHelper& helper) {
switch (kind) {
case PythonArgKind::TorchTensorDlpack:
return arrayrepr_torch_tensor_dlpack(pyobj, index_bitwidth, arena);
if (stream_is_capturing)
return arrayrepr_torch_tensor_pymethod(pyobj, index_bitwidth, helper.arena, helper);
return arrayrepr_torch_tensor_dlpack(pyobj, index_bitwidth, helper.arena, helper);
case PythonArgKind::DlpackArray:
return arrayrepr_dlpack(pyobj, index_bitwidth, arena);
return raise(PyExc_RuntimeError, "__dlpack__ array arguments aren't supported yet");
case PythonArgKind::CudaArray:
return arrayrepr_cuda_array_iface(pyobj, index_bitwidth, arena);
return arrayrepr_cuda_array_iface(pyobj, index_bitwidth, helper.arena, helper);
default:
return raise(PyExc_AssertionError, "Unexpected argument kind for array: %d",
static_cast<int>(kind));
}
}

static Status extract_py_list(const DriverApi* driver, PyObject* pyobj, unsigned index_bitwidth,
LaunchHelper& helper) {
bool stream_is_capturing, LaunchHelper& helper) {
size_t len = PyList_GET_SIZE(pyobj);
if (len > INT32_MAX)
return raise(PyExc_TypeError, "List is too long");
Expand All @@ -1288,7 +1265,7 @@ static Status extract_py_list(const DriverApi* driver, PyObject* pyobj, unsigned
PyTypeObject* first_item_type = first_item->ob_type;

Result<ArrayRepr> first_repr_res = get_array_repr(first_arg_kind, first_item, index_bitwidth,
helper.arena);
stream_is_capturing, helper);
if (!first_repr_res.is_ok()) return ErrorRaised;

Word* item_pointers = helper.arena.alloc(len);
Expand Down Expand Up @@ -1322,7 +1299,8 @@ static Status extract_py_list(const DriverApi* driver, PyObject* pyobj, unsigned
kind = *res;
}

Result<ArrayRepr> repr_res = get_array_repr(kind, item, index_bitwidth, helper.arena);
Result<ArrayRepr> repr_res = get_array_repr(kind, item, index_bitwidth,
stream_is_capturing, helper);
if (!repr_res.is_ok()) return ErrorRaised;
item_pointers[i].arena_ptr = repr_res->repr;

Expand Down Expand Up @@ -1372,6 +1350,7 @@ static Status extract_cuda_args(const DriverApi* driver,
const Vec<bool>& constant_arg_flags,
const Vec<bool>& int64_index_flags,
const Vec<bool>& int64_param_flags,
bool stream_is_capturing,
LaunchHelper& helper) {
CHECK(num_pyargs == arg_kinds.size());
helper.arena.clear();
Expand All @@ -1388,14 +1367,18 @@ static Status extract_cuda_args(const DriverApi* driver,

switch (arg_kinds[i]) {
case PythonArgKind::TorchTensorDlpack:
if (!extract_array<arrayrepr_torch_tensor_dlpack>(
driver, pyobj, index_bitwidth, helper))
if (stream_is_capturing) {
if (!extract_array<arrayrepr_torch_tensor_pymethod>(
driver, pyobj, index_bitwidth, helper)) {
return ErrorRaised;
}
} else if (!extract_array<arrayrepr_torch_tensor_dlpack>(
driver, pyobj, index_bitwidth, helper)) {
return ErrorRaised;
}
break;
case PythonArgKind::DlpackArray:
if (!extract_array<arrayrepr_dlpack>(driver, pyobj, index_bitwidth, helper))
return ErrorRaised;
break;
return raise(PyExc_RuntimeError, "__dlpack__ array arguments aren't supported yet");
case PythonArgKind::CudaArray:
if (!extract_array<arrayrepr_cuda_array_iface>(driver, pyobj, index_bitwidth, helper))
return ErrorRaised;
Expand All @@ -1410,7 +1393,10 @@ static Status extract_cuda_args(const DriverApi* driver,
if (!extract_py_bool(pyobj, is_constant, helper)) return ErrorRaised;
break;
case PythonArgKind::PyList:
if (!extract_py_list(driver, pyobj, index_bitwidth, helper)) return ErrorRaised;
if (!extract_py_list(driver, pyobj, index_bitwidth,
stream_is_capturing, helper)) {
return ErrorRaised;
}
break;
}
}
Expand Down Expand Up @@ -2040,6 +2026,16 @@ struct PreparedLaunch {
unsigned dynamic_smem_bytes;
};

static bool needs_stream_capture_status(const Vec<PythonArgKind>& arg_kinds) {
for (PythonArgKind kind : arg_kinds) {
if (kind == PythonArgKind::TorchTensorDlpack
|| kind == PythonArgKind::PyList) {
return true;
}
}
return false;
}

static Result<CUcontext> get_stream_context(const DriverApi* driver, CUstream stream) {
CUcontext ctx = nullptr;
CUresult res = driver->cuStreamGetCtx(stream, &ctx);
Expand All @@ -2053,6 +2049,16 @@ static Result<CUcontext> get_stream_context(const DriverApi* driver, CUstream st
return ctx;
}

static Result<bool> get_stream_capture_status(const DriverApi* driver, CUstream stream) {
CUstreamCaptureStatus status = CU_STREAM_CAPTURE_STATUS_NONE;
CUresult res = driver->cuStreamIsCapturing(stream, &status);
if (res != CUDA_SUCCESS) {
return raise(PyExc_RuntimeError, "Failed to check stream capturing status: %s",
get_cuda_error(driver, res));
}
return status != CU_STREAM_CAPTURE_STATUS_NONE;
}

static Result<PreparedLaunch> prepare_launch(
const DriverApi* driver,
PyObject* dispatcher_pyobj,
Expand Down Expand Up @@ -2107,9 +2113,16 @@ static Result<PreparedLaunch> prepare_launch(
PythonArgProfile{family_item->value, std::move(arg_kinds)});
}

bool stream_is_capturing = false;
if (needs_stream_capture_status(profile_item->value.arg_kinds)) {
Result<bool> capture_status = get_stream_capture_status(driver, launch_stream);
if (!capture_status.is_ok()) return ErrorRaised;
stream_is_capturing = *capture_status;
}

if (!extract_cuda_args(driver, pyargs, num_pyargs, profile_item->value.arg_kinds,
dispatcher.constant_arg_flags, dispatcher.int64_index_flags,
dispatcher.int64_param_flags, *helper)) {
dispatcher.int64_param_flags, stream_is_capturing, *helper)) {
return ErrorRaised;
}

Expand Down Expand Up @@ -2143,12 +2156,7 @@ static Result<PreparedLaunch> prepare_launch(
// Handle list arguments
if (!helper->list_args.empty()) {
if (!tx) {
CUstreamCaptureStatus status;
CUresult res = driver->cuStreamIsCapturing(launch_stream, &status);
if (res != CUDA_SUCCESS)
return raise(PyExc_RuntimeError, "Failed to check stream capturing status: %s",
get_cuda_error(driver, res));
if (status != CU_STREAM_CAPTURE_STATUS_NONE)
if (stream_is_capturing)
return raise(PyExc_RuntimeError, "List argument in CUDAGraph isn't supported yet");

Result<StreamBufferPool*> pool_res = get_stream_buffer_pool(driver,
Expand Down Expand Up @@ -2971,4 +2979,3 @@ Status tile_kernel_init(PyObject* m) {

return OK;
}

25 changes: 25 additions & 0 deletions test/test_cudagraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,17 @@
import pytest


class DlpackProxy:
def __init__(self, tensor):
self.tensor = tensor

def __dlpack__(self, stream=None):
return self.tensor.__dlpack__(stream=stream)

def __dlpack_device__(self):
return self.tensor.__dlpack_device__()


@ct.kernel
def add_one(x):
xi = ct.load(x, 0, ())
Expand All @@ -27,6 +38,20 @@ def test_simple():
assert x.item() == 10


def test_proxy_dlpack_unsupported():
x = torch.zeros(1, device='cuda')
with pytest.raises(RuntimeError, match=r"__dlpack__ array arguments aren't supported yet"):
ct.launch(torch.cuda.current_stream(), (1,), add_one, (DlpackProxy(x),))


def test_proxy_dlpack_cudagraph_unsupported():
x = torch.zeros(1, device='cuda')
graph = torch.cuda.CUDAGraph()
with pytest.raises(RuntimeError, match=r"__dlpack__ array arguments aren't supported yet"):
with torch.cuda.graph(graph):
ct.launch(torch.cuda.current_stream(), (1,), add_one, (DlpackProxy(x),))


@ct.kernel
def matmul_accumulate(x, y, z):
acc = ct.load(z, (0, 0), (16, 16))
Expand Down