diff --git a/optimized/mlx/scripts/sa3_mlx.py b/optimized/mlx/scripts/sa3_mlx.py index c8893b6..ae22b0e 100644 --- a/optimized/mlx/scripts/sa3_mlx.py +++ b/optimized/mlx/scripts/sa3_mlx.py @@ -391,9 +391,10 @@ def main(): # ── Sampling ────────────────────────────────────────────────────────────── ap.add_argument("--seconds", type=float, default=30.0, help="Output audio length in seconds. T_lat (latent positions) is derived as " - "ceil(seconds * 44100 / 4096), then bumped to even when --decoder=same-s " - "(encoder modulo-32 padding requirement). Final WAV is trimmed to exactly " - "--seconds.") + "ceil(seconds * 44100 / 4096) — a natural ceil that depends ONLY on " + "--seconds (and --seed via the DiT), never on --decoder. This keeps the " + "sampled latent (and hence the music) identical across decoders for a " + "given prompt/seed. Final WAV is trimmed to exactly --seconds.") ap.add_argument("--steps", type=int, default=8, help="Number of pingpong sampling steps. Minimum 1 (single forward pass — fastest, " "lowest quality). rf_denoiser is distilled for 8 (default — sweet spot). " @@ -465,10 +466,15 @@ def main(): # Empty prompt is allowed — T5Gemma will produce padding-only embeddings, # which (with the learned padding_embedding) is the unconditional case. dtype = mx.float32 if args.dit_dtype == "fp32" else mx.float16 + # T_lat is a DECODER-INDEPENDENT function of --seconds only (natural ceil, no + # even-bump). This matches the TensorRT pipeline (sa3_trt.py::resolve_T_lat) so + # the same prompt/seed/seconds yields the SAME latent — and hence the same music — + # regardless of --decoder. The old code bumped T_lat to even for same-s, which made + # make_noise(T_lat, seed) draw a different noise tensor and thus sample a completely + # different latent just by switching decoder. SAME-S's internal-length constraint + # (T_lat*17 must align to 34, i.e. even) is satisfied by the decode dispatch below: + # odd T_lat is routed through decode_chunked, whose every window is an even kernel. T_lat = max(1, math.ceil(args.seconds * SAMPLE_RATE / SAMPLES_PER_LATENT)) - # SAME-S requires T_audio_patches divisible by 32 → T_lat must be even (T_aud=T_lat*16). - if args.decoder == "same-s" and T_lat % 2 != 0: - T_lat += 1 target_dur = T_lat * SAMPLES_PER_LATENT / SAMPLE_RATE # Inpaint validation + parameter mapping @@ -596,8 +602,15 @@ def main(): sub(f"encoder load {(time.time()-t0)*1000:.0f} ms") t0 = time.time() + # The ENCODER (SAME-S) needs T_audio_patches divisible by `pad_mod` (32) → an + # even latent length. T_lat itself stays the natural-ceil (decoder-independent) + # value; we only round the encode grid UP to the modulo, then trim the latent + # back to T_lat. This keeps the DiT/noise path identical across decoders. + enc_T_lat = T_lat + if (T_lat * 16) % pad_mod != 0: + enc_T_lat = math.ceil((T_lat * 16) / pad_mod) * pad_mod // 16 + target_samples = enc_T_lat * SAMPLES_PER_LATENT audio_np = read_wav(args.init_audio) # (2, T_audio_in) - target_samples = T_lat * SAMPLES_PER_LATENT if audio_np.shape[-1] >= target_samples: audio_np = audio_np[:, :target_samples] init_action = f"trimmed to {target_samples} samples" @@ -610,13 +623,13 @@ def main(): # Patch + encode (encoder always runs FP32 — softnorm-bottleneck-sensitive) t0 = time.time() - patches_np = patch_audio(audio_np, patch_size=256) # (1, 512, T_lat*16) + patches_np = patch_audio(audio_np, patch_size=256) # (1, 512, enc_T_lat*16) # Sanity: T_audio_patches must be divisible by encoder's required modulo assert patches_np.shape[-1] % pad_mod == 0, ( f"T_audio_patches={patches_np.shape[-1]} not divisible by {pad_mod} " f"(decoder={args.decoder})" ) - init_latents = enc_model(mx.array(patches_np)) + init_latents = enc_model(mx.array(patches_np))[..., :T_lat] # trim to natural T_lat mx.eval(init_latents) _stage_peak_b('Init audio encode') sub(f"encode {(time.time()-t0)*1000:.0f} ms latents {init_latents.shape}") @@ -742,14 +755,25 @@ def _on_step(i: int, total: int): t0 = time.time() kernel = chunk + 2 * ovl if T_lat > kernel: + # Natural-ceil T_lat may be odd; decode_chunked windows are all even kernels, + # so odd T_lat decodes correctly here (this is the common path, incl. 30s→323). patches = chunk_fn(decoder, latents_fp32, chunk, ovl) decode_mode = f"chunked (chunk={chunk}, ovl={ovl})" elif T_lat % 2 == 0: patches = decoder(latents_fp32) decode_mode = "un-chunked" - else: + elif T_lat > 6: + # Odd T_lat in (6, kernel]: a chunk=2,ovl=2 kernel (=6 < T_lat) keeps every + # window even for SAME-S while never zero-padding real latents. patches = chunk_fn(decoder, latents_fp32, 2, 2) decode_mode = "chunked (chunk=2, ovl=2)" + else: + # Tiny odd T_lat ≤ 6 (sub-0.5s): no even chunk kernel fits. Reflect-pad one + # latent to make T even, un-chunked decode, then trim the extra 16 patches. + # SAME-L has no even constraint and takes this path only for symmetry. + latents_even = mx.concatenate([latents_fp32, latents_fp32[..., -1:]], axis=-1) + patches = decoder(latents_even)[..., : T_lat * 16] + decode_mode = "un-chunked (odd-pad→trim)" mx.eval(patches) _stage_peak_b('Decode') sub(f"decode {decode_mode} → {(time.time()-t0)*1000:.0f} ms patches {patches.shape}")