From 1c9e7c52058f7052aa8f943bad5d23cb4e150e25 Mon Sep 17 00:00:00 2001 From: Jim Burtoft Date: Sat, 28 Feb 2026 01:38:03 -0500 Subject: [PATCH 1/5] Fix Neuron engine compatibility with optimum >= 2.0 and Neuron SDK 2.28 - Handle removal of optimum.bettertransformer in optimum >= 2.0 by wrapping the import in try/except (acceleration.py) - Include token_type_ids in Neuron tokenizer output to match compiled model expectations (neuron.py) - Update Dockerfile.neuron to remove pinned SDK versions - Rewrite README.md with tested instructions for AWS DLAMI setup Tested on inf2.xlarge with Neuron SDK 2.28, optimum-neuron 0.4.3 and 0.4.5. Performance: ~210 embeddings/sec throughput, ~25ms latency (bge-small-en-v1.5). --- infra/aws_neuron/Dockerfile.neuron | 21 +-- infra/aws_neuron/README.md | 140 ++++++++++++------ .../infinity_emb/transformer/acceleration.py | 17 ++- .../transformer/embedder/neuron.py | 1 - 4 files changed, 118 insertions(+), 61 deletions(-) diff --git a/infra/aws_neuron/Dockerfile.neuron b/infra/aws_neuron/Dockerfile.neuron index d9bc8558..4ee16669 100644 --- a/infra/aws_neuron/Dockerfile.neuron +++ b/infra/aws_neuron/Dockerfile.neuron @@ -1,5 +1,11 @@ -# Is an mirror of -# 763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-inference-neuronx:2.1.2-transformers4.43.2-neuronx-py310-sdk2.20.0-ubuntu20.04 +# Recommended: Use the AWS Deep Learning AMI Neuron (Ubuntu 24.04) directly +# instead of building a custom Docker image. See README.md for instructions. +# +# If you must use Docker, this Dockerfile provides a starting point. +# Note: The Neuron runtime must be available on the host (--device=/dev/neuron0). + +# Base image with Neuron SDK pre-installed +# Mirror of HuggingFace Neuron inference image FROM michaelf34/aws-neuron-base-img:0.0.25-inference AS base WORKDIR /app @@ -7,17 +13,12 @@ WORKDIR /app COPY ./infra/aws_neuron/requirements_no_gpu.txt requirements_no_gpu.txt RUN pip3 install -r requirements_no_gpu.txt RUN pip config set global.extra-index-url https://pip.repos.neuron.amazonaws.com -# req -# RUN pip3 install --no-deps --upgrade optimum[neuronx]==1.20.0 -RUN pip3 install --no-deps sentence_transformers==3.3.1 -# libneuronxla-2.0.5347.0 ml-dtypes-0.2.0 neuronx-cc-2.15.143.0+e39249ad setuptools-69.5.1 torch-neuronx-2.1.2.2.3.2 torch-xla-2.1.5 transformers-neuronx-0.12.313 -RUN pip3 install --upgrade neuronx-cc==2.15.* torch-neuronx torchvision transformers-neuronx libneuronxla protobuf optimum-neuron==0.0.20 +RUN pip3 install --no-deps sentence_transformers +RUN pip3 install --upgrade neuronx-cc torch-neuronx torchvision libneuronxla protobuf optimum-neuron optimum -# base is also checkpointed to -# docker pull michaelf34/aws-neuron-base-img:neuroncc2-15--optimum-1-17--transformers-4-36 FROM base AS infinity_latest COPY ./libs/infinity_emb . RUN pip3 install -e . ENV INFINITY_BATCH_SIZE=8 ENV INFINITY_ENGINE=neuron -ENTRYPOINT [ "infinity_emb" ] \ No newline at end of file +ENTRYPOINT [ "infinity_emb" ] diff --git a/infra/aws_neuron/README.md b/infra/aws_neuron/README.md index 1f157be6..54397676 100644 --- a/infra/aws_neuron/README.md +++ b/infra/aws_neuron/README.md @@ -1,33 +1,106 @@ -# Launch an EC2 Instance on AWS: +# Running Infinity on AWS Inferentia / Trainium -### Start a EC2 Instance with Huggingface AMI (free AMI image with Neuron Tools/Docker installed) -- https://aws.amazon.com/marketplace/pp/prodview-gr3e6yiscria2 -- View Purchase Options -> Configure -- Use `64-Bit AMI`, `20241115 (Nov 18, 2024)` -- Region, e.g. `us-west-2` -- Set Instance type `inf2.xlarge` (has two neuron accelerators) -- Login with username `ubuntu` (using your standard EC2 setup e.g. `ssh ubuntu@ec2-14-11-13-12.us-west-2.compute.amazonaws.com`) +## Recommended: Use the HuggingFace Neuron AMI (no Docker) + +The simplest approach is to run Infinity directly on an EC2 instance with the +HuggingFace Neuron AMI, which comes with `optimum-neuron`, `optimum`, `transformers`, +and `sentence-transformers` pre-installed with compatible Neuron SDK versions. + +### 1. Launch an EC2 Instance + +- Use the **HuggingFace Neuron AMI** (`huggingface-neuron-*`) from the AWS Marketplace + - This AMI ships optimum-neuron 0.4.4, neuronx-cc 2.21, Python 3.10 — all compatible + - Search for `huggingface-neuron` in the EC2 AMI catalog +- Instance type: **inf2.xlarge** (2 NeuronCores, 32 GB) or larger +- Disk: The AMI defaults to 512 GB + +### 2. Install Infinity + +```bash +# SSH into the instance +ssh ubuntu@ + +# Activate the pre-installed PyTorch environment +source /opt/aws_neuronx_venv_pytorch_2_8/bin/activate + +# Clone and install Infinity from source (don't overwrite Neuron packages) +git clone https://github.com/michaelfeil/infinity.git ~/infinity +cd ~/infinity/libs/infinity_emb +pip install --no-deps . + +# Install remaining runtime dependencies (most are already present on the HF AMI) +pip install uvicorn fastapi orjson typer httptools pydantic posthog \ + prometheus-fastapi-instrumentator hf_transfer rich +``` + +### 3. Run Infinity with Neuron engine + +```bash +infinity_emb v2 --engine neuron --model-id BAAI/bge-small-en-v1.5 --batch-size 4 +``` + +The first run will compile the model for Neuron (~100 seconds). Subsequent runs use the cached compilation. + +### 4. Test it + +```bash +curl http://localhost:7997/embeddings \ + -H "Content-Type: application/json" \ + -d '{"input": ["Hello world", "How are you?"], "model": "BAAI/bge-small-en-v1.5"}' +``` + +## Performance (inf2.xlarge, bge-small-en-v1.5, batch_size=4) + +Tested on HuggingFace Neuron AMI (optimum-neuron 0.4.4, neuronx-cc 2.21, SDK 2.27): + +| Metric | Value | +|--------|-------| +| Latency (serial, 1 sentence) | ~28 ms | +| Latency (serial, 4 sentences) | ~28 ms | +| Throughput (4 concurrent) | ~201 embeddings/sec | +| Compilation time (first run) | ~99 seconds | + +## Tested Stack + +| Package | Version | +|---------|---------| +| optimum-neuron | 0.4.4 | +| optimum | 2.0.0 | +| neuronx-cc | 2.21.33363 | +| torch-neuronx | 2.8.0.2.10 | +| torch | 2.8.0 | +| transformers | 4.57.3 | +| Python | 3.10.12 | + +## Alternative: Docker + +### Build from source -### Optional: build docker image from scratch ```bash git clone https://github.com/michaelfeil/infinity cd infinity -docker buildx build -t michaelf34/infinity:0.0.x-neuron -f ./infra/aws_neuron/Dockerfile.neuron +docker buildx build -t infinity-neuron -f ./infra/aws_neuron/Dockerfile.neuron . ``` -### Run the image on EC2 +### Run on EC2 ```bash -docker run -it --rm --device=/dev/neuron0 michaelf34/infinity:0.0.71-neuron v2 --model-id BAAI/bge-small-en-v1.5 --batch-size 8 --log-level debug +docker run -it --rm --device=/dev/neuron0 infinity-neuron \ + v2 --model-id BAAI/bge-small-en-v1.5 --batch-size 8 ``` -### Run task on ECS (Work in progress) +**Note:** The host must have the Neuron driver installed. The Docker approach is less tested than the direct AMI approach above. + +## Limitations + +- The `--engine neuron` flag currently supports **text embeddings only** (no reranking or classification) +- The Neuron engine requires a **constant batch size** (requests are padded automatically) +- Models are compiled on first use; compilation can take 60-120 seconds -1. Create a AWS ECS Cluster with EC2: -- Amazon Machine Image (AMI): Amazon Linux 2 - *Neuron* -- inf2.xlarge as machine type. +## ECS Deployment + +See the ECS task definition example below for container orchestration: -2. Create a Task: ```json { "family": "ecs-infinity-neuron", @@ -45,10 +118,7 @@ docker run -it --rm --device=/dev/neuron0 michaelf34/infinity:0.0.71-neuron v2 - "executionRoleArn": "${YOUR_EXECUTION_ROLE}", "containerDefinitions": [ { - "entryPoint": [ - "infinity_emb", - "v2" - ], + "entryPoint": ["infinity_emb", "v2"], "portMappings": [ { "hostPort": 7997, @@ -61,41 +131,19 @@ docker run -it --rm --device=/dev/neuron0 michaelf34/infinity:0.0.71-neuron v2 - { "containerPath": "/dev/neuron0", "hostPath": "/dev/neuron0", - "permissions": [ - "read", - "write" - ] + "permissions": ["read", "write"] } ], "capabilities": { - "add": [ - "IPC_LOCK" - ] + "add": ["IPC_LOCK"] } }, "cpu": 0, "memoryReservation": 1000, - "image": "michaelf34/infinity:0.0.71-neuron", + "image": "infinity-neuron:latest", "essential": true, "name": "infinity-neuron" } ] } ``` - -You can also add logging: -``` - // same indent as "linuxParameters" - "logConfiguration": { - "logDriver": "awslogs", - "options": { - "awslogs-group": "/ecs/ecs-infinity-neuron", - "mode": "non-blocking", - "awslogs-create-group": "true", - "max-buffer-size": "25m", - "awslogs-region": "us-west-2", // set correct location. - "awslogs-stream-prefix": "ecs" - }, - "secretOptions": [] - } -``` \ No newline at end of file diff --git a/libs/infinity_emb/infinity_emb/transformer/acceleration.py b/libs/infinity_emb/infinity_emb/transformer/acceleration.py index 1d7b7c7f..55ec37b8 100644 --- a/libs/infinity_emb/infinity_emb/transformer/acceleration.py +++ b/libs/infinity_emb/infinity_emb/transformer/acceleration.py @@ -8,10 +8,19 @@ from infinity_emb.primitives import Device if CHECK_OPTIMUM.is_available: - from optimum.bettertransformer import ( # type: ignore[import-untyped] - BetterTransformer, - BetterTransformerManager, - ) + try: + from optimum.bettertransformer import ( # type: ignore[import-untyped] + BetterTransformer, + BetterTransformerManager, + ) + except (ImportError, ModuleNotFoundError): + # optimum.bettertransformer was removed in optimum >= 2.0 + CHECK_OPTIMUM.mark_dirty( + ImportError( + "optimum.bettertransformer is not available in this version of optimum. " + "BetterTransformer support requires optimum < 2.0." + ) + ) if CHECK_TORCH.is_available: import torch diff --git a/libs/infinity_emb/infinity_emb/transformer/embedder/neuron.py b/libs/infinity_emb/infinity_emb/transformer/embedder/neuron.py index 433bd67c..0be97d92 100644 --- a/libs/infinity_emb/infinity_emb/transformer/embedder/neuron.py +++ b/libs/infinity_emb/infinity_emb/transformer/embedder/neuron.py @@ -124,7 +124,6 @@ def encode_pre(self, sentences: list[str]) -> dict[str, "torch.Tensor"]: padding=True, truncation="longest_first", return_tensors="pt", - return_token_type_ids=False, ) return input_dict From 19c22bbafe7e9af16edf9f68c0dba92903e94a80 Mon Sep 17 00:00:00 2001 From: Jim Burtoft Date: Sat, 28 Feb 2026 22:16:09 -0500 Subject: [PATCH 2/5] Add CHECK_OPTIMUM guard and cross-instance benchmarks - Fix NameError when BetterTransformerManager is not available in check_if_bettertransformer_possible() (affects both torch and neuron engines on optimum >= 2.0) - Update README with benchmark results across g5, inf2, trn2 instances --- infra/aws_neuron/README.md | 30 ++++++++++++++----- .../infinity_emb/transformer/acceleration.py | 3 ++ 2 files changed, 25 insertions(+), 8 deletions(-) diff --git a/infra/aws_neuron/README.md b/infra/aws_neuron/README.md index 54397676..e1fa8202 100644 --- a/infra/aws_neuron/README.md +++ b/infra/aws_neuron/README.md @@ -49,16 +49,30 @@ curl http://localhost:7997/embeddings \ -d '{"input": ["Hello world", "How are you?"], "model": "BAAI/bge-small-en-v1.5"}' ``` -## Performance (inf2.xlarge, bge-small-en-v1.5, batch_size=4) +## Performance (bge-small-en-v1.5, batch_size=4) -Tested on HuggingFace Neuron AMI (optimum-neuron 0.4.4, neuronx-cc 2.21, SDK 2.27): +### Latency (serial requests, P50) -| Metric | Value | -|--------|-------| -| Latency (serial, 1 sentence) | ~28 ms | -| Latency (serial, 4 sentences) | ~28 ms | -| Throughput (4 concurrent) | ~201 embeddings/sec | -| Compilation time (first run) | ~99 seconds | +| Workload | g5.xlarge (GPU) | inf2.xlarge | trn2.3xlarge | +|----------|----------------|-------------|--------------| +| 1 short sentence | 14.2ms | 25.9ms | 18.9ms | +| 4 short sentences | 16.0ms | 26.5ms | 19.4ms | +| 4 long sentences | 16.2ms | 27.0ms | 19.9ms | + +### Throughput (concurrent requests) + +| Workload | g5.xlarge (GPU) | inf2.xlarge | trn2.3xlarge | +|----------|----------------|-------------|--------------| +| 4 sentences, 4 concurrent | 421 emb/s | 207 emb/s | 351 emb/s | +| 4 sentences, 8 concurrent | 536 emb/s | 206 emb/s | 349 emb/s | + +**Notes:** +- g5.xlarge uses `--engine torch`; inf2/trn2 use `--engine neuron` +- Neuron latency is constant regardless of batch content (padded to compiled batch size) +- GPU throughput scales with concurrency; Neuron throughput is flat +- Compilation time: ~60-100 seconds on first run (cached after that) + +Tested on HuggingFace Neuron AMI (optimum-neuron 0.4.4, neuronx-cc 2.21, SDK 2.27). ## Tested Stack diff --git a/libs/infinity_emb/infinity_emb/transformer/acceleration.py b/libs/infinity_emb/infinity_emb/transformer/acceleration.py index 55ec37b8..135c0ed2 100644 --- a/libs/infinity_emb/infinity_emb/transformer/acceleration.py +++ b/libs/infinity_emb/infinity_emb/transformer/acceleration.py @@ -46,6 +46,9 @@ def check_if_bettertransformer_possible(engine_args: "EngineArgs") -> bool: if not engine_args.bettertransformer: return False + if not CHECK_OPTIMUM.is_available: + return False + config = AutoConfig.from_pretrained( pretrained_model_name_or_path=engine_args.model_name_or_path, revision=engine_args.revision, From 771e63040df1976437888c58530bbcfdaac23730 Mon Sep 17 00:00:00 2001 From: Jim Burtoft Date: Mon, 2 Mar 2026 13:30:35 -0500 Subject: [PATCH 3/5] Use data parallelism instead of tensor parallelism for Neuron Compile models for a single NeuronCore (num_cores=1) instead of sharding across all cores. This gives 5% better single-core performance and enables linear scaling via multiple processes. With 2 processes on inf2.xlarge (2 cores): - Before (tensor parallel): 206 emb/s - After (data parallel): 425 emb/s (+106%) Update README with multi-process deployment instructions and revised benchmark results. --- infra/aws_neuron/README.md | 35 +++++++++++++------ .../transformer/embedder/neuron.py | 5 ++- 2 files changed, 29 insertions(+), 11 deletions(-) diff --git a/infra/aws_neuron/README.md b/infra/aws_neuron/README.md index e1fa8202..6e876f79 100644 --- a/infra/aws_neuron/README.md +++ b/infra/aws_neuron/README.md @@ -36,12 +36,27 @@ pip install uvicorn fastapi orjson typer httptools pydantic posthog \ ### 3. Run Infinity with Neuron engine ```bash +# Single core (uses one NeuronCore) infinity_emb v2 --engine neuron --model-id BAAI/bge-small-en-v1.5 --batch-size 4 ``` The first run will compile the model for Neuron (~100 seconds). Subsequent runs use the cached compilation. -### 4. Test it +### 4. Scale across all NeuronCores (data parallelism) + +The Neuron runtime is limited to one model per process. To use all NeuronCores, +run one server process per core, each pinned to a different core: + +```bash +# inf2.xlarge has 2 NeuronCores (cores 0 and 1) +NEURON_RT_VISIBLE_CORES=0 infinity_emb v2 --engine neuron --model-id BAAI/bge-small-en-v1.5 --batch-size 4 --port 7997 & +NEURON_RT_VISIBLE_CORES=1 infinity_emb v2 --engine neuron --model-id BAAI/bge-small-en-v1.5 --batch-size 4 --port 7998 & +``` + +Then use a load balancer (nginx, HAProxy, etc.) to distribute requests across +ports. This gives **linear throughput scaling**: 2 cores = 2x throughput, 8 cores = 8x. + +### 5. Test it ```bash curl http://localhost:7997/embeddings \ @@ -53,18 +68,18 @@ curl http://localhost:7997/embeddings \ ### Latency (serial requests, P50) -| Workload | g5.xlarge (GPU) | inf2.xlarge | trn2.3xlarge | -|----------|----------------|-------------|--------------| -| 1 short sentence | 14.2ms | 25.9ms | 18.9ms | -| 4 short sentences | 16.0ms | 26.5ms | 19.4ms | -| 4 long sentences | 16.2ms | 27.0ms | 19.9ms | +| Workload | g5.xlarge (GPU) | inf2.xlarge (1 core) | inf2.xlarge (2 cores) | +|----------|----------------|---------------------|----------------------| +| 1 short sentence | 14.2ms | 25.0ms | 25.0ms | +| 4 short sentences | 16.0ms | 25.6ms | 25.6ms | +| 4 long sentences | 16.2ms | 26.0ms | 26.0ms | ### Throughput (concurrent requests) -| Workload | g5.xlarge (GPU) | inf2.xlarge | trn2.3xlarge | -|----------|----------------|-------------|--------------| -| 4 sentences, 4 concurrent | 421 emb/s | 207 emb/s | 351 emb/s | -| 4 sentences, 8 concurrent | 536 emb/s | 206 emb/s | 349 emb/s | +| Workload | g5.xlarge (GPU) | inf2.xlarge (1 core) | inf2.xlarge (2 cores) | +|----------|----------------|---------------------|----------------------| +| 4 sentences, 4 concurrent | 421 emb/s | 216 emb/s | 425 emb/s | +| 4 sentences, 8 concurrent | 536 emb/s | 215 emb/s | 424 emb/s | **Notes:** - g5.xlarge uses `--engine torch`; inf2/trn2 use `--engine neuron` diff --git a/libs/infinity_emb/infinity_emb/transformer/embedder/neuron.py b/libs/infinity_emb/infinity_emb/transformer/embedder/neuron.py index 0be97d92..8ca2122e 100644 --- a/libs/infinity_emb/infinity_emb/transformer/embedder/neuron.py +++ b/libs/infinity_emb/infinity_emb/transformer/embedder/neuron.py @@ -97,7 +97,10 @@ def __init__(self, *, engine_args: EngineArgs): ) self._infinity_tokenizer = copy.deepcopy(self.tokenizer) - compiler_args = {"num_cores": get_nc_count(), "auto_cast_type": "fp16"} + # Compile for a single NeuronCore. For data-parallel scaling across + # multiple cores, run separate server processes pinned to individual + # cores via NEURON_RT_VISIBLE_CORES (see README). + compiler_args = {"num_cores": 1, "auto_cast_type": "fp16"} input_shapes = { "batch_size": engine_args.batch_size, "sequence_length": ( From d9fe9f078ba8c8f6ab5c2a2d4cefcef1c35e4ff1 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Mon, 2 Mar 2026 21:02:15 +0000 Subject: [PATCH 4/5] Add trn2.3xlarge benchmarks and data parallelism examples to Neuron docs - Add trn2.3xlarge (4 NeuronCores) benchmark results: 19ms latency, 753 emb/s - Add trn2 data parallelism example (4 processes on 4 cores) - Restructure throughput table to show scaling across instance types - Note: trn2 has ~30% lower per-core latency than inf2 --- infra/aws_neuron/README.md | 37 ++++++++++++++++++++++++------------- 1 file changed, 24 insertions(+), 13 deletions(-) diff --git a/infra/aws_neuron/README.md b/infra/aws_neuron/README.md index 6e876f79..6ee0d17e 100644 --- a/infra/aws_neuron/README.md +++ b/infra/aws_neuron/README.md @@ -11,7 +11,7 @@ and `sentence-transformers` pre-installed with compatible Neuron SDK versions. - Use the **HuggingFace Neuron AMI** (`huggingface-neuron-*`) from the AWS Marketplace - This AMI ships optimum-neuron 0.4.4, neuronx-cc 2.21, Python 3.10 — all compatible - Search for `huggingface-neuron` in the EC2 AMI catalog -- Instance type: **inf2.xlarge** (2 NeuronCores, 32 GB) or larger +- Instance type: **inf2.xlarge** (2 NeuronCores, 32 GB), **trn2.3xlarge** (4 NeuronCores, 128 GB), or larger - Disk: The AMI defaults to 512 GB ### 2. Install Infinity @@ -51,10 +51,16 @@ run one server process per core, each pinned to a different core: # inf2.xlarge has 2 NeuronCores (cores 0 and 1) NEURON_RT_VISIBLE_CORES=0 infinity_emb v2 --engine neuron --model-id BAAI/bge-small-en-v1.5 --batch-size 4 --port 7997 & NEURON_RT_VISIBLE_CORES=1 infinity_emb v2 --engine neuron --model-id BAAI/bge-small-en-v1.5 --batch-size 4 --port 7998 & + +# trn2.3xlarge has 4 NeuronCores (cores 0-3) +NEURON_RT_VISIBLE_CORES=0 infinity_emb v2 --engine neuron --model-id BAAI/bge-small-en-v1.5 --batch-size 4 --port 7997 & +NEURON_RT_VISIBLE_CORES=1 infinity_emb v2 --engine neuron --model-id BAAI/bge-small-en-v1.5 --batch-size 4 --port 7998 & +NEURON_RT_VISIBLE_CORES=2 infinity_emb v2 --engine neuron --model-id BAAI/bge-small-en-v1.5 --batch-size 4 --port 7999 & +NEURON_RT_VISIBLE_CORES=3 infinity_emb v2 --engine neuron --model-id BAAI/bge-small-en-v1.5 --batch-size 4 --port 8000 & ``` Then use a load balancer (nginx, HAProxy, etc.) to distribute requests across -ports. This gives **linear throughput scaling**: 2 cores = 2x throughput, 8 cores = 8x. +ports. Throughput scales linearly with cores: 2 cores = 2x, 4 cores = 4x. ### 5. Test it @@ -68,26 +74,31 @@ curl http://localhost:7997/embeddings \ ### Latency (serial requests, P50) -| Workload | g5.xlarge (GPU) | inf2.xlarge (1 core) | inf2.xlarge (2 cores) | +| Workload | g5.xlarge (GPU) | inf2.xlarge (1 core) | trn2.3xlarge (1 core) | |----------|----------------|---------------------|----------------------| -| 1 short sentence | 14.2ms | 25.0ms | 25.0ms | -| 4 short sentences | 16.0ms | 25.6ms | 25.6ms | -| 4 long sentences | 16.2ms | 26.0ms | 26.0ms | +| 1 short sentence | 14.2ms | 25.0ms | 19.0ms | +| 4 short sentences | 16.0ms | 25.6ms | 19.5ms | +| 4 long sentences | 16.2ms | 26.0ms | 20.3ms | -### Throughput (concurrent requests) +### Throughput (concurrent requests, data parallelism) -| Workload | g5.xlarge (GPU) | inf2.xlarge (1 core) | inf2.xlarge (2 cores) | -|----------|----------------|---------------------|----------------------| -| 4 sentences, 4 concurrent | 421 emb/s | 216 emb/s | 425 emb/s | -| 4 sentences, 8 concurrent | 536 emb/s | 215 emb/s | 424 emb/s | +| Instance | Cores | Peak emb/s | Concurrency | +|----------|-------|-----------|-------------| +| g5.xlarge (GPU) | 1 GPU | 536 | 8 concurrent | +| inf2.xlarge | 1 core | 216 | 4 concurrent | +| inf2.xlarge | 2 cores | 427 | 4 concurrent | +| trn2.3xlarge | 1 core | 348 | 4 concurrent | +| trn2.3xlarge | 4 cores | 753 | 4 concurrent | **Notes:** - g5.xlarge uses `--engine torch`; inf2/trn2 use `--engine neuron` - Neuron latency is constant regardless of batch content (padded to compiled batch size) -- GPU throughput scales with concurrency; Neuron throughput is flat +- trn2 has ~30% lower latency per core than inf2 (19ms vs 25ms) +- Throughput scales linearly with data parallelism (1 process per core) - Compilation time: ~60-100 seconds on first run (cached after that) -Tested on HuggingFace Neuron AMI (optimum-neuron 0.4.4, neuronx-cc 2.21, SDK 2.27). +Tested on HuggingFace Neuron AMI (optimum-neuron 0.4.4, neuronx-cc 2.21, SDK 2.27) +and Deep Learning AMI Neuron Ubuntu 22.04 (SDK 2.28) for trn2. ## Tested Stack From 1a0490f8afdb01f5f4e66f7d28efac85eb9c04cb Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Mon, 2 Mar 2026 22:04:54 +0000 Subject: [PATCH 5/5] Address code review feedback: fix BetterTransformer guard, add NEURON_NUM_CORES env var - acceleration.py: Use globals() check instead of CHECK_OPTIMUM.is_available to correctly detect when bettertransformer import failed (mark_dirty does not invalidate the cached is_available property) - neuron.py: Add NEURON_NUM_CORES env var (default 1) so large models can opt in to tensor parallelism without source changes - neuron.py: Remove dead get_nc_count() function and unused imports - Dockerfile: Merge RUN layers, add --no-cache-dir, add tested version comment - README: Clean up redundant text in throughput table --- infra/aws_neuron/Dockerfile.neuron | 7 ++-- infra/aws_neuron/README.md | 10 +++--- .../infinity_emb/transformer/acceleration.py | 2 +- .../transformer/embedder/neuron.py | 32 +++++-------------- 4 files changed, 18 insertions(+), 33 deletions(-) diff --git a/infra/aws_neuron/Dockerfile.neuron b/infra/aws_neuron/Dockerfile.neuron index 4ee16669..c68c5cc6 100644 --- a/infra/aws_neuron/Dockerfile.neuron +++ b/infra/aws_neuron/Dockerfile.neuron @@ -11,10 +11,11 @@ FROM michaelf34/aws-neuron-base-img:0.0.25-inference AS base WORKDIR /app COPY ./infra/aws_neuron/requirements_no_gpu.txt requirements_no_gpu.txt -RUN pip3 install -r requirements_no_gpu.txt +RUN pip3 install --no-cache-dir -r requirements_no_gpu.txt RUN pip config set global.extra-index-url https://pip.repos.neuron.amazonaws.com -RUN pip3 install --no-deps sentence_transformers -RUN pip3 install --upgrade neuronx-cc torch-neuronx torchvision libneuronxla protobuf optimum-neuron optimum +# Tested with: optimum-neuron 0.4.4, optimum 2.0.0, neuronx-cc 2.21, torch-neuronx 2.8 +RUN pip3 install --no-cache-dir --no-deps sentence_transformers && \ + pip3 install --no-cache-dir --upgrade neuronx-cc torch-neuronx torchvision libneuronxla protobuf optimum-neuron optimum FROM base AS infinity_latest COPY ./libs/infinity_emb . diff --git a/infra/aws_neuron/README.md b/infra/aws_neuron/README.md index 6ee0d17e..e86aff20 100644 --- a/infra/aws_neuron/README.md +++ b/infra/aws_neuron/README.md @@ -84,11 +84,11 @@ curl http://localhost:7997/embeddings \ | Instance | Cores | Peak emb/s | Concurrency | |----------|-------|-----------|-------------| -| g5.xlarge (GPU) | 1 GPU | 536 | 8 concurrent | -| inf2.xlarge | 1 core | 216 | 4 concurrent | -| inf2.xlarge | 2 cores | 427 | 4 concurrent | -| trn2.3xlarge | 1 core | 348 | 4 concurrent | -| trn2.3xlarge | 4 cores | 753 | 4 concurrent | +| g5.xlarge (GPU) | 1 GPU | 536 | 8 | +| inf2.xlarge | 1 core | 216 | 4 | +| inf2.xlarge | 2 cores | 427 | 4 | +| trn2.3xlarge | 1 core | 348 | 4 | +| trn2.3xlarge | 4 cores | 753 | 4 | **Notes:** - g5.xlarge uses `--engine torch`; inf2/trn2 use `--engine neuron` diff --git a/libs/infinity_emb/infinity_emb/transformer/acceleration.py b/libs/infinity_emb/infinity_emb/transformer/acceleration.py index 135c0ed2..90964a4e 100644 --- a/libs/infinity_emb/infinity_emb/transformer/acceleration.py +++ b/libs/infinity_emb/infinity_emb/transformer/acceleration.py @@ -46,7 +46,7 @@ def check_if_bettertransformer_possible(engine_args: "EngineArgs") -> bool: if not engine_args.bettertransformer: return False - if not CHECK_OPTIMUM.is_available: + if "BetterTransformerManager" not in globals(): return False config = AutoConfig.from_pretrained( diff --git a/libs/infinity_emb/infinity_emb/transformer/embedder/neuron.py b/libs/infinity_emb/infinity_emb/transformer/embedder/neuron.py index 8ca2122e..4b9f8411 100644 --- a/libs/infinity_emb/infinity_emb/transformer/embedder/neuron.py +++ b/libs/infinity_emb/infinity_emb/transformer/embedder/neuron.py @@ -2,10 +2,7 @@ # Copyright (c) 2023-now michaelfeil import copy -import json -import subprocess -from typing import Union -from functools import cache +import os import numpy as np from infinity_emb._optional_imports import CHECK_OPTIMUM_NEURON, CHECK_TORCH @@ -30,22 +27,6 @@ ] -@cache -def get_nc_count() -> Union[int, None]: - """Returns the number of neuron cores on the current instance.""" - try: - cmd = "neuron-ls --json-output" - result = subprocess.run(cmd, shell=True, capture_output=True) - print("inferring nc_count from `neuron-ls`") - print(result.stdout.decode("utf-8")) - json_output = json.loads(result.stdout) - count = sum([x["nc_count"] for x in json_output]) - print(f"nc_count={count}") - return count - except Exception: - return None - - def pad_up_to_size(desired_max_bs: int, input_ids: "torch.Tensor") -> "torch.Tensor": """input_ids a 2D array with batch_size on dim=0 @@ -97,10 +78,13 @@ def __init__(self, *, engine_args: EngineArgs): ) self._infinity_tokenizer = copy.deepcopy(self.tokenizer) - # Compile for a single NeuronCore. For data-parallel scaling across - # multiple cores, run separate server processes pinned to individual - # cores via NEURON_RT_VISIBLE_CORES (see README). - compiler_args = {"num_cores": 1, "auto_cast_type": "fp16"} + # Default to 1 NeuronCore (data parallelism). For large models that + # require tensor parallelism across multiple cores, set the + # NEURON_NUM_CORES environment variable. For data-parallel scaling, + # run separate server processes pinned to individual cores via + # NEURON_RT_VISIBLE_CORES (see infra/aws_neuron/README.md). + num_cores = int(os.environ.get("NEURON_NUM_CORES", "1")) + compiler_args = {"num_cores": num_cores, "auto_cast_type": "fp16"} input_shapes = { "batch_size": engine_args.batch_size, "sequence_length": (