-
Notifications
You must be signed in to change notification settings - Fork 99
Optimize aten::min/max.dim with TopK op #2780
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
@microsoft-github-policy-service agree |
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## main #2780 +/- ##
==========================================
+ Coverage 70.45% 70.47% +0.02%
==========================================
Files 228 228
Lines 27258 27264 +6
Branches 2761 2763 +2
==========================================
+ Hits 19204 19214 +10
+ Misses 7102 7100 -2
+ Partials 952 950 -2 ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for creating the PR. Reading it again it seems like topk is more general than ReduceMax and ArgMax. From a node count perspective this may be fewer nodes, but I wonder if the original is easier to optimize with.
Thanks so much for the review! That is a great point, I took some time to dig into the ONNX Runtime implementations to see how they handle this.
So to the best of my knowledge, TopK might brings more instruction overhead but with less IO. I would appreciate your thoughts here—which approach aligns more with the community's needs? I am flexible to pivot to other tasks if we want to keep the original implementation. |
|
I am not exactly sure what the actual usage of this operator looks like. Are the two outputs always used? One can imagine that if the second output is unused at all, computing it would be a waste of effort. I wonder if it would make sense for you to contribute a rewrite rule to https://github.com/microsoft/onnxscript/tree/main/onnxscript/rewriter/rules ? This way we can do fusion only when the two outputs are used (if not the second output will be removed by the dead code elimination pass) |
Fix pytorch/pytorch#76344
Context
As mentioned in the issue,
torch.max(dim=...)can be optimized with TopK to replace the current ReduceMax and ArgMax implementation. This optimization reduces redundant input scans and avoids potential performance overhead in certain execution providers (e.g., ONNX Runtime CUDA EP microsoft/onnxruntime#11348).In additional, given the
torch.min(dim=...)has the similar pattern with max, I also apply this optimization to it.Verification
Successfully passed existing OpInfo consistency tests: