Skip to content

Commit 480fffc

Browse files
authored
Merge pull request #15 from OpenBioSim/fix_14
Fix issue #14
2 parents 90f0696 + ff787c1 commit 480fffc

2 files changed

Lines changed: 103 additions & 5 deletions

File tree

src/loch/_platforms/_cuda.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@ def __init__(
114114
# Use the primary context (shared with OpenMM and other CUDA users).
115115
self._pycuda_context = self._cuda_device.retain_primary_context()
116116
self._pycuda_context.push()
117+
self._push_count = 1
117118

118119
self._device = self._pycuda_context.get_device()
119120

@@ -256,22 +257,26 @@ def push_context(self):
256257
Push the primary context onto the calling thread's context stack.
257258
"""
258259
self._pycuda_context.push()
260+
self._push_count += 1
259261

260262
def pop_context(self):
261263
"""
262264
Pop the primary context from the calling thread's context stack.
263265
"""
264266
self._pycuda_context.pop()
267+
self._push_count -= 1
265268

266269
def cleanup(self):
267270
"""
268-
Clean up CUDA resources and pop the context pushed during __init__.
271+
Clean up CUDA resources and pop all outstanding context pushes.
269272
"""
270273
if self._pycuda_context is not None:
271-
try:
272-
self._pycuda_context.pop()
273-
except Exception:
274-
pass
274+
for _ in range(self._push_count):
275+
try:
276+
self._pycuda_context.pop()
277+
except Exception:
278+
pass
279+
self._push_count = 0
275280
self._pycuda_context = None
276281

277282
@property

tests/test_platform.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
from unittest.mock import MagicMock, patch
2+
3+
import pytest
4+
5+
# Skip the entire module if PyCUDA is not installed.
6+
pytest.importorskip("pycuda")
7+
8+
9+
def _make_backend(mock_driver):
10+
"""Instantiate CUDAPlatform with a mocked PyCUDA driver."""
11+
from loch._platforms._cuda import CUDAPlatform
12+
13+
mock_driver.Device.count.return_value = 1
14+
mock_context = MagicMock()
15+
mock_driver.Device.return_value.retain_primary_context.return_value = mock_context
16+
17+
backend = CUDAPlatform(
18+
device=0,
19+
num_points=3,
20+
num_batch=10,
21+
num_waters=5,
22+
num_atoms=100,
23+
num_threads=32,
24+
)
25+
return backend, mock_context
26+
27+
28+
class TestCUDAPushCount:
29+
"""Tests for CUDAPlatform context push-count tracking."""
30+
31+
def test_initial_push_count(self):
32+
"""Push count starts at 1 after __init__ (one push for the lifetime context)."""
33+
with patch("loch._platforms._cuda._cuda") as mock_driver:
34+
backend, mock_context = _make_backend(mock_driver)
35+
assert backend._push_count == 1
36+
mock_context.push.assert_called_once()
37+
38+
def test_push_increments_count(self):
39+
"""push_context() increments _push_count."""
40+
with patch("loch._platforms._cuda._cuda") as mock_driver:
41+
backend, _ = _make_backend(mock_driver)
42+
backend.push_context()
43+
assert backend._push_count == 2
44+
backend.push_context()
45+
assert backend._push_count == 3
46+
47+
def test_pop_decrements_count(self):
48+
"""pop_context() decrements _push_count."""
49+
with patch("loch._platforms._cuda._cuda") as mock_driver:
50+
backend, _ = _make_backend(mock_driver)
51+
backend.push_context()
52+
backend.pop_context()
53+
assert backend._push_count == 1
54+
55+
def test_cleanup_pops_once_normally(self):
56+
"""cleanup() pops exactly once when no extra pushes are outstanding."""
57+
with patch("loch._platforms._cuda._cuda") as mock_driver:
58+
backend, mock_context = _make_backend(mock_driver)
59+
backend.cleanup()
60+
assert mock_context.pop.call_count == 1
61+
assert backend._push_count == 0
62+
assert backend._pycuda_context is None
63+
64+
def test_cleanup_pops_all_outstanding(self):
65+
"""cleanup() pops all outstanding pushes, simulating a crash mid-move."""
66+
with patch("loch._platforms._cuda._cuda") as mock_driver:
67+
backend, mock_context = _make_backend(mock_driver)
68+
# Simulate two push_context() calls that were never popped (e.g.
69+
# two GCMC moves crashed before their paired pop_context()).
70+
backend.push_context()
71+
backend.push_context()
72+
assert backend._push_count == 3
73+
backend.cleanup()
74+
assert mock_context.pop.call_count == 3
75+
assert backend._push_count == 0
76+
assert backend._pycuda_context is None
77+
78+
def test_cleanup_handles_pop_exception(self):
79+
"""cleanup() continues safely if pop() raises (e.g. stack already empty)."""
80+
with patch("loch._platforms._cuda._cuda") as mock_driver:
81+
backend, mock_context = _make_backend(mock_driver)
82+
mock_context.pop.side_effect = Exception("context stack is empty")
83+
backend.cleanup()
84+
assert backend._pycuda_context is None
85+
86+
def test_cleanup_idempotent(self):
87+
"""Calling cleanup() a second time is a no-op."""
88+
with patch("loch._platforms._cuda._cuda") as mock_driver:
89+
backend, mock_context = _make_backend(mock_driver)
90+
backend.cleanup()
91+
pop_count_after_first = mock_context.pop.call_count
92+
backend.cleanup()
93+
assert mock_context.pop.call_count == pop_count_after_first

0 commit comments

Comments
 (0)