Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions export/orbax/export/modules/obm_module_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,36 @@ def test_obm_module_bfloat16_conversion(self, enable_bf16_optimization):
with self.subTest('test_weights_b_dtype'):
self.assertEqual(module.model_params['b'].dtype, expected_dtype)

def test_obm_module_gpu_xla_flags_integration_stable(self):
param_shape = (2, 5)
param_dtype = jnp.dtype(jnp.float32)
param_spec = jax.ShapeDtypeStruct(shape=param_shape, dtype=param_dtype)
model_function_name = 'simple_add'

jax2obm_options = obm_configs.Jax2ObmOptions(
checkpoint_path='checkpoint_path',
native_serialization_platforms=('cuda',),
xla_flags_per_platform={
'cuda': ['--xla_gpu_enable_latency_hiding_scheduler=true']
},
)

orbax_model_module = obm_module.ObmModule(
params=param_spec,
apply_fn={model_function_name: simple_add},
jax2obm_options=jax2obm_options,
)

xla_compile_options_map = (
orbax_model_module.xla_compile_options_per_platform
)
self.assertIsNotNone(xla_compile_options_map)
build_options_cuda = xla_compile_options_map.map['cuda']
self.assertIn(
'xla_gpu_enable_latency_hiding_scheduler',
build_options_cuda.env_option_overrides,
)


class GetSharedValueTest(parameterized.TestCase):

