-
Notifications
You must be signed in to change notification settings - Fork 600
Solve pytorch-triton and triton package contention #2540
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Solve pytorch-triton and triton package contention #2540
Conversation
| num_ctas, # arg2: num_ctas (int) | ||
| compiled.metadata.shared, # arg3: shared_mem_bytes (int) | ||
| compiled.asm["ptx"], # arg4: ptx (str) | ||
| "", # arg5: ttir (str) - empty |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will soon be the same as main. as this change is made here in: #1921, to be merged. it is just in this PR so I can test triton calls locally with the nitghtly jax container without running into errors because of jax 0.8.2+
Greptile SummaryThis PR resolves the package contention between Key changes:
Impact: Users running mixed JAX+PyTorch environments can now properly configure triton package selection. The detection logic warns users if they accidentally install the broken placeholder Confidence Score: 4/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant User
participant SetupPy as setup.py
participant BuildTools as build_tools/*.py
participant TritonUtils as triton_extensions/utils.py
participant Triton as triton package
User->>SetupPy: pip install (PyTorch)
SetupPy->>BuildTools: install_requirements()
BuildTools-->>SetupPy: ["pytorch-triton", ...]
Note over SetupPy: Requires PyTorch index
User->>SetupPy: pip install (JAX)
SetupPy->>BuildTools: test_requirements()
alt NVTE_USE_PYTORCH_TRITON=1
BuildTools-->>SetupPy: ["pytorch-triton"]
else Default
BuildTools-->>SetupPy: ["triton"]
end
User->>TritonUtils: import triton_extensions
TritonUtils->>Triton: import triton
TritonUtils->>TritonUtils: _detect_triton_package()
alt Placeholder package (0.0.1)
TritonUtils-->>User: ImportError with fix instructions
else pytorch-triton detected
alt NVTE_USE_PYTORCH_TRITON=1
TritonUtils-->>User: Silent (acknowledged)
else Not acknowledged
TritonUtils-->>User: UserWarning
end
else Standard triton
TritonUtils-->>User: Normal operation
end
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Additional Comments (1)
-
transformer_engine/jax/triton_extensions/utils.py, line 322 (link)syntax: Typo:
compile.nameshould becompiled.name. The variablecompileis not defined in this scope - onlycompiledexists from line 300. This will cause aNameErrorat runtime for JAX versions < 0.8.2.
4 files reviewed, 1 comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Additional Comments (1)
-
build_tools/pytorch.py, line 21 (link)style: Placeholder text
<version??>should be replaced with an actual version (e.g.,cu121orcu124) or made generic.
4 files reviewed, 1 comment
build_tools/jax.py
Outdated
| use_pytorch_triton = os.environ.get("NVTE_USE_PYTORCH_TRITON", "0").lower() in ( | ||
| "1", | ||
| "true", | ||
| "yes", | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| use_pytorch_triton = os.environ.get("NVTE_USE_PYTORCH_TRITON", "0").lower() in ( | |
| "1", | |
| "true", | |
| "yes", | |
| ) | |
| use_pytorch_triton = bool(int(os.environ.get("NVTE_USE_PYTORCH_TRITON", "0"))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's better to be specific with envvars and follow convention from rest of the codebase.
| "pydantic", | ||
| "nvdlfw-inspect", | ||
| "triton", | ||
| "pytorch-triton", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If pytorch-triton from PyPI is actually a placeholder, then we shouldn't list it here as a dependency
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pytorch-triton should be the default for transformerengine-pytorch. Not just a placeholder. It should be used, all the time when pytorch framework is used.
triton is the default for jax, unless either in 2 scenarios happen:
- there is both jax and pytorch installed, and they are using TE pytorch to call the triton kernels
- The user specify NVTE_USE_PYTORCH_TRITON=1 while using TE jax, to make sure there is no performance diff between using different versions of triton, between pytorch and jax.
…for jax Signed-off-by: tdophung <tdophung@nvidia.com>
Signed-off-by: tdophung <tdophung@nvidia.com>
Signed-off-by: tdophung <tdophung@nvidia.com>
Signed-off-by: tdophung <tdophung@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: tdophung <tdophung@nvidia.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Additional Comments (1)
-
build_tools/jax.py, line 37 (link)logic: Inconsistent env var parsing:
utils.pyaccepts"1","true","yes"(via.lower()check), but this line only accepts"0"or"1"(viaint()cast). If user setsNVTE_USE_PYTORCH_TRITON=true, this will raiseValueError: invalid literal for int() with base 10: 'true'.
4 files reviewed, 1 comment
e2773e1 to
469ed68
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Additional Comments (1)
-
build_tools/pytorch.py, line 35 (link)style: Listing
pytorch-tritonhere won't work with standardpip installfrom PyPI since the real package is only available from PyTorch's package index. Users runningpip install .without specifying--index-urlwill get the placeholder package that raises "Should never be installed".Consider adding an installation note to the README or setup.py, or documenting that this package requires:
pip install --index-url https://download.pytorch.org/whl/cu121 pytorch-triton
4 files reviewed, 1 comment
Signed-off-by: tdophung <tdophung@nvidia.com>
zhongbozhu
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we merge this fix soon?
Description
pytorch-tritonandtritonpackages install to the same location at site-packages/triton, andtritondoes not work for pytorch's torch.compile() call as there are a few things pytorch has added onto their version of triton (creatingpytorch-tritonto make it work and validated it with the release of torch). Howeverpytorch-tritonshould in theory (and experimented) still be compatible with how jax uses it*.Fixes # (issue)
Type of change
Changes
Checklist: