Skip to content

Commit 1eaf8c3

Browse files
authored
fix: live_avatar silent audio flow (#36)
Add live_avatar internal audio stream and play_audio support
1 parent 47cc7ee commit 1eaf8c3

3 files changed

Lines changed: 185 additions & 7 deletions

File tree

Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
"""
2+
Audio stream manager for live_avatar mode.
3+
4+
Mirrors the JS SDK's AudioStreamManager — ensures WebRTC always has
5+
audio frames to send even when no user mic/audio is provided.
6+
"""
7+
8+
import asyncio
9+
import fractions
10+
import io
11+
import logging
12+
from collections import deque
13+
from pathlib import Path
14+
from typing import Optional, Union
15+
16+
import av
17+
from aiortc import MediaStreamTrack
18+
19+
logger = logging.getLogger(__name__)
20+
21+
SAMPLE_RATE = 48000
22+
SAMPLES_PER_FRAME = 960 # 20ms at 48kHz
23+
BYTES_PER_SAMPLE = 2 # s16 format
24+
BYTES_PER_FRAME = SAMPLES_PER_FRAME * BYTES_PER_SAMPLE
25+
26+
27+
def _make_silence_frame() -> av.AudioFrame:
28+
frame = av.AudioFrame(samples=SAMPLES_PER_FRAME, layout="mono", format="s16")
29+
for plane in frame.planes:
30+
plane.update(bytes(BYTES_PER_FRAME))
31+
return frame
32+
33+
34+
class _AudioTrack(MediaStreamTrack):
35+
kind = "audio"
36+
37+
def __init__(self) -> None:
38+
super().__init__()
39+
self._queue: deque[av.AudioFrame] = deque()
40+
self._pts = 0
41+
self._start: Optional[float] = None
42+
self._done_event: Optional[asyncio.Event] = None
43+
44+
async def recv(self) -> av.AudioFrame:
45+
if self._start is None:
46+
self._start = asyncio.get_event_loop().time()
47+
48+
target = self._start + (self._pts / SAMPLE_RATE)
49+
delay = target - asyncio.get_event_loop().time()
50+
if delay > 0:
51+
await asyncio.sleep(delay)
52+
53+
if self._queue:
54+
frame = self._queue.popleft()
55+
if not self._queue and self._done_event:
56+
self._done_event.set()
57+
self._done_event = None
58+
else:
59+
frame = _make_silence_frame()
60+
61+
frame.pts = self._pts
62+
frame.sample_rate = SAMPLE_RATE
63+
frame.time_base = fractions.Fraction(1, SAMPLE_RATE)
64+
self._pts += SAMPLES_PER_FRAME
65+
66+
return frame
67+
68+
def enqueue(self, frames: list[av.AudioFrame], done: asyncio.Event) -> None:
69+
self._queue.extend(frames)
70+
self._done_event = done
71+
72+
def clear(self) -> None:
73+
self._queue.clear()
74+
if self._done_event:
75+
self._done_event.set()
76+
self._done_event = None
77+
78+
79+
class AudioStreamManager:
80+
"""Manages audio for live_avatar mode.
81+
82+
Provides a continuous audio track that outputs silence by default
83+
and allows playing audio data through it via play_audio().
84+
"""
85+
86+
def __init__(self) -> None:
87+
self._track = _AudioTrack()
88+
self._playing = False
89+
90+
def get_track(self) -> MediaStreamTrack:
91+
return self._track
92+
93+
@property
94+
def is_playing(self) -> bool:
95+
return self._playing
96+
97+
async def play_audio(self, audio: Union[bytes, str, Path]) -> None:
98+
"""Play audio through the stream. Resolves when audio finishes playing.
99+
100+
Args:
101+
audio: Audio data as bytes, file path string, or Path object.
102+
"""
103+
if self._playing:
104+
self.stop_audio()
105+
106+
if isinstance(audio, bytes):
107+
container: av.InputContainer = av.open(io.BytesIO(audio)) # type: ignore[assignment]
108+
else:
109+
container: av.InputContainer = av.open(str(audio)) # type: ignore[assignment]
110+
111+
try:
112+
resampler = av.AudioResampler(format="s16", layout="mono", rate=SAMPLE_RATE)
113+
raw = bytearray()
114+
115+
for frame in container.decode(audio=0):
116+
for resampled in resampler.resample(frame):
117+
raw.extend(bytes(resampled.planes[0]))
118+
119+
for resampled in resampler.resample(None):
120+
raw.extend(bytes(resampled.planes[0]))
121+
finally:
122+
container.close()
123+
124+
if not raw:
125+
return
126+
127+
frames = []
128+
for i in range(0, len(raw), BYTES_PER_FRAME):
129+
chunk = raw[i : i + BYTES_PER_FRAME]
130+
if len(chunk) < BYTES_PER_FRAME:
131+
chunk.extend(bytes(BYTES_PER_FRAME - len(chunk)))
132+
133+
frame = av.AudioFrame(samples=SAMPLES_PER_FRAME, layout="mono", format="s16")
134+
frame.planes[0].update(bytes(chunk))
135+
frames.append(frame)
136+
137+
done = asyncio.Event()
138+
self._playing = True
139+
self._track.enqueue(frames, done)
140+
141+
await done.wait()
142+
self._playing = False
143+
144+
def stop_audio(self) -> None:
145+
self._track.clear()
146+
self._playing = False
147+
148+
def cleanup(self) -> None:
149+
self.stop_audio()
150+
self._track.stop()

decart/realtime/client.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from aiortc import MediaStreamTrack
99
from pydantic import BaseModel
1010

11+
from .audio_stream_manager import AudioStreamManager
1112
from .webrtc_manager import WebRTCManager, WebRTCConfiguration
1213
from .messages import PromptMessage, SessionIdMessage, GenerationTickMessage
1314
from .subscribe import (
@@ -75,6 +76,7 @@ def __init__(
7576
self._manager = manager
7677
self._http_session = http_session
7778
self._model_name = model_name
79+
self._audio_stream_manager: Optional[AudioStreamManager] = None
7880
self._connection_callbacks: list[Callable[[ConnectionState], None]] = []
7981
self._error_callbacks: list[Callable[[DecartSDKError], None]] = []
8082
self._generation_tick_callbacks: list[Callable[[GenerationTickMessage], None]] = []
@@ -111,6 +113,13 @@ async def connect(
111113

112114
model_name: RealTimeModels = options.model.name # type: ignore[assignment]
113115

116+
is_avatar_live = model_name == "live_avatar"
117+
audio_stream_manager: Optional[AudioStreamManager] = None
118+
119+
if is_avatar_live and local_track is None:
120+
audio_stream_manager = AudioStreamManager()
121+
local_track = audio_stream_manager.get_track()
122+
114123
config = WebRTCConfiguration(
115124
webrtc_url=ws_url,
116125
api_key=api_key,
@@ -126,7 +135,6 @@ async def connect(
126135
model_name=model_name,
127136
)
128137

129-
# Create HTTP session for file conversions
130138
http_session = aiohttp.ClientSession()
131139

132140
manager = WebRTCManager(config)
@@ -135,6 +143,7 @@ async def connect(
135143
http_session=http_session,
136144
model_name=model_name,
137145
)
146+
client._audio_stream_manager = audio_stream_manager
138147

139148
config.on_connection_state_change = client._emit_connection_change
140149
config.on_error = lambda error: client._emit_error(WebRTCError(str(error), cause=error))
@@ -159,6 +168,8 @@ async def connect(
159168
initial_prompt=initial_prompt,
160169
)
161170
except Exception as e:
171+
if audio_stream_manager:
172+
audio_stream_manager.cleanup()
162173
await manager.cleanup()
163174
await http_session.close()
164175
raise WebRTCError(str(e), cause=e)
@@ -319,6 +330,17 @@ async def set_prompt(
319330
finally:
320331
self._manager.unregister_prompt_wait(prompt)
321332

333+
async def play_audio(self, audio: Union[bytes, str, Path]) -> None:
334+
"""Play audio through the avatar stream. Resolves when audio finishes.
335+
336+
Only available for live_avatar connections without a user-provided audio track.
337+
"""
338+
if self._audio_stream_manager is None:
339+
raise InvalidInputError(
340+
"play_audio() is only available for live_avatar without a user-provided audio track"
341+
)
342+
await self._audio_stream_manager.play_audio(audio)
343+
322344
async def set_image(
323345
self,
324346
image: Optional[FileInput],
@@ -349,6 +371,9 @@ def get_connection_state(self) -> ConnectionState:
349371
async def disconnect(self) -> None:
350372
self._buffering = False
351373
self._buffer.clear()
374+
if self._audio_stream_manager:
375+
self._audio_stream_manager.cleanup()
376+
self._audio_stream_manager = None
352377
await self._manager.cleanup()
353378
if self._http_session and not self._http_session.closed:
354379
await self._http_session.close()

examples/avatar_live.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,8 @@ async def main():
6363
if audio_file:
6464
print(f"🔊 Audio file: {audio_file}")
6565

66-
# Load audio if provided
6766
audio_track = None
67+
6868
if audio_file:
6969
print("Loading audio file...")
7070
player = MediaPlayer(audio_file)
@@ -131,12 +131,15 @@ def on_error(error):
131131
print("✓ Connected!")
132132
print(f"Session ID: {realtime_client.session_id}")
133133

134-
if audio_file:
135-
print("\nPlaying audio through avatar...")
136-
print("(The avatar will animate based on the audio)")
134+
if audio_file and not audio_track:
135+
print("\nPlaying audio via play_audio()...")
136+
await realtime_client.play_audio(audio_file)
137+
print("✓ Audio playback complete")
138+
elif audio_file:
139+
print("\nStreaming audio through avatar via MediaStreamTrack...")
137140
else:
138-
print("\nNo audio provided - avatar will be static")
139-
print("You can update the avatar image dynamically using set_image()")
141+
print("\nNo audio provided - avatar will be idle")
142+
print("You can play audio dynamically using play_audio()")
140143

141144
print("\nPress Ctrl+C to stop and save the recording...")
142145

0 commit comments

Comments
 (0)