Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 27 additions & 1 deletion tunix/generate/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -949,16 +949,42 @@ def intersect_trees(
candidate_b.pop(match_index)

found_candidate = None
scan_idx = layer_idx

for cand in [tuple(candidate_a), tuple(candidate_b)]:
if cand in src_flat:
found_candidate = cand
break

# Candidate C: Inhomogeneous Scan (GPT-OSS 120B / Interleaved Layers)
if not found_candidate:
# Check if layers_0 exists, which implies interleaved/inhomogeneous scanning
candidate_c_probe = list(key_tuple)
candidate_c_probe[match_index] = 'layers_0'
if tuple(candidate_c_probe) in src_flat:
# Determine cycle length by probing layers_0, layers_1, ...
cycle = 0
while True:
candidate_c_probe[match_index] = f'layers_{cycle}'
if tuple(candidate_c_probe) in src_flat:
cycle += 1
else:
break

if cycle > 0:
src_layer_idx = layer_idx % cycle
scan_idx = layer_idx // cycle

candidate_c = list(key_tuple)
candidate_c[match_index] = f'layers_{src_layer_idx}'
if tuple(candidate_c) in src_flat:
found_candidate = tuple(candidate_c)

if found_candidate:
src_val = src_flat[found_candidate]
# Slice the scanned parameter
sliced_val = _slice_scanned_param(
src_val, tgt_val, layer_idx, str(key_tuple)
src_val, tgt_val, scan_idx, str(key_tuple)
)
sliced_val = _apply_dtype_cast(
sliced_val, tgt_val.dtype, str(key_tuple)
Expand Down
3 changes: 2 additions & 1 deletion tunix/rl/reshard.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import jaxtyping
from flax import nnx
from tunix.rl import utils
from jax._src.sharding_impls import GSPMDSharding as GSPMDSharding

# TODO(tsbao): move this to util
def callback_on_ready(
Expand Down Expand Up @@ -449,7 +450,7 @@ def reshard_pytree(

def _get_dst_sharding(x):
if isinstance(
x, jax.sharding.NamedSharding | jax.sharding.SingleDeviceSharding
x, jax.sharding.NamedSharding | jax.sharding.SingleDeviceSharding | GSPMDSharding
):
Comment on lines 452 to 454
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Using GSPMDSharding requires importing from a private JAX _src module (line 31), which is risky as these internal APIs can change without warning. This violates the maintainability principle from the style guide (line 12).

A more robust and future-proof approach is to check against the public abstract base class jax.sharding.Sharding. This will cover all sharding types (NamedSharding, SingleDeviceSharding, GSPMDSharding, etc.) and allows you to remove the import from jax._src.sharding_impls.

    if isinstance(x, jax.sharding.Sharding):
References
  1. The style guide states that code should be maintainable and easy to modify and extend. Relying on private, unstable APIs from dependencies makes the code harder to maintain as it can break with dependency updates. (link)

return x
else:
Expand Down