Skip to content

Fix Neuron engine for SDK 2.28+ and add Trainium (trn2) support#659

Open
jimburtoft wants to merge 5 commits intomichaelfeil:mainfrom
jimburtoft:fix/neuron-sdk-2.28-compat
Open

Fix Neuron engine for SDK 2.28+ and add Trainium (trn2) support#659
jimburtoft wants to merge 5 commits intomichaelfeil:mainfrom
jimburtoft:fix/neuron-sdk-2.28-compat

Conversation

@jimburtoft
Copy link

@jimburtoft jimburtoft commented Mar 2, 2026

Related Issue

closes #408

Summary

Fixes 3 bugs that prevent --engine neuron from working with Neuron SDK 2.28 / optimum neuron >= 4.0, switches from tensor parallelism to data parallelism for a 2-4x throughput improvement, and adds Trainium (trn2) support with benchmarks.

Bug Fixes

  1. acceleration.py: optimum.bettertransformer was removed in optimum >= 2.0. The unconditional import crashes at module load time, breaking both --engine torch and --engine neuron. Fixed with a try/except wrapper that marks the optional import as dirty.
  2. acceleration.py: check_if_bettertransformer_possible() calls BetterTransformerManager without checking if the import succeeded. Added a CHECK_OPTIMUM.is_available guard.
  3. neuron.py: Tokenizer was called with return_token_type_ids=False, but the compiled Neuron model expects 3 inputs (including token_type_ids). Removing that flag fixes the input mismatch.

Performance Improvement

Changed num_cores from get_nc_count() (tensor parallelism — shards model across all cores) to 1 (data parallelism — one model per core, multiple server processes). For small/medium embedding models, tensor parallelism wastes cores.
To scale, users run one infinity_emb process per NeuronCore, each pinned via NEURON_RT_VISIBLE_CORES. Throughput scales linearly.

Benchmarks (bge-small-en-v1.5, batch_size=4)

Latency (serial requests, P50)

Workload g5.xlarge (GPU) inf2.xlarge (1 core) trn2.3xlarge (1 core)
1 short sentence 14.2ms 25.0ms 19.0ms
4 short sentences 16.0ms 25.6ms 19.5ms
4 long sentences 16.2ms 26.0ms 20.3ms

Throughput (data parallelism)

Instance Cores Peak emb/s
g5.xlarge (GPU) 1 GPU 536
inf2.xlarge 1 core 216
inf2.xlarge 2 cores 427
trn2.3xlarge 1 core 348
trn2.3xlarge 4 cores 753
trn2 has ~30% lower per-core latency than inf2.

Files Changed

  • libs/infinity_emb/infinity_emb/transformer/acceleration.py — bettertransformer import fix + CHECK_OPTIMUM guard
  • libs/infinity_emb/infinity_emb/transformer/embedder/neuron.py — remove return_token_type_ids=False, change num_cores to 1
  • infra/aws_neuron/Dockerfile.neuron — remove pinned SDK versions
  • infra/aws_neuron/README.md — rewritten with setup guide, data parallelism instructions, trn2 examples, and benchmarks
    Tested on inf2.xlarge (SDK 2.27), trn2.3xlarge (SDK 2.28), and g5.xlarge (GPU baseline).

Checklist

  • [ x] I have read the CONTRIBUTING guidelines.
  • I have added tests to cover my changes.
  • [ x] I have updated the documentation (docs folder) accordingly.

Additional Notes

tested independently, but not sure if you have the build path for testing on Inferentia.

jimburtoft and others added 4 commits February 28, 2026 17:52
- Handle removal of optimum.bettertransformer in optimum >= 2.0 by wrapping
  the import in try/except (acceleration.py)
- Include token_type_ids in Neuron tokenizer output to match compiled model
  expectations (neuron.py)
- Update Dockerfile.neuron to remove pinned SDK versions
- Rewrite README.md with tested instructions for AWS DLAMI setup

Tested on inf2.xlarge with Neuron SDK 2.28, optimum-neuron 0.4.3 and 0.4.5.
Performance: ~210 embeddings/sec throughput, ~25ms latency (bge-small-en-v1.5).
- Fix NameError when BetterTransformerManager is not available in
  check_if_bettertransformer_possible() (affects both torch and neuron
  engines on optimum >= 2.0)
