Assume I have multiple graphs (nodes, senders, receivers and graph label). How would one create batches that would work with `jax.vmap`?
Assume I have multiple graphs (nodes, senders, receivers and graph label). How would one create batches that would work with
jax.vmap?