Skip to content

Commit 5e576a5

Browse files
authored
feat(displayhook): add register_hook/unregister_hook to ZMQShellDisplayHook (#1522)
2 parents 33a980a + d56c7a3 commit 5e576a5

2 files changed

Lines changed: 252 additions & 3 deletions

File tree

ipykernel/displayhook.py

Lines changed: 45 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,13 @@
66

77
import builtins
88
import sys
9+
import threading
910
import typing as t
1011
from contextvars import ContextVar
1112

1213
from IPython.core.displayhook import DisplayHook
1314
from jupyter_client.session import Session, extract_header
14-
from traitlets import Any, Instance
15+
from traitlets import Any, Instance, default
1516

1617
from ipykernel.jsonutil import encode_images, json_clean
1718

@@ -80,13 +81,41 @@ class ZMQShellDisplayHook(DisplayHook):
8081
session = Instance(Session, allow_none=True)
8182
pub_socket = Any(allow_none=True)
8283
_parent_header: ContextVar[dict[str, Any]]
84+
_thread_local = Any()
8385
msg: dict[str, t.Any] | None
8486

8587
def __init__(self, *args, **kwargs):
8688
super().__init__(*args, **kwargs)
8789
self._parent_header = ContextVar("parent_header")
8890
self._parent_header.set({})
8991

92+
@default("_thread_local")
93+
def _default_thread_local(self):
94+
return threading.local()
95+
96+
@property
97+
def _hooks(self):
98+
if not hasattr(self._thread_local, "hooks"):
99+
self._thread_local.hooks = []
100+
return self._thread_local.hooks
101+
102+
def register_hook(self, hook):
103+
"""Register a transform hook on the execute_result message.
104+
105+
Mirrors ``ZMQDisplayPublisher.register_hook``. Each hook receives the
106+
outbound message dict and must return either a (possibly mutated)
107+
message dict to continue the chain, or ``None`` to suppress the send.
108+
"""
109+
self._hooks.append(hook)
110+
111+
def unregister_hook(self, hook):
112+
"""Remove a previously registered hook. Returns True on success."""
113+
try:
114+
self._hooks.remove(hook)
115+
return True
116+
except ValueError:
117+
return False
118+
90119
@property
91120
def parent_header(self):
92121
try:
@@ -124,9 +153,22 @@ def write_format_data(self, format_dict, md_dict=None):
124153
self.msg["content"]["metadata"] = md_dict
125154

126155
def finish_displayhook(self):
127-
"""Finish up all displayhook activities."""
156+
"""Finish up all displayhook activities.
157+
158+
Runs the registered hook chain before ``session.send``. Each hook
159+
either returns a message (to continue) or ``None`` (to suppress the
160+
send). This mirrors the transform pipeline on
161+
``ZMQDisplayPublisher.publish`` so a single hook implementation can
162+
attach to both the ``display_data`` and ``execute_result`` paths.
163+
"""
128164
sys.stdout.flush()
129165
sys.stderr.flush()
130166
if self.msg and self.msg["content"]["data"] and self.session:
131-
self.session.send(self.pub_socket, self.msg, ident=self.topic)
167+
msg = self.msg
168+
for hook in self._hooks:
169+
msg = hook(msg)
170+
if msg is None:
171+
self.msg = None
172+
return
173+
self.session.send(self.pub_socket, msg, ident=self.topic)
132174
self.msg = None

tests/test_displayhook.py

Lines changed: 207 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,207 @@
1+
"""Tests for the ZMQ execute_result displayhook."""
2+
3+
# Copyright (c) IPython Development Team.
4+
# Distributed under the terms of the Modified BSD License.
5+
6+
import unittest
7+
from queue import Queue
8+
from threading import Thread
9+
10+
import zmq
11+
from IPython.core.interactiveshell import InteractiveShell
12+
from jupyter_client.session import Session
13+
from traitlets import Int
14+
15+
from ipykernel.displayhook import ZMQShellDisplayHook
16+
17+
18+
class NoReturnHook:
19+
call_count = 0
20+
21+
def __call__(self, msg):
22+
self.call_count += 1
23+
24+
25+
class ReturnHook(NoReturnHook):
26+
def __call__(self, msg):
27+
super().__call__(msg)
28+
return msg
29+
30+
31+
class MutatingHook(NoReturnHook):
32+
"""Attaches a buffer to the message and returns it."""
33+
34+
def __call__(self, msg):
35+
super().__call__(msg)
36+
msg.setdefault("buffers", []).append(b"arrow-bytes")
37+
return msg
38+
39+
40+
class CounterSession(Session):
41+
send_count = Int(0)
42+
last_msg = None
43+
44+
def send(self, *args, **kwargs):
45+
self.send_count += 1
46+
# args: (stream, msg_or_type, ...)
47+
if len(args) >= 2:
48+
self.last_msg = args[1]
49+
return super().send(*args, **kwargs)
50+
51+
52+
def _drive(hook, data=None):
53+
"""Run a single execute_result emission through the hook."""
54+
if data is None:
55+
data = {"text/plain": "1"}
56+
hook.start_displayhook()
57+
hook.write_format_data(data, {})
58+
hook.finish_displayhook()
59+
60+
61+
class ZMQShellDisplayHookTests(unittest.TestCase):
62+
def setUp(self):
63+
self.context = zmq.Context()
64+
self.socket = self.context.socket(zmq.PUB)
65+
self.session = CounterSession()
66+
self.shell = InteractiveShell()
67+
self.disp = ZMQShellDisplayHook(shell=self.shell)
68+
self.disp.session = self.session
69+
self.disp.pub_socket = self.socket
70+
71+
def tearDown(self):
72+
self.socket.close()
73+
self.context.term()
74+
75+
def test_no_hooks_sends_message(self):
76+
"""With no hooks registered, finish_displayhook still calls send."""
77+
assert self.disp._hooks == []
78+
_drive(self.disp)
79+
assert self.session.send_count == 1
80+
81+
def test_thread_local_hooks(self):
82+
"""_hooks is thread-local: registering on one thread doesn't leak."""
83+
assert self.disp._hooks == []
84+
85+
def hook(msg):
86+
return msg
87+
88+
self.disp.register_hook(hook)
89+
assert self.disp._hooks == [hook]
90+
91+
q: Queue = Queue()
92+
93+
def read_other_thread():
94+
q.put(self.disp._hooks)
95+
96+
t = Thread(target=read_other_thread)
97+
t.start()
98+
other = q.get(timeout=10)
99+
t.join()
100+
assert other == []
101+
102+
def test_hook_returning_none_halts_send(self):
103+
"""A hook that returns None suppresses session.send."""
104+
hook = NoReturnHook()
105+
self.disp.register_hook(hook)
106+
107+
_drive(self.disp)
108+
109+
assert hook.call_count == 1
110+
assert self.session.send_count == 0
111+
assert self.disp.msg is None
112+
113+
def test_hook_returning_msg_calls_send(self):
114+
"""A hook that returns the message lets it through to send."""
115+
hook = ReturnHook()
116+
self.disp.register_hook(hook)
117+
118+
_drive(self.disp)
119+
120+
assert hook.call_count == 1
121+
assert self.session.send_count == 1
122+
123+
def test_hook_can_mutate_message(self):
124+
"""A hook can attach buffers (the original motivation)."""
125+
hook = MutatingHook()
126+
self.disp.register_hook(hook)
127+
128+
_drive(self.disp)
129+
130+
assert hook.call_count == 1
131+
assert self.session.send_count == 1
132+
sent = self.session.last_msg
133+
assert sent is not None
134+
assert sent.get("buffers") == [b"arrow-bytes"]
135+
136+
def test_hook_chain_short_circuits(self):
137+
"""If an early hook returns None, later hooks are not called."""
138+
first = NoReturnHook()
139+
second = NoReturnHook()
140+
self.disp.register_hook(first)
141+
self.disp.register_hook(second)
142+
143+
_drive(self.disp)
144+
145+
assert first.call_count == 1
146+
assert second.call_count == 0
147+
assert self.session.send_count == 0
148+
149+
def test_hook_chain_threads_message(self):
150+
"""Each hook receives the message returned by the previous hook."""
151+
observed: list[dict] = []
152+
153+
def first(msg):
154+
msg["content"]["metadata"]["seen_by_first"] = True
155+
return msg
156+
157+
def second(msg):
158+
observed.append(msg)
159+
return msg
160+
161+
self.disp.register_hook(first)
162+
self.disp.register_hook(second)
163+
164+
_drive(self.disp)
165+
166+
assert len(observed) == 1
167+
assert observed[0]["content"]["metadata"].get("seen_by_first") is True
168+
assert self.session.send_count == 1
169+
170+
def test_unregister_hook(self):
171+
"""Unregistered hooks no longer run; double-unregister returns False."""
172+
hook = NoReturnHook()
173+
self.disp.register_hook(hook)
174+
175+
_drive(self.disp)
176+
assert hook.call_count == 1
177+
assert self.session.send_count == 0
178+
179+
first = self.disp.unregister_hook(hook)
180+
assert bool(first)
181+
182+
_drive(self.disp)
183+
# Hook didn't run again, but the message went out via session.send.
184+
assert hook.call_count == 1
185+
assert self.session.send_count == 1
186+
187+
# Unregistering an unknown hook returns False.
188+
assert not bool(self.disp.unregister_hook(hook))
189+
190+
def test_empty_data_skips_send_and_hooks(self):
191+
"""The existing guard: if content.data is empty, don't send or hook."""
192+
hook = ReturnHook()
193+
self.disp.register_hook(hook)
194+
195+
# start_displayhook initializes self.msg with empty data; if we never
196+
# call write_format_data, the data dict stays empty and finish should
197+
# short-circuit before calling either hooks or send.
198+
self.disp.start_displayhook()
199+
self.disp.finish_displayhook()
200+
201+
assert hook.call_count == 0
202+
assert self.session.send_count == 0
203+
assert self.disp.msg is None
204+
205+
206+
if __name__ == "__main__":
207+
unittest.main()

0 commit comments

Comments
 (0)