- Update README with benchmark results across g5, inf2, trn2 instances
Compile models for a single NeuronCore (num_cores=1) instead of
sharding across all cores.  This gives 5% better single-core
performance and enables linear scaling via multiple processes.

With 2 processes on inf2.xlarge (2 cores):
- Before (tensor parallel): 206 emb/s
- After (data parallel):    425 emb/s  (+106%)

Update README with multi-process deployment instructions and
revised benchmark results.
- Add trn2.3xlarge (4 NeuronCores) benchmark results: 19ms latency, 753 emb/s
- Add trn2 data parallelism example (4 processes on 4 cores)
- Restructure throughput table to show scaling across instance types
- Note: trn2 has ~30% lower per-core latency than inf2
@chatgpt-codex-connector
Copy link

Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits.
Credits must be used to enable repository wide code reviews.

@gemini-code-assist
Copy link

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the Neuron engine's functionality and performance within Infinity, ensuring compatibility with the latest Neuron SDKs and Optimum versions. It addresses critical bugs that previously prevented the engine from operating correctly and introduces a more efficient data parallelism strategy for scaling inference across AWS Inferentia and Trainium hardware. The changes also expand hardware support to include Trainium (trn2), providing comprehensive documentation and benchmarks for optimal deployment.

Highlights

  • Neuron Engine Compatibility Fixes: Fixed compatibility issues with Neuron SDK 2.28+ and Optimum >= 2.0 by implementing a try/except block for optimum.bettertransformer imports and adding a CHECK_OPTIMUM.is_available guard.
  • Tokenizer Input Mismatch Resolution: Resolved a tokenizer input mismatch for Neuron models by removing return_token_type_ids=False, which was causing issues with compiled Neuron models expecting three inputs.
  • Performance Improvement via Data Parallelism: Improved Neuron engine performance by switching the model compilation strategy from tensor parallelism (get_nc_count()) to data parallelism (single core), enabling linear throughput scaling across multiple NeuronCores.
  • Trainium (trn2) Support and Documentation: Added official support for AWS Trainium (trn2) instances, including comprehensive setup guides, data parallelism instructions, and detailed performance benchmarks in the README.md.
Changelog
  • infra/aws_neuron/Dockerfile.neuron
    • Removed specific version pinning for sentence_transformers, neuronx-cc, optimum-neuron, and optimum to allow for newer versions.
    • Added comments recommending the use of AWS Deep Learning AMI Neuron directly.
  • infra/aws_neuron/README.md
    • Completely rewrote the guide for running Infinity on AWS Inferentia/Trainium.
    • Added detailed instructions for using the HuggingFace Neuron AMI without Docker.
    • Provided explicit steps for installing Infinity and its dependencies.
    • Included guidance on scaling across all NeuronCores using data parallelism with NEURON_RT_VISIBLE_CORES.
    • Added comprehensive performance benchmarks for latency and throughput on GPU, Inferentia, and Trainium instances.
    • Listed the tested software stack versions.
    • Updated the Docker build and run instructions.
    • Added sections on limitations and ECS deployment.
    • Simplified the entryPoint and permissions arrays in the ECS task definition example.
    • Removed logging configuration example from ECS task definition.
  • libs/infinity_emb/infinity_emb/transformer/acceleration.py
    • Wrapped the optimum.bettertransformer import in a try/except block to prevent crashes when optimum >= 2.0 is used, marking CHECK_OPTIMUM as dirty on failure.
    • Added a CHECK_OPTIMUM.is_available guard before calling BetterTransformerManager in check_if_bettertransformer_possible() to prevent errors if BetterTransformer is not available.
  • libs/infinity_emb/infinity_emb/transformer/embedder/neuron.py
    • Changed num_cores for Neuron model compilation from get_nc_count() to 1, enabling data parallelism for better throughput.
    • Removed return_token_type_ids=False from the tokenizer call, resolving an input mismatch issue for compiled Neuron models.
Activity
  • No human activity has been recorded for this pull request since its creation.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces important fixes for running on AWS Neuron with SDK 2.28+ and adds support for Trainium instances. The changes correctly handle the removal of bettertransformer in recent optimum versions, fix an input mismatch for Neuron models, and improve throughput by switching to data parallelism. The documentation and Dockerfile updates are also valuable additions. My review includes a few suggestions: one to improve the Dockerfile build process, a minor correction in the new documentation, and a more robust fix for checking an optional dependency to prevent a potential runtime error.

Comment on lines +49 to +50
if not CHECK_OPTIMUM.is_available:
return False

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The added check if not CHECK_OPTIMUM.is_available: might not be sufficient to prevent a NameError. The is_available property on OptionalImports is a cached_property. It will return True if optimum was available at first check, even if the subsequent import of bettertransformer fails and mark_dirty is called. This means check_if_bettertransformer_possible will proceed and fail when accessing BetterTransformerManager.

A more robust way to check if BetterTransformerManager was imported successfully is to check for its presence in the module's scope.

Suggested change
if not CHECK_OPTIMUM.is_available:
return False
if "BetterTransformerManager" not in globals():
return False

Comment on lines +16 to +17
RUN pip3 install --no-deps sentence_transformers
RUN pip3 install --upgrade neuronx-cc torch-neuronx torchvision libneuronxla protobuf optimum-neuron optimum

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To reduce the number of layers in the Docker image and improve build efficiency, it's a good practice to combine consecutive RUN commands. You can merge the two pip3 install commands into a single RUN instruction. Also, consider using the --no-cache-dir option with pip to prevent caching and reduce the final image size.

RUN pip3 install --no-cache-dir --no-deps sentence_transformers && \
    pip3 install --no-cache-dir --upgrade neuronx-cc torch-neuronx torchvision libneuronxla protobuf optimum-neuron optimum

Comment on lines +85 to +91
| Instance | Cores | Peak emb/s | Concurrency |
|----------|-------|-----------|-------------|
| g5.xlarge (GPU) | 1 GPU | 536 | 8 concurrent |
| inf2.xlarge | 1 core | 216 | 4 concurrent |
| inf2.xlarge | 2 cores | 427 | 4 concurrent |
| trn2.3xlarge | 1 core | 348 | 4 concurrent |
| trn2.3xlarge | 4 cores | 753 | 4 concurrent |

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The Concurrency column in the throughput table contains redundant text (concurrent). For better clarity and consistency, it would be better to have just the number, similar to other columns. This makes the table easier to parse for humans and potentially for automated tools.

Suggested change
| Instance | Cores | Peak emb/s | Concurrency |
|----------|-------|-----------|-------------|
| g5.xlarge (GPU) | 1 GPU | 536 | 8 concurrent |
| inf2.xlarge | 1 core | 216 | 4 concurrent |
| inf2.xlarge | 2 cores | 427 | 4 concurrent |
| trn2.3xlarge | 1 core | 348 | 4 concurrent |
| trn2.3xlarge | 4 cores | 753 | 4 concurrent |
| Instance | Cores | Peak emb/s | Concurrency |
|----------|-------|-----------|-------------|
| g5.xlarge (GPU) | 1 GPU | 536 | 8 |
| inf2.xlarge | 1 core | 216 | 4 |
| inf2.xlarge | 2 cores | 427 | 4 |
| trn2.3xlarge | 1 core | 348 | 4 |
| trn2.3xlarge | 4 cores | 753 | 4 |

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 2, 2026

Greptile Summary

This PR fixes three bugs that prevented --engine neuron from working with Neuron SDK 2.28 / optimum >= 2.0, switches the NeuronCore allocation strategy from tensor parallelism to data parallelism, and adds Trainium (trn2) support with benchmarks and updated documentation.

Key changes:

  • acceleration.py: The unconditional from optimum.bettertransformer import is now wrapped in a try/except for optimum >= 2.0 compatibility. However, the new availability guard in check_if_bettertransformer_possible only checks whether the top-level optimum package exists — it does not check whether the import actually succeeded. When optimum >= 2.0 is installed, the function still reaches BetterTransformerManager.MODEL_MAPPING and raises a NameError.
  • neuron.py: The tokenizer call no longer suppresses token_type_ids, which is the correct fix for models compiled with 3 inputs. The number of NeuronCores is now hard-coded to 1 and get_nc_count is now dead code. The core-count change is a silent breaking change for large models that require tensor parallelism across multiple NeuronCores.
  • Dockerfile.neuron: All SDK version pins have been removed, which improves forward compatibility at the cost of build reproducibility.
  • README.md: Thoroughly rewritten with data-parallelism setup guide, trn2 benchmarks, and a tested stack table.

Confidence Score: 3/5

  • The PR is mostly safe to merge but contains one logic bug in acceleration.py that can cause a NameError for --engine torch users on optimum >= 2.0 with bettertransformer=True.
  • The tokenizer fix and data-parallelism switch are sound. The acceleration.py guard is incomplete: CHECK_OPTIMUM.is_available returns True for any optimum installation, so marking the import dirty does not prevent BetterTransformerManager.MODEL_MAPPING from being accessed when the submodule is missing. This is a concrete runtime error for a realistic configuration. The num_cores=1 hard-code is a documented behavioral change but could silently break large-model deployments without a clear escape hatch. Dockerfile unpinning is a minor concern.
  • libs/infinity_emb/infinity_emb/transformer/acceleration.py — the check_if_bettertransformer_possible guard must also test CHECK_OPTIMUM._marked_as_dirty.

Important Files Changed

Filename Overview
libs/infinity_emb/infinity_emb/transformer/acceleration.py Adds try/except around the optimum.bettertransformer import for optimum >= 2.0 compatibility, but the new CHECK_OPTIMUM.is_available guard in check_if_bettertransformer_possible is insufficient — it does not check _marked_as_dirty, so a NameError on BetterTransformerManager is still possible when optimum >= 2.0 is installed.
infra/aws_neuron/Dockerfile.neuron Removes all pinned SDK version numbers from pip install commands, improving forward compatibility but sacrificing build reproducibility; adds a missing trailing newline and clarifying comments.
infra/aws_neuron/README.md Documentation rewrite; adds data-parallelism setup guide, trn2 benchmarks, and tested stack table. No code issues.
libs/infinity_emb/infinity_emb/transformer/embedder/neuron.py Two changes: tokenizer no longer suppresses token_type_ids (correct for models compiled expecting 3 inputs), and the number of NeuronCores is hard-coded to 1; get_nc_count is now dead code, and the core-count change may silently break large models that require cross-core tensor parallelism.

Sequence Diagram

sequenceDiagram
    participant Client
    participant Server as infinity_emb (one process per core)
    participant Embedder as NeuronOptimumEmbedder
    participant Tok as AutoTokenizer
    participant Model as NeuronModelForFeatureExtraction

    Client->>Server: POST /embeddings
    Server->>Embedder: encode_pre(sentences)
    Embedder->>Tok: tokenize — returns input_ids, attention_mask, token_type_ids
    Tok-->>Embedder: input_dict
    Embedder-->>Server: input_dict

    Server->>Embedder: encode_core(input_dict)
    Embedder->>Embedder: pad_up_to_size(batch_size, each tensor)
    Embedder->>Model: forward(**input_dict)
    Model-->>Embedder: last_hidden_state
    Embedder-->>Server: token_embeddings + attention_mask

    Server->>Embedder: encode_post(embedding)
    Embedder->>Embedder: pooling + normalize
    Embedder-->>Server: float32 embeddings
    Server-->>Client: response

    Note over Client,Server: One server process per NeuronCore (data parallelism)
Loading

Last reviewed commit: d9fe9f0

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

4 files reviewed, 4 comments

Edit Code Review Agent Settings | Greptile

# Compile for a single NeuronCore. For data-parallel scaling across
# multiple cores, run separate server processes pinned to individual
# cores via NEURON_RT_VISIBLE_CORES (see README).
compiler_args = {"num_cores": 1, "auto_cast_type": "fp16"}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hard-coded num_cores=1 silently breaks models requiring tensor parallelism

