Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 34 additions & 10 deletions optimized/mlx/scripts/sa3_mlx.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,9 +391,10 @@ def main():
# ── Sampling ──────────────────────────────────────────────────────────────
ap.add_argument("--seconds", type=float, default=30.0,
help="Output audio length in seconds. T_lat (latent positions) is derived as "
"ceil(seconds * 44100 / 4096), then bumped to even when --decoder=same-s "
"(encoder modulo-32 padding requirement). Final WAV is trimmed to exactly "
"--seconds.")
"ceil(seconds * 44100 / 4096) — a natural ceil that depends ONLY on "
"--seconds (and --seed via the DiT), never on --decoder. This keeps the "
"sampled latent (and hence the music) identical across decoders for a "
"given prompt/seed. Final WAV is trimmed to exactly --seconds.")
ap.add_argument("--steps", type=int, default=8,
help="Number of pingpong sampling steps. Minimum 1 (single forward pass — fastest, "
"lowest quality). rf_denoiser is distilled for 8 (default — sweet spot). "
Expand Down Expand Up @@ -465,10 +466,15 @@ def main():
# Empty prompt is allowed — T5Gemma will produce padding-only embeddings,
# which (with the learned padding_embedding) is the unconditional case.
dtype = mx.float32 if args.dit_dtype == "fp32" else mx.float16
# T_lat is a DECODER-INDEPENDENT function of --seconds only (natural ceil, no
# even-bump). This matches the TensorRT pipeline (sa3_trt.py::resolve_T_lat) so
# the same prompt/seed/seconds yields the SAME latent — and hence the same music —
# regardless of --decoder. The old code bumped T_lat to even for same-s, which made
# make_noise(T_lat, seed) draw a different noise tensor and thus sample a completely
# different latent just by switching decoder. SAME-S's internal-length constraint
# (T_lat*17 must align to 34, i.e. even) is satisfied by the decode dispatch below:
# odd T_lat is routed through decode_chunked, whose every window is an even kernel.
T_lat = max(1, math.ceil(args.seconds * SAMPLE_RATE / SAMPLES_PER_LATENT))
# SAME-S requires T_audio_patches divisible by 32 → T_lat must be even (T_aud=T_lat*16).
if args.decoder == "same-s" and T_lat % 2 != 0:
T_lat += 1
target_dur = T_lat * SAMPLES_PER_LATENT / SAMPLE_RATE

# Inpaint validation + parameter mapping
Expand Down Expand Up @@ -596,8 +602,15 @@ def main():
sub(f"encoder load {(time.time()-t0)*1000:.0f} ms")

t0 = time.time()
# The ENCODER (SAME-S) needs T_audio_patches divisible by `pad_mod` (32) → an
# even latent length. T_lat itself stays the natural-ceil (decoder-independent)
# value; we only round the encode grid UP to the modulo, then trim the latent
# back to T_lat. This keeps the DiT/noise path identical across decoders.
enc_T_lat = T_lat
if (T_lat * 16) % pad_mod != 0:
enc_T_lat = math.ceil((T_lat * 16) / pad_mod) * pad_mod // 16
target_samples = enc_T_lat * SAMPLES_PER_LATENT
audio_np = read_wav(args.init_audio) # (2, T_audio_in)
target_samples = T_lat * SAMPLES_PER_LATENT
if audio_np.shape[-1] >= target_samples:
audio_np = audio_np[:, :target_samples]
init_action = f"trimmed to {target_samples} samples"
Expand All @@ -610,13 +623,13 @@ def main():

# Patch + encode (encoder always runs FP32 — softnorm-bottleneck-sensitive)
t0 = time.time()
patches_np = patch_audio(audio_np, patch_size=256) # (1, 512, T_lat*16)
patches_np = patch_audio(audio_np, patch_size=256) # (1, 512, enc_T_lat*16)
# Sanity: T_audio_patches must be divisible by encoder's required modulo
assert patches_np.shape[-1] % pad_mod == 0, (
f"T_audio_patches={patches_np.shape[-1]} not divisible by {pad_mod} "
f"(decoder={args.decoder})"
)
init_latents = enc_model(mx.array(patches_np))
init_latents = enc_model(mx.array(patches_np))[..., :T_lat] # trim to natural T_lat
mx.eval(init_latents)
_stage_peak_b('Init audio encode')
sub(f"encode {(time.time()-t0)*1000:.0f} ms latents {init_latents.shape}")
Expand Down Expand Up @@ -742,14 +755,25 @@ def _on_step(i: int, total: int):
t0 = time.time()
kernel = chunk + 2 * ovl
if T_lat > kernel:
# Natural-ceil T_lat may be odd; decode_chunked windows are all even kernels,
# so odd T_lat decodes correctly here (this is the common path, incl. 30s→323).
patches = chunk_fn(decoder, latents_fp32, chunk, ovl)
decode_mode = f"chunked (chunk={chunk}, ovl={ovl})"
elif T_lat % 2 == 0:
patches = decoder(latents_fp32)
decode_mode = "un-chunked"
else:
elif T_lat > 6:
# Odd T_lat in (6, kernel]: a chunk=2,ovl=2 kernel (=6 < T_lat) keeps every
# window even for SAME-S while never zero-padding real latents.
patches = chunk_fn(decoder, latents_fp32, 2, 2)
decode_mode = "chunked (chunk=2, ovl=2)"
else:
# Tiny odd T_lat ≤ 6 (sub-0.5s): no even chunk kernel fits. Reflect-pad one
# latent to make T even, un-chunked decode, then trim the extra 16 patches.
# SAME-L has no even constraint and takes this path only for symmetry.
latents_even = mx.concatenate([latents_fp32, latents_fp32[..., -1:]], axis=-1)
patches = decoder(latents_even)[..., : T_lat * 16]
decode_mode = "un-chunked (odd-pad→trim)"
mx.eval(patches)
_stage_peak_b('Decode')
sub(f"decode {decode_mode} → {(time.time()-t0)*1000:.0f} ms patches {patches.shape}")
Expand Down
Loading