Skip to content

feat: NixlConnector support for Neuron disaggregated inference#27

Open
dmvevents wants to merge 1 commit intovllm-project:release-0.5.0from
dmvevents:feat/nixl-connector-neuron
Open

feat: NixlConnector support for Neuron disaggregated inference#27
dmvevents wants to merge 1 commit intovllm-project:release-0.5.0from
dmvevents:feat/nixl-connector-neuron

Conversation

@dmvevents
Copy link
Copy Markdown

Summary

Adds NixlConnector (NIXL KV cache transfer) support to vllm-neuron, enabling disaggregated prefill/decode serving where a Trainium worker handles prefill and an NVIDIA GPU worker handles decode. KV cache is transferred from Neuron device memory through CPU DRAM via NIXL LIBFABRIC over EFA RDMA.

This follows the same host-buffer pattern already established by TPU and XPU in vLLM core: accelerator device memory that cannot be directly registered with NIXL is staged through pinned CPU DRAM.

Proven by POC: Trainium prefill (24ms) → NIXL RDMA transfer → H100 decode (2139ms total, Qwen3-0.6B) via Dynamo with KV-aware routing.

Changes

1. vllm_neuron/attention/neuron_attn.py — Implement get_kv_cache_shape()

The NotImplementedError stub is replaced with the actual KV cache shape matching CUDA FlashAttention paged block format:

return (2, num_blocks, block_size, num_kv_heads, head_size)

Why: NixlConnector calls get_kv_cache_shape() during handshake to compute block offsets. Without this, the connector crashes at startup.

2. vllm_neuron/worker/neuronx_distributed_model_runner.py — KV reshape + block alignment

Moves KV tensors to CPU and reshapes NxDI's BHSD layout (batch, kv_heads, seq_len, head_dim) into CUDA paged block format (2, num_blocks, block_size, kv_heads, head_dim). Also aligns the block count to kv_cache_config.num_blocks (the scheduler's pool size) to prevent makeXferReq: remote index out of range.

Why: NIXL cannot register Neuron device memory (no CUDA, no Level Zero). This is the same pattern TPU and XPU use — stage KV through CPU DRAM. The reshape ensures NixlConnector can index into the paged blocks correctly.

3. vllm_neuron/worker/neuron_worker.py — Add get_kv_connector_handshake_metadata()

def get_kv_connector_handshake_metadata(self):
    from vllm.distributed.kv_transfer import has_kv_transfer_group, get_kv_transfer_group
    from vllm.distributed.parallel_state import get_tp_group
    if not has_kv_transfer_group():
        return None
    connector = get_kv_transfer_group()
    metadata = connector.get_handshake_metadata()
    if metadata is None:
        return None
    return {get_tp_group().rank_in_group: metadata}

Why: The V1 engine calls this method during startup to exchange NIXL agent metadata between workers. Without it, the prefill and decode workers cannot discover each other.

4. vllm_neuron/platform.py — Accept **kwargs in get_attn_backend_cls()

Adds **kwargs and default values for has_sink and use_sparse parameters.

Why: vLLM 0.16+ added new keyword arguments to this method. Without **kwargs, the Neuron platform crashes on the new arguments.

5. vllm_neuron/worker/neuronx_distributed_model_runner.py — Skip NIXL context when no transfers pending

Guards the KV connector context manager entry with a check for actual pending transfers (reqs_to_recv or reqs_to_send).

Why: The NIXL context manager calls start_load_kv() which blocks waiting for incoming data. On non-transfer iterations (the majority of decode iterations), this hangs the worker indefinitely.

6. vllm_neuron/worker/neuronx_distributed_model_runner.py — Fix max_model_len attribute

# Before:
max_len = self.scheduler_config.max_model_len
# After:
max_len = self.model_config.max_model_len

Why: In the disaggregated inference path, max_model_len is on ModelConfig, not SchedulerConfig.

7. vllm_neuron/platform.py — Fix get_nixl_supported_devices() + add get_nixl_memory_type()

  • Returns ("cpu",) (tuple) instead of {"cpu"} (set) to match the expected type
  • Adds get_nixl_memory_type() returning "DRAM" for host-buffered transfers
  • Improved docstrings

Runtime Workarounds (not in this PR)

Two patches target code outside vllm-neuron and must be applied at runtime:

  • P5 (vllm/v1/executor/uniproc_executor.py): Force max_concurrent_batches=1. Neuron compilation and execution are synchronous — the default async batch overlap causes race conditions.
  • P8 (NxD Inference application_base.py): Add --skip-pass=SimplifyNeuronTensor compiler flag. Workaround for NCC_ITEN404 compiler crash with paged attention HLO.

Testing

Hardware: trn1.32xlarge (prefill) + p5.48xlarge/H100 (decode), same VPC/AZ, EFA RDMA

End-to-end test via Dynamo:

# Start Dynamo frontend + prefill (Trainium) + decode (H100)
kubectl apply -f disagg_trainium_cuda_efa.yaml

# Send request
curl -s http://localhost:8000/v1/completions \
  -H "Content-Type: application/json" \
  -d '{"model":"Qwen/Qwen3-0.6B","prompt":"What is AI?","max_tokens":30}'

Result: 24.2ms prefill on Trainium, 2139ms total with 30 tokens decoded on H100.

Configuration:

# Prefill (Trainium)
--kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both","kv_buffer_device":"cpu","kv_connector_extra_config":{"backends":["LIBFABRIC"]}}'

# Decode (NVIDIA GPU)
--kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both","kv_buffer_device":"cuda","kv_connector_extra_config":{"backends":["LIBFABRIC"]}}'

Known Limitations

  • Prefix caching must be disabled on Trainium due to neuronx-cc NCC_ITEN404 bug (filed as aws-neuron/aws-neuron-sdk#1304)
  • max-num-seqs must be 1 on Trainium to avoid the same compiler bug

Related Issues

Enable cross-accelerator KV cache transfer via NIXL LIBFABRIC over
EFA RDMA, allowing Trainium prefill + NVIDIA GPU decode serving.

Changes (9 patches across 4 files in vllm-neuron):

P1: neuron_attn.py — implement get_kv_cache_shape() returning the
    CUDA FlashAttention paged block dimension order
    (2, num_blocks, block_size, num_kv_heads, head_size).

P2/P9/P10: neuronx_distributed_model_runner.py — move KV tensors to
    CPU, reshape NxDI BHSD layout into CUDA paged block format, and
    align block count to scheduler pool size. Prevents
    "makeXferReq: remote index out of range".

P3: neuron_worker.py — add get_kv_connector_handshake_metadata() so
    V1 engine can exchange NIXL agent metadata at startup.

P4: platform.py — add **kwargs and defaults to get_attn_backend_cls()
    for forward compatibility with new vLLM keyword arguments.

P6: neuronx_distributed_model_runner.py — skip NIXL context manager
    when no transfers are pending, avoiding blocking on start_load_kv().

P7: neuronx_distributed_model_runner.py — fix max_model_len reference
    (scheduler_config -> model_config) for disaggregated inference path.

Also: fix get_nixl_supported_devices() return type (set -> tuple),
add get_nixl_memory_type() returning "DRAM", improve docstrings.

Note: P5 (force sync execution) and P8 (--skip-pass=SimplifyNeuronTensor)
target vllm core and NxD Inference respectively — documented as runtime
workarounds in the PR description.

Tested: Dynamo cross-accel P/D (Trainium prefill + H100 decode)
- Multi-request RDMA at 7.6 GB/s
- 3P1D load balancing
- Role swap (GPU prefill, Neuron decode)
- 4 concurrent requests with KVBM offload

References: vllm-project#26
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant