Skip to content

Conversation

@danielhumanmod
Copy link

Fix pytorch/pytorch#76344

Context

As mentioned in the issue, torch.max(dim=...) can be optimized with TopK to replace the current ReduceMax and ArgMax implementation. This optimization reduces redundant input scans and avoids potential performance overhead in certain execution providers (e.g., ONNX Runtime CUDA EP microsoft/onnxruntime#11348).

In additional, given the torch.min(dim=...) has the similar pattern with max, I also apply this optimization to it.

Verification

Successfully passed existing OpInfo consistency tests:

  • pytest tests/function_libs/torch_lib/ops_test.py
  • pytest tests/function_libs/torch_lib/e2e_ops_tests.py

@danielhumanmod
Copy link
Author

@danielhumanmod please read the following Contributor License Agreement(CLA). If you agree with the CLA, please reply with the following information.

@microsoft-github-policy-service agree [company="{your company}"]

Options:

  • (default - no company specified) I have sole ownership of intellectual property rights to my Submissions and I am not making Submissions in the course of work for my employer.
@microsoft-github-policy-service agree
  • (when company given) I am making Submissions in the course of work for my employer (or my employer has intellectual property rights in my Submissions by contract or applicable law). I have permission from my employer to make Submissions and enter into this Agreement on behalf of my employer. By signing below, the defined term “You” includes me and my employer.
@microsoft-github-policy-service agree company="Microsoft"

Contributor License Agreement

@microsoft-github-policy-service agree

@codecov
Copy link

codecov bot commented Jan 25, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 70.47%. Comparing base (e06dd92) to head (39baa38).

Additional details and impacted files
@@            Coverage Diff             @@
##             main    #2780      +/-   ##
==========================================
+ Coverage   70.45%   70.47%   +0.02%     
==========================================
  Files         228      228              
  Lines       27258    27264       +6     
  Branches     2761     2763       +2     
==========================================
+ Hits        19204    19214      +10     
+ Misses       7102     7100       -2     
+ Partials      952      950       -2     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Collaborator

@justinchuby justinchuby left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for creating the PR. Reading it again it seems like topk is more general than ReduceMax and ArgMax. From a node count perspective this may be fewer nodes, but I wonder if the original is easier to optimize with.

@github-project-automation github-project-automation bot moved this from Todo to In Progress in ONNX Script Review Board Jan 25, 2026
@danielhumanmod
Copy link
Author

Thanks for creating the PR. Reading it again it seems like topk is more general than ReduceMax and ArgMax. From a node count perspective this may be fewer nodes, but I wonder if the original is easier to optimize with.

Thanks so much for the review! That is a great point, I took some time to dig into the ONNX Runtime implementations to see how they handle this.

  1. From ONNX runtime perspective,

    1. CPU EP provide a fastline when k = 1, which performs a simple linear scan. So on CPU, it seems to behave identically to a fused max+argmax.
    2. CUDA EP will walk through the whole Bitonic/Radix sort process, which can involve more complex instructions. But the upside is that these operations happen primarily in shared memory.
  2. PyTorch Inductor (as an reference): it adopts a similar approach—splitting into reduce_max/arg_max in IR—but leaves it to the runtime (Scheduler) to fuse them. However, when I checked ONNX Runtime, it didn't seem to have an optimization rule to automatically fuse ReduceMax and ArgMax, which implies the split approach effectively incurs one more IO pass compared to TopK

So to the best of my knowledge, TopK might brings more instruction overhead but with less IO. I would appreciate your thoughts here—which approach aligns more with the community's needs? I am flexible to pivot to other tasks if we want to keep the original implementation.

@justinchuby
Copy link
Collaborator

I am not exactly sure what the actual usage of this operator looks like. Are the two outputs always used? One can imagine that if the second output is unused at all, computing it would be a waste of effort. I wonder if it would make sense for you to contribute a rewrite rule to https://github.com/microsoft/onnxscript/tree/main/onnxscript/rewriter/rules ? This way we can do fusion only when the two outputs are used (if not the second output will be removed by the dead code elimination pass)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

Status: In Progress

Development

Successfully merging this pull request may close these issues.

[ONNX] Use topk to export max(dim,keepdim) to onnx

2 participants