From 824457a576a9272e3b1bbaa9cb690a0167a2bfd7 Mon Sep 17 00:00:00 2001 From: Orbax Authors Date: Fri, 6 Mar 2026 09:53:41 -0800 Subject: [PATCH] Support the XLA GPU compilation flags in Orbax PiperOrigin-RevId: 879684156 --- .../orbax/export/modules/obm_module_test.py | 30 +++++ .../model/core/python/compile_options_util.py | 92 ++++++++++++-- .../core/python/compile_options_util_test.py | 115 ++++++++++++++++++ 3 files changed, 230 insertions(+), 7 deletions(-) diff --git a/export/orbax/export/modules/obm_module_test.py b/export/orbax/export/modules/obm_module_test.py index 2d39efd93..9a096432e 100644 --- a/export/orbax/export/modules/obm_module_test.py +++ b/export/orbax/export/modules/obm_module_test.py @@ -384,6 +384,36 @@ def test_obm_module_bfloat16_conversion(self, enable_bf16_optimization): with self.subTest('test_weights_b_dtype'): self.assertEqual(module.model_params['b'].dtype, expected_dtype) + def test_obm_module_gpu_xla_flags_integration_stable(self): + param_shape = (2, 5) + param_dtype = jnp.dtype(jnp.float32) + param_spec = jax.ShapeDtypeStruct(shape=param_shape, dtype=param_dtype) + model_function_name = 'simple_add' + + jax2obm_options = obm_configs.Jax2ObmOptions( + checkpoint_path='checkpoint_path', + native_serialization_platforms=('cuda',), + xla_flags_per_platform={ + 'cuda': ['--xla_gpu_enable_latency_hiding_scheduler=true'] + }, + ) + + orbax_model_module = obm_module.ObmModule( + params=param_spec, + apply_fn={model_function_name: simple_add}, + jax2obm_options=jax2obm_options, + ) + + xla_compile_options_map = ( + orbax_model_module.xla_compile_options_per_platform + ) + self.assertIsNotNone(xla_compile_options_map) + build_options_cuda = xla_compile_options_map.map['cuda'] + self.assertIn( + 'xla_gpu_enable_latency_hiding_scheduler', + build_options_cuda.env_option_overrides, + ) + class GetSharedValueTest(parameterized.TestCase): diff --git a/model/orbax/experimental/model/core/python/compile_options_util.py b/model/orbax/experimental/model/core/python/compile_options_util.py index 7c572b249..fc734185e 100644 --- a/model/orbax/experimental/model/core/python/compile_options_util.py +++ b/model/orbax/experimental/model/core/python/compile_options_util.py @@ -27,6 +27,7 @@ from .google.protobuf import any_pb2 from .platforms.xla.service.jellyfish import tpu_compilation_environment_pb2 as tpu_comp_env_pb2 from .platforms.xla.service.jellyfish.python import tpu_compilation_environment as tpu_comp_env +from .third_party.neptune.model._src.core import xla_gpu_flag_validation from tensorflow.compiler.xla import xla_data_pb2 # pylint: disable=g-direct-tensorflow-import from tensorflow.compiler.xla import xla_pb2 # pylint: disable=g-direct-tensorflow-import from tensorflow.compiler.xla.pjrt.proto import compile_options_pb2 # pylint: disable=g-direct-tensorflow-import @@ -165,6 +166,10 @@ def generate_xla_compile_options( tpu_platform_name = manifest_pb2.Platform.Name( manifest_pb2.Platform.TPU ).lower() + cuda_platform_name = manifest_pb2.Platform.Name( + manifest_pb2.Platform.CUDA + ).lower() + compile_options_map = manifest_pb2.CompileOptionsProtoMap() if native_serialization_platforms is None: # If no native serialization platforms are specified, we will set the @@ -195,24 +200,97 @@ def generate_xla_compile_options( ) for platform in platforms: - if platform.lower() == tpu_platform_name: - if xla_flags_per_platform: - xla_flags_overrides = xla_flags_per_platform.get(platform, None) + if xla_flags_per_platform: + xla_flags_overrides = xla_flags_per_platform.get(platform, None) + if xla_flags_overrides: _validate_xla_flags_setting(xla_flags_overrides, persist_xla_flags) - else: - xla_flags_overrides = None + else: + xla_flags_overrides = None + + platform_lower = platform.lower() + if platform_lower == tpu_platform_name: compile_environment = _generate_tpu_compilation_env(xla_flags_overrides) + elif platform_lower == cuda_platform_name: + # GPU Trick: Empty proto to bypass 'None' check and enable jax_mesh + # serialization. + compile_environment = xla_pb2.CompilationEnvironmentsProto() else: + # CPU Path: Leave as None to preserve legacy portable execution behavior. compile_environment = None - compile_options_map.map[platform.lower()].CopyFrom( - _generate_compilation_options(compile_environment, jax_mesh) + + compile_options = _generate_compilation_options( + compile_environment, jax_mesh ) + + # Inject env_option_overrides natively for GPU using a dedicated helper. + if platform_lower == cuda_platform_name and xla_flags_overrides: + _apply_gpu_compilation_env_options(compile_options, xla_flags_overrides) + + compile_options_map.map[platform_lower].CopyFrom(compile_options) + if not persist_xla_flags: for compile_options in compile_options_map.map.values(): compile_options.executable_build_options.comp_envs.Clear() return compile_options_map +def _apply_gpu_compilation_env_options( + compile_options: compile_options_pb2.CompileOptionsProto, + xla_flags_overrides: Sequence[str], +) -> None: + """Applies XLA flag overrides generically for GPU platforms. + + Args: + compile_options: The compilation options proto to be modified. + xla_flags_overrides: A sequence of XLA flags to apply as option overrides. + """ + overrides_map = _parse_env_option_overrides(xla_flags_overrides) + for k, v in overrides_map.items(): + compile_options.env_option_overrides[k].CopyFrom(v) + + +def _parse_env_option_overrides( + xla_flags: Sequence[str], +) -> dict[str, compile_options_pb2.OptionOverrideProto]: + """Parses a list of XLA flags into a dictionary of OptionOverrideProto.""" + overrides = {} + for flag in xla_flags: + if not flag.startswith('--'): + raise ValueError(f"Flag {flag} must start with '--'") + + try: + # Use the C++ ValidateXlaGPUFlag logic to ensure consistent policy + # enforcement across Python and C++ layers. + # The C++ function expects the flag with the '--' prefix. + xla_gpu_flag_validation.validate_xla_gpu_flag(flag, strict=True) + except Exception as e: + # pybind11_abseil appends the status code name to the exception string. + # Remove it to match exactly what users would see from the C++ binaries. + err_msg = str(e) + if err_msg.endswith(' [INVALID_ARGUMENT]'): + err_msg = err_msg.removesuffix(' [INVALID_ARGUMENT]') + raise ValueError(err_msg) from e + + key, value = flag[2:].split('=', 1) + override_proto = compile_options_pb2.OptionOverrideProto() + + # Infer type (True/False/Int/Float/String) + if value.lower() == 'true': + override_proto.bool_field = True + elif value.lower() == 'false': + override_proto.bool_field = False + elif value.isdigit() or (value.startswith('-') and value[1:].isdigit()): + override_proto.int_field = int(value) + else: + try: + override_proto.double_field = float(value) + except ValueError: + override_proto.string_field = value + + overrides[key] = override_proto + return overrides + + def _validate_xla_flags_setting( xla_flags_overrides: Sequence[str] | None, persist_xla_flags: bool ) -> None: diff --git a/model/orbax/experimental/model/core/python/compile_options_util_test.py b/model/orbax/experimental/model/core/python/compile_options_util_test.py index a49358632..22248e17d 100644 --- a/model/orbax/experimental/model/core/python/compile_options_util_test.py +++ b/model/orbax/experimental/model/core/python/compile_options_util_test.py @@ -328,6 +328,121 @@ def test_generate_xla_compile_options_xla_flags_no_persist_raise_error(self): persist_xla_flags=False, ) + def test_generate_xla_compile_options_env_overrides(self): + compile_options_map = compile_options_util.generate_xla_compile_options( + native_serialization_platforms=['cuda'], + xla_flags_per_platform={ + 'cuda': [ + '--xla_gpu_enable_latency_hiding_scheduler=true', + '--xla_gpu_autotune_level=0', + ] + }, + persist_xla_flags=True, + ) + self.assertIn('cuda', compile_options_map.map) + compile_options = compile_options_map.map['cuda'] + + overrides = compile_options.env_option_overrides + self.assertIn('xla_gpu_enable_latency_hiding_scheduler', overrides) + self.assertTrue( + overrides['xla_gpu_enable_latency_hiding_scheduler'].bool_field + ) + + self.assertIn('xla_gpu_autotune_level', overrides) + self.assertEqual(overrides['xla_gpu_autotune_level'].int_field, 0) + + def test_generate_xla_compile_options_gpu_flags_experimental_rejection(self): + with self.assertRaisesRegex( + ValueError, + r'XLA GPU compilation flag --xla_gpu_experimental_flag=true is not' + r' supported. Please check field description at' + r' CompilationConfig::xla_gpu_flags', + ): + compile_options_util.generate_xla_compile_options( + native_serialization_platforms=['cuda'], + xla_flags_per_platform={'cuda': ['--xla_gpu_experimental_flag=true']}, + persist_xla_flags=True, + ) + + @parameterized.named_parameters( + dict( + testcase_name='bool_true', + flag='--xla_gpu_enable_latency_hiding_scheduler=true', + expected_key='xla_gpu_enable_latency_hiding_scheduler', + expected_field='bool_field', + expected_value=True, + ), + dict( + testcase_name='bool_false', + flag='--xla_gpu_enable_latency_hiding_scheduler=false', + expected_key='xla_gpu_enable_latency_hiding_scheduler', + expected_field='bool_field', + expected_value=False, + ), + dict( + testcase_name='bool_uppercase_true', + flag='--xla_gpu_enable_latency_hiding_scheduler=TRUE', + expected_key='xla_gpu_enable_latency_hiding_scheduler', + expected_field='bool_field', + expected_value=True, + ), + dict( + testcase_name='int_positive', + flag='--xla_gpu_autotune_level=4', + expected_key='xla_gpu_autotune_level', + expected_field='int_field', + expected_value=4, + ), + dict( + testcase_name='int_negative', + flag='--xla_gpu_nccl_termination_timeout_seconds=-1', + expected_key='xla_gpu_nccl_termination_timeout_seconds', + expected_field='int_field', + expected_value=-1, + ), + dict( + testcase_name='float_positive', + flag='--xla_gpu_auto_spmd_partitioning_memory_budget_ratio=1.5', + expected_key='xla_gpu_auto_spmd_partitioning_memory_budget_ratio', + expected_field='double_field', + expected_value=1.5, + ), + dict( + testcase_name='float_negative', + flag='--xla_gpu_auto_spmd_partitioning_memory_budget_ratio=-0.5', + expected_key='xla_gpu_auto_spmd_partitioning_memory_budget_ratio', + expected_field='double_field', + expected_value=-0.5, + ), + dict( + testcase_name='string_value', + flag='--xla_gpu_cuda_data_dir=/usr/local/cuda', + expected_key='xla_gpu_cuda_data_dir', + expected_field='string_field', + expected_value='/usr/local/cuda', + ), + ) + @mock.patch.object( + compile_options_util.xla_gpu_flag_validation, 'validate_xla_gpu_flag' + ) + def test_generate_xla_compile_options_gpu_flags_type_inference( + self, mock_validate, flag, expected_key, expected_field, expected_value + ): + del mock_validate # Unused, just patching for bypass + compile_options_map = compile_options_util.generate_xla_compile_options( + native_serialization_platforms=['cuda'], + xla_flags_per_platform={'cuda': [flag]}, + persist_xla_flags=True, + ) + self.assertIsNotNone(compile_options_map.map) + build_options_cuda = compile_options_map.map['cuda'] + self.assertIn(expected_key, build_options_cuda.env_option_overrides) + override_proto = build_options_cuda.env_option_overrides[expected_key] + with self.subTest('test_oneof_field'): + self.assertEqual(override_proto.WhichOneof('value'), expected_field) + with self.subTest('test_value'): + self.assertEqual(getattr(override_proto, expected_field), expected_value) + @parameterized.named_parameters( dict( testcase_name='1d_mesh',