Skip to content

Conversation

@copybara-service
Copy link

@copybara-service copybara-service bot commented Jan 15, 2026

Shard split rngs in lifted vmap if spmd_axis_name is given and applied to vmapped axis. Without this, jax.vmap gives an error: ValueError: Mapped away dimension of inputs passed to vmap should be sharded the same. Got inconsistent axis specs: None vs batch due to split_rngs being replicated.

@copybara-service copybara-service bot changed the title Shard split rngs in lifted vmap if spmd_axis_name is given and applied to vmapped axis. Shard split rngs in lifted vmap if spmd_axis_name is given and applied to vmapped axis. Without this, jax.vmap gives an error: ValueError: Mapped away dimension of inputs passed to vmap should be sharded the same. Got inconsistent axis specs: None vs batch due to split_rngs being replicated. Jan 15, 2026
…d to vmapped axis. Without this, jax.vmap gives an error: `ValueError: Mapped away dimension of inputs passed to vmap should be sharded the same. Got inconsistent axis specs: None vs batch` due to split_rngs being replicated.

PiperOrigin-RevId: 857821732
@copybara-service copybara-service bot merged commit 7551bfc into main Jan 18, 2026
5 of 6 checks passed
@copybara-service copybara-service bot deleted the test_856253490 branch January 18, 2026 13:51
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant