Skip to content

Commit f6d03bb

Browse files
committed
Make doctest extension work across python versions
Signed-off-by: Jay Gu <jagu@nvidia.com>
1 parent c0eba43 commit f6d03bb

2 files changed

Lines changed: 34 additions & 27 deletions

File tree

docs/source/_ext/doctest_ext.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -125,12 +125,15 @@ class FdCaptureDocTestRunner(SphinxDocTestRunner):
125125
(failure reports) reaches the terminal instead of the temp file.
126126
"""
127127

128+
def __init__(self, *args, **kwargs):
129+
super().__init__(*args, **kwargs)
130+
self._test_stats = {}
131+
128132
def summarize(self, out, verbose=None):
129-
groups = self._name2ft
130133
totalf = totalt = 0
131134
lines = []
132-
only_default = set(groups.keys()) <= {'default'}
133-
for name, (f, t) in sorted(groups.items()):
135+
only_default = set(self._test_stats.keys()) <= {'default'}
136+
for name, (f, t) in sorted(self._test_stats.items()):
134137
totalf += f
135138
totalt += t
136139
if not only_default:
@@ -156,7 +159,10 @@ def run(self, test, compileflags=None, out=None, clear_globs=True):
156159
orig_logger = _mod.logger
157160
_mod.logger = _TerminalLogger(orig_logger, saved_fd)
158161
try:
159-
return super().run(test, compileflags, out, clear_globs)
162+
result = super().run(test, compileflags, out, clear_globs)
163+
f, t = self._test_stats.get(test.name, (0, 0))
164+
self._test_stats[test.name] = (f + result.failed, t + result.attempted)
165+
return result
160166
finally:
161167
_mod.logger = orig_logger
162168
self._fakeout = save_fakeout

src/cuda/tile/_stub.py

Lines changed: 24 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1193,22 +1193,22 @@ def _doc_atomic_rmw_op(f):
11931193
11941194
Examples:
11951195
1196-
.. testcode::
1197-
:template: setup_only.py
1196+
.. testcode::
1197+
:template: setup_only.py
11981198
1199-
@ct.kernel
1200-
def kernel(x):
1201-
indices = ct.arange(4, dtype=ct.int32)
1202-
update = ct.full((4,), 10, dtype=ct.int32)
1203-
old = ct.{op_name}(x, indices, update)
1204-
print(old)
1199+
@ct.kernel
1200+
def kernel(x):
1201+
indices = ct.arange(4, dtype=ct.int32)
1202+
update = ct.full((4,), 10, dtype=ct.int32)
1203+
old = ct.{op_name}(x, indices, update)
1204+
print(old)
12051205
1206-
x = torch.ones(4, dtype=torch.int32, device='cuda')
1207-
ct.launch(stream, (1,), kernel, (x,))
1206+
x = torch.ones(4, dtype=torch.int32, device='cuda')
1207+
ct.launch(stream, (1,), kernel, (x,))
12081208
1209-
.. testoutput::
1209+
.. testoutput::
12101210
1211-
[1, 1, 1, 1]
1211+
[1, 1, 1, 1]
12121212
"""
12131213

12141214
return f
@@ -1930,20 +1930,21 @@ def wrapped(*args, **kwargs):
19301930
orig_doc = f.__doc__ or ""
19311931
extra_block = _math_op_extra_block(f, indent=" ")
19321932

1933-
wrapped.__doc__ = f"""Performs {op_name} reduction on tile along the `axis`.
1933+
wrapped.__doc__ = f"""\
1934+
Performs {op_name} reduction on tile along the `axis`.
19341935
1935-
Args:
1936-
x (Tile): input tile.
1937-
axis (None | const int | tuple[const int,...]): the axis for reduction.
1938-
The default, `axis=None`, will reduce all of the elements.
1939-
For `argmin` and `argmax`, tuple of axis is not supported.
1940-
keepdims (const bool): If true, preserves the number of dimension
1941-
from the input tile.{extra_block}
1936+
Args:
1937+
x (Tile): input tile.
1938+
axis (None | const int | tuple[const int,...]): the axis for reduction.
1939+
The default, `axis=None`, will reduce all of the elements.
1940+
For `argmin` and `argmax`, tuple of axis is not supported.
1941+
keepdims (const bool): If true, preserves the number of dimension
1942+
from the input tile.{extra_block}
19421943
1943-
Returns:
1944-
Tile:
1944+
Returns:
1945+
Tile:
19451946
1946-
""" + orig_doc
1947+
""" + orig_doc
19471948

19481949
return wrapped
19491950

0 commit comments

Comments
 (0)