diff --git a/optimized/mlx/README.md b/optimized/mlx/README.md index 1d7b615..2a20374 100644 --- a/optimized/mlx/README.md +++ b/optimized/mlx/README.md @@ -94,6 +94,10 @@ Apple Silicon only (MLX is Metal-backed). Python 3.10+. `./install.sh ./sa3 --prompt "ambient drone" --cfg 3.0 --negative-prompt "drums, vocals" \ --dit sm-music --decoder same-s --out drone.wav +# Apply a LoRA finetune (merged into the DiT at load; base must match --dit) +./sa3 --prompt "arabic maqam oud taqsim" --dit medium --decoder same-l \ + --lora ./my_lora.safetensors --lora-strength 1.0 --out maqam.wav + # Generate + play immediately (afplay; Ctrl-C stops both) ./sa3 --prompt "rainforest" --dit sm-sfx --decoder same-s --play @@ -178,6 +182,8 @@ Sample run on **M4 Pro / 48 GB**: | `--init-noise-level` | 1.0 | σmax; 0.4–0.8 typical for variation, 1.0 = full regen, >1 = overshoot | | `--inpaint-range` | — | `START,END` seconds; regenerate that span, keep the rest | | `--dit-dtype` | fp16 | DiT compute dtype (decoder always FP32; T5Gemma always fp16) | +| `--lora` | — | One or more `.safetensors` LoRA adapters merged into the DiT at load (SA3-native or PEFT). Pickle `.ckpt/.pt` is refused. Base must match `--dit` | +| `--lora-strength` | 1.0 | Application weight per `--lora` delta; 0 = bit-exact bypass, >1 amplifies | | `--free-models` | on | Progressive model freeing; `--no-free-models` keeps them resident | | `--out` | out.wav | Relative → `output/`; absolute → as-is. 16-bit PCM stereo @ 44.1 kHz, trimmed to exactly `--seconds` | | `--play` | off | After writing, play via `afplay`; Ctrl-C stops both processes | diff --git a/optimized/mlx/models/defs/dit_mlx.py b/optimized/mlx/models/defs/dit_mlx.py index 3487c65..95bfdda 100644 --- a/optimized/mlx/models/defs/dit_mlx.py +++ b/optimized/mlx/models/defs/dit_mlx.py @@ -304,12 +304,17 @@ def convert_weights_from_torch_ckpt(ckpt_path): return out -def load_dit(weights_path, T_lat=320, dtype=mx.float16, compile_=False): +def load_dit(weights_path, T_lat=320, dtype=mx.float16, compile_=False, + lora_paths=None, lora_strength=1.0, lora_log=print): """Build MLX DiT and load weights. weights_path can be either: - the sa3-sm-music torch ckpt (slow; converts at load time), OR - a pre-converted MLX file (.npz or .safetensors — fast path). + + lora_paths: optional list of LoRA adapters (.safetensors / PEFT dir) to merge + into the weights at load time. lora_strength scales every adapter's delta. + See models/defs/lora_merge.py. """ p = str(weights_path) if p.endswith(".npz") or p.endswith(".safetensors"): @@ -317,6 +322,11 @@ def load_dit(weights_path, T_lat=320, dtype=mx.float16, compile_=False): else: wd = convert_weights_from_torch_ckpt(p) + if lora_paths: + from .lora_merge import merge_loras_into_weights + stats = merge_loras_into_weights(wd, lora_paths, strength=lora_strength, log=lora_log) + lora_log(f"lora: merged {stats['merged']} layer(s) from {stats['adapters']} adapter(s)") + model = DiT(T_lat=T_lat) wd_list = [(k, v.astype(dtype)) for k, v in wd.items()] model.load_weights(wd_list, strict=False) diff --git a/optimized/mlx/models/defs/dit_mlx_medium.py b/optimized/mlx/models/defs/dit_mlx_medium.py index 2efa4a3..73ebef4 100644 --- a/optimized/mlx/models/defs/dit_mlx_medium.py +++ b/optimized/mlx/models/defs/dit_mlx_medium.py @@ -405,11 +405,16 @@ def convert_weights(safetensors_path, out_path=None): return out -def load_dit(weights_path, T_lat=320, dtype=mx.float16, compile_=False): +def load_dit(weights_path, T_lat=320, dtype=mx.float16, compile_=False, + lora_paths=None, lora_strength=1.0, lora_log=print): """Build MLX DiT and load weights. weights_path: either the .safetensors (we'll convert in-memory) or a pre-converted .safetensors-mlx file. + + lora_paths: optional list of LoRA adapters (.safetensors / PEFT dir) to merge + into the weights at load time. lora_strength scales every adapter's delta. + See models/defs/lora_merge.py. """ weights_path = str(weights_path) if weights_path.endswith(".safetensors") and ("medium-ARC" in weights_path): @@ -418,6 +423,11 @@ def load_dit(weights_path, T_lat=320, dtype=mx.float16, compile_=False): else: wd = dict(mx.load(weights_path)) + if lora_paths: + from .lora_merge import merge_loras_into_weights + stats = merge_loras_into_weights(wd, lora_paths, strength=lora_strength, log=lora_log) + lora_log(f"lora: merged {stats['merged']} layer(s) from {stats['adapters']} adapter(s)") + model = DiT(T_lat=T_lat) # Cast to target dtype (no-op when already at `dtype`). diff --git a/optimized/mlx/models/defs/lora_merge.py b/optimized/mlx/models/defs/lora_merge.py new file mode 100644 index 0000000..b312382 --- /dev/null +++ b/optimized/mlx/models/defs/lora_merge.py @@ -0,0 +1,375 @@ +"""LoRA merge-at-load for the MLX SA3 DiT. + +Adds LoRA inference support to the MLX path, which the `sa3_mlx.py` CLI otherwise +lacks. The LoRA delta is **merged into the DiT weight dict at load time**, before +the model is built — no runtime parametrization and no extra per-step forward +cost. A strength of 0 is a bit-exact bypass. + +Trust boundary: only `.safetensors` adapters are accepted. The legacy pickle +`.ckpt`/`.pt`/`.bin` path — which `torch.load` would execute arbitrary code from — +is refused outright; this module never calls `torch.load`. + +Two on-disk conventions are supported: + + * **SA3-native** (`scripts/train_lora.py` output): tensor keys + ``.parametrizations.weight.0.{lora_A,lora_B,M_xs,magnitude, + magnitude_r,magnitude_c}`` with the adapter config + (``adapter_type``/``rank``/``alpha``/``include``/``exclude``) JSON-encoded in + the safetensors **metadata** under ``"lora_config"``. Covers all nine adapter + types (lora, dora-rows/cols, bora, and the four -xs variants). + * **PEFT** (huggingface `peft`): keys ``base_model.model..lora_{A,B}.weight`` + with ``r``/``lora_alpha`` in a sibling ``adapter_config.json``. Standard LoRA, + plus DoRA when ``use_dora`` is set. + +The per-adapter-type math mirrors ``LoRAParametrization.*_forward`` in +``stable_audio_3/models/lora/model.py`` (and the accumulate-deltas-against-the- +original-weight semantics of ``merge_loras_into_base_model``), computed in +float32 and cast back to the DiT dtype. `-xs` adapters do not store their frozen +SVD bases, so they are recomputed from the base weight here, matching the +reference (`torch.linalg.svd` + a deterministic sign convention). +""" + +from __future__ import annotations + +import json +import os + +import mlx.core as mx +import numpy as np + +# Pickle-backed extensions we refuse to load (the trust boundary). +_PICKLE_EXTS = (".ckpt", ".pt", ".pth", ".bin") + +# Adapter param names per type (mirrors utils._get_adapter_param_names). +_PARAMS_FOR = { + "lora": ("lora_A", "lora_B"), + "dora-rows": ("lora_A", "lora_B", "magnitude"), + "dora-cols": ("lora_A", "lora_B", "magnitude"), + "bora": ("lora_A", "lora_B", "magnitude_r", "magnitude_c"), + "lora-xs": ("M_xs",), + "dora-rows-xs": ("M_xs", "magnitude"), + "dora-cols-xs": ("M_xs", "magnitude"), + "bora-xs": ("M_xs", "magnitude_r", "magnitude_c"), +} + + +class LoraError(Exception): + """An adapter could not be loaded or applied.""" + + +# ── safetensors reading (no torch, no safetensors pkg — MLX reads it) ────────── + +def _np(arr) -> np.ndarray: + return np.array(arr.astype(mx.float32), dtype=np.float32) + + +def _load_safetensors(path: str): + """Return ``(tensors: dict[str, np.ndarray], metadata: dict)``. Refuses pickle.""" + lower = path.lower() + if lower.endswith(_PICKLE_EXTS): + raise LoraError( + f"refusing to load pickle-format adapter {os.path.basename(path)!r} — " + f"only .safetensors adapters are accepted (a .ckpt/.pt is unpickled by " + f"torch.load and can execute arbitrary code)" + ) + if not lower.endswith(".safetensors"): + raise LoraError(f"not a .safetensors adapter: {path!r}") + arrs, meta = mx.load(path, return_metadata=True) + return {k: _np(v) for k, v in arrs.items()}, (meta or {}) + + +# ── SVD bases for -xs adapters (recomputed; mirrors model.py) ────────────────── + +def _canonicalize_svd_signs(U: np.ndarray, Vh: np.ndarray): + """Deterministic sign convention: largest-magnitude element of each U column + is positive (mirrors model._canonicalize_svd_signs).""" + max_abs_idx = np.argmax(np.abs(U), axis=0) + signs = np.sign(U[max_abs_idx, np.arange(U.shape[1])]) + signs[signs == 0] = 1.0 + return U * signs[None, :], Vh * signs[:, None] + + +def _svd_bases(W0: np.ndarray, rank: int): + """Return ``(U[:, :rank], V[:, :rank])`` from the SVD of ``W0`` (fan_out, fan_in), + with V such that ``U @ diag(S) @ V.T`` reconstructs W0 (mirrors model.py).""" + U_full, _S, Vh_full = np.linalg.svd(W0, full_matrices=False) + U_full, Vh_full = _canonicalize_svd_signs(U_full, Vh_full) + U = U_full[:, :rank] + V = Vh_full[:rank, :].T + return U, V + + +# ── per-type merge math (numpy, float32) ─────────────────────────────────────── + +def _merged_weight(W0: np.ndarray, p: dict, adapter_type: str, scaling: float) -> np.ndarray: + """Return the LoRA-merged weight for one layer at full strength (lora_strength=1). + + ``W0`` is (fan_out, fan_in) float32; ``p`` holds the adapter tensors for this + layer. Mirrors the matching ``*_forward`` in model.py. + """ + if adapter_type == "lora": + delta = p["lora_B"] @ p["lora_A"] + return W0 + scaling * delta + + if adapter_type in ("dora-rows", "dora-cols"): + norm_dim = 1 if adapter_type == "dora-rows" else 0 + delta = p["lora_B"] @ p["lora_A"] + V = W0 + scaling * delta + V_hat = V / (np.linalg.norm(V, axis=norm_dim, keepdims=True) + 1e-12) + mag = _mag_2d(p["magnitude"], norm_dim) + return V_hat * mag + + if adapter_type == "bora": + delta = p["lora_B"] @ p["lora_A"] + V = W0 + scaling * delta + V_r = V / (np.linalg.norm(V, axis=1, keepdims=True) + 1e-12) + inter = p["magnitude_r"].reshape(-1, 1) * V_r + H_c = inter / (np.linalg.norm(inter, axis=0, keepdims=True) + 1e-12) + return H_c * p["magnitude_c"].reshape(1, -1) + + if adapter_type.endswith("-xs"): + rank = p["M_xs"].shape[0] + U, V = _svd_bases(W0, rank) + delta = U @ p["M_xs"] @ V.T + Vfull = W0 + scaling * delta + if adapter_type == "lora-xs": + return Vfull + if adapter_type in ("dora-rows-xs", "dora-cols-xs"): + norm_dim = 1 if adapter_type == "dora-rows-xs" else 0 + V_hat = Vfull / (np.linalg.norm(Vfull, axis=norm_dim, keepdims=True) + 1e-12) + mag = _mag_2d(p["magnitude"], norm_dim) + return V_hat * mag + if adapter_type == "bora-xs": + V_r = Vfull / (np.linalg.norm(Vfull, axis=1, keepdims=True) + 1e-12) + inter = p["magnitude_r"].reshape(-1, 1) * V_r + H_c = inter / (np.linalg.norm(inter, axis=0, keepdims=True) + 1e-12) + return H_c * p["magnitude_c"].reshape(1, -1) + + raise LoraError(f"unknown adapter_type {adapter_type!r}") + + +def _check_shapes(layer: str, W0: np.ndarray, p: dict, adapter_type: str) -> None: + """Fail with a clear message when the adapter doesn't fit the base weight — + almost always because the adapter was trained for a different base than + ``--dit`` (e.g. a medium adapter on sm-music). Without this the mismatch + surfaces as a raw numpy broadcasting error deep in the merge.""" + fan_out, fan_in = W0.shape + if adapter_type.endswith("-xs"): + rank = p["M_xs"].shape[0] + if rank > min(fan_out, fan_in): + raise LoraError( + f"{layer}: LoRA-XS rank {rank} exceeds base min-dim " + f"{min(fan_out, fan_in)} for weight {W0.shape} — wrong base for --dit?" + ) + return + b_out, b_rank = p["lora_B"].shape + a_rank, a_in = p["lora_A"].shape + if b_out != fan_out or a_in != fan_in or a_rank != b_rank: + raise LoraError( + f"{layer}: adapter lora_B{p['lora_B'].shape}·lora_A{p['lora_A'].shape} " + f"does not fit base weight {W0.shape} — wrong base for --dit?" + ) + + +def _mag_2d(mag: np.ndarray, norm_dim: int) -> np.ndarray: + """Reshape a (possibly 2D) magnitude vector to broadcast against the weight on + ``norm_dim`` (mirrors `magnitude.unsqueeze(norm_dim)` after a squeeze). + ``atleast_1d`` guards the degenerate (1, 1) case where squeeze yields a + 0-d array (no real DiT layer has a single output, but keep it total).""" + mag = np.atleast_1d(np.squeeze(mag)) + return mag.reshape(-1, 1) if norm_dim == 1 else mag.reshape(1, -1) + + +# ── checkpoint parsing → normalized per-layer adapter ────────────────────────── + +def _layer_to_npz_key(layer: str) -> str: + """Map a checkpoint layer name to its DiT npz weight key. The MLX converter + renames ``to_local_embed.{0,2}`` → ``to_local_embed.seq.{0,2}`` (dit_mlx.py); + every other Linear/Conv1d name passes through unchanged.""" + layer = layer.replace(".to_local_embed.0", ".to_local_embed.seq.0") + layer = layer.replace(".to_local_embed.2", ".to_local_embed.seq.2") + return f"{layer}.weight" + + +def _resolve_path(path: str) -> str: + """Accept a .safetensors file or a PEFT adapter directory (resolve to the + adapter_model.safetensors inside it).""" + if os.path.isdir(path): + cand = os.path.join(path, "adapter_model.safetensors") + if os.path.isfile(cand): + return cand + hits = [f for f in os.listdir(path) if f.lower().endswith(".safetensors")] + if len(hits) == 1: + return os.path.join(path, hits[0]) + raise LoraError( + f"{path!r}: expected one .safetensors adapter in the directory, found {hits}" + ) + return path + + +def _parse_adapter(path: str): + """Load one adapter and return ``(adapter_type, scaling, layers)`` where + ``layers`` maps a checkpoint layer name → its param dict (numpy float32).""" + tensors, meta = _load_safetensors(path) + + native_marker = ".parametrizations.weight.0." + is_native = any(native_marker in k for k in tensors) + + if is_native: + cfg = json.loads(meta.get("lora_config", "{}")) if meta else {} + layers = _group_native(tensors) + rank = int(cfg.get("rank") or _infer_rank(layers)) + alpha = float(cfg.get("alpha", rank)) + adapter_type = _resolve_native_type(cfg.get("adapter_type", "lora")) + scaling = alpha / rank + return adapter_type, scaling, layers + + # PEFT — config lives in a sibling adapter_config.json + peft_marker = ".lora_A.weight" + if any(k.endswith(peft_marker) for k in tensors): + cfg = _read_peft_config(path) + rank = int(cfg["r"]) + alpha = float(cfg.get("lora_alpha", rank)) + use_dora = bool(cfg.get("use_dora", False)) + use_rslora = bool(cfg.get("use_rslora", False)) + adapter_type = "dora-rows" if use_dora else "lora" + scaling = alpha / (np.sqrt(rank) if use_rslora else rank) + layers = _group_peft(tensors) + return adapter_type, scaling, layers + + raise LoraError( + f"{os.path.basename(path)!r}: not a recognised LoRA (no SA3-native " + f"parametrization keys and no PEFT lora_A/lora_B keys)" + ) + + +def _group_native(tensors: dict) -> dict: + marker = ".parametrizations.weight.0." + layers: dict[str, dict] = {} + for k, v in tensors.items(): + if marker not in k: + continue + layer, _, param = k.partition(marker) + layers.setdefault(layer, {})[param] = v + return layers + + +def _group_peft(tensors: dict) -> dict: + prefix = "base_model.model." + layers: dict[str, dict] = {} + for k, v in tensors.items(): + name = k[len(prefix):] if k.startswith(prefix) else k + for suffix, param in ((".lora_A.weight", "lora_A"), + (".lora_B.weight", "lora_B"), + (".lora_magnitude_vector.weight", "magnitude")): + if name.endswith(suffix): + layers.setdefault(name[: -len(suffix)], {})[param] = v + break + return layers + + +def _read_peft_config(path: str) -> dict: + base = os.path.dirname(path) + cfg_path = os.path.join(base, "adapter_config.json") + if not os.path.isfile(cfg_path): + raise LoraError( + f"PEFT adapter at {path!r} is missing its adapter_config.json sibling" + ) + with open(cfg_path) as fh: + return json.load(fh) + + +def _infer_rank(layers: dict) -> int: + for params in layers.values(): + if "lora_A" in params: + return params["lora_A"].shape[0] + if "M_xs" in params: + return params["M_xs"].shape[0] + raise LoraError("cannot infer LoRA rank (no lora_A / M_xs tensors)") + + +def _resolve_native_type(adapter_type: str) -> str: + """Legacy 'dora' → 'dora-rows' (the paper-correct default; mirrors + utils.resolve_adapter_type, minus the 2D-magnitude shape sniff we don't need + because saved magnitudes are 1D).""" + return "dora-rows" if adapter_type == "dora" else adapter_type + + +# ── public entry point ───────────────────────────────────────────────────────── + +def merge_loras_into_weights(weights: dict, lora_paths, strength: float = 1.0, + log=lambda _m: None) -> dict: + """Merge one or more LoRA adapters into ``weights`` in place. + + ``weights`` is the DiT weight dict as loaded from the npz (str → mx.array). + ``strength`` is the application weight applied to every adapter's delta (the + `--lora-strength` knob; matches ``application_weight`` in + ``merge_loras_into_base_model``). Deltas are accumulated against the original + weight, then applied once, so stacking is order-independent for linear LoRA. + + Returns a stats dict ``{"merged": int, "skipped": list[str], "adapters": int}``. + """ + if not lora_paths: + return {"merged": 0, "skipped": [], "adapters": 0} + + parsed = [] + for raw in lora_paths: + path = _resolve_path(raw) + adapter_type, scaling, layers = _parse_adapter(path) + parsed.append((path, adapter_type, scaling, layers)) + log(f"lora: {os.path.basename(path)} — {adapter_type}, " + f"scaling={scaling:.3f}, {len(layers)} target layers") + + # Accumulate deltas per npz key against the *original* weight. Each entry is + # [summed_delta, restore] — the layout restorer is the same across repeats. + accum: dict[str, list] = {} + skipped: list[str] = [] + for path, adapter_type, scaling, layers in parsed: + need = _PARAMS_FOR.get(adapter_type, ()) + for layer, params in layers.items(): + key = _layer_to_npz_key(layer) + if key not in weights: + skipped.append(layer) + continue + missing = [n for n in need if n not in params] + if missing: + raise LoraError(f"{layer}: adapter is {adapter_type} but missing {missing}") + W0, restore = _weight_as_2d(weights[key]) + _check_shapes(layer, W0, params, adapter_type) + merged = _merged_weight(W0, params, adapter_type, scaling) + delta = strength * (merged - W0) + if key in accum: + accum[key][0] += delta + else: + accum[key] = [delta, restore] + + for key, (delta, restore) in accum.items(): + W0, _ = _weight_as_2d(weights[key]) + weights[key] = mx.array(restore(W0 + delta)) + + if skipped: + log(f"lora: skipped {len(skipped)} layer(s) not in this DiT " + f"(e.g. {skipped[0]})") + if not accum: + log("lora: WARNING — merged 0 layers; the adapter targets nothing in this " + "DiT (wrong base for --dit, or unsupported target modules)") + return {"merged": len(accum), "skipped": skipped, "adapters": len(parsed)} + + +def _weight_as_2d(arr): + """Return ``(W2d, restore)`` where ``W2d`` is the PyTorch-layout 2D weight + (fan_out, fan_in) as numpy float32, and ``restore(W2d)`` rebuilds the MLX + layout. Linear weights are already 2D == PyTorch layout; Conv1d weights are + stored MLX-style (out, k, in) and round-trip through PyTorch (out, in, k).""" + np_arr = _np(arr) + if np_arr.ndim == 2: + return np_arr, lambda w: w.astype(np.float32) + if np_arr.ndim == 3: + out, k, cin = np_arr.shape + w2d = np_arr.transpose(0, 2, 1).reshape(out, cin * k) # (out, in*k), PyTorch order + + def restore(w): + return w.reshape(out, cin, k).transpose(0, 2, 1).astype(np.float32) + + return w2d, restore + raise LoraError(f"unexpected weight rank {np_arr.ndim} for a LoRA target") diff --git a/optimized/mlx/scripts/sa3_mlx.py b/optimized/mlx/scripts/sa3_mlx.py index c8893b6..69db172 100644 --- a/optimized/mlx/scripts/sa3_mlx.py +++ b/optimized/mlx/scripts/sa3_mlx.py @@ -170,13 +170,18 @@ def prompt_user_if_missing(args): return args -def load_dit(dit_name: str, T_lat: int, dtype): +def load_dit(dit_name: str, T_lat: int, dtype, lora_paths=None, lora_strength=1.0): cfg = DIT_CHOICES[dit_name] ckpt = ensure_local(cfg["ckpt"]) import importlib, io, contextlib mod = importlib.import_module(cfg["loader"]) + # The loader's own chatter is swallowed, but the LoRA merge summary is routed + # to stderr (not redirected) so it stays visible to callers that capture it. + lora_log = lambda m: print(m, file=sys.stderr) with contextlib.redirect_stdout(io.StringIO()): - model = mod.load_dit(str(ckpt), T_lat=T_lat, dtype=dtype, compile_=False) + model = mod.load_dit(str(ckpt), T_lat=T_lat, dtype=dtype, compile_=False, + lora_paths=lora_paths, lora_strength=lora_strength, + lora_log=lora_log) return model, str(ckpt) @@ -387,6 +392,16 @@ def main(): help="Path to the bundled T5Gemma FP16 .npz (weights + tokenizer). " "Default points at models/mlx/t5gemma_f16.npz next to this script; " "auto-downloaded from HuggingFace if not present.") + ap.add_argument("--lora", nargs="+", default=None, metavar="ADAPTER", + help="One or more LoRA adapters to merge into the DiT at load time. " + "Each is a .safetensors file (SA3-native train_lora.py output) or a " + "PEFT adapter directory/.safetensors (with its adapter_config.json). " + "ONLY .safetensors is accepted — a pickle .ckpt/.pt is refused (it " + "would execute code on load). The adapter's base must match --dit " + "(e.g. a medium adapter with --dit medium).") + ap.add_argument("--lora-strength", type=float, default=1.0, + help="Application weight for every --lora delta (default 1.0). 0 disables " + "the adapter (bit-identical to no LoRA); >1 amplifies it.") # ── Sampling ────────────────────────────────────────────────────────────── ap.add_argument("--seconds", type=float, default=30.0, @@ -625,7 +640,11 @@ def main(): # ── 3b. DiT pingpong sampling ── stage("[3/5]", f"DiT — load + sample ({args.steps} steps, σmax={sigma_max:.2f})") - t0 = time.time(); dit_model, _ = load_dit(args.dit, T_lat=T_lat, dtype=dtype) + if args.lora: + sub(f"lora {', '.join(os.path.basename(p.rstrip('/')) for p in args.lora)} " + f"(strength {args.lora_strength:g})") + t0 = time.time(); dit_model, _ = load_dit(args.dit, T_lat=T_lat, dtype=dtype, + lora_paths=args.lora, lora_strength=args.lora_strength) _stage_peak_b('DiT load') sub(f"load {time.time()-t0:.1f}s") diff --git a/optimized/mlx/scripts/test_lora_merge.py b/optimized/mlx/scripts/test_lora_merge.py new file mode 100644 index 0000000..de2135c --- /dev/null +++ b/optimized/mlx/scripts/test_lora_merge.py @@ -0,0 +1,217 @@ +"""Unit tests for models/defs/lora_merge.py — the LoRA merge-at-load math. + +Pure-numpy/MLX, no model weights or torch needed. Run either way: + + python scripts/test_lora_merge.py # standalone (no pytest needed) + pytest scripts/test_lora_merge.py + +Correctness is established by: + * the zero-init -> identity invariant for every adapter type (a strong guard on + axis / reshape / transpose bugs: with B/M_xs zeroed and magnitudes at their + init values, every forward must return W0); + * an exact independent reconstruction for standard LoRA and PEFT; + * the Conv1d (out,k,in) <-> PyTorch (out,in,k) layout round-trip; + * the to_local_embed name remap; + * --lora-strength scaling and the bit-exact strength-0 bypass; + * the trust boundary (pickle refusal) and base-mismatch / 0-merge handling. +""" +import json +import os +import sys +import tempfile + +import numpy as np +import mlx.core as mx + +# Import lora_merge the same way the CLI does (REPO = optimized/mlx on sys.path). +REPO = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +sys.path.insert(0, REPO) +from models.defs import lora_merge as lm # noqa: E402 + +rng = np.random.default_rng(0) + + +def _save_native(path, layer, params, cfg): + d = {f"{layer}.parametrizations.weight.0.{k}": mx.array(v.astype(np.float32)) + for k, v in params.items()} + mx.save_safetensors(path, d, metadata={"lora_config": json.dumps(cfg)}) + + +def _init_params(adapter_type, W0, rank, nonzero): + """Init so a zeroed core -> identity; if nonzero, give a real core.""" + fo, fi = W0.shape + p = {} + if not adapter_type.endswith("-xs"): + p["lora_A"] = rng.standard_normal((rank, fi)).astype(np.float32) + p["lora_B"] = (rng.standard_normal((fo, rank)) if nonzero + else np.zeros((fo, rank))).astype(np.float32) + else: + p["M_xs"] = (rng.standard_normal((rank, rank)) if nonzero + else np.zeros((rank, rank))).astype(np.float32) + if "magnitude" in lm._PARAMS_FOR[adapter_type]: + nd = 1 if "rows" in adapter_type else 0 + p["magnitude"] = np.linalg.norm(W0, axis=nd).astype(np.float32) + if "magnitude_r" in lm._PARAMS_FOR[adapter_type]: + p["magnitude_r"] = np.linalg.norm(W0, axis=1).astype(np.float32) + p["magnitude_c"] = np.linalg.norm(W0, axis=0).astype(np.float32) + return p + + +def _np(arr): + return np.array(arr.astype(mx.float32)) + + +def test_trust_boundary_refuses_pickle(): + for bad in ("evil.ckpt", "evil.pt", "evil.bin", "weights.npz"): + try: + lm._load_safetensors(bad) + assert False, f"should have refused {bad}" + except lm.LoraError: + pass + + +def test_all_types_zero_init_identity_and_nonzero_delta(): + W0 = rng.standard_normal((6, 8)).astype(np.float32) + with tempfile.TemporaryDirectory() as tmp: + for atype in lm._PARAMS_FOR: + path = os.path.join(tmp, f"{atype}.safetensors") + # zero-init -> identity + _save_native(path, "foo", _init_params(atype, W0, 4, nonzero=False), + {"adapter_type": atype, "rank": 4, "alpha": 4}) + w = {"foo.weight": mx.array(W0)} + lm.merge_loras_into_weights(w, [path]) + assert np.allclose(_np(w["foo.weight"]), W0, atol=1e-4), f"{atype} identity" + # nonzero core -> finite, changed + _save_native(path, "foo", _init_params(atype, W0, 4, nonzero=True), + {"adapter_type": atype, "rank": 4, "alpha": 4}) + w = {"foo.weight": mx.array(W0)} + lm.merge_loras_into_weights(w, [path]) + got = _np(w["foo.weight"]) + assert np.isfinite(got).all() and not np.allclose(got, W0), f"{atype} delta" + + +def test_standard_lora_exact_and_strength(): + W0 = rng.standard_normal((6, 8)).astype(np.float32) + A = rng.standard_normal((4, 8)).astype(np.float32) + B = rng.standard_normal((6, 4)).astype(np.float32) + expect = W0 + (8.0 / 4.0) * (B @ A) + with tempfile.TemporaryDirectory() as tmp: + path = os.path.join(tmp, "std.safetensors") + _save_native(path, "foo", {"lora_A": A, "lora_B": B}, + {"adapter_type": "lora", "rank": 4, "alpha": 8}) + w = {"foo.weight": mx.array(W0)} + lm.merge_loras_into_weights(w, [path]) + assert np.allclose(_np(w["foo.weight"]), expect, atol=1e-4) + # strength 0.5 halves the delta; strength 0 is the original + w = {"foo.weight": mx.array(W0)} + lm.merge_loras_into_weights(w, [path], strength=0.5) + assert np.allclose(_np(w["foo.weight"]), W0 + 0.5 * (8.0 / 4.0) * (B @ A), atol=1e-4) + w = {"foo.weight": mx.array(W0)} + lm.merge_loras_into_weights(w, [path], strength=0.0) + assert np.allclose(_np(w["foo.weight"]), W0, atol=1e-6) + + +def test_conv1d_layout_round_trip(): + out_c, k_c, in_c = 4, 1, 3 + Wc = rng.standard_normal((out_c, k_c, in_c)).astype(np.float32) # MLX layout + Ac = rng.standard_normal((2, in_c * k_c)).astype(np.float32) + Bc = rng.standard_normal((out_c, 2)).astype(np.float32) + with tempfile.TemporaryDirectory() as tmp: + path = os.path.join(tmp, "conv.safetensors") + _save_native(path, "conv", {"lora_A": Ac, "lora_B": Bc}, + {"adapter_type": "lora", "rank": 2, "alpha": 2}) + w = {"conv.weight": mx.array(Wc)} + lm.merge_loras_into_weights(w, [path]) + got = _np(w["conv.weight"]) + expect = Wc + (Bc @ Ac).reshape(out_c, in_c, k_c).transpose(0, 2, 1) + assert got.shape == Wc.shape and np.allclose(got, expect, atol=1e-4) + + +def test_to_local_embed_remap(): + Wt = rng.standard_normal((5, 5)).astype(np.float32) + with tempfile.TemporaryDirectory() as tmp: + path = os.path.join(tmp, "tle.safetensors") + _save_native(path, "transformer.layers.0.to_local_embed.0", + {"lora_A": rng.standard_normal((2, 5)).astype(np.float32), + "lora_B": rng.standard_normal((5, 2)).astype(np.float32)}, + {"adapter_type": "lora", "rank": 2, "alpha": 2}) + w = {"transformer.layers.0.to_local_embed.seq.0.weight": mx.array(Wt)} + stats = lm.merge_loras_into_weights(w, [path]) + assert stats["merged"] == 1 and not stats["skipped"] + + +def test_peft_format_exact(): + W0 = rng.standard_normal((6, 8)).astype(np.float32) + A = rng.standard_normal((4, 8)).astype(np.float32) + B = rng.standard_normal((6, 4)).astype(np.float32) + with tempfile.TemporaryDirectory() as tmp: + with open(os.path.join(tmp, "adapter_config.json"), "w") as fh: + json.dump({"peft_type": "LORA", "r": 4, "lora_alpha": 8, "use_dora": False}, fh) + mx.save_safetensors( + os.path.join(tmp, "adapter_model.safetensors"), + {"base_model.model.foo.lora_A.weight": mx.array(A), + "base_model.model.foo.lora_B.weight": mx.array(B)}, + metadata={"format": "pt"}) + w = {"foo.weight": mx.array(W0)} + stats = lm.merge_loras_into_weights(w, [tmp]) # directory input + assert stats["merged"] == 1 + assert np.allclose(_np(w["foo.weight"]), W0 + (8.0 / 4.0) * (B @ A), atol=1e-4) + + +def test_unknown_layers_skipped_base_untouched(): + W0 = rng.standard_normal((6, 8)).astype(np.float32) + with tempfile.TemporaryDirectory() as tmp: + path = os.path.join(tmp, "std.safetensors") + _save_native(path, "foo", + {"lora_A": rng.standard_normal((4, 8)).astype(np.float32), + "lora_B": rng.standard_normal((6, 4)).astype(np.float32)}, + {"adapter_type": "lora", "rank": 4, "alpha": 4}) + w = {"other.weight": mx.array(W0)} + msgs = [] + stats = lm.merge_loras_into_weights(w, [path], log=msgs.append) + assert stats["merged"] == 0 and stats["skipped"] == ["foo"] + assert np.allclose(_np(w["other.weight"]), W0) + # a 0-merge run must warn, not look like a successful no-op + assert any("WARNING" in m and "0 layers" in m for m in msgs) + + +def test_base_mismatch_raises_clear_error(): + # base is (6, 8) but the adapter was trained for an (10, 8) layer + W0 = rng.standard_normal((6, 8)).astype(np.float32) + with tempfile.TemporaryDirectory() as tmp: + path = os.path.join(tmp, "mismatch.safetensors") + _save_native(path, "foo", + {"lora_A": rng.standard_normal((4, 8)).astype(np.float32), + "lora_B": rng.standard_normal((10, 4)).astype(np.float32)}, + {"adapter_type": "lora", "rank": 4, "alpha": 4}) + w = {"foo.weight": mx.array(W0)} + try: + lm.merge_loras_into_weights(w, [path]) + assert False, "should have raised on base mismatch" + except lm.LoraError as e: + assert "foo" in str(e) and "base" in str(e).lower() + + +def test_svd_bases_orthonormal(): + Wsvd = rng.standard_normal((6, 8)).astype(np.float32) + U, V = lm._svd_bases(Wsvd, rank=6) + assert np.allclose(U.T @ U, np.eye(6), atol=1e-4) + assert np.allclose(V.T @ V, np.eye(6), atol=1e-4) + + +def _main(): + tests = [v for k, v in sorted(globals().items()) if k.startswith("test_")] + failed = [] + for t in tests: + try: + t() + print(f" ok {t.__name__}") + except Exception as e: # noqa: BLE001 + print(f" FAIL {t.__name__}: {e}") + failed.append(t.__name__) + print("\n" + ("ALL PASS" if not failed else f"FAILURES: {failed}")) + return 1 if failed else 0 + + +if __name__ == "__main__": + sys.exit(_main())