Describe the feature
We want a StableHLO/PJRT compatible mechanism to express graph level memory placement and explicit data movement between Neuron memory spaces.
The Neuron PJRT plugin already exposes multiple memory kinds. For example, on inf2.8xlarge, each visible NeuronCore reports both device and pinned_host memory:
platform: neuron
version: 0.104
extensions:
profiler
memory_descriptions
layouts
devices:
NC_v2 (NeuronCore(id=0, process_index=0, local_id=0))
memories:
NeuronMemory(id=0, kind=device)
NeuronMemory(id=1, kind=pinned_host)
NC_v2 (NeuronCore(id=1, process_index=0, local_id=1))
memories:
NeuronMemory(id=2, kind=device)
NeuronMemory(id=3, kind=pinned_host)
What is missing is a clear front end contract connecting those PJRT memory descriptions to StableHLO graph construction and runtime buffer creation. A custom StableHLO frontend needs to be able to say:
- This function argument is expected in memory kind X.
- This result should be materialized in memory kind X.
- This intermediate tensor should be moved/materialized in memory kind X.
Runtime uploads via PJRT_Client_BufferFromHostBuffer should be able to target the same memory kind that the compiled graph expects.
Existing XLA/PJRT conventions appear relevant, but their Neuron support is not documented:
- function argument/result attributes such as
mhlo.memory_kind = "pinned_host";
stablehlo.custom_call with call_target_name = "annotate_device_placement";
mhlo.frontend_attributes containing _xla_buffer_placement = "pinned_host".
The requested feature is either support and documentation for these conventions on Neuron, or an equivalent Neuron-specific public contract.
Use Case
LLM decoding is sensitive to small host/device synchronization points. Even when the main compute kernels are fast, token throughput can be limited by:
- metadata movement
- synchronous host/device transfers
- implicit copies inserted by the compiler or runtime
Small inference metadata tensors are common in decode workloads, for example block/page tables, seqlens, tokens offsets....
For these tensors a custom frontend may want to use a pinned host memory kind when the PJRT plugin advertises one. Without a graph-level placement contract the frontend cannot know whether a compiled StableHLO argument/result/intermediate is supposed to be: device or pinned host.
Proposed Solution
Support existing XLA memory-kind conventions on Neuron:
Function argument/result placement:
func.func public @main(
%metadata: tensor<...> {mhlo.memory_kind = "pinned_host"},
%device_tensor: tensor<...> {mhlo.memory_kind = "device"}
) -> (
tensor<...> {mhlo.memory_kind = "device"}
)
Intra graph placement or movement:
%placed = stablehlo.custom_call @annotate_device_placement(%x)
{
call_target_name = "annotate_device_placement",
has_side_effect = true,
backend_config = "",
mhlo.frontend_attributes = {
_xla_buffer_placement = "pinned_host"
}
}
If the generic XLA convention is not supported or not intended for Neuron, a Neuron-specific documented custom call would also work, for example:
%placed = stablehlo.custom_call @AwsNeuronMemoryPlacement(%x)
{
call_target_name = "AwsNeuronMemoryPlacement",
backend_config = "pinned_host"
}
The exact spelling is less important than having a public and stable contract usable by StableHLO/PJRT frontends.
Other Information
No response
Acknowledgements
Describe the feature
We want a StableHLO/PJRT compatible mechanism to express graph level memory placement and explicit data movement between Neuron memory spaces.
The Neuron PJRT plugin already exposes multiple memory kinds. For example, on
inf2.8xlarge, each visible NeuronCore reports bothdeviceandpinned_hostmemory:What is missing is a clear front end contract connecting those PJRT memory descriptions to StableHLO graph construction and runtime buffer creation. A custom StableHLO frontend needs to be able to say:
Runtime uploads via
PJRT_Client_BufferFromHostBuffershould be able to target the same memory kind that the compiled graph expects.Existing XLA/PJRT conventions appear relevant, but their Neuron support is not documented:
mhlo.memory_kind = "pinned_host";stablehlo.custom_callwithcall_target_name = "annotate_device_placement";mhlo.frontend_attributescontaining_xla_buffer_placement = "pinned_host".The requested feature is either support and documentation for these conventions on Neuron, or an equivalent Neuron-specific public contract.
Use Case
LLM decoding is sensitive to small host/device synchronization points. Even when the main compute kernels are fast, token throughput can be limited by:
Small inference metadata tensors are common in decode workloads, for example block/page tables, seqlens, tokens offsets....
For these tensors a custom frontend may want to use a pinned host memory kind when the PJRT plugin advertises one. Without a graph-level placement contract the frontend cannot know whether a compiled StableHLO argument/result/intermediate is supposed to be: device or pinned host.
Proposed Solution
Support existing XLA memory-kind conventions on Neuron:
Function argument/result placement:
Intra graph placement or movement:
If the generic XLA convention is not supported or not intended for Neuron, a Neuron-specific documented custom call would also work, for example:
The exact spelling is less important than having a public and stable contract usable by StableHLO/PJRT frontends.
Other Information
No response
Acknowledgements