Skip to content

Avoid MPS bfloat16 cross kernel in rot6d conversion#23

Open
lyonsno wants to merge 1 commit into
warmshao:mainfrom
lyonsno:codex/bf16-manual-cross-clean-0530
Open

Avoid MPS bfloat16 cross kernel in rot6d conversion#23
lyonsno wants to merge 1 commit into
warmshao:mainfrom
lyonsno:codex/bf16-manual-cross-clean-0530

Conversation

@lyonsno
Copy link
Copy Markdown

@lyonsno lyonsno commented May 30, 2026

Summary

  • replace torch.linalg.cross in rot6d_to_rotmat with the equivalent elementwise 3D cross product
  • add focused rot6d dtype/device coverage, including an MPS bfloat16 regression test when MPS is available
  • add a small CPU autograd smoke for the rotation conversion path

Why

On MPS with bfloat16 and fallback disabled, torch.linalg.cross can fail with RuntimeError: Failed to create function state object for: cross_bfloat. The explicit cross-product formula uses basic tensor operations, preserves dtype/device, and keeps the operation on-device for MPS bfloat16.

This is separate from #22 and only addresses the cross_bfloat compatibility failure. It is not intended as a performance optimization.

Tests

  • PYTORCH_ENABLE_MPS_FALLBACK=0 uv run --with pytest --python .venv/bin/python pytest tests/test_rot6d_mps_bf16.py -q
  • PYTORCH_ENABLE_MPS_FALLBACK=0 .venv/bin/python - <<'PY' ... rot6d_to_rotmat(..., device='mps', dtype=torch.bfloat16) ... PY
  • .venv/bin/python -m py_compile wilor_mini/models/vit.py tests/test_rot6d_mps_bf16.py

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant