diff --git a/tests/tx/models/test_qwen3.py b/tests/tx/models/test_qwen3.py index 7028fc2357..2b401c11b9 100644 --- a/tests/tx/models/test_qwen3.py +++ b/tests/tx/models/test_qwen3.py @@ -65,6 +65,32 @@ def load_moe_base_weights(jax_moe_layer: Qwen3MoeSparseMoeBlock, hf_moe_layer: H ) +def assert_allclose_mixed_scale( + actual, + expected, + *, + rtol: float = 1e-3, + base_atol: float = 1e-3, + scale_atol: float = 1e-6, + err_msg: str = "", +) -> None: + actual_arr = np.asarray(actual) + expected_arr = np.asarray(expected) + if not np.isfinite(actual_arr).all() or not np.isfinite(expected_arr).all(): + raise AssertionError(f"{err_msg}\nNon-finite values found in MoE LoRA comparison.") + + max_abs = max(float(np.max(np.abs(actual_arr))), float(np.max(np.abs(expected_arr))), 1.0) + atol = max(base_atol, scale_atol * max_abs) + np.testing.assert_allclose( + actual_arr, + expected_arr, + rtol=rtol, + atol=atol, + equal_nan=False, + err_msg=f"{err_msg}\nmax_abs={max_abs:.6g}, atol={atol:.6g}", + ) + + @pytest.mark.parametrize("ep,tp", [(1, 1), (1, 2), (2, 1)]) def test_qwen3_moe_layer(ep: int, tp: int): model_name = "trl-internal-testing/tiny-Qwen3MoeForCausalLM" @@ -130,7 +156,8 @@ def test_qwen3_moe_layer_lora(ep: int, tp: int): config = Qwen3Config(base_config, max_lora_adapters=3, max_lora_rank=4, shard_attention_heads=True) hf_moe_layer = hf_model.model.layers[0].mlp - x = torch.randn(3, 4, config.hidden_size) + torch_rng = torch.Generator().manual_seed(0) + x = torch.randn(3, 4, config.hidden_size, generator=torch_rng) mesh = jax.make_mesh((1, ep, tp), ("fsdp", "ep", "tp"), axis_types=(jax.sharding.AxisType.Auto,) * 3) with jax.set_mesh(mesh): @@ -177,7 +204,14 @@ def test_qwen3_moe_layer_lora(ep: int, tp: int): x_sample = x[sample_idx : sample_idx + 1].numpy() output_merged, _ = moe_layer_merged(x_sample, return_router_logits=True) - assert np.allclose(output_with_lora[sample_idx : sample_idx + 1], output_merged, rtol=1e-3, atol=1e-3) + assert_allclose_mixed_scale( + output_with_lora[sample_idx : sample_idx + 1], + output_merged, + rtol=1e-3, + base_atol=1e-3, + scale_atol=1e-6, + err_msg=f"MoE LoRA merged-weight mismatch for sample_idx={sample_idx}, adapter_idx={adapter_idx}", + ) def test_qwen3_lora():