From 0d87ef08e398bf3e621e31e65e5fd37e8c6e2264 Mon Sep 17 00:00:00 2001 From: tdophung Date: Mon, 22 Dec 2025 16:19:36 -0800 Subject: [PATCH 1/8] Add triton version detection logic, and NVTE_USE_PYTORCH_TRITON knob for jax Signed-off-by: tdophung --- .../jax/triton_extensions/__init__.py | 33 +++- .../jax/triton_extensions/utils.py | 173 +++++++++++++++++- 2 files changed, 204 insertions(+), 2 deletions(-) diff --git a/transformer_engine/jax/triton_extensions/__init__.py b/transformer_engine/jax/triton_extensions/__init__.py index c98254d7d3..ae98aaacc0 100644 --- a/transformer_engine/jax/triton_extensions/__init__.py +++ b/transformer_engine/jax/triton_extensions/__init__.py @@ -9,7 +9,33 @@ IMPORTANT: This module requires Triton to be installed. If you don't have Triton, use transformer_engine.jax.cpp_extensions instead (CUDA/FFI based primitives). -Install Triton: pip install triton + +Triton Package Options: +----------------------- +There are two compatible Triton packages: + +1. Standard 'triton' from OpenAI (recommended for JAX-only environments): + pip install triton + +2. 'pytorch-triton' from PyTorch's index (for mixed JAX+PyTorch environments): + pip install torch --index-url https://download.pytorch.org/whl/cu121 + # pytorch-triton is automatically installed as a dependency + + Both packages work with JAX Triton kernels. The pytorch-triton package + has version format "X.Y.Z+" (e.g., "3.0.0+45fff310c8"). + +WARNING: Do NOT run 'pip install pytorch-triton' directly! The package on PyPI +is a placeholder that will fail with "RuntimeError: Should never be installed". +The real pytorch-triton only comes bundled with PyTorch from PyTorch's index. + + +Environment Variables: + NVTE_USE_PYTORCH_TRITON: If set to "1", acknowledge using pytorch-triton + for JAX Triton kernels (suppresses compatibility warnings). Set this + when both JAX and PyTorch are installed in the same environment. + + Example: + export NVTE_USE_PYTORCH_TRITON=1 Usage: @@ -23,6 +49,11 @@ def lowering(ctx, x, **kwargs): # Use permutation functions from transformer_engine.jax.triton_extensions import make_row_id_map, permute_with_mask_map + + # Check Triton package info + from transformer_engine.jax.triton_extensions import get_triton_info + info = get_triton_info() + print(f"Using Triton {info['version']} from {info['source']}") """ from .utils import * diff --git a/transformer_engine/jax/triton_extensions/utils.py b/transformer_engine/jax/triton_extensions/utils.py index 064b2843c6..d157740fcc 100644 --- a/transformer_engine/jax/triton_extensions/utils.py +++ b/transformer_engine/jax/triton_extensions/utils.py @@ -6,9 +6,33 @@ This module provides utility functions for integrating Triton kernels into JAX primitives. Triton is only imported when this module is used. + +Triton Package Compatibility: + There are two Triton packages that can be used: + + 1. 'triton' (from OpenAI/PyPI): Standard package, works with JAX out of the box. + Install with: pip install triton + + 2. 'pytorch-triton' (from PyTorch's index): Bundled with PyTorch, includes + PyTorch-specific patches. Version format: "3.0.0+" + + IMPORTANT: The 'pytorch-triton' package on PyPI (version 0.0.1) is a + placeholder that will NOT work. The real pytorch-triton is only available + from PyTorch's package index and is auto-installed with PyTorch: + pip install torch --index-url https://download.pytorch.org/whl/cu121 + + pytorch-triton has been tested to work with JAX Triton kernels. + +Environment Variables: + NVTE_USE_PYTORCH_TRITON: If set to "1", explicitly acknowledge using + pytorch-triton for JAX Triton kernels (suppresses warnings). This is + useful when both JAX and PyTorch are installed in the same environment. + Default is "0". """ import hashlib +import os +import warnings from typing import Any, Callable, Mapping import zlib @@ -17,6 +41,115 @@ import jax.numpy as jnp +# Placeholder package version on PyPI that should never be used +_PYTORCH_TRITON_PLACEHOLDER_VERSION = "0.0.1" + + +def _detect_triton_package(): + """Detect which Triton package is installed and validate compatibility. + + Returns: + tuple: (triton_version: str or None, is_pytorch_triton: bool, is_placeholder: bool) + + The function detects: + - None: Triton not installed + - Standard triton from OpenAI (versions like "3.1.0") + - Real pytorch-triton from PyTorch's index (versions like "3.0.0+45fff310c8") + - Placeholder pytorch-triton from PyPI (version "0.0.1" - broken, raises RuntimeError) + """ + try: + import triton + triton_version = getattr(triton, "__version__", "unknown") + except ImportError: + return None, False, False + except RuntimeError as e: + # The placeholder pytorch-triton package from PyPI raises: + # RuntimeError: "Should never be installed" + if "Should never be installed" in str(e): + return _PYTORCH_TRITON_PLACEHOLDER_VERSION, False, True + raise + + # Check for placeholder package (version 0.0.1 from PyPI) + is_placeholder = triton_version == _PYTORCH_TRITON_PLACEHOLDER_VERSION + + # Real pytorch-triton versions have a commit SHA suffix like "3.0.0+45fff310c8" + is_pytorch_triton = "+" in triton_version and len(triton_version.split("+")[-1]) >= 8 + + return triton_version, is_pytorch_triton, is_placeholder + + +def _check_triton_compatibility(): + """Check Triton package compatibility and emit warnings if necessary. + + This function handles the case where both JAX and PyTorch may be installed, + each expecting different Triton packages: + - JAX typically uses the standard 'triton' package from OpenAI + - PyTorch uses 'pytorch-triton' which is versioned with commit SHAs + + The NVTE_USE_PYTORCH_TRITON environment variable can be used to explicitly + acknowledge using pytorch-triton with JAX (suppresses warnings). + + Raises: + ImportError: If triton is not installed or the placeholder package is detected. + """ + triton_version, is_pytorch_triton, is_placeholder = _detect_triton_package() + + # Handle placeholder package from PyPI + if is_placeholder: + raise ImportError( + "Detected the placeholder 'pytorch-triton' package (version 0.0.1) from PyPI.\n" + "This is NOT a functional Triton installation.\n\n" + "The placeholder package exists to prevent namespace conflicts. To fix this:\n\n" + "Option 1 - Use standard Triton (recommended for JAX-only environments):\n" + " pip uninstall pytorch-triton triton\n" + " pip install triton\n\n" + "Option 2 - Use real pytorch-triton (for mixed JAX+PyTorch environments):\n" + " pip uninstall pytorch-triton triton\n" + " pip install torch --index-url https://download.pytorch.org/whl/cu121\n" + " # pytorch-triton is automatically installed as a torch dependency\n\n" + "Note: Do NOT run 'pip install pytorch-triton' directly - this installs\n" + "the broken placeholder. The real pytorch-triton only comes from PyTorch's index." + ) + + if triton_version is None: + raise ImportError( + "Triton is required for transformer_engine.jax.triton_extensions.\n\n" + "Option 1 - Install standard Triton (recommended for JAX-only):\n" + " pip install triton\n\n" + "Option 2 - Install PyTorch with pytorch-triton (for mixed environments):\n" + " pip install torch --index-url https://download.pytorch.org/whl/cu121\n\n" + "If you don't need Triton, use transformer_engine.jax.cpp_extensions instead." + ) + + use_pytorch_triton_env = os.environ.get("NVTE_USE_PYTORCH_TRITON", "0").lower() + use_pytorch_triton_explicit = use_pytorch_triton_env in ("1", "true", "yes") + + if is_pytorch_triton: + if use_pytorch_triton_explicit: + # User explicitly opted in - just log info (no warning) + pass # Silent acknowledgment, no warning needed + else: + # pytorch-triton detected but user didn't explicitly opt in + warnings.warn( + f"Detected pytorch-triton package (version {triton_version}) instead of " + f"the standard 'triton' package from OpenAI. This typically happens when " + f"PyTorch is installed alongside JAX.\n\n" + f"pytorch-triton is compatible with JAX Triton kernels. To suppress this " + f"warning, set:\n" + f" export NVTE_USE_PYTORCH_TRITON=1\n\n" + f"Alternatively, for a JAX-only environment:\n" + f" - Use separate virtual environments for JAX and PyTorch, or\n" + f" - Use transformer_engine.jax.cpp_extensions instead (CUDA-based, no Triton needed)", + category=UserWarning, + stacklevel=3, + ) + + return triton_version, is_pytorch_triton + + +# Perform compatibility check and get triton info +_TRITON_VERSION, _IS_PYTORCH_TRITON = _check_triton_compatibility() + try: from jax._src.lib import gpu_triton from triton.compiler import compiler as tc @@ -30,12 +163,42 @@ ) from e -__all__ = ["triton_call_lowering"] +__all__ = ["triton_call_lowering", "get_triton_info"] # Triton kernel cache (module-level, shared across all kernels) _TRITON_KERNEL_CACHE = {} +def get_triton_info(): + """Get information about the installed Triton package. + + Returns: + dict: Dictionary containing: + - version (str): Triton version string (e.g., "3.1.0" or "3.0.0+45fff310c8") + - is_pytorch_triton (bool): True if using real pytorch-triton from PyTorch's index + - is_openai_triton (bool): True if using standard triton from OpenAI/PyPI + - env_acknowledged (bool): True if NVTE_USE_PYTORCH_TRITON=1 is set + - source (str): "pytorch" or "openai" indicating the package source + + Example: + >>> from transformer_engine.jax.triton_extensions import get_triton_info + >>> info = get_triton_info() + >>> print(f"Triton version: {info['version']} (from {info['source']})") + >>> if info['is_pytorch_triton']: + ... print("Using pytorch-triton - compatible with both PyTorch and JAX") + """ + use_pytorch_triton_env = os.environ.get("NVTE_USE_PYTORCH_TRITON", "0").lower() + env_acknowledged = use_pytorch_triton_env in ("1", "true", "yes") + + return { + "version": _TRITON_VERSION, + "is_pytorch_triton": _IS_PYTORCH_TRITON, + "is_openai_triton": not _IS_PYTORCH_TRITON, + "env_acknowledged": env_acknowledged and _IS_PYTORCH_TRITON, + "source": "pytorch" if _IS_PYTORCH_TRITON else "openai", + } + + def get_triton_dtype(aval): """Convert JAX dtype to Triton type string. @@ -142,7 +305,11 @@ def compile_triton( ) # Create kernel object for JAX +<<<<<<< HEAD # From jax/jaxlib/gpu/triton_kernels.cc: +======= + # From jax/jaxlib/gpu/triton_kernels.cc: +>>>>>>> de40a714 (Add triton version detection logic, and NVTE_USE_PYTORCH_TRITON knob for jax) from packaging import version if version.parse(jax.__version__) >= version.parse("0.8.2"): @@ -157,7 +324,11 @@ def compile_triton( ) else: kernel = gpu_triton.TritonKernel( +<<<<<<< HEAD compiled.name, +======= + compile.name, +>>>>>>> de40a714 (Add triton version detection logic, and NVTE_USE_PYTORCH_TRITON knob for jax) num_warps, compiled.metadata.shared, compiled.asm["ptx"], From 4e18ee9bf5a037d1e099c71fd329b508d7549546 Mon Sep 17 00:00:00 2001 From: tdophung Date: Mon, 22 Dec 2025 16:48:31 -0800 Subject: [PATCH 2/8] change build requirements and installation to reflect new option Signed-off-by: tdophung --- build_tools/jax.py | 25 +++++++++++++++++++++++-- build_tools/pytorch.py | 14 ++++++++++++-- 2 files changed, 35 insertions(+), 4 deletions(-) diff --git a/build_tools/jax.py b/build_tools/jax.py index 276c9943d6..ec410060a4 100644 --- a/build_tools/jax.py +++ b/build_tools/jax.py @@ -19,8 +19,29 @@ def install_requirements() -> List[str]: def test_requirements() -> List[str]: - """Test dependencies for TE/JAX extensions.""" - return ["numpy", "triton"] + """Test dependencies for TE/JAX extensions. + + Triton Package Selection: + The triton package is selected based on NVTE_USE_PYTORCH_TRITON environment variable: + + Default (NVTE_USE_PYTORCH_TRITON unset or "0"): + Returns 'triton' - OpenAI's standard package from PyPI. + Install with: pip install triton + + NVTE_USE_PYTORCH_TRITON=1: + Returns 'pytorch-triton' - for mixed JAX+PyTorch environments. + Install with: pip install pytorch-triton --index-url https://download.pytorch.org/whl/cu121 + + Note: Do NOT install pytorch-triton from PyPI directly - that's a placeholder. + """ + use_pytorch_triton = os.environ.get("NVTE_USE_PYTORCH_TRITON", "0").lower() in ("1", "true", "yes") + + triton_package = "pytorch-triton" if use_pytorch_triton else "triton" + + return [ + "numpy", + triton_package, + ] def xla_path() -> str: diff --git a/build_tools/pytorch.py b/build_tools/pytorch.py index b4815a0942..7b8522b052 100644 --- a/build_tools/pytorch.py +++ b/build_tools/pytorch.py @@ -13,7 +13,17 @@ def install_requirements() -> List[str]: - """Install dependencies for TE/PyTorch extensions.""" + """Install dependencies for TE/PyTorch extensions. + + IMPORTANT - PyTorch Index Required for pytorch-triton: + These dependencies MUST be installed using PyTorch's package index: + + pip install pytorch-triton --index-url https://download.pytorch.org/whl/ + + - pytorch-triton is only available from PyTorch's index (not PyPI) + - The 'pytorch-triton' package on PyPI is a placeholder that will fail + - torch.compile() requires pytorch-triton, not OpenAI's 'triton' package + """ return [ "torch>=2.1", "einops", @@ -22,7 +32,7 @@ def install_requirements() -> List[str]: "packaging", "pydantic", "nvdlfw-inspect", - "triton", + "pytorch-triton", ] From d736431b1afa332b3fd9ba931ac70ea56fc53361 Mon Sep 17 00:00:00 2001 From: tdophung Date: Mon, 22 Dec 2025 16:54:53 -0800 Subject: [PATCH 3/8] reduce boilerplate comments Signed-off-by: tdophung --- .../jax/triton_extensions/utils.py | 43 +------------------ 1 file changed, 2 insertions(+), 41 deletions(-) diff --git a/transformer_engine/jax/triton_extensions/utils.py b/transformer_engine/jax/triton_extensions/utils.py index d157740fcc..86922ff1e0 100644 --- a/transformer_engine/jax/triton_extensions/utils.py +++ b/transformer_engine/jax/triton_extensions/utils.py @@ -7,27 +7,7 @@ This module provides utility functions for integrating Triton kernels into JAX primitives. Triton is only imported when this module is used. -Triton Package Compatibility: - There are two Triton packages that can be used: - - 1. 'triton' (from OpenAI/PyPI): Standard package, works with JAX out of the box. - Install with: pip install triton - - 2. 'pytorch-triton' (from PyTorch's index): Bundled with PyTorch, includes - PyTorch-specific patches. Version format: "3.0.0+" - - IMPORTANT: The 'pytorch-triton' package on PyPI (version 0.0.1) is a - placeholder that will NOT work. The real pytorch-triton is only available - from PyTorch's package index and is auto-installed with PyTorch: - pip install torch --index-url https://download.pytorch.org/whl/cu121 - - pytorch-triton has been tested to work with JAX Triton kernels. - -Environment Variables: - NVTE_USE_PYTORCH_TRITON: If set to "1", explicitly acknowledge using - pytorch-triton for JAX Triton kernels (suppresses warnings). This is - useful when both JAX and PyTorch are installed in the same environment. - Default is "0". +Triton Package Compatibility --> see __init__.py """ import hashlib @@ -79,19 +59,7 @@ def _detect_triton_package(): def _check_triton_compatibility(): - """Check Triton package compatibility and emit warnings if necessary. - - This function handles the case where both JAX and PyTorch may be installed, - each expecting different Triton packages: - - JAX typically uses the standard 'triton' package from OpenAI - - PyTorch uses 'pytorch-triton' which is versioned with commit SHAs - - The NVTE_USE_PYTORCH_TRITON environment variable can be used to explicitly - acknowledge using pytorch-triton with JAX (suppresses warnings). - - Raises: - ImportError: If triton is not installed or the placeholder package is detected. - """ + """Check Triton package compatibility and emit warnings if necessary.""" triton_version, is_pytorch_triton, is_placeholder = _detect_triton_package() # Handle placeholder package from PyPI @@ -179,13 +147,6 @@ def get_triton_info(): - is_openai_triton (bool): True if using standard triton from OpenAI/PyPI - env_acknowledged (bool): True if NVTE_USE_PYTORCH_TRITON=1 is set - source (str): "pytorch" or "openai" indicating the package source - - Example: - >>> from transformer_engine.jax.triton_extensions import get_triton_info - >>> info = get_triton_info() - >>> print(f"Triton version: {info['version']} (from {info['source']})") - >>> if info['is_pytorch_triton']: - ... print("Using pytorch-triton - compatible with both PyTorch and JAX") """ use_pytorch_triton_env = os.environ.get("NVTE_USE_PYTORCH_TRITON", "0").lower() env_acknowledged = use_pytorch_triton_env in ("1", "true", "yes") From 38d55f3ce2100d76023099136c905a0ddc159d1b Mon Sep 17 00:00:00 2001 From: tdophung Date: Mon, 22 Dec 2025 16:56:35 -0800 Subject: [PATCH 4/8] format code Signed-off-by: tdophung --- build_tools/pytorch.py | 6 +++--- transformer_engine/jax/triton_extensions/__init__.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/build_tools/pytorch.py b/build_tools/pytorch.py index 7b8522b052..98511e45cb 100644 --- a/build_tools/pytorch.py +++ b/build_tools/pytorch.py @@ -14,12 +14,12 @@ def install_requirements() -> List[str]: """Install dependencies for TE/PyTorch extensions. - + IMPORTANT - PyTorch Index Required for pytorch-triton: These dependencies MUST be installed using PyTorch's package index: - + pip install pytorch-triton --index-url https://download.pytorch.org/whl/ - + - pytorch-triton is only available from PyTorch's index (not PyPI) - The 'pytorch-triton' package on PyPI is a placeholder that will fail - torch.compile() requires pytorch-triton, not OpenAI's 'triton' package diff --git a/transformer_engine/jax/triton_extensions/__init__.py b/transformer_engine/jax/triton_extensions/__init__.py index ae98aaacc0..d9708fde9f 100644 --- a/transformer_engine/jax/triton_extensions/__init__.py +++ b/transformer_engine/jax/triton_extensions/__init__.py @@ -33,7 +33,7 @@ NVTE_USE_PYTORCH_TRITON: If set to "1", acknowledge using pytorch-triton for JAX Triton kernels (suppresses compatibility warnings). Set this when both JAX and PyTorch are installed in the same environment. - + Example: export NVTE_USE_PYTORCH_TRITON=1 @@ -49,7 +49,7 @@ def lowering(ctx, x, **kwargs): # Use permutation functions from transformer_engine.jax.triton_extensions import make_row_id_map, permute_with_mask_map - + # Check Triton package info from transformer_engine.jax.triton_extensions import get_triton_info info = get_triton_info() From 3986ab6e72a0edc290b48c6c1f2174a1fa047817 Mon Sep 17 00:00:00 2001 From: tdophung Date: Mon, 22 Dec 2025 18:37:46 -0800 Subject: [PATCH 5/8] fix typo Signed-off-by: tdophung --- build_tools/jax.py | 18 ++++--- .../jax/triton_extensions/utils.py | 51 +++++++++++++++---- 2 files changed, 52 insertions(+), 17 deletions(-) diff --git a/build_tools/jax.py b/build_tools/jax.py index ec410060a4..999f4ede9b 100644 --- a/build_tools/jax.py +++ b/build_tools/jax.py @@ -20,24 +20,28 @@ def install_requirements() -> List[str]: def test_requirements() -> List[str]: """Test dependencies for TE/JAX extensions. - + Triton Package Selection: The triton package is selected based on NVTE_USE_PYTORCH_TRITON environment variable: - + Default (NVTE_USE_PYTORCH_TRITON unset or "0"): Returns 'triton' - OpenAI's standard package from PyPI. Install with: pip install triton - + NVTE_USE_PYTORCH_TRITON=1: Returns 'pytorch-triton' - for mixed JAX+PyTorch environments. Install with: pip install pytorch-triton --index-url https://download.pytorch.org/whl/cu121 - + Note: Do NOT install pytorch-triton from PyPI directly - that's a placeholder. """ - use_pytorch_triton = os.environ.get("NVTE_USE_PYTORCH_TRITON", "0").lower() in ("1", "true", "yes") - + use_pytorch_triton = os.environ.get("NVTE_USE_PYTORCH_TRITON", "0").lower() in ( + "1", + "true", + "yes", + ) + triton_package = "pytorch-triton" if use_pytorch_triton else "triton" - + return [ "numpy", triton_package, diff --git a/transformer_engine/jax/triton_extensions/utils.py b/transformer_engine/jax/triton_extensions/utils.py index 86922ff1e0..31391b2c33 100644 --- a/transformer_engine/jax/triton_extensions/utils.py +++ b/transformer_engine/jax/triton_extensions/utils.py @@ -7,7 +7,27 @@ This module provides utility functions for integrating Triton kernels into JAX primitives. Triton is only imported when this module is used. -Triton Package Compatibility --> see __init__.py +Triton Package Compatibility: + There are two Triton packages that can be used: + + 1. 'triton' (from OpenAI/PyPI): Standard package, works with JAX out of the box. + Install with: pip install triton + + 2. 'pytorch-triton' (from PyTorch's index): Bundled with PyTorch, includes + PyTorch-specific patches. Version format: "3.0.0+" + + IMPORTANT: The 'pytorch-triton' package on PyPI (version 0.0.1) is a + placeholder that will NOT work. The real pytorch-triton is only available + from PyTorch's package index and is auto-installed with PyTorch: + pip install torch --index-url https://download.pytorch.org/whl/cu121 + + pytorch-triton has been tested to work with JAX Triton kernels. + +Environment Variables: + NVTE_USE_PYTORCH_TRITON: If set to "1", explicitly acknowledge using + pytorch-triton for JAX Triton kernels (suppresses warnings). This is + useful when both JAX and PyTorch are installed in the same environment. + Default is "0". """ import hashlib @@ -59,7 +79,19 @@ def _detect_triton_package(): def _check_triton_compatibility(): - """Check Triton package compatibility and emit warnings if necessary.""" + """Check Triton package compatibility and emit warnings if necessary. + + This function handles the case where both JAX and PyTorch may be installed, + each expecting different Triton packages: + - JAX typically uses the standard 'triton' package from OpenAI + - PyTorch uses 'pytorch-triton' which is versioned with commit SHAs + + The NVTE_USE_PYTORCH_TRITON environment variable can be used to explicitly + acknowledge using pytorch-triton with JAX (suppresses warnings). + + Raises: + ImportError: If triton is not installed or the placeholder package is detected. + """ triton_version, is_pytorch_triton, is_placeholder = _detect_triton_package() # Handle placeholder package from PyPI @@ -147,6 +179,13 @@ def get_triton_info(): - is_openai_triton (bool): True if using standard triton from OpenAI/PyPI - env_acknowledged (bool): True if NVTE_USE_PYTORCH_TRITON=1 is set - source (str): "pytorch" or "openai" indicating the package source + + Example: + from transformer_engine.jax.triton_extensions import get_triton_info + info = get_triton_info() + print(f"Triton version: {info['version']} (from {info['source']})") + if info['is_pytorch_triton']: + print("Using pytorch-triton - compatible with both PyTorch and JAX") """ use_pytorch_triton_env = os.environ.get("NVTE_USE_PYTORCH_TRITON", "0").lower() env_acknowledged = use_pytorch_triton_env in ("1", "true", "yes") @@ -266,11 +305,7 @@ def compile_triton( ) # Create kernel object for JAX -<<<<<<< HEAD # From jax/jaxlib/gpu/triton_kernels.cc: -======= - # From jax/jaxlib/gpu/triton_kernels.cc: ->>>>>>> de40a714 (Add triton version detection logic, and NVTE_USE_PYTORCH_TRITON knob for jax) from packaging import version if version.parse(jax.__version__) >= version.parse("0.8.2"): @@ -285,11 +320,7 @@ def compile_triton( ) else: kernel = gpu_triton.TritonKernel( -<<<<<<< HEAD compiled.name, -======= - compile.name, ->>>>>>> de40a714 (Add triton version detection logic, and NVTE_USE_PYTORCH_TRITON knob for jax) num_warps, compiled.metadata.shared, compiled.asm["ptx"], From 69f88c8869431dfaa7ff0c527b53e18cca059bdd Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 23 Dec 2025 01:18:39 +0000 Subject: [PATCH 6/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../jax/triton_extensions/utils.py | 63 +++++++++---------- 1 file changed, 31 insertions(+), 32 deletions(-) diff --git a/transformer_engine/jax/triton_extensions/utils.py b/transformer_engine/jax/triton_extensions/utils.py index 31391b2c33..3477847e48 100644 --- a/transformer_engine/jax/triton_extensions/utils.py +++ b/transformer_engine/jax/triton_extensions/utils.py @@ -9,23 +9,23 @@ Triton Package Compatibility: There are two Triton packages that can be used: - + 1. 'triton' (from OpenAI/PyPI): Standard package, works with JAX out of the box. Install with: pip install triton - + 2. 'pytorch-triton' (from PyTorch's index): Bundled with PyTorch, includes PyTorch-specific patches. Version format: "3.0.0+" - - IMPORTANT: The 'pytorch-triton' package on PyPI (version 0.0.1) is a + + IMPORTANT: The 'pytorch-triton' package on PyPI (version 0.0.1) is a placeholder that will NOT work. The real pytorch-triton is only available from PyTorch's package index and is auto-installed with PyTorch: pip install torch --index-url https://download.pytorch.org/whl/cu121 - + pytorch-triton has been tested to work with JAX Triton kernels. Environment Variables: - NVTE_USE_PYTORCH_TRITON: If set to "1", explicitly acknowledge using - pytorch-triton for JAX Triton kernels (suppresses warnings). This is + NVTE_USE_PYTORCH_TRITON: If set to "1", explicitly acknowledge using + pytorch-triton for JAX Triton kernels (suppresses warnings). This is useful when both JAX and PyTorch are installed in the same environment. Default is "0". """ @@ -47,10 +47,10 @@ def _detect_triton_package(): """Detect which Triton package is installed and validate compatibility. - + Returns: tuple: (triton_version: str or None, is_pytorch_triton: bool, is_placeholder: bool) - + The function detects: - None: Triton not installed - Standard triton from OpenAI (versions like "3.1.0") @@ -59,6 +59,7 @@ def _detect_triton_package(): """ try: import triton + triton_version = getattr(triton, "__version__", "unknown") except ImportError: return None, False, False @@ -68,32 +69,32 @@ def _detect_triton_package(): if "Should never be installed" in str(e): return _PYTORCH_TRITON_PLACEHOLDER_VERSION, False, True raise - + # Check for placeholder package (version 0.0.1 from PyPI) is_placeholder = triton_version == _PYTORCH_TRITON_PLACEHOLDER_VERSION - + # Real pytorch-triton versions have a commit SHA suffix like "3.0.0+45fff310c8" is_pytorch_triton = "+" in triton_version and len(triton_version.split("+")[-1]) >= 8 - + return triton_version, is_pytorch_triton, is_placeholder def _check_triton_compatibility(): """Check Triton package compatibility and emit warnings if necessary. - + This function handles the case where both JAX and PyTorch may be installed, each expecting different Triton packages: - JAX typically uses the standard 'triton' package from OpenAI - PyTorch uses 'pytorch-triton' which is versioned with commit SHAs - + The NVTE_USE_PYTORCH_TRITON environment variable can be used to explicitly acknowledge using pytorch-triton with JAX (suppresses warnings). - + Raises: ImportError: If triton is not installed or the placeholder package is detected. """ triton_version, is_pytorch_triton, is_placeholder = _detect_triton_package() - + # Handle placeholder package from PyPI if is_placeholder: raise ImportError( @@ -110,7 +111,7 @@ def _check_triton_compatibility(): "Note: Do NOT run 'pip install pytorch-triton' directly - this installs\n" "the broken placeholder. The real pytorch-triton only comes from PyTorch's index." ) - + if triton_version is None: raise ImportError( "Triton is required for transformer_engine.jax.triton_extensions.\n\n" @@ -120,10 +121,10 @@ def _check_triton_compatibility(): " pip install torch --index-url https://download.pytorch.org/whl/cu121\n\n" "If you don't need Triton, use transformer_engine.jax.cpp_extensions instead." ) - + use_pytorch_triton_env = os.environ.get("NVTE_USE_PYTORCH_TRITON", "0").lower() use_pytorch_triton_explicit = use_pytorch_triton_env in ("1", "true", "yes") - + if is_pytorch_triton: if use_pytorch_triton_explicit: # User explicitly opted in - just log info (no warning) @@ -131,19 +132,17 @@ def _check_triton_compatibility(): else: # pytorch-triton detected but user didn't explicitly opt in warnings.warn( - f"Detected pytorch-triton package (version {triton_version}) instead of " - f"the standard 'triton' package from OpenAI. This typically happens when " - f"PyTorch is installed alongside JAX.\n\n" - f"pytorch-triton is compatible with JAX Triton kernels. To suppress this " - f"warning, set:\n" - f" export NVTE_USE_PYTORCH_TRITON=1\n\n" - f"Alternatively, for a JAX-only environment:\n" - f" - Use separate virtual environments for JAX and PyTorch, or\n" - f" - Use transformer_engine.jax.cpp_extensions instead (CUDA-based, no Triton needed)", + f"Detected pytorch-triton package (version {triton_version}) instead of the" + " standard 'triton' package from OpenAI. This typically happens when PyTorch is" + " installed alongside JAX.\n\npytorch-triton is compatible with JAX Triton" + " kernels. To suppress this warning, set:\n export" + " NVTE_USE_PYTORCH_TRITON=1\n\nAlternatively, for a JAX-only environment:\n - Use" + " separate virtual environments for JAX and PyTorch, or\n - Use" + " transformer_engine.jax.cpp_extensions instead (CUDA-based, no Triton needed)", category=UserWarning, stacklevel=3, ) - + return triton_version, is_pytorch_triton @@ -171,7 +170,7 @@ def _check_triton_compatibility(): def get_triton_info(): """Get information about the installed Triton package. - + Returns: dict: Dictionary containing: - version (str): Triton version string (e.g., "3.1.0" or "3.0.0+45fff310c8") @@ -179,7 +178,7 @@ def get_triton_info(): - is_openai_triton (bool): True if using standard triton from OpenAI/PyPI - env_acknowledged (bool): True if NVTE_USE_PYTORCH_TRITON=1 is set - source (str): "pytorch" or "openai" indicating the package source - + Example: from transformer_engine.jax.triton_extensions import get_triton_info info = get_triton_info() @@ -189,7 +188,7 @@ def get_triton_info(): """ use_pytorch_triton_env = os.environ.get("NVTE_USE_PYTORCH_TRITON", "0").lower() env_acknowledged = use_pytorch_triton_env in ("1", "true", "yes") - + return { "version": _TRITON_VERSION, "is_pytorch_triton": _IS_PYTORCH_TRITON, From 469ed6896469232c9e1792914986a8a9e180c16e Mon Sep 17 00:00:00 2001 From: tdophung Date: Fri, 2 Jan 2026 09:55:28 -0800 Subject: [PATCH 7/8] make env var more precise Signed-off-by: tdophung --- build_tools/jax.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/build_tools/jax.py b/build_tools/jax.py index 999f4ede9b..f07c0a202f 100644 --- a/build_tools/jax.py +++ b/build_tools/jax.py @@ -34,11 +34,7 @@ def test_requirements() -> List[str]: Note: Do NOT install pytorch-triton from PyPI directly - that's a placeholder. """ - use_pytorch_triton = os.environ.get("NVTE_USE_PYTORCH_TRITON", "0").lower() in ( - "1", - "true", - "yes", - ) + use_pytorch_triton = bool(int(os.environ.get("NVTE_USE_PYTORCH_TRITON", "0"))) triton_package = "pytorch-triton" if use_pytorch_triton else "triton" From 27485fa80ded1a8b1fd78c874b668b82862d0e91 Mon Sep 17 00:00:00 2001 From: tdophung Date: Mon, 5 Jan 2026 11:44:41 -0800 Subject: [PATCH 8/8] make env variables checking consitent Signed-off-by: tdophung --- transformer_engine/jax/triton_extensions/utils.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/transformer_engine/jax/triton_extensions/utils.py b/transformer_engine/jax/triton_extensions/utils.py index 3477847e48..59fc5c60af 100644 --- a/transformer_engine/jax/triton_extensions/utils.py +++ b/transformer_engine/jax/triton_extensions/utils.py @@ -122,8 +122,7 @@ def _check_triton_compatibility(): "If you don't need Triton, use transformer_engine.jax.cpp_extensions instead." ) - use_pytorch_triton_env = os.environ.get("NVTE_USE_PYTORCH_TRITON", "0").lower() - use_pytorch_triton_explicit = use_pytorch_triton_env in ("1", "true", "yes") + use_pytorch_triton_explicit = bool(int(os.environ.get("NVTE_USE_PYTORCH_TRITON", "0"))) if is_pytorch_triton: if use_pytorch_triton_explicit: @@ -186,8 +185,7 @@ def get_triton_info(): if info['is_pytorch_triton']: print("Using pytorch-triton - compatible with both PyTorch and JAX") """ - use_pytorch_triton_env = os.environ.get("NVTE_USE_PYTORCH_TRITON", "0").lower() - env_acknowledged = use_pytorch_triton_env in ("1", "true", "yes") + env_acknowledged = bool(int(os.environ.get("NVTE_USE_PYTORCH_TRITON", "0"))) return { "version": _TRITON_VERSION,