diff --git a/optimized/tflite/.gitignore b/optimized/tflite/.gitignore new file mode 100644 index 0000000..14e4dcf --- /dev/null +++ b/optimized/tflite/.gitignore @@ -0,0 +1,14 @@ +# Model weights — pulled from HuggingFace, never committed here. +# (The SentencePiece tokenizer at models/tokenizer.model IS committed — it's +# small and the .tflite T5Gemma is encoder-only.) +*.tflite + +# Generated audio — anywhere on the tree, including output/. +*.wav + +# Python virtualenv created by install.sh +.venv/ + +# Python bytecode caches +__pycache__/ +*.pyc diff --git a/optimized/tflite/README.md b/optimized/tflite/README.md new file mode 100644 index 0000000..122bf7d --- /dev/null +++ b/optimized/tflite/README.md @@ -0,0 +1,240 @@ +# sa3_tflite — Stable Audio 3 on CPU via LiteRT / TFLite + +Portable CPU inference for **Stable Audio 3** — the LiteRT/TFLite sibling of the +[MLX](../mlx) (Apple Silicon) and [TensorRT](../tensorRT) (NVIDIA GPU) releases. +No PyTorch, transformers, or stable-audio-tools at runtime — just `ai_edge_litert` +(LiteRT) driving fully self-contained `.tflite` graphs through the XNNPACK CPU +delegate. Runs anywhere LiteRT runs: **macOS / Linux, x86 / ARM**. + +## Quick Install + +One line on a fresh machine — installs everything and plays back ~30 seconds of +"Impending tribal, epic orchestral buildup": + +```bash +curl -LsSf https://raw.githubusercontent.com/Stability-AI/stable-audio-3/main/optimized/tflite/bootstrap.sh | bash +``` + +Already cloned the repo? Run from inside `optimized/tflite/`: + +```bash +./install.sh # one-time setup +./sa3 --prompt "Impending tribal, epic orchestral buildup" --play # generates + plays +``` + +## Three models, four modes + +| `--dit` | model | best for | +|------------|--------------------|--------------------------------| +| `sm-music` | sa3-sm-music (50 M block) | fast music generation | +| `sm-sfx` | sa3-sm-sfx (50 M block) | sound effects | +| `medium` | sa3-medium-ARC (1.4 B) | higher-quality music, slower | + +| mode | flags | example | +|------------------|-----------------------------------------------|----------------------------------| +| text-to-audio | `--prompt P` | new clip from a description | +| audio-to-audio | `--prompt P --init-audio IN.wav --init-noise-level σ` | variation of an existing clip | +| inpainting | `--prompt P --init-audio IN.wav --inpaint-range "S,E"` | regenerate one section, keep rest | +| CFG + negative | `--cfg 3.0 --negative-prompt P_NEG` | steer toward / away from prompts | + +``` +prompt ─▶ T5Gemma encoder ─▶ DiT pingpong sampler ─▶ SAME-S/L decoder ─▶ WAV + ▲ + optional: encoder + init audio (audio-to-audio / inpaint) +``` + +## Install + +```bash +./install.sh +``` + +`install.sh` is uv-based. On a fresh machine it will: + +1. Install [uv](https://github.com/astral-sh/uv) via the official curl + installer if it's missing (prompts y/N; `-y` skips the prompt). +2. Create a project-local `.venv/` with managed Python 3.11. +3. `uv pip install` the runtime deps into the venv (much faster than pip). +4. Ask which DiT bundles to download from HuggingFace + (`stabilityai/stable-audio-3-optimized`). Each pick pulls its matching + audio codec; T5Gemma (the shared text encoder) is downloaded once. + Already-present weights are skipped. + +End-to-end on a fresh machine: **~10 seconds** + weight downloads. + +> Don't want to pre-pick bundles? Skip install entirely and just run +> `./sa3 --prompt …` — any missing model file is downloaded from HF on +> first use and symlinked into `models/tflite/` from the HuggingFace cache. + +Portable CPU (no GPU required). Python 3.9+. `./install.sh --python 3.12` to +pin a different Python. + +## Run + +`./sa3` is a thin shell wrapper around `.venv/bin/python scripts/sa3_tflite.py +"$@"` that prompts to run `./install.sh` if uv or `.venv/` isn't set up. + +```bash +# Text-to-audio +./sa3 --prompt "lofi house loop" --dit sm-music --decoder same-s --out lofi.wav + +# Sound effects +./sa3 --prompt "footsteps on gravel" --dit sm-sfx --decoder same-s --out steps.wav + +# Higher-quality music (medium DiT, chunked SAME-L decode) +./sa3 --prompt "A beautiful piano arpeggio grows into a cinematic climax" \ + --dit medium --decoder same-l --seconds 30 --out piano.wav + +# Audio-to-audio variation (σmax 0.4-0.8 typical) +./sa3 --prompt "jazz fusion with electric piano" --dit sm-music --decoder same-s \ + --init-audio funk.wav --init-noise-level 0.7 --out funk_jazz.wav + +# Inpaint seconds 4-7 +./sa3 --prompt "explosive drum break" --dit sm-music --decoder same-s \ + --init-audio funk.wav --inpaint-range "4,7" --out funk_drums.wav + +# CFG + negative prompt +./sa3 --prompt "ambient drone" --cfg 3.0 --negative-prompt "drums, vocals" \ + --dit sm-music --decoder same-s --out drone.wav + +# Generate + play immediately (afplay; Ctrl-C stops both) +./sa3 --prompt "rainforest" --dit sm-sfx --decoder same-s --play + +# All options + categorised examples +./sa3 --help +``` + +Omit `--dit` / `--decoder` for an interactive arrow-key picker. Omit +`--prompt` for a stdin prompt. Relative `--out` paths land in `output/` +(auto-created); absolute paths are honoured as-is. The output path is +printed prominently as a `▸ saved` line at the end of each run. + +Use `--threads` to control the XNNPACK CPU thread count (default 8). + +### Without the wrapper + +```bash +.venv/bin/python scripts/sa3_tflite.py --prompt "..." --dit medium --decoder same-l +# or, after `source .venv/bin/activate`: +python scripts/sa3_tflite.py --prompt "..." --dit medium --decoder same-l +``` + +## Speed & memory + +This is a **CPU** path — it trades the GPU releases' speed for portability. The +small models comfortably beat realtime on a modern laptop CPU; `medium` is slower +(its DiT is ~5.8 GB fp32 and it chunk-decodes SAME-L). Use ≥ ~20 s clips: very +short clips have too few latent tokens for the sampler to settle into a coherent +loop. Throughput scales with `--threads` up to your physical core count (4–8 is +the usual sweet spot; more threads on a short model adds overhead). + +For sub-realtime latency on a supported device, prefer the GPU siblings: +[MLX](../mlx) on Apple Silicon, [TensorRT](../tensorRT) on NVIDIA. + +## Flag reference + +| Flag | Default | Notes | +|-----------------------|----------|-----------------------------------------------------------------------| +| `--prompt` | (asks) | Text prompt; empty string = unconditional | +| `--negative-prompt` | — | CFG uncond branch; only used when `--cfg ≠ 1.0` | +| `--dit` | (asks) | `sm-music`, `sm-sfx`, or `medium` | +| `--decoder` | (asks) | `same-s` (pairs with sm-*) or `same-l` (pairs with medium) | +| `--seconds` | 30 | Output length (use ≥ ~20 s) | +| `--steps` | 8 | Pingpong sampler steps; 1 = single forward (fastest), 8 = sweet spot | +| `--seed` | random | Set for reproducibility; the chosen seed is printed at the end | +| `--cfg` | 1.0 | Guidance scale; 1.0 = off, >1 toward prompt, <1 toward uncond. ≠1 runs cond+uncond each step | +| `--apg` | 1.0 | Adaptive Projected Guidance; only matters when `--cfg ≠ 1` | +| `--cfg-batched` | on | When `--cfg ≠ 1`, run cond+uncond as one batch=2 invoke on the variable-batch DiT (~7–29% faster on Apple-Silicon AMX). `--no-cfg-batched` → sequential batch=1 dual-pass. Bit-identical | +| `--init-audio` | — | WAV (any format via ffmpeg) input for audio-to-audio / inpaint | +| `--init-noise-level` | 1.0 | σmax; 0.4–0.8 typical for variation, 1.0 = full regen, >1 = overshoot | +| `--inpaint-range` | — | `START,END` seconds; regenerate that span, keep the rest | +| `--threads` | 8 | XNNPACK CPU threads (all TFLite models run on CPU) | +| `--free-models` | on | Free each model after its last use; `--no-free-models` keeps them resident | +| `--out` / `-o` | (auto) | Relative → `output/`; absolute → as-is. 16-bit PCM stereo @ 44.1 kHz, trimmed to exactly `--seconds` | +| `--play` | off | After writing, play via `afplay` (macOS); Ctrl-C stops both | + +All `.tflite` models are **fp32** except T5Gemma, which is **fp16** (numerically +lossless there). There is no dtype knob: on CPU, int8/fp16 weights buy size, not +speed (XNNPACK dequantizes to fp32 to matmul), and int8 costs quality on the DiT +— so this release ships the fp32 graphs directly. (See "Notes on the design".) + +## Files + +``` +sa3_tflite/ +├── sa3 ← shell wrapper (use this) +├── install.sh ← uv bootstrap (run once) +├── bootstrap.sh ← one-line curl installer +├── README.md +├── requirements.txt ← ai_edge_litert, numpy, sentencepiece, soundfile, huggingface_hub +├── output/ ← default landing zone for generated WAVs +├── scripts/ +│ ├── sa3_tflite.py ← orchestrator CLI (invoked by ./sa3) +│ ├── weights.py ← weights manifest + HF auto-download +│ ├── examples.py ← shared examples block (--help + post-install) +│ └── install.py ← install.sh's Python half (bundle picker) +└── models/ + ├── tokenizer.model ← SentencePiece model, BUNDLED (~4 MB; T5Gemma tflite is encoder-only) + ├── defs/ + │ └── tflite_pipeline.py ← Tokenizer + T5Gemma front-end + pingpong schedule + sampler + WAV + └── tflite/ ← .tflite models (auto-downloaded; ~2.3 GB small, ~9.5 GB medium) + ├── t5gemma/encoder_fp16.tflite 564 MB text encoder (fp16) + ├── sa3-sm-music/dit_fp32.tflite 1.8 GB small music DiT (conditioner baked in) + ├── sa3-sm-sfx/dit_fp32.tflite 1.8 GB small sfx DiT (conditioner baked in) + ├── sa3-m/dit_fp32.tflite 5.8 GB medium DiT (conditioner baked in) + ├── same-s/{enc,dec}_fp32.tflite ~220 MB each shared sm-* codec + └── same-l/{enc,dec}_fp32.tflite ~1.8 GB each medium codec +``` + +The DiT graphs are **baked-I/O**: the conditioner (prompt-padding + seconds +embedder) and the patch/unpatch are compiled into the graph, so the DiT takes the +raw T5Gemma output directly and the decoder emits audio directly. The two small +DiTs share the SAME-S codec (bit-exact between checkpoints), so only one set of +small-codec files is shipped. + +## Auto-download from HuggingFace + +Model files aren't bundled — they're pulled from +`stabilityai/stable-audio-3-optimized` (under `tflite/…`) on first use and +symlinked into `models/tflite/` from the HF cache. No duplication. Anonymous +downloads work but are rate-limited; `huggingface-cli login` with a free read-only +token lifts the cap. The SentencePiece tokenizer (`models/tokenizer.model`, ~4 MB) +is the one weight that IS committed, since the `.tflite` T5Gemma is encoder-only. + +## Notes on the design + +- **Baked-I/O varlen graphs.** Each `.tflite` is a single self-contained graph + with the conditioner and patch/unpatch in-graph, accepting a variable sequence + length — so one file serves any `--seconds`. The DiT is a 6-input graph + (`x, t, t5_hidden, t5_mask, seconds, local_add_cond`); feed raw T5 outputs and + the in-graph conditioner handles prompt-padding + seconds-embedding. +- **fp32 everywhere (except fp16 T5Gemma).** On CPU, quantizing buys size, not + speed — XNNPACK dequantizes int8/fp16 weights to fp32 to matmul, so fp16 is + actually *slower* and int8 gives no speedup. And the DiT will not go int8 at + quality: per-step error compounds over the 8 chaotic sampling steps into a + *different* (still plausible) sample, not a degraded one. So this release ships + the fp32 graphs directly. T5Gemma fp16 is the sole exception — it's numerically + lossless there and halves that file. +- **Monotonic audio-to-audio schedule.** The pingpong schedule applies the LogSNR + shift to the normalized `[1→0]` grid, then scales by σmax, so audio-to-audio + (σmax < 1) stays monotonically decreasing while keeping all N distilled steps. + σmax = 1.0 (text-to-audio) is bit-identical to the classic schedule. +- **SAME-L chunked decode.** The SAME-L decoder's dense sliding-window-attention + mask is O(T²), so long clips are decoded in overlap-8 windows of 64 latent + tokens (the throughput optimum) and stitched. SAME-S has a narrow receptive + field and decodes whole. +- **CFG (`--cfg ≠ 1`)** combines a cond and an uncond velocity in denoised space + (optional APG). The canonical DiT is **variable-batch**, so by default cond+uncond + run as **one batch=2 invoke per step** (`--cfg-batched`) — ~7–29% faster on + Apple-Silicon (the AMX matrix unit amortizes the weight loads across both rows). + `--no-cfg-batched` falls back to a sequential batch=1 dual-pass (like the TensorRT + release, whose engine is static-batch=1); the two are bit-identical. + +## License & attribution + +Model weights derived from Stability AI's Stable Audio 3 checkpoints. +T5Gemma text encoder from Google. + +Use of the Stable Audio 3 weights is governed by the **Stability AI +Community License**. Please refer to the full terms at +. diff --git a/optimized/tflite/bootstrap.sh b/optimized/tflite/bootstrap.sh new file mode 100755 index 0000000..a7bdeeb --- /dev/null +++ b/optimized/tflite/bootstrap.sh @@ -0,0 +1,156 @@ +#!/usr/bin/env bash +# +# sa3_tflite bootstrap — Stable Audio 3 inference on CPU (LiteRT/TFLite) in one command. +# +# Hosted at: +# https://raw.githubusercontent.com/Stability-AI/stable-audio-3/main/optimized/tflite/bootstrap.sh +# +# Usage: +# curl -LsSf https://raw.githubusercontent.com/Stability-AI/stable-audio-3/main/optimized/tflite/bootstrap.sh | bash +# curl -LsSf https://raw.githubusercontent.com/Stability-AI/stable-audio-3/main/optimized/tflite/bootstrap.sh | bash -s -- --prompt "Death Metal" --dit medium --decoder same-l +# +# Default demo prompt is "Impending tribal, epic orchestral buildup". +# +# What it does: +# 1. Verifies curl + tar are present (portable CPU stack — runs on macOS/Linux, x86/ARM). +# 2. Fetches the project: +# - If git is installed → `git clone --depth=1` into ./stable-audio-3/, +# then cd into optimized/tflite/ (real repo; pullable, modifiable). +# - If not → tarball pull via curl + tar, extracting only optimized/tflite/ +# into ./sa3_tflite/ (no git needed). +# 3. Runs ./install.sh -y inside it (uv + Python 3.11 + venv + weight downloads). +# 4. Runs ./sa3 with whatever args you passed (default: "Impending tribal, epic orchestral buildup" demo + --play). +# +set -euo pipefail + +# uv's curl installer drops the binary at $XDG_BIN_HOME (~/.local/bin by default) +# and updates the user's shell profile — but that profile only takes effect in +# *new* shells. We pre-emptively put both locations on PATH so the just-installed +# uv (and anything else from this run) is findable in the current process tree. +export PATH="${XDG_BIN_HOME:-$HOME/.local/bin}:$HOME/.local/bin:$PATH" + +REPO_OWNER="Stability-AI" +REPO_NAME="stable-audio-3" +BRANCH="main" +SUBDIR_IN_REPO="optimized/tflite" +LOCAL_DIR="sa3_tflite" +DEFAULT_ARGS=(--prompt "Impending tribal, epic orchestral buildup" --dit sm-music --decoder same-s --seconds 30 --play) + +TAR_URL="https://github.com/$REPO_OWNER/$REPO_NAME/archive/refs/heads/$BRANCH.tar.gz" +TAR_INNER="$REPO_NAME-$BRANCH/$SUBDIR_IN_REPO" + +# ── colours ───────────────────────────────────────────────────────────────── +if [[ -t 1 ]]; then + BOLD=$'\033[1m'; CYAN=$'\033[1;36m'; RED=$'\033[1;31m' + YELLOW=$'\033[1;33m'; GREEN=$'\033[1;32m'; DIM=$'\033[2m'; RESET=$'\033[0m' +else + BOLD=""; CYAN=""; RED=""; YELLOW=""; GREEN=""; DIM=""; RESET="" +fi +step() { printf '\n%s→ %s%s\n' "$CYAN" "$1" "$RESET"; } +fail() { printf '\n%serror%s: %s\n' "$RED" "$RESET" "$1" >&2; exit 1; } +ok() { printf ' %s✓%s %s\n' "$GREEN" "$RESET" "$1"; } +warn() { printf '%swarning%s: %s\n' "$YELLOW" "$RESET" "$1" >&2; } + +# ── 1. platform note (portable — no gate) ─────────────────────────────────── +OS="$(uname -s)"; ARCH="$(uname -m)" +ok "platform: $OS/$ARCH (LiteRT/TFLite is CPU-portable)" + +# ── 2. preflight: curl + tar (preinstalled on macOS/most Linux) ───────────── +for tool in curl tar; do + command -v "$tool" >/dev/null 2>&1 || \ + fail "$tool not found on PATH. Install it and re-run." +done +ok "curl + tar present" + +# ── 3. fetch the project ──────────────────────────────────────────────────── +# Prefer `git clone` if git is on the machine — the user gets a real repo +# they can pull updates from / navigate sibling subdirs in (mlx/, tensorRT/). +# Falls back to a tarball pull (curl + tar) if git is missing. + +if command -v git >/dev/null 2>&1; then + GIT_DIR="$REPO_NAME" + WORK_DIR="$GIT_DIR/$SUBDIR_IN_REPO" + + if [[ -d "$GIT_DIR/.git" ]]; then + step "Reusing existing ./$GIT_DIR (git pull --ff-only)" + git -C "$GIT_DIR" pull --ff-only + elif [[ -e "$GIT_DIR" ]]; then + fail "./$GIT_DIR exists but isn't a git repo — remove or rename it." + else + step "git clone https://github.com/$REPO_OWNER/$REPO_NAME → ./$GIT_DIR" + git clone --depth=1 "https://github.com/$REPO_OWNER/$REPO_NAME" "$GIT_DIR" + fi + + [[ -d "$WORK_DIR" ]] || \ + fail "Expected '$SUBDIR_IN_REPO' inside the repo but didn't find it." + ok "ready at ./$WORK_DIR" +else + # No git — pull a tarball and extract only optimized/tflite/. + WORK_DIR="$LOCAL_DIR" + + if [[ -d "$LOCAL_DIR" && -x "$LOCAL_DIR/install.sh" ]]; then + step "Reusing existing $LOCAL_DIR/ (delete it to re-download)" + else + if [[ -e "$LOCAL_DIR" ]]; then + fail "./$LOCAL_DIR exists but doesn't look like a sa3_tflite checkout — remove or rename it." + fi + step "git not installed — downloading $REPO_OWNER/$REPO_NAME ($BRANCH) tarball → ./$LOCAL_DIR" + + TMP_TAR="$(mktemp -t sa3_repo.XXXXXX).tar.gz" + TMP_EXTRACT="$(mktemp -d -t sa3_extract.XXXXXX)" + trap 'rm -rf "$TMP_TAR" "$TMP_EXTRACT"' EXIT + + # --progress-bar writes to stderr; -f makes 404/5xx a real curl error + curl -fL --progress-bar "$TAR_URL" -o "$TMP_TAR" + + # BSD tar (macOS) / GNU tar both extract only paths matching the pattern. + tar -xz -f "$TMP_TAR" -C "$TMP_EXTRACT" "$TAR_INNER" + + SRC="$TMP_EXTRACT/$TAR_INNER" + [[ -d "$SRC" ]] || fail "Expected '$TAR_INNER' inside the tarball but didn't find it." + mv "$SRC" "$LOCAL_DIR" + ok "extracted $(find "$LOCAL_DIR" -type f | wc -l | tr -d ' ') files to ./$LOCAL_DIR" + fi +fi + +# ── 4. install ────────────────────────────────────────────────────────────── +cd "$WORK_DIR" +[[ -x ./install.sh ]] || fail "install.sh missing or not executable in ./$WORK_DIR." +step "Running ./install.sh -y" +./install.sh -y + +# ── 5. inference ──────────────────────────────────────────────────────────── +# Run as a subprocess (not `exec`) so we can drop the user into an +# interactive shell here when it finishes. +if [[ $# -gt 0 ]]; then + step "Running ./sa3 $*" + ./sa3 "$@" || true +else + step "Running demo: ./sa3 ${DEFAULT_ARGS[*]}" + printf ' %s(pass your own args via: curl -LsSf https://raw.githubusercontent.com/Stability-AI/stable-audio-3/main/optimized/tflite/bootstrap.sh | bash -s -- --prompt "..." ...)%s\n' "$DIM" "$RESET" + ./sa3 "${DEFAULT_ARGS[@]}" || true +fi + +# ── 6. drop user into an interactive shell sitting in the project dir ────── +# A subprocess can't change its parent shell's cwd — but we CAN replace +# ourselves with a fresh interactive shell, leaving the user at a prompt +# inside ./$WORK_DIR. `exit` (or Ctrl-D) returns them to their original +# shell, at their original cwd, just like a normal subshell. +# +# `< /dev/tty` is essential when bootstrap.sh was invoked via curl|bash: +# stdin at this point is the (closed) pipe; an interactive shell needs a +# real terminal. /dev/tty always refers to the user's controlling TTY. + +if [[ ! -e /dev/tty ]]; then + # Headless / scripted invocation — skip the shell drop. + exit 0 +fi + +USER_SHELL="${SHELL:-/bin/bash}" +printf '\n%s━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━%s\n' "$BOLD" "$RESET" +printf ' %s✓ you are now in%s %s%s%s\n' "$GREEN" "$RESET" "$BOLD" "$(pwd)" "$RESET" +printf ' %stype %s./sa3 --help%s for options, or %sexit%s to return to your previous shell%s\n' \ + "$DIM" "$RESET$BOLD" "$RESET$DIM" "$RESET$BOLD" "$RESET$DIM" "$RESET" +printf '%s━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━%s\n\n' "$BOLD" "$RESET" + +exec "$USER_SHELL" -i < /dev/tty diff --git a/optimized/tflite/install.sh b/optimized/tflite/install.sh new file mode 100755 index 0000000..710bfb2 --- /dev/null +++ b/optimized/tflite/install.sh @@ -0,0 +1,122 @@ +#!/usr/bin/env bash +# +# SA3 TFLite installer — uv-based. +# +# Creates a project-local .venv/ with the right Python and runtime deps, +# then hands off to install.py for the interactive weight-download prompt. +# +# Portable CPU stack (LiteRT / TFLite via ai_edge_litert) — runs on +# macOS/Linux, x86/ARM. No Apple-Silicon requirement. +# +# Usage: +# ./install.sh # auto-detect uv, prompt to install if missing +# ./install.sh -y # assume yes to "install uv?" prompt +# ./install.sh --python VER # pin a specific Python (default: 3.11) +# +# After install: +# source .venv/bin/activate +# python scripts/sa3_tflite.py --prompt "lofi house" --dit sm-music --decoder same-s +# # or, without activating: +# .venv/bin/python scripts/sa3_tflite.py --prompt "lofi house" --dit sm-music +# +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +VENV_DIR="$SCRIPT_DIR/.venv" +PY_VERSION_DEFAULT="3.11" + +# ── colours ───────────────────────────────────────────────────────────────── +if [[ -t 1 ]]; then + BOLD=$'\033[1m'; CYAN=$'\033[1;36m'; RED=$'\033[1;31m' + YELLOW=$'\033[1;33m'; GREEN=$'\033[1;32m'; DIM=$'\033[2m'; RESET=$'\033[0m' +else + BOLD=""; CYAN=""; RED=""; YELLOW=""; GREEN=""; DIM=""; RESET="" +fi +step() { printf '\n%s→ %s%s\n' "$CYAN" "$1" "$RESET"; } +fail() { printf '%serror%s: %s\n' "$RED" "$RESET" "$1" >&2; } +warn() { printf '%swarning%s: %s\n' "$YELLOW" "$RESET" "$1" >&2; } +ok() { printf ' %s✓%s %s\n' "$GREEN" "$RESET" "$1"; } + +# ── arg parsing ───────────────────────────────────────────────────────────── +ASSUME_YES=0 +PY_VERSION="$PY_VERSION_DEFAULT" +EXTRA_ARGS=() +while [[ $# -gt 0 ]]; do + case "$1" in + -y|--yes) ASSUME_YES=1; shift ;; + --python) PY_VERSION="$2"; shift 2 ;; + --python=*) PY_VERSION="${1#--python=}"; shift ;; + -h|--help) + sed -n '2,/^set -euo/p' "$0" | sed -e '$d' -e 's/^# \{0,1\}//' + exit 0 ;; + *) EXTRA_ARGS+=("$1"); shift ;; + esac +done + +# ── platform note (portable — CPU only, no gate) ──────────────────────────── +OS="$(uname -s)"; ARCH="$(uname -m)" +ok "platform: $OS/$ARCH (LiteRT/TFLite is CPU-portable; XNNPACK delegate)" + +# ── ensure uv is installed ────────────────────────────────────────────────── +ensure_uv() { + if command -v uv >/dev/null 2>&1; then + ok "uv $(uv --version 2>/dev/null | awk '{print $2}') already installed at $(command -v uv)" + return 0 + fi + + step "uv not found — uv is required (much faster than pip, also manages Python versions)" + if [[ "$ASSUME_YES" -ne 1 ]]; then + printf ' Install uv now via the official installer? (curl + sh) %s[Y/n]%s ' "$DIM" "$RESET" + read -r REPLY < /dev/tty + case "$REPLY" in + ""|y|Y|yes|YES) ;; + *) + fail "install aborted — install uv manually then re-run:" + printf ' curl -LsSf https://astral.sh/uv/install.sh | sh\n' >&2 + printf ' or: brew install uv\n' >&2 + exit 1 ;; + esac + fi + + step "Installing uv (curl -LsSf https://astral.sh/uv/install.sh | sh)" + if ! curl -LsSf https://astral.sh/uv/install.sh | sh; then + fail "uv installer failed. Try a manual install:" + printf ' brew install uv\n' >&2 + exit 1 + fi + # The installer drops uv at ~/.local/bin/uv (or $XDG_BIN_HOME) + export PATH="$HOME/.local/bin:${XDG_BIN_HOME:-}:$PATH" + if ! command -v uv >/dev/null 2>&1; then + fail "uv was installed but isn't on PATH. Add ~/.local/bin to PATH, restart your shell, and re-run install.sh." + exit 1 + fi + ok "uv $(uv --version | awk '{print $2}') installed" +} + +ensure_uv + +# ── create venv (uv auto-installs the requested Python if missing) ────────── +step "Creating virtual environment at .venv/ with Python $PY_VERSION" +if [[ -d "$VENV_DIR" ]]; then + EXISTING_PY=$("$VENV_DIR/bin/python" -c 'import sys; print(".".join(map(str, sys.version_info[:2])))' 2>/dev/null || echo "unknown") + if [[ "$EXISTING_PY" == "$PY_VERSION"* ]]; then + ok "reusing existing .venv (Python $EXISTING_PY)" + else + warn "existing .venv uses Python $EXISTING_PY (wanted $PY_VERSION) — recreating" + rm -rf "$VENV_DIR" + uv venv --seed --python "$PY_VERSION" "$VENV_DIR" + fi +else + uv venv --seed --python "$PY_VERSION" "$VENV_DIR" +fi + +# ── install runtime deps ──────────────────────────────────────────────────── +step "Installing dependencies (uv pip install -r requirements.txt)" +VIRTUAL_ENV="$VENV_DIR" uv pip install -r "$SCRIPT_DIR/requirements.txt" + +# ── hand off to install.py ────────────────────────────────────────────────── +# Any unrecognized args we collected (e.g. --download medium,sm-music) get +# forwarded to install.py via EXTRA_ARGS. +step "Bundle picker" +INSTALL_SKIP_PIP=1 exec "$VENV_DIR/bin/python" "$SCRIPT_DIR/scripts/install.py" \ + "${EXTRA_ARGS[@]+"${EXTRA_ARGS[@]}"}" diff --git a/optimized/tflite/models/__init__.py b/optimized/tflite/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/optimized/tflite/models/defs/__init__.py b/optimized/tflite/models/defs/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/optimized/tflite/models/defs/tflite_pipeline.py b/optimized/tflite/models/defs/tflite_pipeline.py new file mode 100644 index 0000000..4b117f5 --- /dev/null +++ b/optimized/tflite/models/defs/tflite_pipeline.py @@ -0,0 +1,143 @@ +"""SA3 text-to-audio inference pipeline for TFLite / LiteRT (CPU) — trimmed to just +what the baked-I/O CLI (scripts/sa3_tflite.py) needs. + +Pipeline: prompt → (SentencePiece) → T5Gemma TFLite → conditioning (baked in-graph) + → DiT pingpong (8 steps, rectified-flow, host-side numpy) → decoder → WAV + +The baked-I/O DiT bakes the conditioner (seconds embedder + prompt padding) and +patch/unpatch into the graph, so this module only needs: the tokenizer, the T5Gemma +front-end, the pingpong schedule, the noise maker, the host-side sampler, and a WAV +writer. The research-only backends (MLX DiT A/B, per-precision model dicts, the numpy +Conditioner / unpatch, encode_prompt / decode_with helpers) live in the speed-metal +repo's tflite_pipeline.py and are intentionally dropped here. +""" +from __future__ import annotations +import wave +from pathlib import Path +from typing import Callable +import numpy as np + +# This file lives in /models/defs/. The bundled SentencePiece model sits at +# /models/tokenizer.model — resolve it relative to this file so the tokenizer +# works regardless of the caller's cwd. +DEFS_DIR = Path(__file__).resolve().parent +MODELS_DIR = DEFS_DIR.parent +TOKENIZER_MODEL = MODELS_DIR / "tokenizer.model" + +SAMPLE_RATE = 44100 +SAMPLES_PER_LATENT = 4096 # decoder upsample (256 patch × 16) +COND_TOKENS = 256 # T5Gemma seq len + + +# ───────────────────────── WAV ───────────────────────── +def save_wav(path, audio): # audio: (2, T) float32 in [-1,1] + audio = np.clip(np.asarray(audio, np.float32), -1, 1) + pcm = (audio * 32767.0).astype(np.int16).T # (T, 2) interleaved + with wave.open(str(path), "wb") as w: + w.setnchannels(audio.shape[0]); w.setsampwidth(2); w.setframerate(SAMPLE_RATE) + w.writeframes(pcm.tobytes()) + + +# ───────────────────────── Tokenizer (SentencePiece, bundled) ───────────────────────── +class Tokenizer: + """SentencePiece front-end. Loads the model from the BUNDLED models/tokenizer.model + (the .tflite T5Gemma is encoder-only and carries no tokenizer). Encode(prompt)[:256], + pad=0 — reproduces the SA3 groundtruth input_ids exactly.""" + def __init__(self, model_path=TOKENIZER_MODEL): + import sentencepiece as spm + model_path = Path(model_path) + if not model_path.exists(): + raise FileNotFoundError( + f"tokenizer model not found at {model_path}. It should ship with this " + f"release under models/tokenizer.model." + ) + self.sp = spm.SentencePieceProcessor() + self.sp.LoadFromFile(str(model_path)) + self.pad = 0 + + def __call__(self, prompt: str, max_len: int = COND_TOKENS): + ids = np.full((1, max_len), self.pad, np.int32) + mask = np.zeros((1, max_len), np.int32) + toks = self.sp.Encode(prompt)[:max_len] + ids[0, :len(toks)] = toks + mask[0, :len(toks)] = 1 + return ids, mask + + +# ───────────────────────── Pingpong schedule (numpy port) ───────────────────────── +def _logsnr_shift(t, anchor=-6.2, end=2.0): + t = t.astype(np.float32) + logsnr = end - t * (end - anchor) + out = 1.0 / (1.0 + np.exp(logsnr)) # sigmoid(-logsnr) + out = np.where(t <= 0, 0.0, out) + out = np.where(t >= 1, 1.0, out) + return out.astype(np.float32) + + +def build_pingpong_schedule(steps, sigma_max=1.0): + # LogSNR pingpong schedule of (steps+1) points from sigma_max down to 0: warp the normalized + # [1→0] grid through the logSNR shift, then scale by sigma_max. Decreases monotonically + # sigma_max→0 across all steps, with the first step exactly at sigma_max to match the init + # mix. sigma_max=1.0 is plain text-to-audio; sigma_max<1.0 is the audio-to-audio start. + t = _logsnr_shift(np.linspace(1.0, 0.0, steps + 1).astype(np.float32)) * np.float32(sigma_max) + t[0] = np.float32(sigma_max) + return t + + +# ───────────────────────── TFLite helper ───────────────────────── +def _interp(path, threads=8): + from ai_edge_litert import interpreter as tfl + it = tfl.Interpreter(model_path=str(path), num_threads=threads) + it.allocate_tensors() + return it + + +class T5GemmaTFLite: + """T5Gemma encoder (fixed 256 text tokens). ids/mask int32 [1,256] → last_hidden [1,256,768] fp32.""" + def __init__(self, path, threads=8): + self.it = _interp(path, threads) + det = sorted(self.it.get_input_details(), key=lambda d: d["name"]) # args_0=ids, args_1=mask + self.i_ids, self.i_mask = det[0]["index"], det[1]["index"] + self.out = self.it.get_output_details()[0]["index"] + + def __call__(self, ids, mask): + self.it.set_tensor(self.i_ids, ids.astype(np.int32)) + self.it.set_tensor(self.i_mask, mask.astype(np.int32)) + self.it.invoke() + return self.it.get_tensor(self.out).copy() # (1,256,768) fp32 + + +# ───────────────────────── Sampler (shared, numpy) ───────────────────────── +def make_noise(T_lat, steps, seed): + rng = np.random.default_rng(seed) + x0 = rng.standard_normal((1, 256, T_lat)).astype(np.float32) + step_noise = [rng.standard_normal((1, 256, T_lat)).astype(np.float32) for _ in range(steps)] + return x0, step_noise + + +def sample(dit_forward: Callable, x0, step_noise, sigmas, cross, gcond, on_step=None, + paste_back=None): + """Rectified-flow pingpong. dit_forward(x,t,cross,gcond)->v (the velocity; CFG, + if any, is folded into dit_forward's return so this stays cfg-agnostic — matches + sa3_trt_core.sample_flow_pingpong, where model_fn returns cfg_v). cross/gcond are + passed through to dit_forward (the baked DiT ignores them — conditioning is in-graph). + + paste_back=(init_lat, keep_mask): after every step, restore the preserved region + (keep_mask 1=keep init, 0=regenerate) so inpainting leaves untouched regions + bit-exact. Mirrors sample_flow_pingpong's paste_back (applied post-renoise).""" + steps = len(sigmas) - 1 + x = x0.copy() + for i in range(steps): + tc, tn = float(sigmas[i]), float(sigmas[i + 1]) + v = dit_forward(x, tc, cross, gcond) + denoised = x - tc * v + if i < steps - 1 and tn > 0: + x = (1 - tn) * denoised + tn * step_noise[i] + else: + x = denoised + if paste_back is not None: + init_lat, keep_mask = paste_back + x = init_lat * keep_mask + x * (1.0 - keep_mask) + if on_step: + on_step(i + 1, steps) + return x diff --git a/optimized/tflite/models/tflite/.gitkeep b/optimized/tflite/models/tflite/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/optimized/tflite/models/tokenizer.model b/optimized/tflite/models/tokenizer.model new file mode 100644 index 0000000..14a2422 Binary files /dev/null and b/optimized/tflite/models/tokenizer.model differ diff --git a/optimized/tflite/output/.gitkeep b/optimized/tflite/output/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/optimized/tflite/requirements.txt b/optimized/tflite/requirements.txt new file mode 100644 index 0000000..c02db48 --- /dev/null +++ b/optimized/tflite/requirements.txt @@ -0,0 +1,5 @@ +ai_edge_litert>=1.0 +numpy>=1.24 +sentencepiece>=0.2 +soundfile>=0.12 +huggingface_hub>=0.20 diff --git a/optimized/tflite/sa3 b/optimized/tflite/sa3 new file mode 100755 index 0000000..c57c3cc --- /dev/null +++ b/optimized/tflite/sa3 @@ -0,0 +1,80 @@ +#!/usr/bin/env bash +# +# Wrapper around scripts/sa3_tflite.py — runs it via the project's .venv so no +# activation is needed. +# +# If setup is incomplete (no uv, no .venv) you'll be prompted to run +# install.sh; saying yes runs the installer and then resumes the command. +# +# Usage: ./sa3 --prompt "lofi house" --dit sm-music --decoder same-s --out a.wav +# +set -euo pipefail + +# Pre-emptively add the uv installer's default bin dir to PATH. The user's +# shell profile only updates new shells; this lets us find uv in the current +# process tree (matters right after a fresh `curl|sh` install). +export PATH="${XDG_BIN_HOME:-$HOME/.local/bin}:$HOME/.local/bin:$PATH" + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" + +if [[ -t 1 ]]; then + RED=$'\033[1;31m'; YELLOW=$'\033[1;33m'; CYAN=$'\033[1;36m' + BOLD=$'\033[1m'; DIM=$'\033[2m'; RESET=$'\033[0m' +else + RED=""; YELLOW=""; CYAN=""; BOLD=""; DIM=""; RESET="" +fi + +fail() { printf '\n%serror%s: %s\n' "$RED" "$RESET" "$1" >&2; exit 1; } + +# Prompts the user to run install.sh. Honors `-y` / non-interactive stdin. +# Resumes (returns 0) on success so the caller can re-attempt the run; +# exits 1 if the user declines or install fails. +prompt_install() { + local reason="$1" + printf '\n%s%s%s\n' "$YELLOW" "$reason" "$RESET" >&2 + + if [[ ! -t 0 ]]; then + printf '%serror%s: stdin is not a TTY — re-run %s./install.sh%s manually.\n' \ + "$RED" "$RESET" "$BOLD" "$RESET" >&2 + exit 1 + fi + + printf ' %sRun ./install.sh now to set up?%s %s[Y/n]%s ' \ + "$BOLD" "$RESET" "$DIM" "$RESET" + read -r REPLY < /dev/tty + case "$REPLY" in + ""|y|Y|yes|YES) ;; + *) fail "install declined — re-run ./install.sh manually when ready." ;; + esac + + printf '\n%s→ running ./install.sh%s\n' "$CYAN" "$RESET" >&2 + if ! "$SCRIPT_DIR/install.sh"; then + fail "installer failed — see output above." + fi +} + +# ── infrastructure checks ─────────────────────────────────────────────────── +if [[ ! -f "$SCRIPT_DIR/scripts/sa3_tflite.py" ]]; then + fail "scripts/sa3_tflite.py not found — repo files are missing or moved." +fi + +needs_install=0 +missing=() +if ! command -v uv >/dev/null 2>&1; then + needs_install=1; missing+=("uv") +fi +if [[ ! -d "$SCRIPT_DIR/.venv" ]]; then + needs_install=1; missing+=(".venv/") +fi + +if [[ $needs_install -eq 1 ]]; then + prompt_install "Setup incomplete — missing: ${missing[*]}" +fi + +# ── hand off to scripts/sa3_tflite.py ─────────────────────────────────────── +# Invoke .venv/bin/python directly rather than via `uv run`. `uv run` walks up +# looking for a pyproject.toml and (in newer uv versions) with --no-project +# creates an ephemeral env with no deps, ignoring our .venv. Direct invocation +# is shorter, faster, and side-steps both bugs. +cd "$SCRIPT_DIR" +exec "$SCRIPT_DIR/.venv/bin/python" scripts/sa3_tflite.py "$@" diff --git a/optimized/tflite/scripts/examples.py b/optimized/tflite/scripts/examples.py new file mode 100644 index 0000000..d6a24a2 --- /dev/null +++ b/optimized/tflite/scripts/examples.py @@ -0,0 +1,137 @@ +"""Shared, colored "Try these commands" block. + +Used by: +- scripts/install.py — printed at the end of install.sh +- scripts/sa3_tflite.py — appended to `./sa3 --help` + +Examples render with the user's `./sa3` wrapper when present (falling back to +`.venv/bin/python` or bare `python`), and only show entries for DiT bundles that +the user actually has installed. Bundles not yet on disk get a "re-run install.sh, +or just use them — weights auto-download" note at the bottom. +""" + +from __future__ import annotations + +import os +import sys +from pathlib import Path + +# This file lives in /scripts/. SCRIPT_DIR points at the project +# root (where ./sa3, ./install.sh, models/, .venv/ live). +SCRIPT_DIR = Path(__file__).resolve().parent.parent + + +# ── ANSI colours (safe no-ops when stdout isn't a TTY) ────────────────────── +def _c(code: str) -> str: + return code if sys.stdout.isatty() else "" + +BOLD = _c("\033[1m") +CYAN = _c("\033[1;36m") +GREEN = _c("\033[1;32m") +YELLOW = _c("\033[1;33m") +DIM = _c("\033[2m") +RESET = _c("\033[0m") + + +def _have(binary: str) -> bool: + from shutil import which + return which(binary) is not None + + +def _py_invocation() -> tuple[str, str, bool]: + """Return (command-to-print, tip-or-empty, is_wrapper). + + Prefer the ./sa3 wrapper if present, else `.venv/bin/python scripts/sa3_tflite.py`, + else fall back to a bare `python scripts/sa3_tflite.py`. + """ + wrapper = SCRIPT_DIR / "sa3" + if wrapper.exists() and os.access(wrapper, os.X_OK): + return "./sa3", "the ./sa3 wrapper uses .venv automatically", True + venv_py = SCRIPT_DIR / ".venv" / "bin" / "python" + if venv_py.exists(): + return str(venv_py.relative_to(SCRIPT_DIR)), "", False + if Path(sys.prefix).resolve() == (SCRIPT_DIR / ".venv").resolve(): + return ".venv/bin/python", "source .venv/bin/activate # to run `python` directly", False + return "python", "", False + + +def print_example_commands(header: str | None = None) -> None: + """Print the categorized example-commands block. + + `header` is the title line inside the rule block. Defaults to a neutral + "Examples:" for --help; install.sh passes "✓ Install complete. …". + """ + from weights import bundle_status + + py, tip, is_wrapper = _py_invocation() + prefix = py if is_wrapper else f"{py} scripts/sa3_tflite.py" + + def hdr(text: str) -> None: + print(f"\n {CYAN}{text}{RESET}") + def cmd(args: str, comment: str = "") -> None: + line = f"{prefix} {args}" + if comment: + print(f" {GREEN}$ {line}{RESET} {DIM}# {comment}{RESET}") + else: + print(f" {GREEN}$ {line}{RESET}") + + if header is None: + header = f"{BOLD}Examples:{RESET}" + + print(f"\n{BOLD}━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━{RESET}") + print(f" {header}") + print(f"{BOLD}━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━{RESET}") + + if tip: + print(f"\n {DIM}tip: {tip}{RESET}") + + have = {name: bundle_status(name) == (4, 4) for name in ("sm-music", "sm-sfx", "medium")} + + # ── Basic generation ───────────────────────────────────────────── + hdr("🎵 Generate audio from a prompt") + if have["sm-music"]: + cmd('--prompt "lofi house loop, mellow piano" \\\n' + f' --dit sm-music --decoder same-s --seconds 30 --out lofi.wav', + "fast music generation") + if have["sm-sfx"]: + cmd('--prompt "footsteps on gravel, then a door slamming" \\\n' + f' --dit sm-sfx --decoder same-s --seconds 20 --out sfx.wav', + "sound-effect generation") + if have["medium"]: + cmd('--prompt "A beautiful piano arpeggio grows into a cinematic climax" \\\n' + f' --dit medium --decoder same-l --seconds 30 --out piano.wav', + "highest quality (chunked SAME-L decode)") + + # ── Playback ───────────────────────────────────────────────────── + hdr("▶ Play immediately after generation") + one_dit = "sm-music" if have["sm-music"] else ("sm-sfx" if have["sm-sfx"] else "medium") + one_dec = "same-l" if one_dit == "medium" else "same-s" + cmd(f'--prompt "ambient drone" --dit {one_dit} --decoder {one_dec} \\\n' + f' --seconds 20 --out drone.wav --play', + "writes WAV + plays via afplay (Ctrl-C stops both)") + + # ── Audio-to-audio + inpaint ───────────────────────────────────── + hdr("🎚️ Audio-to-audio & inpainting (requires an input WAV)") + cmd(f'--prompt "jazz fusion with electric piano" --dit {one_dit} --decoder {one_dec} \\\n' + f' --init-audio funk.wav --init-noise-level 0.7 --out funk_jazz.wav', + "variation: 0.4-0.8 typical, higher = more change") + cmd(f'--prompt "explosive drum break" --dit {one_dit} --decoder {one_dec} \\\n' + f' --init-audio funk.wav --inpaint-range "4,7" --out funk_drums.wav', + "regenerate seconds 4-7, keep rest") + + # ── CFG & negative prompts ─────────────────────────────────────── + hdr("🎯 Steer with CFG + negative prompts") + cmd(f'--prompt "ambient drone" --cfg 3.0 \\\n' + f' --negative-prompt "drums, vocals, distortion" \\\n' + f' --dit {one_dit} --decoder {one_dec} --out clean_drone.wav', + "cfg > 1.0 toward prompt, neg pushes away (sequential dual-pass)") + + # Bundles not installed (offer the on-demand path) + missing = [name for name, ok in have.items() if not ok] + if missing: + print(f"\n {YELLOW}note:{RESET} bundles not installed: {', '.join(missing)}.") + print(f" Re-run {BOLD}./install.sh{RESET} to pick them up, or just use them in" + f" {BOLD}./sa3{RESET} —") + print(f" weights auto-download from HuggingFace on first use.") + + print(f"\n{BOLD}━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━{RESET}\n") diff --git a/optimized/tflite/scripts/install.py b/optimized/tflite/scripts/install.py new file mode 100644 index 0000000..d87ab7e --- /dev/null +++ b/optimized/tflite/scripts/install.py @@ -0,0 +1,151 @@ +#!/usr/bin/env python3 +"""Internal install plumbing — do NOT invoke directly. Use ../install.sh from the project root. + +Called by install.sh after it has set up uv + .venv. Responsibilities: +1. (Optionally) pip-install requirements.txt — skipped when INSTALL_SKIP_PIP=1 + is set (install.sh sets this since uv already handled deps). +2. Ask which DiT bundles to download, then fetch them from HuggingFace. +3. Print example commands the user can copy-paste. + +Unlike the MLX release, this TFLite/LiteRT stack is portable CPU — it runs on +macOS/Linux, x86/ARM. No Apple-Silicon requirement. +""" + +from __future__ import annotations + +import os +import subprocess +import sys +from pathlib import Path + +# This file lives in /scripts/. SCRIPT_DIR points at the project +# root (where models/, requirements.txt, .venv/, sa3 wrapper live). +SCRIPT_DIR = Path(__file__).resolve().parent.parent +sys.path.insert(0, str(Path(__file__).resolve().parent)) # `from weights import …` + +MIN_PY = (3, 9) +SKIP_PIP = os.environ.get("INSTALL_SKIP_PIP") == "1" + + +def step(msg: str) -> None: + print(f"\n\033[1;36m→ {msg}\033[0m") + + +def check_environment() -> None: + """Bail out early with a clear message if the interpreter is too old.""" + py = sys.version_info + if (py.major, py.minor) < MIN_PY: + print( + f"\n\033[1;31merror\033[0m: this script is running under Python " + f"{py.major}.{py.minor}, but this stack requires Python " + f"{MIN_PY[0]}.{MIN_PY[1]}+.\n\n" + f"Re-run install.py with a Python {MIN_PY[0]}.{MIN_PY[1]}+ interpreter, e.g.:\n" + f" /path/to/python3.11 install.py\n\n" + f"On macOS: brew install python@3.11\n" + f"On Debian/Ubuntu: apt install python3.11\n", + file=sys.stderr, + ) + sys.exit(1) + # TFLite / LiteRT (ai_edge_litert) is portable across macOS/Linux, x86/ARM — + # no platform gate here (unlike the Metal-backed MLX release). + + +def pip_install_requirements() -> None: + if SKIP_PIP: + step("Dependencies already installed by install.sh (uv) — skipping pip step") + return + step("Installing Python dependencies") + req = SCRIPT_DIR / "requirements.txt" + + # Prefer pip; fall back to uv if pip isn't available in this interpreter + # (e.g. user is running install.py directly inside a uv-created venv + # that wasn't seeded with pip). + pip_available = subprocess.run( + [sys.executable, "-m", "pip", "--version"], + capture_output=True, + ).returncode == 0 + + if pip_available: + cmd = [sys.executable, "-m", "pip", "install", "-r", str(req)] + elif _have("uv"): + cmd = ["uv", "pip", "install", "--python", sys.executable, "-r", str(req)] + else: + print( + f"\n\033[1;31merror\033[0m: neither pip nor uv is available in this " + f"interpreter ({sys.executable}).\n" + f"Install pip with `python -m ensurepip --upgrade`, or use ./install.sh " + f"which sets everything up via uv.", + file=sys.stderr, + ) + sys.exit(1) + print(f" $ {' '.join(cmd)}") + subprocess.check_call(cmd) + + +def _have(binary: str) -> bool: + """True if the given binary is on PATH.""" + from shutil import which + return which(binary) is not None + + +def main() -> None: + import argparse + ap = argparse.ArgumentParser( + description="SA3 TFLite install — non-interactive bundle downloader + post-install help.", + epilog="Normally invoked via ../install.sh, not directly.", + ) + ap.add_argument("--download", default="", + metavar="BUNDLES", + help="Comma-separated list of bundles to pre-download " + "(sm-music, sm-sfx, medium). Without this flag, " + "nothing is downloaded — sa3_tflite.py will fetch any " + "missing weights from HuggingFace on first use.") + cli = ap.parse_args() + + check_environment() + pip_install_requirements() + + # Import after pip install so a fresh checkout works. + from weights import DIT_BUNDLES, SHARED, BUNDLE_SIZES, bundle_status, ensure_local + + step("Current weights status") + for name in DIT_BUNDLES: + present, total = bundle_status(name) + mark = "✓" if present == total else " " + print(f" [{mark}] {name:9s} {present}/{total} files present ({BUNDLE_SIZES[name]})") + + chosen: list[str] = [] + if cli.download.strip(): + chosen = [b.strip() for b in cli.download.split(",") if b.strip()] + unknown = [b for b in chosen if b not in DIT_BUNDLES] + if unknown: + print(f"\n\033[1;31merror\033[0m: unknown bundle(s): {', '.join(unknown)}. " + f"Choices: {', '.join(DIT_BUNDLES)}", file=sys.stderr) + sys.exit(1) + + if chosen: + step(f"Downloading {len(chosen)} bundle(s): {', '.join(chosen)}") + seen: set[str] = set() + for name in chosen: + print(f"\n[{name}]") + items = DIT_BUNDLES[name] + SHARED + for rel, _hf in items: + if rel in seen: + continue + seen.add(rel) + ensure_local(rel) + else: + step("No --download set — weights will auto-download on first ./sa3 use") + print(f" To pre-download instead, pass: {sys.executable.split('/')[-1]} install.py --download sm-music") + print(f" or: ./install.sh --download sm-music,medium") + + from examples import print_example_commands, BOLD, GREEN, RESET + print_example_commands(header=f"{BOLD}{GREEN}✓ Install complete.{RESET} Try these commands:") + + +if __name__ == "__main__": + try: + main() + except KeyboardInterrupt: + print("\nAborted.") + sys.exit(130) diff --git a/optimized/tflite/scripts/sa3_tflite.py b/optimized/tflite/scripts/sa3_tflite.py new file mode 100644 index 0000000..804e645 --- /dev/null +++ b/optimized/tflite/scripts/sa3_tflite.py @@ -0,0 +1,755 @@ +"""SA3 text-to-audio (+ audio-to-audio + inpainting + CFG) on CPU via the BAKED-I/O +varlen TFLite / LiteRT models — the portable-CPU sibling of the MLX and TensorRT +releases. CLI at feature parity with sa3_mlx.py / sa3_trt.py, sharing their flag names. + +Baked I/O (ONNX / TensorRT convention — conditioner + patch/unpatch are IN-GRAPH): + DiT (6-input): x[1,256,L], t[1], t5_hidden[1,256,768], t5_mask[1,256], + seconds_total[1], local_add_cond[1,257,L] -> velocity[1,256,L] + (feed RAW T5 outputs; prompt-padding + seconds-embed happen in-graph) + Encoder (audio-in): audio[1,2,N] -> latents[1,256,N/4096] (patch-encode baked) + Decoder (audio-out): latents[1,256,L] -> audio[1,2,L*4096] (unpatch baked) + T5Gemma: t5gemma/encoder_fp16.tflite (fixed 256 text tokens; tokenizer bundled) + +Modes (identical flags to sa3_mlx / sa3_trt): + text-to-audio --prompt P + audio-to-audio --prompt P --init-audio IN.wav [--init-noise-level σ] + inpainting --prompt P --init-audio IN.wav --inpaint-range START,END + negative CFG --prompt P --cfg N [--negative-prompt P_NEG] [--apg A] + +CFG uses the TRT/ONNX baked-conditioner convention: the uncond branch feeds the negative +prompt's T5 output (or an all-zero hidden+mask, which the in-graph conditioner turns into +learned padding embeddings) — it does NOT zero cross_attn like MLX (our DiT bakes the +conditioner, so there's no raw cross_attn). CFG runs cond+uncond as ONE batch=2 invoke on the +variable-batch DiT (--cfg-batched, default; ~7-29%% faster on Apple-Silicon AMX) or as a +sequential batch=1 dual-pass (--no-cfg-batched). Both are bit-identical. + +All model files auto-download from HuggingFace (stabilityai/stable-audio-3-optimized) +on first use and symlink into models/tflite/ from the HF cache. See scripts/weights.py. +""" +from __future__ import annotations +import argparse, math, os, random, re, subprocess, sys, termios, time, tty, wave +from pathlib import Path +import numpy as np + +REPO = Path(__file__).resolve().parent.parent # project root (scripts/ is one level down) +sys.path.insert(0, str(REPO)) # so `from models.defs.* import` resolves +sys.path.insert(0, str(REPO / "scripts")) # so `from weights import *` resolves + +from models.defs import tflite_pipeline as P # Tokenizer, T5GemmaTFLite, build_pingpong_schedule, make_noise, sample, save_wav +from weights import ensure_local, is_present + +SAMPLE_RATE = 44100 +SAMPLES_PER_LATENT = 4096 +COND_TOKENS = 256 +COND_DIM = 768 +MIN_SIGMA = 0.01 # rf_denoiser is undefined at t≈0 → NaN below this + +# ─── Model manifest (local rel paths; resolved via weights.ensure_local at load) ─── +DIT_REL = { + "sm-music": "models/tflite/sa3-sm-music/dit_fp32.tflite", + "sm-sfx": "models/tflite/sa3-sm-sfx/dit_fp32.tflite", + "medium": "models/tflite/sa3-m/dit_fp32.tflite", +} +DEC_REL = { + "same-s": "models/tflite/same-s/dec_fp32.tflite", + "same-l": "models/tflite/same-l/dec_fp32.tflite", +} +ENC_REL = { + "same-s": "models/tflite/same-s/enc_fp32.tflite", + "same-l": "models/tflite/same-l/enc_fp32.tflite", +} +T5_REL = "models/tflite/t5gemma/encoder_fp16.tflite" +DEFAULT_DECODER = {"sm-music": "same-s", "sm-sfx": "same-s", "medium": "same-l"} + +# SAME-L dense SWA mask is O(S^2): chunk long decodes. SAME-S has a narrow field: whole. +SAMEL_CHUNK = 64 # latent tokens/window — throughput optimum (6.5x RT) vs the O(S^2) dense SWA mask +SAMEL_OVERLAP = 8 # symmetric latent-token context each interior side (SAME-L needs >=8) + + +def _free(): + import gc; gc.collect() + + +# ─── ANSI display (match sa3_mlx.py style) ───────────────────────────────── +_USE_COLOR = sys.stdout.isatty() +_RULE_W = 64 +def _c(code, s): return f"\x1b[{code}m{s}\x1b[0m" if _USE_COLOR else s +def bold(s): return _c("1", s) +def dim(s): return _c("2", s) +def cyan(s): return _c("36", s) +def yellow(s): return _c("33", s) +def green(s): return _c("32", s) +def magenta(s): return _c("35", s) +def red(s): return _c("31", s) +def rule(ch="━"): print(cyan(ch * _RULE_W)) +def banner(t): + rule(); print(f" {bold(t)}"); rule() +def stage(idx, label, ms=None): + head = f" {cyan(idx)} {bold(label)}" + if ms is None: + print(head, flush=True); return + visible = len(f" {idx} {label}") + dots = dim("·" * max(2, _RULE_W - visible - 9)) + print(f"{head} {dots} {yellow(f'{ms:>5.0f} ms')}", flush=True) +def sub(t): print(f" {dim(t)}", flush=True) + + +# ─── Interactive arrow-key picker (from sa3_mlx.py) ──────────────────────── +def _arrow_pick(prompt: str, options: list[str], default: str | None = None) -> str: + """Tiny arrow-key picker — no external deps, posix termios only. + + Up/Down to move, Enter to select, Ctrl-C to abort. Falls back to a + numeric prompt when stdin isn't a TTY (piped input, CI, etc.). + """ + if not sys.stdin.isatty(): + print(prompt) + for i, o in enumerate(options): + mark = "*" if o == default else " " + print(f" {mark} [{i}] {o}") + s = input(f"Choose [0-{len(options)-1}] (Enter for default): ").strip() + if s == "": + return default or options[0] + if s.isdigit() and 0 <= int(s) < len(options): + return options[int(s)] + return s if s in options else (default or options[0]) + + idx = options.index(default) if default in options else 0 + fd = sys.stdin.fileno() + old = termios.tcgetattr(fd) + print(prompt) + for _ in options: + print() + try: + tty.setcbreak(fd) + while True: + sys.stdout.write(f"\x1b[{len(options)}A") + for i, o in enumerate(options): + if i == idx: + sys.stdout.write(f"\x1b[2K\x1b[36m▶ {o}\x1b[0m\n") + else: + sys.stdout.write(f"\x1b[2K {o}\n") + sys.stdout.flush() + ch = sys.stdin.read(1) + if ch == "\x1b": + seq = sys.stdin.read(2) + if seq == "[A": + idx = (idx - 1) % len(options) + elif seq == "[B": + idx = (idx + 1) % len(options) + elif ch in ("\n", "\r"): + return options[idx] + elif ch == "\x03": + raise KeyboardInterrupt + finally: + termios.tcsetattr(fd, termios.TCSADRAIN, old) + + +def prompt_user_if_missing(args): + """Interactive arrow-key selection when --dit / --decoder aren't supplied.""" + if args.dit is None: + args.dit = _arrow_pick("Choose DiT model:", list(DIT_REL.keys()), default="sm-music") + print(f" → {args.dit}") + if args.decoder is None: + suggested = DEFAULT_DECODER[args.dit] + args.decoder = _arrow_pick("Choose audio decoder:", list(DEC_REL.keys()), default=suggested) + print(f" → {args.decoder}") + if args.seed is None: + args.seed = random.randint(0, 2**31 - 1) + return args + + +def _preflight_download(args, dec) -> None: + """Resolve every model file this run needs and download any that are missing — + BEFORE the banner prints and BEFORE the wall-clock starts. Network time then + isn't charged against '×realtime' and the user sees download progress as a + clearly separate setup step.""" + needed = [T5_REL, DIT_REL[args.dit], DEC_REL[dec]] + if args.init_audio: + needed.append(ENC_REL[dec]) + missing = [p for p in needed if not is_present(p)] + if not missing: + return + print(f" Fetching {len(missing)} missing model file(s) before starting:") + for rel in missing: + ensure_local(rel) + print() + + +# ─── WAV read (16-bit PCM fast path + ffmpeg fallback) ───────────────────── +def read_wav(path: str) -> np.ndarray: + """Return (2, T) float32 in [-1, 1]. 16-bit/44.1k native; else via ffmpeg. Mono→stereo.""" + try: + with wave.open(path, "rb") as w: + nch, sw, sr, n = w.getnchannels(), w.getsampwidth(), w.getframerate(), w.getnframes() + if sr == SAMPLE_RATE and sw == 2: + raw = np.frombuffer(w.readframes(n), np.int16).astype(np.float32) / 32767.0 + if nch == 1: + return np.stack([raw, raw], 0) + return raw.reshape(-1, nch).T[:2] + except wave.Error: + pass + try: + r = subprocess.run(["ffmpeg", "-v", "error", "-i", path, "-f", "s16le", + "-ar", str(SAMPLE_RATE), "-ac", "2", "-"], + capture_output=True, check=True) + except FileNotFoundError: + raise RuntimeError(f"{path}: unsupported WAV format. Install ffmpeg for 24/32-bit or " + f"non-44.1kHz input: brew install ffmpeg") + except subprocess.CalledProcessError as e: + raise RuntimeError(f"{path}: ffmpeg failed — {e.stderr.decode().strip()}") + raw = np.frombuffer(r.stdout, np.int16).astype(np.float32) / 32767.0 + return raw.reshape(-1, 2).T + + +# ─── Baked DiT backend (6-input; batched OR sequential CFG + inpaint local_add_cond) ─── +def _apply_cfg(x, t, v_cond, v_uncond, cfg, apg): + """Combine cond/uncond velocities in denoised space (RF), optional APG. fp32. + Matches sa3_trt_core.model_fn. Returns cfg_v so `denoised = x - t*v` is the guided one.""" + sigma = np.float32(t) + xf = x.astype(np.float32) + cond_d = xf - v_cond.astype(np.float32) * sigma + uncond_d = xf - v_uncond.astype(np.float32) * sigma + diff = cond_d - uncond_d + if apg <= 0.0: + cfg_diff = diff + else: + norm = np.sqrt((cond_d * cond_d).sum(axis=(-2, -1), keepdims=True)) + unit = cond_d / np.maximum(norm, 1e-8) + parallel = (diff * unit).sum(axis=(-2, -1), keepdims=True) * unit + diff_orth = diff - parallel + cfg_diff = diff_orth if apg >= 1.0 else (apg * diff_orth + (1.0 - apg) * diff) + cfg_d = cond_d + (cfg - 1.0) * cfg_diff + return ((xf - cfg_d) / sigma).astype(np.float32) + + +class BakedDiT: + """model_fn(x,t,cross,gcond)->v compatible with P.sample. cross/gcond are IGNORED + (conditioning is baked in-graph, driven by the raw T5 outputs held here). + + cfg==1.0 -> one batch=1 forward with the (cond) T5 output. + cfg!=1.0 -> CFG combining a cond and an uncond (null_hidden/null_mask) velocity in + denoised space with optional APG. Two interchangeable backends (bit-identical + output; auto-falls-back to batch=1 for cfg==1.0): + batched=True (default): ONE batch=2 invoke/step — cond=row0, uncond=row1 — on the + canonical variable-batch DiT (its batch axis is dynamic). ~7-29%% faster + on Apple-Silicon (the AMX matrix unit amortizes weight loads across the 2 rows). + batched=False: SEQUENTIAL dual-pass — two batch=1 invokes/step, like TensorRT (whose + engine is static-batch=1). Use if a backend/delegate dislikes batch=2. + Returns cfg_v so P.sample's `denoised = x - t*v` yields the guided denoised.""" + def __init__(self, path, L, t5_hidden, t5_mask, seconds, threads=8, + cfg=1.0, apg=1.0, null_hidden=None, null_mask=None, local_add_cond=None, + batched=True): + from ai_edge_litert import interpreter as tfl + self.L = L + self.cfg = float(cfg); self.apg = float(apg) + self.n_fwd = 0 + # Batched CFG only applies when there's an uncond branch to co-batch (cfg != 1.0). + self.batched = bool(batched) and self.cfg != 1.0 + B = 2 if self.batched else 1 + self.B = B + + self.it = tfl.Interpreter(model_path=str(path), num_threads=threads) + det = self.it.get_input_details() + def pick(pred): return [d for d in det if pred(d)] + self.i_x = pick(lambda d: len(d["shape"]) == 3 and d["shape"][1] == 256 and d["shape"][2] != COND_DIM)[0] + self.i_lac = pick(lambda d: len(d["shape"]) == 3 and d["shape"][1] == 257)[0] + self.i_t5h = pick(lambda d: len(d["shape"]) == 3 and d["shape"][2] == COND_DIM)[0] + self.i_t5m = pick(lambda d: len(d["shape"]) == 2)[0] + scalars = sorted(pick(lambda d: len(d["shape"]) == 1), key=lambda d: d["name"]) + self.i_t, self.i_sec = scalars[0], scalars[1] # args_1=t < args_4=seconds by name + self.out = self.it.get_output_details()[0]["index"] + # Resize batch axis to B (the canonical DiT is variable-batch) + length axis to L. + self.it.resize_tensor_input(self.i_x["index"], [B, 256, L], strict=False) + self.it.resize_tensor_input(self.i_lac["index"], [B, 257, L], strict=False) + self.it.resize_tensor_input(self.i_t5h["index"], [B, COND_TOKENS, COND_DIM], strict=False) + self.it.resize_tensor_input(self.i_t5m["index"], [B, COND_TOKENS], strict=False) + self.it.resize_tensor_input(self.i_t["index"], [B], strict=False) + self.it.resize_tensor_input(self.i_sec["index"], [B], strict=False) + self.it.allocate_tensors() + + t5h = t5_hidden.astype(np.float32) + t5m = t5_mask.astype(np.float32) + sec = np.array([np.float32(seconds)], np.float32) + lac = (np.zeros((1, 257, L), np.float32) if local_add_cond is None + else local_add_cond.astype(np.float32)) + self.null_h = None if null_hidden is None else null_hidden.astype(np.float32) + self.null_m = None if null_mask is None else null_mask.astype(np.float32) + if self.batched: + # row0 = cond, row1 = uncond. t5_hidden/mask differ per row; seconds + lac are + # shared → tiled. All four are constant across steps → set resident once. + self.it.set_tensor(self.i_t5h["index"], np.concatenate([t5h, self.null_h], axis=0)) + self.it.set_tensor(self.i_t5m["index"], np.concatenate([t5m, self.null_m], axis=0)) + self.it.set_tensor(self.i_sec["index"], np.concatenate([sec, sec], axis=0)) + self.it.set_tensor(self.i_lac["index"], np.concatenate([lac, lac], axis=0)) + else: + self.t5h, self.t5m = t5h, t5m + self.it.set_tensor(self.i_sec["index"], sec) # seconds + lac constant → resident + self.it.set_tensor(self.i_lac["index"], lac) + + def _fwd(self, x, t, t5h, t5m): + """One batch=1 invoke (cfg==1.0 or sequential CFG).""" + self.it.set_tensor(self.i_x["index"], x.astype(np.float32)) + self.it.set_tensor(self.i_t["index"], np.array([np.float32(t)], np.float32)) + self.it.set_tensor(self.i_t5h["index"], t5h) + self.it.set_tensor(self.i_t5m["index"], t5m) + self.it.invoke() + self.n_fwd += 1 + return self.it.get_tensor(self.out).copy() + + def _fwd_batched(self, x, t): + """One batch=2 invoke -> (v_cond, v_uncond). x is the shared batch=1 state (both CFG + branches denoise the same latent), tiled to [2,256,L]; t5h/t5m/sec/lac already resident.""" + x2 = np.concatenate([x, x], axis=0).astype(np.float32) + self.it.set_tensor(self.i_x["index"], x2) + self.it.set_tensor(self.i_t["index"], np.array([np.float32(t), np.float32(t)], np.float32)) + self.it.invoke() + self.n_fwd += 2 + v2 = self.it.get_tensor(self.out) + return v2[0:1].copy(), v2[1:2].copy() + + def __call__(self, x, t, cross=None, gcond=None): + if self.cfg == 1.0: + return self._fwd(x, t, self.t5h, self.t5m) + if self.batched: + v_cond, v_uncond = self._fwd_batched(x, t) + else: + v_cond = self._fwd(x, t, self.t5h, self.t5m) + v_uncond = self._fwd(x, t, self.null_h, self.null_m) + return _apply_cfg(x, t, v_cond, v_uncond, self.cfg, self.apg) + + +# ─── Baked audio-in encoder (audio -> latents; SAME-S needs even L) ──────── +class BakedEncoder: + def __init__(self, path, threads=8, needs_even=False): + from ai_edge_litert import interpreter as tfl + self.it = tfl.Interpreter(model_path=str(path), num_threads=threads) + self.i = self.it.get_input_details()[0]["index"] + self.o = self.it.get_output_details()[0]["index"] + self.needs_even = needs_even + self._cur_N = None + + def _resize(self, N): + if self._cur_N != N: + self.it.resize_tensor_input(self.i, [1, 2, N], strict=False) + self.it.allocate_tensors() + self._cur_N = N + + def encode(self, audio, T_lat): + """audio: (1,2,M), M a multiple of 4096 (caller pads to even L for SAME-S). + Returns latents (1,256,T_lat) — trimmed back to the natural (decoder-independent) T_lat.""" + self._resize(audio.shape[-1]) + self.it.set_tensor(self.i, audio.astype(np.float32)) + self.it.invoke() + return self.it.get_tensor(self.o)[:, :, :T_lat].copy() + + +# ─── Baked audio-out decoder (whole for SAME-S; chunked for SAME-L) ───────── +class BakedDecoder: + def __init__(self, path, threads=8, needs_even=False): + from ai_edge_litert import interpreter as tfl + self.tfl = tfl + self.path = str(path) + self.threads = threads + self.needs_even = needs_even # SAME-S: T_aud=L*16 must be %32 -> pad odd L->even + self.it = tfl.Interpreter(model_path=self.path, num_threads=threads) + self.i = self.it.get_input_details()[0]["index"] + self.o = self.it.get_output_details()[0]["index"] + self.i_det = self.it.get_input_details()[0] + self._cur_L = None + + def _resize(self, L): + if self._cur_L != L: + self.it.resize_tensor_input(self.i, [1, 256, L], strict=False) + self.it.allocate_tensors() + self._cur_L = L + + def decode_whole(self, latents): + L = latents.shape[2] + if self.needs_even and L % 2 != 0: + # SAME-S needs even L. Pad one edge-replicated latent token, decode at L+1, + # trim the extra token's audio. Narrow receptive field => negligible boundary + # effect on the kept L*4096 samples. Keeps odd-L requests working without + # changing the DiT/noise path (which stays natural-ceil, MLX/TRT-matched). + latents = np.concatenate([latents, latents[:, :, -1:]], axis=2) + self._resize(L + 1) + self.it.set_tensor(self.i, latents.astype(np.float32)) + self.it.invoke() + return self.it.get_tensor(self.o)[:, :, :L * SAMPLES_PER_LATENT].copy() + self._resize(L) + self.it.set_tensor(self.i, latents.astype(np.float32)) + self.it.invoke() + return self.it.get_tensor(self.o).copy() # [1,2,L*4096] + + def decode_chunked(self, latents, chunk, overlap, on_chunk=None): + """Stitch audio directly (stride = 4096 samples per latent token).""" + B, C, L = latents.shape + if L <= chunk: + return self.decode_whole(latents) + S = SAMPLES_PER_LATENT + out = np.zeros((B, 2, L * S), np.float32) + K = chunk - 2 * overlap + assert K > 0, (chunk, overlap) + core = 0 + n_windows = (L + K - 1) // K + wi = 0 + while core < L: + core_end = min(core + K, L) + win_start = core - overlap + win_end = win_start + chunk + if win_start < 0: + win_start, win_end = 0, chunk + if win_end > L: + win_end, win_start = L, L - chunk + win = latents[:, :, win_start:win_end] + y = self.decode_whole(win) # [1,2,chunk*4096] + ks = (core - win_start) * S + ke = (core_end - win_start) * S + out[:, :, core * S: core_end * S] = y[:, :, ks:ke] + wi += 1 + if on_chunk: + on_chunk(wi, n_windows) + core = core_end + return out + + +def valid_T_lat(seconds): + """seconds -> T_lat via natural ceil, DECODER-INDEPENDENT and identical to MLX/TRT + (sa3_mlx / sa3_trt.resolve_T_lat = max(1, ceil(seconds*44100/4096))). So MLX == TRT == + TFLite pick the SAME length -> no length-driven divergence, and true ODD-length + requests are honored. The DiT handles odd L natively; SAME-L takes any L; SAME-S + (even T_aud=L*16 % 32) pads odd->even at encode/decode and trims.""" + return max(1, int(np.ceil(seconds * SAMPLE_RATE / SAMPLES_PER_LATENT))) + + +class _HelpfulParser(argparse.ArgumentParser): + """argparse that prints full help (not just usage) when a flag is unknown / invalid, + and tacks the shared example-commands block onto the end of -h / --help.""" + def error(self, message): + sys.stderr.write(f"\nerror: {message}\n\n") + self.print_help(sys.stderr) + sys.exit(2) + def print_help(self, file=None): + super().print_help(file) + try: + from examples import print_example_commands + print_example_commands() + except Exception: + pass # never let an examples-block failure mask the actual --help + + +def main(): + ap = _HelpfulParser( + description="SA3 text-to-audio (+ audio-to-audio + inpainting + CFG) — baked-I/O varlen TFLite / CPU", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=("modes\n" + " text-to-audio --prompt P\n" + " audio-to-audio --prompt P --init-audio IN.wav [--init-noise-level σ]\n" + " inpainting --prompt P --init-audio IN.wav --inpaint-range START,END\n" + " negative CFG --prompt P --cfg N [--negative-prompt P_NEG] [--apg A]\n")) + # Inputs + ap.add_argument("--prompt", default=None, + help="Text prompt. Empty string = unconditional. If omitted, asked via stdin.") + ap.add_argument("--negative-prompt", default=None, + help="Negative prompt for CFG's uncond branch. Ignored when --cfg=1.0. " + "When unset and --cfg≠1.0, the uncond branch uses all-zero T5 hidden+mask " + "(→ learned padding embeddings in-graph).") + ap.add_argument("--init-audio", default=None, + help="WAV to start from. With --init-noise-level → audio-to-audio; with " + "--inpaint-range → inpainting. Encoder loaded automatically; audio is " + "trimmed/zero-padded to --seconds. Any format via ffmpeg fallback.") + ap.add_argument("--inpaint-range", default=None, + help="Inpaint time range 'START,END' in seconds (e.g. '5,10'). Requires " + "--init-audio. Regenerates the masked span, preserves the rest (paste-back).") + # Models + ap.add_argument("--dit", choices=["sm-music", "sm-sfx", "medium"], default=None, + help="DiT model (names match sa3_mlx / sa3_trt). If omitted, prompts " + "interactively with an arrow-key picker.") + ap.add_argument("--decoder", choices=["same-s", "same-l"], default=None, + help="Audio decoder. Default: same-s for sm-music/sm-sfx, same-l for medium. " + "If omitted, prompts interactively with an arrow-key picker.") + # Sampling + ap.add_argument("--seconds", type=float, default=30.0, + help="Output length. T_lat = ceil(seconds*44100/4096) (natural ceil, decoder-" + "independent, matches MLX/TRT). Final WAV trimmed to exactly --seconds.") + ap.add_argument("--steps", type=int, default=8, + help="Pingpong sampling steps (≥1). rf_denoiser is distilled for 8 (default).") + ap.add_argument("--seed", type=int, default=None, + help="Random seed. If omitted, a random seed is chosen and printed at the end.") + ap.add_argument("--init-noise-level", type=float, default=1.0, + help="σmax — schedule's starting noise level. With --init-audio: 0.4-0.8 varies, " + "1.0 = full regen (init ignored). Min %.2f (model NaNs at t≈0)." % MIN_SIGMA) + ap.add_argument("--cfg", type=float, default=1.0, + help="Classifier-Free Guidance scale. 1.0 = off (single pass). >1 toward prompt, " + "<1 toward uncond. Any value ≠1.0 runs cond+uncond each step (see --cfg-batched).") + ap.add_argument("--apg", type=float, default=1.0, + help="Adaptive Projected Guidance [0..1], only when --cfg≠1.0. 1.0 = full APG " + "(orthogonal projection), 0.0 = vanilla CFG. rf_denoiser default 1.0.") + ap.add_argument("--cfg-batched", action=argparse.BooleanOptionalAction, default=True, + help="When --cfg≠1.0, run cond+uncond as ONE batch=2 invoke on the variable-batch " + "DiT (default; ~7-29%% faster on Apple-Silicon AMX). --no-cfg-batched forces the " + "sequential batch=1 dual-pass (like TensorRT). Bit-identical either way.") + # Runtime / output + ap.add_argument("--threads", type=int, default=8, help="XNNPACK CPU threads.") + ap.add_argument("--free-models", action=argparse.BooleanOptionalAction, default=True, + help="Free each model after its last use to lower peak RAM (default on).") + ap.add_argument("--out", "-o", default=None, + help="Output WAV path. Relative → output/; absolute → as-is. " + "If omitted, auto-named from the prompt + seed.") + ap.add_argument("--play", action="store_true", + help="Play the WAV via macOS `afplay` after writing (blocking).") + args = ap.parse_args() + if args.steps < 1: + ap.error(f"--steps must be ≥ 1 (got {args.steps})") + + # Interactive fills (match MLX/TRT). + args = prompt_user_if_missing(args) + if args.prompt is None: + args.prompt = input("Prompt: ").strip() + if args.seed is None: + args.seed = random.randint(0, 2**31 - 1) + + dec = args.decoder or DEFAULT_DECODER[args.dit] + T_lat = valid_T_lat(args.seconds) + target_dur = T_lat * SAMPLES_PER_LATENT / SAMPLE_RATE + + # Resolve output path — auto-name from prompt+seed when --out is not given. + # Relative paths land in output/ (auto-created); absolute paths honored as-is. + # A relative path that already starts with "output/" is taken as-is (relative to + # cwd) so `--out output/foo.wav` doesn't become output/output/foo.wav. + if args.out is None: + slug = re.sub(r'[^a-z0-9]+', '_', args.prompt.lower()).strip('_')[:48] + args.out = f"{slug}_{args.seed}.wav" if slug else f"out_{args.seed}.wav" + out_path = Path(args.out) + if not out_path.is_absolute() and out_path.parts[:1] != ("output",): + out_path = REPO / "output" / out_path + out_path.parent.mkdir(parents=True, exist_ok=True) + args.out = str(out_path) + + # Inpaint validation → latent range. + inpaint_range = None + inp_start_sec = inp_end_sec = None + if args.inpaint_range is not None: + if args.init_audio is None: + sys.exit("error: --inpaint-range requires --init-audio (the audio to inpaint into)") + try: + s_str, e_str = args.inpaint_range.split(",") + inp_start_sec = float(s_str.strip()); inp_end_sec = float(e_str.strip()) + except ValueError: + sys.exit(f"error: --inpaint-range must be 'START,END' in seconds; got {args.inpaint_range!r}") + if not (0 <= inp_start_sec < inp_end_sec <= args.seconds): + sys.exit(f"error: invalid inpaint range {inp_start_sec}-{inp_end_sec}s " + f"(must satisfy 0 <= start < end <= {args.seconds}s)") + inp_start_lat = max(0, int(round(inp_start_sec * SAMPLE_RATE / SAMPLES_PER_LATENT))) + inp_end_lat = min(T_lat, int(round(inp_end_sec * SAMPLE_RATE / SAMPLES_PER_LATENT))) + inpaint_range = (inp_start_lat, inp_end_lat) + + sigma_max = float(args.init_noise_level) + if sigma_max < MIN_SIGMA: + sys.exit(f"error: --init-noise-level={sigma_max} too low (min {MIN_SIGMA}; model NaNs at t≈0)") + mode = ("inpaint" if inpaint_range else + "audio-to-audio" if args.init_audio else "text-to-audio") + + # ── Preflight: download any missing model files BEFORE the banner/wall-clock. ── + _preflight_download(args, dec) + + # ── Banner ── + print() + banner(f"SA3 → TFLite/CPU {mode}") + k = lambda s: dim(f"{s:>11}") + print(f" {k('prompt')} {bold(repr(args.prompt))}") + if args.negative_prompt: + suffix = "" if args.cfg != 1.0 else dim(" (ignored: --cfg=1.0)") + print(f" {k('neg prompt')} {bold(repr(args.negative_prompt))}{suffix}") + line = f" {k('dit')} {magenta(args.dit)} {k('decoder')} {magenta(dec)}" + if args.init_audio: + line += f" {k('encoder')} {magenta(dec)}" + line += f" {k('threads')} {args.threads}" + print(line) + if args.init_audio: + print(f" {k('init audio')} {bold(args.init_audio)}") + if inpaint_range: + s0, s1 = inpaint_range + print(f" {k('inpaint')} {bold(f'{inp_start_sec:.2f}s..{inp_end_sec:.2f}s')} " + f"{dim(f'(latent {s0}..{s1} of {T_lat})')}") + cfg_line = f" {k('σmax')} {bold(f'{sigma_max:.2f}')} {k('cfg')} {args.cfg}" + if args.cfg != 1.0: + cfg_line += dim(f" (apg={args.apg}, {'batched' if args.cfg_batched else 'sequential'} CFG)") + print(cfg_line) + print(f" {k('seconds')} {args.seconds}s {k('steps')} {args.steps} {k('seed')} {args.seed}") + print(f" {k('T_lat')} {T_lat} {dim(f'({target_dur:.2f}s → trimmed to {args.seconds:.2f}s)')}") + print() + + # Stage numbering (extra stage when encoding init audio). + N = 3 + (1 if args.init_audio else 0) + TAG = {"t5": f"[1/{N}]"} + if args.init_audio: + TAG.update(enc=f"[2/{N}]", dit=f"[3/{N}]", dec=f"[4/{N}]") + else: + TAG.update(dit=f"[2/{N}]", dec=f"[3/{N}]") + + t_wall = time.perf_counter() + + # ── T5Gemma (tokenize + encode; + negative prompt for CFG) ── + stage(TAG["t5"], "T5Gemma (tokenize + encode)") + t0 = time.perf_counter() + tok = P.Tokenizer() + ids, mask = tok(args.prompt) # (1,256) int32 each + t5 = P.T5GemmaTFLite(ensure_local(T5_REL), args.threads) + t5_hidden = t5(ids, mask) # (1,256,768) fp32 + null_h = null_m = None + if args.cfg != 1.0: + if args.negative_prompt: + n_ids, n_mask = tok(args.negative_prompt) + null_h = t5(n_ids, n_mask) + null_m = n_mask.astype(np.float32) + else: + # All-zero hidden+mask → in-graph conditioner emits learned padding embeds + # for every position (the standard unconditional branch). No extra T5 pass. + null_h = np.zeros((1, COND_TOKENS, COND_DIM), np.float32) + null_m = np.zeros((1, COND_TOKENS), np.float32) + t5_ms = (time.perf_counter() - t0) * 1000 + stage(TAG["t5"], "T5Gemma (tokenize + encode)", t5_ms) + sub(f"t5_hidden {t5_hidden.shape} mask sum={int(mask.sum())}" + + (f" neg ({'prompt' if args.negative_prompt else 'zeros'})" if null_h is not None else "")) + if args.free_models: + del t5; _free() + + # ── (audio-to-audio / inpaint) Encode init audio → init_latents ── + init_latents = None + if args.init_audio: + stage(TAG["enc"], f"Encode init audio → latents ({dec})") + t0 = time.perf_counter() + # SAME-S encoder needs even L; round the encode grid up, trim latents to T_lat. + enc_L = T_lat + 1 if (dec == "same-s" and T_lat % 2 != 0) else T_lat + target_samples = enc_L * SAMPLES_PER_LATENT + audio_in = read_wav(args.init_audio) # (2, T_in) + if audio_in.shape[-1] >= target_samples: + audio_in = audio_in[:, :target_samples] + init_action = f"trimmed to {target_samples} samples" + else: + pad = target_samples - audio_in.shape[-1] + audio_in = np.pad(audio_in, ((0, 0), (0, pad))) + init_action = f"zero-padded by {pad} samples" + audio_in = audio_in[None] # (1,2,target_samples) + enc = BakedEncoder(ensure_local(ENC_REL[dec]), args.threads, needs_even=(dec == "same-s")) + init_latents = enc.encode(audio_in, T_lat) # (1,256,T_lat) + enc_ms = (time.perf_counter() - t0) * 1000 + stage(TAG["enc"], f"Encode init audio → latents ({dec})", enc_ms) + sub(f"{init_action} latents {init_latents.shape}") + if args.free_models: + del enc; _free() + + # ── Build inpaint local_add_cond + paste-back, and initial noise ── + local_add_cond = None + paste_back = None + if inpaint_range is not None: + s0, s1 = inpaint_range + keep = np.ones((1, 1, T_lat), np.float32); keep[:, :, s0:s1] = 0.0 # 1=keep, 0=regen + masked = init_latents.astype(np.float32) * keep + local_add_cond = np.concatenate([keep, masked], axis=1) # (1,257,T_lat), TRT layout + paste_back = (init_latents.astype(np.float32), keep) + + x0, step_noise = P.make_noise(T_lat, args.steps, args.seed) # x0 = pure noise + if init_latents is not None and inpaint_range is None: + # rf_denoiser init mix (linear interp): noise = init*(1-σmax) + pure*σmax + x0 = init_latents.astype(np.float32) * (1.0 - sigma_max) + x0 * sigma_max + + # ── DiT load + pingpong sample ── + stage(TAG["dit"], f"DiT — load + sample ({args.steps} steps, σmax={sigma_max:.2f})") + t0 = time.perf_counter() + cfg_note = (("CFG batched (1× batch=2 invoke/step)" if args.cfg_batched + else "CFG sequential (2× batch=1 invokes/step)") if args.cfg != 1.0 else "") + print(f" {dim('loading baked DiT ' + args.dit + ' ...')}", flush=True) + backend = BakedDiT(ensure_local(DIT_REL[args.dit]), T_lat, t5_hidden, mask.astype(np.float32), + args.seconds, args.threads, cfg=args.cfg, apg=args.apg, + null_hidden=null_h, null_mask=null_m, local_add_cond=local_add_cond, + batched=args.cfg_batched) + load_ms = (time.perf_counter() - t0) * 1000 + sub(f"load {load_ms/1000:.1f}s" + (f" {cfg_note}" if cfg_note else "")) + + sig = P.build_pingpong_schedule(args.steps, sigma_max) + sched_str = " · ".join(f"{float(x):.3f}" for x in sig) + sub(f"schedule {sched_str}") + + step_prev = [time.perf_counter()] + def on_step(i, total): + now = time.perf_counter(); el = (now - step_prev[0]) * 1000; step_prev[0] = now + if _USE_COLOR: + bar_w = 20; filled = int(round(bar_w * i / total)) + bar = cyan("█" * filled) + dim("·" * (bar_w - filled)) + sys.stdout.write(f"\r\x1b[K {dim('sampling')} {bar} " + f"{bold(f'step {i}/{total}')} {yellow(f'{el:.0f} ms')}") + sys.stdout.flush() + else: + print(f" sampling step {i}/{total} {el:.0f} ms", flush=True) + + t0 = time.perf_counter() + latents = P.sample(backend, x0, step_noise, sig, None, None, + on_step=on_step, paste_back=paste_back) + samp_ms = (time.perf_counter() - t0) * 1000 + if _USE_COLOR: + sys.stdout.write("\r\x1b[K") + stage(TAG["dit"], f"DiT — load + sample ({args.steps} steps, σmax={sigma_max:.2f})", + load_ms + samp_ms) + sub(f"sample {samp_ms:.0f} ms ({samp_ms/max(args.steps,1):.0f} ms/step, " + f"{backend.n_fwd} forwards) latents {latents.shape}") + if args.free_models: + del backend; _free() + + # ── Decode (audio-out) + WAV ── + stage(TAG["dec"], f"Decoder ({dec}, audio-out) + WAV") + t0 = time.perf_counter() + print(f" {dim('loading baked decoder ' + dec + ' ...')}", flush=True) + decoder = BakedDecoder(ensure_local(DEC_REL[dec]), args.threads, needs_even=(dec == "same-s")) + load2_ms = (time.perf_counter() - t0) * 1000 + sub(f"load {load2_ms:.0f} ms") + + t0 = time.perf_counter() + if dec == "same-l" and T_lat > SAMEL_CHUNK: + print(f" {dim(f'SAME-L chunked decode (chunk={SAMEL_CHUNK}, ovl={SAMEL_OVERLAP}) ...')}", flush=True) + def on_chunk(i, n): + print(f" {dim(f'decode chunk {i}/{n}')}", flush=True) + audio = decoder.decode_chunked(latents, SAMEL_CHUNK, SAMEL_OVERLAP, on_chunk=on_chunk) + dmode = f"chunked (chunk={SAMEL_CHUNK}, ovl={SAMEL_OVERLAP})" + else: + print(f" {dim('whole decode ...')}", flush=True) + audio = decoder.decode_whole(latents) + dmode = "whole" + dec_ms = (time.perf_counter() - t0) * 1000 + + audio_np = audio[0] # (2, L*4096) + req = int(round(args.seconds * SAMPLE_RATE)) + if audio_np.shape[-1] > req: + audio_np = audio_np[:, :req] + P.save_wav(args.out, audio_np) + stage(TAG["dec"], f"Decoder ({dec}, audio-out) + WAV", load2_ms + dec_ms) + peak = float(np.abs(audio_np).max()); rms = float(np.sqrt((audio_np**2).mean())) + sub(f"decode {dmode} {dec_ms:.0f} ms audio {audio_np.shape} peak {peak:.3f} rms {rms:.3f}") + if args.free_models: + del decoder; _free() + + total = time.perf_counter() - t_wall + dur = audio_np.shape[-1] / SAMPLE_RATE + print() + rule() + print(f" {bold(green('done'))} {bold(f'{total:.2f}s')} wall → {dur:.1f}s audio → " + f"{bold(yellow(f'{dur/total:.2f}× realtime'))} {dim(f'seed {args.seed}')}") + abs_out = os.path.abspath(args.out) + try: + rel_out = os.path.relpath(abs_out) + except ValueError: + rel_out = abs_out + shown = rel_out if len(rel_out) <= len(abs_out) and not rel_out.startswith("..") else abs_out + print(f" {bold(green('▸ saved'))} {bold(shown)} {dim(f'({abs_out})' if shown != abs_out else '')}".rstrip()) + rule() + + if args.play: + try: + print(f" {bold('▶ playing')} {args.out} {dim('(Ctrl-C to stop)')}") + subprocess.run(["afplay", args.out], check=False) + except KeyboardInterrupt: + print() + + +if __name__ == "__main__": + main() diff --git a/optimized/tflite/scripts/weights.py b/optimized/tflite/scripts/weights.py new file mode 100644 index 0000000..3fb2121 --- /dev/null +++ b/optimized/tflite/scripts/weights.py @@ -0,0 +1,177 @@ +"""Shared weights manifest + downloader. + +Maps every TFLite / LiteRT model file the runtime needs to its position in the +`stabilityai/stable-audio-3-optimized` HuggingFace repo (under `tflite/…`). + +`install.py` calls `ensure_local` upfront for the bundles the user picks. +`sa3_tflite.py` calls `ensure_local` lazily, just before each model loads — so +a fresh checkout with no weights still works if the user is willing to wait for +the first run to download them. + +The SentencePiece tokenizer is BUNDLED at models/tokenizer.model (the .tflite +T5Gemma is encoder-only), so it's not in this manifest. +""" + +from __future__ import annotations + +from pathlib import Path + +REPO_ID = "stabilityai/stable-audio-3-optimized" +# weights.py lives in /scripts/; SCRIPT_DIR points at the project +# root so the local rel paths in the manifest ("models/tflite/foo.tflite") +# resolve against the actual project layout. +SCRIPT_DIR = Path(__file__).resolve().parent.parent + + +# Bundles the install script offers to the user. Each maps to a list of +# model files (local relative path on the left, HF repo path on the right). +# T5Gemma is in SHARED because every bundle needs it. The two small DiTs +# share the SAME-S codec; medium uses the SAME-L codec. + +DIT_BUNDLES: dict[str, list[tuple[str, str]]] = { + "sm-music": [ + ("models/tflite/sa3-sm-music/dit_fp32.tflite", "tflite/sa3-sm-music/dit_fp32.tflite"), + ("models/tflite/same-s/enc_fp32.tflite", "tflite/same-s/enc_fp32.tflite"), + ("models/tflite/same-s/dec_fp32.tflite", "tflite/same-s/dec_fp32.tflite"), + ], + "sm-sfx": [ + ("models/tflite/sa3-sm-sfx/dit_fp32.tflite", "tflite/sa3-sm-sfx/dit_fp32.tflite"), + ("models/tflite/same-s/enc_fp32.tflite", "tflite/same-s/enc_fp32.tflite"), + ("models/tflite/same-s/dec_fp32.tflite", "tflite/same-s/dec_fp32.tflite"), + ], + "medium": [ + ("models/tflite/sa3-m/dit_fp32.tflite", "tflite/sa3-m/dit_fp32.tflite"), + ("models/tflite/same-l/enc_fp32.tflite", "tflite/same-l/enc_fp32.tflite"), + ("models/tflite/same-l/dec_fp32.tflite", "tflite/same-l/dec_fp32.tflite"), + ], +} + +SHARED: list[tuple[str, str]] = [ + ("models/tflite/t5gemma/encoder_fp16.tflite", "tflite/t5gemma/encoder_fp16.tflite"), +] + +# Human-friendly bundle sizes (for the install prompt). Exact, from HF metadata. +BUNDLE_SIZES = { + "sm-music": "2.3 GB (small music DiT + SAME-S codec, all fp32)", + "sm-sfx": "2.3 GB (small sfx DiT + SAME-S codec, all fp32)", + "medium": "9.5 GB (medium DiT + SAME-L codec, all fp32)", +} +# T5Gemma (shared, fp16) adds ~0.6 GB the first time any bundle is fetched. + +# Flat (local_rel_path → hf_path) lookup — used by sa3_tflite.py for lazy +# auto-download at load time. +FLAT_MANIFEST: dict[str, str] = {} +for _items in DIT_BUNDLES.values(): + for _rel, _hf in _items: + FLAT_MANIFEST[_rel] = _hf +for _rel, _hf in SHARED: + FLAT_MANIFEST[_rel] = _hf + + +def _hf_token_configured() -> bool: + """True if any HF token is set — env var or cached login on disk.""" + import os + if os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN"): + return True + try: + from huggingface_hub import get_token # huggingface_hub ≥ 0.19 + return bool(get_token()) + except ImportError: + try: + from huggingface_hub import HfFolder + return bool(HfFolder.get_token()) + except Exception: + return False + except Exception: + return False + + +_LOGIN_TIP_SHOWN = False + +def _show_hf_login_tip_once() -> None: + """Print a one-time login suggestion if no HF token is configured. + + Anonymous downloads work but have a ~50 GB/day soft cap on HF's LFS CDN + and lower aggregate bandwidth — a free token effectively removes both. + Stays silent if a token is already in place. + """ + global _LOGIN_TIP_SHOWN + if _LOGIN_TIP_SHOWN: + return + _LOGIN_TIP_SHOWN = True + if _hf_token_configured(): + return + import sys + YEL = "\033[1;33m" if sys.stdout.isatty() else "" + BOLD = "\033[1m" if sys.stdout.isatty() else "" + DIM = "\033[2m" if sys.stdout.isatty() else "" + RST = "\033[0m" if sys.stdout.isatty() else "" + print() + print(f" {YEL}⚠ not logged in to HuggingFace{RST} — anonymous downloads work but are") + print(f" rate-limited (~50 GB/day cap on the LFS CDN). For faster, higher-limit") + print(f" downloads, log in once with a free read-only token:") + print() + print(f" 1. create an account at {BOLD}https://huggingface.co/join{RST}") + print(f" 2. generate a token at {BOLD}https://huggingface.co/settings/tokens{RST}") + print(f" {DIM}('Read' scope is enough){RST}") + print(f" 3. save it on this machine — pick one:") + print(f" {BOLD}hf auth login{RST} {DIM}# modern (huggingface_hub ≥ 1.0){RST}") + print(f" {BOLD}huggingface-cli login{RST} {DIM}# classic; still works{RST}") + print(f" {BOLD}export HF_TOKEN=hf_xxx{RST} {DIM}# one-off / scripts{RST}") + print() + + +def ensure_local(local_rel_path: str, verbose: bool = True) -> Path: + """Resolve a model file to an absolute local path, downloading if missing. + + Files are streamed into the HuggingFace cache (~/.cache/huggingface/hub/) + and symlinked into the project at `local_rel_path` so the on-disk layout + looks the same whether the file was bundled or downloaded. + """ + target = SCRIPT_DIR / local_rel_path + if target.exists() or target.is_symlink(): + return target + + if local_rel_path not in FLAT_MANIFEST: + raise FileNotFoundError( + f"{local_rel_path} is not in the weights manifest — can't auto-download." + ) + + # First-download tip: nudge users toward logging in to HF for better limits. + # No-op if a token is already configured. + _show_hf_login_tip_once() + + hf_filename = FLAT_MANIFEST[local_rel_path] + if verbose: + print(f" ↓ downloading {hf_filename} (from {REPO_ID})") + + try: + from huggingface_hub import hf_hub_download + except ImportError as e: + raise RuntimeError( + "huggingface_hub is required to auto-download weights.\n" + "Run: pip install huggingface_hub\n" + "Or run the install.py script in this directory." + ) from e + + cached = hf_hub_download(repo_id=REPO_ID, filename=hf_filename) + target.parent.mkdir(parents=True, exist_ok=True) + # Symlink keeps the HF cache canonical (one copy on disk) while exposing + # the file at the project-relative path the runtime expects. + if target.is_symlink(): + target.unlink() + target.symlink_to(cached) + return target + + +def is_present(local_rel_path: str) -> bool: + """True if the file exists locally (does not trigger a download).""" + p = SCRIPT_DIR / local_rel_path + return p.exists() or p.is_symlink() + + +def bundle_status(bundle: str) -> tuple[int, int]: + """Returns (present_count, total_count) for the bundle (including SHARED).""" + items = DIT_BUNDLES[bundle] + SHARED + present = sum(1 for rel, _ in items if is_present(rel)) + return present, len(items)