Setting num_cores=1 works well for small/medium embedding models (as documented in the README), but it will silently produce a worse result — or fail to compile entirely — for large models that don't fit within a single NeuronCore's on-chip memory. The previous get_nc_count() behavior at least used all available cores for a single model instance.

Consider making this configurable (e.g. via an EngineArgs field or an environment variable) so that users with large models can opt back in to tensor parallelism without changing source code:

compiler_args = {
    "num_cores": int(os.environ.get("NEURON_NUM_CORES", "1")),
    "auto_cast_type": "fp16",
}

At a minimum, it would be worth adding a log warning when num_cores=1 is used on an instance with more than one NeuronCore, so operators are aware they are under-utilising the chip with a single process.

Comment on lines +16 to +17
RUN pip3 install --no-deps sentence_transformers
RUN pip3 install --upgrade neuronx-cc torch-neuronx torchvision libneuronxla protobuf optimum-neuron optimum
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unpinned dependency versions may cause non-reproducible builds

Removing the version pins for neuronx-cc, torch-neuronx, libneuronxla, optimum-neuron, and optimum means every docker build will pull the latest available versions. While this avoids the old hard-coded versions going stale, it can cause silent breakage whenever a new SDK release ships an incompatible change.

Consider at minimum pinning major versions (e.g. optimum-neuron>=0.4,<1.0) or adding a comment documenting the tested version matrix from the README so that the file serves as a reproducible baseline.

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 2, 2026

Additional Comments (2)

libs/infinity_emb/infinity_emb/transformer/acceleration.py
Incomplete guard allows NameError when optimum >= 2.0 is installed

The new guard on lines 49–50 only checks whether the top-level optimum package is importable. When optimum >= 2.0 is installed, CHECK_OPTIMUM.is_available evaluates to True (because importlib.util.find_spec("optimum") succeeds), but the from optimum.bettertransformer import … block above already failed and called CHECK_OPTIMUM.mark_dirty(…). Critically, mark_dirty does not flip is_available — it only sets _marked_as_dirty.

So the guard passes, execution reaches line 58, and a NameError: name 'BetterTransformerManager' is not defined is raised, crashing any caller that has bettertransformer=True in their EngineArgs when using optimum >= 2.0.

The guard needs to also test the dirty flag:

    if not CHECK_OPTIMUM.is_available or CHECK_OPTIMUM._marked_as_dirty:
        return False

The same issue affects the to_bettertransformer function further down, but that function calls CHECK_OPTIMUM.mark_required() (which does check _marked_as_dirty) before reaching BetterTransformer.transform, so it raises a clean ImportError rather than a NameError — only check_if_bettertransformer_possible is left unprotected.


libs/infinity_emb/infinity_emb/transformer/embedder/neuron.py
get_nc_count is now dead code

After changing num_cores from get_nc_count() to the hard-coded 1, the get_nc_count function is no longer called anywhere in this file. It should be removed to avoid confusion, since it also runs a subprocess (neuron-ls) and produces print output when called.

(Remove the entire get_nc_count function, lines 33–46.)

…_NUM_CORES env var

- acceleration.py: Use globals() check instead of CHECK_OPTIMUM.is_available
  to correctly detect when bettertransformer import failed (mark_dirty does
  not invalidate the cached is_available property)
- neuron.py: Add NEURON_NUM_CORES env var (default 1) so large models can
  opt in to tensor parallelism without source changes
- neuron.py: Remove dead get_nc_count() function and unused imports
- Dockerfile: Merge RUN layers, add --no-cache-dir, add tested version comment
- README: Clean up redundant text in throughput table
Copy link
Owner

@michaelfeil michaelfeil left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you test this out? Do you need help publishing a image or release version?

@jimburtoft
Copy link
Author

Yes, tested on inf2.xlarge (SDK 2.27), trn2.3xlarge (SDK 2.28), and g5.xlarge (GPU). Fresh clone → install → run → embedding requests all working. Data parallelism gives 427 emb/s on inf2 (vs 206 old) and 750 emb/s on trn2 with 4 processes.
All README benchmarks are from actual runs.
A new release would be good. We probably don't need a new image at this point

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Doc] Documentation on how to run infinity on AWS Inf2

2 participants