-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathTest.py
More file actions
50 lines (40 loc) · 1.51 KB
/
Test.py
File metadata and controls
50 lines (40 loc) · 1.51 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import slangpy as spy
import numpy as np
from pathlib import Path
def testWaveOps(funcName):
print(funcName)
EXAMPLE_DIR = Path(__file__).parent
# Create device
device = spy.Device(
enable_debug_layers=True,
compiler_options={"include_paths": [EXAMPLE_DIR]},
)
# Load test program
test_program = device.load_program("TestWaveOps.slang", [funcName])
test_kernel = device.create_compute_kernel(test_program)
# Create output buffer - we'll test with 32 threads (1 wave)
num_threads = 64
output_buffer = device.create_buffer(
struct_size = 4,
usage = spy.BufferUsage.unordered_access | spy.BufferUsage.copy_source | spy.BufferUsage.shader_resource,
size = 4 * num_threads,
)
# Run the test
command_encoder = device.create_command_encoder()
with command_encoder.begin_compute_pass() as pass_encoder:
shader_object = pass_encoder.bind_pipeline(test_kernel.pipeline)
cursor = spy.ShaderCursor(shader_object)
cursor.gOutputBuffer = output_buffer
pass_encoder.dispatch([num_threads, 1, 1])
device.submit_command_buffer(command_encoder.finish())
device.wait_for_idle()
# Read back results
results = output_buffer.to_numpy()
results_filtered = results[::4]
print(results_filtered)
def main():
testWaveOps('testWaveGetLaneIndex')
testWaveOps('testWaveActiveCountBits')
testWaveOps('testWavePrefixCountBits')
if __name__ == "__main__":
main()