Expand Down
92 changes: 85 additions & 7 deletions model/orbax/experimental/model/core/python/compile_options_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from .google.protobuf import any_pb2
from .platforms.xla.service.jellyfish import tpu_compilation_environment_pb2 as tpu_comp_env_pb2
from .platforms.xla.service.jellyfish.python import tpu_compilation_environment as tpu_comp_env
from .third_party.neptune.model._src.core import xla_gpu_flag_validation
from tensorflow.compiler.xla import xla_data_pb2 # pylint: disable=g-direct-tensorflow-import
from tensorflow.compiler.xla import xla_pb2 # pylint: disable=g-direct-tensorflow-import
from tensorflow.compiler.xla.pjrt.proto import compile_options_pb2 # pylint: disable=g-direct-tensorflow-import
Expand Down Expand Up @@ -165,6 +166,10 @@ def generate_xla_compile_options(
tpu_platform_name = manifest_pb2.Platform.Name(
manifest_pb2.Platform.TPU
).lower()
cuda_platform_name = manifest_pb2.Platform.Name(
manifest_pb2.Platform.CUDA
).lower()

compile_options_map = manifest_pb2.CompileOptionsProtoMap()
if native_serialization_platforms is None:
# If no native serialization platforms are specified, we will set the
Expand Down Expand Up @@ -195,24 +200,97 @@ def generate_xla_compile_options(
)

for platform in platforms:
if platform.lower() == tpu_platform_name:
if xla_flags_per_platform:
xla_flags_overrides = xla_flags_per_platform.get(platform, None)
if xla_flags_per_platform:
xla_flags_overrides = xla_flags_per_platform.get(platform, None)
if xla_flags_overrides:
_validate_xla_flags_setting(xla_flags_overrides, persist_xla_flags)
else:
xla_flags_overrides = None
else:
xla_flags_overrides = None

platform_lower = platform.lower()
if platform_lower == tpu_platform_name:
compile_environment = _generate_tpu_compilation_env(xla_flags_overrides)
elif platform_lower == cuda_platform_name:
# GPU Trick: Empty proto to bypass 'None' check and enable jax_mesh
# serialization.
compile_environment = xla_pb2.CompilationEnvironmentsProto()
else:
# CPU Path: Leave as None to preserve legacy portable execution behavior.
compile_environment = None
compile_options_map.map[platform.lower()].CopyFrom(
_generate_compilation_options(compile_environment, jax_mesh)

compile_options = _generate_compilation_options(
compile_environment, jax_mesh
)

# Inject env_option_overrides natively for GPU using a dedicated helper.
if platform_lower == cuda_platform_name and xla_flags_overrides:
_apply_gpu_compilation_env_options(compile_options, xla_flags_overrides)

compile_options_map.map[platform_lower].CopyFrom(compile_options)

if not persist_xla_flags:
for compile_options in compile_options_map.map.values():
compile_options.executable_build_options.comp_envs.Clear()
return compile_options_map


def _apply_gpu_compilation_env_options(
compile_options: compile_options_pb2.CompileOptionsProto,
xla_flags_overrides: Sequence[str],
) -> None:
"""Applies XLA flag overrides generically for GPU platforms.

Args:
compile_options: The compilation options proto to be modified.
xla_flags_overrides: A sequence of XLA flags to apply as option overrides.
"""
overrides_map = _parse_env_option_overrides(xla_flags_overrides)
for k, v in overrides_map.items():
compile_options.env_option_overrides[k].CopyFrom(v)


def _parse_env_option_overrides(
xla_flags: Sequence[str],
) -> dict[str, compile_options_pb2.OptionOverrideProto]:
"""Parses a list of XLA flags into a dictionary of OptionOverrideProto."""
overrides = {}
for flag in xla_flags:
if not flag.startswith('--'):
raise ValueError(f"Flag {flag} must start with '--'")

try:
# Use the C++ ValidateXlaGPUFlag logic to ensure consistent policy
# enforcement across Python and C++ layers.
# The C++ function expects the flag with the '--' prefix.
xla_gpu_flag_validation.validate_xla_gpu_flag(flag, strict=True)
except Exception as e:
# pybind11_abseil appends the status code name to the exception string.
# Remove it to match exactly what users would see from the C++ binaries.
err_msg = str(e)
if err_msg.endswith(' [INVALID_ARGUMENT]'):
err_msg = err_msg.removesuffix(' [INVALID_ARGUMENT]')
raise ValueError(err_msg) from e

key, value = flag[2:].split('=', 1)
override_proto = compile_options_pb2.OptionOverrideProto()

# Infer type (True/False/Int/Float/String)
if value.lower() == 'true':
override_proto.bool_field = True
elif value.lower() == 'false':
override_proto.bool_field = False
elif value.isdigit() or (value.startswith('-') and value[1:].isdigit()):
override_proto.int_field = int(value)
else:
try:
override_proto.double_field = float(value)
except ValueError:
override_proto.string_field = value

overrides[key] = override_proto
return overrides


def _validate_xla_flags_setting(
xla_flags_overrides: Sequence[str] | None, persist_xla_flags: bool
) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,121 @@ def test_generate_xla_compile_options_xla_flags_no_persist_raise_error(self):
persist_xla_flags=False,
)

def test_generate_xla_compile_options_env_overrides(self):
compile_options_map = compile_options_util.generate_xla_compile_options(
native_serialization_platforms=['cuda'],
xla_flags_per_platform={
'cuda': [
'--xla_gpu_enable_latency_hiding_scheduler=true',
'--xla_gpu_autotune_level=0',
]
},
persist_xla_flags=True,
)
self.assertIn('cuda', compile_options_map.map)
compile_options = compile_options_map.map['cuda']

overrides = compile_options.env_option_overrides
self.assertIn('xla_gpu_enable_latency_hiding_scheduler', overrides)
self.assertTrue(
overrides['xla_gpu_enable_latency_hiding_scheduler'].bool_field
)

self.assertIn('xla_gpu_autotune_level', overrides)
self.assertEqual(overrides['xla_gpu_autotune_level'].int_field, 0)

def test_generate_xla_compile_options_gpu_flags_experimental_rejection(self):
with self.assertRaisesRegex(
ValueError,
r'XLA GPU compilation flag --xla_gpu_experimental_flag=true is not'
r' supported. Please check field description at'
r' CompilationConfig::xla_gpu_flags',
):
compile_options_util.generate_xla_compile_options(
native_serialization_platforms=['cuda'],
xla_flags_per_platform={'cuda': ['--xla_gpu_experimental_flag=true']},
persist_xla_flags=True,
)

@parameterized.named_parameters(
dict(
testcase_name='bool_true',
flag='--xla_gpu_enable_latency_hiding_scheduler=true',
expected_key='xla_gpu_enable_latency_hiding_scheduler',
expected_field='bool_field',
expected_value=True,
),
dict(
testcase_name='bool_false',
flag='--xla_gpu_enable_latency_hiding_scheduler=false',
expected_key='xla_gpu_enable_latency_hiding_scheduler',
expected_field='bool_field',
expected_value=False,
),
dict(
testcase_name='bool_uppercase_true',
flag='--xla_gpu_enable_latency_hiding_scheduler=TRUE',
expected_key='xla_gpu_enable_latency_hiding_scheduler',
expected_field='bool_field',
expected_value=True,
),
dict(
testcase_name='int_positive',
flag='--xla_gpu_autotune_level=4',
expected_key='xla_gpu_autotune_level',
expected_field='int_field',
expected_value=4,
),
dict(
testcase_name='int_negative',
flag='--xla_gpu_nccl_termination_timeout_seconds=-1',
expected_key='xla_gpu_nccl_termination_timeout_seconds',
expected_field='int_field',
expected_value=-1,
),
dict(
testcase_name='float_positive',
flag='--xla_gpu_auto_spmd_partitioning_memory_budget_ratio=1.5',
expected_key='xla_gpu_auto_spmd_partitioning_memory_budget_ratio',
expected_field='double_field',
expected_value=1.5,
),
dict(
testcase_name='float_negative',
flag='--xla_gpu_auto_spmd_partitioning_memory_budget_ratio=-0.5',
expected_key='xla_gpu_auto_spmd_partitioning_memory_budget_ratio',
expected_field='double_field',
expected_value=-0.5,
),
dict(
testcase_name='string_value',
flag='--xla_gpu_cuda_data_dir=/usr/local/cuda',
expected_key='xla_gpu_cuda_data_dir',
expected_field='string_field',
expected_value='/usr/local/cuda',
),
)
@mock.patch.object(
compile_options_util.xla_gpu_flag_validation, 'validate_xla_gpu_flag'
)
def test_generate_xla_compile_options_gpu_flags_type_inference(
self, mock_validate, flag, expected_key, expected_field, expected_value
):
del mock_validate # Unused, just patching for bypass
compile_options_map = compile_options_util.generate_xla_compile_options(
native_serialization_platforms=['cuda'],
xla_flags_per_platform={'cuda': [flag]},
persist_xla_flags=True,
)
self.assertIsNotNone(compile_options_map.map)
build_options_cuda = compile_options_map.map['cuda']
self.assertIn(expected_key, build_options_cuda.env_option_overrides)
override_proto = build_options_cuda.env_option_overrides[expected_key]
with self.subTest('test_oneof_field'):
self.assertEqual(override_proto.WhichOneof('value'), expected_field)
with self.subTest('test_value'):
self.assertEqual(getattr(override_proto, expected_field), expected_value)

@parameterized.named_parameters(
dict(
testcase_name='1d_mesh',
Expand Down
Loading