Skip to content

Conversation

@tdophung
Copy link
Collaborator

@tdophung tdophung commented Dec 23, 2025

Description

pytorch-triton and triton packages install to the same location at site-packages/triton, and triton does not work for pytorch's torch.compile() call as there are a few things pytorch has added onto their version of triton (creating pytorch-triton to make it work and validated it with the release of torch). However pytorch-triton should in theory (and experimented) still be compatible with how jax uses it*.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Add new env var to control when to use pytorch-triton in jax
  • switch pytorch back to using/checking for pytorch-triton by default
  • Add documentation (comments) on this contention of packages

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

num_ctas, # arg2: num_ctas (int)
compiled.metadata.shared, # arg3: shared_mem_bytes (int)
compiled.asm["ptx"], # arg4: ptx (str)
"", # arg5: ttir (str) - empty
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will soon be the same as main. as this change is made here in: #1921, to be merged. it is just in this PR so I can test triton calls locally with the nitghtly jax container without running into errors because of jax 0.8.2+

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Dec 23, 2025

Greptile Summary

This PR resolves the package contention between pytorch-triton and triton packages which both install to site-packages/triton.

Key changes:

  • PyTorch now defaults to pytorch-triton dependency (required for torch.compile())
  • JAX uses standard triton by default but supports pytorch-triton via NVTE_USE_PYTORCH_TRITON=1 env var for mixed environments
  • Added triton package detection and validation in utils.py with helpful error messages for the PyPI placeholder package
  • Comprehensive documentation added explaining installation requirements and package compatibility

Impact: Users running mixed JAX+PyTorch environments can now properly configure triton package selection. The detection logic warns users if they accidentally install the broken placeholder pytorch-triton from PyPI instead of from PyTorch's index.

Confidence Score: 4/5

  • This PR is safe to merge - it adds documentation, configuration options, and helpful validation without changing core functionality.
  • Changes are well-documented, follow existing codebase patterns for environment variable handling (bool(int(os.environ.get(...)))), and add defensive validation. The detection logic handles edge cases like the placeholder package. No breaking changes to existing behavior for users who don't set the new env var.
  • build_tools/pytorch.py - changing from 'triton' to 'pytorch-triton' in install_requirements means users must use PyTorch's package index, which is documented but could cause installation issues if users miss the documentation.

Important Files Changed

Filename Overview
build_tools/pytorch.py Changed triton package dependency from 'triton' to 'pytorch-triton' with updated documentation explaining PyTorch index installation requirements.
build_tools/jax.py Added NVTE_USE_PYTORCH_TRITON environment variable to dynamically select between 'triton' and 'pytorch-triton' packages for JAX test requirements.
transformer_engine/jax/triton_extensions/init.py Documentation-only changes adding detailed guidance on triton package options and the new NVTE_USE_PYTORCH_TRITON environment variable.
transformer_engine/jax/triton_extensions/utils.py Added triton package detection, validation logic, and get_triton_info() utility to handle pytorch-triton vs standard triton compatibility.

Sequence Diagram

sequenceDiagram
    participant User
    participant SetupPy as setup.py
    participant BuildTools as build_tools/*.py
    participant TritonUtils as triton_extensions/utils.py
    participant Triton as triton package

    User->>SetupPy: pip install (PyTorch)
    SetupPy->>BuildTools: install_requirements()
    BuildTools-->>SetupPy: ["pytorch-triton", ...]
    Note over SetupPy: Requires PyTorch index

    User->>SetupPy: pip install (JAX)
    SetupPy->>BuildTools: test_requirements()
    alt NVTE_USE_PYTORCH_TRITON=1
        BuildTools-->>SetupPy: ["pytorch-triton"]
    else Default
        BuildTools-->>SetupPy: ["triton"]
    end

    User->>TritonUtils: import triton_extensions
    TritonUtils->>Triton: import triton
    TritonUtils->>TritonUtils: _detect_triton_package()
    alt Placeholder package (0.0.1)
        TritonUtils-->>User: ImportError with fix instructions
    else pytorch-triton detected
        alt NVTE_USE_PYTORCH_TRITON=1
            TritonUtils-->>User: Silent (acknowledged)
        else Not acknowledged
            TritonUtils-->>User: UserWarning
        end
    else Standard triton
        TritonUtils-->>User: Normal operation
    end
Loading

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (1)

  1. transformer_engine/jax/triton_extensions/utils.py, line 322 (link)

    syntax: Typo: compile.name should be compiled.name. The variable compile is not defined in this scope - only compiled exists from line 300. This will cause a NameError at runtime for JAX versions < 0.8.2.

4 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (1)

  1. build_tools/pytorch.py, line 21 (link)

    style: Placeholder text <version??> should be replaced with an actual version (e.g., cu121 or cu124) or made generic.

4 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Comment on lines 37 to 41
use_pytorch_triton = os.environ.get("NVTE_USE_PYTORCH_TRITON", "0").lower() in (
"1",
"true",
"yes",
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
use_pytorch_triton = os.environ.get("NVTE_USE_PYTORCH_TRITON", "0").lower() in (
"1",
"true",
"yes",
)
use_pytorch_triton = bool(int(os.environ.get("NVTE_USE_PYTORCH_TRITON", "0")))

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's better to be specific with envvars and follow convention from rest of the codebase.

"pydantic",
"nvdlfw-inspect",
"triton",
"pytorch-triton",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If pytorch-triton from PyPI is actually a placeholder, then we shouldn't list it here as a dependency

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pytorch-triton should be the default for transformerengine-pytorch. Not just a placeholder. It should be used, all the time when pytorch framework is used.

triton is the default for jax, unless either in 2 scenarios happen:

  • there is both jax and pytorch installed, and they are using TE pytorch to call the triton kernels
  • The user specify NVTE_USE_PYTORCH_TRITON=1 while using TE jax, to make sure there is no performance diff between using different versions of triton, between pytorch and jax.

tdophung and others added 7 commits January 2, 2026 09:59
…for jax

Signed-off-by: tdophung <tdophung@nvidia.com>
Signed-off-by: tdophung <tdophung@nvidia.com>
Signed-off-by: tdophung <tdophung@nvidia.com>
Signed-off-by: tdophung <tdophung@nvidia.com>
Signed-off-by: tdophung <tdophung@nvidia.com>
Signed-off-by: tdophung <tdophung@nvidia.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (1)

  1. build_tools/jax.py, line 37 (link)

    logic: Inconsistent env var parsing: utils.py accepts "1", "true", "yes" (via .lower() check), but this line only accepts "0" or "1" (via int() cast). If user sets NVTE_USE_PYTORCH_TRITON=true, this will raise ValueError: invalid literal for int() with base 10: 'true'.

4 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

@tdophung tdophung force-pushed the teddy/pytorch-triton-contention branch from e2773e1 to 469ed68 Compare January 2, 2026 18:02
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (1)

  1. build_tools/pytorch.py, line 35 (link)

    style: Listing pytorch-triton here won't work with standard pip install from PyPI since the real package is only available from PyTorch's package index. Users running pip install . without specifying --index-url will get the placeholder package that raises "Should never be installed".

    Consider adding an installation note to the README or setup.py, or documenting that this package requires:

    pip install --index-url https://download.pytorch.org/whl/cu121 pytorch-triton
    

4 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Signed-off-by: tdophung <tdophung@nvidia.com>
Copy link
Collaborator

@zhongbozhu zhongbozhu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we merge this fix soon?

@tdophung tdophung merged commit 5f828c2 into NVIDIA:main Jan 8, 2026
13 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants