Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 17 additions & 3 deletions whisper/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,19 @@
}


def _get_model_variant(name_or_path: str) -> str | None:
"""Extract model variant for alignment heads lookup."""
if name_or_path in _ALIGNMENT_HEADS:
return name_or_path

# Extract from repo name like "openai/whisper-large-v3"
name = name_or_path.split("/")[-1]
if name.startswith("whisper-"):
return name[8:] # Remove "whisper-" prefix

return None


def _download(url: str, root: str) -> str:
os.makedirs(root, exist_ok=True)

Expand Down Expand Up @@ -156,10 +169,11 @@ def load_torch_weights_and_config(
if download_root is None:
download_root = os.path.join(os.path.expanduser("~"), ".cache/whisper")

# todo: accept alignment_heads of local Pytorch checkpoint
alignment_heads = None
# Look up alignment heads using normalized variant name
variant = _get_model_variant(name_or_path)
alignment_heads = _ALIGNMENT_HEADS.get(variant) if variant else None

if name_or_path in _MODELS:
alignment_heads = _ALIGNMENT_HEADS[name_or_path]
name_or_path = _download(_MODELS[name_or_path], download_root)
elif not Path(name_or_path).exists():
# Try downloading from HF
Expand Down