From 6756c65dc9f59554c74618b4122e236cb8010440 Mon Sep 17 00:00:00 2001 From: Deepankar Singh Date: Wed, 10 Jun 2026 22:22:07 +0530 Subject: [PATCH 1/3] Add Qwen3.6-27B hybrid DeltaNet/GQA contrib model and vLLM serving path Clean-room cut of codex/nki-deltanet-multihead-cte (43ca740) onto upstream main: contrib model + NKI kernels, hybrid APC/GDN cache support, core substrate, unit tests, validation scripts, and the verified nativechunk baseline evidence. Co-Authored-By: Claude Fable 5 --- contrib/models/Qwen3.6-27B/README.md | 344 + .../docs/CODEX_CONTINUOUS_BATCHING_PROMPT.md | 193 + .../docs/FULL_FP8_ISSUES_AND_FIXES.md | 405 + .../docs/HYBRID_APC_PRODUCTION_PLAN.md | 364 + .../QWEN36_FP8_TIERFIX_VALIDATION_20260526.md | 2364 +++++ .../docs/patches/mtp_batched_accept.patch | 154 + .../docs/patches/mtp_batched_accept_README.md | 118 + .../scripts/openai_compat_server.py | 501 + .../scripts/probe_qkvgate_kernel_layout.py | 165 + .../validate_deltanet_chunk_step_nki.py | 340 + .../scripts/validate_deltanet_fused_nki.py | 2638 ++++++ .../validate_deltanet_recurrent_step_nki.py | 247 + .../scripts/validate_qwen_segcte_attention.py | 361 + contrib/models/Qwen3.6-27B/src/__init__.py | 41 + contrib/models/Qwen3.6-27B/src/hybrid_apc.py | 1798 ++++ .../models/Qwen3.6-27B/src/modeling_qwen35.py | 8040 +++++++++++++++++ .../Qwen3.6-27B/src/modeling_qwen35_vision.py | 819 ++ .../Qwen3.6-27B/src/modeling_qwen35_vl.py | 662 ++ .../Qwen3.6-27B/src/nki_kernels/__init__.py | 10 + .../src/nki_kernels/nki_deltanet.py | 607 ++ .../src/nki_kernels/nki_deltanet_chunked.py | 431 + .../src/nki_kernels/nki_deltanet_fused.py | 2991 ++++++ .../nki_kernels/nki_deltanet_fused_legacy.py | 613 ++ .../src/nki_kernels/qwen_qk_norm_rope.py | 230 + contrib/models/Qwen3.6-27B/test/__init__.py | 0 .../Qwen3.6-27B/test/integration/__init__.py | 0 .../integration/qwen36_27b_compile_fp8.py | 1705 ++++ .../test/integration/test_model.py | 605 ++ .../models/Qwen3.6-27B/test/unit/__init__.py | 0 .../Qwen3.6-27B/test/unit/test_config.py | 303 + .../test/unit/test_deltanet_decay.py | 523 ++ .../test/unit/test_hybrid_apc_manager.py | 1790 ++++ .../test/unit/test_hybrid_apc_validation.py | 405 + .../test/unit/test_hybrid_cache_manager.py | 464 + .../unit/test_qwen36_artifact_config_audit.py | 102 + .../test/unit/test_qwen36_chat_proxy.py | 166 + .../unit/test_qwen36_compile_fp8_config.py | 1307 +++ .../test/unit/test_qwen36_model_aliases.py | 1941 ++++ .../test/unit/test_vllm_scheduler_patch.py | 2465 +++++ .../test/unit/test_vllm_serving_config.py | 368 + .../test/unit/test_weight_conversion.py | 535 ++ contrib/models/Qwen3.6-27B/vllm/README.md | 483 + .../Qwen3.6-27B/vllm/hf_qwen35_config.py | 68 + .../Qwen3.6-27B/vllm/install_qwen36_vllm.sh | 61 + .../Qwen3.6-27B/vllm/patch_nxdi_registry.py | 71 + .../vllm_neuron_qwen36_kv_cache_spec.patch | 47 + .../Qwen3.6-27B/vllm/qwen36_chat_proxy.py | 446 + .../vllm/qwen36_hybrid_apc_scheduler_patch.py | 2774 ++++++ .../Qwen3.6-27B/vllm/run_offline_inference.py | 586 ++ .../models/Qwen3.6-27B/vllm/serve_qwen36.py | 25 + .../models/Qwen3.6-27B/vllm/sitecustomize.py | 23 + .../Qwen3.6-27B/vllm/start_vllm_server.sh | 697 ++ .../models/config.py | 79 + .../models/model_base.py | 335 +- .../models/model_wrapper.py | 795 +- .../modules/async_execution.py | 2571 +++++- .../modules/attention/attention_base.py | 492 +- .../modules/attention/gqa.py | 439 +- .../modules/attention/nki_kernels/__init__.py | 0 .../qwen_gated_output_projection.py | 371 + .../nki_kernels/qwen_segcte256/__init__.py | 0 .../attention_segmented_cte_256.py | 2113 +++++ .../fused_segmented_attention_256.py | 1527 ++++ .../modules/attention/utils.py | 8 +- .../modules/autobucketing.py | 6 + .../modules/generation/sampling.py | 79 +- .../modules/kvcache/block_kv_cache_manager.py | 115 +- .../modules/kvcache/hybrid_prefix_cache.py | 238 + .../modules/kvcache/utils.py | 68 + .../modules/sliding_window/attention.py | 6 +- .../utils/constants.py | 15 + .../utils/hf_adapter.py | 37 +- test/unit/models/test_model_wrapper.py | 201 + .../test_prefix_caching_bucket_selection.py | 573 ++ .../modules/attention/test_attention_base.py | 69 + test/unit/modules/attention/test_gqa.py | 286 + test/unit/modules/generation/test_sampling.py | 62 + .../kvcache/test_block_kv_cache_manager.py | 276 + .../kvcache/test_hybrid_prefix_cache.py | 141 + test/unit/modules/test_async_execution.py | 2347 ++++- test/unit/modules/test_autobucketing.py | 22 + .../test_qwen36_openai_compat_server.py | 46 + .../scripts/test_qwen36_validation_gates.py | 53 + .../baseline_258k_chunk_timing.json | 7 + .../baseline_summary.json | 66 + .../log_scan_empty.txt | 0 ...ossguard_258k_probe_20260609T000000Z.jsonl | 1 + ...uard_clean16k_probe_20260609T000000Z.jsonl | 1 + validation_scripts/neuron_memory_sampler.py | 188 + .../qwen36_artifact_config_audit.py | 270 + .../qwen36_bf16_length_sweep.py | 120 + .../qwen36_chat_completion_context_bench.py | 562 ++ .../qwen36_hf_first_mismatch_logits.py | 195 + .../qwen36_hf_neuron_greedy_match.py | 212 + .../qwen36_hybrid_apc_context_sweep.py | 466 + .../qwen36_hybrid_apc_validation.py | 1475 +++ .../qwen36_offline_decode_bench.py | 446 + .../qwen36_openai_boundary_apc_probe.py | 266 + .../qwen36_openai_chat_apc_validation.py | 641 ++ .../qwen36_split_qkv_tkg_probe.py | 197 + .../qwen36_steady_cold_prefill_bench.py | 216 + 101 files changed, 60795 insertions(+), 300 deletions(-) create mode 100644 contrib/models/Qwen3.6-27B/README.md create mode 100644 contrib/models/Qwen3.6-27B/docs/CODEX_CONTINUOUS_BATCHING_PROMPT.md create mode 100644 contrib/models/Qwen3.6-27B/docs/FULL_FP8_ISSUES_AND_FIXES.md create mode 100644 contrib/models/Qwen3.6-27B/docs/HYBRID_APC_PRODUCTION_PLAN.md create mode 100644 contrib/models/Qwen3.6-27B/docs/QWEN36_FP8_TIERFIX_VALIDATION_20260526.md create mode 100644 contrib/models/Qwen3.6-27B/docs/patches/mtp_batched_accept.patch create mode 100644 contrib/models/Qwen3.6-27B/docs/patches/mtp_batched_accept_README.md create mode 100644 contrib/models/Qwen3.6-27B/scripts/openai_compat_server.py create mode 100644 contrib/models/Qwen3.6-27B/scripts/probe_qkvgate_kernel_layout.py create mode 100644 contrib/models/Qwen3.6-27B/scripts/validate_deltanet_chunk_step_nki.py create mode 100644 contrib/models/Qwen3.6-27B/scripts/validate_deltanet_fused_nki.py create mode 100644 contrib/models/Qwen3.6-27B/scripts/validate_deltanet_recurrent_step_nki.py create mode 100644 contrib/models/Qwen3.6-27B/scripts/validate_qwen_segcte_attention.py create mode 100644 contrib/models/Qwen3.6-27B/src/__init__.py create mode 100644 contrib/models/Qwen3.6-27B/src/hybrid_apc.py create mode 100644 contrib/models/Qwen3.6-27B/src/modeling_qwen35.py create mode 100644 contrib/models/Qwen3.6-27B/src/modeling_qwen35_vision.py create mode 100644 contrib/models/Qwen3.6-27B/src/modeling_qwen35_vl.py create mode 100644 contrib/models/Qwen3.6-27B/src/nki_kernels/__init__.py create mode 100644 contrib/models/Qwen3.6-27B/src/nki_kernels/nki_deltanet.py create mode 100644 contrib/models/Qwen3.6-27B/src/nki_kernels/nki_deltanet_chunked.py create mode 100644 contrib/models/Qwen3.6-27B/src/nki_kernels/nki_deltanet_fused.py create mode 100644 contrib/models/Qwen3.6-27B/src/nki_kernels/nki_deltanet_fused_legacy.py create mode 100644 contrib/models/Qwen3.6-27B/src/nki_kernels/qwen_qk_norm_rope.py create mode 100644 contrib/models/Qwen3.6-27B/test/__init__.py create mode 100644 contrib/models/Qwen3.6-27B/test/integration/__init__.py create mode 100644 contrib/models/Qwen3.6-27B/test/integration/qwen36_27b_compile_fp8.py create mode 100644 contrib/models/Qwen3.6-27B/test/integration/test_model.py create mode 100644 contrib/models/Qwen3.6-27B/test/unit/__init__.py create mode 100644 contrib/models/Qwen3.6-27B/test/unit/test_config.py create mode 100644 contrib/models/Qwen3.6-27B/test/unit/test_deltanet_decay.py create mode 100644 contrib/models/Qwen3.6-27B/test/unit/test_hybrid_apc_manager.py create mode 100644 contrib/models/Qwen3.6-27B/test/unit/test_hybrid_apc_validation.py create mode 100644 contrib/models/Qwen3.6-27B/test/unit/test_hybrid_cache_manager.py create mode 100644 contrib/models/Qwen3.6-27B/test/unit/test_qwen36_artifact_config_audit.py create mode 100644 contrib/models/Qwen3.6-27B/test/unit/test_qwen36_chat_proxy.py create mode 100644 contrib/models/Qwen3.6-27B/test/unit/test_qwen36_compile_fp8_config.py create mode 100644 contrib/models/Qwen3.6-27B/test/unit/test_qwen36_model_aliases.py create mode 100644 contrib/models/Qwen3.6-27B/test/unit/test_vllm_scheduler_patch.py create mode 100644 contrib/models/Qwen3.6-27B/test/unit/test_vllm_serving_config.py create mode 100644 contrib/models/Qwen3.6-27B/test/unit/test_weight_conversion.py create mode 100644 contrib/models/Qwen3.6-27B/vllm/README.md create mode 100644 contrib/models/Qwen3.6-27B/vllm/hf_qwen35_config.py create mode 100755 contrib/models/Qwen3.6-27B/vllm/install_qwen36_vllm.sh create mode 100644 contrib/models/Qwen3.6-27B/vllm/patch_nxdi_registry.py create mode 100644 contrib/models/Qwen3.6-27B/vllm/patches/vllm_neuron_qwen36_kv_cache_spec.patch create mode 100644 contrib/models/Qwen3.6-27B/vllm/qwen36_chat_proxy.py create mode 100644 contrib/models/Qwen3.6-27B/vllm/qwen36_hybrid_apc_scheduler_patch.py create mode 100644 contrib/models/Qwen3.6-27B/vllm/run_offline_inference.py create mode 100644 contrib/models/Qwen3.6-27B/vllm/serve_qwen36.py create mode 100644 contrib/models/Qwen3.6-27B/vllm/sitecustomize.py create mode 100755 contrib/models/Qwen3.6-27B/vllm/start_vllm_server.sh create mode 100644 src/neuronx_distributed_inference/modules/attention/nki_kernels/__init__.py create mode 100644 src/neuronx_distributed_inference/modules/attention/nki_kernels/qwen_gated_output_projection.py create mode 100644 src/neuronx_distributed_inference/modules/attention/nki_kernels/qwen_segcte256/__init__.py create mode 100644 src/neuronx_distributed_inference/modules/attention/nki_kernels/qwen_segcte256/attention_segmented_cte_256.py create mode 100644 src/neuronx_distributed_inference/modules/attention/nki_kernels/qwen_segcte256/fused_segmented_attention_256.py create mode 100644 src/neuronx_distributed_inference/modules/kvcache/hybrid_prefix_cache.py create mode 100644 test/unit/modules/kvcache/test_hybrid_prefix_cache.py create mode 100644 test/unit/scripts/test_qwen36_openai_compat_server.py create mode 100644 test/unit/scripts/test_qwen36_validation_gates.py create mode 100644 validation_outputs/qwen36_nativechunk_baseline_20260609T000000Z/baseline_258k_chunk_timing.json create mode 100644 validation_outputs/qwen36_nativechunk_baseline_20260609T000000Z/baseline_summary.json create mode 100644 validation_outputs/qwen36_nativechunk_baseline_20260609T000000Z/log_scan_empty.txt create mode 100644 validation_outputs/qwen36_nativechunk_baseline_20260609T000000Z/qwen36_256k_nativechunk_crossguard_258k_probe_20260609T000000Z.jsonl create mode 100644 validation_outputs/qwen36_nativechunk_baseline_20260609T000000Z/qwen36_256k_nativechunk_crossguard_clean16k_probe_20260609T000000Z.jsonl create mode 100644 validation_scripts/neuron_memory_sampler.py create mode 100644 validation_scripts/qwen36_artifact_config_audit.py create mode 100644 validation_scripts/qwen36_bf16_length_sweep.py create mode 100644 validation_scripts/qwen36_chat_completion_context_bench.py create mode 100644 validation_scripts/qwen36_hf_first_mismatch_logits.py create mode 100644 validation_scripts/qwen36_hf_neuron_greedy_match.py create mode 100644 validation_scripts/qwen36_hybrid_apc_context_sweep.py create mode 100644 validation_scripts/qwen36_hybrid_apc_validation.py create mode 100644 validation_scripts/qwen36_offline_decode_bench.py create mode 100644 validation_scripts/qwen36_openai_boundary_apc_probe.py create mode 100644 validation_scripts/qwen36_openai_chat_apc_validation.py create mode 100644 validation_scripts/qwen36_split_qkv_tkg_probe.py create mode 100644 validation_scripts/qwen36_steady_cold_prefill_bench.py diff --git a/contrib/models/Qwen3.6-27B/README.md b/contrib/models/Qwen3.6-27B/README.md new file mode 100644 index 00000000..a6327f07 --- /dev/null +++ b/contrib/models/Qwen3.6-27B/README.md @@ -0,0 +1,344 @@ +# Contrib Model: Qwen3.6-27B + +NeuronX Distributed Inference implementation of Qwen3.6-27B, a 27B parameter dense model from Alibaba Cloud with a hybrid DeltaNet + GQA attention architecture. + +## Relationship to Qwen3.5-27B + +Qwen3.6-27B is a **post-training update** of Qwen3.5-27B with improved agentic coding and thinking preservation. The models share **identical architecture** (`qwen3_5` model_type, `Qwen3_5ForConditionalGeneration`) -- only weights differ. This contrib reuses the same NxDI implementation as [Qwen3.5-27B](../Qwen3.5-27B/) (PR #128). Any code updates to Qwen3.5-27B should be propagated to this contrib and vice versa. + +### Config differences from Qwen3.5-27B + +| Field | Value | Impact | +|-------|-------|--------| +| `output_gate_type` | `"swish"` | **Ignored** -- not used by HF transformers or NxDI (gate uses sigmoid) | +| `language_model_only` | `false` | Informational, not used by model code | +| `bos_token_id` | `248044` | New but not architecture-relevant | +| `pad_token_id` | `null` | New at text_config level (already handled) | +| `partial_rotary_factor` | `0.25` | Already in rope_parameters, redundant copy | +| `transformers_version` | `4.57.1` | Updated from `4.57.0.dev0` | + +No architecture changes are required relative to the Qwen3.5-27B hybrid +implementation. This contrib packages the NxDI Qwen3.6-27B model code, +DeltaNet NKI kernels, FP8/vLLM serving helpers, and validation coverage for the +Qwen3.6 weights. + +## Model Family + +| Model | HuggingFace ID | Params | Instance | +|-------|----------------|--------|----------| +| **Qwen3.6-27B** | [`Qwen/Qwen3.6-27B`](https://huggingface.co/Qwen/Qwen3.6-27B) | 27B | trn2.3xlarge (TP=4) | + +**License:** Apache 2.0 + +## Architecture Details + +| Feature | Value | +|---------|-------| +| Layers | 64 (48 DeltaNet + 16 GQA) | +| Layer Pattern | [3 DeltaNet + 1 GQA] x 16 | +| Hidden Size | 5120 | +| GQA Attention | 24 heads, 4 KV heads, head_dim=256 | +| DeltaNet Attention | 48 value heads, 16 key heads, k_dim=v_dim=128 | +| Dense MLP | SwiGLU (gate_proj + up_proj: 5120 -> 17408, down_proj: 17408 -> 5120) | +| Position Encoding | Partial RoPE (25% of head_dim = 64 dims), mRoPE for VL | +| Vocabulary | 248,320 | +| Normalization | RMSNorm with +1 weight convention | +| Activation | SiLU gated MLP | + +### Unique Architecture Features + +- **Hybrid DeltaNet + GQA:** 48 of 64 layers use Gated DeltaNet (linear recurrent attention), 16 layers use standard GQA with KV cache. The pattern repeats every 4 layers: 3 DeltaNet + 1 GQA. +- **DeltaNet Linear Attention:** Uses the delta rule for recurrent state updates with gated decay. Per-step: `state *= exp(g); delta = (v - state^T @ k) * beta; state += outer(k, delta); output = state^T @ q`. Runs as a chunked algorithm for context encoding, per-token recurrence for token generation. +- **Custom NKI Kernels:** Three NKI kernels implement the DeltaNet forward pass on Neuron: a per-token recurrent kernel (TKG), a per-chunk kernel (legacy), and a fused single-kernel chunked forward (CTE). The fused kernel uses a Neumann series for intra-chunk correction with state persistence in SBUF across chunks. +- **GQA Output Gate:** Attention layers use a sigmoid output gate. `q_proj` is 2x sized and interleaved: `[head0_query | head0_gate | head1_query | ...]`. The gate is split during weight conversion and applied after attention. +- **Partial RoPE:** Only 25% of head_dim (64 of 256 dimensions) receives rotary embeddings. The remaining 192 dimensions are identity (no rotation). +- **+1 RMSNorm Convention:** HF weights use `output = norm(x) * (1 + weight)` where weight is initialized to zeros. Converted to standard `output = norm(x) * weight` during loading by adding 1.0 to all RMSNorm weights (except DeltaNet internal norms, which use standard convention). +- **Vision-Language Support:** Optional ViT encoder runs on CPU (HBM fully consumed by 27B text decoder). Vision embeddings are injected via a scatter mask at traced input positions. + +## Test Results + +### Unit Tests (CPU) + +| Test Module | Tests | Status | +|-------------|-------|--------| +| test_config.py | 26 | 26/26 PASS | +| test_weight_conversion.py | 16 | 16/16 PASS | +| test_hybrid_cache_manager.py | 13 | 13/13 PASS | +| test_deltanet_decay.py | 2 | 2/2 PASS | +| **Total** | **57** | **57/57 PASS** | + +Unit tests are architecture-level and do not depend on weights. Coverage includes config parsing, weight conversion, hybrid cache allocation/update behavior, and DeltaNet decay handling. + +### Quality Validation (Qwen3.6-27B, trn2.3xlarge, TP=4, SDK 2.29) + +7/7 text-only quality tests passed with `enable_thinking=False`: + +| Test | Expected | Result | +|------|----------|--------| +| Speed of light | 299,792,458 m/s | PASS | +| 17 * 23 | 391 | PASS | +| 60mph * 2.5h | 150 miles | PASS | +| is_prime function | Correct Python | PASS | +| French translation | Bonjour, comment allez-vous ? | PASS | +| Capital of Japan | Tokyo | PASS | +| sqrt(144) | 12 | PASS | + +## Performance Benchmarks + +### Qwen3.6-27B on trn2.3xlarge (TP=4, LNC=2, SDK 2.29, BF16) + +**TTFT (Time To First Token)** + +| Input Length | P50 (ms) | P95 (ms) | +|-------------|----------|----------| +| 16 tokens | 305.3 | 305.6 | +| 64 tokens | 305.4 | 305.9 | +| 128 tokens | 306.6 | 306.8 | +| 256 tokens | 306.2 | 306.3 | + +**TPOT / Throughput** + +| Output Length | TPOT P50 (ms) | tok/s P50 | E2E P50 (ms) | +|--------------|---------------|-----------|---------------| +| 16 | 54.3 | 18.4 | 1,121 | +| 32 | 54.4 | 18.4 | 1,993 | +| 64 | 54.2 | 18.5 | 3,720 | +| 128 | 54.2 | 18.5 | 4,912 | + +### Comparison with Qwen3.5-27B + +| Metric | Qwen3.5-27B | Qwen3.6-27B | Delta | +|--------|------------|------------|-------| +| TPOT P50 | 53 ms | 54.2 ms | +2.3% | +| Throughput | 18.9 tok/s | 18.5 tok/s | -2.1% | +| TTFT (128 tok) | 576 ms | 306.6 ms | -47% * | + +\* TTFT improvement is due to compilation config differences (256-token bucket vs 128-token bucket), not model differences. Architectural performance is equivalent. + +### Long-Context vLLM Baseline + +A 128K FP8-MLP artifact was validated on trn2.3xlarge (TP=4, LNC=2, SDK 2.29) +with the vLLM Neuron plugin, Qwen chunked prefill, and native vLLM APC enabled. + +| Metric | Result | +|--------|--------| +| Max model length | 131,072 tokens | +| Context encoding bucket | 512 | +| Prefill throughput | 404-428 tok/s from 512 through 64K prompt tokens | +| Decode throughput | 26.3-26.6 tok/s | +| 64K quality | needle retrieval prompts returned all expected codes | +| State reset | repeated short-after-long validation passed after 32K and 64K requests | +| Peak Neuron device memory | ~53.25 GB decimal during the 64K eval | + +TTFT/TPOT details for the same 128K FP8/vLLM artifact: + +| Metric | Result | Notes | +|--------|--------|-------| +| Decode TPOT | ~37.6-38.0 ms/token | Derived from 26.3-26.6 tok/s decode | +| Cold 512-token TTFT | ~1.2-1.3s | Derived from measured prefill throughput plus one decode step | +| Cold 32K-token TTFT | ~76.6-81.1s | Derived from measured prefill throughput plus one decode step | +| Cold 64K-token TTFT | ~153-162s | Derived from measured prefill throughput plus one decode step | +| Warm APC latency, ~10.8K prompt | 1.36-2.38s | Exact-repeat, partial-prefix, and cross-prefix validation runs | +| Cold APC baseline, ~10.8K prompt | 25.17-26.68s | Same prompts with prefix cache disabled or cold | + +Native vLLM prefix caching/APC was also validated with exact greedy output +matches: + +| APC Scenario | Cold | Warm | Speedup | Result | +|--------------|------|------|---------|--------| +| Server exact-repeat, ~10.8K prompt tokens | 26.68s | 1.67s | 16.0x | exact text match | +| Offline exact-repeat | 26.19s | 2.38s | 11.0x | exact token-ID match | +| Offline partial-prefix reuse | 25.52s | 1.70s | 15.0x | exact token-ID match | +| Server cross-prefix reuse | 25.17s | 1.36s | 18.5x | exact text match | + +### Key Observations + +- **BF16 TP=4 is HBM-limited:** The pure BF16 path is limited to short contexts on trn2.3xlarge. The validated 128K baseline uses MLP-only FP8 weights plus the hybrid cache manager. +- **DeltaNet enables efficient TKG:** Token generation uses O(1) per-token recurrence instead of O(n) KV cache attention for 48/64 layers. +- **vLLM APC is high leverage:** Repeated-prefix requests avoid replaying long chunked prefill and are the largest observed latency win for chat/RAG-style workloads. +- **Performance equivalent to Qwen3.5-27B:** The BF16 TPOT difference is within measurement noise. Expected since architectures are identical. + +## Usage + +### Text-Only (trn2.3xlarge, TP=4) + +```python +import json +import torch +from transformers import AutoTokenizer, GenerationConfig +from neuronx_distributed_inference.models.config import NeuronConfig, OnDeviceSamplingConfig +from neuronx_distributed_inference.utils.hf_adapter import HuggingFaceGenerationAdapter + +from src.modeling_qwen35 import Qwen35InferenceConfig, NeuronQwen35ForCausalLM + +model_path = "/path/to/Qwen3.6-27B" +compiled_path = "/scratch/qwen36_traced/" + +neuron_config = NeuronConfig( + tp_degree=4, + batch_size=1, + ctx_batch_size=1, + tkg_batch_size=1, + seq_len=128, + torch_dtype=torch.bfloat16, + logical_nc_config=2, + enable_bucketing=False, + flash_decoding_enabled=False, + on_device_sampling_config=OnDeviceSamplingConfig(top_k=1), + save_sharded_checkpoint=True, +) + +# Read config.json directly (model_type 'qwen3_5' may not be +# registered in all transformers versions) +import os +with open(os.path.join(model_path, "config.json")) as f: + hf_config = json.load(f) +text_config = hf_config.get("text_config", hf_config) +config_dict = dict(text_config) +config_dict["pad_token_id"] = text_config.get("eos_token_id", 248044) +config_dict.setdefault("tie_word_embeddings", False) + +config = Qwen35InferenceConfig( + neuron_config=neuron_config, + **config_dict, +) + +model = NeuronQwen35ForCausalLM(model_path, config) +model.compile(compiled_path) + +# Reload from compiled artifacts +model = NeuronQwen35ForCausalLM(compiled_path) +model.load(compiled_path) + +tokenizer = AutoTokenizer.from_pretrained(model_path, padding_side="right") +gen_config = GenerationConfig( + do_sample=True, top_k=1, + pad_token_id=tokenizer.pad_token_id, + eos_token_id=tokenizer.eos_token_id, +) + +inputs = tokenizer("The capital of France is", return_tensors="pt") +gen_model = HuggingFaceGenerationAdapter(model) +outputs = gen_model.generate( + inputs.input_ids, + generation_config=gen_config, + attention_mask=inputs.attention_mask, + max_new_tokens=50, +) +print(tokenizer.decode(outputs[0], skip_special_tokens=True)) +``` + +### Vision-Language (trn2.3xlarge, TP=4) + +The VL pipeline uses the text decoder on Neuron and the vision encoder on CPU: + +```python +from src.modeling_qwen35_vl import NeuronQwen35VLForCausalLM, Qwen35VLInferenceConfig + +vl_model = NeuronQwen35VLForCausalLM( + model_path="/path/to/Qwen3.6-27B", + config=vl_config, +) +vl_model.compile(compiled_path) +vl_model.load(compiled_path) + +# See test/integration/test_model.py for full VL usage example +``` + +### DeltaNet Kernel Selection + +The DeltaNet forward path can be controlled via environment variables: + +| Env Var | Forward Path | Use Case | +|---------|-------------|----------| +| `USE_NKI_FUSED=1` | Fused chunked NKI kernel | Best CTE performance (default for SDK 2.29) | +| `USE_NKI_CHUNKED=1` | Per-chunk NKI kernel | Legacy, superseded by fused | +| `USE_NKI=1` | Per-token NKI kernel | TKG (always used for token generation) | +| `DELTANET_SEQUENTIAL=1` | Sequential PyTorch | Debugging/reference | +| *(none)* | PyTorch chunked | Default fallback for CTE | + +## Caveats + +1. **BF16 HBM pressure at TP=4:** The pure BF16 model consumes nearly all HBM on trn2.3xlarge. Use the FP8/vLLM path for the validated 128K artifact, or a larger instance for additional batching/headroom. + +2. **SDK 2.29+ required:** The NKI DeltaNet kernels require NKI 0.3.0 (SDK 2.29). No library modifications needed -- runs on stock SDK 2.29 DLAMI (`/opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/`). + +3. **No mini model test:** Unlike DeepSeek-V3, a mini model cannot be provided because DeltaNet layers require NKI kernels that only execute on Neuron devices. Integration tests require a trn2 instance with the full 27B weights. + +4. **Vision encoder runs on CPU:** The ViT cannot be placed on Neuron because HBM is fully consumed by the text decoder. This adds ~918ms latency per image. Future optimization: quantize text decoder to free HBM, or use larger instance. + +5. **Compilation time:** The short-context BF16 path compiles in roughly 13 minutes. The validated 128K FP8/vLLM artifact takes longer because it includes long-context cache shapes and presharded checkpoints. + +6. **+1 RMSNorm convention:** Qwen3.5/3.6 uses `output = norm(x) * (1 + weight)` for most RMSNorm layers, but DeltaNet internal norms use standard `output = norm(x) * weight`. The weight conversion handles this automatically, but custom weight loading must be aware of both conventions. + +7. **DeltaNet numerical stability:** DeltaNet kernels rely on normalized Q/K inputs and bounded decay handling. The chunked path includes regression coverage for decay handling; changes to the fused kernel should be validated against the CPU reference and long-context stress prompts. + +8. **Shared codebase with Qwen3.5-27B:** This contrib uses the same `Qwen35*` class names and `modeling_qwen35*.py` filenames as the [Qwen3.5-27B contrib](../Qwen3.5-27B/). This is intentional -- both models share the `qwen3_5` model_type. The code is identical; only the HuggingFace model ID and weights differ. + +## Maximum Sequence Length + +| seq_len | Path | Status | Notes | +|---------|------|--------|-------| +| 128 | BF16 NxDI | **PASS** | BF16 baseline/quality checks | +| 256 | BF16 NxDI | **PASS** | BF16 benchmark bucket | +| 512 | BF16 NxDI | **PASS** | 4 DeltaNet chunks | +| 65,536 | FP8/vLLM | **PASS** | chunked prefill, quality, and state-reset validation | +| 131,072 | FP8/vLLM | **PASS** | compiled and served with 512-token CTE bucket | + +For production long-context serving on trn2.3xlarge, use the FP8/vLLM artifact +and 512-token context encoding bucket. Larger instances are recommended for +larger batches or additional serving headroom. + +## Compatibility Matrix + +| Instance | TP | LNC | Status | Notes | +|----------|-----|-----|--------|-------| +| trn2.3xlarge | 4 | 2 | **PASS** | BF16 short-context and FP8 128K vLLM/APC validated | +| trn2.12xlarge | 16 | 2 | Expected PASS | Untested, recommended for batching/headroom | + +### SDK Configuration + +| Component | Version | +|-----------|---------| +| NxDI | 0.9.17334 | +| neuronx-cc | 2.24.5133 | +| torch | 2.9.1 | +| transformers | 4.57.6 | +| NKI | 0.3.0 | +| NXDI venv | `/opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/` | + +## Testing + +### Unit Tests (CPU only, no device needed) + +```bash +cd contrib/models/Qwen3.6-27B/ +# On DLAMI: source /opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/bin/activate +pytest test/unit/ -v +``` + +Tests: config parsing (26), weight conversion (16), hybrid cache manager (13), and DeltaNet decay handling (2) = **57 tests**. + +### Integration Tests (needs trn2.3xlarge with 4 NeuronCores) + +```bash +cd contrib/models/Qwen3.6-27B/ + +QWEN35_MODEL_PATH=/mnt/models/Qwen3.6-27B \ +QWEN35_COMPILED_PATH=/mnt/models/qwen36_traced \ +pytest test/integration/test_model.py --capture=tee-sys +``` + +Tests: model loads, generates, coherence, top-token valid, capital test, TTFT, throughput, multi-prompt = **8 tests**. + +Note: The env var is `QWEN35_MODEL_PATH` (not `QWEN36`) because the code uses the `qwen3_5` model_type internally. + +## Example Checkpoints + +- [`Qwen/Qwen3.6-27B`](https://huggingface.co/Qwen/Qwen3.6-27B) (BF16, ~52 GB) + +## Maintainer + +AWS Neuron + +**Last Updated:** 2026-04-23 diff --git a/contrib/models/Qwen3.6-27B/docs/CODEX_CONTINUOUS_BATCHING_PROMPT.md b/contrib/models/Qwen3.6-27B/docs/CODEX_CONTINUOUS_BATCHING_PROMPT.md new file mode 100644 index 00000000..7e758a85 --- /dev/null +++ b/contrib/models/Qwen3.6-27B/docs/CODEX_CONTINUOUS_BATCHING_PROMPT.md @@ -0,0 +1,193 @@ +# Codex Prompt — Enable Continuous Batching for Qwen3.6-27B + +## Context + +vLLM v1 already does continuous batching. The bottleneck is on the Neuron +side: the current MTP artifact (and baseline v3) was compiled with +`tkg_batch_size=1`, meaning the device can only execute one decode stream +per forward call regardless of how many vLLM tries to schedule. + +To enable real continuous batching: recompile with `tkg_batch_size > 1` so +the device-side decode graph processes multiple sequences in parallel. + +Current compile harness has it hardcoded: +- `contrib/models/Qwen3.6-27B/test/integration/qwen36_27b_compile_mtp.py:79-81` + - `batch_size=1, ctx_batch_size=1, tkg_batch_size=1` + +vLLM start script ALREADY wires `tkg_batch_size = MAX_NUM_SEQS` via +`--override-neuron-config`. This override only takes effect if the +underlying NEFF was compiled with the matching batch size. **No runtime +override of compile-time batch dimension is possible.** + +## Goal + +Compile and validate a continuous-batching artifact with `tkg_batch_size=8` +(and matching `batch_size=8`). Demonstrate aggregate throughput scaling +with `max-num-seqs=8` on real workloads. Document HBM peak. + +## Phase A: Compile harness update (target 0.5 day) + +A.1 Modify `contrib/models/Qwen3.6-27B/test/integration/qwen36_27b_compile_mtp.py` +to accept a CLI `--tkg-batch-size` argument (default 1 for backward compat). +Apply to NeuronConfig: +```python +batch_size = args.tkg_batch_size +ctx_batch_size = 1 # prefill stays single-stream per CTE call +tkg_batch_size = args.tkg_batch_size +``` + +A.2 Add `--max-num-seqs` and `--max-model-len` mirrors if not already present. + +A.3 Reduce `seq_len` for the first batched run to keep HBM in budget: +- batch=8, seq_len=16384 → KV cache ~8 GB, total HBM ~63 GB (fits) +- batch=8, seq_len=32768 → KV cache ~16 GB, total HBM ~71 GB (fits) +- batch=8, seq_len=65536 → KV cache ~34 GB, total HBM ~89 GB (tight) + +Start with batch=8, seq_len=32768. Validate. Push to 65536 only if HBM +budget allows. + +A.4 Document in the compile harness which compile flags need to match the +vLLM `--max-num-seqs` value at serve time. + +## Phase B: Compile + load (target 0.5 day) + +B.1 Compile artifact: +```bash +python contrib/models/Qwen3.6-27B/test/integration/qwen36_27b_compile_mtp.py \ + --model-path /opt/dlami/nvme/models/Qwen3.6-27B \ + --compiled-path /opt/dlami/nvme/qwen_artifacts/qwen36_27b_32k_fp8_batch8_run1 \ + --seq-len 32768 \ + --cte-bucket 512 \ + --tp-degree 4 \ + --logical-nc-config 2 \ + --tkg-batch-size 8 \ + --load-after-compile +``` + +Expected compile time: ~22 min. Slightly longer than batch=1 due to bigger +TKG graph. + +B.2 Load artifact on hardware. Verify load succeeds and HBM peak after +load is below 80 GB (leaves headroom for activations during inference). + +B.3 If NRT_RESOURCE (HBM blew up): drop seq_len to 16384 or batch to 4. +Report which tensor exceeded budget. + +## Phase C: Single-stream regression check (target 0.25 day) + +C.1 Bring up vLLM server with `--max-num-seqs 8` pointing to the new +artifact: +```bash +bash contrib/models/Qwen3.6-27B/vllm/start_vllm_server.sh \ + --compiled-artifacts /opt/dlami/nvme/qwen_artifacts/qwen36_27b_32k_fp8_batch8_run1 \ + --max-num-seqs 8 \ + --seq-len 32768 \ + --enable-chunked-prefill \ + --enable-prefix-caching +``` + +C.2 Single-stream smoke (one request at a time): +- Math: 17 × 23 should return 391 +- 762-token MGS prompt: coherent output + +C.3 Single-stream perf (one request at a time): +- 4K prompt, 128-token decode: measure prefill tok/s + decode tok/s +- Compare against baseline v3 (~418 prefill, ~27 decode) +- Expect: equivalent prefill, slightly lower decode (batch=8 graph has + some per-step overhead even at active batch=1). Acceptable: within 10%. + +If decode regresses more than 20% from baseline v3: there's a configuration +issue. Investigate before continuing. + +## Phase D: Aggregate throughput measurement (target 0.5 day) + +D.1 Use existing `validation_scripts/qwen36_27b_vllm_concurrency_eval.py` +or write a small async harness if needed. Test at: +- concurrency=1 (baseline) +- concurrency=2 +- concurrency=4 +- concurrency=8 (full batch) +- concurrency=16 (tests queueing behavior) + +D.2 Two prompt distributions: +- "Short": 1K prompts, 128 token decode (chat-like workload) +- "Medium": 8K prompts, 256 token decode (RAG-like workload) + +D.3 Capture for each (concurrency, prompt_len) point: +- Aggregate input tok/s +- Aggregate output tok/s +- Per-stream input tok/s +- Per-stream output tok/s +- P50 / P95 TTFT +- P50 / P95 inter-token latency (TPOT) +- HBM peak (neuron-monitor during run) + +D.4 Expected scaling pattern: +- concurrency=1: ~baseline single-stream +- concurrency=4: ~2-3× aggregate (sub-linear because per-stream slows) +- concurrency=8: ~3-5× aggregate at the batch ceiling +- concurrency=16: queued, aggregate same as 8 but TTFT spikes + +If concurrency=8 aggregate is NOT 3× the concurrency=1 number: batching +isn't actually happening at the device. Verify by checking +neuron-monitor: should see batch=8 graph activity, not batch=1. + +## Phase E: APC interaction (target 0.25 day) + +E.1 Add a shared prefix to the prompts (system message + variable user +turn). Repeat the concurrency=8 measurement. + +E.2 Expected: APC hit rate ≥ 50% across the 8 concurrent streams (they +share the system prompt). Aggregate prefill should jump significantly on +warm streams. + +E.3 If APC hit rate stays low when streams share a prefix: APC cache is +being evicted across concurrent streams. Investigate cache size limits. + +## Phase F: Documentation (target 0.25 day) + +F.1 Update `OPTIMIZATION_ARC.md` with the continuous-batching results: +- Add a "Continuous batching" row to the "What worked" table +- Update the "Hardware utilization" section with the aggregate numbers +- Update the "How this compares to NVIDIA" table with aggregate Trainium + numbers (the Millstone H100 page shows aggregate at 5 concurrent) + +F.2 Create `vllm/CONTINUOUS_BATCHING.md` with: +- Compile flags required +- vLLM serve flags required +- Measured throughput curve (concurrency 1-16) +- HBM budget table by (batch, seq_len) +- APC interaction notes + +## Hard constraints + +1. Do not modify baseline v3 artifact. Tag the new artifact as + `qwen36-27b-continuous-batching-v1` if all gates pass. +2. Commit + push after each phase. +3. Maximum compile attempts: 3. Each ~22 min. +4. If HBM exceeds 92 GB at batch=8: drop seq_len to 16384 and retry. +5. Do not enable MTP speculation in this artifact (defer to PR #4). Spec + decoding + continuous batching together is harder; tackle one at a time. + +## Expected outcomes + +| Outcome | Probability | Meaning | +|---|---|---| +| batch=8 compiles, loads, scales to 3-5× aggregate | 60% | Best case; ship as PR #2 (continuous batching baseline) | +| batch=8 compiles but scales less than 2× | 20% | Diagnose: probably KV cache contention or scheduler overhead | +| HBM blowup, must drop to batch=4 or seq_len=16K | 15% | Acceptable fallback; still 2-3× aggregate | +| Quality regression at batch>1 | 5% | Bug in hybrid cache at batch>1; investigate | + +Begin with Phase A. Report after Phase B. Do not chain phases. + +## Why this is high priority + +Currently single-stream decode is 27 tok/s. After continuous batching +with batch=8: +- Aggregate decode probably 150-200 tok/s (4-7× single) +- This is the metric that maps to "production serving capacity" +- Without it, you cannot answer "how many users can one instance serve?" +- Without it, the cost-per-token comparison vs H100 cannot be made + +This is the prerequisite for the multi-instance scaling discussion and +for any honest production-deployment claims. diff --git a/contrib/models/Qwen3.6-27B/docs/FULL_FP8_ISSUES_AND_FIXES.md b/contrib/models/Qwen3.6-27B/docs/FULL_FP8_ISSUES_AND_FIXES.md new file mode 100644 index 00000000..ec7a8b4a --- /dev/null +++ b/contrib/models/Qwen3.6-27B/docs/FULL_FP8_ISSUES_AND_FIXES.md @@ -0,0 +1,405 @@ +# Qwen3.6 27B Full FP8 — Issues Encountered and Fixes + +Consolidated catalog of every issue hit during the full-FP8 / 256K hybrid-APC +work and how each one was resolved. Branch: `codex/full-fp8-qwen36`. + +Source-of-truth detail (with exact log lines, PIDs, artifact paths) lives in: + +- [QWEN36_FP8_TIERFIX_VALIDATION_20260526.md](./QWEN36_FP8_TIERFIX_VALIDATION_20260526.md) — full chronological log +- [HYBRID_APC_PRODUCTION_PLAN.md](./HYBRID_APC_PRODUCTION_PLAN.md) — production bucket strategy +- [profile_artifacts/qwen36_256k_fp8_sparse_runtime_20260525/ERROR_LOG.md](../../../../profile_artifacts/qwen36_256k_fp8_sparse_runtime_20260525/ERROR_LOG.md) — runtime load failures +- [AGENTS.md](../../../../AGENTS.md) — error-logging contract and measurement rules + +This document is the index. Each entry: **what broke → why → what we changed → verification**. + +--- + +## Table of Contents + +1. [Quantization & Checkpoint Conversion](#1-quantization--checkpoint-conversion) +2. [Neuron Compiler Failures](#2-neuron-compiler-failures-neuronx-cc) +3. [vLLM / Hybrid APC / Scheduler](#3-vllm--hybrid-apc--scheduler) +4. [Runtime Load & Memory (NRT_RESOURCE / scratchpad / HBM)](#4-runtime-load--memory-nrt_resource--scratchpad--hbm) +5. [Custom NKI Kernel (`qwen_segcte256`)](#5-custom-nki-kernel-qwen_segcte256) +6. [Validation Harness & Measurement Bugs](#6-validation-harness--measurement-bugs) +7. [Tooling, Sync, Shell, SSH](#7-tooling-sync-shell-ssh) +8. [Lessons Codified in `AGENTS.md`](#8-lessons-codified-in-agentsmd) + +--- + +## 1. Quantization & Checkpoint Conversion + +### 1.1 MLP-only FP8 scope was insufficient for "full FP8" + +- **Symptom:** Original path only quantized MLP layers; attention, DeltaNet projections, and fused QKV stayed BF16. +- **Cause:** `_mlp_only_modules_to_not_convert` excluded entire `self_attn` and `linear_attn` modules; checkpoint rewrite only handled MLP scale tensors. +- **Fix:** Added `fp8_full` mode in [qwen36_27b_compile_fp8.py](../test/integration/qwen36_27b_compile_fp8.py); broadened module selector to all Linear matmuls (MLP + attention + DeltaNet `in_proj_*` / `out_proj`); kept embeddings, norms, rotary, `lm_head`, DeltaNet `conv1d`/`A_log`/`dt_bias` in BF16. +- **Verification:** Unit tests in [test_qwen36_compile_fp8_config.py](../test/unit/test_qwen36_compile_fp8_config.py). + +### 1.2 Scale tensors not transformed alongside weights + +- **Symptom:** Loading FP8 artifact failed because scale tensors didn't match the transformed weights (Q/gate split, fused QKV concat, DeltaNet QKV TP reorder). +- **Cause:** Checkpoint converter in [modeling_qwen35.py](../src/modeling_qwen35.py) only transformed `.weight`; FP8 needs the matching `.scale` to follow the same reorder/split/concat. +- **Fix:** Added scale-aware transforms in the converter for: q_proj weight/scale split → `output_gate_proj`, fused `Wqkv` weight+scale creation, and DeltaNet `in_proj_qkv.weight/scale` TP reorder using identical permutation. FP8 concat uses `view(torch.int8)` round-trip because PyTorch rejects direct `torch.float8_e4m3fn` concat. +- **Verification:** [test_weight_conversion.py](../test/unit/test_weight_conversion.py). + +--- + +## 2. Neuron Compiler Failures (`neuronx-cc`) + +### 2.1 `NCC_ITIN902 TensorInitialization` / `AffineIV doesn't appear in params or loopnest` + +- **Symptom:** Compiler internal error during NEFF tensorization on specific 2D prefix-caching bucket pairs: + - `cte=256, prefix=16384` + - `cte=1024, prefix=1024` + - `cte=2048, prefix=2048` + - `cte=4096, prefix=4096` +- **Cause:** Compiler bug in `neuronx-cc` lowering on power-of-two square shapes and the small-active/large-prefix corner. AWS log itself says "open a Neuron SDK issue." +- **Fix:** Avoid those exact shapes. Use safe CTE ladder `512, 768, 1536, 3072` (256-aligned, non-square) and limit prefix-bucket granularity. +- **Verification:** `cte512_768_1536_3072_pfx16k` compile completed with `COMPILE_DONE` and 0 `NCC_ITIN902`. + +### 2.2 Combined dense + long-prefix artifact — `[F137] neuronx-cc forcibly killed (-9)` + +- **Symptom:** Compiling all five long-prefix pairs (`3072:0,32k,64k,128k,256k`) in one run was killed by OOM on the largest buckets (`bk3`, `bk4`). +- **Cause:** Compile-host RAM pressure when multiple HLOs compile in parallel for very large shapes. +- **Fix:** Split tiers into separate compile runs. Implemented sequential orchestrator in [tmp_compile_qwen256k_fp8_full_prod_three_prefix_tiers_hostlogits.sh](../../../../tmp_compile_qwen256k_fp8_full_prod_three_prefix_tiers_hostlogits.sh). +- **Verification:** Three tiered artifacts (`pfx32k_64k`, `pfx128k`, `pfx256k`) all reached `COMPILE_DONE` on the same host once compiled sequentially. + +### 2.3 Bash script wrote artifact paths with spaces (`tkg32768 131072 262144`) + +- **Symptom:** Orchestrator created malformed paths because the launcher used `${ARR[*]}` instead of joining with `_`. +- **Cause:** Bash array word-splitting in label construction. +- **Fix:** Use explicit `IFS=_` join in the helper script. +- **Verification:** Relaunch produced underscore-only paths. + +### 2.4 `head_dim must be <= 128 (got 256)` — `NCC_INKI016` + +- **Symptom:** AWS Neuron 2.30 `attention_segmented_cte` kernel hard-asserts `head_dim <= 128`; Qwen3.6 uses `head_dim=256`. +- **Cause:** Kernel was not designed for 256-wide head dim. +- **Fix:** Wrote custom `qwen_segcte256` kernel that splits Q/K into two 128-wide D tiles and accumulates `Q_lo@K_lo + Q_hi@K_hi` into one PSUM before softmax. See [Section 5](#5-custom-nki-kernel-qwen_segcte256). +- **Verification:** Custom kernel BIR-compiled cleanly for production shape `q=(2,3072,256)`, `k/v=(1024,1,256,256)`, `prior_seg_size=32768`. + +### 2.5 `NCC_INLA001 Allocated memory out of bound (128x402724)` — SBUF scratch too large + +- **Symptom:** First version of the custom segmented CTE kernel compiled HLO but exceeded SBUF in the backend. +- **Cause:** Each Q group held its own K/V segment buffers + scratch live simultaneously; per-group block-dim allocation × 24 Q groups blew SBUF. +- **Fix:** Two-stage kernel rewrite in [fused_segmented_attention_256.py](../../../../src/neuronx_distributed_inference/modules/attention/nki_kernels/qwen_segcte256/fused_segmented_attention_256.py): + 1. Allocate one reusable Q-group window instead of `block_dim=[num_grps]`. + 2. Stream active CTE in 512-token chunks through the same online-softmax accumulator. + 3. Cap packed Q loads to 4 groups (instead of 8) for `head_dim=256`. +- **Verification:** Production-shape BIR scratch dropped from `402724` to `31360`, under the 32767 SBUF free-dim limit. Full compile produced `COMPILE_DONE`. + +--- + +## 3. vLLM / Hybrid APC / Scheduler + +### 3.1 64 GiB KV cache estimate on 96 GiB Trn2 (model wouldn't even start) + +- **Symptom:** vLLM rejected 256K context with `64.0 GiB KV cache needed, 39.12 GiB available`. +- **Cause:** vLLM-Neuron runner created `FullAttentionSpec` for all 64 layers. Qwen3.6 is hybrid — only 16 of 64 layers are full-attention; the other 48 are DeltaNet (no token-long KV). +- **Fix:** Patched `get_kv_cache_spec` in [qwen36_hybrid_apc_scheduler_patch.py](../vllm/qwen36_hybrid_apc_scheduler_patch.py:129) to report KV only for the 16 full-attention layers, with local KV heads per TP rank. +- **Verification:** Server log: `Using Qwen hybrid KV-cache spec for 16/64 attention layers`, `GPU KV cache size: 262,400 tokens`. + +### 3.2 Warm prefix-cache continuation crashed with "no `hybrid_full_input_ids`" + +- **Symptom:** Second request reusing a cached prefix died because runner received suffix-only tokens without the full prompt context the GDN path needs. +- **Cause:** Scheduler metadata didn't carry full `all_token_ids` through suffix-prefill requests; runner's strict guard rightly rejected. +- **Fix:** Scheduler patch now attaches `full_input_ids` only when `num_computed_tokens > 0` (cached continuation), not for the first cold chunk; async prep bridge unpacks it back to `hybrid_full_input_ids` and slices to active suffix length. +- **Verification:** [test_hybrid_apc_manager.py](../test/unit/test_hybrid_apc_manager.py) + working 8k→16k→18432 cold/warm exactness on TRN2. + +### 3.3 `request_prefix_len` polluted by generated tokens + +- **Symptom:** During decode, vLLM `request.num_tokens` grows past the original prompt length; that leaked into APC metadata and made cold vs warm runs schedule differently. +- **Cause:** Metadata used `request.num_tokens` instead of the original prompt length. +- **Fix:** Cap `request_prefix_len` to the prompt-only token count. +- **Verification:** Cold and warm 8k runs now schedule identically (`prompt_len=8192 restore_len=6144 suffix_len=2048`). + +### 3.4 Dummy token `0` (`!!!!`) leaked into output during chunked prefill + +- **Symptom:** Cold output started with two `0` tokens before real decoding; warm output started correctly. +- **Cause:** vLLM-Neuron host-logits sampling appended `sampled_token_ids` from incomplete chunked-prefill rows. Earlier mask attempt was on the wrong path (`_sample_on_device` instead of `_generate_model_runner_output` for `hostlogits` artifacts). +- **Fix:** Added `_generate_model_runner_output` wrapper that masks incomplete-prefill rows before vLLM appends sampled IDs to request state. Used scalar `.item()` to work on Neuron/XLA tensors. +- **Verification:** 8k cold and warm now both emit `[3817, 7840, 9197, 4590]`. Test coverage added. + +### 3.5 Chunked prefill above 16k prefix exceeded compiled prefix bucket + +- **Symptom:** Cold 32k prompt failed with `Prefill len 512 with prefix len 16896 exceeds compiled 2D buckets... largest prefix bucket 16384`. +- **Cause:** vLLM's chunked-prefill continuation presents (active_chunk, computed_prefix) shapes to NxDI's 2D bucket selector; with `pfx16k`, a 32k prompt eventually reaches prefix `16896`. +- **Fix:** Two approaches: + 1. Runtime cap on backed prefix reads (`QWEN36_HYBRID_APC_MAX_BACKED_PREFIX_READ_LEN`). + 2. Split production artifacts by prefix tier; route long contexts to the long artifact. +- **Verification:** Tiered split (`pfx16k` for short, `pfx32k_64k`, `pfx128k`, `pfx256k` for long) compiled and validated for context up to 128K. 256K required the custom kernel (see §5). + +### 3.6 Sparse 2D bucket support needed — runtime assumed full Cartesian grid + +- **Symptom:** Wanting sparse pairs like `cte=3072 × prefix=262144` only (without the failing `cte=512 × prefix=262144`) wasn't possible — runtime did `bucket_idx = prefill_index * len(prefix_buckets) + prefix_index`. +- **Cause:** NxDI runtime hard-assumed a rectangular bucket grid. +- **Fix:** Added `context_encoding_bucket_pairs` config + sparse-pair-aware runtime selection in [model_wrapper.py](../../../../src/neuronx_distributed_inference/models/model_wrapper.py:1126) and [autobucketing.py](../../../../src/neuronx_distributed_inference/modules/autobucketing.py:162). Wired through compile script and vLLM serving config. +- **Verification:** Unit tests in [test_autobucketing.py](../../../../test/unit/modules/test_autobucketing.py) and [test_prefix_caching_bucket_selection.py](../../../../test/unit/models/test_prefix_caching_bucket_selection.py). + +### 3.7 Async sample called before any `execute_model()` (V1 scheduler) + +- **Symptom:** With async scheduling, `sample_tokens()` was invoked once before any cached logits existed; Neuron runner raised. +- **Cause:** vLLM-Neuron runner had no "no output yet" guard like the GPU path. +- **Fix:** Added no-output guard in the runner wrapper. + +### 3.8 Contract mismatch: `expected 24/29 tensors, got 15` + +- **Symptom:** With prefix caching disabled at vLLM level but Hybrid APC enabled, the model wrapper got only 15 mandatory tensors while artifact expected 29. +- **Cause:** Compiled artifact's input contract is fixed at trace time. Runtime config flipping `is_prefix_caching` off without recompiling broke the contract. +- **Fix:** + 1. Server script preserves the compiled `is_prefix_caching` contract from `neuron_config.json` even when vLLM-level prefix caching is off. + 2. Qwen wrapper expands 15-tensor runtime input to 24/29-tensor traced form by padding with inert MRoPE/vision/tile tensors. +- **Verification:** Unit coverage in [test_qwen36_model_aliases.py](../test/unit/test_qwen36_model_aliases.py). + +--- + +## 4. Runtime Load & Memory (`NRT_RESOURCE` / scratchpad / HBM) + +### 4.1 Combined sparse artifact failed to load on `trn2.3xlarge` + +- **Symptom:** Artifact compiled fine, but TRN2 load failed with `Failed to allocate 1.000GB ... usage: shared scratchpad` at `_tp0_bk36` (long-prefix NEFF). +- **Cause:** Trn2.3xlarge has 96 GiB total but in four 24 GiB HBM banks under LNC=2. Per-bank usage hit `~22-24 GiB` (tensors + scratchpad) before runtime needed another aligned 1 GiB allocation. The "combined" artifact loaded **every** compiled CTE×prefix NEFF at once. +- **Fix:** Physical split into multiple artifacts; route requests to the smallest artifact that covers the prefix tier. +- **Verification:** `pfx32k_64k` and `pfx128k` loaded and ran end-to-end after the split. + +### 4.2 Runtime bucket-override JSON didn't reduce loaded NEFFs + +- **Symptom:** Setting `--context-encoding-bucket-pairs 512:0 512:512` at runtime still failed at `_tp0_bk36` load. +- **Cause:** Saved `model.pt` references all compiled workdir NEFFs; runtime overrides control routing, not which NEFFs get staged. +- **Fix:** Per §4.1 — split artifacts physically; runtime overrides alone are insufficient. + +### 4.3 `NEURON_SCRATCHPAD_PAGE_SIZE=2048` did not help + +- **Symptom:** Tried larger scratchpad page size to relieve alignment pressure; still failed with `Failed to allocate 2.000GB`. +- **Cause:** Total scratchpad footprint, not just alignment fragmentation. +- **Fix:** Abandon page-size-only mitigation for over-broad artifacts; compile narrower artifacts. + +### 4.4 Initial three-tier artifacts compiled but failed to load + +- **Symptom:** `pfx32k_64k`, `pfx128k`, `pfx256k` all compiled with `seq_len=262144`, `pa_num_blocks=1024`, `tkg=[32768,131072,262144]` — and all failed `NRT_RESOURCE` at load. +- **Cause:** "Tiered" by prefix only; every tier still paid the full 256K cache and 3 TKG buckets. +- **Fix:** True tier-specific budgets in [tmp_compile_qwen256k_fp8_full_prod_three_prefix_tiers_hostlogits.sh](../../../../tmp_compile_qwen256k_fp8_full_prod_three_prefix_tiers_hostlogits.sh): + - `pfx32k_64k`: `seq_len=65536`, `pa_num_blocks=256`, `tkg=[32768,65536]`, keep dense `3072:0`. + - `pfx128k`: `seq_len=131072`, `pa_num_blocks=512`, `tkg=[131072]`, omit dense `3072:0`. + - `pfx256k`: `seq_len=262144`, `pa_num_blocks=1024`, `tkg=[262144]`, omit dense `3072:0`. +- **Verification:** All three tierfix artifacts loaded; `pfx32k_64k` and `pfx128k` passed prefill + chat. + +### 4.5 Device profiling caused `NRT_RESOURCE` on `pfx256k` load + +- **Symptom:** First 256K runtime validation died because `NEURON_RT_INSPECT_DEVICE_PROFILE=1` reserved `2.348 GB HBM per NC`, pushing per-bank load over the edge. +- **Cause:** Device profiler adds non-trivial HBM tax. +- **Fix:** Run validation without `NEURON_RT_INSPECT_DEVICE_PROFILE`. Profile separately on smaller artifacts or with reduced sampling. + +### 4.6 Null block (vLLM adds 1 reserved block) — `pa_num_blocks=1024` was too small + +- **Symptom:** vLLM logs showed `num_gpu_blocks` becoming `1025` after the runtime adds a reserved null block, but compiled artifact only had 1024 physical blocks. +- **Cause:** Off-by-one between compile-time `pa_num_blocks` and runtime "user-usable + null" convention. +- **Fix:** Compile with `pa_num_blocks=1025`. Validation runners now treat compiled count as physical (includes null) and set `num_gpu_blocks_override` to `compiled - 1`. +- **Verification:** Compile config logged `pa_num_blocks=1025, pa_min_blocks=1024, pa_headroom_blocks=1`. Updated [qwen36_hybrid_apc_context_sweep.py](../../../../validation_scripts/qwen36_hybrid_apc_context_sweep.py) + [qwen36_offline_decode_bench.py](../../../../validation_scripts/qwen36_offline_decode_bench.py). + +--- + +## 5. Custom NKI Kernel (`qwen_segcte256`) + +Required because AWS Neuron 2.30 `attention_segmented_cte` rejects `head_dim > 128`. + +### 5.1 `dma_copy dst partition dimension 256 exceeds maximum 128` + +- **Symptom:** BIR compile failed when loading K cache: K SBUF tile shape `(256, 512)` violated the 128-partition rule. +- **Cause:** Tried to keep `head_dim=256` on the partition axis. +- **Fix:** Load each 256-token KV block as two 128-token halves: temp `(128, 128)`, transpose each, write into 128-token offset inside K tile. + +### 5.2 `unsupported expression` — list comprehensions + +- **Symptom:** `[(k_lo[i], k_hi[i]) for i in range(...)]` rejected by NKI specialization. +- **Cause:** NKI front-end doesn't accept Python list comprehensions inside kernel helpers. +- **Fix:** Build the list with explicit `for ... append`. + +### 5.3 `failed to resolve name 'x::0.shape'` + +- **Symptom:** After splitting K into `(lo, hi)` pair, old metadata lookup `k_sbuf[0].shape[1]` returned `.shape` from the pair tuple. +- **Fix:** Branch the metadata lookup to use `k_sbuf[0][0].shape[1]` on the split-K path. + +### 5.4 `dma_transpose dst.shape must match transposed src.shape` + +- **Symptom:** Q load pattern used `ac.d=256` as D extent while destination was 128. +- **Fix:** Use 128-wide D extent in source pattern: `[[ac.d, num_f], [1,1], [1,1], [1,128]]`. + +### 5.5 `reduce_one_batch` signature mismatch + +- **Symptom:** Compile failed with `batch_idx * sb_p * num_grps` where `batch_idx` was an object. +- **Cause:** Copied call signature didn't match installed Neuron 2.30 helper's argument order. +- **Fix:** Call with explicit keyword arguments matching the installed helper. + +### 5.6 `NCC_INLA001 Allocated memory out of bound (128x402724)` + +- See [§2.5](#25-ncc_inla001-allocated-memory-out-of-bound-128x402724--sbuf-scratch-too-large). The fix (group-window aliasing + active streaming + Q-pack cap) reduced production-shape SBUF scratch from `402724` to `31360`. + +### 5.7 `_exp_impl` partial-sum slot index out of range + +- **Symptom:** Active-streaming variant tried to index exp partial-sum slot 1 when each 512-token chunk only allocated slot 0. +- **Cause:** Active attention config still referenced full active KV length per chunk instead of per-chunk view. +- **Fix:** Specialize the active attention config per chunk with that chunk's `global KV end`, retain global K start via `kv_section_idx`. + +### 5.8 Runtime `scalar DGE out-of-bound access` at 256K prefill (PA-blocks) + +- **Symptom:** Compile passed, model loaded, KV initialized, then context-encoding NEFF crashed mid-execution with repeated scalar-DGE OOB. +- **First hypothesis tested:** vLLM adds a null block (`1025` physical), but artifact had `pa_num_blocks=1024`. +- **Fix attempted:** Recompiled with `pa_num_blocks=1025`. **Did not fix it** — runtime still hit DGE OOB on the new artifact. + +### 5.9 Runtime DGE OOB — root cause: final partial active chunk reads past block table + +- **Symptom:** Even with `pa1025`, the 261,888-token prefill failed in `context_encoding_model/_tp0_bk0` with scalar-DGE OOB. +- **Cause:** Active stream loop always processed 6 full sections per CTE bucket, even when the final real active chunk was only 768 tokens. At the end of the 256K prompt, the kernel read block-table offsets beyond the 1024-entry table. +- **Fix:** Pad the kernel's internal block table by the CTE active block count (1024 → 1036 entries). Padded active stream loads resolve to block 0 instead of reading past the table. +- **Verification:** Bound-fix artifact compiled (`COMPILE_DONE`), and the no-device-profile 256K runtime validation passed: + - Cold 261,888 prefill: `551.97s` + - Warm refill (16-token suffix on shared 261,872-token prefix): `10.76s` + - Cold throughput: `474.46 tok/s`; warm refill throughput: `24,342.68 tok/s` + - Real-token + token-range checks: passed + - Host RSS peak: `35.31 GiB`; Neuron active allocation peak: `~28 GiB`; high-water counter: `58 GiB` + +### 5.10 Block-table active-block-fill (necessary but not sufficient) + +- **Symptom:** Earlier hypothesis was that `block_table` had `-1` entries for the active suffix. +- **Fix attempted:** Fill active block ids from `slot_mapping // pa_block_size` before NKI dispatch. Aligned with AWS docs on `nisa.dma_copy` dynamic addressing. +- **Result:** Helped the 8K smoke test but did NOT fix the 256K case. The real bug was §5.9. + +### 5.11 Production envelope and fail-closed hardening + +- **Finding:** The bound-fix 256K artifact has strong validation evidence, but only for the exact serving envelope: 256K context, `pa_num_blocks=1025`, one `cte3072:pfx262144` bucket, `qwen_segcte256` segment size 512, batch/concurrency 1, backed prefix reads, non-KVP, and non-transposed K cache. +- **Risk:** Enabling Hybrid APC outside the blessed vLLM launcher could previously fall back to local prompt hashing or synthetic attention block refs. +- **Fix:** When `use_hybrid_apc_manager=True`, `Qwen35InferenceConfig` now defaults to requiring vLLM metadata and attention block refs, with local hash fallback disabled. Validation-only flows can still opt back into local fallback explicitly. +- **Risk:** The generic ModelWrapper used absolute Hybrid APC control positions (`args[25]`) for restore-active detection. +- **Fix:** Restore-active detection now reads from the final five Hybrid APC control args, so future pre-control extras do not silently misbucket CTE. +- **Risk:** `qwen_segcte256` still exposed KVP and transposed-K branches that were not validated for production and contained NKI 0.3-sensitive HBM output/intermediate patterns. +- **Fix:** `qwen_segcte256` now raises immediately for `kvp_offset`/KVP or `k_pre_transposed=True`. The validated production path remains the non-KVP, non-transposed K path used by `attention_base.py`. + +--- + +## 6. Validation Harness & Measurement Bugs + +### 6.1 TPOT measured from streamed content chunks, not tokens + +- **Symptom:** Reported TPOT was `~109 ms/chunk` at 16K context (with 16 generated tokens → only 8 streamed content chunks), masking real decode speed. +- **Fix:** [qwen36_chat_completion_context_bench.py](../../../../validation_scripts/qwen36_chat_completion_context_bench.py) now requests `stream_options: {"include_usage": true}` and computes `token_tpot_seconds` from `usage.completion_tokens`. Old chunk metric preserved as `content_chunk_tpot_seconds`. +- **Verification:** Corrected 16k pfx16k measurement: `~50-52 ms/token`, `~19-20 decode tok/s`. + +### 6.2 "Warm prefill" was actually full-prompt cache replay + +- **Symptom:** Sub-second warm runs were misinterpreted as refill speed. +- **Cause:** [qwen36_hybrid_apc_context_sweep.py](../../../../validation_scripts/qwen36_hybrid_apc_context_sweep.py) generated the exact same prompt twice — that's an exact cache hit, not a refill. +- **Fix:** Default warm mode now: warm shared prefix + suffix A, then measure shared prefix + suffix B. +- **Verification:** Corrected 16k partial refill: `0.91s` for a 16,368-token shared prefix → ~`18k tok/s` reuse rate. + +### 6.3 Sweep accepted dummy token `0` (`!!!!`) as "real" output + +- **Symptom:** Validator's `vocab_size=248044` check passed because token `0` was within range, masking the chunked-prefill output leak. +- **Fix:** Tighter validation: explicitly fail if all generated tokens equal the configured dummy id, regardless of vocab bounds. Then use `usage.completion_tokens` + tokenizer/AutoConfig vocab fallback for true range check. + +### 6.4 Hardcoded `seq_len=262144, pa_num_blocks=1024` for every tier + +- **Symptom:** Three-tier validation runner forced 256K cache shape on the 64K and 128K tiers, causing `NRT_RESOURCE`. +- **Fix:** [tmp_run_qwen256k_fp8_tierfix_validation.sh](../../../../tmp_run_qwen256k_fp8_tierfix_validation.sh) now uses per-tier `(seq_len, pa_num_blocks, tkg buckets)`. + +### 6.5 Chat wrapper passed `--pa-num-blocks` to server script (unknown arg) + +- **Symptom:** `start_vllm_server.sh` rejected `--pa-num-blocks`. +- **Fix:** Pass `--pa-num-blocks` only to offline benchmarks; server gets `--num-gpu-blocks-override` via the appropriate path. + +### 6.6 Memory sampler could hang the wrapper if benchmark never started + +- **Symptom:** With `--stop-when-no-match`, sampler waited forever if vLLM died during startup. +- **Fix:** Sampler ignores its own PID, handles SIGTERM/SIGINT to write summary JSON; wrapper explicitly stops sampler per phase rather than relying on regex disappearance. Sampler regex broadened to match vLLM server processes during startup. + +### 6.7 `start_vllm_server.sh` forced `ENABLE_PREFIX_CACHING=1` when `--enable-hybrid-apc` + +- **Symptom:** Couldn't test "Hybrid APC on, prefix-cache reads off" because flags were coupled. +- **Fix:** Split controls: `ENABLE_PREFIX_CACHING`, `ENABLE_HYBRID_APC`, `HYBRID_APC_DISABLE_UNBACKED_PREFIX_READS`, `HYBRID_APC_ENABLE_BACKED_PREFIX_READS`, and `QWEN36_HYBRID_APC_INSTALL_PATCH` are now independent. + +### 6.8 Validator's `vocab_size` check rejected legitimate model tokens + +- **Symptom:** Model emitted token `248068`, valid for the loaded model (`vocab_size=248320`) but above the tokenizer's base `vocab_size=248044`. +- **Fix:** Use `max(tokenizer.vocab_size, len(tokenizer), AutoConfig.vocab_size)` as the upper bound. + +--- + +## 7. Tooling, Sync, Shell, SSH + +### 7.1 Stale remote code (no `--context-encoding-bucket-pairs`) + +- **Symptom:** Remote compile script lacked sparse-pair CLI flag even though local repo had it. +- **Fix:** Sync the compile entrypoint along with runtime code; `bash -n` + `py_compile` checks before launching. + +### 7.2 `scp` multi-file → wrong directory + +- **Symptom:** Multi-file `scp` of mixed sources landed extra copies in the last directory. +- **Fix:** Use explicit per-file destinations or `rsync -R`. Cleaned the misplaced copies and removed them. + +### 7.3 Local zsh expanded `*` in remote command + +- **Symptom:** `ssh host "find /path/*"` failed locally because zsh tried to glob the path on the Mac. +- **Fix:** Quote remote command bodies; use single quotes around the SSH command argument. + +### 7.4 `rsync --info=stats2` rejected by macOS BSD rsync + +- **Fix:** Use portable `--stats`. + +### 7.5 SSH `Permission denied (publickey)` for EC2-to-EC2 transfers + +- **Symptom:** Source EC2 had no key for destination. +- **Fix:** Three options used at various times: + 1. SSH agent forwarding from local `trainium.pem`. + 2. Temporary ed25519 key created on source, authorized on destination, removed after transfer. + 3. `scp -3` relay through local (slow — avoid for large artifacts). + +### 7.6 `pkill -f` matched its own SSH command, killed the shell + +- **Symptom:** Cleanup SSH exited 255 with no output because broad `pgrep -f` pattern matched the SSH command line itself. +- **Fix:** Use explicit PIDs from prior status or narrower patterns; never use `pkill -f` patterns that could match the controlling shell. + +### 7.7 Remote `python` not on PATH + +- **Symptom:** Status/parsing commands failed with `python: command not found`. +- **Fix:** Use `python3` for remote helpers; activate Neuron venv for actual runtime work. + +### 7.8 Overlay venv missing PyTorch / `libneuronpjrt-path` + +- **Symptom:** Neuron 2.30 overlay venv had `nki 0.4` but no PyTorch; later, `torch_xla` failed to find the base venv's `libneuronpjrt-path` helper. +- **Fix:** Compile launcher adds base venv `site-packages` behind overlay, and base venv `bin` to `PATH`. + +### 7.9 Backgrounded shell ate `$ROOT` (`tee /run.log`) + +- **Symptom:** Wrapper backgrounded too broadly; nested var expansion broke; ended up writing to `/run.log` and python wasn't on PATH. +- **Fix:** Cleaner wrapper structure: start sampler separately as a tracked nohup PID; run benchmark in main subshell with explicit env activation. + +### 7.10 TRN2 SSH banner timeout / port unreachable during heavy compile + +- **Symptom:** SSH banner exchange timed out, then later TCP itself stopped. Local AWS CLI had stale credentials so couldn't inspect instance state. +- **Mitigation:** Use light periodic probes (not long live-tails) during heavy compiles; keep heartbeat automation as the resume signal. + +--- + +## 8. Lessons Codified in `AGENTS.md` + +Two operational rules added to the repo's [AGENTS.md](../../../../AGENTS.md): + +1. **Error-logging contract:** every error gets logged with what failed, exact error text, how we got there, hypothesis, fix, and verification — enough detail that another agent can reconstruct it. +2. **Measurement discipline:** + - TPOT must come from `usage.completion_tokens` (request `stream_options.include_usage`), not streamed content chunks. + - "Warm refill" requires a shared prefix + different suffix; identical prompts only measure exact cache hits. + - Record artifact, CTE buckets, prefix buckets, and whether backed prefix reads were enabled with every reported number. + +--- + +## Final State Summary + +| Tier | Artifact | Status | +|---|---|---| +| 16K | `cte512_768_1536_3072_pfx16k` | **Production-validated** — chat + multi-turn smoke pass | +| 64K | `pfx32k_64k_pa256` | **Loads + runs**, prefill + chat pass | +| 128K | `pfx128k_pa512` | **Loads + runs**, prefill + chat pass | +| 256K | `pfx256k_segcte512stream_qpack4_boundfix_pa1025` | **Validated only for the exact gated config** — cold `551.97s`, warm refill `10.76s`, real tokens validated | + +**Open work before general "production-ready":** repeat 256K runs (×3-5), full OpenAI server path test on `pfx256k`, multi-turn chat at long context, soak/load test, and fresh validation for any other bucket, KVP, transposed K cache, sliding-window, or multi-seq serving configuration. diff --git a/contrib/models/Qwen3.6-27B/docs/HYBRID_APC_PRODUCTION_PLAN.md b/contrib/models/Qwen3.6-27B/docs/HYBRID_APC_PRODUCTION_PLAN.md new file mode 100644 index 00000000..571b0266 --- /dev/null +++ b/contrib/models/Qwen3.6-27B/docs/HYBRID_APC_PRODUCTION_PLAN.md @@ -0,0 +1,364 @@ +# Qwen3.6 Hybrid APC Production Plan + +## Build Order + +```text +1. Production hybrid APC correctness +2. Dynamic CTE bucket serving +3. Block-size, bucket, and HBM tuning +4. GDN state dtype and memory optimization +5. Decode-side improvements +6. Kernel fusion and speculative decode +``` + +Do not start with FP8 recurrent cache, MTP, EAGLE, Medusa, flash decode, KV +tiling, or deeper GDN kernel fusion. Those add scheduler and rollback +complexity before the cache contract is correct. + +## Target Cache Object + +```text +HybridPrefixCheckpoint + cumulative_prefix_hash + token_ids_hash + cache_salt / tenant key + prefix_length_at_boundary + + attention: + per-attention-layer KV block refs + + gdn: + per-GDN-layer recurrent_state checkpoint + per-GDN-layer conv_state checkpoint + + metadata: + dtype + layout_version + model_revision + ref_count + last_access_time + valid_state_mask +``` + +The usable hit is the deepest cumulative-prefix boundary where all required +state exists: + +```text +usable_hit_len = + intersection( + attention_KV_full_block_hit, + all_GDN_recurrent_prefix_checkpoint_hits, + all_GDN_conv_prefix_checkpoint_hits + ) +``` + +If attention KV hits 16K but GDN state only hits 12K, suffix prefill must resume +from 12K. + +## Qwen3.6 GDN State + +At every reusable cumulative-prefix boundary, cache: + +```text +recurrent_state: [num_local_value_heads, key_dim, value_dim] +conv_state: [conv_dim, conv_kernel_size - 1] +``` + +Initial dtype policy: + +```text +attention KV: bfloat16 +GDN conv_state: bfloat16 +GDN recurrent_state: float32 +``` + +Conv state is small but correctness-critical. Recurrent state dominates GDN +cache memory and should remain FP32 until BF16 exactness is proven. + +## Restore Flow + +For prompt length `P` and hybrid hit length `H`: + +```text +cached prefix: tokens [0, H) +suffix prefill: tokens [H, P) +decode: tokens [P, ...) +``` + +Serving path: + +```text +1. vLLM hashes prompt blocks. +2. Hybrid APC computes usable H. +3. Restore attention block table for [0, H). +4. Restore GDN recurrent_state at H. +5. Restore GDN conv_state at H. +6. Send only suffix tokens [H, P) to Neuron CTE. +7. Position IDs start at H. +8. Attention suffix attends to cached KV plus new suffix KV. +9. GDN recurrence starts from restored recurrent_state. +10. GDN conv starts from restored conv_state. +11. Store new boundary checkpoints for newly completed blocks. +12. Decode uses final restored and updated state. +``` + +## Sprint Plan + +### Sprint 1: Correctness Foundation + +Build: + +```text +HybridAPCManager +GDN recurrent/conv prefix-boundary checkpoint cache +hybrid hit intersection +partial-prefix restore path +FP32 recurrent cache option +correctness tests +``` + +Success criteria: + +```text +warm full-prefix output == cold output +partial-prefix output == cold output +attention-only false hit cannot happen +concurrent requests do not leak state +``` + +Current v0 branch status: + +```text +implemented: + HybridAPCMetadataStore for cumulative-prefix checkpoint metadata + bounded model-side HybridGDNCheckpointCache tensor bank + model restore/commit slot inputs + use_hybrid_apc_manager initialization without the old guard + v0 launcher validation requiring checkpoint interval == block size + async prefix-caching bridge for scheduler-supplied restore/commit tensors + request finish/cancel lifecycle callbacks for checkpoint refcounts + Trainium exactness and HBM validation harness + +still required before production: + vLLM scheduler integration that computes cumulative-prefix hashes and slots + Trainium execution of cold/warm exactness harness on compiled artifacts + production cancellation/eviction callback wiring from vLLM events + long-context HBM sweep to choose checkpoint slot count and commit policy + larger production prefix buckets for 32K+ warm reuse +``` + +Production prefix-bucket plan: + +```text +Previous 256K FP8 artifact was correct only up to its compiled prefix bucket +coverage: + prefix_buckets = [256, 512, 1024, 2048, 4096, 8192, 16384] + +32K/64K/128K contexts can still run on the 256K artifact, but warm APC reuse +above 16K must replay the remainder. This is correct but slower. + +Production strategy is one sparse 2D CTE/prefix artifact, not two separate +models: + + dense fast path: + CTE buckets = [512, 768, 1536, 3072] + prefix buckets = [0, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768] + + long-prefix fallback: + [CTE 3072, prefix 65536] + [CTE 3072, prefix 131072] + [CTE 3072, prefix 262144] + +The dense fast path is for common short/normal cached prefixes and preserves +prefill speed by avoiding unnecessary padding to 3072. The sparse long-prefix +fallback enables 64K/128K/256K prefix reuse without compiling the full CTE x +prefix Cartesian grid that triggers Neuron compiler tensorization failures. +``` + +Implementation notes: + +```text +compile flag: + --context-encoding-bucket-pairs ACTIVE:PREFIX ... + +runtime behavior: + Prefix-caching CTE bucket selection now chooses the smallest actual compiled + [active_tokens, prefix_tokens] pair that can serve the request, instead of + assuming every CTE bucket exists for every prefix bucket. + +serving behavior: + vLLM override config forwards context_encoding_bucket_pairs so loaded + artifacts use the same sparse matrix they were compiled with. +``` + +## Fixed Bug Record: Neuron Tensorization Failure on Full 2D Prefix Grid + +```text +What failed: + 256K FP8 full Hybrid APC compile with pfx256k and multiple CTE buckets. + +How we got there: + Host: ubuntu@16.26.202.235 + Script: + tmp_compile_qwen256k_fp8_full_cte512_768_1536_3072_pfx256k_hostlogits.sh + Key args: + --seq-len 262144 + --max-context-length 262144 + --cte-buckets 512 768 1536 3072 + --prefix-buckets 256 512 1024 2048 4096 8192 16384 32768 65536 131072 262144 + --weight-dtype fp8_full + --enable-prefix-caching + --enable-hybrid-apc + --enable-vllm-chunked-prefill + +Exact error: + NCC_ITIN902 TensorInitialization error: + AffineIV doesn't appear in params or loopnest + +Failed generated buckets: + bk9 = [CTE 512, prefix 65536] + bk10 = [CTE 512, prefix 131072] + bk21 = [CTE 768, prefix 65536] + bk22 = [CTE 768, prefix 131072] + +Root cause hypothesis: + HLO generation succeeds, then neuronx-cc fails inside internal tensorization + for some small-active-token / large-prefix-token 2D prefix-cache shapes. This + is a Neuron compiler lowering bug, not disk pressure and not an invalid model + config. + +Fix: + Stop compiling the full Cartesian product. Add explicit sparse + context_encoding_bucket_pairs and route runtime selection over the actual + compiled pair list. + +Mitigation shape set: + Dense fast path only up to 32K prefix for all production CTE buckets: + [512/768/1536/3072] x [0..32768] + Long-prefix fallback only on largest CTE bucket: + [3072, 65536], [3072, 131072], [3072, 262144] + +Verification: + Unit/config tests passed: + 38 local contrib tests passed + 86 remote Neuron-env focused tests passed + Sparse high-prefix probe compile started with 7 CTE HLOs and no NCC_ITIN902 + observed at HLO generation time; final NEFF compile result must still be + checked before treating the sparse artifact as production-ready. +``` + +## Fixed Bug Record: Invalid Fast Warm Prefill + +This bug is useful to showcase because the first symptom looked like excellent +performance, but the warm path was not executing the same model semantics as +cold prefill. + +```text +Symptom: + Warm prefill appeared sub-second, but cold/warm generated token IDs diverged. + Cold also leaked placeholder token IDs: + cold = [0, 0, 3817, 7840] + warm = [3817, 7840, 9197, 4590] + +Root causes: + 1. vLLM attention prefix hits could exceed the deepest GDN checkpoint that + was actually available. + 2. Scheduler metadata used request token counts that could include generated + tokens instead of prompt-only tokens. + 3. Incomplete chunked-prefill rows in the host-logits path could append + placeholder sampled IDs as real generated tokens. + +Fix: + 1. Cap vLLM prefix-cache reads to the largest GDN-backed checkpoint. + 2. Build Hybrid APC metadata from prompt-only length/token IDs. + 3. Mask incomplete chunked-prefill sampled IDs to -1 before vLLM appends + them to request state. + +Evidence after fix: + 8K cold/warm exactness passed: + cold = [3817, 7840, 9197, 4590] + warm = [3817, 7840, 9197, 4590] + repeat_exact = true + + Warm prefill became slower than the invalid shortcut, but correct: + cold ~= 15.26s + warm ~= 4.95s +``` + +### Sprint 2: Dynamic CTE Buckets + +Build: + +```text +multi-bucket CTE artifact path +runtime suffix bucket selection +262K TP=4 [256] artifact +block_size 128/256 comparison +``` + +Success criteria: + +```text +short prompts retain 1.5x-2.3x latency gain +262K TP=4 [256] loads +TP=4 beats TP=8 unless TP=4 cannot load +``` + +### Sprint 3: Memory and HBM Tuning + +Build: + +```text +GDN recurrent state slot accounting +eviction/ref-count policy +FP32 vs BF16 recurrent experiment +attention KV memory report +hybrid cache memory dashboard +``` + +### Sprint 4: Decode Optimization + +Build: + +```text +lower-overhead GDN state gather/scatter +decode microbenchmarks +batch-slot reuse optimization +possibly fused recurrent step +``` + +## Test Matrix + +Correctness: + +```text +cold vs warm exact token IDs +partial-prefix exact match +non-block-aligned shared prefix floors to full block +attention hit with missing GDN state falls back +conv-state restore failure test by zeroing conv state +multi-hit chat simulation +mixed cold/warm continuous batching +long-context warm hit at 128K and 262K +``` + +Performance: + +```text +Context length: 256, 512, 2K, 8K, 32K, 128K, 262K +Block size: 64, 128, 256 +CTE buckets: [256], [512], [256,512], [256,512,1024] +TP: 4, and 8 only if HBM/load requires it +Cache mode: no APC, attention APC only, hybrid APC +GDN dtype: recurrent FP32, recurrent BF16 experiment +Workloads: single request, repeated system prompt, chat, long-doc QA +``` + +Immediate Trainium experiments: + +```text +262K TP=4, block_size=256, CTE buckets [256] +262K TP=4, block_size=128, CTE buckets [256] +128K TP=4, block_size=128, CTE buckets [256,512] +128K TP=4, block_size=256, CTE buckets [256,512] +``` diff --git a/contrib/models/Qwen3.6-27B/docs/QWEN36_FP8_TIERFIX_VALIDATION_20260526.md b/contrib/models/Qwen3.6-27B/docs/QWEN36_FP8_TIERFIX_VALIDATION_20260526.md new file mode 100644 index 00000000..6c61706f --- /dev/null +++ b/contrib/models/Qwen3.6-27B/docs/QWEN36_FP8_TIERFIX_VALIDATION_20260526.md @@ -0,0 +1,2364 @@ +# Qwen3.6 27B FP8 Tierfix Validation - 2026-05-26 + +This note records the 2026-05-26 TRN2 validation of the three prefix-tier FP8 +artifacts and the current blocker for 256K prefix serving. + +Raw result JSON is stored at: + +```text +profile_artifacts/qwen36_fp8_tierfix_validation_20260526/summary_partial_with_pfx256_failure.json +``` + +Remote validation root: + +```text +/home/ubuntu/validation_logs/fp8_256k/tierfix_validation_20260526T152617Z +``` + +Test host: + +```text +ubuntu@16.50.61.215 +instance: trn2.3xlarge +logical-neuroncore-config: 2 +``` + +## Artifact Results + +### 32K/64K Prefix Tier + +Artifact: + +```text +/mnt/trainium_artifacts/qwen_artifacts/qwen36_27b_64k_fp8_full_lmheadbf16_hybrid_apc_prod_pfx32k_64k_nki_fusedstable_directsolve_hostlogits_b256_cte3072_pfx32k_64k_pa256_slots64_tkg32768_65536_async_20260526T132620Z_tierfix_pfx32k_64k +``` + +Compiled limits: + +```text +seq_len=65536 +max_context_length=65536 +pa_num_blocks=256 +pa_block_size=256 +prefix_buckets=[32768, 65536] +context_encoding_bucket_pairs=[[3072, 0], [3072, 32768], [3072, 65536]] +token_generation_buckets=[32768, 65536] +``` + +Prefill: + +| target tokens | cold prefill | cold TPS | warm refill | warm refill TPS | real tokens | +| --- | ---: | ---: | ---: | ---: | --- | +| 32768 | 60.596s | 540.77 | 5.480s | 5979.19 | pass | +| 65280 | 121.736s | 536.24 | 5.667s | 11519.06 | pass | + +Chat/decode: + +| target tokens | run | TTFT | TPOT | decode TPS | completion tokens | +| --- | --- | ---: | ---: | ---: | ---: | +| 32768 | cold | 59.913s | 79.67ms | 12.55 | 64 | +| 32768 | repeat | 5.549s | 78.36ms | 12.76 | 64 | +| 65280 | cold | 67.959s | 83.66ms | 11.95 | 64 | +| 65280 | repeat | 5.785s | 83.50ms | 11.98 | 64 | + +Runtime evidence: + +```text +vLLM reported GPU KV cache size: 65,792 tokens +max concurrency for 65,536 tokens: 1.00x +peak host RSS during chat: 34.04 GiB +``` + +### 128K Prefix Tier + +Artifact: + +```text +/mnt/trainium_artifacts/qwen_artifacts/qwen36_27b_128k_fp8_full_lmheadbf16_hybrid_apc_prod_pfx128k_nki_fusedstable_directsolve_hostlogits_b256_cte3072_pfx128k_pa512_slots64_tkg131072_async_20260526T132620Z_tierfix_pfx128k +``` + +Compiled limits: + +```text +seq_len=131072 +max_context_length=131072 +pa_num_blocks=512 +pa_block_size=256 +prefix_buckets=[131072] +context_encoding_bucket_pairs=[[3072, 131072]] +token_generation_buckets=[131072] +``` + +Prefill: + +| target tokens | cold prefill | cold TPS | warm refill | warm refill TPS | real tokens | +| --- | ---: | ---: | ---: | ---: | --- | +| 130816 | 298.355s | 438.46 | 6.972s | 18762.74 | pass | + +Chat/decode: + +| target tokens | run | TTFT | TPOT | decode TPS | completion tokens | +| --- | --- | ---: | ---: | ---: | ---: | +| 130816 | cold | 298.550s | 173.88ms | 5.75 | 64 | +| 130816 | repeat | 7.250s | 173.41ms | 5.77 | 64 | + +Runtime evidence: + +```text +vLLM reported GPU KV cache size: 131,328 tokens +max concurrency for 131,072 tokens: 1.00x +peak host RSS during chat: 33.75 GiB +``` + +### 256K Prefix Tier + +Artifact: + +```text +/mnt/trainium_artifacts/qwen_artifacts/qwen36_27b_256k_fp8_full_lmheadbf16_hybrid_apc_prod_pfx256k_nki_fusedstable_directsolve_hostlogits_b256_cte3072_pfx256k_pa1024_slots64_tkg262144_async_20260526T132620Z_tierfix_pfx256k +``` + +Compiled limits: + +```text +seq_len=262144 +max_context_length=262144 +pa_num_blocks=1024 +pa_block_size=256 +prefix_buckets=[262144] +context_encoding_bucket_pairs=[[3072, 262144]] +token_generation_buckets=[262144] +``` + +Result: failed during Neuron Runtime load before validation could run. + +Exact failure: + +```text +NRT_RESOURCE in nrt_load_util +Failed to allocate 1.000GB (alignment: 4.000MB, usage: shared scratchpad) +Failed to load NN: + .../context_encoding_model/_tp0_bk0/model.MODULE_c3eddc16a94d9c7dfe80+5c498585.neff +err: 4 +Failed to create logical core info for subgraph 0 to 1 +Failed to stage graph to NeuronCore +Failed to load collectives for model +``` + +TDRV memory table at failure: + +```text +per-HBM TOTAL: 22.056GB +Model Tensors: 12.052GB +Shared Scratchpad: 10.000GB +Failed next alloc: 1.000GB shared scratchpad +``` + +Retry with debug scratchpad placement disabled: + +```text +NEURON_RT_DBG_SCRATCHPAD_ON_SINGLE_CORE=0 +root=/home/ubuntu/validation_logs/fp8_256k/tierfix_validation_pfx256_dbg0_20260526T160225Z +``` + +Result: still failed with the same `NRT_RESOURCE` class. + +```text +per-HBM TOTAL: 22.056GB +Model Tensors: 12.052GB +Shared Scratchpad: 7.000GB on one logical core + 3.000GB on sibling +Failed next alloc: 1.000GB shared scratchpad +``` + +Probe with smaller runtime scratchpad page: + +```text +NEURON_RT_DBG_SCRATCHPAD_ON_SINGLE_CORE=0 +NEURON_SCRATCHPAD_PAGE_SIZE=512 +root=/home/ubuntu/validation_logs/fp8_256k/pfx256_pagesize512_probe_20260526T160430Z +``` + +Result: still failed. + +```text +NRT_RESOURCE in nrt_load_util +Failed to allocate 512.000MB (alignment: 4.000MB, usage: shared scratchpad) +per-HBM TOTAL: 23.056GB +Model Tensors: 12.052GB +Shared Scratchpad: 11.000GB +``` + +## Why 256K Prefix Fails + +The current 256K prefix artifact is not failing because GDN attention KV cache +needs a full-attention 256K cache. It fails earlier: Neuron Runtime cannot load +the 256K context-encoding NEFF because the compiled NEFF's model tensors plus +shared scratchpad exceed the usable HBM slice for that logical placement. + +AWS Neuron's device-memory documentation describes HBM usage categories such as +model tensors, shared scratchpad, non-shared scratchpad, DMA rings, and runtime +allocations. It also documents that scratchpad page size must be coordinated +between compile-time `NEURON_CC_FLAGS=--hbm-scratchpad-page-size=...` and +runtime `NEURON_SCRATCHPAD_PAGE_SIZE=...`; changing only runtime placement/page +size is not guaranteed to repair a NEFF whose compiled scratchpad layout is too +large. See: + +```text +https://awsdocs-neuron.readthedocs-hosted.com/en/latest/neuron-runtime/explore/device-memory.html +``` + +AWS Trainium2 documentation lists 96 GiB of device memory per Trainium2 chip, +but this validation shows the failing pfx256 CTE NEFF is constrained by the +per-HBM/logical-placement allocation shown in the TDRV table, not by aggregate +host RAM or the headline chip memory number. See: + +```text +https://awsdocs-neuron.readthedocs-hosted.com/en/latest/about-neuron/arch/neuron-hardware/trainium2.html +``` + +Current hypothesis: + +```text +The pfx256 context-encoding NEFF at [CTE 3072, prefix 262144] has too much +compiled tensor + shared scratchpad footprint for trn2.3xlarge LNC2 placement. +Runtime-only tweaks did not reduce that footprint enough. Fixing it requires a +new compile with lower pfx256 CTE scratchpad/tensor footprint, a different +tiling/page-size compile, or avoiding the pfx262144 CTE NEFF. +``` + +## Can We Use 128K Prefix and Infer 256K Context? + +Not with the current 128K artifact. + +The current 128K artifact is a 128K-total artifact: + +```text +max_context_length=131072 +seq_len=131072 +pa_num_blocks=512 +token_generation_buckets=[131072] +``` + +It cannot serve or decode a 256K context because the compiled position range, +KV capacity, and token-generation bucket stop at 131072. + +A separate 256K-total artifact with only a 128K prefix bucket is a valid next +mitigation to test: + +```text +seq_len=262144 +max_context_length=262144 +pa_num_blocks=1024 +prefix_buckets=[131072] +context_encoding_bucket_pairs=[[3072, 131072]] +token_generation_buckets=[262144] +``` + +That would be semantically valid for 256K context if it loads, but it changes +the caching behavior: + +```text +cached reusable prefix: up to 128K +remaining prompt suffix: replay/refill up to the requested context length +decode positions: up to 256K, because max_context_length and tkg are 256K +``` + +So 128K prefix is not a replacement for 256K context. It is a cache boundary +inside a 256K-capable artifact. It should be correct, but slower than true +pfx256 reuse for prompts where the reusable shared prefix is above 128K. + +Risk: + +```text +This still needs a 256K token-generation/KV-capable artifact. The pfx128 CTE +NEFF may avoid the pfx256 shared-scratchpad load failure, but the 256K decode +and PA footprint still must be compiled and load-tested before we call it +production-ready. +``` + +## Robust 256K Prefix Fix Under Test + +The robust fix is to keep the 256K prefix bucket but stop compiling the +prefix-attention CTE as one monolithic `[active_tokens, prefix_tokens]` score +tensor. + +Implementation: + +```text +src/neuronx_distributed_inference/models/config.py + NeuronConfig.prefix_cte_attention_chunk_size + +src/neuronx_distributed_inference/modules/attention/attention_base.py + NeuronAttentionBase.perform_prefix_prefill_chunked_prior() + +contrib/models/Qwen3.6-27B/test/integration/qwen36_27b_compile_fp8.py + --prefix-cte-attention-chunk-size +``` + +Behavior: + +```text +If prefix_cte_attention_chunk_size is set and prior_len exceeds it, prefix CTE +attention streams cached-prefix K/V in fixed chunks and combines the chunks with +online softmax. This avoids materializing the full [Q, prefix] score tensor. +The compiled bucket can still be [CTE 3072, prefix 262144]. +``` + +Why this is the robust path: + +```text +The failed pfx256 compile produced an 11GB page-aligned scratchpad requirement +for the pfx256 context-encoding NEFF. 32K/64K prefix-tier artifacts already +compiled and loaded. Streaming pfx256 as eight 32K chunks should bound live +attention-score memory near the proven smaller prefix shapes while preserving +correct full-256K prefix semantics. +``` + +Compile probe started: + +```text +host: ubuntu@16.51.90.254 +pid: 59247 +artifact: + /mnt/trainium_artifacts/qwen_artifacts/qwen36_27b_256k_fp8_full_lmheadbf16_hybrid_apc_prod_pfx256k_stream32k_nki_fusedstable_directsolve_hostlogits_b256_cte3072_pfx256k_stream32k_pa1024_slots64_tkg262144_async_20260526T164116Z_pfx256_stream32k +workdir: + /mnt/trainium_artifacts/qwen_artifacts/_nxd_model_workdir_256k_fp8_full_prod_pfx256k_stream32k_cte3072_pfx256k_stream32k_pa1024_tkg262144_20260526T164116Z_pfx256_stream32k +log: + /home/ubuntu/validation_logs/fp8_256k/qwen36_27b_256k_fp8_full_lmheadbf16_hybrid_apc_prod_pfx256k_stream32k_nki_fusedstable_directsolve_hostlogits_b256_cte3072_pfx256k_stream32k_pa1024_slots64_tkg262144_async_20260526T164116Z_pfx256_stream32k_compile.log +key args: + --seq-len 262144 + --max-context-length 262144 + --prefix-buckets 262144 + --context-encoding-bucket-pairs 3072:262144 + --token-generation-buckets 262144 + --pa-num-blocks 1024 + --prefix-cte-attention-chunk-size 32768 +``` + +Validation so far: + +```text +Local config test: + python3 -m pytest contrib/models/Qwen3.6-27B/test/unit/test_qwen36_compile_fp8_config.py \ + -k 'prefix_cte_attention_chunk_size or sparse_context_encoding_bucket_pairs_are_forwarded' + result: 2 passed + +Remote attention test: + NEURON_PLATFORM_TARGET_OVERRIDE=trn2 python -m pytest \ + test/unit/modules/attention/test_attention_base.py \ + -k 'prefix_prefill_chunked_prior or prefix_prefill_sharded_flash_attn or prefix_prefill_unsharded_flash_attn' + result: 6 passed + +Compile status at start: + HLO generation completed for context_encoding_model and token_generation_model. + neuronx-cc priority compilation started with no NCC_ITIN902/NRT_RESOURCE in + the main log at the time this note was updated. +``` + +## Error Log + +### Remote `rg` Missing + +```text +What failed: + Remote repo inspection command using `rg`. + +How it failed: + bash: line 1: rg: command not found + +How we got there: + Host ubuntu@16.50.61.215 did not have ripgrep installed. + +Hypothesis: + The TRN2 instance image lacks the local developer tooling installed on the + Mac workspace. + +Fix: + Switched remote inspection to `find`, `grep`, `sed`, and `python3`. + +Verification: + Remote config inspection completed and printed the 128K/256K neuron_config + limits recorded in this note. +``` + +### Launcher PID Redirection + +```text +What failed: + Initial background validation launcher. + +How it failed: + bash: line 1: ${PID}: ambiguous redirect + +How we got there: + A shell grouping/variable expansion issue while starting the long validation + command and writing the PID file. + +Hypothesis: + The PID variable was expanded in the wrong shell context. + +Fix: + Manually wrote the detected validation PID to: + /home/ubuntu/validation_logs/fp8_256k/tierfix_validation_20260526T152617Z/run.pid + +Verification: + The validation continued and produced the raw summary JSON stored in this + repo. +``` + +### Remote `python` Missing + +```text +What failed: + Remote JSON/config parsing helper invoked as `python`. + +How it failed: + bash: line 1: python: command not found + +How we got there: + The remote instance exposes Python as `python3`, not `python`. + +Hypothesis: + No `python` compatibility symlink on the remote image. + +Fix: + Reran the helper with `python3`. + +Verification: + Parsed artifact config fields successfully. +``` + +### Local Attention Unit Test Missing `torch_xla` + +```text +What failed: + Local focused attention unit test. + +Command: + python3 -m pytest test/unit/modules/attention/test_attention_base.py \ + -k 'prefix_prefill_chunked_prior or prefix_prefill_sharded_flash_attn or prefix_prefill_unsharded_flash_attn' + +How it failed: + ModuleNotFoundError: No module named 'torch_xla' + +How we got there: + The Mac workspace Python environment does not include torch_xla. + +Hypothesis: + Local environment is not the Neuron inference venv. + +Fix: + Synced the changed files to ubuntu@16.51.90.254 and reran in the Neuron venv. + +Verification: + Remote test passed with 6 selected tests after setting + NEURON_PLATFORM_TARGET_OVERRIDE=trn2. +``` + +### Remote Attention Unit Test Platform Override + +```text +What failed: + First remote focused attention unit test on ubuntu@16.51.90.254. + +How it failed: + RuntimeError: Unsupported Platform - r7i.24xlarge + If you want to compile on CPU, please supply a compiler target argument. + +How we got there: + The compile host is a CPU/cross-compile instance. Importing Neuron/NxD modules + without a platform override caused torch_neuronx to infer the host platform + instead of the target Trainium platform. + +Hypothesis: + Neuron unit tests that import NxD need NEURON_PLATFORM_TARGET_OVERRIDE when + running on non-Trainium compile hosts. + +Fix: + Reran with: + NEURON_PLATFORM_TARGET_OVERRIDE=trn2 + +Verification: + 6 selected attention prefix-prefill tests passed. +``` + +### 256K Prefix Runtime Load Failure + +```text +What failed: + 256K pfx256 artifact prefill/runtime load. + +How it failed: + NRT_RESOURCE in nrt_load_util: + Failed to allocate 1.000GB shared scratchpad + Failed to load context_encoding_model/_tp0_bk0/model...neff, err: 4 + +How we got there: + Artifact: + qwen36_27b_256k_fp8_full_lmheadbf16_hybrid_apc_prod_pfx256k_... + Inputs: + seq_len=262144 + pa_num_blocks=1024 + length=261888 + CTE/prefix pair=[3072, 262144] + token_generation_buckets=[262144] + +Hypothesis: + The pfx256 context-encoding NEFF compiled tensor + shared scratchpad footprint + exceeds the usable per-HBM allocation for the logical NeuronCore placement. + +Fix attempted: + Retried with: + NEURON_RT_DBG_SCRATCHPAD_ON_SINGLE_CORE=0 + Then probed: + NEURON_RT_DBG_SCRATCHPAD_ON_SINGLE_CORE=0 + NEURON_SCRATCHPAD_PAGE_SIZE=512 + +Verification: + Both retries still failed with `NRT_RESOURCE`, so the remaining blocker is a + compiled NEFF footprint issue, not just a runtime placement knob. +``` + +### Python-Level 256K Prefix Chunking Did Not Reduce NEFF Memory + +```text +What failed: + Robust pfx256 mitigation probe using Python-level prefix attention chunking. + +How it failed: + The compile itself completed, but the context-encoding NEFF memory footprint + did not improve: + COMPILE_DONE + context HBM: 24.101GB + total page-aligned scratchpad: 11.000000GB + +How we got there: + Host: + ubuntu@16.51.90.254 + Artifact: + /mnt/trainium_artifacts/qwen_artifacts/qwen36_27b_256k_fp8_full_lmheadbf16_hybrid_apc_prod_pfx256k_stream32k_nki_fusedstable_directsolve_hostlogits_b256_cte3072_pfx256k_stream32k_pa1024_slots64_tkg262144_async_20260526T164116Z_pfx256_stream32k + Workdir: + /mnt/trainium_artifacts/qwen_artifacts/_nxd_model_workdir_256k_fp8_full_prod_pfx256k_stream32k_cte3072_pfx256k_stream32k_pa1024_tkg262144_20260526T164116Z_pfx256_stream32k + Inputs: + --seq-len 262144 + --max-context-length 262144 + --cte-buckets 3072 + --prefix-buckets 262144 + --context-encoding-bucket-pairs 3072:262144 + --token-generation-buckets 262144 + --pa-num-blocks 1024 + --prefix-cte-attention-chunk-size 32768 + +Hypothesis: + XLA/Neuron static tracing still lowered the Python chunk loop into a graph + with the same large flat prefix attention footprint. This confirms that + chunking must happen inside an NKI kernel or via a newer segmented CTE kernel, + not in regular PyTorch graph code. + +Fix or mitigation applied: + Do not treat the pfx256_stream32k artifact as the production fix. Web/docs + research points to Neuron 2.30 NKI Library `Attention Segmented CTE` and + `KV-Parallel Segmented Prefill` as the next production-grade path because + they process block KV/prefix cache in configurable segments inside the kernel. + +Verification: + Pending. Need either: + 1. Upgrade/overlay a Neuron 2.30 NKI Library containing segmented CTE and + wire prefix CTE to that kernel; or + 2. Write a custom NKI segmented prefix attention kernel if the library + kernel is unavailable in our runtime. +``` + +### Neuron 2.30 Segmented CTE Overlay Inspection + +```text +What failed: + First SSH inspection command after creating the Neuron 2.30 segmented CTE + overlay on ubuntu@16.51.90.254. + +How it failed: + The command exited 1 because it used `set -o pipefail` with: + find "$NKILIB_DIR" -maxdepth 5 -type f | grep -E "attention.*(seg|prefill|cte).*\.py$|kv.*prefill.*\.py$" + The `find` maxdepth/pattern missed the files under + src/nkilib_src/nkilib/core/attention, so `grep` returned no matches. + +How we got there: + Host: + ubuntu@16.51.90.254 + Overlay venv: + /home/ubuntu/venvs/neuron_230_segmented_cte + Source checkout: + /home/ubuntu/nki-library-2.30 + Branch: + 2.30_release + Installed Python packages: + nki==0.4.0+25940409122.gd30719f9 + neuronx-cc==2.25.3371.0+f524f7f8 + +Root cause: + Inspection-command bug, not an overlay setup failure. + +Fix: + Reran inspection with the overlay activated and direct Python imports. + +Verification: + Confirmed: + IMPORT_OK nkilib.core.attention.attention_segmented_cte + IMPORT_OK nkilib.core.attention.kv_parallel_segmented_prefill + attention_segmented_cte signature accepts block KV cache, block_tables, + prior_tokens, block_size, and prior_seg_size. +``` + +### Local Segmented CTE Search/Syntax Checks + +```text +What failed: + Local source search for k-cache transposition references. + +How it failed: + Command exited 2: + rg: src/modeling_qwen35.py: No such file or directory (os error 2) + +How we got there: + I searched `src/modeling_qwen35.py`, but this repository stores the Qwen model + file at: + contrib/models/Qwen3.6-27B/src/modeling_qwen35.py + +Root cause: + Wrong local path in the search command. + +Fix: + Reran with `rg --files` and then searched existing paths under `src` and + `contrib/models/Qwen3.6-27B`. + +Verification: + Found the relevant `k_cache_transposed` references and confirmed block KV + cache disables transposed K cache. + +What failed: + First local Python syntax command: + python -m py_compile ... + +How it failed: + zsh:1: command not found: python + +How we got there: + The local Mac shell exposes `python3` but not `python`. + +Fix: + Reran: + python3 -m py_compile ... + +Verification: + Syntax compile passed for: + src/neuronx_distributed_inference/modules/attention/attention_base.py + src/neuronx_distributed_inference/models/config.py + src/neuronx_distributed_inference/models/model_wrapper.py + contrib/models/Qwen3.6-27B/test/integration/qwen36_27b_compile_fp8.py + contrib/models/Qwen3.6-27B/test/unit/test_qwen36_compile_fp8_config.py + Focused config tests passed: + 23 passed, 3 subtests passed +``` + +### Segmented CTE Overlay Wiring and Compile Launch + +```text +What failed: + Remote diff preview after syncing segmented CTE files to ubuntu@16.51.90.254. + +How it failed: + The command exited 141 because it ran `git diff ... | head -240` under + `set -o pipefail`; `head` closed the pipe and `git diff` received SIGPIPE. + +How we got there: + Files had already been installed into: + /home/ubuntu/inferentia-gdn-fused-noclamp-4340808 + The failing command was only a preview step after install. + +Root cause: + Shell preview mistake, not a sync failure. + +Fix: + Reran remote status/syntax checks without piping through `head`. + +Verification: + Remote `py_compile` passed for the synced files. + +What failed: + First remote focused tests in the Neuron 2.30 overlay venv. + +How it failed: + /home/ubuntu/venvs/neuron_230_segmented_cte/bin/python: + No module named pytest + +How we got there: + The overlay venv was intentionally minimal and only installed newer + nki/neuronx-cc. + +Root cause: + Missing test dependency in the overlay venv. + +Fix: + Ran focused unit tests in the base Neuron venv instead: + /opt/aws_neuronx_venv_pytorch_inference_vllm_0_16 + +Verification: + Remote tests passed: + 42 passed, 3 subtests passed + +What failed: + First overlay import check for the synced attention module. + +How it failed: + ModuleNotFoundError: No module named 'torch' + +How we got there: + The overlay venv had nki 0.4 / neuronx-cc 2.25 but did not inherit the base + Neuron venv's PyTorch/NxD packages. + +Root cause: + `python -m venv --system-site-packages` does not inherit packages installed + inside another venv. + +Fix: + Added a `.pth` file in the overlay site-packages pointing to: + /opt/aws_neuronx_venv_pytorch_inference_vllm_0_16/lib/python3.12/site-packages + +Verification: + Overlay then imported PyTorch, but exposed the next PATH issue below. + +What failed: + Overlay import after adding base site-packages. + +How it failed: + FileNotFoundError: + [Errno 2] No such file or directory: 'libneuronpjrt-path' + +How we got there: + torch_xla imported from the base venv site-packages, but the overlay + activation did not include the base venv `bin` directory on PATH. + +Root cause: + Base Neuron helper executables were not visible when using overlay Python. + +Fix: + Updated the compile launcher to export: + PATH="${NEURON_VENV}/bin:${BASE_NEURON_VENV}/bin:${PATH}" + +Verification: + Overlay import check passed: + TORCH 2.9.1+cu128 + NKI 0.4.0+25940409122.gd30719f9 + NEURONXCC 2.25.3371.0+f524f7f8 + SEGMENTED_KERNEL True + +What failed: + Potential segmented CTE compile-sample invalidity for + [active=3072, prefix=262144] with `pa_num_blocks=1024`. + +How it would fail: + The generated sample `slot_mapping` would write active KV at positions + 262144..265215, past the 256K cache capacity, before segmented CTE reads + active KV from block cache. + +How we got there: + Existing prefix CTE sampled `computed_context_lens=prefix_bucket`; this was + fine for flat `attention_cte` because active KV was passed separately, but + segmented CTE reads active KV from the updated block cache. + +Root cause: + The sample value for `computed_context_lens` was not constrained to + `max_context_length - active_bucket` for segmented CTE. + +Fix: + In `model_wrapper.py`, for context-encoding segmented CTE samples, keep the + bucket shape at 262144 but set the sample prior to: + min(prefix_bucket, max_context_length - n_active_tokens) + For the pfx256/cte3072 trace this is: + computed_context_lens=259072 + +Verification: + The segmented CTE compile got through both context HLOs and the + token-generation HLO without sample OOB or import errors. + +What was cleaned up: + Removed the two known-bad pfx256 probes before launching the new compile: + qwen36_27b_256k_..._prod_pfx256k_..._20260526T132620Z_tierfix_pfx256k + qwen36_27b_256k_..._prod_pfx256k_stream32k_..._20260526T164116Z_pfx256_stream32k + plus their `_nxd_model_workdir_*` directories. + +Why: + They were already proven not production-ready: + pfx256 tierfix hit runtime load NRT_RESOURCE. + pfx256_stream32k compiled but kept the same large HBM/scratch footprint. + +Verification: + Free disk on /mnt/trainium_artifacts increased from 35GB to 93GB. + +Current compile: + Host: + ubuntu@16.51.90.254 + PID: + 65885 + Artifact: + /mnt/trainium_artifacts/qwen_artifacts/qwen36_27b_256k_fp8_full_lmheadbf16_hybrid_apc_prod_pfx256k_segcte32k_nki_fusedstable_directsolve_hostlogits_b256_cte3072_pfx256k_segcte32k_pa1024_slots64_tkg262144_async_20260526T174252Z + Log: + /home/ubuntu/validation_logs/fp8_256k/qwen36_27b_256k_fp8_full_lmheadbf16_hybrid_apc_prod_pfx256k_segcte32k_nki_fusedstable_directsolve_hostlogits_b256_cte3072_pfx256k_segcte32k_pa1024_slots64_tkg262144_async_20260526T174252Z_compile.log + Key flags: + --context-encoding-bucket-pairs 3072:262144 + --prefix-cte-attention-backend segmented_cte + --prefix-cte-attention-segment-size 32768 + --pa-num-blocks 1024 + Status: + HLO generation completed for both context_encoding_model traces and the + token_generation_model trace. neuronx-cc compilation is running. +``` + +### Segmented CTE Compile Completed but pfx256 Footprint Still Has Flat Gather + +```text +What failed: + The pfx256 segmented CTE compile completed, but it did not eliminate the + high-footprint 256K-prefix context NEFF. + +How it failed: + Compile status: + COMPILE_DONE + Artifact: + /mnt/trainium_artifacts/qwen_artifacts/qwen36_27b_256k_fp8_full_lmheadbf16_hybrid_apc_prod_pfx256k_segcte32k_nki_fusedstable_directsolve_hostlogits_b256_cte3072_pfx256k_segcte32k_pa1024_slots64_tkg262144_async_20260526T174252Z + Context bucket summaries: + context_encoding_model/_tp0_bk0: + Total estimated HBM usage: 13.65GB + Total page-aligned scratchpad: 1.500000GB + context_encoding_model/_tp0_bk1: + Total estimated HBM usage: 24.10GB + Total page-aligned scratchpad: 11.000000GB + Token-generation summary: + token_generation_model/_tp0_bk0: + Total estimated HBM usage: 12.42GB + Total page-aligned scratchpad: 0.500000GB + +How we got there: + Host: + ubuntu@16.51.90.254 + Key compile args: + --context-encoding-bucket-pairs 3072:262144 + --prefix-cte-attention-backend segmented_cte + --prefix-cte-attention-segment-size 32768 + --pa-num-blocks 1024 + Overlay: + nki==0.4.0+25940409122.gd30719f9 + neuronx-cc==2.25.3371.0+f524f7f8 + +Evidence: + `neuron_config.json` in the artifact records: + prefix_cte_attention_backend=segmented_cte + prefix_cte_attention_segment_size=32768 + But `context_encoding_model/_tp0_bk1/log-neuron-cc.txt` still contains + large indirect loads from: + get_kv_by_layer_id/_get_block_cache_and_reshape_bhsd/aten.index_select + with cache-shaped tensors such as: + bfloat16 (1025, 65536) + That means the long-prefix trace still materialized the flattened block-cache + gather before/alongside the segmented CTE path. + +Root cause / hypothesis: + The integration still calls `kv_mgr.get_kv_by_layer_id(**kwargs)` before the + segmented CTE pre-update path. For prefix caching, that method gathers block + KV into flat BHSD prior tensors through `_get_block_cache_and_reshape_bhsd`. + Those flattened gathers remain in the HLO and dominate the 256K-prefix NEFF, + so using `attention_segmented_cte` later is not enough. + +Fix or next mitigation: + The robust fix is to add a true raw-block-cache prefix path for segmented + CTE: + 1. In context encoding when `prefix_cte_attention_backend=segmented_cte`, + do not call `kv_mgr.get_kv_by_layer_id` for prefix prior. + 2. Fetch raw per-layer block KV via `kv_mgr._fetch_cache(...)` or a clean + public wrapper. + 3. Pre-update active K/V into raw block KV. + 4. Call `attention_segmented_cte` with raw block KV, `active_block_table`, + and `computed_context_lens`. + 5. Return the updated raw block KV and skip the old flat prior path. + +Verification: + Not fixed yet. The completed artifact should not be treated as the pfx256 + production fix. It can be transferred only for confirmation, but based on the + compile footprint it is expected to have the same runtime-load risk as the + previous pfx256 artifact. +``` + +### Raw Block Segmented CTE Fix Applied + +```text +What failed: + Web/docs review found that the previous segmented CTE integration did not + match the official block-KV contract. The Qwen hybrid prefill path still + called `get_kv_by_layer_id`, which flattened prefix blocks before attention. + +How it failed: + The pfx256 segmented CTE artifact compiled, but `log-neuron-cc.txt` still + showed `_get_block_cache_and_reshape_bhsd/aten.index_select` in the pfx256 + context bucket and HBM reached 24.10GB per core with 11GB page-aligned + scratchpad. + +How we got there: + Branch: + codex/full-fp8-qwen36 + Backend: + prefix_cte_attention_backend=segmented_cte + Bucket: + context_encoding_bucket_pairs=3072:262144 + The base attention code had a segmented CTE call, but Qwen's hybrid path + pre-fetched `past_key_values` through `QwenHybridBlockKVCacheManager.get_cache` + and then used `perform_qwen_chunked_prefill` over flat selected prefix KV. + +Root cause / hypothesis: + Official Neuron docs say NxDI prefix caching uses block KV, but the default + prefix-caching flow gathers block KV into a flat layout before attention. + Neuron 2.30 adds `Attention Segmented CTE` and `KV-Parallel Segmented + Prefill` kernels specifically for block-based KV cache. Therefore the fix is + not another bucket shape; it is avoiding the flat gather entirely for the + segmented CTE path. + +Fix applied: + - Added `BlockKVCacheManager.get_raw_kv_by_layer_id()` to return block-layout + KV without `_get_block_cache_and_reshape_bhsd`. + - Changed `QwenHybridBlockKVCacheManager.get_cache()` so segmented context + prefix buckets return raw block KV for full-attention layers. + - Changed Qwen chunked prefill so `prefix_cte_attention_backend=segmented_cte` + pre-updates active K/V into raw block cache and calls + `attention_segmented_cte` with `active_block_table` and + `computed_context_lens`. + - Changed Qwen cache update to accept already-updated raw block KV and skip a + second block-cache update. + - Fixed the base attention segmented path so it no longer requires flat + `past_key_value` before dispatching to segmented CTE. + +Verification: + Pending local unit tests and a new pfx256 compile. Expected compile evidence + for success: + - no `_get_block_cache_and_reshape_bhsd/aten.index_select` in pfx256 + context HLO/logs; + - pfx256 context HBM below the per-core 24GB limit with materially smaller + scratchpad than the failed 24.10GB / 11GB artifact. +``` + +### Raw Block Segmented CTE Compile Blocked by head_dim=256 + +```text +What failed: + Fresh pfx256 raw-block segmented CTE compile failed during HLO generation. + +How it failed: + Host: + ubuntu@16.51.90.254 + PID: + 69740 + Log: + /home/ubuntu/validation_logs/fp8_256k/qwen36_27b_256k_fp8_full_lmheadbf16_hybrid_apc_prod_pfx256k_rawsegcte32k_nki_fusedstable_directsolve_hostlogits_b256_cte3072_pfx256k_rawsegcte32k_pa1024_slots64_tkg262144_async_20260526T183314Z_compile.log + Artifact target: + /mnt/trainium_artifacts/qwen_artifacts/qwen36_27b_256k_fp8_full_lmheadbf16_hybrid_apc_prod_pfx256k_rawsegcte32k_nki_fusedstable_directsolve_hostlogits_b256_cte3072_pfx256k_rawsegcte32k_pa1024_slots64_tkg262144_async_20260526T183314Z + Exact error: + AssertionError: error: failed to compile NKI kernel: + Collected 1 different diagnostics: + - [x2] error: assertion failed: [INTERNAL_ERROR] [NCC_INKI016] + Kernel validation exception: head_dim must be <= 128 (got 256). + Larger head_dim not yet supported. - Please check the validation + message and adjust kernel inputs accordingly + +How we got there: + Command launched `tmp_compile_qwen256k_fp8_full_prod_prefix_tier_hostlogits.sh` + with: + TIER_NAME=pfx256k_rawsegcte32k + PREFIX_BUCKETS_STR=262144 + PAIR_ARGS_STR=3072:262144 + CTE_BUCKETS_STR=3072 + TKG_BUCKETS_STR=262144 + PA_NUM_BLOCKS=1024 + PREFIX_CTE_ATTENTION_BACKEND=segmented_cte + PREFIX_CTE_ATTENTION_SEGMENT_SIZE=32768 + NEURON_VENV=/home/ubuntu/venvs/neuron_230_segmented_cte + NKI_LIBRARY_SRC=/home/ubuntu/nki-library-2.30/src/nkilib_src + +Root cause / hypothesis: + Proven root cause for this failure: Neuron 2.30 `attention_segmented_cte` + hard-validates `head_dim <= 128`, while Qwen3.6 27B has attention + `head_dim=256`. + Evidence: + /home/ubuntu/nki-library-2.30/src/nkilib_src/nkilib/core/attention/attention_segmented_cte.py + contains: + kernel_assert(head_dim <= 128, ...) + The bundled NKI model test config for `qwen3_235b` uses `d_head=128`, so + the official segmented CTE Qwen coverage does not cover this 27B head_dim. + +Fix or mitigation: + The raw-block segmented CTE integration is correct structurally, but the + official kernel cannot support this model without a head_dim=256 variant. + Viable next options are: + 1. Build a Qwen-specific head_dim=256 segmented CTE kernel that accumulates + QK over two 128-wide D tiles before softmax, then computes PV over the + full 256-wide V. This is the robust fix if we require a true 256K prefix + bucket. + 2. Use the existing production-safe tier strategy with <=128K prefix buckets + and route 256K-context requests through a lower prefix bucket, accepting + extra refill work. + 3. Open/escalate an AWS Neuron issue requesting head_dim=256 support in + `attention_segmented_cte`. + +Verification: + Compile did not complete. Do not retry this exact raw-block segmented CTE + compile until the head_dim=256 kernel limitation is addressed. +``` + +### Qwen head_dim=256 Segmented CTE Kernel Bring-Up + +```text +What failed: + The first Qwen-specific segmented CTE prototype was a direct copy of the + Neuron 2.30 segmented CTE kernel with only the top-level head_dim validator + relaxed. + +How it failed: + Host: + ubuntu@16.51.90.254 + Remote repo: + /home/ubuntu/inferentia-gdn-fused-noclamp-4340808 + Probe: + Offline NKI compile_to_bir with q=(2,256,256), + k/v_cache=(8,1,256,256), block_size=256, prior_seg_size=512, + tp_q=True, tp_out=False, target=trn2. + Exact errors hit and fixed: + 1. dma_copy dst partition dimension 256 exceeds maximum 128 + at attention_segmented_cte_256.py load_kv_cache. + Cause: copied kernel still loaded K as one (head_dim, K_TILE) tile. + Fix: split K into low/high (128, K_TILE) tiles. + 2. unsupported expression on list comprehensions creating K tile pairs. + Cause: NKI specialization rejected Python list comprehensions. + Fix: build the list with explicit for/append meta-programming. + 3. failed to resolve name 'x::0.shape' from k_sbuf[0].shape. + Cause: split K tile entries are Python pairs, not NKI tensors. + Fix: read K_TILE_SIZE from k_sbuf[0][0].shape for head_dim=256. + 4. dma_copy dst partition dimension 256 exceeds maximum 128 on the + temporary non-transposed K block load. + Cause: temp was shaped (block_size, 128), and block_size=256 became + the partition dimension. + Fix: load each K block in 128-token by 128-dim chunks. + 5. dma_copy src/dst element mismatch src=32768 dst=16384. + Cause: source access pattern still selected full D=256 for a 128-wide + destination. + Fix: use HBM source pattern [[head_dim, 128], [1, 128]] for each + 128-token by 128-dim K chunk. + 6. dma_transpose Q shape mismatch: source D=256, destination D=128. + Cause: split Q source pattern used full D as the transposed extent. + Fix: keep token stride at ac.d but set the transposed D count to 128: + [[ac.d, num_f], [1, 1], [1, 1], [1, 128]]. + 7. reduce_one_batch batch_idx typed as object. + Cause: copied call used an older helper signature and passed output + tensors where Neuron 2.30 expects batch_idx/grp_start/grp_end. + Fix: call reduce_one_batch with batch_idx=0, grp_start=0, + grp_end=n_grps, d=head_dim, num_grps=n_grps, sb_p=sb_p. + +Fix implemented: + Added a Qwen-specific NKI package: + src/neuronx_distributed_inference/modules/attention/nki_kernels/qwen_segcte256/ + + The kernel keeps V and output on legal free dimensions and only splits the + QK contraction: + logits = Q_lo @ K_lo + Q_hi @ K_hi + + This follows the documented NKI matmul rule that contraction dimensions + larger than 128 must be accumulated through multiple nc_matmul writes to the + same PSUM tile. + +Verification: + Local syntax: + python3 -m py_compile attention_base.py attention_segmented_cte_256.py + fused_segmented_attention_256.py + + Remote syntax: + PYCOMPILE_OK on ubuntu@16.51.90.254 under + /home/ubuntu/venvs/neuron_230_segmented_cte with Neuron 2.30 NKI source. + + Remote NKI BIR probe: + BIR_OK for q=(2,256,256), k/v=(8,1,256,256), prior_seg_size=512. + + Remote production-shape NKI BIR probe: + BIR_Q3072_OK for q=(2,3072,256), k/v=(1024,1,256,256), + block_size=256, prior_seg_size=32768, pa_num_blocks=1024. + Reported scratch: + sb_scratch_sizes=[402724] + psum_scratch_sizes=[15360] + +Remaining work: + This validates NKI front-end/BIR legality for the target bucket geometry. + Full model compile and runtime numerical validation are still required before + calling the pfx256 artifact production-ready. +``` + +### pfx256 segcte256d32k Full Compile Failed on SBUF Scratch Allocation + +```text +What failed: + Full Qwen3.6 27B FP8 pfx256k compile with the Qwen head_dim=256 segmented + CTE kernel failed during neuronx-cc compilation of context_encoding_model. + +How it failed: + Host: + ubuntu@16.51.90.254 + PID: + 76788 + Log: + /home/ubuntu/validation_logs/fp8_256k/qwen36_27b_256k_fp8_full_lmheadbf16_hybrid_apc_prod_pfx256k_segcte256d32k_nki_fusedstable_directsolve_hostlogits_b256_cte3072_pfx256k_pa1024_slots64_tkg262144_async_20260526T191006Z_compile.log + Workdir: + /mnt/trainium_artifacts/qwen_artifacts/_nxd_model_workdir_256k_fp8_full_prod_pfx256k_segcte256d32k_cte3072_pfx256k_pa1024_tkg262144_20260526T191006Z + Failing buckets: + context_encoding_model/_tp0_bk0 + context_encoding_model/_tp0_bk1 + Exact compiler error: + [INTERNAL_ERROR] [NCC_INLA001] Unhandled exception with message: + Allocated memory out of bound + {scratch_sb_for_inst__I-361}@SB<0,0>(128x402724) + #Internal DebugInfo: + Exit: + neuronx-cc returned non-zero exit status 70. + +How we got there: + Compile command used: + TIER_NAME=pfx256k_segcte256d32k + PREFIX_BUCKETS_STR=262144 + PAIR_ARGS_STR=3072:262144 + CTE_BUCKETS_STR=3072 + TKG_BUCKETS_STR=262144 + PA_NUM_BLOCKS=1024 + PREFIX_CTE_ATTENTION_BACKEND=segmented_cte + PREFIX_CTE_ATTENTION_SEGMENT_SIZE=32768 + NEURON_VENV=/home/ubuntu/venvs/neuron_230_segmented_cte + NKI_LIBRARY_SRC=/home/ubuntu/nki-library-2.30/src/nkilib_src + +Root cause / hypothesis: + Proven: + The custom head_dim=256 NKI kernel is BIR-legal for the target shape, but + the backend rejects the generated context CTE kernel because its live SBUF + scratch allocation is too large: 128x402724. + Best current hypothesis: + The first head_dim=256 kernel keeps too many per-segment K/V and per-Q-group + attention buffers live in SBUF. Splitting D into two 128-wide K/Q tiles fixed + the head_dim validator, but doubled K-side live storage and still inherited + the upstream segmented-CTE allocation style that materializes too much segment + state at once. + +Additional probes: + Offline BIR probes after failure showed scratch is still high even with lower + segment sizes: + q=3072, segment=8192 -> sb_scratch_sizes=[206116] + q=3072, segment=4096 -> sb_scratch_sizes=[116064] + q=3072, segment=2048 -> sb_scratch_sizes=[107840] + q=3072, segment=512 -> sb_scratch_sizes=[107808] + q=512, segment=512 -> sb_scratch_sizes=[54208] + These are BIR-legal but still likely too high for backend SBUF placement. + +Fix or mitigation: + Do not retry the same pfx256 segcte256d32k compile. + The next robust kernel fix is to reduce live SBUF, not only segment length: + - stream K/V tiles through the QK and PV loops instead of holding an entire + prior segment in SBUF; + - allocate MM1/MM2 scratch per Q group or a small group window instead of + block_dim=[num_grps] for all 3072 active tokens; + - keep only the running softmax stats/output persistent across segments. + This is a second-stage kernel rewrite. The current kernel fixed the head_dim + problem but is not production-ready for pfx256 because of SBUF pressure. + +Verification: + The compile failed. No artifact was produced. The heartbeat monitor was + stopped after recording this failure. +``` + +### pfx256 Kernel Rewrite: Active CTE Streaming + Q-Pack Cap + +```text +What failed: + The first qwen_segcte256 kernel fixed head_dim=256 front-end legality, but + full pfx256 compile failed because the context CTE kernel needed an illegal + live SBUF allocation: + {scratch_sb_for_inst__I-361}@SB<0,0>(128x402724) + +How we got there: + Host: + ubuntu@16.51.90.254 + Branch: + codex/full-fp8-qwen36 + Files: + src/neuronx_distributed_inference/modules/attention/nki_kernels/qwen_segcte256/fused_segmented_attention_256.py + src/neuronx_distributed_inference/modules/attention/nki_kernels/qwen_segcte256/attention_segmented_cte_256.py + Target compile shape: + CTE bucket 3072, prefix bucket 262144, PA blocks 1024, head_dim 256, + FP8 full model, hybrid APC, segmented CTE backend. + +Errors encountered while fixing: + 1. NKI front-end rejected list comprehensions inside the kernel. + Command: + Remote inline compile_kernel_to_nir probe with + q=(2,256,256), k/v=(8,1,256,256), prior_seg_size=512. + Exact pattern: + unsupported expression on list comprehensions for mm1_masked_row, + exp_sb_row, mm1_copy_row, mm1_affine_select_output_row, exp_tp_row, + and _repeat_ref. + Root cause: + NKI kernels do not accept those Python list-comprehension expressions. + Fix: + Replaced each list comprehension with an explicit for-loop and append. + Verification: + The same BIR probe passed: + BIR_SMALL_OK + sb_scratch_sizes=[30592, 30592] + psum_scratch_sizes=[9216, 9216] + + 2. First active-streaming BIR hit an out-of-bound tensor access. + Command: + Remote inline compile_kernel_to_nir probe with + q=(2,3072,256), k/v=(1024,1,256,256), block_tables=(1,1024), + prior_seg_size=512. + Exact error: + assertion failed: Out-of-bound access for tensor `unnamed` on dimension + 1: index 1 exceed dimension size of 1. + Called from fused_segmented_attention_256.py in _exp_impl(). + Root cause / hypothesis: + The active stream allocated exp/running partial-sum columns for one + 512-token chunk, but ac.seqlen_k_active_updated still described the full + 3072-token active range, so _exp_impl tried to address chunk index 1 in + a one-column buffer. + Fix: + Rebuild ac/atp per active stream chunk with + seqlen_k_active_updated=next_section_offset_active, while preserving the + global K position through SectionParams.kv_section_idx. + Verification: + The q=3072, segment=512 BIR probe advanced past _exp_impl and compiled. + + 3. A docs/inspection helper import failed while probing NKI internals. + Command: + Import nki.framework.torch_xla in + /home/ubuntu/venvs/neuron_230_segmented_cte. + Exact error: + FileNotFoundError: [Errno 2] No such file or directory: + 'libneuronpjrt-path' + Root cause / hypothesis: + Importing torch_xla through the overlay venv initialized torch_neuronx + without the base Neuron runtime path. + Fix: + Avoid that inspection path for BIR probes; import NkiTensor from + nki.language.tensor and shared_hbm from nki.language.buffers. + Verification: + BIR probes compiled with the direct NKI imports. + +Fix implemented: + The robust simple fix is not a larger prefix segment. It is a smaller live + working set: + - alias per-Q-group temporary SBUF buffers to one reusable group window; + - stream active CTE K/V through the same bounded K/V SBUF window used by + prior-prefix segments; + - keep only running max/sum/output persistent across active/prior segments; + - cap Q group packing to 4 groups for head_dim=256. + + The compile must use: + PREFIX_CTE_ATTENTION_BACKEND=segmented_cte + PREFIX_CTE_ATTENTION_SEGMENT_SIZE=512 + CTE_BUCKETS_STR=3072 + PAIR_ARGS_STR=3072:262144 + PREFIX_BUCKETS_STR=262144 + PA_NUM_BLOCKS=1024 + +Verification: + Local syntax: + python3 -m py_compile \ + src/neuronx_distributed_inference/modules/attention/nki_kernels/qwen_segcte256/fused_segmented_attention_256.py \ + src/neuronx_distributed_inference/modules/attention/nki_kernels/qwen_segcte256/attention_segmented_cte_256.py + + Remote syntax: + REMOTE_PYCOMPILE_OK on ubuntu@16.51.90.254 under + /home/ubuntu/venvs/neuron_230_segmented_cte. + + Remote BIR scratch results after the rewrite: + q=3072, prior_seg_size=512: + BIR_STREAMACTIVE_Q3072_SEG512_QPACK4_OK + sb_scratch_sizes=[31360, 31360] + psum_scratch_sizes=[9216, 9216] + + q=3072, prior_seg_size=1024: + sb_scratch_sizes=[35488, 35488] + + q=3072, prior_seg_size=2048: + sb_scratch_sizes=[43680, 43680] + + q=3072, prior_seg_size=4096: + sb_scratch_sizes=[60096, 60096] + +Conclusion: + For Trn2 head_dim=256 with the current NKI layout, prior_seg_size=512 is the + only verified segment size under the documented SBUF free-dimension limit + of 32767. The previous 32k segment path and the 1024+ segment probes remain + unsafe. Full model compile and runtime validation are still required before + marking the pfx256 artifact production-ready. +``` + +### pfx256 segcte512stream Full Compile Launched + +```text +What changed: + Launched the full model compile using the verified active-streaming kernel + shape instead of the failed 32k-segment kernel. + +Host: + ubuntu@16.51.90.254 + +Compile PID: + 84525 + +Artifact target: + /mnt/trainium_artifacts/qwen_artifacts/qwen36_27b_256k_fp8_full_lmheadbf16_hybrid_apc_prod_pfx256k_segcte512stream_qpack4_nki_fusedstable_directsolve_hostlogits_b256_cte3072_pfx256k_pa1024_slots64_tkg262144_async_20260526T195604Z + +Workdir: + /mnt/trainium_artifacts/qwen_artifacts/_nxd_model_workdir_256k_fp8_full_prod_pfx256k_segcte512stream_qpack4_cte3072_pfx256k_pa1024_tkg262144_20260526T195604Z + +Log: + /home/ubuntu/validation_logs/fp8_256k/qwen36_27b_256k_fp8_full_lmheadbf16_hybrid_apc_prod_pfx256k_segcte512stream_qpack4_nki_fusedstable_directsolve_hostlogits_b256_cte3072_pfx256k_pa1024_slots64_tkg262144_async_20260526T195604Z_compile.log + +PID file: + /home/ubuntu/validation_logs/fp8_256k/qwen36_27b_256k_fp8_full_lmheadbf16_hybrid_apc_prod_pfx256k_segcte512stream_qpack4_nki_fusedstable_directsolve_hostlogits_b256_cte3072_pfx256k_pa1024_slots64_tkg262144_async_20260526T195604Z_compile.pid + +Inputs / flags: + TIER_NAME=pfx256k_segcte512stream_qpack4 + PREFIX_BUCKETS_STR=262144 + PAIR_ARGS_STR=3072:262144 + CTE_BUCKETS_STR=3072 + TKG_BUCKETS_STR=262144 + SEQ_LEN=262144 + MAX_CONTEXT_LENGTH=262144 + PA_NUM_BLOCKS=1024 + OMIT_ZERO_PREFIX_PAIR=1 + PREFIX_CTE_ATTENTION_BACKEND=segmented_cte + PREFIX_CTE_ATTENTION_SEGMENT_SIZE=512 + NEURON_VENV=/home/ubuntu/venvs/neuron_230_segmented_cte + BASE_NEURON_VENV=/opt/aws_neuronx_venv_pytorch_inference_vllm_0_16 + NKI_LIBRARY_SRC=/home/ubuntu/nki-library-2.30/src/nkilib_src + +Initial status: + HLO generation completed for context_encoding_model and token_generation_model. + neuronx-cc started compiling token_generation_model. + +Profiling note: + NKI profiling cannot replace compile preflight because it requires an + executed NEFF/NTFF pair. The pre-compile gate for this issue is BIR scratch + sizing. After COMPILE_DONE, runtime validation should be run with: + NEURON_RT_INSPECT_ENABLE=1 + NEURON_RT_INSPECT_DEVICE_PROFILE=1 + NEURON_RT_INSPECT_OUTPUT_DIR=/inspect + Then identify the qwen_segcte256 NEFF, capture with profile-nth-exec=2, and + save neuron-profile/neuron-explorer summary-json beside the cold/warm + prefill, TTFT, TPOT, and memory results. +``` + +### Runtime Validation Attempt with Device Profiling Failed + +```text +What failed: + First pfx256 segcte512stream runtime validation on the TRN2 validation host + failed during context_encoding_model NEFF load. + +How it failed: + Host: + ubuntu@16.50.61.215 + Output root: + /home/ubuntu/validation_logs/fp8_256k/pfx256_segcte512_runtime_20260526T202502Z + Phase: + pfx256k_prefill + Status: + pfx256k_prefill.status = 1 + Exact runtime errors: + Failed to allocate 1.000GB (alignment: 4.000MB, usage: shared scratchpad) + on ND 0:NC 2 + device profiling is enabled, which requires 2.348GB of HBM per NC. Either + reduce the the model size to enable profiling, or disable profiling. + Failed to allocate virtual scratchpad! + Failed to load NN: + /mnt/trainium_artifacts/qwen_artifacts/_nxd_model_workdir_256k_fp8_full_prod_pfx256k_segcte512stream_qpack4_cte3072_pfx256k_pa1024_tkg262144_20260526T195604Z/context_encoding_model/_tp0_bk0/model.MODULE_dc595ea41a524c32e935+86f42f0e.neff + Failure: NRT_RESOURCE in nrt_load_util + RuntimeError: Engine core initialization failed. + +How we got there: + Artifact: + /mnt/trainium_artifacts/qwen_artifacts/qwen36_27b_256k_fp8_full_lmheadbf16_hybrid_apc_prod_pfx256k_segcte512stream_qpack4_nki_fusedstable_directsolve_hostlogits_b256_cte3072_pfx256k_pa1024_slots64_tkg262144_async_20260526T195604Z + Validation flags: + seq_len=262144 + pa_num_blocks=1024 + cte_buckets=3072 + token_generation_buckets=262144 + context_encoding_bucket_pairs=3072:262144 + max_tokens=1 + length=261888 + Profiling environment enabled: + NEURON_RT_INSPECT_ENABLE=1 + NEURON_RT_INSPECT_DEVICE_PROFILE=1 + NEURON_RT_INSPECT_OUTPUT_DIR=/home/ubuntu/validation_logs/fp8_256k/pfx256_segcte512_runtime_20260526T202502Z/inspect + +Memory evidence: + Runtime table for the failing HBM group showed: + Model tensors: 12.052GB + Shared scratchpad: 6.000GB + Profiler buffers: 4.758GB total, 2.379GB per NC + Total shown on HBM group: 22.814GB + +Root cause / hypothesis: + Proven: + Device profiling itself adds enough HBM pressure to prevent the 256K context + NEFF from loading. The error explicitly names profiler buffers and says to + disable profiling or reduce model size. + Not proven: + This does not prove the artifact fails without profiling. The profiler + overhead is the immediate blocker for this attempt. + +Fix / mitigation: + Rerun the same pfx256 validation without NEURON_RT_INSPECT_DEVICE_PROFILE. + Keep memory sampling enabled via neuron_memory_sampler. If runtime validation + passes, profile a smaller context/shorter profile variant or capture profiling + from a reduced-shape NEFF because full 256K device profiling does not fit. + +Verification: + Pending rerun without device profiling. +``` + +### Runtime Validation Without Device Profiling Failed with DGE OOB + +```text +What failed: + The no-profile pfx256 segcte512stream runtime validation failed during the + 261888-token context prefill execution after the artifact loaded. + +How it failed: + Host: + ubuntu@16.50.61.215 + Output root: + /home/ubuntu/validation_logs/fp8_256k/pfx256_segcte512_runtime_noprofile_20260526T202721Z + Wrapper PID: + 21303 + Context sweep PID: + 21309 + Phase: + pfx256k_prefill + Log: + /home/ubuntu/validation_logs/fp8_256k/pfx256_segcte512_runtime_noprofile_20260526T202721Z/pfx256k_prefill.log + Exact runtime errors: + TDRV:exec_process_custom_notification nd0:nc2:h_model.id1006: + Received notification generated at runtime: failed to run scatter/gather + (indirect memory copy via scalar DGE), due to out-of-bound access. + model name = + /mnt/trainium_artifacts/qwen_artifacts/_nxd_model_workdir_256k_fp8_full_prod_pfx256k_segcte512stream_qpack4_cte3072_pfx256k_pa1024_tkg262144_20260526T195604Z/context_encoding_model/_tp0_bk0/model.MODULE_dc595ea41a524c32e935+86f42f0e.neff. + + NMGR:kmgr_exec_worker_do_work Async request 88 failed for model + .../context_encoding_model/_tp0_bk0/model.MODULE_dc595ea41a524c32e935+86f42f0e.neff + on vnc 1 with status 1006 + + NMGR:kmgr_async_exec_default_exec_status_callback Exec id 88 for model + 10006 on worker 1 failed with fatal status 1006... aborting. + + /opt/workspace/KaenaRuntime/kmgr/kmgr_async_exec.cc:34: + void kmgr_async_exec_default_exec_status_callback(...): + Assertion `0' failed. + + ERROR Engine core proc EngineCore_DP0 died unexpectedly, shutting down client. + +How we got there: + Artifact: + /mnt/trainium_artifacts/qwen_artifacts/qwen36_27b_256k_fp8_full_lmheadbf16_hybrid_apc_prod_pfx256k_segcte512stream_qpack4_nki_fusedstable_directsolve_hostlogits_b256_cte3072_pfx256k_pa1024_slots64_tkg262144_async_20260526T195604Z + Runtime flags: + seq_len=262144 + max_model_len=262144 + pa_num_blocks=1024 + block_size=256 + gdn_checkpoint_interval=256 + max_gdn_checkpoint_slots=64 + cte_buckets=3072 + token_generation_buckets=262144 + context_encoding_bucket_pairs=3072:262144 + lengths=261888 + max_tokens=1 + suffix_tokens=16 + require_real_tokens=true + Runtime environment: + NEURON_RT_INSPECT_ENABLE=0 + NEURON_RT_INSPECT_DEVICE_PROFILE unset + Kernel/compile path: + PREFIX_CTE_ATTENTION_BACKEND=segmented_cte + PREFIX_CTE_ATTENTION_SEGMENT_SIZE=512 + TIER_NAME=pfx256k_segcte512stream_qpack4 + +Memory evidence: + The artifact loaded before execution. The memory sampler showed the runtime + had dropped to present-only bookkeeping after the fatal DGE error, not a + NRT_RESOURCE allocation failure: + latest host RSS: about 1.08 GiB for qwen36_hybrid_apc_context_sweep.py + neuron present: about 6.46GB total + latest total bytes: 0 + This separates this failure from the earlier device-profiling HBM failure. + +Root cause / hypothesis: + Proven: + The 256K pfx artifact compiles and loads without device profiling, but the + context_encoding_model NEFF issues an out-of-bound scalar DGE access during + execution. + Best current hypothesis: + The qwen_segcte256 segmented CTE kernel has a runtime address calculation + bug for the actual long-prefix path. The likely fault is in the mapping of + block-table, prior segment, active segment, or kv_section_idx offsets for + the 261888-token request. BIR scratch sizing and compile legality did not + catch it because the address goes out of range only with real runtime block + tables and long-prefix execution. + +Fix / mitigation applied: + Stopped the failed validation wrapper and context sweep on ubuntu@16.50.61.215 + to free Neuron resources: + kill -TERM 21309 / children through wrapper PID 21303, then kill stale + sampler PID 21308. + +Next mitigation: + Do not retry the same pfx256 segcte512stream artifact as production. Build a + targeted runtime/addressing debug path for qwen_segcte256: + 1. Reproduce with a smaller debug prefix artifact or a reduced long-prefix + request that still uses segmented_cte address math. + 2. Instrument or assert the block table index, prior segment start, active + stream start, kv_section_idx, and max addressed block before DGE loads. + 3. Patch the segmented CTE offset mapping, then rerun BIR preflight and a + no-profile runtime prefill before enabling any profiling. + +Verification: + Validation did not complete. No prefill, TTFT, TPOT, or chat metrics were + produced for this artifact. +``` + +### Null-Block PA Count Mismatch Hypothesis for DGE OOB + +```text +Additional evidence: + The failing pfx256 artifact was compiled with: + pa_num_blocks=1024 + block_size=256 + max_context_length=262144 + Runtime vLLM logs showed: + Adding 1 to num_gpu_blocks_override (1024 -> 1025) to account for null + block allocation + User provided pa_num_blocks (1024) matching original + --num-gpu-blocks-override intent. Incrementing pa_num_blocks to 1025 to + match the increment for a null block in vllm. + +Why this matters: + For vLLM, the user-intended usable block count for 256K at block size 256 is + 1024. vLLM adds one reserved null block, so the physical block-KV cache needs + 1025 blocks. The current artifact was compiled as pa1024, so a block-table + value of 1024 can be legal to vLLM but out of bounds for the compiled NEFF's + raw block-KV cache. That matches the observed scalar-DGE OOB in + context_encoding_model. + +Root cause / hypothesis update: + Best current hypothesis is now a PA physical-block sizing mismatch, not + scratch/HBM pressure. The qwen_segcte256 kernel may still need address tests, + but the first robust/simple fix to try is compiling the artifact with 1025 + physical PA blocks while running vLLM with 1024 usable blocks. + +Fix applied to validation scripts: + Updated validation_scripts/qwen36_hybrid_apc_context_sweep.py and + validation_scripts/qwen36_offline_decode_bench.py so artifact pa_num_blocks + is treated as physical block count. When the artifact uses block KV or prefix + caching and has more blocks than the minimum usable request, validation passes + artifact_pa_num_blocks - 1 as vLLM's num_gpu_blocks_override. + +Next mitigation: + Compile a replacement pfx256 segcte512stream artifact with: + PA_NUM_BLOCKS=1025 + PREFIX_CTE_ATTENTION_SEGMENT_SIZE=512 + CTE_BUCKETS_STR=3072 + PAIR_ARGS_STR=3072:262144 + PREFIX_BUCKETS_STR=262144 + Then validate it with user-usable pa override 1024 so vLLM adds the null + block back to 1025. +``` + +### PA1025 Relaunch Setup Errors and Correction + +```text +What failed: + First corrected relaunch attempt on ubuntu@16.50.61.215 used: + TIER_NAME=pfx256k_segcte512stream_qpack4_pafix + PA_NUM_BLOCKS=1025 + PREFIX_CTE_ATTENTION_BACKEND=segmented_cte + PREFIX_CTE_ATTENTION_SEGMENT_SIZE=512 + PAIR_ARGS_STR=3072:262144 + +How it failed: + The remote helper script was stale and hardcoded: + --pa-num-blocks 1024 + The resulting process PID 28426 was running a pa1024 compile even though the + environment requested PA_NUM_BLOCKS=1025. The log showed: + CONTEXT_TRACE_SHAPE ... "pa_num_blocks": 1024, "pa_min_blocks": 1024, + "pa_headroom_blocks": 0 + +How we got there: + The local helper had already been updated to include _pa${PA_NUM_BLOCKS} in + the artifact name and to pass --pa-num-blocks "${PA_NUM_BLOCKS}", but that + helper had not been synced to ubuntu@16.50.61.215. + +Fix / mitigation applied: + Stopped PID 28426 before useful compilation work continued, synced the local + helper to: + /home/ubuntu/inferentia-gdn-fused-noclamp-4340808/tmp_compile_qwen256k_fp8_full_prod_prefix_tier_hostlogits.sh + Verified the synced helper contains: + BASE=..._pa${PA_NUM_BLOCKS}_... + --pa-num-blocks "${PA_NUM_BLOCKS}" + +Verification: + Relaunch produced a _pa1025_ artifact name. +``` + +```text +What failed: + The next relaunch on ubuntu@16.50.61.215 failed before compilation started. + +Exact error: + qwen36_27b_compile_fp8.py: error: unrecognized arguments: + --omit-zero-prefix-pair + --prefix-cte-attention-backend segmented_cte + --prefix-cte-attention-segment-size 512 + +How we got there: + Remote repo branch was codex/full-fp8-qwen36 at 03e7e3a, but + contrib/models/Qwen3.6-27B/test/integration/qwen36_27b_compile_fp8.py was + stale relative to the local full-FP8 branch work. The runtime modules already + had segmented_cte support, but the compile entrypoint did not expose the + required CLI flags. + +Fix / mitigation applied: + Synced the local compile entrypoint to: + /home/ubuntu/inferentia-gdn-fused-noclamp-4340808/contrib/models/Qwen3.6-27B/test/integration/qwen36_27b_compile_fp8.py + Verified with: + python3 -m py_compile contrib/models/Qwen3.6-27B/test/integration/qwen36_27b_compile_fp8.py + grep for omit-zero-prefix, prefix-cte-attention, and segmented_cte. + +Verification: + Corrected relaunch started as PID 29224 with: + artifact: + /mnt/trainium_artifacts/qwen_artifacts/qwen36_27b_256k_fp8_full_lmheadbf16_hybrid_apc_prod_pfx256k_segcte512stream_qpack4_pafix_nki_fusedstable_directsolve_hostlogits_b256_cte3072_pfx256k_pa1025_slots64_tkg262144_async_20260526T205447Z + log: + /home/ubuntu/validation_logs/fp8_256k/qwen36_27b_256k_fp8_full_lmheadbf16_hybrid_apc_prod_pfx256k_segcte512stream_qpack4_pafix_nki_fusedstable_directsolve_hostlogits_b256_cte3072_pfx256k_pa1025_slots64_tkg262144_async_20260526T205447Z_compile.log + The compile log now shows: + CONTEXT_TRACE_SHAPE ... "pa_num_blocks": 1025, + "pa_min_blocks": 1024, + "pa_headroom_blocks": 1, + "prefix_cte_attention_backend": "segmented_cte", + "prefix_cte_attention_segment_size": 512 + and then enters HLO generation for context_encoding_model. +``` + +```text +What failed: + The PA1025 relaunch at 20260526T205447Z reached HLO tracing but failed in + Python before neuronx-cc compilation. + +Exact error: + AttributeError: 'QwenHybridBlockKVCacheManager' object has no attribute + 'get_raw_kv_by_layer_id'. Did you mean: 'get_kv_by_layer_id'? + +Evidence: + PID/log: + /home/ubuntu/validation_logs/fp8_256k/qwen36_27b_256k_fp8_full_lmheadbf16_hybrid_apc_prod_pfx256k_segcte512stream_qpack4_pafix_nki_fusedstable_directsolve_hostlogits_b256_cte3072_pfx256k_pa1025_slots64_tkg262144_async_20260526T205447Z_compile.pid + /home/ubuntu/validation_logs/fp8_256k/qwen36_27b_256k_fp8_full_lmheadbf16_hybrid_apc_prod_pfx256k_segcte512stream_qpack4_pafix_nki_fusedstable_directsolve_hostlogits_b256_cte3072_pfx256k_pa1025_slots64_tkg262144_async_20260526T205447Z_compile.log + Stack: + modeling_qwen35.py:get_cache -> self.get_raw_kv_by_layer_id(...) + torch.nn.Module.__getattr__ raised AttributeError. + +How we got there: + The Qwen model file and attention path expected the newer raw block-KV cache + accessor, but the remote + src/neuronx_distributed_inference/modules/kvcache/block_kv_cache_manager.py + had not been synced with the matching full-FP8 branch changes. + +Fix / mitigation applied: + Synced the matching local cache/runtime files to ubuntu@16.50.61.215: + src/neuronx_distributed_inference/modules/kvcache/block_kv_cache_manager.py + src/neuronx_distributed_inference/models/config.py + src/neuronx_distributed_inference/models/model_wrapper.py + src/neuronx_distributed_inference/modules/async_execution.py + src/neuronx_distributed_inference/modules/autobucketing.py + src/neuronx_distributed_inference/modules/attention/attention_base.py + src/neuronx_distributed_inference/modules/attention/nki_kernels/ + Verified with py_compile and confirmed: + def get_raw_kv_by_layer_id(self, idx, kvcache_buffer=None, **kwargs) + +Verification: + Relaunched as PID 30174 with artifact: + /mnt/trainium_artifacts/qwen_artifacts/qwen36_27b_256k_fp8_full_lmheadbf16_hybrid_apc_prod_pfx256k_segcte512stream_qpack4_pafix_nki_fusedstable_directsolve_hostlogits_b256_cte3072_pfx256k_pa1025_slots64_tkg262144_async_20260526T205813Z + Current log shows: + CONTEXT_TRACE_SHAPE ... "pa_num_blocks": 1025, + "pa_headroom_blocks": 1, + "prefix_cte_attention_backend": "segmented_cte", + "prefix_cte_attention_segment_size": 512 + Finished generating HLO for context_encoding_model + Started loading module token_generation_model +``` + +```text +Operator error: + During the remote sync fix, one multi-file scp command targeted the attention + directory for all source files. It created extra inert copies under: + src/neuronx_distributed_inference/modules/attention/config.py + src/neuronx_distributed_inference/modules/attention/model_wrapper.py + src/neuronx_distributed_inference/modules/attention/async_execution.py + src/neuronx_distributed_inference/modules/attention/autobucketing.py + +Impact / hypothesis: + These files are not imported by the current attention package path, but they + are remote workspace clutter and should be removed after explicit approval or + during the next cleanup pass. + +Fix / mitigation applied: + Re-copied each file to its correct destination. No compile path depends on + the accidental files. +``` + +### PA1025 pfx256 Runtime Validation Failed with DGE OOB + +```text +What failed: + No-device-profile runtime validation of the corrected PA1025 pfx256k + segmented CTE artifact on TRN2 ubuntu@16.50.61.215. + +Artifact: + /mnt/trainium_artifacts/qwen_artifacts/qwen36_27b_256k_fp8_full_lmheadbf16_hybrid_apc_prod_pfx256k_segcte512stream_qpack4_pafix_nki_fusedstable_directsolve_hostlogits_b256_cte3072_pfx256k_pa1025_slots64_tkg262144_async_20260526T205813Z + +Validation output root: + /home/ubuntu/validation_logs/fp8_256k/pfx256_pa1025_runtime_noprofile_20260527T034323Z + +Inputs / flags: + validation_scripts/qwen36_hybrid_apc_context_sweep.py + --lengths 261888 + --max-tokens 16 + --suffix-tokens 16 + --seq-len 262144 + --max-model-len 262144 + --cte-buckets 3072 + --context-encoding-bucket-pairs 3072:262144 + --token-generation-buckets 262144 + --async-mode + --block-size 256 + --gdn-checkpoint-interval 256 + --max-gdn-checkpoint-slots 64 + --gdn-recurrent-cache-dtype float32 + --gdn-conv-cache-dtype bfloat16 + --require-real-tokens + Device profiling was explicitly disabled: + unset NEURON_RT_INSPECT_ENABLE + unset NEURON_RT_INSPECT_DEVICE_PROFILE + unset NEURON_RT_INSPECT_OUTPUT_DIR + +Observed runtime context: + Engine loaded the compiled artifact successfully. + vLLM reported: + GPU KV cache size: 262,400 tokens + Maximum concurrency for 262,144 tokens per request: 1.00x + Prompt execution started for the 261888-token request. + +Exact error: + At 2026-05-27T03:53:20Z, the context_encoding_model NEFF repeatedly emitted: + TDRV:exec_process_custom_notification ... failed to run scatter/gather + (indirect memory copy via scalar DGE), due to out-of-bound access. + model name = + /mnt/trainium_artifacts/qwen_artifacts/_nxd_model_workdir_256k_fp8_full_prod_pfx256k_segcte512stream_qpack4_pafix_cte3072_pfx256k_pa1025_tkg262144_20260526T205813Z/context_encoding_model/_tp0_bk0/model.MODULE_30b568c5d3faaeced212+b0ee5af3.neff + The first repeated failures were on Neuron cores nc4/nc6 and then expanded + across other cores. + +Memory evidence: + This was not an NRT_RESOURCE/OOM failure. The memory sampler captured: + peak_host_rss_gib: 35.310752868652344 + peak_neuron_by_category_gib.total: 157.06698608398438 + peak_neuron_by_category_gib.present: 7.32489013671875 + Note: peak_neuron_total_gib in the sampler summary sums peak/present/total + categories and should not be used as a real HBM footprint. + +Root cause / hypothesis update: + The null-block PA mismatch was not the root cause for the 256K runtime + failure. Compiling with 1025 physical PA blocks fixed the compile/load shape + and provided physical capacity for the null block, but the actual long-prefix + segmented CTE path still generates an out-of-range scalar-DGE address at + runtime. The best current hypothesis is now a qwen_segcte256 address-mapping + bug for the pfx256k context_encoding bucket, likely in block-table indexing, + prior-segment offset, active-stream offset, or kv_section_idx mapping inside + the custom segmented CTE kernel. + +Fix / mitigation applied: + Stopped the failed validation run and sampler after the DGE OOB: + wrapper PID: 31812 + sampler PID: 31814 + context sweep PID: 31815 + EngineCore PID: 31872 + Verified those PIDs were no longer present afterward. + +Remaining blocker: + The PA1025 pfx256k artifact is not runtime-valid and is not production-ready. + Do not run OpenAI/server TTFT/TPOT validation on this artifact until the + segmented CTE 256K address calculation is fixed or replaced. + +Next mitigation: + Build a targeted qwen_segcte256 debug/fix path: + 1. Reproduce with a small diagnostic harness that exercises the same + segmented CTE addressing with controlled block_table values. + 2. Add bounds checks or debug-side assertions for physical block id, + kv_head/block offset, prior segment start, active segment start, and + kv_section_idx before the DGE loads. + 3. Patch the qwen_segcte256 NKI address mapping, then recompile the + pfx256k bucket and rerun no-device-profile runtime validation. +``` + +```text +Operator/status-check errors encountered during this validation: + +1. A status check command exited 127 because it used `python` in a remote + non-login shell where only the activated venv process had `python` on PATH. + The validation process itself was unaffected. Mitigation: subsequent status + parsing used `python3` or an activated venv. + +2. The first cleanup command used a broad pgrep pattern: + qwen36_hybrid_apc_context_sweep|VLLM::EngineCore|neuron_memory_sampler + and matched its own SSH-side shell command, causing the SSH cleanup command + to exit 255 before printing post-cleanup status. Mitigation: reran cleanup + with explicit known PIDs 31812, 31814, 31815, and 31872, then verified they + were no longer running. +``` + +### Segmented CTE Active Block-Table Fill Fix + +```text +What failed: + Follow-up investigation of the PA1025 pfx256k runtime DGE OOB found that the + segmented CTE kernel reads the active suffix K/V from the raw paged KV cache + through active_block_table. If active suffix logical block-table entries are + still unset, the NKI kernel can consume an invalid block id for scalar DGE. + +Evidence / how we got there: + Artifact: + /mnt/trainium_artifacts/qwen_artifacts/qwen36_27b_256k_fp8_full_lmheadbf16_hybrid_apc_prod_pfx256k_segcte512stream_qpack4_pafix_nki_fusedstable_directsolve_hostlogits_b256_cte3072_pfx256k_pa1025_slots64_tkg262144_async_20260526T205813Z + Compile trace shape: + context_encoding_bucket_pairs=[[3072,262144]] + pa_num_blocks=1025 + pa_min_blocks=1024 + pa_headroom_blocks=1 + prefix_cte_attention_backend=segmented_cte + prefix_cte_attention_segment_size=512 + Runtime loaded the artifact with pa_num_blocks=1025, then failed inside the + context_encoding_model NEFF with: + failed to run scatter/gather (indirect memory copy via scalar DGE), + due to out-of-bound access + There were no debug `pad-pre`, `pad-post`, or `qwen-cte-call` lines in the + failed validation log because QWEN36_HYBRID_APC_DEBUG was not enabled. + +Root cause / best current hypothesis: + BlockKVCacheManager writes the active suffix K/V into the raw block cache by + slot_mapping. The qwen_segcte256 path then reads active K/V from the raw block + cache by active_block_table. For segmented CTE, active_block_table must contain + physical block ids for logical active-suffix blocks as well as prefix blocks. + If those active entries remain -1 or otherwise unset, the NKI kernel casts the + block table to uint32 and can form a huge scalar-DGE HBM offset. That matches + the observed runtime-only scalar DGE OOB after successful load. + +Fix / mitigation applied locally: + Patched: + src/neuronx_distributed_inference/models/model_wrapper.py + Added segmented-CTE-only input preprocessing in `_pad_prefix_caching_inputs`: + - derive active logical block positions from computed_context_lens + token + index + - derive active physical block ids from slot_mapping // pa_block_size + - fill those active logical block-table entries before masking/padding + - include active tokens when sizing the segmented CTE block table + This leaves the non-segmented attention_cte path unchanged. + +Test added: + test/unit/models/test_prefix_caching_bucket_selection.py + test_segmented_cte_padding_fills_active_block_table_from_slots + The focused case starts with block_table [[0, 1, 2, -1]], prefix_len=768, + suffix_len=48, pa_block_size=256, and slot_mapping in physical block 4. The + expected padded block table is [[0, 1, 2, 4]]. + +Local verification: + Command: + python3 -m py_compile src/neuronx_distributed_inference/models/model_wrapper.py test/unit/models/test_prefix_caching_bucket_selection.py + Result: + pass + +Local test environment errors: + Command: + python3 -m pytest test/unit/models/test_prefix_caching_bucket_selection.py -q + Result: + exit 2 during collection + Exact error: + ModuleNotFoundError: No module named 'neuronx_distributed_inference' + Mitigation: + reran with PYTHONPATH=src + + Command: + PYTHONPATH=src python3 -m pytest test/unit/models/test_prefix_caching_bucket_selection.py -q + Result: + exit 2 during collection + Exact error: + ModuleNotFoundError: No module named 'neuronx_distributed' + Root cause / hypothesis: + The local Mac environment lacks the Neuron/NxD Python dependency needed for + this test module. This is an environment dependency issue, not a syntax + failure; py_compile passed locally. + +Next verification: + Sync the patch to TRN2 ubuntu@16.50.61.215, run py_compile and the focused + pytest in the Neuron venv, then rerun a no-device-profile pfx256 validation + with QWEN36_HYBRID_APC_DEBUG=1. If the shorter debug validation passes, rerun + the original 261888-token validation against the same compiled artifact. +``` + +### Active Block-Table Fill Validation Results + +```text +What passed: + Remote syntax/unit validation on TRN2 ubuntu@16.50.61.215 after syncing: + src/neuronx_distributed_inference/models/model_wrapper.py + test/unit/models/test_prefix_caching_bucket_selection.py + contrib/models/Qwen3.6-27B/docs/QWEN36_FP8_TIERFIX_VALIDATION_20260526.md + +Command: + cd /home/ubuntu/inferentia-gdn-fused-noclamp-4340808 + source /opt/aws_neuronx_venv_pytorch_inference_vllm_0_16/bin/activate + PYTHONPATH=src python -m py_compile \ + src/neuronx_distributed_inference/models/model_wrapper.py \ + test/unit/models/test_prefix_caching_bucket_selection.py + PYTHONPATH=src python -m pytest \ + test/unit/models/test_prefix_caching_bucket_selection.py -q + +Result: + 35 passed, 46 warnings in 5.33s + +What passed at runtime: + Short debug validation with the PA1025 pfx256k artifact and the local + active-block-table fill patch completed without DGE OOB. + +Artifact: + /mnt/trainium_artifacts/qwen_artifacts/qwen36_27b_256k_fp8_full_lmheadbf16_hybrid_apc_prod_pfx256k_segcte512stream_qpack4_pafix_nki_fusedstable_directsolve_hostlogits_b256_cte3072_pfx256k_pa1025_slots64_tkg262144_async_20260526T205813Z + +Output root: + /home/ubuntu/validation_logs/fp8_256k/pfx256_pa1025_activeblockfix_short_20260527T_test + +Inputs: + --lengths 8192 + --max-tokens 1 + --suffix-tokens 16 + --seq-len 262144 + --max-model-len 262144 + --cte-buckets 3072 + --context-encoding-bucket-pairs 3072:262144 + --token-generation-buckets 262144 + --async-mode + --block-size 256 + --gdn-checkpoint-interval 256 + --max-gdn-checkpoint-slots 64 + --gdn-recurrent-cache-dtype float32 + --gdn-conv-cache-dtype bfloat16 + --require-real-tokens + QWEN36_HYBRID_APC_DEBUG=1 + +Short-run evidence: + The debug trace showed the active/prefix block table now includes the active + physical block range. Examples: + prefix_len=6144, slot_mapping max=8447, block_table max=32 + prefix_len=6144, slot_mapping max=18687, block_table max=72 + The 8192-token run completed: + cold elapsed: 14.63105383799848s + warm elapsed: 4.818916980999347s + real_tokens_passed: true + +What still failed: + Full 261888-token validation with the same artifact and patch still failed + inside the context_encoding_model NEFF with scalar DGE OOB. + +Output root: + /home/ubuntu/validation_logs/fp8_256k/pfx256_pa1025_activeblockfix_full_20260527T0524Z + +Inputs: + --lengths 261888 + --max-tokens 16 + --suffix-tokens 16 + --seq-len 262144 + --max-model-len 262144 + --cte-buckets 3072 + --context-encoding-bucket-pairs 3072:262144 + --token-generation-buckets 262144 + --async-mode + --block-size 256 + --gdn-checkpoint-interval 256 + --max-gdn-checkpoint-slots 64 + --gdn-recurrent-cache-dtype float32 + --gdn-conv-cache-dtype bfloat16 + --require-real-tokens + Device profiling and QWEN36_HYBRID_APC_DEBUG were disabled for the full run. + +Exact error: + First repeated failures at run.log lines 2592+: + 2026-May-27 05:21:35.021738 ... ERROR TDRV:exec_process_custom_notification + nd0:nc6:h_model.id1005: Received notification generated at runtime: + failed to run scatter/gather (indirect memory copy via scalar DGE), + due to out-of-bound access. model name = + /mnt/trainium_artifacts/qwen_artifacts/_nxd_model_workdir_256k_fp8_full_prod_pfx256k_segcte512stream_qpack4_pafix_cte3072_pfx256k_pa1025_tkg262144_20260526T205813Z/context_encoding_model/_tp0_bk0/model.MODULE_30b568c5d3faaeced212+b0ee5af3.neff. + The same error appeared on nc4, nc5, nc6, nc7 and later nc0/nc1/nc2/nc3. + The runtime also reported: + TDRV:exec_request_process_errors [ND 0][NC 6] Out of bounds access on model ... + NMGR:dlr_exec_wait Execution completed with err: 1006. mode->h_nn=1008, lnc=2 + +Core dump evidence: + Neuron generated NRT_EXEC_OOB dumps: + /tmp/neuron-core-dump/dt-20260527-051233-cid-d99e36ea74c263ca + i-05d3f024966df11d5-nd0-nc4-pid-39738-tid-39861-lid-1 + i-05d3f024966df11d5-nd0-nc6-pid-39738-tid-39862-lid-2 + i-05d3f024966df11d5-nd0-nc2-pid-39738-tid-39863-lid-3 + +Memory evidence: + This was not a Neuron load OOM/NRT_RESOURCE failure. + Memory summary: + peak_host_rss_gib: 34.55003356933594 + peak_neuron_by_category_gib.present: 6.589611053466797 + peak_neuron_by_category_gib.total: 157.06698608398438 + As before, the sampler's peak_neuron_total_gib sums sysfs categories and is + not a single real HBM allocation. + +Root cause / hypothesis update: + The active-block-table fill is necessary and fixes a real input-prep hazard, + but it is not sufficient for the pfx256/261888 path. The remaining scalar DGE + OOB is likely inside qwen_segcte256 address generation for high prior segment + indices, for example: + - prior segment block-table offset when prefix_len approaches 256K + - the first/last partial-prior segment around a 512-token segment boundary + - segment index to block-table index arithmetic in the NKI kernel + - kv_section_idx or KV-head/block offset at high logical block ids + This is now confirmed as a kernel/addressing bug, not a PA1025 capacity issue + and not just missing active physical block ids. + +Fix / mitigation applied: + Stopped the failed full validation and sampler: + sampler PID: 39613 + wrapper bash PID: 39684 + context sweep PID: 39692 + EngineCore PID: 39738 + PID 39738 became a short-lived defunct EngineCore while neuron-dump wrote + NRT_EXEC_OOB dumps. No qwen36 sweep/sampler process remained afterward. + +Remaining blocker: + The PA1025 pfx256k segmented CTE artifact is still not runtime-valid for + 261888-token / 256K-context serving. It must not be called production-ready. + +Next mitigation: + Add high-prefix debug instrumentation or a CPU/NKI address simulator for + qwen_segcte256 and binary-search the failing prefix length with the pfx256 + artifact. The short 8K smoke is not enough; test lengths should bracket the + failure, e.g. 32768, 65536, 131072, 196608, 229376, and 261888, with debug + enabled only around the final failing CTE chunk. +``` + +```text +Operator errors during the active-block-table validation: + +1. The first full-run wrapper backgrounded too broad a shell command and lost + ROOT/PATH state. It printed: + tee: /run.log: Permission denied + bash: line 1: python: command not found + The validation did not start. An orphaned sampler PID 39303 was killed. + Mitigation: reran with explicit absolute output paths and separate sampler + launch. + +2. The first separate sampler launch quoted ROOT incorrectly inside nested + local/remote shell expansion. It printed: + mkdir: missing operand + bash: line 1: /sampler.pid: Permission denied + No validation ran from that command. Mitigation: relaunched sampler with + literal absolute paths. +``` + +### Root Cause Found: Final Partial Active Chunk Reads Past Block Table + +```text +What failed: + The PA1025 pfx256k artifact still emitted scalar DGE OOB at 261888 tokens even + after the Python active-block-table fill. The 8192-token smoke passed, which + meant the remaining bug was specific to high-prefix / end-of-context address + generation. + +Code path: + src/neuronx_distributed_inference/modules/attention/nki_kernels/qwen_segcte256/ + attention_segmented_cte_256.py + fused_segmented_attention_256.py + +Root cause: + In qwen_segcte256 active streaming, the compiled 3072-token CTE bucket is + split into six 512-token active stream sections: + active_stream_tokens = 512 + num_active_stream_sections = ceil(3072 / 512) = 6 + num_blocks_per_active_stream = 512 / 256 = 2 + + For the final real chunk of a 261888-token prompt: + prior_tokens = 261120 + real active_len = 768 + active_block_offset = prior_tokens // 256 = 1020 + + The compiled active-stream loop still loads all six bucket sections, so the + block-table offsets are: + 1020, 1022, 1024, 1026, 1028, 1030 + + A real pfx256 block_table has 1024 entries. The older internal padding only + padded to 1026 entries for the prior-segment one-past read: + padded_width = (1024 // 2 + 1) * 2 = 1026 + + Therefore active sections 4 and 5 can read block-table offsets 1028/1030, + outside the internally padded table. That exactly matches AWS Neuron's DGE + docs: scalar/vector DGE offsets must still resolve to valid tensor addresses, + otherwise runtime reports out-of-bound scatter/gather. + +Fix applied locally: + Patched `attention_segmented_cte_256.py` to pad the internal block table for + both hazards: + padded_width_for_prior = one extra prior segment + padded_width_for_active_stream = max_blocks_per_seq + seqlen_q // block_size + padded_width = rounded max of both + + For pfx256 cte3072 this pads from 1024 to 1036 entries, so out-of-range + compiled active-stream sections read zero block ids from the padded tail + instead of DGE-reading past the block table. Block id 0 is the existing null + block, so this matches the intended padding semantics. + +Why this aligns with docs: + The NKI/DGE docs allow dynamic/scalar-offset DMA patterns, but the program is + responsible for keeping the dynamic address inside the tensor. Padding the + source table before the scalar DGE access is the simple robust fix; relying on + masks after the DMA is too late because the OOB happens during the DMA + descriptor execution. + +Verification so far: + Local syntax: + python3 -m py_compile \ + src/neuronx_distributed_inference/modules/attention/nki_kernels/qwen_segcte256/attention_segmented_cte_256.py + result: pass + +Remaining work: + Sync to TRN2, run remote py_compile, recompile the pfx256 segmented CTE + artifact, then rerun the 261888-token no-device-profile validation. The old + PA1025 artifact cannot be fixed in place because this change is inside the + compiled NKI kernel. +``` + +### Bound-Fix PFX256 Runtime Validation Passed + +```text +Artifact: + /mnt/trainium_artifacts/qwen_artifacts/ + qwen36_27b_256k_fp8_full_lmheadbf16_hybrid_apc_prod_pfx256k_segcte512stream_qpack4_boundfix_nki_fusedstable_directsolve_hostlogits_b256_cte3072_pfx256k_pa1025_slots64_tkg262144_async_20260527T052822Z + +Validation root: + /home/ubuntu/validation_logs/fp8_256k/pfx256_boundfix_runtime_20260527T0552Z + +Inputs: + length: 261888 prompt tokens + max_tokens: 16 + seq_len/max_model_len: 262144 + cte/prefix pair: 3072:262144 + token_generation_bucket: 262144 + pa_num_blocks: 1025 + backend: segmented_cte + segment_size: 512 + profiling: disabled + +Result: + passed: true + real_tokens_passed: true + token_range_passed: true + non_dummy_generated_token_count: 48 + unique_generated_token_count: 41 + +Timings: + cold prefill+decode: 551.9684613459976s + prefix warmup prefill+decode: 551.5538088760004s + measured warm/refill+decode: 10.758389775000978s + cold effective prompt throughput: 474.46189110402327 tokens/s + warm/refill effective prompt throughput: 24342.67631839693 tokens/s + +Memory summary: + peak_host_rss_gib: 35.314510345458984 + peak_neuron_by_category_gib.present: 6.291294097900391 + peak_neuron_by_category_gib.total: 159.61318969726562 + +Notes: + The sysfs Neuron memory sampler aggregates categories and logical cores; the + `present` category is the most useful live resident counter from this sampler. + The larger `total` and `peak` aggregates are not single-device HBM usage. + +Monitor-side error encountered: + Command: + ssh ... 'python - < 262144). Running this sequence through the model will result in indexing errors + It had already completed 32768, 65536, and 131072 rows successfully. + The stuck child was manually terminated, so the suite recorded: + [2026-05-27T08:33:26+00:00] END server_context_bench rc=143 + How we got there: + validation_scripts/qwen36_chat_completion_context_bench.py was run with: + --lengths 32768,65536,131072,261888 --turns 8 --repeats 1 + The old prompt builder doubled filler repetitions until it exceeded the + target, which created a transient 426209-token chat-template probe for the + 261888-token target. + Root cause: + Validation harness bug, not a Neuron runtime/model failure. The prompt + builder used exponential overshoot probes that are too large near the + 262144-token model limit. + Fix: + Updated _make_messages in validation_scripts/qwen36_chat_completion_context_bench.py + to estimate filler repetitions from one-repeat token delta and correct + downward instead of doubling past the target. Synced the fixed script to + TRN2 and reran only server startup + server_context_bench. + Verification: + python3 -m py_compile validation_scripts/qwen36_chat_completion_context_bench.py + passed locally. + Corrected context bench passed: + 32768 target: prompt 32764, status 200, TTFT 57.0513s, completion 16 + 65536 target: prompt 65524, status 200, TTFT 66.4657s, completion 16 + 131072 target: prompt 131070, status 200, TTFT 132.5198s, completion 9 + 261888 target: prompt 261876, status 200, TTFT 319.2527s, completion 16 + +Memory summaries: + Primary server peak host RSS: 35.3254 GiB + Corrected server peak host RSS: 35.3525 GiB + Corrected server live Neuron `present` peak: 10.7526 GiB from sampler + +Monitor/tooling errors encountered: + write_stdin failed when attempting to interrupt old tail sessions: + stdin is closed for this session; rerun exec_command with tty=true to keep stdin open + This was a local monitoring-tool state issue. It did not affect remote + validation. The old remote suite had already exited and the corrected rerun + used a new tail session. + + A local sandboxed ps probe failed: + zsh:1: operation not permitted: ps + This was local sandboxing, not a repo or remote failure. Remote process + checks were done through ssh instead. + + The command used to stop the completed remote live tail returned ssh exit + code 255 with no stderr: + ssh ... 'pkill -f "tail -n 80 -F .*prod_readiness_boundfix_contextbench_capped_20260527T083606Z" || true' + Hypothesis: + pkill matched and terminated the remote tail/ssh session while the command + was still attached, so ssh reported disconnect as 255. + Verification: + The tail session then reported `Process exited with code 255`; validation + had already completed and the server had already shut down cleanly. +``` diff --git a/contrib/models/Qwen3.6-27B/docs/patches/mtp_batched_accept.patch b/contrib/models/Qwen3.6-27B/docs/patches/mtp_batched_accept.patch new file mode 100644 index 00000000..a4ed7d11 --- /dev/null +++ b/contrib/models/Qwen3.6-27B/docs/patches/mtp_batched_accept.patch @@ -0,0 +1,154 @@ +From: Deepankar Singh +Subject: [PATCH] Qwen3.6-27B OpenAI server: consume all fused-spec accepted tokens per host loop + +The decode loop currently calls _token_scalar(out.tokens) which flattens the +output and returns only the first token. For non-spec inference this is +correct (out.tokens has shape (1,1)). For fused-spec MTP it discards all +accepted tokens after the first, causing the host to advance only one token +per Python iteration even when the device accepted multiple. + +Root cause of the observed 1.6x MTP gain vs the expected 2-2.5x: + - spec length = 2 in the artifact + - device returns N accepted tokens per forward (N in 1..2 typically) + - server keeps only tokens[0], requeues the rest implicitly by feeding + tokens[0] back as the next input + - effective speedup = 1 + (P_accept_2)*0.5 ~= 1.6x at high acceptance + +Fix: + - Add _accepted_tokens() that returns the prefix of in-vocab non-pad tokens + - Rewrite the decode loop as a while-loop that runs one forward per + iteration and commits ALL accepted tokens (up to max_tokens, stopping at + first EOS). + - Pre-allocate decode_ids / decode_position_ids / decode_attention_mask + (already done in the current code; preserved). + - Position-id update uses the position of the most recently committed + token: pos_value = prompt_tokens + len(new_ids) - 1. + +Expected result: decode 1.6x -> 2.0-2.4x on the same artifact, no recompile. + +--- + contrib/models/Qwen3.6-27B/scripts/openai_compat_server.py | 75 +++++++++++++-------- + 1 file changed, 47 insertions(+), 28 deletions(-) + +diff --git a/contrib/models/Qwen3.6-27B/scripts/openai_compat_server.py b/contrib/models/Qwen3.6-27B/scripts/openai_compat_server.py +--- a/contrib/models/Qwen3.6-27B/scripts/openai_compat_server.py ++++ b/contrib/models/Qwen3.6-27B/scripts/openai_compat_server.py +@@ -50,6 +50,7 @@ + return str(prompt) + + ++# Legacy: keep for prefill path which always returns a single token. + def _token_scalar(tokens: Any) -> int: + if hasattr(tokens, "detach"): + tokens = tokens.detach().cpu() +@@ -58,6 +59,36 @@ def _token_scalar(tokens: Any) -> int: + return int(tokens.reshape(-1)[0].item()) + + ++def _accepted_tokens( ++ tokens: Any, ++ vocab_size: int, ++ pad_id: Any = None, ++) -> List[int]: ++ """Return the prefix of legitimately-accepted tokens from a fused-spec output. ++ ++ For non-spec inference, ``out.tokens`` is shape (1, 1) and this returns a ++ 1-element list. For fused-spec MTP with speculation length K, ``out.tokens`` ++ is shape (1, K) where unused slots are padded with -1 or pad_token_id. ++ Scan left-to-right; the first out-of-vocab or pad value marks the boundary. ++ """ ++ if hasattr(tokens, "detach"): ++ tokens = tokens.detach().cpu() ++ if hasattr(tokens, "ndim") and tokens.ndim == 0: ++ v = int(tokens.item()) ++ if 0 <= v < vocab_size and (pad_id is None or v != pad_id): ++ return [v] ++ return [] ++ flat = tokens.reshape(-1).tolist() ++ accepted: List[int] = [] ++ for raw in flat: ++ v = int(raw) ++ if v < 0 or v >= vocab_size: ++ break ++ if pad_id is not None and v == pad_id: ++ break ++ accepted.append(v) ++ return accepted ++ ++ + class QwenOpenAIServer: + def __init__(self, args: argparse.Namespace): + self.args = args +@@ -187,8 +218,7 @@ class QwenOpenAIServer: + if first_token is None: + raise RuntimeError("prefill produced no token") + +- new_ids = [] +- current_token = first_token ++ new_ids: List[int] = [] + vocab_size = len(self.tokenizer) + raw_eos_id = self.tokenizer.eos_token_id + eos_ids = ( +@@ -196,6 +226,7 @@ class QwenOpenAIServer: + if isinstance(raw_eos_id, (list, tuple, set)) + else {raw_eos_id} + ) ++ pad_id = self.tokenizer.pad_token_id + decode_ids = torch.empty((1, 1), dtype=torch.int32) + decode_position_ids = torch.empty((1, 1), dtype=torch.int32) + decode_attention_mask = torch.ones( +@@ -203,21 +234,24 @@ class QwenOpenAIServer: + dtype=torch.int32, + ) + finish_reason = "length" ++ ++ # Bootstrap: commit the prefill token at position prompt_tokens. ++ if first_token < 0 or first_token >= vocab_size: ++ raise RuntimeError(f"prefill generated invalid token id: {first_token}") ++ new_ids.append(first_token) ++ if first_token in eos_ids: ++ finish_reason = "stop" ++ ++ # Decode loop: one forward per iteration, consume ALL accepted tokens. ++ # For non-spec, accepted is length 1. For fused-spec MTP, accepted may ++ # be length up to (speculation_length + 1). + with torch.no_grad(): +- for step in range(max_tokens): +- if current_token in eos_ids: +- finish_reason = "stop" +- break +- if current_token < 0 or current_token >= vocab_size: +- raise RuntimeError(f"model generated invalid token id: {current_token}") +- new_ids.append(current_token) +- if step == max_tokens - 1: +- break +- +- pos_value = prompt_tokens + step +- decode_ids[0, 0] = current_token ++ while len(new_ids) < max_tokens and finish_reason == "length": ++ last_token = new_ids[-1] ++ pos_value = prompt_tokens + len(new_ids) - 1 ++ decode_ids[0, 0] = last_token + decode_position_ids[0, 0] = pos_value + active_attention_mask = decode_attention_mask[:, : pos_value + 1] + out = self.model( + input_ids=decode_ids, + attention_mask=active_attention_mask, +@@ -226,7 +260,17 @@ class QwenOpenAIServer: + sampling_params=sampling_params, + return_dict=True, + ) +- current_token = _token_scalar(out.tokens) ++ accepted = _accepted_tokens(out.tokens, vocab_size, pad_id=pad_id) ++ if not accepted: ++ raise RuntimeError("model produced no accepted tokens in decode step") ++ for tok in accepted: ++ if len(new_ids) >= max_tokens: ++ break ++ if tok < 0 or tok >= vocab_size: ++ raise RuntimeError(f"model generated invalid token id: {tok}") ++ new_ids.append(tok) ++ if tok in eos_ids: ++ finish_reason = "stop" ++ break + elapsed = time.perf_counter() - t0 diff --git a/contrib/models/Qwen3.6-27B/docs/patches/mtp_batched_accept_README.md b/contrib/models/Qwen3.6-27B/docs/patches/mtp_batched_accept_README.md new file mode 100644 index 00000000..d00b4d3c --- /dev/null +++ b/contrib/models/Qwen3.6-27B/docs/patches/mtp_batched_accept_README.md @@ -0,0 +1,118 @@ +# MTP Batched Accept Fix + +## Problem + +The OpenAI-compatible server in `scripts/openai_compat_server.py` discards +fused-spec accepted tokens beyond the first. The decode loop calls +`_token_scalar(out.tokens)` which returns only `tokens[0]`, then feeds that +token back as the next input. Result: host advances 1 token per Python loop +iteration even when the device accepted multiple via MTP speculation. + +Observed effect on `qwen36_27b_128k_fp8_mtp_run2` artifact: +- Expected decode: 2.0-2.5x baseline (NVIDIA's published MTP gain for length=2) +- Actual decode: 1.6x baseline (44 tok/s vs 27 baseline) +- Gap is purely host-loop, not device compute + +## Fix + +Patch: `mtp_batched_accept.patch` + +Changes: +1. Add `_accepted_tokens(tokens, vocab_size, pad_id)` helper that scans the + fused-spec output tensor and returns the prefix of in-vocab non-pad tokens. +2. Rewrite the decode loop as a `while` loop with bootstrap + iterations: + - Bootstrap: commit `first_token` from prefill at position `prompt_tokens`. + - Each iteration: feed `new_ids[-1]` at position + `prompt_tokens + len(new_ids) - 1`, then commit all accepted tokens + returned by the device. + - Stop on EOS, max_tokens cap, or invalid token id. +3. No model recompile required. No NeuronConfig changes. + +## Apply + +From repo root on branch `codex/qwen36-mtp-vllm-apc`: + +```bash +git apply docs/patches/mtp_batched_accept.patch +# or, if line numbers shifted: +git apply --3way docs/patches/mtp_batched_accept.patch +``` + +Verify by inspection: + +```bash +grep -n "_accepted_tokens\|while len(new_ids)" \ + contrib/models/Qwen3.6-27B/scripts/openai_compat_server.py +``` + +Should show the helper definition near the top and the new while-loop in the +decode path. + +## Validation gates (in order) + +Run against the existing `qwen36_27b_128k_fp8_mtp_run2` artifact. + +### Gate 1: Smoke +Math prompt returns 391 with coherent text. No invalid token errors. Same +behavior as before the patch. + +```bash +curl -s -X POST http://localhost:8000/v1/chat/completions \ + -H 'Content-Type: application/json' \ + -d '{"model":"qwen3.6-27b-128k-fp8-mtp","messages":[ + {"role":"user","content":"What is 17 * 23?"}], + "max_tokens":32}' +``` + +Expect output containing `391`. + +### Gate 2: Greedy parity +Same 5 fixed prompts before and after patch. Greedy decode (top_k=1). +Token-by-token output should be **identical** between pre-patch and post-patch +because the patch only changes how the host loop consumes the device output, +not the math. + +If mismatch: bug in `_accepted_tokens` (likely missing pad sentinel or +off-by-one). Investigate before measuring perf. + +### Gate 3: Decode tok/s +Same benchmarks as the MTP results doc: +- 32-token prompt, 128-token completion: expect decode tok/s ~50-60 (vs 41.6) +- 28-token prompt, 256-token completion: expect decode tok/s ~55-65 (vs 44.3) +- 3959-token prompt, 128-token completion: expect decode tok/s ~55-65 (vs 45.2) + +If decode is unchanged from previous MTP measurements: spec is not actually +accepting multiple tokens per forward. Verify by logging +`len(accepted)` distribution during a 200-token generation; expect mean ≥ 1.5. + +### Gate 4: Long-context coherence +16K-token prompt, 256-token completion. Output should be coherent and not +contain any invalid tokens. Same quality as pre-patch. + +## Expected speedup + +| Workload | Before patch | After patch | Mechanism | +|---|---:|---:|---| +| 32-tok / 128-out decode | 41.6 tok/s | **~55-65 tok/s** | Consume 2 accepted per forward | +| 28-tok / 256-out decode | 44.3 tok/s | **~55-65 tok/s** | Sustained spec acceptance | +| 4K / 128-out decode | 45.2 tok/s | **~55-65 tok/s** | Same | + +Combined with baseline v3 (27 tok/s) → MTP after patch (~55-65) is **2.0-2.4x +total decode speedup**, matching NVIDIA's published number for spec length=2. + +## What this does NOT do + +- Does not change prefill speed (still ~420 tok/s flat across contexts) +- Does not change model quality (same math, same tokens, same logits) +- Does not change vLLM bridge (custom OpenAI server only) +- Does not change cache management +- Does not require artifact recompile + +## Followups after this lands + +1. Tag artifact + branch as `qwen36-27b-mtp-v2` with the new tok/s numbers +2. Apply the same batched-accept logic to the vLLM-Neuron decode path (once + the v1 MTP registry gap is fixed) +3. Investigate speculation length=3 (currently length=2 in the artifact) +4. Measure acceptance rate distribution; if mean < 1.5, MTP head quality is + the limit, not the host loop diff --git a/contrib/models/Qwen3.6-27B/scripts/openai_compat_server.py b/contrib/models/Qwen3.6-27B/scripts/openai_compat_server.py new file mode 100644 index 00000000..37bf10d9 --- /dev/null +++ b/contrib/models/Qwen3.6-27B/scripts/openai_compat_server.py @@ -0,0 +1,501 @@ +#!/usr/bin/env python3 +"""Minimal OpenAI-compatible HTTP server for the Qwen3.6-27B NxDI artifact. + +This intentionally avoids uvicorn/fastapi runtime dependencies so it can run in +the stock Neuron inference venv. It supports non-streaming: + - GET /health + - GET /v1/models + - POST /v1/completions + - POST /v1/chat/completions +""" + +import argparse +import json +import sys +import threading +import time +import traceback +import uuid +from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer +from typing import Any, Dict, List + +import torch + + +def _json_response(handler: BaseHTTPRequestHandler, status: int, payload: Dict[str, Any]): + body = json.dumps(payload, ensure_ascii=False).encode("utf-8") + handler.send_response(status) + handler.send_header("Content-Type", "application/json") + handler.send_header("Content-Length", str(len(body))) + handler.send_header("Access-Control-Allow-Origin", "*") + handler.send_header("Access-Control-Allow-Headers", "authorization,content-type") + handler.send_header("Access-Control-Allow-Methods", "GET,POST,OPTIONS") + handler.end_headers() + handler.wfile.write(body) + + +def _error(handler: BaseHTTPRequestHandler, status: int, message: str): + _json_response( + handler, + status, + {"error": {"message": message, "type": "server_error", "code": status}}, + ) + + +def _is_token_id(value: Any) -> bool: + return isinstance(value, int) and not isinstance(value, bool) + + +def _completion_prompt(prompt: Any) -> str | list[int]: + if isinstance(prompt, str): + return prompt + if isinstance(prompt, list): + if not prompt: + raise ValueError("prompt list must not be empty") + first = prompt[0] + if _is_token_id(first): + if not all(_is_token_id(item) for item in prompt): + raise ValueError("token-id prompt lists must contain only integers") + return [int(item) for item in prompt] + if isinstance(first, list) and first and all(_is_token_id(item) for item in first): + return [int(item) for item in first] + if isinstance(first, str): + return first + raise ValueError( + "unsupported prompt list shape; use a string, list[int], list[str], " + "or list[list[int]]" + ) + return str(prompt) + + +def _token_scalar(tokens: Any) -> int: + if hasattr(tokens, "detach"): + tokens = tokens.detach().cpu() + if tokens.ndim == 0: + return int(tokens.item()) + return int(tokens.reshape(-1)[0].item()) + + +def _normalize_stop_sequences(stop: Any) -> List[str]: + if stop is None: + return [] + if isinstance(stop, str): + return [stop] + if isinstance(stop, list): + return [item for item in stop if isinstance(item, str)] + return [] + + +def _coerce_optional_bool(value: Any) -> bool | None: + if isinstance(value, bool): + return value + if isinstance(value, int) and value in (0, 1): + return bool(value) + if isinstance(value, str): + normalized = value.strip().lower().replace("-", "_").replace(" ", "_") + if normalized in {"1", "true", "yes", "y", "on", "enable", "enabled", "thinking"}: + return True + if normalized in { + "0", + "false", + "no", + "n", + "off", + "disable", + "disabled", + "none", + "non_thinking", + "no_thinking", + }: + return False + return None + + +def _resolve_enable_thinking(body: Dict[str, Any]) -> bool: + for key in ("enable_thinking", "thinking_enabled", "thinking"): + if key in body: + value = body.get(key) + if isinstance(value, dict): + for nested_key in ("enable_thinking", "enabled", "enable", "value"): + if nested_key in value: + coerced = _coerce_optional_bool(value.get(nested_key)) + if coerced is not None: + return coerced + budget = value.get("budget_tokens") + if isinstance(budget, int): + return budget > 0 + coerced = _coerce_optional_bool(value) + if coerced is not None: + return coerced + + template_kwargs = body.get("chat_template_kwargs") + if isinstance(template_kwargs, dict): + coerced = _coerce_optional_bool(template_kwargs.get("enable_thinking")) + if coerced is not None: + return coerced + + reasoning = body.get("reasoning") + if isinstance(reasoning, dict): + for nested_key in ("enable_thinking", "enabled", "enable", "value"): + if nested_key in reasoning: + coerced = _coerce_optional_bool(reasoning.get(nested_key)) + if coerced is not None: + return coerced + effort = reasoning.get("effort") or reasoning.get("reasoning_effort") + coerced = _coerce_optional_bool(effort) + if coerced is not None: + return coerced + if isinstance(effort, str) and effort.strip(): + return True + else: + coerced = _coerce_optional_bool(reasoning) + if coerced is not None: + return coerced + + if "reasoning_effort" in body: + effort = body.get("reasoning_effort") + coerced = _coerce_optional_bool(effort) + if coerced is not None: + return coerced + if isinstance(effort, str) and effort.strip(): + return True + + return False + + +class QwenOpenAIServer: + def __init__(self, args: argparse.Namespace): + self.args = args + self.model_id = args.model_id + self.lock = threading.Lock() + self._load_model() + + def _load_model(self): + if self.args.contrib_root not in sys.path: + sys.path.insert(0, self.args.contrib_root) + + from transformers import AutoTokenizer, GenerationConfig + from neuronx_distributed_inference.modules.generation.sampling import ( + prepare_sampling_params, + ) + from src.modeling_qwen35 import NeuronQwen35ForCausalLM + + print("Loading tokenizer from", self.args.model_path, flush=True) + self.tokenizer = AutoTokenizer.from_pretrained( + self.args.model_path, + padding_side="right", + ) + if self.tokenizer.pad_token is None: + self.tokenizer.pad_token = self.tokenizer.eos_token + + print("Loading NxDI artifact from", self.args.compiled_path, flush=True) + t0 = time.perf_counter() + self.model = NeuronQwen35ForCausalLM(self.args.compiled_path) + self.model.load(self.args.compiled_path) + self.model.reset() + self.prepare_sampling_params = prepare_sampling_params + self.GenerationConfig = GenerationConfig + print(f"Model loaded in {time.perf_counter() - t0:.2f}s", flush=True) + + def _chat_prompt(self, messages: List[Dict[str, Any]], enable_thinking: bool = False) -> str: + try: + return self.tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + enable_thinking=enable_thinking, + ) + except TypeError: + return self.tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + ) + except Exception: + lines = [] + for msg in messages: + role = msg.get("role", "user") + content = msg.get("content", "") + lines.append(f"{role}: {content}") + lines.append("assistant:") + return "\n".join(lines) + + def _generate(self, prompt: str | list[int], body: Dict[str, Any]) -> Dict[str, Any]: + max_tokens = int(body.get("max_tokens", body.get("max_completion_tokens", 128)) or 128) + if max_tokens <= 0: + raise ValueError("max_tokens must be positive") + if max_tokens > self.args.max_new_tokens_limit: + raise ValueError( + f"max_tokens={max_tokens} exceeds server limit {self.args.max_new_tokens_limit}" + ) + + if isinstance(prompt, list): + input_ids = torch.tensor([prompt], dtype=torch.long) + else: + input_ids = torch.tensor( + [self.tokenizer(prompt, add_special_tokens=False).input_ids], + dtype=torch.long, + ) + prompt_tokens = int(input_ids.shape[1]) + if prompt_tokens <= 0: + raise ValueError("prompt must contain at least one token") + vocab_size = len(self.tokenizer) + invalid_prompt_ids = [ + int(tok) + for tok in input_ids.reshape(-1).tolist() + if int(tok) < 0 or int(tok) >= vocab_size + ] + if invalid_prompt_ids: + raise ValueError(f"prompt contains invalid token ids: {invalid_prompt_ids[:8]}") + if prompt_tokens + max_tokens > self.args.seq_len: + raise ValueError( + f"prompt_tokens + max_tokens = {prompt_tokens + max_tokens} exceeds " + f"seq_len={self.args.seq_len}" + ) + + temperature = float(body.get("temperature", 0.0) or 0.0) + top_p = float(body.get("top_p", 1.0) or 1.0) + top_k = int(body.get("top_k", 1) or 1) + # NxDI's traced on-device sampler for this artifact uses do_sample=True. + # OpenAI temperature=0 means greedy, but passing literal 0 into that + # sampler divides logits by zero. top_k=1 with temperature=1 is the + # deterministic greedy path used by the validated HF adapter tests. + sampler_temperature = temperature + if temperature <= 0.0: + sampler_temperature = 1.0 + top_p = 1.0 + top_k = 1 + sampling_params = self.prepare_sampling_params( + batch_size=1, + top_k=[top_k], + top_p=[top_p], + temperature=[sampler_temperature], + ) + seq_ids = torch.tensor([0], dtype=torch.int32) + + with self.lock: + if hasattr(self.model, "reset"): + self.model.reset() + t0 = time.perf_counter() + first_token = None + for start in range(0, prompt_tokens, self.args.chunk_size): + end = min(start + self.args.chunk_size, prompt_tokens) + valid = end - start + chunk_ids = input_ids[:, start:end] + attention_mask = torch.ones((1, valid), dtype=torch.long) + position_ids = torch.arange( + start, + end, + dtype=torch.long, + ).unsqueeze(0) + + with torch.no_grad(): + out = self.model( + input_ids=chunk_ids, + attention_mask=attention_mask, + position_ids=position_ids, + seq_ids=seq_ids, + sampling_params=sampling_params, + return_dict=True, + ) + first_token = _token_scalar(out.tokens) + + if first_token is None: + raise RuntimeError("prefill produced no token") + + new_ids = [] + current_token = first_token + raw_eos_id = self.tokenizer.eos_token_id + eos_ids = ( + set(raw_eos_id) + if isinstance(raw_eos_id, (list, tuple, set)) + else {raw_eos_id} + ) + decode_ids = torch.empty((1, 1), dtype=torch.int32) + decode_position_ids = torch.empty((1, 1), dtype=torch.int32) + decode_attention_mask = torch.ones( + (1, prompt_tokens + max_tokens), + dtype=torch.int32, + ) + finish_reason = "length" + with torch.no_grad(): + for step in range(max_tokens): + if current_token in eos_ids: + finish_reason = "stop" + break + if current_token < 0 or current_token >= vocab_size: + raise RuntimeError(f"model generated invalid token id: {current_token}") + new_ids.append(current_token) + if step == max_tokens - 1: + break + + pos_value = prompt_tokens + step + decode_ids[0, 0] = current_token + decode_position_ids[0, 0] = pos_value + active_attention_mask = decode_attention_mask[:, : pos_value + 1] + out = self.model( + input_ids=decode_ids, + attention_mask=active_attention_mask, + position_ids=decode_position_ids, + seq_ids=seq_ids, + sampling_params=sampling_params, + return_dict=True, + ) + current_token = _token_scalar(out.tokens) + elapsed = time.perf_counter() - t0 + + invalid = [tok for tok in new_ids if tok < 0 or tok >= vocab_size] + if invalid: + raise RuntimeError(f"model generated invalid token ids: {invalid[:8]}") + + text = self.tokenizer.decode(new_ids, skip_special_tokens=True) + for stop in _normalize_stop_sequences(body.get("stop")): + if isinstance(stop, str) and stop in text: + text = text.split(stop, 1)[0] + + return { + "text": text, + "prompt_tokens": prompt_tokens, + "completion_tokens": len(new_ids), + "elapsed": elapsed, + "tokens": new_ids, + "finish_reason": finish_reason, + } + + +def make_handler(server_state: QwenOpenAIServer): + class Handler(BaseHTTPRequestHandler): + protocol_version = "HTTP/1.1" + + def log_message(self, fmt, *args): + print(f"{self.address_string()} - {fmt % args}", flush=True) + + def do_OPTIONS(self): + _json_response(self, 200, {}) + + def do_GET(self): + if self.path == "/health": + _json_response(self, 200, {"status": "ok", "model": server_state.model_id}) + elif self.path == "/v1/models": + _json_response( + self, + 200, + { + "object": "list", + "data": [ + { + "id": server_state.model_id, + "object": "model", + "created": int(time.time()), + "owned_by": "local", + } + ], + }, + ) + else: + _error(self, 404, f"unknown route: {self.path}") + + def do_POST(self): + try: + length = int(self.headers.get("content-length", "0")) + body = json.loads(self.rfile.read(length).decode("utf-8") or "{}") + if body.get("stream"): + raise ValueError("stream=true is not supported by this minimal server yet") + + if self.path == "/v1/completions": + result = server_state._generate( + _completion_prompt(body.get("prompt", "")), + body, + ) + _json_response( + self, + 200, + { + "id": f"cmpl-{uuid.uuid4().hex}", + "object": "text_completion", + "created": int(time.time()), + "model": server_state.model_id, + "choices": [ + { + "index": 0, + "text": result["text"], + "finish_reason": result["finish_reason"], + } + ], + "usage": { + "prompt_tokens": result["prompt_tokens"], + "completion_tokens": result["completion_tokens"], + "total_tokens": result["prompt_tokens"] + + result["completion_tokens"], + }, + "x_latency_seconds": result["elapsed"], + }, + ) + elif self.path == "/v1/chat/completions": + messages = body.get("messages") or [] + if not isinstance(messages, list): + raise ValueError("messages must be a list") + result = server_state._generate( + server_state._chat_prompt( + messages, + enable_thinking=_resolve_enable_thinking(body), + ), + body, + ) + _json_response( + self, + 200, + { + "id": f"chatcmpl-{uuid.uuid4().hex}", + "object": "chat.completion", + "created": int(time.time()), + "model": server_state.model_id, + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": result["text"], + }, + "finish_reason": result["finish_reason"], + } + ], + "usage": { + "prompt_tokens": result["prompt_tokens"], + "completion_tokens": result["completion_tokens"], + "total_tokens": result["prompt_tokens"] + + result["completion_tokens"], + }, + "x_latency_seconds": result["elapsed"], + }, + ) + else: + _error(self, 404, f"unknown route: {self.path}") + except Exception as exc: + traceback.print_exc() + _error(self, 500, str(exc)) + + return Handler + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--host", default="0.0.0.0") + parser.add_argument("--port", type=int, default=8000) + parser.add_argument("--model-id", default="qwen3.6-27b-neuron") + parser.add_argument("--model-path", required=True) + parser.add_argument("--compiled-path", required=True) + parser.add_argument("--contrib-root", required=True) + parser.add_argument("--seq-len", type=int, default=65536) + parser.add_argument("--chunk-size", type=int, default=512) + parser.add_argument("--max-new-tokens-limit", type=int, default=512) + args = parser.parse_args() + + state = QwenOpenAIServer(args) + httpd = ThreadingHTTPServer((args.host, args.port), make_handler(state)) + print(f"Serving {args.model_id} on http://{args.host}:{args.port}", flush=True) + httpd.serve_forever() + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/contrib/models/Qwen3.6-27B/scripts/probe_qkvgate_kernel_layout.py b/contrib/models/Qwen3.6-27B/scripts/probe_qkvgate_kernel_layout.py new file mode 100644 index 00000000..1d54f33e --- /dev/null +++ b/contrib/models/Qwen3.6-27B/scripts/probe_qkvgate_kernel_layout.py @@ -0,0 +1,165 @@ +#!/usr/bin/env python3 +"""Isolation probe: does the nkilib `qkv` kernel emit BSD as a contiguous +[Q | gate | K | V] block in weight-row order when num_q_heads is doubled? + +WHY THIS EXISTS +--------------- +The qkvgate decode path folds the output-gate projection into the fused Wqkv +weight and projects Q/gate/K/V in ONE nkilib `qkv` kernel call with +`num_q_heads = 2 * num_heads` (gate masquerades as a second set of Q heads). +It then splits the result assuming the layout is contiguous: + + q_width = num_heads * head_dim + Q = packed[..., 0 : q_width] + gate = packed[..., q_width : 2*q_width] + K = packed[..., 2*q_width : 2*q_width + kv_width] + V = packed[..., 2*q_width + kv_width : ] + +Everything UPSTREAM (weight packing) and DOWNSTREAM (sigmoid gate, attention, +o_proj) is shared with the known-good `qknormrope` baseline and is proven. +The ONE never-isolated assumption is the kernel's output head ordering under a +doubled num_q_heads. This probe tests exactly that and PRINTS the kernel's real +permutation so the fix is unambiguous. + +HOW IT WORKS +------------ +We build an Wqkv whose every output column carries an identifiable "code": + Q head h -> all head_dim columns == 100 + h + gate head h -> 200 + h + K head h -> 300 + h + V head h -> 400 + h +With a ones() input and no bias, output[col] == code(col) (the column's +row-sum). We then run the SAME kernel call the model uses and decode the codes +straight off the output. The expected (contiguous) order is + [Q0..Q(n-1), gate0..gate(n-1), K0..K(kv-1), V0..V(kv-1)] +If the kernel reorders (e.g. GQA-interleaves Q with K/V, or splits the doubled +q-region differently), the decoded order reveals precisely how -> that IS the +corrected split. + +RUN ON A HOST WITH nkilib + neuronxcc (e.g. the compile host 16.51.94.87): + python probe_qkvgate_kernel_layout.py --num-heads 2 --num-kv-heads 1 --head-dim 256 +Two host-specific knobs are flagged with `HOST:` below — adjust if the local +nki API differs. +""" +import argparse + +import torch + +import nki as _nkilib_nki +from nkilib.core.qkv.qkv import qkv as _nkilib_qkv +from nkilib.core.utils.common_types import ( + NormType as NormType, + QKVOutputLayout as QKVOutputLayout, + QuantizationType as QuantizationType, +) + +KERNEL = _nkilib_nki.jit(_nkilib_qkv) + +Q_BASE, GATE_BASE, K_BASE, V_BASE = 100, 200, 300, 400 + + +def region_of(code: int) -> str: + base = (code // 100) * 100 + return {100: "Q", 200: "gate", 300: "K", 400: "V"}.get(base, "?") + + +def build_identifiable_weight(hidden, num_q, num_kv, head_dim, dtype): + """Wqkv with column codes. Layout matches the packed weight rows: + [Q(num_q) | gate(num_q) | K(num_kv) | V(num_kv)] (num_q == real num_heads). + Returns a Linear-style weight [out_features, hidden]; each row set so that + a ones() input yields output[col] == code(col).""" + regions = ( + [(Q_BASE, h) for h in range(num_q)] + + [(GATE_BASE, h) for h in range(num_q)] + + [(K_BASE, h) for h in range(num_kv)] + + [(V_BASE, h) for h in range(num_kv)] + ) + out_features = len(regions) * head_dim + w = torch.zeros(out_features, hidden, dtype=torch.float32) + col = 0 + for base, h in regions: + code = base + h + w[col : col + head_dim, :] = code / hidden # row-sum == code + col += head_dim + return w.to(dtype) + + +def decode_layout(out_row, head_dim): + """out_row: 1D tensor of length out_features. Returns list of decoded codes, + one per head block (head_dim columns), using the block's median value.""" + codes = [] + n_blocks = out_row.shape[0] // head_dim + for b in range(n_blocks): + block = out_row[b * head_dim : (b + 1) * head_dim].float() + codes.append(int(round(block.median().item()))) + return codes + + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument("--num-heads", type=int, default=2, help="real num_heads (per rank)") + ap.add_argument("--num-kv-heads", type=int, default=1) + ap.add_argument("--head-dim", type=int, default=256) + ap.add_argument("--hidden", type=int, default=512) + ap.add_argument("--lnc", type=int, default=1) + ap.add_argument("--dtype", default="bfloat16", choices=["bfloat16", "float32"]) + args = ap.parse_args() + + dtype = getattr(torch, args.dtype) + n, kv, hd, hidden = args.num_heads, args.num_kv_heads, args.head_dim, args.hidden + + weight = build_identifiable_weight(hidden, n, kv, hd, dtype) + x = torch.ones(1, 1, hidden, dtype=dtype) + + # ----- reference: plain matmul, then the SAME split the model uses ----- + ref = (x.float() @ weight.float().t()).squeeze(0).squeeze(0) # [out_features] + ref_codes = decode_layout(ref, hd) + print("REFERENCE (contiguous weight-row order):") + print(" ", [f"{region_of(c)}{c % 100}" for c in ref_codes]) + + # ----- kernel under test: exact call from _qkv_gate_packed_projection_nki ----- + # HOST: weight orientation. The model applies transpose_parallel_linear_layer + # before handing the weight to the kernel. If the call below errors on shape, + # pass weight.t().contiguous() instead. + kernel_weight = weight + # HOST: device vs simulator. On a Trainium core the jitted call below runs + # directly. For CPU simulation use: out = _nkilib_nki.simulate_kernel( + # _nkilib_qkv, input=x, fused_qkv_weights=kernel_weight, ...same kwargs...) + packed = KERNEL[args.lnc]( + input=x, + fused_qkv_weights=kernel_weight, + output_layout=QKVOutputLayout.BSD, + bias=None, + fused_residual_add=False, + mlp_prev=None, + attention_prev=None, + fused_norm_type=NormType.NO_NORM, + gamma_norm_weights=None, + norm_eps=1e-6, + fused_rope=False, + cos_cache=None, + sin_cache=None, + quantization_type=QuantizationType.NONE, + qkv_w_scale=None, + qkv_in_scale=None, + d_head=hd, + num_q_heads=n * 2, # gate folded in as extra Q heads + num_kv_heads=kv, + ) + packed = torch.as_tensor(packed).reshape(-1) + out_codes = decode_layout(packed, hd) + print("KERNEL OUTPUT (actual order):") + print(" ", [f"{region_of(c)}{c % 100}" for c in out_codes]) + + if out_codes == ref_codes: + print("\nRESULT: layout MATCHES -> split is correct; bug is NOT head order.") + print(" Re-run with FP8 weights/scales to test the ROW-quant path.") + else: + print("\nRESULT: layout MISMATCH -> the tensor_split offsets are wrong.") + print(" Map each output block to its code above to derive the fix:") + for i, c in enumerate(out_codes): + print(f" out block {i:>2} -> {region_of(c)} head {c % 100}") + + +if __name__ == "__main__": + main() diff --git a/contrib/models/Qwen3.6-27B/scripts/validate_deltanet_chunk_step_nki.py b/contrib/models/Qwen3.6-27B/scripts/validate_deltanet_chunk_step_nki.py new file mode 100644 index 00000000..51d0dcd4 --- /dev/null +++ b/contrib/models/Qwen3.6-27B/scripts/validate_deltanet_chunk_step_nki.py @@ -0,0 +1,340 @@ +#!/usr/bin/env python3 +"""Validate and optionally profile the Qwen DeltaNet per-chunk NKI kernel. + +The reference path is CPU-only by design. Keeping reference math off the XLA +device avoids compiling extra NEFFs that obscure the NKI kernel profile. +""" + +from __future__ import annotations + +import argparse +import importlib.util +import json +import math +import os +import sys +from pathlib import Path +from typing import Any + + +P_MAX = 128 + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Validate/profile deltanet_chunk_step against a CPU reference." + ) + parser.add_argument("--seed", type=int, default=1234) + parser.add_argument("--runs", type=int, default=1) + parser.add_argument("--target", default="trn2") + parser.add_argument("--lnc", type=int, default=1) + parser.add_argument("--visible-cores", default="0") + parser.add_argument("--inspect", action="store_true") + parser.add_argument("--dge", action="store_true") + parser.add_argument( + "--inspect-dir", + default="/mnt/trainium_artifacts/profiles/deltanet_chunk_step_isolated", + ) + parser.add_argument("--atol", type=float, default=2.0e-2) + parser.add_argument("--rtol", type=float, default=2.0e-2) + parser.add_argument("--value-scale", type=float, default=0.05) + parser.add_argument("--state-scale", type=float, default=0.01) + parser.add_argument("--gate-scale", type=float, default=0.01) + parser.add_argument("--fail-on-mismatch", action="store_true") + return parser.parse_args() + + +def configure_environment(args: argparse.Namespace) -> Path: + os.environ.setdefault("NEURON_PLATFORM_TARGET_OVERRIDE", args.target) + os.environ.setdefault("NEURON_CC_FLAGS", f"--target {args.target} --lnc {args.lnc}") + os.environ.setdefault("NEURON_RT_VISIBLE_CORES", args.visible_cores) + + inspect_dir = Path(args.inspect_dir).expanduser().resolve() + if args.inspect: + inspect_dir.mkdir(parents=True, exist_ok=True) + os.environ["NEURON_RT_INSPECT_ENABLE"] = "1" + os.environ["NEURON_RT_INSPECT_DEVICE_PROFILE"] = "1" + os.environ["NEURON_RT_INSPECT_SYSTEM_PROFILE"] = "0" + os.environ["NEURON_RT_INSPECT_OUTPUT_DIR"] = str(inspect_dir) + os.environ["XLA_IR_DEBUG"] = "1" + os.environ["XLA_HLO_DEBUG"] = "1" + os.environ["NEURON_FRAMEWORK_DEBUG"] = "1" + if args.dge: + os.environ["NEURON_RT_ENABLE_DGE_NOTIFICATIONS"] = "1" + return inspect_dir + + +def add_qwen_to_path() -> None: + script_path = Path(__file__).resolve() + qwen_root = script_path.parents[1] + sys.path.insert(0, str(qwen_root)) + + +def load_chunked_kernel(): + kernel_path = ( + Path(__file__).resolve().parents[1] + / "src" + / "nki_kernels" + / "nki_deltanet_chunked.py" + ) + spec = importlib.util.spec_from_file_location( + "qwen36_nki_deltanet_chunked_under_test", + kernel_path, + ) + module = importlib.util.module_from_spec(spec) + assert spec.loader is not None + spec.loader.exec_module(module) + return module.deltanet_chunk_step + + +def make_inputs(torch: Any, args: argparse.Namespace) -> dict[str, Any]: + generator = torch.Generator(device="cpu") + generator.manual_seed(args.seed) + + def randn(scale: float) -> Any: + return torch.randn((P_MAX, P_MAX), generator=generator, dtype=torch.float32) * scale + + query = randn(args.value_scale) + key = randn(args.value_scale) + value = randn(args.value_scale) + state_in = randn(args.state_scale) + + query = torch.nn.functional.normalize(query, p=2, dim=-1) / math.sqrt(P_MAX) + key = torch.nn.functional.normalize(key, p=2, dim=-1) + + beta_col = torch.sigmoid( + torch.randn((P_MAX, 1), generator=generator, dtype=torch.float32) + ) + beta_broadcast = beta_col.expand(P_MAX, P_MAX).contiguous() + + # Qwen GDN decay gates are negative log-decays. Keep them small enough that + # reference comparisons isolate algorithmic errors rather than overflow. + g_raw = -torch.nn.functional.softplus( + torch.randn((P_MAX, 1), generator=generator, dtype=torch.float32) + ) + g_raw = g_raw * args.gate_scale + g_cumsum_col = torch.cumsum(g_raw, dim=0) + g_cumsum = g_cumsum_col.expand(P_MAX, P_MAX).contiguous() + g_last = g_cumsum_col[-1:].expand(P_MAX, P_MAX).contiguous() + + lower_mask = torch.tril(torch.ones((P_MAX, P_MAX), dtype=torch.float32), diagonal=-1) + lower_mask_diag = torch.tril(torch.ones((P_MAX, P_MAX), dtype=torch.float32)) + identity = torch.eye(P_MAX, dtype=torch.float32) + + return { + "query": query.contiguous(), + "key": key.contiguous(), + "value": value.contiguous(), + "beta_broadcast": beta_broadcast, + "g_cumsum": g_cumsum, + "g_last": g_last, + "state_in": state_in.contiguous(), + "lower_mask": lower_mask.contiguous(), + "identity": identity.contiguous(), + "lower_mask_diag": lower_mask_diag.contiguous(), + } + + +def reference_math(torch: Any, inputs: dict[str, Any]) -> tuple[Any, Any, Any]: + q = inputs["query"] + k = inputs["key"] + v = inputs["value"] + beta = inputs["beta_broadcast"] + gc = inputs["g_cumsum"][:, 0:1] + gl = inputs["g_last"][:, 0:1] + state = inputs["state_in"] + lower = inputs["lower_mask"] + lower_diag = inputs["lower_mask_diag"] + eye = inputs["identity"] + + k_beta = k * beta + v_beta = v * beta + + decay = torch.exp(gc - gc.T) + decay_strict = decay * lower + decay_diag = decay * lower_diag + + qk_beta = k_beta @ k.T + a_mat = -(qk_beta * decay_strict) + a_mat = a_mat * lower + + # Intended kernel math: A is strictly lower triangular, N = inv(I - A). + lhs = eye - a_mat + n_mat = torch.linalg.solve_triangular(lhs, eye, upper=False) + + exp_gc = torch.exp(gc) + value_corr = n_mat @ v_beta + k_cumdecay = n_mat @ (k_beta * exp_gc) + + qk_raw = q @ k.T + attn_intra = qk_raw * decay_diag + + v_prime = k_cumdecay @ state + v_new = value_corr - v_prime + + attn_inter = (q * exp_gc) @ state + intra_out = attn_intra @ v_new + output = attn_inter + intra_out + + exp_gl_minus_gc = torch.exp(gl - gc) + k_raw_decay = k * exp_gl_minus_gc + kv_outer = k_raw_decay.T @ v_new + state_out = state * torch.exp(gl) + kv_outer + + return output.contiguous(), state_out.contiguous(), n_mat.contiguous() + + +def reference_kernel_mirror(torch: Any, inputs: dict[str, Any]) -> tuple[Any, Any, Any]: + q = inputs["query"] + k = inputs["key"] + v = inputs["value"] + beta = inputs["beta_broadcast"] + gc = inputs["g_cumsum"][:, 0:1] + gl = inputs["g_last"][:, 0:1] + state = inputs["state_in"] + lower = inputs["lower_mask"] + lower_diag = inputs["lower_mask_diag"] + eye = inputs["identity"] + + k_beta = k * beta + v_beta = v * beta + + decay = torch.exp(gc - gc.T) + qk = k_beta @ k.T + a_mat = -(qk * decay * lower) * lower + + p_acc = eye + a_mat + a_pow = a_mat.clone() + for _ in range(6): + a_pow = (a_pow @ a_pow) * lower + p_acc = ((eye + a_pow) @ p_acc) * lower_diag + + exp_gc = torch.exp(gc) + value_corr = p_acc @ v_beta + k_cumdecay = p_acc @ (k_beta * exp_gc) + attn_intra = (q @ k.T) * (decay * lower_diag) + v_new = value_corr - (k_cumdecay @ state) + output = ((q * exp_gc) @ state) + (attn_intra @ v_new) + + k_raw_decay = k * torch.exp(gl - gc) + state_out = state * torch.exp(gl) + (k_raw_decay.T @ v_new) + + return output.contiguous(), state_out.contiguous(), p_acc.contiguous() + + +def tensor_metrics(torch: Any, actual: Any, expected: Any) -> dict[str, float | bool]: + diff = actual - expected + expected_norm = torch.linalg.vector_norm(expected).item() + diff_norm = torch.linalg.vector_norm(diff).item() + actual_flat = actual.reshape(-1).to(torch.float64) + expected_flat = expected.reshape(-1).to(torch.float64) + denom = torch.linalg.vector_norm(actual_flat) * torch.linalg.vector_norm(expected_flat) + cosine = ( + float(torch.dot(actual_flat, expected_flat) / denom) + if denom.item() != 0.0 + else float("nan") + ) + return { + "finite": bool(torch.isfinite(actual).all().item()), + "max_abs": float(torch.max(torch.abs(diff)).item()), + "mean_abs": float(torch.mean(torch.abs(diff)).item()), + "diff_norm": float(diff_norm), + "expected_norm": float(expected_norm), + "relative_norm": float(diff_norm / max(expected_norm, 1.0e-12)), + "cosine": cosine, + } + + +def main() -> int: + args = parse_args() + inspect_dir = configure_environment(args) + add_qwen_to_path() + + import torch + import torch_xla.core.xla_model as xm + + deltanet_chunk_step = load_chunked_kernel() + + inputs = make_inputs(torch, args) + math_out, math_state, math_n = reference_math(torch, inputs) + mirror_out, mirror_state, mirror_n = reference_kernel_mirror(torch, inputs) + + mirror_vs_math = { + "output": tensor_metrics(torch, mirror_out, math_out), + "state": tensor_metrics(torch, mirror_state, math_state), + "n_matrix": tensor_metrics(torch, mirror_n, math_n), + } + + device = xm.xla_device() + xla_inputs = {name: tensor.to(device=device) for name, tensor in inputs.items()} + + out_cpu = state_cpu = None + for _ in range(args.runs): + out_dev, state_dev = deltanet_chunk_step( + xla_inputs["query"], + xla_inputs["key"], + xla_inputs["value"], + xla_inputs["beta_broadcast"], + xla_inputs["g_cumsum"], + xla_inputs["g_last"], + xla_inputs["state_in"], + xla_inputs["lower_mask"], + xla_inputs["identity"], + xla_inputs["lower_mask_diag"], + ) + xm.mark_step() + out_cpu = out_dev.detach().cpu().float() + state_cpu = state_dev.detach().cpu().float() + + assert out_cpu is not None + assert state_cpu is not None + + nki_vs_math = { + "output": tensor_metrics(torch, out_cpu, math_out), + "state": tensor_metrics(torch, state_cpu, math_state), + } + nki_vs_mirror = { + "output": tensor_metrics(torch, out_cpu, mirror_out), + "state": tensor_metrics(torch, state_cpu, mirror_state), + } + + output_close = torch.allclose(out_cpu, math_out, atol=args.atol, rtol=args.rtol) + state_close = torch.allclose(state_cpu, math_state, atol=args.atol, rtol=args.rtol) + output_finite = bool(torch.isfinite(out_cpu).all().item()) + state_finite = bool(torch.isfinite(state_cpu).all().item()) + passed = bool(output_close and state_close and output_finite and state_finite) + + result = { + "passed": passed, + "seed": args.seed, + "runs": args.runs, + "atol": args.atol, + "rtol": args.rtol, + "inspect": args.inspect, + "dge": args.dge, + "output_finite": output_finite, + "state_finite": state_finite, + "inspect_dir": str(inspect_dir), + "environment": { + key: os.environ.get(key) + for key in ( + "NEURON_CC_FLAGS", + "NEURON_PLATFORM_TARGET_OVERRIDE", + "NEURON_RT_VISIBLE_CORES", + "NEURON_RT_INSPECT_ENABLE", + "NEURON_RT_ENABLE_DGE_NOTIFICATIONS", + ) + }, + "mirror_vs_math": mirror_vs_math, + "nki_vs_math": nki_vs_math, + "nki_vs_mirror": nki_vs_mirror, + } + print(json.dumps(result, indent=2, sort_keys=True)) + + if args.fail_on_mismatch and not passed: + return 2 + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/contrib/models/Qwen3.6-27B/scripts/validate_deltanet_fused_nki.py b/contrib/models/Qwen3.6-27B/scripts/validate_deltanet_fused_nki.py new file mode 100644 index 00000000..f8030613 --- /dev/null +++ b/contrib/models/Qwen3.6-27B/scripts/validate_deltanet_fused_nki.py @@ -0,0 +1,2638 @@ +#!/usr/bin/env python3 +"""Validate and optionally inspect/profile the fused Qwen DeltaNet NKI kernel. + +The CPU reference stays off the XLA device so the generated NEFFs are from the +NKI kernel under test, not from reference PyTorch ops. +""" + +from __future__ import annotations + +import argparse +import importlib.util +import json +import math +import os +import sys +import time +from pathlib import Path +from typing import Any + + +P_MAX = 128 + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Validate/profile deltanet_fused_chunked_fwd against CPU math." + ) + parser.add_argument("--seed", type=int, default=1234) + parser.add_argument("--seq-len", type=int, default=256) + parser.add_argument( + "--chunk-size", + type=int, + default=int(os.environ.get("QWEN36_DELTANET_CHUNK_SIZE", "128")), + choices=(64, 128), + help=( + "Active fused GDN CTE token chunk size. 128 is the existing path; " + "64 is the FlashQLA-inspired probe path." + ), + ) + parser.add_argument("--heads", type=int, default=2) + parser.add_argument( + "--multihead", + action="store_true", + help="Validate deltanet_fused_chunked_fwd_multihead with one grid program per head.", + ) + parser.add_argument( + "--validate-cpu-chunk-invariance", + action="store_true", + help=( + "Run a CPU-only check that exact 64-token and 128-token chunked " + "DeltaNet references produce equivalent outputs/states." + ), + ) + parser.add_argument( + "--validate-restored-suffix-carry", + action="store_true", + help=( + "Validate the serving-style GDN carry boundary: run one full padded " + "sequence and compare it with restored calls over split CTE buckets." + ), + ) + parser.add_argument( + "--restore-split-lens", + default="512,512,201", + help=( + "Comma-separated real token counts for " + "--validate-restored-suffix-carry. Each segment is padded to " + "--restore-bucket-size before the next restored call." + ), + ) + parser.add_argument( + "--restore-bucket-size", + type=int, + default=512, + help="Per-call padded CTE bucket size for --validate-restored-suffix-carry.", + ) + parser.add_argument( + "--validate-autocp-affine", + action="store_true", + help=( + "Validate the isolated one-chunk AutoCP affine-piece NKI probe " + "instead of the fused forward kernel." + ), + ) + parser.add_argument( + "--validate-autocp-prefix", + action="store_true", + help=( + "Validate the isolated AutoCP state-prefix NKI probe over all " + "128-token chunks in --seq-len." + ), + ) + parser.add_argument( + "--validate-autocp-chain", + action="store_true", + help=( + "Validate the isolated AutoCP prefix plus output-apply NKI chain " + "against the CPU AutoCP reference." + ), + ) + parser.add_argument( + "--validate-autocp-prefix-apply", + action="store_true", + help=( + "Validate the fused AutoCP state-prefix/output-apply NKI pass " + "against CPU affine stacks." + ), + ) + parser.add_argument( + "--validate-autocp-full", + action="store_true", + help=( + "Validate NKI chunk-parallel affine generation plus prefix/apply " + "against the CPU AutoCP reference." + ), + ) + parser.add_argument( + "--validate-autocp-state-summary", + action="store_true", + help=( + "Validate compact AutoCP NKI state-summary generation against " + "CPU segment state transforms." + ), + ) + parser.add_argument( + "--validate-autocp-compact-chain", + action="store_true", + help=( + "Validate compact AutoCP summary + state-prefix + recurrent segment " + "replay against the CPU sequential reference." + ), + ) + parser.add_argument( + "--validate-compact-autocp-reference", + action="store_true", + help=( + "Run a CPU-only check for compact AutoCP state summaries: compose " + "chunk transforms into per-segment state transforms, prefix segment " + "states, then replay each segment recurrently." + ), + ) + parser.add_argument( + "--autocp-cp-chunks", + type=int, + default=4, + help="Number of 128-token chunks per compact AutoCP segment.", + ) + parser.add_argument( + "--head-group-size", + type=int, + default=1, + help=( + "Number of flattened (batch, head) rows per multihead NKI launch. " + "Use larger values to test launch-grid batching." + ), + ) + parser.add_argument("--runs", type=int, default=1) + parser.add_argument("--target", default="trn2") + parser.add_argument("--lnc", type=int, default=1) + parser.add_argument("--visible-cores", default="0") + parser.add_argument("--inspect", action="store_true") + parser.add_argument("--dge", action="store_true") + parser.add_argument( + "--inspect-dir", + default="/mnt/trainium_artifacts/profiles/deltanet_fused_isolated", + ) + parser.add_argument("--atol", type=float, default=3.0e-2) + parser.add_argument("--rtol", type=float, default=3.0e-2) + parser.add_argument("--value-scale", type=float, default=0.05) + parser.add_argument("--state-scale", type=float, default=0.01) + parser.add_argument("--gate-scale", type=float, default=0.01) + parser.add_argument("--fail-on-mismatch", action="store_true") + return parser.parse_args() + + +def configure_environment(args: argparse.Namespace) -> Path: + chunk_size = int(getattr(args, "chunk_size", P_MAX)) + args.chunk_size = chunk_size + if args.seq_len <= 0 or args.seq_len % chunk_size != 0: + raise ValueError( + "--seq-len must be a positive multiple of --chunk-size; " + f"got seq_len={args.seq_len}, chunk_size={chunk_size}" + ) + autocp_modes = ( + getattr(args, "validate_autocp_affine", False) + or getattr(args, "validate_autocp_prefix", False) + or getattr(args, "validate_autocp_chain", False) + or getattr(args, "validate_autocp_prefix_apply", False) + or getattr(args, "validate_autocp_full", False) + or getattr(args, "validate_autocp_state_summary", False) + or getattr(args, "validate_autocp_compact_chain", False) + or getattr(args, "validate_compact_autocp_reference", False) + ) + if autocp_modes and chunk_size != P_MAX: + raise ValueError( + "AutoCP validators are still 128-token chunk probes; " + f"got --chunk-size={chunk_size}" + ) + if args.multihead and args.heads <= 0: + raise ValueError("--heads must be positive when --multihead is set") + if args.head_group_size <= 0: + raise ValueError("--head-group-size must be positive") + if getattr(args, "validate_autocp_full", False) and (args.seq_len // P_MAX) % 2 != 0: + raise ValueError( + "--validate-autocp-full uses LNC2-striped affine generation and " + "requires an even 128-token chunk count; " + f"got seq_len={args.seq_len}, chunks={args.seq_len // P_MAX}" + ) + cp_chunks = int(getattr(args, "autocp_cp_chunks", 4)) + args.autocp_cp_chunks = cp_chunks + if cp_chunks <= 0: + raise ValueError("--autocp-cp-chunks must be positive") + compact_nki_modes = ( + getattr(args, "validate_autocp_state_summary", False) + or getattr(args, "validate_autocp_compact_chain", False) + ) + if compact_nki_modes: + num_chunks = args.seq_len // P_MAX + if num_chunks % cp_chunks != 0: + raise ValueError( + "Compact AutoCP NKI validators require the 128-token chunk " + "count to be divisible by --autocp-cp-chunks; " + f"chunks={num_chunks}, cp_chunks={cp_chunks}" + ) + num_segments = num_chunks // cp_chunks + if num_segments % args.lnc != 0: + raise ValueError( + "--validate-autocp-state-summary requires the segment count " + "to be divisible by --lnc; " + f"segments={num_segments}, lnc={args.lnc}" + ) + os.environ.setdefault("NEURON_PLATFORM_TARGET_OVERRIDE", args.target) + os.environ.setdefault("NEURON_CC_FLAGS", f"--target {args.target} --lnc {args.lnc}") + os.environ.setdefault("NEURON_RT_VISIBLE_CORES", args.visible_cores) + os.environ["QWEN36_DELTANET_CHUNK_SIZE"] = str(chunk_size) + os.environ["QWEN36_DELTANET_AUTOCP_CP_CHUNKS"] = str(cp_chunks) + + inspect_dir = Path(args.inspect_dir).expanduser().resolve() + if args.inspect: + inspect_dir.mkdir(parents=True, exist_ok=True) + os.environ["NEURON_RT_INSPECT_ENABLE"] = "1" + os.environ["NEURON_RT_INSPECT_DEVICE_PROFILE"] = "1" + os.environ["NEURON_RT_INSPECT_SYSTEM_PROFILE"] = "0" + os.environ["NEURON_RT_INSPECT_OUTPUT_DIR"] = str(inspect_dir) + os.environ["XLA_IR_DEBUG"] = "1" + os.environ["XLA_HLO_DEBUG"] = "1" + os.environ["NEURON_FRAMEWORK_DEBUG"] = "1" + if args.dge: + os.environ["NEURON_RT_ENABLE_DGE_NOTIFICATIONS"] = "1" + return inspect_dir + + +def add_qwen_to_path() -> None: + script_path = Path(__file__).resolve() + qwen_root = script_path.parents[1] + sys.path.insert(0, str(qwen_root)) + + +def load_kernel_module(): + kernel_path = ( + Path(__file__).resolve().parents[1] + / "src" + / "nki_kernels" + / "nki_deltanet_fused.py" + ) + spec = importlib.util.spec_from_file_location( + "qwen36_nki_deltanet_fused_under_test", + kernel_path, + ) + module = importlib.util.module_from_spec(spec) + assert spec.loader is not None + spec.loader.exec_module(module) + return module + + +def load_fused_kernel(multihead: bool): + module = load_kernel_module() + if multihead: + return module.deltanet_fused_chunked_fwd_multihead + return module.deltanet_fused_chunked_fwd + + +def load_autocp_affine_kernel(): + return load_kernel_module().deltanet_autocp_affine_chunk + + +def load_autocp_affine_sequence_kernel(): + return load_kernel_module().deltanet_autocp_affine_sequence + + +def load_autocp_state_summary_kernel(): + return load_kernel_module().deltanet_autocp_state_summary_sequence + + +def load_autocp_prefix_kernel(): + return load_kernel_module().deltanet_autocp_state_prefix + + +def load_autocp_apply_kernel(): + return load_kernel_module().deltanet_autocp_apply_output + + +def load_autocp_prefix_apply_kernel(): + return load_kernel_module().deltanet_autocp_prefix_apply_output + + +def multihead_launch_spec(num_heads: int, lnc: int): + if num_heads <= lnc: + return num_heads + if os.environ.get("QWEN36_DELTANET_MULTIHEAD_SPMD", "1") == "0": + raise ValueError( + "--head-group-size exceeds --lnc but " + "QWEN36_DELTANET_MULTIHEAD_SPMD=0; " + f"group_size={num_heads}, lnc={lnc}" + ) + + import nki.language as nl + + if not hasattr(nl, "spmd_dim") or not hasattr(nl, "nc"): + if os.environ.get("QWEN36_DELTANET_MULTIHEAD_GRID_FALLBACK", "0") == "1": + return (num_heads, 1) + raise ValueError( + "--head-group-size exceeds --lnc, but this NKI runtime does not " + f"expose spmd_dim/nc; group_size={num_heads}, lnc={lnc}" + ) + return (nl.spmd_dim(num_heads, nl.nc(lnc)),) + + +def launch_spec_label(spec: Any) -> str: + if isinstance(spec, int): + return str(spec) + return repr(spec) + + +def make_inputs(torch: Any, args: argparse.Namespace) -> dict[str, Any]: + generator = torch.Generator(device="cpu") + generator.manual_seed(args.seed) + + def randn(shape: tuple[int, ...], scale: float) -> Any: + return torch.randn(shape, generator=generator, dtype=torch.float32) * scale + + multihead = bool(getattr(args, "multihead", False)) + heads = int(getattr(args, "heads", 1)) + chunk_size = int(getattr(args, "chunk_size", P_MAX)) + prefix_shape = (heads,) if multihead else () + query = randn((*prefix_shape, args.seq_len, P_MAX), args.value_scale) + key = randn((*prefix_shape, args.seq_len, P_MAX), args.value_scale) + value = randn((*prefix_shape, args.seq_len, P_MAX), args.value_scale) + state_in = randn((*prefix_shape, P_MAX, P_MAX), args.state_scale) + + beta = torch.sigmoid(randn((*prefix_shape, args.seq_len, 1), 1.0)) + g_raw = -torch.nn.functional.softplus(randn((*prefix_shape, args.seq_len, 1), 1.0)) + g_raw = g_raw * args.gate_scale + + lower_mask = torch.zeros((P_MAX, P_MAX), dtype=torch.float32) + lower_mask_diag = torch.zeros((P_MAX, P_MAX), dtype=torch.float32) + identity = torch.zeros((P_MAX, P_MAX), dtype=torch.float32) + lower_mask[:chunk_size, :chunk_size] = torch.tril( + torch.ones((chunk_size, chunk_size), dtype=torch.float32), + diagonal=-1, + ) + lower_mask_diag[:chunk_size, :chunk_size] = torch.tril( + torch.ones((chunk_size, chunk_size), dtype=torch.float32) + ) + identity[:chunk_size, :chunk_size] = torch.eye( + chunk_size, + dtype=torch.float32, + ) + + return { + "chunk_size": chunk_size, + "query": query.contiguous(), + "key": key.contiguous(), + "value": value.contiguous(), + "g_raw": g_raw.contiguous(), + "beta": beta.contiguous(), + "state_in": state_in.contiguous(), + "lower_mask": lower_mask.contiguous(), + "identity": identity.contiguous(), + "lower_mask_diag": lower_mask_diag.contiguous(), + } + + +def stable_causal_decay(torch: Any, gc: Any, mask: Any) -> Any: + """Compute exp(gc[i] - gc[j]) only where the causal mask is active.""" + diff = gc - gc.T + masked_diff = torch.where(mask.bool(), diff, torch.zeros_like(diff)) + return torch.exp(masked_diff) * mask + + +def move_tensor_inputs_to_device(inputs: dict[str, Any], device: Any) -> dict[str, Any]: + return { + name: tensor.to(device=device) + for name, tensor in inputs.items() + if hasattr(tensor, "to") + } + + +def normalize_reference_qk(torch: Any, query: Any, key: Any) -> tuple[Any, Any]: + """Match the fused kernel's in-kernel Q/K l2norm and Q scale.""" + query_norm = torch.nn.functional.normalize( + query, + p=2, + dim=-1, + eps=1.0e-6, + ) / math.sqrt(P_MAX) + key_norm = torch.nn.functional.normalize( + key, + p=2, + dim=-1, + eps=1.0e-6, + ) + return query_norm, key_norm + + +def blocked_lower_triangular_solve( + torch: Any, + lhs: Any, + rhs: Any, + block_size: int, +) -> Any: + """Solve lower-triangular lhs @ x = rhs by block forward substitution.""" + if block_size <= 0: + raise ValueError("block_size must be positive") + rows, cols = lhs.shape + if rows != cols: + raise ValueError(f"lhs must be square, got {lhs.shape}") + if rhs.shape[0] != rows: + raise ValueError( + f"rhs row count must match lhs, got lhs={lhs.shape} rhs={rhs.shape}" + ) + + solved_blocks = [] + for row_start in range(0, rows, block_size): + row_end = min(row_start + block_size, rows) + rhs_block = rhs[row_start:row_end].clone() + for col_start, solved in zip( + range(0, row_start, block_size), + solved_blocks, + ): + col_end = min(col_start + block_size, rows) + lhs_block = lhs[row_start:row_end, col_start:col_end] + rhs_block = rhs_block - lhs_block @ solved + + diag_block = lhs[row_start:row_end, row_start:row_end] + solved_block = torch.linalg.solve_triangular( + diag_block, + rhs_block, + upper=False, + ) + solved_blocks.append(solved_block) + + return torch.cat(solved_blocks, dim=0) + + +def block_prefix_lower_triangular_solve( + torch: Any, + lhs: Any, + rhs: Any, + block_size: int, +) -> Any: + """Solve lhs @ x = rhs with FLA-style block affine segment combines. + + Each diagonal block first produces an affine map from all preceding rows + to the block output: + + x_i = solve(L_ii, rhs_i) - solve(L_ii, L_i, 1: + merged = [] + for idx in range(0, len(segments), 2): + if idx + 1 == len(segments): + merged.append(segments[idx]) + continue + + left = segments[idx] + right = segments[idx + 1] + if left["end"] != right["start"]: + raise ValueError( + "segments must be contiguous, got " + f"{left['start']}:{left['end']} and " + f"{right['start']}:{right['end']}" + ) + + ext_width = left["start"] + left_width = left["end"] - left["start"] + right_transfer = right["transfer"] + right_ext = right_transfer[:, 0:ext_width] + right_left = right_transfer[:, ext_width : ext_width + left_width] + + merged_right_rhs = right["rhs"] + right_left @ left["rhs"] + if ext_width == 0: + merged_transfer = lhs.new_empty( + (left["end"] - left["start"] + right["end"] - right["start"], 0) + ) + else: + merged_right_transfer = right_ext + right_left @ left["transfer"] + merged_transfer = torch.cat( + [left["transfer"], merged_right_transfer], + dim=0, + ) + + merged.append( + { + "start": left["start"], + "end": right["end"], + "rhs": torch.cat([left["rhs"], merged_right_rhs], dim=0), + "transfer": merged_transfer, + } + ) + segments = merged + + return segments[0]["rhs"] + + +def _hierarchical_kkt_inverse_transpose( + torch: Any, + a_t: Any, + leaf_size: int, +) -> Any: + """Build ``inv(I - A).T`` using FlashQLA-style block combines.""" + rows, cols = a_t.shape + if rows != cols: + raise ValueError(f"a_t must be square, got {a_t.shape}") + if rows <= leaf_size: + steps = math.ceil(math.log2(rows)) + inv_t = torch.eye(rows, dtype=a_t.dtype, device=a_t.device) + power_t = a_t.clone() + power = power_t.T.contiguous() + for step_idx in range(steps): + inv_t = inv_t + inv_t @ power_t + if step_idx != steps - 1: + power = power @ power + power_t = power.T.contiguous() + return inv_t + + if rows % 2 != 0: + raise ValueError(f"rows must split evenly, got {rows}") + mid = rows // 2 + left_t = _hierarchical_kkt_inverse_transpose( + torch, + a_t[0:mid, 0:mid], + leaf_size, + ) + right_t = _hierarchical_kkt_inverse_transpose( + torch, + a_t[mid:rows, mid:rows], + leaf_size, + ) + a_cross_t = a_t[0:mid, mid:rows] + cross_t = left_t @ a_cross_t @ right_t + + inv_t = torch.zeros_like(a_t) + inv_t[0:mid, 0:mid] = left_t + inv_t[0:mid, mid:rows] = cross_t + inv_t[mid:rows, mid:rows] = right_t + return inv_t + + +def hierarchical_kkt_lower_triangular_solve( + torch: Any, + lhs: Any, + rhs: Any, + leaf_size: int = 16, +) -> Any: + """Solve ``lhs @ x = rhs`` with FlashQLA-style KKT hierarchy. + + FlashQLA's Hopper KKT kernel builds the intra-chunk triangular inverse + through small diagonal inversions and block combines. This CPU reference + mirrors that algebra: for ``lhs = I - A`` with strictly lower ``A``, build + ``N.T = inv(I - A).T`` recursively, then compute ``x = N @ rhs``. + """ + if leaf_size <= 0: + raise ValueError("leaf_size must be positive") + rows, cols = lhs.shape + if rows != cols: + raise ValueError(f"lhs must be square, got {lhs.shape}") + if rhs.shape[0] != rows: + raise ValueError( + f"rhs row count must match lhs, got lhs={lhs.shape} rhs={rhs.shape}" + ) + if rows % leaf_size != 0 or leaf_size & (leaf_size - 1) != 0: + raise ValueError( + "leaf_size must be a power of two that divides the row count; " + f"got rows={rows}, leaf_size={leaf_size}" + ) + + eye = torch.eye(rows, dtype=lhs.dtype, device=lhs.device) + a_t = (eye - lhs).T.contiguous() + inv_t = _hierarchical_kkt_inverse_transpose(torch, a_t, leaf_size) + return inv_t.T @ rhs + + +def scan_doubling_lower_triangular_solve( + torch: Any, + lhs: Any, + rhs: Any, + steps: int, +) -> Any: + """Solve/approximate lower-triangular lhs @ x = rhs by Neumann doubling.""" + if steps <= 0: + raise ValueError("steps must be positive") + rows, cols = lhs.shape + if rows != cols: + raise ValueError(f"lhs must be square, got {lhs.shape}") + if rhs.shape[0] != rows: + raise ValueError( + f"rhs row count must match lhs, got lhs={lhs.shape} rhs={rhs.shape}" + ) + + eye = torch.eye(rows, dtype=lhs.dtype, device=lhs.device) + power = eye - lhs + solved = rhs.clone() + for scan_idx in range(steps): + solved = solved + power @ solved + if scan_idx != steps - 1: + power = power @ power + return solved + + +def reference_math_one_head(torch: Any, inputs: dict[str, Any]) -> tuple[Any, Any]: + chunk_size = int(inputs.get("chunk_size", P_MAX)) + lower = inputs["lower_mask"][0:chunk_size, 0:chunk_size] + lower_diag = inputs["lower_mask_diag"][0:chunk_size, 0:chunk_size] + eye = inputs["identity"][0:chunk_size, 0:chunk_size] + state = inputs["state_in"].clone() + outputs = [] + + for start in range(0, inputs["query"].shape[0], chunk_size): + end = start + chunk_size + q, k = normalize_reference_qk( + torch, + inputs["query"][start:end], + inputs["key"][start:end], + ) + v = inputs["value"][start:end] + g = inputs["g_raw"][start:end] + beta = inputs["beta"][start:end] + + gc = torch.cumsum(g, dim=0) + gl = gc[-1:] + k_beta = k * beta + v_beta = v * beta + + decay_strict = stable_causal_decay(torch, gc, lower) + decay_diag = stable_causal_decay(torch, gc, lower_diag) + + qk_beta = k_beta @ k.T + a_mat = -(qk_beta * decay_strict) * lower + + lhs = eye - a_mat + + exp_gc = torch.exp(gc) + solve_rhs = v_beta - ((k_beta * exp_gc) @ state) + v_new = torch.linalg.solve_triangular(lhs, solve_rhs, upper=False) + attn_intra = (q @ k.T) * decay_diag + + chunk_out = ((q * exp_gc) @ state) + (attn_intra @ v_new) + outputs.append(chunk_out) + + k_raw_decay = k * torch.exp(gl - gc) + state = (state * torch.exp(gl)) + (k_raw_decay.T @ v_new) + + return torch.cat(outputs, dim=0).contiguous(), state.contiguous() + + +def deltanet_chunk_affine_parts( + torch: Any, + inputs: dict[str, Any], + start: int, +) -> dict[str, Any]: + """Return FlashQLA-style per-chunk affine pieces independent of state. + + For each chunk, DeltaNet can be represented as: + + output_i = output_base_i + output_state_i @ state_i + state_{i+1} = state_matrix_i @ state_i + state_bias_i + + The current fused NKI path computes this implicitly while carrying + ``state`` through chunks. AutoCP needs these pieces first, then performs a + prefix over ``(state_matrix, state_bias)`` to recover chunk initial states. + """ + end = start + P_MAX + lower = inputs["lower_mask"] + lower_diag = inputs["lower_mask_diag"] + eye = inputs["identity"] + + q, k = normalize_reference_qk( + torch, + inputs["query"][start:end], + inputs["key"][start:end], + ) + v = inputs["value"][start:end] + g = inputs["g_raw"][start:end] + beta = inputs["beta"][start:end] + + gc = torch.cumsum(g, dim=0) + gl = gc[-1:] + exp_gc = torch.exp(gc) + exp_gl = torch.exp(gl).reshape(()) + k_beta = k * beta + v_beta = v * beta + + decay_strict = stable_causal_decay(torch, gc, lower) + decay_diag = stable_causal_decay(torch, gc, lower_diag) + + qk_beta = k_beta @ k.T + a_mat = -(qk_beta * decay_strict) * lower + lhs = eye - a_mat + + value_u = torch.linalg.solve_triangular(lhs, v_beta, upper=False) + state_w = torch.linalg.solve_triangular(lhs, k_beta * exp_gc, upper=False) + attn_intra = (q @ k.T) * decay_diag + k_raw_decay = k * torch.exp(gl - gc) + + output_base = attn_intra @ value_u + output_state = (q * exp_gc) - (attn_intra @ state_w) + state_matrix = (exp_gl * eye) - (k_raw_decay.T @ state_w) + state_bias = k_raw_decay.T @ value_u + + return { + "output_base": output_base, + "output_state": output_state, + "state_matrix": state_matrix, + "state_bias": state_bias, + } + + +def apply_deltanet_chunk_affine( + torch: Any, + parts: dict[str, Any], + state: Any, +) -> tuple[Any, Any]: + output = parts["output_base"] + (parts["output_state"] @ state) + next_state = (parts["state_matrix"] @ state) + parts["state_bias"] + return output, next_state + + +def compose_deltanet_state_affine( + torch: Any, + first: dict[str, Any], + second: dict[str, Any], +) -> dict[str, Any]: + """Compose two state transforms, applying ``first`` then ``second``.""" + matrix = second["state_matrix"] @ first["state_matrix"] + bias = (second["state_matrix"] @ first["state_bias"]) + second["state_bias"] + return {"state_matrix": matrix, "state_bias": bias} + + +def autocp_reference_math_one_head( + torch: Any, + inputs: dict[str, Any], + cp_chunks: int = 4, +) -> tuple[Any, Any]: + """Reference FlashQLA-style grouped context-parallel state prepass. + + ``cp_chunks`` is the number of 128-token chunks in each local CP segment. + This CPU reference still executes serially, but its dataflow is the one we + need for a NKI port: build state-independent chunk transforms, combine them + per segment, then run each segment from a corrected initial state. + """ + if cp_chunks <= 0: + raise ValueError("cp_chunks must be positive") + seq_len = inputs["query"].shape[0] + if seq_len % P_MAX != 0: + raise ValueError(f"seq_len must be divisible by {P_MAX}, got {seq_len}") + + parts = [ + deltanet_chunk_affine_parts(torch, inputs, start) + for start in range(0, seq_len, P_MAX) + ] + + eye = inputs["identity"] + zero_state = torch.zeros_like(inputs["state_in"]) + group_transforms = [] + for group_start in range(0, len(parts), cp_chunks): + group_transform = {"state_matrix": eye, "state_bias": zero_state} + for chunk_parts in parts[group_start : group_start + cp_chunks]: + group_transform = compose_deltanet_state_affine( + torch, + group_transform, + chunk_parts, + ) + group_transforms.append(group_transform) + + group_initial_states = [] + state = inputs["state_in"].clone() + for transform in group_transforms: + group_initial_states.append(state) + state = (transform["state_matrix"] @ state) + transform["state_bias"] + + outputs = [] + for group_idx, group_start in enumerate(range(0, len(parts), cp_chunks)): + state = group_initial_states[group_idx] + for chunk_parts in parts[group_start : group_start + cp_chunks]: + output, state = apply_deltanet_chunk_affine(torch, chunk_parts, state) + outputs.append(output) + + final_state = state + return torch.cat(outputs, dim=0).contiguous(), final_state.contiguous() + + +def blocked_reference_math_one_head( + torch: Any, + inputs: dict[str, Any], + block_size: int = 16, +) -> tuple[Any, Any]: + lower = inputs["lower_mask"] + lower_diag = inputs["lower_mask_diag"] + eye = inputs["identity"] + state = inputs["state_in"].clone() + outputs = [] + + for start in range(0, inputs["query"].shape[0], P_MAX): + end = start + P_MAX + q, k = normalize_reference_qk( + torch, + inputs["query"][start:end], + inputs["key"][start:end], + ) + v = inputs["value"][start:end] + g = inputs["g_raw"][start:end] + beta = inputs["beta"][start:end] + + gc = torch.cumsum(g, dim=0) + gl = gc[-1:] + k_beta = k * beta + v_beta = v * beta + + decay_strict = stable_causal_decay(torch, gc, lower) + decay_diag = stable_causal_decay(torch, gc, lower_diag) + + qk_beta = k_beta @ k.T + a_mat = -(qk_beta * decay_strict) * lower + lhs = eye - a_mat + + exp_gc = torch.exp(gc) + solve_rhs = v_beta - ((k_beta * exp_gc) @ state) + v_new = blocked_lower_triangular_solve( + torch, + lhs, + solve_rhs, + block_size, + ) + attn_intra = (q @ k.T) * decay_diag + + chunk_out = ((q * exp_gc) @ state) + (attn_intra @ v_new) + outputs.append(chunk_out) + + k_raw_decay = k * torch.exp(gl - gc) + state = (state * torch.exp(gl)) + (k_raw_decay.T @ v_new) + + return torch.cat(outputs, dim=0).contiguous(), state.contiguous() + + +def reference_math(torch: Any, inputs: dict[str, Any]) -> tuple[Any, Any]: + if inputs["query"].dim() == 2: + return reference_math_one_head(torch, inputs) + + outputs = [] + states = [] + for head_idx in range(inputs["query"].shape[0]): + head_inputs = { + "chunk_size": inputs.get("chunk_size", P_MAX), + "query": inputs["query"][head_idx], + "key": inputs["key"][head_idx], + "value": inputs["value"][head_idx], + "g_raw": inputs["g_raw"][head_idx], + "beta": inputs["beta"][head_idx], + "state_in": inputs["state_in"][head_idx], + "lower_mask": inputs["lower_mask"], + "identity": inputs["identity"], + "lower_mask_diag": inputs["lower_mask_diag"], + } + out, state = reference_math_one_head(torch, head_inputs) + outputs.append(out) + states.append(state) + + return torch.stack(outputs, dim=0).contiguous(), torch.stack(states, dim=0).contiguous() + + +def blocked_reference_math( + torch: Any, + inputs: dict[str, Any], + block_size: int = 16, +) -> tuple[Any, Any]: + if inputs["query"].dim() == 2: + return blocked_reference_math_one_head(torch, inputs, block_size) + + outputs = [] + states = [] + for head_idx in range(inputs["query"].shape[0]): + head_inputs = { + "chunk_size": inputs.get("chunk_size", P_MAX), + "query": inputs["query"][head_idx], + "key": inputs["key"][head_idx], + "value": inputs["value"][head_idx], + "g_raw": inputs["g_raw"][head_idx], + "beta": inputs["beta"][head_idx], + "state_in": inputs["state_in"][head_idx], + "lower_mask": inputs["lower_mask"], + "identity": inputs["identity"], + "lower_mask_diag": inputs["lower_mask_diag"], + } + out, state = blocked_reference_math_one_head( + torch, + head_inputs, + block_size, + ) + outputs.append(out) + states.append(state) + + return torch.stack(outputs, dim=0).contiguous(), torch.stack(states, dim=0).contiguous() + + +def autocp_reference_math( + torch: Any, + inputs: dict[str, Any], + cp_chunks: int = 4, +) -> tuple[Any, Any]: + if inputs["query"].dim() == 2: + return autocp_reference_math_one_head(torch, inputs, cp_chunks) + + outputs = [] + states = [] + for head_idx in range(inputs["query"].shape[0]): + head_inputs = { + "chunk_size": inputs.get("chunk_size", P_MAX), + "query": inputs["query"][head_idx], + "key": inputs["key"][head_idx], + "value": inputs["value"][head_idx], + "g_raw": inputs["g_raw"][head_idx], + "beta": inputs["beta"][head_idx], + "state_in": inputs["state_in"][head_idx], + "lower_mask": inputs["lower_mask"], + "identity": inputs["identity"], + "lower_mask_diag": inputs["lower_mask_diag"], + } + out, state = autocp_reference_math_one_head(torch, head_inputs, cp_chunks) + outputs.append(out) + states.append(state) + + return torch.stack(outputs, dim=0).contiguous(), torch.stack(states, dim=0).contiguous() + + +def build_compact_autocp_segment_transforms( + torch: Any, + inputs: dict[str, Any], + cp_chunks: int = 4, +) -> dict[str, Any]: + """Compose chunk state transforms into compact per-segment transforms. + + Unlike ``build_autocp_affine_stacks``, this deliberately does not retain + output-affine pieces. The intended NKI path prefixes only segment-level + state transforms, then replays each segment recurrently from its corrected + initial state. + """ + if cp_chunks <= 0: + raise ValueError("cp_chunks must be positive") + seq_len = inputs["query"].shape[0] + if seq_len % P_MAX != 0: + raise ValueError(f"seq_len must be divisible by {P_MAX}, got {seq_len}") + + num_chunks = seq_len // P_MAX + identity = inputs["identity"] + zero_state = torch.zeros_like(inputs["state_in"]) + segment_matrices = [] + segment_biases = [] + segment_chunk_counts = [] + + for chunk_start in range(0, num_chunks, cp_chunks): + segment_transform = {"state_matrix": identity, "state_bias": zero_state} + chunk_end = min(chunk_start + cp_chunks, num_chunks) + for chunk_idx in range(chunk_start, chunk_end): + chunk_parts = deltanet_chunk_affine_parts( + torch, + inputs, + chunk_idx * P_MAX, + ) + segment_transform = compose_deltanet_state_affine( + torch, + segment_transform, + chunk_parts, + ) + segment_matrices.append(segment_transform["state_matrix"]) + segment_biases.append(segment_transform["state_bias"]) + segment_chunk_counts.append(chunk_end - chunk_start) + + return { + "state_matrix": torch.stack(segment_matrices, dim=0).contiguous(), + "state_bias": torch.stack(segment_biases, dim=0).contiguous(), + "chunk_counts": segment_chunk_counts, + } + + +def slice_single_head_inputs(inputs: dict[str, Any], start: int, end: int, state: Any) -> dict[str, Any]: + return { + "chunk_size": inputs.get("chunk_size", P_MAX), + "query": inputs["query"][start:end], + "key": inputs["key"][start:end], + "value": inputs["value"][start:end], + "g_raw": inputs["g_raw"][start:end], + "beta": inputs["beta"][start:end], + "state_in": state, + "lower_mask": inputs["lower_mask"], + "identity": inputs["identity"], + "lower_mask_diag": inputs["lower_mask_diag"], + } + + +def compact_autocp_reference_math_one_head( + torch: Any, + inputs: dict[str, Any], + cp_chunks: int = 4, +) -> tuple[Any, Any]: + """Reference compact AutoCP: segment state prefix plus recurrent replay.""" + segment_transforms = build_compact_autocp_segment_transforms( + torch, + inputs, + cp_chunks, + ) + segment_states, final_state = autocp_state_prefix_reference( + torch, + segment_transforms["state_matrix"], + segment_transforms["state_bias"], + inputs["state_in"], + ) + + outputs = [] + token_start = 0 + for segment_idx, chunk_count in enumerate(segment_transforms["chunk_counts"]): + token_end = token_start + (chunk_count * P_MAX) + segment_inputs = slice_single_head_inputs( + inputs, + token_start, + token_end, + segment_states[segment_idx], + ) + segment_output, _ = reference_math_one_head( + torch, + segment_inputs, + ) + outputs.append(segment_output) + token_start = token_end + + return torch.cat(outputs, dim=0).contiguous(), final_state.contiguous() + + +def compact_autocp_reference_math( + torch: Any, + inputs: dict[str, Any], + cp_chunks: int = 4, +) -> tuple[Any, Any]: + if inputs["query"].dim() == 2: + return compact_autocp_reference_math_one_head(torch, inputs, cp_chunks) + + outputs = [] + states = [] + for head_idx in range(inputs["query"].shape[0]): + head_inputs = { + "chunk_size": inputs.get("chunk_size", P_MAX), + "query": inputs["query"][head_idx], + "key": inputs["key"][head_idx], + "value": inputs["value"][head_idx], + "g_raw": inputs["g_raw"][head_idx], + "beta": inputs["beta"][head_idx], + "state_in": inputs["state_in"][head_idx], + "lower_mask": inputs["lower_mask"], + "identity": inputs["identity"], + "lower_mask_diag": inputs["lower_mask_diag"], + } + out, state = compact_autocp_reference_math_one_head( + torch, + head_inputs, + cp_chunks, + ) + outputs.append(out) + states.append(state) + + return torch.stack(outputs, dim=0).contiguous(), torch.stack(states, dim=0).contiguous() + + +def compact_autocp_materialization_counts( + seq_len: int, + cp_chunks: int, +) -> dict[str, int]: + if cp_chunks <= 0: + raise ValueError("cp_chunks must be positive") + if seq_len % P_MAX != 0: + raise ValueError(f"seq_len must be divisible by {P_MAX}, got {seq_len}") + num_chunks = seq_len // P_MAX + num_segments = math.ceil(num_chunks / cp_chunks) + return { + "num_chunks": num_chunks, + "num_segments": num_segments, + "existing_autocp_dense_128x128_tensors": 4 * num_chunks, + "compact_prefix_dense_128x128_tensors": 2 * num_segments, + "dense_tensor_reduction": (4 * num_chunks) - (2 * num_segments), + } + + +def tensor_metrics(torch: Any, actual: Any, expected: Any) -> dict[str, float | bool]: + diff = actual - expected + expected_norm = torch.linalg.vector_norm(expected).item() + diff_norm = torch.linalg.vector_norm(diff).item() + actual_flat = actual.reshape(-1).to(torch.float64) + expected_flat = expected.reshape(-1).to(torch.float64) + denom = torch.linalg.vector_norm(actual_flat) * torch.linalg.vector_norm(expected_flat) + cosine = ( + float(torch.dot(actual_flat, expected_flat) / denom) + if denom.item() != 0.0 + else float("nan") + ) + return { + "finite": bool(torch.isfinite(actual).all().item()), + "max_abs": float(torch.max(torch.abs(diff)).item()), + "mean_abs": float(torch.mean(torch.abs(diff)).item()), + "diff_norm": float(diff_norm), + "expected_norm": float(expected_norm), + "relative_norm": float(diff_norm / max(expected_norm, 1.0e-12)), + "cosine": cosine, + } + + +def multihead_tensor_metrics(torch: Any, actual: Any, expected: Any) -> list[dict[str, Any]]: + """Return per-head metrics plus pairwise relative norms for head-mix debugging.""" + metrics = [] + for actual_head in range(actual.shape[0]): + pairwise_relative_norm = [] + for expected_head in range(expected.shape[0]): + pairwise_relative_norm.append( + tensor_metrics( + torch, + actual[actual_head], + expected[expected_head], + )["relative_norm"] + ) + head_metrics = tensor_metrics( + torch, + actual[actual_head], + expected[actual_head], + ) + head_metrics["actual_head"] = actual_head + head_metrics["pairwise_relative_norm"] = pairwise_relative_norm + metrics.append(head_metrics) + return metrics + + +def build_autocp_affine_stacks( + torch: Any, + inputs: dict[str, Any], +) -> dict[str, Any]: + parts = [ + deltanet_chunk_affine_parts(torch, inputs, start) + for start in range(0, inputs["query"].shape[0], P_MAX) + ] + return { + "state_matrix": torch.stack( + [chunk_parts["state_matrix"] for chunk_parts in parts], + dim=0, + ).contiguous(), + "state_bias": torch.stack( + [chunk_parts["state_bias"] for chunk_parts in parts], + dim=0, + ).contiguous(), + "output_base": torch.stack( + [chunk_parts["output_base"] for chunk_parts in parts], + dim=0, + ).contiguous(), + "output_state": torch.stack( + [chunk_parts["output_state"] for chunk_parts in parts], + dim=0, + ).contiguous(), + } + + +def autocp_state_prefix_reference( + torch: Any, + state_matrix: Any, + state_bias: Any, + initial_state: Any, +) -> tuple[Any, Any]: + chunk_states = [] + state = initial_state.clone() + for chunk_idx in range(state_matrix.shape[0]): + chunk_states.append(state.clone()) + state = (state_matrix[chunk_idx] @ state) + state_bias[chunk_idx] + return torch.stack(chunk_states, dim=0).contiguous(), state.contiguous() + + +def validate_cpu_chunk_invariance(torch: Any, args: argparse.Namespace) -> int: + if args.seq_len % P_MAX != 0: + raise ValueError( + "--validate-cpu-chunk-invariance requires --seq-len to be a " + f"positive multiple of {P_MAX}; got {args.seq_len}" + ) + + args128 = argparse.Namespace(**{**vars(args), "chunk_size": P_MAX}) + args64 = argparse.Namespace(**{**vars(args), "chunk_size": 64}) + inputs128 = make_inputs(torch, args128) + inputs64 = make_inputs(torch, args64) + + ref128_out, ref128_state = reference_math(torch, inputs128) + ref64_out, ref64_state = reference_math(torch, inputs64) + output_close = bool( + torch.allclose(ref64_out, ref128_out, atol=args.atol, rtol=args.rtol) + ) + state_close = bool( + torch.allclose(ref64_state, ref128_state, atol=args.atol, rtol=args.rtol) + ) + output_finite = bool(torch.isfinite(ref64_out).all().item()) + state_finite = bool(torch.isfinite(ref64_state).all().item()) + passed = bool(output_close and state_close and output_finite and state_finite) + + result = { + "passed": passed, + "validate_cpu_chunk_invariance": True, + "seed": args.seed, + "seq_len": args.seq_len, + "heads": args.heads if args.multihead else 1, + "multihead": args.multihead, + "atol": args.atol, + "rtol": args.rtol, + "output_finite": output_finite, + "state_finite": state_finite, + "chunk64_vs_chunk128": { + "output": tensor_metrics(torch, ref64_out, ref128_out), + "state": tensor_metrics(torch, ref64_state, ref128_state), + }, + } + print(json.dumps(result, indent=2, sort_keys=True)) + + if args.fail_on_mismatch and not passed: + return 2 + return 0 + + +def parse_restore_split_lens(spec: str) -> list[int]: + try: + values = [int(part.strip()) for part in spec.split(",") if part.strip()] + except ValueError as exc: + raise ValueError(f"Invalid --restore-split-lens {spec!r}") from exc + if not values or any(value <= 0 for value in values): + raise ValueError( + "--restore-split-lens must contain positive integer lengths; " + f"got {spec!r}" + ) + return values + + +def zero_sequence_tail(torch: Any, inputs: dict[str, Any], real_seq_len: int) -> None: + sequence_keys = ("query", "key", "value", "g_raw", "beta") + for key in sequence_keys: + tensor = inputs[key] + if tensor.dim() == 2: + tensor[real_seq_len:] = 0 + else: + tensor[:, real_seq_len:] = 0 + + +def slice_sequence_tensor(tensor: Any, start: int, end: int) -> Any: + if tensor.dim() == 2: + return tensor[start:end].contiguous() + return tensor[:, start:end].contiguous() + + +def slice_sequence_for_compare(tensor: Any, end: int) -> Any: + if tensor.dim() == 2: + return tensor[:end] + return tensor[:, :end] + + +def cat_sequence_outputs(torch: Any, tensors: list[Any]) -> Any: + if not tensors: + raise ValueError("No output tensors to concatenate") + if tensors[0].dim() == 2: + return torch.cat(tensors, dim=0) + return torch.cat(tensors, dim=1) + + +def copy_inputs_with_state_and_slice( + inputs: dict[str, Any], + *, + start: int, + end: int, + state: Any, +) -> dict[str, Any]: + return { + "chunk_size": inputs.get("chunk_size", P_MAX), + "query": slice_sequence_tensor(inputs["query"], start, end), + "key": slice_sequence_tensor(inputs["key"], start, end), + "value": slice_sequence_tensor(inputs["value"], start, end), + "g_raw": slice_sequence_tensor(inputs["g_raw"], start, end), + "beta": slice_sequence_tensor(inputs["beta"], start, end), + "state_in": state, + "lower_mask": inputs["lower_mask"], + "identity": inputs["identity"], + "lower_mask_diag": inputs["lower_mask_diag"], + } + + +def run_fused_kernel_once( + torch: Any, + deltanet_fused_chunked_fwd: Any, + inputs: dict[str, Any], + args: argparse.Namespace, +) -> tuple[Any, Any, list[str]]: + if args.multihead: + pair_outputs = [] + pair_states = [] + launch_spec_labels = [] + head_group_size = min(args.head_group_size, args.heads) + for head_start in range(0, args.heads, head_group_size): + head_end = min(head_start + head_group_size, args.heads) + launch_heads = head_end - head_start + launch_spec = multihead_launch_spec(launch_heads, args.lnc) + launch_spec_labels.append(launch_spec_label(launch_spec)) + out_pair, state_pair = deltanet_fused_chunked_fwd[launch_spec]( + inputs["query"][head_start:head_end], + inputs["key"][head_start:head_end], + inputs["value"][head_start:head_end], + inputs["g_raw"][head_start:head_end], + inputs["beta"][head_start:head_end], + inputs["state_in"][head_start:head_end], + inputs["lower_mask"], + inputs["identity"], + inputs["lower_mask_diag"], + ) + pair_outputs.append(out_pair) + pair_states.append(state_pair) + return torch.cat(pair_outputs, dim=0), torch.cat(pair_states, dim=0), launch_spec_labels + + out_dev, state_dev = deltanet_fused_chunked_fwd( + inputs["query"], + inputs["key"], + inputs["value"], + inputs["g_raw"], + inputs["beta"], + inputs["state_in"], + inputs["lower_mask"], + inputs["identity"], + inputs["lower_mask_diag"], + ) + return out_dev, state_dev, [] + + +def validate_restored_suffix_carry( + torch: Any, + xm: Any, + args: argparse.Namespace, + inspect_dir: Path, +) -> int: + split_lens = parse_restore_split_lens(args.restore_split_lens) + bucket_size = int(args.restore_bucket_size) + chunk_size = int(args.chunk_size) + if bucket_size <= 0: + raise ValueError("--restore-bucket-size must be positive") + if bucket_size % chunk_size != 0: + raise ValueError( + "--restore-bucket-size must be a multiple of --chunk-size; " + f"got bucket_size={bucket_size}, chunk_size={chunk_size}" + ) + if any(length > bucket_size for length in split_lens): + raise ValueError( + "Each --restore-split-lens value must fit in --restore-bucket-size; " + f"split_lens={split_lens}, bucket_size={bucket_size}" + ) + + real_seq_len = sum(split_lens) + padded_seq_len = len(split_lens) * bucket_size + input_args = argparse.Namespace(**{**vars(args), "seq_len": padded_seq_len}) + inputs = make_inputs(torch, input_args) + zero_sequence_tail(torch, inputs, real_seq_len) + ref_out, ref_state = reference_math(torch, inputs) + + deltanet_fused_chunked_fwd = load_fused_kernel(args.multihead) + device = xm.xla_device() + xla_inputs = move_tensor_inputs_to_device(inputs, device) + + full_out_cpu = full_state_cpu = None + split_out_cpu = split_state_cpu = None + run_elapsed_seconds = [] + launch_spec_labels = [] + for _ in range(args.runs): + run_start = time.perf_counter() + full_out_dev, full_state_dev, full_launch_specs = run_fused_kernel_once( + torch, + deltanet_fused_chunked_fwd, + xla_inputs, + args, + ) + if full_launch_specs and not launch_spec_labels: + launch_spec_labels = full_launch_specs + + state_dev = xla_inputs["state_in"] + split_outputs = [] + offset = 0 + for real_len in split_lens: + segment_inputs = copy_inputs_with_state_and_slice( + xla_inputs, + start=offset, + end=offset + bucket_size, + state=state_dev, + ) + out_dev, state_dev, segment_launch_specs = run_fused_kernel_once( + torch, + deltanet_fused_chunked_fwd, + segment_inputs, + args, + ) + if segment_launch_specs and not launch_spec_labels: + launch_spec_labels = segment_launch_specs + split_outputs.append(out_dev) + offset += bucket_size + + split_out_dev = cat_sequence_outputs(torch, split_outputs) + split_state_dev = state_dev + xm.mark_step() + full_out_cpu = full_out_dev.detach().cpu().float() + full_state_cpu = full_state_dev.detach().cpu().float() + split_out_cpu = split_out_dev.detach().cpu().float() + split_state_cpu = split_state_dev.detach().cpu().float() + run_elapsed_seconds.append(time.perf_counter() - run_start) + + assert full_out_cpu is not None + assert full_state_cpu is not None + assert split_out_cpu is not None + assert split_state_cpu is not None + + ref_real = slice_sequence_for_compare(ref_out, real_seq_len) + full_real = slice_sequence_for_compare(full_out_cpu, real_seq_len) + split_real = slice_sequence_for_compare(split_out_cpu, real_seq_len) + + full_output_close = bool( + torch.allclose(full_real, ref_real, atol=args.atol, rtol=args.rtol) + ) + full_state_close = bool( + torch.allclose(full_state_cpu, ref_state, atol=args.atol, rtol=args.rtol) + ) + split_output_close = bool( + torch.allclose(split_real, ref_real, atol=args.atol, rtol=args.rtol) + ) + split_state_close = bool( + torch.allclose(split_state_cpu, ref_state, atol=args.atol, rtol=args.rtol) + ) + split_vs_full_output_close = bool( + torch.allclose(split_real, full_real, atol=args.atol, rtol=args.rtol) + ) + split_vs_full_state_close = bool( + torch.allclose(split_state_cpu, full_state_cpu, atol=args.atol, rtol=args.rtol) + ) + finite = { + "full_output": bool(torch.isfinite(full_real).all().item()), + "full_state": bool(torch.isfinite(full_state_cpu).all().item()), + "split_output": bool(torch.isfinite(split_real).all().item()), + "split_state": bool(torch.isfinite(split_state_cpu).all().item()), + } + passed = bool( + full_output_close + and full_state_close + and split_output_close + and split_state_close + and split_vs_full_output_close + and split_vs_full_state_close + and all(finite.values()) + ) + + result = { + "passed": passed, + "validate_restored_suffix_carry": True, + "seed": args.seed, + "split_lens": split_lens, + "restore_bucket_size": bucket_size, + "real_seq_len": real_seq_len, + "padded_seq_len": padded_seq_len, + "chunk_size": chunk_size, + "heads": args.heads if args.multihead else 1, + "head_group_size": args.head_group_size if args.multihead else 1, + "launch_specs": launch_spec_labels if args.multihead else [], + "multihead": args.multihead, + "runs": args.runs, + "run_elapsed_seconds": run_elapsed_seconds, + "cached_run_elapsed_seconds": run_elapsed_seconds[1:], + "atol": args.atol, + "rtol": args.rtol, + "inspect": args.inspect, + "dge": args.dge, + "finite": finite, + "close": { + "full_output": full_output_close, + "full_state": full_state_close, + "split_output": split_output_close, + "split_state": split_state_close, + "split_vs_full_output": split_vs_full_output_close, + "split_vs_full_state": split_vs_full_state_close, + }, + "inspect_dir": str(inspect_dir), + "environment": { + key: os.environ.get(key) + for key in ( + "NEURON_CC_FLAGS", + "NEURON_PLATFORM_TARGET_OVERRIDE", + "NEURON_RT_VISIBLE_CORES", + "NEURON_RT_INSPECT_ENABLE", + "NEURON_RT_ENABLE_DGE_NOTIFICATIONS", + "QWEN36_DELTANET_CHUNK_SIZE", + "QWEN36_DELTANET_SOLVE_BLOCK_SIZE", + "QWEN36_DELTANET_SOLVE_SCAN_STEPS", + "QWEN36_DELTANET_SOLVE_ACTIVE_PREFIX_K", + "QWEN36_DELTANET_SOLVE_MODE", + ) + }, + "nki_vs_reference": { + "full_output_real": tensor_metrics(torch, full_real, ref_real), + "full_state": tensor_metrics(torch, full_state_cpu, ref_state), + "split_output_real": tensor_metrics(torch, split_real, ref_real), + "split_state": tensor_metrics(torch, split_state_cpu, ref_state), + "split_vs_full_output_real": tensor_metrics(torch, split_real, full_real), + "split_vs_full_state": tensor_metrics( + torch, + split_state_cpu, + full_state_cpu, + ), + }, + } + if args.multihead: + result["nki_vs_reference"]["split_output_per_head"] = multihead_tensor_metrics( + torch, + split_real, + ref_real, + ) + result["nki_vs_reference"]["split_state_per_head"] = multihead_tensor_metrics( + torch, + split_state_cpu, + ref_state, + ) + print(json.dumps(result, indent=2, sort_keys=True)) + + if args.fail_on_mismatch and not passed: + return 2 + return 0 + + +def validate_autocp_affine_chunk(torch: Any, xm: Any, args: argparse.Namespace, inspect_dir: Path) -> int: + if args.multihead: + raise ValueError("--validate-autocp-affine expects single-head inputs") + + deltanet_autocp_affine_chunk = load_autocp_affine_kernel() + inputs = make_inputs(torch, args) + ref_parts = deltanet_chunk_affine_parts(torch, inputs, 0) + + device = xm.xla_device() + xla_inputs = move_tensor_inputs_to_device(inputs, device) + + part_names = ("output_base", "output_state", "state_matrix", "state_bias") + actual_parts = None + run_elapsed_seconds = [] + for _ in range(args.runs): + run_start = time.perf_counter() + parts_dev = deltanet_autocp_affine_chunk( + xla_inputs["query"][0:P_MAX], + xla_inputs["key"][0:P_MAX], + xla_inputs["value"][0:P_MAX], + xla_inputs["g_raw"][0:P_MAX], + xla_inputs["beta"][0:P_MAX], + xla_inputs["lower_mask"], + xla_inputs["identity"], + xla_inputs["lower_mask_diag"], + ) + xm.mark_step() + actual_parts = { + name: part.detach().cpu().float() + for name, part in zip(part_names, parts_dev, strict=True) + } + run_elapsed_seconds.append(time.perf_counter() - run_start) + + assert actual_parts is not None + + close = { + name: bool( + torch.allclose( + actual_parts[name], + ref_parts[name], + atol=args.atol, + rtol=args.rtol, + ) + ) + for name in part_names + } + finite = { + name: bool(torch.isfinite(actual_parts[name]).all().item()) + for name in part_names + } + passed = all(close.values()) and all(finite.values()) + + result = { + "passed": bool(passed), + "validate_autocp_affine": True, + "seed": args.seed, + "seq_len": args.seq_len, + "runs": args.runs, + "run_elapsed_seconds": run_elapsed_seconds, + "cached_run_elapsed_seconds": run_elapsed_seconds[1:], + "atol": args.atol, + "rtol": args.rtol, + "inspect": args.inspect, + "dge": args.dge, + "finite": finite, + "close": close, + "inspect_dir": str(inspect_dir), + "environment": { + key: os.environ.get(key) + for key in ( + "NEURON_CC_FLAGS", + "NEURON_PLATFORM_TARGET_OVERRIDE", + "NEURON_RT_VISIBLE_CORES", + "NEURON_RT_INSPECT_ENABLE", + "NEURON_RT_ENABLE_DGE_NOTIFICATIONS", + "QWEN36_DELTANET_SOLVE_BLOCK_SIZE", + "QWEN36_DELTANET_SOLVE_SCAN_STEPS", + "QWEN36_DELTANET_SOLVE_ACTIVE_PREFIX_K", + "QWEN36_DELTANET_SOLVE_MODE", + ) + }, + "nki_vs_reference": { + name: tensor_metrics(torch, actual_parts[name], ref_parts[name]) + for name in part_names + }, + } + print(json.dumps(result, indent=2, sort_keys=True)) + + if args.fail_on_mismatch and not passed: + return 2 + return 0 + + +def validate_autocp_state_prefix(torch: Any, xm: Any, args: argparse.Namespace, inspect_dir: Path) -> int: + if args.multihead: + raise ValueError("--validate-autocp-prefix expects single-head inputs") + + deltanet_autocp_state_prefix = load_autocp_prefix_kernel() + inputs = make_inputs(torch, args) + affine = build_autocp_affine_stacks(torch, inputs) + ref_chunk_states, ref_final_state = autocp_state_prefix_reference( + torch, + affine["state_matrix"], + affine["state_bias"], + inputs["state_in"], + ) + + device = xm.xla_device() + state_matrix_dev = affine["state_matrix"].to(device=device) + state_bias_dev = affine["state_bias"].to(device=device) + initial_state_dev = inputs["state_in"].to(device=device) + + chunk_states_cpu = final_state_cpu = None + run_elapsed_seconds = [] + for _ in range(args.runs): + run_start = time.perf_counter() + chunk_states_dev, final_state_dev = deltanet_autocp_state_prefix( + state_matrix_dev, + state_bias_dev, + initial_state_dev, + ) + xm.mark_step() + chunk_states_cpu = chunk_states_dev.detach().cpu().float() + final_state_cpu = final_state_dev.detach().cpu().float() + run_elapsed_seconds.append(time.perf_counter() - run_start) + + assert chunk_states_cpu is not None + assert final_state_cpu is not None + + chunk_states_close = bool( + torch.allclose( + chunk_states_cpu, + ref_chunk_states, + atol=args.atol, + rtol=args.rtol, + ) + ) + final_state_close = bool( + torch.allclose( + final_state_cpu, + ref_final_state, + atol=args.atol, + rtol=args.rtol, + ) + ) + chunk_states_finite = bool(torch.isfinite(chunk_states_cpu).all().item()) + final_state_finite = bool(torch.isfinite(final_state_cpu).all().item()) + passed = bool( + chunk_states_close + and final_state_close + and chunk_states_finite + and final_state_finite + ) + + result = { + "passed": passed, + "validate_autocp_prefix": True, + "seed": args.seed, + "seq_len": args.seq_len, + "num_chunks": args.seq_len // P_MAX, + "runs": args.runs, + "run_elapsed_seconds": run_elapsed_seconds, + "cached_run_elapsed_seconds": run_elapsed_seconds[1:], + "atol": args.atol, + "rtol": args.rtol, + "inspect": args.inspect, + "dge": args.dge, + "finite": { + "chunk_states": chunk_states_finite, + "final_state": final_state_finite, + }, + "close": { + "chunk_states": chunk_states_close, + "final_state": final_state_close, + }, + "inspect_dir": str(inspect_dir), + "environment": { + key: os.environ.get(key) + for key in ( + "NEURON_CC_FLAGS", + "NEURON_PLATFORM_TARGET_OVERRIDE", + "NEURON_RT_VISIBLE_CORES", + "NEURON_RT_INSPECT_ENABLE", + "NEURON_RT_ENABLE_DGE_NOTIFICATIONS", + "QWEN36_DELTANET_SOLVE_BLOCK_SIZE", + "QWEN36_DELTANET_SOLVE_SCAN_STEPS", + "QWEN36_DELTANET_SOLVE_ACTIVE_PREFIX_K", + "QWEN36_DELTANET_SOLVE_MODE", + ) + }, + "nki_vs_reference": { + "chunk_states": tensor_metrics( + torch, + chunk_states_cpu, + ref_chunk_states, + ), + "final_state": tensor_metrics( + torch, + final_state_cpu, + ref_final_state, + ), + }, + } + print(json.dumps(result, indent=2, sort_keys=True)) + + if args.fail_on_mismatch and not passed: + return 2 + return 0 + + +def validate_autocp_chain(torch: Any, xm: Any, args: argparse.Namespace, inspect_dir: Path) -> int: + if args.multihead: + raise ValueError("--validate-autocp-chain expects single-head inputs") + + deltanet_autocp_state_prefix = load_autocp_prefix_kernel() + deltanet_autocp_apply_output = load_autocp_apply_kernel() + + inputs = make_inputs(torch, args) + affine = build_autocp_affine_stacks(torch, inputs) + ref_out, ref_final_state = autocp_reference_math(torch, inputs) + + device = xm.xla_device() + state_matrix_dev = affine["state_matrix"].to(device=device) + state_bias_dev = affine["state_bias"].to(device=device) + output_base_dev = affine["output_base"].to(device=device) + output_state_dev = affine["output_state"].to(device=device) + initial_state_dev = inputs["state_in"].to(device=device) + + out_cpu = final_state_cpu = None + run_elapsed_seconds = [] + for _ in range(args.runs): + run_start = time.perf_counter() + chunk_states_dev, final_state_dev = deltanet_autocp_state_prefix( + state_matrix_dev, + state_bias_dev, + initial_state_dev, + ) + out_dev = deltanet_autocp_apply_output( + output_base_dev, + output_state_dev, + chunk_states_dev, + ) + xm.mark_step() + out_cpu = out_dev.detach().cpu().float() + final_state_cpu = final_state_dev.detach().cpu().float() + run_elapsed_seconds.append(time.perf_counter() - run_start) + + assert out_cpu is not None + assert final_state_cpu is not None + + output_close = bool(torch.allclose(out_cpu, ref_out, atol=args.atol, rtol=args.rtol)) + final_state_close = bool( + torch.allclose( + final_state_cpu, + ref_final_state, + atol=args.atol, + rtol=args.rtol, + ) + ) + output_finite = bool(torch.isfinite(out_cpu).all().item()) + final_state_finite = bool(torch.isfinite(final_state_cpu).all().item()) + passed = bool(output_close and final_state_close and output_finite and final_state_finite) + + result = { + "passed": passed, + "validate_autocp_chain": True, + "seed": args.seed, + "seq_len": args.seq_len, + "num_chunks": args.seq_len // P_MAX, + "runs": args.runs, + "run_elapsed_seconds": run_elapsed_seconds, + "cached_run_elapsed_seconds": run_elapsed_seconds[1:], + "atol": args.atol, + "rtol": args.rtol, + "inspect": args.inspect, + "dge": args.dge, + "finite": { + "output": output_finite, + "final_state": final_state_finite, + }, + "close": { + "output": output_close, + "final_state": final_state_close, + }, + "inspect_dir": str(inspect_dir), + "environment": { + key: os.environ.get(key) + for key in ( + "NEURON_CC_FLAGS", + "NEURON_PLATFORM_TARGET_OVERRIDE", + "NEURON_RT_VISIBLE_CORES", + "NEURON_RT_INSPECT_ENABLE", + "NEURON_RT_ENABLE_DGE_NOTIFICATIONS", + "QWEN36_DELTANET_SOLVE_BLOCK_SIZE", + "QWEN36_DELTANET_SOLVE_SCAN_STEPS", + "QWEN36_DELTANET_SOLVE_ACTIVE_PREFIX_K", + "QWEN36_DELTANET_SOLVE_MODE", + ) + }, + "nki_vs_reference": { + "output": tensor_metrics(torch, out_cpu, ref_out), + "final_state": tensor_metrics(torch, final_state_cpu, ref_final_state), + }, + } + print(json.dumps(result, indent=2, sort_keys=True)) + + if args.fail_on_mismatch and not passed: + return 2 + return 0 + + +def validate_autocp_prefix_apply(torch: Any, xm: Any, args: argparse.Namespace, inspect_dir: Path) -> int: + if args.multihead: + raise ValueError("--validate-autocp-prefix-apply expects single-head inputs") + + deltanet_autocp_prefix_apply = load_autocp_prefix_apply_kernel() + + inputs = make_inputs(torch, args) + affine = build_autocp_affine_stacks(torch, inputs) + ref_out, ref_final_state = autocp_reference_math(torch, inputs) + + device = xm.xla_device() + output_base_dev = affine["output_base"].to(device=device) + output_state_dev = affine["output_state"].to(device=device) + state_matrix_dev = affine["state_matrix"].to(device=device) + state_bias_dev = affine["state_bias"].to(device=device) + initial_state_dev = inputs["state_in"].to(device=device) + + out_cpu = final_state_cpu = None + run_elapsed_seconds = [] + for _ in range(args.runs): + run_start = time.perf_counter() + out_dev, final_state_dev = deltanet_autocp_prefix_apply( + output_base_dev, + output_state_dev, + state_matrix_dev, + state_bias_dev, + initial_state_dev, + ) + xm.mark_step() + out_cpu = out_dev.detach().cpu().float() + final_state_cpu = final_state_dev.detach().cpu().float() + run_elapsed_seconds.append(time.perf_counter() - run_start) + + assert out_cpu is not None + assert final_state_cpu is not None + + output_close = bool(torch.allclose(out_cpu, ref_out, atol=args.atol, rtol=args.rtol)) + final_state_close = bool( + torch.allclose( + final_state_cpu, + ref_final_state, + atol=args.atol, + rtol=args.rtol, + ) + ) + output_finite = bool(torch.isfinite(out_cpu).all().item()) + final_state_finite = bool(torch.isfinite(final_state_cpu).all().item()) + passed = bool(output_close and final_state_close and output_finite and final_state_finite) + + result = { + "passed": passed, + "validate_autocp_prefix_apply": True, + "seed": args.seed, + "seq_len": args.seq_len, + "num_chunks": args.seq_len // P_MAX, + "runs": args.runs, + "run_elapsed_seconds": run_elapsed_seconds, + "cached_run_elapsed_seconds": run_elapsed_seconds[1:], + "atol": args.atol, + "rtol": args.rtol, + "inspect": args.inspect, + "dge": args.dge, + "finite": { + "output": output_finite, + "final_state": final_state_finite, + }, + "close": { + "output": output_close, + "final_state": final_state_close, + }, + "inspect_dir": str(inspect_dir), + "environment": { + key: os.environ.get(key) + for key in ( + "NEURON_CC_FLAGS", + "NEURON_PLATFORM_TARGET_OVERRIDE", + "NEURON_RT_VISIBLE_CORES", + "NEURON_RT_INSPECT_ENABLE", + "NEURON_RT_ENABLE_DGE_NOTIFICATIONS", + "QWEN36_DELTANET_SOLVE_BLOCK_SIZE", + "QWEN36_DELTANET_SOLVE_SCAN_STEPS", + "QWEN36_DELTANET_SOLVE_ACTIVE_PREFIX_K", + "QWEN36_DELTANET_SOLVE_MODE", + ) + }, + "nki_vs_reference": { + "output": tensor_metrics(torch, out_cpu, ref_out), + "final_state": tensor_metrics(torch, final_state_cpu, ref_final_state), + }, + } + print(json.dumps(result, indent=2, sort_keys=True)) + + if args.fail_on_mismatch and not passed: + return 2 + return 0 + + +def validate_autocp_full(torch: Any, xm: Any, args: argparse.Namespace, inspect_dir: Path) -> int: + if args.multihead: + raise ValueError("--validate-autocp-full expects single-head inputs") + + import nki.language as nl + + deltanet_autocp_affine_sequence = load_autocp_affine_sequence_kernel() + deltanet_autocp_prefix_apply = load_autocp_prefix_apply_kernel() + + inputs = make_inputs(torch, args) + ref_out, ref_final_state = autocp_reference_math(torch, inputs) + ref_affine = build_autocp_affine_stacks(torch, inputs) + + device = xm.xla_device() + xla_inputs = move_tensor_inputs_to_device(inputs, device) + + out_cpu = final_state_cpu = None + affine_cpu = None + run_elapsed_seconds = [] + num_chunks = args.seq_len // P_MAX + if hasattr(nl, "spmd_dim") and hasattr(nl, "nc"): + affine_launch_spec = ( + nl.spmd_dim(num_chunks, nl.nc(args.lnc)), + 1, + ) + else: + affine_launch_spec = args.lnc + for _ in range(args.runs): + run_start = time.perf_counter() + output_base_dev, output_state_dev, state_matrix_dev, state_bias_dev = ( + deltanet_autocp_affine_sequence[affine_launch_spec]( + xla_inputs["query"], + xla_inputs["key"], + xla_inputs["value"], + xla_inputs["g_raw"], + xla_inputs["beta"], + xla_inputs["lower_mask"], + xla_inputs["identity"], + xla_inputs["lower_mask_diag"], + ) + ) + out_dev, final_state_dev = deltanet_autocp_prefix_apply( + output_base_dev, + output_state_dev, + state_matrix_dev, + state_bias_dev, + xla_inputs["state_in"], + ) + xm.mark_step() + out_cpu = out_dev.detach().cpu().float() + final_state_cpu = final_state_dev.detach().cpu().float() + affine_cpu = { + "output_base": output_base_dev.detach().cpu().float(), + "output_state": output_state_dev.detach().cpu().float(), + "state_matrix": state_matrix_dev.detach().cpu().float(), + "state_bias": state_bias_dev.detach().cpu().float(), + } + run_elapsed_seconds.append(time.perf_counter() - run_start) + + assert out_cpu is not None + assert final_state_cpu is not None + assert affine_cpu is not None + + output_close = bool(torch.allclose(out_cpu, ref_out, atol=args.atol, rtol=args.rtol)) + final_state_close = bool( + torch.allclose( + final_state_cpu, + ref_final_state, + atol=args.atol, + rtol=args.rtol, + ) + ) + affine_close = { + name: bool( + torch.allclose( + affine_cpu[name], + ref_affine[name], + atol=args.atol, + rtol=args.rtol, + ) + ) + for name in affine_cpu + } + output_finite = bool(torch.isfinite(out_cpu).all().item()) + final_state_finite = bool(torch.isfinite(final_state_cpu).all().item()) + affine_finite = { + name: bool(torch.isfinite(affine_cpu[name]).all().item()) + for name in affine_cpu + } + passed = bool( + output_close + and final_state_close + and output_finite + and final_state_finite + and all(affine_close.values()) + and all(affine_finite.values()) + ) + + result = { + "passed": passed, + "validate_autocp_full": True, + "seed": args.seed, + "seq_len": args.seq_len, + "num_chunks": num_chunks, + "runs": args.runs, + "run_elapsed_seconds": run_elapsed_seconds, + "cached_run_elapsed_seconds": run_elapsed_seconds[1:], + "atol": args.atol, + "rtol": args.rtol, + "inspect": args.inspect, + "dge": args.dge, + "finite": { + "output": output_finite, + "final_state": final_state_finite, + **{f"affine_{name}": value for name, value in affine_finite.items()}, + }, + "close": { + "output": output_close, + "final_state": final_state_close, + **{f"affine_{name}": value for name, value in affine_close.items()}, + }, + "inspect_dir": str(inspect_dir), + "environment": { + key: os.environ.get(key) + for key in ( + "NEURON_CC_FLAGS", + "NEURON_PLATFORM_TARGET_OVERRIDE", + "NEURON_RT_VISIBLE_CORES", + "NEURON_RT_INSPECT_ENABLE", + "NEURON_RT_ENABLE_DGE_NOTIFICATIONS", + "QWEN36_DELTANET_SOLVE_BLOCK_SIZE", + "QWEN36_DELTANET_SOLVE_SCAN_STEPS", + "QWEN36_DELTANET_SOLVE_ACTIVE_PREFIX_K", + "QWEN36_DELTANET_SOLVE_MODE", + ) + }, + "nki_vs_reference": { + "output": tensor_metrics(torch, out_cpu, ref_out), + "final_state": tensor_metrics(torch, final_state_cpu, ref_final_state), + **{ + f"affine_{name}": tensor_metrics( + torch, + affine_cpu[name], + ref_affine[name], + ) + for name in affine_cpu + }, + }, + } + print(json.dumps(result, indent=2, sort_keys=True)) + + if args.fail_on_mismatch and not passed: + return 2 + return 0 + + +def validate_autocp_state_summary( + torch: Any, + xm: Any, + args: argparse.Namespace, + inspect_dir: Path, +) -> int: + if args.multihead: + raise ValueError("--validate-autocp-state-summary expects single-head inputs") + + import nki.language as nl + + deltanet_autocp_state_summary = load_autocp_state_summary_kernel() + + inputs = make_inputs(torch, args) + ref_segments = build_compact_autocp_segment_transforms( + torch, + inputs, + cp_chunks=args.autocp_cp_chunks, + ) + + device = xm.xla_device() + xla_inputs = move_tensor_inputs_to_device(inputs, device) + num_segments = ref_segments["state_matrix"].shape[0] + if hasattr(nl, "spmd_dim") and hasattr(nl, "nc"): + launch_spec = ( + nl.spmd_dim(num_segments, nl.nc(args.lnc)), + 1, + ) + else: + launch_spec = args.lnc + + state_matrix_cpu = state_bias_cpu = None + run_elapsed_seconds = [] + for _ in range(args.runs): + run_start = time.perf_counter() + state_matrix_dev, state_bias_dev = deltanet_autocp_state_summary[launch_spec]( + xla_inputs["key"], + xla_inputs["value"], + xla_inputs["g_raw"], + xla_inputs["beta"], + xla_inputs["lower_mask"], + xla_inputs["identity"], + ) + xm.mark_step() + state_matrix_cpu = state_matrix_dev.detach().cpu().float() + state_bias_cpu = state_bias_dev.detach().cpu().float() + run_elapsed_seconds.append(time.perf_counter() - run_start) + + assert state_matrix_cpu is not None + assert state_bias_cpu is not None + + matrix_close = bool( + torch.allclose( + state_matrix_cpu, + ref_segments["state_matrix"], + atol=args.atol, + rtol=args.rtol, + ) + ) + bias_close = bool( + torch.allclose( + state_bias_cpu, + ref_segments["state_bias"], + atol=args.atol, + rtol=args.rtol, + ) + ) + matrix_finite = bool(torch.isfinite(state_matrix_cpu).all().item()) + bias_finite = bool(torch.isfinite(state_bias_cpu).all().item()) + passed = bool(matrix_close and bias_close and matrix_finite and bias_finite) + + result = { + "passed": passed, + "validate_autocp_state_summary": True, + "seed": args.seed, + "seq_len": args.seq_len, + "num_chunks": args.seq_len // P_MAX, + "num_segments": num_segments, + "cp_chunks": args.autocp_cp_chunks, + "runs": args.runs, + "run_elapsed_seconds": run_elapsed_seconds, + "cached_run_elapsed_seconds": run_elapsed_seconds[1:], + "atol": args.atol, + "rtol": args.rtol, + "inspect": args.inspect, + "dge": args.dge, + "finite": { + "state_matrix": matrix_finite, + "state_bias": bias_finite, + }, + "close": { + "state_matrix": matrix_close, + "state_bias": bias_close, + }, + "inspect_dir": str(inspect_dir), + "environment": { + key: os.environ.get(key) + for key in ( + "NEURON_CC_FLAGS", + "NEURON_PLATFORM_TARGET_OVERRIDE", + "NEURON_RT_VISIBLE_CORES", + "NEURON_RT_INSPECT_ENABLE", + "NEURON_RT_ENABLE_DGE_NOTIFICATIONS", + "QWEN36_DELTANET_CHUNK_SIZE", + "QWEN36_DELTANET_AUTOCP_CP_CHUNKS", + "QWEN36_DELTANET_SOLVE_BLOCK_SIZE", + "QWEN36_DELTANET_SOLVE_SCAN_STEPS", + "QWEN36_DELTANET_SOLVE_ACTIVE_PREFIX_K", + "QWEN36_DELTANET_SOLVE_MODE", + ) + }, + "nki_vs_reference": { + "state_matrix": tensor_metrics( + torch, + state_matrix_cpu, + ref_segments["state_matrix"], + ), + "state_bias": tensor_metrics( + torch, + state_bias_cpu, + ref_segments["state_bias"], + ), + }, + } + print(json.dumps(result, indent=2, sort_keys=True)) + + if args.fail_on_mismatch and not passed: + return 2 + return 0 + + +def validate_autocp_compact_chain( + torch: Any, + xm: Any, + args: argparse.Namespace, + inspect_dir: Path, +) -> int: + if args.multihead: + raise ValueError("--validate-autocp-compact-chain expects single-head inputs") + + import nki.language as nl + + deltanet_autocp_state_summary = load_autocp_state_summary_kernel() + deltanet_autocp_state_prefix = load_autocp_prefix_kernel() + deltanet_fused_multihead = load_fused_kernel(True) + + inputs = make_inputs(torch, args) + ref_out, ref_final_state = reference_math(torch, inputs) + + device = xm.xla_device() + xla_inputs = move_tensor_inputs_to_device(inputs, device) + num_chunks = args.seq_len // P_MAX + num_segments = num_chunks // args.autocp_cp_chunks + if hasattr(nl, "spmd_dim") and hasattr(nl, "nc"): + summary_launch_spec = ( + nl.spmd_dim(num_segments, nl.nc(args.lnc)), + 1, + ) + else: + summary_launch_spec = args.lnc + + out_cpu = final_state_cpu = None + run_elapsed_seconds = [] + segment_len = args.autocp_cp_chunks * P_MAX + replay_group_size = min(num_segments, args.lnc) + for _ in range(args.runs): + run_start = time.perf_counter() + state_matrix_dev, state_bias_dev = deltanet_autocp_state_summary[ + summary_launch_spec + ]( + xla_inputs["key"], + xla_inputs["value"], + xla_inputs["g_raw"], + xla_inputs["beta"], + xla_inputs["lower_mask"], + xla_inputs["identity"], + ) + segment_states_dev, final_state_dev = deltanet_autocp_state_prefix( + state_matrix_dev, + state_bias_dev, + xla_inputs["state_in"], + ) + + q_segments = xla_inputs["query"].reshape(num_segments, segment_len, P_MAX).contiguous() + k_segments = xla_inputs["key"].reshape(num_segments, segment_len, P_MAX).contiguous() + v_segments = xla_inputs["value"].reshape(num_segments, segment_len, P_MAX).contiguous() + g_segments = xla_inputs["g_raw"].reshape(num_segments, segment_len, 1).contiguous() + beta_segments = xla_inputs["beta"].reshape(num_segments, segment_len, 1).contiguous() + replay_outputs = [] + for segment_start in range(0, num_segments, replay_group_size): + segment_end = min(segment_start + replay_group_size, num_segments) + launch_segments = segment_end - segment_start + replay_launch_spec = multihead_launch_spec(launch_segments, args.lnc) + out_group, _ = deltanet_fused_multihead[replay_launch_spec]( + q_segments[segment_start:segment_end], + k_segments[segment_start:segment_end], + v_segments[segment_start:segment_end], + g_segments[segment_start:segment_end], + beta_segments[segment_start:segment_end], + segment_states_dev[segment_start:segment_end], + xla_inputs["lower_mask"], + xla_inputs["identity"], + xla_inputs["lower_mask_diag"], + ) + replay_outputs.append(out_group) + out_segments_dev = torch.cat(replay_outputs, dim=0) + out_dev = out_segments_dev.reshape(args.seq_len, P_MAX) + xm.mark_step() + out_cpu = out_dev.detach().cpu().float() + final_state_cpu = final_state_dev.detach().cpu().float() + run_elapsed_seconds.append(time.perf_counter() - run_start) + + assert out_cpu is not None + assert final_state_cpu is not None + + output_close = bool(torch.allclose(out_cpu, ref_out, atol=args.atol, rtol=args.rtol)) + final_state_close = bool( + torch.allclose( + final_state_cpu, + ref_final_state, + atol=args.atol, + rtol=args.rtol, + ) + ) + output_finite = bool(torch.isfinite(out_cpu).all().item()) + final_state_finite = bool(torch.isfinite(final_state_cpu).all().item()) + passed = bool( + output_close + and final_state_close + and output_finite + and final_state_finite + ) + + result = { + "passed": passed, + "validate_autocp_compact_chain": True, + "seed": args.seed, + "seq_len": args.seq_len, + "num_chunks": num_chunks, + "num_segments": num_segments, + "cp_chunks": args.autocp_cp_chunks, + "replay_group_size": replay_group_size, + "runs": args.runs, + "run_elapsed_seconds": run_elapsed_seconds, + "cached_run_elapsed_seconds": run_elapsed_seconds[1:], + "atol": args.atol, + "rtol": args.rtol, + "inspect": args.inspect, + "dge": args.dge, + "finite": { + "output": output_finite, + "final_state": final_state_finite, + }, + "close": { + "output": output_close, + "final_state": final_state_close, + }, + "inspect_dir": str(inspect_dir), + "environment": { + key: os.environ.get(key) + for key in ( + "NEURON_CC_FLAGS", + "NEURON_PLATFORM_TARGET_OVERRIDE", + "NEURON_RT_VISIBLE_CORES", + "NEURON_RT_INSPECT_ENABLE", + "NEURON_RT_ENABLE_DGE_NOTIFICATIONS", + "QWEN36_DELTANET_CHUNK_SIZE", + "QWEN36_DELTANET_AUTOCP_CP_CHUNKS", + "QWEN36_DELTANET_SOLVE_BLOCK_SIZE", + "QWEN36_DELTANET_SOLVE_SCAN_STEPS", + "QWEN36_DELTANET_SOLVE_ACTIVE_PREFIX_K", + "QWEN36_DELTANET_SOLVE_MODE", + ) + }, + "nki_vs_reference": { + "output": tensor_metrics(torch, out_cpu, ref_out), + "final_state": tensor_metrics( + torch, + final_state_cpu, + ref_final_state, + ), + }, + } + print(json.dumps(result, indent=2, sort_keys=True)) + + if args.fail_on_mismatch and not passed: + return 2 + return 0 + + +def validate_compact_autocp_reference(torch: Any, args: argparse.Namespace) -> int: + inputs = make_inputs(torch, args) + expected_out, expected_state = reference_math(torch, inputs) + actual_out, actual_state = compact_autocp_reference_math( + torch, + inputs, + cp_chunks=args.autocp_cp_chunks, + ) + + output_close = bool( + torch.allclose(actual_out, expected_out, atol=args.atol, rtol=args.rtol) + ) + state_close = bool( + torch.allclose(actual_state, expected_state, atol=args.atol, rtol=args.rtol) + ) + output_finite = bool(torch.isfinite(actual_out).all().item()) + state_finite = bool(torch.isfinite(actual_state).all().item()) + passed = bool(output_close and state_close and output_finite and state_finite) + + result = { + "passed": passed, + "validate_compact_autocp_reference": True, + "seed": args.seed, + "seq_len": args.seq_len, + "heads": args.heads if args.multihead else 1, + "multihead": args.multihead, + "cp_chunks": args.autocp_cp_chunks, + "atol": args.atol, + "rtol": args.rtol, + "output_finite": output_finite, + "state_finite": state_finite, + "materialization": compact_autocp_materialization_counts( + args.seq_len, + args.autocp_cp_chunks, + ), + "compact_vs_sequential": { + "output": tensor_metrics(torch, actual_out, expected_out), + "state": tensor_metrics(torch, actual_state, expected_state), + }, + } + print(json.dumps(result, indent=2, sort_keys=True)) + + if args.fail_on_mismatch and not passed: + return 2 + return 0 + + +def main() -> int: + args = parse_args() + inspect_dir = configure_environment(args) + add_qwen_to_path() + + import torch + + if args.validate_cpu_chunk_invariance: + return validate_cpu_chunk_invariance(torch, args) + if args.validate_compact_autocp_reference: + return validate_compact_autocp_reference(torch, args) + + import torch_xla.core.xla_model as xm + + if args.validate_autocp_affine: + return validate_autocp_affine_chunk(torch, xm, args, inspect_dir) + if args.validate_autocp_prefix: + return validate_autocp_state_prefix(torch, xm, args, inspect_dir) + if args.validate_autocp_chain: + return validate_autocp_chain(torch, xm, args, inspect_dir) + if args.validate_autocp_prefix_apply: + return validate_autocp_prefix_apply(torch, xm, args, inspect_dir) + if args.validate_autocp_full: + return validate_autocp_full(torch, xm, args, inspect_dir) + if args.validate_autocp_state_summary: + return validate_autocp_state_summary(torch, xm, args, inspect_dir) + if args.validate_autocp_compact_chain: + return validate_autocp_compact_chain(torch, xm, args, inspect_dir) + if args.validate_restored_suffix_carry: + return validate_restored_suffix_carry(torch, xm, args, inspect_dir) + + deltanet_fused_chunked_fwd = load_fused_kernel(args.multihead) + + inputs = make_inputs(torch, args) + ref_out, ref_state = reference_math(torch, inputs) + + device = xm.xla_device() + xla_inputs = move_tensor_inputs_to_device(inputs, device) + + out_cpu = state_cpu = None + run_elapsed_seconds = [] + launch_spec_labels = [] + for _ in range(args.runs): + run_start = time.perf_counter() + if args.multihead: + pair_outputs = [] + pair_states = [] + head_group_size = min(args.head_group_size, args.heads) + for head_start in range(0, args.heads, head_group_size): + head_end = min(head_start + head_group_size, args.heads) + launch_heads = head_end - head_start + launch_spec = multihead_launch_spec(launch_heads, args.lnc) + if len(launch_spec_labels) < math.ceil(args.heads / head_group_size): + launch_spec_labels.append(launch_spec_label(launch_spec)) + out_pair, state_pair = deltanet_fused_chunked_fwd[launch_spec]( + xla_inputs["query"][head_start:head_end], + xla_inputs["key"][head_start:head_end], + xla_inputs["value"][head_start:head_end], + xla_inputs["g_raw"][head_start:head_end], + xla_inputs["beta"][head_start:head_end], + xla_inputs["state_in"][head_start:head_end], + xla_inputs["lower_mask"], + xla_inputs["identity"], + xla_inputs["lower_mask_diag"], + ) + pair_outputs.append(out_pair) + pair_states.append(state_pair) + out_dev = torch.cat(pair_outputs, dim=0) + state_dev = torch.cat(pair_states, dim=0) + else: + out_dev, state_dev = deltanet_fused_chunked_fwd( + xla_inputs["query"], + xla_inputs["key"], + xla_inputs["value"], + xla_inputs["g_raw"], + xla_inputs["beta"], + xla_inputs["state_in"], + xla_inputs["lower_mask"], + xla_inputs["identity"], + xla_inputs["lower_mask_diag"], + ) + xm.mark_step() + out_cpu = out_dev.detach().cpu().float() + state_cpu = state_dev.detach().cpu().float() + run_elapsed_seconds.append(time.perf_counter() - run_start) + + assert out_cpu is not None + assert state_cpu is not None + + output_close = torch.allclose(out_cpu, ref_out, atol=args.atol, rtol=args.rtol) + state_close = torch.allclose(state_cpu, ref_state, atol=args.atol, rtol=args.rtol) + output_finite = bool(torch.isfinite(out_cpu).all().item()) + state_finite = bool(torch.isfinite(state_cpu).all().item()) + passed = bool(output_close and state_close and output_finite and state_finite) + + result = { + "passed": passed, + "seed": args.seed, + "seq_len": args.seq_len, + "chunk_size": args.chunk_size, + "heads": args.heads if args.multihead else 1, + "head_group_size": args.head_group_size if args.multihead else 1, + "launch_specs": launch_spec_labels if args.multihead else [], + "multihead": args.multihead, + "runs": args.runs, + "run_elapsed_seconds": run_elapsed_seconds, + "cached_run_elapsed_seconds": run_elapsed_seconds[1:], + "atol": args.atol, + "rtol": args.rtol, + "inspect": args.inspect, + "dge": args.dge, + "output_finite": output_finite, + "state_finite": state_finite, + "inspect_dir": str(inspect_dir), + "environment": { + key: os.environ.get(key) + for key in ( + "NEURON_CC_FLAGS", + "NEURON_PLATFORM_TARGET_OVERRIDE", + "NEURON_RT_VISIBLE_CORES", + "NEURON_RT_INSPECT_ENABLE", + "NEURON_RT_ENABLE_DGE_NOTIFICATIONS", + "QWEN36_DELTANET_CHUNK_SIZE", + "QWEN36_DELTANET_SOLVE_BLOCK_SIZE", + "QWEN36_DELTANET_SOLVE_SCAN_STEPS", + "QWEN36_DELTANET_SOLVE_ACTIVE_PREFIX_K", + "QWEN36_DELTANET_SOLVE_MODE", + ) + }, + "nki_vs_reference": { + "output": tensor_metrics(torch, out_cpu, ref_out), + "state": tensor_metrics(torch, state_cpu, ref_state), + }, + } + if args.multihead: + result["nki_vs_reference"]["output_per_head"] = multihead_tensor_metrics( + torch, + out_cpu, + ref_out, + ) + result["nki_vs_reference"]["state_per_head"] = multihead_tensor_metrics( + torch, + state_cpu, + ref_state, + ) + print(json.dumps(result, indent=2, sort_keys=True)) + + if args.fail_on_mismatch and not passed: + return 2 + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/contrib/models/Qwen3.6-27B/scripts/validate_deltanet_recurrent_step_nki.py b/contrib/models/Qwen3.6-27B/scripts/validate_deltanet_recurrent_step_nki.py new file mode 100644 index 00000000..31d428b0 --- /dev/null +++ b/contrib/models/Qwen3.6-27B/scripts/validate_deltanet_recurrent_step_nki.py @@ -0,0 +1,247 @@ +#!/usr/bin/env python3 +"""Validate/profile the Qwen DeltaNet one-token decode NKI kernel. + +The reference path is CPU-only by design. Keeping reference math off the XLA +device avoids compiling extra NEFFs that obscure the NKI kernel profile. +""" + +from __future__ import annotations + +import argparse +import importlib.util +import json +import math +import os +import sys +from pathlib import Path +from typing import Any + + +P_MAX = 128 + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Validate/profile deltanet_recurrent_step against CPU math." + ) + parser.add_argument("--seed", type=int, default=1234) + parser.add_argument("--runs", type=int, default=1) + parser.add_argument("--batch-heads", type=int, default=4) + parser.add_argument("--target", default="trn2") + parser.add_argument("--lnc", type=int, default=1) + parser.add_argument("--visible-cores", default="0") + parser.add_argument("--inspect", action="store_true") + parser.add_argument("--dge", action="store_true") + parser.add_argument( + "--inspect-dir", + default="/mnt/trainium_artifacts/profiles/deltanet_recurrent_step_isolated", + ) + parser.add_argument("--atol", type=float, default=2.0e-2) + parser.add_argument("--rtol", type=float, default=2.0e-2) + parser.add_argument("--value-scale", type=float, default=0.05) + parser.add_argument("--state-scale", type=float, default=0.01) + parser.add_argument("--gate-scale", type=float, default=0.01) + parser.add_argument("--fail-on-mismatch", action="store_true") + return parser.parse_args() + + +def configure_environment(args: argparse.Namespace) -> Path: + os.environ.setdefault("NEURON_PLATFORM_TARGET_OVERRIDE", args.target) + os.environ.setdefault("NEURON_CC_FLAGS", f"--target {args.target} --lnc {args.lnc}") + os.environ.setdefault("NEURON_RT_VISIBLE_CORES", args.visible_cores) + + inspect_dir = Path(args.inspect_dir).expanduser().resolve() + if args.inspect: + inspect_dir.mkdir(parents=True, exist_ok=True) + os.environ["NEURON_RT_INSPECT_ENABLE"] = "1" + os.environ["NEURON_RT_INSPECT_DEVICE_PROFILE"] = "1" + os.environ["NEURON_RT_INSPECT_SYSTEM_PROFILE"] = "0" + os.environ["NEURON_RT_INSPECT_OUTPUT_DIR"] = str(inspect_dir) + os.environ["XLA_IR_DEBUG"] = "1" + os.environ["XLA_HLO_DEBUG"] = "1" + os.environ["NEURON_FRAMEWORK_DEBUG"] = "1" + if args.dge: + os.environ["NEURON_RT_ENABLE_DGE_NOTIFICATIONS"] = "1" + return inspect_dir + + +def add_qwen_to_path() -> None: + script_path = Path(__file__).resolve() + qwen_root = script_path.parents[1] + sys.path.insert(0, str(qwen_root)) + + +def load_step_kernel(): + kernel_path = ( + Path(__file__).resolve().parents[1] + / "src" + / "nki_kernels" + / "nki_deltanet.py" + ) + spec = importlib.util.spec_from_file_location( + "qwen36_nki_deltanet_step_under_test", + kernel_path, + ) + module = importlib.util.module_from_spec(spec) + assert spec.loader is not None + spec.loader.exec_module(module) + return module.deltanet_recurrent_step_batched + + +def make_inputs(torch: Any, args: argparse.Namespace) -> dict[str, Any]: + generator = torch.Generator(device="cpu") + generator.manual_seed(args.seed) + if args.batch_heads <= 0: + raise ValueError("--batch-heads must be positive") + + def randn(shape: tuple[int, ...], scale: float) -> Any: + return torch.randn(shape, generator=generator, dtype=torch.float32) * scale + + query = randn((args.batch_heads, P_MAX), args.value_scale) + key = randn((args.batch_heads, P_MAX), args.value_scale) + value = randn((args.batch_heads, P_MAX), args.value_scale) + state_in = randn((args.batch_heads * P_MAX, P_MAX), args.state_scale) + + query = torch.nn.functional.normalize(query, p=2, dim=-1) / math.sqrt(P_MAX) + key = torch.nn.functional.normalize(key, p=2, dim=-1) + + beta = torch.sigmoid(randn((args.batch_heads, 1), 1.0)).contiguous() + + g = ( + -torch.nn.functional.softplus(randn((args.batch_heads, 1), 1.0)) + * args.gate_scale + ).contiguous() + + return { + "query": query.contiguous(), + "key": key.contiguous(), + "value": value.contiguous(), + "g": g, + "beta": beta, + "state_in": state_in.contiguous(), + } + + +def reference_math(torch: Any, inputs: dict[str, Any]) -> tuple[Any, Any]: + outputs = [] + states = [] + batch_heads = inputs["query"].shape[0] + + for bh in range(batch_heads): + q = inputs["query"][bh] + k = inputs["key"][bh] + v = inputs["value"][bh] + g = inputs["g"][bh].reshape(1, 1) + beta = inputs["beta"][bh] + state = inputs["state_in"][bh * P_MAX : (bh + 1) * P_MAX] + + state_decayed = state * torch.exp(g) + kv_mem = (state_decayed * k.unsqueeze(-1)).sum(dim=0) + delta = (v - kv_mem) * beta + state_out = state_decayed + k.unsqueeze(-1) * delta.unsqueeze(0) + output = (state_out * q.unsqueeze(-1)).sum(dim=0) + outputs.append(output) + states.append(state_out) + + return torch.stack(outputs, dim=0).contiguous(), torch.cat(states, dim=0).contiguous() + + +def tensor_metrics(torch: Any, actual: Any, expected: Any) -> dict[str, float | bool]: + diff = actual - expected + expected_norm = torch.linalg.vector_norm(expected).item() + diff_norm = torch.linalg.vector_norm(diff).item() + actual_flat = actual.reshape(-1).to(torch.float64) + expected_flat = expected.reshape(-1).to(torch.float64) + denom = torch.linalg.vector_norm(actual_flat) * torch.linalg.vector_norm( + expected_flat + ) + cosine = ( + float(torch.dot(actual_flat, expected_flat) / denom) + if denom.item() != 0.0 + else float("nan") + ) + return { + "finite": bool(torch.isfinite(actual).all().item()), + "max_abs": float(torch.max(torch.abs(diff)).item()), + "mean_abs": float(torch.mean(torch.abs(diff)).item()), + "diff_norm": float(diff_norm), + "expected_norm": float(expected_norm), + "relative_norm": float(diff_norm / max(expected_norm, 1.0e-12)), + "cosine": cosine, + } + + +def main() -> int: + args = parse_args() + inspect_dir = configure_environment(args) + add_qwen_to_path() + + import torch + import torch_xla.core.xla_model as xm + + deltanet_recurrent_step_batched = load_step_kernel() + + inputs = make_inputs(torch, args) + ref_out, ref_state = reference_math(torch, inputs) + + device = xm.xla_device() + xla_inputs = {name: tensor.to(device=device) for name, tensor in inputs.items()} + + out_cpu = state_cpu = None + for _ in range(args.runs): + out_dev, state_dev = deltanet_recurrent_step_batched( + xla_inputs["query"], + xla_inputs["key"], + xla_inputs["value"], + xla_inputs["g"], + xla_inputs["beta"], + xla_inputs["state_in"], + ) + xm.mark_step() + out_cpu = out_dev.detach().cpu().float() + state_cpu = state_dev.detach().cpu().float() + + assert out_cpu is not None + assert state_cpu is not None + + output_close = torch.allclose(out_cpu, ref_out, atol=args.atol, rtol=args.rtol) + state_close = torch.allclose(state_cpu, ref_state, atol=args.atol, rtol=args.rtol) + output_finite = bool(torch.isfinite(out_cpu).all().item()) + state_finite = bool(torch.isfinite(state_cpu).all().item()) + passed = bool(output_close and state_close and output_finite and state_finite) + + result = { + "passed": passed, + "seed": args.seed, + "runs": args.runs, + "batch_heads": args.batch_heads, + "atol": args.atol, + "rtol": args.rtol, + "inspect": args.inspect, + "dge": args.dge, + "output_finite": output_finite, + "state_finite": state_finite, + "inspect_dir": str(inspect_dir), + "environment": { + key: os.environ.get(key) + for key in ( + "NEURON_CC_FLAGS", + "NEURON_PLATFORM_TARGET_OVERRIDE", + "NEURON_RT_VISIBLE_CORES", + "NEURON_RT_INSPECT_OUTPUT_DIR", + ) + }, + "metrics": { + "output": tensor_metrics(torch, out_cpu, ref_out), + "state": tensor_metrics(torch, state_cpu, ref_state), + }, + } + print(json.dumps(result, indent=2, sort_keys=True)) + + if args.fail_on_mismatch and not passed: + return 1 + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/contrib/models/Qwen3.6-27B/scripts/validate_qwen_segcte_attention.py b/contrib/models/Qwen3.6-27B/scripts/validate_qwen_segcte_attention.py new file mode 100644 index 00000000..5a7dac90 --- /dev/null +++ b/contrib/models/Qwen3.6-27B/scripts/validate_qwen_segcte_attention.py @@ -0,0 +1,361 @@ +#!/usr/bin/env python3 +"""Validate Qwen head_dim=256 segmented CTE attention against CPU math. + +The CPU reference stays off XLA so the generated NEFF is the NKI kernel under +test, matching the NKI debugging workflow. +""" + +from __future__ import annotations + +import argparse +import json +import math +import os +from dataclasses import dataclass +from typing import Any + + +@dataclass(frozen=True) +class Case: + prior_len: int + active_real_len: int + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Validate qwen_segcte256 segmented CTE attention." + ) + parser.add_argument("--seed", type=int, default=1234) + parser.add_argument("--batch-size", type=int, default=1) + parser.add_argument("--q-heads", type=int, default=6) + parser.add_argument("--kv-heads", type=int, default=1) + parser.add_argument("--head-dim", type=int, default=256) + parser.add_argument("--q-len", type=int, default=512) + parser.add_argument("--block-size", type=int, default=256) + parser.add_argument("--prior-seg-size", type=int, default=512) + parser.add_argument( + "--cases", + default="0:512,512:512,1024:512,1024:201", + help="Comma-separated prior:real-active cases. q_len stays padded.", + ) + parser.add_argument("--target", default="trn2") + parser.add_argument("--lnc", type=int, default=2) + parser.add_argument("--visible-cores", default="0") + parser.add_argument("--scale", type=float, default=0.12) + parser.add_argument( + "--reference-score-scale", + type=float, + default=1.0, + help="Extra multiplier applied only to CPU-reference attention scores.", + ) + parser.add_argument("--value-scale", type=float, default=1.0) + parser.add_argument( + "--value-pattern", + choices=("random", "ones", "token"), + default="random", + help="V pattern for diagnostics. ones should return ones for any valid softmax.", + ) + parser.add_argument( + "--identity-block-table", + action="store_true", + help="Use logical block i -> physical block i instead of random physical IDs.", + ) + parser.add_argument( + "--pad-block-table-to-q-len", + action=argparse.BooleanOptionalAction, + default=True, + help=( + "Size the block table for prior_len + padded q_len, matching the " + "segmented CTE serving wrapper. The CPU comparison still uses only " + "active_real_len tokens." + ), + ) + parser.add_argument( + "--normalize-qk", + action=argparse.BooleanOptionalAction, + default=True, + help=( + "Generate model-faithful Q/K: l2-normalized K and l2-normalized " + "Q divided by sqrt(head_dim), matching Qwen qk-norm attention." + ), + ) + parser.add_argument("--atol", type=float, default=5.0e-2) + parser.add_argument("--rtol", type=float, default=5.0e-2) + parser.add_argument("--fail-on-mismatch", action="store_true") + return parser.parse_args() + + +def configure_environment(args: argparse.Namespace) -> None: + os.environ.setdefault("NEURON_PLATFORM_TARGET_OVERRIDE", args.target) + os.environ.setdefault("NEURON_CC_FLAGS", f"--target {args.target} --lnc {args.lnc}") + os.environ.setdefault("NEURON_RT_VISIBLE_CORES", args.visible_cores) + + +def parse_cases(raw: str, q_len: int) -> list[Case]: + cases: list[Case] = [] + for item in raw.split(","): + item = item.strip() + if not item: + continue + prior_raw, active_raw = item.split(":", 1) + case = Case(prior_len=int(prior_raw), active_real_len=int(active_raw)) + if case.prior_len < 0: + raise ValueError(f"prior_len must be non-negative: {case}") + if case.active_real_len <= 0 or case.active_real_len > q_len: + raise ValueError(f"active_real_len must be in 1..q_len: {case}") + cases.append(case) + if not cases: + raise ValueError("--cases produced no cases") + return cases + + +def make_case_tensors(torch: Any, args: argparse.Namespace, case: Case) -> dict[str, Any]: + generator = torch.Generator(device="cpu") + generator.manual_seed(args.seed + case.prior_len * 1009 + case.active_real_len) + + batch = args.batch_size + q_heads = args.q_heads + kv_heads = args.kv_heads + q_len = args.q_len + block = args.block_size + head_dim = args.head_dim + total_padded_len = case.prior_len + q_len + real_total_len = case.prior_len + case.active_real_len + needed_blocks = (real_total_len + block - 1) // block + padded_blocks = (total_padded_len + block - 1) // block + block_table_blocks = padded_blocks if args.pad_block_table_to_q_len else needed_blocks + physical_blocks = max(block_table_blocks + 7, padded_blocks + 3, 8) + + q_raw = torch.randn( + (batch * q_heads, q_len, head_dim), + generator=generator, + dtype=torch.float32, + ) + k_raw = torch.randn( + (batch, kv_heads, total_padded_len, head_dim), + generator=generator, + dtype=torch.float32, + ) + if args.normalize_qk: + q = torch.nn.functional.normalize(q_raw, p=2, dim=-1, eps=1.0e-6) + q = q * (args.scale / math.sqrt(head_dim)) + logical_k = torch.nn.functional.normalize( + k_raw, + p=2, + dim=-1, + eps=1.0e-6, + ) + logical_k = logical_k * args.scale + else: + q = q_raw * (args.scale / math.sqrt(head_dim)) + logical_k = k_raw * args.scale + if args.value_pattern == "ones": + logical_v = torch.ones( + (batch, kv_heads, total_padded_len, head_dim), + dtype=torch.float32, + ) * args.value_scale + elif args.value_pattern == "token": + token_values = torch.arange(total_padded_len, dtype=torch.float32) + token_values = token_values.view(1, 1, total_padded_len, 1) + logical_v = token_values.expand(batch, kv_heads, -1, head_dim).contiguous() + logical_v = logical_v / max(1.0, float(total_padded_len)) * args.value_scale + else: + logical_v = torch.randn( + (batch, kv_heads, total_padded_len, head_dim), + generator=generator, + dtype=torch.float32, + ) * args.value_scale + + block_table = torch.empty((batch, block_table_blocks), dtype=torch.int32) + k_cache = torch.zeros( + (physical_blocks, kv_heads, block, head_dim), + dtype=torch.float32, + ) + v_cache = torch.zeros_like(k_cache) + + for b in range(batch): + if args.identity_block_table: + perm = torch.arange(block_table_blocks, dtype=torch.int64) + else: + # Exercise indirect block-table reads instead of the identity layout. + perm = torch.randperm(physical_blocks, generator=generator)[:block_table_blocks] + block_table[b] = perm.to(torch.int32) + for logical_block, physical_block_t in enumerate(perm.tolist()): + start = logical_block * block + end = min(start + block, total_padded_len) + width = end - start + if width <= 0: + continue + k_cache[physical_block_t, :, :width, :] = logical_k[b, :, start:end, :] + v_cache[physical_block_t, :, :width, :] = logical_v[b, :, start:end, :] + + return { + "q": q.contiguous(), + "logical_k": logical_k.contiguous(), + "logical_v": logical_v.contiguous(), + "k_cache": k_cache.contiguous(), + "v_cache": v_cache.contiguous(), + "block_table": block_table.contiguous(), + "prior_tokens": torch.full((batch, 1), case.prior_len, dtype=torch.int32), + } + + +def cpu_reference(torch: Any, tensors: dict[str, Any], args: argparse.Namespace, case: Case) -> Any: + batch = args.batch_size + q_heads = args.q_heads + kv_heads = args.kv_heads + q_len = args.q_len + active_real = case.active_real_len + compare_len = active_real + output = torch.empty( + (batch * q_heads, compare_len, args.head_dim), + dtype=torch.float32, + ) + + key_positions = torch.arange( + case.prior_len + active_real, + dtype=torch.int64, + ).view(1, -1) + query_positions = ( + case.prior_len + torch.arange(active_real, dtype=torch.int64) + ).view(-1, 1) + causal = key_positions <= query_positions + + for b in range(batch): + for qh in range(q_heads): + flat_head = b * q_heads + qh + kvh = qh * kv_heads // q_heads + q = tensors["q"][flat_head, :active_real, :].to(torch.bfloat16).float() + k = tensors["logical_k"][ + b, + kvh, + : case.prior_len + active_real, + :, + ].to(torch.bfloat16).float() + v = tensors["logical_v"][ + b, + kvh, + : case.prior_len + active_real, + :, + ].to(torch.bfloat16).float() + scores = q @ k.T + scores = scores * args.reference_score_scale + scores = scores.masked_fill(~causal, -float("inf")) + probs = torch.softmax(scores, dim=-1) + output[flat_head] = probs @ v + return output + + +def metrics(torch: Any, actual: Any, expected: Any, args: argparse.Namespace) -> dict[str, Any]: + actual_f = actual.detach().float() + expected_f = expected.detach().float() + diff = actual_f - expected_f + diff_norm = torch.linalg.vector_norm(diff) + expected_norm = torch.linalg.vector_norm(expected_f) + rel_norm = diff_norm / expected_norm.clamp_min(1.0e-12) + actual_flat = actual_f.reshape(-1) + expected_flat = expected_f.reshape(-1) + cosine = torch.nn.functional.cosine_similarity( + actual_flat, + expected_flat, + dim=0, + eps=1.0e-12, + ) + allclose = torch.allclose(actual_f, expected_f, atol=args.atol, rtol=args.rtol) + return { + "allclose": bool(allclose), + "max_abs": float(diff.abs().max().item()), + "rel_norm": float(rel_norm.item()), + "cosine": float(cosine.item()), + "actual_min": float(actual_f.min().item()), + "actual_max": float(actual_f.max().item()), + "actual_mean": float(actual_f.mean().item()), + "expected_min": float(expected_f.min().item()), + "expected_max": float(expected_f.max().item()), + "expected_mean": float(expected_f.mean().item()), + } + + +def main() -> int: + args = parse_args() + configure_environment(args) + + import nki + import torch + from torch_xla.core import xla_model as xm + + from neuronx_distributed_inference.modules.attention.nki_kernels.qwen_segcte256.attention_segmented_cte_256 import ( + attention_segmented_cte, + ) + + if args.head_dim != 256: + raise ValueError("qwen_segcte256 validation must use --head-dim 256") + if args.q_heads % args.kv_heads != 0: + raise ValueError("--q-heads must be divisible by --kv-heads") + if args.q_len % args.block_size != 0 or args.q_len % 128 != 0: + raise ValueError("--q-len must be divisible by block size and 128") + if args.prior_seg_size % args.block_size != 0: + raise ValueError("--prior-seg-size must be divisible by --block-size") + + cases = parse_cases(args.cases, args.q_len) + device = xm.xla_device() + kernel = nki.jit(attention_segmented_cte) + failures: list[dict[str, Any]] = [] + rows: list[dict[str, Any]] = [] + + for case in cases: + tensors = make_case_tensors(torch, args, case) + expected = cpu_reference(torch, tensors, args, case) + q = tensors["q"].to(torch.bfloat16).to(device=device) + k_cache = tensors["k_cache"].to(torch.bfloat16).to(device=device) + v_cache = tensors["v_cache"].to(torch.bfloat16).to(device=device) + block_table = tensors["block_table"].to(torch.int32).to(device=device) + prior_tokens = tensors["prior_tokens"].to(torch.int32).to(device=device) + + launch = kernel[args.lnc] if args.lnc > 1 else kernel + actual_full = launch( + q, + k_cache, + v_cache, + block_table, + prior_tokens, + args.block_size, + args.prior_seg_size, + 1.0, + tp_q=True, + tp_out=False, + sliding_window=None, + sink=None, + num_q_heads=args.q_heads, + k_pre_transposed=False, + ).cpu() + actual = actual_full[:, : case.active_real_len, :] + row = { + "case": { + "prior_len": case.prior_len, + "active_real_len": case.active_real_len, + "q_len": args.q_len, + }, + "block_table_shape": list(tensors["block_table"].shape), + "k_cache_shape": list(tensors["k_cache"].shape), + "metrics": metrics(torch, actual, expected, args), + } + rows.append(row) + print(json.dumps(row, sort_keys=True), flush=True) + if not row["metrics"]["allclose"]: + failures.append(row) + + summary = { + "ok": not failures, + "num_cases": len(rows), + "num_failures": len(failures), + "args": vars(args), + } + print(json.dumps({"summary": summary}, sort_keys=True), flush=True) + if failures and args.fail_on_mismatch: + return 1 + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/contrib/models/Qwen3.6-27B/src/__init__.py b/contrib/models/Qwen3.6-27B/src/__init__.py new file mode 100644 index 00000000..7e79aa03 --- /dev/null +++ b/contrib/models/Qwen3.6-27B/src/__init__.py @@ -0,0 +1,41 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +from src.modeling_qwen35 import ( + NeuronGatedDeltaNet, + NeuronQwen35Attention, + NeuronQwen35DecoderLayer, + NeuronQwen35ForCausalLM, + NeuronQwen35Model, + Qwen35DecoderModelInstance, + Qwen35InferenceConfig, + Qwen35MLP, + Qwen35ModelWrapper, +) +from src.modeling_qwen35_vision import ( + NeuronQwen35VisionForImageEncoding, + NeuronQwen35VisionModel, +) +from src.modeling_qwen35_vl import ( + NeuronQwen35VLForCausalLM, + Qwen35VLInferenceConfig, +) + +__all__ = [ + # Text decoder + "NeuronGatedDeltaNet", + "NeuronQwen35Attention", + "NeuronQwen35DecoderLayer", + "NeuronQwen35ForCausalLM", + "NeuronQwen35Model", + "Qwen35DecoderModelInstance", + "Qwen35InferenceConfig", + "Qwen35MLP", + "Qwen35ModelWrapper", + # Vision encoder + "NeuronQwen35VisionForImageEncoding", + "NeuronQwen35VisionModel", + # Vision-language + "NeuronQwen35VLForCausalLM", + "Qwen35VLInferenceConfig", +] diff --git a/contrib/models/Qwen3.6-27B/src/hybrid_apc.py b/contrib/models/Qwen3.6-27B/src/hybrid_apc.py new file mode 100644 index 00000000..f1304a73 --- /dev/null +++ b/contrib/models/Qwen3.6-27B/src/hybrid_apc.py @@ -0,0 +1,1798 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Qwen hybrid APC metadata lifecycle. + +This module intentionally stores only control-plane metadata. GDN recurrent and +conv checkpoint tensors live in the model-side checkpoint bank; the metadata +store owns prefix identity, validity, refcounts, LRU state, and memory +accounting. +""" + +from __future__ import annotations + +import hashlib +import os +import struct +from collections import OrderedDict, deque +from dataclasses import dataclass +from typing import Callable, Hashable, Iterable, NamedTuple + +import torch + + +class HybridPrefixKey(NamedTuple): + cumulative_prefix_hash: Hashable + prefix_len: int + block_size: int + cache_salt: Hashable | None + model_revision: str + layout_version: int + tp_rank: int + recurrent_dtype: str + conv_dtype: str + + +class HybridAPCHitPlan(NamedTuple): + attention_hit_len: int + recurrent_hit_len: int + conv_hit_len: int + usable_hit_len: int + restore_checkpoint_prefix_len: int + residual_replay_len: int + suffix_len: int + checkpoint_slot: int | None + checkpoint_key: HybridPrefixKey | None + + +class HybridAPCPreparedRequest(NamedTuple): + request_id: Hashable + input_dict: dict[str, torch.Tensor] + plan: HybridAPCHitPlan + commit_prefix_len: int + commit_key: HybridPrefixKey | None + commit_slot: int | None + attention_block_refs: tuple[int, ...] + + +@dataclass +class HybridAPCStats: + checkpoints: int = 0 + bytes_used: int = 0 + evictions: int = 0 + hits: int = 0 + misses: int = 0 + + +@dataclass +class HybridAPCRequestRecord: + request_id: Hashable + state: str + restored_key: HybridPrefixKey | None = None + committed_keys: list[HybridPrefixKey] | None = None + reserved_slots: list[int] | None = None + + def __post_init__(self): + if self.committed_keys is None: + self.committed_keys = [] + if self.reserved_slots is None: + self.reserved_slots = [] + + +@dataclass +class HybridPrefixCheckpoint: + key: HybridPrefixKey + prefix_len: int + attention_block_refs: tuple[int, ...] + gdn_checkpoint_slot: int + valid_recurrent_layers: torch.Tensor + valid_conv_layers: torch.Tensor + refcount: int = 0 + last_access_step: int = 0 + bytes_used: int = 0 + evictable: bool = True + attention_valid: bool = True + + def has_valid_recurrent(self, required_layers: tuple[int, ...]) -> bool: + return _mask_has_layers(self.valid_recurrent_layers, required_layers) + + def has_valid_conv(self, required_layers: tuple[int, ...]) -> bool: + return _mask_has_layers(self.valid_conv_layers, required_layers) + + def has_valid_gdn(self, required_layers: tuple[int, ...]) -> bool: + return self.has_valid_recurrent(required_layers) and self.has_valid_conv( + required_layers + ) + + def has_valid_hybrid_state(self, required_layers: tuple[int, ...]) -> bool: + return self.attention_valid and self.has_valid_gdn(required_layers) + + +def _normalize_dtype(dtype: str | torch.dtype) -> str: + if dtype == torch.float32: + return "float32" + if dtype == torch.bfloat16: + return "bfloat16" + normalized = str(dtype).lower() + aliases = { + "fp32": "float32", + "float32": "float32", + "torch.float32": "float32", + "bf16": "bfloat16", + "bfloat16": "bfloat16", + "torch.bfloat16": "bfloat16", + } + if normalized not in aliases: + raise ValueError(f"unsupported hybrid APC dtype: {dtype}") + return aliases[normalized] + + +def _mask_has_layers(mask: torch.Tensor, required_layers: tuple[int, ...]) -> bool: + if mask.numel() == 0: + return False + for layer in required_layers: + if layer >= mask.numel() or not bool(mask[layer].item()): + return False + return True + + +def _env_flag(name: str) -> bool: + value = os.environ.get(name) + return value is not None and value.strip().lower() not in { + "", + "0", + "false", + "no", + "off", + } + + +def _publish_scheduler_gdn_checkpoint(key): + try: + from qwen36_hybrid_apc_scheduler_patch import ( # noqa: WPS433 + register_hybrid_apc_gdn_checkpoint, + ) + except Exception: + return + try: + register_hybrid_apc_gdn_checkpoint(key) + except Exception: + return + + +def _unpublish_scheduler_gdn_checkpoint(key): + try: + from qwen36_hybrid_apc_scheduler_patch import ( # noqa: WPS433 + unregister_hybrid_apc_gdn_checkpoint, + ) + except Exception: + return + try: + unregister_hybrid_apc_gdn_checkpoint(key) + except Exception: + return + + +def estimate_qwen_gdn_checkpoint_bytes_per_rank( + *, + num_gdn_layers: int = 48, + local_value_heads: int = 12, + local_key_heads: int = 4, + key_dim: int = 128, + value_dim: int = 128, + conv_kernel_size: int = 4, + recurrent_dtype: str | torch.dtype = "float32", + conv_dtype: str | torch.dtype = "bfloat16", +) -> int: + recurrent_dtype = _normalize_dtype(recurrent_dtype) + conv_dtype = _normalize_dtype(conv_dtype) + recurrent_bytes = 4 if recurrent_dtype == "float32" else 2 + conv_bytes = 4 if conv_dtype == "float32" else 2 + recurrent_numel = num_gdn_layers * local_value_heads * key_dim * value_dim + conv_dim = 2 * local_key_heads * key_dim + local_value_heads * value_dim + conv_numel = num_gdn_layers * conv_dim * (conv_kernel_size - 1) + return recurrent_numel * recurrent_bytes + conv_numel * conv_bytes + + +def estimate_qwen_hybrid_cache_bytes_per_rank( + *, + max_context_len: int, + checkpoint_interval: int, + num_attention_layers: int = 16, + local_kv_heads: int = 1, + attention_head_dim: int = 256, + attention_kv_dtype: str | torch.dtype = "bfloat16", + **gdn_kwargs, +) -> dict[str, int]: + attention_dtype = _normalize_dtype(attention_kv_dtype) + attention_bytes = 4 if attention_dtype == "float32" else 2 + attention_kv = ( + int(max_context_len) + * num_attention_layers + * 2 + * local_kv_heads + * attention_head_dim + * attention_bytes + ) + checkpoints = max(0, int(max_context_len)) // int(checkpoint_interval) + gdn_per_checkpoint = estimate_qwen_gdn_checkpoint_bytes_per_rank(**gdn_kwargs) + gdn_total = checkpoints * gdn_per_checkpoint + return { + "attention_kv_bytes": attention_kv, + "gdn_checkpoint_bytes": gdn_total, + "gdn_bytes_per_checkpoint": gdn_per_checkpoint, + "num_gdn_checkpoints": checkpoints, + "total_bytes": attention_kv + gdn_total, + } + + +def _flatten_single_request_tokens(token_ids: torch.Tensor | Iterable[int]) -> torch.Tensor: + if isinstance(token_ids, torch.Tensor): + tokens = token_ids.detach().cpu() + else: + tokens = torch.tensor(list(token_ids), dtype=torch.int64) + if tokens.ndim == 2 and tokens.shape[0] == 1: + tokens = tokens.reshape(-1) + elif tokens.ndim != 1: + raise ValueError( + "token_ids must be a single request tensor with shape [seq] or [1, seq], " + f"got {tuple(tokens.shape)}" + ) + return tokens.to(torch.int64).contiguous() + + +def build_cumulative_prefix_hashes( + token_ids: torch.Tensor | Iterable[int], + *, + block_size: int, + prefix_lens: Iterable[int] | None = None, +) -> dict[int, str]: + """Build deterministic cumulative prefix hashes at block boundaries. + + This is a local scheduler bridge helper, not a replacement for vLLM's + production block hash. It deliberately hashes the parent digest plus the + next block's token bytes so a reused final block with a different parent + prefix produces a different cumulative hash. + """ + + block_size = int(block_size) + if block_size <= 0: + raise ValueError(f"block_size must be positive, got {block_size}") + + tokens = _flatten_single_request_tokens(token_ids) + seq_len = int(tokens.numel()) + if prefix_lens is None: + requested_lens = set(range(block_size, seq_len + 1, block_size)) + else: + requested_lens = {int(prefix_len) for prefix_len in prefix_lens} + requested_lens = {prefix_len for prefix_len in requested_lens if prefix_len > 0} + for prefix_len in requested_lens: + if prefix_len > seq_len: + raise ValueError(f"prefix_len {prefix_len} exceeds token length {seq_len}") + if prefix_len % block_size != 0: + raise ValueError( + f"prefix_len {prefix_len} must be a multiple of block_size {block_size}" + ) + + if not requested_lens: + return {} + + max_prefix_len = max(requested_lens) + parent_digest = b"" + hashes: dict[int, str] = {} + for block_start in range(0, max_prefix_len, block_size): + block_end = block_start + block_size + block = tokens[block_start:block_end] + digest = hashlib.blake2b(digest_size=16) + digest.update(parent_digest) + digest.update(struct.pack(" int: + checkpoint_interval = int(checkpoint_interval) + if checkpoint_interval <= 0: + raise ValueError( + f"checkpoint_interval must be positive, got {checkpoint_interval}" + ) + return max(0, int(prefix_len)) // checkpoint_interval * checkpoint_interval + + +def apply_hybrid_apc_prefill_plan( + input_dict: dict[str, torch.Tensor], + *, + plan: HybridAPCHitPlan, + commit_slot: int | None = None, + request_prefix_len: int | None = None, + gdn_active_carry: bool = False, + block_size: int | None = None, +) -> dict[str, torch.Tensor]: + """Materialize model inputs for a scheduler-selected hybrid APC hit plan. + + The serving scheduler owns prefix hashing, attention block-table selection, + checkpoint lookup, and checkpoint-slot reservation. This helper only applies + the chosen restore boundary to the token tensors and emits explicit + restore/commit control tensors. GDN state is restored only when the plan has + a checkpoint slot; slot ID presence alone is never treated as a cache hit. + """ + + if "input_ids" not in input_dict: + raise KeyError("input_ids is required to apply a hybrid APC prefill plan") + + input_ids = input_dict["input_ids"] + if input_ids.ndim != 2: + raise ValueError(f"input_ids must be [batch, seq], got {tuple(input_ids.shape)}") + + batch_size, available_len = input_ids.shape + prompt_len = available_len if request_prefix_len is None else int(request_prefix_len) + restore_len = int(plan.restore_checkpoint_prefix_len) + if prompt_len < 0: + raise ValueError(f"request_prefix_len must be non-negative, got {prompt_len}") + if restore_len < 0 or restore_len > prompt_len: + raise ValueError( + "restore_checkpoint_prefix_len must be in [0, request_prefix_len], " + f"got {restore_len} and {prompt_len}" + ) + if prompt_len > available_len: + raise ValueError( + f"request_prefix_len {prompt_len} exceeds input_ids length {available_len}" + ) + if plan.checkpoint_slot is None and restore_len != 0: + raise ValueError("restore checkpoint prefix length requires a checkpoint slot") + if plan.checkpoint_slot is not None and restore_len == 0: + raise ValueError("checkpoint slot restore requires a positive prefix length") + + output = dict(input_dict) + suffix_len = prompt_len - restore_len + device = input_ids.device + + output["input_ids"] = input_ids[:, restore_len:prompt_len] + + attention_mask = input_dict.get("attention_mask") + if ( + isinstance(attention_mask, torch.Tensor) + and attention_mask.ndim >= 2 + and attention_mask.shape[0] == batch_size + and attention_mask.shape[1] >= prompt_len + ): + output["attention_mask"] = attention_mask[:, restore_len:prompt_len] + + inputs_embeds = input_dict.get("inputs_embeds") + if ( + isinstance(inputs_embeds, torch.Tensor) + and inputs_embeds.ndim >= 3 + and inputs_embeds.shape[0] == batch_size + and inputs_embeds.shape[1] >= prompt_len + ): + output["inputs_embeds"] = inputs_embeds[:, restore_len:prompt_len] + + def _slot_mapping_covers_suffix(value: torch.Tensor) -> bool: + if value.ndim == 1: + if batch_size == 1: + return int(value.numel()) >= suffix_len + return int(value.numel()) >= batch_size * suffix_len + if value.ndim >= 2: + return value.shape[0] >= batch_size and value.shape[1] >= suffix_len + return False + + def _slot_mapping_needs_repair(value) -> bool: + if not isinstance(value, torch.Tensor) or value.numel() == 0: + return True + if not _slot_mapping_covers_suffix(value): + return True + return bool((value.to(torch.int64) < 0).any().item()) + + unbacked_attention_hit = ( + plan.checkpoint_slot is None + and int(plan.attention_hit_len) > 0 + and restore_len == 0 + ) + + def _synthesize_suffix_slot_mapping() -> torch.Tensor | None: + if block_size is None or int(block_size) <= 0 or suffix_len <= 0: + return None + block_table = input_dict.get("block_table") + if not isinstance(block_table, torch.Tensor) or block_table.numel() == 0: + return None + table = block_table + if table.ndim == 1: + table = table.unsqueeze(0) + if table.ndim != 2 or table.shape[0] < batch_size: + return None + block_size_int = int(block_size) + positions = torch.arange( + restore_len, + prompt_len, + dtype=torch.int64, + device=table.device, + ) + logical_blocks = torch.div(positions, block_size_int, rounding_mode="floor") + if logical_blocks.numel() == 0 or int(logical_blocks.max().item()) >= table.shape[1]: + return None + offsets = positions.remainder(block_size_int) + rows = [] + table_i64 = table.to(torch.int64) + for batch_idx in range(batch_size): + physical_blocks = torch.index_select( + table_i64[batch_idx], + 0, + logical_blocks, + ) + rows.append(physical_blocks * block_size_int + offsets) + return torch.stack(rows, dim=0) + + slot_mapping = input_dict.get("slot_mapping") + if ( + isinstance(slot_mapping, torch.Tensor) + and slot_mapping.ndim >= 2 + and slot_mapping.shape[0] == batch_size + and slot_mapping.shape[1] >= prompt_len + ): + output["slot_mapping"] = slot_mapping[:, restore_len:prompt_len] + elif isinstance(slot_mapping, torch.Tensor) and slot_mapping.ndim == 1: + if batch_size == 1 and slot_mapping.numel() >= prompt_len: + output["slot_mapping"] = slot_mapping[restore_len:prompt_len] + elif slot_mapping.numel() >= batch_size * prompt_len: + flattened = slot_mapping.reshape(batch_size, -1) + output["slot_mapping"] = flattened[:, restore_len:prompt_len] + if unbacked_attention_hit: + synthesized_slot_mapping = _synthesize_suffix_slot_mapping() + if synthesized_slot_mapping is not None: + dtype = ( + slot_mapping.dtype + if isinstance(slot_mapping, torch.Tensor) + else torch.int32 + ) + output["slot_mapping"] = synthesized_slot_mapping.to(dtype=dtype) + elif _slot_mapping_needs_repair(output.get("slot_mapping")): + synthesized_slot_mapping = _synthesize_suffix_slot_mapping() + if synthesized_slot_mapping is not None: + dtype = ( + slot_mapping.dtype + if isinstance(slot_mapping, torch.Tensor) + else torch.int32 + ) + repaired_slot_mapping = synthesized_slot_mapping.to(dtype=dtype) + current_slot_mapping = output.get("slot_mapping") + if ( + isinstance(current_slot_mapping, torch.Tensor) + and current_slot_mapping.numel() == repaired_slot_mapping.numel() + ): + repaired_slot_mapping = torch.where( + current_slot_mapping.to(torch.int64) < 0, + repaired_slot_mapping.reshape(current_slot_mapping.shape), + current_slot_mapping, + ) + output["slot_mapping"] = repaired_slot_mapping + + position_template = input_dict.get("position_ids") + position_dtype = ( + position_template.dtype + if isinstance(position_template, torch.Tensor) + else torch.int64 + ) + position_ids = torch.arange( + restore_len, + prompt_len, + dtype=position_dtype, + device=device, + ).unsqueeze(0) + output["position_ids"] = position_ids.expand(batch_size, suffix_len).contiguous() + default_rotary_positions = torch.arange( + restore_len, + prompt_len, + dtype=torch.int32, + device=device, + ) + output["rotary_position_ids"] = default_rotary_positions.view( + 1, + 1, + suffix_len, + ).expand(3, batch_size, suffix_len).contiguous() + + for key in ("rotary_position_id", "rotary_position_ids"): + value = input_dict.get(key) + if not isinstance(value, torch.Tensor): + continue + if ( + value.ndim == 2 + and value.shape[0] == batch_size + and value.shape[1] >= prompt_len + ): + output[key] = value[:, restore_len:prompt_len] + elif ( + value.ndim == 3 + and value.shape[1] == batch_size + and value.shape[2] >= prompt_len + ): + output[key] = value[:, :, restore_len:prompt_len] + + def _batch_i32(value: int) -> torch.Tensor: + return torch.full((batch_size,), int(value), dtype=torch.int32, device=device) + + def _batch_i32_col(value: int) -> torch.Tensor: + return torch.full((batch_size, 1), int(value), dtype=torch.int32, device=device) + + disable_restore = _env_flag("QWEN36_DISABLE_HYBRID_GDN_RESTORE") + disable_commit = _env_flag("QWEN36_DISABLE_HYBRID_GDN_COMMIT") + restore_available = plan.checkpoint_slot is not None and not disable_restore + restore_enabled = restore_available and not gdn_active_carry + commit_enabled = commit_slot is not None and not disable_commit + output["computed_context_lens"] = _batch_i32_col(restore_len) + output["full_context_lens"] = _batch_i32_col(prompt_len) + output["num_queries"] = _batch_i32_col(suffix_len) + output["hybrid_restore_slot_ids"] = _batch_i32( + 0 if not restore_available else int(plan.checkpoint_slot) + ) + output["hybrid_restore_mask"] = _batch_i32(1 if restore_enabled else 0) + output["hybrid_restore_prefix_lens"] = _batch_i32( + restore_len if restore_available else 0 + ) + output["hybrid_commit_slot_ids"] = _batch_i32( + 0 if not commit_enabled else commit_slot + ) + output["hybrid_commit_mask"] = _batch_i32(1 if commit_enabled else 0) + + if os.environ.get("QWEN36_HYBRID_APC_DEBUG") == "1": + print( + "[hybrid_apc_debug] apply " + f"prompt_len={prompt_len} restore_len={restore_len} " + f"suffix_len={suffix_len} restore_slot={plan.checkpoint_slot} " + f"commit_slot={commit_slot} gdn_active_carry={gdn_active_carry} " + f"input_shape={tuple(input_ids.shape)} " + f"output_shape={tuple(output['input_ids'].shape)}", + flush=True, + ) + + return output + + +def apply_hybrid_apc_suffix_prefill_plan( + input_dict: dict[str, torch.Tensor], + *, + plan: HybridAPCHitPlan, + request_prefix_len: int, + commit_slot: int | None = None, + attention_block_refs: Iterable[int] | None = None, + gdn_active_carry: bool = False, +) -> dict[str, torch.Tensor]: + """Materialize Hybrid APC controls when vLLM already sliced to suffix. + + This diagnostic path is used only when the caller explicitly allows an + unhashed single-checkpoint restore. The input tokens are already the active + suffix, so this helper must not slice token tensors by ``restore_len``. + """ + + if "input_ids" not in input_dict: + raise KeyError("input_ids is required to apply a hybrid APC suffix plan") + + input_ids = input_dict["input_ids"] + if input_ids.ndim != 2: + raise ValueError(f"input_ids must be [batch, seq], got {tuple(input_ids.shape)}") + + batch_size, suffix_len = input_ids.shape + prompt_len = int(request_prefix_len) + restore_len = int(plan.restore_checkpoint_prefix_len) + expected_suffix_len = prompt_len - restore_len + if plan.checkpoint_slot is None or restore_len <= 0: + raise ValueError("suffix-only Hybrid APC restore requires a checkpoint slot") + if expected_suffix_len != suffix_len: + raise ValueError( + "suffix-only Hybrid APC input length mismatch: " + f"expected {expected_suffix_len}, got {suffix_len}" + ) + + output = dict(input_dict) + device = input_ids.device + output["input_ids"] = input_ids + refs: tuple[int, ...] = () + if attention_block_refs is not None: + refs = tuple(int(ref) for ref in attention_block_refs) + if refs: + block_table_template = input_dict.get("block_table") + refs_table = torch.tensor( + [refs] * batch_size, + dtype=torch.int32, + device=device, + ) + has_block_table = ( + isinstance(block_table_template, torch.Tensor) + and block_table_template.numel() > 0 + and block_table_template.ndim >= 2 + and block_table_template.shape[0] >= batch_size + ) + if has_block_table: + active_table = block_table_template[:batch_size].to( + dtype=torch.int32, + device=device, + ) + if active_table.shape[1] > len(refs): + suffix_table = active_table[:, len(refs) :] + else: + suffix_table = active_table + output["block_table"] = torch.cat([refs_table, suffix_table], dim=1) + else: + output["block_table"] = refs_table + + position_template = input_dict.get("position_ids") + position_dtype = ( + position_template.dtype + if isinstance(position_template, torch.Tensor) + else torch.int64 + ) + position_ids = torch.arange( + restore_len, + prompt_len, + dtype=position_dtype, + device=device, + ).unsqueeze(0) + output["position_ids"] = position_ids.expand(batch_size, suffix_len).contiguous() + default_rotary_positions = torch.arange( + restore_len, + prompt_len, + dtype=torch.int32, + device=device, + ) + output["rotary_position_ids"] = default_rotary_positions.view( + 1, + 1, + suffix_len, + ).expand(3, batch_size, suffix_len).contiguous() + + for key in ("rotary_position_id", "rotary_position_ids"): + value = input_dict.get(key) + if not isinstance(value, torch.Tensor): + continue + rotary_positions = torch.arange( + restore_len, + prompt_len, + dtype=value.dtype, + device=device, + ) + if value.ndim == 2: + output[key] = rotary_positions.unsqueeze(0).expand( + batch_size, + suffix_len, + ).contiguous() + elif value.ndim == 3: + output[key] = rotary_positions.view(1, 1, suffix_len).expand( + value.shape[0], + batch_size, + suffix_len, + ).contiguous() + + def _batch_i32(value: int) -> torch.Tensor: + return torch.full((batch_size,), int(value), dtype=torch.int32, device=device) + + def _batch_i32_col(value: int) -> torch.Tensor: + return torch.full((batch_size, 1), int(value), dtype=torch.int32, device=device) + + output["computed_context_lens"] = _batch_i32_col(restore_len) + output["full_context_lens"] = _batch_i32_col(prompt_len) + output["num_queries"] = _batch_i32_col(suffix_len) + output["hybrid_restore_slot_ids"] = _batch_i32(int(plan.checkpoint_slot)) + output["hybrid_restore_mask"] = _batch_i32(0 if gdn_active_carry else 1) + output["hybrid_restore_prefix_lens"] = _batch_i32(restore_len) + output["hybrid_commit_slot_ids"] = _batch_i32(0 if commit_slot is None else commit_slot) + output["hybrid_commit_mask"] = _batch_i32(1 if commit_slot is not None else 0) + + if os.environ.get("QWEN36_HYBRID_APC_DEBUG") == "1": + print( + "[hybrid_apc_debug] apply-suffix " + f"prompt_len={prompt_len} restore_len={restore_len} " + f"suffix_len={suffix_len} restore_slot={plan.checkpoint_slot} " + f"commit_slot={commit_slot} " + f"gdn_active_carry={gdn_active_carry} " + f"attention_block_refs={refs} " + f"input_shape={tuple(input_ids.shape)}", + flush=True, + ) + + return output + + +class HybridAPCSlotAllocator: + """Small checkpoint-slot allocator for local scheduler integration tests.""" + + def __init__(self, num_slots: int): + num_slots = int(num_slots) + if num_slots <= 0: + raise ValueError(f"num_slots must be positive, got {num_slots}") + self.num_slots = num_slots + self._free = deque(range(num_slots)) + self._reserved: set[int] = set() + self._committed: set[int] = set() + + @property + def free_slots(self) -> tuple[int, ...]: + return tuple(self._free) + + @property + def reserved_slots(self) -> tuple[int, ...]: + return tuple(sorted(self._reserved)) + + @property + def committed_slots(self) -> tuple[int, ...]: + return tuple(sorted(self._committed)) + + def reserve(self) -> int: + if not self._free: + raise RuntimeError("no hybrid APC checkpoint slots available") + slot = int(self._free.popleft()) + self._reserved.add(slot) + return slot + + def mark_committed(self, slot: int): + slot = int(slot) + self.validate_slot_range(slot) + if slot not in self._reserved and slot not in self._committed: + raise ValueError(f"hybrid APC checkpoint slot {slot} is not reserved") + self._reserved.discard(slot) + self._committed.add(slot) + + def release(self, slot: int): + slot = int(slot) + self.validate_slot_range(slot) + was_known = slot in self._reserved or slot in self._committed + self._reserved.discard(slot) + self._committed.discard(slot) + if was_known and slot not in self._free: + self._free.append(slot) + + def release_committed(self, slot: int) -> bool: + slot = int(slot) + self.validate_slot_range(slot) + if slot in self._reserved: + return False + was_committed = slot in self._committed + self._committed.discard(slot) + if was_committed and slot not in self._free: + self._free.append(slot) + return was_committed + + def validate_slot_range(self, slot: int): + slot = int(slot) + if slot < 0 or slot >= self.num_slots: + raise ValueError( + f"hybrid APC checkpoint slot {slot} is outside " + f"[0, {self.num_slots})" + ) + + +class HybridAPCSchedulerBridge: + """Local request-prep bridge for production hybrid APC scheduler wiring. + + The real vLLM/NxDI scheduler must supply the attention APC hit length, + active attention block refs, and tenant/cache metadata. This bridge performs + the Qwen hybrid-specific part: intersect attention hits with GDN checkpoint + metadata, materialize suffix model inputs, reserve a GDN checkpoint slot, + and commit checkpoint metadata after a successful prefill. + """ + + def __init__( + self, + *, + store: "HybridAPCMetadataStore", + slot_allocator: HybridAPCSlotAllocator, + cache_salt: Hashable | None = None, + model_revision: str | None = None, + layout_version: int | None = None, + tp_rank: int | None = None, + recurrent_dtype: str | torch.dtype | None = None, + conv_dtype: str | torch.dtype | None = None, + allow_local_hash_fallback: bool = True, + require_attention_block_refs: bool = False, + reject_unbacked_attention_hits: bool = True, + ): + self.store = store + self.slot_allocator = slot_allocator + self.cache_salt = cache_salt + self.model_revision = model_revision + self.layout_version = layout_version + self.tp_rank = tp_rank + self.recurrent_dtype = recurrent_dtype + self.conv_dtype = conv_dtype + self.allow_local_hash_fallback = bool(allow_local_hash_fallback) + self.require_attention_block_refs = bool(require_attention_block_refs) + self.reject_unbacked_attention_hits = bool(reject_unbacked_attention_hits) + self._same_request_committed_keys: dict[Hashable, set[HybridPrefixKey]] = {} + self.store.set_checkpoint_slot_releaser( + self.slot_allocator.release_committed + ) + + @property + def requires_external_metadata(self) -> bool: + return ( + not self.allow_local_hash_fallback + or self.require_attention_block_refs + ) + + def prepare_request( + self, + *, + request_id: Hashable, + input_dict: dict[str, torch.Tensor], + attention_hit_len: int, + request_prefix_len: int | None = None, + cumulative_hashes_by_prefix_len: dict[int, Hashable] | None = None, + attention_block_refs_by_prefix_len: dict[int, Iterable[int]] | None = None, + ) -> HybridAPCPreparedRequest: + if "input_ids" not in input_dict: + raise KeyError("input_ids is required for hybrid APC request prep") + input_ids = input_dict["input_ids"] + prompt_len = ( + int(input_ids.shape[1]) + if request_prefix_len is None + else int(request_prefix_len) + ) + commit_prefix_len = floor_to_checkpoint_boundary( + prompt_len, + self.store.checkpoint_interval, + ) + + if cumulative_hashes_by_prefix_len is None: + if not self.allow_local_hash_fallback: + raise ValueError( + "hybrid APC production mode requires vLLM cumulative prefix " + "hashes; set hybrid_apc_allow_local_hash_fallback=True only " + "for controlled local validation" + ) + cumulative_hashes_by_prefix_len = build_cumulative_prefix_hashes( + input_ids, + block_size=self.store.block_size, + ) + + plan = self.store.compute_hit_plan( + cumulative_hashes_by_prefix_len=cumulative_hashes_by_prefix_len, + attention_hit_len=attention_hit_len, + request_prefix_len=prompt_len, + cache_salt=self.cache_salt, + model_revision=self.model_revision, + layout_version=self.layout_version, + tp_rank=self.tp_rank, + recurrent_dtype=self.recurrent_dtype, + conv_dtype=self.conv_dtype, + ) + disable_restore = _env_flag("QWEN36_DISABLE_HYBRID_GDN_RESTORE") + disable_commit = _env_flag("QWEN36_DISABLE_HYBRID_GDN_COMMIT") + if disable_restore and plan.checkpoint_slot is not None: + plan = HybridAPCHitPlan( + attention_hit_len=0, + recurrent_hit_len=0, + conv_hit_len=0, + usable_hit_len=0, + restore_checkpoint_prefix_len=0, + residual_replay_len=0, + suffix_len=prompt_len, + checkpoint_slot=None, + checkpoint_key=None, + ) + if ( + self.reject_unbacked_attention_hits + and not _env_flag("QWEN36_ALLOW_UNBACKED_HYBRID_APC_FALLBACK") + and not disable_restore + and int(attention_hit_len) > 0 + and plan.checkpoint_slot is None + ): + raise ValueError( + "hybrid APC received an attention prefix hit without a matching " + "GDN checkpoint; scheduler must intersect attention KV hits with " + "GDN checkpoint hits or disable prefix reuse for this request" + ) + if plan.checkpoint_slot is not None: + self.slot_allocator.validate_slot_range(plan.checkpoint_slot) + + commit_key = None + commit_slot = None + attention_block_refs: tuple[int, ...] = () + # The Neuron checkpoint bank can only commit the active GDN state at + # the end of this traced prefill call. Do not label that state as an + # earlier checkpoint boundary unless the current prefill ends exactly + # at that boundary; scheduler-level chunking must create those boundary + # calls. + can_commit_boundary = commit_prefix_len > 0 and commit_prefix_len == prompt_len + if can_commit_boundary and not disable_commit: + if commit_prefix_len not in cumulative_hashes_by_prefix_len: + raise ValueError( + f"missing cumulative prefix hash for commit boundary {commit_prefix_len}" + ) + commit_key = self.store.make_key( + cumulative_prefix_hash=cumulative_hashes_by_prefix_len[commit_prefix_len], + prefix_len=commit_prefix_len, + cache_salt=self.cache_salt, + model_revision=self.model_revision, + layout_version=self.layout_version, + tp_rank=self.tp_rank, + recurrent_dtype=self.recurrent_dtype, + conv_dtype=self.conv_dtype, + ) + if attention_block_refs_by_prefix_len is not None: + attention_block_refs = tuple( + int(ref) + for ref in attention_block_refs_by_prefix_len.get( + commit_prefix_len, + (), + ) + ) + if not attention_block_refs and plan.checkpoint_key is not None: + suffix_refs = tuple( + int(ref) + for ref in attention_block_refs_by_prefix_len.get( + plan.suffix_len, + (), + ) + ) + checkpoint = self.store.lookup(plan.checkpoint_key) + if ( + checkpoint is not None + and suffix_refs + and commit_prefix_len + == checkpoint.prefix_len + plan.suffix_len + ): + attention_block_refs = ( + tuple(int(ref) for ref in checkpoint.attention_block_refs) + + suffix_refs + ) + if not attention_block_refs and not self.require_attention_block_refs: + attention_block_refs = tuple( + range(commit_prefix_len // self.store.block_size) + ) + if self.store.lookup(commit_key) is None: + commit_slot = self._reserve_commit_slot() + + same_request_keys = self._same_request_committed_keys.get(request_id, set()) + existing_record = self.store._requests.get(request_id) + if existing_record is not None: + same_request_keys = same_request_keys | set(existing_record.committed_keys) + gdn_active_carry = ( + plan.checkpoint_key is not None + and plan.checkpoint_key in same_request_keys + ) + if ( + os.environ.get("QWEN36_HYBRID_APC_DEBUG") == "1" + and gdn_active_carry + ): + print( + "[hybrid_apc_debug] prefill-active-carry " + f"request_id={request_id!r} prefix_len={plan.restore_checkpoint_prefix_len} " + f"slot={plan.checkpoint_slot}", + flush=True, + ) + + model_inputs = apply_hybrid_apc_prefill_plan( + input_dict, + plan=plan, + commit_slot=commit_slot, + request_prefix_len=prompt_len, + block_size=self.store.block_size, + gdn_active_carry=gdn_active_carry, + ) + record = self.store.on_request_restore( + request_id=request_id, + checkpoint_key=plan.checkpoint_key, + ) + if commit_slot is not None: + record.reserved_slots.append(commit_slot) + self.store.on_prefill_running(request_id) + + return HybridAPCPreparedRequest( + request_id=request_id, + input_dict=model_inputs, + plan=plan, + commit_prefix_len=commit_prefix_len, + commit_key=commit_key, + commit_slot=commit_slot, + attention_block_refs=attention_block_refs, + ) + + def prepare_suffix_only_request( + self, + *, + request_id: Hashable, + input_dict: dict[str, torch.Tensor], + attention_hit_len: int, + request_prefix_len: int, + cumulative_hashes_by_prefix_len: dict[int, Hashable] | None = None, + attention_block_refs_by_prefix_len: dict[int, Iterable[int]] | None = None, + ) -> HybridAPCPreparedRequest | None: + """Prepare a suffix-only request using scheduler-approved restore metadata.""" + + if "input_ids" not in input_dict: + raise KeyError("input_ids is required for hybrid APC request prep") + input_ids = input_dict["input_ids"] + if input_ids.ndim != 2: + raise ValueError( + f"input_ids must be [batch, seq], got {tuple(input_ids.shape)}" + ) + request_prefix_len = int(request_prefix_len) + attention_hit_len = max(0, int(attention_hit_len)) + suffix_len = int(input_ids.shape[1]) + restore_len = min(attention_hit_len, request_prefix_len) + restore_len = floor_to_checkpoint_boundary( + restore_len, + self.store.checkpoint_interval, + ) + if restore_len <= 0 or request_prefix_len - restore_len != suffix_len: + return None + + checkpoint = None + checkpoint_key = None + try: + from qwen36_hybrid_apc_scheduler_patch import ( # noqa: WPS433 + pop_hybrid_apc_authorized_prefix_key, + ) + except Exception: + pop_hybrid_apc_authorized_prefix_key = None + + if pop_hybrid_apc_authorized_prefix_key is not None: + checkpoint_key = pop_hybrid_apc_authorized_prefix_key( + prefix_len=restore_len, + request_id=request_id, + cache_salt=self.cache_salt, + model_revision=self.model_revision or self.store.model_revision, + layout_version=( + self.layout_version + if self.layout_version is not None + else self.store.layout_version + ), + tp_rank=self.tp_rank if self.tp_rank is not None else self.store.tp_rank, + recurrent_dtype=( + self.recurrent_dtype + if self.recurrent_dtype is not None + else self.store.recurrent_dtype + ), + conv_dtype=( + self.conv_dtype + if self.conv_dtype is not None + else self.store.conv_dtype + ), + ) + if checkpoint_key is not None: + checkpoint = self.store.lookup(checkpoint_key) + + if checkpoint is None and checkpoint_key is not None: + raise ValueError( + "suffix-only hybrid APC received a scheduler-authorized " + "prefix key that is missing from the GDN checkpoint store" + ) + + if checkpoint is None and _env_flag( + "QWEN36_HYBRID_APC_ALLOW_UNHASHED_SINGLE_PREFIX_RESTORE" + ): + checkpoint = self.store.lookup_unique_prefix_len( + prefix_len=restore_len, + cache_salt=self.cache_salt, + model_revision=self.model_revision, + layout_version=self.layout_version, + tp_rank=self.tp_rank, + recurrent_dtype=self.recurrent_dtype, + conv_dtype=self.conv_dtype, + ) + if checkpoint is None: + if self.reject_unbacked_attention_hits: + raise ValueError( + "suffix-only hybrid APC received an attention prefix hit " + "without scheduler-authorized GDN checkpoint metadata" + ) + return None + + plan = HybridAPCHitPlan( + attention_hit_len=attention_hit_len, + recurrent_hit_len=checkpoint.prefix_len, + conv_hit_len=checkpoint.prefix_len, + usable_hit_len=checkpoint.prefix_len, + restore_checkpoint_prefix_len=checkpoint.prefix_len, + residual_replay_len=0, + suffix_len=suffix_len, + checkpoint_slot=checkpoint.gdn_checkpoint_slot, + checkpoint_key=checkpoint.key, + ) + disable_commit = _env_flag("QWEN36_DISABLE_HYBRID_GDN_COMMIT") + commit_prefix_len = floor_to_checkpoint_boundary( + request_prefix_len, + self.store.checkpoint_interval, + ) + commit_key = None + commit_slot = None + attention_block_refs: tuple[int, ...] = () + can_commit_boundary = ( + commit_prefix_len > 0 + and commit_prefix_len == request_prefix_len + and not disable_commit + ) + if can_commit_boundary: + if cumulative_hashes_by_prefix_len is None: + if not self.allow_local_hash_fallback: + raise ValueError( + "hybrid APC production mode requires vLLM cumulative prefix " + f"hashes to commit suffix-only boundary {commit_prefix_len}" + ) + elif commit_prefix_len not in cumulative_hashes_by_prefix_len: + raise ValueError( + f"missing cumulative prefix hash for commit boundary {commit_prefix_len}" + ) + if cumulative_hashes_by_prefix_len is None: + can_commit_boundary = False + if can_commit_boundary: + commit_key = self.store.make_key( + cumulative_prefix_hash=cumulative_hashes_by_prefix_len[commit_prefix_len], + prefix_len=commit_prefix_len, + cache_salt=self.cache_salt, + model_revision=self.model_revision, + layout_version=self.layout_version, + tp_rank=self.tp_rank, + recurrent_dtype=self.recurrent_dtype, + conv_dtype=self.conv_dtype, + ) + if attention_block_refs_by_prefix_len is not None: + attention_block_refs = tuple( + int(ref) + for ref in attention_block_refs_by_prefix_len.get( + commit_prefix_len, + (), + ) + ) + if not attention_block_refs: + suffix_refs = tuple( + int(ref) + for ref in attention_block_refs_by_prefix_len.get( + suffix_len, + (), + ) + ) + if ( + suffix_refs + and commit_prefix_len == checkpoint.prefix_len + suffix_len + ): + attention_block_refs = ( + tuple(int(ref) for ref in checkpoint.attention_block_refs) + + suffix_refs + ) + if not attention_block_refs and not self.require_attention_block_refs: + attention_block_refs = tuple( + range(commit_prefix_len // self.store.block_size) + ) + if self.store.lookup(commit_key) is None: + commit_slot = self._reserve_commit_slot() + + same_request_keys = self._same_request_committed_keys.get(request_id, set()) + existing_record = self.store._requests.get(request_id) + if existing_record is not None: + same_request_keys = same_request_keys | set(existing_record.committed_keys) + gdn_active_carry = checkpoint.key in same_request_keys + if ( + os.environ.get("QWEN36_HYBRID_APC_DEBUG") == "1" + and gdn_active_carry + ): + print( + "[hybrid_apc_debug] suffix-active-carry " + f"request_id={request_id!r} prefix_len={checkpoint.prefix_len} " + f"slot={checkpoint.gdn_checkpoint_slot}", + flush=True, + ) + + model_inputs = apply_hybrid_apc_suffix_prefill_plan( + input_dict, + plan=plan, + request_prefix_len=request_prefix_len, + commit_slot=commit_slot, + attention_block_refs=checkpoint.attention_block_refs, + gdn_active_carry=gdn_active_carry, + ) + record = self.store.on_request_restore( + request_id=request_id, + checkpoint_key=plan.checkpoint_key, + ) + if commit_slot is not None: + record.reserved_slots.append(commit_slot) + self.store.on_prefill_running(request_id) + + return HybridAPCPreparedRequest( + request_id=request_id, + input_dict=model_inputs, + plan=plan, + commit_prefix_len=commit_prefix_len, + commit_key=commit_key, + commit_slot=commit_slot, + attention_block_refs=attention_block_refs or checkpoint.attention_block_refs, + ) + + def commit_prefill( + self, + prepared: HybridAPCPreparedRequest, + *, + attention_block_refs: Iterable[int] | None = None, + bytes_used: int = 0, + ) -> HybridPrefixCheckpoint | None: + if prepared.commit_key is None or prepared.commit_slot is None: + return None + refs = ( + tuple(int(ref) for ref in attention_block_refs) + if attention_block_refs is not None + else prepared.attention_block_refs + ) + if self.require_attention_block_refs and not refs: + raise ValueError( + "hybrid APC checkpoint commit requires real attention block refs " + "from the vLLM/NxDI APC allocator" + ) + checkpoint = self.store.insert( + key=prepared.commit_key, + attention_block_refs=refs, + gdn_checkpoint_slot=prepared.commit_slot, + bytes_used=bytes_used, + ) + _publish_scheduler_gdn_checkpoint(prepared.commit_key) + self.slot_allocator.mark_committed(prepared.commit_slot) + record = self.store.on_checkpoint_committed( + request_id=prepared.request_id, + checkpoint_key=prepared.commit_key, + ) + self._same_request_committed_keys.setdefault( + prepared.request_id, + set(), + ).add(prepared.commit_key) + if len(self._same_request_committed_keys) > 4096: + self._same_request_committed_keys.clear() + if prepared.commit_slot in record.reserved_slots: + record.reserved_slots.remove(prepared.commit_slot) + return checkpoint + + def _reserve_commit_slot(self) -> int: + try: + return self.slot_allocator.reserve() + except RuntimeError: + target_checkpoints = self.slot_allocator.num_slots - 1 + if self.store.max_checkpoints is not None: + target_checkpoints = min( + target_checkpoints, + int(self.store.max_checkpoints) - 1, + ) + evicted = self.store.evict_lru( + target_checkpoints=max(0, target_checkpoints) + ) + if evicted: + return self.slot_allocator.reserve() + raise + + def finish_request(self, request_id: Hashable) -> HybridAPCRequestRecord | None: + record = self.store.on_request_finish(request_id) + if record is not None: + for slot in record.reserved_slots: + self.slot_allocator.release(slot) + return record + + def cancel_request( + self, + prepared: HybridAPCPreparedRequest, + ) -> HybridAPCRequestRecord | None: + record = self.store._requests.get(prepared.request_id) + if ( + prepared.commit_slot is not None + and record is not None + and prepared.commit_slot in record.reserved_slots + ): + self.slot_allocator.release(prepared.commit_slot) + return self.store.on_request_cancel(prepared.request_id) + + +class HybridAPCMetadataStore: + """CPU-side lifecycle store for hybrid prefix-boundary checkpoints.""" + + def __init__( + self, + *, + required_gdn_layers: Iterable[int], + block_size: int, + checkpoint_interval: int | None = None, + max_checkpoints: int | None = None, + max_bytes: int | None = None, + layout_version: int = 1, + model_revision: str = "unknown", + tp_rank: int = 0, + recurrent_dtype: str | torch.dtype = "float32", + conv_dtype: str | torch.dtype = "bfloat16", + allow_residual_replay: bool = False, + checkpoint_slot_releaser: Callable[[int], object] | None = None, + ): + self.required_gdn_layers = tuple(sorted({int(x) for x in required_gdn_layers})) + if not self.required_gdn_layers: + raise ValueError("required_gdn_layers must not be empty") + self.num_layer_mask_bits = max(self.required_gdn_layers) + 1 + self.block_size = int(block_size) + if self.block_size <= 0: + raise ValueError(f"block_size must be positive, got {block_size}") + self.checkpoint_interval = ( + self.block_size + if checkpoint_interval is None + else int(checkpoint_interval) + ) + if self.checkpoint_interval <= 0: + raise ValueError( + f"checkpoint_interval must be positive, got {checkpoint_interval}" + ) + if self.checkpoint_interval % self.block_size != 0: + raise ValueError( + "checkpoint_interval must be a multiple of block_size for v0 " + f"hybrid APC, got {self.checkpoint_interval} and {self.block_size}" + ) + self.max_checkpoints = max_checkpoints + if self.max_checkpoints is not None and self.max_checkpoints <= 0: + raise ValueError(f"max_checkpoints must be positive, got {max_checkpoints}") + self.max_bytes = max_bytes + if self.max_bytes is not None and self.max_bytes <= 0: + raise ValueError(f"max_bytes must be positive, got {max_bytes}") + self.layout_version = int(layout_version) + self.model_revision = str(model_revision) + self.tp_rank = int(tp_rank) + self.recurrent_dtype = _normalize_dtype(recurrent_dtype) + self.conv_dtype = _normalize_dtype(conv_dtype) + self.allow_residual_replay = bool(allow_residual_replay) + self._checkpoint_slot_releaser = checkpoint_slot_releaser + + self._by_key: OrderedDict[HybridPrefixKey, HybridPrefixCheckpoint] = ( + OrderedDict() + ) + self._slot_to_key: dict[int, HybridPrefixKey] = {} + self._requests: dict[Hashable, HybridAPCRequestRecord] = {} + self._step = 0 + self.stats = HybridAPCStats() + + def set_checkpoint_slot_releaser( + self, + releaser: Callable[[int], object] | None, + ): + self._checkpoint_slot_releaser = releaser + + def __len__(self) -> int: + return len(self._by_key) + + @property + def bytes_used(self) -> int: + return sum(checkpoint.bytes_used for checkpoint in self._by_key.values()) + + def _next_step(self) -> int: + self._step += 1 + return self._step + + def make_key( + self, + *, + cumulative_prefix_hash: Hashable, + prefix_len: int, + cache_salt: Hashable | None = None, + model_revision: str | None = None, + layout_version: int | None = None, + tp_rank: int | None = None, + recurrent_dtype: str | torch.dtype | None = None, + conv_dtype: str | torch.dtype | None = None, + ) -> HybridPrefixKey: + prefix_len = int(prefix_len) + if prefix_len < 0: + raise ValueError(f"prefix_len must be non-negative, got {prefix_len}") + if prefix_len % self.checkpoint_interval != 0: + raise ValueError( + "prefix_len must align to checkpoint_interval " + f"{self.checkpoint_interval}, got {prefix_len}" + ) + return HybridPrefixKey( + cumulative_prefix_hash=cumulative_prefix_hash, + prefix_len=prefix_len, + block_size=self.block_size, + cache_salt=cache_salt, + model_revision=self.model_revision + if model_revision is None + else str(model_revision), + layout_version=self.layout_version + if layout_version is None + else int(layout_version), + tp_rank=self.tp_rank if tp_rank is None else int(tp_rank), + recurrent_dtype=self.recurrent_dtype + if recurrent_dtype is None + else _normalize_dtype(recurrent_dtype), + conv_dtype=self.conv_dtype if conv_dtype is None else _normalize_dtype(conv_dtype), + ) + + def _make_mask(self, valid_layers: torch.Tensor | int | Iterable[int] | None): + if valid_layers is None: + layers = self.required_gdn_layers + mask = torch.zeros(self.num_layer_mask_bits, dtype=torch.bool) + mask[list(layers)] = True + return mask + if isinstance(valid_layers, torch.Tensor): + mask = valid_layers.detach().cpu().to(torch.bool).flatten().clone() + if mask.numel() < self.num_layer_mask_bits: + padded = torch.zeros(self.num_layer_mask_bits, dtype=torch.bool) + padded[: mask.numel()] = mask + mask = padded + return mask + mask = torch.zeros(self.num_layer_mask_bits, dtype=torch.bool) + if isinstance(valid_layers, int): + bitmask = int(valid_layers) + for layer in range(self.num_layer_mask_bits): + mask[layer] = bool(bitmask & (1 << layer)) + return mask + for layer in valid_layers: + layer = int(layer) + if layer >= mask.numel(): + padded = torch.zeros(layer + 1, dtype=torch.bool) + padded[: mask.numel()] = mask + mask = padded + mask[layer] = True + return mask + + def insert( + self, + *, + key: HybridPrefixKey, + attention_block_refs: Iterable[int], + gdn_checkpoint_slot: int, + valid_recurrent_layers: torch.Tensor | int | Iterable[int] | None = None, + valid_conv_layers: torch.Tensor | int | Iterable[int] | None = None, + bytes_used: int = 0, + evictable: bool = True, + ) -> HybridPrefixCheckpoint: + if key.block_size != self.block_size: + raise ValueError( + f"key block_size {key.block_size} does not match store block_size {self.block_size}" + ) + if key.layout_version != self.layout_version: + raise ValueError( + f"key layout_version {key.layout_version} does not match store layout_version {self.layout_version}" + ) + recurrent_mask = self._make_mask(valid_recurrent_layers) + conv_mask = self._make_mask(valid_conv_layers) + checkpoint = HybridPrefixCheckpoint( + key=key, + prefix_len=key.prefix_len, + attention_block_refs=tuple(int(ref) for ref in attention_block_refs), + gdn_checkpoint_slot=int(gdn_checkpoint_slot), + valid_recurrent_layers=recurrent_mask, + valid_conv_layers=conv_mask, + last_access_step=self._next_step(), + bytes_used=int(bytes_used), + evictable=bool(evictable), + ) + if not checkpoint.has_valid_gdn(self.required_gdn_layers): + raise ValueError("checkpoint is missing recurrent or conv state") + + old = self._slot_to_key.get(checkpoint.gdn_checkpoint_slot) + if old is not None and old != key: + self.mark_invalid(old) + if key in self._by_key: + old_checkpoint = self._by_key[key] + self._slot_to_key.pop(old_checkpoint.gdn_checkpoint_slot, None) + if ( + old_checkpoint.gdn_checkpoint_slot != checkpoint.gdn_checkpoint_slot + and self._checkpoint_slot_releaser is not None + ): + self._checkpoint_slot_releaser(old_checkpoint.gdn_checkpoint_slot) + self._by_key[key] = checkpoint + self._by_key.move_to_end(key) + self._slot_to_key[checkpoint.gdn_checkpoint_slot] = key + self._evict_over_budget() + self._refresh_stats() + return checkpoint + + def lookup( + self, + key: HybridPrefixKey, + *, + require_attention: bool = True, + require_gdn: bool = True, + ) -> HybridPrefixCheckpoint | None: + checkpoint = self._by_key.get(key) + if checkpoint is None: + self.stats.misses += 1 + return None + if require_attention and not checkpoint.attention_valid: + self.stats.misses += 1 + return None + if require_gdn and not checkpoint.has_valid_gdn(self.required_gdn_layers): + self.stats.misses += 1 + return None + checkpoint.last_access_step = self._next_step() + self._by_key.move_to_end(key) + self.stats.hits += 1 + return checkpoint + + def lookup_unique_prefix_len( + self, + *, + prefix_len: int, + cache_salt: Hashable | None = None, + model_revision: str | None = None, + layout_version: int | None = None, + tp_rank: int | None = None, + recurrent_dtype: str | torch.dtype | None = None, + conv_dtype: str | torch.dtype | None = None, + ) -> HybridPrefixCheckpoint | None: + """Return the only valid checkpoint at a prefix length, if unambiguous.""" + + prefix_len = int(prefix_len) + model_revision = self.model_revision if model_revision is None else str(model_revision) + layout_version = self.layout_version if layout_version is None else int(layout_version) + tp_rank = self.tp_rank if tp_rank is None else int(tp_rank) + recurrent_dtype = ( + self.recurrent_dtype + if recurrent_dtype is None + else _normalize_dtype(recurrent_dtype) + ) + conv_dtype = self.conv_dtype if conv_dtype is None else _normalize_dtype(conv_dtype) + + candidates: list[HybridPrefixKey] = [] + for key, checkpoint in self._by_key.items(): + if key.prefix_len != prefix_len: + continue + if key.cache_salt != cache_salt: + continue + if key.model_revision != model_revision: + continue + if key.layout_version != layout_version: + continue + if key.tp_rank != tp_rank: + continue + if key.recurrent_dtype != recurrent_dtype or key.conv_dtype != conv_dtype: + continue + if not checkpoint.attention_valid: + continue + if not checkpoint.has_valid_gdn(self.required_gdn_layers): + continue + candidates.append(key) + + if not candidates: + self.stats.misses += 1 + return None + if len(candidates) > 1: + raise ValueError( + "ambiguous unhashed Hybrid APC restore: " + f"{len(candidates)} checkpoints match prefix_len={prefix_len}" + ) + return self.lookup(candidates[0]) + + def mark_invalid( + self, + key: HybridPrefixKey | None = None, + *, + checkpoint_slot: int | None = None, + state_kind: str | None = None, + layer_id: int | None = None, + ) -> bool: + if key is None: + if checkpoint_slot is None: + raise ValueError("key or checkpoint_slot is required") + key = self._slot_to_key.get(int(checkpoint_slot)) + if key is None: + return False + checkpoint = self._by_key.get(key) + if checkpoint is None: + return False + + if state_kind is None: + self._delete_checkpoint(key) + self._refresh_stats() + return True + if state_kind == "attention": + checkpoint.attention_valid = False + elif state_kind == "recurrent": + if layer_id is None: + checkpoint.valid_recurrent_layers.zero_() + elif int(layer_id) < checkpoint.valid_recurrent_layers.numel(): + checkpoint.valid_recurrent_layers[int(layer_id)] = False + elif state_kind == "conv": + if layer_id is None: + checkpoint.valid_conv_layers.zero_() + elif int(layer_id) < checkpoint.valid_conv_layers.numel(): + checkpoint.valid_conv_layers[int(layer_id)] = False + else: + raise ValueError(f"unknown state_kind: {state_kind}") + _unpublish_scheduler_gdn_checkpoint(key) + return True + + def inc_ref(self, key: HybridPrefixKey) -> int: + checkpoint = self.lookup(key, require_attention=False, require_gdn=False) + if checkpoint is None: + raise KeyError(key) + checkpoint.refcount += 1 + return checkpoint.refcount + + def dec_ref(self, key: HybridPrefixKey) -> int: + checkpoint = self.lookup(key, require_attention=False, require_gdn=False) + if checkpoint is None: + raise KeyError(key) + checkpoint.refcount = max(0, checkpoint.refcount - 1) + return checkpoint.refcount + + def on_request_restore( + self, + *, + request_id: Hashable, + checkpoint_key: HybridPrefixKey | None, + ) -> HybridAPCRequestRecord: + record = HybridAPCRequestRecord( + request_id=request_id, + state="NEW", + restored_key=checkpoint_key, + ) + if checkpoint_key is not None: + self.inc_ref(checkpoint_key) + record.state = "RESTORED_FROM_HYBRID_APC" + self._requests[request_id] = record + return record + + def on_prefill_running(self, request_id: Hashable) -> HybridAPCRequestRecord: + record = self._requests[request_id] + record.state = "PREFILL_RUNNING" + return record + + def on_checkpoint_committed( + self, + *, + request_id: Hashable, + checkpoint_key: HybridPrefixKey, + ) -> HybridAPCRequestRecord: + record = self._requests.setdefault( + request_id, + HybridAPCRequestRecord(request_id=request_id, state="PREFILL_RUNNING"), + ) + record.state = "PREFILL_COMMIT_PENDING" + record.committed_keys.append(checkpoint_key) + return record + + def on_decode_running(self, request_id: Hashable) -> HybridAPCRequestRecord: + record = self._requests[request_id] + record.state = "DECODE_RUNNING" + return record + + def on_request_finish(self, request_id: Hashable) -> HybridAPCRequestRecord | None: + record = self._requests.pop(request_id, None) + if record is None: + return None + if record.restored_key is not None and record.restored_key in self._by_key: + self.dec_ref(record.restored_key) + record.state = "FINISHED" + return record + + def on_request_cancel(self, request_id: Hashable) -> HybridAPCRequestRecord | None: + record = self._requests.pop(request_id, None) + if record is None: + return None + if record.restored_key is not None and record.restored_key in self._by_key: + self.dec_ref(record.restored_key) + for key in record.committed_keys: + if key in self._by_key: + self.mark_invalid(key) + record.state = "CANCELLED" + return record + + def evict_lru(self, *, target_checkpoints: int | None = None) -> list[HybridPrefixKey]: + target = self.max_checkpoints if target_checkpoints is None else target_checkpoints + if target is None: + return [] + evicted: list[HybridPrefixKey] = [] + for key, checkpoint in list(self._by_key.items()): + if len(self._by_key) <= target: + break + if checkpoint.refcount > 0 or not checkpoint.evictable: + continue + self._delete_checkpoint(key) + evicted.append(key) + self.stats.evictions += len(evicted) + self._refresh_stats() + return evicted + + def on_attention_block_evicted(self, block_ref: int) -> list[HybridPrefixKey]: + invalidated: list[HybridPrefixKey] = [] + for key, checkpoint in self._by_key.items(): + if int(block_ref) in checkpoint.attention_block_refs: + checkpoint.attention_valid = False + _unpublish_scheduler_gdn_checkpoint(key) + invalidated.append(key) + return invalidated + + def on_gdn_checkpoint_evicted(self, checkpoint_slot: int) -> bool: + return self.mark_invalid(checkpoint_slot=int(checkpoint_slot)) + + def compute_hit_plan( + self, + *, + cumulative_hashes_by_prefix_len: dict[int, Hashable], + attention_hit_len: int, + request_prefix_len: int, + cache_salt: Hashable | None = None, + model_revision: str | None = None, + layout_version: int | None = None, + tp_rank: int | None = None, + recurrent_dtype: str | torch.dtype | None = None, + conv_dtype: str | torch.dtype | None = None, + ) -> HybridAPCHitPlan: + attention_hit_len = max(0, int(attention_hit_len)) + request_prefix_len = max(0, int(request_prefix_len)) + target_hit_len = min(attention_hit_len, request_prefix_len) + candidate_lens = sorted( + ( + int(prefix_len) + for prefix_len in cumulative_hashes_by_prefix_len + if int(prefix_len) <= target_hit_len + and int(prefix_len) % self.checkpoint_interval == 0 + ), + reverse=True, + ) + + for prefix_len in candidate_lens: + key = self.make_key( + cumulative_prefix_hash=cumulative_hashes_by_prefix_len[prefix_len], + prefix_len=prefix_len, + cache_salt=cache_salt, + model_revision=model_revision, + layout_version=layout_version, + tp_rank=tp_rank, + recurrent_dtype=recurrent_dtype, + conv_dtype=conv_dtype, + ) + checkpoint = self.lookup(key) + if checkpoint is None: + continue + + if self.allow_residual_replay: + usable_hit_len = target_hit_len + residual_replay_len = target_hit_len - prefix_len + suffix_len = request_prefix_len - target_hit_len + else: + usable_hit_len = prefix_len + residual_replay_len = 0 + suffix_len = request_prefix_len - prefix_len + return HybridAPCHitPlan( + attention_hit_len=attention_hit_len, + recurrent_hit_len=prefix_len, + conv_hit_len=prefix_len, + usable_hit_len=usable_hit_len, + restore_checkpoint_prefix_len=prefix_len, + residual_replay_len=residual_replay_len, + suffix_len=suffix_len, + checkpoint_slot=checkpoint.gdn_checkpoint_slot, + checkpoint_key=key, + ) + + return HybridAPCHitPlan( + attention_hit_len=attention_hit_len, + recurrent_hit_len=0, + conv_hit_len=0, + usable_hit_len=0, + restore_checkpoint_prefix_len=0, + residual_replay_len=0, + suffix_len=request_prefix_len, + checkpoint_slot=None, + checkpoint_key=None, + ) + + def _evict_over_budget(self): + if self.max_checkpoints is not None: + self.evict_lru(target_checkpoints=self.max_checkpoints) + if self.max_bytes is None: + return + evicted = 0 + for key, checkpoint in list(self._by_key.items()): + if self.bytes_used <= self.max_bytes: + break + if checkpoint.refcount > 0 or not checkpoint.evictable: + continue + self._delete_checkpoint(key) + evicted += 1 + self.stats.evictions += evicted + + def _delete_checkpoint(self, key: HybridPrefixKey): + checkpoint = self._by_key.pop(key, None) + if checkpoint is not None: + self._slot_to_key.pop(checkpoint.gdn_checkpoint_slot, None) + _unpublish_scheduler_gdn_checkpoint(key) + if self._checkpoint_slot_releaser is not None: + self._checkpoint_slot_releaser(checkpoint.gdn_checkpoint_slot) + + def _refresh_stats(self): + self.stats.checkpoints = len(self._by_key) + self.stats.bytes_used = self.bytes_used diff --git a/contrib/models/Qwen3.6-27B/src/modeling_qwen35.py b/contrib/models/Qwen3.6-27B/src/modeling_qwen35.py new file mode 100644 index 00000000..9cd06fc2 --- /dev/null +++ b/contrib/models/Qwen3.6-27B/src/modeling_qwen35.py @@ -0,0 +1,8040 @@ +""" +NxDI contrib: Qwen3.5-27B / Qwen3.6-27B (qwen3_5 -- dense model) + +Supports both Qwen3.5-27B and Qwen3.6-27B. These models share identical +architecture (qwen3_5 model_type). Qwen3.6-27B is a post-training update +with improved agentic coding and thinking preservation -- no architecture +changes, only weight differences. + +Hybrid DeltaNet + Standard Attention + Dense MLP architecture. +Adapted from Qwen3.5-35B-A3B (MoE) -- MoE removed, dense MLP added. + +48 of 64 layers use Gated DeltaNet (linear recurrent attention) +16 of 64 layers use standard GQA with KV cache + output gate +All 64 layers use a dense SwiGLU MLP (intermediate_size=17408) + +Architecture details: +- DeltaNet layers: separate in_proj_{qkv, z, a, b}, causal conv1d on QKV, gated delta rule +- Attention layers: q_proj doubled (Q + gate), partial RoPE (25% of head_dim), sigmoid output gate +- Dense MLP: standard SwiGLU (gate_proj, up_proj, down_proj) -- no MoE, no router, no experts +- KV cache: NxDI KVCacheManager for attention layers; DeltaNet layers store recurrent+conv + state as nn.Parameter buffers and return dummy KV tuples + +Config compatibility notes: +- Qwen3.6-27B adds output_gate_type="swish" to text_config. This field is + unused by both HF transformers and this NxDI code (gate uses sigmoid, as + confirmed across transformers v4.57.6, v5.6.0, and GitHub main). Safe to ignore. +""" + +import gc +import json +import math +import logging +import os +import re +import sys +import time +from typing import Any, Hashable, List, NamedTuple, Optional, Tuple + +import torch +import torch.nn.functional as F +from torch import nn + +from neuronx_distributed_inference.models.model_base import ( + NeuronBaseForCausalLM, + NeuronBaseModel, +) +from neuronx_distributed_inference.modules.async_execution import ( + cancel_hybrid_apc_request, + finish_hybrid_apc_request, + prepare_hybrid_apc_model_inputs, + prepare_hybrid_apc_request_for_execution, +) +from neuronx_distributed_inference.modules.custom_calls import CustomRMSNorm + +try: + from neuronxcc.nki._private_kernels.attention import attention_isa_kernel +except ImportError: + from neuronxcc.nki.kernels.attention import attention_isa_kernel + +from neuronx_distributed.parallel_layers import parallel_state +from neuronx_distributed.parallel_layers.layers import ( + ColumnParallelLinear, + ParallelEmbedding, + RowParallelLinear, +) +from neuronx_distributed.parallel_layers.mappings import _gather_along_dim +from neuronx_distributed.utils import cpu_mode + +try: + from nki import jit as nki_jit # NKI 0.3.0+ (SDK 2.29) +except ImportError: + from torch_neuronx.xla_impl.ops import nki_jit # NKI 0.2.x (SDK 2.28) +from transformers.models.qwen3_moe.modeling_qwen3_moe import Qwen3MoeRMSNorm + +from src.nki_kernels.nki_deltanet import deltanet_recurrent_fwd as _deltanet_nki_kernel +from src.nki_kernels.nki_deltanet import ( + deltanet_recurrent_fwd_state as _deltanet_nki_kernel_state, +) +from src.nki_kernels.nki_deltanet import ( + deltanet_recurrent_step_batched as _deltanet_nki_step_batched, +) +from src.nki_kernels.nki_deltanet_chunked import ( + deltanet_chunk_step as _deltanet_nki_chunk_step, +) +from src.nki_kernels.nki_deltanet_fused import ( + deltanet_autocp_affine_sequence as _deltanet_autocp_affine_sequence, + deltanet_autocp_apply_output as _deltanet_autocp_apply_output, + deltanet_autocp_prefix_apply_output as _deltanet_autocp_prefix_apply_output, + deltanet_autocp_state_summary_sequence as _deltanet_autocp_state_summary_sequence, + deltanet_autocp_state_prefix as _deltanet_autocp_state_prefix, + deltanet_fused_chunked_fwd as _deltanet_fused_kernel, + deltanet_fused_chunked_fwd_multihead as _deltanet_fused_multihead_kernel, +) +from src.nki_kernels.nki_deltanet_fused_legacy import ( + deltanet_fused_chunked_fwd as _deltanet_fused_legacy_direct_kernel, +) +from src.nki_kernels.nki_deltanet_fused import ( + _make_lower_mask, + _make_lower_mask_diag, + _make_identity, +) +try: + import nki as _nkilib_nki + from nkilib.core.qkv.qkv import qkv as _nkilib_qkv + from nkilib.core.utils.common_types import ( + NormType as _NkilibNormType, + QKVOutputLayout as _NkilibQKVOutputLayout, + QuantizationType as _NkilibQuantizationType, + ) + + _qwen_gate_projection_kernel = _nkilib_nki.jit(_nkilib_qkv) +except Exception: + _NkilibNormType = None + _NkilibQKVOutputLayout = None + _NkilibQuantizationType = None + _qwen_gate_projection_kernel = None + +try: + from src.nki_kernels.qwen_qk_norm_rope import ( + qwen_qk_norm_partial_rope_kernel as _qwen_qk_norm_partial_rope_kernel, + ) +except Exception: + _qwen_qk_norm_partial_rope_kernel = None +from src.hybrid_apc import ( + HybridAPCMetadataStore, + HybridAPCSchedulerBridge, + HybridAPCSlotAllocator, +) + + +def _infer_neuron_lnc(default: int = 1) -> int: + flags = os.environ.get("NEURON_CC_FLAGS", "") + match = re.search(r"(?:^|\s)--lnc(?:=|\s+)(\d+)", flags) + if match is None: + return default + return max(1, int(match.group(1))) + + +def _resolve_deltanet_multihead_group_size(total_heads: int) -> int: + lnc = _infer_neuron_lnc() + raw_group_size = os.environ.get("QWEN36_DELTANET_MULTIHEAD_GROUP_SIZE") + if raw_group_size is None: + requested_group_size = 2 if lnc >= 2 else 1 + else: + requested_group_size = max(1, int(raw_group_size)) + if requested_group_size > lnc: + raise ValueError( + f"QWEN36_DELTANET_MULTIHEAD_GROUP_SIZE={requested_group_size} " + f"requires NEURON_CC_FLAGS --lnc >= {requested_group_size}; " + f"inferred lnc={lnc}" + ) + return max(1, min(total_heads, requested_group_size)) + + +def _deltanet_multihead_launch_spec(num_heads: int): + """Return the launch spec for a grouped multihead DeltaNet CTE kernel. + + The legacy ``kernel[2]`` launch only covers two programs. For larger + grouped launches we need an SPMD axis distributed over the available NCs, + while each program still handles exactly one flattened (batch, head) row. + """ + lnc = _infer_neuron_lnc() + if num_heads <= lnc: + return num_heads + if os.environ.get("QWEN36_DELTANET_MULTIHEAD_SPMD", "1") == "0": + raise ValueError( + "QWEN36_DELTANET_MULTIHEAD_GROUP_SIZE exceeds inferred LNC but " + "QWEN36_DELTANET_MULTIHEAD_SPMD=0; " + f"group_size={num_heads}, inferred_lnc={lnc}" + ) + + import nki.language as _nl # Imported lazily so CPU-only unit stubs still load. + + if not hasattr(_nl, "spmd_dim") or not hasattr(_nl, "nc"): + if os.environ.get("QWEN36_DELTANET_MULTIHEAD_GRID_FALLBACK", "0") == "1": + return (num_heads, 1) + raise ValueError( + "QWEN36_DELTANET_MULTIHEAD_GROUP_SIZE exceeds inferred LNC, but " + "this NKI runtime does not expose spmd_dim/nc; " + f"group_size={num_heads}, inferred_lnc={lnc}" + ) + return (_nl.spmd_dim(num_heads, _nl.nc(lnc)),) + + +def _qwen35_grouped_prefix_attention( + Q, + K_cache, + V_cache, + query_positions, + cache_positions, + key_valid_mask=None, +): + """GQA-native prefix attention without materializing repeated KV heads.""" + B, q_heads, q_len, head_dim = Q.shape + kv_heads = K_cache.shape[1] + if q_heads % kv_heads != 0: + raise ValueError( + "Qwen grouped prefix attention requires q_heads to be divisible " + f"by kv_heads, got q_heads={q_heads}, kv_heads={kv_heads}." + ) + + q_per_kv = q_heads // kv_heads + if cache_positions.ndim == 4: + cache_positions = cache_positions.reshape(B, -1) + elif cache_positions.ndim != 2: + raise ValueError( + "cache_positions must have shape (B, K) or (B, 1, 1, K), " + f"got {tuple(cache_positions.shape)}." + ) + + if key_valid_mask is not None: + if key_valid_mask.ndim == 4: + key_valid_mask = key_valid_mask.reshape(B, -1) + elif key_valid_mask.ndim != 2: + raise ValueError( + "key_valid_mask must have shape (B, K) or (B, 1, 1, K), " + f"got {tuple(key_valid_mask.shape)}." + ) + + q_grouped = Q.reshape(B, kv_heads, q_per_kv, q_len, head_dim) + k_grouped = K_cache.transpose(-1, -2).unsqueeze(2) + attn_weights = torch.matmul(q_grouped, k_grouped) / math.sqrt(head_dim) + + causal_mask = cache_positions[:, None, None, None, :] <= query_positions[ + :, None, None, :, None + ] + if key_valid_mask is not None: + causal_mask = causal_mask & key_valid_mask[:, None, None, None, :] + attn_weights = attn_weights.masked_fill(~causal_mask, -65504.0) + attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(Q.dtype) + + attn_output = torch.matmul(attn_weights, V_cache.unsqueeze(2)) + return attn_output.reshape(B, q_heads, q_len, head_dim) + + +def _qwen35_expanded_prefix_attention( + Q, + K_cache, + V_cache, + query_positions, + cache_positions, + key_valid_mask=None, +): + B, q_heads, q_len, head_dim = Q.shape + kv_heads = K_cache.shape[1] + cache_len = K_cache.shape[2] + + if q_heads != kv_heads: + kv_rep = q_heads // kv_heads + K_full = ( + K_cache.unsqueeze(2) + .expand(-1, -1, kv_rep, -1, -1) + .reshape(B, q_heads, cache_len, head_dim) + ) + V_full = ( + V_cache.unsqueeze(2) + .expand(-1, -1, kv_rep, -1, -1) + .reshape(B, q_heads, cache_len, head_dim) + ) + else: + K_full = K_cache + V_full = V_cache + + attn_weights = torch.matmul(Q, K_full.transpose(-1, -2)) / math.sqrt(head_dim) + causal_mask = cache_positions <= query_positions[:, None, :, None] + if key_valid_mask is not None: + causal_mask = causal_mask & key_valid_mask + attn_weights = attn_weights.masked_fill(~causal_mask, -65504.0) + attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(Q.dtype) + return torch.matmul(attn_weights, V_full) + + +def _qwen36_prefix_attention_impl() -> str: + raw = os.environ.get("QWEN36_PREFIX_ATTENTION_IMPL", "grouped").strip().lower() + aliases = { + "grouped": "grouped", + "current": "grouped", + "expanded": "expanded", + "legacy": "expanded", + "legacy_expanded": "expanded", + } + if raw not in aliases: + raise ValueError( + "QWEN36_PREFIX_ATTENTION_IMPL must be grouped/current or " + f"expanded/legacy, got {raw!r}" + ) + return aliases[raw] + + +def _resolve_deltanet_autocp_lnc(num_chunks: int) -> int: + lnc = _infer_neuron_lnc() + raw_lnc = os.environ.get("QWEN36_DELTANET_AUTOCP_LNC") + if raw_lnc is None: + launch_lnc = 2 if lnc >= 2 and num_chunks % 2 == 0 else 1 + else: + launch_lnc = max(1, int(raw_lnc)) + if launch_lnc > lnc: + raise ValueError( + f"QWEN36_DELTANET_AUTOCP_LNC={launch_lnc} requires " + f"NEURON_CC_FLAGS --lnc >= {launch_lnc}; inferred lnc={lnc}" + ) + if launch_lnc not in (1, 2): + raise ValueError( + f"QWEN36_DELTANET_AUTOCP_LNC must be 1 or 2, got {launch_lnc}" + ) + if num_chunks % launch_lnc != 0: + raise ValueError( + "QWEN36_DELTANET_AUTOCP_CTE requires the number of 128-token " + f"chunks to be divisible by launch LNC; chunks={num_chunks}, " + f"launch_lnc={launch_lnc}" + ) + return launch_lnc + + +def _deltanet_autocp_affine_launch_spec(num_chunks: int, launch_lnc: int): + """Return a SPMD affine launch grid, falling back to legacy LNC split. + + Bare ``kernel[2]`` launches only two logical cores. For AutoCP affine + generation we need one independent program per 128-token chunk, sharded + across those logical cores. NKI represents that as a SPMD grid dimension + with an attached NC distribution. + """ + if os.environ.get("QWEN36_DELTANET_AUTOCP_SPMD_AFFINE", "1") == "0": + return launch_lnc + import nki.language as _nl # Imported lazily so CPU-only unit stubs still load. + + if not hasattr(_nl, "spmd_dim") or not hasattr(_nl, "nc"): + return launch_lnc + + if launch_lnc == 2: + return (_nl.spmd_dim(num_chunks, _nl.nc(2)), 1) + return (num_chunks, 1) + + +def _resolve_deltanet_autocp_cp_chunks(num_chunks: int) -> int: + cp_chunks = max(1, int(os.environ.get("QWEN36_DELTANET_AUTOCP_CP_CHUNKS", "4"))) + if num_chunks % cp_chunks != 0: + raise ValueError( + "QWEN36_DELTANET_COMPACT_AUTOCP_CTE requires the number of " + "128-token chunks to be divisible by QWEN36_DELTANET_AUTOCP_CP_CHUNKS; " + f"chunks={num_chunks}, cp_chunks={cp_chunks}" + ) + return cp_chunks + +from neuronx_distributed_inference.models.config import ( + InferenceConfig, + NeuronConfig, +) +from neuronx_distributed_inference.models.llama.modeling_llama import NeuronLlamaMLP +from neuronx_distributed_inference.models.model_wrapper import ( + CONTEXT_ENCODING_MODEL_TAG, + TOKEN_GENERATION_MODEL_TAG, + DecoderModelInstance, + ModelWrapper, +) +from neuronx_distributed_inference.modules.attention.attention_base import NeuronAttentionBase +from neuronx_distributed_inference.modules.attention.utils import ( + RotaryEmbedding, + move_heads_front, + transpose_parallel_linear_layer, +) +try: + from neuronx_distributed_inference.modules.attention.utils import ( + preprocess_quantized_linear_layer, + ) +except (ImportError, AttributeError): + def preprocess_quantized_linear_layer(layer): + return layer + +from neuronx_distributed_inference.modules.kvcache.block_kv_cache_manager import ( + BlockKVCacheManager, +) +from neuronx_distributed_inference.modules.kvcache.kv_cache_manager import KVCacheManager +from neuronx_distributed_inference.models.layer_boundary_marker import ( + ModuleMarkerEndWrapper, + ModuleMarkerStartWrapper, +) + +logger = logging.getLogger(__name__) + +try: + from neuronxcc.nki._pre_prod_kernels import ( + NormType as _QKVNormType, + QKVOutputLayout as _QKVOutputLayout, + QuantizationType as _QKVQuantizationType, + ) + from neuronxcc.nki._pre_prod_kernels.qkv_tkg_impl import ( + nki_qkv_projection_tkg_impl as _qkv_tkg_nki_kernel, + ) +except ImportError: + _QKVNormType = None + _QKVOutputLayout = None + _QKVQuantizationType = None + _qkv_tkg_nki_kernel = None + +try: + _flash_fwd_call = nki_jit()(attention_isa_kernel) +except TypeError: + from torch_neuronx.xla_impl.ops import nki_jit as _torch_xla_nki_jit + + _flash_fwd_call = _torch_xla_nki_jit()(attention_isa_kernel) + +# Option B: Direct nkilib flash attention for head_dim > 128 +USE_NKILIB_KERNEL = os.environ.get("USE_NKILIB_KERNEL", "0") == "1" + +_nkilib_flash_attn = None +if USE_NKILIB_KERNEL: + try: + import neuronxcc.nki as _nki + from neuronx_distributed_inference.modules.attention.attention_base import ( + peel_decorations as _peel_decorations, + get_platform_target as _get_platform_target, + ) + from neuronxcc.nki.compiler import ( + skip_middle_end_transformations as _skip_middle_end, + enable_stack_allocator as _enable_stack_allocator, + ) + + import importlib + + _fork_path = "/home/ubuntu/nki-library-fork/nkilib_src" + if os.path.isdir(_fork_path) and _fork_path not in sys.path: + sys.path.insert(0, _fork_path) + _to_remove = [k for k in sys.modules if k.startswith("nkilib")] + for k in _to_remove: + del sys.modules[k] + import nki.language as _stub_nl + import neuronxcc.nki.language as _real_nl + + for _attr in [ + "NKIObject", + "float8_e4m3fn", + "float8_e4m3fn_x4", + "float8_e5m2_x4", + "float4_e2m1fn_x4", + ]: + if not hasattr(_real_nl, _attr) and hasattr(_stub_nl, _attr): + setattr(_real_nl, _attr, getattr(_stub_nl, _attr)) + from nkilib.core.attention.attention_cte import ( + attention_cte as _attention_cte_raw, + _MAX_HEAD_DIM, + ) + + assert _MAX_HEAD_DIM == 256, ( + f"nkilib fork has _MAX_HEAD_DIM={_MAX_HEAD_DIM}, expected 256. " + f"System nkilib may have been loaded instead of fork." + ) + logger.info( + f"Loaded nkilib attention_cte from fork (_MAX_HEAD_DIM={_MAX_HEAD_DIM})" + ) + + _raw_fn = _peel_decorations(_attention_cte_raw) + os.environ.setdefault("NEURON_PLATFORM_TARGET_OVERRIDE", _get_platform_target()) + _nkilib_flash_attn = _nki.jit( + _raw_fn, + show_compiler_tb=True, + debug_kernel=True, + ) + _nkilib_flash_attn = _skip_middle_end(_nkilib_flash_attn) + _nkilib_flash_attn = _enable_stack_allocator( + _nkilib_flash_attn, log_level=logging.INFO + ) + logger.info("Option B: nkilib flash attention loaded for head_dim > 128") + except Exception as e: + logger.warning(f"Option B: Failed to load nkilib flash attention: {e}") + import traceback as _tb + + _tb.print_exc() + _nkilib_flash_attn = None + +# Option A: Detect if patch_attn_kernel was imported +NKILIB_PATCH_ACTIVE = False +try: + from importlib import import_module as _import_module + + _attn_mod = _import_module("neuronxcc.nki._pre_prod_kernels.attn_fwd") + if hasattr(_attn_mod, "_original_attention_nki_kernel_adapter"): + NKILIB_PATCH_ACTIVE = True + logger.info("Option A detected: _pre_prod_kernels patched with nkilib kernel") +except Exception: + pass + + +# ============================================================ +# Newton-Raphson Refined RMSNorm +# ============================================================ +USE_NEWTON_RMSNORM = os.environ.get("USE_NEWTON_RMSNORM") == "1" +USE_PYTHON_RMSNORM = os.environ.get("USE_PYTHON_RMSNORM") == "1" + + +class NewtonRMSNorm(nn.Module): + """RMSNorm with Newton-Raphson refined rsqrt for improved numerical accuracy.""" + + def __init__(self, hidden_size=None, eps=1e-6): + super().__init__() + self.weight = None + if hidden_size is not None: + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.hidden_size = hidden_size + self.variance_epsilon = eps + + def forward(self, hidden_states): + original_dtype = hidden_states.dtype + x = hidden_states.to(torch.float32) + variance = x.pow(2).mean(-1, keepdim=True) + y = torch.rsqrt(variance + self.variance_epsilon) + y = y * (3.0 - (variance + self.variance_epsilon) * y * y) * 0.5 + result = x * y + if self.weight is not None: + result = result * self.weight.float() + return result.to(original_dtype) + + +def get_rmsnorm_cls(): + if cpu_mode() or USE_PYTHON_RMSNORM: + return Qwen3MoeRMSNorm + return NewtonRMSNorm if USE_NEWTON_RMSNORM else CustomRMSNorm + + +def l2norm(x, dim=-1, eps=1e-6): + return F.normalize(x, p=2, dim=dim, eps=eps) + + +class GDNAPCReusePlan(NamedTuple): + """Exact hybrid-APC reuse plan for attention KV plus GDN checkpoints.""" + + attention_hit_len: int + recurrent_hit_len: int + conv_hit_len: int + reusable_prefix_len: int + restore_checkpoint_prefix_len: int + residual_replay_len: int + suffix_len: int + + +def _non_negative_len(name: str, value: int) -> int: + value = int(value) + if value < 0: + raise ValueError(f"{name} must be non-negative, got {value}") + return value + + +def _normalize_hybrid_cache_dtype(name: str, value, default: str) -> str: + if value is None: + value = default + if isinstance(value, torch.dtype): + if value == torch.float32: + return "float32" + if value == torch.bfloat16: + return "bfloat16" + normalized = str(value).lower() + aliases = { + "fp32": "float32", + "float32": "float32", + "torch.float32": "float32", + "bf16": "bfloat16", + "bfloat16": "bfloat16", + "torch.bfloat16": "bfloat16", + } + if normalized not in aliases: + raise ValueError( + f"{name} must be one of fp32/float32 or bf16/bfloat16, got {value}" + ) + return aliases[normalized] + + +def _torch_dtype_from_hybrid_cache_dtype(value: str) -> torch.dtype: + value = _normalize_hybrid_cache_dtype("hybrid cache dtype", value, "bfloat16") + if value == "float32": + return torch.float32 + if value == "bfloat16": + return torch.bfloat16 + raise AssertionError(f"unexpected hybrid cache dtype {value}") + + +def plan_gdn_apc_reuse( + *, + attention_hit_len: int, + recurrent_hit_len: int, + conv_hit_len: int, + request_prefix_len: int, + gdn_checkpoint_interval: int, +) -> GDNAPCReusePlan: + """Plan exact prefix reuse for Qwen hybrid APC. + + Attention KV can be reused up to the vLLM APC hit, but DeltaNet can only + resume exactly from a boundary with both recurrent and conv checkpoint + state. When the attention hit is inside a GDN interval, restore the nearest + earlier checkpoint and replay the residual tokens before running the suffix. + """ + attention_hit_len = _non_negative_len("attention_hit_len", attention_hit_len) + recurrent_hit_len = _non_negative_len("recurrent_hit_len", recurrent_hit_len) + conv_hit_len = _non_negative_len("conv_hit_len", conv_hit_len) + request_prefix_len = _non_negative_len("request_prefix_len", request_prefix_len) + gdn_checkpoint_interval = int(gdn_checkpoint_interval) + if gdn_checkpoint_interval <= 0: + raise ValueError( + f"gdn_checkpoint_interval must be positive, got {gdn_checkpoint_interval}" + ) + + reusable_prefix_len = min( + attention_hit_len, + recurrent_hit_len, + conv_hit_len, + request_prefix_len, + ) + restore_checkpoint_prefix_len = ( + reusable_prefix_len // gdn_checkpoint_interval + ) * gdn_checkpoint_interval + residual_replay_len = reusable_prefix_len - restore_checkpoint_prefix_len + suffix_len = request_prefix_len - reusable_prefix_len + + return GDNAPCReusePlan( + attention_hit_len=attention_hit_len, + recurrent_hit_len=recurrent_hit_len, + conv_hit_len=conv_hit_len, + reusable_prefix_len=reusable_prefix_len, + restore_checkpoint_prefix_len=restore_checkpoint_prefix_len, + residual_replay_len=residual_replay_len, + suffix_len=suffix_len, + ) + + +# ============================================================ +# Gated DeltaNet Module (Linear Recurrent Attention) +# ============================================================ + + +class NeuronGatedDeltaNet(nn.Module): + """ + Gated DeltaNet linear attention for Neuron. + + Replaces standard attention for 48 of 64 layers in Qwen3.5/3.6-27B. + Uses a chunk-based linear recurrence instead of KV cache. + + HF weight layout (27B dense -- scaled dimensions): + - in_proj_qkv.weight: (key_dim*2 + value_dim, hidden_size) = (10240, 5120) + - in_proj_z.weight: (value_dim, hidden_size) = (6144, 5120) + - in_proj_a.weight: (num_v_heads, hidden_size) = (48, 5120) + - in_proj_b.weight: (num_v_heads, hidden_size) = (48, 5120) + - conv1d.weight: (conv_dim, 1, conv_kernel_size) = (10240, 1, 4) + - A_log: (num_v_heads,) = (48,) + - dt_bias: (num_v_heads,) = (48,) + - norm.weight: (head_v_dim,) = (128,) + - out_proj.weight: (hidden_size, value_dim) = (5120, 6144) + """ + + def __init__(self, config, layer_idx: int): + super().__init__() + tc = config + + self.hidden_size = tc.hidden_size # 5120 + self.tp_degree = tc.neuron_config.tp_degree + self.global_num_v_heads = tc.linear_num_value_heads # 48 + self.global_num_k_heads = tc.linear_num_key_heads # 16 + self.head_k_dim = tc.linear_key_head_dim # 128 + self.head_v_dim = tc.linear_value_head_dim # 128 + if self.global_num_v_heads % self.tp_degree != 0: + raise ValueError( + f"linear_num_value_heads={self.global_num_v_heads} must be divisible " + f"by tp_degree={self.tp_degree}" + ) + if self.global_num_k_heads % self.tp_degree != 0: + raise ValueError( + f"linear_num_key_heads={self.global_num_k_heads} must be divisible " + f"by tp_degree={self.tp_degree}" + ) + self.num_v_heads = self.global_num_v_heads // self.tp_degree + self.num_k_heads = self.global_num_k_heads // self.tp_degree + self.global_key_dim = self.head_k_dim * self.global_num_k_heads # 2048 + self.global_value_dim = self.head_v_dim * self.global_num_v_heads # 6144 + self.key_dim = self.head_k_dim * self.num_k_heads # 512 at TP=4 + self.value_dim = self.head_v_dim * self.num_v_heads # 1536 at TP=4 + self.conv_kernel_size = tc.linear_conv_kernel_dim # 4 + self.layer_idx = layer_idx + self.rms_norm_eps = tc.rms_norm_eps + self.use_hybrid_cache_manager = getattr(tc, "use_hybrid_cache_manager", False) + self.use_hybrid_apc_manager = getattr(tc, "use_hybrid_apc_manager", False) + self.use_qwen_hybrid_chunked_prefill = getattr( + tc, "use_qwen_hybrid_chunked_prefill", False + ) + self.use_qwen_hybrid_chunked_prefill_nki = getattr( + tc, "use_qwen_hybrid_chunked_prefill_nki", False + ) + self.use_qwen_deltanet_decode_nki = getattr( + tc, "use_qwen_deltanet_decode_nki", False + ) + self.use_cold_zero_conv_fast_path = getattr( + tc, "use_cold_zero_conv_fast_path", False + ) + + # KV cache dummy shape info + self.head_dim = tc.head_dim # 256 + tp_degree = tc.neuron_config.tp_degree + raw_kv_heads = tc.num_key_value_heads + if raw_kv_heads < tp_degree: + replicated_kv_heads = tp_degree + else: + replicated_kv_heads = raw_kv_heads + self.kv_heads_per_rank = replicated_kv_heads // tp_degree + + # Conv1d on concatenated QKV (NOT Z). Store the depthwise kernel in a + # ColumnParallelLinear parameter container so NxD's checkpoint sharder + # can split it by output channel. Forward still uses it as Conv1d + # weight after unsqueezing the singleton input-channel dimension. + self.global_conv_dim = self.global_key_dim * 2 + self.global_value_dim # 10240 + self.conv_dim = self.key_dim * 2 + self.value_dim # 2560 at TP=4 + self.conv1d_weight = ColumnParallelLinear( + self.conv_kernel_size, + self.global_conv_dim, + bias=False, + gather_output=False, + ) + + # Input/output projections are the large DeltaNet tensors. Shard them + # with tensor parallelism; convert_qwen35_hf_to_neuron_state_dict() + # reorders in_proj_qkv into per-rank [Q_local | K_local | V_local] + # blocks before NxD slices the output dimension. + self.in_proj_qkv = ColumnParallelLinear( + self.hidden_size, + self.global_key_dim * 2 + self.global_value_dim, + bias=False, + gather_output=False, + ) + self.in_proj_z = ColumnParallelLinear( + self.hidden_size, + self.global_value_dim, + bias=False, + gather_output=False, + ) + self.in_proj_b = ColumnParallelLinear( + self.hidden_size, + self.global_num_v_heads, + bias=False, + gather_output=False, + ) + self.in_proj_a = ColumnParallelLinear( + self.hidden_size, + self.global_num_v_heads, + bias=False, + gather_output=False, + ) + + # Same parameter-container pattern for per-value-head decay vectors. + # These are used as vectors in forward but sharded by output dim during + # checkpoint conversion/loading. + self.dt_bias_weight = ColumnParallelLinear( + 1, + self.global_num_v_heads, + bias=False, + gather_output=False, + ) + self.A_log_weight = ColumnParallelLinear( + 1, + self.global_num_v_heads, + bias=False, + gather_output=False, + ) + + # Output norm and projection + self.norm = Qwen3MoeRMSNorm(self.head_v_dim, eps=self.rms_norm_eps) + self.out_proj = RowParallelLinear( + self.global_value_dim, + self.hidden_size, + bias=False, + input_is_parallel=True, + ) + + # State buffers for CTE -> TKG carry-over + alloc_batch_size = getattr(config.neuron_config, "max_batch_size", 1) + self._phase_batch_size = getattr(config.neuron_config, "batch_size", 1) + recurrent_buffer_dtype = ( + _torch_dtype_from_hybrid_cache_dtype(config.hybrid_recurrent_cache_dtype) + if self.use_hybrid_apc_manager + else config.neuron_config.torch_dtype + ) + conv_buffer_dtype = ( + _torch_dtype_from_hybrid_cache_dtype(config.hybrid_conv_cache_dtype) + if self.use_hybrid_apc_manager + else config.neuron_config.torch_dtype + ) + self.recurrent_state_buffer = nn.Parameter( + torch.zeros( + alloc_batch_size, + self.num_v_heads, + self.head_k_dim, + self.head_v_dim, + dtype=recurrent_buffer_dtype, + ), + requires_grad=False, + ) + self.conv_state_buffer = nn.Parameter( + torch.zeros( + alloc_batch_size, + self.conv_dim, + self.conv_kernel_size - 1, + dtype=conv_buffer_dtype, + ), + requires_grad=False, + ) + + def _conv1d_weight(self): + return self.conv1d_weight.weight.unsqueeze(1) + + def _dt_bias(self): + return self.dt_bias_weight.weight.squeeze(1) + + def _A_log(self): + return self.A_log_weight.weight.squeeze(1) + + def _recurrent_step(self, query, key, value, g, beta, recurrent_state): + """Single-step recurrent update for token generation.""" + query = l2norm(query, dim=-1) + key = l2norm(key, dim=-1) + scale = 1.0 / (query.shape[-1] ** 0.5) + query = query * scale + + q_t = query[:, :, 0] + k_t = key[:, :, 0] + v_t = value[:, :, 0] + g_t = g[:, :, 0].exp().unsqueeze(-1).unsqueeze(-1) + beta_t = beta[:, :, 0].unsqueeze(-1) + + new_state = recurrent_state * g_t + kv_mem = (new_state * k_t.unsqueeze(-1)).sum(dim=-2) + delta = (v_t - kv_mem) * beta_t + new_state = new_state + k_t.unsqueeze(-1) * delta.unsqueeze(-2) + output = (new_state * q_t.unsqueeze(-1)).sum(dim=-2) + + return output.unsqueeze(2), new_state + + def _nki_recurrent_step(self, query, key, value, g, beta, recurrent_state): + """Single-step recurrent update using the stateful NKI decode kernel.""" + query = l2norm(query, dim=-1) + key = l2norm(key, dim=-1) + B, H, S, k_dim = query.shape + v_dim = value.shape[-1] + scale = 1.0 / (k_dim**0.5) + query = query * scale + + BH = B * H + query_flat = query.reshape(BH, S, k_dim)[:, 0, :].contiguous() + key_flat = key.reshape(BH, S, k_dim)[:, 0, :].contiguous() + value_flat = value.reshape(BH, S, v_dim)[:, 0, :].contiguous() + g_flat = g.reshape(BH, S)[:, 0:1].contiguous() + beta_flat = beta.reshape(BH, S)[:, 0:1].contiguous() + state_flat = recurrent_state.reshape(BH * k_dim, v_dim).contiguous() + + output_flat, state_flat_out = _deltanet_nki_step_batched( + query_flat, + key_flat, + value_flat, + g_flat, + beta_flat, + state_flat, + ) + + output = output_flat.reshape(B, H, S, v_dim) + new_state = state_flat_out.reshape(B, H, k_dim, v_dim) + + return output, new_state + + def _nki_recurrent_forward(self, query, key, value, g, beta): + """Full-sequence recurrent forward using NKI kernel for context encoding.""" + query = l2norm(query, dim=-1) + key = l2norm(key, dim=-1) + B, H, S, k_dim = query.shape + v_dim = value.shape[-1] + scale = 1.0 / (k_dim**0.5) + query = query * scale + + BH = B * H + query_flat = query.reshape(BH, S, k_dim).contiguous() + key_flat = key.reshape(BH, S, k_dim).contiguous() + value_flat = value.reshape(BH, S, v_dim).contiguous() + + g_flat = g.reshape(BH, S).unsqueeze(-1).expand(-1, -1, v_dim).contiguous() + beta_flat = beta.reshape(BH, S).unsqueeze(-1).expand(-1, -1, v_dim).contiguous() + + outputs = [] + states = [] + for bh in range(BH): + out_bh, state_bh = _deltanet_nki_kernel_state( + query_flat[bh], + key_flat[bh], + value_flat[bh], + g_flat[bh], + beta_flat[bh], + ) + outputs.append(out_bh) + states.append(state_bh) + + output = torch.stack(outputs, dim=0) + output = output.reshape(B, H, S, v_dim) + + final_state = torch.stack(states, dim=0) + final_state = final_state.reshape(B, H, k_dim, v_dim) + + return output, final_state + + def _nki_chunked_forward( + self, query, key, value, g, beta, output_final_state=False, initial_state=None + ): + """Chunked NKI kernel forward for context encoding (prefill).""" + chunk_size = 128 + + query = l2norm(query, dim=-1) + key = l2norm(key, dim=-1) + B, H, S, k_dim = query.shape + v_dim = value.shape[-1] + scale = 1.0 / (k_dim**0.5) + query = query * scale + + pad_size = (chunk_size - S % chunk_size) % chunk_size + if pad_size > 0: + query = F.pad(query, (0, 0, 0, pad_size)) + key = F.pad(key, (0, 0, 0, pad_size)) + value = F.pad(value, (0, 0, 0, pad_size)) + beta = F.pad(beta, (0, pad_size)) + g = F.pad(g, (0, pad_size)) + total_seq_len = S + pad_size + + num_chunks = total_seq_len // chunk_size + g_reshaped = g.reshape(B, H, num_chunks, chunk_size) + g_cs = g_reshaped.cumsum(dim=-1) + g_last_per_chunk = g_cs[:, :, :, -1:] + g_last_expanded = g_last_per_chunk.expand(-1, -1, -1, chunk_size) + + query_chunks = query.reshape(B, H, num_chunks, chunk_size, k_dim) + key_chunks = key.reshape(B, H, num_chunks, chunk_size, k_dim) + value_chunks = value.reshape(B, H, num_chunks, chunk_size, v_dim) + + beta_chunks = ( + beta.reshape(B, H, num_chunks, chunk_size) + .unsqueeze(-1) + .expand(-1, -1, -1, -1, v_dim) + ) + gc_chunks = g_cs.unsqueeze(-1).expand(-1, -1, -1, -1, v_dim) + gl_chunks = g_last_expanded.unsqueeze(-1).expand(-1, -1, -1, -1, v_dim) + + BH = B * H + query_chunks = query_chunks.reshape( + BH, num_chunks, chunk_size, k_dim + ).contiguous() + key_chunks = key_chunks.reshape(BH, num_chunks, chunk_size, k_dim).contiguous() + value_chunks = value_chunks.reshape( + BH, num_chunks, chunk_size, v_dim + ).contiguous() + beta_chunks = beta_chunks.reshape( + BH, num_chunks, chunk_size, v_dim + ).contiguous() + gc_chunks = gc_chunks.reshape(BH, num_chunks, chunk_size, v_dim).contiguous() + gl_chunks = gl_chunks.reshape(BH, num_chunks, chunk_size, v_dim).contiguous() + + device = query.device + lower_mask = torch.tril( + torch.ones(chunk_size, chunk_size, dtype=torch.float32, device=device), + diagonal=-1, + ) + identity_mat = torch.eye(chunk_size, dtype=torch.float32, device=device) + lower_mask_diag = torch.tril( + torch.ones(chunk_size, chunk_size, dtype=torch.float32, device=device), + diagonal=0, + ) + + initial_state_flat = None + if initial_state is not None: + initial_state_flat = initial_state.reshape(BH, k_dim, v_dim).float().contiguous() + + all_outputs = [] + all_states = [] + for bh in range(BH): + if initial_state_flat is None: + state = torch.zeros(k_dim, v_dim, dtype=torch.float32, device=device) + else: + state = initial_state_flat[bh] + + head_chunks = [] + for c_idx in range(num_chunks): + q_chunk = query_chunks[bh, c_idx].contiguous() + k_chunk = key_chunks[bh, c_idx].contiguous() + v_chunk = value_chunks[bh, c_idx].contiguous() + beta_chunk = beta_chunks[bh, c_idx].contiguous() + gc_chunk = gc_chunks[bh, c_idx].contiguous() + gl_chunk = gl_chunks[bh, c_idx].contiguous() + + out_chunk, state = _deltanet_nki_chunk_step( + q_chunk, + k_chunk, + v_chunk, + beta_chunk, + gc_chunk, + gl_chunk, + state, + lower_mask, + identity_mat, + lower_mask_diag, + ) + head_chunks.append(out_chunk) + + head_output = torch.cat(head_chunks, dim=0) + all_outputs.append(head_output) + all_states.append(state) + + output = torch.stack(all_outputs, dim=0) + output = output.reshape(B, H, total_seq_len, v_dim) + output = output[:, :, :S] + + if output_final_state: + final_state = torch.stack(all_states, dim=0) + last_recurrent_state = final_state.reshape(B, H, k_dim, v_dim) + else: + last_recurrent_state = None + + return output, last_recurrent_state + + def _fused_chunked_forward( + self, + query, + key, + value, + g, + beta, + output_final_state=False, + initial_state=None, + _segment_disabled=False, + ): + """Fused single-kernel chunked forward for CTE — SSD-style. + + Processes all chunks in a single NKI kernel call per (B,H) pair. + State persists in SBUF across chunks (no HBM round-trips). + Cumsum of g computed in-kernel via tensor_tensor_scan. + + This is the optimized version of _nki_chunked_forward with: + 1. Single kernel call per (B,H) instead of B*H*num_chunks + 2. State in SBUF across all chunks (biggest perf win) + 3. In-kernel cumsum (avoids PyTorch cumsum overhead) + 4. tensor_scalar for broadcasts (no explicit loops) + + initial_state is the restored GDN recurrent checkpoint for warm or + partial-prefix suffix prefill. Cold prefill passes zeros. + """ + chunk_size = int(os.environ.get("QWEN36_DELTANET_CHUNK_SIZE", "128")) + if chunk_size not in (64, 128): + raise ValueError( + "QWEN36_DELTANET_CHUNK_SIZE must be 64 or 128 for fused CTE; " + f"got {chunk_size}" + ) + + cte_impl = os.environ.get("QWEN36_DELTANET_CTE_IMPL", "current").lower() + if cte_impl in ("legacy", "legacy_direct", "direct"): + use_legacy_direct_cte = True + elif cte_impl in ("current", "optimized"): + use_legacy_direct_cte = False + else: + raise ValueError( + "QWEN36_DELTANET_CTE_IMPL must be current or legacy_direct; " + f"got {cte_impl!r}" + ) + if use_legacy_direct_cte and chunk_size != 128: + raise ValueError( + "QWEN36_DELTANET_CTE_IMPL=legacy_direct requires " + f"QWEN36_DELTANET_CHUNK_SIZE=128; got {chunk_size}" + ) + + B, H, S, k_dim = query.shape + v_dim = value.shape[-1] + + # Pad sequence to multiple of chunk_size + pad_size = (chunk_size - S % chunk_size) % chunk_size + if pad_size > 0: + query = F.pad(query, (0, 0, 0, pad_size)) + key = F.pad(key, (0, 0, 0, pad_size)) + value = F.pad(value, (0, 0, 0, pad_size)) + beta = F.pad(beta, (0, pad_size)) + g = F.pad(g, (0, pad_size)) + total_seq_len = S + pad_size + + segment_tokens = int( + os.environ.get("QWEN36_DELTANET_FUSED_SEGMENT_TOKENS", "0") or "0" + ) + if ( + not _segment_disabled + and segment_tokens > 0 + and total_seq_len > segment_tokens + ): + if segment_tokens < chunk_size or segment_tokens % chunk_size != 0: + raise ValueError( + "QWEN36_DELTANET_FUSED_SEGMENT_TOKENS must be a positive " + "multiple of QWEN36_DELTANET_CHUNK_SIZE; " + f"got segment_tokens={segment_tokens}, chunk_size={chunk_size}" + ) + segment_outputs = [] + state = initial_state + for start in range(0, total_seq_len, segment_tokens): + end = min(start + segment_tokens, total_seq_len) + segment_output, state = self._fused_chunked_forward( + query[:, :, start:end, :], + key[:, :, start:end, :], + value[:, :, start:end, :], + g[:, :, start:end], + beta[:, :, start:end], + output_final_state=True, + initial_state=state, + _segment_disabled=True, + ) + segment_outputs.append(segment_output) + output = torch.cat(segment_outputs, dim=2)[:, :, :S, :] + return output, state if output_final_state else None + + if use_legacy_direct_cte: + query = l2norm(query, dim=-1) + key = l2norm(key, dim=-1) + query = query * (1.0 / (k_dim ** 0.5)) + + BH = B * H + # Flatten to (BH, S, dim). Grouped multihead launches are opt-in + # because isolated validation must pass before using them in artifacts. + query_flat = query.reshape(BH, total_seq_len, k_dim).contiguous() + key_flat = key.reshape(BH, total_seq_len, k_dim).contiguous() + value_flat = value.reshape(BH, total_seq_len, v_dim).contiguous() + + # g and beta: (BH, S) -> (BH, S, 1) for the kernel's (S, 1) input layout + g_flat = g.reshape(BH, total_seq_len).unsqueeze(-1).contiguous() + beta_flat = beta.reshape(BH, total_seq_len).unsqueeze(-1).contiguous() + if initial_state is None: + initial_state_flat = torch.zeros( + BH, k_dim, v_dim, dtype=torch.float32, device=query.device + ) + else: + initial_state_flat = initial_state.reshape(BH, k_dim, v_dim).float().contiguous() + + # Create constant mask tensors (shared across all B*H calls) + device = query.device + lower_mask = torch.tensor( + _make_lower_mask(), dtype=torch.float32, device=device + ) + identity_mat = torch.tensor( + _make_identity(), dtype=torch.float32, device=device + ) + lower_mask_diag = torch.tensor( + _make_lower_mask_diag(), dtype=torch.float32, device=device + ) + + use_multihead_cte = ( + not use_legacy_direct_cte + and os.environ.get("QWEN36_DELTANET_MULTIHEAD_CTE", "1") != "0" + ) + if use_multihead_cte: + pair_outputs = [] + pair_states = [] + head_group_size = _resolve_deltanet_multihead_group_size(BH) + for bh_start in range(0, BH, head_group_size): + bh_end = min(bh_start + head_group_size, BH) + launch_heads = bh_end - bh_start + launch_spec = _deltanet_multihead_launch_spec(launch_heads) + out_pair, state_pair = _deltanet_fused_multihead_kernel[launch_spec]( + query_flat[bh_start:bh_end], # (G, S, 128) + key_flat[bh_start:bh_end], # (G, S, 128) + value_flat[bh_start:bh_end], # (G, S, 128) + g_flat[bh_start:bh_end], # (G, S, 1) — RAW g, not cumsum + beta_flat[bh_start:bh_end], # (G, S, 1) — sigmoid(b) + initial_state_flat[bh_start:bh_end], + lower_mask, # (128, 128) + identity_mat, # (128, 128) + lower_mask_diag, # (128, 128) + ) + pair_outputs.append(out_pair) + pair_states.append(state_pair) + + output = torch.cat(pair_outputs, dim=0) + final_state = torch.cat(pair_states, dim=0) + else: + fused_singlehead_kernel = ( + _deltanet_fused_legacy_direct_kernel + if use_legacy_direct_cte + else _deltanet_fused_kernel + ) + all_outputs = [] + all_states = [] + for bh in range(BH): + out_bh, state_bh = fused_singlehead_kernel( + query_flat[bh], # (S, 128) + key_flat[bh], # (S, 128) + value_flat[bh], # (S, 128) + g_flat[bh], # (S, 1) — RAW g, not cumsum + beta_flat[bh], # (S, 1) — sigmoid(b) + initial_state_flat[bh], # (128, 128) recurrent checkpoint + lower_mask, # (128, 128) + identity_mat, # (128, 128) + lower_mask_diag, # (128, 128) + ) + all_outputs.append(out_bh) + all_states.append(state_bh) + + output = torch.stack(all_outputs, dim=0) + final_state = torch.stack(all_states, dim=0) + + output = output.reshape(B, H, total_seq_len, v_dim) + output = output[:, :, :S] + + if output_final_state: + last_recurrent_state = final_state.reshape(B, H, k_dim, v_dim) + else: + last_recurrent_state = None + + return output, last_recurrent_state + + def _compact_autocp_chunked_forward( + self, query, key, value, g, beta, output_final_state=False, initial_state=None + ): + """Compact AutoCP CTE probe: prefix segment state summaries, replay segments. + + Compared with ``_autocp_chunked_forward``, this avoids materializing + per-chunk output-affine tensors. It is intentionally opt-in because the + first version reuses the existing recurrent fused kernel for segment + replay; a later NKI replay kernel can collapse the segment loop. + """ + chunk_size = 128 + + B, H, S, k_dim = query.shape + v_dim = value.shape[-1] + if k_dim != 128 or v_dim != 128: + raise ValueError( + "QWEN36_DELTANET_COMPACT_AUTOCP_CTE requires 128-wide " + f"key/value heads; got k_dim={k_dim}, v_dim={v_dim}" + ) + + pad_size = (chunk_size - S % chunk_size) % chunk_size + if pad_size > 0: + query = F.pad(query, (0, 0, 0, pad_size)) + key = F.pad(key, (0, 0, 0, pad_size)) + value = F.pad(value, (0, 0, 0, pad_size)) + beta = F.pad(beta, (0, pad_size)) + g = F.pad(g, (0, pad_size)) + total_seq_len = S + pad_size + num_chunks = total_seq_len // chunk_size + if num_chunks <= 0: + raise ValueError("QWEN36_DELTANET_COMPACT_AUTOCP_CTE requires chunks") + cp_chunks = _resolve_deltanet_autocp_cp_chunks(num_chunks) + num_segments = num_chunks // cp_chunks + launch_lnc = _resolve_deltanet_autocp_lnc(num_segments) + summary_launch_spec = _deltanet_autocp_affine_launch_spec( + num_segments, + launch_lnc, + ) + + BH = B * H + query_flat = query.reshape(BH, total_seq_len, k_dim).contiguous() + key_flat = key.reshape(BH, total_seq_len, k_dim).contiguous() + value_flat = value.reshape(BH, total_seq_len, v_dim).contiguous() + g_flat = g.reshape(BH, total_seq_len).unsqueeze(-1).contiguous() + beta_flat = beta.reshape(BH, total_seq_len).unsqueeze(-1).contiguous() + if initial_state is None: + initial_state_flat = torch.zeros( + BH, k_dim, v_dim, dtype=torch.float32, device=query.device + ) + else: + initial_state_flat = initial_state.reshape(BH, k_dim, v_dim).float().contiguous() + + device = query.device + lower_mask = torch.tensor( + _make_lower_mask(), dtype=torch.float32, device=device + ) + identity_mat = torch.tensor( + _make_identity(), dtype=torch.float32, device=device + ) + lower_mask_diag = torch.tensor( + _make_lower_mask_diag(), dtype=torch.float32, device=device + ) + + segment_len = cp_chunks * chunk_size + all_outputs = [] + all_states = [] + for bh in range(BH): + segment_matrix, segment_bias = ( + _deltanet_autocp_state_summary_sequence[summary_launch_spec]( + key_flat[bh], + value_flat[bh], + g_flat[bh], + beta_flat[bh], + lower_mask, + identity_mat, + ) + ) + segment_states, final_state = _deltanet_autocp_state_prefix( + segment_matrix, + segment_bias, + initial_state_flat[bh], + ) + + q_segments = query_flat[bh].reshape(num_segments, segment_len, k_dim).contiguous() + k_segments = key_flat[bh].reshape(num_segments, segment_len, k_dim).contiguous() + v_segments = value_flat[bh].reshape(num_segments, segment_len, v_dim).contiguous() + g_segments = g_flat[bh].reshape(num_segments, segment_len, 1).contiguous() + beta_segments = beta_flat[bh].reshape(num_segments, segment_len, 1).contiguous() + + replay_group_size = _resolve_deltanet_multihead_group_size(num_segments) + replay_outputs = [] + for segment_start in range(0, num_segments, replay_group_size): + segment_end = min(segment_start + replay_group_size, num_segments) + launch_segments = segment_end - segment_start + replay_launch_spec = _deltanet_multihead_launch_spec(launch_segments) + out_group, _ = _deltanet_fused_multihead_kernel[replay_launch_spec]( + q_segments[segment_start:segment_end], + k_segments[segment_start:segment_end], + v_segments[segment_start:segment_end], + g_segments[segment_start:segment_end], + beta_segments[segment_start:segment_end], + segment_states[segment_start:segment_end], + lower_mask, + identity_mat, + lower_mask_diag, + ) + replay_outputs.append(out_group) + out_segments = torch.cat(replay_outputs, dim=0) + + all_outputs.append(out_segments.reshape(total_seq_len, v_dim)) + all_states.append(final_state) + + output = torch.stack(all_outputs, dim=0) + output = output.reshape(B, H, total_seq_len, v_dim) + output = output[:, :, :S] + + if output_final_state: + final_state = torch.stack(all_states, dim=0) + last_recurrent_state = final_state.reshape(B, H, k_dim, v_dim) + else: + last_recurrent_state = None + + return output, last_recurrent_state + + def _autocp_chunked_forward( + self, query, key, value, g, beta, output_final_state=False, initial_state=None + ): + """FlashQLA-style AutoCP CTE path for exact GDN prefill probes. + + This path decomposes each 128-token chunk into an affine state transform, + scans chunk states, then applies the per-chunk initial state to outputs. + It is gated by QWEN36_DELTANET_AUTOCP_CTE while we measure whether the + extra custom-call/HBM traffic beats the recurrent fused path. + """ + chunk_size = 128 + + B, H, S, k_dim = query.shape + v_dim = value.shape[-1] + if k_dim != 128 or v_dim != 128: + raise ValueError( + "QWEN36_DELTANET_AUTOCP_CTE requires 128-wide key/value heads; " + f"got k_dim={k_dim}, v_dim={v_dim}" + ) + + pad_size = (chunk_size - S % chunk_size) % chunk_size + if pad_size > 0: + query = F.pad(query, (0, 0, 0, pad_size)) + key = F.pad(key, (0, 0, 0, pad_size)) + value = F.pad(value, (0, 0, 0, pad_size)) + beta = F.pad(beta, (0, pad_size)) + g = F.pad(g, (0, pad_size)) + total_seq_len = S + pad_size + num_chunks = total_seq_len // chunk_size + if num_chunks <= 0: + raise ValueError("QWEN36_DELTANET_AUTOCP_CTE requires at least one chunk") + launch_lnc = _resolve_deltanet_autocp_lnc(num_chunks) + + BH = B * H + query_flat = query.reshape(BH, total_seq_len, k_dim).contiguous() + key_flat = key.reshape(BH, total_seq_len, k_dim).contiguous() + value_flat = value.reshape(BH, total_seq_len, v_dim).contiguous() + g_flat = g.reshape(BH, total_seq_len).unsqueeze(-1).contiguous() + beta_flat = beta.reshape(BH, total_seq_len).unsqueeze(-1).contiguous() + if initial_state is None: + initial_state_flat = torch.zeros( + BH, k_dim, v_dim, dtype=torch.float32, device=query.device + ) + else: + initial_state_flat = initial_state.reshape(BH, k_dim, v_dim).float().contiguous() + + device = query.device + lower_mask = torch.tensor( + _make_lower_mask(), dtype=torch.float32, device=device + ) + identity_mat = torch.tensor( + _make_identity(), dtype=torch.float32, device=device + ) + lower_mask_diag = torch.tensor( + _make_lower_mask_diag(), dtype=torch.float32, device=device + ) + affine_launch_spec = _deltanet_autocp_affine_launch_spec( + num_chunks, + launch_lnc, + ) + + all_outputs = [] + all_states = [] + for bh in range(BH): + output_base, output_state, state_matrix, state_bias = ( + _deltanet_autocp_affine_sequence[affine_launch_spec]( + query_flat[bh], + key_flat[bh], + value_flat[bh], + g_flat[bh], + beta_flat[bh], + lower_mask, + identity_mat, + lower_mask_diag, + ) + ) + if os.environ.get("QWEN36_DELTANET_AUTOCP_SPLIT_APPLY") == "1": + chunk_states, final_state = _deltanet_autocp_state_prefix( + state_matrix, + state_bias, + initial_state_flat[bh], + ) + out_bh = _deltanet_autocp_apply_output( + output_base, + output_state, + chunk_states, + ) + else: + out_bh, final_state = _deltanet_autocp_prefix_apply_output( + output_base, + output_state, + state_matrix, + state_bias, + initial_state_flat[bh], + ) + all_outputs.append(out_bh) + all_states.append(final_state) + + output = torch.stack(all_outputs, dim=0) + output = output.reshape(B, H, total_seq_len, v_dim) + output = output[:, :, :S] + + if output_final_state: + final_state = torch.stack(all_states, dim=0) + last_recurrent_state = final_state.reshape(B, H, k_dim, v_dim) + else: + last_recurrent_state = None + + return output, last_recurrent_state + + def _sequential_forward(self, query, key, value, g, beta, output_final_state=False): + """Sequential full-sequence gated delta rule for CTE. + + Uses the same per-step recurrence as _recurrent_step but loops over the + full sequence. Avoids the slice-assignment loop in _chunk_forward that + may compile incorrectly on Neuron/XLA. + """ + query = l2norm(query, dim=-1) + key = l2norm(key, dim=-1) + + B, H, S, k_dim = query.shape + v_dim = value.shape[-1] + scale = 1.0 / (k_dim**0.5) + query = query * scale + + state = query.new_zeros(B, H, k_dim, v_dim) + all_outputs = [] + for t in range(S): + q_t = query[:, :, t] # (B, H, K) + k_t = key[:, :, t] # (B, H, K) + v_t = value[:, :, t] # (B, H, V) + beta_t = beta[:, :, t].unsqueeze(-1) # (B, H, 1) + g_t = g[:, :, t].exp().unsqueeze(-1).unsqueeze(-1) # (B, H, 1, 1) + + # Gated delta rule + state = state * g_t + kv_mem = (state * k_t.unsqueeze(-1)).sum(dim=-2) # (B, H, V) + delta = (v_t - kv_mem) * beta_t # (B, H, V) + state = state + k_t.unsqueeze(-1) * delta.unsqueeze(-2) # (B, H, K, V) + + o_t = (state * q_t.unsqueeze(-1)).sum(dim=-2) # (B, H, V) + all_outputs.append(o_t.unsqueeze(2)) + + output = torch.cat(all_outputs, dim=2) # (B, H, S, V) + final_state = state if output_final_state else None + return output, final_state + + def _chunk_forward( + self, query, key, value, g, beta, output_final_state=False, initial_state=None + ): + """Chunk-based forward for context encoding (prefill).""" + chunk_size = 64 + + query = l2norm(query, dim=-1) + key = l2norm(key, dim=-1) + + B, H, S, k_dim = query.shape + v_dim = value.shape[-1] + scale = 1.0 / (k_dim**0.5) + query = query * scale + + pad_size = (chunk_size - S % chunk_size) % chunk_size + if pad_size > 0: + query = F.pad(query, (0, 0, 0, pad_size)) + key = F.pad(key, (0, 0, 0, pad_size)) + value = F.pad(value, (0, 0, 0, pad_size)) + beta = F.pad(beta, (0, pad_size)) + g = F.pad(g, (0, pad_size)) + total_seq_len = S + pad_size + + v_beta = value * beta.unsqueeze(-1) + k_beta = key * beta.unsqueeze(-1) + + num_chunks = total_seq_len // chunk_size + query = query.reshape(B, H, num_chunks, chunk_size, k_dim) + key = key.reshape(B, H, num_chunks, chunk_size, k_dim) + value = value.reshape(B, H, num_chunks, chunk_size, v_dim) + k_beta = k_beta.reshape(B, H, num_chunks, chunk_size, k_dim) + v_beta = v_beta.reshape(B, H, num_chunks, chunk_size, v_dim) + g = g.reshape(B, H, num_chunks, chunk_size) + + mask = torch.triu( + torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), + diagonal=0, + ) + + g = g.cumsum(dim=-1) + decay_mask = (g.unsqueeze(-1) - g.unsqueeze(-2)).tril().exp().tril() + + attn = -((k_beta @ key.transpose(-1, -2)) * decay_mask).masked_fill(mask, 0) + for i in range(1, chunk_size): + row = attn[..., i, :i].clone() + sub = attn[..., :i, :i].clone() + attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2) + attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device) + + value = attn @ v_beta + k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1)) + + if initial_state is None: + last_recurrent_state = torch.zeros( + B, H, k_dim, v_dim, dtype=query.dtype, device=query.device + ) + else: + last_recurrent_state = initial_state.to(dtype=query.dtype) + core_attn_out = torch.zeros_like(value) + mask2 = torch.triu( + torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), + diagonal=1, + ) + + for i in range(num_chunks): + q_i = query[:, :, i] + k_i = key[:, :, i] + v_i = value[:, :, i] + + attn_i = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_( + mask2, 0 + ) + + v_prime = k_cumdecay[:, :, i] @ last_recurrent_state + v_new = v_i - v_prime + + attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state + core_attn_out[:, :, i] = attn_inter + attn_i @ v_new + + last_recurrent_state = ( + last_recurrent_state * g[:, :, i, -1, None, None].exp() + + ( + k_i * (g[:, :, i, -1, None] - g[:, :, i]).exp()[..., None] + ).transpose(-1, -2) + @ v_new + ) + + core_attn_out = core_attn_out.reshape(B, H, -1, v_dim) + core_attn_out = core_attn_out[:, :, :S] + + if not output_final_state: + last_recurrent_state = None + + return core_attn_out, last_recurrent_state + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask=None, + position_ids=None, + past_key_value=None, + **kwargs, + ): + """Forward pass compatible with NxDI decoder layer interface.""" + batch_size, seq_len, _ = hidden_states.shape + + seq_ids = kwargs.get("seq_ids", None) + is_for_context_encoding = bool(kwargs.get("is_for_context_encoding", False)) + qwen_chunked_prefill_active = ( + self.use_qwen_hybrid_chunked_prefill + and past_key_value is not None + and seq_len > 1 + ) + is_decode = ( + past_key_value is not None + and not qwen_chunked_prefill_active + and not is_for_context_encoding + ) + + # Padding mask for DeltaNet: [B, S, 1] with 1.0 for real tokens, 0.0 for padding. + # Passed from get_model_output where it's computed from input_ids != pad_token_id. + # Embeddings are already zeroed for padding tokens; this mask additionally + # zeros the decay gate so the recurrent state is preserved unchanged + # through padding positions (no spurious decay). + valid_mask_1d = kwargs.get("deltanet_padding_mask", None) # [B, S, 1] or None + static_hybrid_cache_active = self.use_hybrid_cache_manager + recurrent_state_cache = None + conv_state_cache = None + if static_hybrid_cache_active and past_key_value is not None: + recurrent_state_cache, conv_state_cache = past_key_value + elif ( + self.use_hybrid_apc_manager + and past_key_value is not None + and len(past_key_value) == 2 + and getattr(past_key_value[0], "dim", lambda: 0)() == 4 + and getattr(past_key_value[1], "dim", lambda: 0)() == 3 + and past_key_value[0].shape[1:] == self.recurrent_state_buffer.shape[1:] + and past_key_value[1].shape[1:] == self.conv_state_buffer.shape[1:] + ): + recurrent_state_cache, conv_state_cache = past_key_value + + # Project inputs + deltanet_fp32 = os.environ.get("DELTANET_FP32") == "1" + if deltanet_fp32 and isinstance(self.in_proj_qkv, nn.Linear): + hs_f32 = hidden_states.float() + qkv = F.linear(hs_f32, self.in_proj_qkv.weight.float()).to( + hidden_states.dtype + ) + z = F.linear(hs_f32, self.in_proj_z.weight.float()).to(hidden_states.dtype) + b = F.linear(hs_f32, self.in_proj_b.weight.float()).to(hidden_states.dtype) + a = F.linear(hs_f32, self.in_proj_a.weight.float()).to(hidden_states.dtype) + else: + qkv = self.in_proj_qkv(hidden_states) + z = self.in_proj_z(hidden_states) + b = self.in_proj_b(hidden_states) + a = self.in_proj_a(hidden_states) + + # Split QKV + query = qkv[..., : self.key_dim] + key = qkv[..., self.key_dim : self.key_dim * 2] + value = qkv[..., self.key_dim * 2 :] + + # Causal Conv1d on QKV + mixed = torch.cat([query, key, value], dim=-1) + mixed = mixed.transpose(1, 2) + + if is_decode: + if conv_state_cache is not None: + conv_state = conv_state_cache[:batch_size] + elif seq_ids is not None: + conv_state = torch.index_select(self.conv_state_buffer, 0, seq_ids) + else: + conv_state = self.conv_state_buffer[:batch_size] + conv_input = torch.cat([conv_state, mixed], dim=-1) + + w = self._conv1d_weight().squeeze(1) + if seq_len == 1: + conv_out = ( + conv_input[:, :, : self.conv_kernel_size] * w.unsqueeze(0) + ).sum(dim=-1, keepdim=True) + else: + conv_out = torch.zeros_like(mixed) + for k in range(self.conv_kernel_size): + conv_out = ( + conv_out + + w[:, k].unsqueeze(0).unsqueeze(-1) + * conv_input[:, :, k : k + 1] + ) + mixed_post_conv = F.silu(conv_out) + + new_conv_state = torch.cat([conv_state[:, :, 1:], mixed], dim=-1) + expected_state_len = self.conv_state_buffer.shape[-1] + if new_conv_state.shape[-1] != expected_state_len: + if new_conv_state.shape[-1] > expected_state_len: + new_conv_state = new_conv_state[:, :, -expected_state_len:] + else: + new_conv_state = F.pad( + new_conv_state, + (expected_state_len - new_conv_state.shape[-1], 0), + ) + alloc_bs = self.conv_state_buffer.shape[0] + if static_hybrid_cache_active: + new_conv_state = new_conv_state.to(self.conv_state_buffer.dtype) + elif seq_ids is not None: + new_conv_state = _qwen36_update_state_rows_by_seq_ids( + self.conv_state_buffer, + new_conv_state.to(self.conv_state_buffer.dtype), + seq_ids, + ) + elif batch_size < alloc_bs: + pad_size = alloc_bs - batch_size + new_conv_state = torch.cat( + [ + new_conv_state, + self.conv_state_buffer[batch_size:] * 0, + ], + dim=0, + ) + else: + new_conv_state = new_conv_state + self.conv_state_buffer * 0 + else: + if ( + conv_state_cache is not None + and (qwen_chunked_prefill_active or is_for_context_encoding) + ): + cold_prefill_from_zero = self.use_cold_zero_conv_fast_path + if cold_prefill_from_zero: + mixed_post_conv = F.silu( + F.conv1d( + mixed, + self._conv1d_weight(), + bias=None, + padding=self.conv_kernel_size - 1, + groups=self.conv_dim, + )[:, :, :seq_len] + ) + state_source = mixed + else: + conv_state = conv_state_cache[:batch_size] + if position_ids is not None: + reset_mask = (position_ids[:, :1].long() == 0).to( + dtype=conv_state.dtype, device=conv_state.device + ) + conv_state = conv_state * ( + 1.0 - reset_mask[:, :, None] + ) + conv_input = torch.cat([conv_state, mixed], dim=-1) + w = self._conv1d_weight().squeeze(1) + conv_out = torch.zeros_like(mixed) + for k in range(self.conv_kernel_size): + conv_out = ( + conv_out + + w[:, k].unsqueeze(0).unsqueeze(-1) + * conv_input[:, :, k : k + seq_len] + ) + mixed_post_conv = F.silu(conv_out) + state_source = conv_input + + state_len = self.conv_kernel_size - 1 + if valid_mask_1d is not None: + num_valid = valid_mask_1d.squeeze(-1).sum(dim=-1, keepdim=True).long() + idx_base = (state_source.shape[-1] - seq_len + num_valid - state_len).clamp(min=0) + offsets = torch.arange(state_len, device=mixed.device).unsqueeze(0) + gather_idx = idx_base + offsets + gather_idx = gather_idx.unsqueeze(1).expand(-1, self.conv_dim, -1) + new_conv_state = torch.gather(state_source, 2, gather_idx) + else: + new_conv_state = state_source[:, :, -state_len:].contiguous() + else: + mixed_post_conv = F.silu( + F.conv1d( + mixed, + self._conv1d_weight(), + bias=None, + padding=self.conv_kernel_size - 1, + groups=self.conv_dim, + )[:, :, :seq_len] + ) + + if valid_mask_1d is not None: + # valid_mask_1d is [B, S, 1]; count valid tokens per batch + state_len = self.conv_kernel_size - 1 + num_valid = ( + valid_mask_1d.squeeze(-1).sum(dim=-1, keepdim=True).long() + ) # [B, 1] + idx_base = num_valid - state_len + idx_base = idx_base.clamp(min=0) + offsets = torch.arange(state_len, device=mixed.device).unsqueeze(0) + gather_idx = idx_base + offsets # [B, state_len] + gather_idx = gather_idx.unsqueeze(1).expand(-1, self.conv_dim, -1) + new_conv_state = torch.gather(mixed, 2, gather_idx) + else: + new_conv_state = mixed[:, :, -self.conv_kernel_size + 1 :].contiguous() + + alloc_bs = self.conv_state_buffer.shape[0] + if static_hybrid_cache_active: + new_conv_state = new_conv_state.to(self.conv_state_buffer.dtype) + elif seq_ids is not None: + new_conv_state = _qwen36_update_state_rows_by_seq_ids( + self.conv_state_buffer, + new_conv_state.to(self.conv_state_buffer.dtype), + seq_ids, + ) + elif batch_size < alloc_bs: + pad_size = alloc_bs - batch_size + new_conv_state = torch.cat( + [ + new_conv_state, + torch.zeros( + pad_size, + self.conv_dim, + self.conv_kernel_size - 1, + dtype=new_conv_state.dtype, + device=new_conv_state.device, + ), + ], + dim=0, + ) + new_conv_state = new_conv_state + self.conv_state_buffer * 0 + else: + new_conv_state = new_conv_state + self.conv_state_buffer * 0 + + mixed_post_conv = mixed_post_conv.transpose(1, 2) + + # Zero out conv1d output for padding positions. + # Conv1d with kernel_size=4 leaks real token info into the first + # few padding positions. Zeroing here ensures Q, K, V are exactly + # zero for all padding positions so the recurrence is unaffected. + if valid_mask_1d is not None: + mixed_post_conv = ( + mixed_post_conv * valid_mask_1d + ) # [B, S, conv_dim] * [B, S, 1] + + query = mixed_post_conv[..., : self.key_dim] + key = mixed_post_conv[..., self.key_dim : self.key_dim * 2] + value = mixed_post_conv[..., self.key_dim * 2 :] + + # Reshape to heads + query = query.reshape(batch_size, seq_len, self.num_k_heads, self.head_k_dim) + key = key.reshape(batch_size, seq_len, self.num_k_heads, self.head_k_dim) + value = value.reshape(batch_size, seq_len, self.num_v_heads, self.head_v_dim) + + # Compute gating + beta = b.sigmoid() + g = -self._A_log().float().exp() * F.softplus(a.float() + self._dt_bias()) + + if valid_mask_1d is not None: + # Zero g for padding → alpha=exp(0)=1 → state preserved through padding + # Zero beta for padding → no state update from padding tokens + mask_2d = valid_mask_1d.squeeze(-1).float() # [B, S] + g = g * mask_2d.unsqueeze(-1) + beta = beta * mask_2d.unsqueeze(-1) + + # Expand K heads to match V heads (16 -> 48) using expand+reshape + if self.num_v_heads // self.num_k_heads > 1: + rep = self.num_v_heads // self.num_k_heads # 3 + query = ( + query.unsqueeze(3) + .expand(-1, -1, -1, rep, -1) + .reshape(batch_size, seq_len, self.num_v_heads, self.head_k_dim) + ) + key = ( + key.unsqueeze(3) + .expand(-1, -1, -1, rep, -1) + .reshape(batch_size, seq_len, self.num_v_heads, self.head_k_dim) + ) + + # Transpose to (B, H, S, dim) + query = query.transpose(1, 2).contiguous().float() + key = key.transpose(1, 2).contiguous().float() + value = value.transpose(1, 2).contiguous().float() + g = g.transpose(1, 2).contiguous().float() + beta = beta.transpose(1, 2).contiguous().float() + + if is_decode: + # TKG: single-step recurrent update + if recurrent_state_cache is not None: + recurrent_state = recurrent_state_cache[:batch_size] + elif seq_ids is not None: + recurrent_state = torch.index_select( + self.recurrent_state_buffer, 0, seq_ids + ) + else: + recurrent_state = self.recurrent_state_buffer[:batch_size] + + use_nki_decode = ( + self.use_qwen_deltanet_decode_nki + or os.environ.get("USE_NKI_DECODE") == "1" + ) + if use_nki_decode and seq_len == 1: + output, new_state = self._nki_recurrent_step( + query, key, value, g, beta, recurrent_state + ) + else: + output, new_state = self._recurrent_step( + query, key, value, g, beta, recurrent_state.float() + ) + new_state_bf16 = new_state.to(self.recurrent_state_buffer.dtype) + alloc_bs = self.recurrent_state_buffer.shape[0] + if static_hybrid_cache_active: + new_rec_state = new_state_bf16 + elif seq_ids is not None: + new_rec_state = _qwen36_update_state_rows_by_seq_ids( + self.recurrent_state_buffer, + new_state_bf16, + seq_ids, + ) + elif batch_size < alloc_bs: + new_rec_state = torch.cat( + [ + new_state_bf16, + self.recurrent_state_buffer[batch_size:] * 0, + ], + dim=0, + ) + else: + new_rec_state = new_state_bf16 + self.recurrent_state_buffer * 0 + else: + # CTE: fused NKI kernel by default (PyTorch _chunk_forward can hit + # neuronx-cc codegen ICE NCC_INLA001 with these DeltaNet dimensions). + # Override with env vars for debugging/benchmarking. + use_nki_fused = os.environ.get("USE_NKI_FUSED", "1") != "0" + use_nki_chunked = os.environ.get("USE_NKI_CHUNKED") == "1" + use_nki = os.environ.get("USE_NKI") == "1" + use_sequential = os.environ.get("DELTANET_SEQUENTIAL") == "1" + use_pytorch_chunk = os.environ.get("USE_PYTORCH_CHUNK") == "1" + use_autocp_cte = os.environ.get("QWEN36_DELTANET_AUTOCP_CTE") == "1" + use_compact_autocp_cte = ( + os.environ.get("QWEN36_DELTANET_COMPACT_AUTOCP_CTE") == "1" + ) + + if recurrent_state_cache is not None and ( + qwen_chunked_prefill_active or is_for_context_encoding + ): + initial_state = recurrent_state_cache[:batch_size].float() + if position_ids is not None: + reset_mask = (position_ids[:, :1].long() == 0).to( + dtype=initial_state.dtype, device=initial_state.device + ) + initial_state = initial_state * (1.0 - reset_mask[:, :, None, None]) + if use_autocp_cte and use_compact_autocp_cte: + output, final_state = self._compact_autocp_chunked_forward( + query, + key, + value, + g, + beta, + output_final_state=True, + initial_state=initial_state, + ) + elif use_autocp_cte: + output, final_state = self._autocp_chunked_forward( + query, + key, + value, + g, + beta, + output_final_state=True, + initial_state=initial_state, + ) + elif use_nki_chunked or ( + self.use_qwen_hybrid_chunked_prefill_nki + and os.environ.get("USE_NKI_FUSED", "1") == "0" + ): + output, final_state = self._nki_chunked_forward( + query, + key, + value, + g, + beta, + output_final_state=True, + initial_state=initial_state, + ) + elif use_pytorch_chunk: + output, final_state = self._chunk_forward( + query, + key, + value, + g, + beta, + output_final_state=True, + initial_state=initial_state, + ) + else: + output, final_state = self._fused_chunked_forward( + query, + key, + value, + g, + beta, + output_final_state=True, + initial_state=initial_state, + ) + elif use_pytorch_chunk: + output, final_state = self._chunk_forward( + query, key, value, g, beta, output_final_state=True + ) + elif use_autocp_cte and use_compact_autocp_cte: + output, final_state = self._compact_autocp_chunked_forward( + query, key, value, g, beta, output_final_state=True + ) + elif use_autocp_cte: + output, final_state = self._autocp_chunked_forward( + query, key, value, g, beta, output_final_state=True + ) + elif use_nki_chunked: + output, final_state = self._nki_chunked_forward( + query, key, value, g, beta, output_final_state=True + ) + elif use_nki: + output, final_state = self._nki_recurrent_forward( + query, key, value, g, beta + ) + elif use_sequential: + output, final_state = self._sequential_forward( + query, key, value, g, beta, output_final_state=True + ) + elif use_nki_fused: + output, final_state = self._fused_chunked_forward( + query, key, value, g, beta, output_final_state=True + ) + else: + output, final_state = self._fused_chunked_forward( + query, key, value, g, beta, output_final_state=True + ) + + if final_state is not None: + final_state_bf16 = final_state.to(self.recurrent_state_buffer.dtype) + alloc_bs = self.recurrent_state_buffer.shape[0] + if static_hybrid_cache_active: + new_rec_state = final_state_bf16 + elif seq_ids is not None: + new_rec_state = _qwen36_update_state_rows_by_seq_ids( + self.recurrent_state_buffer, + final_state_bf16, + seq_ids, + ) + elif batch_size < alloc_bs: + new_rec_state = torch.cat( + [ + final_state_bf16, + torch.zeros( + alloc_bs - batch_size, + self.num_v_heads, + self.head_k_dim, + self.head_v_dim, + dtype=final_state_bf16.dtype, + device=final_state_bf16.device, + ), + ], + dim=0, + ) + new_rec_state = new_rec_state + self.recurrent_state_buffer * 0 + else: + new_rec_state = final_state_bf16 + self.recurrent_state_buffer * 0 + else: + new_rec_state = self.recurrent_state_buffer * 1 + + if ( + is_for_context_encoding + and not static_hybrid_cache_active + and valid_mask_1d is not None + and hasattr(valid_mask_1d, "numel") + and valid_mask_1d.numel() > 0 + ): + active_rows = _qwen36_active_state_rows(valid_mask_1d, seq_ids) + new_conv_state = _qwen36_preserve_inactive_state_rows( + new_conv_state, + self.conv_state_buffer, + active_rows, + ) + new_rec_state = _qwen36_preserve_inactive_state_rows( + new_rec_state, + self.recurrent_state_buffer, + active_rows, + ) + + # Output: norm, gate, project + output = output.to(hidden_states.dtype) + output = output.transpose(1, 2).contiguous() + output = output.reshape(batch_size, seq_len, self.num_v_heads, self.head_v_dim) + output = self.norm(output) + z_gate = z.reshape(batch_size, seq_len, self.num_v_heads, self.head_v_dim) + output = output * F.silu(z_gate) + output = output.reshape(batch_size, seq_len, self.value_dim) + output = self.out_proj(output) + + if static_hybrid_cache_active: + return output, (new_rec_state, new_conv_state), new_rec_state, new_conv_state + + # Return dummy KV for KVCacheManager + dummy_k = torch.zeros( + batch_size, + self.kv_heads_per_rank, + seq_len, + self.head_dim, + dtype=hidden_states.dtype, + device=hidden_states.device, + ) + dummy_v = torch.zeros_like(dummy_k) + + return output, (dummy_k, dummy_v), new_rec_state, new_conv_state + + +# ============================================================ +# InferenceConfig (Dense -- no MoE) +# ============================================================ + + +class Qwen35InferenceConfig(InferenceConfig): + """Config for Qwen3.5/3.6-27B (dense) with hybrid DeltaNet + Attention.""" + + @classmethod + def from_pretrained(cls, model_path: str, **kwargs) -> "Qwen35InferenceConfig": + """Load Qwen3.5/Qwen3.6 text config from a pretrained model directory. + + Qwen3.6 stores the decoder settings under the top-level multimodal + `text_config`. NxDI's text-only inference config expects those fields + flattened onto the inference config itself. + """ + neuron_config = kwargs.pop("neuron_config", None) + if neuron_config is None: + neuron_config = NeuronConfig( + tp_degree=1, + batch_size=1, + seq_len=128, + torch_dtype=torch.bfloat16, + save_sharded_checkpoint=True, + ) + + config_path = os.path.join(model_path, "config.json") + if not os.path.exists(config_path): + raise FileNotFoundError(f"Configuration file not found at {config_path}") + + with open(config_path, "r", encoding="utf-8") as handle: + config_dict = json.load(handle) + + text_config = config_dict.get("text_config", config_dict) + rope_parameters = text_config.get("rope_parameters") or {} + inference_config = dict(text_config) + inference_config.setdefault("_name_or_path", model_path) + inference_config.setdefault("model_type", "qwen3_5_text") + inference_config.setdefault("architectures", config_dict.get("architectures", [])) + inference_config.setdefault("tie_word_embeddings", config_dict.get("tie_word_embeddings", False)) + if "rope_theta" not in inference_config and "rope_theta" in rope_parameters: + inference_config["rope_theta"] = rope_parameters["rope_theta"] + if ( + "partial_rotary_factor" not in inference_config + and "partial_rotary_factor" in rope_parameters + ): + inference_config["partial_rotary_factor"] = rope_parameters[ + "partial_rotary_factor" + ] + inference_config.update(kwargs) + return cls(neuron_config=neuron_config, **inference_config) + + def __init__(self, *args, **kwargs): + # Set defaults BEFORE super().__init__() because it calls validate_config() + # which checks get_required_attributes(). These can be overridden by + # kwargs or load_config. + + # Layer types for hybrid dispatch: [3 DeltaNet + 1 GQA] repeated. + if "layer_types" not in kwargs and not any( + hasattr(a, "layer_types") for a in args if hasattr(a, "__dict__") + ): + num_layers = kwargs.get("num_hidden_layers", 64) + if num_layers % 4 != 0: + raise ValueError( + f"Qwen3.5 hybrid layer count must be divisible by 4, got {num_layers}" + ) + layer_types = [] + for _ in range(num_layers // 4): + layer_types.extend( + [ + "linear_attention", + "linear_attention", + "linear_attention", + "full_attention", + ] + ) + kwargs.setdefault("layer_types", layer_types) + + # DeltaNet-specific config defaults + kwargs.setdefault("linear_num_value_heads", 48) + kwargs.setdefault("linear_num_key_heads", 16) + kwargs.setdefault("linear_key_head_dim", 128) + kwargs.setdefault("linear_value_head_dim", 128) + kwargs.setdefault("linear_conv_kernel_dim", 4) + kwargs.setdefault("use_hybrid_cache_manager", False) + kwargs.setdefault("use_hybrid_apc_manager", False) + kwargs.setdefault("use_qwen_hybrid_chunked_prefill", False) + kwargs.setdefault("use_qwen_hybrid_chunked_prefill_nki", False) + kwargs.setdefault("use_qwen_deltanet_decode_nki", False) + kwargs.setdefault("gdn_checkpoint_interval", 256) + kwargs.setdefault("max_gdn_checkpoint_slots", 8) + kwargs.setdefault("hybrid_apc_layout_version", 1) + kwargs.setdefault("hybrid_apc_allow_residual_replay", False) + kwargs.setdefault("hybrid_apc_cache_salt", None) + use_hybrid_apc_manager = bool(kwargs.get("use_hybrid_apc_manager", False)) + kwargs.setdefault( + "hybrid_apc_require_vllm_metadata", use_hybrid_apc_manager + ) + kwargs.setdefault( + "hybrid_apc_allow_local_hash_fallback", not use_hybrid_apc_manager + ) + kwargs.setdefault( + "hybrid_apc_require_attention_block_refs", use_hybrid_apc_manager + ) + kwargs.setdefault("hybrid_apc_reject_unbacked_attention_hits", True) + kwargs.setdefault("hybrid_apc_disable_unbacked_prefix_reads", False) + kwargs.setdefault("hybrid_apc_enable_backed_prefix_reads", False) + kwargs.setdefault( + "hybrid_apc_model_revision", + kwargs.get("_name_or_path", kwargs.get("model_revision", "unknown")), + ) + kwargs.setdefault( + "hybrid_recurrent_cache_dtype", + kwargs.get("gdn_recurrent_cache_dtype", "float32"), + ) + kwargs.setdefault( + "hybrid_conv_cache_dtype", + kwargs.get("gdn_conv_cache_dtype", "bfloat16"), + ) + kwargs.setdefault( + "gdn_recurrent_cache_dtype", kwargs["hybrid_recurrent_cache_dtype"] + ) + kwargs.setdefault("gdn_conv_cache_dtype", kwargs["hybrid_conv_cache_dtype"]) + kwargs.setdefault("hybrid_cache_mode", "all") + kwargs.setdefault( + "hybrid_cache_prefix_boundary_only", + kwargs.get("hybrid_cache_block_boundary_only", True), + ) + kwargs.setdefault( + "hybrid_cache_block_boundary_only", + kwargs["hybrid_cache_prefix_boundary_only"], + ) + kwargs.setdefault("hybrid_cache_validate_exact", False) + kwargs.setdefault("use_text_only_cte_inputs", True) + kwargs.setdefault("use_compact_cte_attention_mask", True) + kwargs.setdefault("use_cold_zero_conv_fast_path", False) + kwargs.setdefault("disable_token_generation_wlo", False) + + super().__init__(*args, **kwargs) + + self.gdn_checkpoint_interval = int(self.gdn_checkpoint_interval) + if self.gdn_checkpoint_interval <= 0: + raise ValueError( + "gdn_checkpoint_interval must be positive, " + f"got {self.gdn_checkpoint_interval}" + ) + self.max_gdn_checkpoint_slots = int(self.max_gdn_checkpoint_slots) + if self.max_gdn_checkpoint_slots <= 0: + raise ValueError( + "max_gdn_checkpoint_slots must be positive, " + f"got {self.max_gdn_checkpoint_slots}" + ) + self.hybrid_apc_layout_version = int(self.hybrid_apc_layout_version) + self.hybrid_recurrent_cache_dtype = _normalize_hybrid_cache_dtype( + "hybrid_recurrent_cache_dtype", + self.hybrid_recurrent_cache_dtype, + "float32", + ) + self.hybrid_conv_cache_dtype = _normalize_hybrid_cache_dtype( + "hybrid_conv_cache_dtype", + self.hybrid_conv_cache_dtype, + "bfloat16", + ) + self.gdn_recurrent_cache_dtype = self.hybrid_recurrent_cache_dtype + self.gdn_conv_cache_dtype = self.hybrid_conv_cache_dtype + self.hybrid_cache_block_boundary_only = ( + self.hybrid_cache_prefix_boundary_only + ) + self.hybrid_apc_require_vllm_metadata = bool( + self.hybrid_apc_require_vllm_metadata + ) + self.hybrid_apc_allow_local_hash_fallback = bool( + self.hybrid_apc_allow_local_hash_fallback + ) + self.hybrid_apc_require_attention_block_refs = bool( + self.hybrid_apc_require_attention_block_refs + ) + self.hybrid_apc_reject_unbacked_attention_hits = bool( + self.hybrid_apc_reject_unbacked_attention_hits + ) + self.hybrid_apc_disable_unbacked_prefix_reads = bool( + self.hybrid_apc_disable_unbacked_prefix_reads + ) + if self.hybrid_apc_require_vllm_metadata: + self.hybrid_apc_allow_local_hash_fallback = False + self.hybrid_apc_require_attention_block_refs = True + self.hybrid_apc_reject_unbacked_attention_hits = True + if self.use_hybrid_cache_manager and self.use_hybrid_apc_manager: + raise ValueError( + "use_hybrid_cache_manager and use_hybrid_apc_manager are mutually exclusive" + ) + if self.use_hybrid_apc_manager and self.hybrid_cache_mode != "all": + raise ValueError("use_hybrid_apc_manager requires hybrid_cache_mode='all'") + if self.use_hybrid_apc_manager: + if self.hybrid_recurrent_cache_dtype != "float32": + raise ValueError( + "use_hybrid_apc_manager requires float32 recurrent GDN " + "checkpoint cache state; bf16 checkpoint roundtrips are not " + "coherent for all-mode prefix caching" + ) + pa_block_size = getattr(self.neuron_config, "pa_block_size", None) + if pa_block_size is not None and self.gdn_checkpoint_interval != int( + pa_block_size + ): + raise ValueError( + "use_hybrid_apc_manager v0 requires " + "gdn_checkpoint_interval == pa_block_size" + ) + if self.hybrid_apc_allow_residual_replay: + raise ValueError( + "hybrid_apc_allow_residual_replay is reserved for v1; " + "v0 restores only exact checkpoint boundaries" + ) + + # Attention output gate + self.attn_output_gate = getattr(self, "attn_output_gate", True) + + # Partial RoPE + self.partial_rotary_factor = getattr(self, "partial_rotary_factor", 0.25) + self.rope_dim = int(self.head_dim * self.partial_rotary_factor) # 64 + + # mRoPE (multimodal RoPE) for VL support + rope_params = getattr(self, "rope_parameters", {}) or {} + self.mrope_section = rope_params.get("mrope_section", [11, 11, 10]) + self.mrope_interleaved = rope_params.get("mrope_interleaved", True) + + # Standard HF config attributes expected by NxDI + if not hasattr(self, "output_attentions"): + self.output_attentions = False + if not hasattr(self, "output_hidden_states"): + self.output_hidden_states = False + + def get_required_attributes(self) -> List[str]: + return [ + "head_dim", + "hidden_act", + "hidden_size", + "intermediate_size", + "max_position_embeddings", + "num_attention_heads", + "num_hidden_layers", + "num_key_value_heads", + "rms_norm_eps", + "rope_theta", + "vocab_size", + # DeltaNet-specific + "linear_num_value_heads", + "linear_num_key_heads", + "linear_key_head_dim", + "linear_value_head_dim", + "linear_conv_kernel_dim", + "layer_types", + ] + + @classmethod + def get_neuron_config_cls(cls): + return NeuronConfig + + +# ============================================================ +# Attention (standard GQA for 16 of 64 layers) +# With output gate: q_proj is 2x sized, split into (query, gate) +# With partial RoPE: only first rope_dim dimensions get rotary +# ============================================================ + + +class Qwen35MRoPEEmbedding(nn.Module): + """Multimodal Rotary Position Embedding (mRoPE) for Qwen3.5. + + Handles 3D position information (temporal, height, width) for VL models. + Position IDs have shape (3, batch_size, seq_len) for T/H/W dimensions. + For text-only (2D position_ids), broadcasts to 3D with identical positions. + """ + + def __init__(self, config): + super().__init__() + self.head_dim = config.head_dim # 256 + self.rope_dim = config.rope_dim # 64 + self.mrope_section = config.mrope_section # [11, 11, 10] + self.mrope_interleaved = getattr(config, "mrope_interleaved", True) + self.rope_theta = config.rope_theta + + # Validate mrope_section sums to rope_dim // 2 = 32 + assert sum(self.mrope_section) == self.rope_dim // 2, ( + f"mrope_section {self.mrope_section} sums to {sum(self.mrope_section)}, " + f"expected {self.rope_dim // 2}" + ) + + def forward(self, x, position_ids_3d): + """Compute cos/sin from 3D position IDs. + + Args: + x: hidden_states (for device/dtype inference) + position_ids_3d: (3, batch_size, seq_len) -- T, H, W positions + + Returns: + cos: (batch_size, seq_len, rope_dim) + sin: (batch_size, seq_len, rope_dim) + """ + device = x.device + dtype = torch.float32 + + if position_ids_3d.ndim == 2: + position_ids_3d = position_ids_3d[None, ...].expand( + 3, position_ids_3d.shape[0], -1 + ) + + inv_freq = 1.0 / ( + self.rope_theta + ** ( + torch.arange(0, self.rope_dim, 2, dtype=dtype, device=device) + / self.rope_dim + ) + ) + inv_freq = inv_freq[None, None, :, None].expand( + 3, position_ids_3d.shape[1], -1, 1 + ) + positions = position_ids_3d[:, :, None, :].float() + freqs = (inv_freq.float() @ positions).transpose(2, 3) + + # Match HF Qwen3.6 mRoPE layout exactly: start from the temporal + # frequencies, then splice H/W frequencies into interleaved positions. + freqs_t = freqs[0] + if self.mrope_interleaved: + for dim, offset in enumerate((1, 2), start=1): + length = self.mrope_section[dim] * 3 + idx = slice(offset, length, 3) + freqs_t[..., idx] = freqs[dim, ..., idx] + + emb = torch.cat((freqs_t, freqs_t), dim=-1) + cos = emb.cos().to(dtype=x.dtype) + sin = emb.sin().to(dtype=x.dtype) + + return cos, sin + + +class NeuronQwen35Attention(NeuronAttentionBase): + """Standard GQA attention for Qwen3.5 with output gate and partial RoPE. + + 24 Q heads, 4 KV heads (6:1 GQA), head_dim=256 for 27B dense. + q_proj is doubled (query + gate), split at load time. + Only first rope_dim=64 of head_dim=256 gets rotary encoding. + + Uses NeuronAttentionBase infrastructure for QKV projection, KV cache, + RoPE, and attention computation. Overrides forward() to insert the + sigmoid output gate between attention output and o_proj. + """ + + def __init__(self, config): + # Partial RoPE: create mRoPE embedding with rope_dim (64) + self.rope_dim = config.rope_dim # 64 = head_dim * partial_rotary_factor + + # Create QK norm modules (will be passed to base class) + rms_norm_eps = config.rms_norm_eps + q_ln = get_rmsnorm_cls()(config.head_dim, rms_norm_eps) + k_ln = get_rmsnorm_cls()(config.head_dim, rms_norm_eps) + + # Partial RoPE: use standard RotaryEmbedding. + # For VL with 3D mRoPE positions, cos/sin are pre-computed externally in + # get_model_output() using Qwen35MRoPEEmbedding and passed as cos_cache/sin_cache. + rotary_emb = RotaryEmbedding( + self.rope_dim, # Only 64 dims get rotary embedding + max_position_embeddings=config.max_position_embeddings, + base=config.rope_theta, + ) + super().__init__( + config=config, + hidden_size=config.hidden_size, + num_attention_heads=config.num_attention_heads, + num_key_value_heads=config.num_key_value_heads, + head_dim=config.head_dim, + rotary_emb=rotary_emb, + rms_norm_eps=rms_norm_eps, + use_qk_norm=False, + q_layernorm=q_ln, + k_layernorm=k_ln, + ) + + # Separate mRoPE module for VL 3D position_ids + self.mrope_emb = Qwen35MRoPEEmbedding(config) + + # Output gate projection: hidden_size -> num_heads * head_dim + # Populated from the second half of q_proj during state dict conversion. + self.output_gate_proj = ColumnParallelLinear( + config.hidden_size, + config.num_attention_heads * config.head_dim, + bias=False, + gather_output=False, + ) + + self.qwen_output_gate_nki_kernel_enabled = bool( + getattr(config, "use_qwen_output_gate_nki", False) + or os.environ.get("QWEN36_OUTPUT_GATE_NKI", "0") == "1" + ) + self.qwen_qkv_gate_packed_enabled = bool( + getattr(config, "use_qwen_qkv_gate_packed", False) + or os.environ.get("QWEN36_QKV_GATE_PACKED", "0") == "1" + ) + self.qwen_gated_o_proj_nki_kernel_enabled = bool( + getattr(config, "use_qwen_gated_o_proj_nki", False) + or os.environ.get("QWEN36_GATED_OUT_PROJ_NKI", "0") == "1" + ) + if ( + self.qwen_output_gate_nki_kernel_enabled + and self.qwen_qkv_gate_packed_enabled + ): + raise ValueError( + "Qwen output-gate NKI and packed QKV+gate are mutually exclusive." + ) + if self.qwen_output_gate_nki_kernel_enabled: + if _qwen_gate_projection_kernel is None: + raise ImportError( + "QWEN36_OUTPUT_GATE_NKI requires nkilib.core.qkv.qkv" + ) + if getattr(config.neuron_config, "quantized", False): + setattr( + self.output_gate_proj, + "post_create_quantized_module_hook", + preprocess_quantized_linear_layer, + ) + else: + self.output_gate_proj.weight = transpose_parallel_linear_layer( + self.output_gate_proj.weight + ) + + if self.qwen_qkv_gate_packed_enabled: + if _qwen_gate_projection_kernel is None: + raise ImportError( + "QWEN36_QKV_GATE_PACKED requires nkilib.core.qkv.qkv" + ) + if not self.fused_qkv: + raise ValueError("QWEN36_QKV_GATE_PACKED requires fused_qkv=True") + self._enable_qwen_qkv_gate_packed_projection(config) + + self.qwen_qk_norm_rope_nki_kernel_enabled = bool( + getattr(config, "use_qwen_qk_norm_rope_nki", False) + or os.environ.get("QWEN36_QK_NORM_ROPE_NKI", "0") == "1" + ) + if self.qwen_qk_norm_rope_nki_kernel_enabled: + if _qwen_qk_norm_partial_rope_kernel is None: + raise ImportError( + "QWEN36_QK_NORM_ROPE_NKI requires src.nki_kernels." + "qwen_qk_norm_rope" + ) + if self.head_dim != 256 or self.rope_dim != 64: + raise ValueError( + "Qwen Q/K norm+RoPE NKI kernel currently supports only " + f"head_dim=256 and rope_dim=64, got head_dim={self.head_dim}, " + f"rope_dim={self.rope_dim}" + ) + + self.qkv_tkg_nki_kernel_enabled = bool( + getattr(config.neuron_config, "qkv_tkg_nki_kernel_enabled", False) + ) and not bool(getattr(config.neuron_config, "is_prefill_stage", False)) + if self.qkv_tkg_nki_kernel_enabled: + if _qkv_tkg_nki_kernel is None: + raise ImportError( + "qkv_tkg_nki_kernel_enabled requires " + "neuronxcc.nki._pre_prod_kernels.qkv_tkg_impl" + ) + if self.fused_qkv: + raise ValueError( + "qkv_tkg_nki_kernel_enabled uses split q/k/v projections " + "and must not be combined with fused_qkv" + ) + if self.qkv_proj_sp_enabled: + raise ValueError( + "qkv_tkg_nki_kernel_enabled does not support sequence-parallel " + "QKV projection" + ) + qkv_proj = self.get_qkv_proj() + split_qkv_projections = ( + qkv_proj.q_proj, + qkv_proj.k_proj, + qkv_proj.v_proj, + ) + for projection in split_qkv_projections: + if not getattr(config.neuron_config, "quantized", False): + projection.weight = transpose_parallel_linear_layer(projection.weight) + + def _enable_qwen_qkv_gate_packed_projection(self, config): + for attr_name in ("qkv_proj", "cte_qkv_proj", "tkg_qkv_proj"): + qkv_proj = getattr(self, attr_name, None) + if qkv_proj is not None and getattr(qkv_proj, "fused_qkv", False): + self._replace_qkv_projection_with_qwen_qkvgate(qkv_proj, config) + + def _replace_qkv_projection_with_qwen_qkvgate(self, qkv_proj, config): + if not hasattr(qkv_proj, "Wqkv"): + raise ValueError("QWEN36_QKV_GATE_PACKED requires a fused Wqkv module") + if not isinstance(qkv_proj.Wqkv, ColumnParallelLinear): + raise ValueError( + "QWEN36_QKV_GATE_PACKED currently supports ColumnParallelLinear Wqkv" + ) + + packed_q_heads = qkv_proj.num_attention_heads * 2 + packed_output_size = ( + packed_q_heads + 2 * qkv_proj.num_key_value_heads + ) * qkv_proj.head_dim + packed_wqkv = ColumnParallelLinear( + qkv_proj.hidden_size, + packed_output_size, + bias=qkv_proj.bias, + gather_output=qkv_proj.gather_output, + dtype=qkv_proj.dtype, + sequence_parallel_enabled=False, + tensor_model_parallel_group=qkv_proj.tensor_model_parallel_group, + rank_ordering=qkv_proj.rank_ordering, + ) + if ( + (qkv_proj.qkv_kernel_enabled or qkv_proj.qkv_nki_kernel_enabled) + and getattr(config.neuron_config, "quantized", False) + ): + setattr( + packed_wqkv, + "post_create_quantized_module_hook", + preprocess_quantized_linear_layer, + ) + elif qkv_proj.qkv_kernel_enabled or qkv_proj.qkv_nki_kernel_enabled: + packed_wqkv.weight = transpose_parallel_linear_layer(packed_wqkv.weight) + + for param in ( + [packed_wqkv.weight, packed_wqkv.scale] + if hasattr(packed_wqkv, "scale") + else [packed_wqkv.weight] + ): + setattr(param, "fused_qkv", True) + setattr(param, "num_attention_heads", packed_q_heads) + setattr(param, "num_key_value_heads", qkv_proj.num_key_value_heads) + setattr(param, "head_dim", qkv_proj.head_dim) + if qkv_proj.bias: + setattr(packed_wqkv.bias, "fused_qkv", True) + setattr(packed_wqkv.bias, "num_attention_heads", packed_q_heads) + setattr(packed_wqkv.bias, "num_key_value_heads", qkv_proj.num_key_value_heads) + setattr(packed_wqkv.bias, "head_dim", qkv_proj.head_dim) + + qkv_proj.Wqkv = packed_wqkv + qkv_proj.qwen_qkv_gate_packed = True + qkv_proj.qwen_real_num_attention_heads = qkv_proj.num_attention_heads + qkv_proj.qwen_packed_num_attention_heads = packed_q_heads + + @staticmethod + def _apply_projection_scale(output, projection): + scale = getattr(projection, "scale", None) + if scale is None: + return output + scale_tensor = scale.data if hasattr(scale, "data") else scale + if ( + scale_tensor.ndim == 2 + and scale_tensor.shape[0] == 128 + and scale_tensor.shape[1] == output.shape[-1] + ): + scale_tensor = scale_tensor[0] + else: + scale_tensor = scale_tensor.reshape(-1) + if scale_tensor.numel() != output.shape[-1]: + raise ValueError( + "QKV TKG projection scale shape does not match output width: " + f"scale={tuple(scale.shape)}, output={tuple(output.shape)}" + ) + return output * scale_tensor.reshape(1, 1, output.shape[-1]).to(output.dtype) + + @staticmethod + def _prepare_qkv_tkg_scale(scale_tensor, output_width): + if ( + scale_tensor.ndim == 2 + and scale_tensor.shape[0] == 128 + and scale_tensor.shape[1] == output_width + ): + return scale_tensor.contiguous() + if ( + scale_tensor.ndim == 2 + and scale_tensor.shape[0] == output_width + and scale_tensor.shape[1] == 1 + ): + return torch.broadcast_to( + scale_tensor.transpose(0, 1), + (128, output_width), + ).contiguous() + if ( + scale_tensor.ndim == 2 + and scale_tensor.shape[0] == 1 + and scale_tensor.shape[1] == output_width + ): + return torch.broadcast_to(scale_tensor, (128, output_width)).contiguous() + if scale_tensor.numel() == output_width: + return torch.broadcast_to( + scale_tensor.reshape(1, output_width), + (128, output_width), + ).contiguous() + raise ValueError( + "QKV TKG projection scale shape does not match output width: " + f"scale={tuple(scale_tensor.shape)}, output_width={output_width}" + ) + + def _run_split_qkv_tkg_projection(self, hidden_states, projection, local_heads): + bias = ( + projection.bias.data.unsqueeze(0) + if getattr(projection, "bias", None) is not None + else None + ) + weight = projection.weight.data + if weight.shape[0] != self.hidden_size and weight.shape[1] == self.hidden_size: + weight = weight.transpose(0, 1).contiguous() + # The preprod QKV TKG kernel's LNC2 path reduces across pi0 and then + # stores both programs to the same shared-HBM slice, which the current + # NKI verifier rejects as an output dependency. Use the single-LNC + # variant for this split projection until that kernel store is fixed. + kernel = _qkv_tkg_nki_kernel[1] + scale = getattr(projection, "scale", None) + if scale is not None: + scale_tensor = scale.data if hasattr(scale, "data") else scale + qkv_w_scales = self._prepare_qkv_tkg_scale( + scale_tensor, + weight.shape[1], + ) + quantization_type = getattr(_QKVQuantizationType, "ROW", None) + if quantization_type is None: + raise ValueError( + "qkv_tkg_nki_kernel_enabled requires ROW quantization support " + "when running quantized split-QKV projections" + ) + else: + qkv_w_scales = None + quantization_type = _QKVQuantizationType.NONE + + output = kernel( + hidden=hidden_states, + qkv_w=weight, + norm_w=None, + fused_add=False, + mlp_prev=None, + attn_prev=None, + d_head=self.head_dim, + output_layout=_QKVOutputLayout.BSD, + eps=self.rms_norm_eps, + norm_type=_QKVNormType.NO_NORM, + qkvInSB=False, + qkv_bias=bias, + norm_bias=None, + hidden_actual=self.hidden_size, + B=hidden_states.shape[0], + S=hidden_states.shape[1], + H=self.hidden_size, + num_q_heads=local_heads, + num_kv_heads=local_heads, + quantization_type=quantization_type, + qkv_w_scales=qkv_w_scales, + qkv_in_scales=None, + ) + if qkv_w_scales is not None: + return output + return self._apply_projection_scale(output, projection) + + def _prep_split_qkv_tkg_tensors( + self, + position_ids, + hidden_states, + past_key_value, + adapter_ids=None, + cos_cache=None, + sin_cache=None, + use_polar_compatible_rope=False, + ): + # NxDI traces a placeholder adapter_ids tensor even when no LoRA + # adapters are active. Qwen3.6 serving here is non-LoRA, so the split + # projection path intentionally ignores the placeholder. + qkv_proj = self.get_qkv_proj() + Q = self._run_split_qkv_tkg_projection( + hidden_states, + qkv_proj.q_proj, + self.num_heads, + ) + K = self._run_split_qkv_tkg_projection( + hidden_states, + qkv_proj.k_proj, + self.num_key_value_heads, + ) + V = self._run_split_qkv_tkg_projection( + hidden_states, + qkv_proj.v_proj, + self.num_key_value_heads, + ) + + bsz, q_len, _ = hidden_states.size() + Q = move_heads_front( + Q, + bsz, + q_len, + self.num_heads, + self.head_dim, + layernorm=self.q_layernorm, + post_transpose_layernorm=self.post_transpose_layernorm, + ) + K = move_heads_front( + K, + bsz, + q_len, + self.num_key_value_heads, + self.head_dim, + layernorm=self.k_layernorm, + post_transpose_layernorm=self.post_transpose_layernorm, + ) + V = move_heads_front( + V, + bsz, + q_len, + self.num_key_value_heads, + self.head_dim, + layernorm=None, + ) + + Q, K, cos_cache, sin_cache = self.apply_rotary_embedding( + Q, + K, + V, + position_ids, + cos_cache, + sin_cache, + use_polar_compatible_rope, + ) + return Q, K, V, cos_cache, sin_cache, None + + def _should_use_qwen_output_gate_nki(self, q_len): + return self.qwen_output_gate_nki_kernel_enabled + + def _should_use_qwen_qkv_gate_packed(self, q_len): + return ( + self.qwen_qkv_gate_packed_enabled + and not self.qkv_proj_sp_enabled + and _qwen_gate_projection_kernel is not None + ) + + def _should_use_qwen_gated_o_proj_nki(self, q_len): + o_proj = self.get_o_proj() + return ( + self.qwen_gated_o_proj_nki_kernel_enabled + and q_len > 1 + and hasattr(o_proj, "forward_gated") + ) + + def _output_gate_proj_nki(self, hidden_states): + weight = self.output_gate_proj.weight.data + bias = ( + self.output_gate_proj.bias.data.unsqueeze(0) + if getattr(self.output_gate_proj, "bias", None) is not None + else None + ) + + qkv_w_scale = None + qkv_in_scale = None + quantization_type = _NkilibQuantizationType.NONE + gate_scale = getattr(self.output_gate_proj, "scale", None) + if gate_scale is not None: + qkv_w_scale = gate_scale.data + gate_input_scale = getattr(self.output_gate_proj, "input_scale", None) + qkv_in_scale = gate_input_scale.data if gate_input_scale is not None else None + quantization_type = _NkilibQuantizationType.ROW + elif getattr(self.config.neuron_config, "quantized", False): + raise RuntimeError( + "Qwen output-gate NKI path requires output_gate_proj.scale " + "when running a quantized artifact." + ) + + return _qwen_gate_projection_kernel[self.logical_nc_config]( + input=hidden_states, + fused_qkv_weights=weight, + output_layout=_NkilibQKVOutputLayout.BSD, + bias=bias, + quantization_type=quantization_type, + qkv_w_scale=qkv_w_scale, + qkv_in_scale=qkv_in_scale, + ) + + def _qkv_gate_packed_projection_nki(self, hidden_states): + qkv_proj = self.get_qkv_proj() + weight = qkv_proj.Wqkv.weight.data + bias = ( + qkv_proj.Wqkv.bias.data.unsqueeze(0) + if getattr(qkv_proj.Wqkv, "bias", None) is not None + else None + ) + + qkv_w_scale = None + qkv_in_scale = None + quantization_type = _NkilibQuantizationType.NONE + qkv_scale = getattr(qkv_proj.Wqkv, "scale", None) + if qkv_scale is not None: + qkv_w_scale = qkv_scale.data + qkv_input_scale = getattr(qkv_proj.Wqkv, "input_scale", None) + qkv_in_scale = qkv_input_scale.data if qkv_input_scale is not None else None + quantization_type = _NkilibQuantizationType.ROW + elif getattr(self.config.neuron_config, "quantized", False): + raise RuntimeError( + "Qwen packed QKV+gate path requires Wqkv.scale when running " + "a quantized artifact." + ) + + packed = _qwen_gate_projection_kernel[self.logical_nc_config]( + input=hidden_states, + fused_qkv_weights=weight, + output_layout=_NkilibQKVOutputLayout.BSD, + bias=bias, + fused_residual_add=False, + mlp_prev=None, + attention_prev=None, + fused_norm_type=_NkilibNormType.NO_NORM, + gamma_norm_weights=None, + norm_eps=self.rms_norm_eps, + fused_rope=False, + cos_cache=None, + sin_cache=None, + quantization_type=quantization_type, + qkv_w_scale=qkv_w_scale, + qkv_in_scale=qkv_in_scale, + d_head=self.head_dim, + num_q_heads=self.num_heads * 2, + num_kv_heads=self.num_key_value_heads, + ) + + q_width = self.num_heads * self.head_dim + gate_end = q_width * 2 + k_end = gate_end + self.num_key_value_heads * self.head_dim + Q, gate, K, V = torch.tensor_split( + packed, + (q_width, gate_end, k_end), + dim=2, + ) + return Q, gate, K, V + + def _prep_qkv_gate_packed_tensors( + self, + position_ids, + hidden_states, + past_key_value, + adapter_ids=None, + cos_cache=None, + sin_cache=None, + use_polar_compatible_rope=False, + ): + Q, gate, K, V = self._qkv_gate_packed_projection_nki(hidden_states) + + bsz, q_len, _ = hidden_states.size() + V = move_heads_front( + V, + bsz, + q_len, + self.num_key_value_heads, + self.head_dim, + layernorm=None, + ) + if cos_cache is None or sin_cache is None: + cos_cache, sin_cache = self.rotary_emb(V, position_ids) + if ( + self._should_use_qwen_qk_norm_rope_nki(q_len) + and cos_cache is not None + and sin_cache is not None + ): + Q, K = _qwen_qk_norm_partial_rope_kernel[self.logical_nc_config]( + Q, + K, + self.q_layernorm.weight.data, + self.k_layernorm.weight.data, + cos_cache, + sin_cache, + self.rms_norm_eps, + ) + else: + Q = move_heads_front( + Q, + bsz, + q_len, + self.num_heads, + self.head_dim, + layernorm=self.q_layernorm, + post_transpose_layernorm=self.post_transpose_layernorm, + ) + K = move_heads_front( + K, + bsz, + q_len, + self.num_key_value_heads, + self.head_dim, + layernorm=self.k_layernorm, + post_transpose_layernorm=self.post_transpose_layernorm, + ) + Q, K, cos_cache, sin_cache = self.apply_rotary_embedding( + Q, + K, + V, + position_ids, + cos_cache, + sin_cache, + use_polar_compatible_rope, + ) + return Q, K, V, gate, cos_cache, sin_cache, None + + def _should_use_qwen_qk_norm_rope_nki(self, q_len): + return ( + self.qwen_qk_norm_rope_nki_kernel_enabled + and q_len > 1 + and self.q_layernorm is not None + and self.k_layernorm is not None + and not self.qkv_proj_sp_enabled + ) + + def _prep_qkv_tensors_qwen_qk_norm_rope_nki( + self, + position_ids, + hidden_states, + past_key_value, + adapter_ids=None, + cos_cache=None, + sin_cache=None, + rmsnorm=None, + ): + Q, K, V, residual = self.get_qkv_proj()( + hidden_states=hidden_states, + rmsnorm=rmsnorm, + adapter_ids=adapter_ids, + residual=None, + ) + + bsz, q_len, _ = hidden_states.size() + V = move_heads_front( + V, + bsz, + q_len, + self.num_key_value_heads, + self.head_dim, + layernorm=None, + ) + if cos_cache is None or sin_cache is None: + cos_cache, sin_cache = self.rotary_emb(V, position_ids) + + Q, K = _qwen_qk_norm_partial_rope_kernel[self.logical_nc_config]( + Q, + K, + self.q_layernorm.weight.data, + self.k_layernorm.weight.data, + cos_cache, + sin_cache, + self.rms_norm_eps, + ) + return Q, K, V, cos_cache, sin_cache, residual + + def apply_rotary_embedding( + self, Q, K, V, position_ids, cos_cache, sin_cache, use_polar_compatible_rope + ): + """Partial RoPE: only apply rotary embedding to first rope_dim dimensions. + + Q shape: (B, H, S, head_dim) where head_dim=256 + cos/sin shape: (B, S, rope_dim) where rope_dim=64 (from RotaryEmbedding(dim=64)) + + Split Q/K along last dim into: + q_rope (first 64 dims) -- apply RoPE + q_pass (remaining 192 dims) -- pass through unchanged + """ + from neuronx_distributed_inference.modules.attention.utils import ( + apply_rotary_pos_emb, + ) + + if self.rotary_emb is not None: + if cos_cache is None or sin_cache is None: + cos_cache, sin_cache = self.rotary_emb(V, position_ids) + + # Split into rope and pass-through portions + Q_orig_dtype = Q.dtype + q_rope = Q[..., : self.rope_dim] # (B, H, S, 64) + q_pass = Q[..., self.rope_dim :] # (B, H, S, 192) + k_rope = K[..., : self.rope_dim] + k_pass = K[..., self.rope_dim :] + + # Apply RoPE only to the rope portion + q_rope, k_rope = apply_rotary_pos_emb(q_rope, k_rope, cos_cache, sin_cache) + + # Concatenate back (ensure bf16 is maintained) + Q = torch.cat([q_rope, q_pass], dim=-1).to(Q_orig_dtype) + K = torch.cat([k_rope, k_pass], dim=-1).to(Q_orig_dtype) + + return Q, K, cos_cache, sin_cache + + def perform_prefill(self, Q, K, V, q_len, bsz, attention_mask=None): + """Prefill path with NKI flash attention for head_dim=256.""" + head_dim = Q.shape[-1] + + # Option B: nkilib flash attention for head_dim > 128 + if _nkilib_flash_attn is not None: + q_contig = Q.contiguous() + k_contig = K.contiguous() + v_contig = V.contiguous() + scale = 1.0 / math.sqrt(head_dim) + result = _nkilib_flash_attn( + q_contig, k_contig, v_contig, scale=scale, use_causal_mask=True + ) + return result, None + + # Option A: kernel patched globally + if NKILIB_PATCH_ACTIVE: + return _flash_fwd_call(Q, K, V, use_causal_mask=True), None + + # Fallback: softmax path (use 3D tensors to avoid compiler ICE with 4D patterns) + if head_dim > 128: + # GQA: expand K/V heads to match Q heads + num_q_heads = Q.shape[1] + num_kv_heads = K.shape[1] + if num_q_heads != num_kv_heads: + kv_rep = num_q_heads // num_kv_heads + K = ( + K.unsqueeze(2) + .expand(-1, -1, kv_rep, -1, -1) + .reshape(bsz, num_q_heads, q_len, head_dim) + ) + V = ( + V.unsqueeze(2) + .expand(-1, -1, kv_rep, -1, -1) + .reshape(bsz, num_q_heads, q_len, head_dim) + ) + # Reshape to 3D (B*H, S, d) to avoid neuronx-cc codegen ICE with 4D + # attention weight tensors (NCC_INLA001: Expected 2D tensor but got 4D AP) + Q_3d = Q.reshape(bsz * num_q_heads, q_len, head_dim) + K_3d = K.reshape(bsz * num_q_heads, q_len, head_dim) + V_3d = V.reshape(bsz * num_q_heads, q_len, head_dim) + attn_weights = torch.bmm(Q_3d, K_3d.transpose(-1, -2)) / math.sqrt(head_dim) + # Build causal mask for 3D: (1, S, S) broadcast over B*H + causal_mask = torch.triu( + torch.full( + (q_len, q_len), + -65504.0, + dtype=attn_weights.dtype, + device=attn_weights.device, + ), + diagonal=1, + ).unsqueeze(0) + attn_weights = attn_weights + causal_mask + attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to( + Q.dtype + ) + attn_output = torch.bmm(attn_weights, V_3d) + # Reshape back to 4D (B, H, S, d) + return attn_output.reshape(bsz, num_q_heads, q_len, head_dim), None + + return _flash_fwd_call(Q, K, V, use_causal_mask=True), None + + def perform_qwen_chunked_prefill( + self, + Q, + K, + V, + past_key_value, + position_ids, + attention_mask=None, + kv_mgr=None, + idx=None, + active_block_table=None, + computed_context_lens=None, + scatter_index=None, + kvcache_buffer=None, + ): + """Exact chunked CTE over full-cache or selected-prefix KV. + + For model-local chunked prefill, the current chunk K/V tensors are + scattered into the full cache at absolute position_ids. For vLLM prefix + reuse, BlockKVCacheManager returns selected prefix blocks already + arranged as logical positions, so concatenate the current suffix K/V + after that logical prefix. + """ + k_cache, v_cache = past_key_value + B, q_heads, q_len, head_dim = Q.shape + kv_heads = K.shape[1] + use_segmented_prefix_cte = ( + getattr( + self.config.neuron_config, + "prefix_cte_attention_backend", + "attention_cte", + ) + == "segmented_cte" + and active_block_table is not None + and getattr(active_block_table, "ndim", 0) > 1 + ) + if use_segmented_prefix_cte: + if kv_mgr is None or idx is None or scatter_index is None: + raise ValueError( + "segmented_cte Qwen prefix prefill requires kv_mgr, idx, " + "and scatter_index so active KV can be written to block KV." + ) + updated_kv = kv_mgr.update_kv_by_layer_id( + idx=idx, + kv_per_layer=(K.to(self.torch_dtype), V.to(self.torch_dtype)), + scatter_index=scatter_index, + kvcache_buffer=kvcache_buffer, + ) + attn_output, _flash_strategy = self.perform_prefix_prefill_segmented_cte( + Q, + q_len, + B, + updated_kv, + active_block_table, + computed_context_lens, + ) + return attn_output.permute(0, 1, 3, 2).contiguous(), updated_kv + + if k_cache.shape[0] != B: + # The cache is allocated at kv_cache_batch_size, while CTE can trace a + # smaller active batch. Keep attention reshapes on the active batch. + k_cache = k_cache[:B] + v_cache = v_cache[:B] + cache_len = k_cache.shape[2] + + pos = position_ids.long() + selected_prefix_cache = cache_len < int( + getattr(self.config.neuron_config, "seq_len", cache_len) + ) + if selected_prefix_cache: + k_cache = torch.cat([k_cache, K.to(k_cache.dtype)], dim=2) + v_cache = torch.cat([v_cache, V.to(v_cache.dtype)], dim=2) + prefix_positions = torch.arange( + cache_len, + device=position_ids.device, + dtype=pos.dtype, + ).view(1, -1).expand(B, -1) + cache_positions = torch.cat([prefix_positions, pos], dim=1).view( + B, + 1, + 1, + -1, + ) + prefix_valid = torch.ones( + (B, cache_len), + device=position_ids.device, + dtype=torch.bool, + ) + if ( + attention_mask is not None + and attention_mask.ndim == 2 + and attention_mask.shape[1] == q_len + ): + active_valid = attention_mask.to(torch.bool) + else: + active_valid = torch.ones( + (B, q_len), + device=position_ids.device, + dtype=torch.bool, + ) + key_valid_mask = torch.cat([prefix_valid, active_valid], dim=1).view( + B, + 1, + 1, + -1, + ) + cache_len = k_cache.shape[2] + else: + k_index = pos[:, None, :, None].expand(B, kv_heads, q_len, head_dim) + k_cache = torch.scatter( + k_cache, + dim=2, + index=k_index, + src=K.to(k_cache.dtype), + ) + v_cache = torch.scatter( + v_cache, + dim=2, + index=k_index, + src=V.to(v_cache.dtype), + ) + cache_positions = torch.arange( + cache_len, + device=position_ids.device, + dtype=pos.dtype, + ).view(1, 1, 1, -1) + key_valid_mask = None + + prefix_attention_impl = _qwen36_prefix_attention_impl() + if prefix_attention_impl == "grouped": + attn_output = _qwen35_grouped_prefix_attention( + Q, + k_cache, + v_cache, + pos, + cache_positions, + key_valid_mask, + ) + else: + attn_output = _qwen35_expanded_prefix_attention( + Q, + k_cache, + v_cache, + pos, + cache_positions, + key_valid_mask, + ) + return attn_output, None + + def forward( + self, + hidden_states, + attention_mask=None, + position_ids=None, + past_key_value=None, + cos_cache=None, + sin_cache=None, + rmsnorm=None, + adapter_ids=None, + active_mask=None, + **kwargs, + ): + """Forward with output gate applied BEFORE o_proj. + + Override NeuronAttentionBase.forward() to insert the sigmoid gate + between the attention output and o_proj, matching the HF reference: + gate = sigmoid(gate_proj(pre_attn_hidden)) + attn_output = attn_output * gate + attn_output = o_proj(attn_output) + """ + bsz, q_len, _ = hidden_states.shape + + # Use standard 2D position_ids for prep_qkv_tensors. + rope_pos_ids = position_ids + + use_split_qkv_tkg = ( + self.qkv_tkg_nki_kernel_enabled + and past_key_value is not None + and q_len == 1 + ) + if self._should_use_qwen_qkv_gate_packed(q_len): + Q, K, V, gate, cos_cache, sin_cache, _residual = ( + self._prep_qkv_gate_packed_tensors( + rope_pos_ids, + hidden_states, + past_key_value, + adapter_ids=adapter_ids, + cos_cache=cos_cache, + sin_cache=sin_cache, + ) + ) + elif use_split_qkv_tkg: + gate = ( + self._output_gate_proj_nki(hidden_states) + if self._should_use_qwen_output_gate_nki(q_len) + else self.output_gate_proj(hidden_states) + ) + Q, K, V, cos_cache, sin_cache, _residual = ( + self._prep_split_qkv_tkg_tensors( + rope_pos_ids, + hidden_states, + past_key_value, + adapter_ids=adapter_ids, + cos_cache=cos_cache, + sin_cache=sin_cache, + ) + ) + elif self.qkv_tkg_nki_kernel_enabled: + raise ValueError( + "qkv_tkg_nki_kernel_enabled is only valid for single-token " + f"decode, got past_key_value={past_key_value is not None}, " + f"q_len={q_len}" + ) + else: + # Compute gate from input hidden states (before QKV projection). + if self._should_use_qwen_output_gate_nki(q_len): + gate = self._output_gate_proj_nki(hidden_states) + else: + gate = self.output_gate_proj(hidden_states) + + # Standard QKV prep (projections, QK norm, RoPE) + if self._should_use_qwen_qk_norm_rope_nki(q_len): + Q, K, V, cos_cache, sin_cache, _residual = ( + self._prep_qkv_tensors_qwen_qk_norm_rope_nki( + rope_pos_ids, + hidden_states, + past_key_value, + adapter_ids=adapter_ids, + cos_cache=cos_cache, + sin_cache=sin_cache, + rmsnorm=rmsnorm, + ) + ) + else: + Q, K, V, cos_cache, sin_cache, _residual = self.prep_qkv_tensors( + rope_pos_ids, + hidden_states, + past_key_value, + adapter_ids=adapter_ids, + cos_cache=cos_cache, + sin_cache=sin_cache, + rmsnorm=rmsnorm, + ) + + qwen_chunked_prefill_active = ( + past_key_value is not None + and q_len > 1 + and getattr(self.config, "use_qwen_hybrid_chunked_prefill", False) + ) + + if past_key_value is None: + # Context encoding (prefill) + attn_output, _flash_strategy = self.perform_prefill( + Q, K, V, q_len, bsz, attention_mask + ) + elif qwen_chunked_prefill_active: + attn_output, present_key_value = self.perform_qwen_chunked_prefill( + Q, + K, + V, + past_key_value, + position_ids, + attention_mask, + kv_mgr=kwargs.get("kv_mgr"), + idx=kwargs.get("idx"), + active_block_table=kwargs.get("active_block_table"), + computed_context_lens=kwargs.get("computed_context_lens"), + scatter_index=kwargs.get("scatter_index"), + kvcache_buffer=kwargs.get("kvcache_buffer"), + ) + else: + # Token generation (decode) + tkg_mask = attention_mask + if tkg_mask is not None and tkg_mask.ndim == 2: + tkg_mask = tkg_mask.unsqueeze(1).unsqueeze(2) # (B, S) -> (B, 1, 1, S) + attn_output = self.compute_for_token_gen( + Q, K, V, position_ids, past_key_value, tkg_mask, active_mask + ) + + # attn_output is (B, H, S, head_dim) -- transpose to (B, S, H*head_dim) + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim) + + o_proj = self.get_o_proj() + if self._should_use_qwen_gated_o_proj_nki(q_len): + attn_output = o_proj.forward_gated(attn_output, gate, adapter_ids=adapter_ids) + else: + # Apply sigmoid output gate BEFORE o_proj (matching HF reference) + attn_output = attn_output * torch.sigmoid(gate) + attn_output = o_proj(attn_output, adapter_ids=adapter_ids) + + # Ensure K, V are in model dtype (bf16) for KV cache update + # (prevents mixed-precision dynamic-update-slice in neuronx-cc) + K = K.to(self.torch_dtype) + V = V.to(self.torch_dtype) + if "present_key_value" not in locals() or present_key_value is None: + present_key_value = (K, V) + past_key_value = present_key_value + return attn_output, past_key_value, cos_cache, sin_cache + + +# ============================================================ +# Dense MLP (replaces MoE) +# ============================================================ + + +class Qwen35MLP(nn.Module): + """Dense SwiGLU MLP for Qwen3.5/3.6-27B. + + gate_proj: hidden_size -> intermediate_size (5120 -> 17408) + up_proj: hidden_size -> intermediate_size (5120 -> 17408) + down_proj: intermediate_size -> hidden_size (17408 -> 5120) + + output = down_proj(silu(gate_proj(x)) * up_proj(x)) + """ + + def __init__(self, config): + super().__init__() + self.gate_proj = ColumnParallelLinear( + config.hidden_size, + config.intermediate_size, + bias=False, + gather_output=False, + ) + self.up_proj = ColumnParallelLinear( + config.hidden_size, + config.intermediate_size, + bias=False, + gather_output=False, + ) + self.down_proj = RowParallelLinear( + config.intermediate_size, + config.hidden_size, + bias=False, + input_is_parallel=True, + ) + + def forward(self, hidden_states): + gate = self.gate_proj(hidden_states) + up = self.up_proj(hidden_states) + hidden_states = F.silu(gate) * up + hidden_states = self.down_proj(hidden_states) + return hidden_states + + +# ============================================================ +# Decoder Layer (hybrid dispatch -- DeltaNet or GQA + Dense MLP) +# ============================================================ + + +class NeuronQwen35DecoderLayer(nn.Module): + """Hybrid decoder layer: dispatches to DeltaNet or standard attention. + Uses dense MLP for all layers (no MoE). + """ + + def __init__(self, config: Qwen35InferenceConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + self.layer_type = config.layer_types[layer_idx] + self.layer_idx = layer_idx + self.config = config + + # Attention (DeltaNet or standard GQA) + if self.layer_type == "linear_attention": + self.linear_attn = NeuronGatedDeltaNet(config, layer_idx) + else: + self.self_attn = NeuronQwen35Attention(config=config) + + # Dense MLP (all layers). The reusable NxDI Llama MLP kernel supports + # both CTE and TKG; keep RMSNorm separate for CTE so normalization stays + # on the conservative high-precision path before FP8 GEMM quantization. + self.mlp_kernel_enabled = bool(config.neuron_config.mlp_kernel_enabled) + self.mlp_kernel_fused_rmsnorm = ( + self.mlp_kernel_enabled + and not config.neuron_config.sequence_parallel_enabled + ) + if self.mlp_kernel_enabled: + tensor_model_parallel_group = ( + parallel_state.get_tensor_model_parallel_group() + if parallel_state.model_parallel_is_initialized() + else None + ) + self.mlp = NeuronLlamaMLP(config, tensor_model_parallel_group) + else: + self.mlp = Qwen35MLP(config) + + self.input_layernorm = get_rmsnorm_cls()( + config.hidden_size, eps=config.rms_norm_eps + ) + self.post_attention_layernorm = get_rmsnorm_cls()( + config.hidden_size, eps=config.rms_norm_eps + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask=None, + position_ids=None, + past_key_value=None, + padding_mask=None, + cos_cache=None, + sin_cache=None, + **kwargs, + ): + residual = hidden_states + + hidden_states = ModuleMarkerStartWrapper()(hidden_states) + hidden_states = self.input_layernorm(hidden_states) + + if self.layer_type == "linear_attention": + # DeltaNet path + attn_out, dummy_kv, new_rec_state, new_conv_state = self.linear_attn( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + **kwargs, + ) + hidden_states = residual + attn_out + present_key_value = dummy_kv + deltanet_states = ( + None + if getattr(self.config, "use_hybrid_cache_manager", False) + else (new_rec_state, new_conv_state) + ) + else: + deltanet_states = None + # Standard attention path + hidden_states, present_key_value, cos_cache, sin_cache = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + cos_cache=cos_cache, + sin_cache=sin_cache, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Dense MLP FFN + residual = hidden_states + if self.mlp_kernel_enabled: + use_fused_mlp_rmsnorm = ( + self.mlp_kernel_fused_rmsnorm + and not bool(kwargs.get("is_for_context_encoding", False)) + and hidden_states.shape[1] == 1 + ) + if use_fused_mlp_rmsnorm: + mlp_fused_rmsnorm = self.post_attention_layernorm + else: + hidden_states = self.post_attention_layernorm(hidden_states) + mlp_fused_rmsnorm = None + hidden_states, _ = self.mlp(hidden_states, rmsnorm=mlp_fused_rmsnorm) + else: + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + hidden_states = ModuleMarkerEndWrapper()(hidden_states) + outputs = ( + hidden_states, + present_key_value, + cos_cache, + sin_cache, + None, + deltanet_states, + ) + return outputs + + +# ============================================================ +# Hybrid Cache Manager (opt-in) +# ============================================================ + + +class HybridDeltaNetCacheManager(KVCacheManager): + """Opt-in local/static cache manager for Qwen hybrid dense models. + + This manager stores DeltaNet recurrent/conv state by batch row and delegates + full-attention layers to the legacy KV manager. It is intentionally not a + production vLLM APC manager: block ownership, prefix hashes, refcounts, + eviction, continuous batching, and tenant isolation must remain in the + vLLM/NxDI block-cache lifecycle. + """ + + def __init__(self, config: Qwen35InferenceConfig, num_kv_head, **kwargs): + self.layer_types = list(config.layer_types) + self._validate_hybrid_config(config) + super().__init__(config, num_kv_head=num_kv_head, **kwargs) + + dtype = ( + config.neuron_config.attention_dtype + if config.neuron_config.attention_dtype is not None + else config.neuron_config.torch_dtype + ) + cache_dtype = getattr(self, "cache_dtype", dtype) + recurrent_cache_dtype = _torch_dtype_from_hybrid_cache_dtype( + config.hybrid_recurrent_cache_dtype + ) + conv_cache_dtype = _torch_dtype_from_hybrid_cache_dtype( + config.hybrid_conv_cache_dtype + ) + max_batch_size = ( + config.neuron_config.kv_cache_batch_size + + config.neuron_config.kv_cache_padding_size + ) + tp_degree = config.neuron_config.tp_degree + if config.linear_num_value_heads % tp_degree != 0: + raise ValueError( + f"linear_num_value_heads={config.linear_num_value_heads} must be divisible " + f"by tp_degree={tp_degree}" + ) + if config.linear_num_key_heads % tp_degree != 0: + raise ValueError( + f"linear_num_key_heads={config.linear_num_key_heads} must be divisible " + f"by tp_degree={tp_degree}" + ) + local_num_value_heads = config.linear_num_value_heads // tp_degree + local_num_key_heads = config.linear_num_key_heads // tp_degree + recurrent_shape = [ + max_batch_size, + local_num_value_heads, + config.linear_key_head_dim, + config.linear_value_head_dim, + ] + conv_dim = ( + 2 * local_num_key_heads * config.linear_key_head_dim + + local_num_value_heads * config.linear_value_head_dim + ) + conv_shape = [ + max_batch_size, + conv_dim, + config.linear_conv_kernel_dim - 1, + ] + + params = [] + for layer_idx, layer_type in enumerate(self.layer_types): + if layer_type == "linear_attention": + params.append( + nn.Parameter( + torch.zeros(recurrent_shape, dtype=recurrent_cache_dtype), + requires_grad=False, + ) + ) + params.append( + nn.Parameter( + torch.zeros(conv_shape, dtype=conv_cache_dtype), + requires_grad=False, + ) + ) + else: + k_shape = self.k_shapes[layer_idx] if hasattr(self, "k_shapes") else self.k_shape + v_shape = self.v_shapes[layer_idx] if hasattr(self, "v_shapes") else self.v_shape + params.append( + nn.Parameter(torch.zeros(k_shape, dtype=cache_dtype), requires_grad=False) + ) + params.append( + nn.Parameter(torch.zeros(v_shape, dtype=cache_dtype), requires_grad=False) + ) + + self.past_key_values = nn.ParameterList(params) + + @staticmethod + def _validate_hybrid_config(config: Qwen35InferenceConfig): + nc = config.neuron_config + unsupported = [] + if nc.is_block_kv_layout: + unsupported.append("block KV layout") + if getattr(nc, "kv_quant_config", None) is not None or getattr(nc, "kv_cache_quant", False): + unsupported.append("KV cache quantization") + if nc.enable_fused_speculation or nc.speculation_length > 0 or nc.is_medusa: + unsupported.append("speculative decoding") + if getattr(nc, "enable_eagle_speculation", False) or getattr(nc, "is_eagle_draft", False): + unsupported.append("EAGLE speculation") + if nc.flash_decoding_enabled: + unsupported.append("flash decoding") + if nc.attention_dp_degree > 1: + unsupported.append("attention data parallelism") + if nc.kv_cache_tiling: + unsupported.append("KV cache tiling") + if nc.padding_side != "right": + unsupported.append("left padding") + if nc.is_continuous_batching: + unsupported.append("continuous batching") + if unsupported: + raise ValueError( + "HybridDeltaNetCacheManager v1 does not support: " + + ", ".join(unsupported) + ) + + def _is_deltanet_layer(self, idx: int) -> bool: + return self.layer_types[idx] == "linear_attention" + + def get_seq_length(self, past_key_values=None): + for idx, layer_type in enumerate(self.layer_types): + if layer_type != "linear_attention": + if past_key_values is None: + _, v_cache = self._fetch_cache(idx) + elif len(past_key_values) == len(self.past_key_values): + v_cache = past_key_values[2 * idx + 1] + else: + v_cache = past_key_values[idx][1] + return v_cache.shape[2] + return 0 + + def get_deltanet_state_by_layer_id(self, idx, kvcache_buffer=None, seq_ids=None): + recurrent_state, conv_state = self._fetch_cache(idx, kvcache_buffer) + if seq_ids is not None: + cache_idx = self.get_cache_update_index_for_seq_ids(seq_ids) + recurrent_state = torch.index_select(recurrent_state, dim=0, index=cache_idx) + conv_state = torch.index_select(conv_state, dim=0, index=cache_idx) + elif self.kv_cache_padding_size > 0: + recurrent_state = recurrent_state[: -self.kv_cache_padding_size] + conv_state = conv_state[: -self.kv_cache_padding_size] + return recurrent_state, conv_state + + def get_cache( + self, + seq_len: int, + skip_slice=False, + kvcache_buffer=None, + seq_ids=None, + windowed_context_encoding_window_idx=-1, + **kwargs, + ): + past_key_values = [] + for idx in range(len(self.past_key_values) // 2): + if self._is_deltanet_layer(idx): + past_key_values.append( + list(self.get_deltanet_state_by_layer_id(idx, kvcache_buffer, seq_ids)) + ) + else: + past_key_values.append( + list( + self.get_kv_by_layer_id( + idx=idx, + skip_slice=skip_slice, + seq_len=seq_len, + kvcache_buffer=kvcache_buffer, + seq_ids=seq_ids, + windowed_context_encoding_window_idx=windowed_context_encoding_window_idx, + **kwargs, + ) + ) + ) + return past_key_values + + def update_cache( + self, + is_for_context_encoding: bool, + seq_ids: torch.Tensor, + position_ids: torch.Tensor, + new_key_values: List[torch.Tensor], + seq_len: int, + scatter_index=None, + kv_active_mask=None, + kvcache_buffer=None, + windowed_context_encoding_window_idx: int = -1, + **kwargs, + ): + updated_cache = [] + for idx, kv_per_layer in enumerate(new_key_values): + if self._is_deltanet_layer(idx): + recurrent_state, conv_state = self.update_deltanet_state_by_layer_id( + idx=idx, + seq_ids=seq_ids, + state_per_layer=kv_per_layer, + kvcache_buffer=kvcache_buffer, + ) + elif kwargs.get("qwen_chunked_prefill_update", False): + recurrent_state, conv_state = self.update_qwen_chunked_kv_by_layer_id( + idx=idx, + seq_ids=seq_ids, + position_ids=position_ids, + kv_per_layer=kv_per_layer, + kvcache_buffer=kvcache_buffer, + valid_mask=kwargs.get("qwen_chunked_valid_mask", None), + ) + else: + recurrent_state, conv_state = self.update_kv_by_layer_id( + idx=idx, + is_for_context_encoding=is_for_context_encoding, + seq_ids=seq_ids, + position_ids=position_ids, + kv_per_layer=kv_per_layer, + seq_len=seq_len, + scatter_index=scatter_index, + kv_active_mask=kv_active_mask, + kvcache_buffer=kvcache_buffer, + windowed_context_encoding_window_idx=windowed_context_encoding_window_idx, + **kwargs, + ) + updated_cache.append(recurrent_state) + updated_cache.append(conv_state) + return updated_cache + + def update_qwen_chunked_kv_by_layer_id( + self, + idx: int, + seq_ids: torch.Tensor, + position_ids: torch.Tensor, + kv_per_layer: Tuple[torch.Tensor, torch.Tensor], + kvcache_buffer=None, + valid_mask=None, + ): + latest_k, latest_v = kv_per_layer + k_cache, v_cache = self._fetch_cache(idx, kvcache_buffer) + latest_k = latest_k.to(k_cache.dtype) + latest_v = latest_v.to(v_cache.dtype) + + if seq_ids is not None: + cache_idx = self.get_cache_update_index_for_seq_ids(seq_ids) + selected_k = torch.index_select(k_cache, dim=0, index=cache_idx) + selected_v = torch.index_select(v_cache, dim=0, index=cache_idx) + else: + cache_idx = None + selected_k = k_cache[: latest_k.shape[0]] + selected_v = v_cache[: latest_v.shape[0]] + + pos = position_ids.long() + k_index = pos[:, None, :, None].expand_as(latest_k) + v_index = pos[:, None, :, None].expand_as(latest_v) + + if valid_mask is not None: + valid = valid_mask.to(torch.bool)[:, None, :, None] + old_k = torch.gather(selected_k, dim=2, index=k_index) + old_v = torch.gather(selected_v, dim=2, index=v_index) + latest_k = torch.where(valid, latest_k, old_k) + latest_v = torch.where(valid, latest_v, old_v) + + updated_k = torch.scatter(selected_k, dim=2, index=k_index, src=latest_k) + updated_v = torch.scatter(selected_v, dim=2, index=v_index, src=latest_v) + + if cache_idx is not None: + k_row_index = cache_idx.view(-1, 1, 1, 1).expand_as(updated_k) + v_row_index = cache_idx.view(-1, 1, 1, 1).expand_as(updated_v) + k_cache = torch.scatter(k_cache, dim=0, index=k_row_index, src=updated_k) + v_cache = torch.scatter(v_cache, dim=0, index=v_row_index, src=updated_v) + return k_cache, v_cache + + if updated_k.shape[0] == k_cache.shape[0]: + return updated_k + k_cache * 0, updated_v + v_cache * 0 + + pad_rows = k_cache.shape[0] - updated_k.shape[0] + if pad_rows > 0: + updated_k = torch.cat([updated_k, k_cache[updated_k.shape[0] :] * 0], dim=0) + updated_v = torch.cat([updated_v, v_cache[updated_v.shape[0] :] * 0], dim=0) + return updated_k + k_cache * 0, updated_v + v_cache * 0 + + def update_deltanet_state_by_layer_id( + self, + idx: int, + seq_ids: torch.Tensor, + state_per_layer: Tuple[torch.Tensor, torch.Tensor], + kvcache_buffer=None, + ): + latest_recurrent, latest_conv = state_per_layer + recurrent_cache, conv_cache = self._fetch_cache(idx, kvcache_buffer) + latest_recurrent = latest_recurrent.to(recurrent_cache.dtype) + latest_conv = latest_conv.to(conv_cache.dtype) + + if latest_recurrent.shape[0] == recurrent_cache.shape[0] and seq_ids is None: + return ( + latest_recurrent + recurrent_cache * 0, + latest_conv + conv_cache * 0, + ) + + if seq_ids is not None: + cache_idx = self.get_cache_update_index_for_seq_ids(seq_ids) + recurrent_index = cache_idx.view(-1, 1, 1, 1).expand_as(latest_recurrent) + conv_index = cache_idx.view(-1, 1, 1).expand_as(latest_conv) + recurrent_cache = torch.scatter( + input=recurrent_cache, + dim=0, + index=recurrent_index, + src=latest_recurrent, + ) + conv_cache = torch.scatter( + input=conv_cache, + dim=0, + index=conv_index, + src=latest_conv, + ) + return recurrent_cache, conv_cache + + pad_size = recurrent_cache.shape[0] - latest_recurrent.shape[0] + if pad_size > 0: + latest_recurrent = torch.cat( + [latest_recurrent, recurrent_cache[latest_recurrent.shape[0] :] * 0], + dim=0, + ) + latest_conv = torch.cat( + [latest_conv, conv_cache[latest_conv.shape[0] :] * 0], + dim=0, + ) + return latest_recurrent + recurrent_cache * 0, latest_conv + conv_cache * 0 + + +class QwenHybridBlockKVCacheManager(BlockKVCacheManager): + """Block KV manager that allocates real KV only for full-attention layers.""" + + _LINEAR_PLACEHOLDER_SHAPE = (1, 1, 1, 1) + + def __init__(self, config: Qwen35InferenceConfig, num_kv_head, **kwargs): + self.layer_types = list(config.layer_types) + super().__init__(config, num_kv_head=num_kv_head, **kwargs) + + params = [] + for layer_type in self.layer_types: + if layer_type == "full_attention": + params.append( + nn.Parameter( + torch.zeros(self.k_shape, dtype=self.cache_dtype), + requires_grad=False, + ) + ) + params.append( + nn.Parameter( + torch.zeros(self.v_shape, dtype=self.cache_dtype), + requires_grad=False, + ) + ) + else: + params.append( + nn.Parameter( + torch.zeros( + self._LINEAR_PLACEHOLDER_SHAPE, + dtype=self.cache_dtype, + ), + requires_grad=False, + ) + ) + params.append( + nn.Parameter( + torch.zeros( + self._LINEAR_PLACEHOLDER_SHAPE, + dtype=self.cache_dtype, + ), + requires_grad=False, + ) + ) + self.past_key_values = nn.ParameterList(params) + + def _is_attention_layer(self, idx: int) -> bool: + return self.layer_types[idx] == "full_attention" + + def get_seq_length(self, past_key_values=None): + for idx, layer_type in enumerate(self.layer_types): + if layer_type == "full_attention": + if past_key_values is None: + _, v_cache = self._fetch_cache(idx) + elif len(past_key_values) == len(self.past_key_values): + v_cache = past_key_values[2 * idx + 1] + else: + v_cache = past_key_values[idx][1] + if v_cache.ndim >= 4 and v_cache.shape[1] == self.pa_block_size: + return self.pa_num_blocks * self.pa_block_size + return v_cache.shape[2] + return 0 + + def get_cache(self, active_block_table=None, kvcache_buffer=None, **kwargs): + past_key_values = [] + use_segmented_prefix_cte = ( + kwargs.get("is_for_context_encoding", False) + and getattr( + self.neuron_config, + "prefix_cte_attention_backend", + "attention_cte", + ) + == "segmented_cte" + and active_block_table is not None + and getattr(active_block_table, "ndim", 0) > 1 + ) + for idx in range(len(self.past_key_values) // 2): + if self._is_attention_layer(idx): + if use_segmented_prefix_cte: + k_cache, v_cache = self.get_raw_kv_by_layer_id( + idx, + kvcache_buffer=kvcache_buffer, + ) + else: + k_cache, v_cache = self.get_kv_by_layer_id( + idx, + active_block_table, + kvcache_buffer=kvcache_buffer, + **kwargs, + ) + else: + k_cache, v_cache = self._fetch_cache( + idx, + kvcache_buffer=kvcache_buffer, + ) + past_key_values.append([k_cache, v_cache]) + return past_key_values + + def _is_raw_block_kv_pair(self, kv_per_layer: List[torch.Tensor]) -> bool: + if len(kv_per_layer) != 2: + return False + k_cache, v_cache = kv_per_layer + return ( + k_cache.ndim == 4 + and v_cache.ndim == 4 + and k_cache.shape[0] == self.pa_num_blocks + self._NUM_EXTRA_RESERVED_BLOCK + and v_cache.shape[0] == self.pa_num_blocks + self._NUM_EXTRA_RESERVED_BLOCK + and k_cache.shape[1] == self.pa_block_size + and v_cache.shape[1] == self.pa_block_size + ) + + def update_cache( + self, + new_key_values: List[torch.Tensor], + scatter_index=None, + kvcache_buffer=None, + **kwargs, + ): + updated_kv_cache = [] + for idx, kv_per_layer in enumerate(new_key_values): + if self._is_attention_layer(idx) and self._is_raw_block_kv_pair( + kv_per_layer + ): + k_cache, v_cache = kv_per_layer + elif self._is_attention_layer(idx): + k_cache, v_cache = self.update_kv_by_layer_id( + idx=idx, + kv_per_layer=kv_per_layer, + scatter_index=scatter_index, + kvcache_buffer=kvcache_buffer, + ) + else: + k_cache, v_cache = self._fetch_cache( + idx, + kvcache_buffer=kvcache_buffer, + ) + k_cache = k_cache * 1 + v_cache = v_cache * 1 + updated_kv_cache.append(k_cache) + updated_kv_cache.append(v_cache) + return updated_kv_cache + + +class HybridGDNCheckpointCache(nn.Module): + """Bounded device-side GDN prefix checkpoint bank. + + Metadata owns prefix hashes, refcounts, and eviction. This module only owns + recurrent/conv tensors addressed by checkpoint slot IDs supplied by the + scheduler/request-prep path. + """ + + def __init__(self, config: Qwen35InferenceConfig): + super().__init__() + self.gdn_layer_ids = tuple( + idx + for idx, layer_type in enumerate(config.layer_types) + if layer_type == "linear_attention" + ) + if not self.gdn_layer_ids: + raise ValueError("HybridGDNCheckpointCache requires GDN layers") + self.layer_to_bank_index = { + layer_id: bank_idx for bank_idx, layer_id in enumerate(self.gdn_layer_ids) + } + self.num_checkpoint_slots = int(config.max_gdn_checkpoint_slots) + if self.num_checkpoint_slots <= 0: + raise ValueError("max_gdn_checkpoint_slots must be positive") + + tp_degree = config.neuron_config.tp_degree + if config.linear_num_value_heads % tp_degree != 0: + raise ValueError("linear_num_value_heads must be divisible by tp_degree") + if config.linear_num_key_heads % tp_degree != 0: + raise ValueError("linear_num_key_heads must be divisible by tp_degree") + + self.local_num_value_heads = config.linear_num_value_heads // tp_degree + self.local_num_key_heads = config.linear_num_key_heads // tp_degree + self.key_dim = config.linear_key_head_dim + self.value_dim = config.linear_value_head_dim + self.conv_dim = ( + 2 * self.local_num_key_heads * config.linear_key_head_dim + + self.local_num_value_heads * config.linear_value_head_dim + ) + self.conv_state_len = config.linear_conv_kernel_dim - 1 + self.recurrent_dtype = _torch_dtype_from_hybrid_cache_dtype( + config.hybrid_recurrent_cache_dtype + ) + self.conv_dtype = _torch_dtype_from_hybrid_cache_dtype( + config.hybrid_conv_cache_dtype + ) + + recurrent_shape = ( + self.num_checkpoint_slots, + self.local_num_value_heads, + self.key_dim, + self.value_dim, + ) + conv_shape = ( + self.num_checkpoint_slots, + self.conv_dim, + self.conv_state_len, + ) + self.recurrent_slots = nn.ParameterList( + [ + nn.Parameter( + torch.zeros(recurrent_shape, dtype=self.recurrent_dtype), + requires_grad=False, + ) + for _ in self.gdn_layer_ids + ] + ) + self.conv_slots = nn.ParameterList( + [ + nn.Parameter( + torch.zeros(conv_shape, dtype=self.conv_dtype), + requires_grad=False, + ) + for _ in self.gdn_layer_ids + ] + ) + + @property + def checkpoint_params(self): + params = [] + for recurrent_slot, conv_slot in zip(self.recurrent_slots, self.conv_slots): + params.append(recurrent_slot) + params.append(conv_slot) + return params + + def bytes_per_checkpoint_per_rank(self) -> int: + recurrent_numel = ( + len(self.gdn_layer_ids) + * self.local_num_value_heads + * self.key_dim + * self.value_dim + ) + conv_numel = len(self.gdn_layer_ids) * self.conv_dim * self.conv_state_len + recurrent_bytes = 4 if self.recurrent_dtype == torch.float32 else 2 + conv_bytes = 4 if self.conv_dtype == torch.float32 else 2 + return recurrent_numel * recurrent_bytes + conv_numel * conv_bytes + + def _safe_slot_ids( + self, + slot_ids: torch.Tensor, + batch_size: int | None = None, + ) -> torch.Tensor: + slot_ids = slot_ids.reshape(-1).long().clamp( + min=0, + max=self.num_checkpoint_slots - 1, + ) + if batch_size is None: + return slot_ids + if slot_ids.shape[0] >= batch_size: + return slot_ids[:batch_size] + pad = torch.zeros( + (batch_size - slot_ids.shape[0],), + dtype=slot_ids.dtype, + device=slot_ids.device, + ) + return torch.cat([slot_ids, pad], dim=0) + + @staticmethod + def _safe_bool_vector( + mask: torch.Tensor, + batch_size: int, + device: torch.device, + ) -> torch.Tensor: + mask = mask.reshape(-1).to(device=device, dtype=torch.bool) + if mask.shape[0] >= batch_size: + return mask[:batch_size] + pad = torch.zeros( + (batch_size - mask.shape[0],), + dtype=torch.bool, + device=device, + ) + return torch.cat([mask, pad], dim=0) + + @staticmethod + def _active_rows( + state: torch.Tensor, + seq_ids: torch.Tensor | None, + batch_size: int, + ) -> torch.Tensor: + if seq_ids is not None and hasattr(seq_ids, "numel") and seq_ids.numel() > 0: + safe_seq_ids = seq_ids.reshape(-1)[:batch_size].to( + device=state.device, + dtype=torch.long, + ) + safe_seq_ids = safe_seq_ids.clamp(min=0, max=int(state.shape[0]) - 1) + return torch.index_select(state, 0, safe_seq_ids) + return state[:batch_size] + + def restore_to_active_rows( + self, + *, + layers: nn.ModuleList, + seq_ids: torch.Tensor | None, + checkpoint_slot_ids: torch.Tensor | None, + restore_mask: torch.Tensor | None, + zero_inactive: bool = False, + ) -> dict[int, tuple[torch.Tensor, torch.Tensor]] | None: + if checkpoint_slot_ids is None or restore_mask is None: + return None + batch_size = max( + int(checkpoint_slot_ids.reshape(-1).shape[0]), + int(restore_mask.reshape(-1).shape[0]), + ) + if batch_size <= 0: + return None + slot_ids = self._safe_slot_ids(checkpoint_slot_ids, batch_size) + restore_mask = self._safe_bool_vector( + restore_mask, + batch_size, + slot_ids.device, + ) + slot_ids = torch.where(restore_mask, slot_ids, torch.zeros_like(slot_ids)) + rec_mask = restore_mask.view(batch_size, 1, 1, 1) + conv_mask = restore_mask.view(batch_size, 1, 1) + + restored = {} + for bank_idx, layer_id in enumerate(self.gdn_layer_ids): + linear_attn = layers[layer_id].linear_attn + active_recurrent = self._active_rows( + linear_attn.recurrent_state_buffer, seq_ids, batch_size + ) + active_conv = self._active_rows( + linear_attn.conv_state_buffer, seq_ids, batch_size + ) + if zero_inactive: + inactive_recurrent = torch.zeros_like(active_recurrent) + inactive_conv = torch.zeros_like(active_conv) + else: + inactive_recurrent = active_recurrent + inactive_conv = active_conv + slot_recurrent = torch.index_select( + self.recurrent_slots[bank_idx], 0, slot_ids + ).to(active_recurrent.dtype) + slot_conv = torch.index_select(self.conv_slots[bank_idx], 0, slot_ids).to( + active_conv.dtype + ) + _debug_qwen36_hybrid_gdn_state( + "restore_slot_recurrent", + slot_recurrent, + layer_id=layer_id, + bank_idx=bank_idx, + slot_ids=slot_ids, + mask=restore_mask, + seq_ids=seq_ids, + ) + _debug_qwen36_hybrid_gdn_state( + "restore_slot_conv", + slot_conv, + layer_id=layer_id, + bank_idx=bank_idx, + slot_ids=slot_ids, + mask=restore_mask, + seq_ids=seq_ids, + ) + restored_recurrent = torch.where( + rec_mask, slot_recurrent, inactive_recurrent + ) + restored_conv = torch.where(conv_mask, slot_conv, inactive_conv) + _debug_qwen36_hybrid_gdn_state( + "restore_active_recurrent", + restored_recurrent, + layer_id=layer_id, + bank_idx=bank_idx, + slot_ids=slot_ids, + mask=restore_mask, + seq_ids=seq_ids, + ) + _debug_qwen36_hybrid_gdn_state( + "restore_active_conv", + restored_conv, + layer_id=layer_id, + bank_idx=bank_idx, + slot_ids=slot_ids, + mask=restore_mask, + seq_ids=seq_ids, + ) + restored[layer_id] = (restored_recurrent, restored_conv) + return restored + + def commit_from_active_rows( + self, + *, + layer_state_pairs: list[tuple[int, torch.Tensor, torch.Tensor]], + seq_ids: torch.Tensor | None, + checkpoint_slot_ids: torch.Tensor | None, + commit_mask: torch.Tensor | None, + ) -> list[torch.Tensor]: + if checkpoint_slot_ids is None or commit_mask is None: + return self.identity_outputs() + batch_size = max( + int(checkpoint_slot_ids.reshape(-1).shape[0]), + int(commit_mask.reshape(-1).shape[0]), + ) + if batch_size <= 0: + return self.identity_outputs() + slot_ids = self._safe_slot_ids(checkpoint_slot_ids, batch_size) + commit_mask = self._safe_bool_vector( + commit_mask, + batch_size, + slot_ids.device, + ) + slot_ids = torch.where(commit_mask, slot_ids, torch.zeros_like(slot_ids)) + rec_mask = commit_mask.view(batch_size, 1, 1, 1) + conv_mask = commit_mask.view(batch_size, 1, 1) + + state_by_layer = { + layer_id: (recurrent_state, conv_state) + for layer_id, recurrent_state, conv_state in layer_state_pairs + } + + def _commit_rows(slots, rows, row_mask): + output = slots * 1 + slot_axis = torch.arange( + slots.shape[0], dtype=slot_ids.dtype, device=slot_ids.device + ) + broadcast_shape = (slots.shape[0],) + (1,) * (slots.ndim - 1) + for row_idx in range(batch_size): + write_mask = torch.logical_and( + row_mask[row_idx], + slot_axis == slot_ids[row_idx], + ).view(broadcast_shape) + row_value = rows[row_idx : row_idx + 1].expand_as(output) + output = torch.where(write_mask, row_value, output) + return output + + outputs = [] + for bank_idx, layer_id in enumerate(self.gdn_layer_ids): + recurrent_slots = self.recurrent_slots[bank_idx] + conv_slots = self.conv_slots[bank_idx] + if layer_id not in state_by_layer: + outputs.append(recurrent_slots * 1) + outputs.append(conv_slots * 1) + continue + + recurrent_state, conv_state = state_by_layer[layer_id] + recurrent_rows = self._active_rows(recurrent_state, seq_ids, batch_size).to( + recurrent_slots.dtype + ) + conv_rows = self._active_rows(conv_state, seq_ids, batch_size).to( + conv_slots.dtype + ) + _debug_qwen36_hybrid_gdn_state( + "commit_input_recurrent", + recurrent_rows, + layer_id=layer_id, + bank_idx=bank_idx, + slot_ids=slot_ids, + mask=commit_mask, + seq_ids=seq_ids, + ) + _debug_qwen36_hybrid_gdn_state( + "commit_input_conv", + conv_rows, + layer_id=layer_id, + bank_idx=bank_idx, + slot_ids=slot_ids, + mask=commit_mask, + seq_ids=seq_ids, + ) + + committed_recurrent = _commit_rows( + recurrent_slots, recurrent_rows, commit_mask + ) + committed_conv = _commit_rows(conv_slots, conv_rows, commit_mask) + committed_recurrent_rows = torch.index_select( + committed_recurrent, 0, slot_ids + ) + committed_conv_rows = torch.index_select(committed_conv, 0, slot_ids) + _debug_qwen36_hybrid_gdn_state( + "commit_slot_recurrent", + committed_recurrent_rows, + layer_id=layer_id, + bank_idx=bank_idx, + slot_ids=slot_ids, + mask=commit_mask, + seq_ids=seq_ids, + ) + _debug_qwen36_hybrid_gdn_state( + "commit_slot_conv", + committed_conv_rows, + layer_id=layer_id, + bank_idx=bank_idx, + slot_ids=slot_ids, + mask=commit_mask, + seq_ids=seq_ids, + ) + + outputs.append(committed_recurrent) + outputs.append(committed_conv) + return outputs + + def identity_outputs(self) -> list[torch.Tensor]: + return [param * 1 for param in self.checkpoint_params] + + +# ============================================================ +# Model +# ============================================================ + + +def _effective_lm_head_pad_size(lm_head, logits, config): + pad_size = getattr(lm_head, "pad_size", None) + if not pad_size: + return pad_size + + if getattr(lm_head, "gather_output", False): + vocab_size = getattr(config, "vocab_size", None) + if vocab_size is not None: + return max(int(logits.shape[-1]) - int(vocab_size), 0) + + return pad_size + + +def _debug_tensor_minmax(tensor): + if tensor is None or not hasattr(tensor, "numel") or tensor.numel() == 0: + return "empty" + flat = tensor.reshape(-1) + return f"{int(flat.min().item())}:{int(flat.max().item())}" + + +def _debug_tensor_values(tensor, limit=8): + if tensor is None or not hasattr(tensor, "numel") or tensor.numel() == 0: + return [] + return tensor.reshape(-1)[:limit].tolist() + + +def _debug_tensor_shape(tensor): + if tensor is None or not hasattr(tensor, "shape"): + return None + return tuple(tensor.shape) + + +def _normalize_qwen36_slot_mapping(slot_mapping, batch_size: int, active_tokens: int): + if ( + slot_mapping is None + or not hasattr(slot_mapping, "numel") + or slot_mapping.numel() == 0 + or not hasattr(slot_mapping, "ndim") + ): + return slot_mapping + if slot_mapping.ndim != 1: + return slot_mapping + + batch_size = int(batch_size) + active_tokens = int(active_tokens) + total_slots = int(slot_mapping.numel()) + if batch_size > 0 and active_tokens > 0 and total_slots == batch_size * active_tokens: + return slot_mapping.reshape(batch_size, active_tokens) + if batch_size == 1: + return slot_mapping.reshape(1, total_slots) + return slot_mapping + + +def _use_legacy_tkg_args() -> bool: + return os.environ.get("QWEN36_TKG_LEGACY_ARGS") == "1" + + +def _qwen36_config_flag(config, neuron_config, name: str, default: bool = False) -> bool: + for owner in (config, neuron_config, getattr(config, "neuron_config", None)): + value = getattr(owner, name, None) + if value is not None: + return bool(value) + return bool(default) + + +def _use_expanded_hybrid_args_for_tag(config, tag: str) -> bool: + if not _qwen36_config_flag(config, None, "use_hybrid_apc_manager"): + return False + # The legacy ABI experiment intentionally keeps both traced stages on the + # older prefix-cache contract. Neuron prunes the extra CTE hybrid metadata + # inputs from the serialized trace, so runtime must not send them either. + if _use_legacy_tkg_args(): + return False + if tag == CONTEXT_ENCODING_MODEL_TAG: + return True + if tag == TOKEN_GENERATION_MODEL_TAG: + return True + return False + + +def _qwen36_shape_entry_arg_count(entry) -> int | None: + if isinstance(entry, str): + try: + entry = json.loads(entry) + except Exception: + return None + if isinstance(entry, (list, tuple)): + return len(entry) + return None + + +def _qwen36_compiled_arg_count(model_wrapper) -> int | None: + counts = [] + for owner in ( + model_wrapper, + getattr(model_wrapper, "model", None), + getattr(getattr(model_wrapper, "model", None), "nxd_model", None), + ): + shape_map = getattr(owner, "input_shape_map", None) + keys = getattr(shape_map, "keys", None) + if not callable(keys): + continue + try: + iterable = keys() + except Exception: + continue + for entry in iterable: + count = _qwen36_shape_entry_arg_count(entry) + if count is not None: + counts.append(count) + return max(counts) if counts else None + + +def _use_expanded_hybrid_args_for_wrapper(model_wrapper, tag: str) -> bool: + compiled_arg_count = _qwen36_compiled_arg_count(model_wrapper) + if compiled_arg_count is not None: + return compiled_arg_count >= 29 + return _use_expanded_hybrid_args_for_tag(model_wrapper.config, tag) + + +def _qwen36_expected_arg_count(config, tag: str) -> int: + return 29 if _use_expanded_hybrid_args_for_tag(config, tag) else 24 + + +def _assert_qwen36_arg_count(stage: str, args, expected: int) -> None: + actual = len(args) + if actual != expected: + raise RuntimeError( + f"Qwen3.6 {stage} argument contract mismatch: " + f"expected {expected} tensors, got {actual}" + ) + + +_QWEN36_PREFIX_ARG_NAMES = ( + "input_ids", + "attention_mask", + "position_ids", + "seq_ids", + "sampling_params", + "prev_hidden", + "adapter_ids", + "accepted_indices", + "current_length", + "medusa_mask", + "scatter_index", + "slot_mapping", + "block_table", + "num_queries", + "computed_context_lens", + "tile_q_indices", + "tile_block_tables", + "tile_masks", + "inputs_embeds", + "kv_cache", + "active_mask", +) +_QWEN36_MROPE_VISION_ARG_NAMES = ( + "rotary_position_ids", + "vision_embeddings", + "vision_mask", +) +_QWEN36_HYBRID_APC_ARG_NAMES = ( + "hybrid_restore_slot_ids", + "hybrid_restore_mask", + "hybrid_restore_prefix_lens", + "hybrid_commit_slot_ids", + "hybrid_commit_mask", +) + + +def _empty_qwen36_arg(): + return torch.empty(0) + + +def _qwen36_arg_names(config, tag: str): + names = list(_QWEN36_PREFIX_ARG_NAMES + _QWEN36_MROPE_VISION_ARG_NAMES) + if _use_expanded_hybrid_args_for_tag(config, tag): + names.extend(_QWEN36_HYBRID_APC_ARG_NAMES) + return names + + +def _normalize_qwen36_prefix_args(prefix_args): + args = list(prefix_args) + if len(args) > len(_QWEN36_PREFIX_ARG_NAMES): + raise RuntimeError( + "Qwen3.6 prefix argument contract mismatch: " + f"expected at most {len(_QWEN36_PREFIX_ARG_NAMES)} base tensors, " + f"got {len(args)}" + ) + while len(args) < len(_QWEN36_PREFIX_ARG_NAMES): + args.append(_empty_qwen36_arg()) + return args + + +def _normalize_qwen36_hybrid_args(hybrid_args, batch_size): + args = list(hybrid_args or ()) + while len(args) < len(_QWEN36_HYBRID_APC_ARG_NAMES): + args.append(torch.zeros((batch_size,), dtype=torch.int32)) + if len(args) > len(_QWEN36_HYBRID_APC_ARG_NAMES): + raise RuntimeError( + "Qwen3.6 Hybrid APC argument contract mismatch: " + f"expected {len(_QWEN36_HYBRID_APC_ARG_NAMES)} tensors, got {len(args)}" + ) + return args + + +def _build_qwen36_stage_args( + config, + tag: str, + prefix_args, + mrope_position_ids, + vision_embeddings, + vision_mask, + hybrid_args=None, +): + args = _normalize_qwen36_prefix_args(prefix_args) + args.extend([mrope_position_ids, vision_embeddings, vision_mask]) + if _use_expanded_hybrid_args_for_tag(config, tag): + batch_size = args[0].shape[0] + args.extend(_normalize_qwen36_hybrid_args(hybrid_args, batch_size)) + _assert_qwen36_arg_count(tag, args, _qwen36_expected_arg_count(config, tag)) + return args + + +def build_cte_args( + config, + prefix_args, + mrope_position_ids, + vision_embeddings, + vision_mask, + hybrid_args=None, +): + return _build_qwen36_stage_args( + config, + CONTEXT_ENCODING_MODEL_TAG, + prefix_args, + mrope_position_ids, + vision_embeddings, + vision_mask, + hybrid_args=hybrid_args, + ) + + +def build_tkg_args( + config, + prefix_args, + mrope_position_ids, + vision_embeddings, + vision_mask, + hybrid_args=None, +): + return _build_qwen36_stage_args( + config, + TOKEN_GENERATION_MODEL_TAG, + prefix_args, + mrope_position_ids, + vision_embeddings, + vision_mask, + hybrid_args=hybrid_args, + ) + + +def _debug_qwen36_arg_contract(stage: str, tag: str, config, args) -> None: + if ( + os.environ.get("QWEN36_ARG_CONTRACT_DEBUG") != "1" + and os.environ.get("QWEN36_HYBRID_APC_DEBUG") != "1" + ): + return + + names = _qwen36_arg_names(config, tag) + print( + f"[qwen36_arg_contract] stage={stage} tag={tag} argc={len(args)}", + flush=True, + ) + for idx, (name, value) in enumerate(zip(names, args)): + shape = _debug_tensor_shape(value) + dtype = getattr(value, "dtype", None) + min_value = "empty" + max_value = "empty" + if value is not None and hasattr(value, "numel") and value.numel() > 0: + try: + flat = value.detach().reshape(-1) if hasattr(value, "detach") else value.reshape(-1) + min_value = flat.min().item() + max_value = flat.max().item() + except Exception as exc: + min_value = f"error:{type(exc).__name__}" + max_value = f"error:{type(exc).__name__}" + print( + "[qwen36_arg_contract] " + f"stage={stage} tag={tag} index={idx} name={name} " + f"shape={shape} dtype={dtype} min={min_value} max={max_value}", + flush=True, + ) + + +def _debug_qwen36_flat_values(value) -> str: + if value is None: + return "None" + if not hasattr(value, "reshape"): + return repr(value) + try: + flat = value.detach().reshape(-1) if hasattr(value, "detach") else value.reshape(-1) + return repr(flat.tolist()) + except Exception as exc: + return f"error:{type(exc).__name__}" + + +def _debug_qwen36_hybrid_gdn_state( + tag: str, + tensor: torch.Tensor, + *, + layer_id: int, + bank_idx: int, + slot_ids: torch.Tensor, + mask: torch.Tensor, + seq_ids: torch.Tensor | None, +) -> None: + if os.environ.get("QWEN36_HYBRID_GDN_STATE_DEBUG") != "1": + return + shape = _debug_tensor_shape(tensor) + dtype = getattr(tensor, "dtype", None) + total = 0 + finite_count = "error" + nan_count = "error" + posinf_count = "error" + neginf_count = "error" + max_abs = "error" + mean_abs = "error" + try: + flat = tensor.detach().float().reshape(-1) + total = int(flat.numel()) + if total > 0: + finite = torch.isfinite(flat) + finite_i = finite.to(torch.int32) + finite_count = int(finite_i.sum().item()) + nan_count = int(torch.isnan(flat).to(torch.int32).sum().item()) + posinf_count = int(torch.isposinf(flat).to(torch.int32).sum().item()) + neginf_count = int(torch.isneginf(flat).to(torch.int32).sum().item()) + safe = torch.where(finite, flat, torch.zeros_like(flat)).abs() + max_abs = float(safe.max().item()) + mean_abs = float((safe.sum() / max(finite_count, 1)).item()) + else: + finite_count = 0 + nan_count = 0 + posinf_count = 0 + neginf_count = 0 + max_abs = "empty" + mean_abs = "empty" + except Exception as exc: + finite_count = f"error:{type(exc).__name__}" + nan_count = f"error:{type(exc).__name__}" + posinf_count = f"error:{type(exc).__name__}" + neginf_count = f"error:{type(exc).__name__}" + max_abs = f"error:{type(exc).__name__}" + mean_abs = f"error:{type(exc).__name__}" + + print( + "[qwen36_hybrid_gdn_state] " + f"tag={tag} layer={layer_id} bank={bank_idx} " + f"slot_ids={_debug_qwen36_flat_values(slot_ids)} " + f"mask={_debug_qwen36_flat_values(mask)} " + f"seq_ids={_debug_qwen36_flat_values(seq_ids)} " + f"shape={shape} dtype={dtype} finite={finite_count}/{total} " + f"nan={nan_count} posinf={posinf_count} neginf={neginf_count} " + f"max_abs={max_abs} mean_abs={mean_abs}", + flush=True, + ) + + +def _validate_qwen36_tkg_input_ids(input_ids, vocab_size) -> None: + if input_ids is None or not hasattr(input_ids, "numel") or input_ids.numel() == 0: + raise ValueError("Qwen3.6 TKG input_ids must be a non-empty tensor") + if input_ids.dtype not in (torch.int32, torch.int64): + raise ValueError( + "Qwen3.6 TKG input_ids must be int32 or int64, " + f"got {input_ids.dtype}" + ) + min_id = int(input_ids.min().item()) + max_id = int(input_ids.max().item()) + if min_id < 0: + raise ValueError(f"Qwen3.6 TKG input_ids contains negative token id {min_id}") + if vocab_size is not None and max_id >= int(vocab_size): + raise ValueError( + "Qwen3.6 TKG input_ids contains out-of-vocab token id " + f"{max_id}; vocab_size={int(vocab_size)}" + ) + + +def _qwen36_query_lengths(full_context_lens, computed_context_lens) -> list[int] | None: + if ( + full_context_lens is None + or computed_context_lens is None + or not hasattr(full_context_lens, "numel") + or not hasattr(computed_context_lens, "numel") + or full_context_lens.numel() == 0 + or computed_context_lens.numel() == 0 + ): + return None + full_values = full_context_lens.reshape(-1).to(torch.int64) + computed_values = computed_context_lens.reshape(-1).to(torch.int64) + count = min(int(full_values.numel()), int(computed_values.numel())) + if count <= 0: + return None + return [ + max(0, int(full_values[idx].item()) - int(computed_values[idx].item())) + for idx in range(count) + ] + + +def _qwen36_prefill_has_incomplete_row(prefill_completion_state) -> bool: + if prefill_completion_state is None: + return False + if hasattr(prefill_completion_state, "numel"): + if prefill_completion_state.numel() == 0: + return False + return not bool(prefill_completion_state.reshape(-1).to(torch.bool).all().item()) + try: + values = list(prefill_completion_state) + except TypeError: + return not bool(prefill_completion_state) + return any(not bool(value) for value in values) + + +def _qwen36_hybrid_apc_mask_has_active_row(mask) -> bool: + if mask is None: + return False + if hasattr(mask, "numel"): + if mask.numel() == 0: + return False + try: + return bool(mask.reshape(-1).to(torch.bool).any().item()) + except (RuntimeError, TypeError, ValueError): + # If a non-empty control tensor cannot be inspected on the host, keep + # the existing controls and avoid preparing the request twice. + return True + try: + values = list(mask) + except TypeError: + return bool(mask) + return any(bool(value) for value in values) + + +def _qwen36_hybrid_apc_controls_need_prepare( + hybrid_restore_mask, + hybrid_commit_mask, +) -> bool: + return not ( + _qwen36_hybrid_apc_mask_has_active_row(hybrid_restore_mask) + or _qwen36_hybrid_apc_mask_has_active_row(hybrid_commit_mask) + ) + + +def _qwen36_hybrid_apc_controls_materialized( + hybrid_restore_mask, + hybrid_restore_prefix_lens, + hybrid_commit_mask, +) -> bool: + return not _qwen36_hybrid_apc_controls_need_prepare( + hybrid_restore_mask, + hybrid_commit_mask, + ) or _qwen36_hybrid_apc_mask_has_active_row(hybrid_restore_prefix_lens) + + +def _qwen36_is_prefill_request( + input_ids, + position_ids, + *, + full_context_lens=None, + computed_context_lens=None, + prefill_completion_state=None, +) -> bool: + if _qwen36_prefill_has_incomplete_row(prefill_completion_state): + return True + + query_lengths = _qwen36_query_lengths(full_context_lens, computed_context_lens) + if ( + query_lengths is not None + and len(query_lengths) > 1 + and input_ids.ndim >= 2 + and input_ids.shape[0] == 1 + and input_ids.shape[-1] == len(query_lengths) + ): + return any(query_len > 1 for query_len in query_lengths) + + # Warm prefix-cache suffixes may start at a nonzero position, but they are + # still multi-token CTE requests. TKG must remain a one-token decode path. + if input_ids.shape[-1] > 1: + return True + return position_ids.min().item() == 0 + + +def _qwen36_deltanet_padding_mask( + *, + input_ids, + inputs_embeds, + attention_mask, + padding_idx, + is_for_context_encoding, + hybrid_restore_mask=None, + num_queries=None, +): + if padding_idx is None: + token_padding_mask = torch.ones( + (*input_ids.shape, 1), + dtype=inputs_embeds.dtype, + device=inputs_embeds.device, + ) + else: + token_padding_mask = ( + (input_ids != padding_idx).unsqueeze(-1).to(inputs_embeds.dtype) + ) + + query_padding_mask = None + if ( + is_for_context_encoding + and num_queries is not None + and hasattr(num_queries, "numel") + and num_queries.numel() >= input_ids.shape[0] + ): + query_lens = num_queries.reshape(-1)[: input_ids.shape[0]].to( + device=inputs_embeds.device, + dtype=torch.long, + ) + positions = torch.arange( + input_ids.shape[1], + device=inputs_embeds.device, + dtype=torch.long, + ) + query_padding_mask = ( + positions.unsqueeze(0) < query_lens.unsqueeze(1) + ).unsqueeze(-1).to(inputs_embeds.dtype) + + if ( + is_for_context_encoding + and query_padding_mask is not None + ): + deltanet_padding_mask = query_padding_mask + elif ( + is_for_context_encoding + and attention_mask is not None + and attention_mask.ndim == 2 + ): + attention_padding_mask = attention_mask.unsqueeze(-1).to(inputs_embeds.dtype) + if attention_padding_mask.shape[1] == inputs_embeds.shape[1]: + deltanet_padding_mask = attention_padding_mask + else: + deltanet_padding_mask = token_padding_mask + else: + deltanet_padding_mask = token_padding_mask + + if ( + is_for_context_encoding + and hybrid_restore_mask is not None + and hasattr(hybrid_restore_mask, "numel") + and hybrid_restore_mask.numel() > 0 + ): + restore_active = hybrid_restore_mask.reshape(-1).to(torch.bool) + if restore_active.numel() < input_ids.shape[0]: + restore_active = torch.cat( + [ + restore_active, + torch.zeros( + input_ids.shape[0] - restore_active.numel(), + dtype=torch.bool, + device=restore_active.device, + ), + ], + dim=0, + ) + restore_active = restore_active[: input_ids.shape[0]].to( + device=inputs_embeds.device + ).view(-1, 1, 1) + deltanet_padding_mask = torch.where( + restore_active, + token_padding_mask, + deltanet_padding_mask, + ) + return deltanet_padding_mask + + +def _qwen36_unpack_packed_decode_batch( + *, + input_ids, + attention_mask, + position_ids, + seq_ids, + adapter_ids, + slot_mapping, + full_context_lens, + computed_context_lens, +): + query_lengths = _qwen36_query_lengths(full_context_lens, computed_context_lens) + if ( + query_lengths is None + or len(query_lengths) <= 1 + or any(query_len > 1 for query_len in query_lengths) + or input_ids.ndim < 2 + or input_ids.shape[0] != 1 + or input_ids.shape[-1] != len(query_lengths) + ): + return input_ids, attention_mask, position_ids, seq_ids, adapter_ids, slot_mapping + + batch_size = len(query_lengths) + + def _unpack_token_rows(value): + if ( + value is not None + and hasattr(value, "ndim") + and value.ndim >= 2 + and value.shape[0] == 1 + and value.shape[1] == batch_size + ): + return value.reshape(batch_size, 1, *value.shape[2:]).contiguous() + return value + + def _repair_batch_vector(value, *, fill_from_index: bool = False): + if value is None or not hasattr(value, "numel") or value.numel() == 0: + return value + flattened = value.reshape(-1) + if flattened.numel() == batch_size: + return flattened + if flattened.numel() == 1 and batch_size > 1: + if fill_from_index: + return torch.arange( + batch_size, + dtype=value.dtype, + device=value.device, + ) + return flattened[:1].expand(batch_size).contiguous() + return value + + input_ids = _unpack_token_rows(input_ids) + position_ids = _unpack_token_rows(position_ids) + slot_mapping = _unpack_token_rows(slot_mapping) + if ( + attention_mask is not None + and hasattr(attention_mask, "ndim") + and attention_mask.ndim >= 2 + and attention_mask.shape[0] == 1 + and attention_mask.shape[1] == batch_size + and computed_context_lens is not None + and hasattr(computed_context_lens, "numel") + and computed_context_lens.numel() >= batch_size + ): + context_lens = computed_context_lens.reshape(-1).to(torch.int64)[:batch_size] + max_context_len = max(1, int(context_lens.max().item())) + repaired_mask = torch.zeros( + (batch_size, max_context_len), + dtype=attention_mask.dtype, + device=attention_mask.device, + ) + for row_idx, context_len in enumerate(context_lens): + active_len = max(0, min(int(context_len.item()), max_context_len)) + if active_len: + repaired_mask[row_idx, :active_len] = 1 + attention_mask = repaired_mask + if ( + slot_mapping is not None + and hasattr(slot_mapping, "ndim") + and slot_mapping.ndim == 1 + and int(slot_mapping.numel()) == batch_size + ): + slot_mapping = slot_mapping.reshape(batch_size, 1).contiguous() + seq_ids = _repair_batch_vector(seq_ids, fill_from_index=True) + adapter_ids = _repair_batch_vector(adapter_ids) + return input_ids, attention_mask, position_ids, seq_ids, adapter_ids, slot_mapping + + +def _qwen36_hashable_request_id(request_id: Any) -> Hashable: + if isinstance(request_id, list): + return tuple(request_id) + try: + hash(request_id) + except TypeError: + return repr(request_id) + return request_id + + +def _qwen36_metadata_for_request( + metadata_by_request_id, + request_id, +) -> dict[str, Any] | None: + if not isinstance(metadata_by_request_id, dict): + return None + normalized = _qwen36_hashable_request_id(request_id) + metadata = metadata_by_request_id.get(normalized) + if metadata is None and request_id is not None: + metadata = metadata_by_request_id.get(str(request_id)) + return metadata if isinstance(metadata, dict) else None + + +def _qwen36_request_metadata_values( + metadata_by_request_id, + request_ids, + key: str, +): + if request_ids is None: + return None + if isinstance(request_ids, list): + request_ids = tuple(request_ids) + elif not isinstance(request_ids, tuple): + request_ids = (request_ids,) + + values = [] + found = False + for request_id in request_ids: + metadata = _qwen36_metadata_for_request(metadata_by_request_id, request_id) + value = metadata.get(key) if metadata is not None else None + values.append(value) + found = found or value is not None + if not found: + return None + return values[0] if len(values) == 1 else tuple(values) + + +def _qwen36_request_ids_have_metadata( + metadata_by_request_id, + request_ids, +) -> bool: + return any( + _qwen36_request_metadata_values( + metadata_by_request_id, + request_ids, + key, + ) + is not None + for key in ( + "cumulative_hashes_by_prefix_len", + "attention_block_refs_by_prefix_len", + "request_prefix_len", + "vllm_attention_hit_len", + ) + ) + + +def _qwen36_select_vllm_hybrid_apc_request_ids( + metadata_by_request_id, + *request_id_groups, +): + first_present = None + for request_ids in request_id_groups: + if request_ids is None: + continue + if first_present is None: + first_present = request_ids + if _qwen36_request_ids_have_metadata(metadata_by_request_id, request_ids): + return request_ids + return first_present + + +def _qwen36_flat_item_count(value: Any) -> int: + if value is None: + return 0 + if hasattr(value, "numel"): + try: + return int(value.reshape(-1).numel()) + except Exception: + return 0 + if isinstance(value, (list, tuple)): + return len(value) + return 1 + + +def _qwen36_pad_batch_repeat_first(value, target_batch): + if value is None or not hasattr(value, "numel") or value.numel() == 0: + return value + if value.ndim == 0 or value.shape[0] >= target_batch: + return value + pad_n = target_batch - value.shape[0] + return torch.cat([value, value[:1].expand(pad_n, *value.shape[1:])], dim=0) + + +def _qwen36_pad_batch_with_value(value, target_batch, fill_value): + if value is None or not hasattr(value, "numel") or value.numel() == 0: + return value + if value.ndim == 0 or value.shape[0] >= target_batch: + return value + pad_shape = (target_batch - value.shape[0],) + tuple(value.shape[1:]) + pad = torch.full(pad_shape, fill_value, dtype=value.dtype, device=value.device) + return torch.cat([value, pad], dim=0) + + +def _qwen36_pad_hybrid_restore_controls_for_dummy_cte_rows( + restore_slot_ids, + restore_mask, + restore_prefix_lens, + target_batch, +): + return ( + _qwen36_pad_batch_with_value(restore_slot_ids, target_batch, 0), + _qwen36_pad_batch_with_value(restore_mask, target_batch, 0), + _qwen36_pad_batch_with_value(restore_prefix_lens, target_batch, 0), + ) + + +def _qwen36_update_state_rows_by_seq_ids(previous_state, new_rows, seq_ids): + if ( + previous_state is None + or new_rows is None + or seq_ids is None + or not hasattr(previous_state, "shape") + or not hasattr(new_rows, "shape") + or not hasattr(seq_ids, "numel") + or previous_state.ndim != new_rows.ndim + or previous_state.shape[1:] != new_rows.shape[1:] + or previous_state.shape[0] <= 0 + or new_rows.shape[0] <= 0 + or seq_ids.numel() == 0 + ): + return new_rows + + row_count = min(int(new_rows.shape[0]), int(seq_ids.reshape(-1).shape[0])) + if row_count <= 0: + return previous_state * 1 + + output = previous_state * 1 + seq_ids_flat = seq_ids.reshape(-1)[:row_count].to( + device=previous_state.device, + dtype=torch.long, + ) + slot_axis = torch.arange( + int(previous_state.shape[0]), + dtype=torch.long, + device=previous_state.device, + ) + broadcast_shape = (int(previous_state.shape[0]),) + ( + 1, + ) * (previous_state.ndim - 1) + typed_rows = new_rows[:row_count].to(previous_state.dtype) + for row_idx in range(row_count): + seq_id = seq_ids_flat[row_idx] + valid_seq = torch.logical_and( + seq_id >= 0, + seq_id < int(previous_state.shape[0]), + ) + write_mask = torch.logical_and(valid_seq, slot_axis == seq_id).view( + broadcast_shape + ) + row_value = typed_rows[row_idx : row_idx + 1].expand_as(output) + output = torch.where(write_mask, row_value, output) + return output + + +def _qwen36_preserve_inactive_state_rows(new_state, previous_state, active_rows): + if ( + new_state is None + or previous_state is None + or active_rows is None + or not hasattr(new_state, "shape") + or not hasattr(previous_state, "shape") + or not hasattr(active_rows, "numel") + or new_state.shape != previous_state.shape + or active_rows.numel() == 0 + ): + return new_state + active_rows = active_rows.reshape(-1).to(device=new_state.device, dtype=torch.bool) + row_count = min(int(active_rows.numel()), int(new_state.shape[0])) + if row_count <= 0: + return new_state + if row_count < int(new_state.shape[0]): + active_rows = torch.cat( + [ + active_rows[:row_count], + torch.ones( + int(new_state.shape[0]) - row_count, + dtype=torch.bool, + device=new_state.device, + ), + ], + dim=0, + ) + else: + active_rows = active_rows[: int(new_state.shape[0])] + view_shape = (int(new_state.shape[0]),) + (1,) * (new_state.ndim - 1) + active_rows = active_rows.view(view_shape) + return torch.where(active_rows, new_state, previous_state) + + +def _qwen36_active_state_rows(valid_mask_1d, seq_ids): + if ( + valid_mask_1d is None + or not hasattr(valid_mask_1d, "numel") + or valid_mask_1d.numel() == 0 + ): + return None + active_rows = valid_mask_1d.squeeze(-1).to(torch.bool).any(dim=-1) + if seq_ids is not None and hasattr(seq_ids, "numel") and seq_ids.numel() > 0: + seq_active = seq_ids.reshape(-1).to( + device=active_rows.device, + dtype=torch.long, + )[: active_rows.numel()] >= 0 + active_rows = active_rows & seq_active + return active_rows + + +def _qwen36_request_ids_tuple(request_ids): + if request_ids is None: + return None + if isinstance(request_ids, list): + return tuple(request_ids) + if isinstance(request_ids, tuple): + return request_ids + return (request_ids,) + + +def _qwen36_request_ids_from_hybrid_apc_records(records): + if records is None: + return None + if isinstance(records, dict): + records = (records,) + elif isinstance(records, list): + records = tuple(records) + if not isinstance(records, tuple): + return None + request_ids = [] + for record in records: + if not isinstance(record, dict): + return None + request_id = record.get("request_id") + if request_id is None: + return None + request_ids.append(request_id) + return tuple(request_ids) if request_ids else None + + +def _qwen36_max_seq_slots_for_request_ids(model, seq_ids, request_count): + max_slots = int(request_count or 0) + for owner in ( + model, + getattr(model, "neuron_config", None), + getattr(getattr(model, "context_encoding_model", None), "neuron_config", None), + getattr(getattr(model, "token_generation_model", None), "neuron_config", None), + ): + for attr in ("batch_size", "max_batch_size", "max_num_seqs"): + value = getattr(owner, attr, None) + if value is None: + continue + try: + max_slots = max(max_slots, int(value)) + except (TypeError, ValueError): + pass + if seq_ids is not None and hasattr(seq_ids, "numel") and seq_ids.numel() > 0: + flat = seq_ids.reshape(-1) + try: + non_negative = flat[flat >= 0] + if non_negative.numel() > 0: + max_slots = max(max_slots, int(non_negative.max().item()) + 1) + except Exception: + pass + return max(1, max_slots) + + +def _qwen36_stable_seq_ids_for_request_ids(model, seq_ids, request_ids): + request_ids = _qwen36_request_ids_tuple(request_ids) + if not request_ids: + return seq_ids + + normalized_request_ids = tuple( + _qwen36_hashable_request_id(request_id) for request_id in request_ids + ) + if any(request_id is None for request_id in normalized_request_ids): + return seq_ids + + slot_by_request = getattr(model, "_qwen36_hybrid_seq_slot_by_request", None) + request_by_slot = getattr(model, "_qwen36_hybrid_request_by_seq_slot", None) + if not isinstance(slot_by_request, dict) or not isinstance(request_by_slot, dict): + slot_by_request = {} + request_by_slot = {} + setattr(model, "_qwen36_hybrid_seq_slot_by_request", slot_by_request) + setattr(model, "_qwen36_hybrid_request_by_seq_slot", request_by_slot) + + max_slots = _qwen36_max_seq_slots_for_request_ids( + model, + seq_ids, + len(normalized_request_ids), + ) + active_request_ids = set(normalized_request_ids) + for stale_slot, stale_owner in list(request_by_slot.items()): + if stale_owner in active_request_ids: + continue + request_by_slot.pop(stale_slot, None) + slot_by_request.pop(stale_owner, None) + + assigned_slots = [] + for request_id in normalized_request_ids: + slot = slot_by_request.get(request_id) + if slot is None or slot < 0 or slot >= max_slots: + free_slots = [ + candidate + for candidate in range(max_slots) + if candidate not in request_by_slot + ] + if not free_slots: + return seq_ids + slot = free_slots[0] + slot_by_request[request_id] = slot + request_by_slot[slot] = request_id + assigned_slots.append(slot) + + dtype = seq_ids.dtype if hasattr(seq_ids, "dtype") else torch.int32 + if seq_ids is not None and hasattr(seq_ids, "device"): + device = seq_ids.device + else: + device = None + kwargs = {"dtype": dtype} + if device is not None: + kwargs["device"] = device + return torch.tensor(assigned_slots, **kwargs) + + +def _qwen36_select_vllm_hybrid_apc_request_ids_for_input( + metadata_by_request_id, + *, + all_request_ids, + new_request_ids, + full_context_lens, + computed_context_lens, + prefill_completion_state, +): + all_request_ids_tuple = _qwen36_request_ids_tuple(all_request_ids) + logical_request_count = max( + _qwen36_flat_item_count(full_context_lens), + _qwen36_flat_item_count(computed_context_lens), + _qwen36_flat_item_count(prefill_completion_state), + ) + if ( + logical_request_count > 1 + and all_request_ids_tuple is not None + and len(all_request_ids_tuple) == logical_request_count + ): + # Keep request identity aligned with the model row order. In mixed + # cached/new prefill batches, scheduler "new" ids can be a strict + # subset, but the metadata vectors still describe every model row. + return all_request_ids_tuple + return _qwen36_select_vllm_hybrid_apc_request_ids( + metadata_by_request_id, + new_request_ids, + all_request_ids, + ) + + +def _qwen36_add_vllm_hybrid_apc_metadata( + hybrid_apc_request_dict: dict[str, Any], + *, + request_ids, + metadata_by_request_id, +) -> None: + for key in ( + "cumulative_hashes_by_prefix_len", + "attention_block_refs_by_prefix_len", + "request_prefix_len", + "vllm_attention_hit_len", + "active_suffix_len", + "full_input_ids", + ): + value = _qwen36_request_metadata_values( + metadata_by_request_id, + request_ids, + key, + ) + if value is not None: + if key == "full_input_ids" and not isinstance(value, torch.Tensor): + input_ids = hybrid_apc_request_dict.get("input_ids") + dtype = ( + input_ids.dtype + if isinstance(input_ids, torch.Tensor) + else torch.int64 + ) + device = ( + input_ids.device + if isinstance(input_ids, torch.Tensor) + else None + ) + value = torch.tensor([list(value)], dtype=dtype, device=device) + hybrid_apc_request_dict[key] = value + + +def _debug_logits_stage(stage: str, tensor) -> None: + if os.environ.get("QWEN36_LOGIT_STAGE_DEBUG") != "1": + return + if tensor is None or not hasattr(tensor, "numel"): + print( + f"[qwen36_logits_debug] stage={stage} tensor=none", + flush=True, + ) + return + if tensor.numel() == 0: + print( + f"[qwen36_logits_debug] stage={stage} " + f"shape={tuple(tensor.shape)} dtype={tensor.dtype} device={tensor.device} empty", + flush=True, + ) + return + + try: + with torch.no_grad(): + flat = tensor.detach().reshape(-1) + if torch.is_floating_point(flat): + finite_mask = torch.isfinite(flat) + finite_count = int(finite_mask.sum().item()) + nan_count = int(torch.isnan(flat).sum().item()) + posinf_count = int( + torch.logical_and(torch.isinf(flat), flat > 0).sum().item() + ) + neginf_count = int( + torch.logical_and(torch.isinf(flat), flat < 0).sum().item() + ) + if finite_count: + finite_flat = flat[finite_mask].float() + finite_min = float(finite_flat.min().item()) + finite_max = float(finite_flat.max().item()) + else: + finite_min = "none" + finite_max = "none" + print( + "[qwen36_logits_debug] " + f"stage={stage} shape={tuple(tensor.shape)} dtype={tensor.dtype} " + f"device={tensor.device} numel={tensor.numel()} finite={finite_count} " + f"nan={nan_count} posinf={posinf_count} neginf={neginf_count} " + f"finite_min={finite_min} finite_max={finite_max}", + flush=True, + ) + else: + print( + "[qwen36_logits_debug] " + f"stage={stage} shape={tuple(tensor.shape)} dtype={tensor.dtype} " + f"device={tensor.device} numel={tensor.numel()} " + f"minmax={_debug_tensor_minmax(tensor)}", + flush=True, + ) + except Exception as exc: + print( + "[qwen36_logits_debug] " + f"stage={stage} summary_error={type(exc).__name__}: {exc}", + flush=True, + ) + + +def _qwen36_output_logits_for_return(logits, lm_head, neuron_config): + if not ( + getattr(neuron_config, "output_logits", False) + and getattr(neuron_config, "on_device_sampling_config", None) is not None + and not getattr(lm_head, "gather_output", True) + ): + return logits + return _gather_along_dim( + logits, + partition_dim=2, + process_group=getattr(lm_head, "tensor_parallel_group", None), + ) + + +class NeuronQwen35Model(NeuronBaseModel): + def setup_attr_for_model(self, config: Qwen35InferenceConfig): + self.on_device_sampling = ( + config.neuron_config.on_device_sampling_config is not None + ) + self.tp_degree = config.neuron_config.tp_degree + self.hidden_size = config.hidden_size + self.num_attention_heads = config.num_attention_heads + self.num_key_value_heads = config.num_key_value_heads + self.max_batch_size = config.neuron_config.max_batch_size + self.buckets = config.neuron_config.buckets + + def init_model(self, config: Qwen35InferenceConfig): + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = ParallelEmbedding( + config.vocab_size, + config.hidden_size, + self.padding_idx, + dtype=config.neuron_config.torch_dtype, + shard_across_embedding=True, + ) + self.layers = nn.ModuleList( + [ + NeuronQwen35DecoderLayer(config, layer_idx) + for layer_idx in range(config.num_hidden_layers) + ] + ) + self.norm = get_rmsnorm_cls()(self.hidden_size, eps=config.rms_norm_eps) + self.lm_head = ColumnParallelLinear( + config.hidden_size, + config.vocab_size, + gather_output=False if self.on_device_sampling else True, + bias=False, + ) + + # mRoPE embedding for VL + self.mrope_emb = Qwen35MRoPEEmbedding(config) + + def init_inference_optimization(self, config: Qwen35InferenceConfig): + super().init_inference_optimization(config) + if getattr(config, "use_hybrid_apc_manager", False): + if getattr(config.neuron_config, "is_block_kv_layout", False): + self.kv_mgr = QwenHybridBlockKVCacheManager( + config, + num_kv_head=self.num_key_value_heads, + ) + self.hybrid_gdn_checkpoint_cache = HybridGDNCheckpointCache(config) + elif getattr(config, "use_hybrid_cache_manager", False): + self.kv_mgr = HybridDeltaNetCacheManager( + config, + num_kv_head=self.num_key_value_heads, + global_rank=self.rank_util, + attention_chunk_size=self.attention_chunk_size, + sliding_window=self.sliding_window, + windowed_context_encoding_size=self.windowed_context_encoding_size, + layer_to_cache_size_mapping=self.layer_to_cache_size_mapping, + ) + + @property + def _deltanet_state_params(self): + """Return DeltaNet state nn.Parameters in alias order.""" + params = [] + for layer in self.layers: + if hasattr(layer, "linear_attn"): + params.append(layer.linear_attn.recurrent_state_buffer) + params.append(layer.linear_attn.conv_state_buffer) + return params + + @property + def _hybrid_gdn_checkpoint_params(self): + if not hasattr(self, "hybrid_gdn_checkpoint_cache"): + return [] + return self.hybrid_gdn_checkpoint_cache.checkpoint_params + + def encode_vision_to_input(self, inputs_embeds, vision_embeddings, vision_mask): + """Scatter vision embeddings into text input embeddings at image token positions.""" + _, max_positions, embedding_dim = inputs_embeds.shape + h_new = inputs_embeds.clone() + vision_flat = vision_embeddings.view(-1, embedding_dim) + positions_flat = vision_mask.view(-1) + h_new.view(-1, embedding_dim).index_put_( + (positions_flat,), vision_flat, accumulate=False + ) + return h_new + + def get_model_output( + self, + input_ids=None, + seq_ids=None, + attention_mask=None, + position_ids=None, + past_key_values=None, + active_mask=None, + inputs_embeds=None, + prev_hidden=None, + adapter_ids=None, + rotary_position_ids=None, + update_cache=False, + is_for_context_encoding=False, + vision_embeddings=None, + vision_mask=None, + hybrid_restore_slot_ids=None, + hybrid_restore_mask=None, + hybrid_restore_prefix_lens=None, + hybrid_commit_slot_ids=None, + hybrid_commit_mask=None, + local_attn_mask=None, + windowed_context_encoding_window_idx=-1, + padding_mask=None, + **kwargs, + ): + """Override to collect DeltaNet state tensors from decoder layers.""" + batch_size, seq_length = input_ids.shape[:2] + if self.config.neuron_config.layer_boundary_markers: + input_ids = ModuleMarkerStartWrapper()(input_ids) + + past_key_values_length = 0 + if past_key_values is not None: + if hasattr(self.kv_mgr, "get_seq_length"): + past_key_values_length = self.kv_mgr.get_seq_length(past_key_values) + else: + past_key_values_length = past_key_values[0][1].shape[2] + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + # CRITICAL: Zero out embeddings for padding tokens so DeltaNet recurrence + # is not polluted. DeltaNet has no attention mask -- it processes all + # sequence positions through a linear recurrence. Padding tokens have + # real embedding vectors which corrupt the recurrence state. + # The mask is [B, S, 1] float with 1.0 for real tokens, 0.0 for padding. + deltanet_padding_mask = _qwen36_deltanet_padding_mask( + input_ids=input_ids, + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + padding_idx=self.padding_idx, + is_for_context_encoding=is_for_context_encoding, + hybrid_restore_mask=hybrid_restore_mask, + num_queries=kwargs.get("num_queries"), + ) + if is_for_context_encoding: + inputs_embeds = inputs_embeds * deltanet_padding_mask + + # Vision embedding injection. Text-only calls still pass dummy vision + # tensors to keep the traced input signature stable; those tensors have + # one dummy entry per text token and must not overwrite text embeddings. + if (vision_embeddings is not None) and (vision_mask is not None): + if vision_embeddings.dtype != self.config.neuron_config.torch_dtype: + vision_embeddings = vision_embeddings.to( + self.config.neuron_config.torch_dtype + ) + has_real_vision_inputs = ( + vision_embeddings.ndim == 3 + and vision_mask.ndim == 3 + and vision_embeddings.shape[1] != seq_length + ) + if is_for_context_encoding and has_real_vision_inputs: + inputs_embeds = self.encode_vision_to_input( + inputs_embeds, vision_embeddings, vision_mask + ) + elif is_for_context_encoding and vision_embeddings.numel() > 0: + inputs_embeds = inputs_embeds + vision_embeddings.sum() * 0 + inputs_embeds = ( + inputs_embeds + vision_mask.sum().to(inputs_embeds.dtype) * 0 + ) + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, + seq_length + past_key_values_length, + dtype=torch.long, + device=device, + ) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + hidden_states = inputs_embeds + + # Get KV cache for TKG and for model-local chunked CTE. + use_qwen_chunked_prefill = ( + is_for_context_encoding + and getattr(self.config, "use_qwen_hybrid_chunked_prefill", False) + ) + active_block_table = kwargs.get("active_block_table", None) + cte_has_prefix_blocks = ( + is_for_context_encoding + and use_qwen_chunked_prefill + and active_block_table is not None + and getattr(active_block_table, "ndim", 0) > 1 + ) + cache_size = ( + self.config.neuron_config.seq_len + if use_qwen_chunked_prefill + else self.n_positions + ) + if (not is_for_context_encoding) or cte_has_prefix_blocks: + if self.kv_mgr is not None: + past_key_values = self.kv_mgr.get_cache( + seq_ids=seq_ids, + seq_len=cache_size, + is_for_context_encoding=is_for_context_encoding, + windowed_context_encoding_window_idx=windowed_context_encoding_window_idx, + **kwargs, + ) + + # Decoder layers + next_decoder_cache = () + deltanet_state_tensors = [] + deltanet_layer_state_pairs = [] + cos_cache = None + sin_cache = None + restored_gdn_states = None + if getattr(self.config, "use_hybrid_apc_manager", False) and hasattr( + self, "hybrid_gdn_checkpoint_cache" + ): + if hybrid_restore_prefix_lens is not None and position_ids is not None: + # Host-side request prep must set suffix position_ids to the + # restored cumulative-prefix boundary. This is a no-op on + # default zero masks, but it keeps the contract explicit. + if ( + not torch.jit.is_tracing() + and hybrid_restore_mask is not None + and bool(hybrid_restore_mask.to(torch.bool).any().item()) + ): + expected = hybrid_restore_prefix_lens.long() + observed = position_ids[:, 0].long() + if not torch.equal(observed, expected): + raise ValueError( + "hybrid APC restore prefix lens must match " + "position_ids[:, 0]" + ) + restored_gdn_states = ( + self.hybrid_gdn_checkpoint_cache.restore_to_active_rows( + layers=self.layers, + seq_ids=seq_ids, + checkpoint_slot_ids=hybrid_restore_slot_ids, + restore_mask=hybrid_restore_mask, + zero_inactive=( + is_for_context_encoding + and not _qwen36_hybrid_apc_mask_has_active_row( + hybrid_restore_prefix_lens + ) + ), + ) + ) + + # Keep CTE masks compact on the Neuron paths. Qwen attention prefill + # applies causal masking inside the attention kernel/path, while DeltaNet + # consumes deltanet_padding_mask separately. Dense SxS masks are only a + # small fallback path and are not viable for long-context CTE. + use_compact_cte_attention_mask = getattr( + self.config, "use_compact_cte_attention_mask", True + ) + use_neuron_cte_attention = use_qwen_chunked_prefill or getattr( + self.config.neuron_config, "is_block_kv_layout", False + ) + # Convert 2D attention_mask to 4D causal mask for the small fallback path. + if ( + attention_mask is not None + and attention_mask.ndim == 2 + and is_for_context_encoding + and not use_compact_cte_attention_mask + and not use_neuron_cte_attention + ): + causal = torch.ones( + (seq_length, seq_length), + dtype=torch.bool, + device=attention_mask.device, + ).tril() + padding_4d = attention_mask[:, None, None, :].to(torch.bool) + attention_mask = (causal[None, None, :, :] & padding_4d).to( + attention_mask.dtype + ) + + # Pre-compute mRoPE cos/sin + if rotary_position_ids is not None and rotary_position_ids.ndim == 3: + cos_cache, sin_cache = self.mrope_emb(inputs_embeds, rotary_position_ids) + + for idx, decoder_layer in enumerate(self.layers): + past_key_value = ( + past_key_values[idx] if past_key_values is not None else None + ) + if restored_gdn_states is not None and idx in restored_gdn_states: + past_key_value = restored_gdn_states[idx] + + layer_outputs = decoder_layer( + hidden_states, + seq_ids=seq_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + active_mask=active_mask, + adapter_ids=adapter_ids, + cos_cache=cos_cache, + sin_cache=sin_cache, + rotary_position_ids=rotary_position_ids, + kv_mgr=self.kv_mgr, + get_kv_per_layer=False, + update_kv_per_layer=False, + idx=idx, + is_for_context_encoding=is_for_context_encoding, + seq_len=cache_size, + residual=None, + local_mask=local_attn_mask, + windowed_context_encoding_window_idx=windowed_context_encoding_window_idx, + padding_mask=padding_mask, + deltanet_padding_mask=deltanet_padding_mask, + qwen_chunked_prefill_update=use_qwen_chunked_prefill, + qwen_chunked_valid_mask=deltanet_padding_mask.squeeze(-1) + if use_qwen_chunked_prefill + else None, + **kwargs, + ) + + hidden_states = layer_outputs[0] + kv = layer_outputs[1] + next_decoder_cache += (kv,) + cos_cache, sin_cache = layer_outputs[2:4] + + # Collect DeltaNet state tensors + deltanet_states = layer_outputs[5] if len(layer_outputs) > 5 else None + if deltanet_states is not None: + deltanet_state_tensors.append(deltanet_states[0]) + deltanet_state_tensors.append(deltanet_states[1]) + deltanet_layer_state_pairs.append( + (idx, deltanet_states[0], deltanet_states[1]) + ) + + # Update KV cache + if update_cache: + next_decoder_cache = self.kv_mgr.update_cache( + is_for_context_encoding=is_for_context_encoding, + seq_ids=seq_ids, + position_ids=position_ids, + new_key_values=next_decoder_cache, + seq_len=cache_size, + windowed_context_encoding_window_idx=windowed_context_encoding_window_idx, + qwen_chunked_prefill_update=use_qwen_chunked_prefill, + qwen_chunked_valid_mask=deltanet_padding_mask.squeeze(-1) + if use_qwen_chunked_prefill + else None, + **kwargs, + ) + + if getattr(self.config, "use_hybrid_apc_manager", False) and hasattr( + self, "hybrid_gdn_checkpoint_cache" + ): + commit_during_tkg = bool( + getattr(self.config, "hybrid_apc_commit_during_token_generation", False) + ) + if not is_for_context_encoding and not commit_during_tkg: + self._hybrid_gdn_checkpoint_updated_states = [] + else: + self._hybrid_gdn_checkpoint_updated_states = ( + self.hybrid_gdn_checkpoint_cache.commit_from_active_rows( + layer_state_pairs=deltanet_layer_state_pairs, + seq_ids=seq_ids, + checkpoint_slot_ids=hybrid_commit_slot_ids, + commit_mask=hybrid_commit_mask, + ) + ) + + _debug_logits_stage("before_final_norm", hidden_states) + hidden_states = self.norm(hidden_states) + _debug_logits_stage("after_final_norm_full", hidden_states) + + self._deltanet_updated_states = deltanet_state_tensors + + return (hidden_states, next_decoder_cache) + + def forward( + self, + input_ids, + attention_mask, + position_ids, + seq_ids, + sampling_params, + prev_hidden=None, + adapter_ids=None, + accepted_indices=None, + current_length=None, + medusa_mask=None, + scatter_index=None, + slot_mapping=None, + active_block_table=None, + num_queries=None, + computed_context_lens=None, + tile_q_indices=None, + tile_block_tables=None, + tile_masks=None, + inputs_embeds=None, + kv_cache=None, + active_mask=None, + rotary_position_id=None, + vision_embeddings=None, + vision_mask=None, + hybrid_restore_slot_ids=None, + hybrid_restore_mask=None, + hybrid_restore_prefix_lens=None, + hybrid_commit_slot_ids=None, + hybrid_commit_mask=None, + ): + """Override base forward to append DeltaNet state tensors to output.""" + prev_hidden = self.set_none_if_empty(prev_hidden) + adapter_ids = self.set_none_if_empty(adapter_ids) + accepted_indices = self.set_none_if_empty(accepted_indices) + current_length = self.set_none_if_empty(current_length) + medusa_mask = self.set_none_if_empty(medusa_mask) + scatter_index = self.set_none_if_empty(scatter_index) + slot_mapping = self.set_none_if_empty(slot_mapping) + active_block_table = self.set_none_if_empty(active_block_table) + num_queries = self.set_none_if_empty(num_queries) + computed_context_lens = self.set_none_if_empty(computed_context_lens) + tile_q_indices = self.set_none_if_empty(tile_q_indices) + tile_block_tables = self.set_none_if_empty(tile_block_tables) + tile_masks = self.set_none_if_empty(tile_masks) + inputs_embeds = self.set_none_if_empty(inputs_embeds) + kv_cache = self.set_none_if_empty(kv_cache) + active_mask = self.set_none_if_empty(active_mask) + rotary_position_id = self.set_none_if_empty(rotary_position_id) + vision_embeddings = self.set_none_if_empty(vision_embeddings) + vision_mask = self.set_none_if_empty(vision_mask) + hybrid_restore_slot_ids = self.set_none_if_empty(hybrid_restore_slot_ids) + hybrid_restore_mask = self.set_none_if_empty(hybrid_restore_mask) + hybrid_restore_prefix_lens = self.set_none_if_empty(hybrid_restore_prefix_lens) + hybrid_commit_slot_ids = self.set_none_if_empty(hybrid_commit_slot_ids) + hybrid_commit_mask = self.set_none_if_empty(hybrid_commit_mask) + + is_for_context_encoding = position_ids.shape[-1] != 1 and not ( + hasattr(self.neuron_config, "speculation_length") + and position_ids.shape[-1] == self.neuron_config.speculation_length + ) + + seq_ids = seq_ids.to(torch.int32) + attn_mask = attention_mask + + hidden_states, updated_kv_cache = self.get_model_output( + input_ids=input_ids, + seq_ids=seq_ids, + attention_mask=attn_mask, + position_ids=position_ids, + active_mask=active_mask, + inputs_embeds=inputs_embeds, + adapter_ids=adapter_ids, + rotary_position_ids=rotary_position_id, + update_cache=True, + is_for_context_encoding=is_for_context_encoding, + padding_mask=None, + active_block_table=active_block_table, + scatter_index=slot_mapping + if getattr(self, "is_block_kv_layout", False) + else scatter_index, + vision_embeddings=vision_embeddings, + vision_mask=vision_mask, + hybrid_restore_slot_ids=hybrid_restore_slot_ids, + hybrid_restore_mask=hybrid_restore_mask, + hybrid_restore_prefix_lens=hybrid_restore_prefix_lens, + hybrid_commit_slot_ids=hybrid_commit_slot_ids, + hybrid_commit_mask=hybrid_commit_mask, + num_queries=num_queries, + computed_context_lens=computed_context_lens, + ) + + batch_size = input_ids.shape[0] + if not getattr(self, "sliced_hidden", False): + if not is_for_context_encoding: + pass + else: + if getattr(self.config, "use_qwen_hybrid_chunked_prefill", False): + query_index = None + if ( + num_queries is not None + and hasattr(num_queries, "numel") + and num_queries.numel() >= batch_size + ): + query_index = ( + num_queries.reshape(-1)[:batch_size] + .to(device=input_ids.device, dtype=torch.long) + .view(batch_size, 1) + - 1 + ).clamp(min=0) + token_index = None + if self.padding_idx is not None: + token_index = ( + (input_ids != self.padding_idx) + .sum(dim=1, keepdim=True) + .long() + - 1 + ).clamp(min=0) + if query_index is not None: + index = query_index + elif attention_mask is not None and attention_mask.ndim == 2: + attention_index = ( + attention_mask.to(torch.long).sum(dim=1, keepdim=True) + - 1 + ).clamp(min=0) + if ( + hybrid_restore_mask is not None + and hasattr(hybrid_restore_mask, "numel") + and hybrid_restore_mask.numel() > 0 + ): + restore_active = ( + hybrid_restore_mask.reshape(-1).to(torch.bool).any() + ) + index = torch.where( + restore_active, + token_index if token_index is not None else attention_index, + attention_index, + ) + else: + index = attention_index + else: + index = ( + token_index + if token_index is not None + else torch.full( + (batch_size, 1), + max(0, input_ids.shape[1] - 1), + dtype=torch.long, + device=input_ids.device, + ) + ) + else: + index = torch.max(position_ids, dim=1, keepdim=True).indices + index = index.unsqueeze(1).expand(batch_size, 1, self.hidden_size) + hidden_states = torch.gather(hidden_states, dim=1, index=index) + + _debug_logits_stage("after_final_norm", hidden_states) + _debug_logits_stage("selected_hidden_before_lm_head", hidden_states) + _debug_logits_stage("lm_head_weight", getattr(self.lm_head, "weight", None)) + logits = self.lm_head(hidden_states) + _debug_logits_stage("after_lm_head_pre_float", logits) + logits = logits.float() + _debug_logits_stage("after_lm_head", logits) + + if hasattr(self.lm_head, "pad_size"): + if self.lm_head.gather_output: + rank_id = torch.tensor(0, device=logits.device, dtype=torch.int32) + world_size = 1 + else: + from neuronx_distributed.parallel_layers import parallel_state + + rank_id = self.rank_util.get_rank() + world_size = torch.distributed.get_world_size( + group=self.lm_head.tensor_parallel_group + ) + from neuronx_distributed_inference.models.model_base import ( + mask_padded_logits, + ) + + logits = mask_padded_logits( + logits, + rank_id, + world_size, + pad_size=_effective_lm_head_pad_size( + self.lm_head, logits, self.config + ), + ) + _debug_logits_stage("after_mask_padded_logits", logits) + + if self.on_device_sampling: + res = self._sample_on_device( + logits, sampling_params, False, is_for_context_encoding + ) + else: + res = logits + + _debug_logits_stage("before_return_logits", logits) + outputs = [res] + if self.neuron_config.output_logits and self.on_device_sampling: + outputs += [ + _qwen36_output_logits_for_return( + logits, + self.lm_head, + self.neuron_config, + ) + ] + _qwen36_validate_alias_output_counts( + self, + updated_kv_cache, + is_for_context_encoding=is_for_context_encoding, + ) + outputs += updated_kv_cache + + # Append DeltaNet state tensors (for input_output_aliases) + if ( + not getattr(self.config, "use_hybrid_cache_manager", False) + and hasattr(self, "_deltanet_updated_states") + ): + outputs += self._deltanet_updated_states + if ( + getattr(self.config, "use_hybrid_apc_manager", False) + and hasattr(self, "_hybrid_gdn_checkpoint_updated_states") + ): + outputs += self._hybrid_gdn_checkpoint_updated_states + + return outputs + + +# ============================================================ +# State Dict Converter (Dense -- no MoE weight handling) +# ============================================================ + + +_QWEN36_FP8_DTYPES = tuple( + dtype + for dtype in ( + getattr(torch, "float8_e4m3fn", None), + getattr(torch, "float8_e5m2", None), + ) + if dtype is not None +) + + +def _qwen36_cat(tensors, dim=0): + """Concatenate tensors, including FP8 tensors on builds without FP8 cat.""" + if tensors and tensors[0].dtype in _QWEN36_FP8_DTYPES: + return torch.cat( + [tensor.contiguous().view(torch.int8) for tensor in tensors], + dim=dim, + ).view(tensors[0].dtype) + return torch.cat(tensors, dim=dim) + + +def convert_qwen35_hf_to_neuron_state_dict(neuron_state_dict, config): + """Convert HF Qwen3.5/3.6-27B weights to NxDI format. + + Weight mappings per layer type: + + DeltaNet layers (linear_attention): + HF: layers.X.linear_attn.{in_proj_qkv, in_proj_z, in_proj_a, in_proj_b, + conv1d, A_log, dt_bias, norm, out_proj} + NxDI: projections keep names; conv1d/A_log/dt_bias are remapped into + ColumnParallelLinear parameter containers so NxD can shard them. + + Full attention layers: + HF: layers.X.self_attn.q_proj.weight: (12288, 5120) -- doubled for gate + NxDI: layers.X.self_attn.Wqkv.weight (fused Q+K+V, gate separated) + layers.X.self_attn.output_gate_proj.weight (gate part) + HF: layers.X.self_attn.{k_proj, v_proj, o_proj, q_norm, k_norm} + NxDI: layers.X.self_attn.{..., q_layernorm, k_layernorm} + + Dense MLP (all layers): + HF: layers.X.mlp.{gate_proj, up_proj, down_proj}.weight + NxDI: layers.X.mlp.{gate_proj, up_proj, down_proj}.weight (same names) + + FP8 quantized checkpoints carry one scale tensor next to each quantized + weight. NxDI normalizes saved ``.weight_scale`` keys to model ``.scale`` + keys before this converter runs, so any Qwen-specific weight split/reorder/ + fusion below must apply the same transformation to the matching scale. + """ + # Add rank_util + neuron_state_dict["rank_util.rank"] = torch.arange( + 0, + config.neuron_config.tp_degree, + dtype=torch.int32, + ) + + def _reorder_deltanet_qkv_for_tp(qkv_weight: torch.Tensor) -> torch.Tensor: + """Pack [Q_all | K_all | V_all] into per-rank Q/K/V blocks. + + ColumnParallelLinear slices the first dimension contiguously. DeltaNet + needs each rank to receive its local query, key, and value heads + together, so the full HF tensor is repacked as: + [rank0 Q | rank0 K | rank0 V | rank1 Q | rank1 K | rank1 V | ...]. + """ + tp_degree = config.neuron_config.tp_degree + num_k_heads = config.linear_num_key_heads + num_v_heads = config.linear_num_value_heads + head_k_dim = config.linear_key_head_dim + head_v_dim = config.linear_value_head_dim + if num_k_heads % tp_degree != 0: + raise ValueError( + f"linear_num_key_heads={num_k_heads} must be divisible by tp_degree={tp_degree}" + ) + if num_v_heads % tp_degree != 0: + raise ValueError( + f"linear_num_value_heads={num_v_heads} must be divisible by tp_degree={tp_degree}" + ) + + key_dim = num_k_heads * head_k_dim + value_dim = num_v_heads * head_v_dim + q_weight = qkv_weight[:key_dim].reshape(num_k_heads, head_k_dim, -1) + k_weight = qkv_weight[key_dim : 2 * key_dim].reshape(num_k_heads, head_k_dim, -1) + v_weight = qkv_weight[2 * key_dim : 2 * key_dim + value_dim].reshape( + num_v_heads, head_v_dim, -1 + ) + local_k_heads = num_k_heads // tp_degree + local_v_heads = num_v_heads // tp_degree + blocks = [] + for rank in range(tp_degree): + blocks.append( + q_weight[ + rank * local_k_heads : (rank + 1) * local_k_heads + ].reshape(-1, qkv_weight.shape[1]) + ) + blocks.append( + k_weight[ + rank * local_k_heads : (rank + 1) * local_k_heads + ].reshape(-1, qkv_weight.shape[1]) + ) + blocks.append( + v_weight[ + rank * local_v_heads : (rank + 1) * local_v_heads + ].reshape(-1, qkv_weight.shape[1]) + ) + return _qwen36_cat(blocks, dim=0).contiguous() + + def _reorder_deltanet_qkv_channels_for_tp(channel_tensor: torch.Tensor) -> torch.Tensor: + """Repack a first-dimension Q/K/V channel tensor into TP rank blocks.""" + tp_degree = config.neuron_config.tp_degree + num_k_heads = config.linear_num_key_heads + num_v_heads = config.linear_num_value_heads + head_k_dim = config.linear_key_head_dim + head_v_dim = config.linear_value_head_dim + key_dim = num_k_heads * head_k_dim + value_dim = num_v_heads * head_v_dim + q_tensor = channel_tensor[:key_dim] + k_tensor = channel_tensor[key_dim : 2 * key_dim] + v_tensor = channel_tensor[2 * key_dim : 2 * key_dim + value_dim] + local_key_dim = key_dim // tp_degree + local_value_dim = value_dim // tp_degree + blocks = [] + for rank in range(tp_degree): + blocks.append(q_tensor[rank * local_key_dim : (rank + 1) * local_key_dim]) + blocks.append(k_tensor[rank * local_key_dim : (rank + 1) * local_key_dim]) + blocks.append( + v_tensor[rank * local_value_dim : (rank + 1) * local_value_dim] + ) + return _qwen36_cat(blocks, dim=0).contiguous() + + def _split_interleaved_q_proj_tensor( + tensor: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Split interleaved Qwen q_proj tensor into query and output gate.""" + num_heads = config.num_attention_heads + head_dim = config.head_dim + trailing_shape = tensor.shape[1:] + tensor = tensor.reshape(num_heads, head_dim * 2, *trailing_shape) + query_tensor = tensor[:, :head_dim, ...].reshape( + num_heads * head_dim, + *trailing_shape, + ) + gate_tensor = tensor[:, head_dim:, ...].reshape( + num_heads * head_dim, + *trailing_shape, + ) + return query_tensor.contiguous(), gate_tensor.contiguous() + + # CRITICAL: Convert (1+weight) RMSNorm weights to standard RMSNorm weights. + # Qwen3.5 uses RMSNorm with `output = norm(x) * (1 + weight)` where weight + # is initialized to zeros. Standard NxDI RMSNorm uses `output = norm(x) * weight` + # where weight is initialized to ones. To convert: new_weight = old_weight + 1.0 + norm_keys_to_convert = [] + for l in range(config.num_hidden_layers): + norm_keys_to_convert.append(f"layers.{l}.input_layernorm.weight") + norm_keys_to_convert.append(f"layers.{l}.post_attention_layernorm.weight") + if config.layer_types[l] == "full_attention": + norm_keys_to_convert.append(f"layers.{l}.self_attn.q_norm.weight") + norm_keys_to_convert.append(f"layers.{l}.self_attn.k_norm.weight") + norm_keys_to_convert.append("norm.weight") + + for nk in norm_keys_to_convert: + if nk in neuron_state_dict: + old_val = neuron_state_dict[nk] + neuron_state_dict[nk] = old_val.float() + 1.0 + if "layers.0." in nk or nk == "norm.weight": + logger.debug( + f"[NORM FIX] {nk}: mean {old_val.float().mean():.4f} -> {neuron_state_dict[nk].mean():.4f}" + ) + else: + if "layers.0." in nk or nk == "norm.weight": + logger.warning(f"[NORM FIX] key not found: {nk}") + + for l in range(config.num_hidden_layers): + layer_type = config.layer_types[l] + + # === DeltaNet layers === + if layer_type == "linear_attention": + qkv_key = f"layers.{l}.linear_attn.in_proj_qkv.weight" + if qkv_key in neuron_state_dict and config.neuron_config.tp_degree > 1: + neuron_state_dict[qkv_key] = _reorder_deltanet_qkv_for_tp( + neuron_state_dict[qkv_key] + ) + qkv_scale_key = f"layers.{l}.linear_attn.in_proj_qkv.scale" + if qkv_scale_key in neuron_state_dict and config.neuron_config.tp_degree > 1: + neuron_state_dict[qkv_scale_key] = _reorder_deltanet_qkv_channels_for_tp( + neuron_state_dict[qkv_scale_key] + ) + + conv_key = f"layers.{l}.linear_attn.conv1d.weight" + conv_weight_key = f"layers.{l}.linear_attn.conv1d_weight.weight" + conv_scale_key = f"layers.{l}.linear_attn.conv1d.scale" + conv_weight_scale_key = f"layers.{l}.linear_attn.conv1d_weight.scale" + if conv_key in neuron_state_dict: + conv_weight = neuron_state_dict.pop(conv_key) + if config.neuron_config.tp_degree > 1: + conv_weight = _reorder_deltanet_qkv_channels_for_tp(conv_weight) + neuron_state_dict[conv_weight_key] = conv_weight.squeeze(1).contiguous() + if conv_scale_key in neuron_state_dict: + conv_scale = neuron_state_dict.pop(conv_scale_key) + if config.neuron_config.tp_degree > 1: + conv_scale = _reorder_deltanet_qkv_channels_for_tp(conv_scale) + neuron_state_dict[conv_weight_scale_key] = conv_scale.contiguous() + + for vector_name in ("A_log", "dt_bias"): + vector_key = f"layers.{l}.linear_attn.{vector_name}" + vector_weight_key = f"layers.{l}.linear_attn.{vector_name}_weight.weight" + if vector_key in neuron_state_dict: + neuron_state_dict[vector_weight_key] = ( + neuron_state_dict.pop(vector_key).reshape(-1, 1).contiguous() + ) + + # === Attention layers === + if layer_type == "full_attention": + neuron_state_dict[f"layers.{l}.self_attn.rank_util.rank"] = torch.arange( + 0, + config.neuron_config.tp_degree, + dtype=torch.int32, + ) + + # QK norms: q_norm -> q_layernorm, k_norm -> k_layernorm + q_norm_key = f"layers.{l}.self_attn.q_norm.weight" + k_norm_key = f"layers.{l}.self_attn.k_norm.weight" + if q_norm_key in neuron_state_dict: + neuron_state_dict[f"layers.{l}.self_attn.q_layernorm.weight"] = ( + neuron_state_dict.pop(q_norm_key).detach().clone() + ) + if k_norm_key in neuron_state_dict: + neuron_state_dict[f"layers.{l}.self_attn.k_layernorm.weight"] = ( + neuron_state_dict.pop(k_norm_key).detach().clone() + ) + + # q_proj is doubled: (12288, 5120) = (num_heads * head_dim * 2, hidden) + # INTERLEAVED: [head0_query(256) | head0_gate(256) | head1_query(256) | ...] + q_proj_key = f"layers.{l}.self_attn.q_proj.weight" + q_proj_scale_key = f"layers.{l}.self_attn.q_proj.scale" + if q_proj_key in neuron_state_dict: + q_proj_w = neuron_state_dict.pop(q_proj_key) + query_w, gate_w = _split_interleaved_q_proj_tensor(q_proj_w) + + neuron_state_dict[q_proj_key] = query_w + neuron_state_dict[f"layers.{l}.self_attn.output_gate_proj.weight"] = ( + gate_w + ) + if q_proj_scale_key in neuron_state_dict: + q_proj_scale = neuron_state_dict.pop(q_proj_scale_key) + query_scale, gate_scale = _split_interleaved_q_proj_tensor( + q_proj_scale + ) + neuron_state_dict[q_proj_scale_key] = query_scale + neuron_state_dict[f"layers.{l}.self_attn.output_gate_proj.scale"] = ( + gate_scale + ) + + # Fuse QKV + if config.neuron_config.fused_qkv: + q_key = f"layers.{l}.self_attn.q_proj.weight" + k_key = f"layers.{l}.self_attn.k_proj.weight" + v_key = f"layers.{l}.self_attn.v_proj.weight" + gate_key = f"layers.{l}.self_attn.output_gate_proj.weight" + pack_gate_in_qkv = bool( + getattr(config, "use_qwen_qkv_gate_packed", False) + ) + if q_key in neuron_state_dict: + qkv_weight_parts = [neuron_state_dict[q_key]] + if pack_gate_in_qkv: + if gate_key not in neuron_state_dict: + raise ValueError( + f"Missing output-gate tensor for packed QKV: {gate_key}" + ) + qkv_weight_parts.append(neuron_state_dict[gate_key]) + qkv_weight_parts.extend( + [neuron_state_dict[k_key], neuron_state_dict[v_key]] + ) + neuron_state_dict[f"layers.{l}.self_attn.Wqkv.weight"] = _qwen36_cat( + qkv_weight_parts + ) + q_scale_key = f"layers.{l}.self_attn.q_proj.scale" + gate_scale_key = f"layers.{l}.self_attn.output_gate_proj.scale" + k_scale_key = f"layers.{l}.self_attn.k_proj.scale" + v_scale_key = f"layers.{l}.self_attn.v_proj.scale" + scale_keys = [q_scale_key] + if pack_gate_in_qkv: + scale_keys.append(gate_scale_key) + scale_keys.extend([k_scale_key, v_scale_key]) + scale_keys_present = [key in neuron_state_dict for key in scale_keys] + if any(scale_keys_present): + if not all(scale_keys_present): + missing = [ + key + for key, present in zip(scale_keys, scale_keys_present) + if not present + ] + raise ValueError( + f"Missing FP8 fused-QKV scale tensor(s): {missing}" + ) + neuron_state_dict[f"layers.{l}.self_attn.Wqkv.scale"] = _qwen36_cat( + [neuron_state_dict[key] for key in scale_keys] + ) + del neuron_state_dict[q_scale_key] + del neuron_state_dict[k_scale_key] + del neuron_state_dict[v_scale_key] + del neuron_state_dict[q_key] + del neuron_state_dict[k_key] + del neuron_state_dict[v_key] + + # Dense MLP: no weight conversion needed -- HF and NxDI use same names + # HF: layers.X.mlp.{gate_proj, up_proj, down_proj}.weight + # NxDI: layers.X.mlp.{gate_proj, up_proj, down_proj}.weight + + gc.collect() + + return neuron_state_dict + + +# ============================================================ +# Custom ModelWrapper and DecoderModelInstance for DeltaNet state aliasing +# ============================================================ + + +def _reassert_hybrid_gdn_checkpoint_param_dtypes(module): + config = getattr(module, "config", None) + if config is None: + return + + recurrent_dtype = _torch_dtype_from_hybrid_cache_dtype( + getattr(config, "hybrid_recurrent_cache_dtype", "float32") + ) + conv_dtype = _torch_dtype_from_hybrid_cache_dtype( + getattr(config, "hybrid_conv_cache_dtype", "bfloat16") + ) + + def _retarget(params, dtype): + for param in params: + if param.dtype != dtype: + param.data = param.data.to(dtype) + + for layer in getattr(module, "layers", []): + linear_attn = getattr(layer, "linear_attn", None) + if linear_attn is None: + continue + recurrent_buffer = getattr(linear_attn, "recurrent_state_buffer", None) + conv_buffer = getattr(linear_attn, "conv_state_buffer", None) + if recurrent_buffer is not None and recurrent_buffer.dtype != recurrent_dtype: + recurrent_buffer.data = recurrent_buffer.data.to(recurrent_dtype) + if conv_buffer is not None and conv_buffer.dtype != conv_dtype: + conv_buffer.data = conv_buffer.data.to(conv_dtype) + + cache = getattr(module, "hybrid_gdn_checkpoint_cache", None) + if cache is not None: + _retarget(cache.recurrent_slots, recurrent_dtype) + _retarget(cache.conv_slots, conv_dtype) + cache.recurrent_dtype = recurrent_dtype + cache.conv_dtype = conv_dtype + + +def _qwen36_is_context_encoding_trace( + n_active_tokens: int | None, + neuron_config, +) -> bool: + n_active_tokens = int(n_active_tokens or 0) + speculation_length = getattr(neuron_config, "speculation_length", None) + return n_active_tokens != 1 and not ( + speculation_length is not None and n_active_tokens == speculation_length + ) + + +def _qwen36_include_hybrid_gdn_checkpoint_outputs( + config, + *, + is_for_context_encoding: bool | None = None, + n_active_tokens: int | None = None, + neuron_config=None, +) -> bool: + if is_for_context_encoding is None: + is_for_context_encoding = _qwen36_is_context_encoding_trace( + n_active_tokens, + neuron_config, + ) + if not getattr(config, "use_hybrid_apc_manager", False): + return True + if is_for_context_encoding: + return True + return bool(getattr(config, "hybrid_apc_commit_during_token_generation", False)) + + +def _qwen36_validate_alias_output_counts( + module, + updated_kv_cache, + *, + is_for_context_encoding: bool, +): + kv_mgr = getattr(module, "kv_mgr", None) + if kv_mgr is not None: + expected_kv = len(kv_mgr.past_key_values) + else: + expected_kv = 0 + actual_kv = len(updated_kv_cache) + if actual_kv != expected_kv: + raise RuntimeError( + "Qwen3.6 output alias count mismatch: " + f"updated_kv_cache has {actual_kv} tensors but kv_mgr.past_key_values " + f"has {expected_kv}" + ) + + expected_states = 0 + if not getattr(module.config, "use_hybrid_cache_manager", False): + expected_states = len(getattr(module, "_deltanet_state_params", [])) + actual_states = len(getattr(module, "_deltanet_updated_states", [])) + if actual_states != expected_states: + raise RuntimeError( + "Qwen3.6 output alias count mismatch: " + f"_deltanet_updated_states has {actual_states} tensors but " + f"_deltanet_state_params has {expected_states}" + ) + + checkpoint_outputs_expected = _qwen36_include_hybrid_gdn_checkpoint_outputs( + module.config, + is_for_context_encoding=is_for_context_encoding, + ) + expected_checkpoints = ( + len(getattr(module, "_hybrid_gdn_checkpoint_params", [])) + if checkpoint_outputs_expected + else 0 + ) + actual_checkpoints = len( + getattr(module, "_hybrid_gdn_checkpoint_updated_states", []) + ) + if actual_checkpoints != expected_checkpoints: + raise RuntimeError( + "Qwen3.6 output alias count mismatch: " + f"_hybrid_gdn_checkpoint_updated_states has {actual_checkpoints} tensors " + f"but _hybrid_gdn_checkpoint_params expects {expected_checkpoints}" + ) + + +class Qwen35DecoderModelInstance(DecoderModelInstance): + """Custom DecoderModelInstance that adds DeltaNet state buffers to input_output_aliases.""" + + def load_module(self): + super().load_module() + _reassert_hybrid_gdn_checkpoint_param_dtypes(self.module) + + @staticmethod + def _num_trace_outputs_before_aliases(neuron_config): + if ( + getattr(neuron_config, "output_logits", False) + and getattr(neuron_config, "on_device_sampling_config", None) is not None + ): + return 2 + return 1 + + def get(self, bucket_rank, **kwargs): + """Override to add DeltaNet state aliases after KV cache aliases.""" + module, input_output_aliases = super().get(bucket_rank, **kwargs) + + num_output_from_trace = self._num_trace_outputs_before_aliases( + self.neuron_config + ) + base_num_output_from_trace = 1 if not self.neuron_config.output_logits else 2 + if num_output_from_trace != base_num_output_from_trace: + alias_shift = base_num_output_from_trace - num_output_from_trace + for param in list(input_output_aliases.keys()): + input_output_aliases[param] -= alias_shift + + if module.kv_mgr is not None: + num_kv = len(module.kv_mgr.past_key_values) + else: + num_kv = 0 + + state_start_idx = num_output_from_trace + num_kv + + if ( + not getattr(module.config, "use_hybrid_cache_manager", False) + and hasattr(module, "_deltanet_state_params") + ): + for i, param in enumerate(module._deltanet_state_params): + input_output_aliases[param] = state_start_idx + i + + checkpoint_start_idx = state_start_idx + len(module._deltanet_state_params) + include_checkpoint_aliases = _qwen36_include_hybrid_gdn_checkpoint_outputs( + module.config, + n_active_tokens=getattr(module, "n_active_tokens", 0), + neuron_config=self.neuron_config, + ) + if include_checkpoint_aliases: + for i, param in enumerate( + getattr(module, "_hybrid_gdn_checkpoint_params", []) + ): + input_output_aliases[param] = checkpoint_start_idx + i + + return module, input_output_aliases + + +class Qwen35ModelWrapper(ModelWrapper): + """Custom ModelWrapper for VL support with mRoPE and vision inputs.""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._qwen36_hybrid_apc_pending_input_dict = None + self.hybrid_apc_store = None + self.hybrid_apc_slot_allocator = None + self.hybrid_apc_bridge = None + self._init_hybrid_apc_scheduler_bridge() + + def _init_hybrid_apc_scheduler_bridge(self): + if not _qwen36_config_flag( + self.config, + self.neuron_config, + "use_hybrid_apc_manager", + ): + return + + required_gdn_layers = tuple( + idx + for idx, layer_type in enumerate(self.config.layer_types) + if layer_type == "linear_attention" + ) + if not required_gdn_layers: + raise ValueError("hybrid APC requires at least one GDN layer") + + tp_rank = 0 + try: + if torch.distributed.is_available() and torch.distributed.is_initialized(): + tp_rank = int(parallel_state.get_tensor_model_parallel_rank()) + except Exception: + tp_rank = 0 + + block_size = int( + getattr( + self.neuron_config, + "pa_block_size", + self.config.gdn_checkpoint_interval, + ) + ) + self.hybrid_apc_store = HybridAPCMetadataStore( + required_gdn_layers=required_gdn_layers, + block_size=block_size, + checkpoint_interval=self.config.gdn_checkpoint_interval, + max_checkpoints=self.config.max_gdn_checkpoint_slots, + layout_version=self.config.hybrid_apc_layout_version, + model_revision=self.config.hybrid_apc_model_revision, + tp_rank=tp_rank, + recurrent_dtype=self.config.hybrid_recurrent_cache_dtype, + conv_dtype=self.config.hybrid_conv_cache_dtype, + allow_residual_replay=self.config.hybrid_apc_allow_residual_replay, + ) + self.hybrid_apc_slot_allocator = HybridAPCSlotAllocator( + self.config.max_gdn_checkpoint_slots + ) + self.hybrid_apc_bridge = HybridAPCSchedulerBridge( + store=self.hybrid_apc_store, + slot_allocator=self.hybrid_apc_slot_allocator, + cache_salt=self.config.hybrid_apc_cache_salt, + model_revision=self.config.hybrid_apc_model_revision, + layout_version=self.config.hybrid_apc_layout_version, + tp_rank=tp_rank, + recurrent_dtype=self.config.hybrid_recurrent_cache_dtype, + conv_dtype=self.config.hybrid_conv_cache_dtype, + allow_local_hash_fallback=self.config.hybrid_apc_allow_local_hash_fallback, + require_attention_block_refs=self.config.hybrid_apc_require_attention_block_refs, + reject_unbacked_attention_hits=( + self.config.hybrid_apc_reject_unbacked_attention_hits + ), + ) + + def ensure_hybrid_apc_scheduler_bridge(self): + if not _qwen36_config_flag( + self.config, + self.neuron_config, + "use_hybrid_apc_manager", + ): + return None + if getattr(self, "hybrid_apc_bridge", None) is None: + self._init_hybrid_apc_scheduler_bridge() + return self.hybrid_apc_bridge + + def get_model_instance(self): + return Qwen35DecoderModelInstance( + model_cls=self.model_cls, + config=self.config, + **self.model_init_kwargs, + ) + + def input_generator(self): + """Generate inputs including mrope_position_ids, vision_embeddings, and vision_mask.""" + base_inputs = super().input_generator() + extended_inputs = [] + + for bucket_inputs in base_inputs: + input_ids = bucket_inputs[0] + batch_size = input_ids.shape[0] + n_active_tokens = input_ids.shape[1] + + is_cte = self.tag == CONTEXT_ENCODING_MODEL_TAG + + if is_cte: + mrope_position_ids = ( + torch.arange(0, n_active_tokens, dtype=torch.int32) + .unsqueeze(0) + .unsqueeze(0) + .expand(3, batch_size, -1) + .contiguous() + ) + + if getattr(self.config, "use_text_only_cte_inputs", True): + vision_embeddings = torch.zeros( + (0,), dtype=self.config.neuron_config.torch_dtype + ) + vision_mask = torch.zeros((0,), dtype=torch.int32) + else: + vision_embeddings = torch.zeros( + (batch_size, n_active_tokens, self.config.hidden_size), + dtype=self.config.neuron_config.torch_dtype, + ) + vision_mask = torch.full( + (batch_size, n_active_tokens, 1), + fill_value=n_active_tokens - 1, + dtype=torch.int32, + ) + else: + mrope_position_ids = torch.zeros((0,), dtype=torch.int32) + vision_embeddings = torch.zeros( + (0,), dtype=self.config.neuron_config.torch_dtype + ) + vision_mask = torch.zeros((0,), dtype=torch.int32) + + hybrid_args = None + if _use_expanded_hybrid_args_for_tag(self.config, self.tag): + hybrid_args = ( + torch.zeros((batch_size,), dtype=torch.int32), + torch.zeros((batch_size,), dtype=torch.int32), + torch.zeros((batch_size,), dtype=torch.int32), + torch.zeros((batch_size,), dtype=torch.int32), + torch.zeros((batch_size,), dtype=torch.int32), + ) + + if is_cte: + padded = build_cte_args( + self.config, + bucket_inputs, + mrope_position_ids, + vision_embeddings, + vision_mask, + hybrid_args=hybrid_args, + ) + else: + padded = build_tkg_args( + self.config, + bucket_inputs, + mrope_position_ids, + vision_embeddings, + vision_mask, + hybrid_args=hybrid_args, + ) + _debug_qwen36_arg_contract( + "compile", + self.tag, + self.config, + padded, + ) + extended_inputs.append(tuple(padded)) + + return extended_inputs + + def _prepare_hybrid_apc_pad_inputs(self, args): + if ( + self.tag != CONTEXT_ENCODING_MODEL_TAG + or len(args) < 29 + or not _qwen36_config_flag( + self.config, + self.neuron_config, + "use_hybrid_apc_manager", + ) + or _qwen36_hybrid_apc_controls_materialized( + args[25], + args[26], + args[28], + ) + ): + return args + + computed_context_lens = args[14] + num_queries = args[13] + full_context_lens = ( + computed_context_lens + num_queries + if hasattr(computed_context_lens, "shape") and hasattr(num_queries, "shape") + else None + ) + hybrid_apc_request_dict = { + "input_ids": args[0], + "attention_mask": args[1], + "position_ids": args[2], + "seq_ids": args[3], + "sampling_params": args[4], + "adapter_ids": args[6], + "slot_mapping": args[11], + "block_table": args[12], + "num_queries": num_queries, + "computed_context_lens": computed_context_lens, + } + if full_context_lens is not None: + hybrid_apc_request_dict["full_context_lens"] = full_context_lens + + request_records = getattr( + self, + "_qwen36_vllm_hybrid_apc_request_records", + None, + ) + request_ids = _qwen36_request_ids_from_hybrid_apc_records(request_records) + if request_records is not None: + hybrid_apc_request_dict["hybrid_request_records"] = request_records + if request_ids is None: + request_ids = getattr(self, "_qwen36_vllm_request_ids", None) + if request_ids is not None: + if isinstance(request_ids, list): + request_ids = tuple(request_ids) + if isinstance(request_ids, tuple) and len(request_ids) == 1: + hybrid_apc_request_dict["hybrid_request_id"] = request_ids[0] + else: + hybrid_apc_request_dict["hybrid_request_id"] = request_ids + cached_request_ids = getattr(self, "_qwen36_vllm_cached_request_ids", None) + if cached_request_ids is not None: + hybrid_apc_request_dict["hybrid_cached_request_ids"] = cached_request_ids + prefill_completion_state = getattr( + self, + "_qwen36_vllm_prefill_completion_state", + None, + ) + if prefill_completion_state is not None: + hybrid_apc_request_dict[ + "hybrid_prefill_completion_state" + ] = prefill_completion_state + _qwen36_add_vllm_hybrid_apc_metadata( + hybrid_apc_request_dict, + request_ids=request_ids, + metadata_by_request_id=getattr( + self, + "_qwen36_vllm_hybrid_apc_metadata_by_request_id", + None, + ), + ) + + prepared_inputs = prepare_hybrid_apc_request_for_execution( + self, + hybrid_apc_request_dict, + ) + hybrid_args = prepare_hybrid_apc_model_inputs(self, prepared_inputs) + if not hybrid_args: + return args + + updated_args = list(args) + for index, key in ( + (0, "input_ids"), + (1, "attention_mask"), + (2, "position_ids"), + (3, "seq_ids"), + (4, "sampling_params"), + (6, "adapter_ids"), + (11, "slot_mapping"), + (12, "block_table"), + (13, "num_queries"), + (14, "computed_context_lens"), + ): + if key in prepared_inputs: + updated_args[index] = prepared_inputs[key] + if len(hybrid_args) == 14 and len(updated_args) >= 29: + updated_args[15:29] = hybrid_args + else: + updated_args[24:29] = hybrid_args + self._qwen36_hybrid_apc_pending_input_dict = hybrid_apc_request_dict + return tuple(updated_args) + + def _forward_with_pad(self, *args): + self._qwen36_hybrid_apc_pending_input_dict = None + try: + outputs = super()._forward_with_pad(*args) + except Exception: + pending = self._qwen36_hybrid_apc_pending_input_dict + self._qwen36_hybrid_apc_pending_input_dict = None + if pending is not None: + cancel_hybrid_apc_request(pending) + raise + pending = self._qwen36_hybrid_apc_pending_input_dict + self._qwen36_hybrid_apc_pending_input_dict = None + if pending is not None: + finish_hybrid_apc_request(pending) + return outputs + + def pad_inputs(self, *args, pad_type="first_fit"): + """Override to pad mrope_position_ids and vision inputs to bucket size.""" + args = self._prepare_hybrid_apc_pad_inputs(args) + if ( + self.tag in (CONTEXT_ENCODING_MODEL_TAG, TOKEN_GENERATION_MODEL_TAG) + and len(args) == 15 + and self.is_prefix_caching + and not getattr( + getattr(self, "neuron_config", None), + "enable_fused_speculation", + False, + ) + and not getattr( + getattr(self, "neuron_config", None), + "enable_eagle_speculation", + False, + ) + ): + args = tuple( + _normalize_qwen36_prefix_args(args) + + [_empty_qwen36_arg(), _empty_qwen36_arg(), _empty_qwen36_arg()] + ) + orig_mrope = args[21] if len(args) >= 22 else None + orig_vis_emb = args[22] if len(args) >= 23 else None + orig_vis_mask = args[23] if len(args) >= 24 else None + if len(args) >= 29: + orig_restore_slots = args[24] + orig_restore_mask = args[25] + orig_restore_prefix = args[26] + orig_commit_slots = args[27] + orig_commit_mask = args[28] + elif ( + len(args) >= 20 + and _use_expanded_hybrid_args_for_tag(self.config, self.tag) + and self.is_prefix_caching + and not self.neuron_config.enable_fused_speculation + and not self.neuron_config.enable_eagle_speculation + ): + orig_restore_slots = args[15] + orig_restore_mask = args[16] + orig_restore_prefix = args[17] + orig_commit_slots = args[18] + orig_commit_mask = args[19] + else: + orig_restore_slots = None + orig_restore_mask = None + orig_restore_prefix = None + orig_commit_slots = None + orig_commit_mask = None + + padded_args = super().pad_inputs(*args, pad_type=pad_type) + + if len(padded_args) >= 24 and orig_mrope is not None: + padded_seq_len = padded_args[0].shape[1] + batch_size = padded_args[0].shape[0] + is_cte = self.tag == CONTEXT_ENCODING_MODEL_TAG + + if is_cte: + current_mrope = orig_mrope + current_vis_emb = orig_vis_emb + current_vis_mask = orig_vis_mask + + if ( + current_mrope.ndim == 3 + and current_mrope.shape[-1] != padded_seq_len + ): + pad_size = padded_seq_len - current_mrope.shape[-1] + last_pos = current_mrope[:, :, -1:] + # Padded tokens are masked out of the active CTE, so do not + # advance mRoPE into fake future positions. + mrope_pad = last_pos.expand(3, batch_size, pad_size) + mrope_position_ids = torch.cat([current_mrope, mrope_pad], dim=-1) + elif current_mrope.ndim == 3: + mrope_position_ids = current_mrope + else: + mrope_position_ids = ( + torch.arange(0, padded_seq_len, dtype=torch.int32) + .unsqueeze(0) + .unsqueeze(0) + .expand(3, batch_size, -1) + .contiguous() + ) + + if ( + current_vis_emb is not None + and current_vis_emb.ndim == 3 + and current_vis_emb.shape[1] < padded_seq_len + ): + pad_emb = torch.zeros( + ( + batch_size, + padded_seq_len - current_vis_emb.shape[1], + current_vis_emb.shape[2], + ), + dtype=current_vis_emb.dtype, + ) + vision_embeddings = torch.cat([current_vis_emb, pad_emb], dim=1) + elif current_vis_emb is not None and current_vis_emb.ndim == 3: + vision_embeddings = current_vis_emb[:, :padded_seq_len] + elif getattr(self.config, "use_text_only_cte_inputs", True): + vision_embeddings = torch.zeros( + (0,), dtype=self.config.neuron_config.torch_dtype + ) + else: + vision_embeddings = torch.zeros( + (batch_size, padded_seq_len, self.config.hidden_size), + dtype=self.config.neuron_config.torch_dtype, + ) + + if ( + current_vis_mask is not None + and current_vis_mask.ndim == 3 + and current_vis_mask.shape[1] < padded_seq_len + ): + pad_mask = torch.full( + (batch_size, padded_seq_len - current_vis_mask.shape[1], 1), + fill_value=padded_seq_len - 1, + dtype=torch.int32, + ) + vision_mask = torch.cat([current_vis_mask, pad_mask], dim=1) + elif current_vis_mask is not None and current_vis_mask.ndim == 3: + vision_mask = current_vis_mask[:, :padded_seq_len] + elif getattr(self.config, "use_text_only_cte_inputs", True): + vision_mask = torch.zeros((0,), dtype=torch.int32) + else: + vision_mask = torch.full( + (batch_size, padded_seq_len, 1), + fill_value=padded_seq_len - 1, + dtype=torch.int32, + ) + + padded_args = ( + *padded_args[:21], + mrope_position_ids, + vision_embeddings, + vision_mask, + ) + + if vision_mask.ndim == 3: + padded_args = list(padded_args) + padded_args[23] = padded_args[23].clamp(max=padded_seq_len - 1) + padded_args = tuple(padded_args) + + if ( + len(padded_args) >= 24 + and _use_expanded_hybrid_args_for_tag(self.config, self.tag) + ): + padded_batch_size = padded_args[0].shape[0] + + def _pad_vector(value, dtype=torch.int32): + if value is None or not hasattr(value, "ndim") or value.ndim == 0: + return torch.zeros((padded_batch_size,), dtype=dtype) + value = value.to(dtype) + if value.shape[0] == padded_batch_size: + return value + if value.shape[0] > padded_batch_size: + return value[:padded_batch_size] + pad = torch.zeros( + (padded_batch_size - value.shape[0],), + dtype=value.dtype, + ) + return torch.cat([value, pad], dim=0) + + hybrid_args = ( + _pad_vector(orig_restore_slots), + _pad_vector(orig_restore_mask), + _pad_vector(orig_restore_prefix), + _pad_vector(orig_commit_slots), + _pad_vector(orig_commit_mask), + ) + if len(padded_args) >= 29: + padded_args = (*padded_args[:24], *hybrid_args) + else: + padded_args = (*padded_args, *hybrid_args) + + _assert_qwen36_arg_count( + self.tag, + padded_args, + _qwen36_expected_arg_count(self.config, self.tag), + ) + _debug_qwen36_arg_contract("pad", self.tag, self.config, padded_args) + return padded_args + + +# ============================================================ +# Top-Level Model +# ============================================================ + + +class NeuronQwen35ForCausalLM(NeuronBaseForCausalLM): + _model_cls = NeuronQwen35Model + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._init_hybrid_apc_scheduler_bridge() + + def _init_hybrid_apc_scheduler_bridge(self): + self.hybrid_apc_store = None + self.hybrid_apc_slot_allocator = None + self.hybrid_apc_bridge = None + if not _qwen36_config_flag( + self.config, + self.neuron_config, + "use_hybrid_apc_manager", + ): + return + + required_gdn_layers = tuple( + idx + for idx, layer_type in enumerate(self.config.layer_types) + if layer_type == "linear_attention" + ) + if not required_gdn_layers: + raise ValueError("hybrid APC requires at least one GDN layer") + + tp_rank = 0 + try: + if torch.distributed.is_available() and torch.distributed.is_initialized(): + tp_rank = int(parallel_state.get_tensor_model_parallel_rank()) + except Exception: + tp_rank = 0 + + block_size = int( + getattr( + self.neuron_config, + "pa_block_size", + self.config.gdn_checkpoint_interval, + ) + ) + self.hybrid_apc_store = HybridAPCMetadataStore( + required_gdn_layers=required_gdn_layers, + block_size=block_size, + checkpoint_interval=self.config.gdn_checkpoint_interval, + max_checkpoints=self.config.max_gdn_checkpoint_slots, + layout_version=self.config.hybrid_apc_layout_version, + model_revision=self.config.hybrid_apc_model_revision, + tp_rank=tp_rank, + recurrent_dtype=self.config.hybrid_recurrent_cache_dtype, + conv_dtype=self.config.hybrid_conv_cache_dtype, + allow_residual_replay=self.config.hybrid_apc_allow_residual_replay, + ) + self.hybrid_apc_slot_allocator = HybridAPCSlotAllocator( + self.config.max_gdn_checkpoint_slots + ) + self.hybrid_apc_bridge = HybridAPCSchedulerBridge( + store=self.hybrid_apc_store, + slot_allocator=self.hybrid_apc_slot_allocator, + cache_salt=self.config.hybrid_apc_cache_salt, + model_revision=self.config.hybrid_apc_model_revision, + layout_version=self.config.hybrid_apc_layout_version, + tp_rank=tp_rank, + recurrent_dtype=self.config.hybrid_recurrent_cache_dtype, + conv_dtype=self.config.hybrid_conv_cache_dtype, + allow_local_hash_fallback=self.config.hybrid_apc_allow_local_hash_fallback, + require_attention_block_refs=self.config.hybrid_apc_require_attention_block_refs, + reject_unbacked_attention_hits=( + self.config.hybrid_apc_reject_unbacked_attention_hits + ), + ) + + def ensure_hybrid_apc_scheduler_bridge(self): + if not _qwen36_config_flag( + self.config, + self.neuron_config, + "use_hybrid_apc_manager", + ): + return None + if getattr(self, "hybrid_apc_bridge", None) is None: + self._init_hybrid_apc_scheduler_bridge() + return self.hybrid_apc_bridge + + def on_attention_block_evicted(self, block_ref: int): + if self.hybrid_apc_store is None: + return [] + return self.hybrid_apc_store.on_attention_block_evicted(block_ref) + + def on_attention_blocks_evicted(self, block_refs): + invalidated = [] + if self.hybrid_apc_store is None: + return invalidated + for block_ref in block_refs: + invalidated.extend( + self.hybrid_apc_store.on_attention_block_evicted(block_ref) + ) + return invalidated + + def get_model_wrapper_cls(self): + """Return custom ModelWrapper with DeltaNet state aliasing.""" + return Qwen35ModelWrapper + + @staticmethod + def load_hf_model(model_path, **kwargs): + """Load HF model weights. + + The model is a VL model (Qwen3_5ForConditionalGeneration) but we + only need the text backbone. + """ + from transformers import AutoModelForCausalLM + + kwargs.setdefault("trust_remote_code", True) + return AutoModelForCausalLM.from_pretrained(model_path, **kwargs) + + @classmethod + def get_config_cls(cls): + return Qwen35InferenceConfig + + @staticmethod + def convert_hf_to_neuron_state_dict(state_dict, config): + """Strip VL wrapper prefix and convert to NxDI format.""" + new_sd = {} + for k, v in state_dict.items(): + if k.startswith("language_model."): + new_k = k.replace("language_model.", "", 1) + new_sd[new_k] = v + elif k.startswith("model.language_model."): + new_k = k.replace("model.language_model.", "", 1) + new_sd[new_k] = v + elif k.startswith("model.visual") or k.startswith("visual"): + continue # Skip vision encoder + elif k.startswith("model."): + new_sd[k.replace("model.", "", 1)] = v + elif k.startswith("mtp."): + continue # Skip MTP + elif k.startswith("lm_head."): + new_sd[k] = v + else: + new_sd[k] = v + + return convert_qwen35_hf_to_neuron_state_dict(new_sd, config) + + def enable_context_encoding(self): + self.compile_tag = CONTEXT_ENCODING_MODEL_TAG + super().enable_context_encoding() + + def enable_token_generation(self): + self.compile_tag = TOKEN_GENERATION_MODEL_TAG + disable_wlo = bool( + getattr(self.config, "disable_token_generation_wlo", False) + ) or os.environ.get("QWEN36_DISABLE_TOKEN_GENERATION_WLO") == "1" + super().enable_token_generation(enable_wlt_optimization=not disable_wlo) + + def _copy_past_key_values(self, outputs): + """Override to also copy DeltaNet state buffers on CPU.""" + super()._copy_past_key_values(outputs) + if getattr(self.config, "use_hybrid_cache_manager", False): + return + + num_output_from_trace = Qwen35DecoderModelInstance._num_trace_outputs_before_aliases( + self.neuron_config + ) + + if ( + hasattr(self, "token_generation_model") + and self.token_generation_model is not None + ): + tkg_model = self.token_generation_model.model + cte_model = self.context_encoding_model.model + else: + return + + if tkg_model.kv_mgr is not None: + num_kv = len(tkg_model.kv_mgr.past_key_values) + else: + num_kv = 0 + + state_start = num_output_from_trace + num_kv + + tkg_params = getattr(tkg_model, "_deltanet_state_params", []) + cte_params = getattr(cte_model, "_deltanet_state_params", []) + + if len(tkg_params) > 0 and state_start + len(tkg_params) <= len(outputs): + for i, (tkg_param, cte_param) in enumerate(zip(tkg_params, cte_params)): + new_state = outputs[state_start + i] + tkg_param.data = new_state + cte_param.data = new_state + + checkpoint_start = state_start + len(tkg_params) + tkg_checkpoint_params = getattr(tkg_model, "_hybrid_gdn_checkpoint_params", []) + cte_checkpoint_params = getattr(cte_model, "_hybrid_gdn_checkpoint_params", []) + if ( + len(tkg_checkpoint_params) > 0 + and checkpoint_start + len(tkg_checkpoint_params) <= len(outputs) + ): + for i, (tkg_param, cte_param) in enumerate( + zip(tkg_checkpoint_params, cte_checkpoint_params) + ): + new_state = outputs[checkpoint_start + i] + tkg_param.data = new_state + cte_param.data = new_state + + def get_required_kwargs(self): + """Return extra kwargs for HF generation loop.""" + return ["llava_args"] + + def _get_model_outputs( + self, + input_ids, + attention_mask, + position_ids, + seq_ids, + sampling_params, + prev_hidden, + adapter_ids, + medusa_args, + llava_args, + slot_mapping=None, + block_table=None, + full_context_lens=None, + computed_context_lens=None, + tf_args=None, + hybrid_restore_slot_ids=None, + hybrid_restore_mask=None, + hybrid_restore_prefix_lens=None, + hybrid_commit_slot_ids=None, + hybrid_commit_mask=None, + ): + """Override to pass Qwen/vLLM positional args explicitly.""" + prefill_completion_state = getattr( + self, + "_qwen36_vllm_prefill_completion_state", + None, + ) + is_prefill = _qwen36_is_prefill_request( + input_ids, + position_ids, + full_context_lens=full_context_lens, + computed_context_lens=computed_context_lens, + prefill_completion_state=prefill_completion_state, + ) + metadata_by_request_id = getattr( + self, + "_qwen36_vllm_hybrid_apc_metadata_by_request_id", + None, + ) + request_records = getattr( + self, + "_qwen36_vllm_hybrid_apc_request_records", + None, + ) + request_ids = _qwen36_request_ids_from_hybrid_apc_records(request_records) + if request_ids is None: + request_ids = _qwen36_select_vllm_hybrid_apc_request_ids_for_input( + metadata_by_request_id, + all_request_ids=getattr(self, "_qwen36_vllm_request_ids", None), + new_request_ids=getattr(self, "_qwen36_vllm_new_request_ids", None), + full_context_lens=full_context_lens, + computed_context_lens=computed_context_lens, + prefill_completion_state=prefill_completion_state, + ) + if not is_prefill: + ( + input_ids, + attention_mask, + position_ids, + seq_ids, + adapter_ids, + slot_mapping, + ) = _qwen36_unpack_packed_decode_batch( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + seq_ids=seq_ids, + adapter_ids=adapter_ids, + slot_mapping=slot_mapping, + full_context_lens=full_context_lens, + computed_context_lens=computed_context_lens, + ) + seq_ids = _qwen36_stable_seq_ids_for_request_ids( + self, + seq_ids, + request_ids, + ) + + hybrid_apc_request_dict = None + if ( + is_prefill + and _qwen36_config_flag( + self.config, + self.neuron_config, + "use_hybrid_apc_manager", + ) + and getattr(self.neuron_config, "is_prefix_caching", False) + and _qwen36_hybrid_apc_controls_need_prepare( + hybrid_restore_mask, + hybrid_commit_mask, + ) + ): + hybrid_apc_request_dict = { + "input_ids": input_ids, + "attention_mask": attention_mask, + "position_ids": position_ids, + "seq_ids": seq_ids, + "sampling_params": sampling_params, + "adapter_ids": adapter_ids, + "slot_mapping": slot_mapping, + "block_table": block_table, + "full_context_lens": full_context_lens, + "computed_context_lens": computed_context_lens, + } + if llava_args: + hybrid_apc_request_dict["llava_args"] = llava_args + if len(llava_args) >= 3: + hybrid_apc_request_dict["rotary_position_ids"] = llava_args[2] + if request_records is not None: + hybrid_apc_request_dict["hybrid_request_records"] = request_records + if request_ids is not None: + if isinstance(request_ids, list): + request_ids = tuple(request_ids) + if isinstance(request_ids, tuple) and len(request_ids) == 1: + hybrid_apc_request_dict["hybrid_request_id"] = request_ids[0] + else: + hybrid_apc_request_dict["hybrid_request_id"] = request_ids + cached_request_ids = getattr( + self, + "_qwen36_vllm_cached_request_ids", + None, + ) + if cached_request_ids is not None: + hybrid_apc_request_dict["hybrid_cached_request_ids"] = ( + cached_request_ids + ) + if prefill_completion_state is not None: + hybrid_apc_request_dict[ + "hybrid_prefill_completion_state" + ] = prefill_completion_state + _qwen36_add_vllm_hybrid_apc_metadata( + hybrid_apc_request_dict, + request_ids=request_ids, + metadata_by_request_id=metadata_by_request_id, + ) + prepared_inputs = prepare_hybrid_apc_request_for_execution( + self, + hybrid_apc_request_dict, + ) + input_ids = prepared_inputs.get("input_ids", input_ids) + attention_mask = prepared_inputs.get("attention_mask", attention_mask) + position_ids = prepared_inputs.get("position_ids", position_ids) + seq_ids = prepared_inputs.get("seq_ids", seq_ids) + sampling_params = prepared_inputs.get("sampling_params", sampling_params) + adapter_ids = prepared_inputs.get("adapter_ids", adapter_ids) + slot_mapping = prepared_inputs.get("slot_mapping", slot_mapping) + block_table = prepared_inputs.get("block_table", block_table) + full_context_lens = prepared_inputs.get("full_context_lens", full_context_lens) + computed_context_lens = prepared_inputs.get( + "computed_context_lens", + computed_context_lens, + ) + num_queries = prepared_inputs.get("num_queries", num_queries) + hybrid_restore_slot_ids = prepared_inputs.get("hybrid_restore_slot_ids") + hybrid_restore_mask = prepared_inputs.get("hybrid_restore_mask") + hybrid_restore_prefix_lens = prepared_inputs.get( + "hybrid_restore_prefix_lens" + ) + hybrid_commit_slot_ids = prepared_inputs.get("hybrid_commit_slot_ids") + hybrid_commit_mask = prepared_inputs.get("hybrid_commit_mask") + prepared_mrope_position_ids = prepared_inputs.get( + "rotary_position_ids", + prepared_inputs.get("rotary_position_id"), + ) + if prepared_mrope_position_ids is not None and llava_args: + llava_args = list(llava_args) + if len(llava_args) >= 3: + llava_args[2] = prepared_mrope_position_ids + elif len(llava_args) >= 2: + llava_args.append(prepared_mrope_position_ids) + elif prepared_mrope_position_ids is not None: + mrope_position_ids = prepared_mrope_position_ids + else: + prepared_mrope_position_ids = None + + seq_len = input_ids.shape[1] + batch_size = input_ids.shape[0] + + if llava_args and len(llava_args) >= 2: + vision_embeddings = llava_args[0] + vision_mask = llava_args[1] + if len(llava_args) >= 3: + mrope_position_ids = llava_args[2] + else: + mrope_position_ids = None + elif is_prefill: + if getattr(self.config, "use_text_only_cte_inputs", True): + vision_embeddings = torch.zeros( + (0,), dtype=self.config.neuron_config.torch_dtype + ) + vision_mask = torch.zeros((0,), dtype=torch.int32) + else: + vision_embeddings = torch.zeros( + (batch_size, seq_len, self.config.hidden_size), + dtype=self.config.neuron_config.torch_dtype, + ) + vision_mask = torch.full( + (batch_size, seq_len, 1), + fill_value=seq_len - 1, + dtype=torch.int32, + ) + mrope_position_ids = prepared_mrope_position_ids + else: + vision_embeddings = torch.zeros((0,), dtype=torch.float32) + vision_mask = torch.zeros((0,), dtype=torch.int32) + mrope_position_ids = None + + if is_prefill: + if mrope_position_ids is None: + mrope_position_ids = ( + torch.arange(0, seq_len, dtype=torch.int32) + .unsqueeze(0) + .unsqueeze(0) + .expand(3, batch_size, -1) + .contiguous() + ) + else: + mrope_position_ids = torch.zeros((0,), dtype=torch.int32) + + def _empty(): + return torch.empty(0) + + def _optional_tensor(value): + return value if value is not None else _empty() + + def _length_matrix(value, default_value, batch=batch_size): + if value is None or not hasattr(value, "numel") or value.numel() == 0: + return torch.full((batch, 1), default_value, dtype=torch.int32) + value = value.to(torch.int32) + if value.ndim == 0: + return value.reshape(1, 1) + if value.ndim == 1: + return value.reshape(-1, 1) + return value + + def _slice_batch(value, start, end): + if value is None or not hasattr(value, "numel") or value.numel() == 0: + return _empty() + if value.ndim > 0 and value.shape[0] >= end: + return value[start:end] + return value + + def _pad_batch(value, target_batch, fill_value=0): + if value is None or not hasattr(value, "numel") or value.numel() == 0: + return value + if value.ndim == 0 or value.shape[0] >= target_batch: + return value + pad_shape = (target_batch - value.shape[0],) + tuple(value.shape[1:]) + pad = torch.full(pad_shape, fill_value, dtype=value.dtype) + return torch.cat([value, pad], dim=0) + + def _pad_batch_repeat_first(value, target_batch): + if value is None or not hasattr(value, "numel") or value.numel() == 0: + return value + if value.ndim == 0 or value.shape[0] >= target_batch: + return value + pad_n = target_batch - value.shape[0] + return torch.cat([value, value[:1].expand(pad_n, *value.shape[1:])], dim=0) + + if self.neuron_config.is_prefix_caching: + if is_prefill: + computed_context_lens_arg = _length_matrix(computed_context_lens, 0) + full_context_lens_arg = _length_matrix(full_context_lens, seq_len) + num_queries_arg = ( + full_context_lens_arg - computed_context_lens_arg + ).to(torch.int32) + else: + if seq_len != 1: + raise ValueError( + "Qwen3.6 TKG expects active decode length 1, " + f"got input_ids.shape[-1]={seq_len}" + ) + num_queries_arg = torch.full( + (batch_size, 1), seq_len, dtype=torch.int32 + ) + if ( + position_ids is not None + and hasattr(position_ids, "numel") + and position_ids.numel() > 0 + ): + computed_context_lens_arg = _length_matrix(position_ids, 0) + elif full_context_lens is not None: + computed_context_lens_arg = _length_matrix( + full_context_lens, seq_len + ) + else: + computed_context_lens_arg = _length_matrix( + computed_context_lens, + attention_mask.shape[-1] if attention_mask is not None else 0, + ) + slot_mapping_arg = _optional_tensor(slot_mapping) + slot_mapping_arg = _normalize_qwen36_slot_mapping( + slot_mapping_arg, + batch_size, + seq_len, + ) + block_table_arg = _optional_tensor(block_table) + else: + computed_context_lens_arg = _empty() + num_queries_arg = _empty() + slot_mapping_arg = _empty() + block_table_arg = _empty() + + if hybrid_restore_slot_ids is None: + hybrid_restore_slot_ids = torch.zeros((batch_size,), dtype=torch.int32) + if hybrid_restore_mask is None: + hybrid_restore_mask = torch.zeros((batch_size,), dtype=torch.int32) + if hybrid_restore_prefix_lens is None: + hybrid_restore_prefix_lens = torch.zeros((batch_size,), dtype=torch.int32) + if hybrid_commit_slot_ids is None: + hybrid_commit_slot_ids = torch.zeros((batch_size,), dtype=torch.int32) + if hybrid_commit_mask is None: + hybrid_commit_mask = torch.zeros((batch_size,), dtype=torch.int32) + + if is_prefill: + ctx_bs = self.context_encoding_model.neuron_config.batch_size + output_logits = [] + + for cb in range(0, batch_size, ctx_bs): + cb_end = min(cb + ctx_bs, batch_size) + actual_chunk = cb_end - cb + + chunk_input_ids = input_ids[cb:cb_end] + chunk_attn_mask = attention_mask[cb:cb_end] + chunk_pos_ids = position_ids[cb:cb_end] + chunk_seq_ids = seq_ids[cb:cb_end] + chunk_sampling = sampling_params[cb:cb_end] + chunk_slot_mapping = _slice_batch(slot_mapping_arg, cb, cb_end) + chunk_block_table = _slice_batch(block_table_arg, cb, cb_end) + chunk_num_queries = _slice_batch(num_queries_arg, cb, cb_end) + chunk_computed_context_lens = _slice_batch( + computed_context_lens_arg, cb, cb_end + ) + chunk_restore_slots = hybrid_restore_slot_ids[cb:cb_end] + chunk_restore_mask = hybrid_restore_mask[cb:cb_end] + chunk_restore_prefix = hybrid_restore_prefix_lens[cb:cb_end] + chunk_commit_slots = hybrid_commit_slot_ids[cb:cb_end] + chunk_commit_mask = hybrid_commit_mask[cb:cb_end] + chunk_prev_hidden = ( + prev_hidden[cb:cb_end] + if prev_hidden is not None + and hasattr(prev_hidden, "ndim") + and prev_hidden.ndim > 0 + and prev_hidden.shape[0] > 0 + else prev_hidden + ) + chunk_adapter_ids = ( + adapter_ids[cb:cb_end] + if adapter_ids is not None + and hasattr(adapter_ids, "ndim") + and adapter_ids.ndim > 0 + and adapter_ids.shape[0] > 0 + else adapter_ids + ) + + if mrope_position_ids.ndim == 3: + chunk_mrope = mrope_position_ids[:, cb:cb_end, :] + else: + chunk_mrope = mrope_position_ids + + if vision_embeddings.ndim == 3: + chunk_vis_emb = vision_embeddings[cb:cb_end] + chunk_vis_mask = vision_mask[cb:cb_end] + else: + chunk_vis_emb = vision_embeddings + chunk_vis_mask = vision_mask + + if actual_chunk < ctx_bs: + pad_n = ctx_bs - actual_chunk + chunk_input_ids = torch.cat( + [chunk_input_ids, chunk_input_ids[:1].expand(pad_n, -1)], dim=0 + ) + chunk_attn_mask = torch.cat( + [chunk_attn_mask, chunk_attn_mask[:1].expand(pad_n, -1)], dim=0 + ) + chunk_pos_ids = torch.cat( + [chunk_pos_ids, chunk_pos_ids[:1].expand(pad_n, -1)], dim=0 + ) + pad_seq = torch.full( + (pad_n,), -1, dtype=chunk_seq_ids.dtype + ) + chunk_seq_ids = torch.cat([chunk_seq_ids, pad_seq], dim=0) + chunk_sampling = torch.cat( + [chunk_sampling, chunk_sampling[:1].expand(pad_n, -1)], dim=0 + ) + chunk_slot_mapping = _pad_batch(chunk_slot_mapping, ctx_bs, -1) + chunk_block_table = _pad_batch_repeat_first( + chunk_block_table, ctx_bs + ) + chunk_num_queries = _pad_batch_repeat_first( + chunk_num_queries, ctx_bs + ) + chunk_computed_context_lens = _pad_batch( + chunk_computed_context_lens, ctx_bs, 0 + ) + # Dummy CTE rows repeat active token tensors to satisfy the + # compiled batch shape, but they must not advertise a + # prefix-cache restore. Their seq_ids are marked negative + # and the DeltaNet state update preserves negative rows, so + # recurrent state cannot leak into seq_ids later reused by + # real requests. + ( + chunk_restore_slots, + chunk_restore_mask, + chunk_restore_prefix, + ) = _qwen36_pad_hybrid_restore_controls_for_dummy_cte_rows( + chunk_restore_slots, + chunk_restore_mask, + chunk_restore_prefix, + ctx_bs, + ) + chunk_commit_slots = torch.cat( + [chunk_commit_slots, torch.zeros(pad_n, dtype=chunk_commit_slots.dtype)], + dim=0, + ) + chunk_commit_mask = torch.cat( + [chunk_commit_mask, torch.zeros(pad_n, dtype=chunk_commit_mask.dtype)], + dim=0, + ) + if ( + chunk_prev_hidden is not None + and hasattr(chunk_prev_hidden, "ndim") + and chunk_prev_hidden.ndim > 0 + and chunk_prev_hidden.shape[0] > 0 + ): + chunk_prev_hidden = torch.cat( + [ + chunk_prev_hidden, + chunk_prev_hidden[:1].expand(pad_n, -1), + ], + dim=0, + ) + if ( + chunk_adapter_ids is not None + and hasattr(chunk_adapter_ids, "ndim") + and chunk_adapter_ids.ndim > 0 + and chunk_adapter_ids.shape[0] > 0 + ): + chunk_adapter_ids = torch.cat( + [ + chunk_adapter_ids, + chunk_adapter_ids[:1].expand(pad_n, -1), + ], + dim=0, + ) + if chunk_mrope.ndim == 3: + chunk_mrope = torch.cat( + [chunk_mrope, chunk_mrope[:, :1, :].expand(-1, pad_n, -1)], + dim=1, + ) + if chunk_vis_emb.ndim == 3: + chunk_vis_emb = torch.cat( + [ + chunk_vis_emb, + torch.zeros( + (pad_n,) + chunk_vis_emb.shape[1:], + dtype=chunk_vis_emb.dtype, + ), + ], + dim=0, + ) + chunk_vis_mask = torch.cat( + [ + chunk_vis_mask, + torch.full( + (pad_n,) + chunk_vis_mask.shape[1:], + fill_value=seq_len - 1, + dtype=chunk_vis_mask.dtype, + ), + ], + dim=0, + ) + + if os.environ.get("QWEN36_HYBRID_APC_DEBUG") == "1": + def _dbg_minmax(tensor): + if not hasattr(tensor, "numel") or tensor.numel() == 0: + return "empty" + flat = tensor.reshape(-1) + return f"{int(flat.min().item())}:{int(flat.max().item())}" + + print( + "[hybrid_apc_debug] qwen-cte-call " + f"input_shape={tuple(chunk_input_ids.shape)} " + f"attention_shape={tuple(chunk_attn_mask.shape)} " + f"position_shape={tuple(chunk_pos_ids.shape)} " + f"position_minmax={_dbg_minmax(chunk_pos_ids)} " + f"seq_ids={chunk_seq_ids.reshape(-1).tolist() if hasattr(chunk_seq_ids, 'numel') and chunk_seq_ids.numel() else []} " + f"slot_shape={tuple(chunk_slot_mapping.shape)} " + f"slot_minmax={_dbg_minmax(chunk_slot_mapping)} " + f"block_shape={tuple(chunk_block_table.shape)} " + f"block_minmax={_dbg_minmax(chunk_block_table)} " + f"num_queries={chunk_num_queries.reshape(-1).tolist() if hasattr(chunk_num_queries, 'numel') and chunk_num_queries.numel() else []} " + f"computed={chunk_computed_context_lens.reshape(-1).tolist() if hasattr(chunk_computed_context_lens, 'numel') and chunk_computed_context_lens.numel() else []} " + f"restore_slots={chunk_restore_slots.reshape(-1).tolist()} " + f"restore_mask={chunk_restore_mask.reshape(-1).tolist()} " + f"restore_prefix={chunk_restore_prefix.reshape(-1).tolist()} " + f"commit_slots={chunk_commit_slots.reshape(-1).tolist()} " + f"commit_mask={chunk_commit_mask.reshape(-1).tolist()}", + flush=True, + ) + + cte_prefix_args = [ + chunk_input_ids, + chunk_attn_mask, + chunk_pos_ids, + chunk_seq_ids, + chunk_sampling, + chunk_prev_hidden, + chunk_adapter_ids, + _empty(), + _empty(), + _empty(), + _empty(), + chunk_slot_mapping, + chunk_block_table, + chunk_num_queries, + chunk_computed_context_lens, + _empty(), + _empty(), + _empty(), + _empty(), + _empty(), + _empty(), + ] + cte_args = build_cte_args( + self.config, + cte_prefix_args, + chunk_mrope, + chunk_vis_emb, + chunk_vis_mask, + hybrid_args=( + chunk_restore_slots, + chunk_restore_mask, + chunk_restore_prefix, + chunk_commit_slots, + chunk_commit_mask, + ), + ) + _debug_qwen36_arg_contract( + "runtime", + CONTEXT_ENCODING_MODEL_TAG, + self.config, + cte_args, + ) + _qwen36_prefill_timing = os.environ.get("QWEN36_PREFILL_TIMING") == "1" + _qwen36_cte_start = time.perf_counter() if _qwen36_prefill_timing else None + try: + chunk_out = self.context_encoding_model(*cte_args) + except Exception: + if hybrid_apc_request_dict is not None: + cancel_hybrid_apc_request(hybrid_apc_request_dict) + hybrid_apc_request_dict = None + raise + if _qwen36_prefill_timing and _qwen36_cte_start is not None: + print( + "[qwen36_perf] qwen_cte_call " + f"elapsed_ms={(time.perf_counter() - _qwen36_cte_start) * 1000.0:.3f} " + f"actual_chunk={actual_chunk} ctx_bs={ctx_bs} " + f"input_shape={tuple(chunk_input_ids.shape)} " + f"num_queries={chunk_num_queries.reshape(-1).tolist() if hasattr(chunk_num_queries, 'numel') and chunk_num_queries.numel() else []} " + f"computed={chunk_computed_context_lens.reshape(-1).tolist() if hasattr(chunk_computed_context_lens, 'numel') and chunk_computed_context_lens.numel() else []} " + f"restore_mask={chunk_restore_mask.reshape(-1).tolist()} " + f"commit_mask={chunk_commit_mask.reshape(-1).tolist()}", + flush=True, + ) + if actual_chunk < ctx_bs: + chunk_out = chunk_out[:actual_chunk] + output_logits.append(chunk_out) + + outputs = ( + torch.cat(output_logits, dim=0) + if len(output_logits) > 1 + else output_logits[0] + ) + self.kv_cache_populated = True + is_run_on_neuron = self.context_encoding_model.is_neuron() + if hybrid_apc_request_dict is not None: + finish_hybrid_apc_request(hybrid_apc_request_dict) + else: + _validate_qwen36_tkg_input_ids( + input_ids, + getattr(self.config, "vocab_size", None), + ) + legacy_tkg_args = _use_legacy_tkg_args() + if ( + os.environ.get("QWEN36_TKG_INPUT_DEBUG") == "1" + or os.environ.get("QWEN36_HYBRID_APC_DEBUG") == "1" + ): + max_model_len = getattr( + self.neuron_config, + "max_length", + getattr(self.neuron_config, "seq_len", None), + ) + print( + "[hybrid_apc_debug] qwen-tkg-call " + f"arg_mode={'prefix24_legacy' if legacy_tkg_args else 'hybrid29'} " + f"input_shape={_debug_tensor_shape(input_ids)} " + f"input_values={_debug_tensor_values(input_ids)} " + f"attention_shape={_debug_tensor_shape(attention_mask)} " + f"position_shape={_debug_tensor_shape(position_ids)} " + f"position_minmax={_debug_tensor_minmax(position_ids)} " + f"slot_shape={_debug_tensor_shape(slot_mapping_arg)} " + f"slot_minmax={_debug_tensor_minmax(slot_mapping_arg)} " + f"block_shape={_debug_tensor_shape(block_table_arg)} " + f"block_minmax={_debug_tensor_minmax(block_table_arg)} " + f"num_queries={_debug_tensor_values(num_queries_arg)} " + "computed_context_lens=" + f"{_debug_tensor_values(computed_context_lens_arg)} " + f"pa_num_blocks={getattr(self.neuron_config, 'pa_num_blocks', None)} " + f"block_size={getattr(self.neuron_config, 'pa_block_size', None)} " + f"seq_len={seq_len} max_model_len={max_model_len}", + flush=True, + ) + tkg_prefix_args = [ + input_ids, + attention_mask, + position_ids, + seq_ids, + sampling_params, + prev_hidden, + adapter_ids, + _empty(), + _empty(), + _empty(), + _empty(), + slot_mapping_arg, + block_table_arg, + num_queries_arg, + computed_context_lens_arg, + _empty(), + _empty(), + _empty(), + _empty(), + _empty(), + _empty(), + ] + tkg_args = build_tkg_args( + self.config, + tkg_prefix_args, + mrope_position_ids, + vision_embeddings, + vision_mask, + hybrid_args=( + hybrid_restore_slot_ids, + hybrid_restore_mask, + hybrid_restore_prefix_lens, + hybrid_commit_slot_ids, + hybrid_commit_mask, + ), + ) + _debug_qwen36_arg_contract( + "runtime", + TOKEN_GENERATION_MODEL_TAG, + self.config, + tkg_args, + ) + outputs = self.token_generation_model(*tkg_args) + is_run_on_neuron = self.token_generation_model.is_neuron() + + return outputs, is_run_on_neuron + + def get_compiler_args(self): + if self.compile_tag == CONTEXT_ENCODING_MODEL_TAG: + optimization_level = "-O1" + else: + optimization_level = "-O1" + + compiler_args = ( + "--enable-saturate-infinity " + "--enable-mixed-precision-accumulation " + f"--model-type transformer {optimization_level} " + "--auto-cast=none " + ) + return compiler_args diff --git a/contrib/models/Qwen3.6-27B/src/modeling_qwen35_vision.py b/contrib/models/Qwen3.6-27B/src/modeling_qwen35_vision.py new file mode 100644 index 00000000..761d7e95 --- /dev/null +++ b/contrib/models/Qwen3.6-27B/src/modeling_qwen35_vision.py @@ -0,0 +1,819 @@ +""" +Qwen3.5-27B / Qwen3.6-27B (Dense) Vision Encoder for NeuronX Distributed Inference. + +Ports the Qwen3.5/3.6 ViT encoder to run on Neuron. The vision encoder +architecture is identical across Qwen3.5-27B and Qwen3.6-27B (same patch +embed, same rotary, same merger) -- only out_hidden_size changes vs the MoE +variant (5120 vs 2048, read from config). + +The vision encoder runs as a separate compiled model from the text decoder, +compiled and loaded via NeuronBaseForImageToText. +""" + +import logging +import math +import os +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F + +# CRITICAL: Use finite negative value instead of -inf for Neuron attention masks. +# The Neuron compiler's bfloat16 handling of -inf produces NaN that bleeds from +# padding positions into ALL positions through the transformer layers. +# -65504.0 is large enough for softmax masking but avoids NaN overflow. +_MASK_NEG_INF = -65504.0 + +logger = logging.getLogger(__name__) + +# -- NxDI imports (available on Neuron instances) -- +try: + from neuronx_distributed_inference.models.application_base import ( + NeuronApplicationBase, + ) + from neuronx_distributed_inference.models.model_wrapper import ModelWrapper + from neuronx_distributed_inference.modules.attention.attention_base import ( + NeuronAttentionBase, + ) + from neuronx_distributed_inference.modules.attention.utils import RotaryEmbedding + from neuronx_distributed.parallel_layers import layers as nxd_layers +except ImportError: + logger.warning( + "NxDI imports unavailable -- vision module can only be used on Neuron instances" + ) + +# -- HuggingFace imports for patch embed (runs on CPU) -- +try: + from transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import ( + Qwen3_5MoeVisionPatchEmbed, + Qwen3_5MoeVisionPatchMerger, + Qwen3_5MoeVisionRotaryEmbedding, + ) +except ImportError: + try: + # transformers 4.57+ uses Qwen3VL* class names + from transformers.models.qwen3_vl.modeling_qwen3_vl import ( + Qwen3VLVisionPatchEmbed as Qwen3_5MoeVisionPatchEmbed, + Qwen3VLVisionPatchMerger as Qwen3_5MoeVisionPatchMerger, + Qwen3VLVisionRotaryEmbedding as Qwen3_5MoeVisionRotaryEmbedding, + ) + except ImportError: + try: + # Older transformers uses Qwen2VL* class names + from transformers.models.qwen2_vl.modeling_qwen2_vl import ( + Qwen2VLVisionPatchEmbed as Qwen3_5MoeVisionPatchEmbed, + Qwen2VLVisionPatchMerger as Qwen3_5MoeVisionPatchMerger, + Qwen2VLVisionRotaryEmbedding as Qwen3_5MoeVisionRotaryEmbedding, + ) + except ImportError: + Qwen3_5MoeVisionPatchEmbed = None + Qwen3_5MoeVisionPatchMerger = None + Qwen3_5MoeVisionRotaryEmbedding = None + + +def apply_rotary_pos_emb_vision(q, k, cos, sin): + """Apply rotary position embeddings to vision Q and K tensors. + + Uses rotate_half style (matching HF reference): + q_embed = (q * cos) + (rotate_half(q) * sin) + + Args: + q: (seq_len, num_heads, head_dim) + k: (seq_len, num_heads, head_dim) + cos: (seq_len, head_dim) + sin: (seq_len, head_dim) + """ + cos = cos.unsqueeze(-2) # (seq_len, 1, head_dim) + sin = sin.unsqueeze(-2) + + def rotate_half(x): + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed.to(q.dtype), k_embed.to(k.dtype) + + +class NeuronQwen35VisionAttention(nn.Module): + """Vision attention for Qwen3.5 MoE. + + Uses fused QKV linear (no bias in Neuron port for efficiency). + Non-causal attention with block-diagonal mask for variable-length images. + """ + + def __init__(self, config): + super().__init__() + self.hidden_size = config.hidden_size + self.num_heads = config.num_heads + self.head_dim = self.hidden_size // self.num_heads + self.scaling = self.head_dim**-0.5 + + # Fused QKV: (hidden_size -> 3 * hidden_size) with bias + self.qkv = nxd_layers.ColumnParallelLinear( + self.hidden_size, + 3 * self.hidden_size, + bias=True, + gather_output=True, + ) + self.proj = nxd_layers.RowParallelLinear( + self.hidden_size, + self.hidden_size, + bias=True, + input_is_parallel=False, + ) + + def forward(self, hidden_states, attention_mask=None, position_embeddings=None): + """ + Args: + hidden_states: (seq_len, hidden_size) + attention_mask: (1, 1, seq_len, seq_len) block-diagonal mask + position_embeddings: (cos, sin) tuple + """ + seq_len = hidden_states.shape[0] + + # QKV projection + qkv = self.qkv(hidden_states) # (seq_len, 3 * hidden_size) + qkv = qkv.reshape(seq_len, 3, self.num_heads, self.head_dim) + qkv = qkv.permute(1, 0, 2, 3) # (3, seq_len, num_heads, head_dim) + q, k, v = qkv.unbind(0) # each (seq_len, num_heads, head_dim) + + # Apply rotary embeddings + if position_embeddings is not None: + cos, sin = position_embeddings + q, k = apply_rotary_pos_emb_vision(q, k, cos, sin) + + # Reshape for batched attention: (1, num_heads, seq_len, head_dim) + q = q.transpose(0, 1).unsqueeze(0) + k = k.transpose(0, 1).unsqueeze(0) + v = v.transpose(0, 1).unsqueeze(0) + + # Scaled dot-product attention + attn_weights = torch.matmul(q, k.transpose(-1, -2)) * self.scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype) + attn_output = torch.matmul(attn_weights, v) + + # Reshape back: (seq_len, hidden_size) + attn_output = attn_output.squeeze(0).transpose(0, 1).reshape(seq_len, -1) + + # Output projection + attn_output = self.proj(attn_output) + return attn_output + + +class NeuronQwen35VisionMLP(nn.Module): + """Vision MLP with GELU activation.""" + + def __init__(self, config): + super().__init__() + self.linear_fc1 = nxd_layers.ColumnParallelLinear( + config.hidden_size, + config.intermediate_size, + bias=True, + gather_output=True, + ) + self.linear_fc2 = nxd_layers.RowParallelLinear( + config.intermediate_size, + config.hidden_size, + bias=True, + input_is_parallel=False, + ) + self.act_fn = nn.GELU() + + def forward(self, hidden_states): + return self.linear_fc2(self.act_fn(self.linear_fc1(hidden_states))) + + +class NeuronQwen35VisionBlock(nn.Module): + """Single vision transformer block: LayerNorm + Attention + LayerNorm + MLP.""" + + def __init__(self, config): + super().__init__() + self.norm1 = nn.LayerNorm(config.hidden_size, eps=1e-6) + self.norm2 = nn.LayerNorm(config.hidden_size, eps=1e-6) + self.attn = NeuronQwen35VisionAttention(config) + self.mlp = NeuronQwen35VisionMLP(config) + + def forward(self, hidden_states, attention_mask=None, position_embeddings=None): + hidden_states = hidden_states + self.attn( + self.norm1(hidden_states), + attention_mask=attention_mask, + position_embeddings=position_embeddings, + ) + hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) + return hidden_states + + +class NeuronQwen35VisionModel(nn.Module): + """Qwen3.5 MoE Vision Encoder for Neuron. + + This is the nn.Module that gets compiled and traced onto Neuron. + Patch embedding, positional embedding, and rotary embedding are computed + on CPU in the ModelWrapper and passed as inputs. + """ + + def __init__(self, config): + super().__init__() + self.config = config + self.blocks = nn.ModuleList( + [NeuronQwen35VisionBlock(config) for _ in range(config.depth)] + ) + # Merger: spatial_merge_size^2 * hidden_size -> out_hidden_size + self.merger_norm = nn.LayerNorm(config.hidden_size, eps=1e-6) + merger_hidden = config.hidden_size * (config.spatial_merge_size**2) + self.merger_fc1 = nn.Linear(merger_hidden, merger_hidden) + self.merger_act = nn.GELU() + self.merger_fc2 = nn.Linear(merger_hidden, config.out_hidden_size) + + def forward(self, hidden_states, attention_mask=None, position_embeddings=None): + """ + Args: + hidden_states: (seq_len, hidden_size) -- after patch_embed + pos_embed + attention_mask: (1, 1, seq_len, seq_len) block-diagonal mask + position_embeddings: (cos, sin) tuple for rotary + + Returns: + vision_embeddings: (merged_seq_len, out_hidden_size) + """ + for block in self.blocks: + hidden_states = block( + hidden_states, + attention_mask=attention_mask, + position_embeddings=position_embeddings, + ) + + # Apply merger: norm -> spatial merge -> fc1 -> gelu -> fc2 + hidden_states = self.merger_norm(hidden_states) + merge_size = self.config.spatial_merge_size + merged_hidden = self.config.hidden_size * (merge_size**2) + hidden_states = hidden_states.view(-1, merged_hidden) + hidden_states = self.merger_fc2(self.merger_act(self.merger_fc1(hidden_states))) + + return hidden_states + + +class CPUVisionModel(nn.Module): + """CPU-only vision encoder (pure PyTorch, no Neuron dependencies). + + Used when HBM is insufficient to load the vision encoder on Neuron + alongside the text decoder (e.g., 27B dense model on trn2.3xlarge). + """ + + def __init__(self, config): + super().__init__() + self.config = config + self.blocks = nn.ModuleList( + [self._make_block(config) for _ in range(config.depth)] + ) + self.merger_norm = nn.LayerNorm(config.hidden_size, eps=1e-6) + merger_hidden = config.hidden_size * (config.spatial_merge_size**2) + self.merger_fc1 = nn.Linear(merger_hidden, merger_hidden) + self.merger_act = nn.GELU() + self.merger_fc2 = nn.Linear(merger_hidden, config.out_hidden_size) + + @staticmethod + def _make_block(config): + """Build a single vision block with standard nn.Linear (no TP).""" + block = nn.Module() + block.norm1 = nn.LayerNorm(config.hidden_size, eps=1e-6) + block.norm2 = nn.LayerNorm(config.hidden_size, eps=1e-6) + + # Attention + attn = nn.Module() + attn.hidden_size = config.hidden_size + attn.num_heads = config.num_heads + attn.head_dim = config.hidden_size // config.num_heads + attn.scaling = attn.head_dim**-0.5 + attn.qkv = nn.Linear(config.hidden_size, 3 * config.hidden_size, bias=True) + attn.proj = nn.Linear(config.hidden_size, config.hidden_size, bias=True) + block.attn = attn + + # MLP + mlp = nn.Module() + mlp.linear_fc1 = nn.Linear( + config.hidden_size, config.intermediate_size, bias=True + ) + mlp.linear_fc2 = nn.Linear( + config.intermediate_size, config.hidden_size, bias=True + ) + mlp.act_fn = nn.GELU() + block.mlp = mlp + + return block + + def _forward_attention(self, attn, hidden_states, attention_mask, cos, sin): + seq_len = hidden_states.shape[0] + qkv = attn.qkv(hidden_states).reshape(seq_len, 3, attn.num_heads, attn.head_dim) + qkv = qkv.permute(1, 0, 2, 3) + q, k, v = qkv.unbind(0) + + if cos is not None and sin is not None: + cos_u = cos.unsqueeze(-2) + sin_u = sin.unsqueeze(-2) + + def rotate_half(x): + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + q = (q * cos_u) + (rotate_half(q) * sin_u) + k = (k * cos_u) + (rotate_half(k) * sin_u) + + q = q.transpose(0, 1).unsqueeze(0) + k = k.transpose(0, 1).unsqueeze(0) + v = v.transpose(0, 1).unsqueeze(0) + + attn_weights = torch.matmul(q, k.transpose(-1, -2)) * attn.scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype) + out = torch.matmul(attn_weights, v) + out = out.squeeze(0).transpose(0, 1).reshape(seq_len, -1) + return attn.proj(out) + + def forward(self, hidden_states, attention_mask, cos, sin): + for block in self.blocks: + hidden_states = hidden_states + self._forward_attention( + block.attn, block.norm1(hidden_states), attention_mask, cos, sin + ) + hidden_states = hidden_states + block.mlp.linear_fc2( + block.mlp.act_fn(block.mlp.linear_fc1(block.norm2(hidden_states))) + ) + + hidden_states = self.merger_norm(hidden_states) + merge_size = self.config.spatial_merge_size + merged_hidden = self.config.hidden_size * (merge_size**2) + hidden_states = hidden_states.view(-1, merged_hidden) + hidden_states = self.merger_fc2(self.merger_act(self.merger_fc1(hidden_states))) + return hidden_states + + +class NeuronQwen35VisionModelWrapper(ModelWrapper): + """Wraps the vision encoder for NxDI tracing. + + Handles CPU-side operations that cannot be traced: + - Patch embedding (Conv3d) + - Positional embedding (Embedding + bilinear interpolation) + - Rotary position embedding computation + - Vision attention mask construction (block-diagonal) + - Sequence length bucketing and padding/unpadding + + Supports three modes: + 1. NxDI traced model (parallel layers) -- standard NxDI compilation + 2. Pre-compiled standalone model -- loaded from torch_neuronx.trace() output + 3. CPU-only model -- for when HBM is full (e.g., 27B dense on trn2.3xlarge) + """ + + def __init__(self, config, model_cls=None, **kwargs): + if model_cls is not None: + super().__init__(config, model_cls, **kwargs) + else: + # Standalone mode: no NxDI model_cls + nn.Module.__init__(self) + self.vision_config = config + self._compiled_model = None # Set by load_compiled() -- single bucket + self._compiled_buckets = None # Set by load_compiled() -- multi-bucket dict + self._cpu_model = None # Set by load_cpu_model() + + # These HF modules run on CPU, outside the traced graph + if Qwen3_5MoeVisionPatchEmbed is not None: + self.patch_embed = Qwen3_5MoeVisionPatchEmbed(config) + self.pos_embed = nn.Embedding( + config.num_position_embeddings, config.hidden_size + ) + self.num_grid_per_side = int(config.num_position_embeddings**0.5) + head_dim = config.hidden_size // config.num_heads + self.rotary_pos_emb = Qwen3_5MoeVisionRotaryEmbedding(head_dim // 2) + else: + logger.warning("HF Qwen3.5 MoE vision classes not available") + + self.vision_seq_len_buckets = kwargs.get( + "vision_seq_len_buckets", [1024, 4096, 16384] + ) + + def load_compiled(self, compiled_model_path): + """Load pre-compiled standalone vision encoder(s). + + Supports two modes: + 1. Single .pt file: Legacy mode, loads one compiled model for one bucket size. + 2. Directory with multiple .pt files: Multi-bucket mode. Files must be named + 'vision_encoder_{bucket_size}.pt' (e.g., 'vision_encoder_256.pt'). + Falls back to single 'vision_encoder.pt' in the directory. + + Args: + compiled_model_path: Path to a .pt file or directory containing bucket .pt files. + """ + import glob as glob_module + + logger.info(f"Loading pre-compiled vision encoder from {compiled_model_path}") + + if os.path.isfile(compiled_model_path): + # Single file mode (legacy) + self._compiled_model = torch.jit.load(compiled_model_path) + self._compiled_buckets = None + logger.info("Vision encoder loaded successfully (single bucket)") + elif os.path.isdir(compiled_model_path): + # Directory mode: look for bucket-specific files + bucket_files = sorted( + glob_module.glob( + os.path.join(compiled_model_path, "vision_encoder_*.pt") + ) + ) + if bucket_files: + self._compiled_buckets = {} + for bf in bucket_files: + # Extract bucket size from filename: vision_encoder_256.pt -> 256 + basename = os.path.basename(bf) + try: + bucket_size = int( + basename.replace("vision_encoder_", "").replace(".pt", "") + ) + self._compiled_buckets[bucket_size] = torch.jit.load(bf) + logger.info(f" Loaded vision bucket {bucket_size} from {bf}") + except ValueError: + logger.warning(f" Skipping unrecognized file: {bf}") + self._compiled_model = None + # Update vision_seq_len_buckets to match compiled buckets + self.vision_seq_len_buckets = sorted(self._compiled_buckets.keys()) + logger.info( + f"Vision encoder loaded with {len(self._compiled_buckets)} buckets: " + f"{self.vision_seq_len_buckets}" + ) + else: + # Fall back to single vision_encoder.pt in directory + single_path = os.path.join(compiled_model_path, "vision_encoder.pt") + if os.path.exists(single_path): + self._compiled_model = torch.jit.load(single_path) + self._compiled_buckets = None + logger.info( + "Vision encoder loaded successfully (single file in dir)" + ) + else: + raise FileNotFoundError( + f"No vision encoder files found in {compiled_model_path}" + ) + else: + raise FileNotFoundError( + f"Vision encoder path not found: {compiled_model_path}" + ) + + def load_vision_weights_from_hf(self, model_path): + """Load patch_embed and pos_embed weights from HF safetensors. + + Args: + model_path: Path to HF model directory + """ + from pathlib import Path + from safetensors import safe_open + + st_files = sorted( + p + for p in Path(model_path).glob("*.safetensors") + if p.suffix == ".safetensors" + ) + loaded = 0 + for sf_path in st_files: + with safe_open(str(sf_path), framework="pt") as f: + for key in f.keys(): + if key == "model.visual.patch_embed.proj.weight": + self.patch_embed.proj.weight.data.copy_(f.get_tensor(key)) + loaded += 1 + elif key == "model.visual.patch_embed.proj.bias": + self.patch_embed.proj.bias.data.copy_(f.get_tensor(key)) + loaded += 1 + elif key == "model.visual.pos_embed.weight": + self.pos_embed.weight.data.copy_(f.get_tensor(key)) + loaded += 1 + logger.info(f"Loaded {loaded} CPU-side vision weight tensors from HF") + + def load_cpu_model(self, model_path): + """Load a CPU-only vision encoder from HF safetensors. + + Use this when HBM is insufficient for the Neuron-compiled vision encoder + (e.g., 27B dense model fills trn2.3xlarge HBM). + + Args: + model_path: Path to HF model directory with safetensors + """ + from pathlib import Path + from safetensors import safe_open + + config = self.vision_config + cpu_model = CPUVisionModel(config) + + # Build key mapping from HF safetensors to CPU model + key_map = {} + for i in range(config.depth): + hf_pre = f"model.visual.blocks.{i}" + loc_pre = f"blocks.{i}" + for suffix in [ + "attn.qkv.weight", + "attn.qkv.bias", + "attn.proj.weight", + "attn.proj.bias", + "mlp.linear_fc1.weight", + "mlp.linear_fc1.bias", + "mlp.linear_fc2.weight", + "mlp.linear_fc2.bias", + "norm1.weight", + "norm1.bias", + "norm2.weight", + "norm2.bias", + ]: + key_map[f"{hf_pre}.{suffix}"] = f"{loc_pre}.{suffix}" + + key_map["model.visual.merger.norm.weight"] = "merger_norm.weight" + key_map["model.visual.merger.norm.bias"] = "merger_norm.bias" + key_map["model.visual.merger.linear_fc1.weight"] = "merger_fc1.weight" + key_map["model.visual.merger.linear_fc1.bias"] = "merger_fc1.bias" + key_map["model.visual.merger.linear_fc2.weight"] = "merger_fc2.weight" + key_map["model.visual.merger.linear_fc2.bias"] = "merger_fc2.bias" + + st_files = sorted(Path(model_path).glob("model*.safetensors")) + loaded = 0 + state_dict = cpu_model.state_dict() + + for sf_path in st_files: + with safe_open(str(sf_path), framework="pt") as f: + for key in f.keys(): + if key in key_map: + local_key = key_map[key] + if local_key in state_dict: + state_dict[local_key].copy_(f.get_tensor(key)) + loaded += 1 + + cpu_model.load_state_dict(state_dict) + cpu_model = cpu_model.to(torch.bfloat16).eval() + self._cpu_model = cpu_model + logger.info( + f"Loaded CPU vision encoder: {loaded} weights, " + f"{sum(p.numel() for p in cpu_model.parameters()) / 1e6:.1f}M params" + ) + + def _get_vision_bucket(self, seq_len): + """Find the smallest bucket that fits the sequence length.""" + for bucket in sorted(self.vision_seq_len_buckets): + if seq_len <= bucket: + return bucket + return self.vision_seq_len_buckets[-1] + + def rot_pos_emb(self, grid_thw): + """Compute rotary positional embeddings for vision tokens. + + Returns: (total_tokens, head_dim) tensor of rotary frequencies. + """ + merge_size = self.vision_config.spatial_merge_size + grid_thw_list = grid_thw.tolist() + + max_hw = max(max(h, w) for _, h, w in grid_thw_list) + freq_table = self.rotary_pos_emb(max_hw) + device = freq_table.device + + total_tokens = sum(t * h * w for t, h, w in grid_thw_list) + pos_ids = torch.empty((total_tokens, 2), dtype=torch.long, device=device) + + offset = 0 + for num_frames, height, width in grid_thw_list: + merged_h, merged_w = height // merge_size, width // merge_size + + block_rows = torch.arange(merged_h, device=device) + block_cols = torch.arange(merged_w, device=device) + intra_row = torch.arange(merge_size, device=device) + intra_col = torch.arange(merge_size, device=device) + + row_idx = ( + block_rows[:, None, None, None] * merge_size + + intra_row[None, None, :, None] + ) + col_idx = ( + block_cols[None, :, None, None] * merge_size + + intra_col[None, None, None, :] + ) + + row_idx = row_idx.expand( + merged_h, merged_w, merge_size, merge_size + ).reshape(-1) + col_idx = col_idx.expand( + merged_h, merged_w, merge_size, merge_size + ).reshape(-1) + + coords = torch.stack((row_idx, col_idx), dim=-1) + if num_frames > 1: + coords = coords.repeat(num_frames, 1) + + num_tokens = coords.shape[0] + pos_ids[offset : offset + num_tokens] = coords + offset += num_tokens + + embeddings = freq_table[pos_ids] + embeddings = embeddings.flatten(1) + return embeddings + + def fast_pos_embed_interpolate(self, grid_thw): + """Bilinear interpolation of positional embeddings for variable resolution.""" + grid_thw_list = grid_thw.tolist() + grid_ts = [row[0] for row in grid_thw_list] + grid_hs = [row[1] for row in grid_thw_list] + grid_ws = [row[2] for row in grid_thw_list] + device = self.pos_embed.weight.device + + idx_list = [[] for _ in range(4)] + weight_list = [[] for _ in range(4)] + + for t, h, w in grid_thw_list: + h_idxs = torch.linspace(0, self.num_grid_per_side - 1, h) + w_idxs = torch.linspace(0, self.num_grid_per_side - 1, w) + + h_idxs_floor = h_idxs.int() + w_idxs_floor = w_idxs.int() + h_idxs_ceil = (h_idxs.int() + 1).clip(max=self.num_grid_per_side - 1) + w_idxs_ceil = (w_idxs.int() + 1).clip(max=self.num_grid_per_side - 1) + + dh = h_idxs - h_idxs_floor + dw = w_idxs - w_idxs_floor + + base_h = h_idxs_floor * self.num_grid_per_side + base_h_ceil = h_idxs_ceil * self.num_grid_per_side + + indices = [ + (base_h[None].T + w_idxs_floor[None]).flatten(), + (base_h[None].T + w_idxs_ceil[None]).flatten(), + (base_h_ceil[None].T + w_idxs_floor[None]).flatten(), + (base_h_ceil[None].T + w_idxs_ceil[None]).flatten(), + ] + weights = [ + ((1 - dh)[None].T * (1 - dw)[None]).flatten(), + ((1 - dh)[None].T * dw[None]).flatten(), + (dh[None].T * (1 - dw)[None]).flatten(), + (dh[None].T * dw[None]).flatten(), + ] + + for i in range(4): + idx_list[i].extend(indices[i].tolist()) + weight_list[i].extend(weights[i].tolist()) + + idx_tensor = torch.tensor(idx_list, dtype=torch.long, device=device) + weight_tensor = torch.tensor( + weight_list, dtype=self.pos_embed.weight.dtype, device=device + ) + pos_embeds = self.pos_embed(idx_tensor).to(device) * weight_tensor[:, :, None] + patch_pos_embeds = pos_embeds[0] + pos_embeds[1] + pos_embeds[2] + pos_embeds[3] + + patch_pos_embeds = patch_pos_embeds.split( + [h * w for h, w in zip(grid_hs, grid_ws)] + ) + + merge_size = self.vision_config.spatial_merge_size + patch_pos_embeds_permute = [] + for pos_embed, t, h, w in zip(patch_pos_embeds, grid_ts, grid_hs, grid_ws): + pos_embed = pos_embed.repeat(t, 1) + pos_embed = ( + pos_embed.view( + t, h // merge_size, merge_size, w // merge_size, merge_size, -1 + ) + .permute(0, 1, 3, 2, 4, 5) + .flatten(0, 4) + ) + patch_pos_embeds_permute.append(pos_embed) + + return torch.cat(patch_pos_embeds_permute) + + def _build_vision_attention_mask(self, grid_thw, seq_len, dtype): + """Build block-diagonal attention mask for variable-length images. + + Each image gets its own attention block (no cross-image attention). + """ + cu_seqlens = torch.repeat_interleave( + grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0] + ).cumsum(dim=0, dtype=torch.int32) + cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) + + # Build block-diagonal mask + mask = torch.full((seq_len, seq_len), _MASK_NEG_INF, dtype=dtype) + for i in range(len(cu_seqlens) - 1): + start = cu_seqlens[i].item() + end = cu_seqlens[i + 1].item() + mask[start:end, start:end] = 0.0 + + return mask.unsqueeze(0).unsqueeze(0) # (1, 1, seq_len, seq_len) + + def forward(self, pixel_values, image_grid_thw): + """Run vision encoding (CPU preprocessing + Neuron traced model). + + Args: + pixel_values: Raw pixel values from HF processor + image_grid_thw: (num_images, 3) -- temporal, height, width in patches + + Returns: + vision_embeddings: (total_merged_tokens, out_hidden_size) + """ + # 1. Patch embedding (CPU, Conv3d) + hidden_states = self.patch_embed(pixel_values) + + # 2. Positional embedding (CPU, bilinear interpolation) + pos_embeds = self.fast_pos_embed_interpolate(image_grid_thw) + hidden_states = hidden_states + pos_embeds + + # 3. Rotary position embeddings (CPU) + rotary_pos_emb = self.rot_pos_emb(image_grid_thw) + emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) + position_embeddings = (emb.cos(), emb.sin()) + + # 4. Vision attention mask (block-diagonal) + seq_len = hidden_states.shape[0] + attention_mask = self._build_vision_attention_mask( + image_grid_thw, seq_len, hidden_states.dtype + ) + + # 5. Bucket and pad for Neuron compilation + bucket_len = self._get_vision_bucket(seq_len) + cos, sin = position_embeddings + if seq_len < bucket_len: + pad_len = bucket_len - seq_len + hidden_states = F.pad(hidden_states, (0, 0, 0, pad_len)) + cos = F.pad(cos, (0, 0, 0, pad_len)) + sin = F.pad(sin, (0, 0, 0, pad_len)) + # Extend mask with _MASK_NEG_INF for padded positions (NOT -inf, which causes NaN on Neuron) + mask = torch.full( + (1, 1, bucket_len, bucket_len), _MASK_NEG_INF, dtype=hidden_states.dtype + ) + mask[:, :, :seq_len, :seq_len] = attention_mask + attention_mask = mask + + # 6. Run vision model (Neuron compiled or CPU fallback) + if self._compiled_buckets is not None: + # Multi-bucket mode: select the compiled model for this bucket + if bucket_len not in self._compiled_buckets: + raise RuntimeError( + f"No compiled vision encoder for bucket size {bucket_len}. " + f"Available buckets: {sorted(self._compiled_buckets.keys())}. " + f"Input seq_len={seq_len} requires bucket {bucket_len}." + ) + compiled_model = self._compiled_buckets[bucket_len] + vision_output = compiled_model( + hidden_states.to(torch.bfloat16), + attention_mask.to(torch.bfloat16), + cos.to(torch.bfloat16), + sin.to(torch.bfloat16), + ) + elif self._compiled_model is not None: + # Single compiled model (legacy) + vision_output = self._compiled_model( + hidden_states.to(torch.bfloat16), + attention_mask.to(torch.bfloat16), + cos.to(torch.bfloat16), + sin.to(torch.bfloat16), + ) + elif self._cpu_model is not None: + # CPU-only mode: run vision encoder on CPU (no bucketing/padding needed + # but we pad anyway for consistency with the same merger math) + with torch.no_grad(): + vision_output = self._cpu_model( + hidden_states.to(torch.bfloat16), + attention_mask.to(torch.bfloat16), + cos.to(torch.bfloat16), + sin.to(torch.bfloat16), + ) + else: + # NxDI traced model: takes (hidden_states, attention_mask, position_embeddings) + vision_output = self.model(hidden_states, attention_mask, (cos, sin)) + + # 7. Unpad: only keep valid merged tokens + merge_area = self.vision_config.spatial_merge_size**2 + total_merged_tokens = sum( + t + * (h // self.vision_config.spatial_merge_size) + * (w // self.vision_config.spatial_merge_size) + for t, h, w in image_grid_thw.tolist() + ) + vision_output = vision_output[:total_merged_tokens] + + return vision_output + + +class NeuronQwen35VisionForImageEncoding(NeuronApplicationBase): + """Standalone application class for vision encoding (for testing).""" + + model_cls = NeuronQwen35VisionModel + model_wrapper_cls = NeuronQwen35VisionModelWrapper + + @staticmethod + def prepare_input_args(image_path, processor): + """Prepare vision inputs from an image path. + + Args: + image_path: Path to image file + processor: HF AutoProcessor + + Returns: + pixel_values, image_grid_thw + """ + from PIL import Image + + image = Image.open(image_path).convert("RGB") + inputs = processor(images=image, return_tensors="pt") + return inputs["pixel_values"], inputs["image_grid_thw"] diff --git a/contrib/models/Qwen3.6-27B/src/modeling_qwen35_vl.py b/contrib/models/Qwen3.6-27B/src/modeling_qwen35_vl.py new file mode 100644 index 00000000..e3afbb1b --- /dev/null +++ b/contrib/models/Qwen3.6-27B/src/modeling_qwen35_vl.py @@ -0,0 +1,662 @@ +""" +Qwen3.5-27B / Qwen3.6-27B Vision-Language Model Orchestrator for NeuronX Distributed Inference. + +This is the top-level VL model that wires together: +- The vision encoder (modeling_qwen35_vision.py) +- The text decoder (modeling_qwen35.py, dense model with vision injection) + +It handles: +- Multimodal RoPE (mRoPE) with interleaved layout +- Vision embedding injection via scatter_by_index_put +- Separate compilation and loading of vision and text models +- The CTE+TKG generation loop with vision inputs + +Architecture follows the NxDI NeuronBaseForImageToText pattern established +by Qwen3-VL in SDK 2.28, adapted for Qwen3.5/3.6 dense model's unique features: +- No deepstack (Qwen3.5/3.6 does not use intermediate vision feature injection) +- DeltaNet linear attention layers in the text decoder +- Dense SwiGLU MLP layers in the text decoder +- Interleaved mRoPE (THWTHW... layout) instead of Qwen3-VL's section-based layout +""" + +import logging +import os +from typing import Optional + +import torch +import torch.nn.functional as F + +logger = logging.getLogger(__name__) + +# NxDI imports +try: + from neuronx_distributed_inference.models.image_to_text_model_base import ( + ImageToTextInferenceConfig, + NeuronBaseForImageToText, + ) + from neuronx_distributed_inference.models.config import NeuronConfig + + HAS_NXDI_VL = True +except ImportError: + HAS_NXDI_VL = False + logger.warning("NxDI VL base classes not available -- VL model requires SDK 2.28+") + +# Local imports +try: + from src.modeling_qwen35 import ( + NeuronQwen35ForCausalLM, + NeuronQwen35Model, + Qwen35InferenceConfig, + Qwen35ModelWrapper, + ) + from src.modeling_qwen35_vision import ( + NeuronQwen35VisionModel, + NeuronQwen35VisionModelWrapper, + ) +except ImportError: + from modeling_qwen35 import ( + NeuronQwen35ForCausalLM, + NeuronQwen35Model, + Qwen35InferenceConfig, + Qwen35ModelWrapper, + ) + from modeling_qwen35_vision import ( + NeuronQwen35VisionModel, + NeuronQwen35VisionModelWrapper, + ) + + +def get_rope_index( + input_ids, + image_grid_thw=None, + video_grid_thw=None, + attention_mask=None, + image_token_id=248056, + video_token_id=248057, + vision_start_token_id=248053, + spatial_merge_size=2, +): + """Compute 3D multimodal RoPE position IDs for Qwen3.5. + + Returns position_ids of shape (3, batch_size, seq_len) where: + - Axis 0: temporal position + - Axis 1: height position + - Axis 2: width position + + For text tokens, all 3 axes have the same sequential position. + For vision tokens, each axis encodes the spatial/temporal grid position. + + Also returns rope_deltas for use during TKG decoding. + + Adapted from HuggingFace Qwen3_5Model.get_rope_index(). + """ + if video_grid_thw is not None: + video_grid_thw = torch.repeat_interleave( + video_grid_thw, video_grid_thw[:, 0], dim=0 + ) + video_grid_thw[:, 0] = 1 + + image_grid_thw_list = ( + image_grid_thw.tolist() if image_grid_thw is not None else None + ) + video_grid_thw_list = ( + video_grid_thw.tolist() if video_grid_thw is not None else None + ) + + mrope_position_deltas = [] + total_input_ids = input_ids + + if attention_mask is None: + attention_mask = torch.ones_like(total_input_ids) + + position_ids = torch.zeros( + 3, + input_ids.shape[0], + input_ids.shape[1], + dtype=input_ids.dtype, + device=input_ids.device, + ) + + image_index, video_index = 0, 0 + attention_mask = attention_mask.to(total_input_ids.device) + + for i, ids in enumerate(total_input_ids): + ids = ids[attention_mask[i] == 1] + image_nums, video_nums = 0, 0 + + vision_start_indices = torch.argwhere(ids == vision_start_token_id).squeeze(1) + if len(vision_start_indices) > 0: + vision_tokens = ids[vision_start_indices + 1] + image_nums = (vision_tokens == image_token_id).sum() + video_nums = (vision_tokens == video_token_id).sum() + + input_tokens = ids.tolist() + llm_pos_ids_list = [] + st = 0 + remain_images, remain_videos = image_nums, video_nums + + for _ in range(image_nums + video_nums): + if image_token_id in input_tokens and remain_images > 0: + ed_image = input_tokens.index(image_token_id, st) + else: + ed_image = len(input_tokens) + 1 + if video_token_id in input_tokens and remain_videos > 0: + ed_video = input_tokens.index(video_token_id, st) + else: + ed_video = len(input_tokens) + 1 + + if ed_image < ed_video: + t, h, w = image_grid_thw_list[image_index] + image_index += 1 + remain_images -= 1 + ed = ed_image + else: + t, h, w = video_grid_thw_list[video_index] + video_index += 1 + remain_videos -= 1 + ed = ed_video + + llm_grid_t = t + llm_grid_h = h // spatial_merge_size + llm_grid_w = w // spatial_merge_size + + text_len = ed - st + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + llm_pos_ids_list.append( + torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx + ) + + t_index = ( + torch.arange(llm_grid_t) + .view(-1, 1) + .expand(-1, llm_grid_h * llm_grid_w) + .flatten() + ) + h_index = ( + torch.arange(llm_grid_h) + .view(1, -1, 1) + .expand(llm_grid_t, -1, llm_grid_w) + .flatten() + ) + w_index = ( + torch.arange(llm_grid_w) + .view(1, 1, -1) + .expand(llm_grid_t, llm_grid_h, -1) + .flatten() + ) + llm_pos_ids_list.append( + torch.stack([t_index, h_index, w_index]) + text_len + st_idx + ) + st = ed + llm_grid_t * llm_grid_h * llm_grid_w + + if st < len(input_tokens): + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + text_len = len(input_tokens) - st + llm_pos_ids_list.append( + torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx + ) + + llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + position_ids[..., i, attention_mask[i] == 1] = llm_positions.to( + position_ids.device + ) + mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i])) + + mrope_position_deltas = torch.tensor( + mrope_position_deltas, device=input_ids.device + ).unsqueeze(1) + return position_ids, mrope_position_deltas + + +class Qwen35VLInferenceConfig: + """Configuration for the full VL model (text + vision). + + Wraps the existing Qwen35InferenceConfig for text and adds + vision-specific settings. + """ + + def __init__( + self, + text_config, + vision_config, + image_token_id=248056, + video_token_id=248057, + vision_start_token_id=248053, + vision_end_token_id=248054, + spatial_merge_size=2, + vision_seq_len_buckets=None, + **kwargs, + ): + """ + Args: + text_config: Qwen35InferenceConfig instance for the text decoder + vision_config: dict with vision encoder hyperparams (depth, hidden_size, etc.) + image_token_id: Token ID for image placeholder tokens + video_token_id: Token ID for video placeholder tokens + vision_start_token_id: Token ID for <|vision_start|> + vision_end_token_id: Token ID for <|vision_end|> + spatial_merge_size: How many patches are merged (2 = 2x2 = 4 patches merged) + vision_seq_len_buckets: List of vision sequence length buckets for compilation + """ + self.text_config = text_config + self.vision_config = vision_config + self.image_token_id = image_token_id + self.video_token_id = video_token_id + self.vision_start_token_id = vision_start_token_id + self.vision_end_token_id = vision_end_token_id + self.spatial_merge_size = spatial_merge_size + self.vision_seq_len_buckets = vision_seq_len_buckets or [1024, 4096, 16384] + + +class NeuronQwen35VLForCausalLM: + """Top-level VL model for Qwen3.5/3.6-27B on Neuron. + + This class manages: + - Separate compilation/loading of vision encoder and text decoder + - CPU-side mRoPE computation + - Vision embedding injection into text decoder + - The CTE+TKG generation loop + + Note: This is NOT an NeuronBaseForImageToText subclass because the + text decoder (NeuronQwen35ForCausalLM) has extensive custom overrides + (DeltaNet state management, custom forward, custom ModelWrapper) that + don't fit the base class pattern. Instead, this class composes the two + models and handles the VL orchestration directly. + """ + + def __init__(self, model_path, text_config, vision_config=None, processor=None): + """ + Args: + model_path: Path to HF model directory + text_config: Qwen35InferenceConfig for text decoder + vision_config: Qwen35VLInferenceConfig (or None for text-only) + processor: HF AutoProcessor for image preprocessing + """ + self.model_path = model_path + self.text_config = text_config + self.vl_config = vision_config + self.processor = processor + + # Text decoder (existing implementation) + self.text_model = NeuronQwen35ForCausalLM( + model_path=model_path, config=text_config + ) + + # Vision encoder (lazy init -- only built if vl_config provided) + self.vision_model_wrapper = None + if vision_config is not None: + self._init_vision_model(vision_config) + + # mRoPE state + self.rope_deltas = None + + def _init_vision_model(self, vl_config): + """Initialize the vision encoder wrapper.""" + from types import SimpleNamespace + + vision_cfg = SimpleNamespace(**vl_config.vision_config) + self.vision_model_wrapper = NeuronQwen35VisionModelWrapper( + config=vision_cfg, + model_cls=None, # Standalone mode (no NxDI parallel layers) + vision_seq_len_buckets=vl_config.vision_seq_len_buckets, + ) + self._vl_config = vl_config + + def compile(self, compiled_model_path): + """Compile both text and vision models. + + For the vision encoder, use compile_vision_encoder.py separately + (standalone torch_neuronx.trace compilation). Then use load() to + load the pre-compiled vision encoder. + """ + # Compile text decoder + text_path = os.path.join(compiled_model_path, "text_model") + os.makedirs(text_path, exist_ok=True) + self.text_model.compile(text_path) + + # Vision encoder is compiled separately via compile_vision_encoder.py + if self.vision_model_wrapper is not None: + logger.info( + "Vision encoder must be compiled separately using " + "compile_vision_encoder.py. Use load() to load the " + "pre-compiled vision encoder." + ) + + def load(self, compiled_model_path, vision_compiled_path=None): + """Load both compiled models. + + Args: + compiled_model_path: Path to compiled text model (or parent dir) + vision_compiled_path: Path to compiled vision encoder .pt file. + If None, looks for 'vision_encoder.pt' in compiled_model_path. + """ + text_path = os.path.join(compiled_model_path, "text_model") + if os.path.exists(text_path): + self.text_model.load(text_path) + else: + # Backward compatibility: text model compiled at root + self.text_model.load(compiled_model_path) + + # Load vision encoder + if self.vision_model_wrapper is not None: + if vision_compiled_path is None: + vision_compiled_path = os.path.join( + compiled_model_path, "vision_encoder.pt" + ) + if os.path.exists(vision_compiled_path): + self.vision_model_wrapper.load_compiled(vision_compiled_path) + # Also load CPU-side weights (patch_embed, pos_embed) + self.vision_model_wrapper.load_vision_weights_from_hf(self.model_path) + logger.info("Vision encoder loaded from pre-compiled model") + else: + logger.warning( + f"No compiled vision encoder found at {vision_compiled_path}. " + "Vision encoding will not be available." + ) + + # Qwen3.5 stop token IDs (loaded from config/tokenizer) + _DEFAULT_EOS_TOKEN_IDS = { + 248044, # <|endoftext|> -- text config eos_token_id + 248046, # <|im_end|> -- tokenizer eos_token / end of assistant turn + } + + def generate( + self, + input_ids, + attention_mask=None, + pixel_values=None, + image_grid_thw=None, + video_grid_thw=None, + max_new_tokens=32, + temperature=0.0, + top_p=1.0, + top_k=0, + eos_token_ids=None, + **kwargs, + ): + """Generate text from text and/or vision inputs. + + Args: + input_ids: (batch_size, seq_len) token IDs + attention_mask: (batch_size, seq_len) attention mask + pixel_values: Vision pixel values from HF processor (or None for text-only) + image_grid_thw: (num_images, 3) grid dimensions + video_grid_thw: (num_videos, 3) grid dimensions + max_new_tokens: Maximum new tokens to generate + temperature: Sampling temperature (0.0 = greedy/argmax) + top_p: Nucleus sampling threshold (1.0 = disabled) + top_k: Top-k sampling (0 = disabled) + eos_token_ids: Set of token IDs to stop generation on + (default: {248044, 248046}) + + Returns: + generated_ids: (batch_size, seq_len + max_new_tokens) token IDs + """ + if eos_token_ids is None: + eos_token_ids = self._DEFAULT_EOS_TOKEN_IDS + + # Reset text model state for a fresh generation. + # This ensures CTE runs (not TKG) even if a prior generate() was called. + # DeltaNet recurrent states don't need explicit zeroing because the CTE + # NKI kernel always starts from zero state. + self.text_model.reset() + + has_vision = pixel_values is not None and pixel_values.numel() > 0 + + # Step 1: Compute 3D mRoPE position IDs + if has_vision and self._vl_config is not None: + position_ids, self.rope_deltas = get_rope_index( + input_ids, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + attention_mask=attention_mask, + image_token_id=self._vl_config.image_token_id, + video_token_id=self._vl_config.video_token_id, + vision_start_token_id=self._vl_config.vision_start_token_id, + spatial_merge_size=self._vl_config.spatial_merge_size, + ) + else: + # Text-only: use standard sequential position IDs + seq_len = input_ids.shape[1] + position_ids = torch.arange(seq_len).unsqueeze(0) + self.rope_deltas = None + + # Step 2: Run vision encoder and prepare injection args + llava_args = [] + batch_size = input_ids.shape[0] + if has_vision and self.vision_model_wrapper is not None: + # The vision encoder processes both image and video frames identically + # (they share the same ViT architecture). The HF processor outputs a + # single pixel_values tensor for images, and video frames are treated + # as multiple images with temporal grid > 1. + vision_embeddings = self.vision_model_wrapper(pixel_values, image_grid_thw) + # vision_embeddings: (total_merged_tokens, out_hidden_size) + + # Build vision_mask: boolean mask of ALL vision token positions + # (both image_token_id and video_token_id placeholders) + image_token_id = self._vl_config.image_token_id + video_token_id = self._vl_config.video_token_id + vision_bool_mask = (input_ids == image_token_id) | ( + input_ids == video_token_id + ) # (BS, seq_len) + + # For batch_size=1 (primary path): extract positions from batch element 0. + # For batch_size>1: each element may have different image token positions; + # we'd need per-element scatter. Currently only batch_size=1 is supported + # for VL (the compiled model uses batch_size=1 for CTE). + if batch_size > 1: + logger.warning( + "VL generation with batch_size > 1 is not fully supported. " + "Using batch element 0 for vision scatter positions." + ) + + positions = ( + vision_bool_mask[0].nonzero(as_tuple=False).squeeze(-1) + ) # (n_vision_tokens,) + + # Reshape vision_embeddings to (1, n_vision_tokens, hidden_size) + n_vis = positions.shape[0] + hidden_size = vision_embeddings.shape[-1] + vis_emb = vision_embeddings[:n_vis].unsqueeze(0) # (1, n_vis, hidden) + + # Pad to match input sequence length for compiled graph compatibility + seq_len = input_ids.shape[1] + pad_limit = seq_len # Must match the bucket size + + # Pad vision_embeddings to (1, pad_limit, hidden_size) + if n_vis < pad_limit: + pad_emb = torch.zeros( + (1, pad_limit - n_vis, hidden_size), + dtype=vis_emb.dtype, + ) + vis_emb_padded = torch.cat([vis_emb, pad_emb], dim=1) + else: + vis_emb_padded = vis_emb[:, :pad_limit] + + # Pad positions to (1, pad_limit, 1) with a SAFE fill value. + # CRITICAL: fill_value must be a valid index (within [0, pad_limit-1]). + # Using pad_limit-1 targets the last position (always a padding slot) + # so index_put_ scatters zero embeddings there harmlessly. + # NOTE: Do NOT use large sentinel values (e.g., 2**30) as they cause + # DGE out-of-bounds crashes in the Neuron runtime. + positions_padded = torch.full( + (1, pad_limit, 1), + fill_value=pad_limit - 1, + dtype=torch.int32, + ) + positions_padded[0, :n_vis, 0] = positions[:pad_limit].to(torch.int32) + + llava_args = [vis_emb_padded, positions_padded] + + # Append 3D mRoPE position IDs for the text model. + # position_ids shape: (3, batch_size, seq_len) from get_rope_index. + # _get_model_outputs receives this at slot 21 and pre-computes + # mRoPE cos/sin in get_model_output() for all decoder layers. + if position_ids.ndim == 3: + mrope_pos = position_ids[:, :, :seq_len].to(torch.int32).contiguous() + llava_args.append(mrope_pos) + else: + vision_embeddings = None + + # Step 3: Context encoding (prefill) + generated_ids = input_ids.clone() + + # CRITICAL: Always pass an explicit attention_mask for CTE. + # The base class _infer_attention_mask() assumes sequential position_ids + # (position_ids[i] >= i). When position_ids come from mRoPE temporal + # axis (non-sequential, e.g., all vision tokens share position 4), + # the inferred mask incorrectly masks out most of the sequence. + # Fix: provide a real all-ones mask for the actual token positions. + if attention_mask is None: + attention_mask = torch.ones_like(input_ids) + + # For slot 2 (position_ids): use SEQUENTIAL positions regardless of mRoPE. + # Slot 2 is only used for: (1) logit position selection via torch.max(), + # (2) attention mask inference (which we bypass with explicit mask above). + # The actual RoPE computation uses slot 21 (rotary_position_ids) from + # _get_model_outputs, NOT slot 2. Using sequential slot 2 ensures + # correct logit selection and avoids any position_ids-related issues. + seq_len = input_ids.shape[1] + cte_position_ids = torch.arange(seq_len, dtype=torch.long).unsqueeze(0) + + with torch.no_grad(): + output = self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=cte_position_ids, + output_attentions=False, + output_hidden_states=False, + return_dict=False, + llava_args=llava_args, + ) + + logits = output[0] if isinstance(output, tuple) else output.logits + next_token = self._sample_token(logits[:, -1, :], temperature, top_p, top_k) + generated_ids = torch.cat([generated_ids, next_token.unsqueeze(-1)], dim=-1) + + # Check EOS after first token + if next_token.item() in eos_token_ids: + return generated_ids + + # Step 4: Token generation (TKG) loop + for _ in range(max_new_tokens - 1): + pos_ids = torch.tensor([[generated_ids.shape[1] - 1]]) + if self.rope_deltas is not None: + pos_ids = pos_ids + self.rope_deltas + + last_token = generated_ids[:, -1:] + with torch.no_grad(): + output = self.text_model( + input_ids=last_token, + position_ids=pos_ids, + output_attentions=False, + output_hidden_states=False, + return_dict=False, + ) + logits = output[0] if isinstance(output, tuple) else output.logits + next_token = self._sample_token(logits[:, -1, :], temperature, top_p, top_k) + generated_ids = torch.cat([generated_ids, next_token.unsqueeze(-1)], dim=-1) + + # Stop on EOS + if next_token.item() in eos_token_ids: + break + + return generated_ids + + @staticmethod + def _sample_token(logits, temperature=0.0, top_p=1.0, top_k=0): + """Sample a token from logits with optional temperature/top-p/top-k. + + Args: + logits: (batch_size, vocab_size) unnormalized logits + temperature: Sampling temperature. 0.0 = greedy (argmax). + top_p: Nucleus sampling threshold. 1.0 = disabled. + top_k: Top-k filtering. 0 = disabled. + + Returns: + token_id: (batch_size,) sampled token IDs + """ + if temperature <= 0.0: + return torch.argmax(logits, dim=-1) + + # Apply temperature + logits = logits / temperature + + # Top-k filtering + if top_k > 0: + top_k = min(top_k, logits.shape[-1]) + indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] + logits[indices_to_remove] = float("-inf") + + # Top-p (nucleus) filtering + if top_p < 1.0: + sorted_logits, sorted_indices = torch.sort(logits, descending=True) + cumulative_probs = torch.cumsum( + torch.softmax(sorted_logits, dim=-1), dim=-1 + ) + # Remove tokens with cumulative probability above the threshold + sorted_indices_to_remove = cumulative_probs > top_p + # Shift right so the first token above threshold is kept + sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[ + ..., :-1 + ].clone() + sorted_indices_to_remove[..., 0] = False + # Scatter back to original indexing + indices_to_remove = sorted_indices_to_remove.scatter( + -1, sorted_indices, sorted_indices_to_remove + ) + logits[indices_to_remove] = float("-inf") + + # Sample from the filtered distribution + probs = torch.softmax(logits, dim=-1) + return torch.multinomial(probs, num_samples=1).squeeze(-1) + + @staticmethod + def prepare_input_args(text_prompt, image_path, processor, role="user"): + """Prepare inputs for vision+text generation. + + Args: + text_prompt: Text prompt string + image_path: Path to image file (or None for text-only) + processor: HF AutoProcessor + role: Message role (default "user") + + Returns: + input_ids, attention_mask, vision_inputs dict + """ + content = [] + if image_path is not None: + import base64 + from pathlib import Path + + image_data = Path(image_path).read_bytes() + b64 = base64.b64encode(image_data).decode("utf-8") + content.append( + { + "type": "image", + "url": f"data:image/jpeg;base64,{b64}", + } + ) + content.append({"type": "text", "text": text_prompt}) + + messages = [{"role": role, "content": content}] + inputs = processor.apply_chat_template( + messages, + tokenize=True, + add_generation_prompt=True, + return_tensors="pt", + return_dict=True, + ) + + input_ids = inputs["input_ids"] + attention_mask = inputs.get("attention_mask", torch.ones_like(input_ids)) + + vision_inputs = {} + if "pixel_values" in inputs: + vision_inputs["pixel_values"] = inputs["pixel_values"] + if "image_grid_thw" in inputs: + vision_inputs["image_grid_thw"] = inputs["image_grid_thw"] + if "video_grid_thw" in inputs: + vision_inputs["video_grid_thw"] = inputs["video_grid_thw"] + + return input_ids, attention_mask, vision_inputs diff --git a/contrib/models/Qwen3.6-27B/src/nki_kernels/__init__.py b/contrib/models/Qwen3.6-27B/src/nki_kernels/__init__.py new file mode 100644 index 00000000..7e78cdb9 --- /dev/null +++ b/contrib/models/Qwen3.6-27B/src/nki_kernels/__init__.py @@ -0,0 +1,10 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Custom NKI kernels for Qwen3.5-27B / Qwen3.6-27B DeltaNet layers. + +Contains three kernel implementations: +- nki_deltanet: Per-token recurrent kernel (used for token generation) +- nki_deltanet_chunked: Per-chunk kernel (legacy, superseded by fused) +- nki_deltanet_fused: Fused single-kernel chunked forward (used for context encoding) +""" diff --git a/contrib/models/Qwen3.6-27B/src/nki_kernels/nki_deltanet.py b/contrib/models/Qwen3.6-27B/src/nki_kernels/nki_deltanet.py new file mode 100644 index 00000000..b2f653c2 --- /dev/null +++ b/contrib/models/Qwen3.6-27B/src/nki_kernels/nki_deltanet.py @@ -0,0 +1,607 @@ +"""NKI kernels for DeltaNet gated delta rule recurrent forward. + +NKI v3 (SDK 2.29, NKI 0.3.0). Processes a SINGLE (batch, head) pair per kernel call. +The caller loops over (B, H) in PyTorch and calls this kernel for each pair. + +Input layout: All inputs are 2D contiguous tensors (S, 128). +Each call processes one (batch, head) element's full sequence. + +k_dim = v_dim = 128, which matches SBUF tile partition dimension exactly. +g and beta are scalars per token, expanded to (S, 128) by the caller. + +Two kernel variants: + deltanet_recurrent_fwd -- returns output only (original) + deltanet_recurrent_fwd_state -- returns (output, final_state) for CTE->TKG carry-over + deltanet_recurrent_step -- one-token state-in/state-out decode step + deltanet_recurrent_step_batched -- one-token batched-head decode step +""" + +import nki +import nki.isa as nisa +import nki.language as nl + +# Partition dimension max (NeuronCore SBUF tile width) +P_MAX = 128 + +# Shuffle mask: broadcast partition 0 to all partitions in a 32-wide group +_BROADCAST_MASK = [0] * 32 + + +@nki.jit +def _deltanet_recurrent_step_batched_kernel( + query: nl.ndarray, # (BH, 128) float32 + key: nl.ndarray, # (BH, 128) float32 + value: nl.ndarray, # (BH, 128) float32 + g_in: nl.ndarray, # (BH, 1) float32, log-decay scalar per head + beta_in: nl.ndarray, # (BH, 1) float32, write-gate scalar per head + state_in: nl.ndarray, # (BH * 128, 128) float32/bfloat16 +): + """Single-launch batched-head one-token DeltaNet decode step. + + The installed NKI framework on the compile hosts uses ``kernel[...]`` for + LNC selection, not custom-op SPMD grids. Keep one framework custom call by + looping over flattened ``(batch, value_head)`` rows inside the kernel. + """ + batch_heads, dim = query.shape + + output = nl.ndarray(query.shape, dtype=query.dtype, buffer=nl.shared_hbm) + state_out = nl.ndarray(state_in.shape, dtype=state_in.dtype, buffer=nl.shared_hbm) + + for bh in nl.sequential_range(batch_heads): + head_offset = bh * dim + state_offset = bh * P_MAX + + q_t = nl.ndarray((P_MAX, 1), dtype=query.dtype, buffer=nl.sbuf) + nisa.dma_copy( + dst=q_t, + src=query.ap(pattern=[[1, P_MAX]], offset=head_offset), + dge_mode=nisa.dge_mode.hwdge, + ) + + k_t = nl.ndarray((P_MAX, 1), dtype=key.dtype, buffer=nl.sbuf) + nisa.dma_copy( + dst=k_t, + src=key.ap(pattern=[[1, P_MAX]], offset=head_offset), + dge_mode=nisa.dge_mode.hwdge, + ) + + v_t = nl.ndarray((P_MAX, 1), dtype=value.dtype, buffer=nl.sbuf) + nisa.dma_copy( + dst=v_t, + src=value.ap(pattern=[[1, P_MAX]], offset=head_offset), + dge_mode=nisa.dge_mode.hwdge, + ) + + g_t = nl.ndarray((P_MAX, 1), dtype=g_in.dtype, buffer=nl.sbuf) + g_scalar = nl.ndarray((1, 1), dtype=g_in.dtype, buffer=nl.sbuf) + nisa.dma_copy(dst=g_scalar, src=g_in.ap(pattern=[[1, 1]], offset=bh)) + for i_shuf in nl.static_range(P_MAX // 32): + nisa.nc_stream_shuffle( + src=g_scalar[0:1, 0:1], + dst=g_t[i_shuf * 32 : i_shuf * 32 + 32, 0:1], + shuffle_mask=_BROADCAST_MASK, + ) + + beta_t = nl.ndarray((P_MAX, 1), dtype=beta_in.dtype, buffer=nl.sbuf) + beta_scalar = nl.ndarray((1, 1), dtype=beta_in.dtype, buffer=nl.sbuf) + nisa.dma_copy(dst=beta_scalar, src=beta_in.ap(pattern=[[1, 1]], offset=bh)) + for i_shuf in nl.static_range(P_MAX // 32): + nisa.nc_stream_shuffle( + src=beta_scalar[0:1, 0:1], + dst=beta_t[i_shuf * 32 : i_shuf * 32 + 32, 0:1], + shuffle_mask=_BROADCAST_MASK, + ) + + state = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.dma_copy( + dst=state, + src=state_in[state_offset : state_offset + P_MAX, 0:dim], + ) + + exp_g = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.activation(dst=exp_g, op=nl.exp, data=g_t, bias=None, scale=1.0) + + state_decayed = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=state_decayed, + data=state, + op0=nl.multiply, + operand0=exp_g, + engine=nisa.vector_engine, + ) + + kv_mem_psum = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=kv_mem_psum, stationary=state_decayed, moving=k_t) + kv_mem = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=kv_mem, src=kv_mem_psum) + + v_sub = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=v_sub, data1=v_t, data2=kv_mem, op=nl.subtract) + + delta = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=delta, + data=v_sub, + op0=nl.multiply, + operand0=beta_t, + engine=nisa.vector_engine, + ) + + delta_row_psum = nl.ndarray((1, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=delta_row_psum, data=delta) + + delta_row_sb = nl.ndarray((1, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=delta_row_sb, src=delta_row_psum) + + delta_broadcast = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + for i_shuf in nl.static_range(P_MAX // 32): + nisa.nc_stream_shuffle( + src=delta_row_sb[0:1, 0:P_MAX], + dst=delta_broadcast[i_shuf * 32 : i_shuf * 32 + 32, 0:P_MAX], + shuffle_mask=_BROADCAST_MASK, + ) + + state_new = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.scalar_tensor_tensor( + dst=state_new, + data=delta_broadcast, + op0=nl.multiply, + operand0=k_t, + op1=nl.add, + operand1=state_decayed, + ) + + o_t_psum = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=o_t_psum, stationary=state_new, moving=q_t) + o_t = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=o_t, src=o_t_psum) + + nisa.dma_copy( + dst=output.ap(pattern=[[1, dim]], offset=head_offset), + src=o_t, + dge_mode=nisa.dge_mode.hwdge, + ) + nisa.dma_copy( + dst=state_out[state_offset : state_offset + P_MAX, 0:dim], + src=state_new, + ) + + return output, state_out + + +def deltanet_recurrent_step_batched( + query, + key, + value, + g_in, + beta_in, + state_in, +): + """Launch the one-token DeltaNet decode step across flattened BH heads.""" + return _deltanet_recurrent_step_batched_kernel( + query, + key, + value, + g_in, + beta_in, + state_in, + ) + + +@nki.jit +def deltanet_recurrent_step( + query: nl.ndarray, # (1, 128) float32 + key: nl.ndarray, # (1, 128) float32 + value: nl.ndarray, # (1, 128) float32 + g_in: nl.ndarray, # (1, 128) float32, log-decay broadcast to 128 + beta_in: nl.ndarray, # (1, 128) float32, write gate broadcast to 128 + state_in: nl.ndarray, # (128, 128) float32 +): + """Stateful one-token DeltaNet decode step. + + This is the token-generation equivalent of one iteration from + ``deltanet_recurrent_fwd_state``. The caller supplies the recurrent + state restored from the decode cache and receives the updated state. + + Returns: + output: (1, 128) float32 + state_out: (128, 128) float32 + """ + _, dim = query.shape + + output = nl.ndarray((1, dim), dtype=query.dtype, buffer=nl.shared_hbm) + state_out = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.shared_hbm) + + q_t = nl.ndarray((P_MAX, 1), dtype=query.dtype, buffer=nl.sbuf) + nisa.dma_copy(dst=q_t, src=query.ap(pattern=[[1, P_MAX]], offset=0)) + + k_t = nl.ndarray((P_MAX, 1), dtype=key.dtype, buffer=nl.sbuf) + nisa.dma_copy(dst=k_t, src=key.ap(pattern=[[1, P_MAX]], offset=0)) + + v_t = nl.ndarray((P_MAX, 1), dtype=value.dtype, buffer=nl.sbuf) + nisa.dma_copy(dst=v_t, src=value.ap(pattern=[[1, P_MAX]], offset=0)) + + g_t = nl.ndarray((P_MAX, 1), dtype=g_in.dtype, buffer=nl.sbuf) + nisa.dma_copy(dst=g_t, src=g_in.ap(pattern=[[1, P_MAX]], offset=0)) + + beta_t = nl.ndarray((P_MAX, 1), dtype=beta_in.dtype, buffer=nl.sbuf) + nisa.dma_copy(dst=beta_t, src=beta_in.ap(pattern=[[1, P_MAX]], offset=0)) + + state = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.dma_copy(dst=state, src=state_in) + + exp_g = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.activation(dst=exp_g, op=nl.exp, data=g_t, bias=None, scale=1.0) + + state_decayed = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=state_decayed, + data=state, + op0=nl.multiply, + operand0=exp_g, + engine=nisa.vector_engine, + ) + nisa.tensor_copy(dst=state, src=state_decayed) + + kv_mem_psum = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=kv_mem_psum, stationary=state, moving=k_t) + kv_mem = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=kv_mem, src=kv_mem_psum) + + v_sub = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=v_sub, data1=v_t, data2=kv_mem, op=nl.subtract) + + delta = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=delta, + data=v_sub, + op0=nl.multiply, + operand0=beta_t, + engine=nisa.vector_engine, + ) + + delta_row_psum = nl.ndarray((1, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=delta_row_psum, data=delta) + + delta_row_sb = nl.ndarray((1, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=delta_row_sb, src=delta_row_psum) + + delta_broadcast = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + for i_shuf in nl.static_range(P_MAX // 32): + nisa.nc_stream_shuffle( + src=delta_row_sb[0:1, 0:P_MAX], + dst=delta_broadcast[i_shuf * 32 : i_shuf * 32 + 32, 0:P_MAX], + shuffle_mask=_BROADCAST_MASK, + ) + + outer_prod = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=outer_prod, + data=delta_broadcast, + op0=nl.multiply, + operand0=k_t, + engine=nisa.vector_engine, + ) + + state_new = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=state_new, data1=state, data2=outer_prod, op=nl.add) + nisa.tensor_copy(dst=state, src=state_new) + + o_t_psum = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=o_t_psum, stationary=state, moving=q_t) + o_t = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=o_t, src=o_t_psum) + + nisa.dma_copy(dst=output.ap(pattern=[[1, dim]], offset=0), src=o_t) + nisa.dma_copy(dst=state_out, src=state) + + return output, state_out + + +@nki.jit +def deltanet_recurrent_fwd( + query: nl.ndarray, # (S, 128) float32 + key: nl.ndarray, # (S, 128) float32 + value: nl.ndarray, # (S, 128) float32 + g_in: nl.ndarray, # (S, 128) float32, log-decay broadcast to 128 + beta_in: nl.ndarray, # (S, 128) float32, write gate broadcast to 128 +) -> nl.ndarray: + """NKI kernel for DeltaNet recurrent forward -- single (batch, head). + + Iterates over sequence tokens with sequential_range. + State matrix (128 x 128) lives in SBUF. + + Args: + query: (S, 128) float32 + key: (S, 128) float32 + value: (S, 128) float32 + g_in: (S, 128) float32 + beta_in: (S, 128) float32 + + Returns: + output: (S, 128) float32 + """ + seq_len, dim = query.shape + + # Output tensor in HBM + output = nl.ndarray((seq_len, dim), dtype=query.dtype, buffer=nl.shared_hbm) + + # Stride: for 2D (S, D), dim0 stride = D=128, dim1 stride = 1 + seq_stride = dim + + # Initialize recurrent state in SBUF: (128, 128) + state = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.memset(dst=state, value=0.0) + + # Sequential loop over tokens (state-dependent) + for t in nl.sequential_range(seq_len): + tok_offset = t * seq_stride + + # ---- Load inputs for token t ---- + q_t = nl.ndarray((P_MAX, 1), dtype=query.dtype, buffer=nl.sbuf) + nisa.dma_copy( + dst=q_t, + src=query.ap(pattern=[[1, P_MAX]], offset=tok_offset), + ) + + k_t = nl.ndarray((P_MAX, 1), dtype=key.dtype, buffer=nl.sbuf) + nisa.dma_copy( + dst=k_t, + src=key.ap(pattern=[[1, P_MAX]], offset=tok_offset), + ) + + v_t = nl.ndarray((P_MAX, 1), dtype=value.dtype, buffer=nl.sbuf) + nisa.dma_copy( + dst=v_t, + src=value.ap(pattern=[[1, P_MAX]], offset=tok_offset), + ) + + g_t = nl.ndarray((P_MAX, 1), dtype=g_in.dtype, buffer=nl.sbuf) + nisa.dma_copy( + dst=g_t, + src=g_in.ap(pattern=[[1, P_MAX]], offset=tok_offset), + ) + + beta_t = nl.ndarray((P_MAX, 1), dtype=beta_in.dtype, buffer=nl.sbuf) + nisa.dma_copy( + dst=beta_t, + src=beta_in.ap(pattern=[[1, P_MAX]], offset=tok_offset), + ) + + # ---- Step 1: Decay state -- state = state * exp(g_t) ---- + exp_g = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.activation(dst=exp_g, op=nl.exp, data=g_t, bias=None, scale=1.0) + + state_decayed = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=state_decayed, + data=state, + op0=nl.multiply, + operand0=exp_g, + engine=nisa.vector_engine, + ) + nisa.tensor_copy(dst=state, src=state_decayed) + + # ---- Step 2: Read memory -- kv_mem = state^T @ k_t ---- + kv_mem_psum = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=kv_mem_psum, stationary=state, moving=k_t) + kv_mem = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=kv_mem, src=kv_mem_psum) + + # ---- Step 3: delta = (v_t - kv_mem) * beta_t ---- + v_sub = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=v_sub, data1=v_t, data2=kv_mem, op=nl.subtract) + + delta = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=delta, + data=v_sub, + op0=nl.multiply, + operand0=beta_t, + engine=nisa.vector_engine, + ) + + # ---- Step 4: state += outer(k_t, delta) ---- + # Broadcast multiply: outer[i,j] = k_t[i] * delta[j] + # 1) Transpose delta (128,1) -> (1,128) in PSUM + # 2) Copy PSUM (1,128) -> SBUF (128,128) -- partition broadcast + # 3) Multiply by k_t (128,1) which broadcasts across free dim + # This avoids the nc_matmul P=1 outer product (wastes 127/128 TE lanes). + + # Transpose delta to get values along free dimension + delta_row_psum = nl.ndarray((1, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=delta_row_psum, data=delta) + + # Copy PSUM (1, 128) -> SBUF (1, 128) first (NKI 0.3.0 requires matching P dims) + delta_row_sb = nl.ndarray((1, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=delta_row_sb, src=delta_row_psum) + + # Broadcast (1, 128) SBUF -> (128, 128) SBUF via nc_stream_shuffle + # Each partition row gets the same delta values + delta_broadcast = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + for i_shuf in nl.static_range(P_MAX // 32): + nisa.nc_stream_shuffle( + src=delta_row_sb[0:1, 0:P_MAX], + dst=delta_broadcast[i_shuf * 32 : i_shuf * 32 + 32, 0:P_MAX], + shuffle_mask=_BROADCAST_MASK, + ) + + # Element-wise multiply: outer[i,j] = delta_broadcast[i,j] * k_t[i,0] + # tensor_scalar broadcasts (P,1) k_t across all F columns + outer_prod = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=outer_prod, + data=delta_broadcast, + op0=nl.multiply, + operand0=k_t, + engine=nisa.vector_engine, + ) + + # Accumulate into state + state_new = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=state_new, data1=state, data2=outer_prod, op=nl.add) + nisa.tensor_copy(dst=state, src=state_new) + + # ---- Step 5: o_t = state^T @ q_t ---- + o_t_psum = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=o_t_psum, stationary=state, moving=q_t) + o_t = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=o_t, src=o_t_psum) + + # ---- Store output for token t ---- + nisa.dma_copy( + dst=output.ap(pattern=[[1, dim]], offset=tok_offset), + src=o_t, + ) + + return output + + +@nki.jit +def deltanet_recurrent_fwd_state( + query: nl.ndarray, # (S, 128) float32 + key: nl.ndarray, # (S, 128) float32 + value: nl.ndarray, # (S, 128) float32 + g_in: nl.ndarray, # (S, 128) float32, log-decay broadcast to 128 + beta_in: nl.ndarray, # (S, 128) float32, write gate broadcast to 128 +): + """NKI kernel for DeltaNet recurrent forward with final state output. + + Same recurrence as deltanet_recurrent_fwd, but ALSO writes the final + recurrent state (128, 128) to an output HBM buffer. This enables + CTE -> TKG state carry-over. + + Returns: + output: (S, 128) float32 -- per-token output + final_state: (128, 128) float32 -- recurrent state after last token + """ + seq_len, dim = query.shape + + # Output tensors in HBM + output = nl.ndarray((seq_len, dim), dtype=query.dtype, buffer=nl.shared_hbm) + final_state = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.shared_hbm) + + # Stride: for 2D (S, D), dim0 stride = D=128, dim1 stride = 1 + seq_stride = dim + + # Initialize recurrent state in SBUF: (128, 128) + state = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.memset(dst=state, value=0.0) + + # Sequential loop over tokens (state-dependent) + for t in nl.sequential_range(seq_len): + tok_offset = t * seq_stride + + # ---- Load inputs for token t ---- + q_t = nl.ndarray((P_MAX, 1), dtype=query.dtype, buffer=nl.sbuf) + nisa.dma_copy( + dst=q_t, + src=query.ap(pattern=[[1, P_MAX]], offset=tok_offset), + ) + + k_t = nl.ndarray((P_MAX, 1), dtype=key.dtype, buffer=nl.sbuf) + nisa.dma_copy( + dst=k_t, + src=key.ap(pattern=[[1, P_MAX]], offset=tok_offset), + ) + + v_t = nl.ndarray((P_MAX, 1), dtype=value.dtype, buffer=nl.sbuf) + nisa.dma_copy( + dst=v_t, + src=value.ap(pattern=[[1, P_MAX]], offset=tok_offset), + ) + + g_t = nl.ndarray((P_MAX, 1), dtype=g_in.dtype, buffer=nl.sbuf) + nisa.dma_copy( + dst=g_t, + src=g_in.ap(pattern=[[1, P_MAX]], offset=tok_offset), + ) + + beta_t = nl.ndarray((P_MAX, 1), dtype=beta_in.dtype, buffer=nl.sbuf) + nisa.dma_copy( + dst=beta_t, + src=beta_in.ap(pattern=[[1, P_MAX]], offset=tok_offset), + ) + + # ---- Step 1: Decay state -- state = state * exp(g_t) ---- + exp_g = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.activation(dst=exp_g, op=nl.exp, data=g_t, bias=None, scale=1.0) + + state_decayed = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=state_decayed, + data=state, + op0=nl.multiply, + operand0=exp_g, + engine=nisa.vector_engine, + ) + nisa.tensor_copy(dst=state, src=state_decayed) + + # ---- Step 2: Read memory -- kv_mem = state^T @ k_t ---- + kv_mem_psum = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=kv_mem_psum, stationary=state, moving=k_t) + kv_mem = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=kv_mem, src=kv_mem_psum) + + # ---- Step 3: delta = (v_t - kv_mem) * beta_t ---- + v_sub = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=v_sub, data1=v_t, data2=kv_mem, op=nl.subtract) + + delta = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=delta, + data=v_sub, + op0=nl.multiply, + operand0=beta_t, + engine=nisa.vector_engine, + ) + + # ---- Step 4: state += outer(k_t, delta) ---- + # Broadcast multiply: outer[i,j] = k_t[i] * delta[j] + delta_row_psum = nl.ndarray((1, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=delta_row_psum, data=delta) + + # Copy PSUM (1, 128) -> SBUF (1, 128) first (NKI 0.3.0 requires matching P dims) + delta_row_sb = nl.ndarray((1, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=delta_row_sb, src=delta_row_psum) + + # Broadcast (1, 128) SBUF -> (128, 128) SBUF via nc_stream_shuffle + delta_broadcast = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + for i_shuf in nl.static_range(P_MAX // 32): + nisa.nc_stream_shuffle( + src=delta_row_sb[0:1, 0:P_MAX], + dst=delta_broadcast[i_shuf * 32 : i_shuf * 32 + 32, 0:P_MAX], + shuffle_mask=_BROADCAST_MASK, + ) + + outer_prod = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=outer_prod, + data=delta_broadcast, + op0=nl.multiply, + operand0=k_t, + engine=nisa.vector_engine, + ) + + state_new = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=state_new, data1=state, data2=outer_prod, op=nl.add) + nisa.tensor_copy(dst=state, src=state_new) + + # ---- Step 5: o_t = state^T @ q_t ---- + o_t_psum = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=o_t_psum, stationary=state, moving=q_t) + o_t = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=o_t, src=o_t_psum) + + # ---- Store output for token t ---- + nisa.dma_copy( + dst=output.ap(pattern=[[1, dim]], offset=tok_offset), + src=o_t, + ) + + # ---- Write final state to HBM ---- + # state is (128, 128) in SBUF, copy to final_state in HBM + # Use dma_copy with full tile: P_MAX rows, dim cols + nisa.dma_copy(dst=final_state, src=state) + + return output, final_state diff --git a/contrib/models/Qwen3.6-27B/src/nki_kernels/nki_deltanet_chunked.py b/contrib/models/Qwen3.6-27B/src/nki_kernels/nki_deltanet_chunked.py new file mode 100644 index 00000000..281e8e14 --- /dev/null +++ b/contrib/models/Qwen3.6-27B/src/nki_kernels/nki_deltanet_chunked.py @@ -0,0 +1,431 @@ +"""NKI per-chunk DeltaNet kernel for CTE (context encoding / prefill). + +Single-chunk kernel: processes one chunk (128 tokens) with masked Neumann +power-doubling for intra-chunk correction. The caller loops over chunks in +PyTorch, passing state between calls. + +Each kernel call: + - Takes one chunk of data: q, k, v, beta, g_cumsum, g_last (all 128x128) + - Takes recurrent state_in (128x128) + - Returns chunk output (128x128) and state_out (128x128) + +No sequence-indexed DMA inside the kernel -- all inputs/outputs are full tiles. +This avoids the DMA OOB issue seen with nl.sequential_range + slice indexing +in the NxDI model compilation context. + +NKI v3 (SDK 2.29, NKI 0.3.0). Uses nki.* namespace. +""" + +import nki +import nki.isa as nisa +import nki.language as nl + +P_MAX = 128 + +# Broadcast partition 0 to all partitions in a 32-wide group. +_BROADCAST_MASK = [0] * 32 + + +@nki.jit +def deltanet_chunk_step( + query, # (128, 128) float32 -- one chunk, l2-normed+scaled + key, # (128, 128) float32 -- one chunk, l2-normed + value, # (128, 128) float32 -- one chunk + beta_broadcast, # (128, 128) float32 -- write gate broadcast to 128 + g_cumsum, # (128, 128) float32 -- cumsum of g within chunk, broadcast + g_last, # (128, 128) float32 -- g_cumsum[-1], constant in chunk, broadcast + state_in, # (128, 128) float32 -- recurrent state from previous chunk + lower_mask, # (128, 128) float32 -- strict lower triangular + identity, # (128, 128) float32 -- identity matrix + lower_mask_diag, # (128, 128) float32 -- lower tri with diagonal +): + """Process one chunk of DeltaNet. + + Returns: + output: (128, 128) float32 -- chunk output + state_out: (128, 128) float32 -- updated recurrent state + """ + C, dim = query.shape # C = 128, dim = 128 + + # Output tensors in HBM + output = nl.ndarray((P_MAX, dim), dtype=query.dtype, buffer=nl.shared_hbm) + state_out = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.shared_hbm) + + # Load all inputs into SBUF + q_c = nl.ndarray((P_MAX, dim), dtype=query.dtype, buffer=nl.sbuf) + nisa.dma_copy(dst=q_c, src=query) + + k_c = nl.ndarray((P_MAX, dim), dtype=key.dtype, buffer=nl.sbuf) + nisa.dma_copy(dst=k_c, src=key) + + v_c = nl.ndarray((P_MAX, dim), dtype=value.dtype, buffer=nl.sbuf) + nisa.dma_copy(dst=v_c, src=value) + + beta_c = nl.ndarray((P_MAX, dim), dtype=beta_broadcast.dtype, buffer=nl.sbuf) + nisa.dma_copy(dst=beta_c, src=beta_broadcast) + + gc_c = nl.ndarray((P_MAX, dim), dtype=g_cumsum.dtype, buffer=nl.sbuf) + nisa.dma_copy(dst=gc_c, src=g_cumsum) + + gl_c = nl.ndarray((P_MAX, dim), dtype=g_last.dtype, buffer=nl.sbuf) + nisa.dma_copy(dst=gl_c, src=g_last) + + state = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.dma_copy(dst=state, src=state_in) + + # Load masks + eye = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.dma_copy(dst=eye, src=identity) + + Lmask = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.dma_copy(dst=Lmask, src=lower_mask) + + Lmask_d = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.dma_copy(dst=Lmask_d, src=lower_mask_diag) + + # ============================================================ + # k_beta = K * beta, v_beta = V * beta + # ============================================================ + k_beta = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=k_beta, data1=k_c, data2=beta_c, op=nl.multiply) + + v_beta = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=v_beta, data1=v_c, data2=beta_c, op=nl.multiply) + + # ============================================================ + # Stable decay factors from cumulative log-decay + # + # The caller passes g_cumsum and g_last broadcast to (128, 128). Extract + # one column and build pairwise decays as exp(gc[i] - gc[j]) so no + # individual exp(-gc[j]) term can overflow. + # ============================================================ + gc_p = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=gc_p[0:P_MAX, 0:1], src=gc_c[0:P_MAX, 0:1]) + + gl_p = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=gl_p[0:P_MAX, 0:1], src=gl_c[0:P_MAX, 0:1]) + + exp_gc_p = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.activation( + dst=exp_gc_p[0:P_MAX, 0:1], + op=nl.exp, + data=gc_p[0:P_MAX, 0:1], + bias=None, + scale=1.0, + ) + + exp_gl_p = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.activation( + dst=exp_gl_p[0:P_MAX, 0:1], + op=nl.exp, + data=gl_p[0:P_MAX, 0:1], + bias=None, + scale=1.0, + ) + + gc_padded = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.memset(dst=gc_padded, value=0.0) + nisa.tensor_copy(dst=gc_padded[0:P_MAX, 0:1], src=gc_p[0:P_MAX, 0:1]) + + gc_row_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=gc_row_psum, data=gc_padded) + + gc_row = nl.ndarray((1, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=gc_row[0:1, 0:P_MAX], src=gc_row_psum[0:1, 0:P_MAX]) + + gc_row_broadcast = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + for i_shuf in nl.static_range(P_MAX // 32): + nisa.nc_stream_shuffle( + src=gc_row[0:1, 0:P_MAX], + dst=gc_row_broadcast[i_shuf * 32 : i_shuf * 32 + 32, 0:P_MAX], + shuffle_mask=_BROADCAST_MASK, + ) + + gc_col_strict = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=gc_col_strict, + data=Lmask, + op0=nl.multiply, + operand0=gc_p, + engine=nisa.vector_engine, + ) + gc_row_strict = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor( + dst=gc_row_strict, data1=gc_row_broadcast, data2=Lmask, op=nl.multiply + ) + g_diff_strict = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor( + dst=g_diff_strict, + data1=gc_col_strict, + data2=gc_row_strict, + op=nl.subtract, + ) + decay_strict_raw = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.activation( + dst=decay_strict_raw, + op=nl.exp, + data=g_diff_strict, + bias=None, + scale=1.0, + ) + decay_strict = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor( + dst=decay_strict, data1=decay_strict_raw, data2=Lmask, op=nl.multiply + ) + + gc_col_diag = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=gc_col_diag, + data=Lmask_d, + op0=nl.multiply, + operand0=gc_p, + engine=nisa.vector_engine, + ) + gc_row_diag = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor( + dst=gc_row_diag, data1=gc_row_broadcast, data2=Lmask_d, op=nl.multiply + ) + g_diff_diag = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor( + dst=g_diff_diag, + data1=gc_col_diag, + data2=gc_row_diag, + op=nl.subtract, + ) + decay_diag_raw = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.activation( + dst=decay_diag_raw, + op=nl.exp, + data=g_diff_diag, + bias=None, + scale=1.0, + ) + decay_diag = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor( + dst=decay_diag, data1=decay_diag_raw, data2=Lmask_d, op=nl.multiply + ) + + # ============================================================ + # Phase 1: Build A matrix (intra-chunk correction) + # QK = k_beta @ k^T -- contract over features + # ============================================================ + kb_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=kb_T_psum, stationary=k_beta, moving=eye) + kb_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=kb_T, src=kb_T_psum) + + k_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=k_T_psum, stationary=k_c, moving=eye) + k_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=k_T, src=k_T_psum) + + QK_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=QK_psum, stationary=kb_T, moving=k_T) + QK = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=QK, src=QK_psum) + + # QK_decay[i,j] = QK[i,j] * exp(gc[i] - gc[j]) for i > j. + QK_decay = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=QK_decay, data1=QK, data2=decay_strict, op=nl.multiply) + + # A = -QK_decay * lower_mask + neg_QK_decay = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=neg_QK_decay, + data=QK_decay, + op0=nl.multiply, + operand0=-1.0, + engine=nisa.vector_engine, + ) + A_mat = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=A_mat, data1=neg_QK_decay, data2=Lmask, op=nl.multiply) + + # ============================================================ + # Masked Neumann power-doubling: + # N = (I + A)(I + A^2)(I + A^4)...(I + A^64) + # + # A_mat is strictly lower triangular, so A^128 = 0. Re-mask after every + # square/multiply so numerical residue cannot leak above the diagonal. + # ============================================================ + P_acc = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=P_acc, data1=eye, data2=A_mat, op=nl.add) + + A_pow = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=A_pow, src=A_mat) + + for _round in nl.sequential_range(6): + Ap_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=Ap_T_psum, data=A_pow) + Ap_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=Ap_T, src=Ap_T_psum) + + Ap_sq_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=Ap_sq_psum, stationary=Ap_T, moving=A_pow) + nisa.tensor_copy(dst=A_pow, src=Ap_sq_psum) + nisa.tensor_tensor(dst=A_pow, data1=A_pow, data2=Lmask, op=nl.multiply) + + IpA = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=IpA, data1=eye, data2=A_pow, op=nl.add) + + IpA_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=IpA_T_psum, data=IpA) + IpA_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=IpA_T, src=IpA_T_psum) + + Pacc_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=Pacc_psum, stationary=IpA_T, moving=P_acc) + nisa.tensor_copy(dst=P_acc, src=Pacc_psum) + nisa.tensor_tensor(dst=P_acc, data1=P_acc, data2=Lmask_d, op=nl.multiply) + + # ============================================================ + # Apply N: value_corr = N @ v_beta, k_cumdecay = N @ (k_beta * exp_gc) + # ============================================================ + N_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=N_T_psum, data=P_acc) + N_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=N_T, src=N_T_psum) + + vc_psum = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=vc_psum, stationary=N_T, moving=v_beta) + value_corr = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=value_corr, src=vc_psum) + + kb_exp_gc = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=kb_exp_gc, + data=k_beta, + op0=nl.multiply, + operand0=exp_gc_p, + engine=nisa.vector_engine, + ) + + kcd_psum = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=kcd_psum, stationary=N_T, moving=kb_exp_gc) + k_cumdecay = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=k_cumdecay, src=kcd_psum) + + # ============================================================ + # Phase 2: Inter-chunk state propagation + # attn_intra = (q @ k^T) * decay_mask * lower_mask_diag + # ============================================================ + q_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=q_T_psum, stationary=q_c, moving=eye) + q_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=q_T, src=q_T_psum) + + qk_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=qk_psum, stationary=q_T, moving=k_T) + qk_raw = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=qk_raw, src=qk_psum) + + attn_intra = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=attn_intra, data1=qk_raw, data2=decay_diag, op=nl.multiply) + + # ============================================================ + # v_prime = k_cumdecay @ state + # ============================================================ + kcd_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=kcd_T_psum, stationary=k_cumdecay, moving=eye) + kcd_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=kcd_T, src=kcd_T_psum) + + vp_psum = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=vp_psum, stationary=kcd_T, moving=state) + v_prime = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=v_prime, src=vp_psum) + + v_new = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=v_new, data1=value_corr, data2=v_prime, op=nl.subtract) + + # ============================================================ + # attn_inter = (q * exp(g_cumsum)) @ state + # ============================================================ + q_exp = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=q_exp, + data=q_c, + op0=nl.multiply, + operand0=exp_gc_p, + engine=nisa.vector_engine, + ) + + qe_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=qe_T_psum, stationary=q_exp, moving=eye) + qe_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=qe_T, src=qe_T_psum) + + ai_psum = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=ai_psum, stationary=qe_T, moving=state) + attn_inter = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=attn_inter, src=ai_psum) + + # ============================================================ + # attn_intra @ v_new + # ============================================================ + ai_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=ai_T_psum, stationary=attn_intra, moving=eye) + ai_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=ai_T, src=ai_T_psum) + + intra_psum = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=intra_psum, stationary=ai_T, moving=v_new) + intra_out = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=intra_out, src=intra_psum) + + # ============================================================ + # chunk_output = attn_inter + intra_out + # ============================================================ + chunk_out = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=chunk_out, data1=attn_inter, data2=intra_out, op=nl.add) + + nisa.dma_copy(dst=output, src=chunk_out) + + # ============================================================ + # State update: state_new = state * exp(g_last) + # + (k * exp(g_last - gc))^T @ v_new + # ============================================================ + gl_minus_gc_p = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor( + dst=gl_minus_gc_p, + data1=gl_p, + data2=gc_p, + op=nl.subtract, + ) + exp_gl_minus_gc_p = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.activation( + dst=exp_gl_minus_gc_p, + op=nl.exp, + data=gl_minus_gc_p, + bias=None, + scale=1.0, + ) + + k_raw_decay = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=k_raw_decay, + data=k_c, + op0=nl.multiply, + operand0=exp_gl_minus_gc_p, + engine=nisa.vector_engine, + ) + + kv_psum = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=kv_psum, stationary=k_raw_decay, moving=v_new) + kv_outer = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=kv_outer, src=kv_psum) + + state_decayed = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=state_decayed, + data=state, + op0=nl.multiply, + operand0=exp_gl_p, + engine=nisa.vector_engine, + ) + + state_new = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=state_new, data1=state_decayed, data2=kv_outer, op=nl.add) + + nisa.dma_copy(dst=state_out, src=state_new) + + return output, state_out diff --git a/contrib/models/Qwen3.6-27B/src/nki_kernels/nki_deltanet_fused.py b/contrib/models/Qwen3.6-27B/src/nki_kernels/nki_deltanet_fused.py new file mode 100644 index 00000000..ed2cf80f --- /dev/null +++ b/contrib/models/Qwen3.6-27B/src/nki_kernels/nki_deltanet_fused.py @@ -0,0 +1,2991 @@ +"""Fused single-kernel DeltaNet chunked forward for CTE (context encoding). + +SSD-style architecture: processes ALL chunks for one (batch, head) pair in +a single NKI kernel call. State (128x128) persists in SBUF across chunks — +no HBM round-trips for inter-chunk state propagation. + +Key optimizations over nki_deltanet_chunked.py: + 1. Single kernel call per (B,H) instead of B*H*num_chunks calls + 2. State in SBUF across all chunks (no HBM state read/write per chunk) + 3. In-kernel cumsum via tensor_tensor_scan (no PyTorch cumsum) + 4. Masks and constants loaded once, reused across chunks + 5. Uses tensor_scalar for partition-broadcast (no explicit broadcast loops) + 6. nc_transpose (Vector Engine) for all 128x128 transposes instead of + nc_matmul(moving=eye) (Tensor Engine) — frees TE for actual math + +NKI 0.3.0 (SDK 2.29). k_dim = v_dim = 128 = P_MAX exactly. +Chunk size = 128 = P_MAX (one tile per chunk). + +Mathematical framework: + Per-chunk direct triangular solve for intra-chunk correction: + QK_decay[i,j] = QK[i,j] * exp(gc[i] - gc[j]) for i > j + A = -QK_decay * lower_mask + v_new = solve((I - A), v_beta - (k_beta * exp(gc)) @ state) + + Inter-chunk state propagation: + attn_inter = (q * exp(gc)) @ state + attn_intra = (q @ k^T) * (strict_decay + I) + output = attn_inter + attn_intra @ v_new + state = exp(g_last) * (state + k_raw_decay^T @ v_new) +""" + +import os + +import numpy as np + +import nki +import nki.isa as nisa +import nki.language as nl + +P_MAX = 128 # Partition dim = k_dim = v_dim +CHUNK_SIZE = int(os.environ.get("QWEN36_DELTANET_CHUNK_SIZE", "128")) +L2_EPS_SQUARED = 1.0e-12 +QUERY_SCALE = P_MAX ** -0.5 +SOLVE_BLOCK_SIZE = int(os.environ.get("QWEN36_DELTANET_SOLVE_BLOCK_SIZE", "32")) +if ( + CHUNK_SIZE <= 0 + or P_MAX % CHUNK_SIZE != 0 + or CHUNK_SIZE % 32 != 0 + or SOLVE_BLOCK_SIZE <= 0 + or CHUNK_SIZE % SOLVE_BLOCK_SIZE != 0 + or SOLVE_BLOCK_SIZE % 32 != 0 + or SOLVE_BLOCK_SIZE & (SOLVE_BLOCK_SIZE - 1) != 0 +): + raise ValueError( + "QWEN36_DELTANET_CHUNK_SIZE must be a positive divisor of P_MAX " + "and a multiple of the 32-partition broadcast group, while " + "QWEN36_DELTANET_SOLVE_BLOCK_SIZE must be positive, divide " + "CHUNK_SIZE, be a power of two, and be a multiple of 32; " + f"P_MAX={P_MAX}, CHUNK_SIZE={CHUNK_SIZE}, got {SOLVE_BLOCK_SIZE}" + ) +MAX_SOLVE_SCAN_STEPS = SOLVE_BLOCK_SIZE.bit_length() - 1 +SOLVE_SCAN_STEPS = int( + os.environ.get("QWEN36_DELTANET_SOLVE_SCAN_STEPS", str(MAX_SOLVE_SCAN_STEPS)) +) +if SOLVE_SCAN_STEPS <= 0 or SOLVE_SCAN_STEPS > MAX_SOLVE_SCAN_STEPS: + raise ValueError( + "QWEN36_DELTANET_SOLVE_SCAN_STEPS must be in " + f"[1, {MAX_SOLVE_SCAN_STEPS}] for SOLVE_BLOCK_SIZE={SOLVE_BLOCK_SIZE}; " + f"got {SOLVE_SCAN_STEPS}" + ) +SOLVE_ACTIVE_PREFIX_K = os.environ.get( + "QWEN36_DELTANET_SOLVE_ACTIVE_PREFIX_K", + "0", +).lower() not in ("0", "false", "no", "off") +SOLVE_MODE = os.environ.get("QWEN36_DELTANET_SOLVE_MODE", "doubling").lower() +AUTOCP_CP_CHUNKS = int(os.environ.get("QWEN36_DELTANET_AUTOCP_CP_CHUNKS", "4")) +if SOLVE_MODE not in ("doubling", "kkt_hier"): + raise ValueError( + "QWEN36_DELTANET_SOLVE_MODE must be one of " + "('doubling', 'kkt_hier'); " + f"got {SOLVE_MODE!r}" + ) +SOLVE_KKT_HIER = SOLVE_MODE == "kkt_hier" +if SOLVE_KKT_HIER and (SOLVE_BLOCK_SIZE != P_MAX or CHUNK_SIZE != P_MAX): + raise ValueError( + "QWEN36_DELTANET_SOLVE_MODE=kkt_hier currently expects " + f"QWEN36_DELTANET_CHUNK_SIZE={P_MAX} and " + f"QWEN36_DELTANET_SOLVE_BLOCK_SIZE={P_MAX}; " + f"got CHUNK_SIZE={CHUNK_SIZE}, SOLVE_BLOCK_SIZE={SOLVE_BLOCK_SIZE}" + ) +if AUTOCP_CP_CHUNKS <= 0: + raise ValueError( + "QWEN36_DELTANET_AUTOCP_CP_CHUNKS must be positive; " + f"got {AUTOCP_CP_CHUNKS}" + ) + +# Broadcast partition 0 to all partitions in a 32-wide group +_BROADCAST_MASK = [0] * 32 + + +def _make_lower_mask(): + """Strict lower triangular active chunk block in a 128x128 constant.""" + mask = np.zeros((P_MAX, P_MAX), dtype=np.float32) + mask[:CHUNK_SIZE, :CHUNK_SIZE] = np.tril( + np.ones((CHUNK_SIZE, CHUNK_SIZE), dtype=np.float32), k=-1 + ) + return mask + + +def _make_lower_mask_diag(): + """Lower triangular active chunk block with diagonal in a 128x128 constant.""" + mask = np.zeros((P_MAX, P_MAX), dtype=np.float32) + mask[:CHUNK_SIZE, :CHUNK_SIZE] = np.tril( + np.ones((CHUNK_SIZE, CHUNK_SIZE), dtype=np.float32), k=0 + ) + return mask + + +def _make_identity(): + """Identity active chunk block in a 128x128 constant.""" + identity = np.zeros((P_MAX, P_MAX), dtype=np.float32) + identity[:CHUNK_SIZE, :CHUNK_SIZE] = np.eye(CHUNK_SIZE, dtype=np.float32) + return identity + + +def _matmul_square(dst, left, right, size): + left_trans_psum = nl.ndarray((size, size), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=left_trans_psum, data=left) + left_trans = nl.ndarray((size, size), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=left_trans, src=left_trans_psum) + + out_psum = nl.ndarray((size, size), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=out_psum, stationary=left_trans, moving=right) + nisa.tensor_copy(dst=dst, src=out_psum) + + +def _offdiag_combine_t(dst, left_t, cross_t, right_t, size): + tmp = nl.ndarray((size, size), dtype=nl.float32, buffer=nl.sbuf) + _matmul_square(tmp, left_t, cross_t, size) + _matmul_square(dst, tmp, right_t, size) + + +def _leaf_inverse32_t(dst, A_T, Imat, start): + nisa.tensor_copy(dst=dst, src=Imat[0:32, 0:32]) + + power_t = nl.ndarray((32, 32), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=power_t, src=A_T[start : start + 32, start : start + 32]) + + power_psum = nl.ndarray((32, 32), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=power_psum, data=power_t) + power = nl.ndarray((32, 32), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=power, src=power_psum) + + for _scan_i in nl.static_range(5): + correction = nl.ndarray((32, 32), dtype=nl.float32, buffer=nl.sbuf) + _matmul_square(correction, dst, power_t, 32) + + next_inv_t = nl.ndarray((32, 32), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=next_inv_t, data1=dst, data2=correction, op=nl.add) + nisa.tensor_copy(dst=dst, src=next_inv_t) + + if _scan_i != 4: + power_next = nl.ndarray((32, 32), dtype=nl.float32, buffer=nl.sbuf) + _matmul_square(power_next, power, power, 32) + + power_next_t_psum = nl.ndarray( + (32, 32), dtype=nl.float32, buffer=nl.psum + ) + nisa.nc_transpose(dst=power_next_t_psum, data=power_next) + power_next_t = nl.ndarray((32, 32), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=power_next_t, src=power_next_t_psum) + + nisa.tensor_copy(dst=power, src=power_next) + nisa.tensor_copy(dst=power_t, src=power_next_t) + + +def _inverse64_t(dst, A_T, Imat, start): + nisa.memset(dst=dst, value=0.0) + + for leaf_idx in nl.static_range(2): + leaf_offset = leaf_idx * 32 + leaf_start = start + leaf_offset + leaf_t = nl.ndarray((32, 32), dtype=nl.float32, buffer=nl.sbuf) + _leaf_inverse32_t(leaf_t, A_T, Imat, leaf_start) + nisa.tensor_copy( + dst=dst[leaf_offset : leaf_offset + 32, leaf_offset : leaf_offset + 32], + src=leaf_t, + ) + + left32_t = nl.ndarray((32, 32), dtype=nl.float32, buffer=nl.sbuf) + right32_t = nl.ndarray((32, 32), dtype=nl.float32, buffer=nl.sbuf) + cross32_t = nl.ndarray((32, 32), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=left32_t, src=dst[0:32, 0:32]) + nisa.tensor_copy(dst=right32_t, src=dst[32:64, 32:64]) + nisa.tensor_copy(dst=cross32_t, src=A_T[start : start + 32, start + 32 : start + 64]) + + off32_t = nl.ndarray((32, 32), dtype=nl.float32, buffer=nl.sbuf) + _offdiag_combine_t(off32_t, left32_t, cross32_t, right32_t, 32) + nisa.tensor_copy(dst=dst[0:32, 32:64], src=off32_t) + + +def _hierarchical_kkt_solve128(v_new, A_T, Imat, solve_rhs, dim): + n_lo_t = nl.ndarray((64, 64), dtype=nl.float32, buffer=nl.sbuf) + n_hi_t = nl.ndarray((64, 64), dtype=nl.float32, buffer=nl.sbuf) + _inverse64_t(n_lo_t, A_T, Imat, 0) + _inverse64_t(n_hi_t, A_T, Imat, 64) + + n128_t = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.memset(dst=n128_t, value=0.0) + nisa.tensor_copy(dst=n128_t[0:64, 0:64], src=n_lo_t) + nisa.tensor_copy(dst=n128_t[64:128, 64:128], src=n_hi_t) + + cross64_t = nl.ndarray((64, 64), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=cross64_t, src=A_T[0:64, 64:128]) + + off64_t = nl.ndarray((64, 64), dtype=nl.float32, buffer=nl.sbuf) + _offdiag_combine_t(off64_t, n_lo_t, cross64_t, n_hi_t, 64) + nisa.tensor_copy(dst=n128_t[0:64, 64:128], src=off64_t) + + solved_psum = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=solved_psum, stationary=n128_t, moving=solve_rhs) + nisa.tensor_copy(dst=v_new, src=solved_psum) + + +def _blocked_doubling_solve(v_new, A_T, solve_rhs, dim): + for solve_block in nl.static_range(CHUNK_SIZE // SOLVE_BLOCK_SIZE): + block_start = solve_block * SOLVE_BLOCK_SIZE + block_end = block_start + SOLVE_BLOCK_SIZE + + prev_contrib = nl.ndarray( + (SOLVE_BLOCK_SIZE, dim), dtype=nl.float32, buffer=nl.sbuf + ) + if solve_block == 0: + nisa.memset(dst=prev_contrib, value=0.0) + else: + prev_psum = nl.ndarray( + (SOLVE_BLOCK_SIZE, dim), dtype=nl.float32, buffer=nl.psum + ) + if SOLVE_ACTIVE_PREFIX_K: + nisa.nc_matmul( + dst=prev_psum, + stationary=A_T[0:block_start, block_start:block_end], + moving=v_new[0:block_start, 0:dim], + ) + else: + nisa.nc_matmul( + dst=prev_psum, + stationary=A_T[0:CHUNK_SIZE, block_start:block_end], + moving=v_new[0:CHUNK_SIZE, 0:dim], + ) + nisa.tensor_copy(dst=prev_contrib, src=prev_psum) + + solve_rhs_block = nl.ndarray( + (SOLVE_BLOCK_SIZE, dim), dtype=nl.float32, buffer=nl.sbuf + ) + nisa.tensor_copy( + dst=solve_rhs_block, + src=solve_rhs[block_start:block_end, 0:dim], + ) + + residual_block = nl.ndarray( + (SOLVE_BLOCK_SIZE, dim), dtype=nl.float32, buffer=nl.sbuf + ) + nisa.tensor_tensor( + dst=residual_block, + data1=solve_rhs_block, + data2=prev_contrib, + op=nl.add, + ) + + A_diag_T = nl.ndarray( + (SOLVE_BLOCK_SIZE, SOLVE_BLOCK_SIZE), + dtype=nl.float32, + buffer=nl.sbuf, + ) + nisa.tensor_copy( + dst=A_diag_T, + src=A_T[block_start:block_end, block_start:block_end], + ) + + A_power_T = nl.ndarray( + (SOLVE_BLOCK_SIZE, SOLVE_BLOCK_SIZE), + dtype=nl.float32, + buffer=nl.sbuf, + ) + nisa.tensor_copy(dst=A_power_T, src=A_diag_T) + + A_power_psum = nl.ndarray( + (SOLVE_BLOCK_SIZE, SOLVE_BLOCK_SIZE), + dtype=nl.float32, + buffer=nl.psum, + ) + nisa.nc_transpose(dst=A_power_psum, data=A_power_T) + A_power = nl.ndarray( + (SOLVE_BLOCK_SIZE, SOLVE_BLOCK_SIZE), + dtype=nl.float32, + buffer=nl.sbuf, + ) + nisa.tensor_copy(dst=A_power, src=A_power_psum) + + local_v = nl.ndarray( + (SOLVE_BLOCK_SIZE, dim), dtype=nl.float32, buffer=nl.sbuf + ) + nisa.tensor_copy(dst=local_v, src=residual_block) + + for _scan_i in nl.static_range(SOLVE_SCAN_STEPS): + correction_psum = nl.ndarray( + (SOLVE_BLOCK_SIZE, dim), dtype=nl.float32, buffer=nl.psum + ) + nisa.nc_matmul( + dst=correction_psum, + stationary=A_power_T, + moving=local_v, + ) + correction = nl.ndarray( + (SOLVE_BLOCK_SIZE, dim), dtype=nl.float32, buffer=nl.sbuf + ) + nisa.tensor_copy(dst=correction, src=correction_psum) + + local_next = nl.ndarray( + (SOLVE_BLOCK_SIZE, dim), dtype=nl.float32, buffer=nl.sbuf + ) + nisa.tensor_tensor( + dst=local_next, data1=local_v, data2=correction, op=nl.add + ) + + nisa.tensor_copy(dst=local_v, src=local_next) + + if _scan_i == SOLVE_SCAN_STEPS - 2: + A_power_next_T_psum = nl.ndarray( + (SOLVE_BLOCK_SIZE, SOLVE_BLOCK_SIZE), + dtype=nl.float32, + buffer=nl.psum, + ) + nisa.nc_matmul( + dst=A_power_next_T_psum, + stationary=A_power, + moving=A_power_T, + ) + nisa.tensor_copy(dst=A_power_T, src=A_power_next_T_psum) + elif _scan_i != SOLVE_SCAN_STEPS - 1: + A_power_next_psum = nl.ndarray( + (SOLVE_BLOCK_SIZE, SOLVE_BLOCK_SIZE), + dtype=nl.float32, + buffer=nl.psum, + ) + nisa.nc_matmul( + dst=A_power_next_psum, + stationary=A_power_T, + moving=A_power, + ) + A_power_next = nl.ndarray( + (SOLVE_BLOCK_SIZE, SOLVE_BLOCK_SIZE), + dtype=nl.float32, + buffer=nl.sbuf, + ) + nisa.tensor_copy(dst=A_power_next, src=A_power_next_psum) + + A_power_next_T_psum = nl.ndarray( + (SOLVE_BLOCK_SIZE, SOLVE_BLOCK_SIZE), + dtype=nl.float32, + buffer=nl.psum, + ) + nisa.nc_transpose(dst=A_power_next_T_psum, data=A_power_next) + A_power_next_T = nl.ndarray( + (SOLVE_BLOCK_SIZE, SOLVE_BLOCK_SIZE), + dtype=nl.float32, + buffer=nl.sbuf, + ) + nisa.tensor_copy(dst=A_power_next_T, src=A_power_next_T_psum) + + nisa.tensor_copy(dst=A_power, src=A_power_next) + nisa.tensor_copy(dst=A_power_T, src=A_power_next_T) + + nisa.tensor_copy( + dst=v_new[block_start:block_end, 0:dim], + src=local_v[0:SOLVE_BLOCK_SIZE, 0:dim], + ) + + +@nki.jit +def deltanet_fused_chunked_fwd( + query: nl.ndarray, # (S, 128) float32 — raw Q; normalized in-kernel + key: nl.ndarray, # (S, 128) float32 — raw K; normalized in-kernel + value: nl.ndarray, # (S, 128) float32 + g_in: nl.ndarray, # (S, 1) float32 — per-token log-decay (NOT cumsum) + beta_in: nl.ndarray, # (S, 1) float32 — per-token write gate + initial_state: nl.ndarray, # (128, 128) float32 — recurrent checkpoint or zeros + lower_mask: nl.ndarray, # (128, 128) float32 — strict lower tri + identity: nl.ndarray, # (128, 128) float32 — identity + lower_mask_diag: nl.ndarray, # (128, 128) float32 — lower tri with diag +): + """Fused chunked DeltaNet forward — single kernel call per (batch, head). + + Processes all chunks sequentially within the kernel, keeping the recurrent + state (128x128) in SBUF across chunks. Returns per-token output and + final state. + + Input requirements: + - S must be divisible by 128 (pad before calling) + - query/key are raw projected chunks; l2-norm and Q scale are in-kernel + - g_in is RAW log-decay (cumsum computed in-kernel via tensor_tensor_scan) + - beta_in is sigmoid(b) (write gate) + - initial_state is zero for cold prefill, or the restored GDN checkpoint + + Returns: + output: (S, 128) float32 + final_state: (128, 128) float32 + """ + seq_len = query.shape[0] + dim = query.shape[1] # 128 + num_chunks = seq_len // CHUNK_SIZE + + # Output tensors in HBM + output = nl.ndarray((seq_len, dim), dtype=query.dtype, buffer=nl.shared_hbm) + final_state_out = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.shared_hbm) + + # ================================================================ + # Load constant masks into SBUF once (reused across all chunks) + # ================================================================ + Lmask = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.dma_copy(dst=Lmask, src=lower_mask) + + UMask_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=UMask_psum, data=Lmask) + UMask = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=UMask, src=UMask_psum) + + Imat = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.dma_copy(dst=Imat, src=identity) + + # Ones vector for cumsum scan: (1, CHUNK_SIZE) + ones_1xC = nl.ndarray((1, CHUNK_SIZE), dtype=nl.float32, buffer=nl.sbuf) + nisa.memset(dst=ones_1xC, value=1.0) + + # Zero initial for cumsum scan + zero_11 = nl.ndarray((1, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.memset(dst=zero_11, value=0.0) + + # ================================================================ + # Initialize recurrent state in SBUF — persists across ALL chunks + # ================================================================ + state = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.dma_copy(dst=state, src=initial_state) + + # ================================================================ + # Sequential chunk processing + # ================================================================ + for i_chunk in nl.sequential_range(num_chunks): + chunk_start = i_chunk * CHUNK_SIZE + + # ---- Load chunk data from HBM ---- + q_c = nl.ndarray((P_MAX, dim), dtype=query.dtype, buffer=nl.sbuf) + if CHUNK_SIZE == P_MAX: + nisa.dma_copy( + dst=q_c, + src=query[chunk_start : chunk_start + CHUNK_SIZE, 0:dim], + ) + else: + nisa.memset(dst=q_c, value=0.0) + nisa.dma_copy( + dst=q_c[0:CHUNK_SIZE, 0:dim], + src=query[chunk_start : chunk_start + CHUNK_SIZE, 0:dim], + ) + + k_c = nl.ndarray((P_MAX, dim), dtype=key.dtype, buffer=nl.sbuf) + if CHUNK_SIZE == P_MAX: + nisa.dma_copy( + dst=k_c, + src=key[chunk_start : chunk_start + CHUNK_SIZE, 0:dim], + ) + else: + nisa.memset(dst=k_c, value=0.0) + nisa.dma_copy( + dst=k_c[0:CHUNK_SIZE, 0:dim], + src=key[chunk_start : chunk_start + CHUNK_SIZE, 0:dim], + ) + + q_square = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=q_square, data1=q_c, data2=q_c, op=nl.multiply) + q_norm_sq = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_reduce(dst=q_norm_sq, data=q_square, op=nl.add, axis=1) + q_norm_sq_clamped = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=q_norm_sq_clamped, + data=q_norm_sq, + op0=nl.maximum, + operand0=L2_EPS_SQUARED, + engine=nisa.vector_engine, + ) + q_inv_norm = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=q_inv_norm, + data=q_norm_sq_clamped, + op0=nl.rsqrt, + operand0=0.0, + engine=nisa.gpsimd_engine, + ) + q_norm = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=q_norm, + data=q_c, + op0=nl.multiply, + operand0=q_inv_norm, + op1=nl.multiply, + operand1=QUERY_SCALE, + engine=nisa.vector_engine, + ) + + k_square = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=k_square, data1=k_c, data2=k_c, op=nl.multiply) + k_norm_sq = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_reduce(dst=k_norm_sq, data=k_square, op=nl.add, axis=1) + k_norm_sq_clamped = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=k_norm_sq_clamped, + data=k_norm_sq, + op0=nl.maximum, + operand0=L2_EPS_SQUARED, + engine=nisa.vector_engine, + ) + k_inv_norm = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=k_inv_norm, + data=k_norm_sq_clamped, + op0=nl.rsqrt, + operand0=0.0, + engine=nisa.gpsimd_engine, + ) + k_norm = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=k_norm, + data=k_c, + op0=nl.multiply, + operand0=k_inv_norm, + engine=nisa.vector_engine, + ) + + v_c = nl.ndarray((P_MAX, dim), dtype=value.dtype, buffer=nl.sbuf) + if CHUNK_SIZE == P_MAX: + nisa.dma_copy( + dst=v_c, + src=value[chunk_start : chunk_start + CHUNK_SIZE, 0:dim], + ) + else: + nisa.memset(dst=v_c, value=0.0) + nisa.dma_copy( + dst=v_c[0:CHUNK_SIZE, 0:dim], + src=value[chunk_start : chunk_start + CHUNK_SIZE, 0:dim], + ) + + # g: (CHUNK_SIZE, 1) — raw log-decay per token + g_chunk_p = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + if CHUNK_SIZE != P_MAX: + nisa.memset(dst=g_chunk_p, value=0.0) + nisa.dma_copy( + dst=g_chunk_p[0:CHUNK_SIZE, 0:1], + src=g_in[chunk_start : chunk_start + CHUNK_SIZE, 0:1], + ) + + # beta: (CHUNK_SIZE, 1) — write gate scalar per token + beta_p = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + if CHUNK_SIZE != P_MAX: + nisa.memset(dst=beta_p, value=0.0) + nisa.dma_copy( + dst=beta_p[0:CHUNK_SIZE, 0:1], + src=beta_in[chunk_start : chunk_start + CHUNK_SIZE, 0:1], + ) + + # ---- In-kernel cumsum of g via tensor_tensor_scan ---- + # Need g as (1, CHUNK_SIZE) for scan along free dim. Use direct + # vector transpose instead of padding through a full 128x128 tile. + g_tp_psum = nl.ndarray((1, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=g_tp_psum, data=g_chunk_p) + + g_row = nl.ndarray((1, CHUNK_SIZE), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy( + dst=g_row[0:1, 0:CHUNK_SIZE], + src=g_tp_psum[0:1, 0:CHUNK_SIZE], + ) + + # cumsum: gc_row[t] = 1.0 * gc_row[t-1] + g_row[t] + gc_row = nl.ndarray((1, CHUNK_SIZE), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor_scan( + dst=gc_row[0:1, 0:CHUNK_SIZE], + data0=ones_1xC[0:1, 0:CHUNK_SIZE], + data1=g_row[0:1, 0:CHUNK_SIZE], + initial=zero_11[0:1, 0:1], + op0=nl.multiply, + op1=nl.add, + ) + + # Transpose gc back to (CHUNK_SIZE, 1) partition layout. + gc_tp_psum = nl.ndarray((CHUNK_SIZE, 1), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=gc_tp_psum, data=gc_row) + + # gc_p: (P_MAX, 1) — cumulative sum of g per token in this chunk + gc_p = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + if CHUNK_SIZE != P_MAX: + nisa.memset(dst=gc_p, value=0.0) + nisa.tensor_copy( + dst=gc_p[0:CHUNK_SIZE, 0:1], + src=gc_tp_psum[0:CHUNK_SIZE, 0:1], + ) + + # g_last = gc[-1] (scalar) — needed for state decay + gl_11 = nl.ndarray((1, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy( + dst=gl_11[0:1, 0:1], + src=gc_row[0:1, CHUNK_SIZE - 1 : CHUNK_SIZE], + ) + + # ---- Compute exp(gc) and exp(g_last) as (P_MAX, 1) scalars ---- + # These (P_MAX, 1) tensors are used with tensor_scalar to broadcast + # across the free dimension without explicit (P_MAX, dim) copies. + + exp_gc_p = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.activation( + dst=exp_gc_p[0:P_MAX, 0:1], + op=nl.exp, + data=gc_p[0:P_MAX, 0:1], + bias=None, + scale=1.0, + ) + + exp_gl_11 = nl.ndarray((1, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.activation( + dst=exp_gl_11, + op=nl.exp, + data=gl_11, + bias=None, + scale=1.0, + ) + + gl_p = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + for i_shuf in nl.static_range(P_MAX // 32): + nisa.nc_stream_shuffle( + src=gl_11[0:1, 0:1], + dst=gl_p[i_shuf * 32 : i_shuf * 32 + 32, 0:1], + shuffle_mask=_BROADCAST_MASK, + ) + + exp_gl_p = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + for i_shuf in nl.static_range(P_MAX // 32): + nisa.nc_stream_shuffle( + src=exp_gl_11[0:1, 0:1], + dst=exp_gl_p[i_shuf * 32 : i_shuf * 32 + 32, 0:1], + shuffle_mask=_BROADCAST_MASK, + ) + + # ============================================================ + # Stable pairwise decay factors from cumulative log-decay. + # + # The original fused path used split scaling: + # exp(gc[i]) * exp(-gc[j]) + # That can materialize huge unused intermediates. Build the same + # causal decay matrices as the per-chunk kernel using exp(gc[i]-gc[j]) + # and mask after the exp so upper-triangular values cannot leak into + # later matmuls. + # ============================================================ + gc_row_broadcast = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + if CHUNK_SIZE != P_MAX: + nisa.memset(dst=gc_row_broadcast, value=0.0) + for i_shuf in nl.static_range(P_MAX // 32): + nisa.nc_stream_shuffle( + src=gc_row[0:1, 0:CHUNK_SIZE], + dst=gc_row_broadcast[ + i_shuf * 32 : i_shuf * 32 + 32, 0:CHUNK_SIZE + ], + shuffle_mask=_BROADCAST_MASK, + ) + + gc_col_strict_t = nl.ndarray( + (P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf + ) + nisa.tensor_tensor( + dst=gc_col_strict_t, + data1=gc_row_broadcast, + data2=UMask, + op=nl.multiply, + ) + gc_row_strict_t = nl.ndarray( + (P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf + ) + nisa.tensor_scalar( + dst=gc_row_strict_t, + data=UMask, + op0=nl.multiply, + operand0=gc_p, + engine=nisa.vector_engine, + ) + g_diff_strict_t = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor( + dst=g_diff_strict_t, + data1=gc_col_strict_t, + data2=gc_row_strict_t, + op=nl.subtract, + ) + decay_strict_t_raw = nl.ndarray( + (P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf + ) + nisa.activation( + dst=decay_strict_t_raw, + op=nl.exp, + data=g_diff_strict_t, + bias=None, + scale=1.0, + ) + decay_strict_t = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor( + dst=decay_strict_t, + data1=decay_strict_t_raw, + data2=UMask, + op=nl.multiply, + ) + + decay_diag_t = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor( + dst=decay_diag_t, data1=decay_strict_t, data2=Imat, op=nl.add + ) + + # ============================================================ + # k_beta = K * beta, v_beta = V * beta + # tensor_scalar broadcasts beta_p (P_MAX, 1) across free dim + # ============================================================ + k_beta = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=k_beta, + data=k_norm, + op0=nl.multiply, + operand0=beta_p, + engine=nisa.vector_engine, + ) + + v_beta = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=v_beta, + data=v_c, + op0=nl.multiply, + operand0=beta_p, + engine=nisa.vector_engine, + ) + + # ============================================================ + # Phase 1: Build A matrix (intra-chunk correction) + # Transpose K and K_beta for matmul + # ============================================================ + kb_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=kb_T_psum, data=k_beta) + kb_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=kb_T, src=kb_T_psum) + + k_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=k_T_psum, data=k_norm) + k_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=k_T, src=k_T_psum) + + # QK_T[j, i] = k_norm[j] @ k_beta[i]. Build the transposed solve + # matrix directly, avoiding a full A -> A_T transpose per chunk. + QK_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=QK_T_psum, stationary=k_T, moving=kb_T) + QK_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=QK_T, src=QK_T_psum) + + # A_T[j, i] = -QK[i, j] * exp(gc[i] - gc[j]) for i > j. + QK_decay_t = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor( + dst=QK_decay_t, data1=QK_T, data2=decay_strict_t, op=nl.multiply + ) + + A_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=A_T, + data=QK_decay_t, + op0=nl.multiply, + operand0=-1.0, + engine=nisa.vector_engine, + ) + # ============================================================ + # Build the single RHS needed for v_new. + # + # Materializing N = inv(I - A) would compute: + # value_corr = N @ v_beta + # k_cumdecay = N @ (k_beta * exp(gc)) + # v_new = value_corr - k_cumdecay @ state + # + # By associativity: + # v_new = N @ (v_beta - (k_beta * exp(gc)) @ state) + # + # Solve this RHS directly. This is equivalent to the nilpotent + # Neumann series, but avoids repeated matrix squaring, which is + # numerically unstable for realistic Qwen decay gates. + # ============================================================ + kb_exp_gc = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=kb_exp_gc, + data=k_beta, + op0=nl.multiply, + operand0=exp_gc_p, + engine=nisa.vector_engine, + ) + + kbe_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=kbe_T_psum, data=kb_exp_gc) + kbe_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=kbe_T, src=kbe_T_psum) + + kbe_state_psum = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=kbe_state_psum, stationary=kbe_T, moving=state) + kbe_state = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=kbe_state, src=kbe_state_psum) + + solve_rhs = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=solve_rhs, data1=v_beta, data2=kbe_state, op=nl.subtract) + + # ============================================================ + # Blocked forward substitution for: + # v_new = solve((I - A), solve_rhs) + # + # A is strictly lower triangular. Compute each solve block's + # contribution from previously solved rows with one dense matmul, then + # solve the small diagonal block row-by-row. This keeps the algebra + # exact while moving the wide part of the triangular solve onto TE + # tiles, closer to the FlashQLA/FLA blocked chunked-prefill structure. + # ============================================================ + v_new = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.memset(dst=v_new, value=0.0) + + if SOLVE_KKT_HIER: + _hierarchical_kkt_solve128(v_new, A_T, Imat, solve_rhs, dim) + else: + _blocked_doubling_solve(v_new, A_T, solve_rhs, dim) + + # ============================================================ + # Phase 2: Inter-chunk state propagation + # attn_intra = (q @ k^T) * (strict_decay + identity) + # ============================================================ + q_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=q_T_psum, data=q_norm) + q_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=q_T, src=q_T_psum) + + # ai_T[j, i] = (q[i] @ k[j]) * transpose(decay_diag)[j, i]. + qk_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=qk_T_psum, stationary=k_T, moving=q_T) + qk_raw_t = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=qk_raw_t, src=qk_T_psum) + + ai_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor( + dst=ai_T, data1=qk_raw_t, data2=decay_diag_t, op=nl.multiply + ) + + # ============================================================ + # attn_inter = (q * exp(gc)) @ state (state is in SBUF!) + # ============================================================ + q_exp = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=q_exp, + data=q_norm, + op0=nl.multiply, + operand0=exp_gc_p, + engine=nisa.vector_engine, + ) + + qe_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=qe_T_psum, data=q_exp) + qe_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=qe_T, src=qe_T_psum) + + ai_psum = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=ai_psum, stationary=qe_T, moving=state) + attn_inter = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=attn_inter, src=ai_psum) + + # ============================================================ + # attn_intra @ v_new + # ============================================================ + intra_psum = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=intra_psum, stationary=ai_T, moving=v_new) + intra_out = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=intra_out, src=intra_psum) + + # ============================================================ + # chunk_output = attn_inter + intra_out + # ============================================================ + chunk_out = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=chunk_out, data1=attn_inter, data2=intra_out, op=nl.add) + + # Store output chunk to HBM + if CHUNK_SIZE == P_MAX: + nisa.dma_copy( + dst=output[chunk_start : chunk_start + CHUNK_SIZE, 0:dim], + src=chunk_out, + ) + else: + nisa.dma_copy( + dst=output[chunk_start : chunk_start + CHUNK_SIZE, 0:dim], + src=chunk_out[0:CHUNK_SIZE, 0:dim], + ) + + # ============================================================ + # State update: state = exp(g_last) * (state + k_raw_decay^T @ v_new) + # state is updated IN-PLACE in SBUF — no HBM round-trip! + # ============================================================ + + # k_raw_decay contributes as exp(g_last) * (k * exp(-gc))^T @ v_new. + # Compute the equivalent stable form k * exp(g_last - gc) directly so + # no exp(-gc) intermediate can overflow. + exp_gl_minus_gc_p = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.activation( + dst=exp_gl_minus_gc_p[0:P_MAX, 0:1], + op=nl.exp, + data=gc_p[0:P_MAX, 0:1], + bias=gl_p[0:P_MAX, 0:1], + scale=-1.0, + ) + + # k_raw_decay = k * exp(g_last - gc) + k_raw_decay = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=k_raw_decay, + data=k_norm, + op0=nl.multiply, + operand0=exp_gl_minus_gc_p, + engine=nisa.vector_engine, + ) + + # k_raw_decay^T @ v_new → (dim, dim) outer product sum + # nc_matmul: result[M,N] = sum_K stationary[K,M] * moving[K,N] + # stationary=k_raw_decay (P_MAX, dim), moving=v_new (P_MAX, dim) + # Result: sum over tokens -> (dim, dim) + kv_psum = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=kv_psum, stationary=k_raw_decay, moving=v_new) + kv_outer = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=kv_outer, src=kv_psum) + + # state = state * exp(g_last) + kv_outer + state_decayed = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=state_decayed, + data=state, + op0=nl.multiply, + operand0=exp_gl_p, + engine=nisa.vector_engine, + ) + nisa.tensor_tensor(dst=state, data1=state_decayed, data2=kv_outer, op=nl.add) + + # ---- Write final state to HBM ---- + nisa.dma_copy(dst=final_state_out, src=state) + + return output, final_state_out + + +@nki.jit +def deltanet_autocp_affine_chunk( + query: nl.ndarray, # (128, 128) float32 - raw Q; normalized in-kernel + key: nl.ndarray, # (128, 128) float32 - raw K; normalized in-kernel + value: nl.ndarray, # (128, 128) float32 + g_in: nl.ndarray, # (128, 1) float32 - per-token log-decay + beta_in: nl.ndarray, # (128, 1) float32 - per-token write gate + lower_mask: nl.ndarray, # (128, 128) float32 - strict lower tri + identity: nl.ndarray, # (128, 128) float32 - identity + lower_mask_diag: nl.ndarray, # (128, 128) float32 - lower tri with diag +): + """Build one chunk's state-independent AutoCP affine pieces. + + For one 128-token DeltaNet chunk: + output = output_base + output_state @ state + next_state = state_matrix @ state + state_bias + + This probe deliberately mirrors the fused CTE chunk math and returns the + four intermediate tensors to HBM for isolated correctness validation before + wiring an AutoCP prepass into serving. + """ + dim = query.shape[1] + + output_base_out = nl.ndarray( + (P_MAX, dim), dtype=nl.float32, buffer=nl.shared_hbm + ) + output_state_out = nl.ndarray( + (P_MAX, dim), dtype=nl.float32, buffer=nl.shared_hbm + ) + state_matrix_out = nl.ndarray( + (P_MAX, dim), dtype=nl.float32, buffer=nl.shared_hbm + ) + state_bias_out = nl.ndarray( + (P_MAX, dim), dtype=nl.float32, buffer=nl.shared_hbm + ) + + Lmask = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.dma_copy(dst=Lmask, src=lower_mask) + + Imat = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.dma_copy(dst=Imat, src=identity) + + Lmask_diag = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.dma_copy(dst=Lmask_diag, src=lower_mask_diag) + + ones_1xC = nl.ndarray((1, CHUNK_SIZE), dtype=nl.float32, buffer=nl.sbuf) + nisa.memset(dst=ones_1xC, value=1.0) + + zero_11 = nl.ndarray((1, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.memset(dst=zero_11, value=0.0) + + q_c = nl.ndarray((P_MAX, dim), dtype=query.dtype, buffer=nl.sbuf) + nisa.dma_copy(dst=q_c, src=query[0:CHUNK_SIZE, 0:dim]) + + k_c = nl.ndarray((P_MAX, dim), dtype=key.dtype, buffer=nl.sbuf) + nisa.dma_copy(dst=k_c, src=key[0:CHUNK_SIZE, 0:dim]) + + q_square = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=q_square, data1=q_c, data2=q_c, op=nl.multiply) + q_norm_sq = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_reduce(dst=q_norm_sq, data=q_square, op=nl.add, axis=1) + q_norm_sq_clamped = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=q_norm_sq_clamped, + data=q_norm_sq, + op0=nl.maximum, + operand0=L2_EPS_SQUARED, + engine=nisa.vector_engine, + ) + q_inv_norm = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=q_inv_norm, + data=q_norm_sq_clamped, + op0=nl.rsqrt, + operand0=0.0, + engine=nisa.gpsimd_engine, + ) + q_norm = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=q_norm, + data=q_c, + op0=nl.multiply, + operand0=q_inv_norm, + op1=nl.multiply, + operand1=QUERY_SCALE, + engine=nisa.vector_engine, + ) + + k_square = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=k_square, data1=k_c, data2=k_c, op=nl.multiply) + k_norm_sq = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_reduce(dst=k_norm_sq, data=k_square, op=nl.add, axis=1) + k_norm_sq_clamped = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=k_norm_sq_clamped, + data=k_norm_sq, + op0=nl.maximum, + operand0=L2_EPS_SQUARED, + engine=nisa.vector_engine, + ) + k_inv_norm = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=k_inv_norm, + data=k_norm_sq_clamped, + op0=nl.rsqrt, + operand0=0.0, + engine=nisa.gpsimd_engine, + ) + k_norm = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=k_norm, + data=k_c, + op0=nl.multiply, + operand0=k_inv_norm, + engine=nisa.vector_engine, + ) + + v_c = nl.ndarray((P_MAX, dim), dtype=value.dtype, buffer=nl.sbuf) + nisa.dma_copy(dst=v_c, src=value[0:CHUNK_SIZE, 0:dim]) + + g_chunk_p = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.dma_copy(dst=g_chunk_p[0:CHUNK_SIZE, 0:1], src=g_in[0:CHUNK_SIZE, 0:1]) + + beta_p = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.dma_copy(dst=beta_p[0:CHUNK_SIZE, 0:1], src=beta_in[0:CHUNK_SIZE, 0:1]) + + g_padded = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.memset(dst=g_padded, value=0.0) + nisa.tensor_copy(dst=g_padded[0:CHUNK_SIZE, 0:1], src=g_chunk_p) + + g_tp_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=g_tp_psum, data=g_padded) + + g_row = nl.ndarray((1, CHUNK_SIZE), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=g_row[0:1, 0:CHUNK_SIZE], src=g_tp_psum[0:1, 0:CHUNK_SIZE]) + + gc_row = nl.ndarray((1, CHUNK_SIZE), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor_scan( + dst=gc_row[0:1, 0:CHUNK_SIZE], + data0=ones_1xC[0:1, 0:CHUNK_SIZE], + data1=g_row[0:1, 0:CHUNK_SIZE], + initial=zero_11[0:1, 0:1], + op0=nl.multiply, + op1=nl.add, + ) + + gc_padded = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.memset(dst=gc_padded, value=0.0) + nisa.tensor_copy(dst=gc_padded[0:1, 0:CHUNK_SIZE], src=gc_row) + + gc_tp_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=gc_tp_psum, data=gc_padded) + + gc_p = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=gc_p[0:CHUNK_SIZE, 0:1], src=gc_tp_psum[0:CHUNK_SIZE, 0:1]) + + gl_11 = nl.ndarray((1, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy( + dst=gl_11[0:1, 0:1], + src=gc_row[0:1, CHUNK_SIZE - 1 : CHUNK_SIZE], + ) + + exp_gc_p = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.activation( + dst=exp_gc_p[0:P_MAX, 0:1], + op=nl.exp, + data=gc_p[0:P_MAX, 0:1], + bias=None, + scale=1.0, + ) + + gl_p = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + for i_shuf in nl.static_range(P_MAX // 32): + nisa.nc_stream_shuffle( + src=gl_11[0:1, 0:1], + dst=gl_p[i_shuf * 32 : i_shuf * 32 + 32, 0:1], + shuffle_mask=_BROADCAST_MASK, + ) + + exp_gl_11 = nl.ndarray((1, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.activation( + dst=exp_gl_11, + op=nl.exp, + data=gl_11, + bias=None, + scale=1.0, + ) + + exp_gl_p = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + for i_shuf in nl.static_range(P_MAX // 32): + nisa.nc_stream_shuffle( + src=exp_gl_11[0:1, 0:1], + dst=exp_gl_p[i_shuf * 32 : i_shuf * 32 + 32, 0:1], + shuffle_mask=_BROADCAST_MASK, + ) + + gc_row_broadcast = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + for i_shuf in nl.static_range(P_MAX // 32): + nisa.nc_stream_shuffle( + src=gc_row[0:1, 0:P_MAX], + dst=gc_row_broadcast[i_shuf * 32 : i_shuf * 32 + 32, 0:P_MAX], + shuffle_mask=_BROADCAST_MASK, + ) + + gc_col_strict = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=gc_col_strict, + data=Lmask, + op0=nl.multiply, + operand0=gc_p, + engine=nisa.vector_engine, + ) + gc_row_strict = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor( + dst=gc_row_strict, data1=gc_row_broadcast, data2=Lmask, op=nl.multiply + ) + g_diff_strict = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor( + dst=g_diff_strict, + data1=gc_col_strict, + data2=gc_row_strict, + op=nl.subtract, + ) + decay_strict_raw = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.activation( + dst=decay_strict_raw, + op=nl.exp, + data=g_diff_strict, + bias=None, + scale=1.0, + ) + decay_strict = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor( + dst=decay_strict, data1=decay_strict_raw, data2=Lmask, op=nl.multiply + ) + + gc_col_diag = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=gc_col_diag, + data=Lmask_diag, + op0=nl.multiply, + operand0=gc_p, + engine=nisa.vector_engine, + ) + gc_row_diag = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor( + dst=gc_row_diag, + data1=gc_row_broadcast, + data2=Lmask_diag, + op=nl.multiply, + ) + g_diff_diag = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor( + dst=g_diff_diag, + data1=gc_col_diag, + data2=gc_row_diag, + op=nl.subtract, + ) + decay_diag_raw = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.activation( + dst=decay_diag_raw, + op=nl.exp, + data=g_diff_diag, + bias=None, + scale=1.0, + ) + decay_diag = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor( + dst=decay_diag, data1=decay_diag_raw, data2=Lmask_diag, op=nl.multiply + ) + + k_beta = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=k_beta, + data=k_norm, + op0=nl.multiply, + operand0=beta_p, + engine=nisa.vector_engine, + ) + + v_beta = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=v_beta, + data=v_c, + op0=nl.multiply, + operand0=beta_p, + engine=nisa.vector_engine, + ) + + kb_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=kb_T_psum, data=k_beta) + kb_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=kb_T, src=kb_T_psum) + + k_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=k_T_psum, data=k_norm) + k_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=k_T, src=k_T_psum) + + QK_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=QK_psum, stationary=kb_T, moving=k_T) + QK = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=QK, src=QK_psum) + + QK_decay = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=QK_decay, data1=QK, data2=decay_strict, op=nl.multiply) + neg_QK_decay = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=neg_QK_decay, + data=QK_decay, + op0=nl.multiply, + operand0=-1.0, + engine=nisa.vector_engine, + ) + + A_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=A_T_psum, data=neg_QK_decay) + A_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=A_T, src=A_T_psum) + + value_u = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.memset(dst=value_u, value=0.0) + if SOLVE_KKT_HIER: + _hierarchical_kkt_solve128(value_u, A_T, Imat, v_beta, dim) + else: + _blocked_doubling_solve(value_u, A_T, v_beta, dim) + + kb_exp_gc = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=kb_exp_gc, + data=k_beta, + op0=nl.multiply, + operand0=exp_gc_p, + engine=nisa.vector_engine, + ) + state_w = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.memset(dst=state_w, value=0.0) + if SOLVE_KKT_HIER: + _hierarchical_kkt_solve128(state_w, A_T, Imat, kb_exp_gc, dim) + else: + _blocked_doubling_solve(state_w, A_T, kb_exp_gc, dim) + + q_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=q_T_psum, data=q_norm) + q_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=q_T, src=q_T_psum) + + qk_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=qk_psum, stationary=q_T, moving=k_T) + qk_raw = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=qk_raw, src=qk_psum) + + attn_intra = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor( + dst=attn_intra, data1=qk_raw, data2=decay_diag, op=nl.multiply + ) + + ai_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=ai_T_psum, data=attn_intra) + ai_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=ai_T, src=ai_T_psum) + + output_base_psum = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=output_base_psum, stationary=ai_T, moving=value_u) + output_base = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=output_base, src=output_base_psum) + + q_exp = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=q_exp, + data=q_norm, + op0=nl.multiply, + operand0=exp_gc_p, + engine=nisa.vector_engine, + ) + + output_state_corr_psum = nl.ndarray( + (P_MAX, dim), dtype=nl.float32, buffer=nl.psum + ) + nisa.nc_matmul(dst=output_state_corr_psum, stationary=ai_T, moving=state_w) + output_state_corr = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=output_state_corr, src=output_state_corr_psum) + + output_state = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor( + dst=output_state, + data1=q_exp, + data2=output_state_corr, + op=nl.subtract, + ) + + gl_minus_gc_p = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor( + dst=gl_minus_gc_p, + data1=gl_p, + data2=gc_p, + op=nl.subtract, + ) + exp_gl_minus_gc_p = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.activation( + dst=exp_gl_minus_gc_p[0:P_MAX, 0:1], + op=nl.exp, + data=gl_minus_gc_p[0:P_MAX, 0:1], + bias=None, + scale=1.0, + ) + + k_raw_decay = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=k_raw_decay, + data=k_norm, + op0=nl.multiply, + operand0=exp_gl_minus_gc_p, + engine=nisa.vector_engine, + ) + + state_bias_psum = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=state_bias_psum, stationary=k_raw_decay, moving=value_u) + state_bias = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=state_bias, src=state_bias_psum) + + state_corr_psum = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=state_corr_psum, stationary=k_raw_decay, moving=state_w) + state_corr = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=state_corr, src=state_corr_psum) + + exp_gl_identity = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=exp_gl_identity, + data=Imat, + op0=nl.multiply, + operand0=exp_gl_p, + engine=nisa.vector_engine, + ) + state_matrix = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor( + dst=state_matrix, + data1=exp_gl_identity, + data2=state_corr, + op=nl.subtract, + ) + + nisa.dma_copy(dst=output_base_out, src=output_base) + nisa.dma_copy(dst=output_state_out, src=output_state) + nisa.dma_copy(dst=state_matrix_out, src=state_matrix) + nisa.dma_copy(dst=state_bias_out, src=state_bias) + + return output_base_out, output_state_out, state_matrix_out, state_bias_out + + +@nki.jit +def deltanet_autocp_affine_sequence( + query: nl.ndarray, # (S, 128) float32 - raw Q; normalized in-kernel + key: nl.ndarray, # (S, 128) float32 - raw K; normalized in-kernel + value: nl.ndarray, # (S, 128) float32 + g_in: nl.ndarray, # (S, 1) float32 - per-token log-decay + beta_in: nl.ndarray, # (S, 1) float32 - per-token write gate + lower_mask: nl.ndarray, # (128, 128) float32 - strict lower tri + identity: nl.ndarray, # (128, 128) float32 - identity + lower_mask_diag: nl.ndarray, # (128, 128) float32 - kept for call compatibility +): + """Build AutoCP affine pieces for one sequence with LNC-striped chunks.""" + seq_len = query.shape[0] + dim = query.shape[1] + num_chunks = seq_len // CHUNK_SIZE + program_idx = nl.program_id(axis=0) + num_programs = nl.num_programs(axes=0) + chunks_per_program = num_chunks // num_programs + + output_base_out = nl.ndarray( + (num_chunks, P_MAX, dim), dtype=nl.float32, buffer=nl.shared_hbm + ) + output_state_out = nl.ndarray( + (num_chunks, P_MAX, dim), dtype=nl.float32, buffer=nl.shared_hbm + ) + state_matrix_out = nl.ndarray( + (num_chunks, P_MAX, dim), dtype=nl.float32, buffer=nl.shared_hbm + ) + state_bias_out = nl.ndarray( + (num_chunks, P_MAX, dim), dtype=nl.float32, buffer=nl.shared_hbm + ) + + Lmask = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.dma_copy(dst=Lmask, src=lower_mask) + + Imat = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.dma_copy(dst=Imat, src=identity) + + ones_1xC = nl.ndarray((1, CHUNK_SIZE), dtype=nl.float32, buffer=nl.sbuf) + nisa.memset(dst=ones_1xC, value=1.0) + + zero_11 = nl.ndarray((1, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.memset(dst=zero_11, value=0.0) + + for chunk_loop in nl.sequential_range(chunks_per_program): + chunk_idx = program_idx * chunks_per_program + chunk_loop + chunk_start = chunk_idx * CHUNK_SIZE + + q_c = nl.ndarray((P_MAX, dim), dtype=query.dtype, buffer=nl.sbuf) + nisa.dma_copy( + dst=q_c, + src=query[chunk_start : chunk_start + CHUNK_SIZE, 0:dim], + ) + + k_c = nl.ndarray((P_MAX, dim), dtype=key.dtype, buffer=nl.sbuf) + nisa.dma_copy( + dst=k_c, + src=key[chunk_start : chunk_start + CHUNK_SIZE, 0:dim], + ) + + q_square = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=q_square, data1=q_c, data2=q_c, op=nl.multiply) + q_norm_sq = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_reduce(dst=q_norm_sq, data=q_square, op=nl.add, axis=1) + q_norm_sq_clamped = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=q_norm_sq_clamped, + data=q_norm_sq, + op0=nl.maximum, + operand0=L2_EPS_SQUARED, + engine=nisa.vector_engine, + ) + q_inv_norm = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=q_inv_norm, + data=q_norm_sq_clamped, + op0=nl.rsqrt, + operand0=0.0, + engine=nisa.gpsimd_engine, + ) + q_norm = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=q_norm, + data=q_c, + op0=nl.multiply, + operand0=q_inv_norm, + op1=nl.multiply, + operand1=QUERY_SCALE, + engine=nisa.vector_engine, + ) + + k_square = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=k_square, data1=k_c, data2=k_c, op=nl.multiply) + k_norm_sq = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_reduce(dst=k_norm_sq, data=k_square, op=nl.add, axis=1) + k_norm_sq_clamped = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=k_norm_sq_clamped, + data=k_norm_sq, + op0=nl.maximum, + operand0=L2_EPS_SQUARED, + engine=nisa.vector_engine, + ) + k_inv_norm = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=k_inv_norm, + data=k_norm_sq_clamped, + op0=nl.rsqrt, + operand0=0.0, + engine=nisa.gpsimd_engine, + ) + k_norm = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=k_norm, + data=k_c, + op0=nl.multiply, + operand0=k_inv_norm, + engine=nisa.vector_engine, + ) + + v_c = nl.ndarray((P_MAX, dim), dtype=value.dtype, buffer=nl.sbuf) + nisa.dma_copy( + dst=v_c, + src=value[chunk_start : chunk_start + CHUNK_SIZE, 0:dim], + ) + + g_chunk_p = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.dma_copy( + dst=g_chunk_p[0:CHUNK_SIZE, 0:1], + src=g_in[chunk_start : chunk_start + CHUNK_SIZE, 0:1], + ) + + beta_p = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.dma_copy( + dst=beta_p[0:CHUNK_SIZE, 0:1], + src=beta_in[chunk_start : chunk_start + CHUNK_SIZE, 0:1], + ) + + g_padded = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.memset(dst=g_padded, value=0.0) + nisa.tensor_copy(dst=g_padded[0:CHUNK_SIZE, 0:1], src=g_chunk_p) + + g_tp_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=g_tp_psum, data=g_padded) + + g_row = nl.ndarray((1, CHUNK_SIZE), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=g_row[0:1, 0:CHUNK_SIZE], src=g_tp_psum[0:1, 0:CHUNK_SIZE]) + + gc_row = nl.ndarray((1, CHUNK_SIZE), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor_scan( + dst=gc_row[0:1, 0:CHUNK_SIZE], + data0=ones_1xC[0:1, 0:CHUNK_SIZE], + data1=g_row[0:1, 0:CHUNK_SIZE], + initial=zero_11[0:1, 0:1], + op0=nl.multiply, + op1=nl.add, + ) + + gc_padded = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.memset(dst=gc_padded, value=0.0) + nisa.tensor_copy(dst=gc_padded[0:1, 0:CHUNK_SIZE], src=gc_row) + + gc_tp_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=gc_tp_psum, data=gc_padded) + + gc_p = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=gc_p[0:CHUNK_SIZE, 0:1], src=gc_tp_psum[0:CHUNK_SIZE, 0:1]) + + gl_11 = nl.ndarray((1, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy( + dst=gl_11[0:1, 0:1], + src=gc_row[0:1, CHUNK_SIZE - 1 : CHUNK_SIZE], + ) + + exp_gc_p = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.activation( + dst=exp_gc_p[0:P_MAX, 0:1], + op=nl.exp, + data=gc_p[0:P_MAX, 0:1], + bias=None, + scale=1.0, + ) + + exp_gl_11 = nl.ndarray((1, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.activation( + dst=exp_gl_11, + op=nl.exp, + data=gl_11, + bias=None, + scale=1.0, + ) + + gl_p = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + for i_shuf in nl.static_range(P_MAX // 32): + nisa.nc_stream_shuffle( + src=gl_11[0:1, 0:1], + dst=gl_p[i_shuf * 32 : i_shuf * 32 + 32, 0:1], + shuffle_mask=_BROADCAST_MASK, + ) + + exp_gl_p = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + for i_shuf in nl.static_range(P_MAX // 32): + nisa.nc_stream_shuffle( + src=exp_gl_11[0:1, 0:1], + dst=exp_gl_p[i_shuf * 32 : i_shuf * 32 + 32, 0:1], + shuffle_mask=_BROADCAST_MASK, + ) + + gc_row_broadcast = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + for i_shuf in nl.static_range(P_MAX // 32): + nisa.nc_stream_shuffle( + src=gc_row[0:1, 0:P_MAX], + dst=gc_row_broadcast[i_shuf * 32 : i_shuf * 32 + 32, 0:P_MAX], + shuffle_mask=_BROADCAST_MASK, + ) + + gc_col_strict = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=gc_col_strict, + data=Lmask, + op0=nl.multiply, + operand0=gc_p, + engine=nisa.vector_engine, + ) + gc_row_strict = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor( + dst=gc_row_strict, data1=gc_row_broadcast, data2=Lmask, op=nl.multiply + ) + g_diff_strict = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor( + dst=g_diff_strict, + data1=gc_col_strict, + data2=gc_row_strict, + op=nl.subtract, + ) + decay_strict_raw = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.activation( + dst=decay_strict_raw, + op=nl.exp, + data=g_diff_strict, + bias=None, + scale=1.0, + ) + decay_strict = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor( + dst=decay_strict, data1=decay_strict_raw, data2=Lmask, op=nl.multiply + ) + + decay_diag = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=decay_diag, data1=decay_strict, data2=Imat, op=nl.add) + + k_beta = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=k_beta, + data=k_norm, + op0=nl.multiply, + operand0=beta_p, + engine=nisa.vector_engine, + ) + + v_beta = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=v_beta, + data=v_c, + op0=nl.multiply, + operand0=beta_p, + engine=nisa.vector_engine, + ) + + kb_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=kb_T_psum, data=k_beta) + kb_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=kb_T, src=kb_T_psum) + + k_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=k_T_psum, data=k_norm) + k_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=k_T, src=k_T_psum) + + QK_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=QK_psum, stationary=kb_T, moving=k_T) + QK = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=QK, src=QK_psum) + + QK_decay = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=QK_decay, data1=QK, data2=decay_strict, op=nl.multiply) + neg_QK_decay = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=neg_QK_decay, + data=QK_decay, + op0=nl.multiply, + operand0=-1.0, + engine=nisa.vector_engine, + ) + + A_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=A_T_psum, data=neg_QK_decay) + A_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=A_T, src=A_T_psum) + + value_u = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.memset(dst=value_u, value=0.0) + if SOLVE_KKT_HIER: + _hierarchical_kkt_solve128(value_u, A_T, Imat, v_beta, dim) + else: + _blocked_doubling_solve(value_u, A_T, v_beta, dim) + + kb_exp_gc = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=kb_exp_gc, + data=k_beta, + op0=nl.multiply, + operand0=exp_gc_p, + engine=nisa.vector_engine, + ) + state_w = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.memset(dst=state_w, value=0.0) + if SOLVE_KKT_HIER: + _hierarchical_kkt_solve128(state_w, A_T, Imat, kb_exp_gc, dim) + else: + _blocked_doubling_solve(state_w, A_T, kb_exp_gc, dim) + + q_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=q_T_psum, data=q_norm) + q_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=q_T, src=q_T_psum) + + qk_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=qk_psum, stationary=q_T, moving=k_T) + qk_raw = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=qk_raw, src=qk_psum) + + attn_intra = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor( + dst=attn_intra, data1=qk_raw, data2=decay_diag, op=nl.multiply + ) + + ai_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=ai_T_psum, data=attn_intra) + ai_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=ai_T, src=ai_T_psum) + + output_base_psum = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=output_base_psum, stationary=ai_T, moving=value_u) + output_base = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=output_base, src=output_base_psum) + + q_exp = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=q_exp, + data=q_norm, + op0=nl.multiply, + operand0=exp_gc_p, + engine=nisa.vector_engine, + ) + + output_state_corr_psum = nl.ndarray( + (P_MAX, dim), dtype=nl.float32, buffer=nl.psum + ) + nisa.nc_matmul(dst=output_state_corr_psum, stationary=ai_T, moving=state_w) + output_state_corr = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=output_state_corr, src=output_state_corr_psum) + + output_state = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor( + dst=output_state, + data1=q_exp, + data2=output_state_corr, + op=nl.subtract, + ) + + gl_minus_gc_p = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor( + dst=gl_minus_gc_p, + data1=gl_p, + data2=gc_p, + op=nl.subtract, + ) + exp_gl_minus_gc_p = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.activation( + dst=exp_gl_minus_gc_p[0:P_MAX, 0:1], + op=nl.exp, + data=gl_minus_gc_p[0:P_MAX, 0:1], + bias=None, + scale=1.0, + ) + + k_raw_decay = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=k_raw_decay, + data=k_norm, + op0=nl.multiply, + operand0=exp_gl_minus_gc_p, + engine=nisa.vector_engine, + ) + + state_bias_psum = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=state_bias_psum, stationary=k_raw_decay, moving=value_u) + state_bias = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=state_bias, src=state_bias_psum) + + state_corr_psum = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=state_corr_psum, stationary=k_raw_decay, moving=state_w) + state_corr = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=state_corr, src=state_corr_psum) + + exp_gl_identity = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=exp_gl_identity, + data=Imat, + op0=nl.multiply, + operand0=exp_gl_p, + engine=nisa.vector_engine, + ) + state_matrix = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor( + dst=state_matrix, + data1=exp_gl_identity, + data2=state_corr, + op=nl.subtract, + ) + + nisa.dma_copy(dst=output_base_out[chunk_idx, 0:P_MAX, 0:dim], src=output_base) + nisa.dma_copy(dst=output_state_out[chunk_idx, 0:P_MAX, 0:dim], src=output_state) + nisa.dma_copy(dst=state_matrix_out[chunk_idx, 0:P_MAX, 0:dim], src=state_matrix) + nisa.dma_copy(dst=state_bias_out[chunk_idx, 0:P_MAX, 0:dim], src=state_bias) + + return output_base_out, output_state_out, state_matrix_out, state_bias_out + + +@nki.jit +def deltanet_autocp_state_summary_sequence( + key: nl.ndarray, # (S, 128) float32 - raw K; normalized in-kernel + value: nl.ndarray, # (S, 128) float32 + g_in: nl.ndarray, # (S, 1) float32 - per-token log-decay + beta_in: nl.ndarray, # (S, 1) float32 - per-token write gate + lower_mask: nl.ndarray, # (128, 128) float32 - strict lower tri + identity: nl.ndarray, # (128, 128) float32 - identity +): + """Build compact AutoCP segment state summaries. + + This is the first production-shaped AutoCP prepass: it skips query/output + affine pieces and emits only per-segment state transforms: + + state_{seg+1} = segment_matrix_seg @ state_seg + segment_bias_seg + + Segment replay can then use the existing recurrent fused kernel from the + corrected segment initial states. + """ + seq_len = key.shape[0] + dim = key.shape[1] + num_chunks = seq_len // CHUNK_SIZE + num_segments = num_chunks // AUTOCP_CP_CHUNKS + program_idx = nl.program_id(axis=0) + num_programs = nl.num_programs(axes=0) + segments_per_program = num_segments // num_programs + + segment_matrix_out = nl.ndarray( + (num_segments, P_MAX, dim), dtype=nl.float32, buffer=nl.shared_hbm + ) + segment_bias_out = nl.ndarray( + (num_segments, P_MAX, dim), dtype=nl.float32, buffer=nl.shared_hbm + ) + + Lmask = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.dma_copy(dst=Lmask, src=lower_mask) + + Imat = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.dma_copy(dst=Imat, src=identity) + + ones_1xC = nl.ndarray((1, CHUNK_SIZE), dtype=nl.float32, buffer=nl.sbuf) + nisa.memset(dst=ones_1xC, value=1.0) + + zero_11 = nl.ndarray((1, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.memset(dst=zero_11, value=0.0) + + for segment_loop in nl.sequential_range(segments_per_program): + segment_idx = program_idx * segments_per_program + segment_loop + first_chunk = segment_idx * AUTOCP_CP_CHUNKS + + segment_matrix = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=segment_matrix, src=Imat) + + segment_bias = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.memset(dst=segment_bias, value=0.0) + + for local_chunk in nl.sequential_range(AUTOCP_CP_CHUNKS): + chunk_idx = first_chunk + local_chunk + chunk_start = chunk_idx * CHUNK_SIZE + + k_c = nl.ndarray((P_MAX, dim), dtype=key.dtype, buffer=nl.sbuf) + nisa.dma_copy( + dst=k_c, + src=key[chunk_start : chunk_start + CHUNK_SIZE, 0:dim], + ) + + k_square = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=k_square, data1=k_c, data2=k_c, op=nl.multiply) + k_norm_sq = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_reduce(dst=k_norm_sq, data=k_square, op=nl.add, axis=1) + k_norm_sq_clamped = nl.ndarray( + (P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf + ) + nisa.tensor_scalar( + dst=k_norm_sq_clamped, + data=k_norm_sq, + op0=nl.maximum, + operand0=L2_EPS_SQUARED, + engine=nisa.vector_engine, + ) + k_inv_norm = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=k_inv_norm, + data=k_norm_sq_clamped, + op0=nl.rsqrt, + operand0=0.0, + engine=nisa.gpsimd_engine, + ) + k_norm = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=k_norm, + data=k_c, + op0=nl.multiply, + operand0=k_inv_norm, + engine=nisa.vector_engine, + ) + + v_c = nl.ndarray((P_MAX, dim), dtype=value.dtype, buffer=nl.sbuf) + nisa.dma_copy( + dst=v_c, + src=value[chunk_start : chunk_start + CHUNK_SIZE, 0:dim], + ) + + g_chunk_p = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.dma_copy( + dst=g_chunk_p[0:CHUNK_SIZE, 0:1], + src=g_in[chunk_start : chunk_start + CHUNK_SIZE, 0:1], + ) + + beta_p = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.dma_copy( + dst=beta_p[0:CHUNK_SIZE, 0:1], + src=beta_in[chunk_start : chunk_start + CHUNK_SIZE, 0:1], + ) + + g_padded = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.memset(dst=g_padded, value=0.0) + nisa.tensor_copy(dst=g_padded[0:CHUNK_SIZE, 0:1], src=g_chunk_p) + + g_tp_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=g_tp_psum, data=g_padded) + + g_row = nl.ndarray((1, CHUNK_SIZE), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy( + dst=g_row[0:1, 0:CHUNK_SIZE], + src=g_tp_psum[0:1, 0:CHUNK_SIZE], + ) + + gc_row = nl.ndarray((1, CHUNK_SIZE), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor_scan( + dst=gc_row[0:1, 0:CHUNK_SIZE], + data0=ones_1xC[0:1, 0:CHUNK_SIZE], + data1=g_row[0:1, 0:CHUNK_SIZE], + initial=zero_11[0:1, 0:1], + op0=nl.multiply, + op1=nl.add, + ) + + gc_padded = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.memset(dst=gc_padded, value=0.0) + nisa.tensor_copy(dst=gc_padded[0:1, 0:CHUNK_SIZE], src=gc_row) + + gc_tp_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=gc_tp_psum, data=gc_padded) + + gc_p = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy( + dst=gc_p[0:CHUNK_SIZE, 0:1], + src=gc_tp_psum[0:CHUNK_SIZE, 0:1], + ) + + gl_11 = nl.ndarray((1, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy( + dst=gl_11[0:1, 0:1], + src=gc_row[0:1, CHUNK_SIZE - 1 : CHUNK_SIZE], + ) + + exp_gc_p = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.activation( + dst=exp_gc_p[0:P_MAX, 0:1], + op=nl.exp, + data=gc_p[0:P_MAX, 0:1], + bias=None, + scale=1.0, + ) + + exp_gl_11 = nl.ndarray((1, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.activation( + dst=exp_gl_11, + op=nl.exp, + data=gl_11, + bias=None, + scale=1.0, + ) + + gl_p = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + for i_shuf in nl.static_range(P_MAX // 32): + nisa.nc_stream_shuffle( + src=gl_11[0:1, 0:1], + dst=gl_p[i_shuf * 32 : i_shuf * 32 + 32, 0:1], + shuffle_mask=_BROADCAST_MASK, + ) + + exp_gl_p = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + for i_shuf in nl.static_range(P_MAX // 32): + nisa.nc_stream_shuffle( + src=exp_gl_11[0:1, 0:1], + dst=exp_gl_p[i_shuf * 32 : i_shuf * 32 + 32, 0:1], + shuffle_mask=_BROADCAST_MASK, + ) + + gc_row_broadcast = nl.ndarray( + (P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf + ) + for i_shuf in nl.static_range(P_MAX // 32): + nisa.nc_stream_shuffle( + src=gc_row[0:1, 0:P_MAX], + dst=gc_row_broadcast[i_shuf * 32 : i_shuf * 32 + 32, 0:P_MAX], + shuffle_mask=_BROADCAST_MASK, + ) + + gc_col_strict = nl.ndarray( + (P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf + ) + nisa.tensor_scalar( + dst=gc_col_strict, + data=Lmask, + op0=nl.multiply, + operand0=gc_p, + engine=nisa.vector_engine, + ) + gc_row_strict = nl.ndarray( + (P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf + ) + nisa.tensor_tensor( + dst=gc_row_strict, + data1=gc_row_broadcast, + data2=Lmask, + op=nl.multiply, + ) + g_diff_strict = nl.ndarray( + (P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf + ) + nisa.tensor_tensor( + dst=g_diff_strict, + data1=gc_col_strict, + data2=gc_row_strict, + op=nl.subtract, + ) + decay_strict_raw = nl.ndarray( + (P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf + ) + nisa.activation( + dst=decay_strict_raw, + op=nl.exp, + data=g_diff_strict, + bias=None, + scale=1.0, + ) + decay_strict = nl.ndarray( + (P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf + ) + nisa.tensor_tensor( + dst=decay_strict, + data1=decay_strict_raw, + data2=Lmask, + op=nl.multiply, + ) + + k_beta = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=k_beta, + data=k_norm, + op0=nl.multiply, + operand0=beta_p, + engine=nisa.vector_engine, + ) + + v_beta = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=v_beta, + data=v_c, + op0=nl.multiply, + operand0=beta_p, + engine=nisa.vector_engine, + ) + + kb_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=kb_T_psum, data=k_beta) + kb_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=kb_T, src=kb_T_psum) + + k_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=k_T_psum, data=k_norm) + k_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=k_T, src=k_T_psum) + + QK_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=QK_psum, stationary=kb_T, moving=k_T) + QK = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=QK, src=QK_psum) + + QK_decay = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor( + dst=QK_decay, + data1=QK, + data2=decay_strict, + op=nl.multiply, + ) + neg_QK_decay = nl.ndarray( + (P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf + ) + nisa.tensor_scalar( + dst=neg_QK_decay, + data=QK_decay, + op0=nl.multiply, + operand0=-1.0, + engine=nisa.vector_engine, + ) + + A_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=A_T_psum, data=neg_QK_decay) + A_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=A_T, src=A_T_psum) + + value_u = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.memset(dst=value_u, value=0.0) + if SOLVE_KKT_HIER: + _hierarchical_kkt_solve128(value_u, A_T, Imat, v_beta, dim) + else: + _blocked_doubling_solve(value_u, A_T, v_beta, dim) + + kb_exp_gc = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=kb_exp_gc, + data=k_beta, + op0=nl.multiply, + operand0=exp_gc_p, + engine=nisa.vector_engine, + ) + state_w = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.memset(dst=state_w, value=0.0) + if SOLVE_KKT_HIER: + _hierarchical_kkt_solve128(state_w, A_T, Imat, kb_exp_gc, dim) + else: + _blocked_doubling_solve(state_w, A_T, kb_exp_gc, dim) + + gl_minus_gc_p = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor( + dst=gl_minus_gc_p, + data1=gl_p, + data2=gc_p, + op=nl.subtract, + ) + exp_gl_minus_gc_p = nl.ndarray( + (P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf + ) + nisa.activation( + dst=exp_gl_minus_gc_p[0:P_MAX, 0:1], + op=nl.exp, + data=gl_minus_gc_p[0:P_MAX, 0:1], + bias=None, + scale=1.0, + ) + + k_raw_decay = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=k_raw_decay, + data=k_norm, + op0=nl.multiply, + operand0=exp_gl_minus_gc_p, + engine=nisa.vector_engine, + ) + + state_bias = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + state_bias_psum = nl.ndarray( + (P_MAX, dim), dtype=nl.float32, buffer=nl.psum + ) + nisa.nc_matmul(dst=state_bias_psum, stationary=k_raw_decay, moving=value_u) + nisa.tensor_copy(dst=state_bias, src=state_bias_psum) + + state_corr_psum = nl.ndarray( + (P_MAX, dim), dtype=nl.float32, buffer=nl.psum + ) + nisa.nc_matmul(dst=state_corr_psum, stationary=k_raw_decay, moving=state_w) + state_corr = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=state_corr, src=state_corr_psum) + + exp_gl_identity = nl.ndarray( + (P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf + ) + nisa.tensor_scalar( + dst=exp_gl_identity, + data=Imat, + op0=nl.multiply, + operand0=exp_gl_p, + engine=nisa.vector_engine, + ) + state_matrix = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor( + dst=state_matrix, + data1=exp_gl_identity, + data2=state_corr, + op=nl.subtract, + ) + + state_matrix_t_psum = nl.ndarray( + (P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum + ) + nisa.nc_transpose(dst=state_matrix_t_psum, data=state_matrix) + state_matrix_t = nl.ndarray( + (P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf + ) + nisa.tensor_copy(dst=state_matrix_t, src=state_matrix_t_psum) + + composed_matrix_psum = nl.ndarray( + (P_MAX, dim), dtype=nl.float32, buffer=nl.psum + ) + nisa.nc_matmul( + dst=composed_matrix_psum, + stationary=state_matrix_t, + moving=segment_matrix, + ) + composed_matrix = nl.ndarray( + (P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf + ) + nisa.tensor_copy(dst=composed_matrix, src=composed_matrix_psum) + + propagated_bias_psum = nl.ndarray( + (P_MAX, dim), dtype=nl.float32, buffer=nl.psum + ) + nisa.nc_matmul( + dst=propagated_bias_psum, + stationary=state_matrix_t, + moving=segment_bias, + ) + propagated_bias = nl.ndarray( + (P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf + ) + nisa.tensor_copy(dst=propagated_bias, src=propagated_bias_psum) + + composed_bias = nl.ndarray( + (P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf + ) + nisa.tensor_tensor( + dst=composed_bias, + data1=propagated_bias, + data2=state_bias, + op=nl.add, + ) + + nisa.tensor_copy(dst=segment_matrix, src=composed_matrix) + nisa.tensor_copy(dst=segment_bias, src=composed_bias) + + nisa.dma_copy( + dst=segment_matrix_out[segment_idx, 0:P_MAX, 0:dim], + src=segment_matrix, + ) + nisa.dma_copy( + dst=segment_bias_out[segment_idx, 0:P_MAX, 0:dim], + src=segment_bias, + ) + + return segment_matrix_out, segment_bias_out + + +@nki.jit +def deltanet_autocp_state_prefix( + state_matrix: nl.ndarray, # (N, 128, 128) float32 + state_bias: nl.ndarray, # (N, 128, 128) float32 + initial_state: nl.ndarray, # (128, 128) float32 +): + """Apply per-chunk AutoCP state transforms and emit chunk initial states. + + Given per-chunk transforms: + state_{i+1} = state_matrix_i @ state_i + state_bias_i + + returns: + chunk_states[i] = state_i + final_state = state_N + + This is the isolated state-prefix correctness probe. A later production + path can replace the loop body with a tree/parallel prefix over the same + HBM interface. + """ + num_chunks = state_matrix.shape[0] + dim = initial_state.shape[1] + + chunk_states_out = nl.ndarray( + (num_chunks, P_MAX, dim), dtype=nl.float32, buffer=nl.shared_hbm + ) + final_state_out = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.shared_hbm) + + state = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.dma_copy(dst=state, src=initial_state[0:P_MAX, 0:dim]) + + for i_chunk in nl.sequential_range(num_chunks): + nisa.dma_copy( + dst=chunk_states_out[i_chunk, 0:P_MAX, 0:dim], + src=state, + ) + + matrix = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.dma_copy( + dst=matrix, + src=state_matrix[i_chunk, 0:P_MAX, 0:dim], + ) + + bias = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.dma_copy( + dst=bias, + src=state_bias[i_chunk, 0:P_MAX, 0:dim], + ) + + matrix_t_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=matrix_t_psum, data=matrix) + matrix_t = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=matrix_t, src=matrix_t_psum) + + propagated_psum = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=propagated_psum, stationary=matrix_t, moving=state) + propagated = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=propagated, src=propagated_psum) + + next_state = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor( + dst=next_state, + data1=propagated, + data2=bias, + op=nl.add, + ) + nisa.tensor_copy(dst=state, src=next_state) + + nisa.dma_copy(dst=final_state_out, src=state) + + return chunk_states_out, final_state_out + + +@nki.jit +def deltanet_autocp_apply_output( + output_base: nl.ndarray, # (N, 128, 128) float32 + output_state: nl.ndarray, # (N, 128, 128) float32 + chunk_states: nl.ndarray, # (N, 128, 128) float32 +): + """Apply AutoCP chunk initial states to state-dependent output terms.""" + num_chunks = output_base.shape[0] + dim = output_base.shape[2] + + output = nl.ndarray( + (num_chunks * CHUNK_SIZE, dim), dtype=nl.float32, buffer=nl.shared_hbm + ) + + for i_chunk in nl.sequential_range(num_chunks): + base = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.dma_copy( + dst=base, + src=output_base[i_chunk, 0:P_MAX, 0:dim], + ) + + state_coeff = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.dma_copy( + dst=state_coeff, + src=output_state[i_chunk, 0:P_MAX, 0:dim], + ) + + state = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.dma_copy( + dst=state, + src=chunk_states[i_chunk, 0:P_MAX, 0:dim], + ) + + coeff_t_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=coeff_t_psum, data=state_coeff) + coeff_t = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=coeff_t, src=coeff_t_psum) + + state_out_psum = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=state_out_psum, stationary=coeff_t, moving=state) + state_out = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=state_out, src=state_out_psum) + + chunk_output = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor( + dst=chunk_output, + data1=base, + data2=state_out, + op=nl.add, + ) + + nisa.dma_copy( + dst=output[i_chunk * CHUNK_SIZE : i_chunk * CHUNK_SIZE + CHUNK_SIZE, 0:dim], + src=chunk_output, + ) + + return output + + +@nki.jit +def deltanet_autocp_prefix_apply_output( + output_base: nl.ndarray, # (N, 128, 128) float32 + output_state: nl.ndarray, # (N, 128, 128) float32 + state_matrix: nl.ndarray, # (N, 128, 128) float32 + state_bias: nl.ndarray, # (N, 128, 128) float32 + initial_state: nl.ndarray, # (128, 128) float32 +): + """Fused AutoCP state-prefix and output-apply pass. + + This removes the intermediate chunk_states HBM tensor and one custom-call + from the AutoCP probe path. It intentionally remains an exact sequential + prefix over dense 128x128 chunk transforms; a matrix-affine prefix cannot be + represented by tensor_tensor_scan's elementwise recurrence. + """ + num_chunks = output_base.shape[0] + dim = output_base.shape[2] + + output = nl.ndarray( + (num_chunks * CHUNK_SIZE, dim), dtype=nl.float32, buffer=nl.shared_hbm + ) + final_state_out = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.shared_hbm) + + state = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.dma_copy(dst=state, src=initial_state[0:P_MAX, 0:dim]) + + for i_chunk in nl.sequential_range(num_chunks): + base = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.dma_copy( + dst=base, + src=output_base[i_chunk, 0:P_MAX, 0:dim], + ) + + state_coeff = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.dma_copy( + dst=state_coeff, + src=output_state[i_chunk, 0:P_MAX, 0:dim], + ) + + coeff_t_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=coeff_t_psum, data=state_coeff) + coeff_t = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=coeff_t, src=coeff_t_psum) + + state_out_psum = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=state_out_psum, stationary=coeff_t, moving=state) + state_out = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=state_out, src=state_out_psum) + + chunk_output = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor( + dst=chunk_output, + data1=base, + data2=state_out, + op=nl.add, + ) + nisa.dma_copy( + dst=output[i_chunk * CHUNK_SIZE : i_chunk * CHUNK_SIZE + CHUNK_SIZE, 0:dim], + src=chunk_output, + ) + + matrix = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.dma_copy( + dst=matrix, + src=state_matrix[i_chunk, 0:P_MAX, 0:dim], + ) + + bias = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.dma_copy( + dst=bias, + src=state_bias[i_chunk, 0:P_MAX, 0:dim], + ) + + matrix_t_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=matrix_t_psum, data=matrix) + matrix_t = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=matrix_t, src=matrix_t_psum) + + propagated_psum = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=propagated_psum, stationary=matrix_t, moving=state) + propagated = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=propagated, src=propagated_psum) + + next_state = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor( + dst=next_state, + data1=propagated, + data2=bias, + op=nl.add, + ) + nisa.tensor_copy(dst=state, src=next_state) + + nisa.dma_copy(dst=final_state_out, src=state) + + return output, final_state_out + + +@nki.jit +def deltanet_fused_chunked_fwd_multihead( + query: nl.ndarray, # (BH, S, 128) float32 — raw Q; normalized in-kernel + key: nl.ndarray, # (BH, S, 128) float32 — raw K; normalized in-kernel + value: nl.ndarray, # (BH, S, 128) float32 + g_in: nl.ndarray, # (BH, S, 1) float32 + beta_in: nl.ndarray, # (BH, S, 1) float32 + initial_state: nl.ndarray, # (BH, 128, 128) float32 + lower_mask: nl.ndarray, # (128, 128) float32 + identity: nl.ndarray, # (128, 128) float32 + lower_mask_diag: nl.ndarray, # (128, 128) float32 +): + """Fused chunked DeltaNet forward for one or more heads with SPMD sharding.""" + num_heads = query.shape[0] + seq_len = query.shape[1] + dim = query.shape[2] + num_chunks = seq_len // CHUNK_SIZE + head_idx = nl.program_id(axis=0) + + output = nl.ndarray( + (num_heads, seq_len, dim), dtype=query.dtype, buffer=nl.shared_hbm + ) + final_state_out = nl.ndarray( + (num_heads, P_MAX, dim), dtype=nl.float32, buffer=nl.shared_hbm + ) + + Lmask = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.dma_copy(dst=Lmask, src=lower_mask) + + UMask_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=UMask_psum, data=Lmask) + UMask = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=UMask, src=UMask_psum) + + Imat = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.dma_copy(dst=Imat, src=identity) + + ones_1xC = nl.ndarray((1, CHUNK_SIZE), dtype=nl.float32, buffer=nl.sbuf) + nisa.memset(dst=ones_1xC, value=1.0) + + zero_11 = nl.ndarray((1, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.memset(dst=zero_11, value=0.0) + + state = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.dma_copy(dst=state, src=initial_state[head_idx, 0:P_MAX, 0:dim]) + + for i_chunk in nl.sequential_range(num_chunks): + chunk_start = i_chunk * CHUNK_SIZE + + q_c = nl.ndarray((P_MAX, dim), dtype=query.dtype, buffer=nl.sbuf) + if CHUNK_SIZE == P_MAX: + nisa.dma_copy( + dst=q_c, + src=query[head_idx, chunk_start : chunk_start + CHUNK_SIZE, 0:dim], + ) + else: + nisa.memset(dst=q_c, value=0.0) + nisa.dma_copy( + dst=q_c[0:CHUNK_SIZE, 0:dim], + src=query[head_idx, chunk_start : chunk_start + CHUNK_SIZE, 0:dim], + ) + + k_c = nl.ndarray((P_MAX, dim), dtype=key.dtype, buffer=nl.sbuf) + if CHUNK_SIZE == P_MAX: + nisa.dma_copy( + dst=k_c, + src=key[head_idx, chunk_start : chunk_start + CHUNK_SIZE, 0:dim], + ) + else: + nisa.memset(dst=k_c, value=0.0) + nisa.dma_copy( + dst=k_c[0:CHUNK_SIZE, 0:dim], + src=key[head_idx, chunk_start : chunk_start + CHUNK_SIZE, 0:dim], + ) + + q_square = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=q_square, data1=q_c, data2=q_c, op=nl.multiply) + q_norm_sq = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_reduce(dst=q_norm_sq, data=q_square, op=nl.add, axis=1) + q_norm_sq_clamped = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=q_norm_sq_clamped, + data=q_norm_sq, + op0=nl.maximum, + operand0=L2_EPS_SQUARED, + engine=nisa.vector_engine, + ) + q_inv_norm = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=q_inv_norm, + data=q_norm_sq_clamped, + op0=nl.rsqrt, + operand0=0.0, + engine=nisa.gpsimd_engine, + ) + q_norm = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=q_norm, + data=q_c, + op0=nl.multiply, + operand0=q_inv_norm, + op1=nl.multiply, + operand1=QUERY_SCALE, + engine=nisa.vector_engine, + ) + + k_square = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=k_square, data1=k_c, data2=k_c, op=nl.multiply) + k_norm_sq = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_reduce(dst=k_norm_sq, data=k_square, op=nl.add, axis=1) + k_norm_sq_clamped = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=k_norm_sq_clamped, + data=k_norm_sq, + op0=nl.maximum, + operand0=L2_EPS_SQUARED, + engine=nisa.vector_engine, + ) + k_inv_norm = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=k_inv_norm, + data=k_norm_sq_clamped, + op0=nl.rsqrt, + operand0=0.0, + engine=nisa.gpsimd_engine, + ) + k_norm = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=k_norm, + data=k_c, + op0=nl.multiply, + operand0=k_inv_norm, + engine=nisa.vector_engine, + ) + + v_c = nl.ndarray((P_MAX, dim), dtype=value.dtype, buffer=nl.sbuf) + if CHUNK_SIZE == P_MAX: + nisa.dma_copy( + dst=v_c, + src=value[head_idx, chunk_start : chunk_start + CHUNK_SIZE, 0:dim], + ) + else: + nisa.memset(dst=v_c, value=0.0) + nisa.dma_copy( + dst=v_c[0:CHUNK_SIZE, 0:dim], + src=value[head_idx, chunk_start : chunk_start + CHUNK_SIZE, 0:dim], + ) + + g_chunk_p = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + if CHUNK_SIZE != P_MAX: + nisa.memset(dst=g_chunk_p, value=0.0) + nisa.dma_copy( + dst=g_chunk_p[0:CHUNK_SIZE, 0:1], + src=g_in[head_idx, chunk_start : chunk_start + CHUNK_SIZE, 0:1], + ) + + beta_p = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + if CHUNK_SIZE != P_MAX: + nisa.memset(dst=beta_p, value=0.0) + nisa.dma_copy( + dst=beta_p[0:CHUNK_SIZE, 0:1], + src=beta_in[head_idx, chunk_start : chunk_start + CHUNK_SIZE, 0:1], + ) + + g_tp_psum = nl.ndarray((1, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=g_tp_psum, data=g_chunk_p) + + g_row = nl.ndarray((1, CHUNK_SIZE), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy( + dst=g_row[0:1, 0:CHUNK_SIZE], + src=g_tp_psum[0:1, 0:CHUNK_SIZE], + ) + + gc_row = nl.ndarray((1, CHUNK_SIZE), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor_scan( + dst=gc_row[0:1, 0:CHUNK_SIZE], + data0=ones_1xC[0:1, 0:CHUNK_SIZE], + data1=g_row[0:1, 0:CHUNK_SIZE], + initial=zero_11[0:1, 0:1], + op0=nl.multiply, + op1=nl.add, + ) + + gc_tp_psum = nl.ndarray((CHUNK_SIZE, 1), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=gc_tp_psum, data=gc_row) + + gc_p = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + if CHUNK_SIZE != P_MAX: + nisa.memset(dst=gc_p, value=0.0) + nisa.tensor_copy( + dst=gc_p[0:CHUNK_SIZE, 0:1], + src=gc_tp_psum[0:CHUNK_SIZE, 0:1], + ) + + gl_11 = nl.ndarray((1, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy( + dst=gl_11[0:1, 0:1], + src=gc_row[0:1, CHUNK_SIZE - 1 : CHUNK_SIZE], + ) + + exp_gc_p = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.activation( + dst=exp_gc_p[0:P_MAX, 0:1], + op=nl.exp, + data=gc_p[0:P_MAX, 0:1], + bias=None, + scale=1.0, + ) + + exp_gl_11 = nl.ndarray((1, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.activation( + dst=exp_gl_11, + op=nl.exp, + data=gl_11, + bias=None, + scale=1.0, + ) + + gl_p = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + for i_shuf in nl.static_range(P_MAX // 32): + nisa.nc_stream_shuffle( + src=gl_11[0:1, 0:1], + dst=gl_p[i_shuf * 32 : i_shuf * 32 + 32, 0:1], + shuffle_mask=_BROADCAST_MASK, + ) + + exp_gl_p = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + for i_shuf in nl.static_range(P_MAX // 32): + nisa.nc_stream_shuffle( + src=exp_gl_11[0:1, 0:1], + dst=exp_gl_p[i_shuf * 32 : i_shuf * 32 + 32, 0:1], + shuffle_mask=_BROADCAST_MASK, + ) + + gc_row_broadcast = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + if CHUNK_SIZE != P_MAX: + nisa.memset(dst=gc_row_broadcast, value=0.0) + for i_shuf in nl.static_range(P_MAX // 32): + nisa.nc_stream_shuffle( + src=gc_row[0:1, 0:CHUNK_SIZE], + dst=gc_row_broadcast[ + i_shuf * 32 : i_shuf * 32 + 32, 0:CHUNK_SIZE + ], + shuffle_mask=_BROADCAST_MASK, + ) + + gc_col_strict_t = nl.ndarray( + (P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf + ) + nisa.tensor_tensor( + dst=gc_col_strict_t, + data1=gc_row_broadcast, + data2=UMask, + op=nl.multiply, + ) + gc_row_strict_t = nl.ndarray( + (P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf + ) + nisa.tensor_scalar( + dst=gc_row_strict_t, + data=UMask, + op0=nl.multiply, + operand0=gc_p, + engine=nisa.vector_engine, + ) + g_diff_strict_t = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor( + dst=g_diff_strict_t, + data1=gc_col_strict_t, + data2=gc_row_strict_t, + op=nl.subtract, + ) + decay_strict_t_raw = nl.ndarray( + (P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf + ) + nisa.activation( + dst=decay_strict_t_raw, + op=nl.exp, + data=g_diff_strict_t, + bias=None, + scale=1.0, + ) + decay_strict_t = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor( + dst=decay_strict_t, + data1=decay_strict_t_raw, + data2=UMask, + op=nl.multiply, + ) + + decay_diag_t = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor( + dst=decay_diag_t, data1=decay_strict_t, data2=Imat, op=nl.add + ) + + k_beta = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=k_beta, + data=k_norm, + op0=nl.multiply, + operand0=beta_p, + engine=nisa.vector_engine, + ) + + v_beta = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=v_beta, + data=v_c, + op0=nl.multiply, + operand0=beta_p, + engine=nisa.vector_engine, + ) + + kb_T_psum = nl.ndarray((P_MAX, CHUNK_SIZE), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=kb_T_psum, data=k_beta[0:CHUNK_SIZE, 0:dim]) + kb_T = nl.ndarray((P_MAX, CHUNK_SIZE), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=kb_T, src=kb_T_psum) + + k_T_psum = nl.ndarray((P_MAX, CHUNK_SIZE), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=k_T_psum, data=k_norm[0:CHUNK_SIZE, 0:dim]) + k_T = nl.ndarray((P_MAX, CHUNK_SIZE), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=k_T, src=k_T_psum) + + QK_T_psum = nl.ndarray( + (CHUNK_SIZE, CHUNK_SIZE), dtype=nl.float32, buffer=nl.psum + ) + nisa.nc_matmul(dst=QK_T_psum, stationary=k_T, moving=kb_T) + QK_T = nl.ndarray((CHUNK_SIZE, CHUNK_SIZE), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=QK_T, src=QK_T_psum) + + QK_decay_t = nl.ndarray( + (CHUNK_SIZE, CHUNK_SIZE), dtype=nl.float32, buffer=nl.sbuf + ) + nisa.tensor_tensor( + dst=QK_decay_t, + data1=QK_T, + data2=decay_strict_t[0:CHUNK_SIZE, 0:CHUNK_SIZE], + op=nl.multiply, + ) + + A_T = nl.ndarray((CHUNK_SIZE, CHUNK_SIZE), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=A_T, + data=QK_decay_t, + op0=nl.multiply, + operand0=-1.0, + engine=nisa.vector_engine, + ) + kb_exp_gc = nl.ndarray((CHUNK_SIZE, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=kb_exp_gc, + data=k_beta[0:CHUNK_SIZE, 0:dim], + op0=nl.multiply, + operand0=exp_gc_p[0:CHUNK_SIZE, 0:1], + engine=nisa.vector_engine, + ) + + kbe_T_psum = nl.ndarray((P_MAX, CHUNK_SIZE), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=kbe_T_psum, data=kb_exp_gc) + kbe_T = nl.ndarray((P_MAX, CHUNK_SIZE), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=kbe_T, src=kbe_T_psum) + + kbe_state_psum = nl.ndarray((CHUNK_SIZE, dim), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=kbe_state_psum, stationary=kbe_T, moving=state) + kbe_state = nl.ndarray((CHUNK_SIZE, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=kbe_state, src=kbe_state_psum) + + solve_rhs = nl.ndarray((CHUNK_SIZE, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor( + dst=solve_rhs, + data1=v_beta[0:CHUNK_SIZE, 0:dim], + data2=kbe_state, + op=nl.subtract, + ) + + v_new = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.memset(dst=v_new, value=0.0) + + if SOLVE_KKT_HIER: + _hierarchical_kkt_solve128(v_new, A_T, Imat, solve_rhs, dim) + else: + _blocked_doubling_solve(v_new, A_T, solve_rhs, dim) + + q_T_psum = nl.ndarray((P_MAX, CHUNK_SIZE), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=q_T_psum, data=q_norm[0:CHUNK_SIZE, 0:dim]) + q_T = nl.ndarray((P_MAX, CHUNK_SIZE), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=q_T, src=q_T_psum) + + qk_T_psum = nl.ndarray( + (CHUNK_SIZE, CHUNK_SIZE), dtype=nl.float32, buffer=nl.psum + ) + nisa.nc_matmul(dst=qk_T_psum, stationary=k_T, moving=q_T) + qk_raw_t = nl.ndarray( + (CHUNK_SIZE, CHUNK_SIZE), dtype=nl.float32, buffer=nl.sbuf + ) + nisa.tensor_copy(dst=qk_raw_t, src=qk_T_psum) + + ai_T = nl.ndarray((CHUNK_SIZE, CHUNK_SIZE), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor( + dst=ai_T, + data1=qk_raw_t, + data2=decay_diag_t[0:CHUNK_SIZE, 0:CHUNK_SIZE], + op=nl.multiply, + ) + + q_exp = nl.ndarray((CHUNK_SIZE, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=q_exp, + data=q_norm[0:CHUNK_SIZE, 0:dim], + op0=nl.multiply, + operand0=exp_gc_p[0:CHUNK_SIZE, 0:1], + engine=nisa.vector_engine, + ) + + qe_T_psum = nl.ndarray((P_MAX, CHUNK_SIZE), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=qe_T_psum, data=q_exp) + qe_T = nl.ndarray((P_MAX, CHUNK_SIZE), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=qe_T, src=qe_T_psum) + + ai_psum = nl.ndarray((CHUNK_SIZE, dim), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=ai_psum, stationary=qe_T, moving=state) + attn_inter = nl.ndarray((CHUNK_SIZE, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=attn_inter, src=ai_psum) + + intra_psum = nl.ndarray((CHUNK_SIZE, dim), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul( + dst=intra_psum, + stationary=ai_T, + moving=v_new[0:CHUNK_SIZE, 0:dim], + ) + intra_out = nl.ndarray((CHUNK_SIZE, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=intra_out, src=intra_psum) + + chunk_out = nl.ndarray((CHUNK_SIZE, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=chunk_out, data1=attn_inter, data2=intra_out, op=nl.add) + + nisa.dma_copy( + dst=output[head_idx, chunk_start : chunk_start + CHUNK_SIZE, 0:dim], + src=chunk_out, + ) + + exp_gl_minus_gc_p = nl.ndarray( + (CHUNK_SIZE, 1), dtype=nl.float32, buffer=nl.sbuf + ) + nisa.activation( + dst=exp_gl_minus_gc_p[0:CHUNK_SIZE, 0:1], + op=nl.exp, + data=gc_p[0:CHUNK_SIZE, 0:1], + bias=gl_p[0:CHUNK_SIZE, 0:1], + scale=-1.0, + ) + + k_raw_decay = nl.ndarray((CHUNK_SIZE, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=k_raw_decay, + data=k_norm[0:CHUNK_SIZE, 0:dim], + op0=nl.multiply, + operand0=exp_gl_minus_gc_p, + engine=nisa.vector_engine, + ) + + kv_psum = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul( + dst=kv_psum, + stationary=k_raw_decay, + moving=v_new[0:CHUNK_SIZE, 0:dim], + ) + kv_outer = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=kv_outer, src=kv_psum) + + state_decayed = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=state_decayed, + data=state, + op0=nl.multiply, + operand0=exp_gl_p, + engine=nisa.vector_engine, + ) + nisa.tensor_tensor(dst=state, data1=state_decayed, data2=kv_outer, op=nl.add) + + nisa.dma_copy(dst=final_state_out[head_idx, 0:P_MAX, 0:dim], src=state) + + return output, final_state_out diff --git a/contrib/models/Qwen3.6-27B/src/nki_kernels/nki_deltanet_fused_legacy.py b/contrib/models/Qwen3.6-27B/src/nki_kernels/nki_deltanet_fused_legacy.py new file mode 100644 index 00000000..5d5562b5 --- /dev/null +++ b/contrib/models/Qwen3.6-27B/src/nki_kernels/nki_deltanet_fused_legacy.py @@ -0,0 +1,613 @@ +"""Fused single-kernel DeltaNet chunked forward for CTE (context encoding). + +SSD-style architecture: processes ALL chunks for one (batch, head) pair in +a single NKI kernel call. State (128x128) persists in SBUF across chunks — +no HBM round-trips for inter-chunk state propagation. + +Key optimizations over nki_deltanet_chunked.py: + 1. Single kernel call per (B,H) instead of B*H*num_chunks calls + 2. State in SBUF across all chunks (no HBM state read/write per chunk) + 3. In-kernel cumsum via tensor_tensor_scan (no PyTorch cumsum) + 4. Masks and constants loaded once, reused across chunks + 5. Uses tensor_scalar for partition-broadcast (no explicit broadcast loops) + 6. nc_transpose (Vector Engine) for all 128x128 transposes instead of + nc_matmul(moving=eye) (Tensor Engine) — frees TE for actual math + +NKI 0.3.0 (SDK 2.29). k_dim = v_dim = 128 = P_MAX exactly. +Chunk size = 128 = P_MAX (one tile per chunk). + +Mathematical framework: + Per-chunk direct triangular solve for intra-chunk correction: + QK_decay[i,j] = QK[i,j] * exp(gc[i] - gc[j]) for i > j + A = -QK_decay * lower_mask + v_new = solve((I - A), v_beta - (k_beta * exp(gc)) @ state) + + Inter-chunk state propagation: + attn_inter = (q * exp(gc)) @ state + attn_intra = (q @ k^T) * decay_mask * lower_mask_diag + output = attn_inter + attn_intra @ v_new + state = exp(g_last) * (state + k_raw_decay^T @ v_new) +""" + +import numpy as np + +import nki +import nki.isa as nisa +import nki.language as nl + +P_MAX = 128 # Partition dim = chunk_size = k_dim = v_dim +CHUNK_SIZE = 128 + +# Broadcast partition 0 to all partitions in a 32-wide group +_BROADCAST_MASK = [0] * 32 + + +def _make_lower_mask(): + """Strict lower triangular (128x128) as numpy constant.""" + return np.tril(np.ones((CHUNK_SIZE, CHUNK_SIZE), dtype=np.float32), k=-1) + + +def _make_lower_mask_diag(): + """Lower triangular with diagonal (128x128) as numpy constant.""" + return np.tril(np.ones((CHUNK_SIZE, CHUNK_SIZE), dtype=np.float32), k=0) + + +def _make_identity(): + """Identity matrix (128x128) as numpy constant.""" + return np.eye(CHUNK_SIZE, dtype=np.float32) + + +@nki.jit +def deltanet_fused_chunked_fwd( + query: nl.ndarray, # (S, 128) float32 — l2-normed and scaled + key: nl.ndarray, # (S, 128) float32 — l2-normed + value: nl.ndarray, # (S, 128) float32 + g_in: nl.ndarray, # (S, 1) float32 — per-token log-decay (NOT cumsum) + beta_in: nl.ndarray, # (S, 1) float32 — per-token write gate + initial_state: nl.ndarray, # (128, 128) float32 — recurrent checkpoint or zeros + lower_mask: nl.ndarray, # (128, 128) float32 — strict lower tri + identity: nl.ndarray, # (128, 128) float32 — identity + lower_mask_diag: nl.ndarray, # (128, 128) float32 — lower tri with diag +): + """Fused chunked DeltaNet forward — single kernel call per (batch, head). + + Processes all chunks sequentially within the kernel, keeping the recurrent + state (128x128) in SBUF across chunks. Returns per-token output and + final state. + + Input requirements: + - S must be divisible by 128 (pad before calling) + - query must be l2-normed and scaled by 1/sqrt(k_dim) + - key must be l2-normed + - g_in is RAW log-decay (cumsum computed in-kernel via tensor_tensor_scan) + - beta_in is sigmoid(b) (write gate) + - initial_state is zero for cold prefill, or the restored GDN checkpoint + + Returns: + output: (S, 128) float32 + final_state: (128, 128) float32 + """ + seq_len = query.shape[0] + dim = query.shape[1] # 128 + num_chunks = seq_len // CHUNK_SIZE + + # Output tensors in HBM + output = nl.ndarray((seq_len, dim), dtype=query.dtype, buffer=nl.shared_hbm) + final_state_out = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.shared_hbm) + + # ================================================================ + # Load constant masks into SBUF once (reused across all chunks) + # ================================================================ + eye = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.dma_copy(dst=eye, src=identity) + + Lmask = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.dma_copy(dst=Lmask, src=lower_mask) + + Lmask_d = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.dma_copy(dst=Lmask_d, src=lower_mask_diag) + + # Ones vector for cumsum scan: (1, CHUNK_SIZE) + ones_1xC = nl.ndarray((1, CHUNK_SIZE), dtype=nl.float32, buffer=nl.sbuf) + nisa.memset(dst=ones_1xC, value=1.0) + + # Zero initial for cumsum scan + zero_11 = nl.ndarray((1, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.memset(dst=zero_11, value=0.0) + + # ================================================================ + # Initialize recurrent state in SBUF — persists across ALL chunks + # ================================================================ + state = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.dma_copy(dst=state, src=initial_state) + + # ================================================================ + # Sequential chunk processing + # ================================================================ + for i_chunk in nl.sequential_range(num_chunks): + chunk_start = i_chunk * CHUNK_SIZE + + # ---- Load chunk data from HBM ---- + q_c = nl.ndarray((P_MAX, dim), dtype=query.dtype, buffer=nl.sbuf) + nisa.dma_copy( + dst=q_c, + src=query[chunk_start : chunk_start + CHUNK_SIZE, 0:dim], + ) + + k_c = nl.ndarray((P_MAX, dim), dtype=key.dtype, buffer=nl.sbuf) + nisa.dma_copy( + dst=k_c, + src=key[chunk_start : chunk_start + CHUNK_SIZE, 0:dim], + ) + + v_c = nl.ndarray((P_MAX, dim), dtype=value.dtype, buffer=nl.sbuf) + nisa.dma_copy( + dst=v_c, + src=value[chunk_start : chunk_start + CHUNK_SIZE, 0:dim], + ) + + # g: (CHUNK_SIZE, 1) — raw log-decay per token + g_chunk_p = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.dma_copy( + dst=g_chunk_p[0:CHUNK_SIZE, 0:1], + src=g_in[chunk_start : chunk_start + CHUNK_SIZE, 0:1], + ) + + # beta: (CHUNK_SIZE, 1) — write gate scalar per token + beta_p = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.dma_copy( + dst=beta_p[0:CHUNK_SIZE, 0:1], + src=beta_in[chunk_start : chunk_start + CHUNK_SIZE, 0:1], + ) + + # ---- In-kernel cumsum of g via tensor_tensor_scan ---- + # Need g as (1, CHUNK_SIZE) for scan along free dim. + # Transpose: (CHUNK_SIZE, 1) -> (1, CHUNK_SIZE) via nc_transpose + g_padded = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.memset(dst=g_padded, value=0.0) + nisa.tensor_copy( + dst=g_padded[0:CHUNK_SIZE, 0:1], + src=g_chunk_p[0:CHUNK_SIZE, 0:1], + ) + + g_tp_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=g_tp_psum, data=g_padded) + + g_row = nl.ndarray((1, CHUNK_SIZE), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy( + dst=g_row[0:1, 0:CHUNK_SIZE], + src=g_tp_psum[0:1, 0:CHUNK_SIZE], + ) + + # cumsum: gc_row[t] = 1.0 * gc_row[t-1] + g_row[t] + gc_row = nl.ndarray((1, CHUNK_SIZE), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor_scan( + dst=gc_row[0:1, 0:CHUNK_SIZE], + data0=ones_1xC[0:1, 0:CHUNK_SIZE], + data1=g_row[0:1, 0:CHUNK_SIZE], + initial=zero_11[0:1, 0:1], + op0=nl.multiply, + op1=nl.add, + ) + + # Transpose gc back to (CHUNK_SIZE, 1) partition layout + gc_padded = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.memset(dst=gc_padded, value=0.0) + nisa.tensor_copy( + dst=gc_padded[0:1, 0:CHUNK_SIZE], + src=gc_row[0:1, 0:CHUNK_SIZE], + ) + + gc_tp_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=gc_tp_psum, data=gc_padded) + + # gc_p: (P_MAX, 1) — cumulative sum of g per token in this chunk + gc_p = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy( + dst=gc_p[0:CHUNK_SIZE, 0:1], + src=gc_tp_psum[0:CHUNK_SIZE, 0:1], + ) + + # g_last = gc[-1] (scalar) — needed for state decay + gl_11 = nl.ndarray((1, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy( + dst=gl_11[0:1, 0:1], + src=gc_row[0:1, CHUNK_SIZE - 1 : CHUNK_SIZE], + ) + + # ---- Compute exp(gc) and exp(g_last) as (P_MAX, 1) scalars ---- + # These (P_MAX, 1) tensors are used with tensor_scalar to broadcast + # across the free dimension without explicit (P_MAX, dim) copies. + + exp_gc_p = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.activation( + dst=exp_gc_p[0:P_MAX, 0:1], + op=nl.exp, + data=gc_p[0:P_MAX, 0:1], + bias=None, + scale=1.0, + ) + + # g_last: scalar, then broadcast to (P_MAX, 1) for direct + # exp(g_last - gc) in the state update. + gl_p = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + for i_shuf in nl.static_range(P_MAX // 32): + nisa.nc_stream_shuffle( + src=gl_11[0:1, 0:1], + dst=gl_p[i_shuf * 32 : i_shuf * 32 + 32, 0:1], + shuffle_mask=_BROADCAST_MASK, + ) + + # exp(g_last): scalar, then broadcast to (P_MAX, 1) + exp_gl_11 = nl.ndarray((1, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.activation( + dst=exp_gl_11, + op=nl.exp, + data=gl_11, + bias=None, + scale=1.0, + ) + + exp_gl_p = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + for i_shuf in nl.static_range(P_MAX // 32): + nisa.nc_stream_shuffle( + src=exp_gl_11[0:1, 0:1], + dst=exp_gl_p[i_shuf * 32 : i_shuf * 32 + 32, 0:1], + shuffle_mask=_BROADCAST_MASK, + ) + + # ============================================================ + # Stable pairwise decay factors from cumulative log-decay. + # + # The original fused path used split scaling: + # exp(gc[i]) * exp(-gc[j]) + # That can materialize huge unused intermediates. Build the same + # causal decay matrices as the per-chunk kernel using exp(gc[i]-gc[j]) + # and mask after the exp so upper-triangular values cannot leak into + # later matmuls. + # ============================================================ + gc_row_broadcast = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + for i_shuf in nl.static_range(P_MAX // 32): + nisa.nc_stream_shuffle( + src=gc_row[0:1, 0:P_MAX], + dst=gc_row_broadcast[i_shuf * 32 : i_shuf * 32 + 32, 0:P_MAX], + shuffle_mask=_BROADCAST_MASK, + ) + + gc_col_strict = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=gc_col_strict, + data=Lmask, + op0=nl.multiply, + operand0=gc_p, + engine=nisa.vector_engine, + ) + gc_row_strict = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor( + dst=gc_row_strict, data1=gc_row_broadcast, data2=Lmask, op=nl.multiply + ) + g_diff_strict = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor( + dst=g_diff_strict, + data1=gc_col_strict, + data2=gc_row_strict, + op=nl.subtract, + ) + decay_strict_raw = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.activation( + dst=decay_strict_raw, + op=nl.exp, + data=g_diff_strict, + bias=None, + scale=1.0, + ) + decay_strict = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor( + dst=decay_strict, data1=decay_strict_raw, data2=Lmask, op=nl.multiply + ) + + gc_col_diag = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=gc_col_diag, + data=Lmask_d, + op0=nl.multiply, + operand0=gc_p, + engine=nisa.vector_engine, + ) + gc_row_diag = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor( + dst=gc_row_diag, data1=gc_row_broadcast, data2=Lmask_d, op=nl.multiply + ) + g_diff_diag = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor( + dst=g_diff_diag, + data1=gc_col_diag, + data2=gc_row_diag, + op=nl.subtract, + ) + decay_diag_raw = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.activation( + dst=decay_diag_raw, + op=nl.exp, + data=g_diff_diag, + bias=None, + scale=1.0, + ) + decay_diag = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor( + dst=decay_diag, data1=decay_diag_raw, data2=Lmask_d, op=nl.multiply + ) + + # ============================================================ + # k_beta = K * beta, v_beta = V * beta + # tensor_scalar broadcasts beta_p (P_MAX, 1) across free dim + # ============================================================ + k_beta = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=k_beta, + data=k_c, + op0=nl.multiply, + operand0=beta_p, + engine=nisa.vector_engine, + ) + + v_beta = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=v_beta, + data=v_c, + op0=nl.multiply, + operand0=beta_p, + engine=nisa.vector_engine, + ) + + # ============================================================ + # Phase 1: Build A matrix (intra-chunk correction) + # Transpose K and K_beta for matmul + # ============================================================ + kb_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=kb_T_psum, data=k_beta) + kb_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=kb_T, src=kb_T_psum) + + k_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=k_T_psum, data=k_c) + k_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=k_T, src=k_T_psum) + + # QK = k_beta^T @ k (contract over features) + QK_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=QK_psum, stationary=kb_T, moving=k_T) + QK = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=QK, src=QK_psum) + + # QK_decay[i,j] = QK[i,j] * exp(gc[i] - gc[j]) for i > j. + QK_decay = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=QK_decay, data1=QK, data2=decay_strict, op=nl.multiply) + + # A = -QK_decay * lower_mask + neg_QK_decay = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=neg_QK_decay, + data=QK_decay, + op0=nl.multiply, + operand0=-1.0, + engine=nisa.vector_engine, + ) + A_mat = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=A_mat, data1=neg_QK_decay, data2=Lmask, op=nl.multiply) + + # ============================================================ + # Build the single RHS needed for v_new. + # + # Materializing N = inv(I - A) would compute: + # value_corr = N @ v_beta + # k_cumdecay = N @ (k_beta * exp(gc)) + # v_new = value_corr - k_cumdecay @ state + # + # By associativity: + # v_new = N @ (v_beta - (k_beta * exp(gc)) @ state) + # + # Solve this RHS directly. This is equivalent to the nilpotent + # Neumann series, but avoids repeated matrix squaring, which is + # numerically unstable for realistic Qwen decay gates. + # ============================================================ + kb_exp_gc = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=kb_exp_gc, + data=k_beta, + op0=nl.multiply, + operand0=exp_gc_p, + engine=nisa.vector_engine, + ) + + kbe_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=kbe_T_psum, data=kb_exp_gc) + kbe_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=kbe_T, src=kbe_T_psum) + + kbe_state_psum = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=kbe_state_psum, stationary=kbe_T, moving=state) + kbe_state = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=kbe_state, src=kbe_state_psum) + + solve_rhs = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=solve_rhs, data1=v_beta, data2=kbe_state, op=nl.subtract) + + # ============================================================ + # Direct forward substitution for: + # v_new = solve((I - A_mat), solve_rhs) + # + # A_mat is strictly lower triangular, so row i only depends on rows + # < i. The full-matmul plus row-select form keeps the shape static + # and compiler-safe while updating exactly one solved row per step. + # ============================================================ + v_new = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.memset(dst=v_new, value=0.0) + + A_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=A_T_psum, data=A_mat) + A_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=A_T, src=A_T_psum) + + for solve_i in nl.static_range(P_MAX): + row_psum = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=row_psum, stationary=A_T, moving=v_new) + row_prod = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=row_prod, src=row_psum) + + row_with_rhs = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor( + dst=row_with_rhs, + data1=row_prod, + data2=solve_rhs, + op=nl.add, + ) + + row_mask = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy( + dst=row_mask[0:P_MAX, 0:1], + src=eye[0:P_MAX, solve_i : solve_i + 1], + ) + + row_update = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=row_update, + data=row_with_rhs, + op0=nl.multiply, + operand0=row_mask, + engine=nisa.vector_engine, + ) + + v_next = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=v_next, data1=v_new, data2=row_update, op=nl.add) + nisa.tensor_copy(dst=v_new, src=v_next) + + # ============================================================ + # Phase 2: Inter-chunk state propagation + # attn_intra = (q @ k^T) * decay_mask * lower_mask_diag + # ============================================================ + q_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=q_T_psum, data=q_c) + q_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=q_T, src=q_T_psum) + + qk_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=qk_psum, stationary=q_T, moving=k_T) + qk_raw = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=qk_raw, src=qk_psum) + + # qk_decay[i,j] = (q @ k^T)[i,j] * exp(gc[i] - gc[j]) for i >= j. + qk_decay = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=qk_decay, data1=qk_raw, data2=decay_diag, op=nl.multiply) + + attn_intra = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor( + dst=attn_intra, data1=qk_decay, data2=Lmask_d, op=nl.multiply + ) + + # ============================================================ + # attn_inter = (q * exp(gc)) @ state (state is in SBUF!) + # ============================================================ + q_exp = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=q_exp, + data=q_c, + op0=nl.multiply, + operand0=exp_gc_p, + engine=nisa.vector_engine, + ) + + qe_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=qe_T_psum, data=q_exp) + qe_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=qe_T, src=qe_T_psum) + + ai_psum = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=ai_psum, stationary=qe_T, moving=state) + attn_inter = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=attn_inter, src=ai_psum) + + # ============================================================ + # attn_intra @ v_new + # ============================================================ + ai_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=ai_T_psum, data=attn_intra) + ai_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=ai_T, src=ai_T_psum) + + intra_psum = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=intra_psum, stationary=ai_T, moving=v_new) + intra_out = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=intra_out, src=intra_psum) + + # ============================================================ + # chunk_output = attn_inter + intra_out + # ============================================================ + chunk_out = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=chunk_out, data1=attn_inter, data2=intra_out, op=nl.add) + + # Store output chunk to HBM + nisa.dma_copy( + dst=output[chunk_start : chunk_start + CHUNK_SIZE, 0:dim], + src=chunk_out, + ) + + # ============================================================ + # State update: state = exp(g_last) * (state + k_raw_decay^T @ v_new) + # state is updated IN-PLACE in SBUF — no HBM round-trip! + # ============================================================ + + # k_raw_decay contributes as exp(g_last) * (k * exp(-gc))^T @ v_new. + # Compute the equivalent stable form k * exp(g_last - gc) directly so + # no exp(-gc) intermediate can overflow. + gl_minus_gc_p = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor( + dst=gl_minus_gc_p, + data1=gl_p, + data2=gc_p, + op=nl.subtract, + ) + exp_gl_minus_gc_p = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.activation( + dst=exp_gl_minus_gc_p[0:P_MAX, 0:1], + op=nl.exp, + data=gl_minus_gc_p[0:P_MAX, 0:1], + bias=None, + scale=1.0, + ) + + # k_raw_decay = k * exp(g_last - gc) + k_raw_decay = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=k_raw_decay, + data=k_c, + op0=nl.multiply, + operand0=exp_gl_minus_gc_p, + engine=nisa.vector_engine, + ) + + # k_raw_decay^T @ v_new → (dim, dim) outer product sum + # nc_matmul: result[M,N] = sum_K stationary[K,M] * moving[K,N] + # stationary=k_raw_decay (P_MAX, dim), moving=v_new (P_MAX, dim) + # Result: sum over tokens -> (dim, dim) + kv_psum = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=kv_psum, stationary=k_raw_decay, moving=v_new) + kv_outer = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=kv_outer, src=kv_psum) + + # state = state * exp(g_last) + kv_outer + # tensor_scalar broadcasts exp_gl_p (P_MAX, 1) across free dim. + state_decayed = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=state_decayed, + data=state, + op0=nl.multiply, + operand0=exp_gl_p, + engine=nisa.vector_engine, + ) + nisa.tensor_tensor(dst=state, data1=state_decayed, data2=kv_outer, op=nl.add) + + # ---- Write final state to HBM ---- + nisa.dma_copy(dst=final_state_out, src=state) + + return output, final_state_out diff --git a/contrib/models/Qwen3.6-27B/src/nki_kernels/qwen_qk_norm_rope.py b/contrib/models/Qwen3.6-27B/src/nki_kernels/qwen_qk_norm_rope.py new file mode 100644 index 00000000..e5535254 --- /dev/null +++ b/contrib/models/Qwen3.6-27B/src/nki_kernels/qwen_qk_norm_rope.py @@ -0,0 +1,230 @@ +"""Qwen3.6-specific Q/K RMSNorm + partial-RoPE NKI kernel. + +The model's full-attention layers use head_dim=256 and partial RoPE over the +first 64 dimensions. This kernel consumes projected Q/K tensors in BSD layout +and returns normalized/rotated Q/K tensors in BHSD layout, replacing the +separate move_heads_front + q/k RMSNorm + partial RoPE torch ops. +""" + +import nki +import nki.isa as nisa +import nki.language as nl + +P_MAX = 128 +D_HEAD = 256 +ROPE_DIM = 64 +ROPE_HALF = 32 +_BROADCAST_MASK = [0] * 32 + + +def _broadcast_row_to_tile(row, out): + for i_shuf in nl.static_range(P_MAX // 32): + nisa.nc_stream_shuffle( + src=row[0:1, 0:D_HEAD], + dst=out[i_shuf * 32 : i_shuf * 32 + 32, 0:D_HEAD], + shuffle_mask=_BROADCAST_MASK, + ) + + +def _normalize_rope_store( + proj, + gamma, + cos_cache, + sin_cache, + out, + eps, +): + batch_size, seq_len, width = proj.shape + num_heads = width // D_HEAD + gamma_2d = gamma.reshape((1, D_HEAD)) + + gamma_row = nl.ndarray((1, D_HEAD), dtype=gamma.dtype, buffer=nl.sbuf) + nisa.dma_copy(dst=gamma_row, src=gamma_2d[0:1, 0:D_HEAD]) + gamma_tile = nl.ndarray((P_MAX, D_HEAD), dtype=gamma.dtype, buffer=nl.sbuf) + _broadcast_row_to_tile(gamma_row, gamma_tile) + + num_seq_tiles = (seq_len + P_MAX - 1) // P_MAX + for b_idx in nl.sequential_range(batch_size): + for h_idx in nl.sequential_range(num_heads): + col_start = h_idx * D_HEAD + for tile_idx in nl.affine_range(num_seq_tiles): + seq_start = tile_idx * P_MAX + p_size = min(P_MAX, seq_len - seq_start) + + x = nl.ndarray((P_MAX, D_HEAD), dtype=nl.float32, buffer=nl.sbuf) + nisa.dma_copy( + dst=x[0:p_size, 0:D_HEAD], + src=proj[ + b_idx, + seq_start : seq_start + p_size, + col_start : col_start + D_HEAD, + ], + ) + + square = nl.ndarray((P_MAX, D_HEAD), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor( + dst=square[0:p_size, 0:D_HEAD], + data1=x[0:p_size, 0:D_HEAD], + data2=x[0:p_size, 0:D_HEAD], + op=nl.multiply, + ) + + sumsq = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_reduce( + dst=sumsq[0:p_size, 0:1], + data=square[0:p_size, 0:D_HEAD], + op=nl.add, + axis=1, + ) + + variance = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=variance[0:p_size, 0:1], + data=sumsq[0:p_size, 0:1], + op0=nl.multiply, + operand0=(1.0 / D_HEAD), + ) + + variance_eps = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=variance_eps[0:p_size, 0:1], + data=variance[0:p_size, 0:1], + op0=nl.add, + operand0=eps, + ) + + inv_rms = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.activation( + dst=inv_rms[0:p_size, 0:1], + data=variance_eps[0:p_size, 0:1], + op=nl.rsqrt, + ) + + normed = nl.ndarray((P_MAX, D_HEAD), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=normed[0:p_size, 0:D_HEAD], + data=x[0:p_size, 0:D_HEAD], + op0=nl.multiply, + operand0=inv_rms[0:p_size, 0:1], + engine=nisa.vector_engine, + ) + nisa.tensor_tensor( + dst=normed[0:p_size, 0:D_HEAD], + data1=normed[0:p_size, 0:D_HEAD], + data2=gamma_tile[0:p_size, 0:D_HEAD], + op=nl.multiply, + ) + + nisa.dma_copy( + dst=out[b_idx, h_idx, seq_start : seq_start + p_size, 0:D_HEAD], + src=normed[0:p_size, 0:D_HEAD], + ) + + cos_tile = nl.ndarray((P_MAX, ROPE_DIM), dtype=nl.float32, buffer=nl.sbuf) + sin_tile = nl.ndarray((P_MAX, ROPE_DIM), dtype=nl.float32, buffer=nl.sbuf) + nisa.dma_copy( + dst=cos_tile[0:p_size, 0:ROPE_DIM], + src=cos_cache[b_idx, seq_start : seq_start + p_size, 0:ROPE_DIM], + ) + nisa.dma_copy( + dst=sin_tile[0:p_size, 0:ROPE_DIM], + src=sin_cache[b_idx, seq_start : seq_start + p_size, 0:ROPE_DIM], + ) + + neg_hi = nl.ndarray((P_MAX, ROPE_HALF), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=neg_hi[0:p_size, 0:ROPE_HALF], + data=normed[0:p_size, ROPE_HALF:ROPE_DIM], + op0=nl.multiply, + operand0=-1.0, + ) + + lo_cos = nl.ndarray((P_MAX, ROPE_HALF), dtype=nl.float32, buffer=nl.sbuf) + hi_sin_neg = nl.ndarray((P_MAX, ROPE_HALF), dtype=nl.float32, buffer=nl.sbuf) + rope_lo = nl.ndarray((P_MAX, ROPE_HALF), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor( + dst=lo_cos[0:p_size, 0:ROPE_HALF], + data1=normed[0:p_size, 0:ROPE_HALF], + data2=cos_tile[0:p_size, 0:ROPE_HALF], + op=nl.multiply, + ) + nisa.tensor_tensor( + dst=hi_sin_neg[0:p_size, 0:ROPE_HALF], + data1=neg_hi[0:p_size, 0:ROPE_HALF], + data2=sin_tile[0:p_size, 0:ROPE_HALF], + op=nl.multiply, + ) + nisa.tensor_tensor( + dst=rope_lo[0:p_size, 0:ROPE_HALF], + data1=lo_cos[0:p_size, 0:ROPE_HALF], + data2=hi_sin_neg[0:p_size, 0:ROPE_HALF], + op=nl.add, + ) + + hi_cos = nl.ndarray((P_MAX, ROPE_HALF), dtype=nl.float32, buffer=nl.sbuf) + lo_sin = nl.ndarray((P_MAX, ROPE_HALF), dtype=nl.float32, buffer=nl.sbuf) + rope_hi = nl.ndarray((P_MAX, ROPE_HALF), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor( + dst=hi_cos[0:p_size, 0:ROPE_HALF], + data1=normed[0:p_size, ROPE_HALF:ROPE_DIM], + data2=cos_tile[0:p_size, ROPE_HALF:ROPE_DIM], + op=nl.multiply, + ) + nisa.tensor_tensor( + dst=lo_sin[0:p_size, 0:ROPE_HALF], + data1=normed[0:p_size, 0:ROPE_HALF], + data2=sin_tile[0:p_size, ROPE_HALF:ROPE_DIM], + op=nl.multiply, + ) + nisa.tensor_tensor( + dst=rope_hi[0:p_size, 0:ROPE_HALF], + data1=hi_cos[0:p_size, 0:ROPE_HALF], + data2=lo_sin[0:p_size, 0:ROPE_HALF], + op=nl.add, + ) + + nisa.dma_copy( + dst=out[b_idx, h_idx, seq_start : seq_start + p_size, 0:ROPE_HALF], + src=rope_lo[0:p_size, 0:ROPE_HALF], + ) + nisa.dma_copy( + dst=out[ + b_idx, + h_idx, + seq_start : seq_start + p_size, + ROPE_HALF:ROPE_DIM, + ], + src=rope_hi[0:p_size, 0:ROPE_HALF], + ) + + +@nki.jit +def qwen_qk_norm_partial_rope_kernel( + q_proj: nl.ndarray, + k_proj: nl.ndarray, + q_gamma: nl.ndarray, + k_gamma: nl.ndarray, + cos_cache: nl.ndarray, + sin_cache: nl.ndarray, + eps: float, +): + batch_size, seq_len, q_width = q_proj.shape + _, _, k_width = k_proj.shape + q_heads = q_width // D_HEAD + k_heads = k_width // D_HEAD + + q_out = nl.ndarray( + (batch_size, q_heads, seq_len, D_HEAD), + dtype=q_proj.dtype, + buffer=nl.shared_hbm, + ) + k_out = nl.ndarray( + (batch_size, k_heads, seq_len, D_HEAD), + dtype=k_proj.dtype, + buffer=nl.shared_hbm, + ) + + _normalize_rope_store(q_proj, q_gamma, cos_cache, sin_cache, q_out, eps) + _normalize_rope_store(k_proj, k_gamma, cos_cache, sin_cache, k_out, eps) + + return q_out, k_out diff --git a/contrib/models/Qwen3.6-27B/test/__init__.py b/contrib/models/Qwen3.6-27B/test/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/contrib/models/Qwen3.6-27B/test/integration/__init__.py b/contrib/models/Qwen3.6-27B/test/integration/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/contrib/models/Qwen3.6-27B/test/integration/qwen36_27b_compile_fp8.py b/contrib/models/Qwen3.6-27B/test/integration/qwen36_27b_compile_fp8.py new file mode 100644 index 00000000..651c313c --- /dev/null +++ b/contrib/models/Qwen3.6-27B/test/integration/qwen36_27b_compile_fp8.py @@ -0,0 +1,1705 @@ +#!/usr/bin/env python3 +"""Compile Qwen3.6-27B 64K with scoped weight-mode ablations. + +This script intentionally starts from the validated 64K hybrid/chunked-prefill +baseline and changes only weight quantization. Supported modes: + +* ``fp8_mlp_only``: MLP linear weights are converted to FP8 while attention, + DeltaNet, normalization, embeddings, lm_head, KV cache, and recurrent state + remain BF16. +* ``fp8_full``: all supported linear/matmul weights are converted to FP8 while + embeddings, normalization, rotary state, DeltaNet recurrent/conv state, KV + cache, and lm_head remain BF16 by default. +* ``bf16_control``: no FP8 conversion; this is the real-token host-logits + control for separating FP8 conversion failures from serving/logits failures. +""" + +from __future__ import annotations + +import argparse +import gc +import json +import os +import sys +from pathlib import Path + +import torch + + +_FP8_ENV_DEFAULTS = { + "XLA_HANDLE_SPECIAL_SCALAR": "1", + "UNSAFE_FP8FNCAST": "1", +} + +_WEIGHT_DTYPE_FP8_MLP_ONLY = "fp8_mlp_only" +_WEIGHT_DTYPE_FP8_FULL = "fp8_full" +_FP8_EXCLUDE_GROUPS = { + "linear_attn", + "linear_attn_qkv", + "linear_attn_z", + "linear_attn_out_proj", + "mlp", + "self_attn", + "self_attn_qkv", + "self_attn_o_proj", +} +_WEIGHT_DTYPE_BF16_CONTROL = "bf16_control" +_FP8_WLO_SKIP_PATTERNS = [ + r".*\.scale$", + r".*\.weight_scale$", + r".*linear_attn\.conv1d_weight\.weight$", +] +_DISABLE_TOKEN_GENERATION_WLO_ENV = "QWEN36_DISABLE_TOKEN_GENERATION_WLO" +_DELTANET_CTE_BACKEND_ENV = { + "USE_NKI_FUSED", + "USE_NKI_CHUNKED", + "USE_NKI", + "DELTANET_SEQUENTIAL", + "USE_PYTORCH_CHUNK", +} + + +def _ensure_fp8_environment() -> None: + for name, value in _FP8_ENV_DEFAULTS.items(): + os.environ.setdefault(name, value) + + +def _repo_root(path: str | None) -> Path: + if path: + return Path(path).expanduser().resolve() + return Path(__file__).resolve().parents[5] + + +def _load_text_config(model_path: Path) -> dict: + with (model_path / "config.json").open() as f: + full_config = json.load(f) + text_config = full_config.get("text_config", full_config) + config_dict = dict(text_config) + config_dict["pad_token_id"] = text_config.get("eos_token_id", 248044) + if "rope_parameters" in text_config: + config_dict["rope_theta"] = text_config["rope_parameters"].get( + "rope_theta", 10000000 + ) + config_dict.setdefault("tie_word_embeddings", False) + return config_dict + + +def _sanitize_reloadable_neuron_config(compiled_path: Path) -> None: + """Keep direct-cast KV quant config reloadable after JSON serialization.""" + config_path = compiled_path / "neuron_config.json" + if not config_path.exists(): + return + + config = json.loads(config_path.read_text()) + neuron_config = config.get("neuron_config", config) + kv_quant_config = neuron_config.get("kv_quant_config") + if not isinstance(kv_quant_config, dict): + return + if not kv_quant_config.get("direct_cast", True): + return + + # Neuron serializes QuantizationType enum defaults as nested JSON objects. + # KVQuantizationConfig expects real enum values on reload, so omit those + # fields and let the constructor restore its per-tensor symmetric defaults. + neuron_config["kv_quant_config"] = {"direct_cast": True} + config_path.write_text(json.dumps(config, indent=2, sort_keys=True) + "\n") + + +def _compiled_parameter_dtype(inf_config) -> torch.dtype: + dtype = getattr(inf_config.neuron_config, "torch_dtype", torch.bfloat16) + if isinstance(dtype, torch.dtype): + return dtype + if isinstance(dtype, str): + dtype_name = dtype.removeprefix("torch.") + return { + "bfloat16": torch.bfloat16, + "float16": torch.float16, + "float32": torch.float32, + }.get(dtype_name, torch.bfloat16) + return torch.bfloat16 + + +def _hybrid_cache_torch_dtype(value, default: torch.dtype) -> torch.dtype: + if value is None: + return default + if isinstance(value, torch.dtype): + return value + normalized = str(value).lower().removeprefix("torch.") + if normalized in {"fp32", "float32"}: + return torch.float32 + if normalized in {"bf16", "bfloat16"}: + return torch.bfloat16 + return default + + +def _ensure_hybrid_checkpoint_weights(compiled_path: Path, inf_config) -> None: + """Add reloadable zero checkpoint-bank tensors when NxD omits them.""" + gdn_layer_ids = [ + idx + for idx, layer_type in enumerate(getattr(inf_config, "layer_types", ())) + if layer_type == "linear_attention" + ] + weights_dir = compiled_path / "weights" + if not gdn_layer_ids or not weights_dir.exists(): + return + + from safetensors import safe_open # noqa: WPS433 + from safetensors.torch import load_file, save_file # noqa: WPS433 + + tp_degree = int(inf_config.neuron_config.tp_degree) + local_num_value_heads = int(inf_config.linear_num_value_heads) // tp_degree + local_num_key_heads = int(inf_config.linear_num_key_heads) // tp_degree + key_dim = int(inf_config.linear_key_head_dim) + value_dim = int(inf_config.linear_value_head_dim) + slots = int(inf_config.max_gdn_checkpoint_slots) + conv_dim = 2 * local_num_key_heads * key_dim + local_num_value_heads * value_dim + conv_state_len = int(inf_config.linear_conv_kernel_dim) - 1 + default_param_dtype = _compiled_parameter_dtype(inf_config) + recurrent_param_dtype = _hybrid_cache_torch_dtype( + getattr( + inf_config, + "hybrid_recurrent_cache_dtype", + getattr(inf_config, "gdn_recurrent_cache_dtype", None), + ), + torch.float32, + ) + conv_param_dtype = _hybrid_cache_torch_dtype( + getattr( + inf_config, + "hybrid_conv_cache_dtype", + getattr(inf_config, "gdn_conv_cache_dtype", None), + ), + default_param_dtype, + ) + + recurrent_shape = (slots, local_num_value_heads, key_dim, value_dim) + conv_shape = (slots, conv_dim, conv_state_len) + recurrent_keys = [ + f"hybrid_gdn_checkpoint_cache.recurrent_slots.{idx}" + for idx in range(len(gdn_layer_ids)) + ] + conv_keys = [ + f"hybrid_gdn_checkpoint_cache.conv_slots.{idx}" + for idx in range(len(gdn_layer_ids)) + ] + + for shard in sorted(weights_dir.glob("tp*_sharded_checkpoint.safetensors")): + with safe_open(shard, framework="pt", device="cpu") as handle: + existing = set(handle.keys()) + metadata = handle.metadata() + missing_recurrent = [key for key in recurrent_keys if key not in existing] + missing_conv = [key for key in conv_keys if key not in existing] + if not missing_recurrent and not missing_conv: + continue + + tensors = load_file(shard, device="cpu") + for key in missing_recurrent: + tensors[key] = torch.zeros(recurrent_shape, dtype=recurrent_param_dtype) + for key in missing_conv: + tensors[key] = torch.zeros(conv_shape, dtype=conv_param_dtype) + + tmp_path = shard.with_suffix(shard.suffix + ".tmp") + save_file(tensors, tmp_path, metadata=metadata) + os.replace(tmp_path, shard) + print( + "CHECKPOINT_BANK_WEIGHTS_ADDED", + shard.name, + len(missing_recurrent), + len(missing_conv), + str(recurrent_param_dtype), + str(conv_param_dtype), + flush=True, + ) + + +def _parse_int_list(values: list[str] | None) -> list[int] | None: + if values is None: + return None + tokens: list[str] = [] + for value in values: + tokens.extend(value.replace(",", " ").split()) + return [int(token) for token in tokens] + + +def _parse_bucket_pairs(values: list[str] | None) -> list[tuple[int, int]] | None: + if values is None: + return None + pairs: list[tuple[int, int]] = [] + for value in values: + for token in value.replace(",", " ").split(): + if ":" in token: + active, prefix = token.split(":", 1) + elif "x" in token: + active, prefix = token.split("x", 1) + else: + raise ValueError( + "--context-encoding-bucket-pairs entries must use " + f"ACTIVE:PREFIX syntax, got {token!r}" + ) + pairs.append((int(active), int(prefix))) + return pairs + + +def _cte_buckets(args: argparse.Namespace) -> list[int]: + buckets = _parse_int_list(args.cte_buckets) or [args.cte_bucket] + buckets = sorted(set(buckets)) + if not buckets: + raise ValueError("At least one CTE bucket is required") + for bucket in buckets: + if bucket <= 0: + raise ValueError(f"CTE buckets must be positive, got {bucket}") + if bucket % 128 != 0: + raise ValueError( + f"CTE bucket {bucket} is not 128-aligned; DeltaNet CTE uses 128-token chunks" + ) + if buckets[-1] > args.seq_len: + raise ValueError( + f"Largest CTE bucket {buckets[-1]} exceeds --seq-len {args.seq_len}" + ) + return buckets + + +def _prefix_buckets(args: argparse.Namespace, cte_buckets: list[int]) -> list[int]: + buckets = _parse_int_list(args.prefix_buckets) or cte_buckets + buckets = sorted(set(buckets)) + if not buckets: + raise ValueError("At least one prefix bucket is required") + for bucket in buckets: + if bucket <= 0: + raise ValueError(f"Prefix buckets must be positive, got {bucket}") + if bucket % args.block_size != 0: + raise ValueError( + f"Prefix bucket {bucket} must be divisible by block size {args.block_size}" + ) + if buckets[-1] > args.seq_len: + raise ValueError( + f"Largest prefix bucket {buckets[-1]} exceeds --seq-len {args.seq_len}" + ) + return buckets + + +def _context_encoding_bucket_pairs( + args: argparse.Namespace, + cte_buckets: list[int], + prefix_buckets: list[int], +) -> list[list[int]] | None: + raw_pairs = _parse_bucket_pairs(args.context_encoding_bucket_pairs) + if raw_pairs is None: + return None + + cte_bucket_set = set(cte_buckets) + prefix_bucket_set = set(prefix_buckets) + pairs = set() + if not getattr(args, "omit_zero_prefix_pair", False): + pairs.update((cte_bucket, 0) for cte_bucket in cte_buckets) + for active_tokens, prefix_tokens in raw_pairs: + if active_tokens not in cte_bucket_set: + raise ValueError( + "--context-encoding-bucket-pairs active bucket must be present " + f"in --cte-buckets, got {active_tokens} with {cte_buckets}" + ) + if prefix_tokens < 0: + raise ValueError( + "--context-encoding-bucket-pairs prefix bucket must be " + f"non-negative, got {prefix_tokens}" + ) + if prefix_tokens > 0 and prefix_tokens not in prefix_bucket_set: + raise ValueError( + "--context-encoding-bucket-pairs prefix bucket must be 0 or " + f"present in --prefix-buckets, got {prefix_tokens} with " + f"{prefix_buckets}" + ) + pairs.add((active_tokens, prefix_tokens)) + + prefix_order = {0: 0} + prefix_order.update( + {prefix_bucket: index + 1 for index, prefix_bucket in enumerate(prefix_buckets)} + ) + cte_order = {cte_bucket: index for index, cte_bucket in enumerate(cte_buckets)} + return [ + [active_tokens, prefix_tokens] + for active_tokens, prefix_tokens in sorted( + pairs, + key=lambda pair: (cte_order[pair[0]], prefix_order[pair[1]]), + ) + ] + + +def _token_generation_buckets(args: argparse.Namespace) -> list[int]: + buckets = _parse_int_list(args.token_generation_buckets) or [args.seq_len] + buckets = sorted(set(buckets)) + if not buckets: + raise ValueError("At least one token-generation bucket is required") + for bucket in buckets: + if bucket <= 0: + raise ValueError( + f"Token-generation buckets must be positive, got {bucket}" + ) + if bucket > args.seq_len: + raise ValueError( + f"Token-generation bucket {bucket} exceeds --seq-len {args.seq_len}" + ) + return buckets + + +def _token_generation_batches(args: argparse.Namespace) -> list[int] | None: + batches = _parse_int_list(args.token_generation_batches) + if batches is None: + return None + batches = sorted(set(batches)) + if not batches: + raise ValueError("Token-generation batches cannot be empty") + for batch in batches: + if batch <= 0: + raise ValueError( + f"Token-generation batches must be positive, got {batch}" + ) + if batch > args.max_num_seqs: + raise ValueError( + f"Token-generation batch {batch} exceeds --max-num-seqs " + f"{args.max_num_seqs}" + ) + return batches + + +def _weights_to_skip_layout_optimization(args: argparse.Namespace) -> list[str]: + patterns: list[str] = [] + if args.weight_dtype in (_WEIGHT_DTYPE_FP8_MLP_ONLY, _WEIGHT_DTYPE_FP8_FULL): + patterns.extend(_FP8_WLO_SKIP_PATTERNS) + patterns.extend(getattr(args, "weights_to_skip_layout_optimization", None) or []) + return list(dict.fromkeys(patterns)) + + +def _disable_token_generation_wlo(args: argparse.Namespace) -> bool: + return bool(getattr(args, "disable_token_generation_wlo", False)) or ( + os.environ.get(_DISABLE_TOKEN_GENERATION_WLO_ENV) == "1" + ) + + +def _validate_prefix_buckets_fit_context( + args: argparse.Namespace, + max_context_length: int, + prefix_buckets: list[int], +) -> None: + if not (args.enable_prefix_caching or args.enable_hybrid_apc): + return + if prefix_buckets[-1] > max_context_length: + raise ValueError( + f"Largest prefix bucket {prefix_buckets[-1]} exceeds " + f"--max-context-length {max_context_length}. Long-context APC needs " + "--max-context-length to cover the largest reusable prefix bucket." + ) + + +def _max_context_length(args: argparse.Namespace, cte_buckets: list[int]) -> int: + max_context_length = args.max_context_length or cte_buckets[-1] + if max_context_length < cte_buckets[-1]: + raise ValueError( + f"--max-context-length {max_context_length} is smaller than largest " + f"CTE bucket {cte_buckets[-1]}" + ) + if max_context_length > args.seq_len: + raise ValueError( + f"--max-context-length {max_context_length} exceeds --seq-len {args.seq_len}" + ) + return max_context_length + + +def _pa_min_blocks(args: argparse.Namespace) -> int: + return max( + 1, + ((args.seq_len + args.block_size - 1) // args.block_size) + * args.max_num_seqs, + ) + + +def _pa_requested_blocks(args: argparse.Namespace) -> int: + min_blocks = _pa_min_blocks(args) + if args.pa_num_blocks is None: + requested_blocks = min_blocks + max(0, int(args.pa_headroom_blocks)) + else: + requested_blocks = int(args.pa_num_blocks) + if requested_blocks < min_blocks: + raise ValueError( + f"--pa-num-blocks {requested_blocks} is too small for seq_len=" + f"{args.seq_len} and block_size={args.block_size}; need at least {min_blocks}" + ) + return requested_blocks + + +def _pa_num_blocks(args: argparse.Namespace) -> int: + # Keep this value identical to vLLM's --num-gpu-blocks-override / + # NeuronConfig.pa_num_blocks contract. NxDI's BlockKVCacheManager accounts + # for its own reserved internal block when prefix caching is enabled. + return _pa_requested_blocks(args) + + +def _configure_base_compile_work_dir( + compiled_path: Path, + requested_work_dir: str | None, +) -> Path: + if requested_work_dir: + work_dir = Path(requested_work_dir).expanduser().resolve() + else: + existing_work_dir = os.environ.get("BASE_COMPILE_WORK_DIR") + if existing_work_dir: + work_dir = Path(existing_work_dir).expanduser().resolve() + else: + work_dir = (compiled_path.parent / "_nxd_model_workdir").resolve() + + work_dir.mkdir(parents=True, exist_ok=True) + os.environ["BASE_COMPILE_WORK_DIR"] = str(work_dir) + return work_dir + + +def _configure_deltanet_cte_backend(backend: str) -> None: + """Select the DeltaNet CTE implementation used while tracing the artifact.""" + if backend == "env": + return + + for name in _DELTANET_CTE_BACKEND_ENV: + os.environ.pop(name, None) + + if backend == "fused": + os.environ["USE_NKI_FUSED"] = "1" + elif backend == "nki_chunked": + os.environ["USE_NKI_FUSED"] = "0" + os.environ["USE_NKI_CHUNKED"] = "1" + elif backend == "pytorch_chunk": + os.environ["USE_NKI_FUSED"] = "0" + os.environ["USE_PYTORCH_CHUNK"] = "1" + elif backend == "sequential": + os.environ["USE_NKI_FUSED"] = "0" + os.environ["DELTANET_SEQUENTIAL"] = "1" + elif backend == "nki_recurrent": + os.environ["USE_NKI_FUSED"] = "0" + os.environ["USE_NKI"] = "1" + else: + raise ValueError(f"Unsupported DeltaNet CTE backend: {backend}") + + +def _mlp_only_modules_to_not_convert(num_layers: int) -> list[str]: + """Exclude numerically sensitive or unsupported modules from FP8 conversion.""" + modules = [ + "embed_tokens", + "model.embed_tokens", + "lm_head", + "norm", + "model.norm", + "rotary_emb", + "model.rotary_emb", + ] + for layer_idx in range(num_layers): + for prefix in ("layers", "model.layers"): + modules.extend( + [ + f"{prefix}.{layer_idx}.self_attn", + f"{prefix}.{layer_idx}.linear_attn", + f"{prefix}.{layer_idx}.input_layernorm", + f"{prefix}.{layer_idx}.post_attention_layernorm", + ] + ) + return modules + + +def _full_fp8_modules_to_not_convert( + num_layers: int, + *, + quantize_lm_head: bool, + quantize_linear_attn_gates: bool = False, + fp8_exclude_groups: set[str] | None = None, +) -> list[str]: + """Exclude non-linear or sensitive modules from full FP8 conversion. + + This follows the common NVIDIA/vLLM policy: quantize eligible Linear + matmuls, keep lm_head in higher precision unless explicitly requested, and + keep normalization/cache/state tensors unquantized. + """ + fp8_exclude_groups = fp8_exclude_groups or set() + modules = [ + "embed_tokens", + "model.embed_tokens", + "norm", + "model.norm", + "rotary_emb", + "model.rotary_emb", + "mrope_emb", + "model.mrope_emb", + ] + if not quantize_lm_head: + modules.extend(["lm_head", "model.lm_head"]) + + modules.extend( + [ + "hybrid_gdn_checkpoint_cache.recurrent_slots", + "hybrid_gdn_checkpoint_cache.conv_slots", + "model.hybrid_gdn_checkpoint_cache.recurrent_slots", + "model.hybrid_gdn_checkpoint_cache.conv_slots", + ] + ) + + for layer_idx in range(num_layers): + for prefix in ("layers", "model.layers"): + layer_prefix = f"{prefix}.{layer_idx}" + modules.extend( + [ + f"{layer_prefix}.input_layernorm", + f"{layer_prefix}.post_attention_layernorm", + f"{layer_prefix}.self_attn.q_norm", + f"{layer_prefix}.self_attn.k_norm", + f"{layer_prefix}.self_attn.q_layernorm", + f"{layer_prefix}.self_attn.k_layernorm", + f"{layer_prefix}.self_attn.rotary_emb", + f"{layer_prefix}.self_attn.mrope_emb", + f"{layer_prefix}.linear_attn.conv1d", + f"{layer_prefix}.linear_attn.conv1d_weight", + f"{layer_prefix}.linear_attn.A_log", + f"{layer_prefix}.linear_attn.A_log_weight", + f"{layer_prefix}.linear_attn.dt_bias", + f"{layer_prefix}.linear_attn.dt_bias_weight", + f"{layer_prefix}.linear_attn.norm", + f"{layer_prefix}.linear_attn.recurrent_state_buffer", + f"{layer_prefix}.linear_attn.conv_state_buffer", + ] + ) + if not quantize_linear_attn_gates: + modules.extend( + [ + f"{layer_prefix}.linear_attn.in_proj_a", + f"{layer_prefix}.linear_attn.in_proj_b", + f"{layer_prefix}.linear_attn.in_proj_ba", + ] + ) + if "linear_attn" in fp8_exclude_groups: + modules.append(f"{layer_prefix}.linear_attn") + else: + if "linear_attn_qkv" in fp8_exclude_groups: + modules.append(f"{layer_prefix}.linear_attn.in_proj_qkv") + if "linear_attn_z" in fp8_exclude_groups: + modules.append(f"{layer_prefix}.linear_attn.in_proj_z") + if "linear_attn_out_proj" in fp8_exclude_groups: + modules.append(f"{layer_prefix}.linear_attn.out_proj") + if "mlp" in fp8_exclude_groups: + modules.append(f"{layer_prefix}.mlp") + if "self_attn" in fp8_exclude_groups: + modules.append(f"{layer_prefix}.self_attn") + else: + if "self_attn_qkv" in fp8_exclude_groups: + for proj_name in ("q_proj", "k_proj", "v_proj"): + modules.append(f"{layer_prefix}.self_attn.{proj_name}") + if "self_attn_o_proj" in fp8_exclude_groups: + modules.append(f"{layer_prefix}.self_attn.o_proj") + return modules + + +def _quantized_checkpoint_ready(path: Path) -> bool: + if path.is_file(): + return True + if path.is_dir(): + return any(path.iterdir()) + return False + + +def _mlp_layer_idx(name: str) -> int | None: + parts = name.split(".") + if len(parts) < 4: + return None + for idx, part in enumerate(parts[:-3]): + if part == "layers" and idx + 1 < len(parts): + try: + return int(parts[idx + 1]) + except ValueError: + return None + return None + + +def _is_mlp_weight( + name: str, + *, + num_layers: int, + quantize_edge_mlp_layers: bool, +) -> bool: + parts = name.split(".") + if not ( + len(parts) >= 4 + and parts[-3] == "mlp" + and parts[-2] in {"gate_proj", "up_proj", "down_proj"} + and parts[-1] == "weight" + ): + return False + if quantize_edge_mlp_layers: + return True + layer_idx = _mlp_layer_idx(name) + if layer_idx is None: + return True + return layer_idx not in {0, num_layers - 1} + + +def _is_full_fp8_weight( + name: str, + *, + quantize_lm_head: bool, + quantize_linear_attn_gates: bool = False, + fp8_exclude_groups: set[str] | None = None, +) -> bool: + fp8_exclude_groups = fp8_exclude_groups or set() + if not name.endswith(".weight"): + return False + parts = name.split(".") + if len(parts) >= 2 and parts[-2] == "lm_head": + return quantize_lm_head + if len(parts) < 4: + return False + + module_name = parts[-3] + projection_name = parts[-2] + if module_name == "mlp" and "mlp" in fp8_exclude_groups: + return False + if module_name == "self_attn": + if "self_attn" in fp8_exclude_groups: + return False + if projection_name in {"q_proj", "k_proj", "v_proj"} and ( + "self_attn_qkv" in fp8_exclude_groups + ): + return False + if projection_name == "o_proj" and "self_attn_o_proj" in fp8_exclude_groups: + return False + if module_name == "linear_attn": + if "linear_attn" in fp8_exclude_groups: + return False + if projection_name in {"in_proj_a", "in_proj_b"}: + return quantize_linear_attn_gates + if projection_name == "in_proj_qkv" and ( + "linear_attn_qkv" in fp8_exclude_groups + ): + return False + if projection_name == "in_proj_z" and "linear_attn_z" in fp8_exclude_groups: + return False + if ( + projection_name == "out_proj" + and "linear_attn_out_proj" in fp8_exclude_groups + ): + return False + supported_projection_names = { + "mlp": {"gate_proj", "up_proj", "down_proj"}, + "self_attn": {"q_proj", "k_proj", "v_proj", "o_proj"}, + "linear_attn": { + "in_proj_qkv", + "in_proj_z", + "in_proj_a", + "in_proj_b", + "out_proj", + }, + } + return projection_name in supported_projection_names.get(module_name, set()) + + +def _scale_name(weight_name: str) -> str: + return weight_name[: -len(".weight")] + ".weight_scale" + + +def _clear_quantized_checkpoint_dir(path: Path) -> None: + path.mkdir(parents=True, exist_ok=True) + for child in path.iterdir(): + if child.name.endswith(".safetensors") or child.name.endswith(".json"): + child.unlink() + + +def _save_manual_fp8_state_dict( + model_path: Path, + output_path: Path, + *, + weight_dtype: str, + quantize_edge_mlp_layers: bool, + quantize_lm_head: bool, + quantize_linear_attn_gates: bool = False, + fp8_exclude_groups: set[str] | None = None, +) -> None: + """Create a sharded FP8 checkpoint directly from HF safetensors. + + Loading the HF architecture requires a newer Transformers than the Neuron + venv uses internally. For these FP8 ablations, we do not need model + execution: the checkpoint transform is a direct tensor rewrite. + """ + from safetensors.torch import load_file, save_file # noqa: WPS433 + from neuronx_distributed.quantization.quantization_utils import ( # noqa: WPS433 + quantize_fp8_per_channel, + ) + + num_layers = int(_load_text_config(model_path)["num_hidden_layers"]) + fp8_exclude_groups = fp8_exclude_groups or set() + index_path = model_path / "model.safetensors.index.json" + if index_path.exists(): + with index_path.open() as f: + source_index = json.load(f) + source_weight_map = source_index["weight_map"] + filenames = sorted(set(source_weight_map.values())) + elif (model_path / "model.safetensors").exists(): + source_weight_map = None + filenames = ["model.safetensors"] + else: + raise FileNotFoundError(f"No safetensors checkpoint found in {model_path}") + + _clear_quantized_checkpoint_dir(output_path) + output_weight_map: dict[str, str] = {} + total_size = 0 + quantized_count = 0 + + for filename in filenames: + shard = load_file(str(model_path / filename)) + output_shard = {} + for name, tensor in shard.items(): + if weight_dtype == _WEIGHT_DTYPE_FP8_MLP_ONLY: + should_quantize = _is_mlp_weight( + name, + num_layers=num_layers, + quantize_edge_mlp_layers=quantize_edge_mlp_layers, + ) + elif weight_dtype == _WEIGHT_DTYPE_FP8_FULL: + should_quantize = _is_full_fp8_weight( + name, + quantize_lm_head=quantize_lm_head, + quantize_linear_attn_gates=quantize_linear_attn_gates, + fp8_exclude_groups=fp8_exclude_groups, + ) + else: + raise ValueError(f"Unsupported FP8 weight dtype: {weight_dtype}") + + if should_quantize: + weight, scale = quantize_fp8_per_channel( + tensor, + torch.float8_e4m3fn, + channel_axis=0, + ) + output_shard[name] = weight + output_shard[_scale_name(name)] = scale + output_weight_map[_scale_name(name)] = filename + total_size += weight.numel() * weight.element_size() + total_size += scale.numel() * scale.element_size() + quantized_count += 1 + else: + output_shard[name] = tensor + total_size += tensor.numel() * tensor.element_size() + output_weight_map[name] = filename + + save_file(output_shard, str(output_path / filename), metadata={"format": "pt"}) + del shard + del output_shard + gc.collect() + + if source_weight_map is not None: + with (output_path / "model.safetensors.index.json").open("w") as f: + json.dump( + { + "metadata": {"total_size": total_size}, + "weight_map": output_weight_map, + }, + f, + indent=2, + sort_keys=True, + ) + + print("MANUAL_FP8_WEIGHT_COUNT", quantized_count, flush=True) + + +def _build_config(args: argparse.Namespace): + from neuronx_distributed_inference.models.config import ( # noqa: WPS433 + NeuronConfig, + OnDeviceSamplingConfig, + ) + from src.modeling_qwen35 import Qwen35InferenceConfig # noqa: WPS433 + + model_path = Path(args.model_path).expanduser().resolve() + config_dict = _load_text_config(model_path) + num_layers = int(config_dict["num_hidden_layers"]) + fp8_exclude_groups = set(getattr(args, "fp8_exclude_groups", []) or []) + if args.weight_dtype == _WEIGHT_DTYPE_FP8_FULL: + modules_to_not_convert = _full_fp8_modules_to_not_convert( + num_layers, + quantize_lm_head=args.quantize_lm_head, + quantize_linear_attn_gates=args.fp8_quantize_linear_attn_gates, + fp8_exclude_groups=fp8_exclude_groups, + ) + else: + modules_to_not_convert = _mlp_only_modules_to_not_convert(num_layers) + if ( + args.weight_dtype == _WEIGHT_DTYPE_FP8_MLP_ONLY + and not args.quantize_edge_mlp_layers + ): + for layer_idx in (0, num_layers - 1): + for prefix in ("layers", "model.layers"): + modules_to_not_convert.append(f"{prefix}.{layer_idx}.mlp") + cte_buckets = _cte_buckets(args) + max_context_length = _max_context_length(args, cte_buckets) + prefix_buckets = _prefix_buckets(args, cte_buckets) + context_encoding_bucket_pairs = _context_encoding_bucket_pairs( + args, + cte_buckets, + prefix_buckets, + ) + token_generation_buckets = _token_generation_buckets(args) + token_generation_batches = _token_generation_batches(args) + _validate_prefix_buckets_fit_context(args, max_context_length, prefix_buckets) + + neuron_config_kwargs = { + "tp_degree": args.tp_degree, + "batch_size": args.max_num_seqs, + "ctx_batch_size": args.ctx_batch_size, + "tkg_batch_size": args.max_num_seqs, + "seq_len": args.seq_len, + "max_context_length": max_context_length, + "max_length": args.seq_len, + "context_encoding_buckets": cte_buckets, + "token_generation_buckets": token_generation_buckets, + "torch_dtype": torch.bfloat16, + "enable_bucketing": len(cte_buckets) > 1 + or len(token_generation_buckets) > 1, + "logical_nc_config": args.logical_nc_config, + "save_sharded_checkpoint": True, + "skip_warmup": args.skip_warmup, + } + if args.async_mode: + neuron_config_kwargs["async_mode"] = True + if token_generation_batches is not None: + neuron_config_kwargs["token_generation_batches"] = token_generation_batches + if ( + args.enable_fused_qkv + or args.enable_qkv_nki_kernels + or args.enable_attn_block_tkg_nki_kernel + ): + neuron_config_kwargs["fused_qkv"] = True + if ( + args.enable_qkv_nki_kernels + or args.enable_attn_block_tkg_nki_kernel + ): + neuron_config_kwargs["qkv_kernel_enabled"] = True + neuron_config_kwargs["qkv_nki_kernel_enabled"] = True + if args.enable_qkv_cte_nki_kernel_fuse_rope: + rope_dim = config_dict.get("rope_dim") + head_dim = config_dict.get("head_dim") + if rope_dim is not None and head_dim is not None and int(rope_dim) != int(head_dim): + raise ValueError( + "--enable-qkv-cte-nki-kernel-fuse-rope is not valid for " + f"partial-RoPE Qwen3.6 configs: rope_dim={rope_dim}, " + f"head_dim={head_dim}. The stock fused-RoPE QKV kernel expects " + "cos/sin to cover the full head dimension." + ) + neuron_config_kwargs["qkv_cte_nki_kernel_fuse_rope"] = True + if args.enable_split_qkv_tkg_nki_kernel: + neuron_config_kwargs["qkv_tkg_nki_kernel_enabled"] = True + if args.enable_attn_block_tkg_nki_kernel: + neuron_config_kwargs["attn_block_tkg_nki_kernel_enabled"] = True + if args.enable_attn_block_tkg_cascaded_attention: + neuron_config_kwargs["attn_block_tkg_nki_kernel_cascaded_attention"] = True + if args.enable_attn_block_tkg_cache_update: + neuron_config_kwargs["attn_block_tkg_nki_kernel_cache_update"] = True + if args.enable_out_proj_nki_kernel: + neuron_config_kwargs["out_proj_kernel_enabled"] = True + if args.enable_mlp_cte_nki_kernel or args.enable_mlp_tkg_nki_kernel: + neuron_config_kwargs["mlp_kernel_enabled"] = True + if args.enable_mlp_tkg_nki_kernel: + neuron_config_kwargs["mlp_tkg_nki_kernel_enabled"] = True + if args.enable_quantized_mlp_kernel: + neuron_config_kwargs["quantized_mlp_kernel_enabled"] = True + if args.enable_k_cache_transposed: + neuron_config_kwargs["k_cache_transposed"] = True + if args.weight_dtype in (_WEIGHT_DTYPE_FP8_MLP_ONLY, _WEIGHT_DTYPE_FP8_FULL): + neuron_config_kwargs.update( + { + "quantized": True, + "quantized_checkpoints_path": str( + Path(args.quantized_checkpoints_path).expanduser().resolve() + ), + "quantization_type": "per_channel_symmetric", + "quantization_dtype": "f8e4m3", + "modules_to_not_convert": modules_to_not_convert, + "kv_cache_quant": False, + "quantized_mlp_kernel_enabled": bool( + args.enable_quantized_mlp_kernel + ), + "activation_quantization_type": None, + } + ) + else: + neuron_config_kwargs["quantized"] = False + wlo_skip_patterns = _weights_to_skip_layout_optimization(args) + if wlo_skip_patterns: + neuron_config_kwargs["weights_to_skip_layout_optimization"] = wlo_skip_patterns + if args.enable_kv_cache_quant: + neuron_config_kwargs["kv_cache_quant"] = True + neuron_config_kwargs["kv_quant_config"] = {"direct_cast": True} + if args.disable_on_device_sampling: + # vLLM/host-side sampling consumes logits from the Neuron trace. Without + # logits, the serving path can only surface placeholder token ids. + neuron_config_kwargs["output_logits"] = True + else: + neuron_config_kwargs["on_device_sampling_config"] = OnDeviceSamplingConfig( + do_sample=False, + top_k=1, + top_p=1.0, + temperature=1.0, + ) + # Qwen's LM head is vocab-sharded when on-device sampling is enabled + # (gather_output=False). The sampler must do distributed argmax/top-k + # across vocab shards instead of sampling only from rank 0's shard. + neuron_config_kwargs["vocab_parallel"] = True + if args.output_logits_with_on_device_sampling: + neuron_config_kwargs["output_logits"] = True + if args.disable_argmax_kernel: + neuron_config_kwargs["disable_argmax_kernel"] = True + if args.disable_context_encoding_argmax_kernel: + neuron_config_kwargs["disable_context_encoding_argmax_kernel"] = True + if args.enable_prefix_caching or args.enable_hybrid_apc or args.enable_vllm_chunked_prefill: + neuron_config_kwargs["is_block_kv_layout"] = True + neuron_config_kwargs["pa_block_size"] = args.block_size + neuron_config_kwargs["pa_num_blocks"] = _pa_num_blocks(args) + if args.enable_prefix_caching or args.enable_hybrid_apc: + neuron_config_kwargs["is_prefix_caching"] = True + neuron_config_kwargs["prefix_buckets"] = prefix_buckets + if context_encoding_bucket_pairs is not None: + neuron_config_kwargs["context_encoding_bucket_pairs"] = ( + context_encoding_bucket_pairs + ) + if args.prefix_cte_attention_chunk_size is not None: + neuron_config_kwargs["prefix_cte_attention_chunk_size"] = ( + args.prefix_cte_attention_chunk_size + ) + neuron_config_kwargs["prefix_cte_attention_backend"] = ( + args.prefix_cte_attention_backend + ) + if args.prefix_cte_attention_segment_size is not None: + neuron_config_kwargs["prefix_cte_attention_segment_size"] = ( + args.prefix_cte_attention_segment_size + ) + if args.enable_vllm_chunked_prefill: + # This flag selects Qwen's custom vLLM/Hybrid APC CTE prefix path. + # Do not set NeuronConfig.chunked_prefill_config here: NxDI's generic + # chunked-prefill feature is still rejected by NeuronBaseForCausalLM. + neuron_config_kwargs["is_block_kv_layout"] = True + + neuron_config = NeuronConfig(**neuron_config_kwargs) + + if args.disable_static_hybrid_cache or args.enable_prefix_caching or args.enable_hybrid_apc: + config_dict["use_hybrid_cache_manager"] = False + else: + config_dict.setdefault("use_hybrid_cache_manager", True) + config_dict["use_hybrid_apc_manager"] = args.enable_hybrid_apc + config_dict["gdn_checkpoint_interval"] = args.gdn_checkpoint_interval + config_dict["max_gdn_checkpoint_slots"] = args.max_gdn_checkpoint_slots + config_dict["gdn_recurrent_cache_dtype"] = args.gdn_recurrent_cache_dtype + config_dict["gdn_conv_cache_dtype"] = args.gdn_conv_cache_dtype + config_dict["hybrid_recurrent_cache_dtype"] = args.gdn_recurrent_cache_dtype + config_dict["hybrid_conv_cache_dtype"] = args.gdn_conv_cache_dtype + config_dict["hybrid_cache_mode"] = args.hybrid_cache_mode + config_dict["hybrid_apc_require_vllm_metadata"] = args.hybrid_apc_require_vllm_metadata + config_dict["hybrid_apc_allow_local_hash_fallback"] = ( + not args.hybrid_apc_require_vllm_metadata + ) + config_dict["hybrid_apc_require_attention_block_refs"] = ( + args.hybrid_apc_require_vllm_metadata + ) + config_dict["hybrid_apc_enable_backed_prefix_reads"] = getattr( + args, + "hybrid_apc_enable_backed_prefix_reads", + False, + ) + config_dict["hybrid_apc_commit_during_token_generation"] = ( + args.hybrid_apc_commit_during_token_generation + ) + config_dict["use_qwen_hybrid_chunked_prefill"] = args.enable_vllm_chunked_prefill + config_dict["use_qwen_hybrid_chunked_prefill_nki"] = args.enable_vllm_chunked_prefill + config_dict["use_qwen_deltanet_decode_nki"] = getattr( + args, "enable_deltanet_decode_nki", False + ) + config_dict["use_text_only_cte_inputs"] = args.text_only_cte + config_dict["use_compact_cte_attention_mask"] = args.compact_cte_attention_mask + config_dict["use_cold_zero_conv_fast_path"] = args.cold_zero_conv_fast_path + config_dict["use_qwen_qk_norm_rope_nki"] = args.enable_qwen_qk_norm_rope_nki_kernel + config_dict["use_qwen_output_gate_nki"] = args.enable_qwen_output_gate_nki_kernel + config_dict["use_qwen_qkv_gate_packed"] = args.enable_qwen_qkv_gate_packed_kernel + config_dict["use_qwen_gated_o_proj_nki"] = args.enable_qwen_gated_o_proj_nki_kernel + config_dict["disable_token_generation_wlo"] = _disable_token_generation_wlo(args) + inf_config = Qwen35InferenceConfig(neuron_config=neuron_config, **config_dict) + return inf_config, modules_to_not_convert + + +def main() -> int: + parser = argparse.ArgumentParser() + parser.add_argument("--repo-root", default=None) + parser.add_argument("--model-path", required=True) + parser.add_argument("--compiled-path", required=True) + parser.add_argument("--quantized-checkpoints-path") + parser.add_argument( + "--base-compile-work-dir", + default=None, + help=( + "NxDI compiler work directory. Defaults next to --compiled-path " + "instead of /tmp so large compiles do not fill the root volume." + ), + ) + parser.add_argument( + "--weight-dtype", + choices=[ + _WEIGHT_DTYPE_FP8_MLP_ONLY, + _WEIGHT_DTYPE_FP8_FULL, + _WEIGHT_DTYPE_BF16_CONTROL, + ], + default=_WEIGHT_DTYPE_FP8_MLP_ONLY, + help=( + "Weight mode to compile. Use bf16_control for the non-FP8 " + "host-logits real-token control." + ), + ) + parser.add_argument("--seq-len", type=int, default=65536) + parser.add_argument( + "--max-context-length", + type=int, + default=None, + help=( + "Maximum total context length in NeuronConfig. Defaults to the " + "largest CTE bucket; set higher when chunked prefill serves long " + "contexts with smaller active chunks." + ), + ) + parser.add_argument("--cte-bucket", type=int, default=512) + parser.add_argument("--cte-buckets", nargs="+", default=None) + parser.add_argument("--prefix-buckets", nargs="+", default=None) + parser.add_argument( + "--context-encoding-bucket-pairs", + nargs="+", + default=None, + help=( + "Optional sparse context-encoding 2D buckets as ACTIVE:PREFIX " + "pairs. Prefix 0 pairs for every CTE bucket are added " + "automatically unless --omit-zero-prefix-pair is set." + ), + ) + parser.add_argument( + "--omit-zero-prefix-pair", + action="store_true", + help=( + "Do not automatically add ACTIVE:0 dense context-encoding pairs. " + "Use for long-prefix fallback artifacts that should not load the " + "dense cold-prefill NEFF." + ), + ) + parser.add_argument("--token-generation-buckets", nargs="+", default=None) + parser.add_argument("--token-generation-batches", nargs="+", default=None) + parser.add_argument( + "--disable-token-generation-wlo", + action="store_true", + help=( + "Disable NxDI token-generation weight layout optimization. Use this " + "when the generated layout_opt graph fails runtime validation." + ), + ) + parser.add_argument( + "--weights-to-skip-layout-optimization", + nargs="+", + default=None, + help=( + "Regex patterns for checkpoint tensors that must not go through " + "weight layout optimization. FP8 modes always add Qwen3.6-safe " + "defaults for per-channel scale tensors and the tiny DeltaNet " + "conv1d weight." + ), + ) + parser.add_argument("--block-size", type=int, default=256) + parser.add_argument("--pa-num-blocks", type=int, default=None) + parser.add_argument( + "--pa-headroom-blocks", + type=int, + default=0, + help=( + "Extra usable PA blocks above the minimum seq_len/max_num_seqs " + "capacity. Ignored when --pa-num-blocks is set. The final value is " + "the NeuronConfig.pa_num_blocks value and should match vLLM " + "--num-gpu-blocks-override." + ), + ) + parser.add_argument("--tp-degree", type=int, default=4) + parser.add_argument("--logical-nc-config", type=int, default=2) + parser.add_argument("--max-num-seqs", type=int, default=1) + parser.add_argument("--ctx-batch-size", type=int, default=1) + parser.add_argument("--skip-warmup", action="store_true") + parser.add_argument("--async-mode", action="store_true") + parser.add_argument("--enable-prefix-caching", action="store_true") + parser.add_argument("--enable-hybrid-apc", action="store_true") + parser.add_argument("--enable-vllm-chunked-prefill", action="store_true") + parser.add_argument( + "--text-only-cte", + action=argparse.BooleanOptionalAction, + default=True, + ) + parser.add_argument( + "--compact-cte-attention-mask", + action=argparse.BooleanOptionalAction, + default=True, + ) + parser.add_argument( + "--cold-zero-conv-fast-path", + action=argparse.BooleanOptionalAction, + default=False, + help=( + "Trace the DeltaNet conv path for cold context encoding that always " + "starts at position 0. Do not use for APC or partial-prefix suffix CTE." + ), + ) + parser.add_argument( + "--enable-deltanet-decode-nki", + action="store_true", + help=( + "Trace token generation with the stateful one-token DeltaNet NKI " + "decode step instead of the default Torch/XLA recurrent step." + ), + ) + parser.add_argument( + "--deltanet-cte-backend", + choices=[ + "env", + "fused", + "nki_chunked", + "pytorch_chunk", + "sequential", + "nki_recurrent", + ], + default="env", + help=( + "DeltaNet CTE backend to force during tracing. The default preserves " + "the caller's USE_NKI_* environment. Use nki_chunked or " + "pytorch_chunk to compile controls for fused-CTE NaNs." + ), + ) + parser.add_argument("--disable-on-device-sampling", action="store_true") + parser.add_argument( + "--disable-argmax-kernel", + action="store_true", + help=( + "Use the non-custom distributed argmax path for on-device greedy " + "sampling. This is slower but avoids the NKI argmax output path " + "when validating sampled-token correctness." + ), + ) + parser.add_argument( + "--disable-context-encoding-argmax-kernel", + action="store_true", + help=( + "Use the non-custom distributed argmax path only for context-encoding " + "greedy sampling. Token generation keeps the configured argmax path, " + "which limits decode-performance impact while isolating prefill " + "sampled-token correctness." + ), + ) + parser.add_argument( + "--output-logits-with-on-device-sampling", + action="store_true", + help=( + "Debug mode: keep on-device greedy sampling enabled but also return " + "logits from the trace so sampled token IDs can be compared with " + "host argmax." + ), + ) + parser.add_argument("--kernel-q-tile-size", type=int, default=128) + parser.add_argument("--kernel-kv-tile-size", type=int, default=1024) + parser.add_argument( + "--enable-fused-qkv", + action="store_true", + help=( + "Fuse Q/K/V projection weights in the NxDI attention module. This " + "is required by the QKV NKI kernels and the block TKG decode kernel." + ), + ) + parser.add_argument( + "--enable-qkv-nki-kernels", + action="store_true", + help=( + "Enable NxDI QKV kernels required by the block token-generation " + "attention kernel." + ), + ) + parser.add_argument( + "--enable-qkv-cte-nki-kernel-fuse-rope", + action="store_true", + help=( + "Pass CTE RoPE cos/sin into the NxDI QKV NKI kernel so Q/K RoPE " + "is fused into the projection kernel. For Qwen3.6 this must be " + "validated with partial-RoPE coverage before compiling a perf artifact." + ), + ) + parser.add_argument( + "--enable-qwen-qk-norm-rope-nki-kernel", + action="store_true", + help=( + "Use the Qwen3.6-specific NKI kernel that fuses Q/K per-head " + "RMSNorm with partial RoPE during multi-token context encoding." + ), + ) + parser.add_argument( + "--enable-qwen-output-gate-nki-kernel", + action="store_true", + help=( + "Use the Qwen3.6-specific output-gate projection path that routes " + "the multi-token attention gate matmul through the NKI QKV CTE " + "projection kernel." + ), + ) + parser.add_argument( + "--enable-qwen-qkv-gate-packed-kernel", + action="store_true", + help=( + "Use the Qwen3.6-specific packed QKV+gate projection path. This " + "packs full-attention Wqkv as [Q | output_gate | K | V] and " + "splits the gate from the QKV NKI output instead of running a " + "separate output_gate_proj." + ), + ) + parser.add_argument( + "--enable-qwen-gated-o-proj-nki-kernel", + action="store_true", + help=( + "Use the Qwen3.6-specific ROW FP8 output-projection kernel that " + "applies sigmoid(output_gate) to attention output inside the " + "projection kernel for multi-token context encoding." + ), + ) + parser.add_argument( + "--enable-split-qkv-tkg-nki-kernel", + action="store_true", + help=( + "Enable Qwen's split Q/K/V token-generation NKI projection path. " + "This is TKG-only and intentionally does not enable fused_qkv or " + "the stock QKV CTE wrapper." + ), + ) + parser.add_argument( + "--enable-attn-block-tkg-nki-kernel", + action="store_true", + help=( + "Enable the NxDI token-generation attention NKI kernel for block " + "KV layout. This targets decode speed when prefix caching is used." + ), + ) + parser.add_argument( + "--enable-attn-block-tkg-cascaded-attention", + action="store_true", + help=( + "Enable cascaded attention for the block token-generation NKI " + "attention kernel." + ), + ) + parser.add_argument( + "--enable-attn-block-tkg-cache-update", + action="store_true", + help=( + "Update KV cache inside the block token-generation attention " + "kernel instead of through the separate cache update path." + ), + ) + parser.add_argument( + "--enable-out-proj-nki-kernel", + action="store_true", + help=( + "Enable NxDI's NKI output-projection kernel for attention output. " + "Block TKG enables this internally; this flag exposes it for " + "non-block-TKG decode experiments." + ), + ) + parser.add_argument( + "--enable-mlp-tkg-nki-kernel", + action="store_true", + help=( + "Use NxDI/NKILib's MLP kernel for token generation. The Qwen3.6 " + "custom decoder keeps this behind a flag because it changes the " + "dense FFN lowering path." + ), + ) + parser.add_argument( + "--enable-mlp-cte-nki-kernel", + action="store_true", + help=( + "Use NxDI/NKILib's MLP kernel for context encoding. This targets " + "cold-prefill dense SwiGLU cost and keeps Qwen CTE RMSNorm on the " + "separate high-precision path before FP8 GEMM quantization." + ), + ) + parser.add_argument( + "--enable-quantized-mlp-kernel", + action="store_true", + help=( + "Enable the quantized FP8 MLP kernel path. Pair this with " + "--enable-mlp-tkg-nki-kernel for FP8 full-weight decode " + "experiments." + ), + ) + parser.add_argument( + "--enable-k-cache-transposed", + action="store_true", + help=( + "Store the K cache in the transposed layout used by the Neuron " + "decode attention path. Best paired with block TKG cache update." + ), + ) + parser.add_argument( + "--enable-kv-cache-quant", + action="store_true", + help=( + "Use the NxDI FP8 direct-cast KV cache quantization path to reduce " + "decode KV-cache HBM traffic." + ), + ) + parser.add_argument( + "--prefix-cte-attention-chunk-size", + type=int, + default=None, + help=( + "When set, long prefix-cache CTE attention streams cached prefix KV " + "in chunks of this size using online softmax instead of compiling " + "one monolithic [active_tokens, prefix_tokens] attention score " + "tensor. This is intended for 256K prefix buckets that exceed " + "Neuron HBM scratchpad when compiled as a single prefix-attention " + "shape." + ), + ) + parser.add_argument( + "--prefix-cte-attention-backend", + choices=["attention_cte", "segmented_cte"], + default="attention_cte", + help=( + "Prefix-cache CTE attention backend. attention_cte is the existing " + "flat-prior kernel. segmented_cte uses the Neuron 2.30 block-KV " + "segmented CTE kernel to stream long cached prefixes by segment." + ), + ) + parser.add_argument( + "--prefix-cte-attention-segment-size", + type=int, + default=None, + help=( + "Prior segment size for --prefix-cte-attention-backend segmented_cte. " + "Must be positive and divisible by --block-size." + ), + ) + parser.add_argument("--disable-static-hybrid-cache", action="store_true") + parser.add_argument("--gdn-checkpoint-interval", type=int, default=256) + parser.add_argument("--max-gdn-checkpoint-slots", type=int, default=8) + parser.add_argument("--gdn-recurrent-cache-dtype", default="float32") + parser.add_argument("--gdn-conv-cache-dtype", default="bfloat16") + parser.add_argument("--hybrid-cache-mode", default="all") + parser.add_argument("--hybrid-apc-require-vllm-metadata", action="store_true") + parser.add_argument( + "--hybrid-apc-enable-backed-prefix-reads", + action=argparse.BooleanOptionalAction, + default=False, + ) + parser.add_argument( + "--hybrid-apc-commit-during-token-generation", + action=argparse.BooleanOptionalAction, + default=False, + help=( + "Keep the legacy Hybrid APC checkpoint-bank commit outputs on " + "token generation traces. The default commits checkpoint banks only " + "during context encoding." + ), + ) + parser.add_argument( + "--quantize-edge-mlp-layers", + action="store_true", + help=( + "Quantize layer-0 and final-layer MLP weights too. By default they " + "stay BF16, matching the AWS Trn2 FP8 tutorial's conservative " + "edge-layer policy." + ), + ) + parser.add_argument( + "--quantize-lm-head", + action="store_true", + help=( + "Also quantize lm_head in fp8_full mode. Default keeps lm_head BF16, " + "matching common NVIDIA/vLLM FP8 policy." + ), + ) + parser.add_argument( + "--fp8-quantize-linear-attn-gates", + action="store_true", + help=( + "Use the older coherent FP8 policy for Qwen3.6 linear-attention " + "gate projections: leave in_proj_a/in_proj_b out of " + "modules_to_not_convert and manually quantize their weights to " + "FP8. This is an isolation flag because public Qwen FP8 configs " + "usually keep gate/control projections higher precision." + ), + ) + parser.add_argument( + "--fp8-exclude-groups", + nargs="*", + choices=sorted(_FP8_EXCLUDE_GROUPS), + default=[], + help=( + "Extra fp8_full module groups to leave BF16 for targeted coherence " + "isolation. Useful values are linear_attn, mlp, self_attn, and the " + "finer-grained linear_attn_qkv/linear_attn_z/linear_attn_out_proj/" + "self_attn_qkv/self_attn_o_proj groups." + ), + ) + parser.add_argument("--force-quantize", action="store_true") + parser.add_argument("--quantize-only", action="store_true") + parser.add_argument( + "--postprocess-only", + action="store_true", + help=( + "Run post-compile artifact fixes on an existing compiled-path " + "without regenerating FP8 checkpoints or invoking model.compile(). " + "Useful after an interrupted checkpoint-bank insertion." + ), + ) + parser.add_argument("--load-after-compile", action="store_true") + args = parser.parse_args() + if ( + args.weight_dtype in (_WEIGHT_DTYPE_FP8_MLP_ONLY, _WEIGHT_DTYPE_FP8_FULL) + and not args.quantized_checkpoints_path + ): + parser.error("--quantized-checkpoints-path is required for FP8 weight modes") + if args.quantize_lm_head and args.weight_dtype != _WEIGHT_DTYPE_FP8_FULL: + parser.error("--quantize-lm-head is only valid with --weight-dtype fp8_full") + if args.enable_split_qkv_tkg_nki_kernel and ( + args.enable_fused_qkv + or args.enable_qkv_nki_kernels + or args.enable_attn_block_tkg_nki_kernel + ): + parser.error( + "--enable-split-qkv-tkg-nki-kernel cannot be combined with " + "--enable-fused-qkv, --enable-qkv-nki-kernels, " + "or --enable-attn-block-tkg-nki-kernel" + ) + if ( + args.context_encoding_bucket_pairs is not None + and not (args.enable_prefix_caching or args.enable_hybrid_apc) + ): + parser.error( + "--context-encoding-bucket-pairs requires prefix caching or Hybrid APC" + ) + if args.max_num_seqs <= 0: + parser.error("--max-num-seqs must be positive") + if args.ctx_batch_size <= 0: + parser.error("--ctx-batch-size must be positive") + if args.pa_headroom_blocks < 0: + parser.error("--pa-headroom-blocks must be non-negative") + if args.pa_num_blocks is not None and args.pa_headroom_blocks: + parser.error("--pa-headroom-blocks cannot be combined with --pa-num-blocks") + if ( + args.prefix_cte_attention_chunk_size is not None + and args.prefix_cte_attention_chunk_size <= 0 + ): + parser.error("--prefix-cte-attention-chunk-size must be positive") + if ( + args.prefix_cte_attention_segment_size is not None + and args.prefix_cte_attention_segment_size <= 0 + ): + parser.error("--prefix-cte-attention-segment-size must be positive") + if ( + args.prefix_cte_attention_segment_size is not None + and args.prefix_cte_attention_segment_size % args.block_size != 0 + ): + parser.error( + "--prefix-cte-attention-segment-size must be divisible by --block-size" + ) + if args.enable_hybrid_apc and args.gdn_checkpoint_interval != args.block_size: + parser.error( + "--enable-hybrid-apc v0 requires --gdn-checkpoint-interval to " + "equal --block-size" + ) + + repo = _repo_root(args.repo_root) + contrib_model_dir = repo / "contrib" / "models" / "Qwen3.6-27B" + sys.path.insert(0, str(repo / "src")) + sys.path.insert(0, str(repo)) + sys.path.insert(0, str(contrib_model_dir)) + if args.weight_dtype in (_WEIGHT_DTYPE_FP8_MLP_ONLY, _WEIGHT_DTYPE_FP8_FULL): + _ensure_fp8_environment() + _configure_deltanet_cte_backend(args.deltanet_cte_backend) + + from src.modeling_qwen35 import NeuronQwen35ForCausalLM # noqa: WPS433 + + model_path = Path(args.model_path).expanduser().resolve() + compiled_path = Path(args.compiled_path).expanduser().resolve() + quantized_path = ( + Path(args.quantized_checkpoints_path).expanduser().resolve() + if args.quantized_checkpoints_path + else None + ) + base_compile_work_dir = _configure_base_compile_work_dir( + compiled_path, + args.base_compile_work_dir, + ) + + inf_config, modules_to_not_convert = _build_config(args) + + print("WEIGHT_DTYPE_MODE", args.weight_dtype, flush=True) + if args.weight_dtype == _WEIGHT_DTYPE_FP8_MLP_ONLY: + print("FP8_MODE mlp_only", flush=True) + elif args.weight_dtype == _WEIGHT_DTYPE_FP8_FULL: + print("FP8_MODE full", flush=True) + else: + print("FP8_MODE disabled_bf16_control", flush=True) + print("QUANTIZE_LM_HEAD", bool(args.quantize_lm_head), flush=True) + print( + "FP8_QUANTIZE_LINEAR_ATTN_GATES", + bool(args.fp8_quantize_linear_attn_gates), + flush=True, + ) + print( + "FP8_EXCLUDE_GROUPS", + ",".join(sorted(set(args.fp8_exclude_groups))) or "none", + flush=True, + ) + print("MODEL_PATH", str(model_path), flush=True) + print("COMPILED_PATH", str(compiled_path), flush=True) + print("BASE_COMPILE_WORK_DIR", str(base_compile_work_dir), flush=True) + print("DELTANET_CTE_BACKEND", args.deltanet_cte_backend, flush=True) + for env_name in sorted(_DELTANET_CTE_BACKEND_ENV): + print(env_name, os.environ.get(env_name), flush=True) + if quantized_path is not None: + print("QUANTIZED_CHECKPOINTS_PATH", str(quantized_path), flush=True) + for env_name in _FP8_ENV_DEFAULTS: + print(env_name, os.environ.get(env_name), flush=True) + print( + "WEIGHTS_TO_SKIP_LAYOUT_OPTIMIZATION", + json.dumps(inf_config.neuron_config.weights_to_skip_layout_optimization), + flush=True, + ) + print( + "DISABLE_TOKEN_GENERATION_WLO", + bool(inf_config.disable_token_generation_wlo), + flush=True, + ) + print("MODULES_TO_NOT_CONVERT_COUNT", len(modules_to_not_convert), flush=True) + print( + "CONTEXT_TRACE_SHAPE", + json.dumps( + { + "seq_len": args.seq_len, + "max_context_length": _max_context_length(args, _cte_buckets(args)), + "context_encoding_buckets": _cte_buckets(args), + "prefix_buckets": _prefix_buckets(args, _cte_buckets(args)), + "prefix_cte_attention_backend": args.prefix_cte_attention_backend, + "prefix_cte_attention_segment_size": ( + args.prefix_cte_attention_segment_size + ), + "prefix_cte_attention_chunk_size": args.prefix_cte_attention_chunk_size, + "context_encoding_bucket_pairs": _context_encoding_bucket_pairs( + args, + _cte_buckets(args), + _prefix_buckets(args, _cte_buckets(args)), + ), + "token_generation_buckets": _token_generation_buckets(args), + "token_generation_batches": _token_generation_batches(args), + "max_num_seqs": args.max_num_seqs, + "ctx_batch_size": args.ctx_batch_size, + "tkg_batch_size": args.max_num_seqs, + "async_mode": args.async_mode, + "skip_warmup": args.skip_warmup, + "enable_prefix_caching": args.enable_prefix_caching, + "enable_hybrid_apc": args.enable_hybrid_apc, + "enable_vllm_chunked_prefill": args.enable_vllm_chunked_prefill, + "enable_deltanet_decode_nki": args.enable_deltanet_decode_nki, + "enable_fused_qkv": args.enable_fused_qkv, + "enable_qkv_nki_kernels": args.enable_qkv_nki_kernels, + "enable_qwen_qk_norm_rope_nki_kernel": ( + args.enable_qwen_qk_norm_rope_nki_kernel + ), + "enable_qwen_qkv_gate_packed_kernel": ( + args.enable_qwen_qkv_gate_packed_kernel + ), + "enable_qwen_gated_o_proj_nki_kernel": ( + args.enable_qwen_gated_o_proj_nki_kernel + ), + "enable_split_qkv_tkg_nki_kernel": ( + args.enable_split_qkv_tkg_nki_kernel + ), + "enable_attn_block_tkg_nki_kernel": ( + args.enable_attn_block_tkg_nki_kernel + ), + "enable_attn_block_tkg_cascaded_attention": ( + args.enable_attn_block_tkg_cascaded_attention + ), + "enable_attn_block_tkg_cache_update": ( + args.enable_attn_block_tkg_cache_update + ), + "enable_out_proj_nki_kernel": args.enable_out_proj_nki_kernel, + "enable_mlp_cte_nki_kernel": args.enable_mlp_cte_nki_kernel, + "enable_mlp_tkg_nki_kernel": args.enable_mlp_tkg_nki_kernel, + "enable_quantized_mlp_kernel": args.enable_quantized_mlp_kernel, + "enable_k_cache_transposed": args.enable_k_cache_transposed, + "enable_kv_cache_quant": args.enable_kv_cache_quant, + "fp8_quantize_linear_attn_gates": bool( + args.fp8_quantize_linear_attn_gates + ), + "block_size": args.block_size, + "pa_min_blocks": _pa_min_blocks(args), + "pa_requested_blocks": _pa_requested_blocks(args), + "pa_usable_headroom_blocks": ( + _pa_requested_blocks(args) - _pa_min_blocks(args) + ), + "pa_headroom_blocks": ( + _pa_requested_blocks(args) - _pa_min_blocks(args) + ), + "pa_num_blocks": _pa_num_blocks(args), + "gdn_checkpoint_interval": args.gdn_checkpoint_interval, + "max_gdn_checkpoint_slots": args.max_gdn_checkpoint_slots, + "hybrid_apc_commit_during_token_generation": ( + args.hybrid_apc_commit_during_token_generation + ), + }, + sort_keys=True, + ), + flush=True, + ) + + if args.postprocess_only: + if not compiled_path.exists(): + raise FileNotFoundError(f"--postprocess-only missing artifact: {compiled_path}") + print("POSTPROCESS_ONLY_START", flush=True) + _ensure_hybrid_checkpoint_weights(compiled_path, inf_config) + _sanitize_reloadable_neuron_config(compiled_path) + print("COMPILE_DONE", flush=True) + return 0 + + if args.weight_dtype == _WEIGHT_DTYPE_BF16_CONTROL: + print("QUANTIZE_SKIP bf16_control", flush=True) + elif args.force_quantize or not _quantized_checkpoint_ready(quantized_path): + print("QUANTIZE_START manual_fp8", flush=True) + _save_manual_fp8_state_dict( + model_path, + quantized_path, + weight_dtype=args.weight_dtype, + quantize_edge_mlp_layers=args.quantize_edge_mlp_layers, + quantize_lm_head=args.quantize_lm_head, + quantize_linear_attn_gates=args.fp8_quantize_linear_attn_gates, + fp8_exclude_groups=set(args.fp8_exclude_groups), + ) + print("QUANTIZE_DONE", flush=True) + else: + print("QUANTIZE_SKIP existing checkpoint found", flush=True) + + if args.quantize_only: + return 0 + + print("COMPILE_START", flush=True) + model = NeuronQwen35ForCausalLM(str(model_path), inf_config) + model.compile(str(compiled_path)) + _ensure_hybrid_checkpoint_weights(compiled_path, inf_config) + _sanitize_reloadable_neuron_config(compiled_path) + del model + gc.collect() + print("COMPILE_DONE", flush=True) + + if args.load_after_compile: + model = NeuronQwen35ForCausalLM(str(compiled_path)) + model.load(str(compiled_path)) + print("LOAD_AFTER_COMPILE_OK", flush=True) + + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/contrib/models/Qwen3.6-27B/test/integration/test_model.py b/contrib/models/Qwen3.6-27B/test/integration/test_model.py new file mode 100644 index 00000000..b1128c12 --- /dev/null +++ b/contrib/models/Qwen3.6-27B/test/integration/test_model.py @@ -0,0 +1,605 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +Integration tests for Qwen3.6-27B on Neuron. + +Tests compilation, loading, inference accuracy, and performance using +the full 27B model with pre-downloaded HuggingFace weights on a trn2 instance. + +Qwen3.6-27B shares identical architecture with Qwen3.5-27B (qwen3_5 model_type). +These tests use the same Qwen35* classes and QWEN35_* env vars because the +underlying code is shared. + +Note: A mini model option is not provided because DeltaNet layers require NKI +kernels that only execute on Neuron devices, and the hybrid DeltaNet + GQA +architecture needs at least TP=4 for the full model to fit in HBM. + +Environment variables: + QWEN35_MODEL_PATH Path to HF model weights (required) + QWEN35_COMPILED_PATH Path to compiled artifacts (default: /tmp/qwen35_27b_traced) + QWEN35_TP_DEGREE Tensor parallelism degree (default: 4) + QWEN35_SEQ_LEN Max sequence length (default: 128) + TTFT_THRESHOLD_MS Max TTFT in ms (default: 5000) + THROUGHPUT_THRESHOLD Min throughput in tok/s (default: 5.0) + +Prerequisites: + - trn2.3xlarge or larger with TP >= 4 NeuronCores available + - NXDI installed (neuronx_distributed_inference) + - HuggingFace weights downloaded to QWEN35_MODEL_PATH + - SDK 2.29+ (NKI 0.3.0 required for DeltaNet kernels) + +Usage: + # Full model (trn2.3xlarge, TP=4): + QWEN35_MODEL_PATH=/mnt/models/Qwen3.6-27B \\ + QWEN35_COMPILED_PATH=/mnt/models/qwen36_traced \\ + pytest test/integration/test_model.py --capture=tee-sys +""" + +import gc +import json +import os +import shutil +import subprocess +import sys +import time + +import pytest +import torch + +# Ensure the contrib root (Qwen3.6-27B/) is on sys.path +_CONTRIB_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")) +if _CONTRIB_ROOT not in sys.path: + sys.path.insert(0, _CONTRIB_ROOT) + +# ── Configuration from environment ────────────────────────────────────── + +MODEL_PATH = os.environ.get("QWEN35_MODEL_PATH", "") +COMPILED_PATH = os.environ.get("QWEN35_COMPILED_PATH", "/tmp/qwen35_27b_traced") +TP_DEGREE = int(os.environ.get("QWEN35_TP_DEGREE", "4")) +SEQ_LEN = int(os.environ.get("QWEN35_SEQ_LEN", "128")) +TTFT_THRESHOLD_MS = float(os.environ.get("TTFT_THRESHOLD_MS", "5000")) +THROUGHPUT_THRESHOLD = float(os.environ.get("THROUGHPUT_THRESHOLD", "5.0")) +USE_HYBRID_CACHE = os.environ.get("QWEN35_USE_HYBRID_CACHE", "0") == "1" +RECORD_HBM = os.environ.get("QWEN35_RECORD_HBM", "0") == "1" + +requires_model_path = pytest.mark.skipif( + not MODEL_PATH, + reason=( + "QWEN35_MODEL_PATH not set. Integration tests require the full 27B model " + "weights. Set QWEN35_MODEL_PATH=/path/to/Qwen3.6-27B to run these tests." + ), +) +requires_hbm_recording = pytest.mark.skipif( + not RECORD_HBM, + reason=( + "QWEN35_RECORD_HBM=1 not set. This optional test records Neuron HBM " + "usage for dummy-KV vs hybrid-cache comparisons." + ), +) + + +# ── Fixtures ──────────────────────────────────────────────────────────── + + +@pytest.fixture(scope="module") +def model_path(): + """Return path to model weights.""" + return MODEL_PATH + + +@pytest.fixture(scope="module") +def compiled_model(model_path): + """Compile and load the model on Neuron.""" + import json + + from neuronx_distributed_inference.models.config import ( + NeuronConfig, + OnDeviceSamplingConfig, + ) + from src.modeling_qwen35 import Qwen35InferenceConfig, NeuronQwen35ForCausalLM + + neuron_config = NeuronConfig( + tp_degree=TP_DEGREE, + batch_size=1, + ctx_batch_size=1, + tkg_batch_size=1, + seq_len=SEQ_LEN, + torch_dtype=torch.bfloat16, + on_device_sampling_config=OnDeviceSamplingConfig(top_k=1), + enable_bucketing=False, + flash_decoding_enabled=False, + logical_nc_config=2, + save_sharded_checkpoint=True, + ) + + # Read config.json directly (model_type 'qwen3_5' may not be in + # AutoConfig registry for all transformers versions) + with open(os.path.join(model_path, "config.json")) as f: + full_config = json.load(f) + text_config = full_config.get("text_config", full_config) + + config_dict = dict(text_config) + config_dict["pad_token_id"] = text_config.get("eos_token_id", 248044) + if "rope_parameters" in text_config: + config_dict["rope_theta"] = text_config["rope_parameters"].get( + "rope_theta", 10000000 + ) + config_dict.setdefault("tie_word_embeddings", False) + + inf_config = Qwen35InferenceConfig( + neuron_config=neuron_config, + use_hybrid_cache_manager=USE_HYBRID_CACHE, + **config_dict, + ) + + # Compile if no existing artifacts + compiled_path = COMPILED_PATH + neff_path = os.path.join(compiled_path, "model.pt") + if not os.path.exists(neff_path): + print(f"Compiling to {compiled_path}...") + model = NeuronQwen35ForCausalLM(model_path, inf_config) + model.compile(compiled_path) + del model + gc.collect() + + # Load + print(f"Loading from {compiled_path}...") + model = NeuronQwen35ForCausalLM(compiled_path) + model.load(compiled_path) + return model + + +@pytest.fixture(scope="module") +def tokenizer(model_path): + """Load tokenizer.""" + from transformers import AutoTokenizer + + tok = AutoTokenizer.from_pretrained(model_path, padding_side="right") + if tok.pad_token is None: + tok.pad_token = tok.eos_token + return tok + + +@pytest.fixture(scope="module") +def generation_config(tokenizer): + """Create generation config.""" + from transformers import GenerationConfig + + return GenerationConfig( + do_sample=True, + top_k=1, + pad_token_id=tokenizer.pad_token_id, + eos_token_id=tokenizer.eos_token_id, + ) + + +def _generate(model, tokenizer, generation_config, prompt, max_new_tokens=20): + """Generate text using the NXDI model.""" + import transformers + + from neuronx_distributed_inference.utils.hf_adapter import ( + HuggingFaceGenerationAdapter, + ) + + inputs = tokenizer(prompt, padding=True, return_tensors="pt") + gen_model = HuggingFaceGenerationAdapter(model) + gen_model.generation_config.transformers_version = transformers.__version__ + generation_config.transformers_version = transformers.__version__ + outputs = gen_model.generate( + inputs.input_ids, + generation_config=generation_config, + attention_mask=inputs.attention_mask, + max_new_tokens=max_new_tokens, + ) + return outputs[0].tolist(), tokenizer.decode(outputs[0], skip_special_tokens=True) + + +def _is_repetitive(text, max_repeat=5): + """Check for excessive word repetition.""" + words = text.split() + if len(words) < max_repeat: + return False + for i in range(len(words) - max_repeat + 1): + if len(set(words[i : i + max_repeat])) == 1: + return True + return False + + +def _parse_peak_neuron_memory(stdout): + peak_device = 0 + peak_tensors = 0 + samples = 0 + for line in stdout.splitlines(): + line = line.strip() + if not line: + continue + try: + report = json.loads(line) + except json.JSONDecodeError: + continue + for runtime in report.get("neuron_runtime_data", []): + memory_used = runtime.get("report", {}).get("memory_used", {}) + used = memory_used.get("neuron_runtime_used_bytes", {}) + peak_device = max(peak_device, int(used.get("neuron_device", 0) or 0)) + nc_usage = ( + used.get("usage_breakdown", {}).get("neuroncore_memory_usage", {}) + ) + tensor_bytes = sum( + int(core.get("tensors", 0) or 0) for core in nc_usage.values() + ) + peak_tensors = max(peak_tensors, tensor_bytes) + samples += 1 + return peak_device, peak_tensors, samples + + +def _capture_neuron_hbm(tmp_path, fn): + if shutil.which("neuron-monitor") is None: + pytest.skip("neuron-monitor is not available") + + monitor_config = { + "period": "0.5s", + "neuron_runtimes": [ + { + "tag_filter": ".*", + "metrics": [{"type": "memory_used", "period": "0.5s"}], + } + ], + } + config_path = tmp_path / "neuron-monitor.json" + config_path.write_text(json.dumps(monitor_config)) + + proc = subprocess.Popen( + ["neuron-monitor", "--config-file", str(config_path)], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + ) + try: + time.sleep(1.0) + result = fn() + time.sleep(1.0) + finally: + proc.terminate() + try: + stdout, stderr = proc.communicate(timeout=5) + except subprocess.TimeoutExpired: + proc.kill() + stdout, stderr = proc.communicate(timeout=5) + + peak_device, peak_tensors, samples = _parse_peak_neuron_memory(stdout) + assert samples > 0, f"neuron-monitor produced no runtime samples: {stderr}" + assert peak_device > 0, "Expected non-zero Neuron device HBM usage" + return result, peak_device, peak_tensors, samples + + +# ── Smoke Tests ───────────────────────────────────────────────────────── + + +@requires_model_path +def test_model_loads(compiled_model): + """Model compiles and loads successfully.""" + assert compiled_model is not None + assert hasattr(compiled_model, "neuron_config") + print(" Model loaded successfully") + + +@requires_model_path +def test_model_generates(compiled_model, tokenizer, generation_config): + """Model generates at least 5 tokens.""" + tokens, text = _generate( + compiled_model, + tokenizer, + generation_config, + "Hello, I am a language model", + max_new_tokens=20, + ) + input_len = len(tokenizer.encode("Hello, I am a language model")) + new_tokens = len(tokens) - input_len + assert new_tokens >= 5, f"Expected >= 5 new tokens, got {new_tokens}" + print(f" Generated {new_tokens} tokens: {text[:100]}...") + + +# ── Accuracy Tests ────────────────────────────────────────────────────── + + +@requires_model_path +def test_output_coherence(compiled_model, tokenizer, generation_config): + """Output should contain multiple words and not be excessively repetitive.""" + _, text = _generate( + compiled_model, + tokenizer, + generation_config, + "The capital of France is", + max_new_tokens=30, + ) + generated = text[len("The capital of France is") :].strip() + words = generated.split() + assert len(words) >= 3, f"Expected >= 3 words, got {len(words)}: '{generated}'" + assert not _is_repetitive(generated), ( + f"Output is excessively repetitive: '{generated}'" + ) + print(f" Output coherent ({len(words)} words): {generated[:80]}...") + + +@requires_model_path +def test_top_token_valid(compiled_model, tokenizer, generation_config): + """First generated token should be a valid decodable token.""" + tokens, _ = _generate( + compiled_model, + tokenizer, + generation_config, + "Hello!", + max_new_tokens=1, + ) + input_len = len(tokenizer.encode("Hello!")) + first_new = tokens[input_len] + assert 0 <= first_new < len(tokenizer), ( + f"Token {first_new} out of vocab range" + ) + decoded = tokenizer.decode([first_new]) + assert len(decoded) > 0, f"Token {first_new} decoded to empty string" + print(f" First token: {first_new} -> '{decoded}'") + + +@requires_model_path +def test_olympics_prompt_no_invalid_tokens( + compiled_model, tokenizer, generation_config +): + """Regression test for NaN logits producing the int32-min token id.""" + prompt = "Give me a summary of the 2020 Olympics in 100 tokens." + tokens, _ = _generate( + compiled_model, + tokenizer, + generation_config, + prompt, + max_new_tokens=32, + ) + input_len = len(tokenizer.encode(prompt)) + generated = tokens[input_len:] + invalid = [token for token in generated if token < 0 or token >= len(tokenizer)] + + assert len(generated) >= 5, f"Expected >= 5 generated tokens, got {generated}" + assert not invalid, f"Generated invalid token ids: {invalid}" + + +@requires_model_path +def test_capital_of_france(compiled_model, tokenizer, generation_config): + """'The capital of France is' should produce 'Paris' in the response.""" + tokens, text = _generate( + compiled_model, + tokenizer, + generation_config, + "The capital of France is", + max_new_tokens=30, + ) + generated = text[len("The capital of France is") :].strip() + assert "paris" in generated.lower(), ( + f"Expected 'Paris' in output, got: '{generated}'" + ) + print(f" Capital of France: {generated}") + + +# ── Performance Tests ─────────────────────────────────────────────────── + + +@requires_model_path +def test_performance_ttft(compiled_model, tokenizer, generation_config): + """Time to first token should be within threshold.""" + prompt = "Hello, I am a language model" + + # Warmup + _generate(compiled_model, tokenizer, generation_config, prompt, max_new_tokens=1) + + # Measure + times = [] + for _ in range(3): + t0 = time.perf_counter() + _generate( + compiled_model, tokenizer, generation_config, prompt, max_new_tokens=1 + ) + times.append((time.perf_counter() - t0) * 1000) + + avg_ms = sum(times) / len(times) + print(f" TTFT: {avg_ms:.1f} ms (threshold: {TTFT_THRESHOLD_MS} ms)") + assert avg_ms < TTFT_THRESHOLD_MS, ( + f"TTFT {avg_ms:.1f}ms > threshold {TTFT_THRESHOLD_MS}ms" + ) + + +@requires_model_path +def test_performance_throughput(compiled_model, tokenizer, generation_config): + """Throughput should meet minimum threshold.""" + prompt = "Once upon a time" + num_new_tokens = 20 + + # Warmup + _generate(compiled_model, tokenizer, generation_config, prompt, max_new_tokens=5) + + # Measure + t0 = time.perf_counter() + tokens, _ = _generate( + compiled_model, + tokenizer, + generation_config, + prompt, + max_new_tokens=num_new_tokens, + ) + elapsed = time.perf_counter() - t0 + + input_len = len(tokenizer.encode(prompt)) + actual_new = len(tokens) - input_len + throughput = actual_new / elapsed if elapsed > 0 else 0 + + print( + f" Throughput: {throughput:.1f} tok/s ({actual_new} tokens in {elapsed:.2f}s)" + ) + print(f" Threshold: {THROUGHPUT_THRESHOLD} tok/s") + assert throughput > THROUGHPUT_THRESHOLD, ( + f"Throughput {throughput:.1f} tok/s < threshold {THROUGHPUT_THRESHOLD}" + ) + + +@requires_model_path +@requires_hbm_recording +def test_hybrid_cache_hbm_snapshot(compiled_model, tokenizer, generation_config, tmp_path): + """Record peak Neuron HBM for dummy-KV vs hybrid-cache comparison runs.""" + prompt = "Give me a summary of the 2020 Olympics in 100 tokens." + max_new_tokens = int(os.environ.get("QWEN35_HBM_NEW_TOKENS", "32")) + + (_, text), peak_device, peak_tensors, samples = _capture_neuron_hbm( + tmp_path, + lambda: _generate( + compiled_model, + tokenizer, + generation_config, + prompt, + max_new_tokens=max_new_tokens, + ), + ) + + mode = "hybrid" if USE_HYBRID_CACHE else "dummy_kv" + print( + " HBM " + f"mode={mode} peak_device_bytes={peak_device} " + f"peak_tensor_bytes={peak_tensors} samples={samples}" + ) + assert len(text) > len(prompt) + + +# ── Multi-Prompt Quality Test ────────────────────────────────────────── + + +@requires_model_path +def test_multi_prompt_generation(compiled_model, tokenizer, generation_config): + """Multiple prompts should produce coherent outputs.""" + prompts = [ + "The capital of France is", + "def fibonacci(n):", + "The largest ocean on Earth is", + "To make a chocolate cake, you need", + ] + + for prompt in prompts: + _, text = _generate( + compiled_model, + tokenizer, + generation_config, + prompt, + max_new_tokens=30, + ) + generated = text[len(prompt) :].strip() + words = generated.split() + assert len(words) >= 2, ( + f"Prompt '{prompt}' generated too few words: '{generated}'" + ) + assert not _is_repetitive(generated), ( + f"Prompt '{prompt}' produced repetitive output: '{generated}'" + ) + print(f" '{prompt[:30]}...' -> {generated[:60]}...") + + +# ── Standalone runner ─────────────────────────────────────────────────── + +if __name__ == "__main__": + print("=" * 60) + print("Qwen3.6-27B Integration Tests") + print("=" * 60) + + if not MODEL_PATH: + print("\nQWEN35_MODEL_PATH not set. Provide the model path to run tests:") + print(" QWEN35_MODEL_PATH=/path/to/Qwen3.6-27B \\") + print(" QWEN35_COMPILED_PATH=/mnt/models/qwen35_traced \\") + print(" python -m pytest test/integration/test_model.py --capture=tee-sys") + sys.exit(0) + + # Setup + from transformers import AutoTokenizer, GenerationConfig as GenConfig + + tok = AutoTokenizer.from_pretrained(MODEL_PATH, padding_side="right") + if tok.pad_token is None: + tok.pad_token = tok.eos_token + gen_cfg = GenConfig( + do_sample=True, + top_k=1, + pad_token_id=tok.pad_token_id, + eos_token_id=tok.eos_token_id, + ) + + # Build model + import json + + from neuronx_distributed_inference.models.config import ( + NeuronConfig, + OnDeviceSamplingConfig, + ) + from src.modeling_qwen35 import Qwen35InferenceConfig, NeuronQwen35ForCausalLM + + nc = NeuronConfig( + tp_degree=TP_DEGREE, + batch_size=1, + ctx_batch_size=1, + tkg_batch_size=1, + seq_len=SEQ_LEN, + torch_dtype=torch.bfloat16, + on_device_sampling_config=OnDeviceSamplingConfig(top_k=1), + enable_bucketing=False, + flash_decoding_enabled=False, + logical_nc_config=2, + save_sharded_checkpoint=True, + ) + + with open(os.path.join(MODEL_PATH, "config.json")) as f: + full_config = json.load(f) + text_config = full_config.get("text_config", full_config) + config_dict = dict(text_config) + config_dict["pad_token_id"] = text_config.get("eos_token_id", 248044) + if "rope_parameters" in text_config: + config_dict["rope_theta"] = text_config["rope_parameters"].get( + "rope_theta", 10000000 + ) + config_dict.setdefault("tie_word_embeddings", False) + ic = Qwen35InferenceConfig(neuron_config=nc, **config_dict) + + cp = COMPILED_PATH + if not os.path.exists(os.path.join(cp, "model.pt")): + print(f"Compiling to {cp}...") + m = NeuronQwen35ForCausalLM(MODEL_PATH, ic) + m.compile(cp) + del m + gc.collect() + + print(f"Loading from {cp}...") + model = NeuronQwen35ForCausalLM(cp) + model.load(cp) + + tests = [ + ("model_loads", lambda: test_model_loads(model)), + ("model_generates", lambda: test_model_generates(model, tok, gen_cfg)), + ("output_coherence", lambda: test_output_coherence(model, tok, gen_cfg)), + ("top_token_valid", lambda: test_top_token_valid(model, tok, gen_cfg)), + ("capital_of_france", lambda: test_capital_of_france(model, tok, gen_cfg)), + ("performance_ttft", lambda: test_performance_ttft(model, tok, gen_cfg)), + ( + "performance_throughput", + lambda: test_performance_throughput(model, tok, gen_cfg), + ), + ( + "multi_prompt_generation", + lambda: test_multi_prompt_generation(model, tok, gen_cfg), + ), + ] + + passed = 0 + for name, fn in tests: + print(f"\n--- {name} ---") + try: + fn() + print(f" PASS") + passed += 1 + except Exception as e: + print(f" FAIL: {e}") + + print(f"\n{'=' * 60}") + print(f"Results: {passed}/{len(tests)} passed") + print(f"{'=' * 60}") diff --git a/contrib/models/Qwen3.6-27B/test/unit/__init__.py b/contrib/models/Qwen3.6-27B/test/unit/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/contrib/models/Qwen3.6-27B/test/unit/test_config.py b/contrib/models/Qwen3.6-27B/test/unit/test_config.py new file mode 100644 index 00000000..6919c590 --- /dev/null +++ b/contrib/models/Qwen3.6-27B/test/unit/test_config.py @@ -0,0 +1,303 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Unit tests for Qwen3.5/3.6-27B inference configuration. + +CPU-only tests that validate config parsing, layer type setup, +DeltaNet parameter defaults, RoPE configuration, and weight conversion logic. +These tests are architecture-level and apply to both Qwen3.5-27B and Qwen3.6-27B. +""" + +import os +import sys +import unittest +from unittest.mock import MagicMock + +import torch + +# Ensure the contrib root (Qwen3.6-27B/) is on sys.path +_CONTRIB_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")) +if _CONTRIB_ROOT not in sys.path: + sys.path.insert(0, _CONTRIB_ROOT) + +from src.modeling_qwen35 import ( + Qwen35InferenceConfig, + convert_qwen35_hf_to_neuron_state_dict, +) +from neuronx_distributed_inference.models.config import NeuronConfig + + +def _make_config(**overrides): + """Create a Qwen35InferenceConfig with reasonable defaults.""" + neuron_config = NeuronConfig( + tp_degree=overrides.pop("tp_degree", 4), + batch_size=1, + seq_len=128, + torch_dtype=torch.bfloat16, + ) + defaults = dict( + hidden_size=5120, + num_hidden_layers=64, + num_attention_heads=24, + num_key_value_heads=4, + head_dim=256, + intermediate_size=17408, + vocab_size=248320, + rms_norm_eps=1e-6, + max_position_embeddings=131072, + rope_theta=10000, + hidden_act="silu", + # DeltaNet-specific + linear_num_value_heads=48, + linear_num_key_heads=16, + linear_key_head_dim=128, + linear_value_head_dim=128, + linear_conv_kernel_dim=4, + ) + defaults.update(overrides) + config = Qwen35InferenceConfig(neuron_config=neuron_config, **defaults) + return config + + +class TestConfigParsing(unittest.TestCase): + """Test basic config attribute initialization.""" + + def test_hidden_size(self): + config = _make_config() + self.assertEqual(config.hidden_size, 5120) + + def test_num_hidden_layers(self): + config = _make_config() + self.assertEqual(config.num_hidden_layers, 64) + + def test_num_attention_heads(self): + config = _make_config() + self.assertEqual(config.num_attention_heads, 24) + + def test_num_key_value_heads(self): + config = _make_config() + self.assertEqual(config.num_key_value_heads, 4) + + def test_head_dim(self): + config = _make_config() + self.assertEqual(config.head_dim, 256) + + def test_intermediate_size(self): + config = _make_config() + self.assertEqual(config.intermediate_size, 17408) + + def test_vocab_size(self): + config = _make_config() + self.assertEqual(config.vocab_size, 248320) + + def test_hidden_act(self): + config = _make_config() + self.assertEqual(config.hidden_act, "silu") + + +class TestLayerTypes(unittest.TestCase): + """Test hybrid layer type assignment (3 DeltaNet + 1 GQA) x 16.""" + + def test_layer_types_length(self): + config = _make_config() + self.assertEqual(len(config.layer_types), 64) + + def test_layer_types_pattern(self): + """Every 4th layer (3, 7, 11, ...) should be full_attention.""" + config = _make_config() + for i in range(64): + expected = "full_attention" if i % 4 == 3 else "linear_attention" + self.assertEqual(config.layer_types[i], expected, f"Layer {i} mismatch") + + def test_deltanet_layer_count(self): + config = _make_config() + dn_count = sum(1 for t in config.layer_types if t == "linear_attention") + self.assertEqual(dn_count, 48) + + def test_gqa_layer_count(self): + config = _make_config() + gqa_count = sum(1 for t in config.layer_types if t == "full_attention") + self.assertEqual(gqa_count, 16) + + +class TestDeltaNetConfig(unittest.TestCase): + """Test DeltaNet-specific configuration defaults.""" + + def test_linear_num_value_heads(self): + config = _make_config() + self.assertEqual(config.linear_num_value_heads, 48) + + def test_linear_num_key_heads(self): + config = _make_config() + self.assertEqual(config.linear_num_key_heads, 16) + + def test_linear_key_head_dim(self): + config = _make_config() + self.assertEqual(config.linear_key_head_dim, 128) + + def test_linear_value_head_dim(self): + config = _make_config() + self.assertEqual(config.linear_value_head_dim, 128) + + def test_linear_conv_kernel_dim(self): + config = _make_config() + self.assertEqual(config.linear_conv_kernel_dim, 4) + + def test_gdn_apc_checkpoint_defaults(self): + config = _make_config() + self.assertFalse(config.use_hybrid_cache_manager) + self.assertFalse(config.use_hybrid_apc_manager) + self.assertEqual(config.gdn_checkpoint_interval, 256) + self.assertEqual(config.max_gdn_checkpoint_slots, 8) + self.assertEqual(config.hybrid_apc_layout_version, 1) + self.assertFalse(config.hybrid_apc_allow_residual_replay) + self.assertEqual(config.gdn_recurrent_cache_dtype, "float32") + self.assertEqual(config.gdn_conv_cache_dtype, "bfloat16") + self.assertEqual(config.hybrid_recurrent_cache_dtype, "float32") + self.assertEqual(config.hybrid_conv_cache_dtype, "bfloat16") + self.assertEqual(config.hybrid_cache_mode, "all") + self.assertTrue(config.hybrid_cache_prefix_boundary_only) + self.assertTrue(config.hybrid_cache_block_boundary_only) + self.assertFalse(config.hybrid_apc_require_vllm_metadata) + self.assertTrue(config.hybrid_apc_allow_local_hash_fallback) + self.assertFalse(config.hybrid_apc_require_attention_block_refs) + self.assertTrue(config.hybrid_apc_reject_unbacked_attention_hits) + self.assertFalse(config.hybrid_apc_disable_unbacked_prefix_reads) + self.assertTrue(config.use_text_only_cte_inputs) + self.assertTrue(config.use_compact_cte_attention_mask) + self.assertFalse(config.use_cold_zero_conv_fast_path) + + def test_hybrid_apc_manager_defaults_fail_closed(self): + config = _make_config( + use_hybrid_apc_manager=True, + gdn_checkpoint_interval=128, + ) + + self.assertTrue(config.hybrid_apc_require_vllm_metadata) + self.assertFalse(config.hybrid_apc_allow_local_hash_fallback) + self.assertTrue(config.hybrid_apc_require_attention_block_refs) + self.assertTrue(config.hybrid_apc_reject_unbacked_attention_hits) + + def test_hybrid_apc_validation_can_opt_into_local_fallback(self): + config = _make_config( + use_hybrid_apc_manager=True, + hybrid_apc_require_vllm_metadata=False, + hybrid_apc_allow_local_hash_fallback=True, + hybrid_apc_require_attention_block_refs=False, + gdn_checkpoint_interval=128, + ) + + self.assertFalse(config.hybrid_apc_require_vllm_metadata) + self.assertTrue(config.hybrid_apc_allow_local_hash_fallback) + self.assertFalse(config.hybrid_apc_require_attention_block_refs) + + def test_hybrid_apc_require_vllm_metadata_enables_strict_metadata(self): + config = _make_config( + use_hybrid_apc_manager=True, + hybrid_apc_require_vllm_metadata=True, + gdn_checkpoint_interval=128, + ) + + self.assertFalse(config.hybrid_apc_allow_local_hash_fallback) + self.assertTrue(config.hybrid_apc_require_attention_block_refs) + self.assertTrue(config.hybrid_apc_reject_unbacked_attention_hits) + + def test_gdn_checkpoint_interval_must_be_positive(self): + with self.assertRaisesRegex(ValueError, "gdn_checkpoint_interval"): + _make_config(gdn_checkpoint_interval=0) + + def test_hybrid_cache_dtype_aliases_are_normalized(self): + config = _make_config( + hybrid_recurrent_cache_dtype="fp32", + hybrid_conv_cache_dtype="bf16", + ) + self.assertEqual(config.hybrid_recurrent_cache_dtype, "float32") + self.assertEqual(config.hybrid_conv_cache_dtype, "bfloat16") + + def test_hybrid_cache_dtype_rejects_fp8(self): + with self.assertRaisesRegex(ValueError, "hybrid_recurrent_cache_dtype"): + _make_config(hybrid_recurrent_cache_dtype="fp8") + + def test_hybrid_apc_requires_float32_recurrent_checkpoint_cache(self): + with self.assertRaisesRegex(ValueError, "requires float32 recurrent GDN"): + _make_config( + use_hybrid_apc_manager=True, + gdn_checkpoint_interval=128, + hybrid_recurrent_cache_dtype="bf16", + ) + + def test_hybrid_apc_rejects_non_all_mode(self): + with self.assertRaisesRegex(ValueError, "hybrid_cache_mode='all'"): + _make_config(use_hybrid_apc_manager=True, hybrid_cache_mode="align") + + def test_hybrid_apc_rejects_residual_replay_in_v0(self): + with self.assertRaisesRegex(ValueError, "reserved for v1"): + _make_config( + use_hybrid_apc_manager=True, + hybrid_apc_allow_residual_replay=True, + gdn_checkpoint_interval=128, + ) + + def test_static_and_apc_managers_are_mutually_exclusive(self): + with self.assertRaisesRegex(ValueError, "mutually exclusive"): + _make_config( + use_hybrid_cache_manager=True, + use_hybrid_apc_manager=True, + ) + + +class TestRoPEConfig(unittest.TestCase): + """Test partial RoPE configuration.""" + + def test_partial_rotary_factor(self): + config = _make_config() + self.assertAlmostEqual(config.partial_rotary_factor, 0.25) + + def test_rope_dim(self): + """rope_dim = head_dim * partial_rotary_factor = 256 * 0.25 = 64.""" + config = _make_config() + self.assertEqual(config.rope_dim, 64) + + def test_attn_output_gate(self): + config = _make_config() + self.assertTrue(config.attn_output_gate) + + def test_mrope_section(self): + config = _make_config() + self.assertEqual(config.mrope_section, [11, 11, 10]) + + def test_mrope_interleaved(self): + config = _make_config() + self.assertTrue(config.mrope_interleaved) + + +class TestNeuronConfig(unittest.TestCase): + """Test Neuron-specific configuration settings.""" + + def test_neuron_config_cls(self): + """Qwen3.5/3.6-27B is dense -- uses NeuronConfig, NOT MoENeuronConfig.""" + self.assertEqual( + Qwen35InferenceConfig.get_neuron_config_cls(), + NeuronConfig, + ) + + def test_required_attributes(self): + config = _make_config() + required = config.get_required_attributes() + self.assertIn("hidden_size", required) + self.assertIn("num_hidden_layers", required) + self.assertIn("linear_num_value_heads", required) + self.assertIn("linear_key_head_dim", required) + self.assertIn("layer_types", required) + + def test_output_attentions_default(self): + config = _make_config() + self.assertFalse(config.output_attentions) + + def test_output_hidden_states_default(self): + config = _make_config() + self.assertFalse(config.output_hidden_states) + + +if __name__ == "__main__": + unittest.main() diff --git a/contrib/models/Qwen3.6-27B/test/unit/test_deltanet_decay.py b/contrib/models/Qwen3.6-27B/test/unit/test_deltanet_decay.py new file mode 100644 index 00000000..7b3a894f --- /dev/null +++ b/contrib/models/Qwen3.6-27B/test/unit/test_deltanet_decay.py @@ -0,0 +1,523 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""CPU-only regressions for the fused DeltaNet decay reference math.""" + +import importlib.util +import os +import types +import unittest + +import torch + + +_CONTRIB_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")) +_VALIDATOR_PATH = os.path.join( + _CONTRIB_ROOT, + "scripts", + "validate_deltanet_fused_nki.py", +) + + +def _load_validator(): + spec = importlib.util.spec_from_file_location( + "qwen36_validate_deltanet_fused_nki", + _VALIDATOR_PATH, + ) + module = importlib.util.module_from_spec(spec) + assert spec.loader is not None + spec.loader.exec_module(module) + return module + + +class TestFusedDeltaNetDecayMath(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.validator = _load_validator() + + def test_stable_causal_decay_masks_before_exp(self): + gc = torch.linspace(0.0, -300.0, 128, dtype=torch.float32).reshape(128, 1) + lower = torch.tril(torch.ones((128, 128), dtype=torch.float32), diagonal=-1) + lower_diag = torch.tril(torch.ones((128, 128), dtype=torch.float32)) + + strict_decay = self.validator.stable_causal_decay(torch, gc, lower) + diag_decay = self.validator.stable_causal_decay(torch, gc, lower_diag) + + self.assertTrue(torch.isfinite(strict_decay).all()) + self.assertTrue(torch.isfinite(diag_decay).all()) + self.assertTrue(torch.equal(strict_decay.triu(), torch.zeros_like(strict_decay.triu()))) + torch.testing.assert_close(torch.diagonal(diag_decay), torch.ones(128)) + + def test_reference_math_is_finite_for_realistic_gate_scale(self): + args = types.SimpleNamespace( + seed=1234, + seq_len=256, + value_scale=0.05, + state_scale=0.01, + gate_scale=1.0, + ) + + inputs = self.validator.make_inputs(torch, args) + output, state = self.validator.reference_math(torch, inputs) + + self.assertTrue(torch.isfinite(output).all()) + self.assertTrue(torch.isfinite(state).all()) + + def test_autocp_affine_chunk_matches_current_reference(self): + args = types.SimpleNamespace( + seed=20260602, + seq_len=128, + value_scale=0.05, + state_scale=0.01, + gate_scale=1.0, + ) + + inputs = self.validator.make_inputs(torch, args) + expected_output, expected_state = self.validator.reference_math(torch, inputs) + parts = self.validator.deltanet_chunk_affine_parts(torch, inputs, 0) + actual_output, actual_state = self.validator.apply_deltanet_chunk_affine( + torch, + parts, + inputs["state_in"], + ) + + torch.testing.assert_close( + actual_output, + expected_output, + atol=2.0e-5, + rtol=2.0e-5, + ) + torch.testing.assert_close( + actual_state, + expected_state, + atol=2.0e-5, + rtol=2.0e-5, + ) + + def test_autocp_reference_matches_current_reference(self): + args = types.SimpleNamespace( + seed=20260602, + seq_len=1024, + value_scale=0.05, + state_scale=0.01, + gate_scale=1.0, + ) + + inputs = self.validator.make_inputs(torch, args) + expected_output, expected_state = self.validator.reference_math(torch, inputs) + for cp_chunks in (1, 2, 4, 8): + actual_output, actual_state = self.validator.autocp_reference_math( + torch, + inputs, + cp_chunks=cp_chunks, + ) + + torch.testing.assert_close( + actual_output, + expected_output, + atol=2.0e-5, + rtol=2.0e-5, + ) + torch.testing.assert_close( + actual_state, + expected_state, + atol=2.0e-5, + rtol=2.0e-5, + ) + + def test_autocp_reference_matches_current_reference_multihead(self): + args = types.SimpleNamespace( + seed=20260602, + seq_len=512, + heads=3, + multihead=True, + value_scale=0.05, + state_scale=0.01, + gate_scale=1.0, + ) + + inputs = self.validator.make_inputs(torch, args) + expected_output, expected_state = self.validator.reference_math(torch, inputs) + actual_output, actual_state = self.validator.autocp_reference_math( + torch, + inputs, + cp_chunks=2, + ) + + torch.testing.assert_close( + actual_output, + expected_output, + atol=2.0e-5, + rtol=2.0e-5, + ) + torch.testing.assert_close( + actual_state, + expected_state, + atol=2.0e-5, + rtol=2.0e-5, + ) + + def test_compact_autocp_reference_matches_current_reference(self): + args = types.SimpleNamespace( + seed=20260602, + seq_len=1024, + value_scale=0.05, + state_scale=0.01, + gate_scale=1.0, + ) + + inputs = self.validator.make_inputs(torch, args) + expected_output, expected_state = self.validator.reference_math(torch, inputs) + for cp_chunks in (1, 2, 4, 8): + actual_output, actual_state = self.validator.compact_autocp_reference_math( + torch, + inputs, + cp_chunks=cp_chunks, + ) + + torch.testing.assert_close( + actual_output, + expected_output, + atol=2.0e-5, + rtol=2.0e-5, + ) + torch.testing.assert_close( + actual_state, + expected_state, + atol=2.0e-5, + rtol=2.0e-5, + ) + + def test_compact_autocp_reference_matches_current_reference_multihead(self): + args = types.SimpleNamespace( + seed=20260602, + seq_len=512, + heads=3, + multihead=True, + value_scale=0.05, + state_scale=0.01, + gate_scale=1.0, + ) + + inputs = self.validator.make_inputs(torch, args) + expected_output, expected_state = self.validator.reference_math(torch, inputs) + for cp_chunks in (1, 2, 4): + actual_output, actual_state = self.validator.compact_autocp_reference_math( + torch, + inputs, + cp_chunks=cp_chunks, + ) + + torch.testing.assert_close( + actual_output, + expected_output, + atol=2.0e-5, + rtol=2.0e-5, + ) + torch.testing.assert_close( + actual_state, + expected_state, + atol=2.0e-5, + rtol=2.0e-5, + ) + + def test_reference_qk_normalization_is_zero_safe(self): + generator = torch.Generator(device="cpu") + generator.manual_seed(20260601) + query = torch.randn((4, 128), generator=generator) * 0.05 + key = torch.randn((4, 128), generator=generator) * 0.05 + query[0].zero_() + key[1].zero_() + + query_norm, key_norm = self.validator.normalize_reference_qk( + torch, + query, + key, + ) + + self.assertTrue(torch.isfinite(query_norm).all()) + self.assertTrue(torch.isfinite(key_norm).all()) + torch.testing.assert_close(query_norm[0], torch.zeros_like(query_norm[0])) + torch.testing.assert_close(key_norm[1], torch.zeros_like(key_norm[1])) + torch.testing.assert_close( + torch.linalg.vector_norm(query_norm[2]), + torch.tensor(self.validator.P_MAX ** -0.5), + atol=1.0e-6, + rtol=1.0e-6, + ) + torch.testing.assert_close( + torch.linalg.vector_norm(key_norm[2]), + torch.tensor(1.0), + atol=1.0e-6, + rtol=1.0e-6, + ) + + def test_multihead_launch_spec_rejects_head_group_size_above_lnc_when_spmd_disabled(self): + previous = os.environ.get("QWEN36_DELTANET_MULTIHEAD_SPMD") + os.environ["QWEN36_DELTANET_MULTIHEAD_SPMD"] = "0" + try: + with self.assertRaisesRegex(ValueError, "head-group-size exceeds --lnc"): + self.validator.multihead_launch_spec(num_heads=2, lnc=1) + finally: + if previous is None: + os.environ.pop("QWEN36_DELTANET_MULTIHEAD_SPMD", None) + else: + os.environ["QWEN36_DELTANET_MULTIHEAD_SPMD"] = previous + + def test_blocked_triangular_solve_matches_torch_solve(self): + generator = torch.Generator(device="cpu") + generator.manual_seed(20260601) + strict_lower = torch.tril( + torch.randn((128, 128), generator=generator) * 0.01, + diagonal=-1, + ) + lhs = torch.eye(128) + strict_lower + rhs = torch.randn((128, 128), generator=generator) * 0.05 + + expected = torch.linalg.solve_triangular(lhs, rhs, upper=False) + for block_size in (8, 16, 32): + actual = self.validator.blocked_lower_triangular_solve( + torch, + lhs, + rhs, + block_size, + ) + torch.testing.assert_close(actual, expected, atol=2.0e-5, rtol=2.0e-5) + + def test_block_prefix_triangular_solve_matches_torch_solve(self): + generator = torch.Generator(device="cpu") + generator.manual_seed(20260602) + strict_lower = torch.tril( + torch.randn((128, 128), generator=generator) * 0.05, + diagonal=-1, + ) + lhs = torch.eye(128) + strict_lower + rhs = torch.randn((128, 128), generator=generator) * 0.05 + + expected = torch.linalg.solve_triangular(lhs, rhs, upper=False) + for block_size in (16, 32, 64): + actual = self.validator.block_prefix_lower_triangular_solve( + torch, + lhs, + rhs, + block_size, + ) + torch.testing.assert_close(actual, expected, atol=2.0e-5, rtol=2.0e-5) + + def test_hierarchical_kkt_triangular_solve_matches_torch_solve(self): + generator = torch.Generator(device="cpu") + generator.manual_seed(20260602) + strict_lower = torch.tril( + torch.randn((128, 128), generator=generator) * 0.01, + diagonal=-1, + ) + lhs = torch.eye(128) + strict_lower + rhs = torch.randn((128, 128), generator=generator) * 0.05 + + expected = torch.linalg.solve_triangular(lhs, rhs, upper=False) + for leaf_size in (8, 16, 32): + actual = self.validator.hierarchical_kkt_lower_triangular_solve( + torch, + lhs, + rhs, + leaf_size, + ) + torch.testing.assert_close(actual, expected, atol=2.0e-5, rtol=2.0e-5) + + def test_two_step_doubling_solve_matches_realistic_chunks(self): + args = types.SimpleNamespace( + seed=20260601, + seq_len=512, + heads=4, + multihead=True, + value_scale=0.05, + state_scale=0.01, + gate_scale=1.0, + ) + + inputs = self.validator.make_inputs(torch, args) + lower = inputs["lower_mask"] + eye = inputs["identity"] + + max_relative_norm = 0.0 + max_absolute = 0.0 + for head_idx in range(args.heads): + state = inputs["state_in"][head_idx].clone() + for start in range(0, args.seq_len, self.validator.P_MAX): + end = start + self.validator.P_MAX + _, key = self.validator.normalize_reference_qk( + torch, + inputs["query"][head_idx, start:end], + inputs["key"][head_idx, start:end], + ) + value = inputs["value"][head_idx, start:end] + g = inputs["g_raw"][head_idx, start:end] + beta = inputs["beta"][head_idx, start:end] + + gc = torch.cumsum(g, dim=0) + k_beta = key * beta + v_beta = value * beta + decay = self.validator.stable_causal_decay(torch, gc, lower) + a_mat = -((k_beta @ key.T) * decay) * lower + lhs = eye - a_mat + rhs = v_beta - ((k_beta * torch.exp(gc)) @ state) + + expected = torch.linalg.solve_triangular(lhs, rhs, upper=False) + actual = self.validator.scan_doubling_lower_triangular_solve( + torch, + lhs, + rhs, + steps=2, + ) + + diff = actual - expected + max_relative_norm = max( + max_relative_norm, + torch.linalg.vector_norm(diff).item() + / torch.linalg.vector_norm(expected).item(), + ) + max_absolute = max(max_absolute, diff.abs().max().item()) + + gl = gc[-1:] + key_decay = key * torch.exp(gl - gc) + state = (state * torch.exp(gl)) + (key_decay.T @ expected) + + self.assertLess(max_relative_norm, 5.0e-6) + self.assertLess(max_absolute, 2.0e-6) + + def test_two_step_doubling_solve_truncates_weak_decay_chunks(self): + args = types.SimpleNamespace( + seed=1241, + seq_len=512, + heads=4, + multihead=True, + value_scale=0.05, + state_scale=0.01, + gate_scale=0.01, + ) + + inputs = self.validator.make_inputs(torch, args) + lower = inputs["lower_mask"] + eye = inputs["identity"] + head_idx = 3 + state = inputs["state_in"][head_idx].clone() + rel_scan2 = None + rel_scan7 = None + + for start in range(0, args.seq_len, self.validator.P_MAX): + end = start + self.validator.P_MAX + _, key = self.validator.normalize_reference_qk( + torch, + inputs["query"][head_idx, start:end], + inputs["key"][head_idx, start:end], + ) + value = inputs["value"][head_idx, start:end] + g = inputs["g_raw"][head_idx, start:end] + beta = inputs["beta"][head_idx, start:end] + + gc = torch.cumsum(g, dim=0) + k_beta = key * beta + v_beta = value * beta + decay = self.validator.stable_causal_decay(torch, gc, lower) + a_mat = -((k_beta @ key.T) * decay) * lower + lhs = eye - a_mat + rhs = v_beta - ((k_beta * torch.exp(gc)) @ state) + + expected = torch.linalg.solve_triangular(lhs, rhs, upper=False) + if start == 256: + scan2 = self.validator.scan_doubling_lower_triangular_solve( + torch, + lhs, + rhs, + steps=2, + ) + scan7 = self.validator.scan_doubling_lower_triangular_solve( + torch, + lhs, + rhs, + steps=7, + ) + rel_scan2 = ( + torch.linalg.vector_norm(scan2 - expected) + / torch.linalg.vector_norm(expected) + ).item() + rel_scan7 = ( + torch.linalg.vector_norm(scan7 - expected) + / torch.linalg.vector_norm(expected) + ).item() + break + + gl = gc[-1:] + key_decay = key * torch.exp(gl - gc) + state = (state * torch.exp(gl)) + (key_decay.T @ expected) + + self.assertIsNotNone(rel_scan2) + self.assertIsNotNone(rel_scan7) + self.assertGreater(rel_scan2, 5.0e-3) + self.assertLess(rel_scan7, 5.0e-6) + + def test_blocked_reference_matches_current_reference(self): + args = types.SimpleNamespace( + seed=1234, + seq_len=256, + value_scale=0.05, + state_scale=0.01, + gate_scale=1.0, + ) + + inputs = self.validator.make_inputs(torch, args) + expected_output, expected_state = self.validator.reference_math(torch, inputs) + actual_output, actual_state = self.validator.blocked_reference_math( + torch, + inputs, + block_size=16, + ) + + torch.testing.assert_close( + actual_output, + expected_output, + atol=2.0e-5, + rtol=2.0e-5, + ) + torch.testing.assert_close( + actual_state, + expected_state, + atol=2.0e-5, + rtol=2.0e-5, + ) + + def test_blocked_reference_matches_current_reference_multihead(self): + args = types.SimpleNamespace( + seed=1234, + seq_len=256, + heads=4, + multihead=True, + value_scale=0.05, + state_scale=0.01, + gate_scale=1.0, + ) + + inputs = self.validator.make_inputs(torch, args) + expected_output, expected_state = self.validator.reference_math(torch, inputs) + actual_output, actual_state = self.validator.blocked_reference_math( + torch, + inputs, + block_size=16, + ) + + torch.testing.assert_close( + actual_output, + expected_output, + atol=2.0e-5, + rtol=2.0e-5, + ) + torch.testing.assert_close( + actual_state, + expected_state, + atol=2.0e-5, + rtol=2.0e-5, + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/contrib/models/Qwen3.6-27B/test/unit/test_hybrid_apc_manager.py b/contrib/models/Qwen3.6-27B/test/unit/test_hybrid_apc_manager.py new file mode 100644 index 00000000..29c8b5e3 --- /dev/null +++ b/contrib/models/Qwen3.6-27B/test/unit/test_hybrid_apc_manager.py @@ -0,0 +1,1790 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +import os +import sys +import unittest +import importlib.util +import types +from unittest.mock import patch + +import torch + + +_CONTRIB_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")) +if _CONTRIB_ROOT not in sys.path: + sys.path.insert(0, _CONTRIB_ROOT) +_VLLM_ROOT = os.path.join(_CONTRIB_ROOT, "vllm") +if _VLLM_ROOT not in sys.path: + sys.path.insert(0, _VLLM_ROOT) + +_HYBRID_APC_PATH = os.path.join(_CONTRIB_ROOT, "src", "hybrid_apc.py") +_SPEC = importlib.util.spec_from_file_location("qwen36_hybrid_apc", _HYBRID_APC_PATH) +_HYBRID_APC = importlib.util.module_from_spec(_SPEC) +sys.modules[_SPEC.name] = _HYBRID_APC +_SPEC.loader.exec_module(_HYBRID_APC) + +HybridAPCMetadataStore = _HYBRID_APC.HybridAPCMetadataStore +HybridAPCHitPlan = _HYBRID_APC.HybridAPCHitPlan +HybridAPCSchedulerBridge = _HYBRID_APC.HybridAPCSchedulerBridge +HybridAPCSlotAllocator = _HYBRID_APC.HybridAPCSlotAllocator +apply_hybrid_apc_prefill_plan = _HYBRID_APC.apply_hybrid_apc_prefill_plan +apply_hybrid_apc_suffix_prefill_plan = ( + _HYBRID_APC.apply_hybrid_apc_suffix_prefill_plan +) +build_cumulative_prefix_hashes = _HYBRID_APC.build_cumulative_prefix_hashes +estimate_qwen_gdn_checkpoint_bytes_per_rank = ( + _HYBRID_APC.estimate_qwen_gdn_checkpoint_bytes_per_rank +) +estimate_qwen_hybrid_cache_bytes_per_rank = ( + _HYBRID_APC.estimate_qwen_hybrid_cache_bytes_per_rank +) +import qwen36_hybrid_apc_scheduler_patch as _SCHEDULER_PATCH # noqa: E402 +from neuronx_distributed_inference.modules.async_execution import ( # noqa: E402 + _combine_vectorized_hybrid_apc_inputs, + finish_hybrid_apc_request, + prepare_hybrid_apc_request_for_execution, +) + + +def _store(**overrides): + defaults = dict( + required_gdn_layers=[0, 1, 2], + block_size=128, + checkpoint_interval=128, + model_revision="rev-a", + layout_version=1, + tp_rank=0, + recurrent_dtype="float32", + conv_dtype="bfloat16", + ) + defaults.update(overrides) + return HybridAPCMetadataStore(**defaults) + + +def _insert(store, prefix_len, prefix_hash=None, **overrides): + key = store.make_key( + cumulative_prefix_hash=prefix_hash or f"h{prefix_len}", + prefix_len=prefix_len, + cache_salt=overrides.pop("cache_salt", "tenant-a"), + model_revision=overrides.pop("model_revision", "rev-a"), + layout_version=overrides.pop("layout_version", 1), + tp_rank=overrides.pop("tp_rank", 0), + recurrent_dtype=overrides.pop("recurrent_dtype", "float32"), + conv_dtype=overrides.pop("conv_dtype", "bfloat16"), + ) + checkpoint = store.insert( + key=key, + attention_block_refs=overrides.pop("attention_block_refs", range(prefix_len // 128)), + gdn_checkpoint_slot=overrides.pop("gdn_checkpoint_slot", prefix_len // 128), + **overrides, + ) + return key, checkpoint + + +class TestHybridAPCMetadataStore(unittest.TestCase): + def test_same_prefix_hash_and_salt_hits(self): + store = _store() + key, _checkpoint = _insert(store, 256) + + plan = store.compute_hit_plan( + cumulative_hashes_by_prefix_len={128: "h128", 256: "h256"}, + attention_hit_len=256, + request_prefix_len=300, + cache_salt="tenant-a", + model_revision="rev-a", + ) + + self.assertEqual(plan.checkpoint_key, key) + self.assertEqual(plan.restore_checkpoint_prefix_len, 256) + self.assertEqual(plan.usable_hit_len, 256) + self.assertEqual(plan.residual_replay_len, 0) + self.assertEqual(plan.suffix_len, 44) + + def test_same_tokens_with_different_salt_misses(self): + store = _store() + _insert(store, 128, prefix_hash="same") + + plan = store.compute_hit_plan( + cumulative_hashes_by_prefix_len={128: "same"}, + attention_hit_len=128, + request_prefix_len=128, + cache_salt="tenant-b", + model_revision="rev-a", + ) + + self.assertIsNone(plan.checkpoint_key) + self.assertEqual(plan.usable_hit_len, 0) + + def test_same_last_block_with_different_cumulative_hash_misses(self): + store = _store() + _insert(store, 256, prefix_hash="parent-a+block-z") + + plan = store.compute_hit_plan( + cumulative_hashes_by_prefix_len={256: "parent-b+block-z"}, + attention_hit_len=256, + request_prefix_len=256, + cache_salt="tenant-a", + model_revision="rev-a", + ) + + self.assertIsNone(plan.checkpoint_key) + + def test_missing_recurrent_layer_invalidates_hit(self): + store = _store() + key, _checkpoint = _insert(store, 128) + store.mark_invalid(key, state_kind="recurrent", layer_id=1) + + plan = store.compute_hit_plan( + cumulative_hashes_by_prefix_len={128: "h128"}, + attention_hit_len=128, + request_prefix_len=128, + cache_salt="tenant-a", + model_revision="rev-a", + ) + + self.assertIsNone(plan.checkpoint_key) + + def test_missing_conv_layer_invalidates_hit(self): + store = _store() + key, _checkpoint = _insert(store, 128) + store.mark_invalid(key, state_kind="conv", layer_id=2) + + plan = store.compute_hit_plan( + cumulative_hashes_by_prefix_len={128: "h128"}, + attention_hit_len=128, + request_prefix_len=128, + cache_salt="tenant-a", + model_revision="rev-a", + ) + + self.assertIsNone(plan.checkpoint_key) + + def test_dtype_layout_and_revision_are_identity(self): + store = _store() + _insert(store, 128) + + for kwargs in ( + {"recurrent_dtype": "bfloat16"}, + {"conv_dtype": "float32"}, + {"layout_version": 2}, + {"model_revision": "rev-b"}, + ): + with self.subTest(kwargs=kwargs): + plan = store.compute_hit_plan( + cumulative_hashes_by_prefix_len={128: "h128"}, + attention_hit_len=128, + request_prefix_len=128, + cache_salt="tenant-a", + **kwargs, + ) + self.assertIsNone(plan.checkpoint_key) + + def test_refcount_blocks_lru_eviction(self): + store = _store(max_checkpoints=2) + key128, _ = _insert(store, 128) + key256, _ = _insert(store, 256) + store.inc_ref(key128) + key384, _ = _insert(store, 384) + + self.assertIsNotNone(store.lookup(key128)) + self.assertIsNone(store.lookup(key256)) + self.assertIsNotNone(store.lookup(key384)) + + def test_evicting_gdn_checkpoint_makes_hybrid_hit_fallback(self): + store = _store() + _key, checkpoint = _insert(store, 128) + store.on_gdn_checkpoint_evicted(checkpoint.gdn_checkpoint_slot) + + plan = store.compute_hit_plan( + cumulative_hashes_by_prefix_len={128: "h128"}, + attention_hit_len=128, + request_prefix_len=128, + cache_salt="tenant-a", + model_revision="rev-a", + ) + + self.assertIsNone(plan.checkpoint_key) + + def test_evicting_attention_block_makes_hybrid_hit_fallback(self): + store = _store() + key, _checkpoint = _insert(store, 256, attention_block_refs=(7, 8)) + invalidated = store.on_attention_block_evicted(8) + + self.assertEqual(invalidated, [key]) + plan = store.compute_hit_plan( + cumulative_hashes_by_prefix_len={256: "h256"}, + attention_hit_len=256, + request_prefix_len=256, + cache_salt="tenant-a", + model_revision="rev-a", + ) + self.assertIsNone(plan.checkpoint_key) + + def test_non_block_aligned_prompt_uses_checkpoint_boundary_in_v0(self): + store = _store(allow_residual_replay=False) + _insert(store, 256) + + plan = store.compute_hit_plan( + cumulative_hashes_by_prefix_len={256: "h256"}, + attention_hit_len=300, + request_prefix_len=384, + cache_salt="tenant-a", + model_revision="rev-a", + ) + + self.assertEqual(plan.usable_hit_len, 256) + self.assertEqual(plan.restore_checkpoint_prefix_len, 256) + self.assertEqual(plan.residual_replay_len, 0) + self.assertEqual(plan.suffix_len, 128) + + def test_residual_replay_requires_explicit_enablement(self): + store = _store(allow_residual_replay=True) + _insert(store, 256) + + plan = store.compute_hit_plan( + cumulative_hashes_by_prefix_len={256: "h256"}, + attention_hit_len=300, + request_prefix_len=384, + cache_salt="tenant-a", + model_revision="rev-a", + ) + + self.assertEqual(plan.usable_hit_len, 300) + self.assertEqual(plan.restore_checkpoint_prefix_len, 256) + self.assertEqual(plan.residual_replay_len, 44) + self.assertEqual(plan.suffix_len, 84) + + def test_request_lifecycle_releases_restored_ref_on_finish(self): + store = _store() + key, _checkpoint = _insert(store, 128) + + record = store.on_request_restore(request_id="req-1", checkpoint_key=key) + self.assertEqual(record.state, "RESTORED_FROM_HYBRID_APC") + self.assertEqual(store.lookup(key).refcount, 1) + + store.on_prefill_running("req-1") + store.on_decode_running("req-1") + finished = store.on_request_finish("req-1") + + self.assertEqual(finished.state, "FINISHED") + self.assertEqual(store.lookup(key).refcount, 0) + + def test_request_cancel_releases_ref_and_drops_pending_commit(self): + store = _store() + restored_key, _checkpoint = _insert(store, 128) + committed_key, _checkpoint = _insert(store, 256) + + store.on_request_restore(request_id="req-1", checkpoint_key=restored_key) + store.on_checkpoint_committed( + request_id="req-1", + checkpoint_key=committed_key, + ) + cancelled = store.on_request_cancel("req-1") + + self.assertEqual(cancelled.state, "CANCELLED") + self.assertEqual(store.lookup(restored_key).refcount, 0) + self.assertIsNone(store.lookup(committed_key)) + + def test_qwen_hbm_estimator_uses_checkpoint_slots_not_token_slots(self): + per_checkpoint = estimate_qwen_gdn_checkpoint_bytes_per_rank() + totals = estimate_qwen_hybrid_cache_bytes_per_rank( + max_context_len=1024, + checkpoint_interval=256, + ) + + self.assertEqual(totals["num_gdn_checkpoints"], 4) + self.assertEqual(totals["gdn_checkpoint_bytes"], per_checkpoint * 4) + self.assertGreater(totals["gdn_checkpoint_bytes"], totals["attention_kv_bytes"]) + + +class TestHybridAPCPrefillPlanInputs(unittest.TestCase): + def test_prefill_plan_materializes_suffix_restore_and_commit(self): + plan = HybridAPCHitPlan( + attention_hit_len=2, + recurrent_hit_len=2, + conv_hit_len=2, + usable_hit_len=2, + restore_checkpoint_prefix_len=2, + residual_replay_len=0, + suffix_len=3, + checkpoint_slot=5, + checkpoint_key=None, + ) + input_dict = { + "input_ids": torch.tensor([[10, 11, 12, 13, 14]], dtype=torch.int32), + "attention_mask": torch.tensor([[1, 1, 1, 1, 1]], dtype=torch.int32), + "position_ids": torch.arange(5, dtype=torch.int32).unsqueeze(0), + "slot_mapping": torch.tensor([[0, 1, 2, 3, 4]], dtype=torch.int32), + } + + output = apply_hybrid_apc_prefill_plan( + input_dict, + plan=plan, + commit_slot=7, + ) + + self.assertTrue( + torch.equal(output["input_ids"], torch.tensor([[12, 13, 14]], dtype=torch.int32)) + ) + self.assertTrue( + torch.equal(output["attention_mask"], torch.tensor([[1, 1, 1]], dtype=torch.int32)) + ) + self.assertTrue( + torch.equal(output["position_ids"], torch.tensor([[2, 3, 4]], dtype=torch.int32)) + ) + self.assertTrue( + torch.equal(output["slot_mapping"], torch.tensor([[2, 3, 4]], dtype=torch.int32)) + ) + self.assertTrue( + torch.equal(output["computed_context_lens"], torch.tensor([[2]], dtype=torch.int32)) + ) + self.assertTrue( + torch.equal(output["full_context_lens"], torch.tensor([[5]], dtype=torch.int32)) + ) + self.assertTrue( + torch.equal(output["num_queries"], torch.tensor([[3]], dtype=torch.int32)) + ) + self.assertTrue( + torch.equal(output["hybrid_restore_slot_ids"], torch.tensor([5], dtype=torch.int32)) + ) + self.assertTrue( + torch.equal(output["hybrid_restore_mask"], torch.tensor([1], dtype=torch.int32)) + ) + self.assertTrue( + torch.equal(output["hybrid_restore_prefix_lens"], torch.tensor([2], dtype=torch.int32)) + ) + self.assertTrue( + torch.equal(output["hybrid_commit_slot_ids"], torch.tensor([7], dtype=torch.int32)) + ) + self.assertTrue( + torch.equal(output["hybrid_commit_mask"], torch.tensor([1], dtype=torch.int32)) + ) + + def test_prefill_plan_does_not_restore_without_checkpoint_slot(self): + plan = HybridAPCHitPlan( + attention_hit_len=0, + recurrent_hit_len=0, + conv_hit_len=0, + usable_hit_len=0, + restore_checkpoint_prefix_len=0, + residual_replay_len=0, + suffix_len=3, + checkpoint_slot=None, + checkpoint_key=None, + ) + input_dict = { + "input_ids": torch.tensor([[10, 11, 12]], dtype=torch.int32), + "position_ids": torch.arange(3, dtype=torch.int32).unsqueeze(0), + } + + output = apply_hybrid_apc_prefill_plan(input_dict, plan=plan) + + self.assertTrue( + torch.equal(output["input_ids"], torch.tensor([[10, 11, 12]], dtype=torch.int32)) + ) + self.assertTrue( + torch.equal(output["hybrid_restore_slot_ids"], torch.tensor([0], dtype=torch.int32)) + ) + self.assertTrue( + torch.equal(output["hybrid_restore_mask"], torch.tensor([0], dtype=torch.int32)) + ) + self.assertTrue( + torch.equal(output["hybrid_commit_mask"], torch.tensor([0], dtype=torch.int32)) + ) + + def test_prefill_plan_rejects_restore_boundary_without_checkpoint_slot(self): + plan = HybridAPCHitPlan( + attention_hit_len=2, + recurrent_hit_len=0, + conv_hit_len=0, + usable_hit_len=0, + restore_checkpoint_prefix_len=2, + residual_replay_len=0, + suffix_len=1, + checkpoint_slot=None, + checkpoint_key=None, + ) + + with self.assertRaisesRegex(ValueError, "requires a checkpoint slot"): + apply_hybrid_apc_prefill_plan( + {"input_ids": torch.tensor([[10, 11, 12]], dtype=torch.int32)}, + plan=plan, + ) + + def test_prefill_plan_uses_restore_boundary_for_residual_replay(self): + plan = HybridAPCHitPlan( + attention_hit_len=5, + recurrent_hit_len=4, + conv_hit_len=4, + usable_hit_len=5, + restore_checkpoint_prefix_len=4, + residual_replay_len=1, + suffix_len=2, + checkpoint_slot=9, + checkpoint_key=None, + ) + input_dict = { + "input_ids": torch.tensor([[10, 11, 12, 13, 14, 15, 16]], dtype=torch.int32), + "position_ids": torch.arange(7, dtype=torch.int32).unsqueeze(0), + } + + output = apply_hybrid_apc_prefill_plan( + input_dict, + plan=plan, + commit_slot=10, + ) + + self.assertTrue( + torch.equal(output["input_ids"], torch.tensor([[14, 15, 16]], dtype=torch.int32)) + ) + self.assertTrue( + torch.equal(output["position_ids"], torch.tensor([[4, 5, 6]], dtype=torch.int32)) + ) + self.assertTrue( + torch.equal(output["computed_context_lens"], torch.tensor([[4]], dtype=torch.int32)) + ) + self.assertTrue( + torch.equal(output["num_queries"], torch.tensor([[3]], dtype=torch.int32)) + ) + + def test_prefill_plan_slices_flattened_slot_mapping_with_suffix(self): + plan = HybridAPCHitPlan( + attention_hit_len=5, + recurrent_hit_len=4, + conv_hit_len=4, + usable_hit_len=5, + restore_checkpoint_prefix_len=4, + residual_replay_len=1, + suffix_len=2, + checkpoint_slot=9, + checkpoint_key=None, + ) + input_dict = { + "input_ids": torch.tensor([[10, 11, 12, 13, 14, 15, 16]], dtype=torch.int32), + "position_ids": torch.arange(7, dtype=torch.int32).unsqueeze(0), + "slot_mapping": torch.arange(100, 107, dtype=torch.int32), + } + + output = apply_hybrid_apc_prefill_plan( + input_dict, + plan=plan, + commit_slot=10, + ) + + self.assertTrue( + torch.equal(output["slot_mapping"], torch.tensor([104, 105, 106], dtype=torch.int32)) + ) + + def test_prefill_plan_synthesizes_padding_suffix_slots_from_block_table(self): + plan = HybridAPCHitPlan( + attention_hit_len=4, + recurrent_hit_len=4, + conv_hit_len=4, + usable_hit_len=4, + restore_checkpoint_prefix_len=4, + residual_replay_len=0, + suffix_len=3, + checkpoint_slot=9, + checkpoint_key=None, + ) + input_dict = { + "input_ids": torch.tensor([[10, 11, 12, 13, 14, 15, 16]], dtype=torch.int32), + "position_ids": torch.arange(7, dtype=torch.int32).unsqueeze(0), + "slot_mapping": torch.full((1, 7), -1, dtype=torch.int32), + "block_table": torch.tensor([[1, 3, 4, 5]], dtype=torch.int32), + } + + output = apply_hybrid_apc_prefill_plan( + input_dict, + plan=plan, + commit_slot=10, + block_size=4, + ) + + self.assertTrue( + torch.equal(output["slot_mapping"], torch.tensor([[12, 13, 14]], dtype=torch.int32)) + ) + + def test_prefill_plan_repairs_negative_active_slots_from_block_table(self): + plan = HybridAPCHitPlan( + attention_hit_len=4, + recurrent_hit_len=0, + conv_hit_len=0, + usable_hit_len=0, + restore_checkpoint_prefix_len=0, + residual_replay_len=0, + suffix_len=6, + checkpoint_slot=None, + checkpoint_key=None, + ) + input_dict = { + "input_ids": torch.tensor([[10, 11, 12, 13, 14, 15]], dtype=torch.int32), + "position_ids": torch.arange(6, dtype=torch.int32).unsqueeze(0), + "slot_mapping": torch.tensor([[-1, -1, 10, 11, 12, 13]], dtype=torch.int32), + "block_table": torch.tensor([[2, 3]], dtype=torch.int32), + } + + output = apply_hybrid_apc_prefill_plan( + input_dict, + plan=plan, + block_size=4, + ) + + self.assertTrue( + torch.equal( + output["slot_mapping"], + torch.tensor([[8, 9, 10, 11, 12, 13]], dtype=torch.int32), + ) + ) + + def test_prefill_plan_repairs_too_short_active_slots_from_block_table(self): + plan = HybridAPCHitPlan( + attention_hit_len=0, + recurrent_hit_len=0, + conv_hit_len=0, + usable_hit_len=0, + restore_checkpoint_prefix_len=0, + residual_replay_len=0, + suffix_len=6, + checkpoint_slot=None, + checkpoint_key=None, + ) + input_dict = { + "input_ids": torch.tensor([[10, 11, 12, 13, 14, 15]], dtype=torch.int32), + "position_ids": torch.arange(6, dtype=torch.int32).unsqueeze(0), + "slot_mapping": torch.tensor([12], dtype=torch.int32), + "block_table": torch.tensor([[2, 3]], dtype=torch.int32), + } + + output = apply_hybrid_apc_prefill_plan( + input_dict, + plan=plan, + block_size=4, + ) + + self.assertTrue( + torch.equal( + output["slot_mapping"], + torch.tensor([[8, 9, 10, 11, 12, 13]], dtype=torch.int32), + ) + ) + + def test_prefill_plan_rebuilds_unbacked_attention_hit_slots(self): + plan = HybridAPCHitPlan( + attention_hit_len=4, + recurrent_hit_len=0, + conv_hit_len=0, + usable_hit_len=0, + restore_checkpoint_prefix_len=0, + residual_replay_len=0, + suffix_len=6, + checkpoint_slot=None, + checkpoint_key=None, + ) + input_dict = { + "input_ids": torch.tensor([[10, 11, 12, 13, 14, 15]], dtype=torch.int32), + "position_ids": torch.arange(6, dtype=torch.int32).unsqueeze(0), + "slot_mapping": torch.tensor([[12, 13, 14, 15, 16, 17]], dtype=torch.int32), + "block_table": torch.tensor([[2, 3]], dtype=torch.int32), + } + + output = apply_hybrid_apc_prefill_plan( + input_dict, + plan=plan, + block_size=4, + ) + + self.assertTrue( + torch.equal( + output["slot_mapping"], + torch.tensor([[8, 9, 10, 11, 12, 13]], dtype=torch.int32), + ) + ) + + def test_suffix_prefill_plan_preserves_active_block_table(self): + plan = HybridAPCHitPlan( + attention_hit_len=4, + recurrent_hit_len=4, + conv_hit_len=4, + usable_hit_len=4, + restore_checkpoint_prefix_len=4, + residual_replay_len=0, + suffix_len=2, + checkpoint_slot=1, + checkpoint_key=None, + ) + block_table = torch.tensor([[7, 8, 9, 10]], dtype=torch.int32) + + output = apply_hybrid_apc_suffix_prefill_plan( + { + "input_ids": torch.tensor([[14, 15]], dtype=torch.int32), + "block_table": block_table, + }, + plan=plan, + request_prefix_len=6, + attention_block_refs=(7,), + ) + + self.assertTrue(torch.equal(output["block_table"], block_table)) + + +class TestHybridAPCVectorizedInputCombiner(unittest.TestCase): + def test_backed_restore_uses_full_attention_mask_with_bucketed_suffix(self): + neuron_config = types.SimpleNamespace( + context_encoding_buckets=[256, 512, 1024, 2048, 4096], + pa_block_size=256, + seq_len=4096, + ) + model = types.SimpleNamespace( + neuron_config=neuron_config, + config=types.SimpleNamespace(neuron_config=neuron_config), + ) + rows = [] + for row_idx in range(2): + rows.append( + { + "input_ids": torch.arange(16, dtype=torch.int32).unsqueeze(0), + "attention_mask": torch.ones((1, 16), dtype=torch.int32), + "position_ids": torch.arange( + 256, + 272, + dtype=torch.int32, + ).unsqueeze(0), + "slot_mapping": torch.arange( + row_idx * 100, + row_idx * 100 + 16, + dtype=torch.int32, + ).unsqueeze(0), + "block_table": torch.tensor([[17 + row_idx]], dtype=torch.int32), + "full_context_lens": torch.tensor([[272]], dtype=torch.int32), + "computed_context_lens": torch.tensor([[256]], dtype=torch.int32), + "num_queries": torch.tensor([[16]], dtype=torch.int32), + "hybrid_restore_mask": torch.tensor([1], dtype=torch.int32), + } + ) + + combined = _combine_vectorized_hybrid_apc_inputs(model, {}, rows) + + self.assertEqual(tuple(combined["input_ids"].shape), (2, 256)) + self.assertEqual(tuple(combined["position_ids"].shape), (2, 256)) + self.assertEqual(tuple(combined["slot_mapping"].shape), (2, 256)) + self.assertEqual(tuple(combined["block_table"].shape), (2, 16)) + self.assertEqual(tuple(combined["attention_mask"].shape), (2, 4096)) + self.assertTrue( + torch.equal( + combined["attention_mask"][:, :272], + torch.ones((2, 272), dtype=torch.int32), + ) + ) + self.assertTrue( + torch.equal( + combined["attention_mask"][:, 272:], + torch.zeros((2, 3824), dtype=torch.int32), + ) + ) + self.assertTrue( + torch.equal( + combined["slot_mapping"][0, :16], + torch.arange(16, dtype=torch.int32), + ) + ) + self.assertTrue( + torch.equal( + combined["slot_mapping"][1, :16], + torch.arange(100, 116, dtype=torch.int32), + ) + ) + self.assertTrue( + torch.equal( + combined["slot_mapping"][:, 16:], + torch.full((2, 240), -1, dtype=torch.int32), + ) + ) + + +class TestHybridAPCSchedulerBridge(unittest.TestCase): + def tearDown(self): + _SCHEDULER_PATCH.clear_hybrid_apc_gdn_checkpoint_registry() + + def test_slot_allocator_validates_lifecycle(self): + allocator = HybridAPCSlotAllocator(num_slots=2) + + with self.assertRaisesRegex(ValueError, "outside"): + allocator.validate_slot_range(2) + with self.assertRaisesRegex(ValueError, "not reserved"): + allocator.mark_committed(1) + + slot = allocator.reserve() + allocator.mark_committed(slot) + + self.assertEqual(allocator.committed_slots, (slot,)) + + def test_cumulative_prefix_hash_includes_parent_prefix(self): + tokens_a = torch.tensor([[1, 2, 3, 4]], dtype=torch.int32) + tokens_b = torch.tensor([[9, 8, 3, 4]], dtype=torch.int32) + + hashes_a = build_cumulative_prefix_hashes(tokens_a, block_size=2) + hashes_b = build_cumulative_prefix_hashes(tokens_b, block_size=2) + + self.assertNotEqual(hashes_a[4], hashes_b[4]) + + def test_bridge_prepares_warm_suffix_and_commits_checkpoint(self): + store = _store() + allocator = HybridAPCSlotAllocator(num_slots=4) + input_ids = torch.arange(256, dtype=torch.int32).unsqueeze(0) + hashes = build_cumulative_prefix_hashes(input_ids, block_size=128) + restored_key, _checkpoint = _insert( + store, + 128, + prefix_hash=hashes[128], + gdn_checkpoint_slot=3, + ) + bridge = HybridAPCSchedulerBridge( + store=store, + slot_allocator=allocator, + cache_salt="tenant-a", + model_revision="rev-a", + reject_unbacked_attention_hits=False, + ) + + prepared = bridge.prepare_request( + request_id="req-warm", + input_dict={ + "input_ids": input_ids, + "attention_mask": torch.ones((1, 256), dtype=torch.int32), + "position_ids": torch.arange(256, dtype=torch.int32).unsqueeze(0), + }, + attention_hit_len=128, + cumulative_hashes_by_prefix_len=hashes, + attention_block_refs_by_prefix_len={256: (11, 12)}, + ) + + self.assertEqual(prepared.plan.restore_checkpoint_prefix_len, 128) + self.assertEqual(prepared.commit_prefix_len, 256) + self.assertEqual(prepared.commit_slot, 0) + self.assertTrue( + torch.equal( + prepared.input_dict["input_ids"], + torch.arange(128, 256, dtype=torch.int32).unsqueeze(0), + ) + ) + self.assertTrue( + torch.equal( + prepared.input_dict["position_ids"], + torch.arange(128, 256, dtype=torch.int32).unsqueeze(0), + ) + ) + self.assertTrue( + torch.equal( + prepared.input_dict["hybrid_restore_mask"], + torch.tensor([1], dtype=torch.int32), + ) + ) + self.assertEqual(store.lookup(restored_key).refcount, 1) + + committed = bridge.commit_prefill(prepared) + self.assertIsNotNone(committed) + self.assertEqual(committed.gdn_checkpoint_slot, 0) + self.assertEqual(committed.attention_block_refs, (11, 12)) + self.assertEqual(allocator.committed_slots, (0,)) + self.assertIsNotNone(store.lookup(prepared.commit_key)) + + fake_scheduler = types.SimpleNamespace( + cache_config=types.SimpleNamespace(block_size=128), + vllm_config=types.SimpleNamespace( + model_config=types.SimpleNamespace( + hf_config=types.SimpleNamespace( + hybrid_apc_model_revision="rev-a", + hybrid_apc_layout_version=1, + hybrid_recurrent_cache_dtype="float32", + hybrid_conv_cache_dtype="bfloat16", + tp_rank=0, + ) + ) + ), + ) + fake_request = types.SimpleNamespace( + prompt_token_ids=list(range(300)), + num_tokens=300, + cache_salt="tenant-a", + ) + self.assertEqual( + _SCHEDULER_PATCH.backed_gdn_prefix_hit_len(fake_scheduler, fake_request), + 256, + ) + + store.mark_invalid(prepared.commit_key, state_kind="conv") + self.assertEqual( + _SCHEDULER_PATCH.backed_gdn_prefix_hit_len(fake_scheduler, fake_request), + 0, + ) + + bridge.finish_request("req-warm") + self.assertEqual(store.lookup(restored_key).refcount, 0) + + def test_bridge_full_input_commit_combines_restored_and_suffix_refs(self): + store = _store() + allocator = HybridAPCSlotAllocator(num_slots=4) + input_ids = torch.arange(256, dtype=torch.int32).unsqueeze(0) + hashes = build_cumulative_prefix_hashes(input_ids, block_size=128) + restored_key, _checkpoint = _insert( + store, + 128, + prefix_hash=hashes[128], + attention_block_refs=(7,), + gdn_checkpoint_slot=3, + ) + bridge = HybridAPCSchedulerBridge( + store=store, + slot_allocator=allocator, + cache_salt="tenant-a", + model_revision="rev-a", + require_attention_block_refs=True, + ) + + prepared = bridge.prepare_request( + request_id="req-full-offset-refs", + input_dict={ + "input_ids": input_ids, + "attention_mask": torch.ones((1, 256), dtype=torch.int32), + "position_ids": torch.arange(256, dtype=torch.int32).unsqueeze(0), + }, + attention_hit_len=128, + cumulative_hashes_by_prefix_len=hashes, + attention_block_refs_by_prefix_len={128: (9,)}, + ) + + self.assertEqual(prepared.plan.checkpoint_key, restored_key) + self.assertEqual(prepared.attention_block_refs, (7, 9)) + committed = bridge.commit_prefill(prepared) + self.assertIsNotNone(committed) + self.assertEqual(committed.attention_block_refs, (7, 9)) + + def test_bridge_misses_without_gdn_checkpoint_and_cancels_reserved_slot(self): + store = _store() + allocator = HybridAPCSlotAllocator(num_slots=2) + input_ids = torch.arange(256, dtype=torch.int32).unsqueeze(0) + bridge = HybridAPCSchedulerBridge( + store=store, + slot_allocator=allocator, + cache_salt="tenant-a", + model_revision="rev-a", + reject_unbacked_attention_hits=False, + ) + + prepared = bridge.prepare_request( + request_id="req-cold", + input_dict={"input_ids": input_ids}, + attention_hit_len=128, + ) + + self.assertEqual(prepared.plan.restore_checkpoint_prefix_len, 0) + self.assertEqual(prepared.commit_slot, 0) + self.assertTrue( + torch.equal(prepared.input_dict["input_ids"], input_ids) + ) + self.assertTrue( + torch.equal( + prepared.input_dict["hybrid_restore_mask"], + torch.tensor([0], dtype=torch.int32), + ) + ) + + cancelled = bridge.cancel_request(prepared) + self.assertEqual(cancelled.state, "CANCELLED") + self.assertEqual(allocator.reserved_slots, ()) + self.assertEqual(allocator.free_slots, (1, 0)) + self.assertEqual(len(store), 0) + + def test_bridge_rejects_attention_hit_without_gdn_checkpoint_by_default(self): + bridge = HybridAPCSchedulerBridge( + store=_store(), + slot_allocator=HybridAPCSlotAllocator(num_slots=2), + cache_salt="tenant-a", + model_revision="rev-a", + ) + + with self.assertRaisesRegex(ValueError, "without a matching GDN checkpoint"): + bridge.prepare_request( + request_id="req-unbacked-hit", + input_dict={ + "input_ids": torch.arange(256, dtype=torch.int32).unsqueeze(0) + }, + attention_hit_len=128, + ) + + def test_bridge_env_can_allow_unbacked_attention_fallback(self): + bridge = HybridAPCSchedulerBridge( + store=_store(), + slot_allocator=HybridAPCSlotAllocator(num_slots=2), + cache_salt="tenant-a", + model_revision="rev-a", + ) + + with patch.dict( + os.environ, + {"QWEN36_ALLOW_UNBACKED_HYBRID_APC_FALLBACK": "1"}, + ): + prepared = bridge.prepare_request( + request_id="req-unbacked-fallback", + input_dict={ + "input_ids": torch.arange(256, dtype=torch.int32).unsqueeze(0) + }, + attention_hit_len=128, + ) + + self.assertIsNone(prepared.plan.checkpoint_key) + self.assertEqual(prepared.input_dict["hybrid_restore_mask"].item(), 0) + + def test_bridge_suffix_only_restore_is_explicit_and_unambiguous(self): + store = _store() + _insert(store, 128, gdn_checkpoint_slot=1) + bridge = HybridAPCSchedulerBridge( + store=store, + slot_allocator=HybridAPCSlotAllocator(num_slots=2), + cache_salt="tenant-a", + model_revision="rev-a", + ) + suffix_ids = torch.arange(128, 256, dtype=torch.int32).unsqueeze(0) + + with self.assertRaisesRegex(ValueError, "without scheduler-authorized"): + bridge.prepare_suffix_only_request( + request_id="req-suffix-disabled", + input_dict={"input_ids": suffix_ids}, + attention_hit_len=128, + request_prefix_len=256, + ) + + with patch.dict( + os.environ, + {"QWEN36_HYBRID_APC_ALLOW_UNHASHED_SINGLE_PREFIX_RESTORE": "1"}, + ): + prepared = bridge.prepare_suffix_only_request( + request_id="req-suffix", + input_dict={"input_ids": suffix_ids}, + attention_hit_len=128, + request_prefix_len=256, + ) + + self.assertIsNotNone(prepared) + self.assertTrue(torch.equal(prepared.input_dict["input_ids"], suffix_ids)) + self.assertTrue( + torch.equal( + prepared.input_dict["position_ids"], + torch.arange(128, 256, dtype=torch.int64).unsqueeze(0), + ) + ) + self.assertEqual(prepared.input_dict["computed_context_lens"].item(), 128) + self.assertEqual(prepared.input_dict["num_queries"].item(), 128) + self.assertEqual(prepared.input_dict["hybrid_restore_mask"].item(), 1) + self.assertEqual(prepared.input_dict["hybrid_restore_slot_ids"].item(), 1) + self.assertTrue( + torch.equal( + prepared.input_dict["rotary_position_ids"], + torch.arange(128, 256, dtype=torch.int32) + .view(1, 1, 128) + .expand(3, 1, 128), + ) + ) + + def test_bridge_suffix_only_restore_uses_scheduler_authorized_key(self): + store = _store() + _insert(store, 128, prefix_hash="h128-a", gdn_checkpoint_slot=0) + key_b, _checkpoint_b = _insert( + store, + 128, + prefix_hash="h128-b", + gdn_checkpoint_slot=1, + ) + bridge = HybridAPCSchedulerBridge( + store=store, + slot_allocator=HybridAPCSlotAllocator(num_slots=2), + cache_salt="tenant-a", + model_revision="rev-a", + ) + _SCHEDULER_PATCH.authorize_hybrid_apc_prefix_read(key_b) + + prepared = bridge.prepare_suffix_only_request( + request_id="req-authorized", + input_dict={ + "input_ids": torch.arange(128, 256, dtype=torch.int32).unsqueeze(0) + }, + attention_hit_len=128, + request_prefix_len=256, + ) + + self.assertIsNotNone(prepared) + self.assertEqual(prepared.plan.checkpoint_key, key_b) + self.assertEqual(prepared.input_dict["hybrid_restore_slot_ids"].item(), 1) + + def test_same_request_suffix_uses_active_gdn_carry(self): + store = _store() + bridge = HybridAPCSchedulerBridge( + store=store, + slot_allocator=HybridAPCSlotAllocator(num_slots=4), + cache_salt="tenant-a", + model_revision="rev-a", + ) + first = bridge.prepare_request( + request_id="req-live", + input_dict={ + "input_ids": torch.arange(128, dtype=torch.int32).unsqueeze(0) + }, + attention_hit_len=0, + cumulative_hashes_by_prefix_len={128: "h128"}, + ) + bridge.commit_prefill(first) + bridge.finish_request("req-live") + _SCHEDULER_PATCH.authorize_hybrid_apc_prefix_read( + first.commit_key, + request_id="req-live", + ) + + same_request = bridge.prepare_suffix_only_request( + request_id="req-live", + input_dict={ + "input_ids": torch.arange(128, 192, dtype=torch.int32).unsqueeze(0) + }, + attention_hit_len=128, + request_prefix_len=192, + ) + + self.assertIsNotNone(same_request) + self.assertEqual(same_request.input_dict["computed_context_lens"].item(), 128) + self.assertEqual( + same_request.input_dict["hybrid_restore_prefix_lens"].item(), + 128, + ) + self.assertEqual(same_request.input_dict["hybrid_restore_mask"].item(), 0) + self.assertEqual(same_request.input_dict["hybrid_restore_slot_ids"].item(), 0) + + _SCHEDULER_PATCH.authorize_hybrid_apc_prefix_read( + first.commit_key, + request_id="req-other", + ) + other_request = bridge.prepare_suffix_only_request( + request_id="req-other", + input_dict={ + "input_ids": torch.arange(128, 192, dtype=torch.int32).unsqueeze(0) + }, + attention_hit_len=128, + request_prefix_len=192, + ) + + self.assertIsNotNone(other_request) + self.assertEqual(other_request.input_dict["hybrid_restore_mask"].item(), 1) + + def test_same_request_full_prompt_slice_uses_active_gdn_carry(self): + store = _store() + bridge = HybridAPCSchedulerBridge( + store=store, + slot_allocator=HybridAPCSlotAllocator(num_slots=4), + cache_salt="tenant-a", + model_revision="rev-a", + ) + first = bridge.prepare_request( + request_id="req-live-full", + input_dict={ + "input_ids": torch.arange(128, dtype=torch.int32).unsqueeze(0) + }, + attention_hit_len=0, + cumulative_hashes_by_prefix_len={128: "h128"}, + ) + bridge.commit_prefill(first) + bridge.finish_request("req-live-full") + + same_request = bridge.prepare_request( + request_id="req-live-full", + input_dict={ + "input_ids": torch.arange(192, dtype=torch.int32).unsqueeze(0) + }, + attention_hit_len=128, + request_prefix_len=192, + cumulative_hashes_by_prefix_len={128: "h128", 192: "h192"}, + ) + + self.assertEqual( + same_request.input_dict["input_ids"].shape, + torch.Size([1, 64]), + ) + self.assertEqual(same_request.input_dict["computed_context_lens"].item(), 128) + self.assertEqual( + same_request.input_dict["hybrid_restore_prefix_lens"].item(), + 128, + ) + self.assertEqual(same_request.input_dict["hybrid_restore_mask"].item(), 0) + + other_request = bridge.prepare_request( + request_id="req-other-full", + input_dict={ + "input_ids": torch.arange(192, dtype=torch.int32).unsqueeze(0) + }, + attention_hit_len=128, + request_prefix_len=192, + cumulative_hashes_by_prefix_len={128: "h128", 192: "h192"}, + ) + + self.assertEqual(other_request.input_dict["hybrid_restore_mask"].item(), 1) + + def test_bridge_suffix_only_restore_uses_checkpoint_attention_block_refs(self): + store = _store() + key, _checkpoint = _insert( + store, + 256, + prefix_hash="h256", + gdn_checkpoint_slot=1, + attention_block_refs=(4, 5), + ) + bridge = HybridAPCSchedulerBridge( + store=store, + slot_allocator=HybridAPCSlotAllocator(num_slots=3), + cache_salt="tenant-a", + model_revision="rev-a", + ) + _SCHEDULER_PATCH.authorize_hybrid_apc_prefix_read( + key, + request_id="req-suffix-blocks", + ) + + prepared = bridge.prepare_suffix_only_request( + request_id="req-suffix-blocks", + input_dict={ + "input_ids": torch.arange(256, 272, dtype=torch.int32).unsqueeze(0), + "block_table": torch.tensor([[99]], dtype=torch.int32), + "rotary_position_id": torch.arange( + 16, + dtype=torch.int32, + ).unsqueeze(0), + "rotary_position_ids": torch.arange( + 16, + dtype=torch.int32, + ).view(1, 1, 16).expand(3, 1, 16), + }, + attention_hit_len=256, + request_prefix_len=272, + ) + + self.assertIsNotNone(prepared) + self.assertTrue( + torch.equal( + prepared.input_dict["block_table"], + torch.tensor([[4, 5, 99]], dtype=torch.int32), + ) + ) + expected_positions = torch.arange(256, 272, dtype=torch.int32) + self.assertTrue( + torch.equal( + prepared.input_dict["rotary_position_id"], + expected_positions.unsqueeze(0), + ) + ) + self.assertTrue( + torch.equal( + prepared.input_dict["rotary_position_ids"], + expected_positions.view(1, 1, 16).expand(3, 1, 16), + ) + ) + + def test_bridge_suffix_only_restore_replaces_prefix_block_refs(self): + store = _store() + key, _checkpoint = _insert( + store, + 256, + prefix_hash="h256", + gdn_checkpoint_slot=1, + attention_block_refs=(4,), + ) + bridge = HybridAPCSchedulerBridge( + store=store, + slot_allocator=HybridAPCSlotAllocator(num_slots=3), + cache_salt="tenant-a", + model_revision="rev-a", + ) + _SCHEDULER_PATCH.authorize_hybrid_apc_prefix_read( + key, + request_id="req-suffix-blocks", + ) + + prepared = bridge.prepare_suffix_only_request( + request_id="req-suffix-blocks", + input_dict={ + "input_ids": torch.arange(256, 272, dtype=torch.int32).unsqueeze(0), + "block_table": torch.tensor([[99, 8]], dtype=torch.int32), + }, + attention_hit_len=256, + request_prefix_len=272, + ) + + self.assertIsNotNone(prepared) + self.assertTrue( + torch.equal( + prepared.input_dict["block_table"], + torch.tensor([[4, 8]], dtype=torch.int32), + ) + ) + + def test_bridge_suffix_only_restore_can_commit_boundary_checkpoint(self): + store = _store() + allocator = HybridAPCSlotAllocator(num_slots=3) + input_ids = torch.arange(256, dtype=torch.int32).unsqueeze(0) + hashes = build_cumulative_prefix_hashes(input_ids, block_size=128) + restored_key, _checkpoint = _insert( + store, + 128, + prefix_hash=hashes[128], + gdn_checkpoint_slot=2, + ) + bridge = HybridAPCSchedulerBridge( + store=store, + slot_allocator=allocator, + cache_salt="tenant-a", + model_revision="rev-a", + ) + _SCHEDULER_PATCH.authorize_hybrid_apc_prefix_read( + restored_key, + request_id="req-suffix-commit", + ) + + prepared = bridge.prepare_suffix_only_request( + request_id="req-suffix-commit", + input_dict={ + "input_ids": torch.arange(128, 256, dtype=torch.int32).unsqueeze(0) + }, + attention_hit_len=128, + request_prefix_len=256, + cumulative_hashes_by_prefix_len=hashes, + attention_block_refs_by_prefix_len={256: (8, 9)}, + ) + + self.assertIsNotNone(prepared) + self.assertEqual(prepared.plan.checkpoint_key, restored_key) + self.assertEqual(prepared.commit_prefix_len, 256) + self.assertEqual(prepared.commit_slot, 0) + self.assertEqual(prepared.input_dict["hybrid_commit_mask"].item(), 1) + self.assertEqual(prepared.input_dict["hybrid_commit_slot_ids"].item(), 0) + committed = bridge.commit_prefill(prepared) + self.assertIsNotNone(committed) + self.assertEqual(committed.attention_block_refs, (8, 9)) + self.assertIsNotNone(store.lookup(prepared.commit_key)) + + def test_bridge_suffix_only_commit_combines_restored_and_suffix_refs(self): + store = _store() + allocator = HybridAPCSlotAllocator(num_slots=3) + input_ids = torch.arange(256, dtype=torch.int32).unsqueeze(0) + hashes = build_cumulative_prefix_hashes(input_ids, block_size=128) + restored_key, _checkpoint = _insert( + store, + 128, + prefix_hash=hashes[128], + attention_block_refs=(7,), + gdn_checkpoint_slot=2, + ) + bridge = HybridAPCSchedulerBridge( + store=store, + slot_allocator=allocator, + cache_salt="tenant-a", + model_revision="rev-a", + require_attention_block_refs=True, + ) + _SCHEDULER_PATCH.authorize_hybrid_apc_prefix_read( + restored_key, + request_id="req-suffix-offset-refs", + ) + + prepared = bridge.prepare_suffix_only_request( + request_id="req-suffix-offset-refs", + input_dict={ + "input_ids": torch.arange(128, 256, dtype=torch.int32).unsqueeze(0) + }, + attention_hit_len=128, + request_prefix_len=256, + cumulative_hashes_by_prefix_len=hashes, + attention_block_refs_by_prefix_len={128: (9,)}, + ) + + self.assertIsNotNone(prepared) + self.assertEqual(prepared.attention_block_refs, (7, 9)) + committed = bridge.commit_prefill(prepared) + self.assertIsNotNone(committed) + self.assertEqual(committed.attention_block_refs, (7, 9)) + + def test_bridge_suffix_only_restore_rejects_ambiguous_prefix_len(self): + store = _store() + _insert(store, 128, prefix_hash="h128-a", gdn_checkpoint_slot=0) + _insert(store, 128, prefix_hash="h128-b", gdn_checkpoint_slot=1) + bridge = HybridAPCSchedulerBridge( + store=store, + slot_allocator=HybridAPCSlotAllocator(num_slots=2), + cache_salt="tenant-a", + model_revision="rev-a", + ) + + with patch.dict( + os.environ, + {"QWEN36_HYBRID_APC_ALLOW_UNHASHED_SINGLE_PREFIX_RESTORE": "1"}, + ): + with self.assertRaisesRegex(ValueError, "ambiguous unhashed"): + bridge.prepare_suffix_only_request( + request_id="req-ambiguous", + input_dict={ + "input_ids": torch.arange(128, 256, dtype=torch.int32) + .unsqueeze(0) + }, + attention_hit_len=128, + request_prefix_len=256, + ) + + def test_bridge_does_not_commit_mid_prompt_checkpoint_boundary(self): + store = _store() + allocator = HybridAPCSlotAllocator(num_slots=2) + input_ids = torch.arange(192, dtype=torch.int32).unsqueeze(0) + hashes = build_cumulative_prefix_hashes(input_ids[:, :128], block_size=128) + bridge = HybridAPCSchedulerBridge( + store=store, + slot_allocator=allocator, + cache_salt="tenant-a", + model_revision="rev-a", + ) + + prepared = bridge.prepare_request( + request_id="req-mid-boundary", + input_dict={"input_ids": input_ids}, + attention_hit_len=0, + cumulative_hashes_by_prefix_len=hashes, + ) + + self.assertEqual(prepared.commit_prefix_len, 128) + self.assertIsNone(prepared.commit_slot) + self.assertEqual(prepared.input_dict["hybrid_commit_mask"].item(), 0) + self.assertEqual(allocator.reserved_slots, ()) + self.assertEqual(allocator.free_slots, (0, 1)) + self.assertIsNone(bridge.commit_prefill(prepared)) + + def test_bridge_can_require_scheduler_prefix_hashes(self): + bridge = HybridAPCSchedulerBridge( + store=_store(), + slot_allocator=HybridAPCSlotAllocator(num_slots=2), + cache_salt="tenant-a", + model_revision="rev-a", + allow_local_hash_fallback=False, + ) + + with self.assertRaisesRegex(ValueError, "requires vLLM cumulative prefix hashes"): + bridge.prepare_request( + request_id="req-strict", + input_dict={"input_ids": torch.arange(128, dtype=torch.int32).unsqueeze(0)}, + attention_hit_len=0, + ) + + def test_bridge_can_require_attention_refs_on_commit(self): + store = _store() + input_ids = torch.arange(128, dtype=torch.int32).unsqueeze(0) + hashes = build_cumulative_prefix_hashes(input_ids, block_size=128) + bridge = HybridAPCSchedulerBridge( + store=store, + slot_allocator=HybridAPCSlotAllocator(num_slots=2), + cache_salt="tenant-a", + model_revision="rev-a", + allow_local_hash_fallback=False, + require_attention_block_refs=True, + ) + + prepared = bridge.prepare_request( + request_id="req-refs", + input_dict={"input_ids": input_ids}, + attention_hit_len=0, + cumulative_hashes_by_prefix_len=hashes, + ) + + with self.assertRaisesRegex(ValueError, "requires real attention block refs"): + bridge.commit_prefill(prepared) + + committed = bridge.commit_prefill(prepared, attention_block_refs=(31,)) + + self.assertEqual(committed.attention_block_refs, (31,)) + + def test_bridge_salt_mismatch_does_not_restore_slot_zero(self): + store = _store() + allocator = HybridAPCSlotAllocator(num_slots=2) + input_ids = torch.arange(128, dtype=torch.int32).unsqueeze(0) + hashes = build_cumulative_prefix_hashes(input_ids, block_size=128) + _insert(store, 128, prefix_hash=hashes[128], cache_salt="tenant-a") + bridge = HybridAPCSchedulerBridge( + store=store, + slot_allocator=allocator, + cache_salt="tenant-b", + model_revision="rev-a", + reject_unbacked_attention_hits=False, + ) + + prepared = bridge.prepare_request( + request_id="req-salt", + input_dict={"input_ids": input_ids}, + attention_hit_len=128, + cumulative_hashes_by_prefix_len=hashes, + ) + + self.assertIsNone(prepared.plan.checkpoint_key) + self.assertEqual(prepared.plan.restore_checkpoint_prefix_len, 0) + self.assertTrue( + torch.equal( + prepared.input_dict["hybrid_restore_slot_ids"], + torch.tensor([0], dtype=torch.int32), + ) + ) + self.assertTrue( + torch.equal( + prepared.input_dict["hybrid_restore_mask"], + torch.tensor([0], dtype=torch.int32), + ) + ) + + def test_bridge_env_can_disable_restore_and_commit(self): + store = _store() + allocator = HybridAPCSlotAllocator(num_slots=2) + input_ids = torch.arange(128, dtype=torch.int32).unsqueeze(0) + hashes = build_cumulative_prefix_hashes(input_ids, block_size=128) + restored_key, _checkpoint = _insert( + store, + 128, + prefix_hash=hashes[128], + gdn_checkpoint_slot=1, + ) + bridge = HybridAPCSchedulerBridge( + store=store, + slot_allocator=allocator, + cache_salt="tenant-a", + model_revision="rev-a", + ) + + with patch.dict( + os.environ, + { + "QWEN36_DISABLE_HYBRID_GDN_RESTORE": "1", + "QWEN36_DISABLE_HYBRID_GDN_COMMIT": "1", + }, + ): + prepared = bridge.prepare_request( + request_id="req-disabled", + input_dict={"input_ids": input_ids}, + attention_hit_len=128, + cumulative_hashes_by_prefix_len=hashes, + ) + + self.assertIsNone(prepared.plan.checkpoint_key) + self.assertIsNone(prepared.commit_slot) + self.assertEqual(store.lookup(restored_key).refcount, 0) + self.assertEqual(allocator.reserved_slots, ()) + self.assertTrue(torch.equal(prepared.input_dict["input_ids"], input_ids)) + self.assertEqual(prepared.input_dict["hybrid_restore_mask"].item(), 0) + self.assertEqual(prepared.input_dict["hybrid_commit_mask"].item(), 0) + self.assertTrue( + torch.equal( + prepared.input_dict["hybrid_restore_mask"], + torch.tensor([0], dtype=torch.int32), + ) + ) + + def test_bridge_skips_commit_when_checkpoint_already_exists(self): + store = _store() + allocator = HybridAPCSlotAllocator(num_slots=2) + input_ids = torch.arange(128, dtype=torch.int32).unsqueeze(0) + hashes = build_cumulative_prefix_hashes(input_ids, block_size=128) + _insert(store, 128, prefix_hash=hashes[128], gdn_checkpoint_slot=1) + bridge = HybridAPCSchedulerBridge( + store=store, + slot_allocator=allocator, + cache_salt="tenant-a", + model_revision="rev-a", + ) + + prepared = bridge.prepare_request( + request_id="req-existing", + input_dict={"input_ids": input_ids}, + attention_hit_len=128, + cumulative_hashes_by_prefix_len=hashes, + ) + + self.assertIsNone(prepared.commit_slot) + self.assertEqual(prepared.input_dict["hybrid_commit_mask"].item(), 0) + self.assertEqual(allocator.free_slots, (0, 1)) + self.assertIsNone(bridge.commit_prefill(prepared)) + + def test_bridge_finish_releases_uncommitted_reserved_slot(self): + store = _store() + allocator = HybridAPCSlotAllocator(num_slots=2) + input_ids = torch.arange(128, dtype=torch.int32).unsqueeze(0) + bridge = HybridAPCSchedulerBridge( + store=store, + slot_allocator=allocator, + cache_salt="tenant-a", + model_revision="rev-a", + ) + + prepared = bridge.prepare_request( + request_id="req-no-commit", + input_dict={"input_ids": input_ids}, + attention_hit_len=0, + ) + + self.assertEqual(prepared.commit_slot, 0) + self.assertEqual(allocator.reserved_slots, (0,)) + finished = bridge.finish_request("req-no-commit") + + self.assertEqual(finished.state, "FINISHED") + self.assertEqual(allocator.reserved_slots, ()) + self.assertEqual(allocator.free_slots, (1, 0)) + + def test_prepare_with_request_record_keeps_lifecycle_on_original_dict(self): + store = _store() + allocator = HybridAPCSlotAllocator(num_slots=2) + bridge = HybridAPCSchedulerBridge( + store=store, + slot_allocator=allocator, + cache_salt="tenant-a", + model_revision="rev-a", + require_attention_block_refs=True, + ) + model = types.SimpleNamespace( + config=types.SimpleNamespace( + use_hybrid_apc_manager=True, + hybrid_apc_require_vllm_metadata=True, + pad_token_id=0, + ), + hybrid_apc_bridge=bridge, + ) + input_ids = torch.arange(128, dtype=torch.int32).unsqueeze(0) + hashes = build_cumulative_prefix_hashes(input_ids, block_size=128) + original_input = { + "input_ids": input_ids, + "hybrid_request_records": ( + { + "request_id": "req-record", + "vllm_attention_hit_len": 0, + "request_prefix_len": 128, + "cumulative_hashes_by_prefix_len": hashes, + "attention_block_refs_by_prefix_len": {128: (11,)}, + }, + ), + } + + prepared_inputs = prepare_hybrid_apc_request_for_execution( + model, + original_input, + ) + + self.assertIsNot(prepared_inputs, original_input) + self.assertIn("_hybrid_apc_prepared", original_input) + self.assertEqual(allocator.reserved_slots, (0,)) + + finish_hybrid_apc_request(original_input) + + self.assertEqual(allocator.reserved_slots, ()) + self.assertEqual(allocator.committed_slots, (0,)) + + def test_scheduler_metadata_carries_prompt_tokens_only(self): + scheduler = types.SimpleNamespace( + cache_config=types.SimpleNamespace(block_size=128), + ) + request = types.SimpleNamespace( + prompt_token_ids=[11, 12], + all_token_ids=[11, 12, 13], + num_tokens=3, + block_hashes=[], + ) + + metadata = _SCHEDULER_PATCH._scheduler_request_metadata( + scheduler, + request, + num_computed_tokens=2, + ) + + self.assertEqual(metadata["request_prefix_len"], 2) + self.assertEqual(metadata["full_input_ids"], (11, 12)) + + def test_scheduler_metadata_omits_full_tokens_for_cold_chunk(self): + scheduler = types.SimpleNamespace( + cache_config=types.SimpleNamespace(block_size=128), + ) + request = types.SimpleNamespace( + prompt_token_ids=[11, 12, 13], + all_token_ids=[11, 12, 13], + num_tokens=3, + block_hashes=[], + ) + + metadata = _SCHEDULER_PATCH._scheduler_request_metadata( + scheduler, + request, + num_computed_tokens=0, + ) + + self.assertNotIn("full_input_ids", metadata) + + def test_scheduler_request_records_preserve_full_input_ids(self): + scheduler_output = types.SimpleNamespace( + num_scheduled_tokens={"req-a": 16}, + ) + setattr( + scheduler_output, + _SCHEDULER_PATCH._SCHEDULER_OUTPUT_METADATA_ATTR, + { + "req-a": { + "request_prefix_len": 144, + "full_input_ids": tuple(range(144)), + "vllm_attention_hit_len": 128, + }, + }, + ) + model_input = types.SimpleNamespace(request_ids=("req-a",)) + + records = _SCHEDULER_PATCH._hybrid_apc_request_records_from_model_input( + model_input, + scheduler_output, + ) + + self.assertIsNotNone(records) + self.assertEqual(records[0]["full_input_ids"], tuple(range(144))) + self.assertEqual(records[0]["active_suffix_len"], 16) + + def test_prepare_with_request_record_full_input_ids_restores_suffix(self): + store = _store() + allocator = HybridAPCSlotAllocator(num_slots=2) + bridge = HybridAPCSchedulerBridge( + store=store, + slot_allocator=allocator, + cache_salt="tenant-a", + model_revision="rev-a", + require_attention_block_refs=True, + ) + model = types.SimpleNamespace( + config=types.SimpleNamespace( + use_hybrid_apc_manager=True, + hybrid_apc_require_vllm_metadata=True, + pad_token_id=0, + ), + hybrid_apc_bridge=bridge, + ) + full_input_ids = torch.arange(144, dtype=torch.int32).unsqueeze(0) + suffix_input_ids = full_input_ids[:, 128:144] + hashes = build_cumulative_prefix_hashes(full_input_ids, block_size=128) + _insert( + store, + 128, + prefix_hash=hashes[128], + attention_block_refs=(11,), + gdn_checkpoint_slot=0, + ) + original_input = { + "input_ids": suffix_input_ids, + "hybrid_request_records": ( + { + "request_id": "req-record-full-ids", + "vllm_attention_hit_len": 128, + "request_prefix_len": 144, + "full_input_ids": tuple(int(item) for item in full_input_ids[0]), + "cumulative_hashes_by_prefix_len": hashes, + "attention_block_refs_by_prefix_len": {128: (11,)}, + "active_suffix_len": 16, + }, + ), + } + + prepared_inputs = prepare_hybrid_apc_request_for_execution( + model, + original_input, + ) + + self.assertTrue( + torch.equal(prepared_inputs["input_ids"], suffix_input_ids), + ) + self.assertEqual(prepared_inputs["computed_context_lens"].item(), 128) + self.assertEqual(prepared_inputs["full_context_lens"].item(), 144) + self.assertEqual(prepared_inputs["num_queries"].item(), 16) + + def test_bridge_evicts_lru_checkpoint_before_reserving_when_slots_full(self): + store = _store(max_checkpoints=2) + allocator = HybridAPCSlotAllocator(num_slots=2) + bridge = HybridAPCSchedulerBridge( + store=store, + slot_allocator=allocator, + cache_salt="tenant-a", + model_revision="rev-a", + ) + + committed_keys = [] + for request_index in range(3): + input_ids = ( + torch.arange(128, dtype=torch.int32).unsqueeze(0) + + request_index * 1000 + ) + hashes = build_cumulative_prefix_hashes(input_ids, block_size=128) + prepared = bridge.prepare_request( + request_id=f"req-{request_index}", + input_dict={"input_ids": input_ids}, + attention_hit_len=0, + cumulative_hashes_by_prefix_len=hashes, + ) + self.assertIsNotNone(prepared.commit_slot) + bridge.commit_prefill(prepared) + bridge.finish_request(prepared.request_id) + committed_keys.append(prepared.commit_key) + + self.assertIsNone(store.lookup(committed_keys[0])) + self.assertIsNotNone(store.lookup(committed_keys[1])) + self.assertIsNotNone(store.lookup(committed_keys[2])) + self.assertEqual(len(allocator.committed_slots), 2) + self.assertEqual(allocator.reserved_slots, ()) + self.assertEqual(allocator.free_slots, ()) + + def test_store_releases_old_slot_when_replacing_same_key(self): + store = _store() + allocator = HybridAPCSlotAllocator(num_slots=2) + store.set_checkpoint_slot_releaser(allocator.release_committed) + + key = store.make_key( + cumulative_prefix_hash="same-prefix", + prefix_len=128, + cache_salt="tenant-a", + model_revision="rev-a", + ) + + first_slot = allocator.reserve() + first = store.insert( + key=key, + attention_block_refs=(1,), + gdn_checkpoint_slot=first_slot, + ) + allocator.mark_committed(first.gdn_checkpoint_slot) + + second_slot = allocator.reserve() + second = store.insert( + key=key, + attention_block_refs=(2,), + gdn_checkpoint_slot=second_slot, + ) + allocator.mark_committed(second.gdn_checkpoint_slot) + + self.assertEqual(store.lookup(key).gdn_checkpoint_slot, second_slot) + self.assertEqual(allocator.committed_slots, (second_slot,)) + self.assertEqual(allocator.free_slots, (first_slot,)) + + def test_attention_block_eviction_unpublishes_scheduler_checkpoint(self): + store = _store() + key, _checkpoint = _insert( + store, + 128, + attention_block_refs=(7,), + gdn_checkpoint_slot=0, + ) + _SCHEDULER_PATCH.register_hybrid_apc_gdn_checkpoint(key) + self.assertTrue(_SCHEDULER_PATCH.unregister_hybrid_apc_gdn_checkpoint(key)) + + _SCHEDULER_PATCH.register_hybrid_apc_gdn_checkpoint(key) + self.assertEqual(store.on_attention_block_evicted(7), [key]) + self.assertFalse(_SCHEDULER_PATCH.unregister_hybrid_apc_gdn_checkpoint(key)) + + +if __name__ == "__main__": + unittest.main() diff --git a/contrib/models/Qwen3.6-27B/test/unit/test_hybrid_apc_validation.py b/contrib/models/Qwen3.6-27B/test/unit/test_hybrid_apc_validation.py new file mode 100644 index 00000000..026140e3 --- /dev/null +++ b/contrib/models/Qwen3.6-27B/test/unit/test_hybrid_apc_validation.py @@ -0,0 +1,405 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +import importlib.util +import json +import sys +import tempfile +import unittest +from pathlib import Path +from types import SimpleNamespace +from unittest.mock import patch + + +_REPO_ROOT = Path(__file__).resolve().parents[5] +_VALIDATION_PATH = _REPO_ROOT / "validation_scripts" / "qwen36_hybrid_apc_validation.py" +_SPEC = importlib.util.spec_from_file_location( + "qwen36_hybrid_apc_validation_under_test", + _VALIDATION_PATH, +) +_VALIDATION = importlib.util.module_from_spec(_SPEC) +sys.modules[_SPEC.name] = _VALIDATION +_SPEC.loader.exec_module(_VALIDATION) + + +def _args(**overrides): + defaults = { + "shared_prefix": "shared", + "suffix_a": " suffix a", + "suffix_b": " suffix b", + "shared_prefix_2": "shared two", + "suffix_c": " suffix c", + "suffix_d": " suffix d", + "max_num_seqs": 2, + "max_tokens": 8, + "compiled_artifacts": None, + "model_path": "/tmp/model", + "cte_buckets": ["256,512"], + "align_prompts_to_cte_buckets": False, + "require_real_tokens": True, + "dummy_token_ids": [0], + "output_json": None, + "block_size": 256, + "gdn_checkpoint_interval": 256, + "seq_len": 2048, + "compact_boundary_lens": None, + "compact_suffix_tokens": 16, + "compact_min_requests": 50, + "compact_min_grouped_speedup": 1.5, + "hybrid_apc_require_vllm_metadata": True, + "hybrid_apc_disable_unbacked_prefix_reads": True, + "hybrid_apc_enable_backed_prefix_reads": True, + "hybrid_apc_max_backed_prefix_read_len": 0, + "max_gdn_checkpoint_slots": 8, + } + defaults.update(overrides) + return SimpleNamespace(**defaults) + + +def _fake_generate_batch(tokens_by_label): + def fake_generate_batch(_args, *, enable_hybrid_apc, labeled_prompts): + return { + label: { + "tokens": list(tokens_by_label[label]), + "elapsed_seconds": 0.01, + } + for label, _prompt in labeled_prompts + } + + return fake_generate_batch + + +def _fake_generate_grouped_batch(tokens_by_label): + def fake_generate_grouped_batch( + _args, + *, + enable_hybrid_apc, + labeled_prompt_groups, + ): + results = {} + for group in labeled_prompt_groups: + for label, _prompt in group: + results[label] = { + "tokens": list(tokens_by_label[label]), + "elapsed_seconds": 0.01, + } + return results + + return fake_generate_grouped_batch + + +def _compact_reference_label(label): + if label.startswith("warm_full_"): + return "cold_full_" + label[len("warm_full_") :] + if label.startswith("warm_partial_"): + return "cold_partial_" + label[len("warm_partial_") :] + if label.startswith("mixed_warm_"): + return "cold_partial_" + label[len("mixed_warm_") :] + if label.startswith("eviction_probe_partial_"): + return "cold_partial_" + label[len("eviction_probe_partial_") :] + if label.startswith("mixed_cold__"): + return "cold_mixed__" + label[len("mixed_cold__") :] + if label.startswith("warmup"): + return label + return label + + +def _fake_compact_generate_batch(_args, *, enable_hybrid_apc, labeled_prompts): + del _args, enable_hybrid_apc + return { + label: { + "tokens": [sum(ord(ch) for ch in _compact_reference_label(label)) % 997 + 1], + "elapsed_seconds": 2.0, + } + for label, _prompt in labeled_prompts + } + + +def _fake_compact_generate_grouped_batch( + _args, + *, + enable_hybrid_apc, + labeled_prompt_groups, +): + del _args, enable_hybrid_apc + results = {} + for group in labeled_prompt_groups: + elapsed = 1.0 if len(group) > 1 else 0.5 + for label, _prompt in group: + results[label] = { + "tokens": [ + sum(ord(ch) for ch in _compact_reference_label(label)) % 997 + 1 + ], + "elapsed_seconds": elapsed, + } + return results + + +class TestHybridAPCValidationRealTokens(unittest.TestCase): + def test_bucket_alignment_pads_prompt_token_ids(self): + class FakeTokenizer: + pad_token_id = 99 + eos_token_id = None + + def encode(self, prompt, add_special_tokens=False): + del add_special_tokens + return list(range(len(prompt.split()))) + + class FakeAutoTokenizer: + @staticmethod + def from_pretrained(_model_path, trust_remote_code): + del trust_remote_code + return FakeTokenizer() + + fake_transformers = SimpleNamespace(AutoTokenizer=FakeAutoTokenizer) + with patch.dict(sys.modules, {"transformers": fake_transformers}): + aligned = _VALIDATION._maybe_bucket_align_labeled_prompts( + _args( + cte_buckets=["4,8"], + align_prompts_to_cte_buckets=True, + ), + [("prompt", "one two three")], + ) + + self.assertEqual( + aligned, + [("prompt", {"prompt_token_ids": [0, 1, 2, 99]})], + ) + + def test_bucket_alignment_rejects_too_long_prompt(self): + with self.assertRaisesRegex(ValueError, "exceeds compiled CTE buckets"): + _VALIDATION._next_bucket(9, [4, 8]) + + def test_real_token_checks_fail_all_dummy_tokens(self): + checks = _VALIDATION._real_token_checks( + { + "cold_full": {"tokens": [0, 0, 0]}, + "warm_full": {"tokens": [0, 0, 0]}, + }, + {0}, + ) + + self.assertFalse(checks["passed"]) + self.assertEqual( + checks["checks"]["cold_full"]["failure"], + "generated tokens are empty or all configured dummy tokens", + ) + + def test_exactness_can_require_non_dummy_generated_tokens(self): + tokens_by_label = { + "cold_full": [0, 0], + "cold_partial": [0, 0], + "warmup_full": [0, 0], + "warm_full": [0, 0], + "warmup_partial": [0, 0], + "warm_partial": [0, 0], + } + with tempfile.TemporaryDirectory() as tmpdir: + output_json = Path(tmpdir) / "report.json" + with patch.object( + _VALIDATION, + "_generate_batch", + side_effect=_fake_generate_batch(tokens_by_label), + ): + rc = _VALIDATION.run_exactness( + _args(output_json=output_json), + ) + + self.assertEqual(rc, 1) + report = output_json.read_text(encoding="utf-8") + self.assertIn('"full_prefix_exact": true', report) + self.assertIn('"partial_prefix_exact": true', report) + self.assertIn('"real_generated_tokens_passed": false', report) + + def test_exactness_passes_when_real_tokens_are_present(self): + tokens_by_label = { + "cold_full": [42, 0], + "cold_partial": [43, 0], + "warmup_full": [42, 0], + "warm_full": [42, 0], + "warmup_partial": [42, 0], + "warm_partial": [43, 0], + } + with patch.object( + _VALIDATION, + "_generate_batch", + side_effect=_fake_generate_batch(tokens_by_label), + ): + rc = _VALIDATION.run_exactness(_args()) + + self.assertEqual(rc, 0) + + def test_batched_exactness_checks_two_concurrent_partials(self): + tokens_by_label = { + "cold_partial_a": [42, 0], + "cold_partial_b": [43, 0], + "warmup_full_a": [44, 0], + "warmup_full_b": [45, 0], + "warm_partial_a": [42, 0], + "warm_partial_b": [43, 0], + } + with tempfile.TemporaryDirectory() as tmpdir: + output_json = Path(tmpdir) / "batched_report.json" + with patch.object( + _VALIDATION, + "_generate_batch", + side_effect=_fake_generate_batch(tokens_by_label), + ): + with patch.object( + _VALIDATION, + "_generate_grouped_batch", + side_effect=_fake_generate_grouped_batch(tokens_by_label), + ): + rc = _VALIDATION.run_batched_exactness( + _args(output_json=output_json) + ) + + self.assertEqual(rc, 0) + report = output_json.read_text(encoding="utf-8") + self.assertIn('"batched_partial_a_exact": true', report) + self.assertIn('"batched_partial_b_exact": true', report) + self.assertIn('"max_num_seqs": 2', report) + + def test_batched_exactness_requires_second_prefix(self): + with self.assertRaisesRegex(ValueError, "--shared-prefix-2 is required"): + _VALIDATION.run_batched_exactness(_args(shared_prefix_2="")) + + def test_batched_exactness_preflights_tkg_batch_size(self): + with tempfile.TemporaryDirectory() as tmpdir: + config_path = Path(tmpdir) / "neuron_config.json" + config_path.write_text( + json.dumps( + {"neuron_config": {"tkg_batch_size": 1, "ctx_batch_size": 2}} + ), + encoding="utf-8", + ) + + with self.assertRaisesRegex( + ValueError, + "tkg_batch_size=1 and max_num_seqs=2", + ): + _VALIDATION.run_batched_exactness( + _args(compiled_artifacts=tmpdir, max_num_seqs=2) + ) + + def test_batched_exactness_preflights_ctx_batch_size(self): + with tempfile.TemporaryDirectory() as tmpdir: + config_path = Path(tmpdir) / "neuron_config.json" + config_path.write_text( + json.dumps( + {"neuron_config": {"tkg_batch_size": 2, "ctx_batch_size": 1}} + ), + encoding="utf-8", + ) + + with self.assertRaisesRegex( + ValueError, + "ctx_batch_size=1 and max_num_seqs=2", + ): + _VALIDATION.run_batched_exactness( + _args(compiled_artifacts=tmpdir, max_num_seqs=2) + ) + + def test_runtime_additional_config_uses_compiled_max_prompt_length(self): + additional_config = { + "max_prompt_length": 512, + "override_neuron_config": { + "max_context_length": 512, + "context_encoding_buckets": [256, 512], + }, + } + with tempfile.TemporaryDirectory() as tmpdir: + config_path = Path(tmpdir) / "neuron_config.json" + config_path.write_text( + json.dumps( + { + "neuron_config": { + "max_context_length": 131072, + "context_encoding_buckets": [256, 512], + } + } + ), + encoding="utf-8", + ) + + aligned = _VALIDATION._align_additional_config_to_compiled_artifact( + _args(compiled_artifacts=tmpdir), + additional_config, + ) + + self.assertEqual(aligned["max_prompt_length"], 131072) + self.assertEqual( + aligned["override_neuron_config"]["max_context_length"], + 131072, + ) + self.assertEqual( + aligned["override_neuron_config"]["context_encoding_buckets"], + [256, 512], + ) + self.assertEqual(additional_config["max_prompt_length"], 512) + + def test_compact_boundary_lengths_cover_checkpoint_edges(self): + self.assertEqual( + _VALIDATION._compact_boundary_lengths( + _args(block_size=4, seq_len=32, max_tokens=1, compact_suffix_tokens=2) + ), + [3, 4, 5, 7, 8, 9], + ) + + def test_compact_gate_requires_strict_metadata(self): + with self.assertRaisesRegex(ValueError, "requires --hybrid-apc-require"): + _VALIDATION.run_compact_gate( + _args(hybrid_apc_require_vllm_metadata=False) + ) + + def test_compact_gate_reports_targeted_exactness_and_speedup(self): + class FakeTokenizer: + pad_token_id = 0 + eos_token_id = 0 + + def encode(self, prompt, add_special_tokens=False): + del add_special_tokens + return list(range(len(prompt.split()))) + + class FakeAutoTokenizer: + @staticmethod + def from_pretrained(_model_path, trust_remote_code): + del trust_remote_code + return FakeTokenizer() + + fake_transformers = SimpleNamespace(AutoTokenizer=FakeAutoTokenizer) + with tempfile.TemporaryDirectory() as tmpdir: + output_json = Path(tmpdir) / "compact.json" + with patch.dict(sys.modules, {"transformers": fake_transformers}): + with patch.object( + _VALIDATION, + "_generate_batch", + side_effect=_fake_compact_generate_batch, + ): + with patch.object( + _VALIDATION, + "_generate_grouped_batch", + side_effect=_fake_compact_generate_grouped_batch, + ): + rc = _VALIDATION.run_compact_gate( + _args( + block_size=4, + gdn_checkpoint_interval=4, + compact_boundary_lens=["3,4"], + compact_suffix_tokens=2, + compact_min_requests=20, + output_json=output_json, + ) + ) + + self.assertEqual(rc, 0) + report = json.loads(output_json.read_text(encoding="utf-8")) + self.assertTrue(report["compact_gate_passed"]) + self.assertEqual(report["boundary_lengths"], [3, 4]) + self.assertGreaterEqual(report["acceptance"]["request_count"], 20) + self.assertTrue(report["acceptance"]["exactness_passed"]) + self.assertTrue(report["acceptance"]["speedup_passed"]) + + +if __name__ == "__main__": + unittest.main() diff --git a/contrib/models/Qwen3.6-27B/test/unit/test_hybrid_cache_manager.py b/contrib/models/Qwen3.6-27B/test/unit/test_hybrid_cache_manager.py new file mode 100644 index 00000000..6e6893c7 --- /dev/null +++ b/contrib/models/Qwen3.6-27B/test/unit/test_hybrid_cache_manager.py @@ -0,0 +1,464 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +import os +import sys +import unittest +from math import prod +from unittest.mock import patch + +import torch + +_CONTRIB_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")) +if _CONTRIB_ROOT not in sys.path: + sys.path.insert(0, _CONTRIB_ROOT) + +from neuronx_distributed_inference.models.config import NeuronConfig +from neuronx_distributed.quantization.quantization_config import KVQuantizationConfig +from src.modeling_qwen35 import ( + HybridDeltaNetCacheManager, + Qwen35InferenceConfig, + QwenHybridBlockKVCacheManager, + plan_gdn_apc_reuse, +) + + +def _make_config(**overrides): + neuron_overrides = overrides.pop("neuron_overrides", {}) + neuron_kwargs = dict( + tp_degree=overrides.pop("tp_degree", 4), + batch_size=1, + max_batch_size=2, + kv_cache_batch_size=2, + seq_len=16, + torch_dtype=torch.bfloat16, + ) + neuron_kwargs.update(neuron_overrides) + neuron_config = NeuronConfig(**neuron_kwargs) + defaults = dict( + hidden_size=5120, + num_hidden_layers=64, + num_attention_heads=24, + num_key_value_heads=4, + head_dim=256, + intermediate_size=17408, + vocab_size=248320, + rms_norm_eps=1e-6, + max_position_embeddings=131072, + rope_theta=10000, + hidden_act="silu", + tie_word_embeddings=False, + linear_num_value_heads=48, + linear_num_key_heads=16, + linear_key_head_dim=128, + linear_value_head_dim=128, + linear_conv_kernel_dim=4, + use_hybrid_cache_manager=True, + ) + defaults.update(overrides) + return Qwen35InferenceConfig(neuron_config=neuron_config, **defaults) + + +def _numel(shape): + return prod(int(dim) for dim in shape) + + +def _managed_cache_numel(mgr): + return sum(param.numel() for param in mgr.past_key_values) + + +def _deltanet_state_numel(config, max_batch_size): + tp_degree = config.neuron_config.tp_degree + local_num_value_heads = config.linear_num_value_heads // tp_degree + local_num_key_heads = config.linear_num_key_heads // tp_degree + recurrent = ( + max_batch_size + * local_num_value_heads + * config.linear_key_head_dim + * config.linear_value_head_dim + ) + conv_dim = ( + 2 * local_num_key_heads * config.linear_key_head_dim + + local_num_value_heads * config.linear_value_head_dim + ) + conv = max_batch_size * conv_dim * (config.linear_conv_kernel_dim - 1) + return recurrent + conv + + +def _local_value_heads(config): + return config.linear_num_value_heads // config.neuron_config.tp_degree + + +def _local_key_heads(config): + return config.linear_num_key_heads // config.neuron_config.tp_degree + + +def _conv_dim(config): + return ( + 2 * _local_key_heads(config) * config.linear_key_head_dim + + _local_value_heads(config) * config.linear_value_head_dim + ) + + +def _recurrent_shape(config, batch_size): + return [ + batch_size, + _local_value_heads(config), + config.linear_key_head_dim, + config.linear_value_head_dim, + ] + + +def _conv_shape(config, batch_size): + return [batch_size, _conv_dim(config), config.linear_conv_kernel_dim - 1] + + +class TestQwenHybridBlockKVCacheManager(unittest.TestCase): + def test_hybrid_apc_block_cache_dequantizes_selected_attention_blocks_only(self): + kv_quant_config = KVQuantizationConfig( + quant_dtype=torch.bfloat16, + direct_cast=True, + ) + config = _make_config( + tp_degree=1, + hidden_size=32, + num_hidden_layers=4, + num_attention_heads=8, + num_key_value_heads=4, + head_dim=4, + linear_num_value_heads=4, + linear_num_key_heads=4, + linear_key_head_dim=4, + linear_value_head_dim=4, + gdn_checkpoint_interval=128, + use_hybrid_cache_manager=False, + use_hybrid_apc_manager=True, + neuron_overrides={ + "batch_size": 1, + "max_batch_size": 1, + "kv_cache_batch_size": 1, + "is_block_kv_layout": True, + "is_prefix_caching": True, + "max_length": 9600, + "pa_num_blocks": 16, + "pa_block_size": 128, + "torch_dtype": torch.float32, + "kv_quant_config": kv_quant_config, + }, + ) + mgr = QwenHybridBlockKVCacheManager( + config, + num_kv_head=config.num_key_value_heads, + ) + seen_shapes = [] + original_dequantize = mgr._dequantize_cache + + def record_dequantize(cache_tensor, layer_idx, is_key=True): + seen_shapes.append((layer_idx, is_key, tuple(cache_tensor.shape))) + return original_dequantize(cache_tensor, layer_idx, is_key=is_key) + + mgr._dequantize_cache = record_dequantize + + cache = mgr.get_cache(active_block_table=torch.tensor([[0, 2]], dtype=torch.int64)) + + self.assertEqual(cache[0][0].shape, mgr._LINEAR_PLACEHOLDER_SHAPE) + self.assertEqual(cache[1][0].shape, mgr._LINEAR_PLACEHOLDER_SHAPE) + self.assertEqual(cache[2][0].shape, mgr._LINEAR_PLACEHOLDER_SHAPE) + self.assertEqual(cache[3][0].dtype, torch.float32) + self.assertEqual(cache[3][1].dtype, torch.float32) + self.assertEqual( + seen_shapes, + [ + (3, True, (1, 4, 256, 4)), + (3, False, (1, 4, 256, 4)), + ], + ) + + +class TestHybridDeltaNetCacheManager(unittest.TestCase): + def test_allocates_per_layer_cache_shapes(self): + config = _make_config() + mgr = HybridDeltaNetCacheManager(config, num_kv_head=config.num_key_value_heads) + + self.assertEqual(len(mgr.past_key_values), config.num_hidden_layers * 2) + self.assertEqual(list(mgr.past_key_values[0].shape), _recurrent_shape(config, 2)) + self.assertEqual(list(mgr.past_key_values[1].shape), _conv_shape(config, 2)) + self.assertEqual(mgr.past_key_values[0].dtype, torch.float32) + self.assertEqual(mgr.past_key_values[1].dtype, torch.bfloat16) + self.assertEqual(mgr.layer_types[3], "full_attention") + self.assertEqual(mgr.past_key_values[6].dim(), 4) + self.assertEqual(mgr.past_key_values[7].shape[2], 16) + + def test_get_cache_slices_only_full_attention_layers(self): + config = _make_config() + mgr = HybridDeltaNetCacheManager(config, num_kv_head=config.num_key_value_heads) + + cache = mgr.get_cache(seq_len=4, seq_ids=torch.tensor([1])) + recurrent_state, conv_state = cache[0] + full_k, full_v = cache[3] + + self.assertEqual(list(recurrent_state.shape), _recurrent_shape(config, 1)) + self.assertEqual(list(conv_state.shape), _conv_shape(config, 1)) + self.assertEqual(full_k.shape[0], 2) + self.assertEqual(full_v.shape[0], 2) + self.assertEqual(full_k.shape[2], 4) + self.assertEqual(full_v.shape[2], 4) + + def test_get_seq_length_uses_first_full_attention_layer(self): + config = _make_config() + mgr = HybridDeltaNetCacheManager(config, num_kv_head=config.num_key_value_heads) + + nested_cache = mgr.get_cache(seq_len=5, seq_ids=torch.tensor([0])) + flat_cache = [tensor for layer_cache in nested_cache for tensor in layer_cache] + + self.assertEqual(nested_cache[0][1].shape[2], 3) + self.assertEqual(mgr.get_seq_length(nested_cache), 5) + self.assertEqual(mgr.get_seq_length(flat_cache), 5) + + def test_get_cache_selects_deltanet_state_rows_by_seq_ids(self): + config = _make_config() + mgr = HybridDeltaNetCacheManager(config, num_kv_head=config.num_key_value_heads) + with torch.no_grad(): + mgr.past_key_values[0][0, ...].fill_(7) + mgr.past_key_values[0][1, ...].fill_(13) + mgr.past_key_values[1][0, ...].fill_(17) + mgr.past_key_values[1][1, ...].fill_(19) + + recurrent_state, conv_state = mgr.get_cache( + seq_len=4, + seq_ids=torch.tensor([1, 0]), + )[0] + + self.assertTrue(torch.all(recurrent_state[0] == 13)) + self.assertTrue(torch.all(recurrent_state[1] == 7)) + self.assertTrue(torch.all(conv_state[0] == 19)) + self.assertTrue(torch.all(conv_state[1] == 17)) + + def test_deltanet_update_scatters_by_seq_id(self): + config = _make_config() + mgr = HybridDeltaNetCacheManager(config, num_kv_head=config.num_key_value_heads) + recurrent = torch.ones(_recurrent_shape(config, 1), dtype=torch.bfloat16) + conv = torch.ones(_conv_shape(config, 1), dtype=torch.bfloat16) + + updated_recurrent, updated_conv = mgr.update_deltanet_state_by_layer_id( + idx=0, + seq_ids=torch.tensor([1]), + state_per_layer=(recurrent, conv), + ) + + self.assertTrue(torch.all(updated_recurrent[0] == 0)) + self.assertTrue(torch.all(updated_conv[0] == 0)) + self.assertTrue(torch.all(updated_recurrent[1] == 1)) + self.assertTrue(torch.all(updated_conv[1] == 1)) + + def test_deltanet_full_batch_update_replaces_state_cache(self): + config = _make_config() + mgr = HybridDeltaNetCacheManager(config, num_kv_head=config.num_key_value_heads) + recurrent = torch.ones(_recurrent_shape(config, 2), dtype=torch.bfloat16) + conv = torch.ones(_conv_shape(config, 2), dtype=torch.bfloat16) + recurrent[0].fill_(3) + recurrent[1].fill_(5) + conv[0].fill_(11) + conv[1].fill_(13) + + updated_recurrent, updated_conv = mgr.update_deltanet_state_by_layer_id( + idx=0, + seq_ids=torch.tensor([0, 1]), + state_per_layer=(recurrent, conv), + ) + + self.assertTrue(torch.all(updated_recurrent[0] == 3)) + self.assertTrue(torch.all(updated_recurrent[1] == 5)) + self.assertTrue(torch.all(updated_conv[0] == 11)) + self.assertTrue(torch.all(updated_conv[1] == 13)) + + def test_deltanet_update_maps_out_of_range_seq_id_to_padding_row(self): + config = _make_config(neuron_overrides={"kv_cache_padding_size": 1}) + mgr = HybridDeltaNetCacheManager(config, num_kv_head=config.num_key_value_heads) + recurrent = torch.ones(_recurrent_shape(config, 1), dtype=torch.bfloat16) + conv = torch.ones(_conv_shape(config, 1), dtype=torch.bfloat16) + + updated_recurrent, updated_conv = mgr.update_deltanet_state_by_layer_id( + idx=0, + seq_ids=torch.tensor([99]), + state_per_layer=(recurrent, conv), + ) + + self.assertTrue(torch.all(updated_recurrent[0] == 0)) + self.assertTrue(torch.all(updated_recurrent[1] == 0)) + self.assertTrue(torch.all(updated_recurrent[2] == 1)) + self.assertTrue(torch.all(updated_conv[2] == 1)) + + def test_deltanet_state_shapes_do_not_scale_with_sequence_length(self): + short_config = _make_config(neuron_overrides={"seq_len": 128}) + long_config = _make_config(neuron_overrides={"seq_len": 2048}) + short_mgr = HybridDeltaNetCacheManager( + short_config, num_kv_head=short_config.num_key_value_heads + ) + long_mgr = HybridDeltaNetCacheManager( + long_config, num_kv_head=long_config.num_key_value_heads + ) + + self.assertEqual(short_mgr.past_key_values[0].shape, long_mgr.past_key_values[0].shape) + self.assertEqual(short_mgr.past_key_values[1].shape, long_mgr.past_key_values[1].shape) + self.assertLess(short_mgr.past_key_values[7].shape[2], long_mgr.past_key_values[7].shape[2]) + + def test_get_cache_trims_padding_row_without_seq_ids(self): + config = _make_config(neuron_overrides={"kv_cache_padding_size": 1}) + mgr = HybridDeltaNetCacheManager(config, num_kv_head=config.num_key_value_heads) + + recurrent_state, conv_state = mgr.get_cache(seq_len=4)[0] + + self.assertEqual(list(recurrent_state.shape), _recurrent_shape(config, 2)) + self.assertEqual(list(conv_state.shape), _conv_shape(config, 2)) + + def test_update_cache_dispatches_deltanet_and_full_attention_layers(self): + config = _make_config() + mgr = HybridDeltaNetCacheManager(config, num_kv_head=config.num_key_value_heads) + new_key_values = [] + for idx in range(4): + first = mgr.past_key_values[2 * idx] + second = mgr.past_key_values[2 * idx + 1] + new_key_values.append( + ( + torch.full_like(first, fill_value=idx + 1), + torch.full_like(second, fill_value=idx + 11), + ) + ) + + position_ids = torch.arange(16, dtype=torch.long).unsqueeze(0).expand(2, -1) + full_k_update = torch.full_like(mgr.past_key_values[6], fill_value=4) + full_v_update = torch.full_like(mgr.past_key_values[7], fill_value=14) + with patch.object( + mgr, "update_kv_by_layer_id", return_value=(full_k_update, full_v_update) + ) as update_kv: + updated = mgr.update_cache( + is_for_context_encoding=True, + seq_ids=torch.tensor([0, 1], dtype=torch.int32), + position_ids=position_ids, + new_key_values=new_key_values, + seq_len=16, + ) + + self.assertEqual(update_kv.call_count, 1) + self.assertEqual(update_kv.call_args.kwargs["idx"], 3) + self.assertTrue(torch.all(updated[0] == 1)) + self.assertTrue(torch.all(updated[1] == 11)) + self.assertTrue(torch.all(updated[6] == 4)) + self.assertTrue(torch.all(updated[7] == 14)) + + def test_managed_cache_removes_dummy_kv_for_deltanet_layers(self): + config = _make_config(neuron_overrides={"seq_len": 1024}) + mgr = HybridDeltaNetCacheManager(config, num_kv_head=config.num_key_value_heads) + max_batch_size = ( + config.neuron_config.kv_cache_batch_size + + config.neuron_config.kv_cache_padding_size + ) + full_kv_per_layer = _numel(mgr.k_shape) + _numel(mgr.v_shape) + deltanet_layers = config.layer_types.count("linear_attention") + legacy_total_numel = ( + full_kv_per_layer * config.num_hidden_layers + + _deltanet_state_numel(config, max_batch_size) * deltanet_layers + ) + expected_savings = full_kv_per_layer * deltanet_layers + + self.assertEqual( + legacy_total_numel - _managed_cache_numel(mgr), + expected_savings, + ) + self.assertLess(_managed_cache_numel(mgr), legacy_total_numel) + + def test_rejects_unsupported_hybrid_modes(self): + unsupported_cases = [ + ({"is_block_kv_layout": True}, "block KV layout"), + ({"padding_side": "left"}, "left padding"), + ({"flash_decoding_enabled": True}, "flash decoding"), + ({"is_continuous_batching": True}, "continuous batching"), + ] + + for neuron_overrides, expected_error in unsupported_cases: + with self.subTest(expected_error=expected_error): + config = _make_config(neuron_overrides=neuron_overrides) + with self.assertRaisesRegex(ValueError, expected_error): + HybridDeltaNetCacheManager( + config, num_kv_head=config.num_key_value_heads + ) + + config = _make_config() + config.neuron_config.kv_cache_quant = True + with self.assertRaisesRegex(ValueError, "KV cache quantization"): + HybridDeltaNetCacheManager(config, num_kv_head=config.num_key_value_heads) + + config = _make_config( + neuron_overrides={ + "attention_dp_degree": 2, + "batch_size": 2, + "ctx_batch_size": 2, + "tkg_batch_size": 2, + "max_batch_size": 2, + "kv_cache_batch_size": 2, + "is_continuous_batching": True, + } + ) + with self.assertRaisesRegex(ValueError, "attention data parallelism"): + HybridDeltaNetCacheManager(config, num_kv_head=config.num_key_value_heads) + + config = _make_config() + config.neuron_config.kv_cache_tiling = True + with self.assertRaisesRegex(ValueError, "KV cache tiling"): + HybridDeltaNetCacheManager(config, num_kv_head=config.num_key_value_heads) + + def test_legacy_config_default_is_disabled(self): + config = _make_config(use_hybrid_cache_manager=False) + self.assertFalse(config.use_hybrid_cache_manager) + + def test_gdn_apc_plan_uses_intersection_across_state_families(self): + plan = plan_gdn_apc_reuse( + attention_hit_len=2048, + recurrent_hit_len=1536, + conv_hit_len=1792, + request_prefix_len=2304, + gdn_checkpoint_interval=256, + ) + + self.assertEqual(plan.reusable_prefix_len, 1536) + self.assertEqual(plan.restore_checkpoint_prefix_len, 1536) + self.assertEqual(plan.residual_replay_len, 0) + self.assertEqual(plan.suffix_len, 768) + + def test_gdn_apc_plan_replays_residual_after_prior_checkpoint(self): + plan = plan_gdn_apc_reuse( + attention_hit_len=1152, + recurrent_hit_len=1152, + conv_hit_len=1152, + request_prefix_len=1408, + gdn_checkpoint_interval=256, + ) + + self.assertEqual(plan.reusable_prefix_len, 1152) + self.assertEqual(plan.restore_checkpoint_prefix_len, 1024) + self.assertEqual(plan.residual_replay_len, 128) + self.assertEqual(plan.suffix_len, 256) + + def test_gdn_apc_plan_validates_lengths(self): + with self.assertRaisesRegex(ValueError, "gdn_checkpoint_interval"): + plan_gdn_apc_reuse( + attention_hit_len=1, + recurrent_hit_len=1, + conv_hit_len=1, + request_prefix_len=1, + gdn_checkpoint_interval=0, + ) + + with self.assertRaisesRegex(ValueError, "attention_hit_len"): + plan_gdn_apc_reuse( + attention_hit_len=-1, + recurrent_hit_len=1, + conv_hit_len=1, + request_prefix_len=1, + gdn_checkpoint_interval=256, + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/contrib/models/Qwen3.6-27B/test/unit/test_qwen36_artifact_config_audit.py b/contrib/models/Qwen3.6-27B/test/unit/test_qwen36_artifact_config_audit.py new file mode 100644 index 00000000..75194eb2 --- /dev/null +++ b/contrib/models/Qwen3.6-27B/test/unit/test_qwen36_artifact_config_audit.py @@ -0,0 +1,102 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +import importlib.util +import json +import sys +import tempfile +import unittest +from pathlib import Path + + +_REPO_ROOT = Path(__file__).resolve().parents[5] +_AUDIT_PATH = _REPO_ROOT / "validation_scripts" / "qwen36_artifact_config_audit.py" +_SPEC = importlib.util.spec_from_file_location( + "qwen36_artifact_config_audit_under_test", + _AUDIT_PATH, +) +_AUDIT = importlib.util.module_from_spec(_SPEC) +sys.modules[_SPEC.name] = _AUDIT +_SPEC.loader.exec_module(_AUDIT) + + +class TestQwen36ArtifactConfigAudit(unittest.TestCase): + def test_audit_flags_current_low_headroom_nki_chunked_shape(self): + with tempfile.TemporaryDirectory() as tmpdir: + artifact = Path(tmpdir) / "qwen36_nki_chunked_artifact" + artifact.mkdir() + (artifact / "neuron_config.json").write_text( + json.dumps( + { + "seq_len": 4096, + "batch_size": 2, + "ctx_batch_size": 2, + "pa_block_size": 256, + "pa_num_blocks": 33, + "max_gdn_checkpoint_slots": 8, + "context_encoding_buckets": [256, 512, 1024, 2048, 4096], + "prefix_buckets": [4096], + "is_prefix_caching": True, + "use_hybrid_apc_manager": True, + } + ) + ) + + summary = _AUDIT.audit( + artifact=artifact, + compile_log=None, + recommended_block_size=32, + min_usable_headroom_blocks=8, + strict_hybrid_gate=True, + ) + + warning_codes = {warning["code"] for warning in summary["warnings"]} + self.assertEqual(summary["pa_min_blocks"], 32) + self.assertEqual(summary["pa_usable_headroom_blocks"], 1) + self.assertIn("non_recommended_block_size", warning_codes) + self.assertIn("low_pa_headroom", warning_codes) + self.assertIn("strict_gate_boundary_slots_exceed_gdn_slots", warning_codes) + self.assertIn("nki_chunked_deltanet_cte", warning_codes) + + def test_audit_reads_nested_neuron_config(self): + with tempfile.TemporaryDirectory() as tmpdir: + artifact = Path(tmpdir) / "qwen36_128k_fp8_artifact" + artifact.mkdir() + (artifact / "neuron_config.json").write_text( + json.dumps( + { + "ctx_batch_size": 1, + "max_gdn_checkpoint_slots": 64, + "use_hybrid_apc_manager": True, + "neuron_config": { + "seq_len": 131072, + "batch_size": 1, + "pa_block_size": 256, + "pa_num_blocks": 512, + "context_encoding_buckets": [256, 512], + "prefix_buckets": [256, 512, 1024, 2048, 4096, 8192, 16384], + "is_prefix_caching": True, + }, + } + ) + ) + + summary = _AUDIT.audit( + artifact=artifact, + compile_log=None, + recommended_block_size=256, + min_usable_headroom_blocks=0, + strict_hybrid_gate=False, + ) + + self.assertEqual(summary["seq_len"], 131072) + self.assertEqual(summary["pa_block_size"], 256) + self.assertEqual(summary["pa_num_blocks"], 512) + self.assertEqual(summary["pa_min_blocks"], 512) + self.assertEqual(summary["context_encoding_buckets"], [256, 512]) + self.assertEqual(summary["prefix_buckets"][-1], 16384) + self.assertTrue(summary["is_prefix_caching"]) + + +if __name__ == "__main__": + unittest.main() diff --git a/contrib/models/Qwen3.6-27B/test/unit/test_qwen36_chat_proxy.py b/contrib/models/Qwen3.6-27B/test/unit/test_qwen36_chat_proxy.py new file mode 100644 index 00000000..11246e54 --- /dev/null +++ b/contrib/models/Qwen3.6-27B/test/unit/test_qwen36_chat_proxy.py @@ -0,0 +1,166 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +import importlib.util +import unittest +from pathlib import Path + + +_REPO_ROOT = Path(__file__).resolve().parents[5] +_PROXY_PATH = ( + _REPO_ROOT / "contrib" / "models" / "Qwen3.6-27B" / "vllm" / "qwen36_chat_proxy.py" +) +_SPEC = importlib.util.spec_from_file_location("qwen36_chat_proxy_under_test", _PROXY_PATH) +_PROXY = importlib.util.module_from_spec(_SPEC) +_SPEC.loader.exec_module(_PROXY) + + +class TestQwen36ChatProxy(unittest.TestCase): + def test_default_policy_disables_thinking(self): + payload = {"messages": [{"role": "user", "content": "hello"}]} + + enabled = _PROXY._apply_thinking_policy(payload, allow_thinking=True) + + self.assertFalse(enabled) + self.assertEqual(payload["chat_template_kwargs"], {"enable_thinking": False}) + + def test_default_thinking_enables_when_no_request_toggle(self): + payload = {"messages": [{"role": "user", "content": "hello"}]} + + enabled = _PROXY._apply_thinking_policy( + payload, + allow_thinking=True, + default_thinking=True, + ) + + self.assertTrue(enabled) + self.assertEqual(payload["chat_template_kwargs"], {"enable_thinking": True}) + + def test_default_thinking_allows_explicit_disable(self): + payload = {"enable_thinking": False} + + enabled = _PROXY._apply_thinking_policy( + payload, + allow_thinking=True, + default_thinking=True, + ) + + self.assertFalse(enabled) + self.assertEqual(payload["chat_template_kwargs"], {"enable_thinking": False}) + self.assertNotIn("enable_thinking", payload) + + def test_force_disabled_policy_overrides_request(self): + payload = { + "enable_thinking": True, + "chat_template_kwargs": {"enable_thinking": True, "foo": "bar"}, + } + + enabled = _PROXY._apply_thinking_policy(payload, allow_thinking=False) + + self.assertFalse(enabled) + self.assertEqual( + payload["chat_template_kwargs"], + {"enable_thinking": False, "foo": "bar"}, + ) + self.assertNotIn("enable_thinking", payload) + + def test_allow_thinking_accepts_top_level_toggle(self): + payload = {"enable_thinking": True, "chat_template_kwargs": {"foo": "bar"}} + + enabled = _PROXY._apply_thinking_policy(payload, allow_thinking=True) + + self.assertTrue(enabled) + self.assertEqual( + payload["chat_template_kwargs"], + {"enable_thinking": True, "foo": "bar"}, + ) + self.assertNotIn("enable_thinking", payload) + + def test_allow_thinking_accepts_native_chat_template_kwargs(self): + payload = {"chat_template_kwargs": {"enable_thinking": "true"}} + + enabled = _PROXY._apply_thinking_policy(payload, allow_thinking=True) + + self.assertTrue(enabled) + self.assertEqual(payload["chat_template_kwargs"], {"enable_thinking": True}) + + def test_allow_thinking_accepts_reasoning_effort(self): + payload = {"reasoning_effort": "low"} + + enabled = _PROXY._apply_thinking_policy(payload, allow_thinking=True) + + self.assertTrue(enabled) + self.assertEqual(payload["chat_template_kwargs"], {"enable_thinking": True}) + self.assertNotIn("reasoning_effort", payload) + + def test_reasoning_effort_none_disables_thinking(self): + payload = {"reasoning_effort": "none"} + + enabled = _PROXY._apply_thinking_policy(payload, allow_thinking=True) + + self.assertFalse(enabled) + self.assertEqual(payload["chat_template_kwargs"], {"enable_thinking": False}) + self.assertNotIn("reasoning_effort", payload) + + def test_thinking_budget_tokens_can_toggle(self): + payload = {"thinking": {"budget_tokens": 128}} + + enabled = _PROXY._apply_thinking_policy(payload, allow_thinking=True) + + self.assertTrue(enabled) + self.assertEqual(payload["chat_template_kwargs"], {"enable_thinking": True}) + self.assertNotIn("thinking", payload) + + def test_system_and_developer_messages_are_hoisted(self): + messages = [ + {"role": "user", "content": "first"}, + {"role": "system", "content": "sys"}, + {"role": "developer", "content": [{"type": "text", "text": "dev"}]}, + {"role": "assistant", "content": "ok"}, + ] + + normalized = _PROXY._normalize_messages_for_qwen(messages) + + self.assertEqual(normalized[0], {"role": "system", "content": "sys\n\ndev"}) + self.assertEqual([message["role"] for message in normalized], ["system", "user", "assistant"]) + + def test_chat_path_allows_trailing_slash_and_query(self): + self.assertEqual(_PROXY._request_path("/v1/chat/completions"), "/v1/chat/completions") + self.assertEqual(_PROXY._request_path("/v1/chat/completions/"), "/v1/chat/completions") + self.assertEqual( + _PROXY._request_path("/v1/chat/completions?api-version=1"), + "/v1/chat/completions", + ) + + def test_streaming_thinking_start_is_injected_when_missing(self): + event = ( + b'data: {"choices":[{"delta":{"content":"Here is the thought"}}]}\n\n' + ) + + patched, decided, changed = _PROXY._prepend_think_start_to_sse_event(event) + + self.assertTrue(decided) + self.assertTrue(changed) + self.assertIn(b'"content":"\\nHere is the thought"', patched) + + def test_streaming_thinking_start_is_not_duplicated(self): + event = b'data: {"choices":[{"delta":{"content":"\\n\\n\\nThought"}}]}\n\n' + + patched, decided, changed = _PROXY._prepend_think_start_to_sse_event(event) + + self.assertTrue(decided) + self.assertFalse(changed) + self.assertEqual(patched, event) + + def test_streaming_usage_chunk_does_not_decide_thinking_start(self): + event = b'data: {"choices":[],"usage":{"completion_tokens":1}}\n\n' + + patched, decided, changed = _PROXY._prepend_think_start_to_sse_event(event) + + self.assertFalse(decided) + self.assertFalse(changed) + self.assertEqual(patched, event) + + +if __name__ == "__main__": + unittest.main() diff --git a/contrib/models/Qwen3.6-27B/test/unit/test_qwen36_compile_fp8_config.py b/contrib/models/Qwen3.6-27B/test/unit/test_qwen36_compile_fp8_config.py new file mode 100644 index 00000000..0aacc219 --- /dev/null +++ b/contrib/models/Qwen3.6-27B/test/unit/test_qwen36_compile_fp8_config.py @@ -0,0 +1,1307 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +import argparse +import importlib.util +import os +import sys +import tempfile +import types +import unittest +from pathlib import Path +from unittest.mock import patch + + +_REPO_ROOT = Path(__file__).resolve().parents[5] +_REPO_SRC = _REPO_ROOT / "src" +if str(_REPO_SRC) not in sys.path: + sys.path.insert(0, str(_REPO_SRC)) + +_COMPILE_PATH = ( + _REPO_ROOT + / "contrib" + / "models" + / "Qwen3.6-27B" + / "test" + / "integration" + / "qwen36_27b_compile_fp8.py" +) +_SPEC = importlib.util.spec_from_file_location( + "qwen36_compile_fp8_under_test", + _COMPILE_PATH, +) +_COMPILE = importlib.util.module_from_spec(_SPEC) +sys.modules[_SPEC.name] = _COMPILE +_SPEC.loader.exec_module(_COMPILE) + + +class _FakeQwen35InferenceConfig: + def __init__(self, *, neuron_config, **config_dict): + self.neuron_config = neuron_config + self.config_dict = config_dict + + +class _FakeNeuronConfig: + def __init__(self, **kwargs): + self.__dict__.update(kwargs) + self.output_logits = kwargs.get("output_logits", False) + self.on_device_sampling_config = kwargs.get("on_device_sampling_config") + self.disable_argmax_kernel = kwargs.get("disable_argmax_kernel", False) + self.disable_context_encoding_argmax_kernel = kwargs.get( + "disable_context_encoding_argmax_kernel", False + ) + + +class _FakeOnDeviceSamplingConfig: + def __init__(self, **kwargs): + self.__dict__.update(kwargs) + + +class _FakeChunkedPrefillConfig: + def __init__(self, **kwargs): + self.__dict__.update(kwargs) + + +def _fake_config_module(): + module = types.ModuleType("neuronx_distributed_inference.models.config") + module.NeuronConfig = _FakeNeuronConfig + module.OnDeviceSamplingConfig = _FakeOnDeviceSamplingConfig + module.ChunkedPrefillConfig = _FakeChunkedPrefillConfig + return module + + +def _fake_qwen_module(): + module = types.ModuleType("src.modeling_qwen35") + module.Qwen35InferenceConfig = _FakeQwen35InferenceConfig + return module + + +def _args(**overrides): + defaults = dict( + model_path="/tmp/qwen36", + quantized_checkpoints_path="/tmp/qwen36-fp8", + weight_dtype="fp8_mlp_only", + seq_len=2048, + max_context_length=None, + cte_bucket=512, + cte_buckets=["256,512"], + prefix_buckets=None, + context_encoding_bucket_pairs=None, + omit_zero_prefix_pair=False, + token_generation_buckets=None, + token_generation_batches=None, + disable_token_generation_wlo=False, + weights_to_skip_layout_optimization=None, + block_size=256, + pa_num_blocks=8, + pa_headroom_blocks=0, + tp_degree=4, + logical_nc_config=2, + max_num_seqs=1, + ctx_batch_size=1, + skip_warmup=False, + async_mode=False, + enable_prefix_caching=True, + enable_hybrid_apc=True, + enable_vllm_chunked_prefill=False, + text_only_cte=True, + compact_cte_attention_mask=True, + cold_zero_conv_fast_path=False, + enable_deltanet_decode_nki=False, + deltanet_cte_backend="env", + disable_on_device_sampling=True, + disable_argmax_kernel=False, + disable_context_encoding_argmax_kernel=False, + output_logits_with_on_device_sampling=False, + kernel_q_tile_size=128, + kernel_kv_tile_size=1024, + enable_fused_qkv=False, + enable_qkv_nki_kernels=False, + enable_qkv_cte_nki_kernel_fuse_rope=False, + enable_qwen_qk_norm_rope_nki_kernel=False, + enable_qwen_output_gate_nki_kernel=False, + enable_qwen_qkv_gate_packed_kernel=False, + enable_qwen_gated_o_proj_nki_kernel=False, + enable_split_qkv_tkg_nki_kernel=False, + enable_attn_block_tkg_nki_kernel=False, + enable_attn_block_tkg_cascaded_attention=False, + enable_attn_block_tkg_cache_update=False, + enable_out_proj_nki_kernel=False, + enable_mlp_cte_nki_kernel=False, + enable_mlp_tkg_nki_kernel=False, + enable_quantized_mlp_kernel=False, + enable_k_cache_transposed=False, + enable_kv_cache_quant=False, + prefix_cte_attention_chunk_size=None, + prefix_cte_attention_backend="attention_cte", + prefix_cte_attention_segment_size=None, + disable_static_hybrid_cache=False, + gdn_checkpoint_interval=256, + max_gdn_checkpoint_slots=8, + gdn_recurrent_cache_dtype="float32", + gdn_conv_cache_dtype="bfloat16", + hybrid_cache_mode="all", + hybrid_apc_require_vllm_metadata=False, + hybrid_apc_enable_backed_prefix_reads=False, + hybrid_apc_commit_during_token_generation=False, + quantize_edge_mlp_layers=False, + quantize_lm_head=False, + fp8_quantize_linear_attn_gates=False, + fp8_exclude_groups=[], + ) + defaults.update(overrides) + return argparse.Namespace(**defaults) + + +class TestQwen36CompileFp8Config(unittest.TestCase): + def test_fp8_environment_defaults_are_set_without_overriding_existing(self): + with patch.dict(os.environ, {}, clear=True): + _COMPILE._ensure_fp8_environment() + self.assertEqual(os.environ["XLA_HANDLE_SPECIAL_SCALAR"], "1") + self.assertEqual(os.environ["UNSAFE_FP8FNCAST"], "1") + + with patch.dict( + os.environ, + { + "XLA_HANDLE_SPECIAL_SCALAR": "custom", + "UNSAFE_FP8FNCAST": "custom", + }, + clear=True, + ): + _COMPILE._ensure_fp8_environment() + self.assertEqual(os.environ["XLA_HANDLE_SPECIAL_SCALAR"], "custom") + self.assertEqual(os.environ["UNSAFE_FP8FNCAST"], "custom") + + def test_host_sampling_compile_keeps_output_logits_enabled(self): + with patch.object( + _COMPILE, + "_load_text_config", + return_value={"num_hidden_layers": 2}, + ), patch.dict( + sys.modules, + { + "neuronx_distributed_inference.models.config": _fake_config_module(), + "src.modeling_qwen35": _fake_qwen_module(), + }, + ): + config, _modules = _COMPILE._build_config( + _args(disable_on_device_sampling=True), + ) + + self.assertTrue(config.neuron_config.output_logits) + self.assertIsNone(config.neuron_config.on_device_sampling_config) + self.assertEqual(config.neuron_config.pa_num_blocks, 8) + self.assertTrue(config.neuron_config.quantized) + + def test_full_fp8_keeps_hybrid_checkpoint_bank_out_of_conversion(self): + with patch.object( + _COMPILE, + "_load_text_config", + return_value={"num_hidden_layers": 2}, + ), patch.dict( + sys.modules, + { + "neuronx_distributed_inference.models.config": _fake_config_module(), + "src.modeling_qwen35": _fake_qwen_module(), + }, + ): + config, modules = _COMPILE._build_config( + _args(weight_dtype="fp8_full", quantize_lm_head=True), + ) + + self.assertEqual(config.config_dict["gdn_recurrent_cache_dtype"], "float32") + self.assertIn("hybrid_gdn_checkpoint_cache.recurrent_slots", modules) + self.assertIn("hybrid_gdn_checkpoint_cache.conv_slots", modules) + self.assertIn( + "hybrid_gdn_checkpoint_cache.recurrent_slots", + config.neuron_config.modules_to_not_convert, + ) + + def test_compile_can_trace_batched_token_generation(self): + with patch.object( + _COMPILE, + "_load_text_config", + return_value={"num_hidden_layers": 2}, + ), patch.dict( + sys.modules, + { + "neuronx_distributed_inference.models.config": _fake_config_module(), + "src.modeling_qwen35": _fake_qwen_module(), + }, + ): + config, _modules = _COMPILE._build_config( + _args( + disable_on_device_sampling=True, + weight_dtype="bf16_control", + quantized_checkpoints_path=None, + max_num_seqs=2, + ctx_batch_size=1, + skip_warmup=True, + pa_num_blocks=16, + ), + ) + + self.assertEqual(config.neuron_config.batch_size, 2) + self.assertEqual(config.neuron_config.ctx_batch_size, 1) + self.assertEqual(config.neuron_config.tkg_batch_size, 2) + self.assertEqual(config.neuron_config.pa_num_blocks, 16) + self.assertTrue(config.neuron_config.skip_warmup) + + def test_compile_can_enable_block_tkg_attention_kernel_flags(self): + with patch.object( + _COMPILE, + "_load_text_config", + return_value={"num_hidden_layers": 2}, + ), patch.dict( + sys.modules, + { + "neuronx_distributed_inference.models.config": _fake_config_module(), + "src.modeling_qwen35": _fake_qwen_module(), + }, + ): + config, _modules = _COMPILE._build_config( + _args( + enable_qkv_nki_kernels=True, + enable_attn_block_tkg_nki_kernel=True, + enable_attn_block_tkg_cascaded_attention=True, + enable_attn_block_tkg_cache_update=True, + ), + ) + + self.assertTrue(config.neuron_config.qkv_kernel_enabled) + self.assertTrue(config.neuron_config.qkv_nki_kernel_enabled) + self.assertTrue(config.neuron_config.fused_qkv) + self.assertTrue(config.neuron_config.attn_block_tkg_nki_kernel_enabled) + self.assertTrue( + config.neuron_config.attn_block_tkg_nki_kernel_cascaded_attention, + ) + self.assertTrue(config.neuron_config.attn_block_tkg_nki_kernel_cache_update) + + def test_compile_can_enable_fused_qkv_without_qkv_kernel(self): + with patch.object( + _COMPILE, + "_load_text_config", + return_value={"num_hidden_layers": 2}, + ), patch.dict( + sys.modules, + { + "neuronx_distributed_inference.models.config": _fake_config_module(), + "src.modeling_qwen35": _fake_qwen_module(), + }, + ): + config, _modules = _COMPILE._build_config( + _args(enable_fused_qkv=True), + ) + + self.assertTrue(config.neuron_config.fused_qkv) + self.assertFalse( + getattr(config.neuron_config, "qkv_kernel_enabled", False), + ) + self.assertFalse( + getattr(config.neuron_config, "qkv_nki_kernel_enabled", False), + ) + + def test_compile_can_enable_qkv_cte_rope_fusion_flag(self): + with patch.object( + _COMPILE, + "_load_text_config", + return_value={ + "num_hidden_layers": 2, + "head_dim": 256, + "rope_dim": 256, + }, + ), patch.dict( + sys.modules, + { + "neuronx_distributed_inference.models.config": _fake_config_module(), + "src.modeling_qwen35": _fake_qwen_module(), + }, + ): + config, _modules = _COMPILE._build_config( + _args( + enable_qkv_nki_kernels=True, + enable_qkv_cte_nki_kernel_fuse_rope=True, + ), + ) + + self.assertTrue(config.neuron_config.qkv_kernel_enabled) + self.assertTrue(config.neuron_config.qkv_nki_kernel_enabled) + self.assertTrue(config.neuron_config.qkv_cte_nki_kernel_fuse_rope) + + def test_compile_rejects_qkv_cte_rope_fusion_for_partial_rope(self): + with patch.object( + _COMPILE, + "_load_text_config", + return_value={ + "num_hidden_layers": 2, + "head_dim": 256, + "rope_dim": 64, + }, + ), patch.dict( + sys.modules, + { + "neuronx_distributed_inference.models.config": _fake_config_module(), + "src.modeling_qwen35": _fake_qwen_module(), + }, + ): + with self.assertRaisesRegex(ValueError, "partial-RoPE Qwen3.6"): + _COMPILE._build_config( + _args( + enable_qkv_nki_kernels=True, + enable_qkv_cte_nki_kernel_fuse_rope=True, + ), + ) + + def test_compile_can_enable_qwen_qk_norm_rope_nki_kernel(self): + with patch.object( + _COMPILE, + "_load_text_config", + return_value={"num_hidden_layers": 2}, + ), patch.dict( + sys.modules, + { + "neuronx_distributed_inference.models.config": _fake_config_module(), + "src.modeling_qwen35": _fake_qwen_module(), + }, + ): + config, _modules = _COMPILE._build_config( + _args(enable_qwen_qk_norm_rope_nki_kernel=True), + ) + + self.assertTrue(config.config_dict["use_qwen_qk_norm_rope_nki"]) + + def test_compile_can_enable_qwen_output_gate_nki_kernel(self): + with patch.object( + _COMPILE, + "_load_text_config", + return_value={"num_hidden_layers": 2}, + ), patch.dict( + sys.modules, + { + "neuronx_distributed_inference.models.config": _fake_config_module(), + "src.modeling_qwen35": _fake_qwen_module(), + }, + ): + config, _modules = _COMPILE._build_config( + _args(enable_qwen_output_gate_nki_kernel=True), + ) + + self.assertTrue(config.config_dict["use_qwen_output_gate_nki"]) + + def test_compile_can_enable_qwen_qkv_gate_packed_kernel(self): + with patch.object( + _COMPILE, + "_load_text_config", + return_value={"num_hidden_layers": 2}, + ), patch.dict( + sys.modules, + { + "neuronx_distributed_inference.models.config": _fake_config_module(), + "src.modeling_qwen35": _fake_qwen_module(), + }, + ): + config, _modules = _COMPILE._build_config( + _args(enable_qwen_qkv_gate_packed_kernel=True), + ) + + self.assertTrue(config.config_dict["use_qwen_qkv_gate_packed"]) + + def test_compile_can_enable_qwen_gated_o_proj_nki_kernel(self): + with patch.object( + _COMPILE, + "_load_text_config", + return_value={"num_hidden_layers": 2}, + ), patch.dict( + sys.modules, + { + "neuronx_distributed_inference.models.config": _fake_config_module(), + "src.modeling_qwen35": _fake_qwen_module(), + }, + ): + config, _modules = _COMPILE._build_config( + _args(enable_qwen_gated_o_proj_nki_kernel=True), + ) + + self.assertTrue(config.config_dict["use_qwen_gated_o_proj_nki"]) + + def test_compile_can_enable_split_qkv_tkg_kernel_without_stock_qkv(self): + with patch.object( + _COMPILE, + "_load_text_config", + return_value={"num_hidden_layers": 2}, + ), patch.dict( + sys.modules, + { + "neuronx_distributed_inference.models.config": _fake_config_module(), + "src.modeling_qwen35": _fake_qwen_module(), + }, + ): + config, _modules = _COMPILE._build_config( + _args(enable_split_qkv_tkg_nki_kernel=True), + ) + + self.assertTrue(config.neuron_config.qkv_tkg_nki_kernel_enabled) + self.assertFalse(getattr(config.neuron_config, "fused_qkv", False)) + self.assertFalse( + getattr(config.neuron_config, "qkv_kernel_enabled", False), + ) + self.assertFalse( + getattr(config.neuron_config, "qkv_nki_kernel_enabled", False), + ) + + def test_compile_can_enable_output_projection_kernel(self): + with patch.object( + _COMPILE, + "_load_text_config", + return_value={"num_hidden_layers": 2}, + ), patch.dict( + sys.modules, + { + "neuronx_distributed_inference.models.config": _fake_config_module(), + "src.modeling_qwen35": _fake_qwen_module(), + }, + ): + config, _modules = _COMPILE._build_config( + _args(enable_out_proj_nki_kernel=True), + ) + + self.assertTrue(config.neuron_config.out_proj_kernel_enabled) + + def test_compile_can_enable_quantized_mlp_tkg_kernel(self): + with patch.object( + _COMPILE, + "_load_text_config", + return_value={"num_hidden_layers": 2}, + ), patch.dict( + sys.modules, + { + "neuronx_distributed_inference.models.config": _fake_config_module(), + "src.modeling_qwen35": _fake_qwen_module(), + }, + ): + config, _modules = _COMPILE._build_config( + _args( + weight_dtype="fp8_full", + enable_mlp_tkg_nki_kernel=True, + enable_quantized_mlp_kernel=True, + ), + ) + + self.assertTrue(config.neuron_config.mlp_kernel_enabled) + self.assertTrue(config.neuron_config.mlp_tkg_nki_kernel_enabled) + self.assertTrue(config.neuron_config.quantized_mlp_kernel_enabled) + + def test_compile_can_enable_quantized_mlp_cte_kernel_without_tkg_flag(self): + with patch.object( + _COMPILE, + "_load_text_config", + return_value={"num_hidden_layers": 2}, + ), patch.dict( + sys.modules, + { + "neuronx_distributed_inference.models.config": _fake_config_module(), + "src.modeling_qwen35": _fake_qwen_module(), + }, + ): + config, _modules = _COMPILE._build_config( + _args( + weight_dtype="fp8_full", + enable_mlp_cte_nki_kernel=True, + enable_quantized_mlp_kernel=True, + ), + ) + + self.assertTrue(config.neuron_config.mlp_kernel_enabled) + self.assertFalse( + getattr(config.neuron_config, "mlp_tkg_nki_kernel_enabled", False) + ) + self.assertTrue(config.neuron_config.quantized_mlp_kernel_enabled) + + def test_decode_memory_flags_are_forwarded(self): + with patch.object( + _COMPILE, + "_load_text_config", + return_value={"num_hidden_layers": 2}, + ), patch.dict( + sys.modules, + { + "neuronx_distributed_inference.models.config": _fake_config_module(), + "src.modeling_qwen35": _fake_qwen_module(), + }, + ): + config, _modules = _COMPILE._build_config( + _args( + enable_k_cache_transposed=True, + enable_kv_cache_quant=True, + hybrid_apc_commit_during_token_generation=True, + ), + ) + + self.assertTrue(config.neuron_config.k_cache_transposed) + self.assertTrue(config.neuron_config.kv_cache_quant) + self.assertEqual(config.neuron_config.kv_quant_config, {"direct_cast": True}) + self.assertTrue( + config.config_dict["hybrid_apc_commit_during_token_generation"], + ) + + def test_on_device_sampling_compile_uses_sampler_config(self): + with patch.object( + _COMPILE, + "_load_text_config", + return_value={"num_hidden_layers": 2}, + ), patch.dict( + sys.modules, + { + "neuronx_distributed_inference.models.config": _fake_config_module(), + "src.modeling_qwen35": _fake_qwen_module(), + }, + ): + config, _modules = _COMPILE._build_config( + _args(disable_on_device_sampling=False), + ) + + self.assertIsNotNone(config.neuron_config.on_device_sampling_config) + self.assertFalse(config.neuron_config.output_logits) + self.assertTrue(config.neuron_config.vocab_parallel) + self.assertEqual(config.neuron_config.pa_num_blocks, 8) + + def test_on_device_sampling_can_disable_custom_argmax_kernel(self): + with patch.object( + _COMPILE, + "_load_text_config", + return_value={"num_hidden_layers": 2}, + ), patch.dict( + sys.modules, + { + "neuronx_distributed_inference.models.config": _fake_config_module(), + "src.modeling_qwen35": _fake_qwen_module(), + }, + ): + config, _modules = _COMPILE._build_config( + _args( + disable_on_device_sampling=False, + disable_argmax_kernel=True, + ), + ) + + self.assertIsNotNone(config.neuron_config.on_device_sampling_config) + self.assertTrue(config.neuron_config.vocab_parallel) + self.assertTrue(config.neuron_config.disable_argmax_kernel) + + def test_on_device_sampling_can_disable_context_encoding_argmax_kernel(self): + with patch.object( + _COMPILE, + "_load_text_config", + return_value={"num_hidden_layers": 2}, + ), patch.dict( + sys.modules, + { + "neuronx_distributed_inference.models.config": _fake_config_module(), + "src.modeling_qwen35": _fake_qwen_module(), + }, + ): + config, _modules = _COMPILE._build_config( + _args( + disable_on_device_sampling=False, + disable_context_encoding_argmax_kernel=True, + ), + ) + + self.assertIsNotNone(config.neuron_config.on_device_sampling_config) + self.assertTrue(config.neuron_config.vocab_parallel) + self.assertFalse(config.neuron_config.disable_argmax_kernel) + self.assertTrue(config.neuron_config.disable_context_encoding_argmax_kernel) + + def test_on_device_sampling_can_also_return_logits_for_debug(self): + with patch.object( + _COMPILE, + "_load_text_config", + return_value={"num_hidden_layers": 2}, + ), patch.dict( + sys.modules, + { + "neuronx_distributed_inference.models.config": _fake_config_module(), + "src.modeling_qwen35": _fake_qwen_module(), + }, + ): + config, _modules = _COMPILE._build_config( + _args( + disable_on_device_sampling=False, + output_logits_with_on_device_sampling=True, + ), + ) + + self.assertIsNotNone(config.neuron_config.on_device_sampling_config) + self.assertTrue(config.neuron_config.output_logits) + self.assertTrue(config.neuron_config.vocab_parallel) + + def test_bf16_control_compile_disables_quantization_and_keeps_host_logits(self): + with patch.object( + _COMPILE, + "_load_text_config", + return_value={"num_hidden_layers": 2}, + ), patch.dict( + sys.modules, + { + "neuronx_distributed_inference.models.config": _fake_config_module(), + "src.modeling_qwen35": _fake_qwen_module(), + }, + ): + config, modules = _COMPILE._build_config( + _args( + disable_on_device_sampling=True, + weight_dtype="bf16_control", + quantized_checkpoints_path=None, + ), + ) + + self.assertTrue(config.neuron_config.output_logits) + self.assertFalse(config.neuron_config.quantized) + self.assertIsNone(config.neuron_config.on_device_sampling_config) + self.assertGreater(len(modules), 0) + + def test_fp8_mlp_only_keeps_edge_mlp_layers_in_bf16_by_default(self): + with patch.object( + _COMPILE, + "_load_text_config", + return_value={"num_hidden_layers": 4}, + ), patch.dict( + sys.modules, + { + "neuronx_distributed_inference.models.config": _fake_config_module(), + "src.modeling_qwen35": _fake_qwen_module(), + }, + ): + _config, modules = _COMPILE._build_config( + _args(disable_on_device_sampling=True), + ) + + self.assertIn("layers.0.mlp", modules) + self.assertIn("layers.3.mlp", modules) + self.assertNotIn("layers.1.mlp", modules) + + def test_fp8_full_quantizes_attention_and_edge_mlp_by_default(self): + with patch.object( + _COMPILE, + "_load_text_config", + return_value={"num_hidden_layers": 4}, + ), patch.dict( + sys.modules, + { + "neuronx_distributed_inference.models.config": _fake_config_module(), + "src.modeling_qwen35": _fake_qwen_module(), + }, + ): + config, modules = _COMPILE._build_config( + _args(disable_on_device_sampling=True, weight_dtype="fp8_full"), + ) + + self.assertTrue(config.neuron_config.quantized) + self.assertIn( + r".*\.scale$", + config.neuron_config.weights_to_skip_layout_optimization, + ) + self.assertIn( + r".*\.weight_scale$", + config.neuron_config.weights_to_skip_layout_optimization, + ) + self.assertIn( + r".*linear_attn\.conv1d_weight\.weight$", + config.neuron_config.weights_to_skip_layout_optimization, + ) + self.assertNotIn("layers.0.mlp", modules) + self.assertNotIn("layers.3.mlp", modules) + self.assertNotIn("layers.0.self_attn", modules) + self.assertNotIn("layers.0.linear_attn", modules) + self.assertIn("layers.0.linear_attn.conv1d_weight", modules) + self.assertIn("layers.0.linear_attn.A_log_weight", modules) + self.assertIn("layers.0.linear_attn.dt_bias_weight", modules) + self.assertIn("layers.0.linear_attn.in_proj_a", modules) + self.assertIn("layers.0.linear_attn.in_proj_b", modules) + self.assertIn("layers.0.linear_attn.in_proj_ba", modules) + self.assertIn("lm_head", modules) + + def test_fp8_full_can_quantize_lm_head_when_requested(self): + with patch.object( + _COMPILE, + "_load_text_config", + return_value={"num_hidden_layers": 2}, + ), patch.dict( + sys.modules, + { + "neuronx_distributed_inference.models.config": _fake_config_module(), + "src.modeling_qwen35": _fake_qwen_module(), + }, + ): + _config, modules = _COMPILE._build_config( + _args( + disable_on_device_sampling=True, + weight_dtype="fp8_full", + quantize_lm_head=True, + ), + ) + + self.assertNotIn("lm_head", modules) + self.assertNotIn("model.lm_head", modules) + + def test_fp8_full_keeps_linear_attention_gate_projections_bf16(self): + self.assertFalse( + _COMPILE._is_full_fp8_weight( + "layers.0.linear_attn.in_proj_a.weight", + quantize_lm_head=True, + ), + ) + self.assertFalse( + _COMPILE._is_full_fp8_weight( + "layers.0.linear_attn.in_proj_b.weight", + quantize_lm_head=True, + ), + ) + self.assertFalse( + _COMPILE._is_full_fp8_weight( + "layers.0.linear_attn.in_proj_ba.weight", + quantize_lm_head=True, + ), + ) + self.assertTrue( + _COMPILE._is_full_fp8_weight( + "layers.0.linear_attn.in_proj_qkv.weight", + quantize_lm_head=True, + ), + ) + + def test_fp8_full_can_use_legacy_fp8_linear_attention_gate_policy(self): + self.assertTrue( + _COMPILE._is_full_fp8_weight( + "layers.0.linear_attn.in_proj_a.weight", + quantize_lm_head=True, + quantize_linear_attn_gates=True, + ), + ) + self.assertTrue( + _COMPILE._is_full_fp8_weight( + "layers.0.linear_attn.in_proj_b.weight", + quantize_lm_head=True, + quantize_linear_attn_gates=True, + ), + ) + self.assertFalse( + _COMPILE._is_full_fp8_weight( + "layers.0.linear_attn.in_proj_ba.weight", + quantize_lm_head=True, + quantize_linear_attn_gates=True, + ), + ) + + def test_legacy_fp8_linear_attention_gate_policy_matches_old_config(self): + with patch.object( + _COMPILE, + "_load_text_config", + return_value={"num_hidden_layers": 2}, + ), patch.dict( + sys.modules, + { + "neuronx_distributed_inference.models.config": _fake_config_module(), + "src.modeling_qwen35": _fake_qwen_module(), + }, + ): + _config, modules = _COMPILE._build_config( + _args( + weight_dtype="fp8_full", + quantize_lm_head=True, + fp8_quantize_linear_attn_gates=True, + ), + ) + + self.assertNotIn("layers.0.linear_attn.in_proj_a", modules) + self.assertNotIn("layers.0.linear_attn.in_proj_b", modules) + self.assertNotIn("layers.0.linear_attn.in_proj_ba", modules) + + def test_fp8_full_can_exclude_remaining_linear_attention_matmuls(self): + for weight_name in ( + "layers.0.linear_attn.in_proj_qkv.weight", + "layers.0.linear_attn.in_proj_z.weight", + "layers.0.linear_attn.out_proj.weight", + ): + self.assertFalse( + _COMPILE._is_full_fp8_weight( + weight_name, + quantize_lm_head=False, + fp8_exclude_groups={"linear_attn"}, + ), + ) + self.assertTrue( + _COMPILE._is_full_fp8_weight( + "layers.0.mlp.up_proj.weight", + quantize_lm_head=False, + fp8_exclude_groups={"linear_attn"}, + ), + ) + + def test_fp8_full_exclude_groups_are_reflected_in_config(self): + with patch.object( + _COMPILE, + "_load_text_config", + return_value={"num_hidden_layers": 2}, + ), patch.dict( + sys.modules, + { + "neuronx_distributed_inference.models.config": _fake_config_module(), + "src.modeling_qwen35": _fake_qwen_module(), + }, + ): + config, modules = _COMPILE._build_config( + _args( + weight_dtype="fp8_full", + fp8_exclude_groups=["linear_attn", "mlp"], + ), + ) + + self.assertIn("layers.0.linear_attn", modules) + self.assertIn("layers.0.mlp", modules) + self.assertIn("model.layers.1.linear_attn", modules) + self.assertIn("layers.0.linear_attn", config.neuron_config.modules_to_not_convert) + + def test_user_wlo_skip_patterns_are_appended_and_deduplicated(self): + with patch.object( + _COMPILE, + "_load_text_config", + return_value={"num_hidden_layers": 2}, + ), patch.dict( + sys.modules, + { + "neuronx_distributed_inference.models.config": _fake_config_module(), + "src.modeling_qwen35": _fake_qwen_module(), + }, + ): + config, _modules = _COMPILE._build_config( + _args( + disable_on_device_sampling=True, + weight_dtype="fp8_full", + weights_to_skip_layout_optimization=[ + r".*\.scale$", + r".*custom_skip.*", + ], + ), + ) + + self.assertEqual( + config.neuron_config.weights_to_skip_layout_optimization, + [ + r".*\.scale$", + r".*\.weight_scale$", + r".*linear_attn\.conv1d_weight\.weight$", + r".*custom_skip.*", + ], + ) + + def test_bf16_control_does_not_add_fp8_wlo_skips_by_default(self): + with patch.object( + _COMPILE, + "_load_text_config", + return_value={"num_hidden_layers": 2}, + ), patch.dict( + sys.modules, + { + "neuronx_distributed_inference.models.config": _fake_config_module(), + "src.modeling_qwen35": _fake_qwen_module(), + }, + ): + config, _modules = _COMPILE._build_config( + _args( + disable_on_device_sampling=True, + weight_dtype="bf16_control", + quantized_checkpoints_path=None, + ), + ) + + self.assertFalse( + hasattr(config.neuron_config, "weights_to_skip_layout_optimization"), + ) + + def test_compile_can_disable_token_generation_wlo(self): + with patch.object( + _COMPILE, + "_load_text_config", + return_value={"num_hidden_layers": 2}, + ), patch.dict( + sys.modules, + { + "neuronx_distributed_inference.models.config": _fake_config_module(), + "src.modeling_qwen35": _fake_qwen_module(), + }, + ): + config, _modules = _COMPILE._build_config( + _args(disable_token_generation_wlo=True), + ) + + self.assertTrue(config.config_dict["disable_token_generation_wlo"]) + + def test_compile_can_disable_token_generation_wlo_from_env(self): + with patch.object( + _COMPILE, + "_load_text_config", + return_value={"num_hidden_layers": 2}, + ), patch.dict( + os.environ, + {"QWEN36_DISABLE_TOKEN_GENERATION_WLO": "1"}, + ), patch.dict( + sys.modules, + { + "neuronx_distributed_inference.models.config": _fake_config_module(), + "src.modeling_qwen35": _fake_qwen_module(), + }, + ): + config, _modules = _COMPILE._build_config(_args()) + + self.assertTrue(config.config_dict["disable_token_generation_wlo"]) + + def test_compile_forwards_cold_cte_fast_path_flags(self): + with patch.object( + _COMPILE, + "_load_text_config", + return_value={"num_hidden_layers": 2}, + ), patch.dict( + sys.modules, + { + "neuronx_distributed_inference.models.config": _fake_config_module(), + "src.modeling_qwen35": _fake_qwen_module(), + }, + ): + config, _modules = _COMPILE._build_config( + _args( + text_only_cte=False, + compact_cte_attention_mask=False, + cold_zero_conv_fast_path=True, + ), + ) + + self.assertFalse(config.config_dict["use_text_only_cte_inputs"]) + self.assertFalse(config.config_dict["use_compact_cte_attention_mask"]) + self.assertTrue(config.config_dict["use_cold_zero_conv_fast_path"]) + + def test_long_prefix_buckets_must_fit_max_context_length(self): + with self.assertRaisesRegex(ValueError, "Largest prefix bucket"): + _COMPILE._validate_prefix_buckets_fit_context( + _args(enable_prefix_caching=True), + max_context_length=512, + prefix_buckets=[512, 131072], + ) + + def test_sparse_context_encoding_bucket_pairs_are_forwarded(self): + with patch.object( + _COMPILE, + "_load_text_config", + return_value={"num_hidden_layers": 2}, + ), patch.dict( + sys.modules, + { + "neuronx_distributed_inference.models.config": _fake_config_module(), + "src.modeling_qwen35": _fake_qwen_module(), + }, + ): + config, _modules = _COMPILE._build_config( + _args( + cte_buckets=["512,1536"], + prefix_buckets=["256,512,65536"], + max_context_length=65536, + seq_len=65536, + pa_num_blocks=256, + context_encoding_bucket_pairs=[ + "512:256,512:512", + "1536:256", + "1536:65536", + ], + ), + ) + + self.assertEqual( + config.neuron_config.context_encoding_bucket_pairs, + [ + [512, 0], + [512, 256], + [512, 512], + [1536, 0], + [1536, 256], + [1536, 65536], + ], + ) + + def test_prefix_cte_attention_chunk_size_is_forwarded(self): + with patch.object( + _COMPILE, + "_load_text_config", + return_value={"num_hidden_layers": 2}, + ), patch.dict( + sys.modules, + { + "neuronx_distributed_inference.models.config": _fake_config_module(), + "src.modeling_qwen35": _fake_qwen_module(), + }, + ): + config, _modules = _COMPILE._build_config( + _args(prefix_cte_attention_chunk_size=32768), + ) + + self.assertEqual(config.neuron_config.prefix_cte_attention_chunk_size, 32768) + + def test_segmented_prefix_cte_attention_config_is_forwarded(self): + with patch.object( + _COMPILE, + "_load_text_config", + return_value={"num_hidden_layers": 2}, + ), patch.dict( + sys.modules, + { + "neuronx_distributed_inference.models.config": _fake_config_module(), + "src.modeling_qwen35": _fake_qwen_module(), + }, + ): + config, _modules = _COMPILE._build_config( + _args( + prefix_cte_attention_backend="segmented_cte", + prefix_cte_attention_segment_size=32768, + ), + ) + + self.assertEqual( + config.neuron_config.prefix_cte_attention_backend, + "segmented_cte", + ) + self.assertEqual( + config.neuron_config.prefix_cte_attention_segment_size, + 32768, + ) + + def test_sparse_context_encoding_bucket_pairs_can_omit_zero_pair(self): + pairs = _COMPILE._context_encoding_bucket_pairs( + _args( + cte_buckets=["3072"], + prefix_buckets=["131072"], + context_encoding_bucket_pairs=["3072:131072"], + omit_zero_prefix_pair=True, + ), + cte_buckets=[3072], + prefix_buckets=[131072], + ) + + self.assertEqual(pairs, [[3072, 131072]]) + + def test_sparse_context_encoding_bucket_pairs_validate_config_buckets(self): + with self.assertRaisesRegex(ValueError, "active bucket"): + _COMPILE._context_encoding_bucket_pairs( + _args(context_encoding_bucket_pairs=["768:256"]), + cte_buckets=[512], + prefix_buckets=[256], + ) + + with self.assertRaisesRegex(ValueError, "prefix bucket"): + _COMPILE._context_encoding_bucket_pairs( + _args(context_encoding_bucket_pairs=["512:1024"]), + cte_buckets=[512], + prefix_buckets=[256], + ) + + def test_pa_num_blocks_rejects_user_blocks_below_sequence_requirement(self): + with self.assertRaisesRegex(ValueError, "need at least 8"): + _COMPILE._pa_num_blocks(_args(pa_num_blocks=7)) + + def test_pa_headroom_blocks_extend_default_pa_capacity(self): + args = _args( + seq_len=4096, + block_size=32, + max_num_seqs=2, + pa_num_blocks=None, + pa_headroom_blocks=32, + ) + + self.assertEqual(_COMPILE._pa_min_blocks(args), 256) + self.assertEqual(_COMPILE._pa_requested_blocks(args), 288) + self.assertEqual(_COMPILE._pa_num_blocks(args), 288) + + def test_base_compile_work_dir_defaults_next_to_artifacts(self): + with self.subTest("default"), patch.dict(os.environ, {}, clear=True): + work_dir = _COMPILE._configure_base_compile_work_dir( + Path("/tmp/qwen_artifacts/model_a"), + None, + ) + + self.assertEqual( + work_dir, + Path("/tmp/qwen_artifacts/_nxd_model_workdir").resolve(), + ) + self.assertEqual(os.environ["BASE_COMPILE_WORK_DIR"], str(work_dir)) + + with self.subTest("existing env"), patch.dict( + os.environ, + {"BASE_COMPILE_WORK_DIR": "/tmp/existing_nxd_workdir"}, + clear=True, + ): + work_dir = _COMPILE._configure_base_compile_work_dir( + Path("/tmp/qwen_artifacts/model_a"), + None, + ) + + self.assertEqual(work_dir, Path("/tmp/existing_nxd_workdir").resolve()) + self.assertEqual(os.environ["BASE_COMPILE_WORK_DIR"], str(work_dir)) + + with self.subTest("explicit override"), patch.dict( + os.environ, + {"BASE_COMPILE_WORK_DIR": "/tmp/existing_nxd_workdir"}, + clear=True, + ): + work_dir = _COMPILE._configure_base_compile_work_dir( + Path("/tmp/qwen_artifacts/model_a"), + "/tmp/explicit_nxd_workdir", + ) + + self.assertEqual(work_dir, Path("/tmp/explicit_nxd_workdir").resolve()) + self.assertEqual(os.environ["BASE_COMPILE_WORK_DIR"], str(work_dir)) + + def test_deltanet_cte_backend_preserves_environment_by_default(self): + with patch.dict( + os.environ, + { + "USE_NKI_FUSED": "custom", + "USE_NKI_CHUNKED": "custom", + }, + clear=True, + ): + _COMPILE._configure_deltanet_cte_backend("env") + + self.assertEqual(os.environ["USE_NKI_FUSED"], "custom") + self.assertEqual(os.environ["USE_NKI_CHUNKED"], "custom") + + def test_deltanet_cte_backend_can_force_nki_chunked(self): + with patch.dict( + os.environ, + { + "USE_NKI_FUSED": "1", + "USE_PYTORCH_CHUNK": "1", + "DELTANET_SEQUENTIAL": "1", + }, + clear=True, + ): + _COMPILE._configure_deltanet_cte_backend("nki_chunked") + + self.assertEqual(os.environ["USE_NKI_FUSED"], "0") + self.assertEqual(os.environ["USE_NKI_CHUNKED"], "1") + self.assertNotIn("USE_PYTORCH_CHUNK", os.environ) + self.assertNotIn("DELTANET_SEQUENTIAL", os.environ) + + def test_deltanet_cte_backend_can_force_pytorch_chunk(self): + with patch.dict(os.environ, {"USE_NKI_CHUNKED": "1"}, clear=True): + _COMPILE._configure_deltanet_cte_backend("pytorch_chunk") + + self.assertEqual(os.environ["USE_NKI_FUSED"], "0") + self.assertEqual(os.environ["USE_PYTORCH_CHUNK"], "1") + self.assertNotIn("USE_NKI_CHUNKED", os.environ) + + def test_backed_prefix_read_compile_flag_is_forwarded(self): + with patch.object( + _COMPILE, + "_load_text_config", + return_value={"num_hidden_layers": 2}, + ), patch.dict( + sys.modules, + { + "neuronx_distributed_inference.models.config": _fake_config_module(), + "src.modeling_qwen35": _fake_qwen_module(), + }, + ): + config, _modules = _COMPILE._build_config( + _args(hybrid_apc_enable_backed_prefix_reads=True), + ) + + self.assertTrue(config.config_dict["hybrid_apc_enable_backed_prefix_reads"]) + + def test_vllm_chunked_prefill_uses_qwen_flags_not_nxdi_chunked_prefill(self): + with patch.object( + _COMPILE, + "_load_text_config", + return_value={"num_hidden_layers": 2}, + ), patch.dict( + sys.modules, + { + "neuronx_distributed_inference.models.config": _fake_config_module(), + "src.modeling_qwen35": _fake_qwen_module(), + }, + ): + config, _modules = _COMPILE._build_config( + _args(enable_vllm_chunked_prefill=True), + ) + + self.assertTrue(config.neuron_config.is_block_kv_layout) + self.assertIsNone(getattr(config.neuron_config, "chunked_prefill_config", None)) + self.assertTrue(config.config_dict["use_qwen_hybrid_chunked_prefill"]) + self.assertTrue(config.config_dict["use_qwen_hybrid_chunked_prefill_nki"]) + + def test_deltanet_decode_nki_compile_flag_is_forwarded(self): + with patch.object( + _COMPILE, + "_load_text_config", + return_value={"num_hidden_layers": 2}, + ), patch.dict( + sys.modules, + { + "neuronx_distributed_inference.models.config": _fake_config_module(), + "src.modeling_qwen35": _fake_qwen_module(), + }, + ): + config, _modules = _COMPILE._build_config( + _args(enable_deltanet_decode_nki=True), + ) + + self.assertTrue(config.config_dict["use_qwen_deltanet_decode_nki"]) + + def test_checkpoint_bank_weights_are_added_for_reload(self): + from safetensors import safe_open + from safetensors.torch import save_file + + with tempfile.TemporaryDirectory() as tmpdir: + compiled_path = Path(tmpdir) + weights_dir = compiled_path / "weights" + weights_dir.mkdir() + shard_path = weights_dir / "tp0_sharded_checkpoint.safetensors" + save_file( + {"existing.weight": _COMPILE.torch.ones(1)}, + shard_path, + metadata={"format": "pt"}, + ) + + inf_config = types.SimpleNamespace( + layer_types=["linear_attention", "full_attention", "linear_attention"], + linear_num_value_heads=48, + linear_num_key_heads=16, + linear_key_head_dim=128, + linear_value_head_dim=128, + linear_conv_kernel_dim=4, + max_gdn_checkpoint_slots=64, + hybrid_recurrent_cache_dtype="float32", + hybrid_conv_cache_dtype="bfloat16", + neuron_config=types.SimpleNamespace( + tp_degree=4, + torch_dtype=_COMPILE.torch.bfloat16, + ), + ) + + _COMPILE._ensure_hybrid_checkpoint_weights(compiled_path, inf_config) + + with safe_open(shard_path, framework="pt", device="cpu") as handle: + keys = set(handle.keys()) + recurrent = handle.get_tensor( + "hybrid_gdn_checkpoint_cache.recurrent_slots.0", + ) + conv = handle.get_tensor("hybrid_gdn_checkpoint_cache.conv_slots.0") + + self.assertIn("existing.weight", keys) + self.assertIn("hybrid_gdn_checkpoint_cache.recurrent_slots.1", keys) + self.assertIn("hybrid_gdn_checkpoint_cache.conv_slots.1", keys) + self.assertEqual(recurrent.dtype, _COMPILE.torch.float32) + self.assertEqual(tuple(recurrent.shape), (64, 12, 128, 128)) + self.assertEqual(conv.dtype, _COMPILE.torch.bfloat16) + self.assertEqual(tuple(conv.shape), (64, 2560, 3)) + + +if __name__ == "__main__": + unittest.main() diff --git a/contrib/models/Qwen3.6-27B/test/unit/test_qwen36_model_aliases.py b/contrib/models/Qwen3.6-27B/test/unit/test_qwen36_model_aliases.py new file mode 100644 index 00000000..949d1701 --- /dev/null +++ b/contrib/models/Qwen3.6-27B/test/unit/test_qwen36_model_aliases.py @@ -0,0 +1,1941 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +import importlib.util +import inspect +import math +import os +import sys +import types +import unittest +from pathlib import Path +from types import SimpleNamespace +from unittest.mock import patch + +import torch +from torch import nn + + +_REPO_ROOT = Path(__file__).resolve().parents[5] +_QWEN_MODEL_PATH = ( + _REPO_ROOT / "contrib" / "models" / "Qwen3.6-27B" / "src" / "modeling_qwen35.py" +) + + +def _package(name): + module = types.ModuleType(name) + module.__path__ = [] + return module + + +def _module(name, **attrs): + module = types.ModuleType(name) + for key, value in attrs.items(): + setattr(module, key, value) + return module + + +def _jit(*_args, **_kwargs): + def decorator(fn): + return fn + + return decorator + + +class _FakeDecoderModelInstance: + def get(self, bucket_rank, **kwargs): + del bucket_rank, kwargs + num_outputs = 1 if not self.neuron_config.output_logits else 2 + kvs = self.module.kv_mgr.past_key_values + aliases = {param: num_outputs + i for i, param in enumerate(kvs)} + self.input_output_aliases = aliases + return self.module, aliases + + +class _FakeModelWrapper: + def input_generator(self): + return self._base_inputs + + def pad_inputs(self, *args, pad_type="first_fit"): + del pad_type + return args + + +def _fake_modules(): + return { + "nki": _module("nki", jit=_jit), + "neuronxcc": _package("neuronxcc"), + "neuronxcc.nki": _package("neuronxcc.nki"), + "neuronxcc.nki._private_kernels": _package("neuronxcc.nki._private_kernels"), + "neuronxcc.nki._private_kernels.attention": _module( + "neuronxcc.nki._private_kernels.attention", + attention_isa_kernel=lambda *args, **kwargs: None, + ), + "neuronx_distributed": _package("neuronx_distributed"), + "neuronx_distributed.parallel_layers": _package( + "neuronx_distributed.parallel_layers" + ), + "neuronx_distributed.parallel_layers.parallel_state": _module( + "neuronx_distributed.parallel_layers.parallel_state", + get_tensor_model_parallel_rank=lambda: 0, + ), + "neuronx_distributed.parallel_layers.layers": _module( + "neuronx_distributed.parallel_layers.layers", + ColumnParallelLinear=nn.Linear, + ParallelEmbedding=nn.Embedding, + RowParallelLinear=nn.Linear, + ), + "neuronx_distributed.parallel_layers.mappings": _module( + "neuronx_distributed.parallel_layers.mappings", + _gather_along_dim=lambda tensor, *_args, **_kwargs: tensor, + ), + "neuronx_distributed.utils": _module( + "neuronx_distributed.utils", + cpu_mode=lambda: True, + ), + "transformers": _package("transformers"), + "transformers.models": _package("transformers.models"), + "transformers.models.qwen3_moe": _package("transformers.models.qwen3_moe"), + "transformers.models.qwen3_moe.modeling_qwen3_moe": _module( + "transformers.models.qwen3_moe.modeling_qwen3_moe", + Qwen3MoeRMSNorm=nn.LayerNorm, + ), + "src": _package("src"), + "src.nki_kernels": _package("src.nki_kernels"), + "src.nki_kernels.nki_deltanet": _module( + "src.nki_kernels.nki_deltanet", + deltanet_recurrent_fwd=lambda *args, **kwargs: None, + deltanet_recurrent_fwd_state=lambda *args, **kwargs: None, + deltanet_recurrent_step_batched=lambda *args, **kwargs: None, + ), + "src.nki_kernels.nki_deltanet_chunked": _module( + "src.nki_kernels.nki_deltanet_chunked", + deltanet_chunk_step=lambda *args, **kwargs: None, + ), + "src.nki_kernels.nki_deltanet_fused": _module( + "src.nki_kernels.nki_deltanet_fused", + deltanet_autocp_affine_sequence=lambda *args, **kwargs: None, + deltanet_autocp_apply_output=lambda *args, **kwargs: None, + deltanet_autocp_prefix_apply_output=lambda *args, **kwargs: None, + deltanet_autocp_state_summary_sequence=lambda *args, **kwargs: None, + deltanet_autocp_state_prefix=lambda *args, **kwargs: None, + deltanet_fused_chunked_fwd=lambda *args, **kwargs: None, + deltanet_fused_chunked_fwd_multihead=lambda *args, **kwargs: None, + _make_lower_mask=lambda *args, **kwargs: None, + _make_lower_mask_diag=lambda *args, **kwargs: None, + _make_identity=lambda *args, **kwargs: None, + ), + "src.nki_kernels.nki_deltanet_fused_legacy": _module( + "src.nki_kernels.nki_deltanet_fused_legacy", + deltanet_fused_chunked_fwd=lambda *args, **kwargs: None, + ), + "src.hybrid_apc": _module( + "src.hybrid_apc", + HybridAPCMetadataStore=object, + HybridAPCSchedulerBridge=object, + HybridAPCSlotAllocator=object, + ), + "neuronx_distributed_inference": _package("neuronx_distributed_inference"), + "neuronx_distributed_inference.models": _package( + "neuronx_distributed_inference.models" + ), + "neuronx_distributed_inference.models.config": _module( + "neuronx_distributed_inference.models.config", + InferenceConfig=object, + NeuronConfig=object, + ), + "neuronx_distributed_inference.models.llama": _package( + "neuronx_distributed_inference.models.llama" + ), + "neuronx_distributed_inference.models.llama.modeling_llama": _module( + "neuronx_distributed_inference.models.llama.modeling_llama", + NeuronLlamaMLP=nn.Module, + ), + "neuronx_distributed_inference.models.model_base": _module( + "neuronx_distributed_inference.models.model_base", + NeuronBaseForCausalLM=object, + NeuronBaseModel=nn.Module, + mask_padded_logits=lambda logits, *_args, **_kwargs: logits, + ), + "neuronx_distributed_inference.models.model_wrapper": _module( + "neuronx_distributed_inference.models.model_wrapper", + CONTEXT_ENCODING_MODEL_TAG="context_encoding_model", + TOKEN_GENERATION_MODEL_TAG="token_generation_model", + DecoderModelInstance=_FakeDecoderModelInstance, + ModelWrapper=_FakeModelWrapper, + ), + "neuronx_distributed_inference.utils": _package( + "neuronx_distributed_inference.utils" + ), + "neuronx_distributed_inference.utils.distributed": _module( + "neuronx_distributed_inference.utils.distributed", + get_tp_group=lambda *_args, **_kwargs: None, + ), + "neuronx_distributed_inference.modules": _package( + "neuronx_distributed_inference.modules" + ), + "neuronx_distributed_inference.modules.async_execution": _module( + "neuronx_distributed_inference.modules.async_execution", + cancel_hybrid_apc_request=lambda *args, **kwargs: None, + finish_hybrid_apc_request=lambda *args, **kwargs: None, + prepare_hybrid_apc_model_inputs=lambda *args, **kwargs: (), + prepare_hybrid_apc_request_for_execution=lambda *args, **kwargs: None, + ), + "neuronx_distributed_inference.modules.custom_calls": _module( + "neuronx_distributed_inference.modules.custom_calls", + CustomRMSNorm=nn.LayerNorm, + ), + "neuronx_distributed_inference.modules.attention": _package( + "neuronx_distributed_inference.modules.attention" + ), + "neuronx_distributed_inference.modules.attention.attention_base": _module( + "neuronx_distributed_inference.modules.attention.attention_base", + NeuronAttentionBase=nn.Module, + ), + "neuronx_distributed_inference.modules.attention.utils": _module( + "neuronx_distributed_inference.modules.attention.utils", + RotaryEmbedding=object, + move_heads_front=lambda tensor, *_args, **_kwargs: tensor, + transpose_parallel_linear_layer=lambda weight: weight, + ), + "neuronx_distributed_inference.modules.kvcache": _package( + "neuronx_distributed_inference.modules.kvcache" + ), + "neuronx_distributed_inference.modules.kvcache.block_kv_cache_manager": _module( + "neuronx_distributed_inference.modules.kvcache.block_kv_cache_manager", + BlockKVCacheManager=object, + ), + "neuronx_distributed_inference.modules.kvcache.kv_cache_manager": _module( + "neuronx_distributed_inference.modules.kvcache.kv_cache_manager", + KVCacheManager=object, + ), + "neuronx_distributed_inference.models.layer_boundary_marker": _module( + "neuronx_distributed_inference.models.layer_boundary_marker", + ModuleMarkerEndWrapper=object, + ModuleMarkerStartWrapper=object, + ), + } + + +def _load_qwen_module(): + spec = importlib.util.spec_from_file_location( + "qwen36_model_aliases_under_test", + _QWEN_MODEL_PATH, + ) + module = importlib.util.module_from_spec(spec) + sys.modules[spec.name] = module + with patch.dict(sys.modules, _fake_modules()): + spec.loader.exec_module(module) + return module + + +def _make_instance(qwen_module, *, output_logits, on_device_sampling_config=None): + kv0, kv1, state, checkpoint = (torch.nn.Parameter(torch.zeros(1)) for _ in range(4)) + module = SimpleNamespace( + kv_mgr=SimpleNamespace(past_key_values=[kv0, kv1]), + config=SimpleNamespace(use_hybrid_cache_manager=False), + _deltanet_state_params=[state], + _hybrid_gdn_checkpoint_params=[checkpoint], + ) + instance = qwen_module.Qwen35DecoderModelInstance.__new__( + qwen_module.Qwen35DecoderModelInstance + ) + instance.neuron_config = SimpleNamespace( + output_logits=output_logits, + on_device_sampling_config=on_device_sampling_config, + ) + instance.module = module + return instance, (kv0, kv1, state, checkpoint) + + +def _make_wrapper(qwen_module, *, tag, use_hybrid_apc_manager=True): + wrapper = qwen_module.Qwen35ModelWrapper.__new__(qwen_module.Qwen35ModelWrapper) + wrapper.tag = tag + wrapper.config = SimpleNamespace( + hidden_size=8, + neuron_config=SimpleNamespace(torch_dtype=torch.bfloat16), + use_text_only_cte_inputs=True, + use_hybrid_apc_manager=use_hybrid_apc_manager, + ) + wrapper._base_inputs = [ + ( + torch.ones((1, 1), dtype=torch.int32), # input_ids + torch.ones((1, 1), dtype=torch.int32), # attention_mask + torch.ones((1, 1), dtype=torch.int32), # position_ids + torch.zeros((1,), dtype=torch.int32), # seq_ids + torch.ones((1, 3), dtype=torch.float32), # sampling_params + torch.empty(0), + torch.zeros((1,), dtype=torch.int32), # adapter_ids + torch.empty(0), + torch.empty(0), + torch.empty(0), + torch.empty(0), + torch.zeros((1, 1), dtype=torch.int32), # slot_mapping + torch.zeros((1, 1), dtype=torch.int32), # block_table + torch.ones((1, 1), dtype=torch.int32), # num_queries + torch.zeros((1, 1), dtype=torch.int32), # computed_context_lens + ) + ] + return wrapper + + +class _IdentityMarker: + def __call__(self, tensor): + return tensor + + +class _RecordingNorm(nn.Module): + def __init__(self): + super().__init__() + self.calls = 0 + + def forward(self, tensor): + self.calls += 1 + return tensor + 1 + + +class _RecordingMlp: + def __init__(self): + self.calls = [] + + def __call__(self, tensor, rmsnorm=None): + self.calls.append((tensor.clone(), rmsnorm)) + return tensor + 2, None + + +class _FakeDeltaNetAttention: + def __call__(self, hidden_states, **_kwargs): + return torch.zeros_like(hidden_states), ("k", "v"), None, None + + +def _make_decoder_layer_for_mlp_test(qwen_module): + qwen_module.ModuleMarkerStartWrapper = _IdentityMarker + qwen_module.ModuleMarkerEndWrapper = _IdentityMarker + layer = qwen_module.NeuronQwen35DecoderLayer.__new__( + qwen_module.NeuronQwen35DecoderLayer + ) + nn.Module.__init__(layer) + layer.layer_type = "linear_attention" + layer.config = SimpleNamespace(use_hybrid_cache_manager=False) + layer.linear_attn = _FakeDeltaNetAttention() + layer.input_layernorm = nn.Identity() + layer.post_attention_layernorm = _RecordingNorm() + layer.mlp = _RecordingMlp() + layer.mlp_kernel_enabled = True + layer.mlp_kernel_fused_rmsnorm = True + return layer + + +def _expanded_prefix_attention_reference( + Q, + K_cache, + V_cache, + query_positions, + cache_positions, + key_valid_mask=None, +): + B, q_heads, q_len, head_dim = Q.shape + kv_heads = K_cache.shape[1] + kv_rep = q_heads // kv_heads + K_full = ( + K_cache.unsqueeze(2) + .expand(-1, -1, kv_rep, -1, -1) + .reshape(B, q_heads, K_cache.shape[2], head_dim) + ) + V_full = ( + V_cache.unsqueeze(2) + .expand(-1, -1, kv_rep, -1, -1) + .reshape(B, q_heads, V_cache.shape[2], head_dim) + ) + if cache_positions.ndim == 4: + cache_positions = cache_positions.reshape(B, -1) + if key_valid_mask is not None and key_valid_mask.ndim == 4: + key_valid_mask = key_valid_mask.reshape(B, -1) + + attn_weights = torch.matmul(Q, K_full.transpose(-1, -2)) / math.sqrt(head_dim) + causal_mask = cache_positions[:, None, None, :] <= query_positions[ + :, None, :, None + ] + if key_valid_mask is not None: + causal_mask = causal_mask & key_valid_mask[:, None, None, :] + attn_weights = attn_weights.masked_fill(~causal_mask, -65504.0) + attn_weights = torch.softmax(attn_weights, dim=-1, dtype=torch.float32).to( + Q.dtype + ) + return torch.matmul(attn_weights, V_full) + + +class TestQwen36ModelAliases(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.qwen_module = _load_qwen_module() + + def test_deltanet_multihead_group_defaults_to_lnc2_when_available(self): + with patch.dict( + os.environ, + {"NEURON_CC_FLAGS": "--target trn2 --lnc 2"}, + clear=True, + ): + self.assertEqual( + self.qwen_module._resolve_deltanet_multihead_group_size(4), + 2, + ) + + def test_deltanet_multihead_group_clamps_to_lnc1_by_default(self): + with patch.dict( + os.environ, + {"NEURON_CC_FLAGS": "--target trn2 --lnc 1"}, + clear=True, + ): + self.assertEqual( + self.qwen_module._resolve_deltanet_multihead_group_size(4), + 1, + ) + + def test_deltanet_multihead_group_rejects_explicit_size_above_lnc(self): + with patch.dict( + os.environ, + { + "NEURON_CC_FLAGS": "--target trn2 --lnc 1", + "QWEN36_DELTANET_MULTIHEAD_GROUP_SIZE": "2", + }, + clear=True, + ): + with self.assertRaisesRegex(ValueError, "requires NEURON_CC_FLAGS --lnc"): + self.qwen_module._resolve_deltanet_multihead_group_size(4) + + def test_deltanet_autocp_lnc_defaults_to_lnc2_for_even_chunks(self): + with patch.dict( + os.environ, + {"NEURON_CC_FLAGS": "--target trn2 --lnc 2"}, + clear=True, + ): + self.assertEqual(self.qwen_module._resolve_deltanet_autocp_lnc(128), 2) + + def test_deltanet_autocp_lnc_falls_back_to_lnc1_for_odd_chunks(self): + with patch.dict( + os.environ, + {"NEURON_CC_FLAGS": "--target trn2 --lnc 2"}, + clear=True, + ): + self.assertEqual(self.qwen_module._resolve_deltanet_autocp_lnc(3), 1) + + def test_deltanet_autocp_lnc_rejects_explicit_uneven_chunks(self): + with patch.dict( + os.environ, + { + "NEURON_CC_FLAGS": "--target trn2 --lnc 2", + "QWEN36_DELTANET_AUTOCP_LNC": "2", + }, + clear=True, + ): + with self.assertRaisesRegex(ValueError, "chunks to be divisible"): + self.qwen_module._resolve_deltanet_autocp_lnc(3) + + def test_grouped_prefix_attention_matches_expanded_gqa_reference(self): + torch.manual_seed(123) + batch_size = 2 + q_heads = 6 + kv_heads = 2 + q_len = 5 + cache_len = 12 + head_dim = 8 + Q = torch.randn(batch_size, q_heads, q_len, head_dim) + K_cache = torch.randn(batch_size, kv_heads, cache_len, head_dim) + V_cache = torch.randn(batch_size, kv_heads, cache_len, head_dim) + query_positions = ( + torch.arange(cache_len - q_len, cache_len) + .view(1, q_len) + .expand(batch_size, -1) + ) + cache_positions = torch.arange(cache_len).view(1, cache_len).expand( + batch_size, + -1, + ) + key_valid_mask = torch.ones(batch_size, cache_len, dtype=torch.bool) + key_valid_mask[1, -2:] = False + + actual = self.qwen_module._qwen35_grouped_prefix_attention( + Q, + K_cache, + V_cache, + query_positions, + cache_positions.view(batch_size, 1, 1, cache_len), + key_valid_mask.view(batch_size, 1, 1, cache_len), + ) + expected = _expanded_prefix_attention_reference( + Q, + K_cache, + V_cache, + query_positions, + cache_positions, + key_valid_mask, + ) + + torch.testing.assert_close(actual, expected, atol=1e-6, rtol=1e-5) + + def test_grouped_prefix_attention_matches_mha_reference(self): + torch.manual_seed(456) + batch_size = 1 + q_heads = 4 + q_len = 4 + cache_len = 7 + head_dim = 8 + Q = torch.randn(batch_size, q_heads, q_len, head_dim) + K_cache = torch.randn(batch_size, q_heads, cache_len, head_dim) + V_cache = torch.randn(batch_size, q_heads, cache_len, head_dim) + query_positions = torch.arange(cache_len - q_len, cache_len).view(1, q_len) + cache_positions = torch.arange(cache_len).view(1, cache_len) + + actual = self.qwen_module._qwen35_grouped_prefix_attention( + Q, + K_cache, + V_cache, + query_positions, + cache_positions, + ) + expected = _expanded_prefix_attention_reference( + Q, + K_cache, + V_cache, + query_positions, + cache_positions, + ) + + torch.testing.assert_close(actual, expected, atol=1e-6, rtol=1e-5) + + def test_expanded_prefix_attention_matches_reference(self): + torch.manual_seed(789) + batch_size = 1 + q_heads = 8 + kv_heads = 2 + q_len = 6 + cache_len = 11 + head_dim = 8 + Q = torch.randn(batch_size, q_heads, q_len, head_dim) + K_cache = torch.randn(batch_size, kv_heads, cache_len, head_dim) + V_cache = torch.randn(batch_size, kv_heads, cache_len, head_dim) + query_positions = torch.arange(cache_len - q_len, cache_len).view(1, q_len) + cache_positions = torch.arange(cache_len).view(1, 1, 1, cache_len) + key_valid_mask = torch.ones(batch_size, 1, 1, cache_len, dtype=torch.bool) + key_valid_mask[:, :, :, -1] = False + + actual = self.qwen_module._qwen35_expanded_prefix_attention( + Q, + K_cache, + V_cache, + query_positions, + cache_positions, + key_valid_mask, + ) + expected = _expanded_prefix_attention_reference( + Q, + K_cache, + V_cache, + query_positions, + cache_positions, + key_valid_mask, + ) + + torch.testing.assert_close(actual, expected, atol=1e-6, rtol=1e-5) + + def test_prefix_attention_impl_env_selects_legacy_expanded(self): + with patch.dict( + os.environ, + {"QWEN36_PREFIX_ATTENTION_IMPL": "legacy_expanded"}, + clear=True, + ): + self.assertEqual( + self.qwen_module._qwen36_prefix_attention_impl(), + "expanded", + ) + + def test_prefix_attention_impl_rejects_unknown_value(self): + with patch.dict( + os.environ, + {"QWEN36_PREFIX_ATTENTION_IMPL": "bogus"}, + clear=True, + ): + with self.assertRaisesRegex(ValueError, "QWEN36_PREFIX_ATTENTION_IMPL"): + self.qwen_module._qwen36_prefix_attention_impl() + + def test_grouped_prefix_attention_rejects_invalid_gqa_shape(self): + Q = torch.zeros(1, 5, 2, 4) + K_cache = torch.zeros(1, 2, 3, 4) + V_cache = torch.zeros(1, 2, 3, 4) + query_positions = torch.arange(2).view(1, 2) + cache_positions = torch.arange(3).view(1, 3) + + with self.assertRaisesRegex(ValueError, "q_heads to be divisible"): + self.qwen_module._qwen35_grouped_prefix_attention( + Q, + K_cache, + V_cache, + query_positions, + cache_positions, + ) + + def test_host_logits_aliases_after_single_trace_output(self): + instance, (kv0, kv1, state, checkpoint) = _make_instance( + self.qwen_module, + output_logits=True, + on_device_sampling_config=None, + ) + + _module, aliases = instance.get(bucket_rank=0) + + self.assertEqual(aliases[kv0], 1) + self.assertEqual(aliases[kv1], 2) + self.assertEqual(aliases[state], 3) + self.assertEqual(aliases[checkpoint], 4) + + def test_on_device_logits_aliases_after_tokens_and_logits(self): + instance, (kv0, kv1, state, checkpoint) = _make_instance( + self.qwen_module, + output_logits=True, + on_device_sampling_config=object(), + ) + + _module, aliases = instance.get(bucket_rank=0) + + self.assertEqual(aliases[kv0], 2) + self.assertEqual(aliases[kv1], 3) + self.assertEqual(aliases[state], 4) + self.assertEqual(aliases[checkpoint], 5) + + def test_hybrid_checkpoint_aliases_skip_tkg_without_commit(self): + instance, (kv0, kv1, state, checkpoint) = _make_instance( + self.qwen_module, + output_logits=True, + on_device_sampling_config=object(), + ) + instance.module.config.use_hybrid_apc_manager = True + instance.module.config.hybrid_apc_commit_during_token_generation = False + instance.module.n_active_tokens = 1 + + _module, aliases = instance.get(bucket_rank=0) + + self.assertEqual(aliases[kv0], 2) + self.assertEqual(aliases[kv1], 3) + self.assertEqual(aliases[state], 4) + self.assertNotIn(checkpoint, aliases) + + def test_hybrid_checkpoint_aliases_include_tkg_with_commit(self): + instance, (kv0, kv1, state, checkpoint) = _make_instance( + self.qwen_module, + output_logits=True, + on_device_sampling_config=object(), + ) + instance.module.config.use_hybrid_apc_manager = True + instance.module.config.hybrid_apc_commit_during_token_generation = True + instance.module.n_active_tokens = 1 + + _module, aliases = instance.get(bucket_rank=0) + + self.assertEqual(aliases[kv0], 2) + self.assertEqual(aliases[kv1], 3) + self.assertEqual(aliases[state], 4) + self.assertEqual(aliases[checkpoint], 5) + + def test_alias_output_count_guard_rejects_shifted_deltanet_states(self): + module = SimpleNamespace( + kv_mgr=SimpleNamespace(past_key_values=[object(), object()]), + config=SimpleNamespace( + use_hybrid_cache_manager=False, + use_hybrid_apc_manager=True, + hybrid_apc_commit_during_token_generation=False, + ), + _deltanet_state_params=[object()], + _deltanet_updated_states=[torch.zeros(1), torch.zeros(1)], + _hybrid_gdn_checkpoint_params=[], + _hybrid_gdn_checkpoint_updated_states=[], + ) + + with self.assertRaisesRegex( + RuntimeError, + "_deltanet_updated_states has 2 tensors but _deltanet_state_params has 1", + ): + self.qwen_module._qwen36_validate_alias_output_counts( + module, + updated_kv_cache=[torch.zeros(1), torch.zeros(1)], + is_for_context_encoding=True, + ) + + def test_gathered_logits_mask_only_actual_vocab_padding(self): + lm_head = SimpleNamespace(pad_size=248320, gather_output=True) + config = SimpleNamespace(vocab_size=248320) + + self.assertEqual( + self.qwen_module._effective_lm_head_pad_size( + lm_head, torch.empty(1, 1, 248320), config + ), + 0, + ) + self.assertEqual( + self.qwen_module._effective_lm_head_pad_size( + lm_head, torch.empty(1, 1, 248336), config + ), + 16, + ) + + def test_sharded_logits_keep_lm_head_pad_size(self): + lm_head = SimpleNamespace(pad_size=128, gather_output=False) + config = SimpleNamespace(vocab_size=248320) + + self.assertEqual( + self.qwen_module._effective_lm_head_pad_size( + lm_head, torch.empty(1, 1, 62080), config + ), + 128, + ) + + def test_on_device_output_logits_are_gathered_before_return(self): + logits = torch.arange(6, dtype=torch.float32).reshape(1, 1, 6) + gathered = torch.arange(24, dtype=torch.float32).reshape(1, 1, 24) + lm_head = SimpleNamespace( + gather_output=False, + tensor_parallel_group="tp_group", + ) + neuron_config = SimpleNamespace( + output_logits=True, + on_device_sampling_config=object(), + ) + + with patch.object( + self.qwen_module, + "_gather_along_dim", + return_value=gathered, + ) as gather: + actual = self.qwen_module._qwen36_output_logits_for_return( + logits, + lm_head, + neuron_config, + ) + + self.assertIs(actual, gathered) + gather.assert_called_once_with( + logits, + partition_dim=2, + process_group="tp_group", + ) + + def test_output_logits_skip_gather_when_not_vocab_sharded(self): + logits = torch.arange(6, dtype=torch.float32).reshape(1, 1, 6) + lm_head = SimpleNamespace(gather_output=True) + neuron_config = SimpleNamespace( + output_logits=True, + on_device_sampling_config=object(), + ) + + actual = self.qwen_module._qwen36_output_logits_for_return( + logits, + lm_head, + neuron_config, + ) + + self.assertIs(actual, logits) + + def test_mlp_kernel_cte_keeps_rmsnorm_separate(self): + layer = _make_decoder_layer_for_mlp_test(self.qwen_module) + hidden = torch.zeros((1, 4, 4), dtype=torch.float32) + + outputs = layer.forward(hidden, is_for_context_encoding=True) + + self.assertEqual(layer.post_attention_layernorm.calls, 1) + self.assertEqual(len(layer.mlp.calls), 1) + mlp_input, fused_rmsnorm = layer.mlp.calls[0] + self.assertIsNone(fused_rmsnorm) + self.assertTrue(torch.allclose(mlp_input, torch.ones_like(mlp_input))) + self.assertEqual(outputs[0].shape, hidden.shape) + + def test_mlp_kernel_tkg_can_fuse_rmsnorm(self): + layer = _make_decoder_layer_for_mlp_test(self.qwen_module) + hidden = torch.zeros((1, 1, 4), dtype=torch.float32) + + layer.forward(hidden, is_for_context_encoding=False) + + self.assertEqual(layer.post_attention_layernorm.calls, 0) + self.assertEqual(len(layer.mlp.calls), 1) + _mlp_input, fused_rmsnorm = layer.mlp.calls[0] + self.assertIs(fused_rmsnorm, layer.post_attention_layernorm) + + def test_fused_deltanet_does_not_clamp_cumulative_decay(self): + self.assertFalse( + hasattr(self.qwen_module, "_bound_fused_deltanet_log_decay") + ) + + def test_split_qkv_tkg_keeps_output_gate_on_standard_projection(self): + init_source = inspect.getsource(self.qwen_module.NeuronQwen35Attention.__init__) + split_tuple_source = init_source.split("split_qkv_projections = (", 1)[1].split( + ")", + 1, + )[0] + self.assertNotIn("output_gate_proj", split_tuple_source) + + forward_source = inspect.getsource(self.qwen_module.NeuronQwen35Attention.forward) + self.assertIn("gate = self.output_gate_proj(hidden_states)", forward_source) + self.assertNotIn( + "self._run_split_qkv_tkg_projection(\n" + " hidden_states,\n" + " self.output_gate_proj,", + forward_source, + ) + + def test_hybrid_checkpoint_commit_ignores_inactive_duplicate_slot_rows(self): + config = SimpleNamespace( + layer_types=["linear_attention"], + max_gdn_checkpoint_slots=3, + linear_num_value_heads=1, + linear_num_key_heads=1, + linear_key_head_dim=2, + linear_value_head_dim=2, + linear_conv_kernel_dim=3, + hybrid_recurrent_cache_dtype="float32", + hybrid_conv_cache_dtype="bfloat16", + neuron_config=SimpleNamespace(tp_degree=1), + ) + cache = self.qwen_module.HybridGDNCheckpointCache(config) + with torch.no_grad(): + cache.recurrent_slots[0].copy_( + torch.arange(12, dtype=torch.float32).reshape(3, 1, 2, 2) + ) + cache.conv_slots[0].copy_( + torch.arange(36, dtype=torch.bfloat16).reshape(3, 6, 2) + ) + + old_recurrent = cache.recurrent_slots[0].detach().clone() + old_conv = cache.conv_slots[0].detach().clone() + recurrent_state = torch.stack( + [ + torch.full((1, 2, 2), 101.0), + torch.full((1, 2, 2), 999.0), + ] + ) + conv_state = torch.stack( + [ + torch.full((6, 2), 11.0, dtype=torch.bfloat16), + torch.full((6, 2), 99.0, dtype=torch.bfloat16), + ] + ) + + recurrent_out, conv_out = cache.commit_from_active_rows( + layer_state_pairs=[(0, recurrent_state, conv_state)], + seq_ids=torch.tensor([0, 1], dtype=torch.int32), + checkpoint_slot_ids=torch.tensor([0, 0], dtype=torch.int32), + commit_mask=torch.tensor([1, 0], dtype=torch.int32), + ) + + self.assertTrue(torch.equal(recurrent_out[0], recurrent_state[0])) + self.assertTrue(torch.equal(conv_out[0], conv_state[0])) + self.assertTrue(torch.equal(recurrent_out[1:], old_recurrent[1:])) + self.assertTrue(torch.equal(conv_out[1:], old_conv[1:])) + + def test_hybrid_checkpoint_bank_reasserts_configured_dtype_after_global_cast(self): + config = SimpleNamespace( + layer_types=["linear_attention"], + max_gdn_checkpoint_slots=3, + linear_num_value_heads=1, + linear_num_key_heads=1, + linear_key_head_dim=2, + linear_value_head_dim=2, + linear_conv_kernel_dim=3, + hybrid_recurrent_cache_dtype="float32", + hybrid_conv_cache_dtype="bfloat16", + neuron_config=SimpleNamespace(tp_degree=1), + ) + cache = self.qwen_module.HybridGDNCheckpointCache(config).to(torch.bfloat16) + linear_attn = SimpleNamespace( + recurrent_state_buffer=nn.Parameter( + torch.zeros((1, 1, 2, 2), dtype=torch.bfloat16), + requires_grad=False, + ), + conv_state_buffer=nn.Parameter( + torch.zeros((1, 6, 2), dtype=torch.float32), + requires_grad=False, + ), + ) + module = SimpleNamespace( + config=config, + layers=[SimpleNamespace(linear_attn=linear_attn)], + hybrid_gdn_checkpoint_cache=cache, + ) + + self.assertEqual(linear_attn.recurrent_state_buffer.dtype, torch.bfloat16) + self.assertEqual(linear_attn.conv_state_buffer.dtype, torch.float32) + self.assertEqual(cache.recurrent_slots[0].dtype, torch.bfloat16) + + self.qwen_module._reassert_hybrid_gdn_checkpoint_param_dtypes(module) + + self.assertEqual(linear_attn.recurrent_state_buffer.dtype, torch.float32) + self.assertEqual(linear_attn.conv_state_buffer.dtype, torch.bfloat16) + self.assertEqual(cache.recurrent_slots[0].dtype, torch.float32) + self.assertEqual(cache.conv_slots[0].dtype, torch.bfloat16) + self.assertEqual(cache.recurrent_dtype, torch.float32) + self.assertEqual(cache.conv_dtype, torch.bfloat16) + + def test_hybrid_checkpoint_restore_clamps_slots_and_ignores_inactive_rows(self): + config = SimpleNamespace( + layer_types=["linear_attention"], + max_gdn_checkpoint_slots=3, + linear_num_value_heads=1, + linear_num_key_heads=1, + linear_key_head_dim=2, + linear_value_head_dim=2, + linear_conv_kernel_dim=3, + hybrid_recurrent_cache_dtype="float32", + hybrid_conv_cache_dtype="bfloat16", + neuron_config=SimpleNamespace(tp_degree=1), + ) + cache = self.qwen_module.HybridGDNCheckpointCache(config) + with torch.no_grad(): + cache.recurrent_slots[0][0].fill_(10) + cache.recurrent_slots[0][1].fill_(20) + cache.recurrent_slots[0][2].fill_(30) + cache.conv_slots[0][0].fill_(1) + cache.conv_slots[0][1].fill_(2) + cache.conv_slots[0][2].fill_(3) + + recurrent_state_buffer = torch.stack( + [ + torch.full((1, 2, 2), 101.0), + torch.full((1, 2, 2), 202.0), + ] + ) + conv_state_buffer = torch.stack( + [ + torch.full((6, 2), 11.0, dtype=torch.bfloat16), + torch.full((6, 2), 22.0, dtype=torch.bfloat16), + ] + ) + layers = [ + SimpleNamespace( + linear_attn=SimpleNamespace( + recurrent_state_buffer=recurrent_state_buffer, + conv_state_buffer=conv_state_buffer, + ) + ) + ] + + restored = cache.restore_to_active_rows( + layers=layers, + seq_ids=torch.tensor([1, -1], dtype=torch.int32), + checkpoint_slot_ids=torch.tensor([999, 999], dtype=torch.int32), + restore_mask=torch.tensor([1, 0], dtype=torch.int32), + ) + recurrent_out, conv_out = restored[0] + + self.assertTrue(torch.equal(recurrent_out[0], cache.recurrent_slots[0][2])) + self.assertTrue(torch.equal(conv_out[0], cache.conv_slots[0][2])) + self.assertTrue(torch.equal(recurrent_out[1], recurrent_state_buffer[0])) + self.assertTrue(torch.equal(conv_out[1], conv_state_buffer[0])) + + def test_hybrid_checkpoint_restore_zeroes_inactive_rows_for_context_prefill(self): + config = SimpleNamespace( + layer_types=["linear_attention"], + max_gdn_checkpoint_slots=2, + linear_num_value_heads=1, + linear_num_key_heads=1, + linear_key_head_dim=2, + linear_value_head_dim=2, + linear_conv_kernel_dim=3, + hybrid_recurrent_cache_dtype="float32", + hybrid_conv_cache_dtype="bfloat16", + neuron_config=SimpleNamespace(tp_degree=1), + ) + cache = self.qwen_module.HybridGDNCheckpointCache(config) + recurrent_state_buffer = torch.full((1, 1, 2, 2), 101.0) + conv_state_buffer = torch.full((1, 6, 2), 11.0, dtype=torch.bfloat16) + layers = [ + SimpleNamespace( + linear_attn=SimpleNamespace( + recurrent_state_buffer=recurrent_state_buffer, + conv_state_buffer=conv_state_buffer, + ) + ) + ] + + restored = cache.restore_to_active_rows( + layers=layers, + seq_ids=torch.tensor([0], dtype=torch.int32), + checkpoint_slot_ids=torch.tensor([0], dtype=torch.int32), + restore_mask=torch.tensor([0], dtype=torch.int32), + zero_inactive=True, + ) + recurrent_out, conv_out = restored[0] + + self.assertTrue(torch.equal(recurrent_out, torch.zeros_like(recurrent_out))) + self.assertTrue(torch.equal(conv_out, torch.zeros_like(conv_out))) + + def test_legacy_tkg_args_are_env_gated(self): + with patch.dict(os.environ, {}, clear=True): + self.assertFalse(self.qwen_module._use_legacy_tkg_args()) + with patch.dict(os.environ, {"QWEN36_TKG_LEGACY_ARGS": "1"}, clear=True): + self.assertTrue(self.qwen_module._use_legacy_tkg_args()) + + def test_legacy_tkg_uses_prefix_contract_for_cte_trace_args(self): + wrapper = _make_wrapper( + self.qwen_module, + tag=self.qwen_module.CONTEXT_ENCODING_MODEL_TAG, + ) + + with patch.dict(os.environ, {"QWEN36_TKG_LEGACY_ARGS": "1"}, clear=True): + generated = wrapper.input_generator()[0] + + self.assertEqual(len(generated), 24) + self.assertEqual(generated[11].shape, (1, 1)) + self.assertEqual(generated[12].shape, (1, 1)) + self.assertEqual(generated[13].shape, (1, 1)) + self.assertEqual(generated[14].shape, (1, 1)) + + def test_legacy_tkg_trace_args_keep_prefix_metadata(self): + wrapper = _make_wrapper( + self.qwen_module, + tag=self.qwen_module.TOKEN_GENERATION_MODEL_TAG, + ) + + with patch.dict(os.environ, {"QWEN36_TKG_LEGACY_ARGS": "1"}, clear=True): + generated = wrapper.input_generator()[0] + + self.assertEqual(len(generated), 24) + self.assertEqual(generated[11].shape, (1, 1)) + self.assertEqual(generated[12].shape, (1, 1)) + self.assertEqual(generated[13].shape, (1, 1)) + self.assertEqual(generated[14].shape, (1, 1)) + + def test_prefix_cache_pad_inputs_expands_minimal_runtime_args(self): + wrapper = _make_wrapper( + self.qwen_module, + tag=self.qwen_module.CONTEXT_ENCODING_MODEL_TAG, + use_hybrid_apc_manager=False, + ) + wrapper.is_prefix_caching = True + wrapper.neuron_config = SimpleNamespace( + enable_fused_speculation=False, + enable_eagle_speculation=False, + ) + + padded = wrapper.pad_inputs(*wrapper._base_inputs[0]) + + self.assertEqual(len(padded), 24) + self.assertEqual(padded[15].numel(), 0) + self.assertEqual(padded[21].shape, (3, 1, 1)) + + def test_hybrid_prefix_cache_pad_inputs_expands_minimal_runtime_args(self): + wrapper = _make_wrapper( + self.qwen_module, + tag=self.qwen_module.CONTEXT_ENCODING_MODEL_TAG, + ) + wrapper.is_prefix_caching = True + wrapper.neuron_config = SimpleNamespace( + enable_fused_speculation=False, + enable_eagle_speculation=False, + ) + + padded = wrapper.pad_inputs(*wrapper._base_inputs[0]) + + self.assertEqual(len(padded), 29) + self.assertEqual(padded[15].numel(), 0) + self.assertEqual(padded[24].shape, (1,)) + + def test_nonlegacy_tkg_trace_args_keep_prefix_and_hybrid_metadata(self): + wrapper = _make_wrapper( + self.qwen_module, + tag=self.qwen_module.TOKEN_GENERATION_MODEL_TAG, + ) + + with patch.dict(os.environ, {}, clear=True): + generated = wrapper.input_generator()[0] + + self.assertEqual(len(generated), 29) + self.assertEqual(generated[11].shape, (1, 1)) + self.assertEqual(generated[14].shape, (1, 1)) + self.assertEqual(generated[24].shape, (1,)) + + def test_tkg_token_guard_rejects_out_of_vocab_id(self): + with self.assertRaisesRegex(ValueError, "out-of-vocab token id"): + self.qwen_module._validate_qwen36_tkg_input_ids( + torch.tensor([[2143289344]], dtype=torch.int32), + 248320, + ) + + def test_tkg_token_guard_accepts_valid_vocab_id(self): + self.qwen_module._validate_qwen36_tkg_input_ids( + torch.tensor([[42]], dtype=torch.int32), + 248320, + ) + + def test_prefill_detection_keeps_nonzero_multi_token_suffix_on_cte(self): + self.assertTrue( + self.qwen_module._qwen36_is_prefill_request( + torch.ones((1, 207), dtype=torch.int32), + torch.arange(207, 414, dtype=torch.int32).reshape(1, -1), + ) + ) + + def test_prefill_detection_keeps_one_token_nonzero_decode_on_tkg(self): + self.assertFalse( + self.qwen_module._qwen36_is_prefill_request( + torch.ones((1, 1), dtype=torch.int32), + torch.tensor([[207]], dtype=torch.int32), + ) + ) + + def test_prefill_detection_routes_packed_batched_decode_to_tkg(self): + self.assertFalse( + self.qwen_module._qwen36_is_prefill_request( + torch.ones((1, 2), dtype=torch.int32), + torch.tensor([[272, 272]], dtype=torch.int32), + full_context_lens=torch.tensor([273, 273], dtype=torch.int32), + computed_context_lens=torch.tensor([272, 272], dtype=torch.int32), + prefill_completion_state=torch.tensor([True, True]), + ) + ) + + def test_prefill_detection_keeps_packed_suffix_prefill_on_cte(self): + self.assertTrue( + self.qwen_module._qwen36_is_prefill_request( + torch.ones((1, 32), dtype=torch.int32), + torch.arange(256, 288, dtype=torch.int32).reshape(1, -1), + full_context_lens=torch.tensor([272, 272], dtype=torch.int32), + computed_context_lens=torch.tensor([256, 256], dtype=torch.int32), + prefill_completion_state=torch.tensor([True, True]), + ) + ) + + def test_prefill_detection_keeps_incomplete_one_token_prefill_on_cte(self): + self.assertTrue( + self.qwen_module._qwen36_is_prefill_request( + torch.ones((1, 2), dtype=torch.int32), + torch.tensor([[272, 49]], dtype=torch.int32), + full_context_lens=torch.tensor([273, 50], dtype=torch.int32), + computed_context_lens=torch.tensor([272, 49], dtype=torch.int32), + prefill_completion_state=torch.tensor([True, False]), + ) + ) + + def test_hybrid_apc_controls_need_prepare_for_missing_or_inert_masks(self): + self.assertTrue( + self.qwen_module._qwen36_hybrid_apc_controls_need_prepare(None, None) + ) + self.assertTrue( + self.qwen_module._qwen36_hybrid_apc_controls_need_prepare( + torch.tensor([0], dtype=torch.int32), + torch.tensor([0], dtype=torch.int32), + ) + ) + + def test_hybrid_apc_controls_skip_prepare_for_active_masks(self): + self.assertFalse( + self.qwen_module._qwen36_hybrid_apc_controls_need_prepare( + torch.tensor([1], dtype=torch.int32), + torch.tensor([0], dtype=torch.int32), + ) + ) + self.assertFalse( + self.qwen_module._qwen36_hybrid_apc_controls_need_prepare( + torch.tensor([0], dtype=torch.int32), + torch.tensor([1], dtype=torch.int32), + ) + ) + + def test_hybrid_apc_controls_materialized_for_zeroed_restore_mask(self): + self.assertTrue( + self.qwen_module._qwen36_hybrid_apc_controls_materialized( + torch.tensor([0], dtype=torch.int32), + torch.tensor([2048], dtype=torch.int32), + torch.tensor([0], dtype=torch.int32), + ) + ) + self.assertFalse( + self.qwen_module._qwen36_hybrid_apc_controls_materialized( + torch.tensor([0], dtype=torch.int32), + torch.tensor([0], dtype=torch.int32), + torch.tensor([0], dtype=torch.int32), + ) + ) + + def test_hybrid_apc_pad_prepare_preserves_full_prefix_tail_contract(self): + wrapper = _make_wrapper( + self.qwen_module, + tag=self.qwen_module.CONTEXT_ENCODING_MODEL_TAG, + ) + wrapper.neuron_config = SimpleNamespace( + enable_fused_speculation=False, + enable_eagle_speculation=False, + ) + wrapper.is_prefix_caching = True + + empty = torch.empty(0) + base_args = list(wrapper._base_inputs[0]) + original_tail = [ + empty, # tile_q_indices + empty, # tile_block_tables + empty, # tile_masks + empty, # inputs_embeds + empty, # kv_cache + empty, # active_mask + torch.empty(0, dtype=torch.int32), # rotary_position_id + torch.empty(0, dtype=torch.bfloat16), # vision_embeddings + torch.empty(0, dtype=torch.int32), # vision_mask + torch.zeros((1,), dtype=torch.int32), # restore slot + torch.zeros((1,), dtype=torch.int32), # restore mask + torch.zeros((1,), dtype=torch.int32), # restore prefix len + torch.zeros((1,), dtype=torch.int32), # commit slot + torch.zeros((1,), dtype=torch.int32), # commit mask + ] + prepared_tail = [ + empty, + empty, + empty, + empty, + empty, + empty, + torch.empty(0, dtype=torch.int32), + torch.empty(0, dtype=torch.bfloat16), + torch.empty(0, dtype=torch.int32), + torch.tensor([3], dtype=torch.int32), + torch.tensor([0], dtype=torch.int32), + torch.tensor([0], dtype=torch.int32), + torch.tensor([7], dtype=torch.int32), + torch.tensor([1], dtype=torch.int32), + ] + + def _prepare_request(_wrapper, input_dict): + return input_dict + + with ( + patch.object( + self.qwen_module, + "prepare_hybrid_apc_request_for_execution", + side_effect=_prepare_request, + ), + patch.object( + self.qwen_module, + "prepare_hybrid_apc_model_inputs", + return_value=prepared_tail, + ), + ): + padded = wrapper.pad_inputs(*(base_args + original_tail)) + + self.assertEqual(len(padded), 29) + self.assertEqual(int(padded[24].item()), 3) + self.assertEqual(int(padded[27].item()), 7) + self.assertEqual(int(padded[28].item()), 1) + + def test_hybrid_apc_pad_prepare_skips_materialized_restore_prefix(self): + wrapper = _make_wrapper( + self.qwen_module, + tag=self.qwen_module.CONTEXT_ENCODING_MODEL_TAG, + ) + wrapper.neuron_config = SimpleNamespace( + enable_fused_speculation=False, + enable_eagle_speculation=False, + ) + wrapper.is_prefix_caching = True + + empty = torch.empty(0) + base_args = list(wrapper._base_inputs[0]) + materialized_tail = [ + empty, + empty, + empty, + empty, + empty, + empty, + torch.empty(0, dtype=torch.int32), + torch.empty(0, dtype=torch.bfloat16), + torch.empty(0, dtype=torch.int32), + torch.tensor([5], dtype=torch.int32), + torch.tensor([0], dtype=torch.int32), + torch.tensor([2048], dtype=torch.int32), + torch.tensor([0], dtype=torch.int32), + torch.tensor([0], dtype=torch.int32), + ] + + with patch.object( + self.qwen_module, + "prepare_hybrid_apc_request_for_execution", + side_effect=AssertionError("materialized controls must not re-prepare"), + ): + padded = wrapper.pad_inputs(*(base_args + materialized_tail)) + + self.assertEqual(len(padded), 29) + self.assertEqual(int(padded[24].item()), 5) + self.assertEqual(int(padded[25].item()), 0) + self.assertEqual(int(padded[26].item()), 2048) + + def test_restored_suffix_deltanet_mask_uses_token_padding(self): + input_ids = torch.tensor([[11, 12, 0, 0]], dtype=torch.int64) + inputs_embeds = torch.ones((1, 4, 2), dtype=torch.float32) + attention_mask = torch.ones((1, 4), dtype=torch.int32) + + mask = self.qwen_module._qwen36_deltanet_padding_mask( + input_ids=input_ids, + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + padding_idx=0, + is_for_context_encoding=True, + hybrid_restore_mask=torch.tensor([1], dtype=torch.int32), + ) + + self.assertEqual(mask.squeeze(-1).tolist(), [[1.0, 1.0, 0.0, 0.0]]) + + def test_non_restored_deltanet_mask_keeps_attention_mask(self): + input_ids = torch.tensor([[11, 12, 0, 0]], dtype=torch.int64) + inputs_embeds = torch.ones((1, 4, 2), dtype=torch.float32) + attention_mask = torch.ones((1, 4), dtype=torch.int32) + + mask = self.qwen_module._qwen36_deltanet_padding_mask( + input_ids=input_ids, + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + padding_idx=0, + is_for_context_encoding=True, + hybrid_restore_mask=torch.tensor([0], dtype=torch.int32), + ) + + self.assertEqual(mask.squeeze(-1).tolist(), [[1.0, 1.0, 1.0, 1.0]]) + + def test_mixed_restore_deltanet_mask_uses_token_padding_per_restored_row(self): + input_ids = torch.tensor( + [[11, 12, 0, 0], [21, 22, 23, 0]], dtype=torch.int64 + ) + inputs_embeds = torch.ones((2, 4, 2), dtype=torch.float32) + attention_mask = torch.tensor( + [[1, 1, 1, 1], [0, 0, 0, 0]], dtype=torch.int32 + ) + + mask = self.qwen_module._qwen36_deltanet_padding_mask( + input_ids=input_ids, + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + padding_idx=0, + is_for_context_encoding=True, + hybrid_restore_mask=torch.tensor([1, 0], dtype=torch.int32), + ) + + self.assertEqual( + mask.squeeze(-1).tolist(), + [[1.0, 1.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]], + ) + + def test_deltanet_mask_uses_num_queries_when_attention_mask_is_full_context(self): + input_ids = torch.tensor( + [[11, 12, 13, 248044, 248044], [21, 248044, 248044, 248044, 248044]], + dtype=torch.int64, + ) + inputs_embeds = torch.ones((2, 5, 2), dtype=torch.float32) + attention_mask = torch.ones((2, 16), dtype=torch.int32) + + mask = self.qwen_module._qwen36_deltanet_padding_mask( + input_ids=input_ids, + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + padding_idx=None, + is_for_context_encoding=True, + hybrid_restore_mask=torch.tensor([0, 0], dtype=torch.int32), + num_queries=torch.tensor([[3], [1]], dtype=torch.int32), + ) + + self.assertEqual( + mask.squeeze(-1).tolist(), + [[1.0, 1.0, 1.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0, 0.0]], + ) + + def test_negative_dummy_seq_ids_mark_inactive_state_rows(self): + active_rows = self.qwen_module._qwen36_active_state_rows( + torch.ones((2, 4, 1), dtype=torch.float32), + torch.tensor([0, -1], dtype=torch.int32), + ) + + self.assertTrue( + torch.equal( + active_rows, + torch.tensor([True, False]), + ) + ) + + def test_inactive_dummy_rows_preserve_previous_state(self): + previous = torch.tensor( + [[[1.0, 2.0]], [[3.0, 4.0]]], + dtype=torch.float32, + ) + updated = torch.tensor( + [[[10.0, 20.0]], [[30.0, 40.0]]], + dtype=torch.float32, + ) + + preserved = self.qwen_module._qwen36_preserve_inactive_state_rows( + updated, + previous, + torch.tensor([True, False]), + ) + + self.assertTrue( + torch.equal( + preserved, + torch.tensor( + [[[10.0, 20.0]], [[3.0, 4.0]]], + dtype=torch.float32, + ), + ) + ) + + def test_state_rows_update_by_seq_ids(self): + previous = torch.tensor( + [[[1.0]], [[2.0]], [[3.0]]], + dtype=torch.float32, + ) + updated_rows = torch.tensor( + [[[10.0]], [[20.0]]], + dtype=torch.float32, + ) + + updated = self.qwen_module._qwen36_update_state_rows_by_seq_ids( + previous, + updated_rows, + torch.tensor([2, 0], dtype=torch.int32), + ) + + self.assertTrue( + torch.equal( + updated, + torch.tensor( + [[[20.0]], [[2.0]], [[10.0]]], + dtype=torch.float32, + ), + ) + ) + + def test_negative_seq_id_state_row_is_noop(self): + previous = torch.tensor( + [[[1.0]], [[2.0]]], + dtype=torch.float32, + ) + updated_rows = torch.tensor( + [[[10.0]], [[20.0]]], + dtype=torch.float32, + ) + + updated = self.qwen_module._qwen36_update_state_rows_by_seq_ids( + previous, + updated_rows, + torch.tensor([1, -1], dtype=torch.int32), + ) + + self.assertTrue( + torch.equal( + updated, + torch.tensor( + [[[1.0]], [[10.0]]], + dtype=torch.float32, + ), + ) + ) + + def test_request_ids_keep_stable_seq_slots_when_rows_reorder(self): + model = SimpleNamespace( + context_encoding_model=SimpleNamespace( + neuron_config=SimpleNamespace(batch_size=2) + ), + token_generation_model=SimpleNamespace( + neuron_config=SimpleNamespace(batch_size=2) + ), + ) + + first = self.qwen_module._qwen36_stable_seq_ids_for_request_ids( + model, + torch.tensor([0], dtype=torch.int32), + ("req-a",), + ) + mixed = self.qwen_module._qwen36_stable_seq_ids_for_request_ids( + model, + torch.tensor([0], dtype=torch.int32), + ("req-b", "req-a"), + ) + next_mixed = self.qwen_module._qwen36_stable_seq_ids_for_request_ids( + model, + torch.tensor([0], dtype=torch.int32), + ("req-c", "req-b"), + ) + + self.assertTrue(torch.equal(first, torch.tensor([0], dtype=torch.int32))) + self.assertTrue(torch.equal(mixed, torch.tensor([1, 0], dtype=torch.int32))) + self.assertTrue( + torch.equal(next_mixed, torch.tensor([0, 1], dtype=torch.int32)) + ) + + def test_single_new_request_reuses_first_stale_seq_slot(self): + model = SimpleNamespace( + context_encoding_model=SimpleNamespace( + neuron_config=SimpleNamespace(batch_size=2) + ), + token_generation_model=SimpleNamespace( + neuron_config=SimpleNamespace(batch_size=2) + ), + ) + + first = self.qwen_module._qwen36_stable_seq_ids_for_request_ids( + model, + torch.tensor([0], dtype=torch.int32), + ("req-a",), + ) + second = self.qwen_module._qwen36_stable_seq_ids_for_request_ids( + model, + torch.tensor([0], dtype=torch.int32), + ("req-b",), + ) + + self.assertTrue(torch.equal(first, torch.tensor([0], dtype=torch.int32))) + self.assertTrue(torch.equal(second, torch.tensor([0], dtype=torch.int32))) + + def test_checkpoint_cache_active_rows_follow_seq_slots(self): + state = torch.tensor( + [[[1.0]], [[2.0]]], + dtype=torch.float32, + ) + + active = self.qwen_module.HybridGDNCheckpointCache._active_rows( + state, + torch.tensor([1, -1], dtype=torch.int32), + 2, + ) + + self.assertTrue( + torch.equal( + active, + torch.tensor( + [[[2.0]], [[1.0]]], + dtype=torch.float32, + ), + ) + ) + + def _make_tiny_deltanet_for_carry_test(self, recurrent_dtype=torch.float32): + layer = self.qwen_module.NeuronGatedDeltaNet.__new__( + self.qwen_module.NeuronGatedDeltaNet + ) + nn.Module.__init__(layer) + + hidden_size = 4 + key_dim = 2 + value_dim = 2 + conv_kernel_size = 3 + conv_dim = key_dim * 2 + value_dim + + layer.hidden_size = hidden_size + layer.tp_degree = 1 + layer.global_num_v_heads = 1 + layer.global_num_k_heads = 1 + layer.head_k_dim = key_dim + layer.head_v_dim = value_dim + layer.num_v_heads = 1 + layer.num_k_heads = 1 + layer.global_key_dim = key_dim + layer.global_value_dim = value_dim + layer.key_dim = key_dim + layer.value_dim = value_dim + layer.conv_kernel_size = conv_kernel_size + layer.conv_dim = conv_dim + layer.layer_idx = 0 + layer.use_hybrid_cache_manager = False + layer.use_hybrid_apc_manager = True + layer.use_qwen_hybrid_chunked_prefill = True + layer.use_qwen_hybrid_chunked_prefill_nki = False + layer.use_qwen_deltanet_decode_nki = False + layer.use_cold_zero_conv_fast_path = False + layer.head_dim = key_dim + layer.kv_heads_per_rank = 1 + + layer.conv1d_weight = nn.Linear(conv_kernel_size, conv_dim, bias=False) + layer.in_proj_qkv = nn.Linear(hidden_size, conv_dim, bias=False) + layer.in_proj_z = nn.Linear(hidden_size, value_dim, bias=False) + layer.in_proj_b = nn.Linear(hidden_size, 1, bias=False) + layer.in_proj_a = nn.Linear(hidden_size, 1, bias=False) + layer.dt_bias_weight = nn.Linear(1, 1, bias=False) + layer.A_log_weight = nn.Linear(1, 1, bias=False) + layer.norm = nn.Identity() + layer.out_proj = nn.Linear(value_dim, hidden_size, bias=False) + layer.recurrent_state_buffer = nn.Parameter( + torch.zeros((1, 1, key_dim, value_dim), dtype=recurrent_dtype), + requires_grad=False, + ) + layer.conv_state_buffer = nn.Parameter( + torch.zeros((1, conv_dim, conv_kernel_size - 1), dtype=torch.bfloat16), + requires_grad=False, + ) + + with torch.no_grad(): + for module in ( + layer.conv1d_weight, + layer.in_proj_qkv, + layer.in_proj_z, + layer.in_proj_b, + layer.in_proj_a, + layer.out_proj, + ): + module.weight.uniform_(-0.04, 0.04) + layer.dt_bias_weight.weight.fill_(-1.0) + layer.A_log_weight.weight.fill_(-2.0) + + return layer + + def _assert_hybrid_gdn_checkpoint_carry_matches_full_prefill_on_cpu( + self, seq_len + ): + torch.manual_seed(36 + seq_len) + chunk = 512 + full_chunks = seq_len // chunk + suffix = seq_len - full_chunks * chunk + self.assertGreater(full_chunks, 0) + self.assertGreater(suffix, 0) + hidden = torch.randn((1, seq_len, 4), dtype=torch.float32) * 0.05 + + for recurrent_dtype in (torch.float32, torch.bfloat16): + layer = self._make_tiny_deltanet_for_carry_test(recurrent_dtype) + cache_config = SimpleNamespace( + layer_types=["linear_attention"], + max_gdn_checkpoint_slots=3, + linear_num_value_heads=1, + linear_num_key_heads=1, + linear_key_head_dim=2, + linear_value_head_dim=2, + linear_conv_kernel_dim=3, + hybrid_recurrent_cache_dtype=( + "bfloat16" if recurrent_dtype is torch.bfloat16 else "float32" + ), + hybrid_conv_cache_dtype="bfloat16", + neuron_config=SimpleNamespace(tp_degree=1), + ) + cache = self.qwen_module.HybridGDNCheckpointCache(cache_config) + layers = [SimpleNamespace(linear_attn=layer)] + seq_ids = torch.tensor([0], dtype=torch.int32) + + def run_cte(tokens, start_pos, past): + positions = torch.arange( + start_pos, + start_pos + tokens.shape[1], + dtype=torch.int64, + ).unsqueeze(0) + mask = torch.ones((1, tokens.shape[1], 1), dtype=torch.float32) + with patch.dict( + os.environ, + {"USE_PYTORCH_CHUNK": "1", "USE_NKI_FUSED": "0"}, + clear=False, + ): + output, _kv, recurrent, conv = layer( + tokens, + position_ids=positions, + past_key_value=past, + seq_ids=seq_ids, + is_for_context_encoding=True, + deltanet_padding_mask=mask, + ) + return output, recurrent, conv + + def commit(slot, recurrent, conv): + recurrent_out, conv_out = cache.commit_from_active_rows( + layer_state_pairs=[(0, recurrent, conv)], + seq_ids=seq_ids, + checkpoint_slot_ids=torch.tensor([slot], dtype=torch.int32), + commit_mask=torch.tensor([1], dtype=torch.int32), + ) + with torch.no_grad(): + cache.recurrent_slots[0].copy_(recurrent_out) + cache.conv_slots[0].copy_(conv_out) + + def restore(slot): + return cache.restore_to_active_rows( + layers=layers, + seq_ids=seq_ids, + checkpoint_slot_ids=torch.tensor([slot], dtype=torch.int32), + restore_mask=torch.tensor([1], dtype=torch.int32), + )[0] + + zero_past = ( + torch.zeros_like(layer.recurrent_state_buffer), + torch.zeros_like(layer.conv_state_buffer), + ) + full_output, full_recurrent, full_conv = run_cte(hidden, 0, zero_past) + + split_outputs = [] + past = zero_past + for chunk_idx in range(full_chunks): + start = chunk_idx * chunk + output, recurrent, conv = run_cte( + hidden[:, start : start + chunk], start, past + ) + split_outputs.append(output) + commit(chunk_idx, recurrent, conv) + past = restore(chunk_idx) + + padded_tail = torch.zeros((1, chunk, 4), dtype=hidden.dtype) + tail_start = full_chunks * chunk + padded_tail[:, :suffix] = hidden[:, tail_start:] + positions = torch.cat( + [ + torch.arange(tail_start, seq_len, dtype=torch.int64), + torch.ones((chunk - suffix,), dtype=torch.int64), + ] + ).unsqueeze(0) + tail_mask = torch.zeros((1, chunk, 1), dtype=torch.float32) + tail_mask[:, :suffix] = 1 + with patch.dict( + os.environ, + {"USE_PYTORCH_CHUNK": "1", "USE_NKI_FUSED": "0"}, + clear=False, + ): + out2, _kv, rec2, conv2 = layer( + padded_tail, + position_ids=positions, + past_key_value=past, + seq_ids=seq_ids, + is_for_context_encoding=True, + deltanet_padding_mask=tail_mask, + ) + + split_outputs.append(out2[:, :suffix]) + split_output = torch.cat(split_outputs, dim=1) + max_diff = (split_output - full_output).abs().max().item() + rec_diff = (rec2.float() - full_recurrent.float()).abs().max().item() + conv_diff = (conv2.float() - full_conv.float()).abs().max().item() + + tolerance = 1e-5 if recurrent_dtype is torch.float32 else 2e-3 + msg = (seq_len, recurrent_dtype) + self.assertLessEqual(max_diff, tolerance, msg) + self.assertLessEqual(rec_diff, tolerance, msg) + self.assertLessEqual(conv_diff, tolerance, msg) + + def test_hybrid_gdn_checkpoint_carry_matches_full_prefill_on_cpu(self): + self._assert_hybrid_gdn_checkpoint_carry_matches_full_prefill_on_cpu(1225) + + def test_hybrid_gdn_checkpoint_carry_matches_full_prefill_on_cpu_at_cliff(self): + self._assert_hybrid_gdn_checkpoint_carry_matches_full_prefill_on_cpu(526) + + def test_dummy_cte_rows_zero_restore_controls(self): + restore_slots, restore_mask, restore_prefix = ( + self.qwen_module._qwen36_pad_hybrid_restore_controls_for_dummy_cte_rows( + torch.tensor([7], dtype=torch.int32), + torch.tensor([1], dtype=torch.int32), + torch.tensor([256], dtype=torch.int32), + 2, + ) + ) + + self.assertTrue( + torch.equal(restore_slots, torch.tensor([7, 0], dtype=torch.int32)) + ) + self.assertTrue( + torch.equal(restore_mask, torch.tensor([1, 0], dtype=torch.int32)) + ) + self.assertTrue( + torch.equal(restore_prefix, torch.tensor([256, 0], dtype=torch.int32)) + ) + + def test_packed_decode_batch_is_unpacked_for_tkg(self): + input_ids, attention_mask, position_ids, seq_ids, adapter_ids, slot_mapping = ( + self.qwen_module._qwen36_unpack_packed_decode_batch( + input_ids=torch.tensor([[271, 198]], dtype=torch.int32), + attention_mask=torch.tensor([[1, 1]], dtype=torch.int32), + position_ids=torch.tensor([[272, 272]], dtype=torch.int32), + seq_ids=torch.tensor([0], dtype=torch.int32), + adapter_ids=torch.tensor([0], dtype=torch.int32), + slot_mapping=torch.tensor([1552, 1553], dtype=torch.int32), + full_context_lens=torch.tensor([273, 273], dtype=torch.int32), + computed_context_lens=torch.tensor([272, 272], dtype=torch.int32), + ) + ) + + self.assertTrue( + torch.equal(input_ids, torch.tensor([[271], [198]], dtype=torch.int32)) + ) + self.assertEqual(attention_mask.shape, (2, 272)) + self.assertTrue( + torch.equal( + attention_mask[:, :272], + torch.ones((2, 272), dtype=torch.int32), + ) + ) + self.assertTrue( + torch.equal( + position_ids, + torch.tensor([[272], [272]], dtype=torch.int32), + ) + ) + self.assertTrue(torch.equal(seq_ids, torch.tensor([0, 1], dtype=torch.int32))) + self.assertTrue( + torch.equal(adapter_ids, torch.tensor([0, 0], dtype=torch.int32)) + ) + self.assertTrue( + torch.equal( + slot_mapping, + torch.tensor([[1552], [1553]], dtype=torch.int32), + ) + ) + + def test_request_scoped_vllm_metadata_is_added_for_hybrid_apc(self): + request_dict = {} + + self.qwen_module._qwen36_add_vllm_hybrid_apc_metadata( + request_dict, + request_ids=("req-a", "req-b"), + metadata_by_request_id={ + "req-a": { + "cumulative_hashes_by_prefix_len": {256: b"a"}, + "attention_block_refs_by_prefix_len": {256: (1,)}, + "request_prefix_len": 256, + "vllm_attention_hit_len": 0, + "active_suffix_len": 256, + }, + "req-b": { + "cumulative_hashes_by_prefix_len": {256: b"b"}, + "attention_block_refs_by_prefix_len": {256: (2,)}, + "request_prefix_len": 272, + "vllm_attention_hit_len": 256, + "active_suffix_len": 16, + }, + }, + ) + + self.assertEqual( + request_dict["cumulative_hashes_by_prefix_len"], + ({256: b"a"}, {256: b"b"}), + ) + self.assertEqual( + request_dict["attention_block_refs_by_prefix_len"], + ({256: (1,)}, {256: (2,)}), + ) + self.assertEqual(request_dict["request_prefix_len"], (256, 272)) + self.assertEqual(request_dict["vllm_attention_hit_len"], (0, 256)) + self.assertEqual(request_dict["active_suffix_len"], (256, 16)) + + def test_request_scoped_vllm_metadata_tensorizes_full_input_ids(self): + request_dict = { + "input_ids": torch.empty((1, 0), dtype=torch.int32), + } + + self.qwen_module._qwen36_add_vllm_hybrid_apc_metadata( + request_dict, + request_ids=("req-2049",), + metadata_by_request_id={ + "req-2049": { + "request_prefix_len": 2049, + "full_input_ids": tuple(range(2049)), + "vllm_attention_hit_len": 2048, + "active_suffix_len": 1, + }, + }, + ) + + self.assertIsInstance(request_dict["full_input_ids"], torch.Tensor) + self.assertEqual(request_dict["full_input_ids"].dtype, torch.int32) + self.assertEqual(tuple(request_dict["full_input_ids"].shape), (1, 2049)) + self.assertEqual(int(request_dict["full_input_ids"][0, -1].item()), 2048) + + def test_vllm_metadata_request_ids_prefer_scheduler_new_request_ids(self): + selected = self.qwen_module._qwen36_select_vllm_hybrid_apc_request_ids( + { + "new-a": {"vllm_attention_hit_len": 256}, + "new-b": {"vllm_attention_hit_len": 256}, + }, + ("new-a", "new-b"), + ("model-a", "model-b"), + ) + + self.assertEqual(selected, ("new-a", "new-b")) + + def test_vllm_metadata_request_ids_use_model_order_for_packed_chunked_batch(self): + selected = ( + self.qwen_module._qwen36_select_vllm_hybrid_apc_request_ids_for_input( + { + "new-a": {"vllm_attention_hit_len": 0}, + "cached-a": {"vllm_attention_hit_len": 271}, + }, + all_request_ids=("new-a", "cached-a"), + new_request_ids=("new-a",), + full_context_lens=torch.tensor([271, 272], dtype=torch.int32), + computed_context_lens=torch.tensor([0, 271], dtype=torch.int32), + prefill_completion_state=torch.tensor([True, True]), + ) + ) + + self.assertEqual(selected, ("new-a", "cached-a")) + + def test_vllm_metadata_request_ids_keep_model_order_when_new_ids_are_subset(self): + selected = ( + self.qwen_module._qwen36_select_vllm_hybrid_apc_request_ids_for_input( + { + "new-a": {"vllm_attention_hit_len": 0}, + }, + all_request_ids=("cached-a", "new-a"), + new_request_ids=("new-a",), + full_context_lens=torch.tensor([1, 272], dtype=torch.int32), + computed_context_lens=torch.tensor([0, 0], dtype=torch.int32), + prefill_completion_state=torch.tensor([True, True]), + ) + ) + + self.assertEqual(selected, ("cached-a", "new-a")) + + def test_flattened_slot_mapping_is_normalized_before_batch_chunking(self): + flattened = torch.arange(256, 719, dtype=torch.int32) + + normalized = self.qwen_module._normalize_qwen36_slot_mapping( + flattened, + batch_size=1, + active_tokens=463, + ) + + self.assertEqual(normalized.shape, (1, 463)) + self.assertTrue(torch.equal(normalized[0], flattened)) + + def test_flattened_decode_slot_mapping_is_normalized_by_batch(self): + flattened = torch.tensor([1488, 1489], dtype=torch.int32) + + normalized = self.qwen_module._normalize_qwen36_slot_mapping( + flattened, + batch_size=2, + active_tokens=1, + ) + + self.assertTrue( + torch.equal( + normalized, + torch.tensor([[1488], [1489]], dtype=torch.int32), + ) + ) + + def test_stage_builders_keep_cte_and_tkg_contracts_explicit(self): + wrapper = _make_wrapper( + self.qwen_module, + tag=self.qwen_module.TOKEN_GENERATION_MODEL_TAG, + ) + prefix_args = wrapper._base_inputs[0] + mrope = torch.zeros((0,), dtype=torch.int32) + vision_embeddings = torch.zeros((0,), dtype=torch.bfloat16) + vision_mask = torch.zeros((0,), dtype=torch.int32) + + with patch.dict(os.environ, {"QWEN36_TKG_LEGACY_ARGS": "1"}, clear=True): + cte_args = self.qwen_module.build_cte_args( + wrapper.config, + prefix_args, + mrope, + vision_embeddings, + vision_mask, + ) + tkg_args = self.qwen_module.build_tkg_args( + wrapper.config, + prefix_args, + mrope, + vision_embeddings, + vision_mask, + ) + + self.assertEqual(len(cte_args), 24) + self.assertEqual(len(tkg_args), 24) + self.assertEqual(cte_args[13].shape, (1, 1)) + self.assertEqual(tkg_args[13].shape, (1, 1)) + + +if __name__ == "__main__": + unittest.main() diff --git a/contrib/models/Qwen3.6-27B/test/unit/test_vllm_scheduler_patch.py b/contrib/models/Qwen3.6-27B/test/unit/test_vllm_scheduler_patch.py new file mode 100644 index 00000000..a3aa0dab --- /dev/null +++ b/contrib/models/Qwen3.6-27B/test/unit/test_vllm_scheduler_patch.py @@ -0,0 +1,2465 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +import importlib.util +import os +import sys +import types +import unittest +from dataclasses import dataclass +from unittest.mock import patch + +import torch + + +_CONTRIB_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")) +_PATCH_PATH = os.path.join( + _CONTRIB_ROOT, + "vllm", + "qwen36_hybrid_apc_scheduler_patch.py", +) +_SCHEDULER_MODULE = "vllm.v1.core.sched.scheduler" +_VLLM_NEURON_RUNNER_MODULE = "vllm_neuron.worker.neuronx_distributed_model_runner" + + +@dataclass(frozen=True) +class FullAttentionSpec: + block_size: int + num_kv_heads: int + head_size: int + dtype: str + sliding_window: int | None = None + + +def _load_patch_module(): + spec = importlib.util.spec_from_file_location("qwen36_scheduler_patch", _PATCH_PATH) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module + + +def _scheduler( + *, + use_hybrid_apc=True, + disable_unbacked_prefix_reads=False, + enable_backed_prefix_reads=False, + use_qwen_hybrid_chunked_prefill=False, + block_size=2, + model_revision="rev-a", + additional_config=None, + max_num_seqs=1, +): + hf_config_kwargs = dict( + use_hybrid_apc_manager=use_hybrid_apc, + hybrid_apc_disable_unbacked_prefix_reads=disable_unbacked_prefix_reads, + hybrid_apc_enable_backed_prefix_reads=enable_backed_prefix_reads, + use_qwen_hybrid_chunked_prefill=use_qwen_hybrid_chunked_prefill, + hybrid_apc_layout_version=1, + hybrid_recurrent_cache_dtype="float32", + hybrid_conv_cache_dtype="bfloat16", + tp_rank=0, + ) + if model_revision is not None: + hf_config_kwargs["hybrid_apc_model_revision"] = model_revision + hf_config = types.SimpleNamespace(**hf_config_kwargs) + model_config = types.SimpleNamespace(hf_config=hf_config) + vllm_config = types.SimpleNamespace( + model_config=model_config, + additional_config=additional_config or {}, + ) + cache_config = types.SimpleNamespace(block_size=block_size) + scheduler_config = types.SimpleNamespace(max_num_seqs=max_num_seqs) + return types.SimpleNamespace( + vllm_config=vllm_config, + cache_config=cache_config, + scheduler_config=scheduler_config, + ) + + +class TestQwen36HybridAPCSchedulerPatch(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.patch = _load_patch_module() + + def tearDown(self): + self.patch.clear_hybrid_apc_gdn_checkpoint_registry() + sys.modules.pop(_SCHEDULER_MODULE, None) + sys.modules.pop(_VLLM_NEURON_RUNNER_MODULE, None) + sys.meta_path = [ + finder + for finder in sys.meta_path + if not getattr(finder, "_qwen36_hybrid_apc_import_hook", False) + ] + + def test_config_flag_disables_prefix_reads_for_hybrid_apc(self): + scheduler = _scheduler(disable_unbacked_prefix_reads=True) + + self.assertTrue(self.patch.should_disable_unbacked_prefix_reads(scheduler)) + + def test_env_flag_disables_prefix_reads_for_hybrid_apc(self): + scheduler = _scheduler() + + with patch.dict( + os.environ, + {"QWEN36_HYBRID_APC_DISABLE_UNBACKED_PREFIX_READS": "1"}, + ): + self.assertTrue(self.patch.should_disable_unbacked_prefix_reads(scheduler)) + + def test_env_flag_wins_when_artifact_config_is_stale(self): + scheduler = _scheduler(use_hybrid_apc=False) + + with patch.dict( + os.environ, + {"QWEN36_HYBRID_APC_DISABLE_UNBACKED_PREFIX_READS": "1"}, + ): + self.assertTrue(self.patch.should_disable_unbacked_prefix_reads(scheduler)) + + def test_reject_unbacked_mode_disables_unbacked_prefix_reads(self): + scheduler = _scheduler( + additional_config={"hybrid_apc_reject_unbacked_attention_hits": True}, + ) + + self.assertTrue(self.patch.should_disable_unbacked_prefix_reads(scheduler)) + + def test_require_metadata_mode_disables_unbacked_prefix_reads(self): + scheduler = _scheduler( + additional_config={"hybrid_apc_require_vllm_metadata": True}, + ) + + self.assertTrue(self.patch.should_disable_unbacked_prefix_reads(scheduler)) + + def test_non_hybrid_apc_model_is_not_changed(self): + scheduler = _scheduler( + use_hybrid_apc=False, + disable_unbacked_prefix_reads=True, + ) + + self.assertFalse(self.patch.should_disable_unbacked_prefix_reads(scheduler)) + + def test_patch_marks_request_skip_reading_prefix_cache(self): + calls = [] + + class FakeScheduler: + def __init__(self): + self.vllm_config = _scheduler( + disable_unbacked_prefix_reads=True + ).vllm_config + + def add_request(self, request): + calls.append(request.skip_reading_prefix_cache) + + installed = self.patch.patch_scheduler_class(FakeScheduler) + request = types.SimpleNamespace(skip_reading_prefix_cache=False) + + FakeScheduler().add_request(request) + + self.assertTrue(installed) + self.assertEqual(calls, [True]) + self.assertTrue(request.skip_reading_prefix_cache) + + def test_registered_gdn_checkpoint_keeps_prefix_read_disabled_without_cte_support(self): + scheduler = _scheduler(block_size=2) + token_ids = [10, 11, 12, 13, 14] + hashes = self.patch._local_cumulative_prefix_hashes( + token_ids, + block_size=2, + max_prefix_len=4, + ) + self.patch.register_hybrid_apc_gdn_checkpoint( + self.patch.HybridGDNPrefixKey( + cumulative_prefix_hash=hashes[2], + prefix_len=2, + block_size=2, + cache_salt=None, + model_revision="rev-a", + layout_version=1, + tp_rank=0, + recurrent_dtype="float32", + conv_dtype="bfloat16", + ) + ) + self.patch.register_hybrid_apc_gdn_checkpoint( + self.patch.HybridGDNPrefixKey( + cumulative_prefix_hash=hashes[4], + prefix_len=4, + block_size=2, + cache_salt=None, + model_revision="rev-a", + layout_version=1, + tp_rank=0, + recurrent_dtype="float32", + conv_dtype="bfloat16", + ) + ) + request = types.SimpleNamespace( + prompt_token_ids=token_ids, + num_tokens=len(token_ids), + cache_salt=None, + ) + + with patch.dict( + os.environ, + {"QWEN36_HYBRID_APC_DISABLE_UNBACKED_PREFIX_READS": "1"}, + ): + self.assertEqual( + self.patch.backed_gdn_prefix_hit_len(scheduler, request), + 4, + ) + self.assertTrue( + self.patch.should_disable_unbacked_prefix_reads(scheduler, request) + ) + + def test_registered_gdn_checkpoint_allows_prefix_read_when_cte_supports_it(self): + scheduler = _scheduler( + block_size=2, + enable_backed_prefix_reads=True, + use_qwen_hybrid_chunked_prefill=True, + ) + token_ids = [10, 11, 12, 13, 14] + hashes = self.patch._local_cumulative_prefix_hashes( + token_ids, + block_size=2, + max_prefix_len=4, + ) + self.patch.register_hybrid_apc_gdn_checkpoint( + self.patch.HybridGDNPrefixKey( + cumulative_prefix_hash=hashes[2], + prefix_len=2, + block_size=2, + cache_salt=None, + model_revision="rev-a", + layout_version=1, + tp_rank=0, + recurrent_dtype="float32", + conv_dtype="bfloat16", + ) + ) + self.patch.register_hybrid_apc_gdn_checkpoint( + self.patch.HybridGDNPrefixKey( + cumulative_prefix_hash=hashes[4], + prefix_len=4, + block_size=2, + cache_salt=None, + model_revision="rev-a", + layout_version=1, + tp_rank=0, + recurrent_dtype="float32", + conv_dtype="bfloat16", + ) + ) + request = types.SimpleNamespace( + prompt_token_ids=token_ids, + num_tokens=len(token_ids), + cache_salt=None, + ) + + with patch.dict( + os.environ, + {"QWEN36_HYBRID_APC_DISABLE_UNBACKED_PREFIX_READS": "1"}, + ): + self.assertEqual( + self.patch.backed_gdn_prefix_hit_len(scheduler, request), + 4, + ) + self.assertFalse( + self.patch.should_disable_unbacked_prefix_reads(scheduler, request) + ) + + def test_largest_backed_prefix_read_does_not_require_lower_checkpoint(self): + scheduler = _scheduler( + block_size=2, + enable_backed_prefix_reads=True, + use_qwen_hybrid_chunked_prefill=True, + ) + token_ids = [10, 11, 12, 13, 14] + hashes = self.patch._local_cumulative_prefix_hashes( + token_ids, + block_size=2, + max_prefix_len=4, + ) + self.patch.register_hybrid_apc_gdn_checkpoint( + self.patch.HybridGDNPrefixKey( + cumulative_prefix_hash=hashes[4], + prefix_len=4, + block_size=2, + cache_salt=None, + model_revision="rev-a", + layout_version=1, + tp_rank=0, + recurrent_dtype="float32", + conv_dtype="bfloat16", + ) + ) + request = types.SimpleNamespace( + request_id="req-largest-backed", + prompt_token_ids=token_ids, + num_tokens=len(token_ids), + cache_salt=None, + ) + + with patch.dict( + os.environ, + {"QWEN36_HYBRID_APC_DISABLE_UNBACKED_PREFIX_READS": "1"}, + ): + self.assertEqual( + self.patch.backed_gdn_prefix_hit_len(scheduler, request), + 4, + ) + self.assertFalse( + self.patch.should_disable_unbacked_prefix_reads(scheduler, request) + ) + + authorized = self.patch.pop_hybrid_apc_authorized_prefix_key( + prefix_len=4, + request_id="req-largest-backed", + cache_salt=None, + model_revision="rev-a", + layout_version=1, + tp_rank=0, + recurrent_dtype="float32", + conv_dtype="bfloat16", + ) + self.assertIsNotNone(authorized) + self.assertEqual(authorized.cumulative_prefix_hash, hashes[4]) + self.assertIsNone( + self.patch.pop_hybrid_apc_authorized_prefix_key( + prefix_len=2, + request_id="req-largest-backed", + cache_salt=None, + model_revision="rev-a", + layout_version=1, + tp_rank=0, + recurrent_dtype="float32", + conv_dtype="bfloat16", + ) + ) + + def test_partial_gdn_coverage_caps_prefix_read_to_backed_checkpoint(self): + scheduler = _scheduler( + block_size=2, + enable_backed_prefix_reads=True, + use_qwen_hybrid_chunked_prefill=True, + ) + token_ids = [10, 11, 12, 13, 14, 15, 16] + hashes = self.patch._local_cumulative_prefix_hashes( + token_ids, + block_size=2, + max_prefix_len=6, + ) + self.patch.register_hybrid_apc_gdn_checkpoint( + self.patch.HybridGDNPrefixKey( + cumulative_prefix_hash=hashes[4], + prefix_len=4, + block_size=2, + cache_salt=None, + model_revision="rev-a", + layout_version=1, + tp_rank=0, + recurrent_dtype="float32", + conv_dtype="bfloat16", + ) + ) + request = types.SimpleNamespace( + request_id="req-partial", + prompt_token_ids=token_ids, + num_tokens=len(token_ids), + cache_salt=None, + ) + + with patch.dict( + os.environ, + {"QWEN36_HYBRID_APC_DISABLE_UNBACKED_PREFIX_READS": "1"}, + ): + self.assertEqual( + self.patch.backed_gdn_prefix_hit_len(scheduler, request), + 4, + ) + self.assertFalse( + self.patch.should_disable_unbacked_prefix_reads(scheduler, request) + ) + + self.assertEqual( + getattr(request, self.patch._MAX_PREFIX_CACHE_HIT_LEN_ATTR), + 4, + ) + self.assertEqual( + getattr(request, self.patch._MAX_PREFIX_CACHE_BLOCKS_ATTR), + 2, + ) + self.assertIsNotNone( + self.patch.pop_hybrid_apc_authorized_prefix_key( + prefix_len=4, + request_id="req-partial", + cache_salt=None, + model_revision="rev-a", + layout_version=1, + tp_rank=0, + recurrent_dtype="float32", + conv_dtype="bfloat16", + ) + ) + + def test_max_backed_prefix_cap_selects_largest_backed_prefix_under_cap(self): + scheduler = _scheduler( + block_size=2, + enable_backed_prefix_reads=True, + use_qwen_hybrid_chunked_prefill=True, + additional_config={"hybrid_apc_max_backed_prefix_read_len": 2}, + ) + token_ids = [10, 11, 12, 13, 14] + hashes = self.patch._local_cumulative_prefix_hashes( + token_ids, + block_size=2, + max_prefix_len=4, + ) + for prefix_len in (2, 4): + self.patch.register_hybrid_apc_gdn_checkpoint( + self.patch.HybridGDNPrefixKey( + cumulative_prefix_hash=hashes[prefix_len], + prefix_len=prefix_len, + block_size=2, + cache_salt=None, + model_revision="rev-a", + layout_version=1, + tp_rank=0, + recurrent_dtype="float32", + conv_dtype="bfloat16", + ) + ) + request = types.SimpleNamespace( + request_id="req-capped", + prompt_token_ids=token_ids, + num_tokens=len(token_ids), + cache_salt=None, + ) + + with patch.dict( + os.environ, + {"QWEN36_HYBRID_APC_DISABLE_UNBACKED_PREFIX_READS": "1"}, + ): + self.assertFalse( + self.patch.should_disable_unbacked_prefix_reads(scheduler, request) + ) + + self.assertEqual( + getattr(request, self.patch._MAX_PREFIX_CACHE_HIT_LEN_ATTR), + 2, + ) + self.assertEqual( + getattr(request, self.patch._MAX_PREFIX_CACHE_BLOCKS_ATTR), + 1, + ) + self.assertIsNotNone( + self.patch.pop_hybrid_apc_authorized_prefix_key( + prefix_len=2, + request_id="req-capped", + cache_salt=None, + model_revision="rev-a", + layout_version=1, + tp_rank=0, + recurrent_dtype="float32", + conv_dtype="bfloat16", + ) + ) + + def test_kv_cache_manager_caps_prefix_hash_lookup_to_backed_len(self): + seen_hashes = [] + + class FakeKVCacheManager: + empty_kv_cache_blocks = "empty" + + def get_computed_blocks(self, request): + seen_hashes.append(tuple(request.block_hashes)) + return "blocks", len(request.block_hashes) * 2 + + fake_module = types.SimpleNamespace(KVCacheManager=FakeKVCacheManager) + self.assertTrue(self.patch._patch_kv_cache_manager_module(fake_module)) + + request = types.SimpleNamespace(block_hashes=[b"a", b"b", b"c", b"d"]) + setattr(request, self.patch._MAX_PREFIX_CACHE_BLOCKS_ATTR, 2) + + result = FakeKVCacheManager().get_computed_blocks(request) + + self.assertEqual(result, ("blocks", 4)) + self.assertEqual(seen_hashes, [(b"a", b"b")]) + self.assertEqual(request.block_hashes, [b"a", b"b", b"c", b"d"]) + + def test_additional_config_allows_prefix_read_when_hf_config_is_stale(self): + scheduler = _scheduler( + block_size=2, + enable_backed_prefix_reads=False, + use_qwen_hybrid_chunked_prefill=False, + additional_config={ + "use_hybrid_apc_manager": True, + "hybrid_apc_disable_unbacked_prefix_reads": True, + "hybrid_apc_enable_backed_prefix_reads": True, + "use_qwen_hybrid_chunked_prefill": True, + }, + ) + token_ids = [10, 11, 12, 13, 14] + hashes = self.patch._local_cumulative_prefix_hashes( + token_ids, + block_size=2, + max_prefix_len=4, + ) + self.patch.register_hybrid_apc_gdn_checkpoint( + self.patch.HybridGDNPrefixKey( + cumulative_prefix_hash=hashes[2], + prefix_len=2, + block_size=2, + cache_salt=None, + model_revision="rev-a", + layout_version=1, + tp_rank=0, + recurrent_dtype="float32", + conv_dtype="bfloat16", + ) + ) + self.patch.register_hybrid_apc_gdn_checkpoint( + self.patch.HybridGDNPrefixKey( + cumulative_prefix_hash=hashes[4], + prefix_len=4, + block_size=2, + cache_salt=None, + model_revision="rev-a", + layout_version=1, + tp_rank=0, + recurrent_dtype="float32", + conv_dtype="bfloat16", + ) + ) + request = types.SimpleNamespace( + prompt_token_ids=token_ids, + num_tokens=len(token_ids), + cache_salt=None, + ) + + self.assertFalse( + self.patch.should_disable_unbacked_prefix_reads(scheduler, request) + ) + authorized = self.patch.pop_hybrid_apc_authorized_prefix_key( + prefix_len=4, + cache_salt=None, + model_revision="rev-a", + layout_version=1, + tp_rank=0, + recurrent_dtype="float32", + conv_dtype="bfloat16", + ) + self.assertIsNotNone(authorized) + self.assertEqual(authorized.cumulative_prefix_hash, hashes[4]) + + def test_authorized_prefix_read_can_be_request_scoped(self): + key = self.patch.HybridGDNPrefixKey( + cumulative_prefix_hash="hash-a", + prefix_len=4, + block_size=2, + cache_salt=None, + model_revision="rev-a", + layout_version=1, + tp_rank=0, + recurrent_dtype="float32", + conv_dtype="bfloat16", + ) + + self.patch.authorize_hybrid_apc_prefix_read(key, request_id="req-a") + + self.assertIsNone( + self.patch.pop_hybrid_apc_authorized_prefix_key( + prefix_len=4, + request_id="req-b", + cache_salt=None, + model_revision="rev-a", + layout_version=1, + tp_rank=0, + recurrent_dtype="float32", + conv_dtype="bfloat16", + ) + ) + self.assertEqual( + self.patch.pop_hybrid_apc_authorized_prefix_key( + prefix_len=4, + request_id="req-a", + cache_salt=None, + model_revision="rev-a", + layout_version=1, + tp_rank=0, + recurrent_dtype="float32", + conv_dtype="bfloat16", + ), + key, + ) + + def test_scheduler_authorizes_backed_prefix_read_by_request_id(self): + scheduler = _scheduler( + block_size=2, + enable_backed_prefix_reads=True, + use_qwen_hybrid_chunked_prefill=True, + ) + token_ids = [10, 11, 12, 13, 14] + hashes = self.patch._local_cumulative_prefix_hashes( + token_ids, + block_size=2, + max_prefix_len=4, + ) + key = self.patch.HybridGDNPrefixKey( + cumulative_prefix_hash=hashes[4], + prefix_len=4, + block_size=2, + cache_salt=None, + model_revision="rev-a", + layout_version=1, + tp_rank=0, + recurrent_dtype="float32", + conv_dtype="bfloat16", + ) + self.patch.register_hybrid_apc_gdn_checkpoint( + self.patch.HybridGDNPrefixKey( + cumulative_prefix_hash=hashes[2], + prefix_len=2, + block_size=2, + cache_salt=None, + model_revision="rev-a", + layout_version=1, + tp_rank=0, + recurrent_dtype="float32", + conv_dtype="bfloat16", + ) + ) + self.patch.register_hybrid_apc_gdn_checkpoint(key) + request = types.SimpleNamespace( + request_id="req-a", + prompt_token_ids=token_ids, + num_tokens=len(token_ids), + cache_salt=None, + ) + + with patch.dict( + os.environ, + {"QWEN36_HYBRID_APC_DISABLE_UNBACKED_PREFIX_READS": "1"}, + ): + self.assertFalse( + self.patch.should_disable_unbacked_prefix_reads(scheduler, request) + ) + + self.assertIsNone( + self.patch.pop_hybrid_apc_authorized_prefix_key( + prefix_len=4, + request_id="req-b", + cache_salt=None, + model_revision="rev-a", + layout_version=1, + tp_rank=0, + recurrent_dtype="float32", + conv_dtype="bfloat16", + ) + ) + self.assertEqual( + self.patch.pop_hybrid_apc_authorized_prefix_key( + prefix_len=4, + request_id="req-a", + cache_salt=None, + model_revision="rev-a", + layout_version=1, + tp_rank=0, + recurrent_dtype="float32", + conv_dtype="bfloat16", + ), + key, + ) + + def test_backed_prefix_hit_uses_vllm_block_hashes_when_available(self): + scheduler = _scheduler(block_size=2) + key = self.patch.HybridGDNPrefixKey( + cumulative_prefix_hash=b"vllm-hash-4", + prefix_len=4, + block_size=2, + cache_salt=None, + model_revision="rev-a", + layout_version=1, + tp_rank=0, + recurrent_dtype="float32", + conv_dtype="bfloat16", + ) + self.patch.register_hybrid_apc_gdn_checkpoint(key) + request = types.SimpleNamespace( + prompt_token_ids=[10, 11, 12, 13, 14], + block_hashes=[b"vllm-hash-2", b"vllm-hash-4"], + num_tokens=5, + cache_salt=None, + ) + + self.assertEqual(self.patch.backed_gdn_prefix_hit(scheduler, request), key) + + def test_scheduler_output_carries_vllm_hashes_and_block_refs(self): + class FakeScheduler: + def __init__(self): + base = _scheduler(block_size=2) + self.vllm_config = base.vllm_config + self.cache_config = base.cache_config + self.scheduler_config = base.scheduler_config + self.requests = { + "req-a": types.SimpleNamespace( + prompt_token_ids=[10, 11, 12, 13], + block_hashes=[b"hash-2", b"hash-4"], + num_tokens=4, + cache_salt=None, + ) + } + + def add_request(self, request): + del request + + def schedule(self): + return types.SimpleNamespace( + scheduled_new_reqs=[ + types.SimpleNamespace( + req_id="req-a", + block_ids=([11, 12],), + num_computed_tokens=0, + ) + ], + scheduled_cached_reqs=types.SimpleNamespace( + req_ids=[], + new_block_ids=[], + num_computed_tokens=[], + ), + num_scheduled_tokens={"req-a": 4}, + ) + + self.patch.patch_scheduler_class(FakeScheduler) + scheduler_output = FakeScheduler().schedule() + metadata = getattr( + scheduler_output, + "_qwen36_hybrid_apc_metadata_by_request_id", + ) + + self.assertEqual( + metadata["req-a"]["cumulative_hashes_by_prefix_len"], + {2: b"hash-2", 4: b"hash-4"}, + ) + self.assertEqual( + metadata["req-a"]["attention_block_refs_by_prefix_len"], + {2: (11,), 4: (11, 12)}, + ) + self.assertEqual(metadata["req-a"]["request_prefix_len"], 4) + self.assertEqual(metadata["req-a"]["vllm_attention_hit_len"], 0) + self.assertEqual(metadata["req-a"]["active_suffix_len"], 4) + + def test_scheduler_output_caps_cached_request_prefix_to_current_chunk(self): + class FakeScheduler: + def __init__(self): + base = _scheduler(block_size=256) + self.vllm_config = base.vllm_config + self.cache_config = base.cache_config + self.scheduler_config = base.scheduler_config + self.requests = { + "req-a": types.SimpleNamespace( + request_id="req-a", + prompt_token_ids=list(range(8193)), + all_token_ids=list(range(8193)), + block_hashes=[f"hash-{idx}".encode() for idx in range(32)], + num_tokens=8193, + cache_salt=None, + ) + } + + def add_request(self, request): + del request + + def schedule(self): + return types.SimpleNamespace( + scheduled_new_reqs=[], + scheduled_cached_reqs=types.SimpleNamespace( + req_ids=["req-a"], + new_block_ids=[(list(range(13, 25)),)], + num_computed_tokens=[3072], + ), + num_scheduled_tokens={"req-a": 3072}, + ) + + self.patch.patch_scheduler_class(FakeScheduler) + scheduler_output = FakeScheduler().schedule() + metadata = getattr( + scheduler_output, + "_qwen36_hybrid_apc_metadata_by_request_id", + )["req-a"] + + self.assertEqual(metadata["request_prefix_len"], 6144) + self.assertEqual(metadata["vllm_attention_hit_len"], 3072) + self.assertEqual(metadata["active_suffix_len"], 3072) + self.assertEqual(len(metadata["full_input_ids"]), 6144) + self.assertIn(6144, metadata["cumulative_hashes_by_prefix_len"]) + self.assertNotIn(8192, metadata["cumulative_hashes_by_prefix_len"]) + + def test_scheduler_output_excludes_generated_tokens_from_prompt_metadata(self): + class FakeScheduler: + def __init__(self): + base = _scheduler(block_size=256) + self.vllm_config = base.vllm_config + self.cache_config = base.cache_config + self.scheduler_config = base.scheduler_config + self.requests = { + "req-a": types.SimpleNamespace( + request_id="req-a", + prompt_token_ids=list(range(8192)), + all_token_ids=list(range(8194)), + block_hashes=[f"hash-{idx}".encode() for idx in range(32)], + num_prompt_tokens=8192, + num_tokens=8194, + cache_salt=None, + ) + } + + def add_request(self, request): + del request + + def schedule(self): + return types.SimpleNamespace( + scheduled_new_reqs=[], + scheduled_cached_reqs=types.SimpleNamespace( + req_ids=["req-a"], + new_block_ids=[(list(range(25, 33)),)], + num_computed_tokens=[6144], + ), + num_scheduled_tokens={"req-a": 2050}, + ) + + self.patch.patch_scheduler_class(FakeScheduler) + scheduler_output = FakeScheduler().schedule() + metadata = getattr( + scheduler_output, + "_qwen36_hybrid_apc_metadata_by_request_id", + )["req-a"] + + self.assertEqual(metadata["request_prefix_len"], 8192) + self.assertEqual(metadata["vllm_attention_hit_len"], 6144) + self.assertEqual(metadata["active_suffix_len"], 2050) + self.assertEqual(len(metadata["full_input_ids"]), 8192) + self.assertEqual(metadata["full_input_ids"][-1], 8191) + self.assertIn(8192, metadata["cumulative_hashes_by_prefix_len"]) + self.assertNotIn(8448, metadata["cumulative_hashes_by_prefix_len"]) + + def test_scheduler_output_authorizes_backed_cached_continuation(self): + class FakeScheduler: + def __init__(self): + base = _scheduler( + block_size=2, + enable_backed_prefix_reads=True, + use_qwen_hybrid_chunked_prefill=True, + ) + self.vllm_config = base.vllm_config + self.cache_config = base.cache_config + self.scheduler_config = base.scheduler_config + self.requests = { + "req-a": types.SimpleNamespace( + request_id="req-a", + prompt_token_ids=[10, 11, 12, 13], + block_hashes=[b"hash-2", b"hash-4"], + num_tokens=4, + cache_salt=None, + ) + } + + def add_request(self, request): + del request + + def schedule(self): + return types.SimpleNamespace( + scheduled_new_reqs=[], + scheduled_cached_reqs=types.SimpleNamespace( + req_ids=["req-a"], + new_block_ids=[([12],)], + num_computed_tokens=[2], + num_output_tokens=[0], + ), + ) + + key = self.patch.HybridGDNPrefixKey( + cumulative_prefix_hash=b"hash-2", + prefix_len=2, + block_size=2, + cache_salt=None, + model_revision="rev-a", + layout_version=1, + tp_rank=0, + recurrent_dtype="float32", + conv_dtype="bfloat16", + ) + self.patch.register_hybrid_apc_gdn_checkpoint(key) + self.patch.patch_scheduler_class(FakeScheduler) + + FakeScheduler().schedule() + + self.assertEqual( + self.patch.pop_hybrid_apc_authorized_prefix_key( + prefix_len=2, + request_id="req-a", + model_revision="rev-a", + ), + key, + ) + + def test_scheduler_preserves_backed_and_cold_prefix_read_decisions_in_mixed_batch(self): + class FakeScheduler: + def __init__(self): + base = _scheduler( + block_size=4, + disable_unbacked_prefix_reads=True, + enable_backed_prefix_reads=True, + use_qwen_hybrid_chunked_prefill=True, + max_num_seqs=2, + ) + self.vllm_config = base.vllm_config + self.cache_config = base.cache_config + self.scheduler_config = base.scheduler_config + self.warm = types.SimpleNamespace( + request_id="warm", + prompt_token_ids=[10, 11, 12, 13, 14, 15], + num_tokens=6, + cache_salt=None, + skip_reading_prefix_cache=False, + ) + self.cold = types.SimpleNamespace( + request_id="cold", + prompt_token_ids=[20, 21, 22, 23, 24, 25], + num_tokens=6, + cache_salt=None, + skip_reading_prefix_cache=False, + ) + self.waiting = [self.warm, self.cold] + self.requests = {"warm": self.warm, "cold": self.cold} + self.schedule_seen_skip_flags = None + + def add_request(self, request): + del request + + def schedule(self): + self.schedule_seen_skip_flags = [ + (request.request_id, request.skip_reading_prefix_cache) + for request in self.waiting + ] + computed_tokens = [ + 0 if request.skip_reading_prefix_cache else 4 + for request in self.waiting + ] + self.waiting = [] + return types.SimpleNamespace( + scheduled_new_reqs=[], + scheduled_cached_reqs=types.SimpleNamespace( + req_ids=["warm", "cold"], + new_block_ids=[([1, 2],), ([3, 4],)], + num_computed_tokens=computed_tokens, + num_output_tokens=[0, 0], + ), + num_scheduled_tokens={ + "warm": 6 - computed_tokens[0], + "cold": 6 - computed_tokens[1], + }, + total_num_scheduled_tokens=12 - sum(computed_tokens), + ) + + warm_token_ids = [10, 11, 12, 13, 14, 15] + hashes = self.patch._local_cumulative_prefix_hashes( + warm_token_ids, + block_size=4, + max_prefix_len=4, + ) + self.patch.register_hybrid_apc_gdn_checkpoint( + self.patch.HybridGDNPrefixKey( + cumulative_prefix_hash=hashes[4], + prefix_len=4, + block_size=4, + cache_salt=None, + model_revision="rev-a", + layout_version=1, + tp_rank=0, + recurrent_dtype="float32", + conv_dtype="bfloat16", + ) + ) + + self.patch.patch_scheduler_class(FakeScheduler) + scheduler = FakeScheduler() + with patch.dict( + os.environ, + {"QWEN36_HYBRID_APC_DISABLE_UNBACKED_PREFIX_READS": "1"}, + ): + self.assertFalse( + self.patch.should_disable_unbacked_prefix_reads( + scheduler, + scheduler.warm, + ) + ) + if self.patch.should_disable_unbacked_prefix_reads( + scheduler, + scheduler.cold, + ): + scheduler.cold.skip_reading_prefix_cache = True + scheduler_output = scheduler.schedule() + + metadata = getattr( + scheduler_output, + "_qwen36_hybrid_apc_metadata_by_request_id", + ) + + self.assertEqual( + scheduler.schedule_seen_skip_flags, + [("warm", False), ("cold", True)], + ) + self.assertEqual( + scheduler_output.scheduled_cached_reqs.num_computed_tokens, + [4, 0], + ) + self.assertEqual(scheduler_output.num_scheduled_tokens["warm"], 2) + self.assertEqual(scheduler_output.num_scheduled_tokens["cold"], 6) + self.assertEqual(metadata["warm"]["vllm_attention_hit_len"], 4) + self.assertEqual(metadata["cold"]["vllm_attention_hit_len"], 0) + self.assertIsNotNone( + self.patch.pop_hybrid_apc_authorized_prefix_key( + prefix_len=4, + request_id="warm", + model_revision="rev-a", + ) + ) + + def test_scheduler_preserves_all_backed_context_batch(self): + class FakeScheduler: + def __init__(self): + base = _scheduler( + block_size=4, + disable_unbacked_prefix_reads=True, + enable_backed_prefix_reads=True, + use_qwen_hybrid_chunked_prefill=True, + max_num_seqs=2, + ) + self.vllm_config = base.vllm_config + self.cache_config = base.cache_config + self.scheduler_config = base.scheduler_config + self.req_a = types.SimpleNamespace( + request_id="req-a", + prompt_token_ids=[10, 11, 12, 13, 14, 15], + num_tokens=6, + cache_salt=None, + skip_reading_prefix_cache=False, + ) + self.req_b = types.SimpleNamespace( + request_id="req-b", + prompt_token_ids=[30, 31, 32, 33, 34, 35], + num_tokens=6, + cache_salt=None, + skip_reading_prefix_cache=False, + ) + self.waiting = [self.req_a, self.req_b] + self.requests = {"req-a": self.req_a, "req-b": self.req_b} + self.schedule_seen_skip_flags = None + + def add_request(self, request): + del request + + def schedule(self): + self.schedule_seen_skip_flags = [ + (request.request_id, request.skip_reading_prefix_cache) + for request in self.waiting + ] + self.waiting = [] + return types.SimpleNamespace( + scheduled_new_reqs=[], + scheduled_cached_reqs=types.SimpleNamespace( + req_ids=["req-a", "req-b"], + new_block_ids=[([1, 2],), ([3, 4],)], + num_computed_tokens=[4, 4], + num_output_tokens=[0, 0], + ), + num_scheduled_tokens={"req-a": 2, "req-b": 2}, + total_num_scheduled_tokens=4, + ) + + for token_ids in ([10, 11, 12, 13, 14, 15], [30, 31, 32, 33, 34, 35]): + hashes = self.patch._local_cumulative_prefix_hashes( + token_ids, + block_size=4, + max_prefix_len=4, + ) + self.patch.register_hybrid_apc_gdn_checkpoint( + self.patch.HybridGDNPrefixKey( + cumulative_prefix_hash=hashes[4], + prefix_len=4, + block_size=4, + cache_salt=None, + model_revision="rev-a", + layout_version=1, + tp_rank=0, + recurrent_dtype="float32", + conv_dtype="bfloat16", + ) + ) + + self.patch.patch_scheduler_class(FakeScheduler) + scheduler = FakeScheduler() + with patch.dict( + os.environ, + {"QWEN36_HYBRID_APC_DISABLE_UNBACKED_PREFIX_READS": "1"}, + ): + self.assertFalse( + self.patch.should_disable_unbacked_prefix_reads( + scheduler, + scheduler.req_a, + ) + ) + self.assertFalse( + self.patch.should_disable_unbacked_prefix_reads( + scheduler, + scheduler.req_b, + ) + ) + scheduler_output = scheduler.schedule() + + metadata = getattr( + scheduler_output, + "_qwen36_hybrid_apc_metadata_by_request_id", + ) + + self.assertEqual( + scheduler.schedule_seen_skip_flags, + [("req-a", False), ("req-b", False)], + ) + self.assertEqual( + scheduler_output.scheduled_cached_reqs.num_computed_tokens, + [4, 4], + ) + self.assertEqual(metadata["req-a"]["vllm_attention_hit_len"], 4) + self.assertEqual(metadata["req-b"]["vllm_attention_hit_len"], 4) + + def test_scheduler_defers_waiting_prefills_while_decode_running(self): + class FakeScheduler: + def __init__(self): + base = _scheduler( + use_hybrid_apc=True, + use_qwen_hybrid_chunked_prefill=True, + max_num_seqs=2, + ) + self.vllm_config = base.vllm_config + self.cache_config = base.cache_config + self.scheduler_config = base.scheduler_config + self.running = [types.SimpleNamespace(request_id="decode")] + self.waiting = [types.SimpleNamespace(request_id="prefill")] + self.waiting_seen_by_schedule = None + + def add_request(self, request): + self.waiting.append(request) + + def schedule(self): + self.waiting_seen_by_schedule = [ + request.request_id for request in self.waiting + ] + return types.SimpleNamespace( + scheduled_new_reqs=[], + scheduled_cached_reqs=None, + ) + + self.patch.patch_scheduler_class(FakeScheduler) + scheduler = FakeScheduler() + scheduler.schedule() + + self.assertEqual(scheduler.waiting_seen_by_schedule, []) + self.assertEqual( + [request.request_id for request in scheduler.waiting], + ["prefill"], + ) + + def test_scheduler_allows_mixed_prefill_decode_when_configured(self): + class FakeScheduler: + def __init__(self): + base = _scheduler( + use_hybrid_apc=True, + use_qwen_hybrid_chunked_prefill=True, + max_num_seqs=2, + additional_config={ + "hybrid_apc_allow_mixed_prefill_decode": True, + }, + ) + self.vllm_config = base.vllm_config + self.cache_config = base.cache_config + self.scheduler_config = base.scheduler_config + self.running = [types.SimpleNamespace(request_id="decode")] + self.waiting = [types.SimpleNamespace(request_id="prefill")] + self.waiting_seen_by_schedule = None + + def add_request(self, request): + self.waiting.append(request) + + def schedule(self): + self.waiting_seen_by_schedule = [ + request.request_id for request in self.waiting + ] + self.waiting = [] + return types.SimpleNamespace( + scheduled_new_reqs=[], + scheduled_cached_reqs=None, + ) + + self.patch.patch_scheduler_class(FakeScheduler) + scheduler = FakeScheduler() + scheduler.schedule() + + self.assertEqual(scheduler.waiting_seen_by_schedule, ["prefill"]) + self.assertEqual(scheduler.waiting, []) + + def test_scheduler_keeps_waiting_prefills_when_no_decode_running(self): + class FakeScheduler: + def __init__(self): + base = _scheduler( + use_hybrid_apc=True, + use_qwen_hybrid_chunked_prefill=True, + max_num_seqs=2, + ) + self.vllm_config = base.vllm_config + self.cache_config = base.cache_config + self.scheduler_config = base.scheduler_config + self.running = [] + self.waiting = [types.SimpleNamespace(request_id="prefill")] + self.waiting_seen_by_schedule = None + + def add_request(self, request): + self.waiting.append(request) + + def schedule(self): + self.waiting_seen_by_schedule = [ + request.request_id for request in self.waiting + ] + return types.SimpleNamespace( + scheduled_new_reqs=[], + scheduled_cached_reqs=None, + ) + + self.patch.patch_scheduler_class(FakeScheduler) + scheduler = FakeScheduler() + scheduler.schedule() + + self.assertEqual(scheduler.waiting_seen_by_schedule, ["prefill"]) + + def test_scheduler_output_metadata_does_not_rewrite_cached_context_hit(self): + class FakeScheduler: + def __init__(self): + base = _scheduler( + block_size=4, + enable_backed_prefix_reads=True, + use_qwen_hybrid_chunked_prefill=True, + ) + self.vllm_config = base.vllm_config + self.cache_config = base.cache_config + self.scheduler_config = base.scheduler_config + self.requests = { + "req-a": types.SimpleNamespace( + request_id="req-a", + prompt_token_ids=[10, 11, 12, 13, 14, 15], + num_tokens=6, + cache_salt=None, + ) + } + + def add_request(self, request): + del request + + def schedule(self): + return types.SimpleNamespace( + scheduled_new_reqs=[], + scheduled_cached_reqs=types.SimpleNamespace( + req_ids=["req-a"], + new_block_ids=[([11, 12],)], + num_computed_tokens=[6], + num_output_tokens=[0], + ), + num_scheduled_tokens={"req-a": 1}, + total_num_scheduled_tokens=1, + ) + + token_ids = [10, 11, 12, 13, 14, 15] + hashes = self.patch._local_cumulative_prefix_hashes( + token_ids, + block_size=4, + max_prefix_len=4, + ) + self.patch.register_hybrid_apc_gdn_checkpoint( + self.patch.HybridGDNPrefixKey( + cumulative_prefix_hash=hashes[4], + prefix_len=4, + block_size=4, + cache_salt=None, + model_revision="rev-a", + layout_version=1, + tp_rank=0, + recurrent_dtype="float32", + conv_dtype="bfloat16", + ) + ) + + self.patch.patch_scheduler_class(FakeScheduler) + with patch.dict( + os.environ, + {"QWEN36_HYBRID_APC_DISABLE_UNBACKED_PREFIX_READS": "1"}, + ): + scheduler_output = FakeScheduler().schedule() + + cached_reqs = scheduler_output.scheduled_cached_reqs + metadata = getattr( + scheduler_output, + "_qwen36_hybrid_apc_metadata_by_request_id", + ) + + self.assertEqual(cached_reqs.num_computed_tokens, [6]) + self.assertEqual(scheduler_output.num_scheduled_tokens["req-a"], 1) + self.assertEqual(scheduler_output.total_num_scheduled_tokens, 1) + self.assertEqual(metadata["req-a"]["vllm_attention_hit_len"], 6) + self.assertEqual(metadata["req-a"]["request_prefix_len"], 6) + + def test_scheduler_output_does_not_cap_decode_rows(self): + class FakeScheduler: + def __init__(self): + base = _scheduler( + block_size=4, + enable_backed_prefix_reads=True, + use_qwen_hybrid_chunked_prefill=True, + ) + self.vllm_config = base.vllm_config + self.cache_config = base.cache_config + self.scheduler_config = base.scheduler_config + self.requests = { + "req-a": types.SimpleNamespace( + request_id="req-a", + prompt_token_ids=[10, 11, 12, 13, 14, 15], + num_tokens=6, + cache_salt=None, + ) + } + + def add_request(self, request): + del request + + def schedule(self): + return types.SimpleNamespace( + scheduled_new_reqs=[], + scheduled_cached_reqs=types.SimpleNamespace( + req_ids=["req-a"], + new_block_ids=[([11, 12],)], + num_computed_tokens=[6], + num_output_tokens=[1], + ), + num_scheduled_tokens={"req-a": 1}, + total_num_scheduled_tokens=1, + ) + + token_ids = [10, 11, 12, 13, 14, 15] + hashes = self.patch._local_cumulative_prefix_hashes( + token_ids, + block_size=4, + max_prefix_len=4, + ) + self.patch.register_hybrid_apc_gdn_checkpoint( + self.patch.HybridGDNPrefixKey( + cumulative_prefix_hash=hashes[4], + prefix_len=4, + block_size=4, + cache_salt=None, + model_revision="rev-a", + layout_version=1, + tp_rank=0, + recurrent_dtype="float32", + conv_dtype="bfloat16", + ) + ) + + self.patch.patch_scheduler_class(FakeScheduler) + with patch.dict( + os.environ, + {"QWEN36_HYBRID_APC_DISABLE_UNBACKED_PREFIX_READS": "1"}, + ): + scheduler_output = FakeScheduler().schedule() + + metadata = getattr( + scheduler_output, + "_qwen36_hybrid_apc_metadata_by_request_id", + ) + + self.assertEqual(scheduler_output.scheduled_cached_reqs.num_computed_tokens, [6]) + self.assertEqual(scheduler_output.num_scheduled_tokens["req-a"], 1) + self.assertEqual(scheduler_output.total_num_scheduled_tokens, 1) + self.assertEqual(metadata["req-a"]["vllm_attention_hit_len"], 6) + + def test_backed_prefix_read_allows_batched_scheduler_when_configured(self): + scheduler = _scheduler( + block_size=2, + enable_backed_prefix_reads=True, + use_qwen_hybrid_chunked_prefill=True, + max_num_seqs=2, + ) + token_ids = [10, 11, 12, 13, 14] + hashes = self.patch._local_cumulative_prefix_hashes( + token_ids, + block_size=2, + max_prefix_len=4, + ) + self.patch.register_hybrid_apc_gdn_checkpoint( + self.patch.HybridGDNPrefixKey( + cumulative_prefix_hash=hashes[2], + prefix_len=2, + block_size=2, + cache_salt=None, + model_revision="rev-a", + layout_version=1, + tp_rank=0, + recurrent_dtype="float32", + conv_dtype="bfloat16", + ) + ) + self.patch.register_hybrid_apc_gdn_checkpoint( + self.patch.HybridGDNPrefixKey( + cumulative_prefix_hash=hashes[4], + prefix_len=4, + block_size=2, + cache_salt=None, + model_revision="rev-a", + layout_version=1, + tp_rank=0, + recurrent_dtype="float32", + conv_dtype="bfloat16", + ) + ) + request = types.SimpleNamespace( + prompt_token_ids=token_ids, + num_tokens=len(token_ids), + cache_salt=None, + ) + + with patch.dict( + os.environ, + {"QWEN36_HYBRID_APC_DISABLE_UNBACKED_PREFIX_READS": "1"}, + ): + self.assertFalse( + self.patch.should_disable_unbacked_prefix_reads(scheduler, request) + ) + self.assertIsNotNone( + self.patch.pop_hybrid_apc_authorized_prefix_key( + prefix_len=4, + cache_salt=None, + model_revision="rev-a", + layout_version=1, + tp_rank=0, + recurrent_dtype="float32", + conv_dtype="bfloat16", + ) + ) + + def test_env_backed_prefix_override_allows_batched_scheduler(self): + scheduler = _scheduler( + block_size=2, + enable_backed_prefix_reads=False, + use_qwen_hybrid_chunked_prefill=False, + max_num_seqs=2, + ) + token_ids = [10, 11, 12, 13, 14] + hashes = self.patch._local_cumulative_prefix_hashes( + token_ids, + block_size=2, + max_prefix_len=4, + ) + self.patch.register_hybrid_apc_gdn_checkpoint( + self.patch.HybridGDNPrefixKey( + cumulative_prefix_hash=hashes[2], + prefix_len=2, + block_size=2, + cache_salt=None, + model_revision="rev-a", + layout_version=1, + tp_rank=0, + recurrent_dtype="float32", + conv_dtype="bfloat16", + ) + ) + self.patch.register_hybrid_apc_gdn_checkpoint( + self.patch.HybridGDNPrefixKey( + cumulative_prefix_hash=hashes[4], + prefix_len=4, + block_size=2, + cache_salt=None, + model_revision="rev-a", + layout_version=1, + tp_rank=0, + recurrent_dtype="float32", + conv_dtype="bfloat16", + ) + ) + request = types.SimpleNamespace( + prompt_token_ids=token_ids, + num_tokens=len(token_ids), + cache_salt=None, + ) + + with patch.dict( + os.environ, + { + "QWEN36_HYBRID_APC_DISABLE_UNBACKED_PREFIX_READS": "1", + "QWEN36_HYBRID_APC_ENABLE_BACKED_PREFIX_READS": "1", + }, + ): + self.assertFalse( + self.patch.should_disable_unbacked_prefix_reads(scheduler, request) + ) + self.assertIsNotNone( + self.patch.pop_hybrid_apc_authorized_prefix_key( + prefix_len=4, + cache_salt=None, + model_revision="rev-a", + layout_version=1, + tp_rank=0, + recurrent_dtype="float32", + conv_dtype="bfloat16", + ) + ) + + def test_additional_config_overrides_scheduler_registry_key_metadata(self): + scheduler = _scheduler( + block_size=2, + model_revision="stale-rev", + additional_config={ + "hybrid_apc_model_revision": "runtime-rev", + "hybrid_apc_layout_version": 2, + "tp_rank": 3, + "hybrid_recurrent_cache_dtype": "bf16", + "hybrid_conv_cache_dtype": "float32", + }, + ) + token_ids = [10, 11, 12, 13, 14] + hashes = self.patch._local_cumulative_prefix_hashes( + token_ids, + block_size=2, + max_prefix_len=4, + ) + key = self.patch.HybridGDNPrefixKey( + cumulative_prefix_hash=hashes[4], + prefix_len=4, + block_size=2, + cache_salt=None, + model_revision="runtime-rev", + layout_version=2, + tp_rank=3, + recurrent_dtype="bfloat16", + conv_dtype="float32", + ) + self.patch.register_hybrid_apc_gdn_checkpoint(key) + request = types.SimpleNamespace( + prompt_token_ids=token_ids, + num_tokens=len(token_ids), + cache_salt=None, + ) + + self.assertEqual(self.patch.backed_gdn_prefix_hit(scheduler, request), key) + + def test_mismatched_gdn_checkpoint_keeps_prefix_read_disabled(self): + scheduler = _scheduler(block_size=2) + token_ids = [10, 11, 12, 13, 14] + hashes = self.patch._local_cumulative_prefix_hashes( + token_ids, + block_size=2, + max_prefix_len=4, + ) + self.patch.register_hybrid_apc_gdn_checkpoint( + self.patch.HybridGDNPrefixKey( + cumulative_prefix_hash=hashes[4], + prefix_len=4, + block_size=2, + cache_salt="tenant-a", + model_revision="rev-a", + layout_version=1, + tp_rank=0, + recurrent_dtype="float32", + conv_dtype="bfloat16", + ) + ) + request = types.SimpleNamespace( + prompt_token_ids=token_ids, + num_tokens=len(token_ids), + cache_salt="tenant-b", + ) + + with patch.dict( + os.environ, + {"QWEN36_HYBRID_APC_DISABLE_UNBACKED_PREFIX_READS": "1"}, + ): + self.assertEqual( + self.patch.backed_gdn_prefix_hit_len(scheduler, request), + 0, + ) + self.assertTrue( + self.patch.should_disable_unbacked_prefix_reads(scheduler, request) + ) + + def test_missing_model_revision_defaults_to_unknown(self): + scheduler = _scheduler(block_size=2, model_revision=None) + token_ids = [10, 11, 12, 13, 14] + hashes = self.patch._local_cumulative_prefix_hashes( + token_ids, + block_size=2, + max_prefix_len=4, + ) + self.patch.register_hybrid_apc_gdn_checkpoint( + self.patch.HybridGDNPrefixKey( + cumulative_prefix_hash=hashes[4], + prefix_len=4, + block_size=2, + cache_salt=None, + model_revision="unknown", + layout_version=1, + tp_rank=0, + recurrent_dtype="float32", + conv_dtype="bfloat16", + ) + ) + request = types.SimpleNamespace( + prompt_token_ids=token_ids, + num_tokens=len(token_ids), + cache_salt=None, + ) + + self.assertEqual( + self.patch.backed_gdn_prefix_hit_len(scheduler, request), + 4, + ) + + def test_import_hook_does_not_import_scheduler_immediately(self): + installed = self.patch.install_import_hook() + + self.assertFalse(installed) + self.assertNotIn(_SCHEDULER_MODULE, sys.modules) + self.assertTrue( + any( + getattr(finder, "_qwen36_hybrid_apc_import_hook", False) + for finder in sys.meta_path + ) + ) + + def test_import_hook_patches_already_loaded_scheduler_module(self): + calls = [] + + class FakeScheduler: + def __init__(self): + self.vllm_config = _scheduler( + disable_unbacked_prefix_reads=True + ).vllm_config + + def add_request(self, request): + calls.append(request.skip_reading_prefix_cache) + + module = types.SimpleNamespace(Scheduler=FakeScheduler) + sys.modules[_SCHEDULER_MODULE] = module + + installed = self.patch.install_import_hook() + request = types.SimpleNamespace(skip_reading_prefix_cache=False) + FakeScheduler().add_request(request) + + self.assertTrue(installed) + self.assertEqual(calls, [True]) + self.assertTrue(request.skip_reading_prefix_cache) + + def test_runner_patch_exposes_request_ids_during_model_execution(self): + class FakeRunner: + def __init__(self): + self.model = types.SimpleNamespace(model=types.SimpleNamespace()) + + def _execute_model_for_text(self, model_input, intermediate_tensors=None): + self.seen_request_ids = getattr( + self.model.model, + "_qwen36_vllm_request_ids", + None, + ) + self.seen_cached_request_ids = getattr( + self.model.model, + "_qwen36_vllm_cached_request_ids", + None, + ) + self.seen_prefill_completion_state = getattr( + self.model.model, + "_qwen36_vllm_prefill_completion_state", + None, + ) + self.seen_metadata = getattr( + self.model.model, + "_qwen36_vllm_hybrid_apc_metadata_by_request_id", + None, + ) + self.seen_request_records = getattr( + self.model.model, + "_qwen36_vllm_hybrid_apc_request_records", + None, + ) + return self.seen_request_ids + + installed = self.patch.patch_neuron_model_runner_class(FakeRunner) + runner = FakeRunner() + result = runner._execute_model_for_text( + types.SimpleNamespace( + request_ids=["req-a"], + _qwen36_cached_request_ids=("req-a",), + prefill_completion_state="done", + _qwen36_hybrid_apc_metadata_by_request_id={ + "req-a": {"cumulative_hashes_by_prefix_len": {4: b"h4"}} + }, + _qwen36_hybrid_apc_request_records=( + { + "request_id": "req-a", + "cumulative_hashes_by_prefix_len": {4: b"h4"}, + }, + ), + ) + ) + + self.assertTrue(installed) + self.assertEqual(result, ("req-a",)) + self.assertEqual(runner.seen_request_ids, ("req-a",)) + self.assertEqual(runner.seen_cached_request_ids, ("req-a",)) + self.assertEqual(runner.seen_prefill_completion_state, "done") + self.assertEqual( + runner.seen_metadata, + {"req-a": {"cumulative_hashes_by_prefix_len": {4: b"h4"}}}, + ) + self.assertEqual( + runner.seen_request_records, + ( + { + "request_id": "req-a", + "cumulative_hashes_by_prefix_len": {4: b"h4"}, + }, + ), + ) + self.assertFalse(hasattr(runner.model, "_qwen36_vllm_request_ids")) + self.assertFalse(hasattr(runner.model.model, "_qwen36_vllm_request_ids")) + self.assertFalse(hasattr(runner.model.model, "_qwen36_vllm_cached_request_ids")) + self.assertFalse( + hasattr(runner.model.model, "_qwen36_vllm_hybrid_apc_request_records") + ) + + def test_runner_patch_applies_runtime_hybrid_apc_config_during_execution(self): + class FakeRunner: + def __init__(self): + self.vllm_config = types.SimpleNamespace( + additional_config={ + "hybrid_apc_require_vllm_metadata": True, + "hybrid_apc_enable_backed_prefix_reads": True, + } + ) + self.model = types.SimpleNamespace( + model=types.SimpleNamespace( + config=types.SimpleNamespace( + hybrid_apc_require_vllm_metadata=False, + hybrid_apc_allow_local_hash_fallback=True, + hybrid_apc_require_attention_block_refs=False, + hybrid_apc_reject_unbacked_attention_hits=False, + hybrid_apc_enable_backed_prefix_reads=False, + ), + hybrid_apc_bridge=types.SimpleNamespace( + allow_local_hash_fallback=True, + require_attention_block_refs=False, + reject_unbacked_attention_hits=False, + ), + ) + ) + + def _execute_model_for_text(self, model_input, intermediate_tensors=None): + del model_input, intermediate_tensors + config = self.model.model.config + bridge = self.model.model.hybrid_apc_bridge + return ( + config.hybrid_apc_require_vllm_metadata, + config.hybrid_apc_allow_local_hash_fallback, + config.hybrid_apc_require_attention_block_refs, + config.hybrid_apc_reject_unbacked_attention_hits, + config.hybrid_apc_enable_backed_prefix_reads, + bridge.allow_local_hash_fallback, + bridge.require_attention_block_refs, + bridge.reject_unbacked_attention_hits, + ) + + installed = self.patch.patch_neuron_model_runner_class(FakeRunner) + runner = FakeRunner() + result = runner._execute_model_for_text(types.SimpleNamespace()) + + self.assertTrue(installed) + self.assertEqual(result, (True, False, True, True, True, False, True, True)) + config = runner.model.model.config + self.assertFalse(config.hybrid_apc_require_vllm_metadata) + self.assertTrue(config.hybrid_apc_allow_local_hash_fallback) + self.assertFalse(config.hybrid_apc_require_attention_block_refs) + self.assertFalse(config.hybrid_apc_reject_unbacked_attention_hits) + self.assertFalse(config.hybrid_apc_enable_backed_prefix_reads) + bridge = runner.model.model.hybrid_apc_bridge + self.assertTrue(bridge.allow_local_hash_fallback) + self.assertFalse(bridge.require_attention_block_refs) + self.assertFalse(bridge.reject_unbacked_attention_hits) + + def test_runner_patch_attaches_scheduler_request_sources_to_model_input(self): + @dataclass(frozen=True) + class FrozenModelInput: + request_ids: list[str] + + class FakeRunner: + def __init__(self): + self.model = types.SimpleNamespace(model=types.SimpleNamespace()) + + def _prepare_model_input(self, scheduler_output): + del scheduler_output + return FrozenModelInput(request_ids=["cached-1", "new-1"]) + + def _execute_model_for_text(self, model_input, intermediate_tensors=None): + del intermediate_tensors + return model_input + + installed = self.patch.patch_neuron_model_runner_class(FakeRunner) + runner = FakeRunner() + scheduler_output = types.SimpleNamespace( + scheduled_cached_reqs=types.SimpleNamespace(req_ids=["cached-1"]), + scheduled_new_reqs=[types.SimpleNamespace(req_id="new-1")], + _qwen36_hybrid_apc_metadata_by_request_id={ + "new-1": {"attention_block_refs_by_prefix_len": {4: (3, 4)}} + }, + ) + model_input = runner._prepare_model_input(scheduler_output) + + self.assertTrue(installed) + self.assertEqual(model_input._qwen36_cached_request_ids, ("cached-1",)) + self.assertEqual(model_input._qwen36_new_request_ids, ("new-1",)) + self.assertEqual( + model_input._qwen36_hybrid_apc_metadata_by_request_id, + {"new-1": {"attention_block_refs_by_prefix_len": {4: (3, 4)}}}, + ) + self.assertEqual( + model_input._qwen36_hybrid_apc_request_records, + ( + {"request_id": "cached-1"}, + { + "request_id": "new-1", + "attention_block_refs_by_prefix_len": {4: (3, 4)}, + }, + ), + ) + + def test_runner_patch_builds_ordered_hybrid_apc_request_records(self): + @dataclass(frozen=True) + class FrozenModelInput: + request_ids: list[str] + + class FakeRunner: + def __init__(self): + self.model = types.SimpleNamespace(model=types.SimpleNamespace()) + + def _prepare_model_input(self, scheduler_output): + del scheduler_output + return FrozenModelInput(request_ids=["warm", "cold"]) + + def _execute_model_for_text(self, model_input, intermediate_tensors=None): + del intermediate_tensors + return model_input + + installed = self.patch.patch_neuron_model_runner_class(FakeRunner) + runner = FakeRunner() + scheduler_output = types.SimpleNamespace( + scheduled_cached_reqs=types.SimpleNamespace(req_ids=["warm"]), + scheduled_new_reqs=[types.SimpleNamespace(req_id="cold")], + num_scheduled_tokens={"warm": 2, "cold": 6}, + _qwen36_hybrid_apc_metadata_by_request_id={ + "warm": { + "vllm_attention_hit_len": 4, + "request_prefix_len": 6, + "cumulative_hashes_by_prefix_len": {4: b"warm-h4"}, + "attention_block_refs_by_prefix_len": {4: (1, 2)}, + }, + "cold": { + "vllm_attention_hit_len": 0, + "request_prefix_len": 6, + }, + }, + ) + model_input = runner._prepare_model_input(scheduler_output) + + self.assertTrue(installed) + self.assertEqual( + model_input._qwen36_hybrid_apc_request_records, + ( + { + "request_id": "warm", + "cumulative_hashes_by_prefix_len": {4: b"warm-h4"}, + "attention_block_refs_by_prefix_len": {4: (1, 2)}, + "request_prefix_len": 6, + "vllm_attention_hit_len": 4, + "active_suffix_len": 2, + }, + { + "request_id": "cold", + "request_prefix_len": 6, + "vllm_attention_hit_len": 0, + "active_suffix_len": 6, + }, + ), + ) + + def test_runner_patch_uses_scheduler_ids_when_model_input_has_no_ids(self): + @dataclass(frozen=True) + class FrozenModelInput: + pass + + class FakeRunner: + def __init__(self): + self.model = types.SimpleNamespace(model=types.SimpleNamespace()) + + def _prepare_model_input(self, scheduler_output): + del scheduler_output + return FrozenModelInput() + + def _execute_model_for_text(self, model_input, intermediate_tensors=None): + del intermediate_tensors + return model_input + + installed = self.patch.patch_neuron_model_runner_class(FakeRunner) + runner = FakeRunner() + scheduler_output = types.SimpleNamespace( + scheduled_cached_reqs=types.SimpleNamespace(req_ids=["warm"]), + scheduled_new_reqs=[], + num_scheduled_tokens={"warm": 2}, + _qwen36_hybrid_apc_metadata_by_request_id={ + "warm": { + "vllm_attention_hit_len": 2, + "request_prefix_len": 4, + "active_suffix_len": 2, + }, + }, + ) + model_input = runner._prepare_model_input(scheduler_output) + + self.assertTrue(installed) + self.assertEqual(model_input._qwen36_cached_request_ids, ("warm",)) + self.assertEqual( + model_input._qwen36_hybrid_apc_request_records, + ( + { + "request_id": "warm", + "request_prefix_len": 4, + "vllm_attention_hit_len": 2, + "active_suffix_len": 2, + }, + ), + ) + + def test_runner_patch_expands_completed_only_prefill_logits(self): + class FakeRunner: + def __init__(self): + self.model = types.SimpleNamespace(model=types.SimpleNamespace()) + + def _execute_model_for_text(self, model_input, intermediate_tensors=None): + del intermediate_tensors + return model_input + + def _prepare_logits_for_sampling(self, hidden_states, model_input): + hidden_states = hidden_states.clone() + for idx, state in enumerate(model_input.prefill_completion_state): + if not state.item(): + hidden_states[idx] = float("-inf") + return hidden_states + + self.patch.patch_neuron_model_runner_class(FakeRunner) + runner = FakeRunner() + logits = runner._prepare_logits_for_sampling( + torch.tensor([[1.0, 2.0, 3.0]]), + types.SimpleNamespace( + request_ids=["req-a", "req-b"], + prefill_completion_state=torch.tensor([True, False]), + ), + ) + + self.assertEqual(tuple(logits.shape), (2, 3)) + torch.testing.assert_close(logits[0], torch.tensor([1.0, 2.0, 3.0])) + self.assertTrue(torch.isneginf(logits[1]).all()) + + def test_runner_patch_leaves_full_prefill_logits_unchanged(self): + class FakeRunner: + def __init__(self): + self.model = types.SimpleNamespace(model=types.SimpleNamespace()) + + def _execute_model_for_text(self, model_input, intermediate_tensors=None): + del intermediate_tensors + return model_input + + def _prepare_logits_for_sampling(self, hidden_states, model_input): + return hidden_states + + self.patch.patch_neuron_model_runner_class(FakeRunner) + runner = FakeRunner() + hidden_states = torch.tensor([[1.0, 2.0], [3.0, 4.0]]) + logits = runner._prepare_logits_for_sampling( + hidden_states, + types.SimpleNamespace( + request_ids=["req-a", "req-b"], + prefill_completion_state=torch.tensor([True, False]), + ), + ) + + torch.testing.assert_close(logits, hidden_states) + + def test_runner_patch_clones_inference_tensor_before_on_device_prefill_mask(self): + class FakeRunner: + def __init__(self): + self.model = types.SimpleNamespace(model=types.SimpleNamespace()) + + def _execute_model_for_text(self, model_input, intermediate_tensors=None): + del intermediate_tensors + return model_input + + def _sample_on_device(self, hidden_states, model_input): + for idx, state in enumerate(model_input.prefill_completion_state): + if not state.item(): + hidden_states[idx] = -1 + return hidden_states + + self.patch.patch_neuron_model_runner_class(FakeRunner) + runner = FakeRunner() + with torch.inference_mode(): + hidden_states = torch.tensor([11, 22], dtype=torch.int32) + + sampled = runner._sample_on_device( + hidden_states, + types.SimpleNamespace( + prefill_completion_state=torch.tensor([True, False]), + ), + ) + + torch.testing.assert_close(sampled, torch.tensor([11, -1], dtype=torch.int32)) + torch.testing.assert_close(hidden_states, torch.tensor([11, 22], dtype=torch.int32)) + + def test_runner_patch_masks_sampled_tokens_for_incomplete_prefill_rows(self): + class FakeRunner: + def __init__(self): + self.model = types.SimpleNamespace(model=types.SimpleNamespace()) + + def _execute_model_for_text(self, model_input, intermediate_tensors=None): + del intermediate_tensors + return model_input + + def _sample_on_device(self, hidden_states, model_input): + del hidden_states, model_input + return types.SimpleNamespace( + sampled_token_ids=torch.tensor([[0], [33]], dtype=torch.int32), + logprobs_tensors=None, + ) + + self.patch.patch_neuron_model_runner_class(FakeRunner) + runner = FakeRunner() + sampled = runner._sample_on_device( + torch.tensor([[11], [22]], dtype=torch.int32), + types.SimpleNamespace( + prefill_completion_state=torch.tensor([False, True]), + ), + ) + + torch.testing.assert_close( + sampled.sampled_token_ids, + torch.tensor([[-1], [33]], dtype=torch.int32), + ) + + def test_runner_patch_repairs_invalid_completed_prefill_sampled_token_from_logits(self): + class FakeRunner: + def __init__(self): + self.model = types.SimpleNamespace( + model=types.SimpleNamespace( + config=types.SimpleNamespace(vocab_size=3) + ) + ) + + def _execute_model_for_text(self, model_input, intermediate_tensors=None): + del intermediate_tensors + return model_input + + def _sample_on_device(self, hidden_states, model_input): + del hidden_states, model_input + return types.SimpleNamespace( + sampled_token_ids=torch.tensor( + [[2147483647]], dtype=torch.int32 + ), + logprobs_tensors=None, + ) + + self.patch.patch_neuron_model_runner_class(FakeRunner) + runner = FakeRunner() + + sampled = runner._sample_on_device( + [ + torch.tensor([[2147483647]], dtype=torch.int32), + torch.tensor([[[0.1, 3.0, 0.2]]], dtype=torch.float32), + ], + types.SimpleNamespace( + prefill_completion_state=torch.tensor([True]), + ), + ) + + torch.testing.assert_close( + sampled.sampled_token_ids, + torch.tensor([[1]], dtype=torch.int32), + ) + + def test_runner_patch_repairs_completed_only_output_row_from_logits(self): + class FakeRunner: + def __init__(self): + self.model = types.SimpleNamespace( + model=types.SimpleNamespace( + config=types.SimpleNamespace(vocab_size=4) + ) + ) + + def _execute_model_for_text(self, model_input, intermediate_tensors=None): + del intermediate_tensors + return model_input + + def _sample_on_device(self, hidden_states, model_input): + del hidden_states, model_input + return types.SimpleNamespace( + sampled_token_ids=torch.tensor( + [[2147483647]], dtype=torch.int32 + ), + logprobs_tensors=None, + ) + + self.patch.patch_neuron_model_runner_class(FakeRunner) + runner = FakeRunner() + + sampled = runner._sample_on_device( + [ + torch.tensor([[2147483647]], dtype=torch.int32), + torch.tensor([[[0.1, 0.2, 0.3, 4.0]]], dtype=torch.float32), + ], + types.SimpleNamespace( + prefill_completion_state=torch.tensor([False, True]), + ), + ) + + torch.testing.assert_close( + sampled.sampled_token_ids, + torch.tensor([[3]], dtype=torch.int32), + ) + + def test_runner_patch_rejects_invalid_completed_prefill_without_logits(self): + class FakeRunner: + def __init__(self): + self.model = types.SimpleNamespace( + model=types.SimpleNamespace( + config=types.SimpleNamespace(vocab_size=248320) + ) + ) + + def _execute_model_for_text(self, model_input, intermediate_tensors=None): + del intermediate_tensors + return model_input + + def _sample_on_device(self, hidden_states, model_input): + del hidden_states, model_input + return types.SimpleNamespace( + sampled_token_ids=torch.tensor( + [[2147483647]], dtype=torch.int32 + ), + logprobs_tensors=None, + ) + + self.patch.patch_neuron_model_runner_class(FakeRunner) + runner = FakeRunner() + + with self.assertRaisesRegex( + ValueError, + "--output-logits-with-on-device-sampling", + ): + runner._sample_on_device( + torch.tensor([[2147483647]], dtype=torch.int32), + types.SimpleNamespace( + prefill_completion_state=torch.tensor([True]), + ), + ) + + def test_runner_patch_rejects_invalid_completed_prefill_with_sharded_logits(self): + class FakeRunner: + def __init__(self): + self.model = types.SimpleNamespace( + model=types.SimpleNamespace( + config=types.SimpleNamespace(vocab_size=16) + ) + ) + + def _execute_model_for_text(self, model_input, intermediate_tensors=None): + del intermediate_tensors + return model_input + + def _sample_on_device(self, hidden_states, model_input): + del hidden_states, model_input + return types.SimpleNamespace( + sampled_token_ids=torch.tensor( + [[2147483647]], dtype=torch.int32 + ), + logprobs_tensors=None, + ) + + self.patch.patch_neuron_model_runner_class(FakeRunner) + runner = FakeRunner() + + with self.assertRaisesRegex( + ValueError, + "gathers vocab-parallel output logits", + ): + runner._sample_on_device( + [ + torch.tensor([[2147483647]], dtype=torch.int32), + torch.tensor([[[0.1, 0.2, 4.0, 0.3]]], dtype=torch.float32), + ], + types.SimpleNamespace( + prefill_completion_state=torch.tensor([True]), + ), + ) + + def test_runner_patch_masks_cpu_sampled_tokens_before_output_update(self): + class FakeRunner: + def __init__(self): + self.model = types.SimpleNamespace(model=types.SimpleNamespace()) + + def _execute_model_for_text(self, model_input, intermediate_tensors=None): + del intermediate_tensors + return model_input + + def _generate_model_runner_output(self, sampler_output): + return sampler_output.sampled_token_ids + + self.patch.patch_neuron_model_runner_class(FakeRunner) + runner = FakeRunner() + setattr( + runner, + self.patch._RUNNER_PREFILL_STATE_FOR_OUTPUT_ATTR, + torch.tensor([False, True]), + ) + sampled = runner._generate_model_runner_output( + types.SimpleNamespace( + sampled_token_ids=torch.tensor([[0], [44]], dtype=torch.int32), + logprobs_tensors=None, + ) + ) + + torch.testing.assert_close( + sampled, + torch.tensor([[-1], [44]], dtype=torch.int32), + ) + + def test_runner_patch_captures_prefill_state_during_sample_tokens(self): + seen = [] + + class FakeRunner: + def __init__(self): + self.model = types.SimpleNamespace(model=types.SimpleNamespace()) + self._cached_logits = torch.tensor([[1.0]]) + self._cached_model_input = types.SimpleNamespace( + prefill_completion_state=torch.tensor([False, True]), + ) + + def _execute_model_for_text(self, model_input, intermediate_tensors=None): + del intermediate_tensors + return model_input + + def sample_tokens(self, grammar_output): + del grammar_output + seen.append( + getattr( + self, + self_patch._RUNNER_PREFILL_STATE_FOR_OUTPUT_ATTR, + ) + ) + return "sampled" + + self_patch = self.patch + self.patch.patch_neuron_model_runner_class(FakeRunner) + runner = FakeRunner() + result = runner.sample_tokens(None) + + self.assertEqual(result, "sampled") + torch.testing.assert_close(seen[0], torch.tensor([False, True])) + self.assertFalse( + hasattr(runner, self.patch._RUNNER_PREFILL_STATE_FOR_OUTPUT_ATTR) + ) + + def test_runner_patch_returns_no_output_for_initial_async_sample(self): + class FakeRunner: + def __init__(self): + self.model = types.SimpleNamespace(model=types.SimpleNamespace()) + self._cached_logits = None + self._cached_model_input = None + + def _execute_model_for_text(self, model_input, intermediate_tensors=None): + del intermediate_tensors + return model_input + + def sample_tokens(self, grammar_output): + del grammar_output + raise RuntimeError( + "sample_tokens() called without prior execute_model(). " + "Logits must be cached first." + ) + + self.patch.patch_neuron_model_runner_class(FakeRunner) + + self.assertIsNone(FakeRunner().sample_tokens(None)) + + def test_runner_patch_uses_hybrid_attention_layers_for_kv_cache_spec(self): + class FakeModelConfig: + dtype = "bfloat16" + + def __init__(self): + self.hf_config = types.SimpleNamespace( + num_hidden_layers=8, + num_attention_heads=24, + num_key_value_heads=4, + layer_types=[ + "linear_attention", + "linear_attention", + "linear_attention", + "full_attention", + ] + * 2, + ) + + def get_sliding_window(self): + return None + + class FakeRunner: + block_size = 256 + + def __init__(self): + self.model = types.SimpleNamespace(head_dim=256) + self.model_config = FakeModelConfig() + self.parallel_config = types.SimpleNamespace(tensor_parallel_size=4) + + def _execute_model_for_text(self, model_input, intermediate_tensors=None): + del intermediate_tensors + return model_input + + def get_kv_cache_spec(self): + return { + f"layers.{idx}.self_attn": FullAttentionSpec( + block_size=1, + num_kv_heads=4, + head_size=1, + dtype="original", + ) + for idx in range(8) + } + + installed = self.patch.patch_neuron_model_runner_class(FakeRunner) + spec = FakeRunner().get_kv_cache_spec() + + self.assertTrue(installed) + self.assertEqual(list(spec), ["layers.3.self_attn", "layers.7.self_attn"]) + self.assertEqual(spec["layers.3.self_attn"].block_size, 256) + self.assertEqual(spec["layers.3.self_attn"].num_kv_heads, 1) + self.assertEqual(spec["layers.3.self_attn"].head_size, 256) + self.assertEqual(spec["layers.3.self_attn"].dtype, "bfloat16") + + def test_runner_patch_uses_full_attention_interval_for_kv_cache_spec(self): + class FakeModelConfig: + dtype = "bfloat16" + + def __init__(self): + self.hf_config = types.SimpleNamespace( + num_hidden_layers=8, + num_attention_heads=24, + num_key_value_heads=4, + full_attention_interval=4, + ) + + def get_sliding_window(self): + return None + + class FakeRunner: + block_size = 128 + + def __init__(self): + self.model = types.SimpleNamespace(head_dim=256) + self.model_config = FakeModelConfig() + self.parallel_config = types.SimpleNamespace(tensor_parallel_size=4) + + def _execute_model_for_text(self, model_input, intermediate_tensors=None): + del intermediate_tensors + return model_input + + def get_kv_cache_spec(self): + return {"original": FullAttentionSpec(1, 4, 1, "original")} + + self.patch.patch_neuron_model_runner_class(FakeRunner) + spec = FakeRunner().get_kv_cache_spec() + + self.assertEqual(list(spec), ["layers.3.self_attn", "layers.7.self_attn"]) + self.assertEqual(spec["layers.7.self_attn"].block_size, 128) + self.assertEqual(spec["layers.7.self_attn"].num_kv_heads, 1) + + def test_runner_patch_keeps_original_kv_cache_spec_for_dense_attention(self): + class FakeModelConfig: + dtype = "bfloat16" + + def __init__(self): + self.hf_config = types.SimpleNamespace( + num_hidden_layers=2, + num_attention_heads=4, + num_key_value_heads=4, + layer_types=["full_attention", "full_attention"], + ) + + def get_sliding_window(self): + return None + + class FakeRunner: + block_size = 128 + + def __init__(self): + self.original_kv_cache_spec_called = False + self.model = types.SimpleNamespace(head_dim=256) + self.model_config = FakeModelConfig() + self.parallel_config = types.SimpleNamespace(tensor_parallel_size=4) + + def _execute_model_for_text(self, model_input, intermediate_tensors=None): + del intermediate_tensors + return model_input + + def get_kv_cache_spec(self): + self.original_kv_cache_spec_called = True + return {"original": FullAttentionSpec(1, 4, 1, "original")} + + self.patch.patch_neuron_model_runner_class(FakeRunner) + runner = FakeRunner() + spec = runner.get_kv_cache_spec() + + self.assertTrue(runner.original_kv_cache_spec_called) + self.assertEqual(list(spec), ["original"]) + + def test_import_hook_patches_already_loaded_neuron_runner_module(self): + class FakeRunner: + def __init__(self): + self.model = types.SimpleNamespace(model=types.SimpleNamespace()) + + def _execute_model_for_text(self, model_input, intermediate_tensors=None): + return getattr(self.model.model, "_qwen36_vllm_request_ids", None) + + module = types.SimpleNamespace(NeuronxDistributedModelRunner=FakeRunner) + sys.modules[_VLLM_NEURON_RUNNER_MODULE] = module + + installed = self.patch.install_import_hook() + runner = FakeRunner() + result = runner._execute_model_for_text( + types.SimpleNamespace(request_ids=("req-a",)) + ) + + self.assertTrue(installed) + self.assertEqual(result, ("req-a",)) + self.assertFalse(hasattr(runner.model, "_qwen36_vllm_request_ids")) + self.assertFalse(hasattr(runner.model.model, "_qwen36_vllm_request_ids")) + + +if __name__ == "__main__": + unittest.main() diff --git a/contrib/models/Qwen3.6-27B/test/unit/test_vllm_serving_config.py b/contrib/models/Qwen3.6-27B/test/unit/test_vllm_serving_config.py new file mode 100644 index 00000000..16ad8483 --- /dev/null +++ b/contrib/models/Qwen3.6-27B/test/unit/test_vllm_serving_config.py @@ -0,0 +1,368 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +import argparse +import importlib.util +import os +import unittest + + +_CONTRIB_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")) +_RUNNER_PATH = os.path.join(_CONTRIB_ROOT, "vllm", "run_offline_inference.py") + + +def _load_runner(): + spec = importlib.util.spec_from_file_location("qwen36_run_offline_inference", _RUNNER_PATH) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module + + +def _args(**overrides): + defaults = dict( + cte_bucket=512, + cte_buckets=None, + cte_bucket_profile="single", + context_encoding_bucket_pairs=None, + seq_len=2048, + tensor_parallel_size=4, + max_num_seqs=1, + ctx_batch_size=1, + logical_nc_config=2, + token_generation_buckets=None, + token_generation_batches=None, + async_mode=False, + block_size=128, + enable_prefix_caching=False, + enable_hybrid_apc=False, + enable_vllm_chunked_prefill=True, + kernel_q_tile_size=128, + kernel_kv_tile_size=1024, + hybrid_gdn_recurrent_cache_dtype=None, + gdn_recurrent_cache_dtype="float32", + hybrid_gdn_conv_cache_dtype=None, + gdn_conv_cache_dtype="bfloat16", + gdn_checkpoint_interval=256, + max_gdn_checkpoint_slots=8, + hybrid_cache_mode="all", + hybrid_cache_prefix_boundary_only=True, + hybrid_cache_validate_exact=False, + hybrid_apc_require_vllm_metadata=False, + hybrid_apc_reject_unbacked_attention_hits=True, + hybrid_apc_disable_unbacked_prefix_reads=False, + hybrid_apc_enable_backed_prefix_reads=False, + hybrid_apc_prefill_chunk_tokens=0, + text_only_cte=True, + compact_cte_attention_mask=True, + cold_zero_conv_fast_path=False, + ) + defaults.update(overrides) + return argparse.Namespace(**defaults) + + +class TestVllmServingConfig(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.runner = _load_runner() + + def test_cte_bucket_list_is_sorted_unique_and_128_aligned(self): + args = _args(cte_buckets=["512,128", "256", "256"]) + + self.assertEqual(self.runner._cte_buckets(args), [128, 256, 512]) + + def test_cte_bucket_rejects_non_128_aligned_value(self): + with self.assertRaisesRegex(ValueError, "128-aligned"): + self.runner._cte_buckets(_args(cte_buckets=["192"])) + + def test_short_profile_builds_dynamic_bucket_config(self): + config = self.runner._override_config(_args(cte_bucket_profile="short")) + neuron_config = config["override_neuron_config"] + + self.assertEqual(neuron_config["context_encoding_buckets"], [128, 256, 512, 1024]) + self.assertEqual(neuron_config["max_context_length"], 1024) + self.assertTrue(neuron_config["enable_bucketing"]) + self.assertEqual(config["max_prompt_length"], 1024) + + def test_text_only_and_compact_mask_flags_are_forwarded(self): + config = self.runner._override_config( + _args( + text_only_cte=False, + compact_cte_attention_mask=False, + cold_zero_conv_fast_path=True, + ) + ) + + self.assertFalse(config["use_text_only_cte_inputs"]) + self.assertFalse(config["use_compact_cte_attention_mask"]) + self.assertTrue(config["use_cold_zero_conv_fast_path"]) + + def test_sparse_context_encoding_bucket_pairs_are_forwarded(self): + config = self.runner._override_config( + _args( + enable_hybrid_apc=True, + block_size=256, + gdn_checkpoint_interval=256, + context_encoding_bucket_pairs=["512:0,512:32768", "3072:131072"], + ) + ) + + self.assertEqual( + config["override_neuron_config"]["context_encoding_bucket_pairs"], + [[512, 0], [512, 32768], [3072, 131072]], + ) + + def test_sparse_context_pairs_keep_prefix_cte_contract_without_vllm_prefix_cache(self): + config = self.runner._override_config( + _args( + enable_prefix_caching=False, + enable_hybrid_apc=False, + context_encoding_bucket_pairs=["512:0", "3072:16384"], + ) + ) + neuron_config = config["override_neuron_config"] + + self.assertFalse(config["use_hybrid_apc_manager"]) + self.assertTrue(neuron_config["is_prefix_caching"]) + self.assertEqual( + neuron_config["context_encoding_bucket_pairs"], + [[512, 0], [3072, 16384]], + ) + + def test_hybrid_apc_requires_checkpoint_interval_equal_block_size(self): + with self.assertRaisesRegex(ValueError, "gdn-checkpoint-interval"): + self.runner._override_config( + _args( + enable_hybrid_apc=True, + enable_prefix_caching=True, + block_size=128, + gdn_checkpoint_interval=256, + ) + ) + + def test_hybrid_apc_enables_prefix_caching_and_slots(self): + args = _args( + enable_hybrid_apc=True, + enable_prefix_caching=False, + block_size=256, + gdn_checkpoint_interval=256, + max_gdn_checkpoint_slots=3, + ) + + config = self.runner._override_config(args) + + self.assertTrue(args.enable_prefix_caching) + self.assertTrue(config["use_hybrid_apc_manager"]) + self.assertEqual(config["max_gdn_checkpoint_slots"], 3) + + def test_hybrid_apc_rejects_bfloat16_recurrent_checkpoint_cache(self): + with self.assertRaisesRegex(ValueError, "requires float32 recurrent GDN"): + self.runner._override_config( + _args( + enable_hybrid_apc=True, + block_size=256, + gdn_checkpoint_interval=256, + gdn_recurrent_cache_dtype="bfloat16", + ) + ) + + def test_hybrid_apc_can_require_vllm_metadata(self): + config = self.runner._override_config( + _args( + enable_hybrid_apc=True, + block_size=256, + gdn_checkpoint_interval=256, + hybrid_apc_require_vllm_metadata=True, + ) + ) + + self.assertTrue(config["hybrid_apc_require_vllm_metadata"]) + self.assertFalse(config["hybrid_apc_allow_local_hash_fallback"]) + self.assertTrue(config["hybrid_apc_require_attention_block_refs"]) + self.assertTrue(config["hybrid_apc_reject_unbacked_attention_hits"]) + + def test_hybrid_apc_can_disable_unbacked_prefix_reads(self): + config = self.runner._override_config( + _args( + enable_hybrid_apc=True, + block_size=256, + gdn_checkpoint_interval=256, + hybrid_apc_disable_unbacked_prefix_reads=True, + ) + ) + + self.assertTrue(config["hybrid_apc_disable_unbacked_prefix_reads"]) + + def test_hybrid_apc_can_enable_backed_prefix_reads(self): + config = self.runner._override_config( + _args( + enable_hybrid_apc=True, + block_size=256, + gdn_checkpoint_interval=256, + hybrid_apc_enable_backed_prefix_reads=True, + ) + ) + + self.assertTrue(config["hybrid_apc_enable_backed_prefix_reads"]) + + def test_chunked_prefill_runtime_flags_are_forwarded(self): + config = self.runner._override_config(_args(enable_vllm_chunked_prefill=True)) + + self.assertTrue(config["use_qwen_hybrid_chunked_prefill"]) + self.assertTrue(config["use_qwen_hybrid_chunked_prefill_nki"]) + + def test_grouped_prefill_defaults_to_largest_compiled_bucket(self): + config = self.runner._override_config( + _args( + cte_buckets=["512,1024"], + seq_len=2048, + enable_hybrid_apc=True, + enable_vllm_chunked_prefill=True, + block_size=256, + gdn_checkpoint_interval=256, + ) + ) + + self.assertEqual(config["qwen_prefill_group_size"], 1024) + self.assertEqual(config["hybrid_apc_prefill_chunk_tokens"], 1024) + + def test_grouped_prefill_records_explicit_four_chunk_group(self): + config = self.runner._override_config( + _args( + cte_buckets=["512,1024,2048"], + seq_len=4096, + enable_hybrid_apc=True, + enable_vllm_chunked_prefill=True, + block_size=256, + gdn_checkpoint_interval=256, + hybrid_apc_prefill_chunk_tokens=2048, + ) + ) + + self.assertEqual(config["qwen_prefill_group_size"], 2048) + self.assertEqual(config["hybrid_apc_prefill_chunk_tokens"], 2048) + + def test_hybrid_apc_chunked_prefill_defaults_to_largest_aligned_bucket(self): + args = _args( + cte_buckets=["256,512"], + enable_hybrid_apc=True, + enable_vllm_chunked_prefill=True, + block_size=256, + gdn_checkpoint_interval=256, + ) + + self.assertEqual( + self.runner._max_num_batched_tokens( + args, + self.runner._cte_buckets(args), + ), + 512, + ) + + def test_hybrid_apc_chunked_prefill_uses_largest_checkpoint_aligned_bucket(self): + args = _args( + cte_buckets=["512,768,1536,3072"], + seq_len=3072, + enable_hybrid_apc=True, + enable_vllm_chunked_prefill=True, + block_size=256, + gdn_checkpoint_interval=256, + ) + + self.assertEqual( + self.runner._max_num_batched_tokens( + args, + self.runner._cte_buckets(args), + ), + 3072, + ) + + def test_hybrid_apc_chunked_prefill_requires_checkpoint_aligned_cte_bucket(self): + args = _args( + cte_buckets=["384"], + enable_hybrid_apc=True, + enable_vllm_chunked_prefill=True, + block_size=256, + gdn_checkpoint_interval=256, + ) + + with self.assertRaisesRegex(ValueError, "multiple"): + self.runner._max_num_batched_tokens( + args, + self.runner._cte_buckets(args), + ) + + def test_hybrid_apc_can_use_safe_non_power_of_two_prefill_chunk(self): + args = _args( + cte_buckets=["512,768,1536,3072"], + seq_len=8192, + enable_hybrid_apc=True, + enable_vllm_chunked_prefill=True, + block_size=256, + gdn_checkpoint_interval=256, + hybrid_apc_prefill_chunk_tokens=3072, + ) + + self.assertEqual( + self.runner._max_num_batched_tokens( + args, + self.runner._cte_buckets(args), + ), + 3072, + ) + + def test_hybrid_apc_can_use_explicit_larger_prefill_chunk(self): + args = _args( + cte_buckets=["256,512,1024,2048,4096,8192"], + seq_len=8192, + enable_hybrid_apc=True, + enable_vllm_chunked_prefill=True, + block_size=256, + gdn_checkpoint_interval=256, + hybrid_apc_prefill_chunk_tokens=8192, + ) + + self.assertEqual( + self.runner._max_num_batched_tokens( + args, + self.runner._cte_buckets(args), + ), + 8192, + ) + + def test_hybrid_apc_larger_prefill_chunk_must_be_compiled_bucket(self): + args = _args( + cte_buckets=["256,512,1024,2048,4096"], + seq_len=8192, + enable_hybrid_apc=True, + enable_vllm_chunked_prefill=True, + block_size=256, + gdn_checkpoint_interval=256, + hybrid_apc_prefill_chunk_tokens=8192, + ) + + with self.assertRaisesRegex(ValueError, "compiled CTE bucket"): + self.runner._max_num_batched_tokens( + args, + self.runner._cte_buckets(args), + ) + + def test_hybrid_apc_larger_prefill_chunk_must_align_to_checkpoint(self): + args = _args( + cte_buckets=["256,384,512,1024"], + seq_len=1024, + enable_hybrid_apc=True, + enable_vllm_chunked_prefill=True, + block_size=256, + gdn_checkpoint_interval=256, + hybrid_apc_prefill_chunk_tokens=384, + ) + + with self.assertRaisesRegex(ValueError, "multiple"): + self.runner._max_num_batched_tokens( + args, + self.runner._cte_buckets(args), + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/contrib/models/Qwen3.6-27B/test/unit/test_weight_conversion.py b/contrib/models/Qwen3.6-27B/test/unit/test_weight_conversion.py new file mode 100644 index 00000000..23adfdd5 --- /dev/null +++ b/contrib/models/Qwen3.6-27B/test/unit/test_weight_conversion.py @@ -0,0 +1,535 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Unit tests for Qwen3.5/3.6-27B HF-to-NxDI weight conversion. + +CPU-only tests that validate: +- RMSNorm (+1 convention) weight conversion +- GQA q_proj interleaved split (query + gate) +- QK norm key renaming (q_norm -> q_layernorm, k_norm -> k_layernorm) +- Fused QKV concatenation +- DeltaNet layer weights pass through unchanged +- VL wrapper prefix stripping +- rank_util injection + +These tests are architecture-level and apply to both Qwen3.5-27B and Qwen3.6-27B. +""" + +import os +import sys +import unittest + +import torch + +_CONTRIB_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")) +if _CONTRIB_ROOT not in sys.path: + sys.path.insert(0, _CONTRIB_ROOT) + +from src.modeling_qwen35 import ( + Qwen35InferenceConfig, + NeuronQwen35ForCausalLM, + convert_qwen35_hf_to_neuron_state_dict, +) +from neuronx_distributed_inference.models.config import NeuronConfig + + +def _make_mini_config(num_layers=4, tp_degree=2, fused_qkv=True): + """Create a small Qwen35InferenceConfig for testing.""" + neuron_config = NeuronConfig( + tp_degree=tp_degree, + batch_size=1, + seq_len=128, + torch_dtype=torch.bfloat16, + fused_qkv=fused_qkv, + ) + config = Qwen35InferenceConfig( + neuron_config=neuron_config, + hidden_size=256, + num_hidden_layers=num_layers, + num_attention_heads=4, + num_key_value_heads=2, + head_dim=64, + intermediate_size=512, + vocab_size=1000, + rms_norm_eps=1e-6, + max_position_embeddings=4096, + rope_theta=10000, + hidden_act="silu", + linear_num_value_heads=8, + linear_num_key_heads=4, + linear_key_head_dim=32, + linear_value_head_dim=32, + linear_conv_kernel_dim=4, + ) + return config + + +def _make_mini_state_dict(config): + """Create a minimal HF-style state dict for conversion testing.""" + sd = {} + H = config.hidden_size # 256 + I = config.intermediate_size # 512 + V = config.vocab_size # 1000 + num_heads = config.num_attention_heads # 4 + num_kv = config.num_key_value_heads # 2 + head_dim = config.head_dim # 64 + + sd["embed_tokens.weight"] = torch.randn(V, H, dtype=torch.bfloat16) * 0.02 + sd["lm_head.weight"] = torch.randn(V, H, dtype=torch.bfloat16) * 0.02 + sd["norm.weight"] = torch.zeros(H, dtype=torch.bfloat16) # +1 convention: zeros + + for l in range(config.num_hidden_layers): + sd[f"layers.{l}.input_layernorm.weight"] = torch.zeros(H, dtype=torch.bfloat16) + sd[f"layers.{l}.post_attention_layernorm.weight"] = torch.zeros( + H, dtype=torch.bfloat16 + ) + + # Dense MLP (all layers) + sd[f"layers.{l}.mlp.gate_proj.weight"] = ( + torch.randn(I, H, dtype=torch.bfloat16) * 0.02 + ) + sd[f"layers.{l}.mlp.up_proj.weight"] = ( + torch.randn(I, H, dtype=torch.bfloat16) * 0.02 + ) + sd[f"layers.{l}.mlp.down_proj.weight"] = ( + torch.randn(H, I, dtype=torch.bfloat16) * 0.02 + ) + + if config.layer_types[l] == "full_attention": + # GQA layer: q_proj is interleaved [head0_q | head0_gate | head1_q | ...] + q_proj = ( + torch.randn(num_heads * head_dim * 2, H, dtype=torch.bfloat16) * 0.02 + ) + sd[f"layers.{l}.self_attn.q_proj.weight"] = q_proj + sd[f"layers.{l}.self_attn.k_proj.weight"] = ( + torch.randn(num_kv * head_dim, H, dtype=torch.bfloat16) * 0.02 + ) + sd[f"layers.{l}.self_attn.v_proj.weight"] = ( + torch.randn(num_kv * head_dim, H, dtype=torch.bfloat16) * 0.02 + ) + sd[f"layers.{l}.self_attn.o_proj.weight"] = ( + torch.randn(H, num_heads * head_dim, dtype=torch.bfloat16) * 0.02 + ) + sd[f"layers.{l}.self_attn.q_norm.weight"] = torch.zeros( + head_dim, dtype=torch.bfloat16 + ) + sd[f"layers.{l}.self_attn.k_norm.weight"] = torch.zeros( + head_dim, dtype=torch.bfloat16 + ) + else: + # DeltaNet layer: minimal required weights + key_dim = config.linear_num_key_heads * config.linear_key_head_dim # 128 + value_dim = ( + config.linear_num_value_heads * config.linear_value_head_dim + ) # 256 + conv_dim = key_dim * 2 + value_dim # 512 + sd[f"layers.{l}.linear_attn.in_proj_qkv.weight"] = ( + torch.randn(conv_dim, H, dtype=torch.bfloat16) * 0.02 + ) + sd[f"layers.{l}.linear_attn.in_proj_z.weight"] = ( + torch.randn(value_dim, H, dtype=torch.bfloat16) * 0.02 + ) + sd[f"layers.{l}.linear_attn.in_proj_a.weight"] = ( + torch.randn(config.linear_num_value_heads, H, dtype=torch.bfloat16) + * 0.02 + ) + sd[f"layers.{l}.linear_attn.in_proj_b.weight"] = ( + torch.randn(config.linear_num_value_heads, H, dtype=torch.bfloat16) + * 0.02 + ) + sd[f"layers.{l}.linear_attn.conv1d.weight"] = ( + torch.randn( + conv_dim, 1, config.linear_conv_kernel_dim, dtype=torch.bfloat16 + ) + * 0.02 + ) + sd[f"layers.{l}.linear_attn.A_log"] = torch.randn( + config.linear_num_value_heads, dtype=torch.bfloat16 + ) + sd[f"layers.{l}.linear_attn.dt_bias"] = torch.randn( + config.linear_num_value_heads, dtype=torch.bfloat16 + ) + sd[f"layers.{l}.linear_attn.norm.weight"] = ( + torch.randn(value_dim, dtype=torch.bfloat16) * 0.5 + ) + sd[f"layers.{l}.linear_attn.out_proj.weight"] = ( + torch.randn(H, value_dim, dtype=torch.bfloat16) * 0.02 + ) + + return sd + + +class TestNormConversion(unittest.TestCase): + """Test (+1 convention) RMSNorm weight conversion.""" + + def test_norm_weight_adds_one(self): + """Weights initialized to zero should become 1.0 after conversion.""" + config = _make_mini_config() + sd = _make_mini_state_dict(config) + result = convert_qwen35_hf_to_neuron_state_dict(sd, config) + # norm.weight was zeros -> should now be ones + torch.testing.assert_close( + result["norm.weight"], + torch.ones_like(result["norm.weight"]), + ) + + def test_input_layernorm_adds_one(self): + config = _make_mini_config() + sd = _make_mini_state_dict(config) + result = convert_qwen35_hf_to_neuron_state_dict(sd, config) + for l in range(config.num_hidden_layers): + w = result[f"layers.{l}.input_layernorm.weight"] + self.assertTrue( + torch.allclose(w, torch.ones_like(w)), + f"Layer {l} input_layernorm not converted", + ) + + def test_post_attn_layernorm_adds_one(self): + config = _make_mini_config() + sd = _make_mini_state_dict(config) + result = convert_qwen35_hf_to_neuron_state_dict(sd, config) + for l in range(config.num_hidden_layers): + w = result[f"layers.{l}.post_attention_layernorm.weight"] + self.assertTrue( + torch.allclose(w, torch.ones_like(w)), + f"Layer {l} post_attention_layernorm not converted", + ) + + def test_qk_norm_adds_one(self): + """Q/K norms on GQA layers should also get +1 applied.""" + config = _make_mini_config() + sd = _make_mini_state_dict(config) + result = convert_qwen35_hf_to_neuron_state_dict(sd, config) + for l in range(config.num_hidden_layers): + if config.layer_types[l] == "full_attention": + q_w = result[f"layers.{l}.self_attn.q_layernorm.weight"] + k_w = result[f"layers.{l}.self_attn.k_layernorm.weight"] + self.assertTrue( + torch.allclose(q_w, torch.ones_like(q_w)), + f"Layer {l} q_layernorm not converted", + ) + self.assertTrue( + torch.allclose(k_w, torch.ones_like(k_w)), + f"Layer {l} k_layernorm not converted", + ) + + +class TestQProjSplit(unittest.TestCase): + """Test q_proj interleaved split into query + gate.""" + + def test_q_proj_split_shapes(self): + """q_proj (num_heads * head_dim * 2, H) -> separate query and gate.""" + config = _make_mini_config(fused_qkv=False) + sd = _make_mini_state_dict(config) + result = convert_qwen35_hf_to_neuron_state_dict(sd, config) + + for l in range(config.num_hidden_layers): + if config.layer_types[l] == "full_attention": + # After split: q_proj should be (num_heads * head_dim, H) = (256, 256) + q_w = result[f"layers.{l}.self_attn.q_proj.weight"] + gate_w = result[f"layers.{l}.self_attn.output_gate_proj.weight"] + expected_shape = ( + config.num_attention_heads * config.head_dim, + config.hidden_size, + ) + self.assertEqual( + q_w.shape, expected_shape, f"Layer {l} q_proj shape wrong" + ) + self.assertEqual( + gate_w.shape, expected_shape, f"Layer {l} gate shape wrong" + ) + + def test_q_proj_deinterleave_correct(self): + """Verify the interleaved split correctly separates query and gate.""" + config = _make_mini_config(fused_qkv=False) + sd = _make_mini_state_dict(config) + + # Create a known pattern: head0 query is 1s, head0 gate is 2s, etc. + l = 3 # First full_attention layer (layer 3) + num_heads = config.num_attention_heads + head_dim = config.head_dim + H = config.hidden_size + + interleaved = torch.zeros(num_heads * head_dim * 2, H, dtype=torch.bfloat16) + for h in range(num_heads): + interleaved[h * head_dim * 2 : h * head_dim * 2 + head_dim, :] = float( + h + 1 + ) # query + interleaved[h * head_dim * 2 + head_dim : (h + 1) * head_dim * 2, :] = ( + float(h + 100) + ) # gate + + sd[f"layers.{l}.self_attn.q_proj.weight"] = interleaved + result = convert_qwen35_hf_to_neuron_state_dict(sd, config) + + q_w = result[f"layers.{l}.self_attn.q_proj.weight"] + gate_w = result[f"layers.{l}.self_attn.output_gate_proj.weight"] + + for h in range(num_heads): + q_head = q_w[h * head_dim : (h + 1) * head_dim, :] + gate_head = gate_w[h * head_dim : (h + 1) * head_dim, :] + self.assertTrue( + torch.all(q_head == float(h + 1)), f"Head {h} query values wrong" + ) + self.assertTrue( + torch.all(gate_head == float(h + 100)), f"Head {h} gate values wrong" + ) + + def test_q_proj_scale_deinterleave_correct(self): + """FP8 q_proj scale should split the same way as q_proj weight.""" + config = _make_mini_config(fused_qkv=False) + sd = _make_mini_state_dict(config) + + l = 3 + num_heads = config.num_attention_heads + head_dim = config.head_dim + + interleaved_scale = torch.zeros(num_heads * head_dim * 2, 1) + for h in range(num_heads): + interleaved_scale[ + h * head_dim * 2 : h * head_dim * 2 + head_dim, + :, + ] = float(h + 1) + interleaved_scale[ + h * head_dim * 2 + head_dim : (h + 1) * head_dim * 2, + :, + ] = float(h + 100) + + sd[f"layers.{l}.self_attn.q_proj.scale"] = interleaved_scale + result = convert_qwen35_hf_to_neuron_state_dict(sd, config) + + q_scale = result[f"layers.{l}.self_attn.q_proj.scale"] + gate_scale = result[f"layers.{l}.self_attn.output_gate_proj.scale"] + + for h in range(num_heads): + q_head = q_scale[h * head_dim : (h + 1) * head_dim, :] + gate_head = gate_scale[h * head_dim : (h + 1) * head_dim, :] + self.assertTrue( + torch.all(q_head == float(h + 1)), f"Head {h} query scale wrong" + ) + self.assertTrue( + torch.all(gate_head == float(h + 100)), f"Head {h} gate scale wrong" + ) + + +class TestQKNormRename(unittest.TestCase): + """Test q_norm -> q_layernorm and k_norm -> k_layernorm renaming.""" + + def test_old_keys_removed(self): + config = _make_mini_config() + sd = _make_mini_state_dict(config) + result = convert_qwen35_hf_to_neuron_state_dict(sd, config) + for l in range(config.num_hidden_layers): + if config.layer_types[l] == "full_attention": + self.assertNotIn(f"layers.{l}.self_attn.q_norm.weight", result) + self.assertNotIn(f"layers.{l}.self_attn.k_norm.weight", result) + + def test_new_keys_present(self): + config = _make_mini_config() + sd = _make_mini_state_dict(config) + result = convert_qwen35_hf_to_neuron_state_dict(sd, config) + for l in range(config.num_hidden_layers): + if config.layer_types[l] == "full_attention": + self.assertIn(f"layers.{l}.self_attn.q_layernorm.weight", result) + self.assertIn(f"layers.{l}.self_attn.k_layernorm.weight", result) + + +class TestFusedQKV(unittest.TestCase): + """Test fused QKV concatenation for attention layers.""" + + def test_fused_qkv_shape(self): + config = _make_mini_config(fused_qkv=True) + sd = _make_mini_state_dict(config) + result = convert_qwen35_hf_to_neuron_state_dict(sd, config) + + for l in range(config.num_hidden_layers): + if config.layer_types[l] == "full_attention": + fused_key = f"layers.{l}.self_attn.Wqkv.weight" + self.assertIn(fused_key, result, f"Layer {l} missing Wqkv") + + q_dim = config.num_attention_heads * config.head_dim + k_dim = config.num_key_value_heads * config.head_dim + v_dim = config.num_key_value_heads * config.head_dim + expected_rows = q_dim + k_dim + v_dim + self.assertEqual(result[fused_key].shape[0], expected_rows) + + def test_fused_qkv_removes_individual_keys(self): + config = _make_mini_config(fused_qkv=True) + sd = _make_mini_state_dict(config) + result = convert_qwen35_hf_to_neuron_state_dict(sd, config) + + for l in range(config.num_hidden_layers): + if config.layer_types[l] == "full_attention": + self.assertNotIn(f"layers.{l}.self_attn.q_proj.weight", result) + self.assertNotIn(f"layers.{l}.self_attn.k_proj.weight", result) + self.assertNotIn(f"layers.{l}.self_attn.v_proj.weight", result) + + def test_fused_qkv_scale_created_and_individual_scales_removed(self): + config = _make_mini_config(fused_qkv=True) + sd = _make_mini_state_dict(config) + l = 3 + q_dim = config.num_attention_heads * config.head_dim + kv_dim = config.num_key_value_heads * config.head_dim + + sd[f"layers.{l}.self_attn.q_proj.scale"] = torch.arange( + q_dim * 2, + dtype=torch.float32, + ).reshape(q_dim * 2, 1) + sd[f"layers.{l}.self_attn.k_proj.scale"] = torch.full((kv_dim, 1), 7.0) + sd[f"layers.{l}.self_attn.v_proj.scale"] = torch.full((kv_dim, 1), 9.0) + + result = convert_qwen35_hf_to_neuron_state_dict(sd, config) + + fused_scale_key = f"layers.{l}.self_attn.Wqkv.scale" + gate_scale_key = f"layers.{l}.self_attn.output_gate_proj.scale" + self.assertIn(fused_scale_key, result) + self.assertIn(gate_scale_key, result) + self.assertNotIn(f"layers.{l}.self_attn.q_proj.scale", result) + self.assertNotIn(f"layers.{l}.self_attn.k_proj.scale", result) + self.assertNotIn(f"layers.{l}.self_attn.v_proj.scale", result) + self.assertEqual(result[fused_scale_key].shape, (q_dim + 2 * kv_dim, 1)) + torch.testing.assert_close( + result[fused_scale_key][q_dim : q_dim + kv_dim], + torch.full((kv_dim, 1), 7.0), + ) + torch.testing.assert_close( + result[fused_scale_key][q_dim + kv_dim :], + torch.full((kv_dim, 1), 9.0), + ) + + +class TestDeltaNetPassthrough(unittest.TestCase): + """Test that DeltaNet layer weights pass through conversion unchanged.""" + + def test_deltanet_weights_unchanged(self): + config = _make_mini_config(tp_degree=1) + sd = _make_mini_state_dict(config) + + # Record original DeltaNet weights + originals = {} + for l in range(config.num_hidden_layers): + if config.layer_types[l] == "linear_attention": + key = f"layers.{l}.linear_attn.in_proj_qkv.weight" + originals[key] = sd[key].clone() + + result = convert_qwen35_hf_to_neuron_state_dict(sd, config) + + for key, orig in originals.items(): + self.assertIn(key, result, f"Missing: {key}") + torch.testing.assert_close( + result[key], orig, msg=f"DeltaNet weight changed: {key}" + ) + + def test_deltanet_norm_not_converted(self): + """DeltaNet layers use standard RMSNorm (NOT +1 convention). + The norm weight should NOT be changed.""" + config = _make_mini_config() + sd = _make_mini_state_dict(config) + + # Set DeltaNet norm to a known non-zero value + for l in range(config.num_hidden_layers): + if config.layer_types[l] == "linear_attention": + sd[f"layers.{l}.linear_attn.norm.weight"] = torch.full( + (config.linear_num_value_heads * config.linear_value_head_dim,), + 0.87, + dtype=torch.bfloat16, + ) + + result = convert_qwen35_hf_to_neuron_state_dict(sd, config) + + for l in range(config.num_hidden_layers): + if config.layer_types[l] == "linear_attention": + w = result[f"layers.{l}.linear_attn.norm.weight"] + # Should still be ~0.87, NOT 1.87 + self.assertTrue( + torch.allclose(w, torch.full_like(w, 0.87), atol=0.01), + f"Layer {l} DeltaNet norm was incorrectly modified", + ) + + def test_deltanet_qkv_scale_reordered_for_tp(self): + config = _make_mini_config(tp_degree=2) + sd = _make_mini_state_dict(config) + l = 0 + key_dim = config.linear_num_key_heads * config.linear_key_head_dim + value_dim = config.linear_num_value_heads * config.linear_value_head_dim + conv_dim = key_dim * 2 + value_dim + scale = torch.arange(conv_dim, dtype=torch.float32).reshape(conv_dim, 1) + sd[f"layers.{l}.linear_attn.in_proj_qkv.scale"] = scale + + result = convert_qwen35_hf_to_neuron_state_dict(sd, config) + + local_key_dim = key_dim // config.neuron_config.tp_degree + local_value_dim = value_dim // config.neuron_config.tp_degree + expected = torch.cat( + [ + scale[0:local_key_dim], + scale[key_dim : key_dim + local_key_dim], + scale[2 * key_dim : 2 * key_dim + local_value_dim], + scale[local_key_dim:key_dim], + scale[key_dim + local_key_dim : 2 * key_dim], + scale[2 * key_dim + local_value_dim : 2 * key_dim + value_dim], + ], + dim=0, + ) + torch.testing.assert_close( + result[f"layers.{l}.linear_attn.in_proj_qkv.scale"], + expected, + ) + + +class TestRankUtil(unittest.TestCase): + """Test rank_util tensor injection.""" + + def test_rank_util_present(self): + config = _make_mini_config(tp_degree=4) + sd = _make_mini_state_dict(config) + result = convert_qwen35_hf_to_neuron_state_dict(sd, config) + self.assertIn("rank_util.rank", result) + expected = torch.arange(0, 4, dtype=torch.int32) + torch.testing.assert_close(result["rank_util.rank"], expected) + + def test_gqa_layer_rank_util(self): + config = _make_mini_config(tp_degree=4) + sd = _make_mini_state_dict(config) + result = convert_qwen35_hf_to_neuron_state_dict(sd, config) + for l in range(config.num_hidden_layers): + if config.layer_types[l] == "full_attention": + key = f"layers.{l}.self_attn.rank_util.rank" + self.assertIn(key, result) + expected = torch.arange(0, 4, dtype=torch.int32) + torch.testing.assert_close(result[key], expected) + + +class TestVLPrefixStripping(unittest.TestCase): + """Test VL wrapper prefix stripping in convert_hf_to_neuron_state_dict.""" + + def test_language_model_prefix_stripped(self): + config = _make_mini_config() + sd = _make_mini_state_dict(config) + + # Wrap with VL prefix + vl_sd = {} + for k, v in sd.items(): + vl_sd[f"language_model.{k}"] = v + vl_sd["visual.encoder.weight"] = torch.zeros(10) # should be skipped + vl_sd["mtp.something"] = torch.zeros(5) # should be skipped + + result = NeuronQwen35ForCausalLM.convert_hf_to_neuron_state_dict(vl_sd, config) + self.assertNotIn("visual.encoder.weight", result) + self.assertNotIn("mtp.something", result) + self.assertIn("norm.weight", result) + + def test_model_language_model_prefix_stripped(self): + config = _make_mini_config() + sd = _make_mini_state_dict(config) + + vl_sd = {} + for k, v in sd.items(): + vl_sd[f"model.language_model.{k}"] = v + + result = NeuronQwen35ForCausalLM.convert_hf_to_neuron_state_dict(vl_sd, config) + self.assertIn("norm.weight", result) + + +if __name__ == "__main__": + unittest.main() diff --git a/contrib/models/Qwen3.6-27B/vllm/README.md b/contrib/models/Qwen3.6-27B/vllm/README.md new file mode 100644 index 00000000..4c921efa --- /dev/null +++ b/contrib/models/Qwen3.6-27B/vllm/README.md @@ -0,0 +1,483 @@ +# Qwen3.6-27B vLLM on Neuron + +This folder contains the first-pass vLLM integration helpers for the +Qwen3.6-27B contrib model. + +The current goal is **vLLM serving through the Neuron/NxDI plugin** for the +validated Qwen3.6 artifact, including long prompts through vLLM's native +chunked-prefill scheduler. + +## Which vLLM Neuron Package? + +Use the vLLM-on-Neuron environment that matches the installed Neuron SDK first. +For SDK 2.29, the AWS Neuron guide lists the NxDI/vLLM plugin stack as +`vLLM 0.16.0` with plugin version `0.5.0`. The +`vllm-project/vllm-neuron` repository is useful source/reference material, but +its README currently describes a beta plugin path tied to older `vLLM 0.11.0` +and SDK 2.26.1. Do not downgrade the working SDK 2.29 environment just to use +that repository. + +On a DLAMI, prefer the preinstalled vLLM/Neuron environment when available. If +the instance does not have one, install the Neuron-compatible vLLM plugin/fork +using the current AWS guide, then run the contrib registry patch below. + +## What Works First + +- Register the contrib `qwen3_5` text model with the NxDI model registry inside + the vLLM environment. +- Start vLLM with `VLLM_PLUGINS=neuron`. +- Load a small-context model or a precompiled artifact with + `NEURON_COMPILED_ARTIFACTS`. +- Run a short OpenAI-compatible smoke prompt. + +## Hybrid APC Production Boundary + +Qwen3.6-27B is a hybrid model, so attention prefix caching alone is not a +complete production APC contract. The current stack has strong serving +primitives for attention-only models: block KV, block tables, slot mapping, +prefix caching, continuous batching, chunked prefill, and decode. For hybrid +attention plus GDN recurrence, the attention KV path is production-shaped, but +the GDN state path is still model-specific glue. + +Current readiness: + +| Layer | Current support | Production readiness | +| --- | ---: | ---: | +| Attention KV cache | Good | High on the existing NxDI/vLLM block-KV path | +| vLLM APC for attention blocks | Working baseline | Medium/high | +| GDN recurrent state cache | Implemented locally | Low/medium | +| GDN conv state cache | Implemented locally | Low/medium | +| Hybrid APC across attention + GDN | Not fully implemented | Low | +| Continuous batching with exact hybrid prefix reuse | Not supported by the local manager | Low | +| Speculation, FP8 cache, tiling, flash decode with hybrid state | Explicitly rejected by the local manager | Low | + +The `HybridDeltaNetCacheManager` is therefore a contrib-local static/stateful +cache manager, not a production hybrid APC manager. It proves the model can +preserve recurrent and conv state, but it is batch-row based rather than +vLLM block-hash, refcount, eviction, and tenant-isolation based. + +Production hybrid APC must define the usable prefix as the intersection of: + +1. attention KV block hit; +2. GDN recurrent prefix-boundary checkpoint hit; +3. GDN conv prefix-boundary checkpoint hit. + +For each GDN layer, the reusable checkpoint object needs: + +```text +recurrent_state: [local_value_heads, key_dim, value_dim] +conv_state: [conv_dim, conv_kernel_size - 1] +``` + +The recurrent state should stay FP32 for exact cold-vs-warm agreement until +BF16 equivalence is proven. Conv state can follow the model-compatible dtype, +but exactness still needs token-level validation. If the attention APC hit lands +inside a GDN checkpoint interval, restore the nearest earlier full GDN +checkpoint, replay the residual tokens, then run the suffix. + +The launchers expose `--enable-hybrid-apc` and explicit hybrid cache dtype +knobs. In the current v0 implementation, `use_hybrid_apc_manager=True` creates +a bounded GDN checkpoint-slot bank and adds restore/commit tensors to the model +signature. The serving request-prep path must still fill those tensors from the +vLLM/NxDI cumulative-prefix hash lifecycle; otherwise the default zero masks run +as attention KV plus normal active-row GDN state with no GDN checkpoint reuse. +For v0, `gdn_checkpoint_interval` must equal the vLLM block size. + +The production server launcher enables strict hybrid APC metadata by default. +That means request prep must provide vLLM/NxDI cumulative prefix hashes and real +attention block refs; local token-hash fallback is reserved for controlled +validation via `--allow-hybrid-apc-local-hash-fallback`. The live scheduler +integration should pass the full prompt before suffix slicing using +`hybrid_full_input_ids`/`full_input_ids`, attach `vllm_attention_hit_len`, pass +`cumulative_hashes_by_prefix_len`, and pass actual attention block refs at +commit time through `actual_attention_block_refs` or +`hybrid_actual_attention_block_refs`. Attention KV eviction should call the +model/store `on_attention_block_evicted` callback so GDN checkpoints do not +outlive the KV blocks they depend on. + +## Chunked Prefill Note + +The Neuron plugin disables vLLM chunked prefill by default and installs a custom +continuous-batching scheduler. For this Qwen3.6 artifact we need vLLM's native +chunked-prefill scheduler so prompts longer than the 512-token context graph are +fed to the precompiled model in 512-token chunks. The launcher sets +`DISABLE_NEURON_CUSTOM_SCHEDULER=1` when `--enable-vllm-chunked-prefill` is +passed. It also launches with `--generation-config vllm` so model +`generation_config.json` does not silently override deterministic sampling +defaults. + +## Install The Contrib Registry Patch + +Activate the vLLM/Neuron environment on the instance, then run: + +```bash +cd /home/ubuntu/inferentia-gdn +contrib/models/Qwen3.6-27B/vllm/install_qwen36_vllm.sh +``` + +If your vLLM environment is not in a standard location: + +```bash +contrib/models/Qwen3.6-27B/vllm/install_qwen36_vllm.sh \ + /opt/aws_neuronx_venv_pytorch_2_9_nxd_inference +``` + +The installer only patches the active environment. It does not modify core repo +files. + +## Start vLLM + +Small-context compile/load path: + +```bash +contrib/models/Qwen3.6-27B/vllm/start_vllm_server.sh \ + --model-path /opt/dlami/nvme/models/Qwen3.6-27B \ + --max-model-len 512 \ + --port 8000 +``` + +Precompiled artifact path: + +```bash +contrib/models/Qwen3.6-27B/vllm/start_vllm_server.sh \ + --model-path /opt/dlami/nvme/models/Qwen3.6-27B \ + --compiled-artifacts /opt/dlami/nvme/qwen_artifacts/qwen36_27b_128k_fp8_mlp_only_vllm_statereset_run1 \ + --max-model-len 131072 \ + --seq-len 131072 \ + --cte-buckets 128,256,512 \ + --port 8000 +``` + +Cold-prefill bucket waste is the first performance target. CTE buckets must stay +128-aligned because the fused DeltaNet CTE path operates in 128-token chunks. +Use one of the explicit profiles when compiling artifacts: + +```bash +# Short-prompt latency +--cte-bucket-profile short # [128,256,512,1024] + +# General production +--cte-bucket-profile general # [256,512,1024,2048] + +# Long-context artifact +--cte-bucket-profile long # [4096,8192,16384,32768] + +# 262K load experiment +--cte-bucket-profile 262k # [256] +``` + +`--cold-zero-conv-fast-path` is only for a cold-only CTE artifact whose suffix +prefill always starts at position 0. Leave it disabled for APC or partial-prefix +serving because restored GDN conv state must be consumed exactly. + +Long-prompt precompiled artifact path: + +```bash +contrib/models/Qwen3.6-27B/vllm/start_vllm_server.sh \ + --model-path /opt/dlami/nvme/models/Qwen3.6-27B \ + --compiled-artifacts /opt/dlami/nvme/qwen_artifacts/qwen36_27b_128k_fp8_mlp_only_vllm_statereset_run1 \ + --max-model-len 131072 \ + --seq-len 131072 \ + --cte-buckets 256,512 \ + --block-size 256 \ + --enable-vllm-chunked-prefill \ + --port 8000 +``` + +Native vLLM prefix-cache experiment: + +```bash +contrib/models/Qwen3.6-27B/vllm/start_vllm_server.sh \ + --model-path /opt/dlami/nvme/models/Qwen3.6-27B \ + --compiled-artifacts /opt/dlami/nvme/qwen_artifacts/qwen36_27b_128k_fp8_mlp_only_vllm_statereset_run1 \ + --max-model-len 131072 \ + --seq-len 131072 \ + --cte-buckets 256,512 \ + --block-size 128 \ + --enable-vllm-chunked-prefill \ + --enable-prefix-caching \ + --gdn-checkpoint-interval 256 \ + --hybrid-gdn-recurrent-cache-dtype float32 \ + --hybrid-gdn-conv-cache-dtype bfloat16 \ + --mamba-cache-mode all \ + --mamba-ssm-cache-dtype float32 \ + --port 8000 +``` + +Treat this as an experiment, not a production mode, until validation passes. +Standard vLLM APC reuses attention KV blocks; Qwen3.6 also needs DeltaNet +recurrent state and conv state as prefix-boundary checkpoints keyed by the +cumulative prefix hash. If native APC does not produce exact greedy matches and +a clear warm-hit speedup, the next step is a hybrid APC path that restores those +GDN checkpoints alongside attention KV. + +For APC experiments, do not treat `256` as the only block size. It can be useful +for long-context amortization, but it is coarse for chat-style prefix reuse. +Run explicit sweeps at `64` and `128`; include `32` when hit granularity matters +enough to justify possible block-table/layout overhead. Keep the GDN checkpoint +interval separate from the attention block size. + +Immediate Trainium experiments: + +```text +262K TP=4, block_size=256, CTE buckets [256] +262K TP=4, block_size=128, CTE buckets [256] +128K TP=4, block_size=128, CTE buckets [256,512] +128K TP=4, block_size=256, CTE buckets [256,512] +``` + +Production chat proxy: + +```bash +contrib/models/Qwen3.6-27B/vllm/start_vllm_server.sh \ + --model-path /opt/dlami/nvme/models/Qwen3.6-27B \ + --compiled-artifacts /opt/dlami/nvme/qwen_artifacts/qwen36_27b_128k_fp8_mlp_only_vllm_statereset_run1 \ + --max-model-len 131072 \ + --seq-len 131072 \ + --cte-bucket 512 \ + --block-size 256 \ + --enable-vllm-chunked-prefill \ + --port 8001 +``` + +Then expose the guarded OpenAI-compatible endpoint on port 8000: + +```bash +python contrib/models/Qwen3.6-27B/vllm/qwen36_chat_proxy.py \ + --backend-url http://127.0.0.1:8001 \ + --port 8000 +``` + +The proxy forces `chat_template_kwargs={"enable_thinking": false}` for +`/v1/chat/completions` by default. It rejects raw `/v1/completions` because raw +prompts bypass the Qwen chat template and can pollute the hybrid model state. +It also hoists `system` and `developer` messages to a single leading `system` +message because the Qwen chat template rejects system messages that appear later +in the conversation. Start the proxy with `--allow-thinking` to allow a +request-level toggle while keeping the default non-thinking path. Supported +toggles include `enable_thinking=true`, `thinking=true`, +`thinking={"enabled": true}`, `reasoning_effort=low|medium|high`, and native +`chat_template_kwargs={"enable_thinking": true}`. Use `--allow-completions` only +for explicit debugging. + +Offline long-prompt smoke: + +```bash +python contrib/models/Qwen3.6-27B/vllm/run_offline_inference.py \ + --model-path /opt/dlami/nvme/models/Qwen3.6-27B \ + --compiled-artifacts /opt/dlami/nvme/qwen_artifacts/qwen36_27b_128k_fp8_mlp_only_vllm_statereset_run1 \ + --max-model-len 131072 \ + --seq-len 131072 \ + --cte-bucket 512 \ + --block-size 256 \ + --enable-vllm-chunked-prefill \ + --chat \ + --prompt "$(python - <<'PY' +print('Summarize this document in one paragraph. ' + 'Neuron inference ' * 700) +PY +)" +``` + +Offline token-exact prefix-cache validation: + +```bash +python validation_scripts/qwen36_vllm_prefix_cache_offline.py \ + --repo-root /home/ubuntu/inferentia-gdn \ + --model-path /opt/dlami/nvme/models/Qwen3.6-27B \ + --compiled-artifacts /opt/dlami/nvme/qwen_artifacts/qwen36_27b_128k_fp8_mlp_only_vllm_statereset_run1 \ + --max-model-len 131072 \ + --seq-len 131072 \ + --cte-bucket 512 \ + --block-size 128 \ + --enable-vllm-chunked-prefill \ + --mamba-cache-mode all +``` + +Offline partial-prefix validation: + +```bash +python validation_scripts/qwen36_vllm_prefix_cache_partial_offline.py \ + --repo-root /home/ubuntu/inferentia-gdn \ + --model-path /opt/dlami/nvme/models/Qwen3.6-27B \ + --compiled-artifacts /opt/dlami/nvme/qwen_artifacts/qwen36_27b_128k_fp8_mlp_only_vllm_statereset_run1 \ + --max-model-len 131072 \ + --seq-len 131072 \ + --cte-bucket 512 \ + --block-size 128 \ + --enable-vllm-chunked-prefill \ + --mamba-cache-mode all +``` + +Server-side prefix-cache validation through the guarded proxy: + +```bash +python validation_scripts/qwen36_prefix_cache_validation.py \ + --base-url http://127.0.0.1:8000 \ + --model qwen3.6-27b-neuron-128k-fp8-mlp +``` + +The acceptance gate is strict: repeated greedy calls must produce identical +output, and warm-hit latency should be materially lower than cold-fill latency. +For hybrid Qwen3.6, prefix-cache validation is not complete until the GDN +recurrent/conv state behavior is proven, not just attention KV cache hits. + +Hybrid APC exactness and HBM harness: + +```bash +python validation_scripts/qwen36_hybrid_apc_validation.py exactness \ + --model-path /opt/dlami/nvme/models/Qwen3.6-27B \ + --compiled-artifacts /opt/dlami/nvme/qwen_artifacts/qwen36_hybrid_apc \ + --seq-len 2048 \ + --cte-buckets 256,512 \ + --block-size 256 \ + --gdn-checkpoint-interval 256 \ + --enable-vllm-chunked-prefill + +python validation_scripts/qwen36_hybrid_apc_validation.py hbm \ + --context-lens 131072 262144 \ + --checkpoint-intervals 128 256 512 +``` + +Native APC validation run on Trn2 with the FP8 128K artifact: + +- server exact-repeat, `~10.8K` prompt tokens: `26.68s` cold to `1.67s` warm, + `16.0x` speedup, exact greedy text match; +- offline exact-repeat, token IDs exposed: `26.19s` cold to `2.38s` warm, + `11.0x` speedup, exact greedy token-ID match; +- offline partial-prefix reuse, token IDs exposed: `25.52s` no-cache target to + `1.70s` APC target after a different shared-prefix warmup request, `15.0x` + speedup, exact greedy token-ID match. +- server hardening, exact repeat: `25.38s` cold to `1.55s` warm, `16.35x` + speedup, exact text match; +- server hardening, cross-prefix reuse after unrelated prefix: `25.17s` cold to + `1.36s` warm, exact text match; +- shared-prefix concurrency at 1/2/4 requests returned all requested markers + exactly; the artifact still queues because it is compiled for `max_num_seqs=1`. + +Validation run on Trn2 with the FP8 128K artifact: + +- state-reset artifact: `/opt/dlami/nvme/qwen_artifacts/qwen36_27b_128k_fp8_mlp_only_vllm_statereset_run1`; +- OpenAI-compatible `/v1/chat/completions` behind the proxy passes focused + quality checks without callers passing `chat_template_kwargs`; +- repeated short-after-long validation passes after 32K and 64K requests, + confirming DeltaNet recurrent/conv state is reset for new requests; +- 32K and 64K needle retrieval prompts return all expected codes; +- measured prefill is `404-428 tok/s` from 512 through 64K prompt tokens; +- measured decode is `26.3-26.6 tok/s`; +- peak Neuron device memory is about `53.25 GB` decimal for the 64K eval. + +Raw `/v1/completions` prompts are not chat-templated and can pollute the hybrid +state if sent directly to the backend. Keep the backend private and expose the +proxy on the public port for production calls. + +4K BF16 Hybrid APC boundary/server probes: + +```bash +# Artifact/config audit before spending a Trn2 run. This flags oversized PA +# blocks, low block headroom, strict-gate boundary pressure, and nki_chunked CTE. +python validation_scripts/qwen36_artifact_config_audit.py \ + /mnt/trainium_artifacts/qwen_artifacts/qwen36_27b_4096_bf16_hybrid_apc_nki_chunked_prefix4096_ctx2_tkg2_r7i_20260520T082342Z \ + --compile-log /home/ubuntu/validation_logs/hybrid_apc_real_tokens/qwen36_4k_bf16_hybrid_apc_nki_chunked_prefix4096_20260520T082342Z_compile.log + +# Boundary-aligned APC proof. Run this directly against vLLM or a proxy started +# with --allow-completions because exact token-ID prompt lengths are required. +python validation_scripts/qwen36_openai_boundary_apc_probe.py \ + --base-url http://127.0.0.1:8000 \ + --model-path /home/ubuntu/models/Qwen3.6-27B \ + --lengths 256,512,1024,2048,4096 \ + --repeats 3 \ + --require-prefix-cache-query \ + --output-jsonl /home/ubuntu/validation_logs/hybrid_apc_real_tokens/boundary_apc_probe.jsonl + +# Cold prefill ctx-batch utilization check. Compare --concurrency 1 and 2 with +# --unique-per-request to avoid warm-cache reuse. +python validation_scripts/qwen36_chat_completion_context_bench.py \ + --base-url http://127.0.0.1:8000 \ + --model /home/ubuntu/models/Qwen3.6-27B \ + --model-path /home/ubuntu/models/Qwen3.6-27B \ + --lengths 4096 \ + --turns 8 \ + --repeats 3 \ + --concurrency 2 \ + --unique-per-request \ + --no-stream \ + --output-json /home/ubuntu/validation_logs/hybrid_apc_real_tokens/chat_4k_concurrency2.json +``` + +4K BF16 compile controls for the current investigation: + +```bash +# Single-request cold-prefill latency control: smaller PA blocks, usable block +# headroom, and fused DeltaNet CTE. Use a fresh compiled path and workdir. +python contrib/models/Qwen3.6-27B/test/integration/qwen36_27b_compile_fp8.py \ + --repo-root /home/ubuntu/inferentia-gdn-experimental \ + --model-path /home/ubuntu/models/Qwen3.6-27B \ + --compiled-path /mnt/trainium_artifacts/qwen_artifacts/qwen36_27b_4096_bf16_hybrid_apc_fused_block32_ctx1 \ + --base-compile-work-dir /mnt/trainium_artifacts/qwen_artifacts/_work_qwen36_4k_fused_block32_ctx1 \ + --weight-dtype bf16_control \ + --seq-len 4096 \ + --max-context-length 4096 \ + --cte-buckets 256,512,1024,2048,4096 \ + --prefix-buckets 4096 \ + --block-size 32 \ + --pa-headroom-blocks 64 \ + --tp-degree 4 \ + --logical-nc-config 2 \ + --max-num-seqs 1 \ + --ctx-batch-size 1 \ + --skip-warmup \ + --enable-prefix-caching \ + --enable-hybrid-apc \ + --enable-vllm-chunked-prefill \ + --deltanet-cte-backend fused \ + --gdn-checkpoint-interval 32 \ + --max-gdn-checkpoint-slots 160 \ + --hybrid-apc-require-vllm-metadata \ + --hybrid-apc-enable-backed-prefix-reads +``` + +The `block_size=32` control follows Neuron's prefix-cache performance guidance, +but it also increases the number of prefix boundaries the strict Hybrid APC gate +must prove. Without boundary chunk commits, a full 4096-token prompt has 128 +possible attention-hit boundaries at block size 32, so `max_gdn_checkpoint_slots` +must be sized accordingly or the safe gate will keep skipping APC reads. + +## Offline Smoke + +```bash +python contrib/models/Qwen3.6-27B/vllm/run_offline_inference.py \ + --model-path /opt/dlami/nvme/models/Qwen3.6-27B \ + --compiled-artifacts /opt/dlami/nvme/qwen_artifacts/qwen36_27b_128k_fp8_mlp_only_vllm_statereset_run1 \ + --max-model-len 131072 \ + --seq-len 131072 \ + --cte-buckets 128,256,512 \ + --chat \ + --prompt "What is 17 * 23? Answer with the number only." +``` + +## Next Milestone + +For cold-prefill latency, fix bucket waste before speculative decode or cache +quantization. The serving entrypoints now support multi-bucket CTE artifacts, +text-only CTE inputs, compact CTE masks, context-batch profiles, and attention +tile overrides. + +For warm-prefix production APC, the required contract remains a unified +prefix-cache object whose attention KV, GDN recurrent state, and GDN conv state +are jointly addressable, evictable, restorable, and exact under continuous +batching. + +Recommended order: + +1. Dynamic CTE buckets: start with `[128,256,512]` for 2K short-prompt tests, + `[256,512]` for 128K, and `[256]` for the 262K TP=4 load experiment. +2. Fused GDN CTE path validation: qwen chunked-prefill should use fused + DeltaNet with restored initial state by default. +3. Text-only CTE and compact-mask validation: no full dummy vision reductions + and no dense 4D causal masks in normal text serving. +4. Hybrid APC exactness: cold vs warm greedy token IDs, partial-prefix reuse, + multi-hit chat history, continuous batching movement, and eviction pressure. +5. Attention block-size sweeps at `64` and `128`, with `32` included for + granularity-sensitive chat workloads. +6. FP8 KV/cache only after the BF16/FP32 baseline is exact. +7. MTP/spec decode after recurrent-state rollback semantics are explicit. diff --git a/contrib/models/Qwen3.6-27B/vllm/hf_qwen35_config.py b/contrib/models/Qwen3.6-27B/vllm/hf_qwen35_config.py new file mode 100644 index 00000000..f764048a --- /dev/null +++ b/contrib/models/Qwen3.6-27B/vllm/hf_qwen35_config.py @@ -0,0 +1,68 @@ +"""Minimal Hugging Face config registration for Qwen3.5/Qwen3.6 vLLM smoke. + +The Neuron vLLM environment can lag upstream Transformers. vLLM validates the +HF config before the NxDI model registry gets a chance to instantiate the +contrib model, so register a permissive config class for the new model_type. +""" + +from __future__ import annotations + +from transformers import AutoConfig, PretrainedConfig + + +class Qwen35TextConfig(PretrainedConfig): + model_type = "qwen3_5_text" + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + +class Qwen35Config(PretrainedConfig): + model_type = "qwen3_5" + sub_configs = {"text_config": Qwen35TextConfig} + + def __init__(self, text_config=None, **kwargs): + if isinstance(text_config, dict): + text_config = Qwen35TextConfig(**text_config) + self.text_config = text_config + if text_config is not None: + for name, value in text_config.to_dict().items(): + if name not in {"architectures", "model_type"}: + kwargs.setdefault(name, value) + rope_parameters = getattr(text_config, "rope_parameters", None) + if isinstance(rope_parameters, dict): + kwargs.setdefault("rope_theta", rope_parameters.get("rope_theta")) + super().__init__(**kwargs) + + +def _is_registered(model_type: str) -> bool: + try: + AutoConfig.for_model(model_type) + except ValueError: + return False + return True + + +def register_qwen35_hf_config() -> None: + if not _is_registered(Qwen35TextConfig.model_type): + AutoConfig.register(Qwen35TextConfig.model_type, Qwen35TextConfig) + if not _is_registered(Qwen35Config.model_type): + AutoConfig.register(Qwen35Config.model_type, Qwen35Config) + + +def register_qwen35_vllm_architecture() -> None: + try: + from vllm.model_executor.models import ModelRegistry + except Exception: + return + + supported_archs = ModelRegistry.get_supported_archs() + qwen3_impl = "vllm.model_executor.models.qwen3:Qwen3ForCausalLM" + for arch in ("Qwen3_5ForConditionalGeneration", "Qwen3_5ForCausalLM"): + if arch not in supported_archs: + ModelRegistry.register_model(arch, qwen3_impl) + + +def register_qwen35_config() -> None: + register_qwen35_hf_config() + register_qwen35_vllm_architecture() diff --git a/contrib/models/Qwen3.6-27B/vllm/install_qwen36_vllm.sh b/contrib/models/Qwen3.6-27B/vllm/install_qwen36_vllm.sh new file mode 100755 index 00000000..f21536eb --- /dev/null +++ b/contrib/models/Qwen3.6-27B/vllm/install_qwen36_vllm.sh @@ -0,0 +1,61 @@ +#!/usr/bin/env bash +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +CONTRIB_ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)" + +if [[ $# -gt 0 ]]; then + VENV="$1" +else + VENV="" + for candidate in \ + /opt/aws_neuronx_venv_pytorch_2_9_nxd_inference \ + /opt/aws_neuronx_venv_pytorch_inference_vllm_0_16 \ + /opt/aws_neuronx_venv_pytorch_inference_vllm_0_13 \ + /opt/aws_neuronx_venv_pytorch_inference_vllm_0_12 \ + /opt/aws_neuronx_venv_pytorch_inference_vllm_0_11 + do + if [[ -x "${candidate}/bin/python" ]]; then + VENV="${candidate}" + break + fi + done +fi + +if [[ -z "${VENV}" || ! -x "${VENV}/bin/python" ]]; then + echo "ERROR: Could not find a vLLM/Neuron Python environment." >&2 + echo "Usage: $0 /path/to/venv" >&2 + exit 1 +fi + +PYTHON="${VENV}/bin/python" +export PATH="${VENV}/bin:${PATH}" +export PYTHONPATH="${CONTRIB_ROOT}:${PYTHONPATH:-}" + +echo "vLLM/Neuron env : ${VENV}" +echo "Contrib root : ${CONTRIB_ROOT}" + +"${PYTHON}" "${SCRIPT_DIR}/patch_nxdi_registry.py" --contrib-root "${CONTRIB_ROOT}" + +"${PYTHON}" - <<'PY' +import importlib.util +from neuronx_distributed_inference.utils.constants import MODEL_TYPES + +if importlib.util.find_spec("vllm") is None: + raise RuntimeError("vLLM is not installed in this environment") + +if importlib.util.find_spec("vllm_neuron") is None: + print( + "WARNING: vllm_neuron package was not found. If this environment uses " + "an AWS vLLM fork with built-in Neuron support this may be fine; " + "otherwise install the Neuron vLLM plugin that matches this SDK.", + ) + +for key in ("qwen3_5", "qwen3_5_text"): + assert key in MODEL_TYPES, f"{key} missing from MODEL_TYPES" + assert "causal-lm" in MODEL_TYPES[key], f"{key}/causal-lm missing" +print("Qwen3.6 vLLM registry verification OK") +PY + +echo "Installation complete." +echo "Remember to set PYTHONPATH=${CONTRIB_ROOT} when starting vLLM." diff --git a/contrib/models/Qwen3.6-27B/vllm/patch_nxdi_registry.py b/contrib/models/Qwen3.6-27B/vllm/patch_nxdi_registry.py new file mode 100644 index 00000000..91fe41c5 --- /dev/null +++ b/contrib/models/Qwen3.6-27B/vllm/patch_nxdi_registry.py @@ -0,0 +1,71 @@ +#!/usr/bin/env python3 +"""Register Qwen3.6 contrib model in the installed NxDI registry. + +This patches the active Python environment, not the repository checkout. The +runtime still needs PYTHONPATH to include contrib/models/Qwen3.6-27B so that +`src.modeling_qwen35` can be imported by the vLLM process. +""" + +from __future__ import annotations + +import argparse +from pathlib import Path + + +MARKER_BEGIN = "# QWEN36_CONTRIB_VLLM_REGISTER_BEGIN" +MARKER_END = "# QWEN36_CONTRIB_VLLM_REGISTER_END" + +REGISTRATION_BLOCK = f""" + +{MARKER_BEGIN} +# Registered by contrib/models/Qwen3.6-27B/vllm/install_qwen36_vllm.sh. +# Requires PYTHONPATH to include the Qwen3.6-27B contrib directory at runtime. +try: + from src.modeling_qwen35 import ( + NeuronQwen35ForCausalLM as _Qwen36ContribForCausalLM, + ) +except Exception: + _Qwen36ContribForCausalLM = None + +if _Qwen36ContribForCausalLM is not None: + MODEL_TYPES.setdefault("qwen3_5", {{}})["causal-lm"] = _Qwen36ContribForCausalLM + MODEL_TYPES.setdefault("qwen3_5_text", {{}})["causal-lm"] = _Qwen36ContribForCausalLM +{MARKER_END} +""" + + +def _constants_path() -> Path: + import neuronx_distributed_inference.utils.constants as constants # noqa: WPS433 + + return Path(constants.__file__).resolve() + + +def main() -> int: + parser = argparse.ArgumentParser() + parser.add_argument("--contrib-root", required=True) + parser.add_argument("--dry-run", action="store_true") + args = parser.parse_args() + + contrib_root = Path(args.contrib_root).expanduser().resolve() + if not (contrib_root / "src" / "modeling_qwen35.py").exists(): + raise FileNotFoundError(f"Qwen3.6 contrib root looks invalid: {contrib_root}") + + path = _constants_path() + text = path.read_text() + if MARKER_BEGIN in text: + print(f"Registry already patched: {path}") + return 0 + + patched = text.rstrip() + REGISTRATION_BLOCK + "\n" + print(f"Patch target: {path}") + if args.dry_run: + print("Dry run; no files written") + return 0 + + path.write_text(patched) + print("Patched NxDI MODEL_TYPES with qwen3_5 and qwen3_5_text") + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/contrib/models/Qwen3.6-27B/vllm/patches/vllm_neuron_qwen36_kv_cache_spec.patch b/contrib/models/Qwen3.6-27B/vllm/patches/vllm_neuron_qwen36_kv_cache_spec.patch new file mode 100644 index 00000000..4ad9321f --- /dev/null +++ b/contrib/models/Qwen3.6-27B/vllm/patches/vllm_neuron_qwen36_kv_cache_spec.patch @@ -0,0 +1,47 @@ +--- a/vllm_neuron/worker/neuronx_distributed_model_runner.py ++++ b/vllm_neuron/worker/neuronx_distributed_model_runner.py +@@ -898,17 +898,39 @@ + KVCacheSpec: A dictionary mapping layer names to their KV cache + format. Layers that do not need KV cache are not included. + """ +- # Get number of layers from model config +- num_layers = get_num_layers_from_hf_config(self.model_config.hf_config) ++ hf_config = self.model_config.hf_config ++ num_layers = get_num_layers_from_hf_config(hf_config) ++ ++ layer_types = getattr(hf_config, "layer_types", None) ++ if layer_types is not None and len(layer_types) == num_layers: ++ attention_layer_indices = [ ++ idx for idx, layer_type in enumerate(layer_types) ++ if layer_type in ("full_attention", "attention", "self_attention") ++ ] ++ else: ++ full_attention_interval = getattr(hf_config, "full_attention_interval", None) ++ if full_attention_interval: ++ interval = int(full_attention_interval) ++ attention_layer_indices = [ ++ idx for idx in range(num_layers) if (idx + 1) % interval == 0 ++ ] ++ else: ++ attention_layer_indices = list(range(num_layers)) ++ ++ total_kv_heads = getattr( ++ hf_config, ++ "num_key_value_heads", ++ getattr(hf_config, "num_attention_heads", self.parallel_config.tensor_parallel_size), ++ ) ++ local_kv_heads = max(1, int(total_kv_heads) // self.parallel_config.tensor_parallel_size) + + kv_cache_spec: dict[str, KVCacheSpec] = {} + +- # Create a spec for each layer +- for layer_idx in range(num_layers): ++ for layer_idx in attention_layer_indices: + layer_name = f"layers.{layer_idx}.self_attn" # standard naming convention + kv_cache_spec[layer_name] = FullAttentionSpec( + block_size=self.block_size, +- num_kv_heads=self.parallel_config.tensor_parallel_size, ++ num_kv_heads=local_kv_heads, + head_size=self.model.head_dim, + dtype=self.model_config.dtype, + sliding_window=self.model_config.get_sliding_window(), diff --git a/contrib/models/Qwen3.6-27B/vllm/qwen36_chat_proxy.py b/contrib/models/Qwen3.6-27B/vllm/qwen36_chat_proxy.py new file mode 100644 index 00000000..fae83a67 --- /dev/null +++ b/contrib/models/Qwen3.6-27B/vllm/qwen36_chat_proxy.py @@ -0,0 +1,446 @@ +#!/usr/bin/env python3 +"""Small OpenAI-compatible guard proxy for Qwen3.6 vLLM serving. + +The upstream Qwen3.6 chat template defaults to thinking mode. For this Neuron +artifact the production-safe chat path is non-thinking mode, so this proxy +injects ``chat_template_kwargs={"enable_thinking": false}`` for chat requests +unless thinking is explicitly enabled by request and the proxy was started with +``--allow-thinking``. It also blocks raw completions by default because they are +not chat-templated. +""" + +from __future__ import annotations + +import argparse +import json +import os +import urllib.error +import urllib.request +from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer +from typing import Any +from urllib.parse import urlsplit + +_HOP_BY_HOP_HEADERS = { + "connection", + "content-length", + "keep-alive", + "proxy-authenticate", + "proxy-authorization", + "te", + "trailer", + "transfer-encoding", + "upgrade", +} + +_TRUE_STRINGS = {"1", "true", "yes", "y", "on", "enable", "enabled", "thinking"} +_FALSE_STRINGS = { + "0", + "false", + "no", + "n", + "off", + "disable", + "disabled", + "none", + "non_thinking", + "no_thinking", +} +_THINKING_PROXY_FIELDS = { + "enable_thinking", + "thinking", + "thinking_enabled", + "reasoning", + "reasoning_effort", +} + + +def _json_response(handler: BaseHTTPRequestHandler, status: int, payload: dict[str, Any]): + body = json.dumps(payload).encode("utf-8") + handler.send_response(status) + handler.send_header("Content-Type", "application/json") + handler.send_header("Content-Length", str(len(body))) + handler.end_headers() + handler.wfile.write(body) + + +def _replace_sse_data_payload(event: bytes, payload: str) -> bytes: + lines = event.decode("utf-8").splitlines() + replaced = False + for index, line in enumerate(lines): + if line.startswith("data:") and not replaced: + lines[index] = "data: " + payload + replaced = True + break + return ("\n".join(lines) + "\n\n").encode("utf-8") + + +def _prepend_think_start_to_sse_event(event: bytes) -> tuple[bytes, bool, bool]: + """Return event, whether a decision was made, and whether it was changed.""" + try: + text = event.decode("utf-8") + except UnicodeDecodeError: + return event, False, False + + for line in text.splitlines(): + if not line.startswith("data:"): + continue + + payload = line[len("data:") :].strip() + if not payload or payload == "[DONE]": + return event, False, False + + try: + obj = json.loads(payload) + except json.JSONDecodeError: + return event, False, False + + choices = obj.get("choices") or [] + for choice in choices: + if not isinstance(choice, dict): + continue + delta = choice.get("delta") or {} + if not isinstance(delta, dict): + continue + content = delta.get("content") + if not isinstance(content, str) or not content: + continue + if content.lstrip().startswith(""): + return event, True, False + + delta["content"] = "\n" + content + return ( + _replace_sse_data_payload( + event, + json.dumps(obj, ensure_ascii=False, separators=(",", ":")), + ), + True, + True, + ) + + return event, False, False + + +def _message_text(content: Any) -> str: + if isinstance(content, str): + return content + if isinstance(content, list): + parts: list[str] = [] + for item in content: + if isinstance(item, dict): + text = item.get("text") + if isinstance(text, str): + parts.append(text) + elif isinstance(item, str): + parts.append(item) + return "\n".join(parts) + return str(content) + + +def _normalize_messages_for_qwen(messages: Any) -> Any: + """Make common OpenAI message layouts acceptable to the Qwen chat template.""" + if not isinstance(messages, list): + return messages + + system_parts: list[str] = [] + normal_messages: list[Any] = [] + for message in messages: + if not isinstance(message, dict): + normal_messages.append(message) + continue + + role = message.get("role") + if role in {"system", "developer"}: + system_parts.append(_message_text(message.get("content", ""))) + else: + normal_messages.append(message) + + if not system_parts: + return messages + + system_message = { + "role": "system", + "content": "\n\n".join(part for part in system_parts if part), + } + return [system_message, *normal_messages] + + +def _request_path(path: str) -> str: + normalized = urlsplit(path).path.rstrip("/") + return normalized or "/" + + +def _coerce_optional_bool(value: Any) -> bool | None: + if isinstance(value, bool): + return value + if isinstance(value, int) and value in (0, 1): + return bool(value) + if isinstance(value, str): + normalized = value.strip().lower().replace("-", "_").replace(" ", "_") + if normalized in _TRUE_STRINGS: + return True + if normalized in _FALSE_STRINGS: + return False + return None + + +def _coerce_thinking_object(value: Any) -> bool | None: + coerced = _coerce_optional_bool(value) + if coerced is not None: + return coerced + if not isinstance(value, dict): + return None + + for key in ("enable_thinking", "enabled", "enable", "value"): + if key in value: + coerced = _coerce_optional_bool(value.get(key)) + if coerced is not None: + return coerced + + budget = value.get("budget_tokens") + if isinstance(budget, int): + return budget > 0 + + effort = value.get("effort") or value.get("reasoning_effort") + coerced = _coerce_optional_bool(effort) + if coerced is not None: + return coerced + if isinstance(effort, str) and effort.strip(): + return True + return None + + +def _requested_thinking_enabled(payload: dict[str, Any]) -> bool | None: + for key in ("enable_thinking", "thinking_enabled", "thinking"): + if key in payload: + coerced = _coerce_thinking_object(payload.get(key)) + if coerced is not None: + return coerced + + template_kwargs = payload.get("chat_template_kwargs") + if isinstance(template_kwargs, dict) and "enable_thinking" in template_kwargs: + coerced = _coerce_optional_bool(template_kwargs.get("enable_thinking")) + if coerced is not None: + return coerced + + if "reasoning" in payload: + coerced = _coerce_thinking_object(payload.get("reasoning")) + if coerced is not None: + return coerced + + if "reasoning_effort" in payload: + effort = payload.get("reasoning_effort") + coerced = _coerce_optional_bool(effort) + if coerced is not None: + return coerced + if isinstance(effort, str) and effort.strip(): + return True + + return None + + +def _apply_thinking_policy( + payload: dict[str, Any], + *, + allow_thinking: bool, + default_thinking: bool = False, +) -> bool: + template_kwargs = payload.get("chat_template_kwargs") + if not isinstance(template_kwargs, dict): + template_kwargs = {} + else: + template_kwargs = dict(template_kwargs) + + requested = _requested_thinking_enabled(payload) + if allow_thinking: + enable_thinking = bool(requested) if requested is not None else default_thinking + else: + enable_thinking = False + template_kwargs["enable_thinking"] = enable_thinking + payload["chat_template_kwargs"] = template_kwargs + + # Keep compatibility toggles out of the backend OpenAI schema. vLLM only + # needs chat_template_kwargs for Qwen's tokenizer chat template. + for key in _THINKING_PROXY_FIELDS: + payload.pop(key, None) + + return enable_thinking + + +class Qwen36ProxyHandler(BaseHTTPRequestHandler): + backend_url: str = "http://127.0.0.1:8001" + force_disable_thinking: bool = True + default_thinking: bool = False + allow_completions: bool = False + + def log_message(self, fmt: str, *args): # noqa: D401 + print(f"{self.address_string()} - {fmt % args}", flush=True) + + def _write_stream_response(self, resp, *, inject_thinking_start: bool): + thinking_start_decided = not inject_thinking_start + buffer = b"" + + while True: + chunk = resp.read(8192) + if not chunk: + break + + if thinking_start_decided: + self.wfile.write(chunk) + self.wfile.flush() + continue + + buffer += chunk + while b"\n\n" in buffer: + event, buffer = buffer.split(b"\n\n", 1) + event += b"\n\n" + event, decided, _changed = _prepend_think_start_to_sse_event(event) + thinking_start_decided = thinking_start_decided or decided + self.wfile.write(event) + self.wfile.flush() + if thinking_start_decided and buffer: + self.wfile.write(buffer) + self.wfile.flush() + buffer = b"" + break + + if buffer: + self.wfile.write(buffer) + self.wfile.flush() + + def _forward(self, method: str, body: bytes | None = None, *, inject_thinking_start: bool = False): + headers = { + key: value + for key, value in self.headers.items() + if key.lower() not in {"host", "content-length", "connection"} + } + url = self.backend_url.rstrip("/") + self.path + req = urllib.request.Request(url, data=body, headers=headers, method=method) + try: + with urllib.request.urlopen(req, timeout=None) as resp: + content_type = resp.headers.get("Content-Type", "") + if "text/event-stream" in content_type.lower(): + self.send_response(resp.status) + for key, value in resp.headers.items(): + if key.lower() in _HOP_BY_HOP_HEADERS: + continue + self.send_header(key, value) + self.end_headers() + self._write_stream_response(resp, inject_thinking_start=inject_thinking_start) + else: + response_body = resp.read() + self.send_response(resp.status) + for key, value in resp.headers.items(): + if key.lower() in _HOP_BY_HOP_HEADERS: + continue + self.send_header(key, value) + self.send_header("Content-Length", str(len(response_body))) + self.end_headers() + self.wfile.write(response_body) + except urllib.error.HTTPError as exc: + error_body = exc.read() + self.send_response(exc.code) + for key, value in exc.headers.items(): + if key.lower() in _HOP_BY_HOP_HEADERS: + continue + self.send_header(key, value) + self.send_header("Content-Length", str(len(error_body))) + self.end_headers() + self.wfile.write(error_body) + + def do_GET(self): # noqa: N802 + self._forward("GET") + + def do_POST(self): # noqa: N802 + length = int(self.headers.get("Content-Length", "0") or "0") + raw_body = self.rfile.read(length) if length else b"" + + request_path = _request_path(self.path) + + if request_path == "/v1/completions" and not self.allow_completions: + _json_response( + self, + 400, + { + "error": { + "message": ( + "Raw /v1/completions is disabled for Qwen3.6. " + "Use /v1/chat/completions so the Qwen chat template " + "and non-thinking mode are applied." + ), + "type": "invalid_request_error", + "code": "qwen36_chat_required", + } + }, + ) + return + + if request_path == "/v1/chat/completions" and raw_body: + try: + payload = json.loads(raw_body) + except json.JSONDecodeError: + self._forward("POST", raw_body) + return + + thinking_enabled = _apply_thinking_policy( + payload, + allow_thinking=not self.force_disable_thinking, + default_thinking=self.default_thinking, + ) + payload["messages"] = _normalize_messages_for_qwen(payload.get("messages")) + raw_body = json.dumps(payload).encode("utf-8") + else: + thinking_enabled = False + + self._forward("POST", raw_body, inject_thinking_start=thinking_enabled) + + +def main() -> int: + parser = argparse.ArgumentParser() + parser.add_argument("--host", default="0.0.0.0") + parser.add_argument("--port", type=int, default=8000) + parser.add_argument( + "--backend-url", + default=os.getenv("QWEN36_BACKEND_URL", "http://127.0.0.1:8001"), + ) + parser.add_argument("--allow-completions", action="store_true") + parser.add_argument( + "--allow-thinking", + action="store_true", + help=( + "Allow request-level Qwen thinking mode. Requests still default to " + "non-thinking; enable with enable_thinking=true, thinking=true, " + "reasoning.enabled=true, reasoning_effort=low/medium/high, or " + "chat_template_kwargs.enable_thinking=true." + ), + ) + parser.add_argument( + "--default-thinking", + action="store_true", + help=( + "Enable Qwen thinking by default when --allow-thinking is set and " + "the request does not explicitly provide a thinking toggle. " + "Explicit enable_thinking=false still disables it." + ), + ) + args = parser.parse_args() + + Qwen36ProxyHandler.backend_url = args.backend_url + Qwen36ProxyHandler.allow_completions = args.allow_completions + Qwen36ProxyHandler.force_disable_thinking = not args.allow_thinking + Qwen36ProxyHandler.default_thinking = args.default_thinking + + server = ThreadingHTTPServer((args.host, args.port), Qwen36ProxyHandler) + print( + "Qwen3.6 proxy listening on " + f"{args.host}:{args.port}, backend={args.backend_url}, " + f"allow_completions={args.allow_completions}, " + f"force_disable_thinking={not args.allow_thinking}, " + f"default_thinking={args.default_thinking}", + flush=True, + ) + server.serve_forever() + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/contrib/models/Qwen3.6-27B/vllm/qwen36_hybrid_apc_scheduler_patch.py b/contrib/models/Qwen3.6-27B/vllm/qwen36_hybrid_apc_scheduler_patch.py new file mode 100644 index 00000000..8563e6a3 --- /dev/null +++ b/contrib/models/Qwen3.6-27B/vllm/qwen36_hybrid_apc_scheduler_patch.py @@ -0,0 +1,2774 @@ +"""vLLM scheduler patch for Qwen Hybrid APC fallback. + +This module is intentionally opt-in. The safe fallback for current Hybrid APC +validation is to make vLLM skip attention-prefix reads before slot allocation +when the GDN checkpoint side is not integrated with the scheduler yet. +""" + +from __future__ import annotations + +import hashlib +import importlib.abc +import importlib.machinery +import json +import logging +import os +import struct +import sys +from typing import Any, Hashable, NamedTuple + +import torch + + +logger = logging.getLogger(__name__) +_SCHEDULER_MODULE = "vllm.v1.core.sched.scheduler" +_KV_CACHE_MANAGER_MODULE = "vllm.v1.core.kv_cache_manager" +_VLLM_NEURON_RUNNER_MODULE = "vllm_neuron.worker.neuronx_distributed_model_runner" +_VLLM_NEURON_LOADER_MODULE = "vllm_neuron.worker.neuronx_distributed_model_loader" +_PATCHED_MODULES = { + _SCHEDULER_MODULE, + _KV_CACHE_MANAGER_MODULE, + _VLLM_NEURON_RUNNER_MODULE, + _VLLM_NEURON_LOADER_MODULE, +} + + +class HybridGDNPrefixKey(NamedTuple): + cumulative_prefix_hash: Hashable + prefix_len: int + block_size: int + cache_salt: Hashable | None + model_revision: str + layout_version: int + tp_rank: int + recurrent_dtype: str + conv_dtype: str + + +_GDN_PREFIX_KEYS: set[HybridGDNPrefixKey] = set() +_AUTHORIZED_PREFIX_READS: dict[int, list[HybridGDNPrefixKey]] = {} +_AUTHORIZED_PREFIX_READS_BY_REQUEST: dict[Hashable, list[HybridGDNPrefixKey]] = {} +_SCHEDULER_OUTPUT_METADATA_ATTR = "_qwen36_hybrid_apc_metadata_by_request_id" +_SCHEDULER_OUTPUT_REQUEST_RECORDS_ATTR = "_qwen36_hybrid_apc_request_records" +_MAX_PREFIX_CACHE_HIT_LEN_ATTR = "_qwen36_hybrid_apc_max_prefix_cache_hit_len" +_MAX_PREFIX_CACHE_BLOCKS_ATTR = "_qwen36_hybrid_apc_max_prefix_cache_blocks" +_RUNNER_PREFILL_STATE_FOR_OUTPUT_ATTR = ( + "_qwen36_hybrid_apc_prefill_completion_state_for_output" +) +_HYBRID_APC_RUNTIME_CONFIG_KEYS = ( + "use_hybrid_apc_manager", + "use_qwen_hybrid_chunked_prefill", + "use_qwen_hybrid_chunked_prefill_nki", + "gdn_checkpoint_interval", + "max_gdn_checkpoint_slots", + "gdn_recurrent_cache_dtype", + "gdn_conv_cache_dtype", + "hybrid_recurrent_cache_dtype", + "hybrid_conv_cache_dtype", + "hybrid_cache_mode", + "hybrid_cache_prefix_boundary_only", + "hybrid_cache_block_boundary_only", + "hybrid_cache_validate_exact", + "hybrid_apc_layout_version", + "hybrid_apc_allow_residual_replay", + "hybrid_apc_cache_salt", + "hybrid_apc_model_revision", + "hybrid_apc_require_vllm_metadata", + "hybrid_apc_allow_local_hash_fallback", + "hybrid_apc_require_attention_block_refs", + "hybrid_apc_reject_unbacked_attention_hits", + "hybrid_apc_disable_unbacked_prefix_reads", + "hybrid_apc_enable_backed_prefix_reads", + "hybrid_apc_max_backed_prefix_read_len", + "hybrid_apc_allow_mixed_prefill_decode", + "hybrid_apc_prefill_chunk_tokens", + "qwen_prefill_group_size", +) +_HYBRID_APC_BRIDGE_CONFIG_ATTRS = { + "hybrid_apc_allow_local_hash_fallback": "allow_local_hash_fallback", + "hybrid_apc_require_attention_block_refs": "require_attention_block_refs", + "hybrid_apc_reject_unbacked_attention_hits": "reject_unbacked_attention_hits", + "hybrid_apc_cache_salt": "cache_salt", + "hybrid_apc_model_revision": "model_revision", + "hybrid_apc_layout_version": "layout_version", + "hybrid_recurrent_cache_dtype": "recurrent_dtype", + "hybrid_conv_cache_dtype": "conv_dtype", +} +_KV_CACHE_ATTENTION_LAYER_TYPES = { + "attention", + "full_attention", + "self_attention", + "sliding_attention", +} + + +def _env_flag(name: str) -> bool: + value = os.environ.get(name) + return value is not None and value.strip().lower() not in { + "", + "0", + "false", + "no", + "off", + } + + +def _env_int(name: str) -> int | None: + value = os.environ.get(name) + if value is None or value.strip() == "": + return None + return int(value) + + +def _get_hf_config(vllm_config: Any) -> Any: + model_config = getattr(vllm_config, "model_config", None) + return getattr(model_config, "hf_config", None) + + +def _get_additional_config(vllm_config: Any) -> dict[str, Any]: + additional_config = getattr(vllm_config, "additional_config", None) + return additional_config if isinstance(additional_config, dict) else {} + + +def _config_flag(config: Any, name: str, default: bool = False) -> bool: + return bool(getattr(config, name, default)) + + +def _config_value(config: Any, name: str, default: Any) -> Any: + return getattr(config, name, default) + + +def _num_layers_from_hf_config( + hf_config: Any, + original_get_kv_cache_spec: Any | None = None, +) -> int | None: + if hf_config is None: + return None + original_globals = getattr(original_get_kv_cache_spec, "__globals__", {}) + get_num_layers = original_globals.get("get_num_layers_from_hf_config") + if get_num_layers is not None: + try: + return int(get_num_layers(hf_config)) + except Exception: + pass + for attr in ("num_hidden_layers", "num_layers", "n_layer"): + value = getattr(hf_config, attr, None) + if value is not None: + return int(value) + layer_types = getattr(hf_config, "layer_types", None) + if layer_types is not None: + try: + return len(layer_types) + except TypeError: + return None + return None + + +def _hybrid_kv_attention_layer_indices( + hf_config: Any, + num_layers: int, +) -> list[int] | None: + layer_types = getattr(hf_config, "layer_types", None) + if layer_types is not None: + try: + layer_types = list(layer_types) + except TypeError: + layer_types = None + if layer_types is not None and len(layer_types) == num_layers: + attention_indices = [ + idx + for idx, layer_type in enumerate(layer_types) + if str(layer_type).lower() in _KV_CACHE_ATTENTION_LAYER_TYPES + ] + if 0 < len(attention_indices) < num_layers: + return attention_indices + return None + + full_attention_interval = getattr(hf_config, "full_attention_interval", None) + if full_attention_interval: + interval = int(full_attention_interval) + if interval > 1: + attention_indices = [ + idx for idx in range(num_layers) if (idx + 1) % interval == 0 + ] + if attention_indices and len(attention_indices) < num_layers: + return attention_indices + return None + + +def _local_num_kv_heads(hf_config: Any, parallel_config: Any) -> int: + tp_size = max(1, int(getattr(parallel_config, "tensor_parallel_size", 1) or 1)) + total_kv_heads = getattr(hf_config, "num_key_value_heads", None) + if total_kv_heads is None: + total_kv_heads = getattr(hf_config, "num_attention_heads", None) + if total_kv_heads is None: + return tp_size + return max(1, int(total_kv_heads) // tp_size) + + +def _full_attention_spec_class(original_get_kv_cache_spec: Any) -> Any | None: + original_globals = getattr(original_get_kv_cache_spec, "__globals__", {}) + spec_cls = original_globals.get("FullAttentionSpec") + if spec_cls is not None: + return spec_cls + try: + from vllm.v1.kv_cache_interface import FullAttentionSpec # noqa: WPS433 + + return FullAttentionSpec + except Exception: + return None + + +def _scheduler_config_flag( + scheduler: Any, + name: str, + default: bool = False, +) -> bool: + vllm_config = getattr(scheduler, "vllm_config", None) + additional_config = _get_additional_config(vllm_config) + if name in additional_config: + return bool(additional_config[name]) + return _config_flag(_get_hf_config(vllm_config), name, default) + + +def _scheduler_config_value( + scheduler: Any, + name: str, + default: Any, +) -> Any: + vllm_config = getattr(scheduler, "vllm_config", None) + additional_config = _get_additional_config(vllm_config) + if name in additional_config: + return additional_config[name] + return _config_value(_get_hf_config(vllm_config), name, default) + + +def _max_num_seqs_for_scheduler(scheduler: Any) -> int: + scheduler_config = getattr(scheduler, "scheduler_config", None) + max_num_seqs = getattr(scheduler_config, "max_num_seqs", 1) + return int(max_num_seqs or 1) + + +def _should_defer_waiting_prefills_while_running(scheduler: Any) -> bool: + if _env_flag("QWEN36_HYBRID_APC_ALLOW_MIXED_PREFILL_DECODE"): + return False + if _scheduler_config_flag( + scheduler, + "hybrid_apc_allow_mixed_prefill_decode", + ): + return False + if _env_flag("QWEN36_HYBRID_APC_DEFER_WAITING_WHILE_RUNNING"): + return True + return ( + _scheduler_config_flag(scheduler, "use_hybrid_apc_manager") + and _scheduler_config_flag(scheduler, "use_qwen_hybrid_chunked_prefill") + and _max_num_seqs_for_scheduler(scheduler) > 1 + ) + + +def _new_empty_queue_like(queue: Any): + try: + return type(queue)() + except Exception: + return None + + +def _queue_add(queue: Any, request: Any) -> None: + add_request = getattr(queue, "add_request", None) + if add_request is not None: + add_request(request) + else: + queue.append(request) + + +def _merge_waiting_queues(front: Any, back: Any): + merged = _new_empty_queue_like(front) + if merged is None: + return back + for request in front: + _queue_add(merged, request) + for request in back: + _queue_add(merged, request) + return merged + + +def _normalize_dtype(value: Any, default: str) -> str: + if value is None: + value = default + normalized = str(value).lower() + aliases = { + "fp32": "float32", + "float32": "float32", + "torch.float32": "float32", + "bf16": "bfloat16", + "bfloat16": "bfloat16", + "torch.bfloat16": "bfloat16", + } + return aliases.get(normalized, normalized) + + +def _normalize_request_id(request_id: Any) -> Hashable | None: + if request_id is None: + return None + if isinstance(request_id, list): + return tuple(request_id) + try: + hash(request_id) + except TypeError: + return repr(request_id) + return request_id + + +def _to_registry_key(key: Any) -> HybridGDNPrefixKey: + return HybridGDNPrefixKey( + cumulative_prefix_hash=getattr(key, "cumulative_prefix_hash"), + prefix_len=int(getattr(key, "prefix_len")), + block_size=int(getattr(key, "block_size")), + cache_salt=getattr(key, "cache_salt", None), + model_revision=str(getattr(key, "model_revision", "unknown")), + layout_version=int(getattr(key, "layout_version", 1)), + tp_rank=int(getattr(key, "tp_rank", 0)), + recurrent_dtype=_normalize_dtype( + getattr(key, "recurrent_dtype", None), + "float32", + ), + conv_dtype=_normalize_dtype(getattr(key, "conv_dtype", None), "bfloat16"), + ) + + +def register_hybrid_apc_gdn_checkpoint(key: Any) -> HybridGDNPrefixKey: + """Publish a committed GDN checkpoint boundary to the scheduler process.""" + + registry_key = _to_registry_key(key) + _GDN_PREFIX_KEYS.add(registry_key) + if _env_flag("QWEN36_HYBRID_APC_DEBUG"): + print( + "[hybrid_apc_debug] scheduler-register " + f"prefix_len={registry_key.prefix_len} " + f"model_revision={registry_key.model_revision} " + f"registry_size={len(_GDN_PREFIX_KEYS)}", + flush=True, + ) + return registry_key + + +def unregister_hybrid_apc_gdn_checkpoint(key: Any) -> bool: + registry_key = _to_registry_key(key) + if registry_key not in _GDN_PREFIX_KEYS: + return False + _GDN_PREFIX_KEYS.remove(registry_key) + if _env_flag("QWEN36_HYBRID_APC_DEBUG"): + print( + "[hybrid_apc_debug] scheduler-unregister " + f"prefix_len={registry_key.prefix_len} " + f"model_revision={registry_key.model_revision} " + f"registry_size={len(_GDN_PREFIX_KEYS)}", + flush=True, + ) + return True + + +def clear_hybrid_apc_gdn_checkpoint_registry() -> None: + _GDN_PREFIX_KEYS.clear() + _AUTHORIZED_PREFIX_READS.clear() + _AUTHORIZED_PREFIX_READS_BY_REQUEST.clear() + + +def authorize_hybrid_apc_prefix_read( + key: Any, + *, + request_id: Hashable | None = None, +) -> HybridGDNPrefixKey: + """Publish a scheduler-approved prefix read for suffix-only model prep.""" + + registry_key = _to_registry_key(key) + normalized_request_id = _normalize_request_id(request_id) + if normalized_request_id is None: + _AUTHORIZED_PREFIX_READS.setdefault(registry_key.prefix_len, []).append( + registry_key + ) + else: + _AUTHORIZED_PREFIX_READS_BY_REQUEST.setdefault( + normalized_request_id, + [], + ).append(registry_key) + return registry_key + + +def _pop_matching_authorized_key( + candidates: list[HybridGDNPrefixKey], + *, + prefix_len: int, + cache_salt: Hashable | None, + model_revision: str, + layout_version: int, + tp_rank: int, + recurrent_dtype: str, + conv_dtype: str, +) -> HybridGDNPrefixKey | None: + for idx, key in enumerate(candidates): + if key.prefix_len != prefix_len: + continue + if key.cache_salt != cache_salt: + continue + if key.model_revision != str(model_revision): + continue + if key.layout_version != int(layout_version): + continue + if key.tp_rank != int(tp_rank): + continue + if key.recurrent_dtype != recurrent_dtype or key.conv_dtype != conv_dtype: + continue + return candidates.pop(idx) + return None + + +def pop_hybrid_apc_authorized_prefix_key( + *, + prefix_len: int, + request_id: Hashable | None = None, + cache_salt: Hashable | None = None, + model_revision: str = "unknown", + layout_version: int = 1, + tp_rank: int = 0, + recurrent_dtype: str = "float32", + conv_dtype: str = "bfloat16", +) -> HybridGDNPrefixKey | None: + """Consume the exact key for a prefix read allowed by the scheduler.""" + + prefix_len = int(prefix_len) + recurrent_dtype = _normalize_dtype(recurrent_dtype, "float32") + conv_dtype = _normalize_dtype(conv_dtype, "bfloat16") + normalized_request_id = _normalize_request_id(request_id) + if normalized_request_id is not None: + candidates = _AUTHORIZED_PREFIX_READS_BY_REQUEST.get(normalized_request_id) + if candidates: + matched = _pop_matching_authorized_key( + candidates, + prefix_len=prefix_len, + cache_salt=cache_salt, + model_revision=model_revision, + layout_version=layout_version, + tp_rank=tp_rank, + recurrent_dtype=recurrent_dtype, + conv_dtype=conv_dtype, + ) + if matched is not None: + if not candidates: + _AUTHORIZED_PREFIX_READS_BY_REQUEST.pop( + normalized_request_id, + None, + ) + return matched + + candidates = _AUTHORIZED_PREFIX_READS.get(prefix_len) + if not candidates: + return None + matched = _pop_matching_authorized_key( + candidates, + prefix_len=prefix_len, + cache_salt=cache_salt, + model_revision=model_revision, + layout_version=layout_version, + tp_rank=tp_rank, + recurrent_dtype=recurrent_dtype, + conv_dtype=conv_dtype, + ) + if matched is not None and not candidates: + _AUTHORIZED_PREFIX_READS.pop(prefix_len, None) + if matched is not None: + return matched + return None + + +def _block_size_for_scheduler(scheduler: Any) -> int: + cache_config = getattr(scheduler, "cache_config", None) + block_size = getattr(cache_config, "block_size", None) + if block_size is None: + hf_config = _get_hf_config(getattr(scheduler, "vllm_config", None)) + block_size = _config_value( + hf_config, + "gdn_checkpoint_interval", + 0, + ) + return int(block_size or 0) + + +def _local_cumulative_prefix_hashes( + token_ids: list[int] | tuple[int, ...], + *, + block_size: int, + max_prefix_len: int, +) -> dict[int, str]: + max_prefix_len = max(0, int(max_prefix_len)) + max_prefix_len = max_prefix_len // block_size * block_size + parent_digest = b"" + hashes: dict[int, str] = {} + for block_start in range(0, max_prefix_len, block_size): + block_end = block_start + block_size + block = [int(token) for token in token_ids[block_start:block_end]] + digest = hashlib.blake2b(digest_size=16) + digest.update(parent_digest) + digest.update(struct.pack(" dict[int, Hashable]: + block_hashes = list(getattr(request, "block_hashes", ()) or ()) + if not block_hashes: + return {} + if max_prefix_len is None: + max_prefix_len = len(block_hashes) * block_size + max_prefix_len = max(0, int(max_prefix_len)) + max_prefix_len = max_prefix_len // block_size * block_size + hashes: dict[int, Hashable] = {} + for index, block_hash in enumerate(block_hashes): + prefix_len = (index + 1) * block_size + if prefix_len > max_prefix_len: + break + hashes[prefix_len] = block_hash + return hashes + + +def _candidate_cumulative_prefix_hashes( + scheduler: Any, + request: Any, + *, + max_prefix_len: int, +) -> list[dict[int, Hashable]]: + block_size = _block_size_for_scheduler(scheduler) + if block_size <= 0: + return [] + candidates = [] + vllm_hashes = _vllm_cumulative_prefix_hashes( + request, + block_size=block_size, + max_prefix_len=max_prefix_len, + ) + if vllm_hashes: + candidates.append(vllm_hashes) + token_ids = getattr(request, "prompt_token_ids", None) + if token_ids: + local_hashes = _local_cumulative_prefix_hashes( + token_ids, + block_size=block_size, + max_prefix_len=max_prefix_len, + ) + if local_hashes: + candidates.append(local_hashes) + return candidates + + +def _request_registry_key( + *, + scheduler: Any, + request: Any, + cumulative_prefix_hash: Hashable, + prefix_len: int, + block_size: int, +) -> HybridGDNPrefixKey: + return HybridGDNPrefixKey( + cumulative_prefix_hash=cumulative_prefix_hash, + prefix_len=int(prefix_len), + block_size=int(block_size), + cache_salt=getattr(request, "cache_salt", None), + model_revision=str( + _scheduler_config_value( + scheduler, + "hybrid_apc_model_revision", + "unknown", + ) + ), + layout_version=int( + _scheduler_config_value(scheduler, "hybrid_apc_layout_version", 1) + ), + tp_rank=int(_scheduler_config_value(scheduler, "tp_rank", 0)), + recurrent_dtype=_normalize_dtype( + _scheduler_config_value( + scheduler, + "hybrid_recurrent_cache_dtype", + _scheduler_config_value( + scheduler, + "gdn_recurrent_cache_dtype", + "float32", + ), + ), + "float32", + ), + conv_dtype=_normalize_dtype( + _scheduler_config_value( + scheduler, + "hybrid_conv_cache_dtype", + _scheduler_config_value( + scheduler, + "gdn_conv_cache_dtype", + "bfloat16", + ), + ), + "bfloat16", + ), + ) + + +def _request_max_cache_hit_len(scheduler: Any, request: Any) -> int: + if request is None: + return 0 + block_size = _block_size_for_scheduler(scheduler) + if block_size <= 0: + return 0 + token_ids = getattr(request, "prompt_token_ids", None) + token_count = int(getattr(request, "num_tokens", len(token_ids or ()))) + max_cache_hit_len = max(0, token_count - 1) + if token_ids: + max_cache_hit_len = min(max_cache_hit_len, len(token_ids)) + return max_cache_hit_len + + +def backed_gdn_prefix_hits(scheduler: Any, request: Any) -> dict[int, HybridGDNPrefixKey]: + """Return request prefix lengths with registered GDN checkpoints.""" + + if request is None: + return {} + block_size = _block_size_for_scheduler(scheduler) + max_cache_hit_len = _request_max_cache_hit_len(scheduler, request) + if block_size <= 0 or max_cache_hit_len <= 0: + return {} + hits: dict[int, HybridGDNPrefixKey] = {} + for hashes in _candidate_cumulative_prefix_hashes( + scheduler, + request, + max_prefix_len=max_cache_hit_len, + ): + for prefix_len in sorted(hashes, reverse=True): + if prefix_len in hits: + continue + key = _request_registry_key( + scheduler=scheduler, + request=request, + cumulative_prefix_hash=hashes[prefix_len], + prefix_len=prefix_len, + block_size=block_size, + ) + if key in _GDN_PREFIX_KEYS: + hits[prefix_len] = key + return hits + + +def backed_gdn_prefix_hit(scheduler: Any, request: Any) -> HybridGDNPrefixKey | None: + """Return the largest request prefix with a registered GDN checkpoint.""" + + hits = backed_gdn_prefix_hits(scheduler, request) + if not hits: + return None + return hits[max(hits)] + + +def _required_backed_prefix_lens(scheduler: Any, request: Any) -> tuple[int, ...]: + if request is None: + return () + block_size = _block_size_for_scheduler(scheduler) + max_cache_hit_len = _request_max_cache_hit_len(scheduler, request) + if block_size <= 0 or max_cache_hit_len <= 0: + return () + required: set[int] = set() + for hashes in _candidate_cumulative_prefix_hashes( + scheduler, + request, + max_prefix_len=max_cache_hit_len, + ): + required.update(int(prefix_len) for prefix_len in hashes) + return tuple(sorted(required)) + + +def _block_id_groups(block_ids: Any) -> list[list[int]]: + if block_ids is None: + return [] + if isinstance(block_ids, tuple): + groups = block_ids + elif ( + isinstance(block_ids, list) + and block_ids + and all(isinstance(item, (list, tuple)) for item in block_ids) + ): + groups = tuple(block_ids) + else: + groups = (block_ids,) + normalized = [] + for group in groups: + try: + normalized.append([int(block_id) for block_id in group]) + except TypeError: + continue + return normalized + + +def _attention_block_refs_by_prefix_len( + block_ids: Any, + *, + block_size: int, +) -> dict[int, tuple[int, ...]]: + groups = _block_id_groups(block_ids) + if not groups: + return {} + max_blocks = max(len(group) for group in groups) + refs_by_prefix_len: dict[int, tuple[int, ...]] = {} + for block_count in range(1, max_blocks + 1): + refs: list[int] = [] + for group in groups: + refs.extend(group[:block_count]) + if refs: + refs_by_prefix_len[block_count * block_size] = tuple(refs) + return refs_by_prefix_len + + +def _scheduler_request_metadata( + scheduler: Any, + request: Any, + *, + block_ids: Any = None, + num_computed_tokens: int | None = None, + active_suffix_len: int | None = None, +) -> dict[str, Any]: + block_size = _block_size_for_scheduler(scheduler) + if request is None or block_size <= 0: + return {} + token_ids = getattr(request, "prompt_token_ids", None) + prompt_token_count = int( + getattr(request, "num_prompt_tokens", len(token_ids or ())) or 0 + ) + if token_ids is not None: + prompt_token_count = min(prompt_token_count, len(token_ids)) + full_request_prefix_len = prompt_token_count + request_prefix_len = full_request_prefix_len + if num_computed_tokens is not None and active_suffix_len is not None: + scheduled_prefix_len = int(num_computed_tokens) + int(active_suffix_len) + request_prefix_len = max(0, min(full_request_prefix_len, scheduled_prefix_len)) + cumulative_hashes = _vllm_cumulative_prefix_hashes( + request, + block_size=block_size, + max_prefix_len=request_prefix_len, + ) + metadata: dict[str, Any] = {} + if cumulative_hashes: + metadata["cumulative_hashes_by_prefix_len"] = cumulative_hashes + refs_by_prefix_len = _attention_block_refs_by_prefix_len( + block_ids, + block_size=block_size, + ) + if refs_by_prefix_len: + metadata["attention_block_refs_by_prefix_len"] = refs_by_prefix_len + metadata["request_prefix_len"] = request_prefix_len + full_token_ids = token_ids or getattr(request, "all_token_ids", None) + has_computed_prefix = ( + num_computed_tokens is not None and int(num_computed_tokens) > 0 + ) + if has_computed_prefix and full_token_ids is not None: + full_token_ids = list(full_token_ids) + if len(full_token_ids) >= request_prefix_len: + metadata["full_input_ids"] = tuple( + int(token_id) + for token_id in full_token_ids[:request_prefix_len] + ) + if num_computed_tokens is not None: + metadata["vllm_attention_hit_len"] = int(num_computed_tokens) + if active_suffix_len is not None: + metadata["active_suffix_len"] = int(active_suffix_len) + return metadata + + +def _unbacked_prefix_reads_disabled_requested(scheduler: Any) -> bool: + if _env_flag("QWEN36_HYBRID_APC_ENABLE_PREFIX_READS"): + return False + + disable_requested = _env_flag("QWEN36_HYBRID_APC_DISABLE_UNBACKED_PREFIX_READS") + if disable_requested: + return True + if not _scheduler_config_flag(scheduler, "use_hybrid_apc_manager"): + return False + if _scheduler_config_flag( + scheduler, + "hybrid_apc_reject_unbacked_attention_hits", + ) or _scheduler_config_flag( + scheduler, + "hybrid_apc_require_vllm_metadata", + ): + return True + return _scheduler_config_flag( + scheduler, + "hybrid_apc_disable_unbacked_prefix_reads", + ) + + +def _request_prefix_len(request: Any) -> int: + if request is None: + return 0 + token_ids = getattr(request, "prompt_token_ids", None) + return int(getattr(request, "num_tokens", len(token_ids or ())) or 0) + + +def _backed_prefix_read_decision(scheduler: Any, request: Any) -> dict[str, Any]: + backed_hits = backed_gdn_prefix_hits(scheduler, request) + required_prefix_lens = _required_backed_prefix_lens(scheduler, request) + backed_hit_len = max(backed_hits) if backed_hits else 0 + max_readable_prefix_len = max(required_prefix_lens) if required_prefix_lens else 0 + missing_higher_backed_lens = [ + prefix_len + for prefix_len in required_prefix_lens + if prefix_len > backed_hit_len and prefix_len not in backed_hits + ] + supports_backed = _supports_backed_prefix_reads(scheduler) + max_backed_prefix_read_len = _max_backed_prefix_read_len(scheduler) + exceeds_backed_prefix_cap = ( + max_backed_prefix_read_len > 0 + and backed_hit_len > max_backed_prefix_read_len + ) + capped_backed_hits = [ + prefix_len + for prefix_len in backed_hits + if max_backed_prefix_read_len <= 0 or prefix_len <= max_backed_prefix_read_len + ] + prefix_read_len = max(capped_backed_hits) if capped_backed_hits else 0 + allowed = ( + prefix_read_len > 0 + and supports_backed + ) + return { + "allowed": allowed, + "backed_hits": backed_hits, + "required_prefix_lens": required_prefix_lens, + "backed_hit_len": backed_hit_len, + "prefix_read_len": prefix_read_len, + "missing_backed_lens": tuple(missing_higher_backed_lens), + "supports_backed": supports_backed, + "max_backed_prefix_read_len": max_backed_prefix_read_len, + "exceeds_backed_prefix_cap": exceeds_backed_prefix_cap, + } + + +def _request_from_scheduler(scheduler: Any, req_id: Any) -> Any: + requests = getattr(scheduler, "requests", None) + if isinstance(requests, dict): + return requests.get(req_id) + return None + + +def _attach_scheduler_output_metadata(scheduler: Any, scheduler_output: Any) -> None: + metadata_by_request_id: dict[Hashable, dict[str, Any]] = {} + active_suffix_lens = _num_scheduled_tokens_by_request_id(scheduler_output) + for req_data in getattr(scheduler_output, "scheduled_new_reqs", ()) or (): + req_id = getattr(req_data, "req_id", None) + request = _request_from_scheduler(scheduler, req_id) + metadata = _scheduler_request_metadata( + scheduler, + request, + block_ids=getattr(req_data, "block_ids", None), + num_computed_tokens=getattr(req_data, "num_computed_tokens", None), + active_suffix_len=active_suffix_lens.get(_normalize_request_id(req_id)), + ) + if metadata: + metadata_by_request_id[_normalize_request_id(req_id)] = metadata + _authorize_scheduled_prefix_read( + scheduler, + request, + request_id=req_id, + prefix_len=metadata.get("vllm_attention_hit_len"), + ) + + cached_reqs = getattr(scheduler_output, "scheduled_cached_reqs", None) + req_ids = list(getattr(cached_reqs, "req_ids", ()) or ()) + new_block_ids = list(getattr(cached_reqs, "new_block_ids", ()) or ()) + num_computed_tokens = list( + getattr(cached_reqs, "num_computed_tokens", ()) or () + ) + for index, req_id in enumerate(req_ids): + request = _request_from_scheduler(scheduler, req_id) + block_ids = new_block_ids[index] if index < len(new_block_ids) else None + computed = ( + int(num_computed_tokens[index]) + if index < len(num_computed_tokens) + else None + ) + metadata = _scheduler_request_metadata( + scheduler, + request, + block_ids=block_ids, + num_computed_tokens=computed, + active_suffix_len=active_suffix_lens.get(_normalize_request_id(req_id)), + ) + if metadata: + metadata_by_request_id[_normalize_request_id(req_id)] = metadata + _authorize_scheduled_prefix_read( + scheduler, + request, + request_id=req_id, + prefix_len=metadata.get("vllm_attention_hit_len"), + ) + + if metadata_by_request_id: + setattr( + scheduler_output, + _SCHEDULER_OUTPUT_METADATA_ATTR, + metadata_by_request_id, + ) + + +def backed_gdn_prefix_hit_len(scheduler: Any, request: Any) -> int: + hit = backed_gdn_prefix_hit(scheduler, request) + if hit is None: + return 0 + return hit.prefix_len + + +def _request_id_for_scheduler_request(request: Any) -> Hashable | None: + if request is None: + return None + for attr in ("request_id", "req_id", "id"): + request_id = getattr(request, attr, None) + if request_id is not None: + return _normalize_request_id(request_id) + return None + + +def _supports_backed_prefix_reads(scheduler: Any) -> bool: + """Return whether this artifact can consume a backed Hybrid APC prefix.""" + + if _env_flag("QWEN36_HYBRID_APC_ENABLE_BACKED_PREFIX_READS"): + return True + + if not _scheduler_config_flag(scheduler, "hybrid_apc_enable_backed_prefix_reads"): + return False + + # A backed GDN checkpoint is not enough on its own. The CTE graph must also + # consume attention KV prefix state; otherwise warm requests restore GDN + # state but full-attention layers still see only the suffix. + return _scheduler_config_flag(scheduler, "use_qwen_hybrid_chunked_prefill") + + +def _max_backed_prefix_read_len(scheduler: Any) -> int: + env_value = _env_int("QWEN36_HYBRID_APC_MAX_BACKED_PREFIX_READ_LEN") + if env_value is not None: + return max(0, env_value) + return max( + 0, + int( + _scheduler_config_value( + scheduler, + "hybrid_apc_max_backed_prefix_read_len", + 0, + ) + or 0 + ), + ) + + +def _set_request_prefix_cache_cap( + request: Any, + *, + prefix_len: int, + block_size: int, +) -> None: + if request is None: + return + prefix_len = max(0, int(prefix_len)) + block_size = max(1, int(block_size)) + try: + setattr(request, _MAX_PREFIX_CACHE_HIT_LEN_ATTR, prefix_len) + setattr( + request, + _MAX_PREFIX_CACHE_BLOCKS_ATTR, + prefix_len // block_size, + ) + except Exception: + return + + +def should_disable_unbacked_prefix_reads(scheduler: Any, request: Any = None) -> bool: + """Return whether this scheduler should avoid vLLM APC reads. + + The current Qwen Hybrid APC control plane can prove an attention hit is + invalid only inside model request prep. That is too late for allocation. + This opt-in fallback makes vLLM allocate the request as no-prefix unless + the scheduler process has a registered matching GDN checkpoint boundary and + the compiled artifact can consume the matching attention KV prefix in CTE. + """ + + disable_requested = _unbacked_prefix_reads_disabled_requested(scheduler) + if not disable_requested: + return False + decision = _backed_prefix_read_decision(scheduler, request) + if _env_flag("QWEN36_HYBRID_APC_DEBUG"): + prompt_len = len(getattr(request, "prompt_token_ids", ()) or ()) + print( + "[hybrid_apc_debug] scheduler-decision " + f"disable_requested={disable_requested} " + f"backed_hit_len={decision['backed_hit_len']} " + f"prefix_read_len={decision['prefix_read_len']} " + f"supports_backed={decision['supports_backed']} " + f"max_num_seqs={_max_num_seqs_for_scheduler(scheduler)} " + f"prompt_len={prompt_len} " + f"required_backed_lens={decision['required_prefix_lens']} " + f"missing_backed_lens={decision['missing_backed_lens']} " + f"max_backed_prefix_read_len={decision['max_backed_prefix_read_len']} " + f"exceeds_backed_prefix_cap={decision['exceeds_backed_prefix_cap']} " + f"registry_size={len(_GDN_PREFIX_KEYS)}", + flush=True, + ) + if decision["allowed"]: + request_id = _request_id_for_scheduler_request(request) + prefix_len = decision["prefix_read_len"] + _set_request_prefix_cache_cap( + request, + prefix_len=prefix_len, + block_size=_block_size_for_scheduler(scheduler), + ) + authorize_hybrid_apc_prefix_read( + decision["backed_hits"][prefix_len], + request_id=request_id, + ) + return False + return True + + +def patch_scheduler_class(scheduler_cls: type) -> bool: + """Patch a vLLM Scheduler class in-place. + + Returns True if this call installed the patch, False if the class was + already patched. + """ + + original_add_request = getattr(scheduler_cls, "add_request", None) + if original_add_request is None: + raise AttributeError(f"{scheduler_cls!r} has no add_request method") + installed = False + + if not getattr(original_add_request, "_qwen36_hybrid_apc_patched", False): + + def add_request_with_hybrid_apc_fallback(self, request): + if should_disable_unbacked_prefix_reads(self, request): + request.skip_reading_prefix_cache = True + return original_add_request(self, request) + + add_request_with_hybrid_apc_fallback._qwen36_hybrid_apc_patched = True + add_request_with_hybrid_apc_fallback._qwen36_original_add_request = ( + original_add_request + ) + scheduler_cls.add_request = add_request_with_hybrid_apc_fallback + installed = True + + original_schedule = getattr(scheduler_cls, "schedule", None) + if original_schedule is not None and not getattr( + original_schedule, + "_qwen36_hybrid_apc_metadata_patched", + False, + ): + + def schedule_with_hybrid_apc_metadata(self, *args, **kwargs): + deferred_waiting = None + temporary_waiting = None + waiting = getattr(self, "waiting", None) + running = getattr(self, "running", None) + if ( + waiting + and running + and _should_defer_waiting_prefills_while_running(self) + ): + temporary_waiting = _new_empty_queue_like(waiting) + if temporary_waiting is not None: + deferred_waiting = waiting + self.waiting = temporary_waiting + try: + scheduler_output = original_schedule(self, *args, **kwargs) + finally: + if deferred_waiting is not None: + current_waiting = getattr(self, "waiting", temporary_waiting) + if current_waiting: + self.waiting = _merge_waiting_queues( + current_waiting, + deferred_waiting, + ) + else: + self.waiting = deferred_waiting + _attach_scheduler_output_metadata(self, scheduler_output) + return scheduler_output + + schedule_with_hybrid_apc_metadata._qwen36_hybrid_apc_metadata_patched = True + schedule_with_hybrid_apc_metadata._qwen36_original_schedule = ( + original_schedule + ) + scheduler_cls.schedule = schedule_with_hybrid_apc_metadata + installed = True + + return installed + + +def _patch_scheduler_module(module: Any) -> bool: + scheduler_cls = getattr(module, "Scheduler", None) + if scheduler_cls is None: + return False + installed = patch_scheduler_class(scheduler_cls) + if installed: + logger.info("Installed Qwen Hybrid APC scheduler fallback patch") + return installed + + +def _patch_kv_cache_manager_module(module: Any) -> bool: + kv_cache_manager_cls = getattr(module, "KVCacheManager", None) + if kv_cache_manager_cls is None: + return False + original_get_computed_blocks = getattr( + kv_cache_manager_cls, + "get_computed_blocks", + None, + ) + if original_get_computed_blocks is None or getattr( + original_get_computed_blocks, + "_qwen36_hybrid_apc_prefix_cap_patched", + False, + ): + return False + + def get_computed_blocks_with_hybrid_apc_cap(self, request, *args, **kwargs): + cap_blocks = getattr(request, _MAX_PREFIX_CACHE_BLOCKS_ATTR, None) + try: + cap_blocks = None if cap_blocks is None else max(0, int(cap_blocks)) + except (TypeError, ValueError): + cap_blocks = None + if cap_blocks is None: + return original_get_computed_blocks(self, request, *args, **kwargs) + if cap_blocks <= 0: + return getattr(self, "empty_kv_cache_blocks"), 0 + + block_hashes = getattr(request, "block_hashes", None) + if not block_hashes or len(block_hashes) <= cap_blocks: + return original_get_computed_blocks(self, request, *args, **kwargs) + + original_block_hashes = block_hashes + if isinstance(block_hashes, tuple): + capped_block_hashes = block_hashes[:cap_blocks] + else: + capped_block_hashes = list(block_hashes[:cap_blocks]) + try: + request.block_hashes = capped_block_hashes + return original_get_computed_blocks(self, request, *args, **kwargs) + finally: + request.block_hashes = original_block_hashes + + get_computed_blocks_with_hybrid_apc_cap._qwen36_hybrid_apc_prefix_cap_patched = ( + True + ) + get_computed_blocks_with_hybrid_apc_cap._qwen36_original_get_computed_blocks = ( + original_get_computed_blocks + ) + kv_cache_manager_cls.get_computed_blocks = get_computed_blocks_with_hybrid_apc_cap + logger.info("Installed Qwen Hybrid APC KV prefix cap patch") + return True + + +def _request_ids_from_model_input(model_input: Any) -> tuple[Hashable, ...] | None: + request_ids = getattr(model_input, "request_ids", None) + return _as_request_id_tuple(request_ids) + + +def _as_request_id_tuple(request_ids: Any) -> tuple[Hashable, ...] | None: + if request_ids is None: + return None + if isinstance(request_ids, tuple): + return request_ids + if isinstance(request_ids, list): + return tuple(request_ids) + if isinstance(request_ids, (str, bytes)): + return (request_ids,) + try: + return tuple(request_ids) + except TypeError: + return (request_ids,) + + +def _request_ids_from_model_input_or_scheduler_output( + model_input: Any, + scheduler_output: Any, +) -> tuple[Hashable, ...] | None: + request_ids = _request_ids_from_model_input(model_input) + if request_ids: + return request_ids + cached_request_ids = _request_ids_from_scheduler_output( + scheduler_output, + kind="cached", + ) + new_request_ids = _request_ids_from_scheduler_output( + scheduler_output, + kind="new", + ) + combined = tuple(cached_request_ids or ()) + tuple(new_request_ids or ()) + return combined or None + + +def _request_ids_from_scheduler_output( + scheduler_output: Any, + *, + kind: str, +) -> tuple[Hashable, ...] | None: + if kind == "cached": + cached_reqs = getattr(scheduler_output, "scheduled_cached_reqs", None) + return _as_request_id_tuple(getattr(cached_reqs, "req_ids", None)) + if kind == "new": + new_reqs = getattr(scheduler_output, "scheduled_new_reqs", None) + if new_reqs is None: + return None + return tuple(getattr(req, "req_id") for req in new_reqs) + raise ValueError(f"unknown scheduler request kind: {kind}") + + +def _num_scheduled_tokens_by_request_id(scheduler_output: Any) -> dict[Hashable, int]: + values = getattr(scheduler_output, "num_scheduled_tokens", None) + if not isinstance(values, dict): + return {} + scheduled_tokens: dict[Hashable, int] = {} + for req_id, value in values.items(): + normalized = _normalize_request_id(req_id) + if normalized is None: + continue + try: + scheduled_tokens[normalized] = int(value) + except (TypeError, ValueError): + continue + return scheduled_tokens + + +def _scheduler_metadata_for_request_id( + metadata_by_request_id: Any, + request_id: Any, +) -> dict[str, Any]: + if not isinstance(metadata_by_request_id, dict): + return {} + normalized = _normalize_request_id(request_id) + metadata = metadata_by_request_id.get(normalized) + if metadata is None and request_id is not None: + metadata = metadata_by_request_id.get(str(request_id)) + return metadata if isinstance(metadata, dict) else {} + + +def _authorize_scheduled_prefix_read( + scheduler: Any, + request: Any, + *, + request_id: Any, + prefix_len: int | None, +) -> None: + """Authorize a vLLM prefix hit that is backed by a committed GDN checkpoint.""" + + if request is None or prefix_len is None: + return + prefix_len = int(prefix_len) + if prefix_len <= 0 or not _supports_backed_prefix_reads(scheduler): + return + max_backed_prefix_read_len = _max_backed_prefix_read_len(scheduler) + if max_backed_prefix_read_len > 0 and prefix_len > max_backed_prefix_read_len: + return + key = backed_gdn_prefix_hits(scheduler, request).get(prefix_len) + if key is None: + return + authorize_hybrid_apc_prefix_read( + key, + request_id=_request_id_for_scheduler_request(request) or request_id, + ) + + +def _hybrid_apc_request_records_from_model_input( + model_input: Any, + scheduler_output: Any, +) -> tuple[dict[str, Any], ...] | None: + request_ids = _request_ids_from_model_input_or_scheduler_output( + model_input, + scheduler_output, + ) + if not request_ids: + return None + metadata_by_request_id = getattr( + scheduler_output, + _SCHEDULER_OUTPUT_METADATA_ATTR, + None, + ) + if not isinstance(metadata_by_request_id, dict): + return None + + num_scheduled_tokens = _num_scheduled_tokens_by_request_id(scheduler_output) + records: list[dict[str, Any]] = [] + found_metadata = False + for request_id in request_ids: + metadata = _scheduler_metadata_for_request_id(metadata_by_request_id, request_id) + record: dict[str, Any] = {"request_id": request_id} + for key in ( + "cumulative_hashes_by_prefix_len", + "attention_block_refs_by_prefix_len", + "request_prefix_len", + "full_input_ids", + "vllm_attention_hit_len", + ): + if key in metadata: + record[key] = metadata[key] + found_metadata = True + normalized = _normalize_request_id(request_id) + if normalized in num_scheduled_tokens: + record["active_suffix_len"] = num_scheduled_tokens[normalized] + records.append(record) + return tuple(records) if found_metadata else None + + +def _request_id_target_models(model: Any) -> list[Any]: + targets = [] + seen = set() + current = model + for _ in range(4): + if current is None: + break + current_id = id(current) + if current_id in seen: + break + seen.add(current_id) + targets.append(current) + for attr in ( + "context_encoding_model", + "token_generation_model", + "fused_spec_model", + ): + wrapper = getattr(current, attr, None) + if wrapper is None: + continue + wrapper_id = id(wrapper) + if wrapper_id in seen: + continue + seen.add(wrapper_id) + targets.append(wrapper) + current = getattr(current, "model", None) + return targets + + +def _runner_hybrid_apc_runtime_config(runner: Any) -> dict[str, Any]: + additional_config = _get_additional_config(getattr(runner, "vllm_config", None)) + runtime_config = { + key: additional_config[key] + for key in _HYBRID_APC_RUNTIME_CONFIG_KEYS + if key in additional_config + } + if runtime_config.get("hybrid_apc_require_vllm_metadata"): + runtime_config["hybrid_apc_allow_local_hash_fallback"] = False + runtime_config["hybrid_apc_require_attention_block_refs"] = True + runtime_config["hybrid_apc_reject_unbacked_attention_hits"] = True + return runtime_config + + +def _config_targets_for_model(model: Any) -> list[Any]: + targets = [] + seen = set() + for target in _request_id_target_models(model): + config = getattr(target, "config", None) + if config is not None: + config_id = id(config) + if config_id not in seen: + seen.add(config_id) + targets.append(config) + if any(hasattr(target, key) for key in _HYBRID_APC_RUNTIME_CONFIG_KEYS): + target_id = id(target) + if target_id not in seen: + seen.add(target_id) + targets.append(target) + return targets + + +def _apply_runtime_config_values( + *, + target: Any, + values: dict[str, Any], + previous_values: list[tuple[Any, str, Any]], + missing: Any, +) -> None: + for attr, value in values.items(): + previous_values.append((target, attr, getattr(target, attr, missing))) + setattr(target, attr, value) + + +def _apply_hybrid_apc_runtime_config( + model: Any, + values: dict[str, Any], + *, + previous_values: list[tuple[Any, str, Any]], + missing: Any, +) -> None: + if not values: + return + for target in _config_targets_for_model(model): + _apply_runtime_config_values( + target=target, + values=values, + previous_values=previous_values, + missing=missing, + ) + bridge_values = { + bridge_attr: values[config_attr] + for config_attr, bridge_attr in _HYBRID_APC_BRIDGE_CONFIG_ATTRS.items() + if config_attr in values + } + if not bridge_values: + return + for target in _request_id_target_models(model): + bridge = getattr(target, "hybrid_apc_bridge", None) + if bridge is None: + continue + _apply_runtime_config_values( + target=bridge, + values=bridge_values, + previous_values=previous_values, + missing=missing, + ) + + +def _debug_logits_tensor(stage: str, tensor: Any) -> None: + if not _env_flag("QWEN36_VLLM_LOGITS_DEBUG"): + return + if tensor is None or not hasattr(tensor, "numel"): + print(f"[qwen36_vllm_logits_debug] stage={stage} tensor=none", flush=True) + return + try: + import torch # noqa: WPS433 + + if tensor.numel() == 0: + print( + "[qwen36_vllm_logits_debug] " + f"stage={stage} shape={tuple(tensor.shape)} dtype={tensor.dtype} empty", + flush=True, + ) + return + flat = tensor.detach().reshape(-1) + if torch.is_floating_point(flat): + finite_mask = torch.isfinite(flat) + finite_count = int(finite_mask.sum().item()) + nan_count = int(torch.isnan(flat).sum().item()) + posinf_count = int( + torch.logical_and(torch.isinf(flat), flat > 0).sum().item() + ) + neginf_count = int( + torch.logical_and(torch.isinf(flat), flat < 0).sum().item() + ) + if finite_count: + finite_flat = flat[finite_mask].float() + finite_min = float(finite_flat.min().item()) + finite_max = float(finite_flat.max().item()) + else: + finite_min = "none" + finite_max = "none" + row_argmax = [] + row_argmax_values = [] + if tensor.ndim >= 2: + rows = tensor.detach().float().reshape(tensor.shape[0], -1) + argmax = rows.argmax(dim=-1) + row_argmax = [int(item) for item in argmax[:8].cpu().tolist()] + row_argmax_values = [ + float(rows[row, argmax[row]].item()) + for row in range(min(rows.shape[0], 8)) + ] + print( + "[qwen36_vllm_logits_debug] " + f"stage={stage} shape={tuple(tensor.shape)} dtype={tensor.dtype} " + f"finite={finite_count} nan={nan_count} posinf={posinf_count} " + f"neginf={neginf_count} finite_min={finite_min} " + f"finite_max={finite_max} row_argmax={row_argmax} " + f"row_argmax_values={row_argmax_values}", + flush=True, + ) + else: + flat_i64 = flat.to(torch.int64) + print( + "[qwen36_vllm_logits_debug] " + f"stage={stage} shape={tuple(tensor.shape)} dtype={tensor.dtype} " + f"min={int(flat_i64.min().item())} max={int(flat_i64.max().item())}", + flush=True, + ) + except Exception as exc: # pragma: no cover - diagnostic only + print( + "[qwen36_vllm_logits_debug] " + f"stage={stage} summary_error={type(exc).__name__}: {exc}", + flush=True, + ) + + +def _expand_completed_prefill_logits(hidden_states: Any, model_input: Any) -> Any: + """Restore completed-only CTE logits to vLLM's scheduled request rows.""" + prefill_state = getattr(model_input, "prefill_completion_state", None) + if prefill_state is None or not hasattr(hidden_states, "shape"): + return hidden_states + if len(getattr(hidden_states, "shape", ())) == 0: + return hidden_states + + try: + import torch # noqa: WPS433 + + if hasattr(prefill_state, "detach"): + state_values = [ + bool(item) + for item in prefill_state.detach().cpu().reshape(-1).tolist() + ] + else: + state_values = [bool(item) for item in prefill_state] + scheduled_rows = len(state_values) + output_rows = int(hidden_states.shape[0]) + if output_rows == scheduled_rows: + return hidden_states + + completed_rows = [idx for idx, is_done in enumerate(state_values) if is_done] + if output_rows != len(completed_rows): + return hidden_states + if not torch.is_floating_point(hidden_states): + return hidden_states + + expanded = hidden_states.new_full( + (scheduled_rows, *tuple(hidden_states.shape[1:])), + float("-inf"), + ) + for src_row, dst_row in enumerate(completed_rows): + expanded[dst_row] = hidden_states[src_row] + if _env_flag("QWEN36_VLLM_LOGITS_DEBUG"): + print( + "[qwen36_vllm_logits_debug] " + f"expanded_completed_prefill_logits output_rows={output_rows} " + f"scheduled_rows={scheduled_rows} completed_rows={completed_rows}", + flush=True, + ) + return expanded + except Exception as exc: # pragma: no cover - defensive shim only + if _env_flag("QWEN36_VLLM_LOGITS_DEBUG"): + print( + "[qwen36_vllm_logits_debug] " + f"expand_completed_prefill_logits_error={type(exc).__name__}: {exc}", + flush=True, + ) + return hidden_states + + +def _prefill_completion_state_values(prefill_completion_state: Any) -> list[bool] | None: + if prefill_completion_state is None: + return None + try: + if hasattr(prefill_completion_state, "numel"): + if prefill_completion_state.numel() == 0: + return None + values = prefill_completion_state.reshape(-1) + normalized = [] + for value in values: + try: + normalized.append(bool(value.item())) + except AttributeError: + normalized.append(bool(value)) + return normalized + values = list(prefill_completion_state) + except Exception: + return None + if not values: + return None + normalized = [] + for value in values: + try: + normalized.append(bool(value.item())) + except AttributeError: + normalized.append(bool(value)) + return normalized + + +def _prefill_completion_has_incomplete_row(prefill_completion_state: Any) -> bool: + values = _prefill_completion_state_values(prefill_completion_state) + return bool(values) and not all(values) + + +def _runner_vocab_size(runner: Any) -> int | None: + owners = [ + runner, + getattr(runner, "model", None), + getattr(getattr(runner, "model", None), "model", None), + getattr(getattr(getattr(runner, "model", None), "model", None), "config", None), + getattr(runner, "model_config", None), + ] + for owner in owners: + vocab_size = getattr(owner, "vocab_size", None) + if vocab_size is not None: + try: + return int(vocab_size) + except (TypeError, ValueError): + return None + config = getattr(owner, "config", None) + vocab_size = getattr(config, "vocab_size", None) + if vocab_size is not None: + try: + return int(vocab_size) + except (TypeError, ValueError): + return None + return None + + +def _format_token_id(value: int) -> str: + if value < 0: + return str(value) + return f"{value} (0x{value & 0xFFFFFFFF:08x})" + + +def _prefill_state_for_output_rows( + values: list[bool], + output_rows: int, +) -> list[bool]: + if output_rows <= 0: + return [] + if output_rows == len(values): + return values + completed_count = sum(1 for value in values if value) + if completed_count > 0 and output_rows == completed_count: + return [True] * output_rows + return values[:output_rows] + + +def _validate_completed_prefill_sampled_tokens( + sampled_token_ids: Any, + prefill_completion_state: Any, + *, + vocab_size: int | None, + stage: str, +) -> None: + values = _prefill_completion_state_values(prefill_completion_state) + if not values or sampled_token_ids is None or not hasattr(sampled_token_ids, "shape"): + return + if not hasattr(sampled_token_ids, "dtype"): + return + if sampled_token_ids.dtype not in (torch.int32, torch.int64): + raise ValueError( + "Qwen3.6 sampled token ids must be int32 or int64 before vLLM " + f"publishes completed prefill rows; stage={stage}; " + f"dtype={sampled_token_ids.dtype}" + ) + shape = getattr(sampled_token_ids, "shape", ()) + if not shape: + return + row_values = _prefill_state_for_output_rows(values, int(shape[0])) + row_count = min(len(row_values), int(shape[0])) + for row_idx, is_done in enumerate(row_values[:row_count]): + if not is_done: + continue + row = sampled_token_ids[row_idx].reshape(-1) + if row.numel() == 0: + continue + invalid_id, reason = _sampled_token_invalid_id_and_reason( + row, + vocab_size=vocab_size, + ) + if invalid_id is None: + continue + raise ValueError( + "Qwen3.6 sampled token id contract violated before vLLM output " + f"update: {reason}; stage={stage}; row={row_idx}; " + f"token_id={_format_token_id(invalid_id)}; " + f"prefill_completion_state={values}; " + f"sampled_shape={tuple(sampled_token_ids.shape)}" + ) + + +def _sampled_token_invalid_id_and_reason( + row: Any, + *, + vocab_size: int | None, +) -> tuple[int | None, str | None]: + if row is None or not hasattr(row, "numel") or row.numel() == 0: + return None, None + min_id = int(row.min().item()) + max_id = int(row.max().item()) + if min_id < 0: + return min_id, "negative" + if vocab_size is not None and max_id >= vocab_size: + return max_id, f"out-of-vocab for vocab_size={vocab_size}" + return None, None + + +def _logits_argmax_token_ids_for_sample_shape( + logits_source: Any, + sampled_token_ids: Any, + *, + vocab_size: int | None = None, +) -> Any: + logits_tensor = _first_tensor_like(logits_source) + if logits_tensor is None or not hasattr(logits_tensor, "dim"): + return None + if not torch.is_floating_point(logits_tensor): + return None + if logits_tensor.dim() >= 3: + logits_for_argmax = logits_tensor[:, -1, :] + elif logits_tensor.dim() == 2: + logits_for_argmax = logits_tensor + elif logits_tensor.dim() == 1: + logits_for_argmax = logits_tensor.reshape(1, -1) + else: + return None + if vocab_size is not None and int(logits_for_argmax.shape[-1]) < int(vocab_size): + return None + + argmax = logits_for_argmax.detach().float().argmax(dim=-1) + target_shape = tuple(getattr(sampled_token_ids, "shape", ())) + if len(target_shape) <= 1: + shaped = argmax.reshape(-1) + else: + shaped = argmax.reshape(-1, *([1] * (len(target_shape) - 1))) + return shaped.to( + device=sampled_token_ids.device, + dtype=sampled_token_ids.dtype, + ) + + +def _summarize_logits_for_fallback(logits_source: Any) -> str: + logits_tensor = _first_tensor_like(logits_source) + if logits_tensor is None or not hasattr(logits_tensor, "dim"): + return "logits=unavailable" + if not torch.is_floating_point(logits_tensor): + return ( + f"logits_shape={tuple(getattr(logits_tensor, 'shape', ())) } " + f"logits_dtype={getattr(logits_tensor, 'dtype', None)} non_float" + ) + try: + logits_float = logits_tensor.detach().float() + flat = logits_float.reshape(-1) + finite_mask = torch.isfinite(flat) + finite_count = int(finite_mask.sum().item()) + nan_count = int(torch.isnan(flat).sum().item()) + posinf_count = int( + torch.logical_and(torch.isinf(flat), flat > 0).sum().item() + ) + neginf_count = int( + torch.logical_and(torch.isinf(flat), flat < 0).sum().item() + ) + finite_min = finite_max = None + if finite_count: + finite_values = flat[finite_mask] + finite_min = float(finite_values.min().item()) + finite_max = float(finite_values.max().item()) + logits_for_argmax = ( + logits_float[:, -1, :] if logits_float.dim() >= 3 else logits_float + ) + argmax = logits_for_argmax.argmax(dim=-1).detach().cpu().reshape(-1) + argmax_values = ( + logits_for_argmax.gather( + dim=-1, + index=logits_for_argmax.argmax(dim=-1, keepdim=True), + ) + .detach() + .cpu() + .reshape(-1) + ) + return ( + f"logits_shape={tuple(logits_tensor.shape)} logits_dtype={logits_tensor.dtype} " + f"finite={finite_count}/{int(flat.numel())} nan={nan_count} " + f"posinf={posinf_count} neginf={neginf_count} " + f"finite_min={finite_min} finite_max={finite_max} " + f"argmax={argmax[:4].tolist()} " + f"argmax_values={[float(item) for item in argmax_values[:4].tolist()]}" + ) + except Exception as exc: # pragma: no cover - diagnostic only + return f"logits_summary_error={type(exc).__name__}: {exc}" + + +def _mask_incomplete_prefill_sampled_tokens( + sampler_output: Any, + prefill_completion_state: Any, + *, + vocab_size: int | None = None, + stage: str = "sample", + logits_source: Any = None, +) -> Any: + values = _prefill_completion_state_values(prefill_completion_state) + if not values: + if _env_flag("QWEN36_HYBRID_APC_DEBUG"): + print( + "[hybrid_apc_debug] sample-mask skip " + f"prefill_completion_state={values}", + flush=True, + ) + return sampler_output + + sampled_token_ids = getattr(sampler_output, "sampled_token_ids", None) + if sampled_token_ids is None or not hasattr(sampled_token_ids, "clone"): + if _env_flag("QWEN36_HYBRID_APC_DEBUG"): + print( + "[hybrid_apc_debug] sample-mask missing-sampled-token-ids " + f"prefill_completion_state={values} " + f"sampler_output_type={type(sampler_output).__name__}", + flush=True, + ) + return sampler_output + shape = getattr(sampled_token_ids, "shape", ()) + if not shape: + if _env_flag("QWEN36_HYBRID_APC_DEBUG"): + print( + "[hybrid_apc_debug] sample-mask scalar-sampled-token-ids " + f"prefill_completion_state={values}", + flush=True, + ) + return sampler_output + row_count = min(len(values), int(shape[0])) + if row_count <= 0: + return sampler_output + + row_values = _prefill_state_for_output_rows(values, int(shape[0])) + row_count = min(len(row_values), int(shape[0])) + fallback_token_ids = None + masked_token_ids = None + repaired_completed_rows: list[dict[str, Any]] = [] + for row_idx, is_done in enumerate(row_values[:row_count]): + if not is_done: + if masked_token_ids is None: + masked_token_ids = sampled_token_ids.clone() + masked_token_ids[row_idx] = -1 + continue + + row = sampled_token_ids[row_idx].reshape(-1) + invalid_id, reason = _sampled_token_invalid_id_and_reason( + row, + vocab_size=vocab_size, + ) + if invalid_id is None: + continue + + if fallback_token_ids is None: + fallback_token_ids = _logits_argmax_token_ids_for_sample_shape( + logits_source, + sampled_token_ids, + vocab_size=vocab_size, + ) + if ( + fallback_token_ids is None + or not hasattr(fallback_token_ids, "shape") + or int(fallback_token_ids.shape[0]) <= row_idx + ): + raise ValueError( + "Qwen3.6 completed prefill sampled token is invalid and logits " + "are unavailable for host fallback. Compile the artifact with " + "--output-logits-with-on-device-sampling from a build that " + "gathers vocab-parallel output logits, or use " + "--disable-on-device-sampling for host sampling. " + f"{reason}; stage={stage}; row={row_idx}; " + f"token_id={_format_token_id(invalid_id)}; " + f"prefill_completion_state={values}; " + f"effective_output_state={row_values}; " + f"sampled_shape={tuple(sampled_token_ids.shape)}" + ) + + if masked_token_ids is None: + masked_token_ids = sampled_token_ids.clone() + masked_token_ids[row_idx] = fallback_token_ids[row_idx] + repaired_completed_rows.append( + { + "row": row_idx, + "reason": reason, + "token_id": invalid_id, + "fallback": int(fallback_token_ids[row_idx].reshape(-1)[0].item()), + "logits_summary": _summarize_logits_for_fallback(logits_source), + } + ) + + if masked_token_ids is None: + _validate_completed_prefill_sampled_tokens( + sampled_token_ids, + values, + vocab_size=vocab_size, + stage=stage, + ) + if _env_flag("QWEN36_HYBRID_APC_DEBUG"): + print( + "[hybrid_apc_debug] sample-mask skip " + f"prefill_completion_state={values}", + flush=True, + ) + return sampler_output + + _validate_completed_prefill_sampled_tokens( + masked_token_ids, + values, + vocab_size=vocab_size, + stage=stage, + ) + try: + sampler_output.sampled_token_ids = masked_token_ids + except Exception: + return sampler_output + for row in repaired_completed_rows: + logger.warning( + "Replacing invalid completed-prefill sampled token with logits " + "argmax before vLLM output update: stage=%s row=%s %s " + "token_id=%s fallback_token_id=%s prefill_completion_state=%s", + stage, + row["row"], + row["reason"], + _format_token_id(int(row["token_id"])), + row["fallback"], + values, + ) + logger.warning( + "Qwen3.6 fallback logits summary: stage=%s row=%s %s", + stage, + row["row"], + row["logits_summary"], + ) + if _env_flag("QWEN36_HYBRID_APC_DEBUG"): + try: + before = sampled_token_ids.detach().cpu().reshape(-1).tolist() + after = masked_token_ids.detach().cpu().reshape(-1).tolist() + except Exception: + before = "unavailable" + after = "unavailable" + print( + "[hybrid_apc_debug] sample-mask applied " + f"prefill_completion_state={values} before={before} after={after}", + flush=True, + ) + return sampler_output + + +def _shape_of(value: Any) -> list[int] | None: + shape = getattr(value, "shape", None) + if shape is None: + return None + return [int(item) for item in shape] + + +def _flatten_int_sample(value: Any, *, limit: int = 8) -> list[int] | None: + detach = getattr(value, "detach", None) + if detach is None: + return None + try: + tensor = detach().cpu().reshape(-1) + return [int(item) for item in tensor[:limit].tolist()] + except Exception: + return None + + +def _is_tensor_like(value: Any) -> bool: + return hasattr(value, "detach") and hasattr(value, "shape") + + +def _first_tensor_like(value: Any) -> Any: + if _is_tensor_like(value): + return value + if isinstance(value, (list, tuple)): + for item in value: + found = _first_tensor_like(item) + if found is not None: + return found + return None + + +def _describe_sample_logits_value(value: Any, *, depth: int = 0) -> dict[str, Any]: + row: dict[str, Any] = {"type": type(value).__name__} + shape = _shape_of(value) + if shape is not None: + row["shape"] = shape + if isinstance(value, (list, tuple)): + row["len"] = len(value) + if depth < 3: + row["items"] = [ + _describe_sample_logits_value(item, depth=depth + 1) + for item in value[:4] + ] + return row + + +def _split_sample_logits_output(value: Any) -> tuple[Any, Any, str]: + """Return token IDs and logits from sample+logits debug model outputs.""" + + tokens = getattr(value, "tokens", None) + logits = getattr(value, "logits", None) + if tokens is not None or logits is not None: + return tokens, logits, type(value).__name__ + + if isinstance(value, (list, tuple)): + if len(value) == 1: + nested_tokens, nested_logits, nested_kind = _split_sample_logits_output( + value[0] + ) + return nested_tokens, nested_logits, f"{type(value).__name__}[{nested_kind}]" + if len(value) >= 2: + return value[0], value[1], type(value).__name__ + + return value, None, type(value).__name__ + + +def _json_float_value(value: float) -> float | str: + if value != value: + return "nan" + if value == float("inf"): + return "inf" + if value == float("-inf"): + return "-inf" + return float(value) + + +def _log_sample_logits_comparison( + hidden_states: Any, + model_input: Any, + sampler_output: Any, +) -> None: + """Debug-only compare traced sampled tokens with returned logits argmax.""" + + path = os.environ.get("QWEN36_SAMPLE_LOGITS_COMPARE_JSONL") + if not path: + return + try: + tokens, logits, hidden_state_kind = _split_sample_logits_output(hidden_states) + token_tensor = _first_tensor_like(tokens) + logits_tensor = _first_tensor_like(logits) + row: dict[str, Any] = { + "request_ids": list(getattr(model_input, "request_ids", ()) or ()), + "hidden_state_type": hidden_state_kind, + "hidden_state_structure": _describe_sample_logits_value(hidden_states), + "tokens_shape": _shape_of(token_tensor), + "logits_shape": _shape_of(logits_tensor), + "sampler_output_type": type(sampler_output).__name__, + } + if token_tensor is not None and logits_tensor is not None: + row["sampled_dtype"] = str(token_tensor.dtype) + row["logits_dtype"] = str(logits_tensor.dtype) + logits_tensor = logits_tensor.detach().float() + flat_logits = logits_tensor.reshape(-1) + finite_mask = torch.isfinite(flat_logits) + finite_count = int(finite_mask.sum().item()) + row.update( + { + "logits_numel": int(flat_logits.numel()), + "logits_finite": finite_count, + "logits_nan": int(torch.isnan(flat_logits).sum().item()), + "logits_posinf": int( + torch.logical_and(torch.isinf(flat_logits), flat_logits > 0) + .sum() + .item() + ), + "logits_neginf": int( + torch.logical_and(torch.isinf(flat_logits), flat_logits < 0) + .sum() + .item() + ), + } + ) + if finite_count: + finite_flat = flat_logits[finite_mask] + row["logits_finite_min"] = float(finite_flat.min().item()) + row["logits_finite_max"] = float(finite_flat.max().item()) + else: + row["logits_finite_min"] = None + row["logits_finite_max"] = None + logits_for_argmax = ( + logits_tensor[:, -1, :] + if logits_tensor.dim() >= 3 + else logits_tensor + ) + argmax_tokens = logits_for_argmax.argmax(dim=-1).detach().cpu().reshape(-1) + argmax_values = ( + logits_for_argmax.gather( + dim=-1, + index=logits_for_argmax.argmax(dim=-1, keepdim=True), + ) + .detach() + .cpu() + .reshape(-1) + ) + sampled_tokens = token_tensor.detach().cpu().reshape(-1) + count = min(int(argmax_tokens.numel()), int(sampled_tokens.numel())) + row.update( + { + "sampled_tokens": [ + int(item) for item in sampled_tokens[: min(count, 8)].tolist() + ], + "logits_argmax_tokens": [ + int(item) for item in argmax_tokens[: min(count, 8)].tolist() + ], + "logits_argmax_values": [ + _json_float_value(float(item)) + for item in argmax_values[: min(count, 8)].tolist() + ], + "num_compared": count, + "num_matches": int( + (sampled_tokens[:count] == argmax_tokens[:count]).sum().item() + ) + if count + else 0, + } + ) + else: + row["sampled_tokens"] = _flatten_int_sample(token_tensor) + + with open(path, "a", encoding="utf-8") as handle: + handle.write(json.dumps(row, sort_keys=True) + "\n") + except Exception as exc: + try: + with open(path, "a", encoding="utf-8") as handle: + handle.write( + json.dumps( + { + "error": f"{type(exc).__name__}: {exc}", + "hidden_state_type": type(hidden_states).__name__, + "stage": "sample_logits_compare", + }, + sort_keys=True, + ) + + "\n" + ) + except Exception: + return + + +def _log_sample_logits_split_error( + hidden_states: Any, + model_input: Any, + exc: BaseException, +) -> None: + path = os.environ.get("QWEN36_SAMPLE_LOGITS_COMPARE_JSONL") + if not path: + return + try: + tokens, logits, hidden_state_kind = _split_sample_logits_output(hidden_states) + token_tensor = _first_tensor_like(tokens) + logits_tensor = _first_tensor_like(logits) + row = { + "error": f"{type(exc).__name__}: {exc}", + "hidden_state_type": hidden_state_kind, + "hidden_state_structure": _describe_sample_logits_value(hidden_states), + "logits_shape": _shape_of(logits_tensor), + "request_ids": list(getattr(model_input, "request_ids", ()) or ()), + "stage": "sample_on_device", + "tokens_shape": _shape_of(token_tensor), + } + with open(path, "a", encoding="utf-8") as handle: + handle.write(json.dumps(row, sort_keys=True) + "\n") + except Exception: + return + + +def patch_neuron_model_runner_class(runner_cls: type) -> bool: + """Patch vLLM-Neuron runner to expose scheduler row metadata.""" + + original_execute = getattr(runner_cls, "_execute_model_for_text", None) + if original_execute is None: + raise AttributeError( + f"{runner_cls!r} has no _execute_model_for_text method" + ) + original_prepare = getattr(runner_cls, "_prepare_model_input", None) + original_prepare_logits = getattr( + runner_cls, + "_prepare_logits_for_sampling", + None, + ) + original_sample_on_device = getattr( + runner_cls, + "_sample_on_device", + None, + ) + original_get_kv_cache_spec = getattr( + runner_cls, + "get_kv_cache_spec", + None, + ) + original_sample_tokens = getattr(runner_cls, "sample_tokens", None) + original_generate_output = getattr( + runner_cls, + "_generate_model_runner_output", + None, + ) + + missing = object() + installed = False + + if original_prepare is not None and not getattr( + original_prepare, + "_qwen36_hybrid_apc_model_input_patched", + False, + ): + + def prepare_model_input_with_hybrid_apc_metadata( + self, + scheduler_output, + *args, + **kwargs, + ): + model_input = original_prepare(self, scheduler_output, *args, **kwargs) + object.__setattr__( + model_input, + "_qwen36_cached_request_ids", + _request_ids_from_scheduler_output( + scheduler_output, + kind="cached", + ), + ) + object.__setattr__( + model_input, + "_qwen36_new_request_ids", + _request_ids_from_scheduler_output( + scheduler_output, + kind="new", + ), + ) + metadata_by_request_id = getattr( + scheduler_output, + _SCHEDULER_OUTPUT_METADATA_ATTR, + None, + ) + if metadata_by_request_id is not None: + object.__setattr__( + model_input, + _SCHEDULER_OUTPUT_METADATA_ATTR, + metadata_by_request_id, + ) + request_records = _hybrid_apc_request_records_from_model_input( + model_input, + scheduler_output, + ) + if request_records is not None: + object.__setattr__( + model_input, + _SCHEDULER_OUTPUT_REQUEST_RECORDS_ATTR, + request_records, + ) + return model_input + + prepare_model_input_with_hybrid_apc_metadata._qwen36_hybrid_apc_model_input_patched = ( + True + ) + prepare_model_input_with_hybrid_apc_metadata._qwen36_original_prepare_model_input = ( + original_prepare + ) + runner_cls._prepare_model_input = prepare_model_input_with_hybrid_apc_metadata + installed = True + + if original_prepare_logits is not None and not getattr( + original_prepare_logits, + "_qwen36_vllm_logits_debug_patched", + False, + ): + + def prepare_logits_for_sampling_with_debug( + self, + hidden_states, + model_input, + *args, + **kwargs, + ): + if _env_flag("QWEN36_VLLM_LOGITS_DEBUG"): + _debug_logits_tensor("runner_hidden_states_before_prepare", hidden_states) + prefill_state = getattr(model_input, "prefill_completion_state", None) + request_ids = getattr(model_input, "request_ids", None) + print( + "[qwen36_vllm_logits_debug] " + f"request_ids={request_ids} prefill_completion_state={prefill_state}", + flush=True, + ) + hidden_states = _expand_completed_prefill_logits(hidden_states, model_input) + _debug_logits_tensor( + "runner_hidden_states_after_prefill_expand", + hidden_states, + ) + logits = original_prepare_logits( + self, + hidden_states, + model_input, + *args, + **kwargs, + ) + _debug_logits_tensor("runner_logits_after_prepare", logits) + return logits + + prepare_logits_for_sampling_with_debug._qwen36_vllm_logits_debug_patched = ( + True + ) + prepare_logits_for_sampling_with_debug._qwen36_original_prepare_logits = ( + original_prepare_logits + ) + runner_cls._prepare_logits_for_sampling = prepare_logits_for_sampling_with_debug + installed = True + + if original_sample_on_device is not None and not getattr( + original_sample_on_device, + "_qwen36_clone_incomplete_prefill_tokens_patched", + False, + ): + + def sample_on_device_with_incomplete_prefill_clone( + self, + hidden_states, + model_input, + *args, + **kwargs, + ): + prefill_state = getattr(model_input, "prefill_completion_state", None) + if _prefill_completion_has_incomplete_row(prefill_state): + clone = getattr(hidden_states, "clone", None) + if clone is not None: + hidden_states = clone() + hidden_states_for_sampling, logits_for_fallback, _ = _split_sample_logits_output( + hidden_states + ) + token_tensor_for_sampling = _first_tensor_like(hidden_states_for_sampling) + if token_tensor_for_sampling is not None: + clone = getattr(token_tensor_for_sampling, "clone", None) + if clone is not None: + token_tensor_for_sampling = clone() + hidden_states_for_sampling = token_tensor_for_sampling + try: + sampler_output = original_sample_on_device( + self, + hidden_states_for_sampling, + model_input, + *args, + **kwargs, + ) + except Exception as exc: + _log_sample_logits_split_error(hidden_states, model_input, exc) + raise + sampler_output = _mask_incomplete_prefill_sampled_tokens( + sampler_output, + prefill_state, + vocab_size=_runner_vocab_size(self), + stage="sample_on_device", + logits_source=logits_for_fallback, + ) + _log_sample_logits_comparison(hidden_states, model_input, sampler_output) + return sampler_output + + sample_on_device_with_incomplete_prefill_clone._qwen36_clone_incomplete_prefill_tokens_patched = ( + True + ) + sample_on_device_with_incomplete_prefill_clone._qwen36_original_sample_on_device = ( + original_sample_on_device + ) + runner_cls._sample_on_device = sample_on_device_with_incomplete_prefill_clone + installed = True + + if original_get_kv_cache_spec is not None and not getattr( + original_get_kv_cache_spec, + "_qwen36_hybrid_kv_cache_spec_patched", + False, + ): + + def get_kv_cache_spec_with_qwen36_hybrid_layers(self): + model_config = getattr(self, "model_config", None) + hf_config = getattr(model_config, "hf_config", None) + num_layers = _num_layers_from_hf_config( + hf_config, + original_get_kv_cache_spec, + ) + if num_layers is None: + return original_get_kv_cache_spec(self) + attention_layer_indices = _hybrid_kv_attention_layer_indices( + hf_config, + num_layers, + ) + if attention_layer_indices is None: + return original_get_kv_cache_spec(self) + full_attention_spec_cls = _full_attention_spec_class( + original_get_kv_cache_spec + ) + if full_attention_spec_cls is None: + return original_get_kv_cache_spec(self) + + parallel_config = getattr(self, "parallel_config", None) + model = getattr(self, "model", None) + get_sliding_window = getattr(model_config, "get_sliding_window", None) + sliding_window = ( + get_sliding_window() if callable(get_sliding_window) else None + ) + local_kv_heads = _local_num_kv_heads(hf_config, parallel_config) + kv_cache_spec = {} + for layer_idx in attention_layer_indices: + layer_name = f"layers.{layer_idx}.self_attn" + kv_cache_spec[layer_name] = full_attention_spec_cls( + block_size=getattr(self, "block_size"), + num_kv_heads=local_kv_heads, + head_size=getattr(model, "head_dim"), + dtype=getattr(model_config, "dtype"), + sliding_window=sliding_window, + ) + logger.info( + "Using Qwen hybrid KV-cache spec for %d/%d attention layers " + "with %d local KV heads", + len(attention_layer_indices), + num_layers, + local_kv_heads, + ) + return kv_cache_spec + + get_kv_cache_spec_with_qwen36_hybrid_layers._qwen36_hybrid_kv_cache_spec_patched = ( + True + ) + get_kv_cache_spec_with_qwen36_hybrid_layers._qwen36_original_get_kv_cache_spec = ( + original_get_kv_cache_spec + ) + runner_cls.get_kv_cache_spec = get_kv_cache_spec_with_qwen36_hybrid_layers + installed = True + + if original_generate_output is not None and not getattr( + original_generate_output, + "_qwen36_mask_incomplete_prefill_output_patched", + False, + ): + + def generate_model_runner_output_with_prefill_mask( + self, + sampler_outputs, + *args, + **kwargs, + ): + prefill_state = getattr( + self, + _RUNNER_PREFILL_STATE_FOR_OUTPUT_ATTR, + None, + ) + if prefill_state is not None: + sampler_outputs = _mask_incomplete_prefill_sampled_tokens( + sampler_outputs, + prefill_state, + vocab_size=_runner_vocab_size(self), + stage="generate_model_runner_output", + ) + return original_generate_output(self, sampler_outputs, *args, **kwargs) + + generate_model_runner_output_with_prefill_mask._qwen36_mask_incomplete_prefill_output_patched = ( + True + ) + generate_model_runner_output_with_prefill_mask._qwen36_original_generate_model_runner_output = ( + original_generate_output + ) + runner_cls._generate_model_runner_output = ( + generate_model_runner_output_with_prefill_mask + ) + installed = True + + if original_sample_tokens is not None and not getattr( + original_sample_tokens, + "_qwen36_capture_prefill_state_for_output_patched", + False, + ): + + def sample_tokens_with_prefill_state_for_output(self, *args, **kwargs): + model_input = getattr(self, "_cached_model_input", None) + if getattr(self, "_cached_logits", None) is None: + return None + prefill_state = getattr(model_input, "prefill_completion_state", None) + previous_value = getattr( + self, + _RUNNER_PREFILL_STATE_FOR_OUTPUT_ATTR, + missing, + ) + setattr(self, _RUNNER_PREFILL_STATE_FOR_OUTPUT_ATTR, prefill_state) + try: + return original_sample_tokens(self, *args, **kwargs) + finally: + if previous_value is missing: + try: + delattr(self, _RUNNER_PREFILL_STATE_FOR_OUTPUT_ATTR) + except AttributeError: + pass + else: + setattr( + self, + _RUNNER_PREFILL_STATE_FOR_OUTPUT_ATTR, + previous_value, + ) + + sample_tokens_with_prefill_state_for_output._qwen36_capture_prefill_state_for_output_patched = ( + True + ) + sample_tokens_with_prefill_state_for_output._qwen36_original_sample_tokens = ( + original_sample_tokens + ) + runner_cls.sample_tokens = sample_tokens_with_prefill_state_for_output + installed = True + + if getattr(original_execute, "_qwen36_hybrid_apc_request_ids_patched", False): + return installed + + def execute_model_for_text_with_request_ids(self, model_input, *args, **kwargs): + model = getattr(self, "model", None) + runtime_config = _runner_hybrid_apc_runtime_config(self) + request_ids = _request_ids_from_model_input(model_input) + if request_ids is None: + request_ids = tuple( + getattr(model_input, "_qwen36_cached_request_ids", ()) or () + ) + tuple(getattr(model_input, "_qwen36_new_request_ids", ()) or ()) + if not request_ids: + request_ids = None + metadata = { + "_qwen36_vllm_request_ids": request_ids, + "_qwen36_vllm_cached_request_ids": getattr( + model_input, + "_qwen36_cached_request_ids", + None, + ), + "_qwen36_vllm_new_request_ids": getattr( + model_input, + "_qwen36_new_request_ids", + None, + ), + "_qwen36_vllm_prefill_completion_state": getattr( + model_input, + "prefill_completion_state", + None, + ), + "_qwen36_vllm_hybrid_apc_metadata_by_request_id": getattr( + model_input, + _SCHEDULER_OUTPUT_METADATA_ATTR, + None, + ), + "_qwen36_vllm_hybrid_apc_request_records": getattr( + model_input, + _SCHEDULER_OUTPUT_REQUEST_RECORDS_ATTR, + None, + ), + } + previous_values = [] + _apply_hybrid_apc_runtime_config( + model, + runtime_config, + previous_values=previous_values, + missing=missing, + ) + if any(value is not None for value in metadata.values()): + for target in _request_id_target_models(model): + for attr, value in metadata.items(): + if value is None: + continue + previous_values.append( + ( + target, + attr, + getattr(target, attr, missing), + ) + ) + setattr(target, attr, value) + try: + return original_execute(self, model_input, *args, **kwargs) + finally: + for target, attr, previous_value in reversed(previous_values): + if previous_value is missing: + try: + delattr(target, attr) + except AttributeError: + pass + else: + setattr(target, attr, previous_value) + + execute_model_for_text_with_request_ids._qwen36_hybrid_apc_request_ids_patched = ( + True + ) + execute_model_for_text_with_request_ids._qwen36_original_execute_model_for_text = ( + original_execute + ) + runner_cls._execute_model_for_text = execute_model_for_text_with_request_ids + return True + + +def _patch_neuron_runner_module(module: Any) -> bool: + runner_cls = getattr(module, "NeuronxDistributedModelRunner", None) + if runner_cls is None: + return False + installed = patch_neuron_model_runner_class(runner_cls) + if installed: + logger.info("Installed Qwen Hybrid APC vLLM-Neuron runner patch") + return installed + + +def _restore_nested_output(restore, value: Any) -> Any: + if hasattr(value, "shape"): + return restore(value) + if isinstance(value, list): + return [_restore_nested_output(restore, item) for item in value] + if isinstance(value, tuple): + return tuple(_restore_nested_output(restore, item) for item in value) + return value + + +def _patch_neuron_loader_module(module: Any) -> bool: + causal_lm_cls = getattr(module, "NeuronCausalLM", None) + if causal_lm_cls is None: + return False + original_forward = getattr(causal_lm_cls, "forward", None) + if original_forward is None or getattr( + original_forward, + "_qwen36_sample_logits_tokens_patched", + False, + ): + return False + + def forward_with_sample_logits_tokens( + self, + input_ids, + input_block_ids, + **kwargs, + ): + import time as _time # noqa: WPS433 + + forward_start = _time.perf_counter() + batch_size = ( + input_ids.shape[0] + if hasattr(input_ids, "shape") + else len(input_ids) + ) + + with self._reordered(input_block_ids, input_ids=input_ids, **kwargs) as ( + sorted_ids, + inputs, + restore, + ): + model_start = _time.perf_counter() + output = self.model( + inputs["input_ids"], + attention_mask=None, + seq_ids=sorted_ids, + block_table=inputs["block_tables"], + **{ + key: value + for key, value in inputs.items() + if key + not in ["input_ids", "block_tables", "prefill_completion_state"] + }, + ) + model_elapsed = (_time.perf_counter() - model_start) * 1000 + module.logger.debug("[PERF] model_execution: %.2fms", model_elapsed) + + output_proc_start = _time.perf_counter() + if self.model.config.neuron_config.on_device_sampling_config: + tokens = getattr(output, "tokens", None) + logits = getattr(output, "logits", None) + if tokens is not None and logits is not None: + output = [tokens, logits] + else: + output = output.hidden_states + if getattr( + self.model.config.neuron_config, + "enable_fused_speculation", + False, + ): + fused = output + output = self._remask_fused_spec_output(fused, inputs) + else: + if self.neuron_config.is_chunked_prefill: + assert kwargs.get("prefill_completion_state") is not None + idx_for_sampling = ( + kwargs["prefill_completion_state"].nonzero().flatten() + ) + output = output.logits[0, idx_for_sampling, :] + else: + output = output.logits[:, -1, :] + output_proc_elapsed = (_time.perf_counter() - output_proc_start) * 1000 + module.logger.debug( + "[PERF] output_processing: %.2fms", + output_proc_elapsed, + ) + + restore_start = _time.perf_counter() + result = _restore_nested_output(restore, output) + restore_elapsed = (_time.perf_counter() - restore_start) * 1000 + module.logger.debug("[PERF] restore: %.2fms", restore_elapsed) + + forward_elapsed = (_time.perf_counter() - forward_start) * 1000 + module.logger.debug( + "[PERF] forward() total: %.2fms [batch=%d]", + forward_elapsed, + batch_size, + ) + return result + + forward_with_sample_logits_tokens._qwen36_sample_logits_tokens_patched = True + forward_with_sample_logits_tokens._qwen36_original_forward = original_forward + causal_lm_cls.forward = forward_with_sample_logits_tokens + logger.info("Installed Qwen sample+logits vLLM-Neuron loader patch") + return True + + +def _patch_module(module_name: str, module: Any) -> bool: + if module_name == _SCHEDULER_MODULE: + return _patch_scheduler_module(module) + if module_name == _KV_CACHE_MANAGER_MODULE: + return _patch_kv_cache_manager_module(module) + if module_name == _VLLM_NEURON_RUNNER_MODULE: + return _patch_neuron_runner_module(module) + if module_name == _VLLM_NEURON_LOADER_MODULE: + return _patch_neuron_loader_module(module) + return False + + +class _HybridAPCSchedulerPatchLoader(importlib.abc.Loader): + _qwen36_hybrid_apc_loader = True + + def __init__(self, wrapped_loader: importlib.abc.Loader): + self.wrapped_loader = wrapped_loader + + def create_module(self, spec): + create_module = getattr(self.wrapped_loader, "create_module", None) + if create_module is None: + return None + return create_module(spec) + + def exec_module(self, module): + self.wrapped_loader.exec_module(module) + _patch_module(module.__name__, module) + + +class _HybridAPCSchedulerPatchFinder(importlib.abc.MetaPathFinder): + _qwen36_hybrid_apc_import_hook = True + + def find_spec(self, fullname, path, target=None): + if fullname not in _PATCHED_MODULES: + return None + spec = importlib.machinery.PathFinder.find_spec(fullname, path) + if spec is None or spec.loader is None: + return spec + if getattr(spec.loader, "_qwen36_hybrid_apc_loader", False): + return spec + spec.loader = _HybridAPCSchedulerPatchLoader(spec.loader) + return spec + + +def install_import_hook() -> bool: + """Patch vLLM components lazily, without importing vLLM at Python startup.""" + + installed = False + for module_name in _PATCHED_MODULES: + module = sys.modules.get(module_name) + if module is not None: + installed = _patch_module(module_name, module) or installed + for finder in sys.meta_path: + if getattr(finder, "_qwen36_hybrid_apc_import_hook", False): + return installed + sys.meta_path.insert(0, _HybridAPCSchedulerPatchFinder()) + return installed + + +def install() -> bool: + """Install the vLLM scheduler patch when vLLM is available.""" + + from vllm.v1.core.sched.scheduler import Scheduler # noqa: WPS433 + + installed = False + module = sys.modules.get(_SCHEDULER_MODULE) + if module is not None: + installed = _patch_scheduler_module(module) + else: + installed = patch_scheduler_class(Scheduler) + kv_cache_manager_module = sys.modules.get(_KV_CACHE_MANAGER_MODULE) + if kv_cache_manager_module is not None: + installed = _patch_kv_cache_manager_module(kv_cache_manager_module) or installed + runner_module = sys.modules.get(_VLLM_NEURON_RUNNER_MODULE) + if runner_module is not None: + installed = _patch_neuron_runner_module(runner_module) or installed + loader_module = sys.modules.get(_VLLM_NEURON_LOADER_MODULE) + if loader_module is not None: + installed = _patch_neuron_loader_module(loader_module) or installed + if installed: + logger.info("Installed Qwen Hybrid APC scheduler fallback patch") + return installed diff --git a/contrib/models/Qwen3.6-27B/vllm/run_offline_inference.py b/contrib/models/Qwen3.6-27B/vllm/run_offline_inference.py new file mode 100644 index 00000000..f4b7e706 --- /dev/null +++ b/contrib/models/Qwen3.6-27B/vllm/run_offline_inference.py @@ -0,0 +1,586 @@ +#!/usr/bin/env python3 +"""Offline vLLM smoke runner for Qwen3.6-27B on Neuron.""" + +from __future__ import annotations + +import argparse +import json +import os +import sys +import time +from pathlib import Path + + +_FP8_ENV_DEFAULTS = { + "XLA_HANDLE_SPECIAL_SCALAR": "1", + "UNSAFE_FP8FNCAST": "1", +} + + +def _ensure_fp8_environment() -> None: + for name, value in _FP8_ENV_DEFAULTS.items(): + os.environ.setdefault(name, value) + + +def _contrib_root(repo_root: str | None) -> Path: + if repo_root: + return Path(repo_root).expanduser().resolve() / "contrib" / "models" / "Qwen3.6-27B" + return Path(__file__).resolve().parents[1] + + +def _parse_int_list(values: list[str] | None) -> list[int] | None: + if values is None: + return None + tokens: list[str] = [] + for value in values: + tokens.extend(value.replace(",", " ").split()) + return [int(token) for token in tokens] + + +def _parse_bucket_pairs(values: list[str] | None) -> list[list[int]] | None: + if values is None: + return None + pairs: set[tuple[int, int]] = set() + for value in values: + for token in value.replace(",", " ").split(): + if ":" in token: + active, prefix = token.split(":", 1) + elif "x" in token: + active, prefix = token.split("x", 1) + else: + raise ValueError( + "--context-encoding-bucket-pairs entries must use " + f"ACTIVE:PREFIX syntax, got {token!r}" + ) + active_tokens, prefix_tokens = int(active), int(prefix) + if active_tokens <= 0 or prefix_tokens < 0: + raise ValueError( + "Context-encoding bucket pairs must be positive active " + f"tokens and non-negative prefix tokens, got {token!r}" + ) + pairs.add((active_tokens, prefix_tokens)) + return [[active, prefix] for active, prefix in sorted(pairs)] + + +def _cte_buckets(args: argparse.Namespace) -> list[int]: + profile_buckets = { + "short": [128, 256, 512, 1024], + "general": [256, 512, 1024, 2048], + "long": [4096, 8192, 16384, 32768], + "262k": [256], + } + if args.cte_bucket_profile != "single": + buckets = list(profile_buckets[args.cte_bucket_profile]) + else: + buckets = _parse_int_list(args.cte_buckets) or [args.cte_bucket] + buckets = sorted(set(buckets)) + if not buckets: + raise ValueError("At least one CTE bucket is required") + for bucket in buckets: + if bucket <= 0: + raise ValueError(f"CTE buckets must be positive, got {bucket}") + if bucket % 128 != 0: + raise ValueError( + f"CTE bucket {bucket} is not 128-aligned; DeltaNet CTE uses 128-token chunks" + ) + if buckets[-1] > args.seq_len: + raise ValueError( + f"Largest CTE bucket {buckets[-1]} exceeds --seq-len {args.seq_len}" + ) + return buckets + + +def _token_generation_buckets(args: argparse.Namespace) -> list[int]: + buckets = _parse_int_list(args.token_generation_buckets) or [args.seq_len] + buckets = sorted(set(buckets)) + if not buckets: + raise ValueError("At least one token-generation bucket is required") + for bucket in buckets: + if bucket <= 0: + raise ValueError( + f"Token-generation buckets must be positive, got {bucket}" + ) + if bucket > args.seq_len: + raise ValueError( + f"Token-generation bucket {bucket} exceeds --seq-len {args.seq_len}" + ) + return buckets + + +def _token_generation_batches(args: argparse.Namespace) -> list[int] | None: + batches = _parse_int_list(args.token_generation_batches) + if batches is None: + return None + batches = sorted(set(batches)) + if not batches: + raise ValueError("Token-generation batches cannot be empty") + for batch in batches: + if batch <= 0: + raise ValueError( + f"Token-generation batches must be positive, got {batch}" + ) + if batch > args.max_num_seqs: + raise ValueError( + f"Token-generation batch {batch} exceeds --max-num-seqs " + f"{args.max_num_seqs}" + ) + return batches + + +def _validate_hybrid_apc_args(args: argparse.Namespace): + if not args.enable_hybrid_apc: + return + if args.hybrid_cache_mode != "all": + raise ValueError("--enable-hybrid-apc requires --hybrid-cache-mode all") + if args.gdn_checkpoint_interval != args.block_size: + raise ValueError( + "--enable-hybrid-apc v0 requires --gdn-checkpoint-interval " + "to equal --block-size" + ) + args.enable_prefix_caching = True + + +def _max_num_batched_tokens(args: argparse.Namespace, cte_buckets: list[int]) -> int: + max_cte_bucket = cte_buckets[-1] + if not args.enable_vllm_chunked_prefill: + return max_cte_bucket + if not args.enable_hybrid_apc: + return max_cte_bucket + + checkpoint_interval = int(args.gdn_checkpoint_interval) + checkpoint_aligned_buckets = [ + bucket for bucket in cte_buckets if bucket % checkpoint_interval == 0 + ] + if not checkpoint_aligned_buckets: + raise ValueError( + "--enable-hybrid-apc with vLLM chunked prefill requires at least " + "one compiled CTE bucket that is a multiple of " + f"--gdn-checkpoint-interval ({checkpoint_interval}); got " + f"{cte_buckets}" + ) + requested_chunk = int(getattr(args, "hybrid_apc_prefill_chunk_tokens", 0) or 0) + if requested_chunk <= 0: + return min(max_cte_bucket, checkpoint_aligned_buckets[-1]) + if requested_chunk % checkpoint_interval != 0: + raise ValueError( + "--hybrid-apc-prefill-chunk-tokens must be a multiple of " + f"--gdn-checkpoint-interval ({checkpoint_interval}), got {requested_chunk}" + ) + if requested_chunk not in cte_buckets: + raise ValueError( + "--hybrid-apc-prefill-chunk-tokens must match a compiled CTE bucket, " + f"got {requested_chunk} with buckets {cte_buckets}" + ) + return min(max_cte_bucket, requested_chunk) + + +def _effective_prefill_group_size( + args: argparse.Namespace, + cte_buckets: list[int], +) -> int: + if args.enable_vllm_chunked_prefill: + return _max_num_batched_tokens(args, cte_buckets) + return cte_buckets[-1] + + +def _pa_num_blocks(args: argparse.Namespace) -> int: + num_gpu_blocks_override = getattr(args, "num_gpu_blocks_override", None) + if num_gpu_blocks_override is not None: + return max(1, num_gpu_blocks_override) + return max( + 1, + ((args.seq_len + args.block_size - 1) // args.block_size) + * args.max_num_seqs, + ) + + +def _normalize_cache_dtype(value: str | None, *, default: str = "float32") -> str: + if value is None: + value = default + normalized = str(value).lower() + aliases = { + "fp32": "float32", + "float32": "float32", + "torch.float32": "float32", + "bf16": "bfloat16", + "bfloat16": "bfloat16", + "torch.bfloat16": "bfloat16", + } + if normalized not in aliases: + raise ValueError( + "GDN recurrent cache dtype must be float32 or bfloat16, " + f"got {value}" + ) + return aliases[normalized] + + +def _recurrent_cache_dtype(args: argparse.Namespace) -> str: + dtype = _normalize_cache_dtype( + args.hybrid_gdn_recurrent_cache_dtype or args.gdn_recurrent_cache_dtype, + default="float32", + ) + if args.enable_hybrid_apc and args.hybrid_cache_mode == "all" and dtype != "float32": + raise ValueError( + "Hybrid APC all-mode requires float32 recurrent GDN checkpoint " + "cache state; use --gdn-recurrent-cache-dtype float32" + ) + return dtype + + +def _override_config(args: argparse.Namespace) -> dict: + _validate_hybrid_apc_args(args) + cte_buckets = _cte_buckets(args) + max_cte_bucket = cte_buckets[-1] + prefill_group_size = _effective_prefill_group_size(args, cte_buckets) + context_encoding_bucket_pairs = _parse_bucket_pairs( + args.context_encoding_bucket_pairs + ) + token_generation_buckets = _token_generation_buckets(args) + token_generation_batches = _token_generation_batches(args) + recurrent_cache_dtype = _recurrent_cache_dtype(args) + conv_cache_dtype = args.hybrid_gdn_conv_cache_dtype or args.gdn_conv_cache_dtype + neuron_config = { + "tp_degree": args.tensor_parallel_size, + "batch_size": args.max_num_seqs, + "ctx_batch_size": args.ctx_batch_size, + "tkg_batch_size": args.max_num_seqs, + "seq_len": args.seq_len, + "max_length": args.seq_len, + "max_context_length": max_cte_bucket, + "context_encoding_buckets": cte_buckets, + "token_generation_buckets": token_generation_buckets, + "enable_bucketing": len(cte_buckets) > 1 + or len(token_generation_buckets) > 1, + "logical_nc_config": args.logical_nc_config, + "torch_dtype": "bfloat16", + "save_sharded_checkpoint": True, + "gdn_checkpoint_interval": args.gdn_checkpoint_interval, + "max_gdn_checkpoint_slots": args.max_gdn_checkpoint_slots, + "gdn_recurrent_cache_dtype": recurrent_cache_dtype, + "gdn_conv_cache_dtype": conv_cache_dtype, + "hybrid_recurrent_cache_dtype": recurrent_cache_dtype, + "hybrid_conv_cache_dtype": conv_cache_dtype, + "hybrid_cache_mode": args.hybrid_cache_mode, + } + if args.async_mode: + neuron_config["async_mode"] = True + if token_generation_batches is not None: + neuron_config["token_generation_batches"] = token_generation_batches + if ( + args.enable_prefix_caching + or args.enable_hybrid_apc + or args.enable_vllm_chunked_prefill + ): + neuron_config["is_block_kv_layout"] = True + neuron_config["pa_block_size"] = args.block_size + neuron_config["pa_num_blocks"] = _pa_num_blocks(args) + uses_prefix_cte_contract = context_encoding_bucket_pairs is not None + if args.enable_prefix_caching or args.enable_hybrid_apc or uses_prefix_cte_contract: + neuron_config["is_prefix_caching"] = True + if context_encoding_bucket_pairs is not None: + neuron_config["context_encoding_bucket_pairs"] = ( + context_encoding_bucket_pairs + ) + if args.enable_vllm_chunked_prefill: + neuron_config.update( + { + "chunked_prefill_config": { + "max_num_seqs": args.max_num_seqs, + "tkg_model_enabled": True, + "kernel_q_tile_size": args.kernel_q_tile_size, + "kernel_kv_tile_size": args.kernel_kv_tile_size, + }, + } + ) + return { + "max_prompt_length": max_cte_bucket, + "use_hybrid_apc_manager": args.enable_hybrid_apc, + "use_text_only_cte_inputs": args.text_only_cte, + "use_compact_cte_attention_mask": args.compact_cte_attention_mask, + "use_cold_zero_conv_fast_path": args.cold_zero_conv_fast_path, + "gdn_checkpoint_interval": args.gdn_checkpoint_interval, + "max_gdn_checkpoint_slots": args.max_gdn_checkpoint_slots, + "gdn_recurrent_cache_dtype": recurrent_cache_dtype, + "gdn_conv_cache_dtype": conv_cache_dtype, + "hybrid_recurrent_cache_dtype": recurrent_cache_dtype, + "hybrid_conv_cache_dtype": conv_cache_dtype, + "hybrid_cache_mode": args.hybrid_cache_mode, + "hybrid_cache_prefix_boundary_only": args.hybrid_cache_prefix_boundary_only, + "hybrid_cache_block_boundary_only": args.hybrid_cache_prefix_boundary_only, + "hybrid_cache_validate_exact": args.hybrid_cache_validate_exact, + "hybrid_apc_require_vllm_metadata": args.hybrid_apc_require_vllm_metadata, + "hybrid_apc_allow_local_hash_fallback": not args.hybrid_apc_require_vllm_metadata, + "hybrid_apc_require_attention_block_refs": args.hybrid_apc_require_vllm_metadata, + "hybrid_apc_reject_unbacked_attention_hits": getattr( + args, + "hybrid_apc_reject_unbacked_attention_hits", + True, + ), + "hybrid_apc_disable_unbacked_prefix_reads": getattr( + args, + "hybrid_apc_disable_unbacked_prefix_reads", + False, + ), + "hybrid_apc_enable_backed_prefix_reads": getattr( + args, + "hybrid_apc_enable_backed_prefix_reads", + False, + ), + "hybrid_apc_max_backed_prefix_read_len": getattr( + args, + "hybrid_apc_max_backed_prefix_read_len", + 0, + ), + "hybrid_apc_prefill_chunk_tokens": ( + prefill_group_size + if args.enable_hybrid_apc and args.enable_vllm_chunked_prefill + else 0 + ), + "qwen_prefill_group_size": prefill_group_size, + "use_qwen_hybrid_chunked_prefill": args.enable_vllm_chunked_prefill, + "use_qwen_hybrid_chunked_prefill_nki": args.enable_vllm_chunked_prefill, + "override_neuron_config": neuron_config, + } + + +def main() -> int: + parser = argparse.ArgumentParser() + parser.add_argument("--repo-root", default=None) + parser.add_argument("--model-path", required=True) + parser.add_argument("--compiled-artifacts", default=None) + parser.add_argument("--prompt", default="What is 17 * 23? Answer with the number only.") + parser.add_argument("--chat", action="store_true") + parser.add_argument("--enable-vllm-chunked-prefill", action="store_true") + parser.add_argument("--enable-prefix-caching", action="store_true") + parser.add_argument("--enable-hybrid-apc", action="store_true") + parser.add_argument("--mamba-cache-mode", default=None) + parser.add_argument("--mamba-cache-dtype", default=None) + parser.add_argument("--mamba-ssm-cache-dtype", default=None) + parser.add_argument("--gdn-checkpoint-interval", type=int, default=256) + parser.add_argument("--max-gdn-checkpoint-slots", type=int, default=8) + parser.add_argument("--gdn-recurrent-cache-dtype", default="float32") + parser.add_argument("--gdn-conv-cache-dtype", default="bfloat16") + parser.add_argument("--hybrid-gdn-recurrent-cache-dtype", default=None) + parser.add_argument("--hybrid-gdn-conv-cache-dtype", default=None) + parser.add_argument("--hybrid-cache-mode", default="all") + parser.add_argument( + "--hybrid-cache-prefix-boundary-only", + "--hybrid-cache-block-boundary-only", + dest="hybrid_cache_prefix_boundary_only", + action=argparse.BooleanOptionalAction, + default=True, + ) + parser.add_argument("--hybrid-cache-validate-exact", action="store_true") + parser.add_argument( + "--hybrid-apc-require-vllm-metadata", + action="store_true", + help=( + "Require serving-provided vLLM cumulative prefix hashes and attention " + "block refs instead of the local token-hash validation fallback." + ), + ) + parser.add_argument( + "--hybrid-apc-reject-unbacked-attention-hits", + action=argparse.BooleanOptionalAction, + default=True, + help=( + "Reject attention prefix-cache hits that do not have a matching GDN " + "checkpoint. Disable only for controlled plumbing/debug isolation." + ), + ) + parser.add_argument( + "--hybrid-apc-disable-unbacked-prefix-reads", + action=argparse.BooleanOptionalAction, + default=False, + help=( + "Make vLLM skip prefix-cache reads for Qwen Hybrid APC until scheduler " + "GDN checkpoint metadata is available." + ), + ) + parser.add_argument( + "--hybrid-apc-enable-backed-prefix-reads", + action=argparse.BooleanOptionalAction, + default=False, + help=( + "Allow vLLM prefix-cache reads when both attention KV and GDN " + "checkpoint state are backed by a CTE artifact compiled for that " + "contract." + ), + ) + parser.add_argument( + "--hybrid-apc-max-backed-prefix-read-len", + type=int, + default=0, + help=( + "Optional safety cap for backed prefix reads. Prefix reads above this " + "token length are disabled even when a GDN checkpoint is registered." + ), + ) + parser.add_argument( + "--hybrid-apc-prefill-chunk-tokens", + type=int, + default=0, + help=( + "Opt into larger vLLM chunked-prefill chunks for Hybrid APC. The " + "value must be a compiled CTE bucket and a multiple of " + "--gdn-checkpoint-interval. Default 0 keeps conservative " + "checkpoint-sized chunks." + ), + ) + parser.add_argument("--num-gpu-blocks-override", type=int, default=None) + parser.add_argument("--max-tokens", type=int, default=64) + parser.add_argument("--temperature", type=float, default=0.0) + parser.add_argument("--top-k", type=int, default=1) + parser.add_argument("--tensor-parallel-size", type=int, default=4) + parser.add_argument("--logical-nc-config", type=int, default=2) + parser.add_argument("--max-num-seqs", type=int, default=1) + parser.add_argument("--ctx-batch-size", type=int, default=1) + parser.add_argument("--token-generation-buckets", nargs="+", default=None) + parser.add_argument("--token-generation-batches", nargs="+", default=None) + parser.add_argument("--async-mode", action="store_true") + parser.add_argument("--max-model-len", type=int, default=512) + parser.add_argument("--seq-len", type=int, default=512) + parser.add_argument("--cte-bucket", type=int, default=512) + parser.add_argument("--cte-buckets", nargs="+", default=None) + parser.add_argument("--context-encoding-bucket-pairs", nargs="+", default=None) + parser.add_argument( + "--cte-bucket-profile", + choices=("single", "short", "general", "long", "262k"), + default="single", + ) + parser.add_argument("--block-size", type=int, default=128) + parser.add_argument("--kernel-q-tile-size", type=int, default=128) + parser.add_argument("--kernel-kv-tile-size", type=int, default=1024) + parser.add_argument( + "--text-only-cte", + action=argparse.BooleanOptionalAction, + default=True, + ) + parser.add_argument( + "--compact-cte-attention-mask", + action=argparse.BooleanOptionalAction, + default=True, + ) + parser.add_argument( + "--cold-zero-conv-fast-path", + action=argparse.BooleanOptionalAction, + default=False, + ) + args = parser.parse_args() + + contrib_root = _contrib_root(args.repo_root) + script_dir = Path(__file__).resolve().parent + sys.path.insert(0, str(script_dir)) + sys.path.insert(0, str(contrib_root)) + os.environ["PYTHONPATH"] = ( + f"{script_dir}:{contrib_root}:{os.environ.get('PYTHONPATH', '')}" + ) + os.environ.setdefault("VLLM_NEURON_FRAMEWORK", "neuronx-distributed-inference") + os.environ.setdefault("VLLM_PLUGINS", "neuron") + if args.enable_hybrid_apc: + os.environ.setdefault("QWEN36_HYBRID_APC_INSTALL_PATCH", "1") + if args.enable_vllm_chunked_prefill: + os.environ["DISABLE_NEURON_CUSTOM_SCHEDULER"] = "1" + if args.hybrid_apc_disable_unbacked_prefix_reads: + os.environ["QWEN36_HYBRID_APC_DISABLE_UNBACKED_PREFIX_READS"] = "1" + if args.compiled_artifacts: + os.environ["NEURON_COMPILED_ARTIFACTS"] = str( + Path(args.compiled_artifacts).expanduser().resolve() + ) + _ensure_fp8_environment() + + from hf_qwen35_config import register_qwen35_config # noqa: WPS433 + from qwen36_hybrid_apc_scheduler_patch import ( # noqa: WPS433 + install_import_hook as install_hybrid_apc_scheduler_patch, + ) + + register_qwen35_config() + install_hybrid_apc_scheduler_patch() + + from vllm import LLM, SamplingParams # noqa: WPS433 + + prompt = args.prompt + if args.chat: + from transformers import AutoTokenizer # noqa: WPS433 + + tokenizer = AutoTokenizer.from_pretrained( + args.model_path, + trust_remote_code=True, + ) + prompt = tokenizer.apply_chat_template( + [{"role": "user", "content": args.prompt}], + tokenize=False, + add_generation_prompt=True, + enable_thinking=False, + ) + + additional_config = _override_config(args) + print("VLLM_QWEN36_CONFIG", json.dumps(additional_config, sort_keys=True), flush=True) + cte_buckets = _cte_buckets(args) + max_cte_bucket = max(cte_buckets) + + llm_kwargs = { + "model": str(Path(args.model_path).expanduser().resolve()), + "trust_remote_code": True, + "dtype": "bfloat16", + "tensor_parallel_size": args.tensor_parallel_size, + "max_num_seqs": args.max_num_seqs, + "max_model_len": args.max_model_len, + "enable_prefix_caching": args.enable_prefix_caching, + "enable_chunked_prefill": args.enable_vllm_chunked_prefill, + "additional_config": additional_config, + } + recurrent_cache_dtype = _recurrent_cache_dtype(args) + if args.enable_prefix_caching or args.enable_hybrid_apc: + llm_kwargs["mamba_cache_mode"] = args.mamba_cache_mode or "all" + llm_kwargs["mamba_ssm_cache_dtype"] = ( + args.mamba_ssm_cache_dtype or recurrent_cache_dtype + ) + elif args.mamba_cache_mode is not None: + llm_kwargs["mamba_cache_mode"] = args.mamba_cache_mode + if args.mamba_cache_dtype is not None: + llm_kwargs["mamba_cache_dtype"] = args.mamba_cache_dtype + if ( + args.mamba_ssm_cache_dtype is not None + and "mamba_ssm_cache_dtype" not in llm_kwargs + ): + llm_kwargs["mamba_ssm_cache_dtype"] = args.mamba_ssm_cache_dtype + if ( + args.num_gpu_blocks_override is not None + or args.enable_prefix_caching + or args.enable_hybrid_apc + or args.enable_vllm_chunked_prefill + ): + llm_kwargs["num_gpu_blocks_override"] = _pa_num_blocks(args) + if ( + args.enable_prefix_caching + or args.enable_hybrid_apc + or args.enable_vllm_chunked_prefill + ): + llm_kwargs["block_size"] = args.block_size + if args.enable_vllm_chunked_prefill: + llm_kwargs["max_num_batched_tokens"] = _effective_prefill_group_size( + args, cte_buckets + ) + llm = LLM(**llm_kwargs) + + sampling = SamplingParams( + temperature=args.temperature, + top_k=args.top_k, + max_tokens=args.max_tokens, + ) + start = time.perf_counter() + outputs = llm.generate([prompt], sampling) + elapsed = time.perf_counter() - start + text = outputs[0].outputs[0].text + token_ids = outputs[0].outputs[0].token_ids + + print("PROMPT", prompt) + print("OUTPUT", text) + print("TOKENS", list(token_ids)) + print("ELAPSED_SECONDS", f"{elapsed:.3f}") + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/contrib/models/Qwen3.6-27B/vllm/serve_qwen36.py b/contrib/models/Qwen3.6-27B/vllm/serve_qwen36.py new file mode 100644 index 00000000..f004d6ea --- /dev/null +++ b/contrib/models/Qwen3.6-27B/vllm/serve_qwen36.py @@ -0,0 +1,25 @@ +#!/usr/bin/env python3 +"""vLLM CLI wrapper that registers Qwen3.6 aliases before validation.""" + +from __future__ import annotations + +import sys + +from hf_qwen35_config import register_qwen35_config +from qwen36_hybrid_apc_scheduler_patch import ( + install_import_hook as install_hybrid_apc_scheduler_patch, +) + + +def main() -> int: + register_qwen35_config() + install_hybrid_apc_scheduler_patch() + + from vllm.entrypoints.cli.main import main as vllm_main + + sys.argv = ["vllm", "serve", *sys.argv[1:]] + return int(vllm_main() or 0) + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/contrib/models/Qwen3.6-27B/vllm/sitecustomize.py b/contrib/models/Qwen3.6-27B/vllm/sitecustomize.py new file mode 100644 index 00000000..cbddbd83 --- /dev/null +++ b/contrib/models/Qwen3.6-27B/vllm/sitecustomize.py @@ -0,0 +1,23 @@ +"""Auto-register Qwen3.5/Qwen3.6 HF config when this folder is on PYTHONPATH. + +Do not import vLLM here unless explicitly requested through an environment +flag. Neuron helper commands such as libneuronpjrt-path run inside Python +subprocesses and expect clean stdout. +""" + +import os + +from hf_qwen35_config import register_qwen35_hf_config + +register_qwen35_hf_config() + +if any( + os.environ.get(name) + for name in ( + "QWEN36_HYBRID_APC_INSTALL_PATCH", + "QWEN36_HYBRID_APC_DISABLE_UNBACKED_PREFIX_READS", + ) +): + from qwen36_hybrid_apc_scheduler_patch import install_import_hook + + install_import_hook() diff --git a/contrib/models/Qwen3.6-27B/vllm/start_vllm_server.sh b/contrib/models/Qwen3.6-27B/vllm/start_vllm_server.sh new file mode 100755 index 00000000..231d0b08 --- /dev/null +++ b/contrib/models/Qwen3.6-27B/vllm/start_vllm_server.sh @@ -0,0 +1,697 @@ +#!/usr/bin/env bash +set -euo pipefail + +MODEL_PATH="" +COMPILED_ARTIFACTS="" +MAX_MODEL_LEN="512" +SEQ_LEN="512" +CTE_BUCKET="512" +CTE_BUCKETS="" +CTE_BUCKET_PROFILE="single" +CONTEXT_ENCODING_BUCKET_PAIRS="" +TP_DEGREE="4" +LNC="2" +MAX_NUM_SEQS="1" +CTX_BATCH_SIZE="1" +TOKEN_GENERATION_BUCKETS="" +TOKEN_GENERATION_BATCHES="" +ASYNC_MODE="0" +PORT="8000" +HOST="0.0.0.0" +ENABLE_CHUNKED_PREFILL="0" +ENABLE_PREFIX_CACHING="0" +ENABLE_HYBRID_APC="0" +MAMBA_CACHE_MODE="" +MAMBA_CACHE_DTYPE="" +MAMBA_SSM_CACHE_DTYPE="" +BLOCK_SIZE="" +GDN_CHECKPOINT_INTERVAL="256" +MAX_GDN_CHECKPOINT_SLOTS="8" +GDN_RECURRENT_CACHE_DTYPE="float32" +GDN_CONV_CACHE_DTYPE="bfloat16" +HYBRID_GDN_RECURRENT_CACHE_DTYPE="" +HYBRID_GDN_CONV_CACHE_DTYPE="" +HYBRID_CACHE_MODE="all" +HYBRID_CACHE_PREFIX_BOUNDARY_ONLY="1" +HYBRID_CACHE_VALIDATE_EXACT="0" +HYBRID_APC_REQUIRE_VLLM_METADATA="1" +HYBRID_APC_DISABLE_UNBACKED_PREFIX_READS="0" +HYBRID_APC_ENABLE_BACKED_PREFIX_READS="0" +HYBRID_APC_ALLOW_MIXED_PREFILL_DECODE="0" +HYBRID_APC_PREFILL_CHUNK_TOKENS="0" +NUM_GPU_BLOCKS_OVERRIDE="" +GPU_MEMORY_UTILIZATION="" +KV_CACHE_DTYPE="" +KV_CACHE_MEMORY_BYTES="" +KERNEL_Q_TILE_SIZE="128" +KERNEL_KV_TILE_SIZE="1024" +TEXT_ONLY_CTE="1" +COMPACT_CTE_ATTENTION_MASK="1" +COLD_ZERO_CONV_FAST_PATH="0" + +while [[ $# -gt 0 ]]; do + case "$1" in + --model-path) MODEL_PATH="$2"; shift 2 ;; + --compiled-artifacts) COMPILED_ARTIFACTS="$2"; shift 2 ;; + --max-model-len) MAX_MODEL_LEN="$2"; shift 2 ;; + --seq-len) SEQ_LEN="$2"; shift 2 ;; + --cte-bucket) CTE_BUCKET="$2"; shift 2 ;; + --cte-buckets) CTE_BUCKETS="$2"; shift 2 ;; + --cte-bucket-profile) CTE_BUCKET_PROFILE="$2"; shift 2 ;; + --context-encoding-bucket-pairs) CONTEXT_ENCODING_BUCKET_PAIRS="$2"; shift 2 ;; + --tensor-parallel-size) TP_DEGREE="$2"; shift 2 ;; + --logical-nc-config) LNC="$2"; shift 2 ;; + --max-num-seqs) MAX_NUM_SEQS="$2"; shift 2 ;; + --ctx-batch-size) CTX_BATCH_SIZE="$2"; shift 2 ;; + --token-generation-buckets) TOKEN_GENERATION_BUCKETS="$2"; shift 2 ;; + --token-generation-batches) TOKEN_GENERATION_BATCHES="$2"; shift 2 ;; + --async-mode) ASYNC_MODE="1"; shift ;; + --no-async-mode) ASYNC_MODE="0"; shift ;; + --enable-vllm-chunked-prefill) ENABLE_CHUNKED_PREFILL="1"; shift ;; + --enable-prefix-caching) ENABLE_PREFIX_CACHING="1"; shift ;; + --disable-prefix-caching|--no-enable-prefix-caching) ENABLE_PREFIX_CACHING="0"; shift ;; + --enable-hybrid-apc) ENABLE_HYBRID_APC="1"; shift ;; + --mamba-cache-mode) MAMBA_CACHE_MODE="$2"; shift 2 ;; + --mamba-cache-dtype) MAMBA_CACHE_DTYPE="$2"; shift 2 ;; + --mamba-ssm-cache-dtype) MAMBA_SSM_CACHE_DTYPE="$2"; shift 2 ;; + --block-size) BLOCK_SIZE="$2"; shift 2 ;; + --gdn-checkpoint-interval) GDN_CHECKPOINT_INTERVAL="$2"; shift 2 ;; + --max-gdn-checkpoint-slots) MAX_GDN_CHECKPOINT_SLOTS="$2"; shift 2 ;; + --gdn-recurrent-cache-dtype) GDN_RECURRENT_CACHE_DTYPE="$2"; shift 2 ;; + --gdn-conv-cache-dtype) GDN_CONV_CACHE_DTYPE="$2"; shift 2 ;; + --hybrid-gdn-recurrent-cache-dtype) HYBRID_GDN_RECURRENT_CACHE_DTYPE="$2"; shift 2 ;; + --hybrid-gdn-conv-cache-dtype) HYBRID_GDN_CONV_CACHE_DTYPE="$2"; shift 2 ;; + --hybrid-cache-mode) HYBRID_CACHE_MODE="$2"; shift 2 ;; + --hybrid-cache-prefix-boundary-only|--hybrid-cache-block-boundary-only) HYBRID_CACHE_PREFIX_BOUNDARY_ONLY="1"; shift ;; + --no-hybrid-cache-prefix-boundary-only|--no-hybrid-cache-block-boundary-only) HYBRID_CACHE_PREFIX_BOUNDARY_ONLY="0"; shift ;; + --hybrid-cache-validate-exact) HYBRID_CACHE_VALIDATE_EXACT="1"; shift ;; + --hybrid-apc-require-vllm-metadata) HYBRID_APC_REQUIRE_VLLM_METADATA="1"; shift ;; + --no-hybrid-apc-require-vllm-metadata|--allow-hybrid-apc-local-hash-fallback) HYBRID_APC_REQUIRE_VLLM_METADATA="0"; shift ;; + --hybrid-apc-disable-unbacked-prefix-reads) HYBRID_APC_DISABLE_UNBACKED_PREFIX_READS="1"; shift ;; + --no-hybrid-apc-disable-unbacked-prefix-reads) HYBRID_APC_DISABLE_UNBACKED_PREFIX_READS="0"; shift ;; + --hybrid-apc-enable-backed-prefix-reads) HYBRID_APC_ENABLE_BACKED_PREFIX_READS="1"; shift ;; + --no-hybrid-apc-enable-backed-prefix-reads) HYBRID_APC_ENABLE_BACKED_PREFIX_READS="0"; shift ;; + --hybrid-apc-allow-mixed-prefill-decode) HYBRID_APC_ALLOW_MIXED_PREFILL_DECODE="1"; shift ;; + --no-hybrid-apc-allow-mixed-prefill-decode) HYBRID_APC_ALLOW_MIXED_PREFILL_DECODE="0"; shift ;; + --hybrid-apc-prefill-chunk-tokens) HYBRID_APC_PREFILL_CHUNK_TOKENS="$2"; shift 2 ;; + --num-gpu-blocks-override) NUM_GPU_BLOCKS_OVERRIDE="$2"; shift 2 ;; + --gpu-memory-utilization) GPU_MEMORY_UTILIZATION="$2"; shift 2 ;; + --kv-cache-dtype) KV_CACHE_DTYPE="$2"; shift 2 ;; + --kv-cache-memory-bytes) KV_CACHE_MEMORY_BYTES="$2"; shift 2 ;; + --kernel-q-tile-size) KERNEL_Q_TILE_SIZE="$2"; shift 2 ;; + --kernel-kv-tile-size) KERNEL_KV_TILE_SIZE="$2"; shift 2 ;; + --text-only-cte) TEXT_ONLY_CTE="1"; shift ;; + --no-text-only-cte|--multimodal-cte) TEXT_ONLY_CTE="0"; shift ;; + --compact-cte-attention-mask) COMPACT_CTE_ATTENTION_MASK="1"; shift ;; + --no-compact-cte-attention-mask) COMPACT_CTE_ATTENTION_MASK="0"; shift ;; + --cold-zero-conv-fast-path) COLD_ZERO_CONV_FAST_PATH="1"; shift ;; + --no-cold-zero-conv-fast-path) COLD_ZERO_CONV_FAST_PATH="0"; shift ;; + --host) HOST="$2"; shift 2 ;; + --port) PORT="$2"; shift 2 ;; + *) echo "Unknown argument: $1" >&2; exit 2 ;; + esac +done + +if [[ -z "${MODEL_PATH}" ]]; then + echo "ERROR: --model-path is required" >&2 + exit 2 +fi + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +CONTRIB_ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)" +REPO_ROOT="$(cd "${CONTRIB_ROOT}/../../.." && pwd)" +export PYTHONPATH="${SCRIPT_DIR}:${CONTRIB_ROOT}:${REPO_ROOT}/src:${PYTHONPATH:-}" +export VLLM_NEURON_FRAMEWORK="neuronx-distributed-inference" +export VLLM_PLUGINS="${VLLM_PLUGINS:-neuron}" + +if [[ -n "${COMPILED_ARTIFACTS}" ]]; then + export NEURON_COMPILED_ARTIFACTS="${COMPILED_ARTIFACTS}" + export XLA_HANDLE_SPECIAL_SCALAR="${XLA_HANDLE_SPECIAL_SCALAR:-1}" + export UNSAFE_FP8FNCAST="${UNSAFE_FP8FNCAST:-1}" +fi +if [[ -z "${BLOCK_SIZE}" ]]; then + BLOCK_SIZE="128" +fi +if [[ "${ENABLE_CHUNKED_PREFILL}" == "1" ]]; then + export DISABLE_NEURON_CUSTOM_SCHEDULER="1" +fi +if [[ "${ENABLE_HYBRID_APC}" == "1" || "${ENABLE_CHUNKED_PREFILL}" == "1" ]]; then + export QWEN36_HYBRID_APC_INSTALL_PATCH="1" +fi +if [[ "${HYBRID_APC_DISABLE_UNBACKED_PREFIX_READS}" == "1" ]]; then + export QWEN36_HYBRID_APC_DISABLE_UNBACKED_PREFIX_READS="1" +fi +if [[ "${HYBRID_APC_ENABLE_BACKED_PREFIX_READS}" == "1" ]]; then + export QWEN36_HYBRID_APC_ENABLE_BACKED_PREFIX_READS="1" +fi +if [[ -z "${HYBRID_GDN_RECURRENT_CACHE_DTYPE}" ]]; then + HYBRID_GDN_RECURRENT_CACHE_DTYPE="${GDN_RECURRENT_CACHE_DTYPE}" +fi +if [[ -z "${HYBRID_GDN_CONV_CACHE_DTYPE}" ]]; then + HYBRID_GDN_CONV_CACHE_DTYPE="${GDN_CONV_CACHE_DTYPE}" +fi +case "${HYBRID_GDN_RECURRENT_CACHE_DTYPE}" in + fp32|float32|torch.float32) + HYBRID_GDN_RECURRENT_CACHE_DTYPE="float32" + ;; + bf16|bfloat16|torch.bfloat16) + if [[ "${ENABLE_HYBRID_APC}" == "1" && "${HYBRID_CACHE_MODE}" == "all" ]]; then + echo "ERROR: Hybrid APC all-mode requires float32 recurrent GDN checkpoint cache state; use --gdn-recurrent-cache-dtype float32." >&2 + exit 2 + fi + HYBRID_GDN_RECURRENT_CACHE_DTYPE="bfloat16" + ;; + *) + echo "ERROR: unsupported --hybrid-gdn-recurrent-cache-dtype ${HYBRID_GDN_RECURRENT_CACHE_DTYPE}; expected float32 or bfloat16" >&2 + exit 2 + ;; +esac +if [[ "${ENABLE_PREFIX_CACHING}" == "1" || "${ENABLE_HYBRID_APC}" == "1" ]]; then + ENABLE_PREFIX_CACHING="1" +fi +if [[ "${ENABLE_PREFIX_CACHING}" == "1" || "${ENABLE_CHUNKED_PREFILL}" == "1" ]]; then + if [[ -z "${NUM_GPU_BLOCKS_OVERRIDE}" ]]; then + NUM_GPU_BLOCKS_OVERRIDE=$(( ((SEQ_LEN + BLOCK_SIZE - 1) / BLOCK_SIZE) * MAX_NUM_SEQS )) + fi +fi +if [[ "${ENABLE_HYBRID_APC}" == "1" ]]; then + if [[ "${HYBRID_CACHE_MODE}" != "all" ]]; then + echo "ERROR: --enable-hybrid-apc requires --hybrid-cache-mode all" >&2 + exit 2 + fi + if [[ "${GDN_CHECKPOINT_INTERVAL}" != "${BLOCK_SIZE}" ]]; then + echo "ERROR: --enable-hybrid-apc v0 requires --gdn-checkpoint-interval to equal --block-size" >&2 + exit 2 + fi +fi +if [[ "${ENABLE_PREFIX_CACHING}" == "1" && -z "${MAMBA_CACHE_MODE}" ]]; then + MAMBA_CACHE_MODE="all" +fi +if [[ "${ENABLE_PREFIX_CACHING}" == "1" && -z "${MAMBA_SSM_CACHE_DTYPE}" ]]; then + case "${HYBRID_GDN_RECURRENT_CACHE_DTYPE}" in + auto|float16|float32) + MAMBA_SSM_CACHE_DTYPE="${HYBRID_GDN_RECURRENT_CACHE_DTYPE}" + ;; + *) + MAMBA_SSM_CACHE_DTYPE="auto" + ;; + esac +fi +case "${MAMBA_SSM_CACHE_DTYPE}" in + ""|auto|float16|float32) + ;; + bfloat16|bf16) + echo "WARNING: vLLM --mamba-ssm-cache-dtype does not accept ${MAMBA_SSM_CACHE_DTYPE}; using auto while preserving hybrid GDN cache dtype in Neuron config." >&2 + MAMBA_SSM_CACHE_DTYPE="auto" + ;; + *) + echo "ERROR: unsupported --mamba-ssm-cache-dtype ${MAMBA_SSM_CACHE_DTYPE}; expected auto, float16, or float32" >&2 + exit 2 + ;; +esac + +CTE_BUCKETS_JSON="$( + python3 - < int("${SEQ_LEN}"): + raise SystemExit( + f"largest CTE bucket {buckets[-1]} exceeds --seq-len ${SEQ_LEN}" + ) +print(json.dumps(buckets)) +PY +)" +MAX_CTE_BUCKET="$( + python3 - < max_num_seqs: + raise SystemExit( + "TOKEN_GENERATION_BATCHES cannot contain values greater than " + f"MAX_NUM_SEQS ({max_num_seqs})" + ) +compiled_artifacts = "${COMPILED_ARTIFACTS}" +compiled_max_prompt = 0 +compiled_uses_prefix_caching = False +compiled_prefix_buckets = None +compiled_prefix_cte_attention_backend = None +compiled_prefix_cte_attention_segment_size = None +compiled_ctx_batch_size = 0 +compiled_tkg_batch_size = 0 +compiled_token_generation_buckets = None +compiled_token_generation_batches = None +compiled_kernel_flags = {} +compiled_decode_memory_flags = {} +compiled_weights_to_skip_layout_optimization = None +compiled_disable_token_generation_wlo = ( + os.environ.get("QWEN36_DISABLE_TOKEN_GENERATION_WLO") == "1" +) +if compiled_artifacts: + config_path = Path(compiled_artifacts).expanduser() / "neuron_config.json" + if config_path.exists(): + with config_path.open(encoding="utf-8") as handle: + compiled_config = json.load(handle) + compiled_disable_token_generation_wlo = ( + compiled_disable_token_generation_wlo + or bool(compiled_config.get("disable_token_generation_wlo")) + ) + nested_config = compiled_config.get("neuron_config") + if isinstance(nested_config, dict): + compiled_config = nested_config + compiled_disable_token_generation_wlo = ( + compiled_disable_token_generation_wlo + or bool(compiled_config.get("disable_token_generation_wlo")) + ) + compiled_max_prompt = int( + compiled_config.get("max_context_length") + or compiled_config.get("max_length") + or compiled_config.get("seq_len") + or 0 + ) + if context_encoding_bucket_pairs is None: + context_encoding_bucket_pairs = compiled_config.get( + "context_encoding_bucket_pairs" + ) + compiled_uses_prefix_caching = bool( + compiled_config.get("is_prefix_caching") + ) + compiled_prefix_buckets = compiled_config.get("prefix_buckets") + compiled_prefix_cte_attention_backend = compiled_config.get( + "prefix_cte_attention_backend" + ) + compiled_prefix_cte_attention_segment_size = compiled_config.get( + "prefix_cte_attention_segment_size" + ) + compiled_ctx_batch_size = int( + compiled_config.get("ctx_batch_size") + or compiled_config.get("batch_size") + or compiled_config.get("max_batch_size") + or 0 + ) + compiled_tkg_batch_size = int( + compiled_config.get("tkg_batch_size") + or compiled_config.get("batch_size") + or compiled_config.get("max_batch_size") + or 0 + ) + compiled_token_generation_batches = compiled_config.get( + "token_generation_batches" + ) + compiled_token_generation_buckets = compiled_config.get( + "token_generation_buckets" + ) + compiled_weights_to_skip_layout_optimization = compiled_config.get( + "weights_to_skip_layout_optimization" + ) + for flag_name in ( + "fused_qkv", + "qkv_kernel_enabled", + "qkv_nki_kernel_enabled", + "qkv_tkg_nki_kernel_enabled", + "attn_block_tkg_nki_kernel_enabled", + "attn_block_tkg_nki_kernel_cascaded_attention", + "attn_block_tkg_nki_kernel_cache_update", + "attn_block_tkg_nki_kernel_use_online_softmax", + "attn_block_tkg_nki_kernel_disable_gpsimd_sb2sb", + "out_proj_kernel_enabled", + "mlp_kernel_enabled", + "mlp_tkg_nki_kernel_enabled", + "quantized_mlp_kernel_enabled", + "rmsnorm_quantize_kernel_enabled", + "quantize_clamp_bound", + ): + if flag_name in compiled_config: + compiled_kernel_flags[flag_name] = compiled_config[flag_name] + for flag_name in ( + "k_cache_transposed", + "kv_cache_quant", + "kv_quant_config", + "quantized", + "quantization_dtype", + "quantization_type", + "quantization_block_size", + "quantization_block_axis", + "quantization_scale_dtype", + "quantized_checkpoints_path", + "modules_to_not_convert", + "draft_model_modules_to_not_convert", + "activation_quantization_type", + ): + if flag_name in compiled_config: + compiled_decode_memory_flags[flag_name] = compiled_config[flag_name] +runtime_max_prompt = compiled_max_prompt or max_cte_bucket +if compiled_artifacts and max_num_seqs > 1: + if compiled_tkg_batch_size and max_num_seqs > compiled_tkg_batch_size: + raise SystemExit( + "compiled artifact cannot serve requested continuous batching: " + f"MAX_NUM_SEQS={max_num_seqs} but compiled tkg_batch_size=" + f"{compiled_tkg_batch_size}" + ) + if compiled_ctx_batch_size and int("${CTX_BATCH_SIZE}") > compiled_ctx_batch_size: + raise SystemExit( + "compiled artifact cannot serve requested CTE batch: " + f"CTX_BATCH_SIZE=${CTX_BATCH_SIZE} but compiled ctx_batch_size=" + f"{compiled_ctx_batch_size}" + ) + +def normalize_int_list(values): + if values is None: + return None + if isinstance(values, str): + return parse_int_list("compiled int list", values) + normalized = sorted(set(int(value) for value in values)) + return normalized or None + +if token_generation_batches is None: + token_generation_batches = normalize_int_list(compiled_token_generation_batches) +if token_generation_buckets is None: + token_generation_buckets = ( + normalize_int_list(compiled_token_generation_buckets) or [seq_len] + ) +if token_generation_buckets[-1] > seq_len: + raise SystemExit( + f"TOKEN_GENERATION_BUCKETS cannot contain values greater than SEQ_LEN ({seq_len})" + ) +if token_generation_batches is not None: + token_generation_batches = [ + batch for batch in token_generation_batches if batch <= max_num_seqs + ] + if not token_generation_batches: + token_generation_batches = None +num_gpu_blocks_override = "${NUM_GPU_BLOCKS_OVERRIDE}" +pa_num_blocks = ( + int(num_gpu_blocks_override) + if num_gpu_blocks_override + else max( + 1, + ((int("${SEQ_LEN}") + int("${BLOCK_SIZE}") - 1) // int("${BLOCK_SIZE}")) + * int("${MAX_NUM_SEQS}"), + ) +) +neuron_config = { + "tp_degree": int("${TP_DEGREE}"), + "batch_size": max_num_seqs, + "ctx_batch_size": int("${CTX_BATCH_SIZE}"), + "tkg_batch_size": max_num_seqs, + "seq_len": seq_len, + "max_length": seq_len, + "max_context_length": runtime_max_prompt, + "context_encoding_buckets": cte_buckets, + "token_generation_buckets": token_generation_buckets, + "enable_bucketing": len(cte_buckets) > 1 or len(token_generation_buckets) > 1, + "logical_nc_config": int("${LNC}"), + "torch_dtype": "bfloat16", + "save_sharded_checkpoint": True, + "pa_block_size": int("${BLOCK_SIZE}"), + "pa_num_blocks": pa_num_blocks, + "gdn_checkpoint_interval": int("${GDN_CHECKPOINT_INTERVAL}"), + "max_gdn_checkpoint_slots": int("${MAX_GDN_CHECKPOINT_SLOTS}"), + "gdn_recurrent_cache_dtype": "${HYBRID_GDN_RECURRENT_CACHE_DTYPE}", + "gdn_conv_cache_dtype": "${HYBRID_GDN_CONV_CACHE_DTYPE}", + "hybrid_recurrent_cache_dtype": "${HYBRID_GDN_RECURRENT_CACHE_DTYPE}", + "hybrid_conv_cache_dtype": "${HYBRID_GDN_CONV_CACHE_DTYPE}", + "hybrid_cache_mode": "${HYBRID_CACHE_MODE}", +} +if async_mode: + neuron_config["async_mode"] = True +if token_generation_batches is not None: + neuron_config["token_generation_batches"] = token_generation_batches +if compiled_weights_to_skip_layout_optimization is not None: + neuron_config["weights_to_skip_layout_optimization"] = ( + compiled_weights_to_skip_layout_optimization + ) +neuron_config.update(compiled_kernel_flags) +neuron_config.update(compiled_decode_memory_flags) +if enable_prefix_caching or enable_hybrid_apc or enable_chunked: + neuron_config["is_block_kv_layout"] = True +uses_prefix_cte_contract = ( + context_encoding_bucket_pairs is not None or compiled_uses_prefix_caching +) +if enable_prefix_caching or enable_hybrid_apc or uses_prefix_cte_contract: + neuron_config["is_prefix_caching"] = True + if context_encoding_bucket_pairs is not None: + neuron_config["context_encoding_bucket_pairs"] = context_encoding_bucket_pairs + if compiled_prefix_buckets is not None: + neuron_config["prefix_buckets"] = compiled_prefix_buckets + if compiled_prefix_cte_attention_backend is not None: + neuron_config["prefix_cte_attention_backend"] = ( + compiled_prefix_cte_attention_backend + ) + if compiled_prefix_cte_attention_segment_size is not None: + neuron_config["prefix_cte_attention_segment_size"] = ( + compiled_prefix_cte_attention_segment_size + ) +# NeuronConfig.chunked_prefill_config trips the built-in block TKG attention +# kernel validation. Qwen Hybrid APC chunking still uses the top-level +# use_qwen_hybrid_chunked_prefill flags below. +if enable_chunked and not compiled_kernel_flags.get( + "attn_block_tkg_nki_kernel_enabled", + False, +): + neuron_config.update({ + "chunked_prefill_config": { + "max_num_seqs": int("${MAX_NUM_SEQS}"), + "tkg_model_enabled": True, + "kernel_q_tile_size": int("${KERNEL_Q_TILE_SIZE}"), + "kernel_kv_tile_size": int("${KERNEL_KV_TILE_SIZE}"), + }, + }) +print(json.dumps({ + "max_prompt_length": runtime_max_prompt, + "use_hybrid_apc_manager": enable_hybrid_apc, + "use_text_only_cte_inputs": "${TEXT_ONLY_CTE}" == "1", + "use_compact_cte_attention_mask": "${COMPACT_CTE_ATTENTION_MASK}" == "1", + "use_cold_zero_conv_fast_path": "${COLD_ZERO_CONV_FAST_PATH}" == "1", + "gdn_checkpoint_interval": int("${GDN_CHECKPOINT_INTERVAL}"), + "max_gdn_checkpoint_slots": int("${MAX_GDN_CHECKPOINT_SLOTS}"), + "gdn_recurrent_cache_dtype": "${HYBRID_GDN_RECURRENT_CACHE_DTYPE}", + "gdn_conv_cache_dtype": "${HYBRID_GDN_CONV_CACHE_DTYPE}", + "hybrid_recurrent_cache_dtype": "${HYBRID_GDN_RECURRENT_CACHE_DTYPE}", + "hybrid_conv_cache_dtype": "${HYBRID_GDN_CONV_CACHE_DTYPE}", + "hybrid_cache_mode": "${HYBRID_CACHE_MODE}", + "hybrid_cache_prefix_boundary_only": "${HYBRID_CACHE_PREFIX_BOUNDARY_ONLY}" == "1", + "hybrid_cache_block_boundary_only": "${HYBRID_CACHE_PREFIX_BOUNDARY_ONLY}" == "1", + "hybrid_cache_validate_exact": "${HYBRID_CACHE_VALIDATE_EXACT}" == "1", + "hybrid_apc_require_vllm_metadata": enable_hybrid_apc and "${HYBRID_APC_REQUIRE_VLLM_METADATA}" == "1", + "hybrid_apc_allow_local_hash_fallback": not (enable_hybrid_apc and "${HYBRID_APC_REQUIRE_VLLM_METADATA}" == "1"), + "hybrid_apc_require_attention_block_refs": enable_hybrid_apc and "${HYBRID_APC_REQUIRE_VLLM_METADATA}" == "1", + "hybrid_apc_disable_unbacked_prefix_reads": enable_hybrid_apc and "${HYBRID_APC_DISABLE_UNBACKED_PREFIX_READS}" == "1", + "hybrid_apc_enable_backed_prefix_reads": enable_hybrid_apc and "${HYBRID_APC_ENABLE_BACKED_PREFIX_READS}" == "1", + "hybrid_apc_allow_mixed_prefill_decode": enable_hybrid_apc and "${HYBRID_APC_ALLOW_MIXED_PREFILL_DECODE}" == "1", + "hybrid_apc_prefill_chunk_tokens": int("${MAX_BATCHED_TOKENS}") if enable_hybrid_apc and enable_chunked else 0, + "qwen_prefill_group_size": int("${MAX_BATCHED_TOKENS}") if enable_chunked else max_cte_bucket, + "use_qwen_hybrid_chunked_prefill": enable_chunked, + "use_qwen_hybrid_chunked_prefill_nki": enable_chunked, + "disable_token_generation_wlo": compiled_disable_token_generation_wlo, + "override_neuron_config": neuron_config, +})) +PY +)" + +echo "Starting vLLM for Qwen3.6-27B" +echo "MODEL_PATH=${MODEL_PATH}" +echo "NEURON_COMPILED_ARTIFACTS=${NEURON_COMPILED_ARTIFACTS:-}" +echo "XLA_HANDLE_SPECIAL_SCALAR=${XLA_HANDLE_SPECIAL_SCALAR:-}" +echo "UNSAFE_FP8FNCAST=${UNSAFE_FP8FNCAST:-}" +echo "QWEN36_DISABLE_TOKEN_GENERATION_WLO=${QWEN36_DISABLE_TOKEN_GENERATION_WLO:-}" +echo "PYTHONPATH=${PYTHONPATH}" +echo "ENABLE_PREFIX_CACHING=${ENABLE_PREFIX_CACHING}" +echo "ENABLE_HYBRID_APC=${ENABLE_HYBRID_APC}" +echo "MAMBA_CACHE_MODE=${MAMBA_CACHE_MODE:-}" +echo "MAMBA_CACHE_DTYPE=${MAMBA_CACHE_DTYPE:-}" +echo "MAMBA_SSM_CACHE_DTYPE=${MAMBA_SSM_CACHE_DTYPE:-}" +echo "BLOCK_SIZE=${BLOCK_SIZE}" +echo "CTE_BUCKETS=${CTE_BUCKETS_JSON}" +echo "CONTEXT_ENCODING_BUCKET_PAIRS=${CONTEXT_ENCODING_BUCKET_PAIRS}" +echo "CTX_BATCH_SIZE=${CTX_BATCH_SIZE}" +echo "KERNEL_Q_TILE_SIZE=${KERNEL_Q_TILE_SIZE}" +echo "KERNEL_KV_TILE_SIZE=${KERNEL_KV_TILE_SIZE}" +echo "TEXT_ONLY_CTE=${TEXT_ONLY_CTE}" +echo "COMPACT_CTE_ATTENTION_MASK=${COMPACT_CTE_ATTENTION_MASK}" +echo "COLD_ZERO_CONV_FAST_PATH=${COLD_ZERO_CONV_FAST_PATH}" +echo "GDN_CHECKPOINT_INTERVAL=${GDN_CHECKPOINT_INTERVAL}" +echo "MAX_GDN_CHECKPOINT_SLOTS=${MAX_GDN_CHECKPOINT_SLOTS}" +echo "HYBRID_GDN_RECURRENT_CACHE_DTYPE=${HYBRID_GDN_RECURRENT_CACHE_DTYPE}" +echo "HYBRID_GDN_CONV_CACHE_DTYPE=${HYBRID_GDN_CONV_CACHE_DTYPE}" +echo "HYBRID_APC_REQUIRE_VLLM_METADATA=${HYBRID_APC_REQUIRE_VLLM_METADATA}" +echo "HYBRID_APC_DISABLE_UNBACKED_PREFIX_READS=${HYBRID_APC_DISABLE_UNBACKED_PREFIX_READS}" +echo "HYBRID_APC_ENABLE_BACKED_PREFIX_READS=${HYBRID_APC_ENABLE_BACKED_PREFIX_READS}" +echo "HYBRID_APC_ALLOW_MIXED_PREFILL_DECODE=${HYBRID_APC_ALLOW_MIXED_PREFILL_DECODE}" +echo "HYBRID_APC_PREFILL_CHUNK_TOKENS=${HYBRID_APC_PREFILL_CHUNK_TOKENS}" +echo "GPU_MEMORY_UTILIZATION=${GPU_MEMORY_UTILIZATION}" +echo "KV_CACHE_DTYPE=${KV_CACHE_DTYPE}" +echo "KV_CACHE_MEMORY_BYTES=${KV_CACHE_MEMORY_BYTES}" +echo "ADDITIONAL_CONFIG=${ADDITIONAL_CONFIG}" + +VLLM_ARGS=( + "${MODEL_PATH}" + --host "${HOST}" \ + --port "${PORT}" \ + --trust-remote-code \ + --dtype bfloat16 \ + --tensor-parallel-size "${TP_DEGREE}" \ + --max-num-seqs "${MAX_NUM_SEQS}" \ + --max-model-len "${MAX_MODEL_LEN}" \ + --generation-config vllm \ + --additional-config "${ADDITIONAL_CONFIG}" +) +if [[ "${ENABLE_PREFIX_CACHING}" == "1" ]]; then + VLLM_ARGS+=(--enable-prefix-caching) +else + VLLM_ARGS+=(--no-enable-prefix-caching) +fi +if [[ -n "${MAMBA_CACHE_MODE}" ]]; then + VLLM_ARGS+=(--mamba-cache-mode "${MAMBA_CACHE_MODE}") +fi +if [[ -n "${MAMBA_CACHE_DTYPE}" ]]; then + VLLM_ARGS+=(--mamba-cache-dtype "${MAMBA_CACHE_DTYPE}") +fi +if [[ -n "${MAMBA_SSM_CACHE_DTYPE}" ]]; then + VLLM_ARGS+=(--mamba-ssm-cache-dtype "${MAMBA_SSM_CACHE_DTYPE}") +fi +if [[ -n "${NUM_GPU_BLOCKS_OVERRIDE}" ]]; then + VLLM_ARGS+=(--num-gpu-blocks-override "${NUM_GPU_BLOCKS_OVERRIDE}") +fi +if [[ -n "${GPU_MEMORY_UTILIZATION}" ]]; then + VLLM_ARGS+=(--gpu-memory-utilization "${GPU_MEMORY_UTILIZATION}") +fi +if [[ -n "${KV_CACHE_DTYPE}" ]]; then + VLLM_ARGS+=(--kv-cache-dtype "${KV_CACHE_DTYPE}") +fi +if [[ -n "${KV_CACHE_MEMORY_BYTES}" ]]; then + VLLM_ARGS+=(--kv-cache-memory-bytes "${KV_CACHE_MEMORY_BYTES}") +fi +if [[ "${ENABLE_PREFIX_CACHING}" == "1" || "${ENABLE_HYBRID_APC}" == "1" || "${ENABLE_CHUNKED_PREFILL}" == "1" ]]; then + VLLM_ARGS+=(--block-size "${BLOCK_SIZE}") +fi +if [[ "${ENABLE_CHUNKED_PREFILL}" == "1" ]]; then + VLLM_ARGS+=( + --enable-chunked-prefill + --max-num-batched-tokens "${MAX_BATCHED_TOKENS}" + ) +else + VLLM_ARGS+=(--no-enable-chunked-prefill) +fi + +exec python "${SCRIPT_DIR}/serve_qwen36.py" "${VLLM_ARGS[@]}" diff --git a/src/neuronx_distributed_inference/models/config.py b/src/neuronx_distributed_inference/models/config.py index 9b58163d..8929060e 100644 --- a/src/neuronx_distributed_inference/models/config.py +++ b/src/neuronx_distributed_inference/models/config.py @@ -178,6 +178,9 @@ def __init__(self, **kwargs) -> None: # Expose argmax kernel flag at top-level for easier configuration with # models like EAGLE self.disable_argmax_kernel = kwargs.pop("disable_argmax_kernel", False) + self.disable_context_encoding_argmax_kernel = kwargs.pop( + "disable_context_encoding_argmax_kernel", False + ) # async self.async_mode = kwargs.pop("async_mode", False) @@ -188,12 +191,75 @@ def __init__(self, **kwargs) -> None: self.bucket_n_active_tokens = kwargs.pop("bucket_n_active_tokens", False) self.context_encoding_buckets = kwargs.pop("context_encoding_buckets", None) self.prefix_buckets = kwargs.pop("prefix_buckets", None) + self.context_encoding_bucket_pairs = kwargs.pop( + "context_encoding_bucket_pairs", None + ) self.token_generation_buckets = kwargs.pop("token_generation_buckets", None) + self.prefix_cte_attention_backend = kwargs.pop( + "prefix_cte_attention_backend", "attention_cte" + ) + assert self.prefix_cte_attention_backend in ( + "attention_cte", + "segmented_cte", + ), ( + "prefix_cte_attention_backend must be one of " + "attention_cte or segmented_cte" + ) + self.prefix_cte_attention_segment_size = kwargs.pop( + "prefix_cte_attention_segment_size", None + ) + if self.prefix_cte_attention_segment_size is not None: + self.prefix_cte_attention_segment_size = int( + self.prefix_cte_attention_segment_size + ) + assert self.prefix_cte_attention_segment_size > 0, ( + "prefix_cte_attention_segment_size must be positive when set" + ) + self.prefix_cte_attention_chunk_size = kwargs.pop( + "prefix_cte_attention_chunk_size", None + ) + if self.prefix_cte_attention_chunk_size is not None: + self.prefix_cte_attention_chunk_size = int( + self.prefix_cte_attention_chunk_size + ) + assert self.prefix_cte_attention_chunk_size > 0, ( + "prefix_cte_attention_chunk_size must be positive when set" + ) if self.context_encoding_buckets is not None: self.context_encoding_buckets.sort() assert ( self.context_encoding_buckets[-1] <= self.max_context_length ), f"Context bucket {self.context_encoding_buckets[-1]} should be <= {self.max_context_length}" + if self.context_encoding_bucket_pairs is not None: + bucket_pairs = [] + for bucket_pair in self.context_encoding_bucket_pairs: + assert len(bucket_pair) == 2, ( + "Context encoding bucket pairs must be [active_tokens, " + f"prefix_tokens], got {bucket_pair}" + ) + active_tokens, prefix_tokens = int(bucket_pair[0]), int(bucket_pair[1]) + assert active_tokens > 0, ( + f"Context encoding active bucket must be positive, got {active_tokens}" + ) + assert prefix_tokens >= 0, ( + f"Context encoding prefix bucket must be non-negative, got {prefix_tokens}" + ) + assert active_tokens <= self.max_context_length, ( + f"Context encoding active bucket {active_tokens} should be <= " + f"{self.max_context_length}" + ) + assert prefix_tokens <= self.max_context_length, ( + f"Context encoding prefix bucket {prefix_tokens} should be <= " + f"{self.max_context_length}" + ) + bucket_pairs.append([active_tokens, prefix_tokens]) + self.context_encoding_bucket_pairs = sorted( + set(tuple(pair) for pair in bucket_pairs) + ) + self.context_encoding_bucket_pairs = [ + [active_tokens, prefix_tokens] + for active_tokens, prefix_tokens in self.context_encoding_bucket_pairs + ] if self.token_generation_buckets is not None: self.token_generation_buckets.sort() assert ( @@ -420,6 +486,19 @@ def __init__(self, **kwargs) -> None: self.strided_context_parallel_kernel_enabled = kwargs.pop("strided_context_parallel_kernel_enabled", False) self.qkv_kernel_enabled = kwargs.pop("qkv_kernel_enabled", False) self.qkv_nki_kernel_enabled = kwargs.pop("qkv_nki_kernel_enabled", False) + self.qkv_tkg_nki_kernel_enabled = kwargs.pop("qkv_tkg_nki_kernel_enabled", False) + if self.qkv_tkg_nki_kernel_enabled: + assert not self.fused_qkv, ( + "qkv_tkg_nki_kernel_enabled uses split Q/K/V projections and " + "cannot be combined with fused_qkv." + ) + assert not ( + self.qkv_kernel_enabled or self.qkv_nki_kernel_enabled + ), ( + "qkv_tkg_nki_kernel_enabled intentionally bypasses the stock " + "QKV CTE/TKG wrapper; do not combine it with qkv_kernel_enabled " + "or qkv_nki_kernel_enabled." + ) self.qkv_cte_nki_kernel_fuse_rope = kwargs.pop("qkv_cte_nki_kernel_fuse_rope", False) if self.qkv_cte_nki_kernel_fuse_rope: assert self.qkv_kernel_enabled and self.qkv_nki_kernel_enabled, \ diff --git a/src/neuronx_distributed_inference/models/model_base.py b/src/neuronx_distributed_inference/models/model_base.py index 341de1c3..d63c63a6 100644 --- a/src/neuronx_distributed_inference/models/model_base.py +++ b/src/neuronx_distributed_inference/models/model_base.py @@ -38,7 +38,14 @@ TOKEN_GENERATION_MODEL_TAG, ModelWrapper, ) -from neuronx_distributed_inference.modules.async_execution import causal_lm_async_execution +from neuronx_distributed_inference.modules.async_execution import ( + cancel_hybrid_apc_request, + causal_lm_async_execution, + finish_hybrid_apc_request, + prepare_disabled_hybrid_apc_model_inputs, + prepare_hybrid_apc_model_inputs, + prepare_hybrid_apc_request_for_execution, +) from neuronx_distributed_inference.modules.eagle.hidden_state import HiddenStateRollingBuffer from neuronx_distributed_inference.modules.eagle.token_tree import TokenTree from neuronx_distributed_inference.modules.flashdecode.utils import ( @@ -1178,7 +1185,19 @@ def _sample_on_device( ): sampling_inputs = logits[:, -1, :] res = self.sampler( - sampling_inputs, sampling_params, rank_id=self.rank_util.get_rank() + sampling_inputs, + sampling_params, + rank_id=self.rank_util.get_rank(), + disable_argmax_kernel_override=( + True + if is_for_context_encoding + and getattr( + self.neuron_config, + "disable_context_encoding_argmax_kernel", + False, + ) + else None + ), ) res = res.to(torch.int32) # Otherwise we return the full logits for multinomial sampling in spec decoding @@ -3442,6 +3461,8 @@ def forward( if not generation_model.is_neuron(): self._copy_past_key_values(outputs) + self._debug_raw_outputs(outputs) + # Get processed and constructed outputs constructed_outputs = self._get_constructed_outputs(outputs, is_run_on_neuron) @@ -3574,60 +3595,173 @@ def _get_model_outputs( self.base_model = ( self.context_encoding_model if is_context_encoding else generation_model ) - if self.neuron_config.enable_eagle_speculation: - outputs = self.base_model( - input_ids, - attention_mask, - position_ids, - seq_ids, - sampling_params, - torch.empty(0), - adapter_ids, - slot_mapping, - block_table, - num_queries, - computed_context_lens, - torch.empty(0), - torch.empty(0), - torch.empty(0), - torch.empty(0), - torch.empty(0), - *llava_args, - ) - elif self.neuron_config.enable_fused_speculation: - outputs = self.base_model( - input_ids, - attention_mask, - position_ids, - seq_ids, - sampling_params, - torch.empty(0), - adapter_ids, - slot_mapping, - block_table, - num_queries, - computed_context_lens, - *llava_args, - ) - else: - outputs = self.base_model( - input_ids, - attention_mask, - position_ids, - seq_ids, - sampling_params, - torch.empty(0), - adapter_ids, - torch.empty(0), - torch.empty(0), - torch.empty(0), - torch.empty(0), - slot_mapping, - block_table, - num_queries, - computed_context_lens, - *llava_args, + hybrid_apc_request_dict = None + llava_args = tuple(llava_args or ()) + + def _replace_hybrid_apc_args(extra_args, hybrid_args): + hybrid_args = tuple(hybrid_args or ()) + if not hybrid_args: + return tuple(extra_args or ()) + extra_args = tuple(extra_args or ()) + if len(extra_args) >= len(hybrid_args): + return (*extra_args[: -len(hybrid_args)], *hybrid_args) + return (*extra_args, *hybrid_args) + + try: + use_hybrid_apc_manager = bool( + getattr(self.config, "use_hybrid_apc_manager", False) + or getattr(self.neuron_config, "use_hybrid_apc_manager", False) + or getattr( + getattr(self.config, "neuron_config", None), + "use_hybrid_apc_manager", + False, + ) ) + if is_context_encoding and use_hybrid_apc_manager: + hybrid_apc_request_dict = { + "input_ids": input_ids, + "attention_mask": attention_mask, + "position_ids": position_ids, + "seq_ids": seq_ids, + "sampling_params": sampling_params, + "adapter_ids": adapter_ids, + "slot_mapping": slot_mapping, + "block_table": block_table, + "full_context_lens": full_context_lens, + "computed_context_lens": computed_context_lens, + "num_queries": num_queries, + } + request_records = getattr( + self, + "_qwen36_vllm_hybrid_apc_request_records", + None, + ) + if request_records is not None: + hybrid_apc_request_dict["hybrid_request_records"] = request_records + request_ids = getattr(self, "_qwen36_vllm_request_ids", None) + if request_ids is not None: + if isinstance(request_ids, list): + request_ids = tuple(request_ids) + if isinstance(request_ids, tuple) and len(request_ids) == 1: + hybrid_apc_request_dict["hybrid_request_id"] = request_ids[0] + else: + hybrid_apc_request_dict["hybrid_request_id"] = request_ids + cached_request_ids = getattr( + self, + "_qwen36_vllm_cached_request_ids", + None, + ) + if cached_request_ids is not None: + hybrid_apc_request_dict[ + "hybrid_cached_request_ids" + ] = cached_request_ids + prefill_completion_state = getattr( + self, + "_qwen36_vllm_prefill_completion_state", + None, + ) + if prefill_completion_state is not None: + hybrid_apc_request_dict[ + "hybrid_prefill_completion_state" + ] = prefill_completion_state + + prepared_inputs = prepare_hybrid_apc_request_for_execution( + self, + hybrid_apc_request_dict, + ) + input_ids = prepared_inputs.get("input_ids", input_ids) + attention_mask = prepared_inputs.get("attention_mask", attention_mask) + position_ids = prepared_inputs.get("position_ids", position_ids) + seq_ids = prepared_inputs.get("seq_ids", seq_ids) + sampling_params = prepared_inputs.get("sampling_params", sampling_params) + adapter_ids = prepared_inputs.get("adapter_ids", adapter_ids) + slot_mapping = prepared_inputs.get("slot_mapping", slot_mapping) + block_table = prepared_inputs.get("block_table", block_table) + full_context_lens = prepared_inputs.get( + "full_context_lens", + full_context_lens, + ) + computed_context_lens = prepared_inputs.get( + "computed_context_lens", + computed_context_lens, + ) + num_queries = prepared_inputs.get( + "num_queries", + full_context_lens - computed_context_lens, + ) + llava_args = _replace_hybrid_apc_args( + llava_args, + prepare_hybrid_apc_model_inputs(self, prepared_inputs), + ) + elif use_hybrid_apc_manager: + llava_args = _replace_hybrid_apc_args( + llava_args, + prepare_disabled_hybrid_apc_model_inputs( + self, + {"seq_ids": seq_ids}, + ), + ) + + if self.neuron_config.enable_eagle_speculation: + outputs = self.base_model( + input_ids, + attention_mask, + position_ids, + seq_ids, + sampling_params, + torch.empty(0), + adapter_ids, + slot_mapping, + block_table, + num_queries, + computed_context_lens, + torch.empty(0), + torch.empty(0), + torch.empty(0), + torch.empty(0), + torch.empty(0), + *llava_args, + ) + elif self.neuron_config.enable_fused_speculation: + outputs = self.base_model( + input_ids, + attention_mask, + position_ids, + seq_ids, + sampling_params, + torch.empty(0), + adapter_ids, + slot_mapping, + block_table, + num_queries, + computed_context_lens, + *llava_args, + ) + else: + outputs = self.base_model( + input_ids, + attention_mask, + position_ids, + seq_ids, + sampling_params, + torch.empty(0), + adapter_ids, + torch.empty(0), + torch.empty(0), + torch.empty(0), + torch.empty(0), + slot_mapping, + block_table, + num_queries, + computed_context_lens, + *llava_args, + ) + except Exception: + if hybrid_apc_request_dict is not None: + cancel_hybrid_apc_request(hybrid_apc_request_dict) + raise + if hybrid_apc_request_dict is not None: + finish_hybrid_apc_request(hybrid_apc_request_dict) is_run_on_neuron = self.base_model.is_neuron() elif self._is_prefill(position_ids): if self.neuron_config.is_medusa: @@ -3819,6 +3953,8 @@ def _get_captured_tensors_offset(self): return 0 def _construct_output_with_tokens_and_logits(self, next_tokens, logits, hidden_states=[]): + self._debug_constructed_output("tokens", next_tokens) + self._debug_constructed_output("logits", logits) OutputParams = CausalLMOutputWithPast( logits=logits, hidden_states=hidden_states, @@ -3841,6 +3977,8 @@ def _construct_output(self, logits_or_next_tokens): logits_or_next_tokens = logits_or_next_tokens[0] next_tokens = logits_or_next_tokens + output_kind = "tokens" if self.on_device_sampling else "logits" + self._debug_constructed_output(output_kind, logits_or_next_tokens) OutputParams = CausalLMOutputWithPast( logits=None if self.on_device_sampling else logits_or_next_tokens, hidden_states=logits_or_next_tokens, @@ -3861,6 +3999,93 @@ def _construct_output(self, logits_or_next_tokens): return OutputParams + def _debug_raw_outputs(self, outputs): + if os.environ.get("NXDI_RAW_OUTPUT_DEBUG") != "1": + return + limit = int(os.environ.get("NXDI_RAW_OUTPUT_DEBUG_LIMIT", "8")) + if isinstance(outputs, (list, tuple)): + print( + f"[nxdi_raw_output_debug] count={len(outputs)} limit={limit}", + flush=True, + ) + iterable = enumerate(outputs[:limit]) + else: + print("[nxdi_raw_output_debug] count=1 limit=1", flush=True) + iterable = [(0, outputs)] + for idx, tensor in iterable: + self._debug_constructed_output(f"raw_output[{idx}]", tensor) + + def _debug_constructed_output(self, name, tensor): + if os.environ.get("NXDI_OUTPUT_DEBUG") != "1": + return + if tensor is None or not hasattr(tensor, "numel"): + print(f"[nxdi_output_debug] name={name} tensor=none", flush=True) + return + if tensor.numel() == 0: + print( + f"[nxdi_output_debug] name={name} shape={tuple(tensor.shape)} " + f"dtype={tensor.dtype} empty", + flush=True, + ) + return + try: + flat = tensor.detach().reshape(-1) + if torch.is_floating_point(flat): + finite_mask = torch.isfinite(flat) + finite_count = int(finite_mask.sum().item()) + nan_count = int(torch.isnan(flat).sum().item()) + posinf_count = int( + torch.logical_and(torch.isinf(flat), flat > 0).sum().item() + ) + neginf_count = int( + torch.logical_and(torch.isinf(flat), flat < 0).sum().item() + ) + if finite_count: + finite_flat = flat[finite_mask].float() + finite_min = float(finite_flat.min().item()) + finite_max = float(finite_flat.max().item()) + first_row = tensor.reshape(-1, tensor.shape[-1])[0].float() + argmax = int(torch.argmax(first_row).item()) + else: + finite_min = "none" + finite_max = "none" + argmax = "none" + topk_suffix = "" + topk = int(os.environ.get("NXDI_OUTPUT_DEBUG_TOPK", "0")) + if finite_count and topk > 0 and tensor.shape[-1] > 0: + topk = min(topk, int(tensor.shape[-1])) + top_values, top_indices = torch.topk(first_row, k=topk) + topk_suffix = ( + f" first_row_top{topk}_ids=" + f"{[int(idx.item()) for idx in top_indices]}" + f" first_row_top{topk}_values=" + f"{[float(value.item()) for value in top_values]}" + ) + print( + "[nxdi_output_debug] " + f"name={name} shape={tuple(tensor.shape)} dtype={tensor.dtype} " + f"finite={finite_count}/{tensor.numel()} nan={nan_count} " + f"posinf={posinf_count} neginf={neginf_count} " + f"finite_min={finite_min} finite_max={finite_max} " + f"first_row_argmax={argmax}{topk_suffix}", + flush=True, + ) + else: + flat_i64 = flat.to(torch.int64) + print( + "[nxdi_output_debug] " + f"name={name} shape={tuple(tensor.shape)} dtype={tensor.dtype} " + f"min={int(flat_i64.min().item())} " + f"max={int(flat_i64.max().item())}", + flush=True, + ) + except Exception as exc: + print( + "[nxdi_output_debug] " + f"name={name} summary_error={type(exc).__name__}: {exc}", + flush=True, + ) + def _prepare_inputs(self): accepted_indices = torch.zeros( (self.neuron_config.batch_size, self.neuron_config.num_medusa_heads + 1), diff --git a/src/neuronx_distributed_inference/models/model_wrapper.py b/src/neuronx_distributed_inference/models/model_wrapper.py index 825fbbe2..b35c15f8 100644 --- a/src/neuronx_distributed_inference/models/model_wrapper.py +++ b/src/neuronx_distributed_inference/models/model_wrapper.py @@ -41,6 +41,11 @@ FUSED_SPECULATION_MODEL_TAG = "fused_speculation_model" VISION_ENCODER_MODEL_TAG = "vision_encoder_model" +_HYBRID_APC_MIN_EXTRA_PREFIX_ARG_COUNT = 14 +_HYBRID_APC_CONTROL_EXTRA_ARG_COUNT = 5 +_HYBRID_APC_RESTORE_ACTIVE_CONTROL_ARG_INDEX = 1 +_PREFIX_CACHING_EXTRA_ARG_START = 15 + # Get the modules_to_not_convert from the neuron configs def get_modules_to_not_convert(neuron_config: NeuronConfig): @@ -373,18 +378,68 @@ def _get_input_shape_for_prefix_caching( prefix_size, adapter_ids, ): - if self.neuron_config.enable_fused_speculation and self.tag == FUSED_SPECULATION_MODEL_TAG: + sample_prefix_size = prefix_size + if ( + self.tag == CONTEXT_ENCODING_MODEL_TAG + and getattr( + self.neuron_config, + "prefix_cte_attention_backend", + "attention_cte", + ) + == "segmented_cte" + ): + sample_prefix_size = max( + 0, + min( + int(prefix_size), + int(self.neuron_config.max_context_length) - int(n_active_tokens), + ), + ) + if self.tag == CONTEXT_ENCODING_MODEL_TAG: + active_positions = torch.arange( + sample_prefix_size, + sample_prefix_size + n_active_tokens, + dtype=torch.int32, + ).unsqueeze(0) + position_ids = active_positions.repeat(batch_size, 1) + slot_mapping = position_ids.clone() + elif self.neuron_config.enable_fused_speculation and self.tag == FUSED_SPECULATION_MODEL_TAG: slot_mapping = torch.zeros((batch_size, self.neuron_config.speculation_length), dtype=torch.int32) else: slot_mapping = torch.zeros((batch_size, n_active_tokens), dtype=torch.int32) - num_blocks = prefix_size // self.neuron_config.pa_block_size - active_block_table = torch.zeros(1, dtype=torch.int32) if num_blocks == 0 else torch.zeros( - (batch_size, num_blocks), dtype=torch.int32 + if ( + self.tag == CONTEXT_ENCODING_MODEL_TAG + and getattr( + self.neuron_config, + "prefix_cte_attention_backend", + "attention_cte", + ) + == "segmented_cte" + ): + block_table_tokens = min( + int(self.neuron_config.max_context_length), + int(prefix_size) + int(n_active_tokens), + ) + else: + block_table_tokens = prefix_size + num_blocks = ( + block_table_tokens + self.neuron_config.pa_block_size - 1 + ) // self.neuron_config.pa_block_size + active_block_table = ( + torch.zeros(1, dtype=torch.int32) + if num_blocks == 0 + else torch.arange(num_blocks, dtype=torch.int32) + .unsqueeze(0) + .repeat(batch_size, 1) ) num_queries = torch.full((batch_size, 1), n_active_tokens, dtype=torch.int32) - computed_context_lens = torch.full((batch_size, 1), prefix_size, dtype=torch.int32) + computed_context_lens = torch.full( + (batch_size, 1), + sample_prefix_size, + dtype=torch.int32, + ) if self.neuron_config.enable_eagle_speculation: if self.tag == FUSED_SPECULATION_MODEL_TAG: return ( @@ -413,12 +468,16 @@ def _get_input_shape_for_prefix_caching( target_attention_mask = torch.zeros(1, dtype=torch.int32) else: target_attention_mask = torch.ones((batch_size, prefix_size), dtype=torch.int32) - target_position_ids = torch.arange(0, prefill, dtype=torch.int32).unsqueeze(0) + target_position_ids = torch.arange(prefix_size, prefix_size + prefill, dtype=torch.int32).unsqueeze(0) target_position_ids = target_position_ids.repeat(batch_size, 1) - target_slot_mapping = torch.zeros((batch_size, prefill), dtype=torch.int32) + target_slot_mapping = target_position_ids.clone() target_num_blocks = prefix_size // self.neuron_config.pa_block_size - target_active_block_table = torch.zeros(1, dtype=torch.int32) if target_num_blocks == 0 else torch.zeros( - (batch_size, target_num_blocks), dtype=torch.int32 + target_active_block_table = ( + torch.zeros(1, dtype=torch.int32) + if target_num_blocks == 0 + else torch.arange(target_num_blocks, dtype=torch.int32) + .unsqueeze(0) + .repeat(batch_size, 1) ) return ( input_ids, @@ -592,10 +651,12 @@ def _forward_with_pad(self, *args): block_kv_empty_args = args[5:7] block_kv_slot_mapping = args[7] block_kv_args = args[8:11] + extra_prefix_args = args[11:] else: block_kv_empty_args = args[5:11] block_kv_slot_mapping = args[11] block_kv_args = args[12:15] + extra_prefix_args = args[15:] # pad the inputs up to the compiled batch size in the end reorder_seq_ids = not self.is_prefix_caching @@ -679,6 +740,36 @@ def _forward_with_pad(self, *args): eagle_empty_args = args[11:16] for arg in eagle_empty_args: padded_args.append(arg) + else: + extra_prefix_arg_count = len(extra_prefix_args) + for extra_prefix_arg_index, arg in enumerate(extra_prefix_args): + if arg.numel() == 0: + padded_args.append(arg) + elif arg.dim() == 3 and arg.shape[0] == 3 and arg.shape[1] == seq_ids.shape[0]: + padded = torch.zeros( + (arg.shape[0], target_batch_size, arg.shape[2]), + dtype=arg.dtype, + ) + padded[:, : arg.shape[1], :] = arg + padded_args.append(padded) + elif arg.shape[0] == seq_ids.shape[0]: + if self._is_hybrid_apc_control_extra_arg( + extra_prefix_arg_index, + extra_prefix_arg_count, + ): + padded = torch.zeros( + (target_batch_size,) + tuple(arg.shape[1:]), + dtype=arg.dtype, + device=arg.device, + ) + padded[: arg.shape[0]] = arg + padded_args.append(padded) + else: + padded_args.append( + self._pad_helper(arg, pad_type="repeat_first_batchline") + ) + else: + padded_args.append(arg) outputs = self._forward(*padded_args) @@ -702,6 +793,55 @@ def _forward_with_pad(self, *args): logits, *kv_cache = outputs return [torch.index_select(logits, 0, seq_ids), *kv_cache] + def _is_hybrid_apc_control_extra_arg( + self, + extra_prefix_arg_index: int, + extra_prefix_arg_count: int, + ) -> bool: + if not getattr(self.config, "use_hybrid_apc_manager", False): + return False + if self.tag not in (CONTEXT_ENCODING_MODEL_TAG, TOKEN_GENERATION_MODEL_TAG): + return False + return ( + extra_prefix_arg_count >= _HYBRID_APC_MIN_EXTRA_PREFIX_ARG_COUNT + and extra_prefix_arg_index + >= extra_prefix_arg_count - _HYBRID_APC_CONTROL_EXTRA_ARG_COUNT + ) + + def _has_hybrid_apc_control_tail(self, args) -> bool: + if not getattr(self.config, "use_hybrid_apc_manager", False): + return False + if self.tag not in (CONTEXT_ENCODING_MODEL_TAG, TOKEN_GENERATION_MODEL_TAG): + return False + return len(args) >= ( + _PREFIX_CACHING_EXTRA_ARG_START + _HYBRID_APC_MIN_EXTRA_PREFIX_ARG_COUNT + ) + + def _hybrid_apc_restore_active_arg(self, args): + if self.tag != CONTEXT_ENCODING_MODEL_TAG: + return None + if not self._has_hybrid_apc_control_tail(args): + return None + + control_start = len(args) - _HYBRID_APC_CONTROL_EXTRA_ARG_COUNT + restore_active = args[ + control_start + _HYBRID_APC_RESTORE_ACTIVE_CONTROL_ARG_INDEX + ] + if not torch.is_tensor(restore_active): + raise RuntimeError( + "Hybrid APC argument contract mismatch: expected restore-active " + "tensor at index 1 of the final 5 Hybrid APC control args" + ) + return restore_active + + def _hybrid_apc_restore_active(self, args) -> bool: + restore_active = self._hybrid_apc_restore_active_arg(args) + return ( + restore_active is not None + and restore_active.numel() > 0 + and bool(restore_active.to(torch.bool).any().item()) + ) + def _forward(self, *args): if self.async_mode: return self._process_async_inputs(*args) @@ -960,8 +1100,28 @@ def get_target_2d_bucket_for_prefix_caching(self, *args, strategy="first_fit"): else: vertical_dim = args[13] horizontal_dim = args[14] + hybrid_apc_restore_active = self._hybrid_apc_restore_active(args) if not self.tag == CONTEXT_ENCODING_MODEL_TAG: + if self.tag == TOKEN_GENERATION_MODEL_TAG: + input_shape = getattr(args[0], "shape", ()) + batch_size = input_shape[0] if len(input_shape) > 0 else 1 + active_len = input_shape[-1] if len(input_shape) > 1 else 1 + vertical_dim = torch.full( + (batch_size, 1), active_len, dtype=torch.int32 + ) + if horizontal_dim.numel() == 0: + horizontal_dim = torch.full((args[0].shape[0], 1), args[1].shape[-1], dtype=torch.int32) + elif horizontal_dim.dim() == 0: + horizontal_dim = horizontal_dim.reshape(1, 1) + elif horizontal_dim.dim() == 1: + horizontal_dim = horizontal_dim.reshape(-1, 1) + if vertical_dim.numel() == 0: + vertical_dim = torch.full((args[0].shape[0], 1), args[0].shape[-1], dtype=torch.int32) + elif vertical_dim.dim() == 0: + vertical_dim = vertical_dim.reshape(1, 1) + elif vertical_dim.dim() == 1: + vertical_dim = vertical_dim.reshape(-1, 1) # Determine all buckets that meet horizontal condition horizontal_max = torch.max(horizontal_dim) horizontal_mask = buckets[:, 1] > horizontal_max + speculation_length @@ -984,84 +1144,488 @@ def get_target_2d_bucket_for_prefix_caching(self, *args, strategy="first_fit"): else: if not self.neuron_config.allow_input_truncation: raise ValueError( - f"Input len {vertical_dim} exceeds largest bucket ({buckets[-1][1]}) for {self.tag}" + f"Active len {vertical_dim} with context len {horizontal_dim} " + f"exceeds largest bucket ({buckets[-1].tolist()}) for {self.tag}" ) else: bucket_idx = -1 return buckets[bucket_idx] # recover the bucket for special handling else: - horizontal_dim = horizontal_dim[0][0] - vertical_dim = vertical_dim[0][0] - prefix_buckets = [] - prefill_buckets = [] - for b in buckets: - if b[0] not in prefill_buckets: - prefill_buckets.append(b[0]) - if b[1] not in prefix_buckets: - prefix_buckets.append(b[1]) + def _cte_bucket_dim_or_default(tensor, default_value): + if tensor.numel() == 0: + return torch.tensor(default_value, dtype=torch.int32) + values = tensor.reshape(-1).to(torch.int32) + batch_size = args[0].shape[0] if args[0].dim() > 0 else 1 + if batch_size > 1: + return torch.max(values) + return values[0] + + if horizontal_dim.numel() == 0: + horizontal_dim = torch.tensor(0, dtype=torch.int32) + else: + horizontal_dim = _cte_bucket_dim_or_default(horizontal_dim, 0) + default_vertical_dim = args[0].shape[-1] if args[0].dim() > 0 else 1 + if vertical_dim.numel() == 0: + vertical_dim = torch.tensor(default_vertical_dim, dtype=torch.int32) + else: + vertical_dim = _cte_bucket_dim_or_default( + vertical_dim, default_vertical_dim + ) + bucket_pairs = [ + (int(bucket[0].item()), int(bucket[1].item())) + for bucket in buckets + ] + prefill_buckets = sorted({bucket[0] for bucket in bucket_pairs}) + prefix_buckets = sorted({bucket[1] for bucket in bucket_pairs}) # Corner case total_context = vertical_dim + horizontal_dim - if total_context <= 512 and total_context > 256: + vertical_dim_int = int(vertical_dim.item()) + horizontal_dim_int = int(horizontal_dim.item()) + input_token_len = args[0].shape[-1] if args[0].dim() > 1 else 1 + suffix_only_cte_continuation = ( + horizontal_dim_int > 0 + and input_token_len <= vertical_dim_int + and input_token_len < vertical_dim_int + horizontal_dim_int + ) + if ( + not hybrid_apc_restore_active + and not suffix_only_cte_continuation + and total_context <= 512 + and total_context > 256 + ): for b in buckets: if b[0] == 512 and b[1] == 0: return b - # Select prefill bucket - prefill_index = 0 + # Select a compiled 2D bucket. NxDI's default prefix-cache buckets + # are a full CTE x prefix grid, but production artifacts may prune + # compiler-problematic high-prefix pairs. if self.neuron_config.enable_eagle_speculation: vertical_dim = vertical_dim + self.neuron_config.pa_block_size - for b in prefill_buckets: - if vertical_dim > b: - prefill_index += 1 - else: - break - # check prefill overflow - if prefill_index == len(prefill_buckets): - if not self.neuron_config.allow_input_truncation: - raise ValueError( - f"Prefill len {vertical_dim} exceeds largest bucket ({prefill_buckets[-1]}) for {self.tag}" + target_prefill_len = int(vertical_dim.item()) + target_prefix_len = int(horizontal_dim.item()) + + def _required_prefix_for_prefill(prefill_len): + empty_prefill_slots = max(0, prefill_len - target_prefill_len) + if self.neuron_config.enable_eagle_speculation: + # Calculate how many blocks can be moved from prefix to prefill. + empty_prefill_block_slots = ( + empty_prefill_slots // self.neuron_config.pa_block_size ) - else: - prefill_index = len(prefill_buckets) - 1 - # Select prefix bucket - prefill_len = prefill_buckets[prefill_index] - empty_prefill_slots = max(0, prefill_len - vertical_dim) - if self.neuron_config.enable_eagle_speculation: - # Calculate how many blocks can be moved from prefix to prefill. - empty_prefill_block_slots = empty_prefill_slots // self.neuron_config.pa_block_size - horizontal_dim = max(0, horizontal_dim - empty_prefill_block_slots * self.neuron_config.pa_block_size) - else: - horizontal_dim = max(0, horizontal_dim - empty_prefill_slots) - prefix_index = 0 - for b in prefix_buckets: - if horizontal_dim > b: - prefix_index += 1 - else: - break - # TODO: Handle this corner scenario by using the largest prefix bucket and up the prefill bucket - assert prefix_index != len(prefix_buckets), f"Prefix len {horizontal_dim} exceeds largest bucket {prefix_buckets[-1]} for {self.tag}" - bucket_idx = prefill_index * len(prefix_buckets) + prefix_index - return buckets[bucket_idx] + if not hybrid_apc_restore_active: + return max( + 0, + target_prefix_len + - empty_prefill_block_slots + * self.neuron_config.pa_block_size, + ) + elif not hybrid_apc_restore_active: + return max(0, target_prefix_len - empty_prefill_slots) + + # Restored Hybrid APC CTE uses the attention-mask tensor as an + # active suffix validity mask, so it is padded to the prefill + # bucket instead of the restored-prefix bucket. Route to a + # traced shape whose block table width matches that mask width. + return max(target_prefix_len, prefill_len) + + candidate_indices = [] + for bucket_idx, (prefill_len, prefix_len) in enumerate(bucket_pairs): + if prefill_len < target_prefill_len: + continue + if prefix_len < _required_prefix_for_prefill(prefill_len): + continue + candidate_indices.append(bucket_idx) + + if candidate_indices: + bucket_idx = min( + candidate_indices, + key=lambda idx: (bucket_pairs[idx][0], bucket_pairs[idx][1]), + ) + return buckets[bucket_idx] + + if self.neuron_config.allow_input_truncation: + return buckets[-1] + raise ValueError( + f"Prefill len {target_prefill_len} with prefix len " + f"{target_prefix_len} exceeds compiled 2D buckets for {self.tag}; " + f"largest prefill bucket {prefill_buckets[-1]}, largest prefix " + f"bucket {prefix_buckets[-1]}" + ) def _pad_prefix_caching_inputs(self, *args, pad_type="first_fit"): - if self.tag == CONTEXT_ENCODING_MODEL_TAG and args[0].shape[0] > 1: - # We delay all paddings for CTE until we really need them - return args + def _debug_int(value): + if hasattr(value, "item"): + return int(value.item()) + return int(value) + + def _debug_minmax(tensor): + if not hasattr(tensor, "numel") or tensor.numel() == 0: + return "empty" + flat = tensor.reshape(-1) + return f"{int(flat.min().item())}:{int(flat.max().item())}" + + debug_hybrid_apc = os.environ.get("QWEN36_HYBRID_APC_DEBUG") == "1" + hybrid_apc_restore_active = self._hybrid_apc_restore_active(args) + + def _first_or_default(tensor, default_value): + if tensor.numel() == 0: + return torch.tensor(default_value, dtype=torch.int32) + return tensor.reshape(-1)[0] + + def _length_matrix_or_default(tensor, default_value): + if tensor.numel() == 0: + return torch.full((args[0].shape[0], 1), default_value, dtype=torch.int32) + if tensor.dim() == 0: + return tensor.reshape(1, 1).to(torch.int32) + if tensor.dim() == 1: + return tensor.reshape(-1, 1).to(torch.int32) + return tensor.to(torch.int32) + + def _mask_block_table_to_prefix_lens( + block_table, + prefix_lens, + active_lens=None, + ): + if ( + block_table.numel() == 0 + or block_table.dim() < 2 + or prefix_lens.numel() == 0 + ): + return block_table + masked = block_table.clone() + flat_prefix_lens = prefix_lens.reshape(-1).to(torch.int64) + flat_active_lens = ( + active_lens.reshape(-1).to(torch.int64) + if active_lens is not None and active_lens.numel() > 0 + else None + ) + row_count = min(masked.shape[0], int(flat_prefix_lens.numel())) + block_size = int(self.neuron_config.pa_block_size) + for row_idx in range(row_count): + prefix_len = max(0, int(flat_prefix_lens[row_idx].item())) + if flat_active_lens is not None and row_idx < int(flat_active_lens.numel()): + prefix_len += max(0, int(flat_active_lens[row_idx].item())) + keep_blocks = min( + masked.shape[1], + (prefix_len + block_size - 1) // block_size, + ) + if keep_blocks < masked.shape[1]: + masked[row_idx, keep_blocks:] = 0 + return masked + + def _fill_segmented_cte_active_blocks( + block_table, + slot_mapping, + prefix_lens, + active_lens, + ): + if ( + block_table.numel() == 0 + or block_table.dim() < 2 + or slot_mapping.numel() == 0 + or slot_mapping.dim() < 2 + or prefix_lens.numel() == 0 + or active_lens.numel() == 0 + ): + return block_table + + block_size = int(self.neuron_config.pa_block_size) + flat_prefix_lens = prefix_lens.reshape(-1).to(torch.int64) + flat_active_lens = active_lens.reshape(-1).to(torch.int64) + row_count = min( + block_table.shape[0], + slot_mapping.shape[0], + int(flat_prefix_lens.numel()), + int(flat_active_lens.numel()), + ) + max_needed_blocks = block_table.shape[1] + for row_idx in range(row_count): + prefix_len = max(0, int(flat_prefix_lens[row_idx].item())) + active_len = max( + 0, + min( + int(flat_active_lens[row_idx].item()), + slot_mapping.shape[1], + ), + ) + max_needed_blocks = max( + max_needed_blocks, + (prefix_len + active_len + block_size - 1) // block_size, + ) + + patched = block_table + if max_needed_blocks > patched.shape[1]: + patched = F.pad( + patched, + (0, max_needed_blocks - patched.shape[1]), + "constant", + 0, + ) + patched = patched.clone() + + for row_idx in range(row_count): + prefix_len = max(0, int(flat_prefix_lens[row_idx].item())) + active_len = max( + 0, + min( + int(flat_active_lens[row_idx].item()), + slot_mapping.shape[1], + ), + ) + for token_idx in range(active_len): + slot = int(slot_mapping[row_idx, token_idx].item()) + if slot < 0: + continue + logical_block = (prefix_len + token_idx) // block_size + if logical_block >= patched.shape[1]: + continue + patched[row_idx, logical_block] = slot // block_size + return patched + # Calculate the buckets prefill_bucket, prefix_bucket = self.get_target_2d_bucket_for_prefix_caching(*args, strategy=pad_type) + use_segmented_prefix_cte = ( + getattr( + self.neuron_config, + "prefix_cte_attention_backend", + "attention_cte", + ) + == "segmented_cte" + ) + + def _prefix_block_table_blocks(prefix_tokens, active_tokens=0): + block_size = int(self.neuron_config.pa_block_size) + total_tokens = int(prefix_tokens) + if self.tag == CONTEXT_ENCODING_MODEL_TAG and use_segmented_prefix_cte: + total_tokens = min( + int(self.neuron_config.max_context_length), + total_tokens + int(active_tokens), + ) + return (total_tokens + block_size - 1) // block_size if self.tag == CONTEXT_ENCODING_MODEL_TAG: if self.neuron_config.enable_fused_speculation: slot_mapping = args[7] block_table = args[8] - prefill_len = args[9][0] - prefix_len = args[10][0] + prefill_len = _first_or_default(args[9], args[0].shape[-1]) + prefix_len = _first_or_default(args[10], 0) + num_queries = _length_matrix_or_default(args[9], prefill_len) + computed_context_lens = _length_matrix_or_default(args[10], prefix_len) else: slot_mapping = args[11] block_table = args[12] - prefill_len = args[13][0] - prefix_len = args[14][0] + prefill_len = _first_or_default(args[13], args[0].shape[-1]) + prefix_len = _first_or_default(args[14], 0) + num_queries = _length_matrix_or_default(args[13], prefill_len) + computed_context_lens = _length_matrix_or_default(args[14], prefix_len) + if slot_mapping.dim() == 1: + if args[0].shape[0] > 1 and slot_mapping.shape[0] == args[0].shape[0]: + slot_mapping = slot_mapping.view(args[0].shape[0], 1) + else: + slot_mapping = slot_mapping.view(1, -1) + if block_table.dim() == 1: + if args[0].shape[0] > 1 and block_table.shape[0] == args[0].shape[0]: + block_table = block_table.view(args[0].shape[0], 1) + else: + block_table = block_table.view(1, -1) + slot_mapping = slot_mapping.to(torch.int32) + block_table = block_table.to(torch.int32) + if use_segmented_prefix_cte: + block_table = _fill_segmented_cte_active_blocks( + block_table, + slot_mapping, + computed_context_lens, + num_queries, + ) + block_table = _mask_block_table_to_prefix_lens( + block_table, + computed_context_lens, + num_queries if use_segmented_prefix_cte else None, + ) + if args[0].shape[0] > 1: + prefill_len = torch.max(num_queries.reshape(-1)) + prefix_len = torch.max(computed_context_lens.reshape(-1)) + if debug_hybrid_apc: + print( + "[hybrid_apc_debug] pad-pre " + f"tag={self.tag} input_shape={tuple(args[0].shape)} " + f"attention_shape={tuple(args[1].shape)} " + f"position_shape={tuple(args[2].shape)} position_minmax={_debug_minmax(args[2])} " + f"slot_shape={tuple(slot_mapping.shape)} slot_minmax={_debug_minmax(slot_mapping)} " + f"block_shape={tuple(block_table.shape)} block_minmax={_debug_minmax(block_table)} " + f"prefill_len={_debug_int(prefill_len)} prefix_len={_debug_int(prefix_len)} " + f"prefill_bucket={prefill_bucket} prefix_bucket={prefix_bucket}", + flush=True, + ) + if args[0].shape[0] > 1: + batch_size = args[0].shape[0] + prefill_bucket_int = _debug_int(prefill_bucket) + prefix_bucket_int = _debug_int(prefix_bucket) + + def _right_pad_or_trim_dim1(tensor, target_len, pad_value): + if tensor.shape[1] > target_len: + return tensor[:, :target_len] + return F.pad( + tensor, + (0, target_len - tensor.shape[1]), + "constant", + pad_value, + ) + + def _restore_attention_mask(tensor, target_len): + tensor = tensor.to(torch.int32) + if ( + target_len > prefill_bucket_int + and tensor.shape[1] <= prefill_bucket_int + ): + full_context_lens = ( + computed_context_lens.reshape(-1).to(torch.int64) + + num_queries.reshape(-1).to(torch.int64) + ) + full_mask = torch.zeros( + (batch_size, target_len), + dtype=torch.int32, + device=tensor.device, + ) + for row_idx in range( + min(batch_size, int(full_context_lens.numel())) + ): + active_len = max( + 0, + min( + int(full_context_lens[row_idx].item()), + target_len, + ), + ) + if active_len: + full_mask[row_idx, :active_len] = 1 + return full_mask + return _right_pad_or_trim_dim1(tensor, target_len, 0) + + padded_inputs = _right_pad_or_trim_dim1( + args[0], prefill_bucket_int, self.config.pad_token_id + ) + padded_position_id = _right_pad_or_trim_dim1( + args[2], prefill_bucket_int, 1 + ) + padded_slot_mapping = _right_pad_or_trim_dim1( + slot_mapping, prefill_bucket_int, -1 + ) + + if hybrid_apc_restore_active: + active_attn_mask = args[1] + if active_attn_mask.dim() == 1: + active_attn_mask = active_attn_mask.view(1, -1) + if active_attn_mask.shape[0] < batch_size: + pad_rows = torch.zeros( + ( + batch_size - active_attn_mask.shape[0], + active_attn_mask.shape[1], + ), + dtype=active_attn_mask.dtype, + device=active_attn_mask.device, + ) + active_attn_mask = torch.cat([active_attn_mask, pad_rows], dim=0) + elif active_attn_mask.shape[0] > batch_size: + active_attn_mask = active_attn_mask[:batch_size] + attention_target_len = ( + prefix_bucket_int + if prefix_bucket_int > 0 + else prefill_bucket_int + ) + padded_attn_mask = _restore_attention_mask( + active_attn_mask, + attention_target_len, + ) + elif prefix_bucket_int == 0: + padded_attn_mask = torch.zeros( + 1, dtype=torch.int32, device=args[1].device + ) + else: + padded_attn_mask = torch.zeros( + (batch_size, prefix_bucket_int), + dtype=torch.int32, + device=args[1].device, + ) + prefix_lengths = computed_context_lens.reshape(-1).to(torch.int64) + for row_idx in range(min(batch_size, int(prefix_lengths.numel()))): + row_prefix_len = max( + 0, + min( + int(prefix_lengths[row_idx].item()), + prefix_bucket_int, + ), + ) + if row_prefix_len: + padded_attn_mask[row_idx, :row_prefix_len] = 1 + + if prefix_bucket_int == 0 and not use_segmented_prefix_cte: + padded_block_table = torch.zeros( + 1, dtype=torch.int32, device=block_table.device + ) + else: + num_blocks = _prefix_block_table_blocks( + prefix_bucket_int, prefill_bucket + ) + if block_table.shape[0] < batch_size: + pad_rows = torch.zeros( + (batch_size - block_table.shape[0], block_table.shape[1]), + dtype=block_table.dtype, + device=block_table.device, + ) + block_table = torch.cat([block_table, pad_rows], dim=0) + elif block_table.shape[0] > batch_size: + block_table = block_table[:batch_size] + if block_table.shape[1] > num_blocks: + padded_block_table = block_table[:, :num_blocks] + else: + padded_block_table = F.pad( + block_table, + (0, num_blocks - block_table.shape[1]), + "constant", + 0, + ) + + if self.neuron_config.enable_fused_speculation: + args = ( + padded_inputs, + padded_attn_mask, + padded_position_id, + *args[3:7], + padded_slot_mapping, + padded_block_table, + num_queries, + computed_context_lens, + *args[11:], + ) + else: + args = ( + padded_inputs, + padded_attn_mask, + padded_position_id, + *args[3:11], + padded_slot_mapping, + padded_block_table, + num_queries, + computed_context_lens, + *args[15:], + ) + if debug_hybrid_apc: + print( + "[hybrid_apc_debug] pad-post " + f"tag={self.tag} batched_cte=1 " + f"prefill_bucket={prefill_bucket} prefix_bucket={prefix_bucket} " + f"padded_input_shape={tuple(padded_inputs.shape)} " + f"padded_attention_shape={tuple(padded_attn_mask.shape)} " + f"padded_position_shape={tuple(padded_position_id.shape)} " + f"padded_slot_shape={tuple(padded_slot_mapping.shape)} " + f"padded_slot_minmax={_debug_minmax(padded_slot_mapping)} " + f"padded_block_shape={tuple(padded_block_table.shape)} " + f"padded_block_minmax={_debug_minmax(padded_block_table)}", + flush=True, + ) + return tuple(args) if self.neuron_config.enable_eagle_speculation: target_recomputation = 0 if prefix_bucket == 0 else self.neuron_config.pa_block_size extra_prefill_slots = max(0, prefill_bucket - prefill_len - target_recomputation) @@ -1095,21 +1659,60 @@ def _pad_prefix_caching_inputs(self, *args, pad_type="first_fit"): target_padded_slot_mapping = F.pad(slot_mapping, (prefix_len - target_adjusted_prefix_len, 0), "constant", -1) target_padded_slot_mapping = F.pad(target_padded_slot_mapping, (0, prefill_bucket - target_padded_slot_mapping.shape[1]), "constant", -1) - num_blocks = prefix_bucket // self.neuron_config.pa_block_size + num_blocks = _prefix_block_table_blocks( + prefix_bucket, prefill_bucket + ) if num_blocks == 0: padded_block_table = torch.zeros(1, dtype=torch.int) target_padded_block_table = torch.zeros(1, dtype=torch.int) else: padded_block_table = F.pad(block_table, (0, num_blocks - block_table.shape[1]), "constant", 0) target_padded_block_table = F.pad(block_table, (0, num_blocks - block_table.shape[1]), "constant", 0) - args = (padded_inputs, padded_attn_mask, padded_position_id, *args[3:7], padded_slot_mapping, padded_block_table, *args[9:11], target_padded_inputs, target_padded_attn_mask, target_padded_position_id, target_padded_slot_mapping, target_padded_block_table) + args = (padded_inputs, padded_attn_mask, padded_position_id, *args[3:7], padded_slot_mapping, padded_block_table, num_queries, computed_context_lens, target_padded_inputs, target_padded_attn_mask, target_padded_position_id, target_padded_slot_mapping, target_padded_block_table) return tuple(args) else: extra_prefill_slots = max(0, prefill_bucket - prefill_len) - adjusted_prefix_len = max(0, prefix_len - extra_prefill_slots) - sliced_inputs = args[0][:, adjusted_prefix_len:] - sliced_attn_mask = args[1][:, :adjusted_prefix_len] - sliced_position_id = args[2][:, adjusted_prefix_len:] + suffix_only_cte_continuation = ( + _debug_int(prefix_len) > 0 + and args[0].shape[-1] <= _debug_int(prefill_len) + and args[0].shape[-1] < _debug_int(prefill_len) + _debug_int(prefix_len) + ) + if hybrid_apc_restore_active: + # Hybrid APC request prep has already sliced input_ids, + # attention_mask, position_ids, and slot_mapping to the + # suffix. Preserve those tensors and keep computed_context + # as the restored prefix length. + adjusted_prefix_len = prefix_len + sliced_inputs = args[0] + sliced_attn_mask = args[1] + sliced_position_id = args[2] + elif suffix_only_cte_continuation: + # Qwen chunked-prefill continuations are already suffix-only + # but still need the prefix bucket/mask for block-KV attention. + adjusted_prefix_len = prefix_len + sliced_inputs = args[0] + sliced_position_id = args[2] + prefix_bucket_int = _debug_int(prefix_bucket) + sliced_attn_mask = torch.zeros( + (args[0].shape[0], prefix_bucket_int), + dtype=args[1].dtype, + device=args[1].device, + ) + prefix_lengths = computed_context_lens.reshape(-1).to(torch.int64) + for row_idx in range( + min(args[0].shape[0], int(prefix_lengths.numel())) + ): + row_prefix_len = max( + 0, + min(int(prefix_lengths[row_idx].item()), prefix_bucket_int), + ) + if row_prefix_len: + sliced_attn_mask[row_idx, :row_prefix_len] = 1 + else: + adjusted_prefix_len = max(0, prefix_len - extra_prefill_slots) + sliced_inputs = args[0][:, adjusted_prefix_len:] + sliced_attn_mask = args[1][:, :adjusted_prefix_len] + sliced_position_id = args[2][:, adjusted_prefix_len:] padded_inputs = F.pad(sliced_inputs, (0, prefill_bucket - sliced_inputs.shape[1]), "constant", self.config.pad_token_id) if prefix_bucket == 0: @@ -1117,29 +1720,73 @@ def _pad_prefix_caching_inputs(self, *args, pad_type="first_fit"): else: padded_attn_mask = F.pad(sliced_attn_mask, (0, prefix_bucket - sliced_attn_mask.shape[1]), "constant", 0) padded_position_id = F.pad(sliced_position_id, (0, prefill_bucket - sliced_position_id.shape[1]), "constant", 1) - padded_slot_mapping = F.pad(slot_mapping, (prefix_len - adjusted_prefix_len, 0), "constant", -1) + left_slot_pad = ( + 0 + if hybrid_apc_restore_active or suffix_only_cte_continuation + else prefix_len - adjusted_prefix_len + ) + padded_slot_mapping = F.pad(slot_mapping, (left_slot_pad, 0), "constant", -1) padded_slot_mapping = F.pad(padded_slot_mapping, (0, prefill_bucket - padded_slot_mapping.shape[1]), "constant", -1) - num_blocks = prefix_bucket // self.neuron_config.pa_block_size + num_blocks = _prefix_block_table_blocks( + prefix_bucket, prefill_bucket + ) if num_blocks == 0: padded_block_table = torch.zeros(1, dtype=torch.int) else: padded_block_table = F.pad(block_table, (0, num_blocks - block_table.shape[1]), "constant", 0) if self.neuron_config.enable_fused_speculation: - args = (padded_inputs, padded_attn_mask, padded_position_id, *args[3:7], padded_slot_mapping, padded_block_table, *args[9:]) + args = (padded_inputs, padded_attn_mask, padded_position_id, *args[3:7], padded_slot_mapping, padded_block_table, num_queries, computed_context_lens, *args[11:]) else: - args = (padded_inputs, padded_attn_mask, padded_position_id, *args[3:11], padded_slot_mapping, padded_block_table, *args[13:]) + args = (padded_inputs, padded_attn_mask, padded_position_id, *args[3:11], padded_slot_mapping, padded_block_table, num_queries, computed_context_lens, *args[15:]) + if debug_hybrid_apc: + print( + "[hybrid_apc_debug] pad-post " + f"tag={self.tag} adjusted_prefix_len={_debug_int(adjusted_prefix_len)} " + f"extra_prefill_slots={_debug_int(extra_prefill_slots)} " + f"padded_input_shape={tuple(padded_inputs.shape)} " + f"padded_attention_shape={tuple(padded_attn_mask.shape)} " + f"padded_position_shape={tuple(padded_position_id.shape)} " + f"padded_slot_shape={tuple(padded_slot_mapping.shape)} " + f"padded_slot_minmax={_debug_minmax(padded_slot_mapping)} " + f"padded_block_shape={tuple(padded_block_table.shape)} " + f"padded_block_minmax={_debug_minmax(padded_block_table)}", + flush=True, + ) return tuple(args) else: padded_attn_mask = F.pad(args[1], (0, prefix_bucket - args[1].shape[1]), "constant", 0) + slot_mapping_arg_idx = 7 if self.neuron_config.enable_fused_speculation else 11 block_table_arg_idx = 8 if self.neuron_config.enable_fused_speculation else 12 + num_queries_arg_idx = 9 if self.neuron_config.enable_fused_speculation else 13 + computed_context_lens_arg_idx = 10 if self.neuron_config.enable_fused_speculation else 14 + slot_mapping = args[slot_mapping_arg_idx] block_table = args[block_table_arg_idx] - pad_right = (prefix_bucket // self.neuron_config.pa_block_size) - block_table.shape[1] + if slot_mapping.dim() == 1: + slot_mapping = slot_mapping.view(1, -1) + if block_table.dim() == 1: + block_table = block_table.view(1, -1) + slot_mapping = slot_mapping.to(torch.int32) + block_table = block_table.to(torch.int32) + padded_slot_mapping = F.pad(slot_mapping, (0, prefill_bucket - slot_mapping.shape[1]), "constant", -1) + pad_right = _prefix_block_table_blocks(prefix_bucket) - block_table.shape[1] block_table_padding = -1 if self.neuron_config.attn_block_tkg_nki_kernel_enabled else 0 padded_block_table = F.pad(block_table, (0, pad_right), "constant", block_table_padding) + if self.tag == TOKEN_GENERATION_MODEL_TAG: + num_queries = torch.full( + (args[0].shape[0], 1), + args[0].shape[-1], + dtype=torch.int32, + ) + else: + num_queries = _length_matrix_or_default(args[num_queries_arg_idx], args[0].shape[-1]) + computed_context_lens = _length_matrix_or_default(args[computed_context_lens_arg_idx], args[1].shape[-1]) new_args = list(args) new_args[1] = padded_attn_mask + new_args[slot_mapping_arg_idx] = padded_slot_mapping new_args[block_table_arg_idx] = padded_block_table + new_args[num_queries_arg_idx] = num_queries + new_args[computed_context_lens_arg_idx] = computed_context_lens return tuple(new_args) def _process_async_inputs(self, *args): @@ -1285,7 +1932,11 @@ def _process_args(self, *args): # set hidden_states if None if args[5] is None: - dummy_hidden_states = torch.zeros((input_batch_size), dtype=torch.int32) + dummy_hidden_states = ( + torch.empty(0) + if self.is_prefix_caching + else torch.zeros((input_batch_size), dtype=torch.int32) + ) args = (*args[:5], dummy_hidden_states, *args[6:]) # set adapter_ids if None diff --git a/src/neuronx_distributed_inference/modules/async_execution.py b/src/neuronx_distributed_inference/modules/async_execution.py index 5e67aa92..c7343a4c 100644 --- a/src/neuronx_distributed_inference/modules/async_execution.py +++ b/src/neuronx_distributed_inference/modules/async_execution.py @@ -1,3 +1,5 @@ +import os +import time from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Union import torch @@ -7,6 +9,2271 @@ from neuronx_distributed_inference.models.model_wrapper import ModelWrapper +def _is_hybrid_apc_enabled(neuron_base_instance: "NeuronBaseForCausalLM") -> bool: + for owner in ( + neuron_base_instance, + getattr(neuron_base_instance, "config", None), + getattr(neuron_base_instance, "neuron_config", None), + getattr(getattr(neuron_base_instance, "config", None), "neuron_config", None), + ): + if bool(getattr(owner, "use_hybrid_apc_manager", False)): + return True + return False + + +def _async_request_ids_signature(neuron_base_instance: "NeuronBaseForCausalLM"): + request_ids = getattr(neuron_base_instance, "_qwen36_vllm_request_ids", None) + if request_ids is None: + return None + if isinstance(request_ids, torch.Tensor): + return tuple(request_ids.detach().cpu().reshape(-1).tolist()) + if isinstance(request_ids, (str, bytes)): + return (request_ids,) + try: + return tuple(request_ids) + except TypeError: + return (request_ids,) + + +def _batch_vector( + input_dict: Dict[str, Any], + key: str, + *, + batch_size: int, + default: int = 0, +) -> torch.Tensor: + value = input_dict.get(key) + if value is None: + return torch.full((batch_size,), default, dtype=torch.int32) + value = value.reshape(-1).to(torch.int32) + if value.shape[0] == batch_size: + return value + if value.shape[0] > batch_size: + return value[:batch_size] + pad = torch.full((batch_size - value.shape[0],), default, dtype=value.dtype) + return torch.cat([value, pad], dim=0) + + +def _first_present(*values): + for value in values: + if value is not None: + return value + return None + + +def _to_python_int(value: Any) -> int: + if isinstance(value, torch.Tensor): + return int(value.reshape(-1)[0].item()) + return int(value) + + +def _single_batch_value(value: Any): + if value is None: + return None + if isinstance(value, torch.Tensor): + flat = value.reshape(-1) + if flat.numel() != 1: + return None + return flat[0] + if isinstance(value, (list, tuple)): + if len(value) != 1: + return None + return value[0] + return value + + +def _multi_batch_int_values(value: Any) -> list[int] | None: + if value is None: + return None + if isinstance(value, torch.Tensor): + flat = value.reshape(-1) + if flat.numel() <= 1: + return None + return [int(item.item()) for item in flat] + if isinstance(value, (list, tuple)): + if len(value) <= 1: + return None + try: + return [int(item) for item in value] + except (TypeError, ValueError): + return None + return None + + +def _single_batch_tensor(value: Any) -> bool: + return isinstance(value, torch.Tensor) and value.ndim >= 1 and value.shape[0] == 1 + + +def _truthy_single_value(value: Any) -> bool: + item = _single_batch_value(value) + if item is None: + return False + if isinstance(item, torch.Tensor): + return bool(item.item()) + return bool(item) + + +def _request_id_matches(candidate: Any, request_id: Any) -> bool: + if candidate == request_id: + return True + try: + return str(candidate) == str(request_id) + except Exception: + return False + + +def _request_id_in_collection(request_id: Any, values: Any) -> bool: + if request_id is None or values is None: + return False + if isinstance(values, torch.Tensor): + values = values.reshape(-1).tolist() + elif isinstance(values, (str, bytes)): + values = (values,) + try: + iterator = iter(values) + except TypeError: + return _request_id_matches(values, request_id) + return any(_request_id_matches(value, request_id) for value in iterator) + + +def _as_hybrid_apc_request_id_tuple(value: Any) -> tuple[Any, ...] | None: + if value is None: + return None + if isinstance(value, torch.Tensor): + return tuple(value.detach().cpu().reshape(-1).tolist()) + if isinstance(value, (str, bytes)): + return (value,) + try: + return tuple(value) + except TypeError: + return (value,) + + +def _lookup_hybrid_apc_metadata_for_request(metadata_by_request_id: Any, request_id: Any): + if not isinstance(metadata_by_request_id, dict): + return None + for key in (request_id, str(request_id)): + metadata = metadata_by_request_id.get(key) + if isinstance(metadata, dict): + return metadata + for key, metadata in metadata_by_request_id.items(): + if _request_id_matches(key, request_id) and isinstance(metadata, dict): + return metadata + return None + + +def _with_hybrid_apc_owner_metadata( + input_dict: Dict[str, Any], + owner: Any, +) -> Dict[str, Any]: + output = input_dict + request_records = getattr(owner, "_qwen36_vllm_hybrid_apc_request_records", None) + if request_records is not None and "hybrid_request_records" not in output: + output = dict(output) + output["hybrid_request_records"] = request_records + + request_ids = _as_hybrid_apc_request_id_tuple( + getattr(owner, "_qwen36_vllm_request_ids", None) + ) + metadata_by_request_id = getattr( + owner, + "_qwen36_vllm_hybrid_apc_metadata_by_request_id", + None, + ) + if ( + request_records is None + and request_ids + and isinstance(metadata_by_request_id, dict) + and "hybrid_request_records" not in output + ): + records = [] + for request_id in request_ids: + metadata = _lookup_hybrid_apc_metadata_for_request( + metadata_by_request_id, + request_id, + ) + if metadata is None: + continue + record = {"request_id": request_id} + record.update(metadata) + records.append(record) + if records: + output = dict(output) + output["hybrid_request_records"] = tuple(records) + + if request_ids and "hybrid_request_id" not in output: + output = dict(output) + output["hybrid_request_id"] = request_ids[0] if len(request_ids) == 1 else request_ids + + for attr, key in ( + ("_qwen36_vllm_cached_request_ids", "hybrid_cached_request_ids"), + ("_qwen36_vllm_prefill_completion_state", "hybrid_prefill_completion_state"), + ): + value = getattr(owner, attr, None) + if value is not None and key not in output: + output = dict(output) + output[key] = value + return output + + +def _with_hybrid_apc_candidate_owner_metadata( + input_dict: Dict[str, Any], + *owners: Any, +) -> Dict[str, Any]: + output = input_dict + seen: set[int] = set() + for owner in owners: + if owner is None: + continue + owner_id = id(owner) + if owner_id in seen: + continue + seen.add(owner_id) + output = _with_hybrid_apc_owner_metadata(output, owner) + return output + + +def _batch_size_from_input_dict(input_dict: Dict[str, Any]) -> int: + batch_size = 1 + for key in ( + "input_ids", + "seq_ids", + "computed_context_lens", + "full_context_lens", + "vllm_attention_hit_len", + "hybrid_attention_hit_len", + "attention_hit_len", + "hybrid_prefill_completion_state", + "hybrid_active_suffix_len", + "active_suffix_len", + "hybrid_request_records", + ): + value = input_dict.get(key) + if isinstance(value, torch.Tensor) and value.ndim >= 1: + batch_size = max(batch_size, int(value.reshape(value.shape[0], -1).shape[0])) + elif isinstance(value, (list, tuple)): + batch_size = max(batch_size, len(value)) + return batch_size + + +def _batch_int_list( + input_dict: Dict[str, Any], + key: str, + *, + batch_size: int, +) -> list[int] | None: + value = input_dict.get(key) + if value is None: + return None + if isinstance(value, torch.Tensor): + flat = value.reshape(-1) + if flat.numel() < batch_size: + return None + return [int(item.item()) for item in flat[:batch_size]] + if isinstance(value, (list, tuple)) and len(value) >= batch_size: + try: + return [int(item) for item in value[:batch_size]] + except (TypeError, ValueError): + return None + return None + + +def _vectorized_query_lengths( + input_dict: Dict[str, Any], + *, + batch_size: int, +) -> list[int] | None: + num_queries = _batch_int_list(input_dict, "num_queries", batch_size=batch_size) + if num_queries is not None: + return num_queries + + active_suffix_len = _batch_int_list( + input_dict, + "hybrid_active_suffix_len", + batch_size=batch_size, + ) + if active_suffix_len is None: + active_suffix_len = _batch_int_list( + input_dict, + "active_suffix_len", + batch_size=batch_size, + ) + if active_suffix_len is not None: + return [max(0, int(query_len)) for query_len in active_suffix_len] + + records = _hybrid_apc_request_records(input_dict, batch_size=batch_size) + record_active_suffix_len = _hybrid_apc_record_values(records, "active_suffix_len") + if ( + isinstance(record_active_suffix_len, (list, tuple)) + and len(record_active_suffix_len) >= batch_size + ): + try: + return [ + max(0, int(query_len)) + for query_len in record_active_suffix_len[:batch_size] + ] + except (TypeError, ValueError): + pass + if batch_size == 1 and record_active_suffix_len is not None: + try: + return [max(0, int(record_active_suffix_len))] + except (TypeError, ValueError): + pass + + full_context_lens = _batch_int_list( + input_dict, + "full_context_lens", + batch_size=batch_size, + ) + computed_context_lens = _batch_int_list( + input_dict, + "computed_context_lens", + batch_size=batch_size, + ) + if full_context_lens is not None and computed_context_lens is not None: + return [ + max(0, full_len - computed_len) + for full_len, computed_len in zip(full_context_lens, computed_context_lens) + ] + + input_ids = input_dict.get("input_ids") + if ( + isinstance(input_ids, torch.Tensor) + and input_ids.ndim >= 2 + and input_ids.shape[0] == 1 + and input_ids.shape[1] % batch_size == 0 + ): + return [input_ids.shape[1] // batch_size] * batch_size + return None + + +def _select_batch_item( + value: Any, + index: int, + batch_size: int, + *, + key: str = "", + query_lengths: list[int] | None = None, +): + if isinstance(value, torch.Tensor): + if value.numel() == 0 or value.ndim == 0: + return value + if ( + key in {"rotary_position_id", "rotary_position_ids"} + and value.ndim >= 2 + and value.shape[1] == batch_size + ): + return value[:, index : index + 1, ...] + if value.shape[0] == batch_size: + return value[index : index + 1] + if ( + value.ndim >= 2 + and value.shape[0] == 1 + and query_lengths is not None + and key + in { + "input_ids", + "attention_mask", + "position_ids", + "slot_mapping", + "inputs_embeds", + } + ): + offset = sum(query_lengths[:index]) + length = query_lengths[index] + if offset + length <= value.shape[1]: + return value[:, offset : offset + length, ...] + if ( + key in {"seq_ids", "adapter_ids"} + and value.ndim == 1 + and value.shape[0] == 1 + and batch_size > 1 + ): + fill_value = index if key == "seq_ids" else int(value.reshape(-1)[0].item()) + return torch.tensor([fill_value], dtype=value.dtype, device=value.device) + return value + if key == "llava_args" and isinstance(value, (list, tuple)): + return [ + _select_batch_item( + item, + index, + batch_size, + key=f"{key}[{idx}]", + query_lengths=query_lengths, + ) + for idx, item in enumerate(value) + ] + if isinstance(value, tuple) and len(value) == batch_size: + return value[index] + if isinstance(value, list) and len(value) == batch_size: + return value[index] + return value + + +def _hybrid_apc_request_records( + input_dict: Dict[str, Any], + *, + batch_size: int, +) -> tuple[dict[str, Any], ...] | None: + records = input_dict.get("hybrid_request_records") + if records is None: + return None + if isinstance(records, dict): + records = (records,) + elif isinstance(records, list): + records = tuple(records) + if not isinstance(records, tuple): + return None + if len(records) != batch_size: + raise ValueError( + "hybrid APC request record count must match batch size: " + f"records={len(records)} batch_size={batch_size}" + ) + if not all(isinstance(record, dict) for record in records): + raise ValueError("hybrid APC request records must be dictionaries") + return records + + +def _hybrid_apc_record_values( + records: tuple[dict[str, Any], ...] | None, + key: str, +): + if not records: + return None + values = [record.get(key) for record in records] + if not any(value is not None for value in values): + return None + return values[0] if len(values) == 1 else tuple(values) + + +def _apply_hybrid_apc_request_record( + row_input: Dict[str, Any], + record: dict[str, Any] | None, +) -> None: + if not isinstance(record, dict): + return + for source_key, target_key in ( + ("request_id", "hybrid_request_id"), + ("vllm_attention_hit_len", "vllm_attention_hit_len"), + ("request_prefix_len", "request_prefix_len"), + ("cumulative_hashes_by_prefix_len", "cumulative_hashes_by_prefix_len"), + ("attention_block_refs_by_prefix_len", "attention_block_refs_by_prefix_len"), + ("active_suffix_len", "hybrid_active_suffix_len"), + ("full_input_ids", "hybrid_full_input_ids"), + ): + value = record.get(source_key) + if value is not None: + if source_key == "full_input_ids" and not isinstance(value, torch.Tensor): + input_ids = row_input.get("input_ids") + dtype = ( + input_ids.dtype + if isinstance(input_ids, torch.Tensor) + else torch.int64 + ) + device = ( + input_ids.device + if isinstance(input_ids, torch.Tensor) + else None + ) + value = torch.tensor([list(value)], dtype=dtype, device=device) + row_input[target_key] = value + + +def _pad_value_for_key( + neuron_base_instance: "NeuronBaseForCausalLM", + key: str, +) -> int: + if key == "input_ids": + return int(getattr(neuron_base_instance.config, "pad_token_id", 0) or 0) + if key in { + "attention_mask", + "hybrid_restore_mask", + "hybrid_restore_prefix_lens", + "hybrid_commit_mask", + }: + return 0 + if key in {"position_ids", "rotary_position_id", "rotary_position_ids"}: + return 1 + if key == "slot_mapping": + return -1 + return 0 + + +def _right_pad_dim1(tensor: torch.Tensor, target_len: int, pad_value: int) -> torch.Tensor: + if tensor.ndim < 2 or tensor.shape[1] == target_len: + return tensor + if tensor.shape[1] > target_len: + raise ValueError( + f"cannot pad tensor with dim1 {tensor.shape[1]} down to {target_len}" + ) + pad_shape = list(tensor.shape) + pad_shape[1] = target_len - tensor.shape[1] + pad = torch.full( + tuple(pad_shape), + pad_value, + dtype=tensor.dtype, + device=tensor.device, + ) + return torch.cat([tensor, pad], dim=1) + + +def _right_pad_last_dim( + tensor: torch.Tensor, + target_len: int, + pad_value: int, +) -> torch.Tensor: + if tensor.shape[-1] == target_len: + return tensor + if tensor.shape[-1] > target_len: + raise ValueError( + f"cannot pad tensor with last dim {tensor.shape[-1]} down to {target_len}" + ) + pad_shape = list(tensor.shape) + pad_shape[-1] = target_len - tensor.shape[-1] + pad = torch.full( + tuple(pad_shape), + pad_value, + dtype=tensor.dtype, + device=tensor.device, + ) + return torch.cat([tensor, pad], dim=-1) + + +def _resize_dim1(tensor: torch.Tensor, target_len: int, pad_value: int) -> torch.Tensor: + if tensor.ndim < 2 or tensor.shape[1] == target_len: + return tensor + if tensor.shape[1] > target_len: + return tensor[:, :target_len, ...] + return _right_pad_dim1(tensor, target_len, pad_value) + + +def _configured_cte_bucket_len( + neuron_base_instance: "NeuronBaseForCausalLM", + current_len: int, +) -> int: + bucket_sources = ( + getattr( + getattr(neuron_base_instance, "neuron_config", None), + "context_encoding_buckets", + None, + ), + getattr( + getattr( + getattr(neuron_base_instance, "context_encoding_model", None), + "neuron_config", + None, + ), + "context_encoding_buckets", + None, + ), + getattr( + getattr( + getattr(neuron_base_instance, "context_encoding_model", None), + "neuron_config", + None, + ), + "buckets", + None, + ), + ) + buckets: list[int] = [] + for source in bucket_sources: + if source is None: + continue + for bucket in source: + if isinstance(bucket, (list, tuple)): + if not bucket: + continue + bucket = bucket[0] + try: + buckets.append(int(bucket)) + except (TypeError, ValueError): + continue + if buckets: + break + for bucket in sorted(set(buckets)): + if current_len <= bucket: + return bucket + return current_len + + +def _pa_block_size(neuron_base_instance: "NeuronBaseForCausalLM") -> int | None: + for owner in ( + getattr(neuron_base_instance, "neuron_config", None), + getattr(getattr(neuron_base_instance, "config", None), "neuron_config", None), + ): + block_size = getattr(owner, "pa_block_size", None) + if block_size: + return int(block_size) + return None + + +def _active_block_table_target_len( + neuron_base_instance: "NeuronBaseForCausalLM", + target_context_len: int | None, +) -> int | None: + if target_context_len is None: + return None + block_size = _pa_block_size(neuron_base_instance) + if not block_size: + return None + return max(1, (int(target_context_len) + block_size - 1) // block_size) + + +def _restore_block_table_target_len( + neuron_base_instance: "NeuronBaseForCausalLM", + row_input_dicts: list[Dict[str, Any]], +) -> int | None: + block_size = _pa_block_size(neuron_base_instance) + if not block_size: + return None + target_len = 0 + for row_input in row_input_dicts: + restore_mask = _single_batch_value(row_input.get("hybrid_restore_mask")) + if restore_mask is None or _to_python_int(restore_mask) <= 0: + continue + restore_prefix_len = _single_batch_value( + row_input.get("hybrid_restore_prefix_lens") + ) + if restore_prefix_len is None: + continue + restore_blocks = ( + _to_python_int(restore_prefix_len) + block_size - 1 + ) // block_size + target_len = max(target_len, restore_blocks) + return target_len or None + + +def _configured_max_context_len( + neuron_base_instance: "NeuronBaseForCausalLM", +) -> int | None: + owners = ( + getattr(neuron_base_instance, "neuron_config", None), + getattr(getattr(neuron_base_instance, "config", None), "neuron_config", None), + getattr( + getattr(neuron_base_instance, "context_encoding_model", None), + "neuron_config", + None, + ), + getattr(neuron_base_instance, "config", None), + ) + for owner in owners: + if owner is None: + continue + for attr in ( + "seq_len", + "max_context_length", + "max_model_len", + "max_position_embeddings", + ): + value = getattr(owner, attr, None) + if value is None: + continue + try: + value = int(value) + except (TypeError, ValueError): + continue + if value > 0: + return value + return None + + +def _full_block_table_target_len( + neuron_base_instance: "NeuronBaseForCausalLM", + row_input_dicts: list[Dict[str, Any]], +) -> int | None: + target_len = 0 + block_size = _pa_block_size(neuron_base_instance) + max_context_len = _configured_max_context_len(neuron_base_instance) + if block_size and max_context_len: + target_len = max(target_len, (max_context_len + block_size - 1) // block_size) + for row_input in row_input_dicts: + block_table = row_input.get("block_table") + if not isinstance(block_table, torch.Tensor) or block_table.numel() == 0: + continue + if block_table.ndim >= 2: + target_len = max(target_len, int(block_table.shape[1])) + elif block_table.ndim == 1: + target_len = max(target_len, int(block_table.numel())) + return target_len or None + + +def _uses_block_backed_restore(row_input_dicts: list[Dict[str, Any]]) -> bool: + for row_input in row_input_dicts: + block_table = row_input.get("block_table") + if not isinstance(block_table, torch.Tensor) or block_table.numel() == 0: + continue + restore_mask = _single_batch_value(row_input.get("hybrid_restore_mask")) + if restore_mask is not None and _to_python_int(restore_mask) > 0: + return True + computed_len = _single_batch_value(row_input.get("computed_context_lens")) + if computed_len is not None and _to_python_int(computed_len) > 0: + return True + return False + + +def _synthesize_slots_from_block_table( + *, + block_table_row: torch.Tensor, + position_row: torch.Tensor, + q_len: int, + block_size: int, + dtype: torch.dtype, +) -> torch.Tensor | None: + if q_len <= 0: + return torch.empty((0,), dtype=dtype, device=position_row.device) + positions = position_row[:q_len].to(torch.int64) + logical_blocks = torch.div(positions, block_size, rounding_mode="floor") + if logical_blocks.numel() == 0: + return torch.empty((0,), dtype=dtype, device=position_row.device) + block_table_i64 = block_table_row.to(torch.int64) + nonzero_block_indices = (block_table_i64 != 0).nonzero(as_tuple=False).reshape(-1) + if int(nonzero_block_indices.numel()) > 0: + block_table_i64 = block_table_i64[: int(nonzero_block_indices[-1].item()) + 1] + if int(logical_blocks.max().item()) >= int(block_table_i64.shape[0]): + # Some vLLM cached/chunked rows carry only the active suffix/decode + # block table, while position_ids remain absolute in the request. + min_logical_block = int(logical_blocks.min().item()) + rebased_blocks = logical_blocks - min_logical_block + if int(rebased_blocks.max().item()) >= int(block_table_i64.shape[0]): + return None + logical_blocks = rebased_blocks + offsets = positions.remainder(block_size) + physical_blocks = torch.index_select( + block_table_i64, + 0, + logical_blocks, + ) + return (physical_blocks * block_size + offsets).to(dtype=dtype) + + +def _repair_vectorized_slot_mapping( + neuron_base_instance: "NeuronBaseForCausalLM", + combined: Dict[str, Any], +) -> None: + input_ids = combined.get("input_ids") + position_ids = combined.get("position_ids") + block_table = combined.get("block_table") + if ( + not isinstance(input_ids, torch.Tensor) + or input_ids.ndim != 2 + or not isinstance(position_ids, torch.Tensor) + or position_ids.ndim != 2 + or not isinstance(block_table, torch.Tensor) + or block_table.ndim != 2 + ): + return + + batch_size, target_len = int(input_ids.shape[0]), int(input_ids.shape[1]) + if batch_size <= 0 or target_len <= 0: + return + block_size = _pa_block_size(neuron_base_instance) + if not block_size: + return + + query_lengths = _vectorized_query_lengths(combined, batch_size=batch_size) + if query_lengths is None: + return + + slot_mapping = combined.get("slot_mapping") + slot_dtype = ( + slot_mapping.dtype + if isinstance(slot_mapping, torch.Tensor) + else torch.int32 + ) + slot_device = input_ids.device + scalar_slots = None + slot_rows = None + if isinstance(slot_mapping, torch.Tensor) and slot_mapping.numel() > 0: + slot_device = slot_mapping.device + if slot_mapping.ndim == 1 and int(slot_mapping.numel()) == batch_size: + scalar_slots = slot_mapping.to(dtype=slot_dtype) + elif slot_mapping.ndim >= 2 and slot_mapping.shape[0] >= batch_size: + slot_rows = _resize_dim1( + slot_mapping[:batch_size].to(dtype=slot_dtype), + target_len, + -1, + ) + + repaired = torch.full( + (batch_size, target_len), + -1, + dtype=slot_dtype, + device=slot_device, + ) + changed = slot_rows is None or tuple(slot_rows.shape[:2]) != (batch_size, target_len) + for row_idx, query_len in enumerate(query_lengths[:batch_size]): + q_len = max(0, min(int(query_len), target_len)) + if q_len == 0: + continue + if slot_rows is not None and bool((slot_rows[row_idx, :q_len] >= 0).all().item()): + repaired[row_idx, :q_len] = slot_rows[row_idx, :q_len] + if q_len < target_len: + repaired[row_idx, q_len:] = -1 + continue + if scalar_slots is not None and q_len == 1 and int(scalar_slots.numel()) > row_idx: + repaired[row_idx, 0] = scalar_slots[row_idx] + changed = True + continue + synthesized = _synthesize_slots_from_block_table( + block_table_row=block_table[row_idx], + position_row=position_ids[row_idx], + q_len=q_len, + block_size=block_size, + dtype=slot_dtype, + ) + if synthesized is None: + if slot_rows is not None: + repaired[row_idx] = slot_rows[row_idx] + elif scalar_slots is not None and int(scalar_slots.numel()) > row_idx: + repaired[row_idx, 0] = scalar_slots[row_idx] + continue + repaired[row_idx, :q_len] = synthesized + changed = True + + if changed or slot_rows is None or bool((repaired[:, :target_len] != slot_rows).any().item()): + combined["slot_mapping"] = repaired + + +def _repair_vectorized_batch_vectors(combined: Dict[str, Any]) -> None: + input_ids = combined.get("input_ids") + if not isinstance(input_ids, torch.Tensor) or input_ids.ndim < 2: + return + batch_size = int(input_ids.shape[0]) + if batch_size <= 1: + return + + seq_ids = combined.get("seq_ids") + if not isinstance(seq_ids, torch.Tensor) or seq_ids.reshape(-1).shape[0] != batch_size: + dtype = seq_ids.dtype if isinstance(seq_ids, torch.Tensor) else torch.int32 + device = seq_ids.device if isinstance(seq_ids, torch.Tensor) else input_ids.device + combined["seq_ids"] = torch.arange(batch_size, dtype=dtype, device=device) + + adapter_ids = combined.get("adapter_ids") + if ( + not isinstance(adapter_ids, torch.Tensor) + or adapter_ids.numel() == 0 + or adapter_ids.reshape(-1).shape[0] != batch_size + ): + dtype = adapter_ids.dtype if isinstance(adapter_ids, torch.Tensor) else torch.int32 + device = adapter_ids.device if isinstance(adapter_ids, torch.Tensor) else input_ids.device + fill_value = ( + int(adapter_ids.reshape(-1)[0].item()) + if isinstance(adapter_ids, torch.Tensor) and adapter_ids.numel() > 0 + else 0 + ) + combined["adapter_ids"] = torch.full( + (batch_size,), + fill_value, + dtype=dtype, + device=device, + ) + + +def _repair_vectorized_attention_mask_for_block_table( + neuron_base_instance: "NeuronBaseForCausalLM", + combined: Dict[str, Any], +) -> None: + input_ids = combined.get("input_ids") + attention_mask = combined.get("attention_mask") + block_table = combined.get("block_table") + if ( + not isinstance(input_ids, torch.Tensor) + or input_ids.ndim != 2 + or not isinstance(attention_mask, torch.Tensor) + or attention_mask.ndim != 2 + or not isinstance(block_table, torch.Tensor) + or block_table.ndim != 2 + ): + return + + batch_size = int(input_ids.shape[0]) + if ( + batch_size <= 0 + or block_table.shape[0] != batch_size + or block_table.shape[1] <= 1 + ): + return + restore_mask = _batch_int_list( + combined, + "hybrid_restore_mask", + batch_size=batch_size, + ) + computed_context_lens = _batch_int_list( + combined, + "computed_context_lens", + batch_size=batch_size, + ) + if not ( + (restore_mask is not None and any(value > 0 for value in restore_mask)) + or ( + computed_context_lens is not None + and any(value > 0 for value in computed_context_lens) + ) + ): + return + + block_size = _pa_block_size(neuron_base_instance) + if not block_size: + return + target_len = int(block_table.shape[1]) * int(block_size) + max_context_len = _configured_max_context_len(neuron_base_instance) + if max_context_len is not None: + target_len = max(target_len, max_context_len) + if int(attention_mask.shape[1]) == target_len: + return + if int(attention_mask.shape[1]) > target_len: + target_len = int(attention_mask.shape[1]) + + context_lens = _batch_int_list(combined, "full_context_lens", batch_size=batch_size) + if context_lens is None: + num_queries = _batch_int_list(combined, "num_queries", batch_size=batch_size) + if computed_context_lens is not None and num_queries is not None: + context_lens = [ + computed_len + query_len + for computed_len, query_len in zip(computed_context_lens, num_queries) + ] + if context_lens is None: + context_lens = [ + int(row.to(torch.int64).sum().item()) + for row in attention_mask[:batch_size] + ] + + repaired = torch.zeros( + (batch_size, target_len), + dtype=attention_mask.dtype, + device=attention_mask.device, + ) + for row_idx, context_len in enumerate(context_lens[:batch_size]): + active_len = max(0, min(int(context_len), target_len)) + if active_len: + repaired[row_idx, :active_len] = 1 + combined["attention_mask"] = repaired + + +_VECTOR_CTE_SEQUENCE_KEYS = { + "input_ids", + "attention_mask", + "position_ids", + "slot_mapping", + "inputs_embeds", +} + + +def _with_zero_hybrid_apc_slots(input_dict: Dict[str, Any]) -> Dict[str, Any]: + output = dict(input_dict) + seq_ids = input_dict.get("seq_ids") + if isinstance(seq_ids, torch.Tensor) and seq_ids.ndim >= 1: + batch_size = int(seq_ids.reshape(-1).shape[0]) + device = seq_ids.device + else: + input_ids = input_dict.get("input_ids") + batch_size = ( + int(input_ids.shape[0]) + if isinstance(input_ids, torch.Tensor) and input_ids.ndim >= 1 + else 1 + ) + device = input_ids.device if isinstance(input_ids, torch.Tensor) else None + kwargs = {"dtype": torch.int32} + if device is not None: + kwargs["device"] = device + zeros = torch.zeros((batch_size,), **kwargs) + output.setdefault("hybrid_restore_slot_ids", zeros) + output.setdefault("hybrid_restore_mask", torch.zeros_like(zeros)) + output.setdefault("hybrid_restore_prefix_lens", torch.zeros_like(zeros)) + output.setdefault("hybrid_commit_slot_ids", torch.zeros_like(zeros)) + output.setdefault("hybrid_commit_mask", torch.zeros_like(zeros)) + if "num_queries" not in output: + query_lengths = _vectorized_query_lengths(output, batch_size=batch_size) + if query_lengths is None: + input_ids = output.get("input_ids") + active_len = ( + int(input_ids.shape[1]) + if isinstance(input_ids, torch.Tensor) and input_ids.ndim >= 2 + else 0 + ) + query_lengths = [active_len] * batch_size + query_kwargs = {"dtype": torch.int32} + if device is not None: + query_kwargs["device"] = device + output["num_queries"] = torch.tensor( + [[max(0, int(query_len))] for query_len in query_lengths[:batch_size]], + **query_kwargs, + ) + input_ids = output.get("input_ids") + if isinstance(input_ids, torch.Tensor) and input_ids.ndim >= 2: + query_lengths = _vectorized_query_lengths(output, batch_size=batch_size) + if query_lengths is None: + query_lengths = [int(input_ids.shape[1])] * batch_size + attention_mask = torch.zeros( + input_ids.shape[:2], + dtype=torch.int32, + device=input_ids.device, + ) + for row_idx, query_len in enumerate(query_lengths[:batch_size]): + active_len = max(0, min(int(query_len), input_ids.shape[1])) + if active_len: + attention_mask[row_idx, :active_len] = 1 + output["attention_mask"] = attention_mask + return output + + +_UNBACKED_SUFFIX_ONLY_HYBRID_APC_ERROR = ( + "suffix-only hybrid APC received an attention prefix hit " + "without scheduler-authorized GDN checkpoint metadata" +) + + +def _is_unbacked_suffix_only_hybrid_apc_error(exc: Exception) -> bool: + return isinstance(exc, ValueError) and _UNBACKED_SUFFIX_ONLY_HYBRID_APC_ERROR in str( + exc + ) + + +def _active_chunk_suffix_len( + *, + suffix_len: int, + active_suffix_len: int | None, +) -> int | None: + if active_suffix_len is None: + return int(suffix_len) + active_len = int(active_suffix_len) + if active_len <= 0 or active_len > int(suffix_len): + return None + return active_len + + +def _is_seq_id_fallback_request_id(request_id: Any) -> bool: + return ( + isinstance(request_id, tuple) + and len(request_id) == 2 + and request_id[0] == "seq_id" + ) + + +def _is_same_request_chunked_prefill_continuation( + input_dict: Dict[str, Any], + *, + request_id: Any, + request_prefix_len: int, + hit_len: int, + suffix_len: int, + active_suffix_len: int | None, +) -> bool: + if suffix_len <= 0 or hit_len <= 0: + return False + active_len = _active_chunk_suffix_len( + suffix_len=suffix_len, + active_suffix_len=active_suffix_len, + ) + if active_len is None: + return False + if int(request_prefix_len) - int(hit_len) != int(active_len): + return True + if ( + input_dict.get("hybrid_prefill_completion_state") is not None + and not _truthy_single_value(input_dict.get("hybrid_prefill_completion_state")) + and _request_id_in_collection( + request_id, + input_dict.get("hybrid_cached_request_ids"), + ) + ): + return True + if _is_seq_id_fallback_request_id(request_id): + return True + return False + + +def _with_inert_hybrid_apc_chunk_continuation( + input_dict: Dict[str, Any], + *, + hit_len: int, + active_prefix_len: int, + suffix_len: int, +) -> Dict[str, Any]: + input_ids = input_dict.get("input_ids") + device = input_ids.device if isinstance(input_ids, torch.Tensor) else None + kwargs = {"dtype": torch.int32} + if device is not None: + kwargs["device"] = device + + output = dict(input_dict) + output["computed_context_lens"] = torch.tensor([[max(0, int(hit_len))]], **kwargs) + output["full_context_lens"] = torch.tensor( + [[max(0, int(active_prefix_len))]], + **kwargs, + ) + output["num_queries"] = torch.tensor([[max(0, int(suffix_len))]], **kwargs) + return _with_zero_hybrid_apc_slots(output) + + +def _with_same_request_gdn_active_carry( + input_dict: Dict[str, Any], +) -> Dict[str, Any]: + """Keep attention prefix reads but carry same-request GDN state directly.""" + + output = dict(input_dict) + _zero_mask_if_present(output, "hybrid_restore_mask") + return output + + +def _replace_prepared_input_dict(prepared, input_dict: Dict[str, Any]): + if hasattr(prepared, "_replace"): + return prepared._replace(input_dict=input_dict) + prepared.input_dict = input_dict + return prepared + + +def _is_completed_cached_decode_row( + input_dict: Dict[str, Any], + *, + request_id: Any, + query_len: int | None, +) -> bool: + if not _request_id_in_collection( + request_id, + input_dict.get("hybrid_cached_request_ids"), + ): + return False + if not _truthy_single_value(input_dict.get("hybrid_prefill_completion_state")): + return False + return query_len is None or int(query_len) <= 1 + + +def _combine_vectorized_hybrid_apc_inputs( + neuron_base_instance: "NeuronBaseForCausalLM", + original_input_dict: Dict[str, Any], + row_input_dicts: list[Dict[str, Any]], +) -> Dict[str, Any]: + combined = dict(original_input_dict) + keys: set[str] = set() + for row_input in row_input_dicts: + keys.update(row_input.keys()) + + max_sequence_dim1 = None + for key in _VECTOR_CTE_SEQUENCE_KEYS: + tensors = [ + row_input.get(key) + for row_input in row_input_dicts + if isinstance(row_input.get(key), torch.Tensor) + ] + if not tensors or not any(tensor.ndim >= 1 for tensor in tensors): + continue + key_max = max(tensor.shape[1] if tensor.ndim >= 2 else 1 for tensor in tensors) + max_sequence_dim1 = ( + key_max + if max_sequence_dim1 is None + else max(max_sequence_dim1, key_max) + ) + target_sequence_dim1 = ( + _configured_cte_bucket_len(neuron_base_instance, max_sequence_dim1) + if max_sequence_dim1 is not None + else None + ) + target_block_dim1 = _active_block_table_target_len( + neuron_base_instance, + target_sequence_dim1, + ) + restore_block_dim1 = _restore_block_table_target_len( + neuron_base_instance, + row_input_dicts, + ) + full_block_dim1 = ( + _full_block_table_target_len(neuron_base_instance, row_input_dicts) + if _uses_block_backed_restore(row_input_dicts) + else None + ) + if full_block_dim1 is not None: + target_block_dim1 = max(target_block_dim1 or 0, full_block_dim1) + + for key in keys: + if key.startswith("_hybrid_apc"): + continue + values = [row_input.get(key) for row_input in row_input_dicts] + if not all(isinstance(value, torch.Tensor) for value in values): + continue + + tensors = [value for value in values if isinstance(value, torch.Tensor)] + if all(tensor.numel() == 0 for tensor in tensors): + combined[key] = tensors[0] + continue + if all(tensor.ndim == 0 for tensor in tensors): + combined[key] = torch.stack(tensors) + continue + if ( + key in {"rotary_position_id", "rotary_position_ids"} + and all(tensor.ndim == 3 and tensor.shape[1] == 1 for tensor in tensors) + ): + target_dim = max(tensor.shape[-1] for tensor in tensors) + combined[key] = torch.cat( + [ + _right_pad_last_dim( + tensor, + target_dim, + _pad_value_for_key(neuron_base_instance, key), + ) + for tensor in tensors + ], + dim=1, + ) + continue + if all(tensor.ndim >= 1 and tensor.shape[0] == 1 for tensor in tensors): + max_dim1 = None + if any(tensor.ndim >= 2 for tensor in tensors): + max_dim1 = max(tensor.shape[1] if tensor.ndim >= 2 else 1 for tensor in tensors) + target_dim1 = max_dim1 + if key in _VECTOR_CTE_SEQUENCE_KEYS and target_sequence_dim1 is not None: + target_dim1 = target_sequence_dim1 + elif key == "block_table" and target_block_dim1 is not None: + target_dim1 = target_block_dim1 + if restore_block_dim1 is not None: + target_dim1 = max(target_dim1, restore_block_dim1) + padded = [] + for tensor in tensors: + current = ( + tensor.reshape(1, -1) + if target_dim1 is not None and tensor.ndim == 1 + else tensor + ) + if target_dim1 is not None and current.ndim >= 2: + resize = _resize_dim1 if key == "block_table" else _right_pad_dim1 + current = resize( + current, + target_dim1, + _pad_value_for_key(neuron_base_instance, key), + ) + padded.append(current) + try: + combined[key] = torch.cat(padded, dim=0) + except RuntimeError as exc: + raise ValueError( + f"cannot combine vectorized hybrid APC tensor {key!r}: " + f"{[tuple(tensor.shape) for tensor in padded]}" + ) from exc + continue + if all(tuple(tensor.shape) == tuple(tensors[0].shape) for tensor in tensors): + combined[key] = tensors[0] + + _repair_vectorized_batch_vectors(combined) + _repair_vectorized_attention_mask_for_block_table(neuron_base_instance, combined) + _repair_vectorized_slot_mapping(neuron_base_instance, combined) + return combined + + +def _prepare_vectorized_hybrid_apc_requests( + neuron_base_instance: "NeuronBaseForCausalLM", + input_dict: Dict[str, Any], + *, + bridge: Any, + batch_size: int, +) -> Dict[str, Any]: + row_outputs: list[Dict[str, Any]] = [] + prepared_requests = [] + query_lengths = _vectorized_query_lengths(input_dict, batch_size=batch_size) + request_records = _hybrid_apc_request_records( + input_dict, + batch_size=batch_size, + ) + try: + for index in range(batch_size): + row_input = { + key: _select_batch_item( + value, + index, + batch_size, + key=key, + query_lengths=query_lengths, + ) + for key, value in input_dict.items() + if not key.startswith("_hybrid_apc") + } + if request_records is not None: + _apply_hybrid_apc_request_record(row_input, request_records[index]) + row_input["hybrid_apc_bridge"] = bridge + row_output = prepare_hybrid_apc_request_for_execution( + neuron_base_instance, + row_input, + ) + row_outputs.append(row_output) + prepared = row_input.get("_hybrid_apc_prepared") + if prepared is not None: + prepared_requests.append(prepared) + except Exception: + for prepared in prepared_requests: + bridge.cancel_request(prepared) + raise + + if prepared_requests: + input_dict["_hybrid_apc_bridge"] = bridge + input_dict["_hybrid_apc_prepared"] = prepared_requests + + combined = _combine_vectorized_hybrid_apc_inputs( + neuron_base_instance, + input_dict, + row_outputs, + ) + if os.environ.get("QWEN36_HYBRID_APC_DEBUG") == "1": + print( + "[hybrid_apc_debug] prepare-vectorized " + f"batch_size={batch_size} prepared={len(prepared_requests)} " + f"input_shape={tuple(input_dict['input_ids'].shape)} " + f"prepared_shape={tuple(combined['input_ids'].shape)} " + f"computed={combined.get('computed_context_lens')} " + f"num_queries={combined.get('num_queries')} " + f"restore_mask={combined.get('hybrid_restore_mask')} " + f"commit_mask={combined.get('hybrid_commit_mask')}", + flush=True, + ) + return combined + + +def _env_flag(name: str) -> bool: + return os.environ.get(name, "").strip().lower() in {"1", "true", "yes", "on"} + + +def _hybrid_gdn_restore_disabled() -> bool: + return _env_flag("QWEN36_DISABLE_HYBRID_GDN_RESTORE") or _env_flag( + "QWEN36_DISABLE_HYBRID_GDN_RESTORE_COMMIT" + ) + + +def _hybrid_gdn_restore_mask_zeroed() -> bool: + return _env_flag("QWEN36_ZERO_HYBRID_GDN_RESTORE_MASK") + + +def _hybrid_gdn_commit_disabled() -> bool: + return _env_flag("QWEN36_DISABLE_HYBRID_GDN_COMMIT") or _env_flag( + "QWEN36_DISABLE_HYBRID_GDN_RESTORE_COMMIT" + ) + + +def _zero_mask_if_present(input_dict: Dict[str, Any], key: str): + value = input_dict.get(key) + if isinstance(value, torch.Tensor): + input_dict[key] = torch.zeros_like(value) + + +def _apply_hybrid_gdn_debug_switches(input_dict: Dict[str, Any]) -> Dict[str, Any]: + if _hybrid_gdn_restore_disabled() or _hybrid_gdn_restore_mask_zeroed(): + _zero_mask_if_present(input_dict, "hybrid_restore_mask") + if _hybrid_gdn_commit_disabled(): + _zero_mask_if_present(input_dict, "hybrid_commit_mask") + return input_dict + + +def _get_hybrid_apc_bridge( + neuron_base_instance: "NeuronBaseForCausalLM", + input_dict: Dict[str, Any], +): + bridge = _first_present( + input_dict.get("hybrid_apc_bridge"), + getattr(neuron_base_instance, "hybrid_apc_bridge", None), + getattr(neuron_base_instance, "_hybrid_apc_last_bridge", None), + ) + if bridge is None: + ensure_bridge = getattr( + neuron_base_instance, + "ensure_hybrid_apc_scheduler_bridge", + None, + ) + if ensure_bridge is not None: + bridge = ensure_bridge() + return bridge + + +def _select_hybrid_apc_owner( + neuron_base_instance: "NeuronBaseForCausalLM", + model_to_execute: "ModelWrapper", + input_dict: Dict[str, Any], +): + candidates = ( + neuron_base_instance, + model_to_execute, + getattr(neuron_base_instance, "context_encoding_model", None), + getattr(neuron_base_instance, "token_generation_model", None), + ) + fallback = neuron_base_instance + seen: set[int] = set() + for candidate in candidates: + if candidate is None: + continue + candidate_id = id(candidate) + if candidate_id in seen: + continue + seen.add(candidate_id) + if fallback is neuron_base_instance and _is_hybrid_apc_enabled(candidate): + fallback = candidate + if not _is_hybrid_apc_enabled(candidate): + continue + if _get_hybrid_apc_bridge(candidate, input_dict) is not None: + return candidate + return fallback + + +def _requires_external_hybrid_apc_metadata( + neuron_base_instance: "NeuronBaseForCausalLM", + bridge: Any, +) -> bool: + for owner in ( + getattr(neuron_base_instance, "config", None), + getattr(neuron_base_instance, "neuron_config", None), + getattr(getattr(neuron_base_instance, "config", None), "neuron_config", None), + ): + if bool(getattr(owner, "hybrid_apc_require_vllm_metadata", False)): + return True + return bool( + getattr(bridge, "requires_external_metadata", False) + ) + + +def prepare_hybrid_apc_request_for_execution( + neuron_base_instance: "NeuronBaseForCausalLM", + input_dict: Dict[str, Any], +) -> Dict[str, Any]: + """Run scheduler-side hybrid APC request preparation when metadata exists. + + The concrete bridge lives with the Qwen contrib model. This function is + intentionally duck-typed so the core async path does not import contrib + modules. vLLM/NxDI request prep can opt in by attaching a bridge object plus + attention-hit metadata to ``input_dict``. + """ + + if not _is_hybrid_apc_enabled(neuron_base_instance): + return input_dict + + bridge = _get_hybrid_apc_bridge(neuron_base_instance, input_dict) + requires_external_metadata = _requires_external_hybrid_apc_metadata( + neuron_base_instance, + bridge, + ) + if bridge is None: + if requires_external_metadata: + raise ValueError( + "hybrid APC requires a scheduler bridge attached to the model " + "or input_dict" + ) + return input_dict + + lifecycle_input_dict = input_dict + batch_size = _batch_size_from_input_dict(input_dict) + request_records = _hybrid_apc_request_records( + input_dict, + batch_size=batch_size, + ) + if batch_size == 1 and request_records is not None: + input_dict = dict(input_dict) + _apply_hybrid_apc_request_record(input_dict, request_records[0]) + + request_id = _first_present( + input_dict.get("hybrid_request_id"), + input_dict.get("request_id"), + _hybrid_apc_record_values(request_records, "request_id"), + ) + if request_id is None: + seq_id = _single_batch_value(input_dict.get("seq_ids")) + if seq_id is not None: + request_id = ("seq_id", _to_python_int(seq_id)) + + attention_hit_len_source = _first_present( + _hybrid_apc_record_values(request_records, "vllm_attention_hit_len"), + input_dict.get("vllm_attention_hit_len"), + input_dict.get("hybrid_attention_hit_len"), + input_dict.get("attention_hit_len"), + ) + if attention_hit_len_source is None: + attention_hit_len_source = input_dict.get("computed_context_lens") + attention_hit_len = _single_batch_value(attention_hit_len_source) + multi_attention_hit_lens = _multi_batch_int_values(attention_hit_len_source) + if attention_hit_len is None and multi_attention_hit_lens is None: + multi_attention_hit_lens = _multi_batch_int_values( + input_dict.get("computed_context_lens") + ) + if ( + attention_hit_len is None + and multi_attention_hit_lens is not None + and batch_size > 1 + ): + if ( + all(hit_len == 0 for hit_len in multi_attention_hit_lens) + and not requires_external_metadata + ): + return input_dict + return _prepare_vectorized_hybrid_apc_requests( + neuron_base_instance, + input_dict, + bridge=bridge, + batch_size=batch_size, + ) + if attention_hit_len is None and multi_attention_hit_lens is not None: + if all(hit_len == 0 for hit_len in multi_attention_hit_lens): + if requires_external_metadata: + raise ValueError( + "hybrid APC v0 request prep supports one request at a time; " + "vectorized continuous-batching metadata is not wired yet" + ) + return input_dict + raise ValueError( + "hybrid APC v0 request prep supports one request at a time; " + "vectorized continuous-batching metadata is not wired yet" + ) + + if request_id is None and attention_hit_len is None and not requires_external_metadata: + return input_dict + if request_id is None: + raise ValueError("hybrid APC request prep requires request_id") + if attention_hit_len is None: + raise ValueError("hybrid APC request prep requires attention hit length") + + request_prefix_len = _first_present( + input_dict.get("request_prefix_len"), + input_dict.get("hybrid_request_prefix_len"), + input_dict.get("prompt_len"), + _single_batch_value(input_dict.get("full_context_lens")), + ) + if request_prefix_len is not None: + request_prefix_len = _to_python_int(request_prefix_len) + + active_suffix_len = _first_present( + input_dict.get("hybrid_active_suffix_len"), + input_dict.get("active_suffix_len"), + _hybrid_apc_record_values(request_records, "active_suffix_len"), + ) + active_suffix_len = _single_batch_value(active_suffix_len) + if active_suffix_len is not None: + active_suffix_len = max(0, _to_python_int(active_suffix_len)) + + query_len = _single_batch_value(input_dict.get("num_queries")) + if query_len is None and active_suffix_len is not None: + query_len = active_suffix_len + elif ( + query_len is None + and request_prefix_len is not None + and attention_hit_len is not None + ): + query_len = max(0, request_prefix_len - _to_python_int(attention_hit_len)) + elif query_len is not None: + query_len = _to_python_int(query_len) + + if _is_completed_cached_decode_row( + input_dict, + request_id=request_id, + query_len=query_len, + ): + return _with_zero_hybrid_apc_slots(input_dict) + + cumulative_hashes_by_prefix_len = _first_present( + input_dict.get("vllm_or_local_prefix_hashes"), + input_dict.get("cumulative_hashes_by_prefix_len"), + input_dict.get("hybrid_cumulative_hashes_by_prefix_len"), + ) + attention_block_refs_by_prefix_len = _first_present( + input_dict.get("attention_block_refs"), + input_dict.get("attention_block_refs_by_prefix_len"), + input_dict.get("hybrid_attention_block_refs_by_prefix_len"), + ) + if ( + requires_external_metadata + and not cumulative_hashes_by_prefix_len + and _to_python_int(attention_hit_len) <= 0 + ): + return _with_zero_hybrid_apc_slots(input_dict) + full_input_ids = _first_present( + input_dict.get("hybrid_full_input_ids"), + input_dict.get("full_input_ids"), + input_dict.get("prompt_input_ids"), + ) + bridge_input_dict = input_dict + prepared = None + if full_input_ids is not None: + if not _single_batch_tensor(full_input_ids): + if requires_external_metadata: + raise ValueError( + "hybrid APC v0 request prep supports one request at a time; " + "vectorized continuous-batching metadata is not wired yet" + ) + return input_dict + bridge_input_dict = dict(input_dict) + bridge_input_dict["input_ids"] = full_input_ids + for source_key, target_key in ( + ("hybrid_full_attention_mask", "attention_mask"), + ("full_attention_mask", "attention_mask"), + ("hybrid_full_position_ids", "position_ids"), + ("full_position_ids", "position_ids"), + ("hybrid_full_slot_mapping", "slot_mapping"), + ("full_slot_mapping", "slot_mapping"), + ): + value = input_dict.get(source_key) + if value is not None: + bridge_input_dict[target_key] = value + elif request_prefix_len is not None: + input_ids = input_dict.get("input_ids") + if ( + isinstance(input_ids, torch.Tensor) + and input_ids.ndim >= 2 + and input_ids.shape[1] < request_prefix_len + ): + if _to_python_int(attention_hit_len) <= 0: + active_prefix_len = min( + request_prefix_len, + int(input_ids.shape[1]), + ) + if ( + cumulative_hashes_by_prefix_len + and active_prefix_len in cumulative_hashes_by_prefix_len + ): + prepared = bridge.prepare_request( + request_id=request_id, + input_dict=input_dict, + attention_hit_len=0, + request_prefix_len=active_prefix_len, + cumulative_hashes_by_prefix_len=cumulative_hashes_by_prefix_len, + attention_block_refs_by_prefix_len=attention_block_refs_by_prefix_len, + ) + else: + return _with_zero_hybrid_apc_slots(input_dict) + # The live prefix-caching request has already been sliced to the + # attention suffix. Without full prompt tokens the bridge cannot + # compute or apply an exact GDN checkpoint boundary. + else: + prepare_suffix_only = getattr( + bridge, + "prepare_suffix_only_request", + None, + ) + if prepare_suffix_only is not None: + suffix_len = int(input_ids.shape[1]) + active_prefix_len = request_prefix_len + hit_len = _to_python_int(attention_hit_len) + # vLLM chunked prefill may report the final prompt length in + # request_prefix_len while scheduling only the next suffix + # chunk. Hybrid APC restore/commit must use the active chunk + # boundary, otherwise the suffix-only bridge rejects the row. + same_request_chunk_continuation = ( + _is_same_request_chunked_prefill_continuation( + input_dict, + request_id=request_id, + request_prefix_len=request_prefix_len, + hit_len=hit_len, + suffix_len=suffix_len, + active_suffix_len=active_suffix_len, + ) + ) + if same_request_chunk_continuation: + active_chunk_suffix_len = _active_chunk_suffix_len( + suffix_len=suffix_len, + active_suffix_len=active_suffix_len, + ) + if active_chunk_suffix_len is None: + active_chunk_suffix_len = suffix_len + active_prefix_len = min( + request_prefix_len, + hit_len + active_chunk_suffix_len, + ) + if ( + active_chunk_suffix_len != suffix_len + or _is_seq_id_fallback_request_id(request_id) + ): + return _with_inert_hybrid_apc_chunk_continuation( + input_dict, + hit_len=hit_len, + active_prefix_len=active_prefix_len, + suffix_len=active_chunk_suffix_len, + ) + try: + prepared = prepare_suffix_only( + request_id=request_id, + input_dict=input_dict, + attention_hit_len=hit_len, + request_prefix_len=active_prefix_len, + cumulative_hashes_by_prefix_len=cumulative_hashes_by_prefix_len, + attention_block_refs_by_prefix_len=attention_block_refs_by_prefix_len, + ) + if ( + prepared is not None + and same_request_chunk_continuation + ): + prepared = _replace_prepared_input_dict( + prepared, + _with_same_request_gdn_active_carry( + prepared.input_dict + ), + ) + except ValueError as exc: + if os.environ.get("QWEN36_HYBRID_APC_DEBUG") == "1": + print( + "[hybrid_apc_debug] suffix-unbacked-check " + f"request_id={request_id!r} " + f"same_request_chunk_continuation={same_request_chunk_continuation} " + f"request_prefix_len={request_prefix_len} " + f"hit_len={hit_len} suffix_len={suffix_len} " + f"active_suffix_len={active_suffix_len} " + f"query_len={query_len} " + f"error={exc}", + flush=True, + ) + if ( + same_request_chunk_continuation + and _is_unbacked_suffix_only_hybrid_apc_error(exc) + ): + return _with_inert_hybrid_apc_chunk_continuation( + input_dict, + hit_len=hit_len, + active_prefix_len=active_prefix_len, + suffix_len=suffix_len, + ) + raise + if prepared is None: + if requires_external_metadata: + raise ValueError( + "hybrid APC production mode received suffix-only input " + "without hybrid_full_input_ids/full_input_ids; request prep " + "must attach full prompt tokens before suffix slicing" + ) + return input_dict + if not _single_batch_tensor(bridge_input_dict.get("input_ids")): + if requires_external_metadata: + raise ValueError( + "hybrid APC v0 request prep supports one request at a time; " + "vectorized continuous-batching metadata is not wired yet" + ) + return input_dict + + if prepared is None: + prepared = bridge.prepare_request( + request_id=request_id, + input_dict=bridge_input_dict, + attention_hit_len=_to_python_int(attention_hit_len), + request_prefix_len=request_prefix_len, + cumulative_hashes_by_prefix_len=cumulative_hashes_by_prefix_len, + attention_block_refs_by_prefix_len=attention_block_refs_by_prefix_len, + ) + setattr(neuron_base_instance, "_hybrid_apc_last_bridge", bridge) + input_dict["_hybrid_apc_bridge"] = bridge + input_dict["_hybrid_apc_prepared"] = prepared + prepared.input_dict["_hybrid_apc_bridge"] = bridge + prepared.input_dict["_hybrid_apc_prepared"] = prepared + if lifecycle_input_dict is not input_dict: + lifecycle_input_dict["_hybrid_apc_bridge"] = bridge + lifecycle_input_dict["_hybrid_apc_prepared"] = prepared + _apply_hybrid_gdn_debug_switches(prepared.input_dict) + if os.environ.get("QWEN36_HYBRID_APC_DEBUG") == "1": + prepared_inputs = prepared.input_dict + print( + "[hybrid_apc_debug] prepare " + f"request_id={request_id!r} attention_hit_len={_to_python_int(attention_hit_len)} " + f"request_prefix_len={request_prefix_len} restore_len={prepared.plan.restore_checkpoint_prefix_len} " + f"commit_prefix_len={prepared.commit_prefix_len} restore_slot={prepared.plan.checkpoint_slot} " + f"commit_slot={prepared.commit_slot} input_shape={tuple(input_dict['input_ids'].shape)} " + f"prepared_shape={tuple(prepared_inputs['input_ids'].shape)} " + f"computed={prepared_inputs.get('computed_context_lens')} " + f"num_queries={prepared_inputs.get('num_queries')} " + f"restore_mask={prepared_inputs.get('hybrid_restore_mask')} " + f"commit_mask={prepared_inputs.get('hybrid_commit_mask')}", + flush=True, + ) + return prepared.input_dict + + +def finish_hybrid_apc_request(input_dict: Dict[str, Any]): + bridge = input_dict.pop("_hybrid_apc_bridge", None) + prepared = input_dict.pop("_hybrid_apc_prepared", None) + if os.environ.get("QWEN36_HYBRID_APC_DEBUG") == "1": + print( + "[hybrid_apc_debug] finish " + f"has_bridge={bridge is not None} has_prepared={prepared is not None}", + flush=True, + ) + if bridge is None or prepared is None: + return + + prepared_requests = prepared if isinstance(prepared, list) else [prepared] + if _hybrid_gdn_commit_disabled(): + for prepared_request in prepared_requests: + bridge.cancel_request(prepared_request) + return + for prepared_request in prepared_requests: + actual_refs = _first_present( + input_dict.get("actual_refs"), + input_dict.get("actual_attention_block_refs"), + input_dict.get("hybrid_actual_attention_block_refs"), + getattr(prepared_request, "attention_block_refs", None), + ) + if os.environ.get("QWEN36_HYBRID_APC_DEBUG") == "1": + print( + "[hybrid_apc_debug] finish-commit " + f"request_id={prepared_request.request_id!r} " + f"commit_prefix_len={getattr(prepared_request, 'commit_prefix_len', None)} " + f"commit_slot={getattr(prepared_request, 'commit_slot', None)} " + f"actual_refs={actual_refs}", + flush=True, + ) + try: + bridge.commit_prefill(prepared_request, attention_block_refs=actual_refs) + except Exception: + bridge.cancel_request(prepared_request) + raise + bridge.finish_request(prepared_request.request_id) + + +def cancel_hybrid_apc_request(input_dict: Dict[str, Any]): + bridge = input_dict.pop("_hybrid_apc_bridge", None) + prepared = input_dict.pop("_hybrid_apc_prepared", None) + if bridge is None or prepared is None: + return + prepared_requests = prepared if isinstance(prepared, list) else [prepared] + for prepared_request in prepared_requests: + bridge.cancel_request(prepared_request) + + +def _active_hybrid_apc_slots( + slot_ids: torch.Tensor, + mask: torch.Tensor, + *, + name: str, +) -> list[int]: + slot_ids = slot_ids.reshape(-1).to(torch.int64) + mask = mask.reshape(-1).to(torch.bool) + if slot_ids.shape != mask.shape: + raise ValueError(f"{name} slot ids and mask must have matching shape") + return [int(slot.item()) for slot in slot_ids[mask]] + + +def _validate_hybrid_apc_slot_inputs( + neuron_base_instance: "NeuronBaseForCausalLM", + *, + restore_slot_ids: torch.Tensor, + restore_mask: torch.Tensor, + commit_slot_ids: torch.Tensor, + commit_mask: torch.Tensor, +): + """Validate active checkpoint slots before the traced model clamps them.""" + + active_restore_slots = _active_hybrid_apc_slots( + restore_slot_ids, + restore_mask, + name="hybrid restore", + ) + active_commit_slots = _active_hybrid_apc_slots( + commit_slot_ids, + commit_mask, + name="hybrid commit", + ) + if not active_restore_slots and not active_commit_slots: + return + + max_slots = getattr( + getattr(neuron_base_instance, "config", None), + "max_gdn_checkpoint_slots", + None, + ) + if max_slots is not None: + max_slots = int(max_slots) + for kind, slots in ( + ("restore", active_restore_slots), + ("commit", active_commit_slots), + ): + for slot in slots: + if slot < 0 or slot >= max_slots: + raise ValueError( + f"hybrid APC {kind} slot {slot} is outside [0, {max_slots})" + ) + + allocator = getattr(neuron_base_instance, "hybrid_apc_slot_allocator", None) + if allocator is None: + return + + committed_slots = set(getattr(allocator, "committed_slots", ())) + reserved_slots = set(getattr(allocator, "reserved_slots", ())) + for slot in active_restore_slots: + if slot not in committed_slots: + raise ValueError( + f"hybrid APC restore slot {slot} is not a committed checkpoint slot" + ) + for slot in active_commit_slots: + if slot not in reserved_slots: + raise ValueError( + f"hybrid APC commit slot {slot} is not a reserved checkpoint slot" + ) + + +def prepare_hybrid_apc_model_inputs( + neuron_base_instance: "NeuronBaseForCausalLM", + input_dict: Dict[str, Any], +) -> list[torch.Tensor]: + """Build optional Qwen hybrid APC args for prefix-caching execution. + + The vLLM scheduler owns prefix hashes and checkpoint-slot allocation. This + bridge only translates scheduler-provided values into the fixed traced model + inputs. If no restore/commit slots are supplied, masks stay zero and the + model executes without GDN checkpoint reuse. + """ + + if not _is_hybrid_apc_enabled(neuron_base_instance): + return [] + + batch_size = int(input_dict["seq_ids"].reshape(-1).shape[0]) + empty = torch.empty(0) + + computed_context_lens = input_dict.get("computed_context_lens") + if computed_context_lens is None: + restore_prefix_lens = torch.zeros((batch_size,), dtype=torch.int32) + else: + restore_prefix_lens = computed_context_lens.reshape(-1).to(torch.int32) + if restore_prefix_lens.shape[0] != batch_size: + restore_prefix_lens = _batch_vector( + {"value": restore_prefix_lens}, + "value", + batch_size=batch_size, + ) + + restore_slot_ids = _batch_vector( + input_dict, + "hybrid_restore_slot_ids", + batch_size=batch_size, + default=0, + ) + restore_mask = _batch_vector( + input_dict, + "hybrid_restore_mask", + batch_size=batch_size, + default=0, + ) + + if "hybrid_restore_prefix_lens" in input_dict: + restore_prefix_lens = _batch_vector( + input_dict, + "hybrid_restore_prefix_lens", + batch_size=batch_size, + default=0, + ) + + commit_slot_ids = _batch_vector( + input_dict, + "hybrid_commit_slot_ids", + batch_size=batch_size, + default=0, + ) + if "hybrid_commit_mask" in input_dict: + commit_mask = _batch_vector( + input_dict, + "hybrid_commit_mask", + batch_size=batch_size, + default=0, + ) + else: + commit_mask = torch.zeros((batch_size,), dtype=torch.int32) + + switch_inputs = { + "hybrid_restore_mask": restore_mask, + "hybrid_commit_mask": commit_mask, + } + _apply_hybrid_gdn_debug_switches(switch_inputs) + restore_mask = switch_inputs["hybrid_restore_mask"] + commit_mask = switch_inputs["hybrid_commit_mask"] + + _validate_hybrid_apc_slot_inputs( + neuron_base_instance, + restore_slot_ids=restore_slot_ids, + restore_mask=restore_mask, + commit_slot_ids=commit_slot_ids, + commit_mask=commit_mask, + ) + + llava_args = input_dict.get("llava_args") or [] + rotary_position_id = _first_present( + input_dict.get("rotary_position_id"), + input_dict.get("rotary_position_ids"), + llava_args[2] if len(llava_args) >= 3 else None, + empty, + ) + vision_embeddings = _first_present( + input_dict.get("vision_embeddings"), + llava_args[0] if len(llava_args) >= 1 else None, + empty, + ) + vision_mask = _first_present( + input_dict.get("vision_mask"), + llava_args[1] if len(llava_args) >= 2 else None, + empty, + ) + + return [ + input_dict.get("tile_q_indices", empty), + input_dict.get("tile_block_tables", empty), + input_dict.get("tile_masks", empty), + input_dict.get("inputs_embeds", empty), + input_dict.get("kv_cache", empty), + input_dict.get("active_mask", empty), + rotary_position_id, + vision_embeddings, + vision_mask, + restore_slot_ids, + restore_mask, + restore_prefix_lens, + commit_slot_ids, + commit_mask, + ] + + +def prepare_disabled_hybrid_apc_model_inputs( + neuron_base_instance: "NeuronBaseForCausalLM", + input_dict: Dict[str, Any], +) -> list[torch.Tensor]: + """Build inert Hybrid APC args for decode/TKG execution. + + Hybrid APC restore/commit is a prefill concern. The compiled model still has + the fixed Hybrid APC inputs, so decode must pass the same arity, but it does + not need request planning, slot validation, or mask/vector normalization on + every generated token. + """ + + if not _is_hybrid_apc_enabled(neuron_base_instance): + return [] + + seq_ids = input_dict["seq_ids"].reshape(-1) + batch_size = int(seq_ids.shape[0]) + device = seq_ids.device + empty = torch.empty(0, device=device) + zeros = torch.zeros((batch_size,), dtype=torch.int32, device=device) + + llava_args = input_dict.get("llava_args") or [] + rotary_position_id = _first_present( + input_dict.get("rotary_position_id"), + input_dict.get("rotary_position_ids"), + llava_args[2] if len(llava_args) >= 3 else None, + empty, + ) + vision_embeddings = _first_present( + input_dict.get("vision_embeddings"), + llava_args[0] if len(llava_args) >= 1 else None, + empty, + ) + vision_mask = _first_present( + input_dict.get("vision_mask"), + llava_args[1] if len(llava_args) >= 2 else None, + empty, + ) + + return [ + input_dict.get("tile_q_indices", empty), + input_dict.get("tile_block_tables", empty), + input_dict.get("tile_masks", empty), + input_dict.get("inputs_embeds", empty), + input_dict.get("kv_cache", empty), + input_dict.get("active_mask", empty), + rotary_position_id, + vision_embeddings, + vision_mask, + zeros, + zeros, + zeros, + zeros, + zeros, + ] + + +def _is_context_encoding_execution( + neuron_base_instance: "NeuronBaseForCausalLM", + model_to_execute: "ModelWrapper", + input_dict: Dict[str, Any], +) -> bool: + if getattr(model_to_execute, "tag", None) == "context_encoding_model": + return True + input_ids = input_dict.get("input_ids") + if ( + isinstance(input_ids, torch.Tensor) + and input_ids.ndim >= 2 + and input_ids.shape[-1] > 1 + and not getattr(neuron_base_instance.neuron_config, "enable_fused_speculation", False) + and not getattr(neuron_base_instance.neuron_config, "enable_eagle_speculation", False) + ): + return True + is_prefill = getattr(neuron_base_instance, "_is_prefill", None) + position_ids = input_dict.get("position_ids") + if callable(is_prefill) and position_ids is not None: + return bool(is_prefill(position_ids)) + return False + + +def _is_cached_chunked_prefill_continuation(inputs: Dict[str, Any]) -> bool: + batch_size = _batch_size_from_input_dict(inputs) + if batch_size != 1: + return False + + request_records = _hybrid_apc_request_records(inputs, batch_size=batch_size) + active_suffix_len = _first_present( + inputs.get("hybrid_active_suffix_len"), + inputs.get("active_suffix_len"), + _hybrid_apc_record_values(request_records, "active_suffix_len"), + ) + active_suffix_len = _single_batch_value(active_suffix_len) + if active_suffix_len is None: + return False + active_suffix_len = _to_python_int(active_suffix_len) + if active_suffix_len <= 0: + return False + + attention_hit_len = _first_present( + _hybrid_apc_record_values(request_records, "vllm_attention_hit_len"), + inputs.get("vllm_attention_hit_len"), + inputs.get("hybrid_attention_hit_len"), + inputs.get("attention_hit_len"), + inputs.get("computed_context_lens"), + ) + attention_hit_len = _single_batch_value(attention_hit_len) + if attention_hit_len is None: + return False + attention_hit_len = _to_python_int(attention_hit_len) + if attention_hit_len <= 0: + return False + + request_prefix_len = _first_present( + _hybrid_apc_record_values(request_records, "request_prefix_len"), + inputs.get("request_prefix_len"), + inputs.get("hybrid_request_prefix_len"), + inputs.get("prompt_len"), + _single_batch_value(inputs.get("full_context_lens")), + ) + request_prefix_len = _single_batch_value(request_prefix_len) + if request_prefix_len is None: + return False + + return _to_python_int(request_prefix_len) >= attention_hit_len + active_suffix_len + + +def _is_chunked_prefill_execution( + neuron_base_instance: "NeuronBaseForCausalLM", + inputs: Dict[str, Any], + *, + is_fused_speculation: bool, +) -> bool: + if is_fused_speculation: + return False + if getattr(neuron_base_instance.neuron_config, "enable_eagle_speculation", False): + return False + input_ids = inputs.get("input_ids") + if ( + isinstance(input_ids, torch.Tensor) + and input_ids.ndim >= 2 + and input_ids.shape[-1] > 1 + ): + return True + return _is_cached_chunked_prefill_continuation(inputs) + + +def _debug_hybrid_apc_owner_metadata_summary(owner: Any) -> str: + if owner is None: + return "None" + parts = [type(owner).__name__] + for attr in ( + "_qwen36_vllm_request_ids", + "_qwen36_vllm_cached_request_ids", + "_qwen36_vllm_prefill_completion_state", + "_qwen36_vllm_hybrid_apc_request_records", + "_qwen36_vllm_hybrid_apc_metadata_by_request_id", + ): + value = getattr(owner, attr, None) + if value is None: + continue + if isinstance(value, dict): + parts.append(f"{attr}=dict[{len(value)}]") + elif isinstance(value, (list, tuple)): + parts.append(f"{attr}=seq[{len(value)}]") + elif isinstance(value, torch.Tensor): + parts.append(f"{attr}=tensor{tuple(value.shape)}") + else: + parts.append(f"{attr}={type(value).__name__}") + return " ".join(parts) + + +def _format_token_id(value: int) -> str: + if value < 0: + return str(value) + return f"{value} (0x{value & 0xFFFFFFFF:08x})" + + +def _model_vocab_size(neuron_base_instance: "NeuronBaseForCausalLM") -> int | None: + for owner in ( + neuron_base_instance, + getattr(neuron_base_instance, "config", None), + getattr(neuron_base_instance, "model", None), + getattr(getattr(neuron_base_instance, "model", None), "config", None), + ): + vocab_size = getattr(owner, "vocab_size", None) + if vocab_size is not None: + try: + return int(vocab_size) + except (TypeError, ValueError): + return None + return None + + +def _summarize_tensor_minmax(value: Any) -> str: + if not isinstance(value, torch.Tensor) or value.numel() == 0: + return "empty" + try: + flat = value.detach().reshape(-1) + return f"{int(flat.min().item())}:{int(flat.max().item())}" + except Exception as exc: + return f"unavailable:{type(exc).__name__}" + + +def _validate_token_generation_input_ids( + neuron_base_instance: "NeuronBaseForCausalLM", + model_to_execute: "ModelWrapper", + input_dict: Dict[str, Any], +) -> None: + if getattr(model_to_execute, "tag", None) != "token_generation_model": + return + input_ids = input_dict.get("input_ids") + if not isinstance(input_ids, torch.Tensor): + return + if input_ids.numel() == 0: + raise ValueError("Token generation input_ids must be non-empty") + if input_ids.dtype not in (torch.int32, torch.int64): + raise ValueError( + "Token generation input_ids must be int32 or int64 before Neuron " + f"execution, got {input_ids.dtype}" + ) + + min_id = int(input_ids.min().item()) + max_id = int(input_ids.max().item()) + vocab_size = _model_vocab_size(neuron_base_instance) + invalid_id = None + reason = None + if min_id < 0: + invalid_id = min_id + reason = "negative" + elif vocab_size is not None and max_id >= vocab_size: + invalid_id = max_id + reason = f"out-of-vocab for vocab_size={vocab_size}" + if invalid_id is None: + return + + request_ids = getattr(neuron_base_instance, "_qwen36_vllm_request_ids", None) + if request_ids is None: + request_ids = input_dict.get("request_ids", input_dict.get("request_id")) + raise ValueError( + "Token generation input_ids contract violated before Neuron execution: " + f"{reason}; token_id={_format_token_id(invalid_id)}; " + f"request_ids={request_ids}; " + f"input_shape={tuple(input_ids.shape)} dtype={input_ids.dtype}; " + f"position_minmax={_summarize_tensor_minmax(input_dict.get('position_ids'))}; " + f"slot_minmax={_summarize_tensor_minmax(input_dict.get('slot_mapping'))}; " + f"block_minmax={_summarize_tensor_minmax(input_dict.get('block_table'))}; " + f"num_queries={_summarize_tensor_minmax(input_dict.get('num_queries'))}; " + "computed_context_lens=" + f"{_summarize_tensor_minmax(input_dict.get('computed_context_lens'))}" + ) + + +def _with_disabled_hybrid_apc_controls(input_dict: Dict[str, Any]) -> Dict[str, Any]: + output = dict(input_dict) + _zero_mask_if_present(output, "hybrid_restore_mask") + _zero_mask_if_present(output, "hybrid_commit_mask") + return output + + class AsyncTensorWrapper: """ Wrapper class for tensors from models executed with async runtime. @@ -76,56 +2343,135 @@ def execute_model_prefix_caching( input_dict: Dict[str, Any], pad_type: str = "first_fit", ) -> Tuple[AsyncTensorWrapper, bool]: - if "num_queries" not in input_dict: - full_context_lens = input_dict["full_context_lens"] - computed_context_lens = input_dict["computed_context_lens"] - num_queries = full_context_lens - computed_context_lens - input_dict["num_queries"] = num_queries + original_input_dict = input_dict + hybrid_apc_owner = _select_hybrid_apc_owner( + neuron_base_instance, + model_to_execute, + input_dict, + ) + if os.environ.get("QWEN36_HYBRID_APC_DEBUG") == "1": + print( + "[hybrid_apc_debug] async-owner " + f"base_enabled={_is_hybrid_apc_enabled(neuron_base_instance)} " + f"wrapper_enabled={_is_hybrid_apc_enabled(model_to_execute)} " + f"owner_type={type(hybrid_apc_owner).__name__} " + f"tag={getattr(model_to_execute, 'tag', None)}", + flush=True, + ) + try: + is_context_encoding = _is_context_encoding_execution( + neuron_base_instance, + model_to_execute, + input_dict, + ) + if is_context_encoding: + input_dict = _with_hybrid_apc_owner_metadata( + input_dict, + hybrid_apc_owner, + ) + input_dict = prepare_hybrid_apc_request_for_execution( + hybrid_apc_owner, + input_dict, + ) + prepared_bridge = input_dict.get("_hybrid_apc_bridge") + if prepared_bridge is not None: + for bridge_owner in ( + neuron_base_instance, + model_to_execute, + getattr(neuron_base_instance, "context_encoding_model", None), + getattr(neuron_base_instance, "token_generation_model", None), + ): + if bridge_owner is not None: + setattr( + bridge_owner, + "_hybrid_apc_last_bridge", + prepared_bridge, + ) + if os.environ.get("QWEN36_HYBRID_APC_DEBUG") == "1": + print( + "[hybrid_apc_debug] async-prepared-return " + f"has_bridge={'_hybrid_apc_bridge' in input_dict} " + f"has_prepared={'_hybrid_apc_prepared' in input_dict}", + flush=True, + ) + for lifecycle_key in ("_hybrid_apc_bridge", "_hybrid_apc_prepared"): + if lifecycle_key in input_dict: + original_input_dict[lifecycle_key] = input_dict[lifecycle_key] + if "_hybrid_apc_prepared" in input_dict: + setattr( + neuron_base_instance, + "_hybrid_apc_pending_input_dict", + input_dict, + ) + else: + input_dict = _with_disabled_hybrid_apc_controls(input_dict) + if "num_queries" not in input_dict: + full_context_lens = input_dict["full_context_lens"] + computed_context_lens = input_dict["computed_context_lens"] + num_queries = full_context_lens - computed_context_lens + input_dict["num_queries"] = num_queries - if ( - not neuron_base_instance.neuron_config.enable_fused_speculation - and not neuron_base_instance.neuron_config.enable_eagle_speculation - ): - return model_to_execute( - input_dict["input_ids"], - input_dict["attention_mask"], - input_dict["position_ids"], - input_dict["seq_ids"], - input_dict["sampling_params"], - torch.empty(0), # prev_hidden - input_dict["adapter_ids"], - torch.empty(0), # accepted_indices - torch.empty(0), # current_length - torch.empty(0), # medusa_mask - torch.empty(0), # scatter_index - input_dict["slot_mapping"], - input_dict["block_table"], - input_dict["num_queries"], - input_dict["computed_context_lens"], - pad_type=pad_type - ), model_to_execute.is_neuron() - elif neuron_base_instance.neuron_config.enable_eagle_speculation: - return model_to_execute( - input_dict["input_ids"], - input_dict["attention_mask"], - input_dict["position_ids"], - input_dict["seq_ids"], - input_dict["sampling_params"], - torch.empty(0), # prev_hidden - input_dict["adapter_ids"], - input_dict["slot_mapping"], - input_dict["block_table"], - input_dict["num_queries"], - input_dict["computed_context_lens"], - torch.empty(0), # target_input_ids - torch.empty(0), # target_attention_mask - torch.empty(0), # target_position_ids - torch.empty(0), # target_slot_mapping - torch.empty(0), # target_active_block_table - pad_type=pad_type - ), model_to_execute.is_neuron() - else: - raise NotImplementedError("Non-EAGLE fused speculation with prefix caching does not support async mode.") + if ( + not neuron_base_instance.neuron_config.enable_fused_speculation + and not neuron_base_instance.neuron_config.enable_eagle_speculation + ): + if is_context_encoding: + hybrid_apc_args = prepare_hybrid_apc_model_inputs( + hybrid_apc_owner, input_dict + ) + else: + _validate_token_generation_input_ids( + neuron_base_instance, + model_to_execute, + input_dict, + ) + hybrid_apc_args = prepare_disabled_hybrid_apc_model_inputs( + hybrid_apc_owner, input_dict + ) + return model_to_execute( + input_dict["input_ids"], + input_dict["attention_mask"], + input_dict["position_ids"], + input_dict["seq_ids"], + input_dict["sampling_params"], + torch.empty(0), # prev_hidden + input_dict["adapter_ids"], + torch.empty(0), # accepted_indices + torch.empty(0), # current_length + torch.empty(0), # medusa_mask + torch.empty(0), # scatter_index + input_dict["slot_mapping"], + input_dict["block_table"], + input_dict["num_queries"], + input_dict["computed_context_lens"], + *hybrid_apc_args, + pad_type=pad_type + ), model_to_execute.is_neuron() + elif neuron_base_instance.neuron_config.enable_eagle_speculation: + return model_to_execute( + input_dict["input_ids"], + input_dict["attention_mask"], + input_dict["position_ids"], + input_dict["seq_ids"], + input_dict["sampling_params"], + torch.empty(0), # prev_hidden + input_dict["adapter_ids"], + input_dict["slot_mapping"], + input_dict["block_table"], + input_dict["num_queries"], + input_dict["computed_context_lens"], + torch.empty(0), # target_input_ids + torch.empty(0), # target_attention_mask + torch.empty(0), # target_position_ids + torch.empty(0), # target_slot_mapping + torch.empty(0), # target_active_block_table + pad_type=pad_type + ), model_to_execute.is_neuron() + else: + raise NotImplementedError("Non-EAGLE fused speculation with prefix caching does not support async mode.") + except Exception: + cancel_hybrid_apc_request(original_input_dict) + raise def execute_model( @@ -196,26 +2542,118 @@ def causal_lm_async_execution( # PREFILL STAGE: is_prefill = neuron_base_instance._is_prefill(inputs["position_ids"]) + prefill_probe_inputs = inputs + if is_prefix_caching: + prefill_probe_inputs = _with_hybrid_apc_candidate_owner_metadata( + inputs, + neuron_base_instance, + getattr(neuron_base_instance, "context_encoding_model", None), + getattr(neuron_base_instance, "token_generation_model", None), + ) + probe_is_chunked_prefill = ( + is_prefix_caching + and _is_chunked_prefill_execution( + neuron_base_instance, + prefill_probe_inputs, + is_fused_speculation=is_fused_speculation, + ) + ) + if ( + is_prefix_caching + and not is_prefill + and os.environ.get("QWEN36_HYBRID_APC_DEBUG") == "1" + ): + input_ids = inputs.get("input_ids") + records = prefill_probe_inputs.get("hybrid_request_records") + print( + "[hybrid_apc_debug] prefill-route-probe " + f"input_shape={tuple(input_ids.shape) if isinstance(input_ids, torch.Tensor) else None} " + f"probe_is_chunked_prefill={probe_is_chunked_prefill} " + f"probe_keys={sorted(k for k in prefill_probe_inputs if k.startswith('hybrid_') or k in ('request_prefix_len', 'active_suffix_len', 'vllm_attention_hit_len'))} " + f"records_len={len(records) if isinstance(records, tuple) else None} " + f"base={_debug_hybrid_apc_owner_metadata_summary(neuron_base_instance)} " + f"context={_debug_hybrid_apc_owner_metadata_summary(getattr(neuron_base_instance, 'context_encoding_model', None))} " + f"token={_debug_hybrid_apc_owner_metadata_summary(getattr(neuron_base_instance, 'token_generation_model', None))}", + flush=True, + ) + if ( + is_prefix_caching + and not is_prefill + and probe_is_chunked_prefill + ): + is_prefill = True + inputs = prefill_probe_inputs + elif is_prefix_caching and is_prefill: + inputs = prefill_probe_inputs neuron_base_instance.async_should_stop = False prefill_outputs = None is_run_on_neuron = None if is_prefill: - prefill_outputs, is_run_on_neuron = execute_model( - neuron_base_instance, neuron_base_instance.context_encoding_model, inputs - ) + try: + timing_enabled = os.environ.get("QWEN36_PREFILL_TIMING") == "1" + execute_start = time.perf_counter() if timing_enabled else None + prefill_outputs, is_run_on_neuron = execute_model( + neuron_base_instance, neuron_base_instance.context_encoding_model, inputs + ) + if timing_enabled and execute_start is not None: + input_ids = inputs.get("input_ids") + position_ids = inputs.get("position_ids") + computed_context_lens = inputs.get("computed_context_lens") + num_queries = inputs.get("num_queries") + print( + "[qwen36_perf] async_execute_model " + f"elapsed_ms={(time.perf_counter() - execute_start) * 1000.0:.3f} " + f"is_run_on_neuron={is_run_on_neuron} " + f"input_shape={tuple(input_ids.shape) if isinstance(input_ids, torch.Tensor) else None} " + f"position_shape={tuple(position_ids.shape) if isinstance(position_ids, torch.Tensor) else None} " + f"num_queries={num_queries.reshape(-1).tolist() if isinstance(num_queries, torch.Tensor) and num_queries.numel() else []} " + f"computed={computed_context_lens.reshape(-1).tolist() if isinstance(computed_context_lens, torch.Tensor) and computed_context_lens.numel() else []} " + f"request_ids={_async_request_ids_signature(neuron_base_instance)}", + flush=True, + ) - # Sequence IDs from vLLM will be in sorted order, but the maximum range of sequence IDs is - # not [0, num_requested_prefills] but [0, max_num_seqs]. To prevent out-of-bound accesses, - # we convert the sequence IDs to their argsorted values. - _seq_ids = torch.argsort(inputs["seq_ids"]) + # Sequence IDs from vLLM will be in sorted order, but the maximum range of sequence IDs is + # not [0, num_requested_prefills] but [0, max_num_seqs]. To prevent out-of-bound accesses, + # we convert the sequence IDs to their argsorted values. + _seq_ids = torch.argsort(inputs["seq_ids"]) - outputs = prefill_outputs.sync_async_result_to_cpu( - _seq_ids, is_fused_speculation=is_fused_speculation, is_prefix_caching=is_prefix_caching - ) + sync_start = time.perf_counter() if timing_enabled else None + outputs = prefill_outputs.sync_async_result_to_cpu( + _seq_ids, is_fused_speculation=is_fused_speculation, is_prefix_caching=is_prefix_caching + ) + if timing_enabled and sync_start is not None: + print( + "[qwen36_perf] async_sync_result " + f"elapsed_ms={(time.perf_counter() - sync_start) * 1000.0:.3f} " + f"is_run_on_neuron={is_run_on_neuron} " + f"request_ids={_async_request_ids_signature(neuron_base_instance)}", + flush=True, + ) + pending_hybrid_apc = getattr( + neuron_base_instance, + "_hybrid_apc_pending_input_dict", + None, + ) + neuron_base_instance._hybrid_apc_pending_input_dict = None + finish_hybrid_apc_request( + pending_hybrid_apc if pending_hybrid_apc is not None else inputs + ) + except Exception: + pending_hybrid_apc = getattr( + neuron_base_instance, + "_hybrid_apc_pending_input_dict", + None, + ) + neuron_base_instance._hybrid_apc_pending_input_dict = None + cancel_hybrid_apc_request( + pending_hybrid_apc if pending_hybrid_apc is not None else inputs + ) + raise # clean up async state neuron_base_instance.prior_outputs = None neuron_base_instance.prior_seq_ids = None + neuron_base_instance.prior_request_ids = None return outputs, is_run_on_neuron @@ -234,10 +2672,20 @@ def causal_lm_async_execution( buckets=generation_model.neuron_config.buckets, max_num_tokens_generated=generation_length, ) + request_ids_signature = _async_request_ids_signature(neuron_base_instance) + prior_request_ids = getattr(neuron_base_instance, "prior_request_ids", None) + request_ids_changed = ( + request_ids_signature is not None + and prior_request_ids is not None + and request_ids_signature != prior_request_ids + ) + force_sync_for_hybrid_apc = _is_hybrid_apc_enabled(neuron_base_instance) stay_in_sync_mode = ( not torch.equal(neuron_base_instance.prior_seq_ids, inputs["seq_ids"]) or hits_bucket_boundary + or request_ids_changed + or force_sync_for_hybrid_apc ) start_async = not stay_in_sync_mode and neuron_base_instance.prior_outputs is None continue_async = not stay_in_sync_mode and not start_async @@ -246,6 +2694,7 @@ def causal_lm_async_execution( # reset async state neuron_base_instance.prior_outputs = None neuron_base_instance.prior_seq_ids = None + neuron_base_instance.prior_request_ids = None if stay_in_sync_mode or start_async: next_outputs, is_run_on_neuron = execute_model( @@ -257,6 +2706,7 @@ def causal_lm_async_execution( if start_async: neuron_base_instance.prior_outputs = next_outputs neuron_base_instance.prior_seq_ids = inputs["seq_ids"] + neuron_base_instance.prior_request_ids = request_ids_signature if start_async or continue_async: if within_bounds(inputs, neuron_base_instance.neuron_config.seq_len, generation_length): @@ -279,6 +2729,7 @@ def causal_lm_async_execution( ) neuron_base_instance.prior_outputs = None neuron_base_instance.prior_seq_ids = None + neuron_base_instance.prior_request_ids = None neuron_base_instance.async_should_stop = True else: raise RuntimeError( @@ -297,10 +2748,12 @@ def causal_lm_async_execution( # make sure prior outputs is not set neuron_base_instance.prior_outputs = None neuron_base_instance.prior_seq_ids = None + neuron_base_instance.prior_request_ids = None return outputs, is_run_on_neuron # next step neuron_base_instance.prior_outputs = next_outputs neuron_base_instance.prior_seq_ids = inputs["seq_ids"] + neuron_base_instance.prior_request_ids = request_ids_signature return outputs, is_run_on_neuron diff --git a/src/neuronx_distributed_inference/modules/attention/attention_base.py b/src/neuronx_distributed_inference/modules/attention/attention_base.py index ccbf4506..1ded860c 100644 --- a/src/neuronx_distributed_inference/modules/attention/attention_base.py +++ b/src/neuronx_distributed_inference/modules/attention/attention_base.py @@ -82,7 +82,38 @@ def import_nki_cte_attention_kernel(): return nki.jit(attention_cte), _has_native_gqa_tp_support +def import_nki_segmented_cte_attention_kernel(): + """Import the Neuron 2.30 block-KV segmented CTE kernel when available.""" + try: + mod = import_module("nkilib.core.attention.attention_segmented_cte") + except ImportError: + return None + + attention_segmented_cte = getattr(mod, "attention_segmented_cte", None) + if attention_segmented_cte is None: + return None + return nki.jit(attention_segmented_cte) + + +def import_nki_qwen_segmented_cte_256_kernel(): + """Import the Qwen head_dim=256 segmented CTE kernel when available.""" + try: + mod = import_module( + "neuronx_distributed_inference.modules.attention.nki_kernels." + "qwen_segcte256.attention_segmented_cte_256" + ) + except ImportError: + return None + + attention_segmented_cte = getattr(mod, "attention_segmented_cte", None) + if attention_segmented_cte is None: + return None + return nki.jit(attention_segmented_cte) + + _flash_fwd_call_nki, _has_native_gqa_tp_support = import_nki_cte_attention_kernel() +_segmented_cte_call_nki = import_nki_segmented_cte_attention_kernel() +_qwen_segmented_cte_256_call_nki = import_nki_qwen_segmented_cte_256_kernel() logger = logging.getLogger("Neuron") @@ -297,6 +328,7 @@ def init_tkg_cp_qkv_o_proj(self, process_group, rank_ordering=None): fused_rmsnorm_skip_gamma=self.neuron_config.fused_rmsnorm_skip_gamma, logical_nc_config=self.neuron_config.logical_nc_config, qkv_kernel_nbsd_layout=self.neuron_config.qkv_kernel_nbsd_layout, + quantized=self.neuron_config.quantized, on_cpu=self.neuron_config.on_cpu, rank_ordering=rank_ordering, ) @@ -319,6 +351,7 @@ def init_tkg_cp_qkv_o_proj(self, process_group, rank_ordering=None): out_proj_kernel_enabled=self.attn_block_tkg_nki_kernel_enabled or self.neuron_config.out_proj_kernel_enabled, logical_nc_config=self.neuron_config.logical_nc_config, rank_ordering=rank_ordering, + quantized=self.neuron_config.quantized, ) if self.learned_sinks_size is not None: self.tkg_learned_sinks = LearnedSink(self.learned_sinks_size, self.num_attention_heads, self.torch_dtype, process_group.size(), rank_ordering) @@ -352,6 +385,7 @@ def init_gqa_properties(self): fused_rmsnorm_skip_gamma=self.neuron_config.fused_rmsnorm_skip_gamma, logical_nc_config=self.neuron_config.logical_nc_config, qkv_kernel_nbsd_layout=self.neuron_config.qkv_kernel_nbsd_layout, + quantized=self.neuron_config.quantized, on_cpu=self.neuron_config.on_cpu, tiling_factor=self.neuron_config.cc_pipeline_tiling_factor if self.neuron_config.tile_cc else 1, seq_len_threshold_for_cc_tiling=self.neuron_config.seq_len_threshold_for_cc_tiling, @@ -377,6 +411,7 @@ def init_gqa_properties(self): logical_nc_config=self.neuron_config.logical_nc_config, tiling_factor=self.neuron_config.cc_pipeline_tiling_factor if self.neuron_config.tile_cc else 1, rank_ordering=cte_rank_ordering, + quantized=self.neuron_config.quantized, ) self.learned_sinks = None if self.learned_sinks_size is not None: @@ -769,11 +804,201 @@ def perform_prefill(self, Q, K, V, q_len, bsz, attention_mask) -> Tensor: attn_output = torch.matmul(active_scores, V_active) return attn_output, flash_attn_strategy - def perform_prefix_prefill(self, Q, K, V, q_len, bsz, attention_mask, past_key_value, active_mask) -> Tensor: + def _prefix_cte_attention_backend(self) -> str: + return getattr( + self.neuron_config, + "prefix_cte_attention_backend", + "attention_cte", + ) + + def _prefix_cte_attention_segment_size(self, q_len: int) -> int: + segment_size = getattr( + self.neuron_config, + "prefix_cte_attention_segment_size", + None, + ) + if segment_size is None: + return int(q_len) + return int(segment_size) + + def _prepare_segmented_cte_cache(self, segmented_past_key_value): + K_cache, V_cache = segmented_past_key_value + block_size = int(self.neuron_config.pa_block_size) + + if K_cache.dim() != 4 or V_cache.dim() != 4: + raise ValueError( + "segmented_cte prefix prefill requires 4D block KV cache " + f"tensors, got K={tuple(K_cache.shape)} V={tuple(V_cache.shape)}" + ) + if K_cache.shape[1] == block_size and V_cache.shape[1] == block_size: + K_cache = K_cache.permute(0, 2, 1, 3) + V_cache = V_cache.permute(0, 2, 1, 3) + elif K_cache.shape[2] != block_size or V_cache.shape[2] != block_size: + raise ValueError( + "segmented_cte prefix prefill expected block KV cache in " + "(blocks, block, heads, dim) or (blocks, heads, block, dim) " + f"layout, got K={tuple(K_cache.shape)} V={tuple(V_cache.shape)}" + ) + return K_cache, V_cache + + def perform_prefix_prefill_segmented_cte( + self, + Q, + q_len, + bsz, + segmented_past_key_value, + active_block_table, + computed_context_lens, + ) -> Tensor: + use_qwen_segcte256 = self.head_dim == 256 + segmented_cte_call_nki = _qwen_segmented_cte_256_call_nki if use_qwen_segcte256 else _segmented_cte_call_nki + if segmented_cte_call_nki is None: + if use_qwen_segcte256: + raise ImportError( + "prefix_cte_attention_backend=segmented_cte with head_dim=256 " + "requires the Qwen qwen_segcte256 kernel package." + ) + raise ImportError( + "prefix_cte_attention_backend=segmented_cte requires " + "nkilib.core.attention.attention_segmented_cte from Neuron " + "2.30+." + ) + if segmented_past_key_value is None: + raise ValueError( + "segmented_cte prefix prefill requires an already-updated " + "block KV cache." + ) + if active_block_table is None or computed_context_lens is None: + raise ValueError( + "segmented_cte prefix prefill requires active_block_table and " + "computed_context_lens." + ) + if self.k_cache_transposed: + raise ValueError("segmented_cte prefix prefill does not support transposed K cache.") + if use_qwen_segcte256 and self.sliding_window is not None: + raise ValueError("qwen_segcte256 segmented CTE does not support sliding-window attention.") + + block_size = int(self.neuron_config.pa_block_size) + prior_seg_size = self._prefix_cte_attention_segment_size(q_len) + if prior_seg_size % block_size != 0: + raise ValueError( + "prefix_cte_attention_segment_size must be divisible by " + f"pa_block_size ({block_size}), got {prior_seg_size}." + ) + + K_cache, V_cache = self._prepare_segmented_cte_cache(segmented_past_key_value) + Q = Q.reshape(bsz * self.num_heads, q_len, self.head_dim).to(self.torch_dtype) + Q = Q / self.softmax_scale + K_cache = K_cache.to(self.torch_dtype) + V_cache = V_cache.to(self.torch_dtype) + active_block_table = active_block_table.to(torch.int32) + prior_tokens = computed_context_lens.reshape(bsz, 1).to(torch.int32) + + learned_sinks = self.get_learned_sinks() + sink = None + if learned_sinks is not None: + sink = ( + learned_sinks.reshape(1, self.num_heads) + .expand(bsz, self.num_heads) + .reshape(bsz * self.num_heads, 1) + .to(self.torch_dtype) + ) + + flash_attn_strategy = self.get_flash_attention_strategy( + q_len, has_attention_mask=False + ) + tp_out = not use_qwen_segcte256 + if flash_attn_strategy == FlashAttentionStrategy.SHARDED_KERNEL: + attn_output = segmented_cte_call_nki[self.logical_nc_config]( + Q, + K_cache, + V_cache, + active_block_table, + prior_tokens, + block_size, + prior_seg_size, + 1.0, + tp_q=True, + tp_out=tp_out, + sliding_window=self.sliding_window, + sink=sink, + num_q_heads=self.num_heads, + k_pre_transposed=False, + ) + elif flash_attn_strategy == FlashAttentionStrategy.UNSHARDED_KERNEL: + attn_output = segmented_cte_call_nki( + Q, + K_cache, + V_cache, + active_block_table, + prior_tokens, + block_size, + prior_seg_size, + 1.0, + tp_q=True, + tp_out=tp_out, + sliding_window=self.sliding_window, + sink=sink, + num_q_heads=self.num_heads, + k_pre_transposed=False, + ) + else: + raise ValueError( + "segmented_cte prefix prefill requires the NKI flash attention " + f"strategy, got {flash_attn_strategy}." + ) + + if use_qwen_segcte256: + attn_output = attn_output.reshape((bsz, self.num_heads, q_len, self.head_dim)).permute(0, 1, 3, 2) + else: + attn_output = attn_output.reshape((bsz, self.num_heads, self.head_dim, q_len)) + logger.debug("Segmented CTE attn output after reshape %s", attn_output.shape) + return attn_output, flash_attn_strategy + + def perform_prefix_prefill( + self, + Q, + K, + V, + q_len, + bsz, + attention_mask, + past_key_value, + active_mask, + active_block_table=None, + computed_context_lens=None, + segmented_past_key_value=None, + ) -> Tensor: """attention computation at prefilling (context encoding) phase""" + if self._prefix_cte_attention_backend() == "segmented_cte": + return self.perform_prefix_prefill_segmented_cte( + Q, + q_len, + bsz, + segmented_past_key_value, + active_block_table, + computed_context_lens, + ) + K_prior = past_key_value[0] V_prior = past_key_value[1] - prior_len = K_prior.shape[-2] + prior_len = K_prior.shape[-1] if self.k_cache_transposed else K_prior.shape[-2] + + prefix_chunk_size = getattr( + self.neuron_config, "prefix_cte_attention_chunk_size", None + ) + if prefix_chunk_size is not None and prior_len > int(prefix_chunk_size): + return self.perform_prefix_prefill_chunked_prior( + Q, + K, + V, + q_len, + bsz, + attention_mask, + past_key_value, + active_mask, + int(prefix_chunk_size), + ) flash_attn_strategy = self.get_flash_attention_strategy(q_len, has_attention_mask=False) logger.debug(f"Flash attention strategy: {flash_attn_strategy}") @@ -913,6 +1138,112 @@ def perform_prefix_prefill(self, Q, K, V, q_len, bsz, attention_mask, past_key_v return attn_output, flash_attn_strategy + def perform_prefix_prefill_chunked_prior( + self, + Q, + K, + V, + q_len, + bsz, + attention_mask, + past_key_value, + active_mask, + prefix_chunk_size, + ) -> Tensor: + """ + Prefix-prefill attention with an online softmax over cached-prefix chunks. + + This path avoids materializing the full [Q, prefix] score tensor for very + long prefix-cache buckets. It is intended for long Hybrid APC buckets + such as [active=3072, prefix=262144], where the monolithic CTE NEFF can + exceed per-HBM scratchpad during runtime load. + """ + logger.debug( + "ATTN: chunked prefix prior for Q.shape=%s prefix_chunk_size=%s", + Q.shape, + prefix_chunk_size, + ) + + K_prior = past_key_value[0] + V_prior = past_key_value[1] + K_active = repeat_kv(K, self.num_key_value_groups) + V_active = repeat_kv(V, self.num_key_value_groups) + K_prior = repeat_kv(K_prior, self.num_key_value_groups) + V_prior = repeat_kv(V_prior, self.num_key_value_groups) + + prior_len = K_prior.shape[-1] if self.k_cache_transposed else K_prior.shape[-2] + compute_dtype = torch.float32 + q_compute = Q.to(compute_dtype) + + running_max = torch.full( + (bsz, self.num_heads, q_len, 1), + torch.finfo(compute_dtype).min, + dtype=compute_dtype, + device=Q.device, + ) + running_sum = torch.zeros( + (bsz, self.num_heads, q_len, 1), + dtype=compute_dtype, + device=Q.device, + ) + running_output = torch.zeros( + (bsz, self.num_heads, q_len, self.head_dim), + dtype=compute_dtype, + device=Q.device, + ) + + def update_online_softmax(scores, values, max_so_far, sum_so_far, out_so_far): + scores = scores.to(compute_dtype) + chunk_max = torch.max(scores, dim=-1, keepdim=True).values + new_max = torch.maximum(max_so_far, chunk_max) + old_scale = torch.exp(max_so_far - new_max) + chunk_scale = torch.exp(scores - new_max) + new_sum = sum_so_far * old_scale + torch.sum( + chunk_scale, dim=-1, keepdim=True + ) + new_out = out_so_far * old_scale + torch.matmul( + chunk_scale.to(values.dtype), + values, + ).to(compute_dtype) + return new_max, new_sum, new_out + + for prefix_start in range(0, int(prior_len), int(prefix_chunk_size)): + prefix_end = min(prefix_start + int(prefix_chunk_size), int(prior_len)) + if self.k_cache_transposed: + K_prior_chunk = K_prior[:, :, :, prefix_start:prefix_end] + else: + K_prior_chunk = K_prior[:, :, prefix_start:prefix_end, :].transpose(2, 3) + V_prior_chunk = V_prior[:, :, prefix_start:prefix_end, :] + prior_scores = torch.matmul(q_compute, K_prior_chunk.to(compute_dtype)) + prior_scores = prior_scores / self.softmax_scale + running_max, running_sum, running_output = update_online_softmax( + prior_scores, + V_prior_chunk, + running_max, + running_sum, + running_output, + ) + + active_scores = torch.matmul( + q_compute, + K_active.transpose(2, 3).to(compute_dtype), + ) + active_scores = active_scores / self.softmax_scale + active_scores = torch.where( + active_mask, + active_scores, + torch.finfo(active_scores.dtype).min, + ) + running_max, running_sum, running_output = update_online_softmax( + active_scores, + V_active, + running_max, + running_sum, + running_output, + ) + attn_output = (running_output / running_sum).to(Q.dtype) + return attn_output, FlashAttentionStrategy.NONE + def perform_prefill_chunked_attn(self, Q, K, V, q_len, bsz, attention_mask, chunk_size) -> Tensor: """attention computation at prefilling (context encoding) phase using native PyTorch ops""" K_active = repeat_kv(K, self.num_key_value_groups) @@ -947,9 +1278,35 @@ def perform_prefill_chunked_attn(self, Q, K, V, q_len, bsz, attention_mask, chun attn_output = torch.cat(outputs, dim=2) return attn_output, FlashAttentionStrategy.NONE - def perform_prefix_prefill_windowed_attn(self, Q, K, V, q_len, bsz, attention_mask, window_size, past_key_value, active_mask) -> Tensor: + def perform_prefix_prefill_windowed_attn( + self, + Q, + K, + V, + q_len, + bsz, + attention_mask, + window_size, + past_key_value, + active_mask, + active_block_table=None, + computed_context_lens=None, + segmented_past_key_value=None, + ) -> Tensor: """attention computation at prefilling (context encoding) phase with sliding window""" - return self.perform_prefix_prefill(Q, K, V, q_len, bsz, attention_mask, past_key_value, active_mask) + return self.perform_prefix_prefill( + Q, + K, + V, + q_len, + bsz, + attention_mask, + past_key_value, + active_mask, + active_block_table=active_block_table, + computed_context_lens=computed_context_lens, + segmented_past_key_value=segmented_past_key_value, + ) def get_flash_attention_strategy_cp(self, q_len): """ @@ -1460,11 +1817,36 @@ def compute_for_token_gen( return attn_output - def attention_context_encode(self, Q, K, V, q_len, bsz, attention_mask, past_key_value=None, active_mask=None): - if past_key_value is None: + def attention_context_encode( + self, + Q, + K, + V, + q_len, + bsz, + attention_mask, + past_key_value=None, + active_mask=None, + active_block_table=None, + computed_context_lens=None, + segmented_past_key_value=None, + ): + if past_key_value is None and segmented_past_key_value is None: attn_output, flash_attn_strategy = self.perform_prefill(Q, K, V, q_len, bsz, attention_mask) else: - attn_output, flash_attn_strategy = self.perform_prefix_prefill(Q, K, V, q_len, bsz, attention_mask, past_key_value, active_mask) + attn_output, flash_attn_strategy = self.perform_prefix_prefill( + Q, + K, + V, + q_len, + bsz, + attention_mask, + past_key_value, + active_mask, + active_block_table=active_block_table, + computed_context_lens=computed_context_lens, + segmented_past_key_value=segmented_past_key_value, + ) if self.flash_decoding_enabled: K, V = self._filter_kv_for_flash_decoding(K, V, q_len, Q) @@ -1510,7 +1892,21 @@ def attention_context_encode_chunked_attention(self, Q, K, V, q_len, bsz, attent attn_output = attn_output.transpose(1, 2).contiguous() return attn_output, K, V - def attention_context_encode_windowed_attention(self, Q, K, V, q_len, bsz, attention_mask, window_size=None, past_key_value=None, active_mask=None): + def attention_context_encode_windowed_attention( + self, + Q, + K, + V, + q_len, + bsz, + attention_mask, + window_size=None, + past_key_value=None, + active_mask=None, + active_block_table=None, + computed_context_lens=None, + segmented_past_key_value=None, + ): if past_key_value is None: attn_output, flash_attn_strategy = self.perform_prefill_windowed_attn(Q, K, V, q_len, bsz, attention_mask, window_size) if flash_attn_strategy not in [FlashAttentionStrategy.NONE]: @@ -1518,7 +1914,20 @@ def attention_context_encode_windowed_attention(self, Q, K, V, q_len, bsz, atten else: attn_output = attn_output.transpose(1, 2).contiguous() # transpose BHSD -> BSHD else: - attn_output, _ = self.perform_prefix_prefill_windowed_attn(Q, K, V, q_len, bsz, attention_mask, window_size, past_key_value, active_mask) + attn_output, _ = self.perform_prefix_prefill_windowed_attn( + Q, + K, + V, + q_len, + bsz, + attention_mask, + window_size, + past_key_value, + active_mask, + active_block_table=active_block_table, + computed_context_lens=computed_context_lens, + segmented_past_key_value=segmented_past_key_value, + ) attn_output = attn_output.transpose(1, 2).contiguous() # transpose BHSD -> BSHD return attn_output, K, V @@ -1733,9 +2142,32 @@ def standard_causal_attention_forward( if rotary_position_ids is None: rotary_position_ids = position_ids + segmented_raw_past_key_value = None + use_raw_segmented_prefix_cte = ( + get_kv_per_layer + and getattr(self.neuron_config, "is_prefix_caching", False) + and self._prefix_cte_attention_backend() == "segmented_cte" + and kwargs.get("active_block_table") is not None + and getattr(kwargs.get("active_block_table"), "ndim", 0) > 1 + and q_len >= 128 + ) if get_kv_per_layer: assert kv_mgr is not None - past_key_value = kv_mgr.get_kv_by_layer_id(**kwargs) + if use_raw_segmented_prefix_cte: + get_raw_kv_by_layer_id = getattr( + kv_mgr, "get_raw_kv_by_layer_id", None + ) + if get_raw_kv_by_layer_id is None: + raise ValueError( + "segmented_cte prefix prefill requires a block KV " + "manager with get_raw_kv_by_layer_id()." + ) + segmented_raw_past_key_value = get_raw_kv_by_layer_id( + idx=kwargs["idx"], + kvcache_buffer=kwargs.get("kvcache_buffer"), + ) + else: + past_key_value = kv_mgr.get_kv_by_layer_id(**kwargs) is_token_gen = past_key_value is not None @@ -1825,6 +2257,26 @@ def standard_causal_attention_forward( use_polar_compatible_rope=use_polar_compatible_rope, ) + segmented_past_key_value = None + segmented_updated_kv = None + use_segmented_prefix_cte = ( + not is_token_gen + and (past_key_value is not None or segmented_raw_past_key_value is not None) + and self._prefix_cte_attention_backend() == "segmented_cte" + ) + if use_segmented_prefix_cte: + if kv_mgr is None or not update_kv_per_layer: + raise ValueError( + "segmented_cte prefix prefill requires kv_mgr and " + "update_kv_per_layer so active KV is present in the block cache." + ) + segmented_updated_kv = kv_mgr.update_kv_by_layer_id( + kv_per_layer=(K, V), + position_ids=position_ids, + **kwargs, + ) + segmented_past_key_value = segmented_updated_kv + if is_token_gen: attn_output = self.attention_tokengen( Q, K, V, attention_mask, position_ids, past_key_value, active_mask, **kwargs @@ -1833,7 +2285,19 @@ def standard_causal_attention_forward( # transpose BHSD -> BSHD attn_output = attn_output.transpose(1, 2).contiguous() else: - attn_output, K, V = self.attention_context_encode(Q, K, V, q_len, bsz, attention_mask, past_key_value, active_mask) + attn_output, K, V = self.attention_context_encode( + Q, + K, + V, + q_len, + bsz, + attention_mask, + past_key_value, + active_mask, + active_block_table=kwargs.get("active_block_table"), + computed_context_lens=kwargs.get("computed_context_lens"), + segmented_past_key_value=segmented_past_key_value, + ) # merge multi head hidden attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim) @@ -1845,9 +2309,11 @@ def standard_causal_attention_forward( # Output K in BNSd if not transposed, otherwise BNdS K = K.permute(0, 1, 3, 2) - kv: Tuple[Tensor, Tensor] = (K, V) + kv: Tuple[Tensor, Tensor] = ( + segmented_updated_kv if segmented_updated_kv is not None else (K, V) + ) - if update_kv_per_layer: + if update_kv_per_layer and segmented_updated_kv is None: assert kv_mgr is not None kv = kv_mgr.update_kv_by_layer_id( kv_per_layer=kv, diff --git a/src/neuronx_distributed_inference/modules/attention/gqa.py b/src/neuronx_distributed_inference/modules/attention/gqa.py index 56383d38..b44b165a 100644 --- a/src/neuronx_distributed_inference/modules/attention/gqa.py +++ b/src/neuronx_distributed_inference/modules/attention/gqa.py @@ -1,5 +1,6 @@ import enum import logging +import math from typing import Optional, Tuple import torch @@ -16,13 +17,23 @@ from torch.distributed import ProcessGroup from torch.nn import functional as F -from neuronx_distributed_inference.modules.attention.utils import transpose_parallel_linear_layer +from neuronx_distributed_inference.modules.attention.utils import ( + preprocess_quantized_linear_layer, + transpose_parallel_linear_layer, +) from neuronx_distributed_inference.modules.lora_serving.lora_module import is_lora_module import nki from nkilib.core.output_projection.output_projection_cte import output_projection_cte from nkilib.core.qkv.qkv import qkv -from nkilib.core.utils.common_types import NormType, QKVOutputLayout +from nkilib.core.utils.common_types import NormType, QKVOutputLayout, QuantizationType + +try: + from neuronx_distributed_inference.modules.attention.nki_kernels.qwen_gated_output_projection import ( + qwen_gated_output_projection_cte, + ) +except Exception: + qwen_gated_output_projection_cte = None logger = logging.getLogger("Neuron") # To satisfy test_gqa @@ -243,6 +254,58 @@ def _replicate_kv(tensor, source_heads: int, repeats: int, head_dim=0): return tensor.view(shape) +def _rank_block_qwen_q_gate_for_tp( + q_tensor: torch.Tensor, + gate_tensor: torch.Tensor, + *, + num_attention_heads: int, + head_dim: int, + tp_degree: int, + dim: int = 0, +) -> torch.Tensor: + """Pack Qwen Q/gate heads so each TP shard receives local Q then local gate.""" + + if q_tensor is None or gate_tensor is None: + raise ValueError("Qwen packed QKV+gate requires both Q and gate tensors") + if q_tensor.shape != gate_tensor.shape: + raise ValueError( + "Qwen packed QKV+gate requires Q and gate tensors with identical " + f"shapes, got {tuple(q_tensor.shape)} and {tuple(gate_tensor.shape)}" + ) + if num_attention_heads % tp_degree != 0: + raise ValueError( + "Qwen packed QKV+gate requires attention heads divisible by TP degree, " + f"got heads={num_attention_heads}, tp_degree={tp_degree}" + ) + expected_width = num_attention_heads * head_dim + if q_tensor.shape[dim] != expected_width: + raise ValueError( + "Qwen packed QKV+gate tensor width does not match attention shape, " + f"got width={q_tensor.shape[dim]}, expected={expected_width}" + ) + + shape = ( + q_tensor.shape[:dim] + + (num_attention_heads, head_dim) + + q_tensor.shape[dim + 1 :] + ) + q_heads = q_tensor.reshape(shape) + gate_heads = gate_tensor.reshape(shape) + heads_per_rank = num_attention_heads // tp_degree + rank_blocks = [] + for rank in range(tp_degree): + start = rank * heads_per_rank + rank_blocks.append(q_heads.narrow(dim, start, heads_per_rank)) + rank_blocks.append(gate_heads.narrow(dim, start, heads_per_rank)) + packed = torch.cat(rank_blocks, dim=dim) + packed_shape = ( + q_tensor.shape[:dim] + + (2 * expected_width,) + + q_tensor.shape[dim + 1 :] + ) + return packed.reshape(packed_shape).contiguous() + + class BaseGroupQueryAttention(nn.Module): def __init__( self, @@ -370,6 +433,7 @@ def __init__( seq_len_threshold_for_cc_tiling: int = 16834, logical_nc_config: int = 1, qkv_kernel_nbsd_layout: bool = False, + quantized: bool = False, on_cpu: bool = False, rank_ordering: dict = None, ): @@ -404,6 +468,7 @@ def __init__( self.seq_len_threshold_for_cc_tiling = seq_len_threshold_for_cc_tiling self.logical_nc_config = logical_nc_config self.qkv_kernel_nbsd_layout = qkv_kernel_nbsd_layout + self.quantized = quantized self.rank_ordering = rank_ordering if self.tensor_model_parallel_group is not None: @@ -418,7 +483,16 @@ def __init__( tensor_model_parallel_group=self.tensor_model_parallel_group, rank_ordering=rank_ordering, ) - if self.qkv_kernel_enabled or self.qkv_nki_kernel_enabled: + if ( + (self.qkv_kernel_enabled or self.qkv_nki_kernel_enabled) + and self.quantized + ): + setattr( + self.Wqkv, + "post_create_quantized_module_hook", + preprocess_quantized_linear_layer, + ) + elif self.qkv_kernel_enabled or self.qkv_nki_kernel_enabled: # we need to transpose the weights on the CPU side to avoid # needing to transpose on the device when using QKV kernel self.Wqkv.weight = transpose_parallel_linear_layer(self.Wqkv.weight) @@ -580,6 +654,15 @@ def _kernel_qkv_forward(self, hidden_states, rmsnorm, residual, cos_cache, sin_c qkv_norm_type = NormType.RMS_NORM_SKIP_GAMMA fuse_rope = cos_cache is not None and sin_cache is not None + qkv_w_scale = None + qkv_in_scale = None + quantization_type = QuantizationType.NONE + qkv_scale = getattr(self.Wqkv, "scale", None) + if qkv_scale is not None: + qkv_w_scale = qkv_scale.data + qkv_input_scale = getattr(self.Wqkv, "input_scale", None) + qkv_in_scale = qkv_input_scale.data if qkv_input_scale is not None else None + quantization_type = QuantizationType.ROW fused_residual_add = False mlp_prev = None @@ -595,24 +678,54 @@ def _kernel_qkv_forward(self, hidden_states, rmsnorm, residual, cos_cache, sin_c ) mlp_prev = residual - QKV = qkv_kernel[self.logical_nc_config]( - input=hidden_states, - fused_qkv_weights=self.Wqkv.weight.data, - output_layout=qkv_output_layout, - bias=self.Wqkv.bias.data.unsqueeze(0) if self.bias else None, - fused_residual_add=fused_residual_add, - mlp_prev=mlp_prev, - attention_prev=attention_prev, - fused_norm_type=qkv_norm_type, - gamma_norm_weights=rmsnorm.weight.data.unsqueeze(0) if fused_rmsnorm else None, - norm_eps=self.rms_norm_eps, - fused_rope=fuse_rope, - cos_cache=cos_cache, - sin_cache=sin_cache, - d_head=self.head_dim, - num_q_heads=self.num_attention_heads // self.tp_degree, - num_kv_heads=self.num_key_value_heads // self.tp_degree, - ) + # --- Qwen3.6 FP8 qkv_cte workaround (sequence tiling) --- + # The nkilib `qkv` kernel routes S > SEQLEN_THRESHOLD_FOR_QKV_CTE (=96) or + # B*S > pmax (=128) to the `qkv_cte` sub-kernel, whose FP8 path corrupts the + # projection for prefills beyond ~96 tokens (validated: coherent <=96, garbage + # >96). The `qkv_tkg` sub-kernel (B*S <= 128, S <= 96, no fused_rope) is correct. + # The QKV projection is per-token (per-token RMSNorm + per-token residual add, + # NO cross-token mixing here), so slicing the sequence into <=96-token tiles and + # concatenating reproduces the full-S result exactly while routing every sub-call + # to the correct qkv_tkg path. Only valid when fused_rope is off (qkv_tkg has no + # RoPE); for the Qwen3.6 attention path RoPE is applied after this projection. + def _qkv_kernel_call(_input, _mlp_prev, _attention_prev): + return qkv_kernel[self.logical_nc_config]( + input=_input, + fused_qkv_weights=self.Wqkv.weight.data, + output_layout=qkv_output_layout, + bias=self.Wqkv.bias.data.unsqueeze(0) if self.bias else None, + fused_residual_add=fused_residual_add, + mlp_prev=_mlp_prev, + attention_prev=_attention_prev, + fused_norm_type=qkv_norm_type, + gamma_norm_weights=rmsnorm.weight.data.unsqueeze(0) if fused_rmsnorm else None, + norm_eps=self.rms_norm_eps, + fused_rope=fuse_rope, + cos_cache=cos_cache, + sin_cache=sin_cache, + quantization_type=quantization_type, + qkv_w_scale=qkv_w_scale, + qkv_in_scale=qkv_in_scale, + d_head=self.head_dim, + num_q_heads=self.num_attention_heads // self.tp_degree, + num_kv_heads=self.num_key_value_heads // self.tp_degree, + ) + + _qkv_tile = min(96, max(1, 128 // bs)) + if (not fuse_rope) and (seqlen > _qkv_tile or bs * seqlen > 128): + _qkv_parts = [] + for _ts in range(0, seqlen, _qkv_tile): + _te = min(_ts + _qkv_tile, seqlen) + _inp = hidden_states[:, _ts:_te, :].contiguous() + if fused_residual_add: + _mp = mlp_prev[:, _ts:_te, :].contiguous() + _ap = attention_prev[:, _ts:_te, :].contiguous() + else: + _mp, _ap = mlp_prev, attention_prev + _qkv_parts.append(_qkv_kernel_call(_inp, _mp, _ap)) + QKV = torch.cat(_qkv_parts, dim=(2 if self.qkv_kernel_nbsd_layout else 1)) + else: + QKV = _qkv_kernel_call(hidden_states, mlp_prev, attention_prev) if fused_residual_add: residual = hidden_states @@ -680,6 +793,10 @@ def preshard_hook(self, model_state_dict: dict, prefix: str) -> bool: prefix_parts = prefix.split(".") prefix = ".".join(prefix_parts[:-1]) hf_prefix = ".".join(prefix_parts[:-2]) + qwen_qkv_gate_packed = bool(getattr(self, "qwen_qkv_gate_packed", False)) + gate_proj_weight = None + gate_proj_scale = None + gate_proj_bias = None if self.fused_qkv: self.replace_prefixes( old_prefix=f"{hf_prefix}.Wqkv", @@ -690,24 +807,27 @@ def preshard_hook(self, model_state_dict: dict, prefix: str) -> bool: qkv_weight, qkv_scale, _ = self.get_weight( prefix=prefix, layer=self.Wqkv, layer_name="Wqkv", model_state_dict=model_state_dict ) - q_proj_weight, k_proj_weight, v_proj_weight = qkv_weight.split( + q_split_sizes = [self._src_num_attention_heads * self.head_dim] + if qwen_qkv_gate_packed: + q_split_sizes.append(self._src_num_attention_heads * self.head_dim) + q_split_sizes.extend( [ - self._src_num_attention_heads * self.head_dim, self._src_num_key_value_heads * self.head_dim, self._src_num_key_value_heads * self.head_dim, - ], - dim=0, + ] ) + qkv_parts = qkv_weight.split(q_split_sizes, dim=0) + if qwen_qkv_gate_packed: + q_proj_weight, gate_proj_weight, k_proj_weight, v_proj_weight = qkv_parts + else: + q_proj_weight, k_proj_weight, v_proj_weight = qkv_parts if qkv_scale is not None: - q_proj_scale, k_proj_scale, v_proj_scale = qkv_scale.split( - [ - self._src_num_attention_heads * self.head_dim, - self._src_num_key_value_heads * self.head_dim, - self._src_num_key_value_heads * self.head_dim, - ], - dim=0, - ) + qkv_scale_parts = qkv_scale.split(q_split_sizes, dim=0) + if qwen_qkv_gate_packed: + q_proj_scale, gate_proj_scale, k_proj_scale, v_proj_scale = qkv_scale_parts + else: + q_proj_scale, k_proj_scale, v_proj_scale = qkv_scale_parts else: q_proj_scale, k_proj_scale, v_proj_scale = None, None, None @@ -715,14 +835,11 @@ def preshard_hook(self, model_state_dict: dict, prefix: str) -> bool: prefix=prefix, layer=self.Wqkv, layer_name="Wqkv", model_state_dict=model_state_dict ) if qkv_bias is not None: - q_proj_bias, k_proj_bias, v_proj_bias = qkv_bias.split( - [ - self._src_num_attention_heads * self.head_dim, - self._src_num_key_value_heads * self.head_dim, - self._src_num_key_value_heads * self.head_dim, - ], - dim=0, - ) + qkv_bias_parts = qkv_bias.split(q_split_sizes, dim=0) + if qwen_qkv_gate_packed: + q_proj_bias, gate_proj_bias, k_proj_bias, v_proj_bias = qkv_bias_parts + else: + q_proj_bias, k_proj_bias, v_proj_bias = qkv_bias_parts else: q_proj_bias, k_proj_bias, v_proj_bias = None, None, None else: @@ -822,6 +939,22 @@ def preshard_hook(self, model_state_dict: dict, prefix: str) -> bool: target_heads=self.num_attention_heads, source_group_size=self._src_num_attention_heads // self._src_num_key_value_heads, ) + if qwen_qkv_gate_packed: + gate_proj_weight, gate_proj_scale = maybe_pad_interleaved( + gate_proj_weight, + pad_dim=0, + source_heads=self._src_num_attention_heads, + target_heads=self.num_attention_heads, + source_group_size=self._src_num_attention_heads // self._src_num_key_value_heads, + tensor_scale=gate_proj_scale, + ) + gate_proj_bias, _ = maybe_pad_interleaved( + gate_proj_bias, + pad_dim=0, + source_heads=self._src_num_attention_heads, + target_heads=self.num_attention_heads, + source_group_size=self._src_num_attention_heads // self._src_num_key_value_heads, + ) if self.sharding_strategy == GQA.CONVERT_TO_MHA: q_proj_weight, q_proj_scale = maybe_pad_tail( @@ -837,6 +970,20 @@ def preshard_hook(self, model_state_dict: dict, prefix: str) -> bool: target_heads=self.num_attention_heads, pad_dim=0, ) + if qwen_qkv_gate_packed: + gate_proj_weight, gate_proj_scale = maybe_pad_tail( + gate_proj_weight, + source_heads=self._src_num_attention_heads, + target_heads=self.num_attention_heads, + pad_dim=0, + tensor_scale=gate_proj_scale, + ) + gate_proj_bias, _ = maybe_pad_tail( + gate_proj_bias, + source_heads=self._src_num_attention_heads, + target_heads=self.num_attention_heads, + pad_dim=0, + ) k_proj_weight, k_proj_scale = maybe_pad_tail( k_proj_weight, source_heads=self._src_num_key_value_heads, @@ -865,18 +1012,47 @@ def preshard_hook(self, model_state_dict: dict, prefix: str) -> bool: ) if self.fused_qkv: - qkv_weight = torch.cat([q_proj_weight, k_proj_weight, v_proj_weight], dim=0) + if qwen_qkv_gate_packed: + q_gate_weight = _rank_block_qwen_q_gate_for_tp( + q_proj_weight, + gate_proj_weight, + num_attention_heads=self.num_attention_heads, + head_dim=self.head_dim, + tp_degree=self.tp_degree, + ) + qkv_weight_parts = [q_gate_weight] + else: + qkv_weight_parts = [q_proj_weight] + qkv_weight_parts.extend([k_proj_weight, v_proj_weight]) + qkv_weight = torch.cat(qkv_weight_parts, dim=0) qkv_scale = None - if all(scale is not None for scale in (q_proj_scale, k_proj_scale, v_proj_scale)): - qkv_scale = torch.cat([q_proj_scale, k_proj_scale, v_proj_scale], dim=0) + if qwen_qkv_gate_packed: + q_gate_scale = None + if q_proj_scale is not None and gate_proj_scale is not None: + q_gate_scale = _rank_block_qwen_q_gate_for_tp( + q_proj_scale, + gate_proj_scale, + num_attention_heads=self.num_attention_heads, + head_dim=self.head_dim, + tp_degree=self.tp_degree, + ) + qkv_scale_parts = [q_gate_scale] + else: + qkv_scale_parts = [q_proj_scale] + qkv_scale_parts.extend([k_proj_scale, v_proj_scale]) + if all(scale is not None for scale in qkv_scale_parts): + qkv_scale = torch.cat(qkv_scale_parts, dim=0) # Set heads info as weight parameter attributes to be used in weights sharding fused_qkv_params = ( [self.Wqkv.weight, self.Wqkv.scale] if qkv_scale is not None else [self.Wqkv.weight] ) + packed_num_attention_heads = ( + self.num_attention_heads * 2 if qwen_qkv_gate_packed else self.num_attention_heads + ) for param in fused_qkv_params: setattr(param, "fused_qkv", True) - setattr(param, "num_attention_heads", self.num_attention_heads) + setattr(param, "num_attention_heads", packed_num_attention_heads) setattr(param, "num_key_value_heads", self.num_key_value_heads) setattr(param, "head_dim", self.head_dim) @@ -889,7 +1065,19 @@ def preshard_hook(self, model_state_dict: dict, prefix: str) -> bool: scale=qkv_scale, ) if self.bias: - qkv_bias = torch.cat([q_proj_bias, k_proj_bias, v_proj_bias], dim=0) + if qwen_qkv_gate_packed: + q_gate_bias = _rank_block_qwen_q_gate_for_tp( + q_proj_bias, + gate_proj_bias, + num_attention_heads=self.num_attention_heads, + head_dim=self.head_dim, + tp_degree=self.tp_degree, + ) + qkv_bias_parts = [q_gate_bias] + else: + qkv_bias_parts = [q_proj_bias] + qkv_bias_parts.extend([k_proj_bias, v_proj_bias]) + qkv_bias = torch.cat(qkv_bias_parts, dim=0) self.set_bias( tensor=qkv_bias, prefix=prefix, @@ -973,6 +1161,7 @@ def __init__( logical_nc_config: int = 1, rank_ordering: dict = None, tiling_factor: int = 1, + quantized: bool = False, ): super().__init__( hidden_size=hidden_size, @@ -993,6 +1182,7 @@ def __init__( self.rpl_reduce_dtype = rpl_reduce_dtype self.sequence_parallel_enabled = sequence_parallel_enabled self.rank_ordering = rank_ordering + self.quantized = quantized if self.tensor_model_parallel_group is not None: self.o_proj = RowParallelLinear( @@ -1008,7 +1198,13 @@ def __init__( rank_ordering=rank_ordering, tile_cc=self.tiling_factor > 1, ) - if self.out_proj_kernel_enabled: + if self.out_proj_kernel_enabled and self.quantized: + setattr( + self.o_proj, + "post_create_quantized_module_hook", + preprocess_quantized_linear_layer, + ) + elif self.out_proj_kernel_enabled: # we need to transpose the weights on the CPU side to avoid # needing to transpose on the device when using out proj kernel self.o_proj.weight = transpose_parallel_linear_layer(self.o_proj.weight) @@ -1033,11 +1229,47 @@ def _kernel_o_proj(self, attention_output): heads_per_core = self.num_attention_heads // self.tp_degree assert ( nd == heads_per_core * self.head_dim - ), f"attention_output.shape = {attention_output.shape}, heads_per_core = {heads_per_core}, head_dim = {self.head_dim}" + ), ( + f"attention_output.shape = {attention_output.shape}, " + f"heads_per_core = {heads_per_core}, head_dim = {self.head_dim}" + ) - # Kernel wants BndS layout for input. attention_output = attention_output.reshape(B, S, heads_per_core, self.head_dim) - kernel_attn_in = attention_output.permute(0, 2, 3, 1) + o_proj_scale = getattr(self.o_proj, "scale", None) + if o_proj_scale is not None: + # ROW path dynamically quantizes BF16 attention rows on-device and + # consumes FP8 weights plus per-row dequant scales shaped [128, H]. + max_row_quant_head_dim = 128 + if self.head_dim > max_row_quant_head_dim: + fold_factor = math.ceil(self.head_dim / max_row_quant_head_dim) + while self.head_dim % fold_factor != 0: + fold_factor += 1 + folded_head_dim = self.head_dim // fold_factor + folded_heads = heads_per_core * fold_factor + if folded_head_dim > max_row_quant_head_dim: + raise RuntimeError( + "Output projection NKI ROW FP8 path cannot fold " + f"head_dim={self.head_dim} to <= {max_row_quant_head_dim}" + ) + if folded_heads > 17: + raise RuntimeError( + "Output projection NKI ROW FP8 path exceeds validated " + f"head count after folding: heads={folded_heads}" + ) + kernel_attn_in = attention_output.reshape( + B, S, folded_heads, folded_head_dim + ) + else: + kernel_attn_in = attention_output + quantization_type = QuantizationType.ROW + weight_scales = o_proj_scale.data + elif self.quantized: + raise RuntimeError("Output projection NKI FP8 path requires o_proj.scale") + else: + # Non-quantized kernel wants BndS layout for input. + kernel_attn_in = attention_output.permute(0, 2, 3, 1) + quantization_type = QuantizationType.NONE + weight_scales = None out = torch.zeros(B, S, H, dtype=attention_output.dtype, device=attention_output.device) @@ -1045,6 +1277,8 @@ def _kernel_o_proj(self, attention_output): attention=kernel_attn_in, weight=self.o_proj.weight.data, bias=self.o_proj.bias.data.unsqueeze(0) / self.tp_degree if self.bias else None, + quantization_type=quantization_type, + weight_scales=weight_scales, ) # All-reduce or reduce-scatter, depending on whether SP is enabled @@ -1064,6 +1298,111 @@ def _kernel_o_proj(self, attention_output): return out + def _kernel_gated_o_proj(self, attention_output, gate): + if qwen_gated_output_projection_cte is None: + return self._kernel_o_proj(attention_output * torch.sigmoid(gate)) + + logger.debug( + f"Qwen gated output projection kernel: logical_nc_config={self.logical_nc_config}" + ) + nd, H = self.o_proj.weight.shape + B, S, attn_nd = attention_output.shape + if attn_nd != nd: + raise RuntimeError( + f"attention_output.shape = {attention_output.shape}, " + f"o_proj.weight.shape = {self.o_proj.weight.shape}" + ) + if gate.shape != attention_output.shape: + raise RuntimeError( + "Qwen gated output projection requires gate shape to match " + f"attention_output shape, got gate={gate.shape}, " + f"attention_output={attention_output.shape}" + ) + + heads_per_core = self.num_attention_heads // self.tp_degree + assert ( + nd == heads_per_core * self.head_dim + ), ( + f"attention_output.shape = {attention_output.shape}, " + f"heads_per_core = {heads_per_core}, head_dim = {self.head_dim}" + ) + + attention_output = attention_output.reshape(B, S, heads_per_core, self.head_dim) + gate = gate.reshape(B, S, heads_per_core, self.head_dim) + + o_proj_scale = getattr(self.o_proj, "scale", None) + if o_proj_scale is None: + gated = attention_output.reshape(B, S, nd) * torch.sigmoid( + gate.reshape(B, S, nd) + ) + return self._kernel_o_proj(gated) + + max_row_quant_head_dim = 128 + if self.head_dim > max_row_quant_head_dim: + fold_factor = math.ceil(self.head_dim / max_row_quant_head_dim) + while self.head_dim % fold_factor != 0: + fold_factor += 1 + folded_head_dim = self.head_dim // fold_factor + folded_heads = heads_per_core * fold_factor + if folded_head_dim > max_row_quant_head_dim: + raise RuntimeError( + "Qwen gated output projection ROW FP8 path cannot fold " + f"head_dim={self.head_dim} to <= {max_row_quant_head_dim}" + ) + if folded_heads > 17: + raise RuntimeError( + "Qwen gated output projection ROW FP8 path exceeds validated " + f"head count after folding: heads={folded_heads}" + ) + kernel_attn_in = attention_output.reshape( + B, S, folded_heads, folded_head_dim + ) + kernel_gate_in = gate.reshape(B, S, folded_heads, folded_head_dim) + else: + kernel_attn_in = attention_output + kernel_gate_in = gate + + out = qwen_gated_output_projection_cte[self.logical_nc_config]( + attention=kernel_attn_in, + gate=kernel_gate_in, + weight=self.o_proj.weight.data, + bias=( + self.o_proj.bias.data.unsqueeze(0) / self.tp_degree + if self.bias + else None + ), + weight_scales=o_proj_scale.data, + ) + + original_dtype = out.dtype + out = out.to(self.rpl_reduce_dtype) + + if self.sequence_parallel_enabled: + out = reduce_scatter_to_sequence_parallel_region( + out, 1, process_group=self.tensor_model_parallel_group + ) + else: + out = reduce_from_tensor_model_parallel_region( + out, process_group=self.tensor_model_parallel_group + ) + + out = out.to(original_dtype) + + return out + + def forward_gated( + self, attention_output: torch.Tensor, gate: torch.Tensor, adapter_ids=None + ): + if ( + self.out_proj_kernel_enabled + and self.quantized + and getattr(self.o_proj, "scale", None) is not None + ): + return self._kernel_gated_o_proj(attention_output, gate) + + gated = attention_output * torch.sigmoid(gate) + return self.forward(gated, adapter_ids=adapter_ids) + def forward(self, attention_output: torch.Tensor, adapter_ids=None): if self.out_proj_kernel_enabled: return self._kernel_o_proj(attention_output) diff --git a/src/neuronx_distributed_inference/modules/attention/nki_kernels/__init__.py b/src/neuronx_distributed_inference/modules/attention/nki_kernels/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/neuronx_distributed_inference/modules/attention/nki_kernels/qwen_gated_output_projection.py b/src/neuronx_distributed_inference/modules/attention/nki_kernels/qwen_gated_output_projection.py new file mode 100644 index 00000000..0ee70d51 --- /dev/null +++ b/src/neuronx_distributed_inference/modules/attention/nki_kernels/qwen_gated_output_projection.py @@ -0,0 +1,371 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Qwen attention gate fusion for ROW FP8 output projection CTE.""" + +from typing import List, Optional + +import nki +import nki.isa as nisa +import nki.language as nl + +from nkilib.core.output_projection.output_projection_cte.output_projection_cte_parameters import ( + P_MAX, + QuantizationConfig, + TilingConfig, + build_quantization_config, + build_tiling_config, + validate_output_projection_inputs, +) +from nkilib.core.output_projection.output_projection_cte.output_projection_cte_quantization import ( + _compute_row_matmul_dequantize, + _perform_input_row_quantization, + _write_results_to_output, +) +from nkilib.core.output_projection.output_projection_cte.output_projection_cte_tensor_io import ( + load_bias, + load_quantized_weights, + load_row_weight_dequant_scales, +) +from nkilib.core.utils.common_types import QuantizationType +from nkilib.core.utils.kernel_assert import kernel_assert +from nkilib.core.utils.kernel_helpers import get_program_sharding_info +from nkilib.core.utils.tensor_view import TensorView + + +def _process_gated_row_quantized_batch_tile( + attention_view: TensorView, + gate_view: TensorView, + output_view: TensorView, + w_sbuf_list: List[nl.ndarray], + bias_sbuf: Optional[nl.ndarray], + weight_row_scale_sbuf: nl.ndarray, + s_block_idx: int, + h_block_idx: int, + n_orig: int, + d_orig: int, + cfg: TilingConfig, + quant_config: QuantizationConfig, +) -> None: + curr_h_block_size = cfg.h_tile.get_tile_bound(h_block_idx) + s_start = s_block_idx * cfg.s_tile.tile_size + + h_dim = n_orig * d_orig + quant_attention_sb = [] + global_dequant_scales = [] + + attn_flat_view = attention_view.reshape((attention_view.shape[0], h_dim)) + gate_flat_view = gate_view.reshape((gate_view.shape[0], h_dim)) + + xpose_dtype = quant_config.quant_data_type if quant_config.use_double_row else nl.bfloat16 + transposed_heads = [] + for head_idx in range(n_orig): + transposed_heads.append( + nl.ndarray((d_orig, cfg.s_tile.tile_size), dtype=xpose_dtype, buffer=nl.sbuf) + ) + + for s_sub_idx in range(cfg.s_tile.subtile_dim_info.tile_count): + curr_s_sub = cfg.s_tile.get_local_subtile_bound(s_block_idx, s_sub_idx) + if curr_s_sub <= 0: + break + s_sub_start = cfg.s_tile.get_local_subtile_start(s_sub_idx) + + attn_sub = nl.ndarray((P_MAX, h_dim), dtype=nl.bfloat16, buffer=nl.sbuf) + attn_sub_view = attn_flat_view.slice( + dim=0, start=s_sub_start, end=s_sub_start + curr_s_sub + ) + nisa.dma_copy(dst=attn_sub[:curr_s_sub, :h_dim], src=attn_sub_view.get_view()) + + gate_sub = nl.ndarray((P_MAX, h_dim), dtype=nl.bfloat16, buffer=nl.sbuf) + gate_sub_view = gate_flat_view.slice( + dim=0, start=s_sub_start, end=s_sub_start + curr_s_sub + ) + nisa.dma_copy(dst=gate_sub[:curr_s_sub, :h_dim], src=gate_sub_view.get_view()) + nisa.activation( + dst=gate_sub[:curr_s_sub, :h_dim], + data=gate_sub[:curr_s_sub, :h_dim], + op=nl.sigmoid, + ) + nisa.tensor_tensor( + dst=attn_sub[:curr_s_sub, :h_dim], + data1=attn_sub[:curr_s_sub, :h_dim], + data2=gate_sub[:curr_s_sub, :h_dim], + op=nl.multiply, + ) + + quant_sub, dequant_scale = _perform_input_row_quantization( + input_sbuf=attn_sub[:curr_s_sub, :h_dim], + quant_dtype=quant_config.quant_data_type, + ) + global_dequant_scales.append(dequant_scale) + + quant_nd = quant_sub.reshape((curr_s_sub, n_orig, d_orig)) + for head_idx in range(n_orig): + if quant_config.use_double_row: + fp8_head = nl.ndarray( + (curr_s_sub, d_orig), + dtype=quant_config.quant_data_type, + buffer=nl.sbuf, + ) + nisa.tensor_copy( + dst=fp8_head, + src=quant_nd[:curr_s_sub, head_idx, :d_orig], + ) + + fp8_psum_step = 2 + xpose_psum = nl.ndarray( + (d_orig, curr_s_sub, fp8_psum_step), + dtype=quant_config.quant_data_type, + buffer=nl.psum, + ) + nisa.nc_transpose( + dst=xpose_psum.ap( + [[curr_s_sub * fp8_psum_step, d_orig], [fp8_psum_step, curr_s_sub]], + offset=0, + ), + data=fp8_head, + ) + nisa.tensor_copy( + dst=transposed_heads[head_idx][ + :d_orig, s_sub_start : s_sub_start + curr_s_sub + ], + src=xpose_psum[:d_orig, :curr_s_sub, 0], + ) + else: + nisa.dma_transpose( + dst=transposed_heads[head_idx][ + :d_orig, s_sub_start : s_sub_start + curr_s_sub + ], + src=quant_nd[:curr_s_sub, head_idx, :d_orig], + ) + + if quant_config.use_double_row: + if n_orig == cfg.n_size // 2: + for head_idx in range(n_orig): + packed_sb = nl.ndarray( + (cfg.d_size, 2, cfg.s_tile.tile_size), + dtype=xpose_dtype, + buffer=nl.sbuf, + ) + nisa.dma_copy( + dst=packed_sb[: cfg.d_size, 0:1, : cfg.s_tile.tile_size], + src=transposed_heads[head_idx][: cfg.d_size, : cfg.s_tile.tile_size], + ) + nisa.dma_copy( + dst=packed_sb[: cfg.d_size, 1:2, : cfg.s_tile.tile_size], + src=transposed_heads[head_idx][ + cfg.d_size : d_orig, : cfg.s_tile.tile_size + ], + ) + quant_attention_sb.append(packed_sb) + else: + for pair_idx in range(n_orig // 2): + packed_sb = nl.ndarray( + (cfg.d_size, 2, cfg.s_tile.tile_size), + dtype=xpose_dtype, + buffer=nl.sbuf, + ) + nisa.dma_copy( + dst=packed_sb[: cfg.d_size, 0:1, : cfg.s_tile.tile_size], + src=transposed_heads[pair_idx * 2][: cfg.d_size, : cfg.s_tile.tile_size], + ) + nisa.dma_copy( + dst=packed_sb[: cfg.d_size, 1:2, : cfg.s_tile.tile_size], + src=transposed_heads[pair_idx * 2 + 1][ + : cfg.d_size, : cfg.s_tile.tile_size + ], + ) + quant_attention_sb.append(packed_sb) + else: + quant_attention_sb = transposed_heads + + result_sb = _compute_row_matmul_dequantize( + quant_attention_sb=quant_attention_sb, + w_sbuf_list=w_sbuf_list, + bias_sbuf=bias_sbuf, + weight_row_scale_sbuf=weight_row_scale_sbuf, + input_dequant_scale_sb=[global_dequant_scales], + s_block_idx=s_block_idx, + h_block_idx=h_block_idx, + curr_h_block_size=curr_h_block_size, + attention_dtype=nl.bfloat16, + cfg=cfg, + quant_config=quant_config, + ) + + _write_results_to_output( + result_sb=result_sb, + output_view=output_view, + s_start=s_start, + s_block_idx=s_block_idx, + curr_h_block_size=curr_h_block_size, + cfg=cfg, + ) + + +def _perform_gated_row_quantized_projection( + attention_hbm: nl.ndarray, + gate_hbm: nl.ndarray, + weight_hbm: nl.ndarray, + output_hbm: nl.ndarray, + bias_hbm: Optional[nl.ndarray], + weight_scale_hbm: nl.ndarray, + prg_id: int, + cfg: TilingConfig, + quant_config: QuantizationConfig, +) -> None: + weight_hbm = weight_hbm.reshape((cfg.n_size, cfg.d_size, cfg.h_size)) + n_orig = attention_hbm.shape[2] + d_orig = attention_hbm.shape[3] + + for h_block_idx in range(cfg.h_tile.tile_count): + h_start = cfg.h_sharded_size * prg_id + h_block_idx * cfg.h_tile.tile_size + curr_h_block_size = cfg.h_tile.get_tile_bound(h_block_idx) + + weight_view = TensorView(weight_hbm).slice( + dim=2, start=h_start, end=h_start + curr_h_block_size + ) + w_sbuf_list = load_quantized_weights( + weight_view=weight_view, cfg=cfg, quant_config=quant_config + ) + + weight_row_scale_sbuf = load_row_weight_dequant_scales( + weight_scale_hbm, + h_start, + curr_h_block_size, + cfg.h_tile.tile_size, + ) + + bias_sbuf = None + if bias_hbm != None: + bias_view = TensorView(bias_hbm).slice( + dim=1, start=h_start, end=h_start + curr_h_block_size + ) + bias_sbuf = load_bias(bias_view=bias_view, cfg=cfg) + + for batch_idx in range(cfg.b_size): + for s_block_idx in range(cfg.s_tile.tile_count): + curr_s_tile_size = cfg.s_tile.get_tile_bound(s_block_idx) + s_start = s_block_idx * cfg.s_tile.tile_size + + attention_view = ( + TensorView(attention_hbm) + .select(dim=0, index=batch_idx) + .slice(dim=0, start=s_start, end=s_start + curr_s_tile_size) + ) + gate_view = ( + TensorView(gate_hbm) + .select(dim=0, index=batch_idx) + .slice(dim=0, start=s_start, end=s_start + curr_s_tile_size) + ) + output_view = ( + TensorView(output_hbm) + .select(dim=0, index=batch_idx) + .slice(dim=1, start=h_start, end=h_start + curr_h_block_size) + ) + + _process_gated_row_quantized_batch_tile( + attention_view=attention_view, + gate_view=gate_view, + output_view=output_view, + w_sbuf_list=w_sbuf_list, + bias_sbuf=bias_sbuf, + weight_row_scale_sbuf=weight_row_scale_sbuf, + s_block_idx=s_block_idx, + h_block_idx=h_block_idx, + n_orig=n_orig, + d_orig=d_orig, + cfg=cfg, + quant_config=quant_config, + ) + + +@nki.jit +def qwen_gated_output_projection_cte( + attention: nl.ndarray, + gate: nl.ndarray, + weight: nl.ndarray, + bias: Optional[nl.ndarray] = None, + weight_scales: Optional[nl.ndarray] = None, + output_dtype: Optional[type] = None, +) -> nl.ndarray: + """Compute output projection over ``attention * sigmoid(gate)`` for Qwen CTE.""" + kernel_assert( + len(attention.shape) == 4, + f"Qwen gated output projection expects attention [B, S, N, D], got {len(attention.shape)}D", + ) + kernel_assert( + len(gate.shape) == 4, + f"Qwen gated output projection expects gate [B, S, N, D], got {len(gate.shape)}D", + ) + b_size, s_size, n_size, d_size = attention.shape + kernel_assert( + gate.shape[0] == b_size + and gate.shape[1] == s_size + and gate.shape[2] == n_size + and gate.shape[3] == d_size, + "Qwen gated output projection requires gate shape to match attention shape", + ) + _, h_size = weight.shape + + _, n_prgs, prg_id = get_program_sharding_info() + if n_prgs == None: + n_prgs = 1 + prg_id = 0 + + validate_output_projection_inputs( + b_size=b_size, + n_size=n_size, + d_size=d_size, + s_size=s_size, + h_size=h_size, + n_prgs=n_prgs, + attention_dtype=attention.dtype, + weight_dtype=weight.dtype, + quantization_type=QuantizationType.ROW, + input_scales=None, + weight_scales=weight_scales, + ) + quant_config = build_quantization_config( + quantization_type=QuantizationType.ROW, + input_scales=None, + weight_scales=weight_scales, + input_data_type=attention.dtype, + weight_data_type=weight.dtype, + ) + tiling_config = build_tiling_config( + b_size=b_size, + n_size=n_size, + d_size=d_size, + s_size=s_size, + h_size=h_size, + n_prgs=n_prgs, + quant_config=quant_config, + weight_dtype=weight.dtype, + ) + + out_dtype = output_dtype if output_dtype != None else attention.dtype + out = nl.ndarray((b_size, s_size, h_size), dtype=out_dtype, buffer=nl.shared_hbm) + _perform_gated_row_quantized_projection( + attention_hbm=attention, + gate_hbm=gate, + weight_hbm=weight, + output_hbm=out, + bias_hbm=bias, + weight_scale_hbm=weight_scales, + prg_id=prg_id, + cfg=tiling_config, + quant_config=quant_config, + ) + return out diff --git a/src/neuronx_distributed_inference/modules/attention/nki_kernels/qwen_segcte256/__init__.py b/src/neuronx_distributed_inference/modules/attention/nki_kernels/qwen_segcte256/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/neuronx_distributed_inference/modules/attention/nki_kernels/qwen_segcte256/attention_segmented_cte_256.py b/src/neuronx_distributed_inference/modules/attention/nki_kernels/qwen_segcte256/attention_segmented_cte_256.py new file mode 100644 index 00000000..f3c0e880 --- /dev/null +++ b/src/neuronx_distributed_inference/modules/attention/nki_kernels/qwen_segcte256/attention_segmented_cte_256.py @@ -0,0 +1,2113 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Segmented attention operations with block-based KV cache support. + +This module provides utilities for attention computation with segmented KV cache, +supporting iterative processing of large sequences through multiple segments. +""" + +import math +from typing import Optional + +import nki.isa as nisa +import nki.language as nl + +from nkilib.core.utils.attention_reduce import _MAX_FREE_TILES, reduce_one_batch +from nkilib.core.utils.kernel_assert import kernel_assert +from nkilib.core.utils.kernel_helpers import get_verified_program_sharding_info +from nkilib.core.utils.modular_allocator import ModularAllocator +from nkilib.core.attention.attention_cte import ( + _K_TILE_SZ, + _V_TILE_SZ, + _attention_cte, +) +from .fused_segmented_attention_256 import fused_segmented_attention_impl + +_QWEN256_D_CHUNK = 128 + + +def _alloc_k_cache_sbuf(allocator, head_dim, num_k_tiles): + if head_dim <= _QWEN256_D_CHUNK: + return allocator.alloc_sbuf_tensor( + shape=(head_dim, _K_TILE_SZ), + dtype=nl.bfloat16, + block_dim=[num_k_tiles], + num_free_tiles=[num_k_tiles], + align_to=32, + ) + + kernel_assert(head_dim == 256, f"qwen_segcte256 expects head_dim 256 when splitting K, got {head_dim}") + k_lo = allocator.alloc_sbuf_tensor( + shape=(_QWEN256_D_CHUNK, _K_TILE_SZ), + dtype=nl.bfloat16, + block_dim=[num_k_tiles], + num_free_tiles=[num_k_tiles], + align_to=32, + ) + k_hi = allocator.alloc_sbuf_tensor( + shape=(_QWEN256_D_CHUNK, _K_TILE_SZ), + dtype=nl.bfloat16, + block_dim=[num_k_tiles], + num_free_tiles=[num_k_tiles], + align_to=32, + ) + k_tiles = [] + for i in range(num_k_tiles): + k_tiles.append((k_lo[i], k_hi[i])) + return k_tiles + + +def floor_nisa_kernel(src_t: nl.ndarray, dst_t: nl.ndarray, p_size: int, f_size: int, allocator: ModularAllocator): + """ + NISA implementation for floor operation using integer casting. + + Algorithm: + casted = (int) a + b = (float) casted + larger = (b > a) * (casted - 1) + smaller = (b <= a) * casted + floor(a) = larger + smaller + + Args: + src_t: Source tensor to compute floor of (dtype: fp32) + dst_t: Destination tensor for floor result (dtype: int32) + p_size: First dimension size + f_size: Second dimension size + allocator: SBUF allocator for temporary tensors + """ + orig_addr = allocator.get_current_address() + + dst_cast = allocator.alloc_sbuf_tensor((p_size, f_size), nl.int32, align_to=4) + dst_cast_back = allocator.alloc_sbuf_tensor((p_size, f_size), nl.float32, align_to=4) + dst_cast_minus1 = allocator.alloc_sbuf_tensor((p_size, f_size), nl.int32, align_to=4) + + nisa.tensor_copy(dst=dst_cast, src=src_t) + nisa.tensor_copy(dst=dst_cast_back, src=dst_cast) + nisa.tensor_scalar(dst=dst_cast_minus1[...], data=dst_cast[...], op0=nl.subtract, operand0=1) + + condition = allocator.alloc_sbuf_tensor((p_size, f_size), nl.int8) + condition_not = allocator.alloc_sbuf_tensor((p_size, f_size), nl.int8) + + nisa.tensor_tensor(dst=condition[...], data1=dst_cast_back[...], data2=src_t[...], op=nl.greater) + nisa.tensor_scalar(dst=condition_not[...], data=condition[...], op0=nl.logical_xor, operand0=1) + + smaller = allocator.alloc_sbuf_tensor((p_size, f_size), nl.int32, align_to=4) + larger = allocator.alloc_sbuf_tensor((p_size, f_size), nl.int32, align_to=4) + nisa.tensor_tensor(dst=smaller[...], data1=dst_cast[...], data2=condition_not[...], op=nl.multiply) + nisa.tensor_tensor(dst=larger[...], data1=dst_cast_minus1[...], data2=condition[...], op=nl.multiply) + + nisa.tensor_tensor(dst=dst_t, data1=larger[...], data2=smaller[...], op=nl.add) + + allocator.set_current_address(address=orig_addr) + + +def ceil_nisa_kernel(src_t: nl.ndarray, dst_t: nl.ndarray, p_size: int, f_size: int, allocator: ModularAllocator): + """ + NISA implementation for ceil operation using floor. + + Algorithm: + ceil(x) = -floor(-x) + + Args: + src_t: Source tensor to compute ceil of (dtype: fp32) + dst_t: Destination tensor for ceil result (dtype: int32) + p_size: First dimension size + f_size: Second dimension size + allocator: SBUF allocator for temporary tensors + """ + orig_addr = allocator.get_current_address() + + # Negate input + neg_src = allocator.alloc_sbuf_tensor((p_size, f_size), nl.float32, align_to=4) + nisa.tensor_scalar(dst=neg_src[...], data=src_t[...], op0=nl.multiply, operand0=-1.0) + + # Compute floor(-x) + floor_neg = allocator.alloc_sbuf_tensor((p_size, f_size), nl.int32, align_to=4) + floor_nisa_kernel(src_t=neg_src, dst_t=floor_neg, p_size=p_size, f_size=f_size, allocator=allocator) + + # Negate result: -floor(-x) + nisa.tensor_scalar(dst=dst_t[...], data=floor_neg[...], op0=nl.multiply, operand0=-1) + + allocator.set_current_address(address=orig_addr) + + +def load_kv_cache( + k_cache, + v_cache, + block_tables, + k_sbuf, + v_sbuf, + b_i, + h_i, + block_table_offset, + num_blocks, + allocator: ModularAllocator, + k_pre_transposed: bool = False, +): + """ + Load KV cache from block tables to SBUF for a single KV head. + + Args: + k_cache: K cache in HBM. Shape depends on k_pre_transposed: + - False: (num_blocks_total, num_kv_head, block_size, head_dim) + - True: (num_blocks_total * num_kv_head, head_dim, block_size) + v_cache: V cache in HBM with shape (num_blocks_total, num_kv_head, block_size, head_dim) + block_tables: Block table tensor with shape (batch_size, max_blocks_per_seq) + k_sbuf: K SBUF tiles to load into + v_sbuf: V SBUF tiles to load into + b_i: Current sequence index in batch + h_i: Current KV head index + block_table_offset: SBUF tensor (1, 1) indicating the block offset for the current segment + num_blocks: Number of blocks to load + allocator: SBUF allocator for temporary tensor allocation + k_pre_transposed: If True, K cache is already stored in transposed layout + (head_dim, block_size) per block, so no transpose is needed during loading. + """ + kernel_assert( + not k_pre_transposed, + "qwen_segcte256 supports only k_pre_transposed=False; " + "the transposed-K path has not been production validated", + ) + + num_kv_head = v_cache.shape[1] + block_size = v_cache.shape[2] + head_dim = v_cache.shape[3] + bs, max_blocks_per_seq = block_tables.shape + + # Get K_TILE_SIZE and V_TILE_SIZE from k_sbuf and v_sbuf shapes + K_TILE_SIZE = k_sbuf[0][0].shape[1] if head_dim == 256 else k_sbuf[0].shape[1] + V_TILE_SIZE = v_sbuf[0].shape[0] + + # Store the original sbuf address + orig_sbuf_addr = allocator.get_current_address() + + kernel_assert( + K_TILE_SIZE >= block_size and K_TILE_SIZE % block_size == 0, + f"K_TILE_SIZE must be >= block_size and divisible by block_size", + ) + num_blocks_per_k_tile = K_TILE_SIZE // block_size + num_k_tiles = num_blocks // num_blocks_per_k_tile + + # Chunk the loading into iterations of up to MAX_BLOCKS_PER_LOAD blocks each. + # This avoids SBUF dimension constraints on block index tensors. + MAX_BLOCKS_PER_LOAD = 128 + num_chunks = math.ceil(num_blocks / MAX_BLOCKS_PER_LOAD) + + for chunk_i in range(num_chunks): + chunk_start = chunk_i * MAX_BLOCKS_PER_LOAD + chunk_num_blocks = min(MAX_BLOCKS_PER_LOAD, num_blocks - chunk_start) + + # Save SBUF address so each chunk's temp allocations are freed + chunk_sbuf_addr = allocator.get_current_address() + + # Compute the block_table_offset for this chunk + chunk_block_table_offset = allocator.alloc_sbuf_tensor(shape=(1, 1), dtype=nl.int32) + if chunk_start == 0: + nisa.tensor_copy(dst=chunk_block_table_offset, src=block_table_offset) + else: + nisa.tensor_scalar( + dst=chunk_block_table_offset, + data=block_table_offset, + op0=nl.add, + operand0=chunk_start, + ) + + # Load block indices from block table as (1, chunk_num_blocks) for V cache / K fallback scalar_offset usage + block_table_idx_before_tp = allocator.alloc_sbuf_tensor(shape=(1, chunk_num_blocks), dtype=nl.uint32) + nisa.dma_copy( + src=block_tables.ap( + pattern=[[1, 1], [1, chunk_num_blocks]], + offset=b_i * max_blocks_per_seq, + scalar_offset=chunk_block_table_offset, + indirect_dim=1, + ), + dst=block_table_idx_before_tp, + ) + + chunk_num_k_tiles = chunk_num_blocks // num_blocks_per_k_tile + chunk_k_tile_start = chunk_start // num_blocks_per_k_tile + + # Load K cache: transposed path (already head_dim x block_size) or original path (needs transpose) + if k_pre_transposed: + # K cache is (num_blocks_total * num_kv_head, head_dim, block_size) — already transposed. + # Per-block dma_copy with scalar_offset. + # For block b and head h, the index into dim-0 is b * num_kv_head + h. + block_table_idx_tp = allocator.alloc_sbuf_tensor(shape=(1, chunk_num_blocks), dtype=nl.uint32) + nisa.tensor_scalar( + dst=block_table_idx_tp, + data=block_table_idx_before_tp, + op0=nl.multiply, + operand0=num_kv_head, + ) + for i in range(chunk_num_k_tiles): + for j in range(num_blocks_per_k_tile): + blk_idx = i * num_blocks_per_k_tile + j + + nisa.dma_copy( + dst=k_sbuf[chunk_k_tile_start + i].ap( + pattern=[[K_TILE_SIZE, head_dim], [1, block_size]], offset=j * block_size + ), + src=k_cache.ap( + pattern=[[block_size, head_dim], [1, block_size]], + offset=h_i * head_dim * block_size, + scalar_offset=block_table_idx_tp.ap( + pattern=[[chunk_num_blocks, 1], [1, 1]], offset=blk_idx + ), + ), + dge_mode=nisa.dge_mode.hwdge, + ) + else: + # Load K cache with dma_transpose when possible (original non-transposed layout). + # HW DGE indirect dma_transpose requires src.shape[-1] % 128 == 0, so + # head_dim must be a multiple of 128 (observed 87.9% rel diff with + # head_dim=64). Fall back to the per-block dma_copy + nc_transpose + # path otherwise. + use_dma_transpose = head_dim == 128 and (chunk_num_blocks % 16 == 0) + + if use_dma_transpose: + # Load block indices directly as (chunk_num_blocks, 1) + block_table_idx = allocator.alloc_sbuf_tensor(shape=(chunk_num_blocks, 1), dtype=nl.uint32) + nisa.dma_copy( + src=block_tables.ap( + pattern=[[1, chunk_num_blocks], [1, 1]], + offset=b_i * max_blocks_per_seq, + scalar_offset=chunk_block_table_offset, + indirect_dim=1, + ), + dst=block_table_idx, + ) + + # Single dma_transpose for this chunk's blocks + k_sbuf_tmp = allocator.alloc_sbuf_tensor( + shape=(head_dim, 1, block_size, chunk_num_blocks), dtype=k_cache.dtype, align_to=32 + ) + nisa.dma_transpose( + src=k_cache.ap( + pattern=[ + [num_kv_head * block_size * head_dim, chunk_num_blocks], + [1, 1], + [head_dim, block_size], + [1, head_dim], + ], + offset=h_i * block_size * head_dim, + vector_offset=block_table_idx, + ), + dst=k_sbuf_tmp, + axes=(3, 1, 2, 0), + oob_mode=nisa.oob_mode.skip, + ) + + # Rearrange from interleaved to contiguous layout in k_sbuf tiles. + # Iterate per-block (flat) so chunk_num_blocks < num_blocks_per_k_tile + # (partial tile) still fills the first chunk_num_blocks slots of the + # first tile rather than silently skipping (the old nested form + # `for i in range(chunk_num_k_tiles)` gave 0 iterations when + # chunk_num_k_tiles = chunk_num_blocks // num_blocks_per_k_tile = 0). + # Any unwritten K-tile tail is handled by the MM1 num_f bound and + # the upstream memset that zeroes k_cache_sbuf. + for blk_idx in range(chunk_num_blocks): + i = blk_idx // num_blocks_per_k_tile + j = blk_idx % num_blocks_per_k_tile + nisa.tensor_copy( + src=k_sbuf_tmp.ap( + pattern=[[block_size * chunk_num_blocks, head_dim], [chunk_num_blocks, block_size]], + offset=blk_idx, + ), + dst=k_sbuf[chunk_k_tile_start + i].ap( + pattern=[[K_TILE_SIZE, head_dim], [1, block_size]], + offset=j * block_size, + ), + ) + else: + if head_dim <= 128: + print( + f"WARNING: chunk_num_blocks={chunk_num_blocks} is not a multiple of 16. " + f"Falling back to per-block dma_copy + nc_transpose for K cache loading." + ) + + # Fallback: Load without transpose per block, then transpose each block. + # For Qwen head_dim=256, split D into two 128-wide SBUF partition + # tiles and keep the pair under the same K-tile index. + if head_dim == 256: + for d_half in range(2): + d_offset = d_half * _QWEN256_D_CHUNK + k_sbuf_no_tp = allocator.alloc_sbuf_tensor( + shape=(_QWEN256_D_CHUNK, _QWEN256_D_CHUNK), + dtype=k_cache.dtype, + align_to=32, + ) + k_psum_transposed = nl.ndarray( + (_QWEN256_D_CHUNK, _QWEN256_D_CHUNK), + dtype=k_cache.dtype, + buffer=nl.psum, + address=(0, 0), + ) + + for blk_idx in range(chunk_num_blocks): + i = blk_idx // num_blocks_per_k_tile + j = blk_idx % num_blocks_per_k_tile + + for token_half in range(block_size // _QWEN256_D_CHUNK): + token_offset = token_half * _QWEN256_D_CHUNK + nisa.dma_copy( + dst=k_sbuf_no_tp, + src=k_cache.ap( + pattern=[[head_dim, _QWEN256_D_CHUNK], [1, _QWEN256_D_CHUNK]], + offset=h_i * block_size * head_dim + token_offset * head_dim + d_offset, + scalar_offset=block_table_idx_before_tp.ap( + pattern=[[chunk_num_blocks, 1], [1, 1]], offset=blk_idx + ), + ), + dge_mode=nisa.dge_mode.hwdge, + ) + nisa.nc_transpose(dst=k_psum_transposed, data=k_sbuf_no_tp) + nisa.tensor_copy( + dst=k_sbuf[chunk_k_tile_start + i][d_half].ap( + pattern=[[K_TILE_SIZE, _QWEN256_D_CHUNK], [1, _QWEN256_D_CHUNK]], + offset=j * block_size + token_offset, + ), + src=k_psum_transposed, + ) + else: + k_sbuf_no_tp = allocator.alloc_sbuf_tensor( + shape=(block_size, head_dim), + dtype=k_cache.dtype, + align_to=32, + ) + k_psum_transposed = nl.ndarray( + (head_dim, block_size), dtype=k_cache.dtype, buffer=nl.psum, address=(0, 0) + ) + + for blk_idx in range(chunk_num_blocks): + i = blk_idx // num_blocks_per_k_tile + j = blk_idx % num_blocks_per_k_tile + + nisa.dma_copy( + dst=k_sbuf_no_tp, + src=k_cache.ap( + pattern=[[head_dim, block_size], [1, head_dim]], + offset=h_i * block_size * head_dim, + scalar_offset=block_table_idx_before_tp.ap( + pattern=[[chunk_num_blocks, 1], [1, 1]], offset=blk_idx + ), + ), + dge_mode=nisa.dge_mode.hwdge, + ) + nisa.nc_transpose(dst=k_psum_transposed, data=k_sbuf_no_tp) + nisa.tensor_copy( + dst=k_sbuf[chunk_k_tile_start + i].ap( + pattern=[[K_TILE_SIZE, head_dim], [1, block_size]], offset=j * block_size + ), + src=k_psum_transposed, + ) + + # Load V cache without transpose + if block_size >= V_TILE_SIZE: + # Original path: each block has one or more V tiles + kernel_assert( + block_size % V_TILE_SIZE == 0, + f"block_size must be divisible by V_TILE_SIZE when block_size >= V_TILE_SIZE", + ) + num_v_tiles_per_block = block_size // V_TILE_SIZE + + for i in range(chunk_num_blocks): + for j in range(num_v_tiles_per_block): + nisa.dma_copy( + dst=v_sbuf[(chunk_start + i) * num_v_tiles_per_block + j].ap( + pattern=[[head_dim, V_TILE_SIZE], [1, head_dim]], offset=0 + ), + src=v_cache.ap( + pattern=[[head_dim, V_TILE_SIZE], [1, head_dim]], + offset=h_i * block_size * head_dim + j * V_TILE_SIZE * head_dim, + scalar_offset=block_table_idx_before_tp.ap( + pattern=[[chunk_num_blocks, 1], [1, 1]], offset=i + ), + ), + dge_mode=nisa.dge_mode.hwdge, + ) + else: + # Small block path: each V tile spans multiple blocks + kernel_assert( + V_TILE_SIZE % block_size == 0, + f"V_TILE_SIZE must be divisible by block_size when block_size < V_TILE_SIZE", + ) + num_blocks_per_v_tile = V_TILE_SIZE // block_size + chunk_num_v_tiles = chunk_num_blocks // num_blocks_per_v_tile + chunk_v_tile_start = chunk_start // num_blocks_per_v_tile + + for v_tile_idx in range(chunk_num_v_tiles): + for blk_in_tile in range(num_blocks_per_v_tile): + block_idx = v_tile_idx * num_blocks_per_v_tile + blk_in_tile + nisa.dma_copy( + dst=v_sbuf[chunk_v_tile_start + v_tile_idx].ap( + pattern=[[head_dim, block_size], [1, head_dim]], + offset=blk_in_tile * block_size * head_dim, + ), + src=v_cache.ap( + pattern=[[head_dim, block_size], [1, head_dim]], + offset=h_i * block_size * head_dim, + scalar_offset=block_table_idx_before_tp.ap( + pattern=[[chunk_num_blocks, 1], [1, 1]], offset=block_idx + ), + ), + dge_mode=nisa.dge_mode.hwdge, + ) + + # Free this chunk's temp allocations + allocator.set_current_address(address=chunk_sbuf_addr) + + # Restore SBUF address to maintain callee-safe behavior + # This allows load_kv_cache to be called multiple times without address conflicts + allocator.set_current_address(address=orig_sbuf_addr) + + +def _attention_segmented_cte_swa_impl( + q: nl.ndarray, + k_cache: nl.ndarray, + v_cache: nl.ndarray, + block_tables: nl.ndarray, + prior_tokens: nl.ndarray, + block_size: int, + prior_seg_size: int, + scale: float, + tp_q: bool, + tp_out: bool, + sliding_window: int, + sink: Optional[nl.ndarray], + num_q_heads: int = 1, + k_pre_transposed: bool = False, + k_scale: Optional[nl.ndarray] = None, + v_scale: Optional[nl.ndarray] = None, +): + """ + Simplified sliding window attention implementation with single-iteration processing. + + With sliding window attention, we only need to attend to at most (sliding_window - 1) + prior tokens, so everything can be done in a single attention_cte call per batch. + + Strategy: + 1. Load active KV from offset (prior_tokens // block_size) + 2. Always load window-sized prior KV (at most ceil((sw-1)/block_size) blocks) + 3. Call attention_cte with prefix caching; prior_used_len dynamically masks + (0 when no prior tokens, clamped otherwise) + + Handles LNC2 sharding similar to the normal multi-segment flow. + + Args: + q: Query tensor + k_cache: K cache in HBM + v_cache: V cache in HBM + block_tables: Block table tensor + prior_tokens: Prior tokens tensor in HBM + block_size: Block size + prior_seg_size: Segment size + scale: Attention scale factor + tp_q: Query transpose flag + tp_out: Output transpose flag + sliding_window: Sliding window size + sink: Optional sink tensor + + Returns: + result: Attention output tensor + """ + # Extract dimensions + if tp_q: + bs_q, seqlen_q, head_dim = q.shape + else: + bs_q, head_dim, seqlen_q = q.shape + + kernel_assert(seqlen_q % 128 == 0, f"Query seqlen {seqlen_q} must be a multiple of 128") + + # Derive num_kv_heads from v_cache shape for GQA mapping (v_cache's + # layout is independent of k_pre_transposed). + num_kv_heads = v_cache.shape[1] + + # Get sharding info for multi-core parallelization + grid_ndim, num_shard, shard_id = get_verified_program_sharding_info("attention_segmented_cte", max_sharding=2) + + # Primary sharding: divide bs_q evenly across shards + num_bs_per_shard = bs_q // num_shard + bs_offset = shard_id * num_bs_per_shard + + # Secondary sharding: handle remainder if bs_q is odd + has_remainder = (bs_q % num_shard) != 0 + last_batch = bs_q - 1 + + # Initialize allocator + allocator = ModularAllocator(initial_address=0) + + # Load KV dequantization scales into SBUF if provided (FP8 KV cache support) + if k_scale is not None: + k_scale_sb = allocator.alloc_sbuf_tensor(shape=(nl.tile_size.pmax, 1), dtype=nl.float32) + nisa.dma_copy(dst=k_scale_sb, src=k_scale) + else: + k_scale_sb = None + if v_scale is not None: + v_scale_sb = allocator.alloc_sbuf_tensor(shape=(nl.tile_size.pmax, 1), dtype=nl.float32) + nisa.dma_copy(dst=v_scale_sb, src=v_scale) + else: + v_scale_sb = None + + # Load prior_tokens to SBUF + prior_tokens_sbuf = allocator.alloc_sbuf_tensor(shape=(1, 1), dtype=nl.int32) + nisa.dma_copy(dst=prior_tokens_sbuf, src=prior_tokens) + + # Calculate active block offset = prior_tokens // block_size + block_size_shift = int(math.log2(block_size)) + active_block_offset_swa = allocator.alloc_sbuf_tensor(shape=(1, 1), dtype=nl.int32) + nisa.tensor_scalar( + dst=active_block_offset_swa, + data=prior_tokens_sbuf, + op0=nl.right_shift, + operand0=block_size_shift, + ) + + # Prior loading: load ceil((sw-1)/block_size) blocks right before active + # num_prior_blocks_to_load is compile-time since sliding_window is compile-time + effective_prior_size = sliding_window - 1 # Max prior tokens + # Must load at least _K_TILE_SZ/block_size blocks so load_kv_cache can fill K tiles + min_blocks_per_k_tile = _K_TILE_SZ // block_size + num_prior_blocks_to_load = max(math.ceil(effective_prior_size / block_size), min_blocks_per_k_tile) + + # Calculate effective_prior_len = min(num_prior_blocks_to_load * block_size, prior_tokens) + # We use the loaded block count (not sw-1) because attention_cte's SWA mask + # will handle masking the extra tokens beyond the window. + # When prior_tokens=0, effective_prior_len=0 and attention_cte masks out all prior. + effective_prior_len_sbuf = allocator.alloc_sbuf_tensor(shape=(1, 1), dtype=nl.int32) + nisa.tensor_scalar( + dst=effective_prior_len_sbuf, + data=prior_tokens_sbuf, + op0=nl.minimum, + operand0=num_prior_blocks_to_load * block_size, + ) + + # Re-buffer sink into private_hbm so _attention_cte's DMA load of + # sink[batch_id, 0] has a buffer tier it can consume. The dynamic-range + # no-op below is kept for scheduling/allocator parity with the prior>0 + # and prior=0 paths; removing it regressed prior>0 cases empirically. + # _attention_cte correctly includes the sink in the section-0 softmax + # even when prior_used_len=0, so no value-masking is needed here. + if sink is not None: + sink_masked = nl.ndarray(shape=sink.shape, dtype=sink.dtype, buffer=nl.private_hbm) + nisa.dma_copy(dst=sink_masked, src=sink) + no_prior_sbuf = allocator.alloc_sbuf_tensor(shape=(1, 1), dtype=nl.int32) + nisa.tensor_scalar(dst=no_prior_sbuf, data=effective_prior_len_sbuf, op0=nl.less_equal, operand0=0) + no_prior_reg = nisa.register_alloc() + nisa.register_load(dst=no_prior_reg, src=no_prior_sbuf) + sink_reload_sbuf = allocator.alloc_sbuf_tensor(shape=sink.shape, dtype=sink.dtype, align_to=4) + nisa.dma_copy(dst=sink_reload_sbuf, src=sink) + for _ in nl.dynamic_range(0, no_prior_reg): + nisa.dma_copy(dst=sink_masked, src=sink_reload_sbuf) + sink = sink_masked + + # prior_block_offset = max(0, active_block_offset - num_prior_blocks_to_load) (dynamic) + # Clamp to 0 to avoid negative offset when prior_tokens < sliding_window + prior_block_offset_swa = allocator.alloc_sbuf_tensor(shape=(1, 1), dtype=nl.uint32, align_to=4) + nisa.tensor_scalar( + dst=prior_block_offset_swa, + data=active_block_offset_swa, + op0=nl.subtract, + operand0=num_prior_blocks_to_load, + op1=nl.maximum, + operand1=0, + ) + # Allocate K/V sbuf for active segment (sized by seqlen_q, not prior_seg_size) + num_active_blocks_swa = seqlen_q // block_size + num_k_tiles_active_swa = math.ceil(seqlen_q / _K_TILE_SZ) + num_v_tiles_active_swa = num_k_tiles_active_swa * (_K_TILE_SZ // _V_TILE_SZ) + num_grps = math.ceil(seqlen_q / 128) + + k_cache_sbuf = _alloc_k_cache_sbuf(allocator, head_dim, num_k_tiles_active_swa) + v_cache_sbuf = allocator.alloc_sbuf_tensor( + shape=(_V_TILE_SZ, head_dim), + dtype=nl.bfloat16, + block_dim=[num_v_tiles_active_swa], + num_free_tiles=[num_v_tiles_active_swa], + ) + + # Allocate HBM buffers for unnormalized output and softmax stats (single batch, reused) + # Intermediate uses tp_out=False so per-Q-position normalization (tensor_scalar with + # 128-element partition vector) works correctly. When tp_out=True, nc_transpose at + # final write. This constraint comes from tensor_scalar requiring the correction + # vector to match the tile's partition dimension (128 Q positions, not head_dim). + softmax_shape_swa = (1, 128, num_grps) + out_o_hbm_swa = nl.ndarray(shape=(1, seqlen_q, head_dim), dtype=q.dtype, buffer=nl.private_hbm) + out_neg_max_hbm_swa = nl.ndarray(shape=softmax_shape_swa, dtype=nl.float32, buffer=nl.private_hbm) + out_sum_hbm_swa = nl.ndarray(shape=softmax_shape_swa, dtype=nl.float32, buffer=nl.private_hbm) + + # Allocate result + if tp_out: + result = nl.ndarray(shape=(bs_q, head_dim, seqlen_q), dtype=q.dtype, buffer=nl.shared_hbm) + else: + result = nl.ndarray(shape=(bs_q, seqlen_q, head_dim), dtype=q.dtype, buffer=nl.shared_hbm) + + # Workaround for NCC_IBIR251: Allocate Q buffer for single batch + # Makes Q "internal" so access patterns work in dynamic loops + if tp_q: + q_internal = nl.ndarray(shape=(1, seqlen_q, head_dim), dtype=q.dtype, buffer=nl.private_hbm) + else: + q_internal = nl.ndarray(shape=(1, head_dim, seqlen_q), dtype=q.dtype, buffer=nl.private_hbm) + + # Prior KV tile counts (compile-time, used by both primary and remainder paths) + # attention_cte derives seqlen_k_prior = len(k_prior_sbuf) * _K_TILE_SZ + # V tiles must be consistent: num_v_tiles = num_k_tiles * (_K_TILE_SZ // _V_TILE_SZ) + num_prior_k_tiles = math.ceil(num_prior_blocks_to_load * block_size / _K_TILE_SZ) + num_prior_v_tiles = num_prior_k_tiles * (_K_TILE_SZ // _V_TILE_SZ) + + # Allocate prior KV SBUF once outside the loop (reused across iterations) + k_prior_sbuf_swa = _alloc_k_cache_sbuf(allocator, head_dim, num_prior_k_tiles) + v_prior_sbuf_swa = allocator.alloc_sbuf_tensor( + shape=(_V_TILE_SZ, head_dim), + dtype=nl.bfloat16, + block_dim=[num_prior_v_tiles], + num_free_tiles=[num_prior_v_tiles], + ) + + # Process primary batches (one at a time) + for b in range(num_bs_per_shard): + batch_id = b + bs_offset # Global bs_q index + + # Derive batch index and KV head index from batch_id for GQA + b_i = batch_id // num_q_heads + h_i = (batch_id % num_q_heads) * num_kv_heads // num_q_heads + + # Copy this batch's query data (layout matches tp_q) + if tp_q: + nisa.dma_copy( + dst=q_internal[0, :, :], + src=q.ap( + pattern=[[head_dim, seqlen_q], [1, head_dim]], + offset=batch_id * seqlen_q * head_dim, + ), + ) + else: + nisa.dma_copy( + dst=q_internal[0, :, :], + src=q.ap( + pattern=[[seqlen_q, head_dim], [1, seqlen_q]], + offset=batch_id * head_dim * seqlen_q, + ), + ) + + # Load active KV cache + load_kv_cache( + k_cache, + v_cache, + block_tables, + k_cache_sbuf, + v_cache_sbuf, + b_i, + h_i, + active_block_offset_swa, + num_active_blocks_swa, + allocator, + k_pre_transposed=k_pre_transposed, + ) + + # Load at most window-sized prior KV; prior_used_len dynamically masks + # (0 when no prior tokens, clamped to actual prior otherwise). + load_kv_cache( + k_cache, + v_cache, + block_tables, + k_prior_sbuf_swa, + v_prior_sbuf_swa, + b_i, + h_i, + prior_block_offset_swa, + num_prior_blocks_to_load, + allocator, + k_pre_transposed=k_pre_transposed, + ) + + init_sbuf_addr = allocator.get_current_address() + + _attention_cte( + q_internal, + None, + None, + scale=scale, + causal_mask=True, + tp_q=tp_q, + tp_k=False, + tp_out=False, + cache_softmax=True, + skip_output_normalization=True, + sliding_window=sliding_window, + sink=sink, + k_cache_sbuf=k_cache_sbuf, + v_cache_sbuf=v_cache_sbuf, + k_prior_sbuf=k_prior_sbuf_swa, + v_prior_sbuf=v_prior_sbuf_swa, + prior_used_len=effective_prior_len_sbuf, + out_o_hbm=out_o_hbm_swa, + out_neg_max_hbm=out_neg_max_hbm_swa, + out_sum_hbm=out_sum_hbm_swa, + init_sbuf_addr=init_sbuf_addr, + k_scale_sb=k_scale_sb, + ) + allocator.set_current_address(init_sbuf_addr) + + # Normalize (divide by S) and write to result + sb_p = nl.tile_size.pmax + sm_pat = [[num_grps, sb_p], [1, num_grps]] + o_tile_pat = [[head_dim, sb_p], [1, head_dim]] + + norm_addr = allocator.get_current_address() + sum_sb = allocator.alloc_sbuf_tensor(shape=(sb_p, num_grps), dtype=nl.float32) + nisa.dma_copy(dst=sum_sb, src=out_sum_hbm_swa.ap(pattern=sm_pat, offset=0)) + sum_recip_sb = allocator.alloc_sbuf_tensor(shape=(sb_p, num_grps), dtype=nl.float32) + nisa.reciprocal(sum_recip_sb, sum_sb) + + num_free = min(num_grps, _MAX_FREE_TILES) + o_sb = allocator.alloc_sbuf_tensor( + shape=(sb_p, head_dim), + dtype=nl.bfloat16, + block_dim=[num_grps], + num_free_tiles=[num_free], + ) + if tp_out: + o_tp_psum = nl.ndarray((head_dim, sb_p), dtype=nl.bfloat16, buffer=nl.psum, address=(0, 0)) + o_tp_sb = allocator.alloc_sbuf_tensor(shape=(head_dim, sb_p), dtype=nl.bfloat16) + for grp_i in range(num_grps): + grp_o_offset = grp_i * sb_p * head_dim + nisa.dma_copy(dst=o_sb[grp_i], src=out_o_hbm_swa.ap(pattern=o_tile_pat, offset=grp_o_offset)) + # Delayed V dequant: fold v_scale multiply into normalization tensor_scalar + if v_scale_sb is not None: + nisa.tensor_scalar( + dst=o_sb[grp_i], + data=o_sb[grp_i], + op0=nl.multiply, + operand0=sum_recip_sb[:, grp_i], + op1=nl.multiply, + operand1=v_scale_sb, + ) + else: + nisa.tensor_scalar( + dst=o_sb[grp_i], + data=o_sb[grp_i], + op0=nl.multiply, + operand0=sum_recip_sb[:, grp_i], + ) + if tp_out: + # nc_transpose (128, head_dim) → (head_dim, 128), then write with transposed AP + nisa.nc_transpose(dst=o_tp_psum, data=o_sb[grp_i]) + nisa.tensor_copy(dst=o_tp_sb, src=o_tp_psum) + nisa.dma_copy( + dst=result.ap( + pattern=[[seqlen_q, head_dim], [1, sb_p]], + offset=batch_id * head_dim * seqlen_q + grp_i * sb_p, + ), + src=o_tp_sb, + ) + else: + dst_o_offset = batch_id * num_grps * sb_p * head_dim + grp_o_offset + nisa.dma_copy(dst=result.ap(pattern=o_tile_pat, offset=dst_o_offset), src=o_sb[grp_i]) + allocator.set_current_address(norm_addr) + + # Handle remainder batch (if bs_q is odd) with asymmetric sequence sharding. + # + # Core 0 handles Q[0 : ceil(num_grps/2)*128], Core 1 handles Q[ceil(num_grps/2)*128 : + # num_grps*128]. For num_grps == 1 the split is degenerate — Core 0 does the full + # remainder and Core 1 short-circuits. + core0_grp_length = (num_grps + 1) // 2 + core1_grp_length = num_grps // 2 + if shard_id == 0: + grp_length = core0_grp_length + grp_start = 0 + else: + grp_length = core1_grp_length + grp_start = core0_grp_length + + run_remainder = has_remainder and grp_length > 0 + + if run_remainder: + # Per-shard active-segment sizing (Core 0's sizing also used by Core 1's + # active-offset math to step past Core 0's active region). + effective_q_tokens = grp_length * 128 + num_blocks_per_effective_seg = effective_q_tokens // block_size + num_k_tiles_per_effective_seg = math.ceil(effective_q_tokens / _K_TILE_SZ) + num_v_tiles_per_effective_seg = math.ceil(effective_q_tokens / _V_TILE_SZ) + num_grps_effective = math.ceil(effective_q_tokens / 128) + + core0_q_tokens = core0_grp_length * 128 + core0_num_blocks_per_effective_seg = core0_q_tokens // block_size + + # Calculate sequence start position for this core's Q slice + seq_start = grp_start * 128 + + # Copy only this core's portion of Q to q_internal_remainder + if tp_q: + q_internal_remainder = nl.ndarray( + shape=(1, effective_q_tokens, head_dim), dtype=q.dtype, buffer=nl.private_hbm + ) + nisa.dma_copy( + dst=q_internal_remainder[0, :, :], + src=q.ap( + pattern=[[head_dim, effective_q_tokens], [1, head_dim]], + offset=last_batch * seqlen_q * head_dim + seq_start * head_dim, + ), + ) + else: + q_internal_remainder = nl.ndarray( + shape=(1, head_dim, effective_q_tokens), dtype=q.dtype, buffer=nl.private_hbm + ) + nisa.dma_copy( + dst=q_internal_remainder[0, :, :], + src=q.ap( + pattern=[[seqlen_q, head_dim], [1, effective_q_tokens]], + offset=last_batch * head_dim * seqlen_q + seq_start, + ), + ) + + # Allocate K/V sbuf for effective segment + k_cache_sbuf_rem = _alloc_k_cache_sbuf(allocator, head_dim, num_k_tiles_per_effective_seg) + v_cache_sbuf_rem = allocator.alloc_sbuf_tensor( + shape=(_V_TILE_SZ, head_dim), + dtype=nl.bfloat16, + block_dim=[num_v_tiles_per_effective_seg], + num_free_tiles=[num_v_tiles_per_effective_seg], + ) + # Allocate HBM buffers for remainder batch (per-shard Q length) + rem_softmax_shape_swa = (1, 128, num_grps_effective) + out_o_hbm_rem = nl.ndarray(shape=(1, effective_q_tokens, head_dim), dtype=q.dtype, buffer=nl.private_hbm) + out_neg_max_hbm_rem = nl.ndarray(shape=rem_softmax_shape_swa, dtype=nl.float32, buffer=nl.private_hbm) + out_sum_hbm_rem = nl.ndarray(shape=rem_softmax_shape_swa, dtype=nl.float32, buffer=nl.private_hbm) + + # Adjust active_block_offset and prior parameters for each core. + # Core 0: same as primary batch (its Q is the leading portion of the remainder). + # Core 1: active starts after Core 0's active region; prior must cover + # Core 0's active as part of its sliding-window prior context. + active_block_offset_rem = allocator.alloc_sbuf_tensor(shape=(1, 1), dtype=nl.int32) + prior_block_offset_rem = allocator.alloc_sbuf_tensor(shape=(1, 1), dtype=nl.uint32) + effective_prior_len_rem = allocator.alloc_sbuf_tensor(shape=(1, 1), dtype=nl.int32) + + if shard_id == 0: + # Core 0: Loads leading active KV, prior same as primary batch + nisa.tensor_copy(dst=active_block_offset_rem, src=active_block_offset_swa) + nisa.tensor_copy(dst=prior_block_offset_rem, src=prior_block_offset_swa) + nisa.tensor_copy(dst=effective_prior_len_rem, src=effective_prior_len_sbuf) + else: + # Core 1: Loads trailing active KV, starting at + # active_block_offset_swa + core0_num_blocks_per_effective_seg + nisa.tensor_scalar( + dst=active_block_offset_rem, + data=active_block_offset_swa, + op0=nl.add, + operand0=core0_num_blocks_per_effective_seg, + ) + # Core 1's prior: blocks right before its active segment + # prior_block_offset = max(0, active_block_offset_rem - num_prior_blocks_to_load) + nisa.tensor_scalar( + dst=prior_block_offset_rem, + data=active_block_offset_rem, + op0=nl.subtract, + operand0=num_prior_blocks_to_load, + op1=nl.maximum, + operand1=0, + ) + # Core 1's effective_prior_len = min(num_prior_blocks_to_load * block_size, + # prior_tokens + core0_q_tokens) + # prior_tokens + core0_q_tokens = total tokens before Core 1's active + nisa.tensor_scalar( + dst=effective_prior_len_rem, + data=prior_tokens_sbuf, + op0=nl.add, + operand0=core0_q_tokens, + op1=nl.minimum, + operand1=num_prior_blocks_to_load * block_size, + ) + # Derive batch/head indices for the remainder batch + rem_b_i = last_batch // num_q_heads + rem_h_i = (last_batch % num_q_heads) * num_kv_heads // num_q_heads + + # Load active KV cache for this core's portion + load_kv_cache( + k_cache, + v_cache, + block_tables, + k_cache_sbuf_rem, + v_cache_sbuf_rem, + rem_b_i, + rem_h_i, + active_block_offset_rem, + num_blocks_per_effective_seg, + allocator, + k_pre_transposed=k_pre_transposed, + ) + + # Load at most window-sized prior KV for remainder; prior_used_len dynamically masks + # (0 when no prior tokens, clamped to actual prior otherwise). + k_prior_sbuf_rem = _alloc_k_cache_sbuf(allocator, head_dim, num_prior_k_tiles) + v_prior_sbuf_rem = allocator.alloc_sbuf_tensor( + shape=(_V_TILE_SZ, head_dim), + dtype=nl.bfloat16, + block_dim=[num_prior_v_tiles], + num_free_tiles=[num_prior_v_tiles], + ) + load_kv_cache( + k_cache, + v_cache, + block_tables, + k_prior_sbuf_rem, + v_prior_sbuf_rem, + rem_b_i, + rem_h_i, + prior_block_offset_rem, + num_prior_blocks_to_load, + allocator, + k_pre_transposed=k_pre_transposed, + ) + + init_sbuf_addr = allocator.get_current_address() + + _attention_cte( + q_internal_remainder, + None, + None, + scale=scale, + causal_mask=True, + tp_q=tp_q, + tp_k=False, + tp_out=False, + cache_softmax=True, + skip_output_normalization=True, + sliding_window=sliding_window, + sink=sink, + k_cache_sbuf=k_cache_sbuf_rem, + v_cache_sbuf=v_cache_sbuf_rem, + k_prior_sbuf=k_prior_sbuf_rem, + v_prior_sbuf=v_prior_sbuf_rem, + prior_used_len=effective_prior_len_rem, + out_o_hbm=out_o_hbm_rem, + out_neg_max_hbm=out_neg_max_hbm_rem, + out_sum_hbm=out_sum_hbm_rem, + init_sbuf_addr=init_sbuf_addr, + k_scale_sb=k_scale_sb, + ) + allocator.set_current_address(init_sbuf_addr) + + # Normalize and write results to HBM for remainder batch + rem_sb_p = nl.tile_size.pmax + rem_sm_pat = [[num_grps_effective, rem_sb_p], [1, num_grps_effective]] + rem_o_tile_pat = [[head_dim, rem_sb_p], [1, head_dim]] + + rem_norm_addr = allocator.get_current_address() + rem_sum_sb = allocator.alloc_sbuf_tensor(shape=(rem_sb_p, num_grps_effective), dtype=nl.float32) + nisa.dma_copy(dst=rem_sum_sb, src=out_sum_hbm_rem.ap(pattern=rem_sm_pat, offset=0)) + rem_sum_recip_sb = allocator.alloc_sbuf_tensor(shape=(rem_sb_p, num_grps_effective), dtype=nl.float32) + nisa.reciprocal(rem_sum_recip_sb, rem_sum_sb) + + rem_num_free = min(num_grps_effective, _MAX_FREE_TILES) + rem_o_sb = allocator.alloc_sbuf_tensor( + shape=(rem_sb_p, head_dim), + dtype=nl.bfloat16, + block_dim=[num_grps_effective], + num_free_tiles=[rem_num_free], + ) + if tp_out: + rem_o_tp_psum = nl.ndarray((head_dim, rem_sb_p), dtype=nl.bfloat16, buffer=nl.psum, address=(0, 0)) + rem_o_tp_sb = allocator.alloc_sbuf_tensor(shape=(head_dim, rem_sb_p), dtype=nl.bfloat16) + for local_grp_i in range(num_grps_effective): + global_grp_i = grp_start + local_grp_i + grp_start_pos = global_grp_i * 128 + + src_o_offset = local_grp_i * rem_sb_p * head_dim + nisa.dma_copy(dst=rem_o_sb[local_grp_i], src=out_o_hbm_rem.ap(pattern=rem_o_tile_pat, offset=src_o_offset)) + # Delayed V dequant: fold v_scale multiply into normalization tensor_scalar + if v_scale_sb is not None: + nisa.tensor_scalar( + dst=rem_o_sb[local_grp_i], + data=rem_o_sb[local_grp_i], + op0=nl.multiply, + operand0=rem_sum_recip_sb[:, local_grp_i], + op1=nl.multiply, + operand1=v_scale_sb, + ) + else: + nisa.tensor_scalar( + dst=rem_o_sb[local_grp_i], + data=rem_o_sb[local_grp_i], + op0=nl.multiply, + operand0=rem_sum_recip_sb[:, local_grp_i], + ) + if tp_out: + nisa.nc_transpose(dst=rem_o_tp_psum, data=rem_o_sb[local_grp_i]) + nisa.tensor_copy(dst=rem_o_tp_sb, src=rem_o_tp_psum) + nisa.dma_copy( + dst=result.ap( + pattern=[[seqlen_q, head_dim], [1, rem_sb_p]], + offset=last_batch * head_dim * seqlen_q + grp_start_pos, + ), + src=rem_o_tp_sb, + ) + else: + dst_o_offset = last_batch * num_grps * rem_sb_p * head_dim + global_grp_i * rem_sb_p * head_dim + nisa.dma_copy(dst=result.ap(pattern=rem_o_tile_pat, offset=dst_o_offset), src=rem_o_sb[local_grp_i]) + allocator.set_current_address(rem_norm_addr) + + return result + + +def attention_segmented_cte( + q: nl.ndarray, + k_cache: nl.ndarray, + v_cache: nl.ndarray, + block_tables: nl.ndarray, + prior_tokens: nl.ndarray, + block_size: int, + prior_seg_size: int, + scale: float = 1.0, + tp_q: bool = True, + tp_out: bool = False, + sliding_window: Optional[int] = None, + sink: Optional[nl.ndarray] = None, + num_q_heads: int = 1, + kvp_offset: Optional[nl.ndarray] = None, + k_pre_transposed: bool = False, + k_scale: Optional[nl.ndarray] = None, + v_scale: Optional[nl.ndarray] = None, +): + """ + Segmented attention computation with block-based KV cache and prefix caching. + + SEGMENTED ATTENTION OVERVIEW: + ================================ + + Case 1: Partial Prior (prior_tokens=640, prior_seg_size=512, block_size=128) + ----------------------------------------------------------------------- + KV Cache Block Layout: + ┌────────┬────────┬────────┬────────┬────────┬────────┬────────┬────────┬────────┐ + │ 0 │ 1 │ 2 │ 3 │ 4 │ 5 │ 6 │ 7 │ 8 │ Block indices + └────────┴────────┴────────┴────────┴────────┴────────┴────────┴────────┴────────┘ + └────Full Prior 0 (4 blks)──────────┴Partial ┴─────────Active (4 blks)───────────┘ Segments + offset=0 Prior offset=5 + offset=4 + + Iteration: Active+Partial(causal) → Prior0(no causal) + + Case 2: Full Prior Only (prior_tokens=1024, prior_seg_size=512, block_size=128) + -------------------------------------------------------------------------- + KV Cache Block Layout: + ┌────────┬────────┬────────┬────────┬────────┬────────┬────────┬────────┬────────┬────────┬────────┬────────┐ + │ 0 │ 1 │ 2 │ 3 │ 4 │ 5 │ 6 │ 7 │ 8 │ 9 │ 10 │ 11 │ Block indices + └────────┴────────┴────────┴────────┴────────┴────────┴────────┴────────┴────────┴────────┴────────┴────────┘ + └────Full Prior 0 (4 blks)──────────┴────Full Prior 1 (4 blks)──────────┴──────Active (4 blks)──────────────┘ Segments + offset=0 offset=4 offset=8 + + Iteration: Active(causal) → Prior1(no causal) → Prior0(no causal) + + Pseudo-code algorithm: + num_full_prior_segments = floor(prior_tokens / prior_seg_size) + partial_prior_tokens = prior_tokens - num_full_prior_segments * prior_seg_size + + # Load active segment KV + k_active, v_active = load_kv_cache(block_tables, offset=prior_tokens // block_size) + + # First iteration: Process active segment with causal mask + if partial_prior_tokens > 0: + # Partial segment prior exists + k_prior, v_prior = load_kv_cache(block_tables, offset=num_full_prior_segments * num_blocks_per_seg) + output = _attention_cte(q, k_prior, v_prior, k_active, v_active, + causal_mask=True, prior_used_len=partial_prior_tokens) + else: + # Full segment prior or no prior + output = _attention_cte(q, k_active, v_active, causal_mask=True) + + # Remaining iterations: Process full prior segments without causal mask + for i in range(num_full_prior_segments): + k_seg, v_seg = load_kv_cache(block_tables, offset=decremented_offset) + seg_output = _attention_cte(q, k_seg, v_seg, causal_mask=False) + output = reduce_one_batch(output, seg_output) # Online softmax rescaling + + return output + + LNC2 SHARDING STRATEGY: + ====================== + + Primary Sharding (Even bs_q): + ------------------------------ + Divides bs_q evenly across 2 cores. Example with bs_q=4: + + Core 0: Q[0], Q[1] (num_bs_per_shard = 2) + Core 1: Q[2], Q[3] (num_bs_per_shard = 2) + + Secondary Sharding (Odd bs_q with Remainder): + ---------------------------------------------- + Primary batches divided evenly, remainder batch uses 50/50 sequence sharding. + Example with bs_q=3, prior_seg_size=2048 (16 groups of 128 tokens): + + Core 0: Q[0] (primary) + Core 1: Q[1] (primary) + + Q[2] (remainder) - 50/50 SEQUENCE SPLIT: + ┌─────────────────────────────────────────────────────────────┐ + │ Q[2]: 2048 tokens (16 groups) │ + ├──────────────────────────────┬──────────────────────────────┤ + │ Core 0: Groups [0-7] │ Core 1: Groups [8-15] │ + │ Tokens [0-1023] │ Tokens [1024-2047] │ + │ effective_prior_seg_size = 1024 │ effective_prior_seg_size = 1024 │ + └──────────────────────────────┴──────────────────────────────┘ + + Prior Segment Handling (with prior_tokens > 0): + ------------------------------------------------ + When original segment has N prior segments, effective segments double. + Example: prior_seg_size=2048, prior=2048 (1 segment), remainder with 50/50 split: + + Original (prior_seg_size=2048): + ┌────────────┬─────────────┐ + │ Prior Seg 0│ Active Seg 1│ + │ (2048 tok) │ (2048 tok) │ + └────────────┴─────────────┘ + + After 50/50 split (effective_prior_seg_size=1024): + ┌──────┬──────┬──────┬──────┐ + │ P0.0 │ P0.1 │ A1.0 │ A1.1 │ (Each = 1024 tokens) + └──────┴──────┴──────┴──────┘ + + Core 0's view (segments from Core 0's perspective): + - Seg 0: Prior tokens [0-1023] (P0.0) + - Seg 1: Prior tokens [1024-2047] (P0.1) + - Seg 2: Active tokens [0-1023] (A1.0) ← Core 0's active + Total: 2 prior segments (2N where N=1) + + Core 1's view (segments from Core 1's perspective): + - Seg 0: Prior tokens [0-1023] (P0.0) + - Seg 1: Prior tokens [1024-2047] (P0.1) + - Seg 2: Active tokens [0-1023] (A1.0) ← Core 0's active (prior for Core 1!) + - Seg 3: Active tokens [1024-2047] (A1.1) ← Core 1's active + Total: 3 prior segments (2N+1) + + Implementation: + - Both cores do 2N iterations in main loop + - Core 1 does +1 extra iteration at block_offset=0 to process Seg 0 + + + Args: + q: Query tensor with shape (batch_size, seqlen_q, d) when tp_q=True + k_cache: K cache in HBM with shape (num_blocks, block_size, num_kv_head, head_dim) + v_cache: V cache in HBM with shape (num_blocks, block_size, num_kv_head, head_dim) + block_tables: Block table tensor with shape (batch_size, max_blocks_per_seq). + max_blocks_per_seq only needs to cover + ceil((prior_tokens + seqlen_q) / block_size); the kernel pads + internally when seqlen_q < prior_seg_size so the traced + partial-prior speculative read stays in-bounds. + prior_tokens: Total number of prior (cached) tokens, shape (1, 1). Must be multiple of block_size. + block_size: Size of each block in the KV cache + prior_seg_size: Size of each KV segment to process iteratively + scale: Scaling factor for attention scores (default 1.0) + tp_q: Query tensor transpose flag (default True) + tp_out: Output tensor transpose flag (default False) + + Returns: + If kvp_offset is None: output tensor with attention results. Shape depends on tp_out parameter. + If kvp_offset is set: tuple of (output, out_neg_max_hbm, out_sum_recip_hbm) for softmax stat + reduction by the caller. + + Example calculations: + prior_tokens=0: prior_last_segment_tokens=0, iterations=1 + prior_tokens=640: prior_last_segment_tokens=128, iterations=2 (prior_seg_size=512) + prior_tokens=1024: prior_last_segment_tokens=512, iterations=2 (prior_seg_size=512) + """ + kernel_assert( + kvp_offset == None, + "qwen_segcte256 KVP mode is not production validated; use the " + "non-KVP segmented CTE path", + ) + kernel_assert( + not k_pre_transposed, + "qwen_segcte256 supports only k_pre_transposed=False; " + "the transposed-K path has not been production validated", + ) + + # Extract dimensions + if tp_q: + bs_q, seqlen_q, d = q.shape + else: + bs_q, d, seqlen_q = q.shape + + # Derive dims from v_cache so we don't depend on k_cache's shape, which + # varies by layout (see k_pre_transposed argument on load_kv_cache). + num_kv_head = v_cache.shape[1] + head_dim = v_cache.shape[3] + bs, max_blocks_per_seq = block_tables.shape + + # Get sharding info for multi-core parallelization + grid_ndim, num_shard, shard_id = get_verified_program_sharding_info("attention_segmented_cte", max_sharding=2) + + # KVP is intentionally rejected above. Keep all public outputs in shared_hbm + # per NKI 0.3 output-buffer requirements. + if kvp_offset is not None: + num_shard = 1 + shard_id = 0 + result_buffer = nl.shared_hbm + + # Primary sharding: divide bs_q (batch_size * num_q_heads) evenly across shards + num_bs_per_shard = bs_q // num_shard + bs_offset = shard_id * num_bs_per_shard + + # Secondary sharding: handle remainder on sequence dimension if bs_q is odd + has_remainder = (bs_q % num_shard) != 0 + last_batch = bs_q - 1 + + # Validate inputs + kernel_assert(seqlen_q % 128 == 0, f"Query seqlen {seqlen_q} must be a multiple of 128") + kernel_assert(seqlen_q % block_size == 0, f"Query seqlen {seqlen_q} must be divisible by block_size {block_size}") + kernel_assert(d == head_dim, f"Query head_dim {d} must match cache head_dim {head_dim}") + kernel_assert( + prior_seg_size % block_size == 0, + f"prior_seg_size {prior_seg_size} must be divisible by block_size {block_size}", + ) + kernel_assert(head_dim <= 256, f"head_dim must be <= 256 (got {head_dim}). Larger head_dim not yet supported by qwen_segcte256.") + + num_blocks_per_seg = prior_seg_size // block_size + + # Initialize allocator + allocator = ModularAllocator(initial_address=0) + + # Pad block_tables internally so every compile-time-traced scalar-DGE read + # stays in bounds. Two independent dynamic paths need headroom: + # 1. the partial-prior helper's one-past segment read at + # (num_full_prior_segments + 1) * num_blocks_per_seg + # 2. Qwen head_dim=256 active streaming, which still walks the full CTE + # bucket even when the final real active chunk is shorter. At pfx256, + # a 768-token final chunk in a 3072 CTE bucket can otherwise read + # active block-table offsets past the 1024 real prefix blocks. + # Unconditional: the NKI compiler's pessimistic bound check on dynamic + # scalar_offset into block_tables makes any conditional predicate + # incomplete. Done before the SWA dispatch so both paths benefit. + num_active_blocks_for_padding = seqlen_q // block_size + padded_width_for_prior = ( + (max_blocks_per_seq // num_blocks_per_seg + 1) * num_blocks_per_seg + ) + padded_width_for_active_stream = ( + max_blocks_per_seq + num_active_blocks_for_padding + ) + padded_width = max(padded_width_for_prior, padded_width_for_active_stream) + if padded_width % num_blocks_per_seg != 0: + padded_width = ( + (padded_width + num_blocks_per_seg - 1) + // num_blocks_per_seg + ) * num_blocks_per_seg + + if padded_width > max_blocks_per_seq: + block_tables_internal = nl.ndarray(shape=(bs, padded_width), dtype=block_tables.dtype, buffer=nl.private_hbm) + pad_addr = allocator.get_current_address() + pad_scratch = allocator.alloc_sbuf_tensor(shape=(bs, padded_width), dtype=block_tables.dtype) + nisa.memset(pad_scratch[...], value=0) + nisa.dma_copy(dst=block_tables_internal, src=pad_scratch) + nisa.dma_copy( + dst=block_tables_internal.ap(pattern=[[padded_width, bs], [1, max_blocks_per_seq]], offset=0), + src=block_tables, + ) + allocator.set_current_address(pad_addr) + block_tables = block_tables_internal + max_blocks_per_seq = padded_width + + # Sliding window attention: use simplified single-iteration path + if sliding_window is not None and sliding_window > 0: + kernel_assert( + sliding_window % block_size == 0, + f"sliding_window {sliding_window} must be divisible by block_size {block_size}", + ) + return _attention_segmented_cte_swa_impl( + q=q, + k_cache=k_cache, + v_cache=v_cache, + block_tables=block_tables, + prior_tokens=prior_tokens, + block_size=block_size, + prior_seg_size=prior_seg_size, + scale=scale, + tp_q=tp_q, + tp_out=tp_out, + sliding_window=sliding_window, + sink=sink, + num_q_heads=num_q_heads, + k_pre_transposed=k_pre_transposed, + k_scale=k_scale, + v_scale=v_scale, + ) + + num_k_tiles_per_seg = math.ceil(prior_seg_size / _K_TILE_SZ) + num_v_tiles_per_seg = num_k_tiles_per_seg * (_K_TILE_SZ // _V_TILE_SZ) + num_grps = math.ceil(seqlen_q / 128) + + # Active segment tile/block counts (may differ from prior when seqlen_q != prior_seg_size) + num_active_blocks = seqlen_q // block_size + num_k_tiles_active = math.ceil(seqlen_q / _K_TILE_SZ) + num_v_tiles_active = num_k_tiles_active * (_K_TILE_SZ // _V_TILE_SZ) + + # Qwen head_dim=256 streams the active CTE through the same small K/V + # window used for prior segments. Keeping the old max(active, prior) + # allocation made the 3072-token CTE bucket carry all active K/V in SBUF. + if head_dim == 256: + active_stream_tokens = min(prior_seg_size, seqlen_q) + kernel_assert( + active_stream_tokens % block_size == 0, + "qwen_segcte256 active stream chunk must be divisible by block_size", + ) + num_k_tiles_active_stream = math.ceil(active_stream_tokens / _K_TILE_SZ) + num_v_tiles_active_stream = num_k_tiles_active_stream * (_K_TILE_SZ // _V_TILE_SZ) + num_k_tiles_sbuf = max(num_k_tiles_per_seg, num_k_tiles_active_stream) + num_v_tiles_sbuf = max(num_v_tiles_per_seg, num_v_tiles_active_stream) + else: + # K/V sbuf must be large enough for both active and prior segments. + num_k_tiles_sbuf = max(num_k_tiles_per_seg, num_k_tiles_active) + num_v_tiles_sbuf = max(num_v_tiles_per_seg, num_v_tiles_active) + + # Load KV dequantization scales into SBUF if provided (FP8 KV cache support) + if k_scale is not None: + k_scale_sb = allocator.alloc_sbuf_tensor(shape=(nl.tile_size.pmax, 1), dtype=nl.float32) + nisa.dma_copy(dst=k_scale_sb, src=k_scale) + else: + k_scale_sb = None + if v_scale is not None: + v_scale_sb = allocator.alloc_sbuf_tensor(shape=(nl.tile_size.pmax, 1), dtype=nl.float32) + nisa.dma_copy(dst=v_scale_sb, src=v_scale) + else: + v_scale_sb = None + + # Compute segment offsets ONCE (both cores execute this for consistent control flow) + prior_tokens_sbuf = allocator.alloc_sbuf_tensor(shape=(1, 1), dtype=nl.int32) + nisa.dma_copy(dst=prior_tokens_sbuf, src=prior_tokens) + + # num_full_prior_segments = floor(prior_tokens / prior_seg_size) + num_full_prior_segments_f32 = allocator.alloc_sbuf_tensor(shape=(1, 1), dtype=nl.float32) + num_full_prior_segments_i32 = allocator.alloc_sbuf_tensor(shape=(1, 1), dtype=nl.int32) + nisa.tensor_scalar( + dst=num_full_prior_segments_f32, data=prior_tokens_sbuf, op0=nl.multiply, operand0=1 / prior_seg_size + ) + floor_nisa_kernel( + src_t=num_full_prior_segments_f32, dst_t=num_full_prior_segments_i32, p_size=1, f_size=1, allocator=allocator + ) + + # Compute block offsets + block_size_shift = int(math.log2(block_size)) + active_block_offset = allocator.alloc_sbuf_tensor(shape=(1, 1), dtype=nl.int32) + nisa.tensor_scalar(dst=active_block_offset, data=prior_tokens_sbuf, op0=nl.right_shift, operand0=block_size_shift) + + # prior_block_offset is allocated per-batch inside the processing loop + # (see comment below). Allocating here instead would cause cross-batch + # compiler scheduling hazards: fused_impl decrements this tensor every + # full-prior iteration, and the compiler can schedule later batches' + # DMA descriptors using the previous batch's decremented value. + + # Compute partial prior and flags + temp_seg_tokens = allocator.alloc_sbuf_tensor(shape=(1, 1), dtype=nl.int32) + nisa.tensor_scalar(dst=temp_seg_tokens, data=num_full_prior_segments_i32, op0=nl.multiply, operand0=prior_seg_size) + + partial_prior_tokens = allocator.alloc_sbuf_tensor(shape=(1, 1), dtype=nl.int32) + nisa.tensor_tensor(dst=partial_prior_tokens, data1=prior_tokens_sbuf, data2=temp_seg_tokens, op=nl.subtract) + + is_partial_prior_segment = allocator.alloc_sbuf_tensor(shape=(1, 1), dtype=nl.int32) + nisa.tensor_scalar(dst=is_partial_prior_segment, data=partial_prior_tokens, op0=nl.greater, operand0=0) + + is_not_partial_prior_segment = allocator.alloc_sbuf_tensor(shape=(1, 1), dtype=nl.int32) + nisa.tensor_scalar( + dst=is_not_partial_prior_segment, data=is_partial_prior_segment, op0=nl.subtract, operand0=1, reverse0=True + ) + + # Allocate K/V sbuf for max(active, prior) tiles (reused across iterations). + # fused_segmented_attention_impl aliases k_prior_sbuf / v_prior_sbuf onto + # these buffers via list slicing, saving ~48 KB per partition on the hot + # non-KVP path. + k_cache_sbuf = _alloc_k_cache_sbuf(allocator, head_dim, num_k_tiles_sbuf) + v_cache_sbuf = allocator.alloc_sbuf_tensor( + shape=(_V_TILE_SZ, head_dim), + dtype=nl.bfloat16, + block_dim=[num_v_tiles_sbuf], + num_free_tiles=[num_v_tiles_sbuf], + ) + + # Allocate HBM buffers for unnormalized output and softmax stats (single batch, reused) + # Uses tp_out=False for intermediate HBM (matches reduce_one_batch layout) + softmax_shape = (1, 128, num_grps) + o_prev_hbm = nl.ndarray(shape=(1, seqlen_q, head_dim), dtype=nl.float32, buffer=nl.private_hbm) + neg_max_prev_hbm = nl.ndarray(shape=softmax_shape, dtype=nl.float32, buffer=nl.private_hbm) + sum_prev_hbm = nl.ndarray(shape=softmax_shape, dtype=nl.float32, buffer=nl.private_hbm) + o_curr_hbm = nl.ndarray(shape=(1, seqlen_q, head_dim), dtype=nl.float32, buffer=nl.private_hbm) + neg_max_curr_hbm = nl.ndarray(shape=softmax_shape, dtype=nl.float32, buffer=nl.private_hbm) + sum_curr_hbm = nl.ndarray(shape=softmax_shape, dtype=nl.float32, buffer=nl.private_hbm) + + # Copy final results to HBM (allocate for full bs_q, write only assigned portion) + # Intermediates always use non-transposed layout (tp_out=False) for reduce_one_batch compatibility. + # When tp_out=True, we transpose during the final normalize+write step. + if tp_out: + result = nl.ndarray(shape=(bs_q, head_dim, seqlen_q), dtype=q.dtype, buffer=result_buffer) + else: + result = nl.ndarray(shape=(bs_q, seqlen_q, head_dim), dtype=q.dtype, buffer=result_buffer) + + # Allocate softmax stats tensors for KV-parallel mode + if kvp_offset is not None: + out_neg_max_hbm = nl.ndarray(shape=(bs_q, seqlen_q), dtype=nl.float32, buffer=result_buffer) + out_sum_recip_hbm = nl.ndarray(shape=(bs_q, seqlen_q), dtype=nl.float32, buffer=result_buffer) + + # Load kvp_offset into SBUF once (reused per batch) + kvp_offset_sbuf = None + if kvp_offset is not None: + kvp_offset_sbuf = allocator.alloc_sbuf_tensor(shape=(1, 1), dtype=nl.int32) + nisa.dma_copy(dst=kvp_offset_sbuf, src=kvp_offset) + + # Workaround for NCC_IBIR251: Allocate Q buffer once for single batch + # Makes Q "internal" so dma_transpose/access patterns work in dynamic loops (LNC2) + if tp_q: + q_internal = nl.ndarray(shape=(1, seqlen_q, head_dim), dtype=q.dtype, buffer=nl.private_hbm) + else: + q_internal = nl.ndarray(shape=(1, head_dim, seqlen_q), dtype=q.dtype, buffer=nl.private_hbm) + + # Process primary batches one at a time using helper function (skip if no primary batches) + for b_idx in range(num_bs_per_shard): + batch_id = b_idx + bs_offset # Global bs_q index + + # Derive batch index and KV head index from batch_id for GQA + batch_b_i = batch_id // num_q_heads + batch_h_i = (batch_id % num_q_heads) * num_kv_head // num_q_heads + + # Copy this batch's query data (layout matches tp_q) + if tp_q: + nisa.dma_copy( + dst=q_internal[0, :, :], + src=q.ap( + pattern=[[head_dim, seqlen_q], [1, head_dim]], + offset=batch_id * seqlen_q * head_dim, + ), + ) + else: + nisa.dma_copy( + dst=q_internal[0, :, :], + src=q.ap( + pattern=[[seqlen_q, head_dim], [1, seqlen_q]], + offset=batch_id * head_dim * seqlen_q, + ), + ) + + # Allocate a fresh prior_block_offset SBUF tensor per batch to avoid + # cross-batch aliasing. fused_segmented_attention_impl's full-prior + # loop decrements this value every iteration; without a fresh per- + # batch tensor, the compiler can schedule subsequent batches' + # compiled DMA descriptors using the decremented (zero or underflow) + # value from the previous batch instead of the restored value. That + # causes HW scalar-DGE OOB on multi-batch (bs_q > 1) configs with + # >=2 full-prior iterations. + prior_block_offset = allocator.alloc_sbuf_tensor(shape=(1, 1), dtype=nl.uint32) + nisa.tensor_scalar( + dst=prior_block_offset, data=num_full_prior_segments_i32, op0=nl.multiply, operand0=num_blocks_per_seg + ) + + # Process this single batch + # fused_segmented_attention_impl handles both KVP and non-KVP via kvp_offset parameter. + # Note: fused_segmented_attention_impl handles allocator reset internally + fused_segmented_attention_impl( + q_hbm=q_internal, + num_batches=1, + k_cache=k_cache, + v_cache=v_cache, + block_tables=block_tables, + k_cache_sbuf=k_cache_sbuf, + v_cache_sbuf=v_cache_sbuf, + o_prev_hbm=o_prev_hbm, + neg_max_prev_hbm=neg_max_prev_hbm, + sum_prev_hbm=sum_prev_hbm, + o_curr_hbm=o_curr_hbm, + neg_max_curr_hbm=neg_max_curr_hbm, + sum_curr_hbm=sum_curr_hbm, + prior_tokens_sbuf=prior_tokens_sbuf, + num_full_prior_segments_i32=num_full_prior_segments_i32, + partial_prior_tokens=partial_prior_tokens, + is_partial_prior_segment=is_partial_prior_segment, + is_not_partial_prior_segment=is_not_partial_prior_segment, + active_block_offset=active_block_offset, + prior_block_offset=prior_block_offset, + allocator=allocator, + prior_seg_size=prior_seg_size, + block_size=block_size, + scale=scale, + head_dim=head_dim, + num_grps=num_grps, + num_active_blocks=num_active_blocks, + num_k_tiles_active=num_k_tiles_active, + num_v_tiles_active=num_v_tiles_active, + num_blocks_per_seg=num_blocks_per_seg, + num_k_tiles_per_seg=num_k_tiles_per_seg, + num_v_tiles_per_seg=num_v_tiles_per_seg, + b_i=batch_b_i, + h_i=batch_h_i, + tp_q=tp_q, + tp_out=tp_out, + load_kv_cache_fn=load_kv_cache, + attention_cte_fn=_attention_cte, + sink=sink, + kvp_offset=kvp_offset_sbuf, + k_pre_transposed=k_pre_transposed, + k_scale_sb=k_scale_sb, + ) + + # Normalize and write results to final output for this batch + # o_prev_hbm[0] has unnormalized output, normalize (divide by S) and write to result[batch_id] + sb_p = nl.tile_size.pmax + sm_pat = [[num_grps, sb_p], [1, num_grps]] + o_tile_pat = [[head_dim, sb_p], [1, head_dim]] + + norm_addr = allocator.get_current_address() + sum_sb = allocator.alloc_sbuf_tensor(shape=(sb_p, num_grps), dtype=nl.float32) + nisa.dma_copy(dst=sum_sb, src=sum_prev_hbm.ap(pattern=sm_pat, offset=0)) + sum_recip_sb = allocator.alloc_sbuf_tensor(shape=(sb_p, num_grps), dtype=nl.float32) + nisa.reciprocal(sum_recip_sb, sum_sb) + + num_free = min(num_grps, _MAX_FREE_TILES) + o_sb = allocator.alloc_sbuf_tensor( + shape=(sb_p, head_dim), + dtype=nl.bfloat16, + block_dim=[num_grps], + num_free_tiles=[num_free], + ) + if tp_out: + o_tp_psum = nl.ndarray((head_dim, sb_p), dtype=nl.bfloat16, buffer=nl.psum, address=(0, 0)) + o_tp_sb = allocator.alloc_sbuf_tensor(shape=(head_dim, sb_p), dtype=nl.bfloat16) + for grp_i in range(num_grps): + grp_o_offset = grp_i * sb_p * head_dim + nisa.dma_copy(dst=o_sb[grp_i], src=o_prev_hbm.ap(pattern=o_tile_pat, offset=grp_o_offset)) + # Delayed V dequant: fold v_scale multiply into normalization tensor_scalar + if v_scale_sb is not None: + nisa.tensor_scalar( + dst=o_sb[grp_i], + data=o_sb[grp_i], + op0=nl.multiply, + operand0=sum_recip_sb[:, grp_i], + op1=nl.multiply, + operand1=v_scale_sb, + ) + else: + nisa.tensor_scalar( + dst=o_sb[grp_i], + data=o_sb[grp_i], + op0=nl.multiply, + operand0=sum_recip_sb[:, grp_i], + ) + # Write to result[batch_id] + if tp_out: + # nc_transpose (128, head_dim) → (head_dim, 128), then write with transposed AP + nisa.nc_transpose(dst=o_tp_psum, data=o_sb[grp_i]) + nisa.tensor_copy(dst=o_tp_sb, src=o_tp_psum) + nisa.dma_copy( + dst=result.ap( + pattern=[[seqlen_q, head_dim], [1, sb_p]], + offset=batch_id * head_dim * seqlen_q + grp_i * sb_p, + ), + src=o_tp_sb, + ) + else: + dst_o_offset = batch_id * num_grps * sb_p * head_dim + grp_o_offset + nisa.dma_copy(dst=result.ap(pattern=o_tile_pat, offset=dst_o_offset), src=o_sb[grp_i]) + allocator.set_current_address(norm_addr) + + # Write softmax stats for KV-parallel mode + if kvp_offset is not None: + stats_addr = allocator.get_current_address() + neg_max_sb_kvp = allocator.alloc_sbuf_tensor(shape=(sb_p, num_grps), dtype=nl.float32) + sum_sb_kvp = allocator.alloc_sbuf_tensor(shape=(sb_p, num_grps), dtype=nl.float32) + sum_recip_sb_kvp = allocator.alloc_sbuf_tensor(shape=(sb_p, num_grps), dtype=nl.float32) + nisa.dma_copy(dst=neg_max_sb_kvp, src=neg_max_prev_hbm.ap(pattern=sm_pat, offset=0)) + nisa.dma_copy(dst=sum_sb_kvp, src=sum_prev_hbm.ap(pattern=sm_pat, offset=0)) + nisa.reciprocal(sum_recip_sb_kvp, sum_sb_kvp) + # Write in token-ordered layout: token t = p + g * sb_p stored at offset t + # AP pattern [[1, sb_p], [sb_p, num_grps]] stores [p, g] at offset p + g * sb_p + tok_pat = [[1, sb_p], [sb_p, num_grps]] + nisa.dma_copy( + dst=out_neg_max_hbm.ap(pattern=tok_pat, offset=batch_id * sb_p * num_grps), src=neg_max_sb_kvp + ) + nisa.dma_copy( + dst=out_sum_recip_hbm.ap(pattern=tok_pat, offset=batch_id * sb_p * num_grps), src=sum_recip_sb_kvp + ) + allocator.set_current_address(stats_addr) + + # Secondary sharding: handle remainder bs_q item with asymmetric sequence sharding. + # + # Core 0 handles Q[0 : ceil(num_grps/2)*128], Core 1 handles Q[ceil(num_grps/2)*128 : + # num_grps*128]. For num_grps == 1 the split is degenerate — Core 0 does the full + # remainder and Core 1 short-circuits. + # + # Prior-segment chunking reuses the top-level prior_seg_size on both cores so + # num_full_prior_segments_i32 is identical across cores; this keeps the 3 + # nl.dynamic_range loops inside fused_segmented_attention_impl iterating the + # same number of times (LNC2 sync requirement). + core0_grp_length = (num_grps + 1) // 2 + core1_grp_length = num_grps // 2 + if shard_id == 0: + grp_length = core0_grp_length + grp_start = 0 + else: + grp_length = core1_grp_length + grp_start = core0_grp_length + + # Both cores must enter the remainder block with matching dynamic_range + # iteration counts to satisfy LNC2 basic-block symmetry (NCC_IXGM002). + # When num_grps == 1, Core 1's grp_length == 0 — instead of short- + # circuiting (which would leave Core 1's IR with fewer basic blocks + # than Core 0's), we use effective_grp_length = max(grp_length, 1) for + # sizing and have Core 1 trace the same code paths against private_hbm + # scratch. Its final result-write is redirected away from the shared + # `result` at the HBM normalization step below. + run_remainder = has_remainder + # Static-Python predicate: Core 0 always has real work; Core 1 has real + # work only when core1_grp_length > 0. + run_work = grp_length > 0 + effective_grp_length = max(grp_length, 1) + # Core 1's dummy path needs grp_start=0 so its (discarded) DMA writes + # target a valid in-bounds offset of its scratch buffer. + effective_grp_start = grp_start if run_work else 0 + + if run_remainder: + # Per-shard active-segment token/tile counts (drive load_kv_cache + fused_impl's + # active processing; NOT the dynamic prior-segment loops). + # Uses effective_grp_length so Core 1's dummy path still has valid + # (non-zero) sizing for the allocations and dynamic_range parameters. + effective_q_tokens = effective_grp_length * 128 + num_blocks_per_effective_seg = effective_q_tokens // block_size + num_k_tiles_per_effective_seg = math.ceil(effective_q_tokens / _K_TILE_SZ) + num_v_tiles_per_effective_seg = num_k_tiles_per_effective_seg * (_K_TILE_SZ // _V_TILE_SZ) + + # Core 0's active-segment sizing — used by Core 1's "+1 extra iteration" to + # attend over Core 0's active KV region, and as the shared effective + # segment size below (ceil so it's >= core1's). + core0_q_tokens = core0_grp_length * 128 + core0_num_blocks_per_effective_seg = core0_q_tokens // block_size + core0_num_k_tiles_per_effective_seg = math.ceil(core0_q_tokens / _K_TILE_SZ) + core0_num_v_tiles_per_effective_seg = core0_num_k_tiles_per_effective_seg * (_K_TILE_SZ // _V_TILE_SZ) + + # Shared effective prior-segment size for the remainder's fused_impl call. + # Both cores pass the same value so num_full_prior_segments (computed from + # prior_tokens / effective_prior_seg_size_shared) is identical across + # cores and the 3 inner nl.dynamic_range loops iterate in lockstep + # (LNC2 sync requirement). + # + # Use the top-level prior_seg_size so the remainder path iterates the + # SAME number of prior segments as the primary path (e.g. 3 iterations + # for prior_tokens=32512, prior_seg_size=8192). Using a smaller size + # here (e.g. core0_q_tokens=128) would balloon the iteration count to + # prior_tokens/core0_q_tokens ≈ 254, amplifying bf16 accumulation + # drift in the flash-attention online softmax to ~15% rel error. The + # k_cache_sbuf / v_cache_sbuf allocations above (lines 1708–1722) are + # already sized for max(num_k_tiles_per_seg, num_k_tiles_active), so + # using prior_seg_size fits. + effective_prior_seg_size_shared = prior_seg_size + num_blocks_per_effective_seg_shared = effective_prior_seg_size_shared // block_size + num_k_tiles_per_effective_seg_shared = math.ceil(effective_prior_seg_size_shared / _K_TILE_SZ) + num_v_tiles_per_effective_seg_shared = num_k_tiles_per_effective_seg_shared * (_K_TILE_SZ // _V_TILE_SZ) + + # Sequence start position for this core's Q slice. Core 1's dummy + # path (run_work=False) clamps to 0 so the DMA source offset + # below stays in-bounds (reads bytes overlapping Core 0's Q slice; + # result is discarded downstream). + seq_start = grp_start * 128 if run_work else 0 + + # Copy only this core's portion of Q to q_internal_remainder. + if tp_q: + q_internal_remainder = nl.ndarray( + shape=(1, effective_q_tokens, head_dim), dtype=q.dtype, buffer=nl.private_hbm + ) + nisa.dma_copy( + dst=q_internal_remainder[0, :, :], + src=q.ap( + pattern=[[head_dim, effective_q_tokens], [1, head_dim]], + offset=last_batch * seqlen_q * head_dim + seq_start * head_dim, + ), + ) + else: + q_internal_remainder = nl.ndarray( + shape=(1, head_dim, effective_q_tokens), dtype=q.dtype, buffer=nl.private_hbm + ) + nisa.dma_copy( + dst=q_internal_remainder[0, :, :], + src=q.ap( + pattern=[[seqlen_q, head_dim], [1, effective_q_tokens]], + offset=last_batch * head_dim * seqlen_q + seq_start, + ), + ) + + # Allocate HBM buffers for remainder (per-shard Q length). + # Use effective_grp_length so Core 1's dummy path gets valid + # non-zero shapes; its writes are isolated in private_hbm. + # + # Intermediate unnormalized output buffers are f32 to match the + # primary batch path (lines 1732/1735) and avoid bf16 quantization + # on every flash-attention online-softmax combine. With many prior + # segments (num_full >= 128), accumulated bf16 quantization error + # on each segment's o_rem_prev_hbm round-trip causes systematic + # drift (observed 7–22% rel diff for num_full >= 128). + rem_softmax_shape = (1, 128, effective_grp_length) + o_rem_prev_hbm = nl.ndarray(shape=(1, effective_q_tokens, head_dim), dtype=nl.float32, buffer=nl.private_hbm) + neg_max_rem_prev_hbm = nl.ndarray(shape=rem_softmax_shape, dtype=nl.float32, buffer=nl.private_hbm) + sum_rem_prev_hbm = nl.ndarray(shape=rem_softmax_shape, dtype=nl.float32, buffer=nl.private_hbm) + o_rem_curr_hbm = nl.ndarray(shape=(1, effective_q_tokens, head_dim), dtype=nl.float32, buffer=nl.private_hbm) + neg_max_rem_curr_hbm = nl.ndarray(shape=rem_softmax_shape, dtype=nl.float32, buffer=nl.private_hbm) + sum_rem_curr_hbm = nl.ndarray(shape=rem_softmax_shape, dtype=nl.float32, buffer=nl.private_hbm) + + # Recompute num_full_prior_segments, partial_prior_tokens, and the + # associated flags using the SHARED effective_prior_seg_size so both + # cores see identical values → dynamic loops inside fused_impl iterate + # the same count. + num_full_prior_segments_remainder_f32 = allocator.alloc_sbuf_tensor(shape=(1, 1), dtype=nl.float32) + num_full_prior_segments_remainder_i32 = allocator.alloc_sbuf_tensor(shape=(1, 1), dtype=nl.int32) + nisa.tensor_scalar( + dst=num_full_prior_segments_remainder_f32, + data=prior_tokens_sbuf, + op0=nl.multiply, + operand0=1.0 / effective_prior_seg_size_shared, + ) + floor_nisa_kernel( + src_t=num_full_prior_segments_remainder_f32, + dst_t=num_full_prior_segments_remainder_i32, + p_size=1, + f_size=1, + allocator=allocator, + ) + + # prior_block_offset_remainder = num_full_prior_segments_remainder * num_blocks_per_effective_seg_shared + prior_block_offset_remainder = allocator.alloc_sbuf_tensor(shape=(1, 1), dtype=nl.uint32) + nisa.tensor_scalar( + dst=prior_block_offset_remainder, + data=num_full_prior_segments_remainder_i32, + op0=nl.multiply, + operand0=num_blocks_per_effective_seg_shared, + ) + + # partial_prior_tokens_remainder = prior_tokens - num_full_prior_segments_remainder * effective_prior_seg_size_shared + temp_seg_tokens_remainder = allocator.alloc_sbuf_tensor(shape=(1, 1), dtype=nl.int32) + nisa.tensor_scalar( + dst=temp_seg_tokens_remainder, + data=num_full_prior_segments_remainder_i32, + op0=nl.multiply, + operand0=effective_prior_seg_size_shared, + ) + partial_prior_tokens_remainder = allocator.alloc_sbuf_tensor(shape=(1, 1), dtype=nl.int32) + nisa.tensor_tensor( + dst=partial_prior_tokens_remainder, + data1=prior_tokens_sbuf, + data2=temp_seg_tokens_remainder, + op=nl.subtract, + ) + + is_partial_prior_segment_remainder = allocator.alloc_sbuf_tensor(shape=(1, 1), dtype=nl.int32) + nisa.tensor_scalar( + dst=is_partial_prior_segment_remainder, + data=partial_prior_tokens_remainder, + op0=nl.greater, + operand0=0, + ) + is_not_partial_prior_segment_remainder = allocator.alloc_sbuf_tensor(shape=(1, 1), dtype=nl.int32) + nisa.tensor_scalar( + dst=is_not_partial_prior_segment_remainder, + data=is_partial_prior_segment_remainder, + op0=nl.subtract, + operand0=1, + reverse0=True, + ) + + # Adjust active_block_offset per core: Core 0 uses the primary offset, + # Core 1 steps past Core 0's active region (size = core0_num_blocks_per_effective_seg). + # When run_work=False (Core 1's dummy path at num_grps=1), use the + # primary offset so the KV cache read stays in-bounds; Core 1's + # output is discarded via the private_hbm scratch redirect below. + active_block_offset_remainder = allocator.alloc_sbuf_tensor(shape=(1, 1), dtype=nl.int32) + if shard_id == 0 or not run_work: + nisa.tensor_copy(dst=active_block_offset_remainder, src=active_block_offset) + else: + nisa.tensor_scalar( + dst=active_block_offset_remainder, + data=active_block_offset, + op0=nl.add, + operand0=core0_num_blocks_per_effective_seg, + ) + + # Derive batch/head indices for remainder batch + rem_b_i = last_batch // num_q_heads + rem_h_i = (last_batch % num_q_heads) * num_kv_head // num_q_heads + + # Process this core's portion. Pass the full k/v_cache_sbuf (allocated + # earlier with max(num_k_tiles_per_seg, num_k_tiles_active) tiles) so + # fused_impl can address num_k_tiles_per_seg_shared tiles during the + # prior loop and num_k_tiles_per_effective_seg tiles during the + # per-shard active segment. + # fused_segmented_attention_impl handles both KVP and non-KVP via kvp_offset parameter. + fused_segmented_attention_impl( + q_hbm=q_internal_remainder, + num_batches=1, + k_cache=k_cache, + v_cache=v_cache, + block_tables=block_tables, + k_cache_sbuf=k_cache_sbuf, + v_cache_sbuf=v_cache_sbuf, + o_prev_hbm=o_rem_prev_hbm, + neg_max_prev_hbm=neg_max_rem_prev_hbm, + sum_prev_hbm=sum_rem_prev_hbm, + o_curr_hbm=o_rem_curr_hbm, + neg_max_curr_hbm=neg_max_rem_curr_hbm, + sum_curr_hbm=sum_rem_curr_hbm, + prior_tokens_sbuf=prior_tokens_sbuf, + # Prior-driven params use the SHARED effective values so + # num_full_prior_segments is identical on both cores. + num_full_prior_segments_i32=num_full_prior_segments_remainder_i32, + partial_prior_tokens=partial_prior_tokens_remainder, + is_partial_prior_segment=is_partial_prior_segment_remainder, + is_not_partial_prior_segment=is_not_partial_prior_segment_remainder, + active_block_offset=active_block_offset_remainder, + prior_block_offset=prior_block_offset_remainder, + allocator=allocator, + prior_seg_size=effective_prior_seg_size_shared, + block_size=block_size, + scale=scale, + head_dim=head_dim, + num_grps=effective_grp_length, + # Active-driven params use per-shard Q length. + num_active_blocks=num_blocks_per_effective_seg, + num_k_tiles_active=num_k_tiles_per_effective_seg, + num_v_tiles_active=num_v_tiles_per_effective_seg, + # Prior-segment chunk sizing uses the shared effective value. + num_blocks_per_seg=num_blocks_per_effective_seg_shared, + num_k_tiles_per_seg=num_k_tiles_per_effective_seg_shared, + num_v_tiles_per_seg=num_v_tiles_per_effective_seg_shared, + b_i=rem_b_i, + h_i=rem_h_i, + tp_q=tp_q, + tp_out=tp_out, + load_kv_cache_fn=load_kv_cache, + attention_cte_fn=_attention_cte, + sink=sink, + kvp_offset=kvp_offset_sbuf, + k_pre_transposed=k_pre_transposed, + k_scale_sb=k_scale_sb, + ) + + # Core 1 does one extra iteration to process Core 0's active segment. + # Core 0's active region covers active_block_offset .. active_block_offset + + # core0_num_blocks_per_effective_seg blocks (the first core0_q_tokens tokens + # of active K). Core 1 needs to attend to these positions (they're + # causally before Core 1's own Q) — fused_impl's active segment only + # covered Core 1's own slice (blocks starting at active_block_offset + + # core0_num_blocks_per_effective_seg). + if shard_id == 1: + init_sbuf_addr = allocator.get_current_address() + # Block offset = active_block_offset (start of all active blocks); + # load_kv_cache will load core0_num_blocks_per_effective_seg blocks + # starting from there, which is exactly Core 0's active region. + extra_offset = allocator.alloc_sbuf_tensor(shape=(1, 1), dtype=nl.uint32) + nisa.tensor_copy(dst=extra_offset, src=active_block_offset) + + # Zero the LAST K/V tile that the extra iteration will read. + # k_cache_sbuf / v_cache_sbuf still hold data from the prior-segment + # loop. load_kv_cache writes core0_num_blocks_per_effective_seg blocks + # starting at tile 0 — any preceding tile is fully overwritten, and + # tiles beyond the consumed range (k_cache_sbuf[:core0_num_k_tiles_per_effective_seg]) + # are sliced away. The only stale region the matmul can see is the + # PARTIAL tail of the last consumed tile when + # core0_num_blocks_per_effective_seg % num_blocks_per_k_tile != 0. + # Zeroing just that one tile is sufficient to bound the spurious + # contribution (Q·0 = 0 scores → zero numerator contribution). + if head_dim == 256: + nisa.memset(k_cache_sbuf[core0_num_k_tiles_per_effective_seg - 1][0][...], value=0.0) + nisa.memset(k_cache_sbuf[core0_num_k_tiles_per_effective_seg - 1][1][...], value=0.0) + else: + nisa.memset(k_cache_sbuf[core0_num_k_tiles_per_effective_seg - 1][...], value=0.0) + nisa.memset(v_cache_sbuf[core0_num_v_tiles_per_effective_seg - 1][...], value=0.0) + + # Load Core 0's active KV segment. + load_kv_cache( + k_cache, + v_cache, + block_tables, + k_cache_sbuf, + v_cache_sbuf, + rem_b_i, + rem_h_i, + extra_offset, + core0_num_blocks_per_effective_seg, + allocator, + k_pre_transposed=k_pre_transposed, + ) + allocator.set_current_address(init_sbuf_addr) + + # Compute attention for this extra segment + _attention_cte( + q_internal_remainder, + None, + None, + scale=scale, + causal_mask=False, + tp_q=tp_q, + tp_k=False, + tp_out=False, + cache_softmax=True, + skip_output_normalization=True, + k_cache_sbuf=k_cache_sbuf[:core0_num_k_tiles_per_effective_seg], + v_cache_sbuf=v_cache_sbuf[:core0_num_v_tiles_per_effective_seg], + out_o_hbm=o_rem_curr_hbm, + out_neg_max_hbm=neg_max_rem_curr_hbm, + out_sum_hbm=sum_rem_curr_hbm, + init_sbuf_addr=init_sbuf_addr, + k_scale_sb=k_scale_sb, + ) + allocator.set_current_address(init_sbuf_addr) + + # HBM-based reduction: combine extra segment into accumulated results. + # Uses effective_grp_length so Core 1's dummy path (num_grps=1) has + # valid non-zero shapes/iteration counts. + rem_sb_p = nl.tile_size.pmax + rem_softmax_pat = [[effective_grp_length, rem_sb_p], [1, effective_grp_length]] + rem_o_pat = [[head_dim, rem_sb_p], [1, head_dim]] + rem_num_free = min(effective_grp_length, _MAX_FREE_TILES) + + rem_neg_max_prev_sb = allocator.alloc_sbuf_tensor(shape=(rem_sb_p, effective_grp_length), dtype=nl.float32) + rem_sum_prev_sb = allocator.alloc_sbuf_tensor(shape=(rem_sb_p, effective_grp_length), dtype=nl.float32) + rem_neg_max_curr_sb = allocator.alloc_sbuf_tensor(shape=(rem_sb_p, effective_grp_length), dtype=nl.float32) + rem_sum_curr_sb = allocator.alloc_sbuf_tensor(shape=(rem_sb_p, effective_grp_length), dtype=nl.float32) + rem_o_prev_sb = allocator.alloc_sbuf_tensor( + shape=(rem_sb_p, head_dim), + dtype=nl.float32, + block_dim=[effective_grp_length], + num_free_tiles=[rem_num_free], + ) + rem_o_curr_sb = allocator.alloc_sbuf_tensor( + shape=(rem_sb_p, head_dim), + dtype=nl.float32, + block_dim=[effective_grp_length], + num_free_tiles=[rem_num_free], + ) + rem_o_new_sb = allocator.alloc_sbuf_tensor( + shape=(rem_sb_p, head_dim), + dtype=nl.float32, + block_dim=[effective_grp_length], + num_free_tiles=[rem_num_free], + ) + rem_batch_loop_addr = allocator.get_current_address() + + reduce_one_batch( + o_rem_prev_hbm, + neg_max_rem_prev_hbm, + sum_rem_prev_hbm, + o_rem_curr_hbm, + neg_max_rem_curr_hbm, + sum_rem_curr_hbm, + 0, + 0, + effective_grp_length, + head_dim, + effective_grp_length, + rem_sb_p, + rem_softmax_pat, + rem_o_pat, + rem_neg_max_prev_sb, + rem_sum_prev_sb, + rem_neg_max_curr_sb, + rem_sum_curr_sb, + rem_o_prev_sb, + rem_o_curr_sb, + rem_o_new_sb, + rem_batch_loop_addr, + allocator, + ) + allocator.set_current_address(init_sbuf_addr) + else: + # Core 0 does dummy ops for control flow consistency + rng_seeds_sb = allocator.alloc_sbuf_tensor(shape=(1, 1), dtype=nl.uint32) + nisa.memset(rng_seeds_sb, 0.0) + nisa.set_rng_seed(rng_seeds_sb) + + # Normalize and write each core's portion to result[last_batch]. + # Core 1's dummy path (run_work=False, i.e. num_grps=1) redirects + # the result write to a private_hbm scratch tensor so it doesn't + # corrupt the shared `result`. Both cores trace the same static + # Python for-loop so the IR stays symmetric for LNC2. + rem_norm_addr = allocator.get_current_address() + rem_sb_p2 = nl.tile_size.pmax + rem_sm_pat = [[effective_grp_length, rem_sb_p2], [1, effective_grp_length]] + rem_o_tile_pat = [[head_dim, rem_sb_p2], [1, head_dim]] + + # Scratch result buffer for Core 1's dummy path. Core 0 points to + # the shared `result`; Core 1 points to a fresh private_hbm with + # matching shape so the DMA writes don't corrupt shared state. + if run_work: + result_write_target = result + else: + result_write_target = nl.ndarray(shape=result.shape, dtype=result.dtype, buffer=nl.private_hbm) + + rem_sum_sb2 = allocator.alloc_sbuf_tensor(shape=(rem_sb_p2, effective_grp_length), dtype=nl.float32) + nisa.dma_copy(dst=rem_sum_sb2, src=sum_rem_prev_hbm.ap(pattern=rem_sm_pat, offset=0)) + rem_sum_recip_sb2 = allocator.alloc_sbuf_tensor(shape=(rem_sb_p2, effective_grp_length), dtype=nl.float32) + nisa.reciprocal(rem_sum_recip_sb2, rem_sum_sb2) + + rem_num_free2 = min(effective_grp_length, _MAX_FREE_TILES) + rem_o_sb2 = allocator.alloc_sbuf_tensor( + shape=(rem_sb_p2, head_dim), + dtype=nl.bfloat16, + block_dim=[effective_grp_length], + num_free_tiles=[rem_num_free2], + ) + if tp_out: + rem_o_tp_psum2 = nl.ndarray((head_dim, rem_sb_p2), dtype=nl.bfloat16, buffer=nl.psum, address=(0, 0)) + rem_o_tp_sb2 = allocator.alloc_sbuf_tensor(shape=(head_dim, rem_sb_p2), dtype=nl.bfloat16) + for local_grp_i in range(effective_grp_length): + global_grp_i = effective_grp_start + local_grp_i + grp_start_pos = global_grp_i * 128 + + # Read from o_rem_prev_hbm[0] at local group offset + src_o_offset = local_grp_i * rem_sb_p2 * head_dim + nisa.dma_copy( + dst=rem_o_sb2[local_grp_i], src=o_rem_prev_hbm.ap(pattern=rem_o_tile_pat, offset=src_o_offset) + ) + # Delayed V dequant: fold v_scale multiply into normalization tensor_scalar + if v_scale_sb is not None: + nisa.tensor_scalar( + dst=rem_o_sb2[local_grp_i], + data=rem_o_sb2[local_grp_i], + op0=nl.multiply, + operand0=rem_sum_recip_sb2[:, local_grp_i], + op1=nl.multiply, + operand1=v_scale_sb, + ) + else: + nisa.tensor_scalar( + dst=rem_o_sb2[local_grp_i], + data=rem_o_sb2[local_grp_i], + op0=nl.multiply, + operand0=rem_sum_recip_sb2[:, local_grp_i], + ) + # Write to result[last_batch] at global group position (or scratch + # on Core 1's dummy path). For Core 1's dummy path, grp_start=0 + # and local_grp_i=0, so grp_start_pos=0 — writes to the scratch + # buffer's beginning, which is safe. + if tp_out: + nisa.nc_transpose(dst=rem_o_tp_psum2, data=rem_o_sb2[local_grp_i]) + nisa.tensor_copy(dst=rem_o_tp_sb2, src=rem_o_tp_psum2) + nisa.dma_copy( + dst=result_write_target.ap( + pattern=[[seqlen_q, head_dim], [1, rem_sb_p2]], + offset=last_batch * head_dim * seqlen_q + grp_start_pos, + ), + src=rem_o_tp_sb2, + ) + else: + dst_o_offset = last_batch * num_grps * rem_sb_p2 * head_dim + global_grp_i * rem_sb_p2 * head_dim + nisa.dma_copy( + dst=result_write_target.ap(pattern=rem_o_tile_pat, offset=dst_o_offset), + src=rem_o_sb2[local_grp_i], + ) + allocator.set_current_address(rem_norm_addr) + + if kvp_offset is not None: + return result, out_neg_max_hbm, out_sum_recip_hbm + return result diff --git a/src/neuronx_distributed_inference/modules/attention/nki_kernels/qwen_segcte256/fused_segmented_attention_256.py b/src/neuronx_distributed_inference/modules/attention/nki_kernels/qwen_segcte256/fused_segmented_attention_256.py new file mode 100644 index 00000000..c9583685 --- /dev/null +++ b/src/neuronx_distributed_inference/modules/attention/nki_kernels/qwen_segcte256/fused_segmented_attention_256.py @@ -0,0 +1,1527 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Fused segmented attention: replaces per-segment _attention_cte + external reduction +with a single flash-attention loop over all segments. + +Uses kv_section_idx=0 so K/V indexing always starts at tile 0 (each segment has +its own K/V in SBUF), while section_idx > 0 triggers the flash attention +accumulation path in _write_back_impl and _update_max_impl. This keeps the PV +accumulation in float32 SBUF across segments, matching _attention_cte's internal +precision. +""" + +import math + +import nki.isa as nisa +import nki.language as nl + +from nkilib.core.utils.attention_reduce import _MAX_FREE_TILES, reduce_one_batch +from nkilib.core.utils.kernel_assert import kernel_assert +from nkilib.core.utils.kernel_helpers import PSUM_BANK_SIZE, div_ceil +from nkilib.core.utils.modular_allocator import ModularAllocator +from nkilib.core.utils.stream_shuffle_broadcast import stream_shuffle_broadcast +from nki.isa import reduce_cmd +from nki.language.opcode import maximum as _maximum +from nkilib.core.attention.attention_cte import ( + _FLOAT32_MIN, + _K_TILE_SZ, + _LARGE_TILE_SZ, + _Q_GRP_SZ, + _V_TILE_SZ, + AttnConfig, + AttnInternalBuffers, + SectionParams, + _allocate_attention_buffers, + _compute_tile_parameters, + _exp_impl, + _fused_qkmax_and_pv_impl, + _get_kv_tile_apc, + _has_any_compute_causal, + _has_any_compute_swa, + _load_q_impl, + _pv_impl, + _qk_and_max_impl, + _setup_range_select_bounds, + _update_max_impl, + _write_back_impl, +) + + +def _run_groups(grp_start, grp_end, ac, atp, sp, bufs, q, batch_id, o, sbuf_addr, sink=None): + """Run Q-group loop with software pipelining.""" + n = grp_end - grp_start + if n <= 1: + _load_q_impl(grp_start, ac, atp, sp, bufs, q, batch_id, sbuf_addr) + _qk_and_max_impl(grp_start, ac, atp, sp, bufs, batch_id) + _update_max_impl(grp_start, ac, atp, sp, bufs, sink) + _exp_impl(grp_start, ac, atp, sp, bufs, sink) + _pv_impl(grp_start, ac, atp, sp, bufs) + _write_back_impl(grp_start, ac, atp, sp, bufs, o, batch_id) + else: + _load_q_impl(grp_start, ac, atp, sp, bufs, q, batch_id, sbuf_addr) + _qk_and_max_impl(grp_start, ac, atp, sp, bufs, batch_id) + _update_max_impl(grp_start, ac, atp, sp, bufs, sink) + _exp_impl(grp_start, ac, atp, sp, bufs, sink) + + _load_q_impl(grp_start + 1, ac, atp, sp, bufs, q, batch_id, sbuf_addr) + _qk_and_max_impl(grp_start + 1, ac, atp, sp, bufs, batch_id) + _update_max_impl(grp_start + 1, ac, atp, sp, bufs, sink) + + for grp_i in range(grp_start, grp_end - 2): + _load_q_impl(grp_i + 2, ac, atp, sp, bufs, q, batch_id, sbuf_addr) + _exp_impl(grp_i + 1, ac, atp, sp, bufs, sink) + _fused_qkmax_and_pv_impl(grp_i, ac, atp, sp, bufs, batch_id) + _write_back_impl(grp_i, ac, atp, sp, bufs, o, batch_id) + _update_max_impl(grp_i + 2, ac, atp, sp, bufs, sink) + + _pv_impl(grp_end - 2, ac, atp, sp, bufs) + _write_back_impl(grp_end - 2, ac, atp, sp, bufs, o, batch_id) + _exp_impl(grp_end - 1, ac, atp, sp, bufs, sink) + _pv_impl(grp_end - 1, ac, atp, sp, bufs) + _write_back_impl(grp_end - 1, ac, atp, sp, bufs, o, batch_id) + + +def _make_ac_atp( + seqlen_q, seqlen_k, head_dim, dtype, causal, scale, tp_q, tp_out, num_sections, use_cp=False, global_cp_deg=None +): + """Create AttnConfig + AttnTileParams.""" + ac = AttnConfig( + seqlen_q=seqlen_q, + seqlen_k_active=seqlen_k, + seqlen_k_prior=None, + d=head_dim, + tp_q=tp_q, + tp_k=False, + tp_out=tp_out, + is_prefix_caching=False, + causal_mask=causal, + use_swa=False, + sliding_window=0, + use_cp=use_cp, + global_cp_deg=global_cp_deg, + cp_strided_q_slicing=False, + cp_striped_input=False, + scale=scale, + cache_softmax=True, + skip_output_normalization=True, + dtype=dtype, + softmax_dtype=nl.float32, + mm_out_dtype=nl.float32, + is_sequence_packed=False, + ) + atp = _compute_tile_parameters(ac, is_seqlen_sharded=False) + if head_dim == 256: + atp.num_q_grps_per_load = min(4, atp.num_grps) + atp.num_sections = num_sections + return ac, atp + + +def _kvp_partial_prior_attention( + q_hbm, + k_cache_sbuf, + v_cache_sbuf, + k_prior_sbuf, + v_prior_sbuf, + o_prev_hbm, + neg_max_prev_hbm, + sum_prev_hbm, + o_curr_hbm, + neg_max_curr_hbm, + sum_curr_hbm, + kvp_offset_active_hbm, + kvp_offset, + prior_block_offset, + partial_prior_tokens, + num_k_tiles_active, + num_v_tiles_active, + num_k_tiles_per_seg, + num_v_tiles_per_seg, + n_grps, + head_dim, + block_size, + sb_p, + scale, + tp_q, + allocator, + attention_cte_fn, + sink=None, + k_scale_sb=None, +): + """KVP partial prior: two separate attention_cte calls (active + prior) then reduce. + + Splits into two calls to avoid the unsupported use_cp + is_prefix_caching combination: + 1. Active-only: causal_mask=True with cp_offset=kvp_offset_active + 2. Prior-only: causal_mask=False with effective_prior_used_len + Results are reduced via online softmax into o_prev_hbm in-place. + """ + init_sbuf_addr = allocator.get_current_address() + + # Call 1: active-only with causal mask + cp_offset. + attention_cte_fn( + q_hbm, + None, + None, + scale=scale, + causal_mask=True, + tp_q=tp_q, + tp_k=False, + tp_out=False, + cache_softmax=True, + skip_output_normalization=True, + k_cache_sbuf=k_cache_sbuf[:num_k_tiles_active], + v_cache_sbuf=v_cache_sbuf[:num_v_tiles_active], + out_o_hbm=o_prev_hbm, + out_neg_max_hbm=neg_max_prev_hbm, + out_sum_hbm=sum_prev_hbm, + init_sbuf_addr=init_sbuf_addr, + sink=sink, + cp_offset=kvp_offset_active_hbm, + global_cp_deg=1, + k_scale_sb=k_scale_sb, + ) + allocator.set_current_address(init_sbuf_addr) + + # Compute effective_prior_used_len = max(0, min(partial_prior_tokens, kvp_offset - prior_block_offset*block_size)) + effective_prior_used_len = allocator.alloc_sbuf_tensor((1, 1), nl.int32) + prior_seg_start_sbuf = allocator.alloc_sbuf_tensor((1, 1), nl.int32) + nisa.tensor_scalar(dst=prior_seg_start_sbuf, data=prior_block_offset, op0=nl.multiply, operand0=block_size) + nisa.tensor_tensor(dst=effective_prior_used_len, data1=kvp_offset, data2=prior_seg_start_sbuf, op=nl.subtract) + nisa.tensor_tensor( + dst=effective_prior_used_len, data1=effective_prior_used_len, data2=partial_prior_tokens, op=nl.minimum + ) + nisa.tensor_scalar(dst=effective_prior_used_len, data=effective_prior_used_len, op0=nl.maximum, operand0=0) + + # Re-zero active KV sbuf before Call 2 (prior-only, causal_mask=False). + for k_idx in range(num_k_tiles_active): + nisa.memset(k_cache_sbuf[k_idx][...], value=0.0) + for v_idx in range(num_v_tiles_active): + nisa.memset(v_cache_sbuf[v_idx][...], value=0.0) + + call2_sbuf_addr = allocator.get_current_address() + # Call 2: prior-only with causal_mask=False and effective_prior_used_len (SBUF). + attention_cte_fn( + q_hbm, + None, + None, + scale=scale, + causal_mask=False, + tp_q=tp_q, + tp_k=False, + tp_out=False, + cache_softmax=True, + skip_output_normalization=True, + k_cache_sbuf=k_cache_sbuf[:num_k_tiles_per_seg], + v_cache_sbuf=v_cache_sbuf[:num_v_tiles_per_seg], + k_prior_sbuf=k_prior_sbuf, + v_prior_sbuf=v_prior_sbuf, + prior_used_len=effective_prior_used_len, + out_o_hbm=o_curr_hbm, + out_neg_max_hbm=neg_max_curr_hbm, + out_sum_hbm=sum_curr_hbm, + init_sbuf_addr=call2_sbuf_addr, + ) + allocator.set_current_address(call2_sbuf_addr) + + # Reduce active (o_prev_hbm) + prior (o_curr_hbm) into o_prev_hbm. + softmax_pat = [[n_grps, sb_p], [1, n_grps]] + o_pat = [[head_dim, sb_p], [1, head_dim]] + num_free = min(n_grps, _MAX_FREE_TILES) + neg_max_prev_sb = allocator.alloc_sbuf_tensor(shape=(sb_p, n_grps), dtype=nl.float32) + sum_prev_sb = allocator.alloc_sbuf_tensor(shape=(sb_p, n_grps), dtype=nl.float32) + neg_max_curr_sb = allocator.alloc_sbuf_tensor(shape=(sb_p, n_grps), dtype=nl.float32) + sum_curr_sb_buf = allocator.alloc_sbuf_tensor(shape=(sb_p, n_grps), dtype=nl.float32) + o_prev_sb = allocator.alloc_sbuf_tensor( + shape=(sb_p, head_dim), dtype=nl.float32, block_dim=[n_grps], num_free_tiles=[num_free] + ) + o_curr_sb = allocator.alloc_sbuf_tensor( + shape=(sb_p, head_dim), dtype=nl.float32, block_dim=[n_grps], num_free_tiles=[num_free] + ) + o_new_sb = allocator.alloc_sbuf_tensor( + shape=(sb_p, head_dim), dtype=nl.float32, block_dim=[n_grps], num_free_tiles=[num_free] + ) + reduce_batch_addr = allocator.get_current_address() + reduce_one_batch( + o_prev_hbm, + neg_max_prev_hbm, + sum_prev_hbm, + o_curr_hbm, + neg_max_curr_hbm, + sum_curr_hbm, + 0, + 0, + n_grps, + head_dim, + n_grps, + sb_p, + softmax_pat, + o_pat, + neg_max_prev_sb, + sum_prev_sb, + neg_max_curr_sb, + sum_curr_sb_buf, + o_prev_sb, + o_curr_sb, + o_new_sb, + reduce_batch_addr, + allocator, + ) + allocator.set_current_address(init_sbuf_addr) + + +def _nonkvp_partial_prior_attention( + q_hbm, + k_cache, + v_cache, + block_tables, + k_cache_sbuf, + v_cache_sbuf, + o_prev_hbm, + neg_max_prev_hbm, + sum_prev_hbm, + o_curr_hbm, + neg_max_curr_hbm, + sum_curr_hbm, + prior_block_offset, + partial_prior_tokens, + num_k_tiles_active, + num_v_tiles_active, + num_k_tiles_per_seg, + num_v_tiles_per_seg, + num_blocks_per_seg, + num_v_tiles_for_prior, + b_i, + h_i, + n_grps, + head_dim, + sb_p, + scale, + tp_q, + allocator, + attention_cte_fn, + load_kv_cache_fn, + sink=None, +): + """Non-KVP partial prior: two sequential attention_cte calls then reduce. + + Mirrors _kvp_partial_prior_attention's 2-pass shape but without cp_offset / + global_cp_deg, and uses the static partial_prior_tokens directly as + prior_used_len (no dynamic effective_prior_used_len math needed). + + Pass 1: active-only, causal_mask=True, sink applied. + ---- Allocate k_prior_sbuf/v_prior_sbuf ALIASED onto k_cache_sbuf / + v_cache_sbuf (same physical SBUF region). Pass 1 has already reduced + its active-K result into o_prev_hbm, so the active K data in + k_cache_sbuf is no longer needed. Load prior K/V into the aliased + region. + Pass 2: prior-only, causal_mask=False, prior_used_len=partial_prior_tokens. + Reduce via online softmax into o_prev_hbm in-place. + + Saves ~48 KB/partition of peak SBUF at head_dim=128 prior_seg_size=8192 + compared to the previous single-fused-call design (which held both + k_cache_sbuf and a separate k_prior_sbuf concurrently live through APC). + """ + init_sbuf_addr = allocator.get_current_address() + + # Pass 1: active-only with causal mask. No k_prior_sbuf reference. + attention_cte_fn( + q_hbm, + None, + None, + scale=scale, + causal_mask=True, + tp_q=tp_q, + tp_k=False, + tp_out=False, + cache_softmax=True, + skip_output_normalization=True, + k_cache_sbuf=k_cache_sbuf[:num_k_tiles_active], + v_cache_sbuf=v_cache_sbuf[:num_v_tiles_active], + out_o_hbm=o_prev_hbm, + out_neg_max_hbm=neg_max_prev_hbm, + out_sum_hbm=sum_prev_hbm, + init_sbuf_addr=init_sbuf_addr, + sink=sink, + ) + allocator.set_current_address(init_sbuf_addr) + + # Alias k_prior_sbuf/v_prior_sbuf onto the first N tiles of + # k_cache_sbuf/v_cache_sbuf via Python list slicing — same physical SBUF, + # no new allocation. k_cache_sbuf is sized with + # max(num_k_tiles_active, num_k_tiles_per_seg) at the caller, so the slice + # is always in range. + kernel_assert( + num_k_tiles_per_seg <= len(k_cache_sbuf), + "k_cache_sbuf must be sized >= num_k_tiles_per_seg for aliased reuse", + ) + kernel_assert( + num_v_tiles_for_prior <= len(v_cache_sbuf), + "v_cache_sbuf must be sized >= num_v_tiles_for_prior for aliased reuse", + ) + k_prior_sbuf = k_cache_sbuf[:num_k_tiles_per_seg] + v_prior_sbuf = v_cache_sbuf[:num_v_tiles_for_prior] + + # Load prior K/V into the aliased region. This overwrites the active K/V + # from Pass 1, which is safe because Pass 1's results are already in + # o_prev_hbm. + load_kv_cache_fn( + k_cache, + v_cache, + block_tables, + k_prior_sbuf, + v_prior_sbuf, + b_i, + h_i, + prior_block_offset, + num_blocks_per_seg, + allocator, + ) + + call2_sbuf_addr = allocator.get_current_address() + # Pass 2: non-APC call treating the aliased prior data as the active K/V. + # `kv_used_len=partial_prior_tokens` dynamically masks K positions beyond + # the used prior range. Previously this was an APC call (k_prior_sbuf + + # k_cache_sbuf both pointing to the aliased memory), but that caused the + # kernel to attend the prior data TWICE — once as "active" (unmasked) and + # once as "prior" (masked by prior_used_len) — inflating sum_curr_hbm by + # ~2× and skewing the reduce_one_batch combination. Using kv_used_len in + # non-APC mode keeps Bucket B's SBUF aliasing AND produces correct output. + attention_cte_fn( + q_hbm, + None, + None, + scale=scale, + causal_mask=False, + tp_q=tp_q, + tp_k=False, + tp_out=False, + cache_softmax=True, + skip_output_normalization=True, + k_cache_sbuf=k_prior_sbuf, + v_cache_sbuf=v_prior_sbuf, + kv_used_len=partial_prior_tokens, + out_o_hbm=o_curr_hbm, + out_neg_max_hbm=neg_max_curr_hbm, + out_sum_hbm=sum_curr_hbm, + init_sbuf_addr=call2_sbuf_addr, + ) + allocator.set_current_address(call2_sbuf_addr) + + # Reduce active (o_prev_hbm) + prior (o_curr_hbm) into o_prev_hbm. + softmax_pat = [[n_grps, sb_p], [1, n_grps]] + o_pat = [[head_dim, sb_p], [1, head_dim]] + num_free = min(n_grps, _MAX_FREE_TILES) + neg_max_prev_sb = allocator.alloc_sbuf_tensor(shape=(sb_p, n_grps), dtype=nl.float32) + sum_prev_sb = allocator.alloc_sbuf_tensor(shape=(sb_p, n_grps), dtype=nl.float32) + neg_max_curr_sb = allocator.alloc_sbuf_tensor(shape=(sb_p, n_grps), dtype=nl.float32) + sum_curr_sb_buf = allocator.alloc_sbuf_tensor(shape=(sb_p, n_grps), dtype=nl.float32) + o_prev_sb = allocator.alloc_sbuf_tensor( + shape=(sb_p, head_dim), dtype=nl.float32, block_dim=[n_grps], num_free_tiles=[num_free] + ) + o_curr_sb = allocator.alloc_sbuf_tensor( + shape=(sb_p, head_dim), dtype=nl.float32, block_dim=[n_grps], num_free_tiles=[num_free] + ) + o_new_sb = allocator.alloc_sbuf_tensor( + shape=(sb_p, head_dim), dtype=nl.float32, block_dim=[n_grps], num_free_tiles=[num_free] + ) + reduce_batch_addr = allocator.get_current_address() + reduce_one_batch( + o_prev_hbm, + neg_max_prev_hbm, + sum_prev_hbm, + o_curr_hbm, + neg_max_curr_hbm, + sum_curr_hbm, + 0, + 0, + n_grps, + head_dim, + n_grps, + sb_p, + softmax_pat, + o_pat, + neg_max_prev_sb, + sum_prev_sb, + neg_max_curr_sb, + sum_curr_sb_buf, + o_prev_sb, + o_curr_sb, + o_new_sb, + reduce_batch_addr, + allocator, + ) + allocator.set_current_address(init_sbuf_addr) + + +_allocate_attention_buffers_base = _allocate_attention_buffers +_load_q_impl_base = _load_q_impl +_qk_and_max_impl_base = _qk_and_max_impl +_pv_impl_base = _pv_impl + + +def _zero_k_tiles(k_sbuf, num_tiles, head_dim): + if head_dim == 256: + for k_idx in range(num_tiles): + nisa.memset(k_sbuf[k_idx][0][...], value=0.0) + nisa.memset(k_sbuf[k_idx][1][...], value=0.0) + else: + for k_idx in range(num_tiles): + nisa.memset(k_sbuf[k_idx][...], value=0.0) + + +def _repeat_ref(value, count): + values = [] + for _ in range(count): + values.append(value) + return values + + +def _allocate_attention_buffers( + allocator, + ac: AttnConfig, + atp, + bufs: AttnInternalBuffers, + sink=None, + k_cache_sbuf=None, + v_cache_sbuf=None, +): + if ac.d <= 128: + return _allocate_attention_buffers_base(allocator, ac, atp, bufs, sink, k_cache_sbuf, v_cache_sbuf) + + kernel_assert(ac.d == 256, f"qwen_segcte256 only supports head_dim=256, got {ac.d}") + kernel_assert(not ac.tp_out, "qwen_segcte256 uses tp_out=False to keep head_dim on the free axis") + + mm1_p, mm1_n = atp.sb_p, nl.tile_size.psum_fmax + mm2_p, mm2_n = atp.sb_p, ac.d + num_q_slots = div_ceil(atp.num_grps, atp.num_q_grps_per_load) + + if k_cache_sbuf is not None and len(k_cache_sbuf) > 0: + bufs.k_sb = k_cache_sbuf + else: + k_lo = allocator.alloc_sbuf_tensor( + shape=(128, _K_TILE_SZ), + dtype=nl.bfloat16, + block_dim=[atp.num_k_tiles_per_section], + num_free_tiles=[atp.num_k_tiles_per_section], + align_to=32, + ) + k_hi = allocator.alloc_sbuf_tensor( + shape=(128, _K_TILE_SZ), + dtype=nl.bfloat16, + block_dim=[atp.num_k_tiles_per_section], + num_free_tiles=[atp.num_k_tiles_per_section], + align_to=32, + ) + bufs.k_sb = [] + for i in range(atp.num_k_tiles_per_section): + bufs.k_sb.append((k_lo[i], k_hi[i])) + + if v_cache_sbuf is not None and len(v_cache_sbuf) > 0: + bufs.v_sb = v_cache_sbuf + else: + bufs.v_sb = allocator.alloc_sbuf_tensor( + shape=(_V_TILE_SZ, ac.d), + dtype=nl.bfloat16, + block_dim=[atp.num_v_tiles_per_section], + num_free_tiles=[atp.num_v_tiles_per_section], + ) + + # This kernel runs Q groups sequentially, not with the upstream + # software-pipelined 3-group schedule. Keep one physical scratch window and + # alias all logical group slots to it so pfx256 does not allocate per-group + # MM1/MM2 scratch for the full CTE bucket. + q_sb_lo = allocator.alloc_sbuf_tensor( + shape=(128, atp.sb_p * atp.num_q_grps_per_load), + dtype=nl.bfloat16, + align_to=32, + ) + q_sb_hi = allocator.alloc_sbuf_tensor( + shape=(128, atp.sb_p * atp.num_q_grps_per_load), + dtype=nl.bfloat16, + align_to=32, + ) + bufs.q_sb_lo = _repeat_ref(q_sb_lo, num_q_slots) + bufs.q_sb_hi = _repeat_ref(q_sb_hi, num_q_slots) + + flash_attn_correction_factor = allocator.alloc_sbuf_tensor( + shape=(atp.sb_p, 1), + dtype=nl.float32, + ) + bufs.flash_attn_correction_factor = _repeat_ref(flash_attn_correction_factor, atp.num_grps) + mm1_partial_max_n_elts = atp.num_k_tiles_per_section + (sink is not None) + mm1_partial_max = allocator.alloc_sbuf_tensor( + shape=(atp.sb_p, mm1_partial_max_n_elts), + dtype=nl.float32, + align_to=4, + ) + bufs.mm1_partial_max = _repeat_ref(mm1_partial_max, atp.num_grps) + mm1_section_max = allocator.alloc_sbuf_tensor( + shape=(atp.sb_p, 1), + dtype=nl.float32, + ) + bufs.mm1_section_max = _repeat_ref(mm1_section_max, atp.num_grps) + n_final_reduce_sum_elts = div_ceil(atp.section_len, atp.exp_inst_elems) + (sink is not None) + exp_partial_sum = allocator.alloc_sbuf_tensor( + shape=(atp.sb_p, n_final_reduce_sum_elts), + dtype=nl.float32, + ) + bufs.exp_partial_sum = _repeat_ref(exp_partial_sum, atp.num_grps) + exp_section_sum = allocator.alloc_sbuf_tensor( + shape=(atp.sb_p, 1), + dtype=nl.float32, + ) + bufs.exp_section_sum = _repeat_ref(exp_section_sum, atp.num_grps) + prev_mm1_running_max = allocator.alloc_sbuf_tensor( + shape=(atp.sb_p, 1), + dtype=nl.float32, + ) + bufs.prev_mm1_running_max = _repeat_ref(prev_mm1_running_max, atp.num_grps) + prev_exp_running_sum = allocator.alloc_sbuf_tensor( + shape=(atp.sb_p, 1), + dtype=nl.float32, + ) + bufs.prev_exp_running_sum = _repeat_ref(prev_exp_running_sum, atp.num_grps) + mm2_prev_output = allocator.alloc_sbuf_tensor( + shape=(atp.sb_p, ac.d), + dtype=ac.mm_out_dtype, + ) + bufs.mm2_prev_output = _repeat_ref(mm2_prev_output, atp.num_grps) + mm2_accum_flash_attn = allocator.alloc_sbuf_tensor( + shape=(atp.sb_p, ac.d), + dtype=nl.float32, + ) + bufs.mm2_accum_flash_attn = _repeat_ref(mm2_accum_flash_attn, atp.num_grps) + mm2_final = allocator.alloc_sbuf_tensor( + shape=(atp.sb_p, ac.d), + dtype=ac.mm_out_dtype, + ) + bufs.mm2_final = _repeat_ref(mm2_final, atp.num_grps) + mm2_sb = allocator.alloc_sbuf_tensor( + shape=(mm2_p, mm2_n), + dtype=ac.mm_out_dtype, + ) + bufs.mm2_sb = _repeat_ref(mm2_sb, atp.num_grps) + mm1_masked_tiles = allocator.alloc_sbuf_tensor( + shape=(atp.sb_p, _LARGE_TILE_SZ), + dtype=nl.float32, + block_dim=[atp.num_large_tiles_per_section], + num_free_tiles=[1], + ) + mm1_masked_row = [] + for large_tile_idx in range(atp.num_large_tiles_per_section): + mm1_masked_row.append(mm1_masked_tiles[large_tile_idx]) + bufs.mm1_masked = _repeat_ref(mm1_masked_row, atp.num_grps) + exp_sb_tiles = allocator.alloc_sbuf_tensor( + shape=(atp.sb_p, _LARGE_TILE_SZ), + dtype=nl.bfloat16, + block_dim=[atp.num_large_tiles_per_section], + num_free_tiles=[1], + ) + exp_sb_row = [] + for large_tile_idx in range(atp.num_large_tiles_per_section): + exp_sb_row.append(exp_sb_tiles[large_tile_idx]) + bufs.exp_sb = _repeat_ref(exp_sb_row, atp.num_grps) + + bufs.mm1_psum = [] + for grp_idx in range(atp.num_grps): + mm1_psum_row = [] + for large_tile_idx in range(atp.num_large_tiles_per_section): + tile_row = [] + for k_tile_idx in range(4): + tile_row.append( + nl.ndarray( + (mm1_p, mm1_n), + dtype=ac.mm_out_dtype, + buffer=nl.psum, + address=(0, (k_tile_idx % 4) * PSUM_BANK_SIZE), + ) + ) + mm1_psum_row.append(tile_row) + bufs.mm1_psum.append(mm1_psum_row) + + if not atp.dynamic_sel_mask: + mm1_copy_tiles = allocator.alloc_sbuf_tensor( + shape=(mm1_p, mm1_n), + dtype=ac.mm_out_dtype, + block_dim=[atp.num_large_tiles_per_section, 4], + num_free_tiles=[1, 1], + ) + mm1_affine_select_output_tiles = allocator.alloc_sbuf_tensor( + shape=(mm1_p, mm1_n), + dtype=ac.mm_out_dtype, + block_dim=[atp.num_large_tiles_per_section, 4], + num_free_tiles=[1, 1], + ) + mm1_copy_row = [] + mm1_affine_select_output_row = [] + for large_tile_idx in range(atp.num_large_tiles_per_section): + mm1_copy_tile_row = [] + mm1_affine_select_output_tile_row = [] + for k_tile_idx in range(4): + mm1_copy_tile_row.append(mm1_copy_tiles[large_tile_idx][k_tile_idx]) + mm1_affine_select_output_tile_row.append( + mm1_affine_select_output_tiles[large_tile_idx][k_tile_idx] + ) + mm1_copy_row.append(mm1_copy_tile_row) + mm1_affine_select_output_row.append(mm1_affine_select_output_tile_row) + bufs.mm1_copy_sb = _repeat_ref(mm1_copy_row, atp.num_grps) + bufs.mm1_affine_select_output = _repeat_ref(mm1_affine_select_output_row, atp.num_grps) + + exp_tp_tiles = allocator.alloc_sbuf_tensor( + shape=(atp.sb_p, atp.mm2_grp_sz), + dtype=nl.bfloat16, + block_dim=[atp.num_large_tiles_per_section, atp.num_tps_in_mm2_grp], + num_free_tiles=[1, atp.num_tps_in_mm2_grp], + align_to=32, + ) + exp_tp_row = [] + for large_tile_idx in range(atp.num_large_tiles_per_section): + exp_tp_tile_row = [] + for tp_idx in range(atp.num_tps_in_mm2_grp): + exp_tp_tile_row.append(exp_tp_tiles[large_tile_idx][tp_idx]) + exp_tp_row.append(exp_tp_tile_row) + bufs.exp_tp_sb = _repeat_ref(exp_tp_row, atp.num_grps) + + bufs.mm2_psum = [] + for grp_idx in range(atp.num_grps): + mm2_psum_row = [] + for large_tile_idx in range(atp.num_large_tiles_per_section): + mm2_psum_row.append( + nl.ndarray( + (mm2_p, mm2_n), + dtype=ac.mm_out_dtype, + buffer=nl.psum, + address=(0, ((4 + (large_tile_idx % 4)) * PSUM_BANK_SIZE)), + ) + ) + bufs.mm2_psum.append(mm2_psum_row) + + +def _load_q_impl(grp_i, ac: AttnConfig, atp, sp: SectionParams, bufs: AttnInternalBuffers, q, batch_id, sbuf_addr): + if ac.d <= 128: + return _load_q_impl_base(grp_i, ac, atp, sp, bufs, q, batch_id, sbuf_addr) + + kernel_assert(ac.d == 256, f"qwen_segcte256 only supports head_dim=256, got {ac.d}") + if grp_i % atp.num_q_grps_per_load != 0: + return + + has_any_compute_pred = ( + _has_any_compute_causal(grp_i, sp.section_offset_active, ac, atp.num_q_grps_per_load) + if (atp.is_causal and not sp.section_contains_prefix) + else True + ) + if not has_any_compute_pred: + return + + q_seqlen_offset = grp_i * _Q_GRP_SZ + q_slot = grp_i // atp.num_q_grps_per_load + num_f = min(ac.seqlen_q - q_seqlen_offset, _Q_GRP_SZ * atp.num_q_grps_per_load) + kernel_assert(str(q.dtype) == str(nl.bfloat16), "qwen_segcte256 currently expects bf16 Q input") + + if ac.tp_q: + _, seqlen, _ = q.shape + for d_half in range(2): + d_offset = d_half * 128 + q_dst = bufs.q_sb_lo[q_slot] if d_half == 0 else bufs.q_sb_hi[q_slot] + nisa.dma_transpose( + dst=q_dst.ap([[_Q_GRP_SZ * atp.num_q_grps_per_load, 128], [1, 1], [1, 1], [1, num_f]]), + src=q.ap( + [[ac.d, num_f], [1, 1], [1, 1], [1, 128]], + offset=batch_id * seqlen * ac.d + q_seqlen_offset * ac.d + d_offset, + ), + ) + else: + _, _, seqlen = q.shape + for d_half in range(2): + d_offset = d_half * 128 + q_dst = bufs.q_sb_lo[q_slot] if d_half == 0 else bufs.q_sb_hi[q_slot] + nisa.dma_copy( + dst=q_dst.ap(pattern=[[_Q_GRP_SZ * atp.num_q_grps_per_load, 128], [1, num_f]], offset=0), + src=q.ap( + pattern=[[seqlen, 128], [1, num_f]], + offset=batch_id * ac.d * seqlen + d_offset * seqlen + q_seqlen_offset, + ), + ) + + +def _qk_and_max_impl(grp_i, ac: AttnConfig, atp, sp: SectionParams, bufs: AttnInternalBuffers, batch_id: int = 0): + if ac.d <= 128: + return _qk_and_max_impl_base(grp_i, ac, atp, sp, bufs, batch_id) + + has_any_compute_pred = ( + _has_any_compute_causal(grp_i, sp.section_offset_active, ac) + if (atp.is_causal and not sp.section_contains_prefix) + else True + ) + if has_any_compute_pred: + nisa.memset(bufs.mm1_partial_max[grp_i], value=_FLOAT32_MIN) + for large_tile_idx in range(atp.num_large_tiles_per_section): + _qk_and_max_large_tile_impl_256(grp_i, large_tile_idx, ac, atp, sp, bufs, batch_id) + + +def _qk_and_max_large_tile_impl_256(qkmax_grp, large_tile_idx, ac, atp, sp, bufs, batch_id: int = 0): + q_seqlen_offset = qkmax_grp * atp.sb_p + num_k_tiles_in_large_tile = _LARGE_TILE_SZ // _K_TILE_SZ + for k_tile_idx in range(num_k_tiles_in_large_tile): + mm1_psum_tile = bufs.mm1_psum[qkmax_grp][large_tile_idx][k_tile_idx] + if not atp.dynamic_sel_mask: + mm1_copy_sb_tile = bufs.mm1_copy_sb[qkmax_grp][large_tile_idx][k_tile_idx] + mm1_affine_select_output_tile = bufs.mm1_affine_select_output[qkmax_grp][large_tile_idx][k_tile_idx] + mm1_masked_tile = bufs.mm1_masked[qkmax_grp][large_tile_idx] + mm1_partial_max_tile = bufs.mm1_partial_max[qkmax_grp] + + k_tile_idx_in_section = large_tile_idx * num_k_tiles_in_large_tile + k_tile_idx + _kv_sec_idx = sp.kv_section_idx if sp.kv_section_idx is not None else sp.section_idx + k_tile_idx_global = atp.num_k_tiles_per_section * _kv_sec_idx + k_tile_idx_in_section + is_prior_tile, seqlen_k, k_start_pos, _ = _get_kv_tile_apc( + ac.is_prefix_caching, + False, + True, + atp.seqlen_k_active_updated, + ac.seqlen_k_prior, + k_tile_idx_global * _K_TILE_SZ, + None, + ) + + if atp.is_causal and not is_prior_tile: + matmul_selection = _has_any_compute_causal(qkmax_grp, k_start_pos, ac) + if ac.use_swa: + matmul_selection = matmul_selection and _has_any_compute_swa(qkmax_grp, k_start_pos, _K_TILE_SZ, ac) + else: + matmul_selection = True + + if q_seqlen_offset >= ac.seqlen_q or k_start_pos >= seqlen_k: + matmul_selection = False + + if matmul_selection and k_tile_idx_in_section < atp.num_k_tiles_per_section: + num_f = min(seqlen_k - k_start_pos, _K_TILE_SZ) + num_q_free = min(ac.seqlen_q - q_seqlen_offset, _Q_GRP_SZ) + if is_prior_tile and bufs.k_sb_prior is not None: + k_tile_to_use = bufs.k_sb_prior[k_start_pos // _K_TILE_SZ] + elif bufs.k_sb_prior is not None: + k_tile_to_use = bufs.k_sb[k_start_pos // _K_TILE_SZ] + else: + k_tile_to_use = bufs.k_sb[k_tile_idx_in_section] + + q_slot = qkmax_grp // atp.num_q_grps_per_load + q_offset = (qkmax_grp % atp.num_q_grps_per_load) * _Q_GRP_SZ + nisa.nc_matmul( + mm1_psum_tile[:num_q_free, :num_f], + bufs.q_sb_lo[q_slot][:128, nl.ds(q_offset, num_q_free)], + k_tile_to_use[0][:128, :num_f], + ) + nisa.nc_matmul( + mm1_psum_tile[:num_q_free, :num_f], + bufs.q_sb_hi[q_slot][:128, nl.ds(q_offset, num_q_free)], + k_tile_to_use[1][:128, :num_f], + ) + + num_p = min(ac.seqlen_q - q_seqlen_offset, _Q_GRP_SZ) + num_f = min(seqlen_k - k_start_pos, _K_TILE_SZ) + diagonal_sel_mask = ( + matmul_selection and ((qkmax_grp * _Q_GRP_SZ) < (k_start_pos + _K_TILE_SZ)) + if (atp.is_causal and not is_prior_tile and not atp.dynamic_sel_mask) + else False + ) + if ac.use_swa and atp.is_causal and not is_prior_tile: + diagonal_sel_mask = not atp.dynamic_sel_mask + + if diagonal_sel_mask: + nisa.tensor_copy(mm1_copy_sb_tile[:num_p, :num_f], mm1_psum_tile[:num_p, :num_f]) + nisa.affine_select( + mm1_affine_select_output_tile[:num_p, :num_f], + pattern=[[-1, num_f]], + offset=qkmax_grp * atp.sb_p - k_start_pos, + channel_multiplier=1, + cmp_op=nl.greater_equal, + on_true_tile=mm1_copy_sb_tile[:num_p, :num_f], + on_false_value=_FLOAT32_MIN, + ) + if ac.use_swa: + nisa.affine_select( + mm1_affine_select_output_tile[:num_p, :num_f], + pattern=[[1, num_f]], + offset=(k_start_pos + ac.sliding_window - 1 - qkmax_grp * atp.sb_p), + channel_multiplier=-1, + cmp_op=nl.greater_equal, + on_true_tile=mm1_affine_select_output_tile[:num_p, :num_f], + on_false_value=_FLOAT32_MIN, + ) + nisa.tensor_scalar_reduce( + mm1_masked_tile[:num_p, nl.ds(k_tile_idx * _K_TILE_SZ, num_f)], + data=mm1_affine_select_output_tile[:num_p, :num_f], + op0=nl.multiply, + operand0=ac.scale, + reduce_op=nl.maximum, + reduce_res=mm1_partial_max_tile[:num_p, k_tile_idx_in_section], + ) + elif atp.dynamic_sel_mask or is_prior_tile: + if is_prior_tile: + bound0 = bufs.range_sel_lbs_prior[:num_p, qkmax_grp] if ac.use_swa else bufs.zero_bias_tensor + bound1 = bufs.range_sel_ubs_prior[:num_p, qkmax_grp] + comp_op1 = nl.less + elif ac.is_sequence_packed: + bound0 = bufs.range_sel_lbs[:num_p, nl.ds(qkmax_grp, 1)] + bound1 = bufs.range_sel_ubs[:num_p, nl.ds(qkmax_grp, 1)] + comp_op1 = nl.less_equal if atp.is_causal else nl.less + else: + bound0 = bufs.range_sel_lbs[:num_p, qkmax_grp] if ac.use_swa else bufs.zero_bias_tensor + bound1 = bufs.range_sel_ubs[:num_p, qkmax_grp] + comp_op1 = nl.less_equal + + kernel_assert(ac.scale == 1.0, "range_select path doesn't support scale != 1.0") + nisa.range_select( + mm1_masked_tile[:num_p, nl.ds(k_tile_idx * _K_TILE_SZ, num_f)], + on_true_tile=mm1_psum_tile[:num_p, :num_f], + on_false_value=_FLOAT32_MIN, + comp_op0=nl.greater_equal, + comp_op1=comp_op1, + bound0=bound0[:num_p, :1], + bound1=bound1[:num_p, :1], + reduce_op=_maximum, + reduce_res=mm1_partial_max_tile[:num_p, k_tile_idx_in_section], + reduce_cmd=reduce_cmd.reset_reduce, + range_start=k_start_pos, + ) + else: + nisa.tensor_scalar_reduce( + mm1_masked_tile[:num_p, nl.ds(k_tile_idx * _K_TILE_SZ, num_f)], + data=mm1_psum_tile[:num_p, :num_f], + op0=nl.multiply, + operand0=ac.scale, + reduce_op=nl.maximum, + reduce_res=mm1_partial_max_tile[:num_p, k_tile_idx_in_section], + ) + + +def _run_groups(grp_start, grp_end, ac, atp, sp, bufs, q, batch_id, o, sbuf_addr, sink=None): + for grp_i in range(grp_start, grp_end): + _load_q_impl(grp_i, ac, atp, sp, bufs, q, batch_id, sbuf_addr) + _qk_and_max_impl(grp_i, ac, atp, sp, bufs, batch_id) + _update_max_impl(grp_i, ac, atp, sp, bufs, sink) + _exp_impl(grp_i, ac, atp, sp, bufs, sink) + _pv_impl_base(grp_i, ac, atp, sp, bufs) + _write_back_impl(grp_i, ac, atp, sp, bufs, o, batch_id) + + +def _run_attention_from_sbuf( + q_hbm, + k_sbuf, + v_sbuf, + out_o_hbm, + out_neg_max_hbm, + out_sum_hbm, + seqlen_q, + seqlen_k, + head_dim, + sb_p, + n_grps, + scale, + causal, + tp_q, + allocator, + sink=None, + kv_used_len=None, +): + ac, atp = _make_ac_atp(seqlen_q, seqlen_k, head_dim, q_hbm.dtype, causal, scale, tp_q, False, 2) + ac.has_kv_used_len = kv_used_len is not None + bufs = AttnInternalBuffers() + bufs.zero_bias_tensor = allocator.alloc_sbuf_tensor(shape=(sb_p, 1), dtype=nl.float32) + nisa.memset(bufs.zero_bias_tensor, 0.0) + bufs.k_scale_sb = None + bufs.mm1_running_max = allocator.alloc_sbuf_tensor(shape=(sb_p, n_grps), dtype=nl.float32) + bufs.exp_running_sum = allocator.alloc_sbuf_tensor(shape=(sb_p, n_grps), dtype=nl.float32) + bufs.exp_sum_reciprocal = allocator.alloc_sbuf_tensor(shape=(sb_p, n_grps), dtype=nl.float32) + if sink is not None: + bufs.sink_sb = allocator.alloc_sbuf_tensor(shape=(sb_p, 1), dtype=nl.float32) + nisa.dma_copy(dst=bufs.sink_sb[0, 0], src=sink[0, 0]) + stream_shuffle_broadcast(src=bufs.sink_sb, dst=bufs.sink_sb) + _allocate_attention_buffers(allocator, ac, atp, bufs, sink, k_sbuf, v_sbuf) + _setup_range_select_bounds(ac, atp, bufs, allocator, None, None, None, None, batch_id=0, kv_used_len=kv_used_len) + sp = SectionParams( + section_idx=0, + section_offset=0, + section_offset_active=0, + next_section_offset_active=seqlen_k, + section_contains_prefix=False, + next_section_contains_prefix=False, + kv_section_idx=0, + ) + sbuf_inner = allocator.get_current_address() + _run_groups(0, n_grps, ac, atp, sp, bufs, q_hbm, 0, out_o_hbm, sbuf_inner, sink=sink) + nisa.dma_copy(dst=out_neg_max_hbm.ap(pattern=[[n_grps, sb_p], [1, n_grps]], offset=0), src=bufs.mm1_running_max) + nisa.dma_copy(dst=out_sum_hbm.ap(pattern=[[n_grps, sb_p], [1, n_grps]], offset=0), src=bufs.exp_running_sum) + + +def _nonkvp_partial_prior_attention( + q_hbm, + k_cache, + v_cache, + block_tables, + k_cache_sbuf, + v_cache_sbuf, + o_prev_hbm, + neg_max_prev_hbm, + sum_prev_hbm, + o_curr_hbm, + neg_max_curr_hbm, + sum_curr_hbm, + prior_block_offset, + partial_prior_tokens, + num_k_tiles_active, + num_v_tiles_active, + num_k_tiles_per_seg, + num_v_tiles_per_seg, + num_blocks_per_seg, + num_v_tiles_for_prior, + b_i, + h_i, + n_grps, + head_dim, + sb_p, + scale, + tp_q, + allocator, + attention_cte_fn, + load_kv_cache_fn, + sink=None, +): + init_sbuf_addr = allocator.get_current_address() + seqlen_q = q_hbm.shape[1] if tp_q else q_hbm.shape[2] + _run_attention_from_sbuf( + q_hbm, + k_cache_sbuf[:num_k_tiles_active], + v_cache_sbuf[:num_v_tiles_active], + o_prev_hbm, + neg_max_prev_hbm, + sum_prev_hbm, + seqlen_q, + seqlen_q, + head_dim, + sb_p, + n_grps, + scale, + True, + tp_q, + allocator, + sink=sink, + ) + allocator.set_current_address(init_sbuf_addr) + + kernel_assert(num_k_tiles_per_seg <= len(k_cache_sbuf), "k_cache_sbuf must fit prior segment") + kernel_assert(num_v_tiles_for_prior <= len(v_cache_sbuf), "v_cache_sbuf must fit prior segment") + k_prior_sbuf = k_cache_sbuf[:num_k_tiles_per_seg] + v_prior_sbuf = v_cache_sbuf[:num_v_tiles_for_prior] + load_kv_cache_fn( + k_cache, + v_cache, + block_tables, + k_prior_sbuf, + v_prior_sbuf, + b_i, + h_i, + prior_block_offset, + num_blocks_per_seg, + allocator, + ) + + call2_sbuf_addr = allocator.get_current_address() + _run_attention_from_sbuf( + q_hbm, + k_prior_sbuf, + v_prior_sbuf, + o_curr_hbm, + neg_max_curr_hbm, + sum_curr_hbm, + seqlen_q, + num_k_tiles_per_seg * _K_TILE_SZ, + head_dim, + sb_p, + n_grps, + scale, + False, + tp_q, + allocator, + kv_used_len=partial_prior_tokens, + ) + allocator.set_current_address(call2_sbuf_addr) + + softmax_pat = [[n_grps, sb_p], [1, n_grps]] + o_pat = [[head_dim, sb_p], [1, head_dim]] + num_free = min(n_grps, _MAX_FREE_TILES) + neg_max_prev_sb = allocator.alloc_sbuf_tensor(shape=(sb_p, n_grps), dtype=nl.float32) + sum_prev_sb = allocator.alloc_sbuf_tensor(shape=(sb_p, n_grps), dtype=nl.float32) + neg_max_curr_sb = allocator.alloc_sbuf_tensor(shape=(sb_p, n_grps), dtype=nl.float32) + sum_curr_sb_buf = allocator.alloc_sbuf_tensor(shape=(sb_p, n_grps), dtype=nl.float32) + o_prev_sb = allocator.alloc_sbuf_tensor( + shape=(sb_p, head_dim), dtype=nl.float32, block_dim=[n_grps], num_free_tiles=[num_free] + ) + o_curr_sb = allocator.alloc_sbuf_tensor( + shape=(sb_p, head_dim), dtype=nl.float32, block_dim=[n_grps], num_free_tiles=[num_free] + ) + o_new_sb = allocator.alloc_sbuf_tensor( + shape=(sb_p, head_dim), dtype=nl.float32, block_dim=[n_grps], num_free_tiles=[num_free] + ) + reduce_batch_addr = allocator.get_current_address() + reduce_one_batch( + o_prev_hbm, + neg_max_prev_hbm, + sum_prev_hbm, + o_curr_hbm, + neg_max_curr_hbm, + sum_curr_hbm, + 0, + 0, + n_grps, + head_dim, + n_grps, + sb_p, + softmax_pat, + o_pat, + neg_max_prev_sb, + sum_prev_sb, + neg_max_curr_sb, + sum_curr_sb_buf, + o_prev_sb, + o_curr_sb, + o_new_sb, + reduce_batch_addr, + allocator, + ) + allocator.set_current_address(init_sbuf_addr) + + +def fused_segmented_attention_impl( + q_hbm, + num_batches, + k_cache, + v_cache, + block_tables, + k_cache_sbuf, + v_cache_sbuf, + o_prev_hbm, + neg_max_prev_hbm, + sum_prev_hbm, + o_curr_hbm, + neg_max_curr_hbm, + sum_curr_hbm, + prior_tokens_sbuf, + num_full_prior_segments_i32, + partial_prior_tokens, + is_partial_prior_segment, + is_not_partial_prior_segment, + active_block_offset, + prior_block_offset, + allocator, + prior_seg_size, + block_size, + scale, + head_dim, + num_grps, + num_active_blocks, + num_k_tiles_active, + num_v_tiles_active, + num_blocks_per_seg, + num_k_tiles_per_seg, + num_v_tiles_per_seg, + b_i=0, + h_i=0, + tp_q=True, + tp_out=False, + load_kv_cache_fn=None, + attention_cte_fn=None, + sink=None, + kvp_offset=None, + k_pre_transposed=False, + k_scale_sb=None, +): + """Fused segmented-attention impl with SBUF aliasing across active/prior passes. + + Uses kv_section_idx=0 so K/V indexing starts at tile 0 for every segment, + while section_idx controls flash attention accumulation: + - Active segment: section_idx=0 (init running stats) + - Prior segments: section_idx=1 (accumulate via _write_back_impl) + + The PV accumulation stays in float32 SBUF across all segments, matching + _attention_cte's internal flash attention precision. + """ + orig_addr = allocator.get_current_address() + + kernel_assert( + kvp_offset == None, + "qwen_segcte256 KVP mode is not production validated; use the " + "non-KVP segmented CTE path", + ) + kernel_assert( + not k_pre_transposed, + "qwen_segcte256 supports only k_pre_transposed=False; " + "the transposed-K path has not been production validated", + ) + + is_kvp = False + # KVP: compute kvp_offset_active = kvp_offset - prior_tokens_sbuf (for active segment cp_offset) + # and allocate kvp_offset_prior_sbuf/hbm for per-iteration prior segment cp_offset. + kvp_offset_active_hbm = None + kvp_offset_prior_sbuf = None + kvp_offset_prior_hbm = None + if is_kvp: + kvp_offset_active_sbuf = allocator.alloc_sbuf_tensor((1, 1), nl.int32) + kvp_offset_active_hbm = nl.ndarray((1, 1), dtype=nl.int32, buffer=nl.shared_hbm) + nisa.tensor_tensor(dst=kvp_offset_active_sbuf, data1=kvp_offset, data2=prior_tokens_sbuf, op=nl.subtract) + nisa.dma_copy(dst=kvp_offset_active_hbm, src=kvp_offset_active_sbuf) + kvp_offset_prior_sbuf = allocator.alloc_sbuf_tensor((1, 1), nl.int32) + kvp_offset_prior_hbm = nl.ndarray((1, 1), dtype=nl.int32, buffer=nl.shared_hbm) + + seqlen_q = q_hbm.shape[1] if tp_q else q_hbm.shape[2] + seqlen_k_active = seqlen_q # actual tokens, not rounded to tile boundary + seqlen_k_prior = prior_seg_size # actual tokens, not rounded to tile boundary + + # num_sections must be > 1 to enable flash attention accumulation path + max_blocks_per_seq = block_tables.shape[1] + max_prior_segments = math.ceil(max_blocks_per_seq * block_size / prior_seg_size) + total_sections = max(max_prior_segments + 1, 2) + + # Prior config: KVP uses causal=True + cp to handle shifted causal mask; non-KVP uses causal=False + ac_p, atp_p = _make_ac_atp( + seqlen_q, + seqlen_k_prior, + head_dim, + q_hbm.dtype, + is_kvp, + scale, + tp_q, + False, + total_sections, + use_cp=is_kvp, + global_cp_deg=1 if is_kvp else None, + ) + + sb_p = atp_p.sb_p + n_grps = atp_p.num_grps + + # Running buffers (persist across all segments in SBUF) + bufs = AttnInternalBuffers() + bufs.zero_bias_tensor = allocator.alloc_sbuf_tensor(shape=(sb_p, 1), dtype=nl.float32) + nisa.memset(bufs.zero_bias_tensor, 0.0) + bufs.k_scale_sb = k_scale_sb + bufs.mm1_running_max = allocator.alloc_sbuf_tensor(shape=(sb_p, n_grps), dtype=nl.float32) + bufs.exp_running_sum = allocator.alloc_sbuf_tensor(shape=(sb_p, n_grps), dtype=nl.float32) + bufs.exp_sum_reciprocal = allocator.alloc_sbuf_tensor(shape=(sb_p, n_grps), dtype=nl.float32) + + sbuf_outer = allocator.get_current_address() + + # Load sink token into SBUF if provided + if sink is not None: + bufs.sink_sb = allocator.alloc_sbuf_tensor(shape=(sb_p, 1), dtype=nl.float32) + nisa.dma_copy(dst=bufs.sink_sb[0, 0], src=sink[0, 0]) + stream_shuffle_broadcast(src=bufs.sink_sb, dst=bufs.sink_sb) + + active_stream_tokens = min(prior_seg_size, seqlen_k_active) + kernel_assert( + active_stream_tokens % block_size == 0, + "qwen_segcte256 active streaming requires an active stream chunk divisible by block_size", + ) + num_active_stream_sections = math.ceil(seqlen_k_active / active_stream_tokens) + num_blocks_per_active_stream = active_stream_tokens // block_size + num_k_tiles_per_active_stream = math.ceil(active_stream_tokens / _K_TILE_SZ) + num_v_tiles_per_active_stream = num_k_tiles_per_active_stream * (_K_TILE_SZ // _V_TILE_SZ) + + active_stream_offset = allocator.alloc_sbuf_tensor(shape=(1, 1), dtype=nl.uint32) + active_stream_addr = allocator.get_current_address() + for active_section_idx in range(num_active_stream_sections): + if active_section_idx == 0: + nisa.tensor_copy(dst=active_stream_offset, src=active_block_offset) + else: + nisa.tensor_scalar( + dst=active_stream_offset, + data=active_block_offset, + op0=nl.add, + operand0=active_section_idx * num_blocks_per_active_stream, + ) + + load_kv_cache_fn( + k_cache, + v_cache, + block_tables, + k_cache_sbuf, + v_cache_sbuf, + b_i, + h_i, + active_stream_offset, + num_blocks_per_active_stream, + allocator, + k_pre_transposed=k_pre_transposed, + ) + + section_offset_active = active_section_idx * active_stream_tokens + next_section_offset_active = min(section_offset_active + active_stream_tokens, seqlen_k_active) + ac_a, atp_a = _make_ac_atp( + seqlen_q, + next_section_offset_active, + head_dim, + q_hbm.dtype, + True, + scale, + tp_q, + False, + total_sections, + ) + atp_a.section_len = active_stream_tokens + atp_a.num_large_tiles_per_section = div_ceil(active_stream_tokens, _LARGE_TILE_SZ) + atp_a.num_k_tiles_per_section = num_k_tiles_per_active_stream + atp_a.num_v_tiles_per_section = num_v_tiles_per_active_stream + sp_active = SectionParams( + section_idx=active_section_idx, + section_offset=section_offset_active, + section_offset_active=section_offset_active, + next_section_offset_active=next_section_offset_active, + section_contains_prefix=False, + next_section_contains_prefix=False, + kv_section_idx=active_section_idx, + ) + allocator.set_current_address(sbuf_outer) + _allocate_attention_buffers( + allocator, + ac_a, + atp_a, + bufs, + sink, + k_cache_sbuf[:num_k_tiles_per_active_stream], + v_cache_sbuf[:num_v_tiles_per_active_stream], + ) + _setup_range_select_bounds(ac_a, atp_a, bufs, allocator, None, None, None, None, batch_id=0) + sbuf_inner = allocator.get_current_address() + _run_groups(0, n_grps, ac_a, atp_a, sp_active, bufs, q_hbm, 0, o_prev_hbm, sbuf_inner, sink=sink) + allocator.set_current_address(active_stream_addr) + + is_partial_reg = nisa.register_alloc() + nisa.register_load(dst=is_partial_reg, src=is_partial_prior_segment) + + for _ in nl.dynamic_range(0, is_partial_reg): + load_kv_cache_fn( + k_cache, + v_cache, + block_tables, + k_cache_sbuf, + v_cache_sbuf, + b_i, + h_i, + prior_block_offset, + num_blocks_per_seg, + allocator, + k_pre_transposed=k_pre_transposed, + ) + allocator.set_current_address(sbuf_outer) + ac_partial, atp_partial = _make_ac_atp( + seqlen_q, + seqlen_k_prior, + head_dim, + q_hbm.dtype, + False, + scale, + tp_q, + False, + total_sections, + ) + ac_partial.has_kv_used_len = True + atp_partial.dynamic_sel_mask = True + _allocate_attention_buffers( + allocator, + ac_partial, + atp_partial, + bufs, + sink, + k_cache_sbuf[:num_k_tiles_per_seg], + v_cache_sbuf[:num_v_tiles_per_seg], + ) + _setup_range_select_bounds( + ac_partial, + atp_partial, + bufs, + allocator, + None, + None, + None, + None, + batch_id=0, + kv_used_len=partial_prior_tokens, + ) + sp_partial = SectionParams( + section_idx=1, + section_offset=0, + section_offset_active=0, + next_section_offset_active=seqlen_k_prior, + section_contains_prefix=False, + next_section_contains_prefix=False, + kv_section_idx=0, + ) + sbuf_inner_partial = allocator.get_current_address() + _run_groups( + 0, + n_grps, + ac_partial, + atp_partial, + sp_partial, + bufs, + q_hbm, + 0, + o_prev_hbm, + sbuf_inner_partial, + sink=sink, + ) + allocator.set_current_address(active_stream_addr) + + sm_pat = [[num_grps, sb_p], [1, num_grps]] + + # --- PRIOR SEGMENTS (section_idx=1, kv_section_idx=0, dynamic loop) --- + # section_idx=1 triggers accumulation: _write_back_impl loads prev output from o_prev_hbm, + # applies correction factor, adds fresh PV, writes back. Running stats update in SBUF. + # kv_section_idx=0 ensures K/V indexing starts at tile 0 (each segment's own SBUF data). + sp_prior = SectionParams( + section_idx=1, + section_offset=0, + section_offset_active=0, + next_section_offset_active=seqlen_k_prior, + section_contains_prefix=False, + next_section_contains_prefix=False, + kv_section_idx=0, + ) + + prior_offset_save = allocator.alloc_sbuf_tensor(shape=(1, 1), dtype=nl.uint32) + nisa.tensor_copy(dst=prior_offset_save, src=prior_block_offset) + + num_prior_reg = nisa.register_alloc() + nisa.register_load(dst=num_prior_reg, src=num_full_prior_segments_i32) + + loop_addr = allocator.get_current_address() + + for _ in nl.dynamic_range(0, num_prior_reg): + nisa.tensor_scalar( + dst=prior_block_offset, data=prior_block_offset, op0=nl.subtract, operand0=num_blocks_per_seg + ) + + load_kv_cache_fn( + k_cache, + v_cache, + block_tables, + k_cache_sbuf, + v_cache_sbuf, + b_i, + h_i, + prior_block_offset, + num_blocks_per_seg, + allocator, + k_pre_transposed=k_pre_transposed, + ) + + allocator.set_current_address(sbuf_outer) + _allocate_attention_buffers(allocator, ac_p, atp_p, bufs, sink, k_cache_sbuf, v_cache_sbuf) + + # KVP: compute kvp_offset_prior = kvp_offset - prior_block_offset * block_size + if is_kvp: + nisa.tensor_scalar(dst=kvp_offset_prior_sbuf, data=prior_block_offset, op0=nl.multiply, operand0=block_size) + nisa.tensor_tensor(dst=kvp_offset_prior_sbuf, data1=kvp_offset, data2=kvp_offset_prior_sbuf, op=nl.subtract) + nisa.dma_copy(dst=kvp_offset_prior_hbm, src=kvp_offset_prior_sbuf) + + prior_cp_offset = kvp_offset_prior_hbm if is_kvp else None + + if is_kvp: + # KVP: use attention_cte_fn with cp_offset, then reduce into accumulated output + init_sbuf_addr = allocator.get_current_address() + attention_cte_fn( + q_hbm, + None, + None, + scale=scale, + causal_mask=True, + tp_q=tp_q, + tp_k=False, + tp_out=False, + cache_softmax=True, + skip_output_normalization=True, + k_cache_sbuf=k_cache_sbuf[:num_k_tiles_per_seg], + v_cache_sbuf=v_cache_sbuf[:num_v_tiles_per_seg], + out_o_hbm=o_curr_hbm, + out_neg_max_hbm=neg_max_curr_hbm, + out_sum_hbm=sum_curr_hbm, + init_sbuf_addr=init_sbuf_addr, + cp_offset=prior_cp_offset, + global_cp_deg=1, + k_scale_sb=k_scale_sb, + ) + allocator.set_current_address(init_sbuf_addr) + + # Reduce current segment into accumulated output + softmax_pat = [[num_grps, sb_p], [1, num_grps]] + o_pat = [[head_dim, sb_p], [1, head_dim]] + num_free = min(num_grps, _MAX_FREE_TILES) + neg_max_prev_sb = allocator.alloc_sbuf_tensor(shape=(sb_p, num_grps), dtype=nl.float32) + sum_prev_sb = allocator.alloc_sbuf_tensor(shape=(sb_p, num_grps), dtype=nl.float32) + neg_max_curr_sb = allocator.alloc_sbuf_tensor(shape=(sb_p, num_grps), dtype=nl.float32) + sum_curr_sb_buf = allocator.alloc_sbuf_tensor(shape=(sb_p, num_grps), dtype=nl.float32) + o_prev_sb = allocator.alloc_sbuf_tensor( + shape=(sb_p, head_dim), dtype=nl.float32, block_dim=[num_grps], num_free_tiles=[num_free] + ) + o_curr_sb = allocator.alloc_sbuf_tensor( + shape=(sb_p, head_dim), dtype=nl.float32, block_dim=[num_grps], num_free_tiles=[num_free] + ) + o_new_sb = allocator.alloc_sbuf_tensor( + shape=(sb_p, head_dim), dtype=nl.float32, block_dim=[num_grps], num_free_tiles=[num_free] + ) + batch_loop_addr = allocator.get_current_address() + reduce_one_batch( + o_prev_hbm, + neg_max_prev_hbm, + sum_prev_hbm, + o_curr_hbm, + neg_max_curr_hbm, + sum_curr_hbm, + 0, + 0, + num_grps, + head_dim, + num_grps, + sb_p, + softmax_pat, + o_pat, + neg_max_prev_sb, + sum_prev_sb, + neg_max_curr_sb, + sum_curr_sb_buf, + o_prev_sb, + o_curr_sb, + o_new_sb, + batch_loop_addr, + allocator, + ) + # Reload updated stats into SBUF running buffers + nisa.dma_copy(dst=bufs.mm1_running_max, src=neg_max_prev_hbm.ap(pattern=softmax_pat, offset=0)) + nisa.dma_copy(dst=bufs.exp_running_sum, src=sum_prev_hbm.ap(pattern=softmax_pat, offset=0)) + else: + _setup_range_select_bounds(ac_p, atp_p, bufs, allocator, None, None, None, None, batch_id=0) + sbuf_inner_p = allocator.get_current_address() + _run_groups(0, n_grps, ac_p, atp_p, sp_prior, bufs, q_hbm, 0, o_prev_hbm, sbuf_inner_p) + + allocator.set_current_address(loop_addr) + + # Restore + nisa.tensor_copy(dst=prior_block_offset, src=prior_offset_save) + + # Write running stats to HBM for caller's normalization + sm_pat = [[num_grps, sb_p], [1, num_grps]] + nisa.dma_copy(dst=neg_max_prev_hbm.ap(pattern=sm_pat, offset=0), src=bufs.mm1_running_max) + nisa.dma_copy(dst=sum_prev_hbm.ap(pattern=sm_pat, offset=0), src=bufs.exp_running_sum) + + allocator.set_current_address(orig_addr) diff --git a/src/neuronx_distributed_inference/modules/attention/utils.py b/src/neuronx_distributed_inference/modules/attention/utils.py index b83da8b4..cfe28726 100644 --- a/src/neuronx_distributed_inference/modules/attention/utils.py +++ b/src/neuronx_distributed_inference/modules/attention/utils.py @@ -83,6 +83,10 @@ def pad_to_128_multiple(x, dim, tensor_grp_size=None): quantized_weight_cache = {} +def _is_row_parallel_quantized_projection(prefix: str) -> bool: + return "down_proj" in prefix or "o_proj" in prefix + + def _get_weight_from_state_dict_quantized(prefix: str, state_dict: Dict[str, Any], tensor_grp_size: Optional[int] = None) -> torch.Tensor: """ Get weight from state dict with quantization support. @@ -106,7 +110,7 @@ def _get_weight_from_state_dict_quantized(prefix: str, state_dict: Dict[str, Any assert ( quantized_tensor.dtype == torch.float8_e4m3fn ), "Expected weight type to be float8_e4m3fn" - dim = 0 if "down_proj" in prefix else 1 + dim = 0 if _is_row_parallel_quantized_projection(prefix) else 1 quantized_tensor = pad_to_128_multiple(quantized_tensor.view(torch.int8).t(), dim, tensor_grp_size) quantized_tensor = quantized_tensor.view(torch.float8_e4m3fn) quantized_tensor = quantized_tensor.contiguous() @@ -151,7 +155,7 @@ def _get_scale_from_state_dict_quantized(prefix: str, state_dict: Dict[str, Any] # transpose --> [1, H] # broadcast --> [128, H] scale = state_dict[prefix + "scale"] - if "down_proj" not in prefix: + if not _is_row_parallel_quantized_projection(prefix): scale = pad_to_128_multiple(scale, 0, tensor_grp_size) scale = scale.t() scale = torch.broadcast_to(scale, (128, scale.shape[1])) diff --git a/src/neuronx_distributed_inference/modules/autobucketing.py b/src/neuronx_distributed_inference/modules/autobucketing.py index 81d4f8e2..759c77dc 100644 --- a/src/neuronx_distributed_inference/modules/autobucketing.py +++ b/src/neuronx_distributed_inference/modules/autobucketing.py @@ -159,6 +159,12 @@ def generate_buckets_for_cte(inference_config: InferenceConfig): if inference_config.neuron_config.is_chunked_prefill: return generate_buckets_for_chunked_prefill_cte(inference_config) + if ( + inference_config.neuron_config.is_prefix_caching + and inference_config.neuron_config.context_encoding_bucket_pairs is not None + ): + return inference_config.neuron_config.context_encoding_bucket_pairs + if not inference_config.neuron_config.enable_bucketing: if inference_config.neuron_config.is_prefix_caching: buckets = generate_2d_buckets_for_prefix_caching( diff --git a/src/neuronx_distributed_inference/modules/generation/sampling.py b/src/neuronx_distributed_inference/modules/generation/sampling.py index 540eb5a1..41a63631 100644 --- a/src/neuronx_distributed_inference/modules/generation/sampling.py +++ b/src/neuronx_distributed_inference/modules/generation/sampling.py @@ -369,24 +369,42 @@ def _multinomial(self, probs, dim, num_samples=1): counts = torch.sum(greater_than_rand, dim=dim).unsqueeze(dim) return counts - def _argmax_sample(self, token_logits, return_values, dim): - if self.neuron_config.on_cpu: - return torch.argmax(token_logits, dim=dim) - else: - # distributed argmax - tokens = nxd_argmax( - tensor=token_logits, - dim=dim, - gather_dim=dim, - keepdim=False, - process_group=self.process_group, - disable_argmax_kernel=self.neuron_config.disable_argmax_kernel - ) - values = torch.ones(tokens.shape, dtype=token_logits.dtype, device=tokens.device) + def _argmax_sample( + self, + token_logits, + return_values, + dim, + disable_argmax_kernel_override=None, + ): + if self.neuron_config.on_cpu or not getattr( + self.neuron_config, "vocab_parallel", False + ): + tokens = torch.argmax(token_logits, dim=dim).to(torch.int32) if return_values: + values = torch.ones( + tokens.shape, dtype=token_logits.dtype, device=tokens.device + ) return tokens, values return tokens + # Distributed argmax is only needed when the vocab dimension is sharded. + tokens = nxd_argmax( + tensor=token_logits, + dim=dim, + gather_dim=dim, + keepdim=False, + process_group=self.process_group, + disable_argmax_kernel=( + self.neuron_config.disable_argmax_kernel + if disable_argmax_kernel_override is None + else disable_argmax_kernel_override + ), + ) + values = torch.ones(tokens.shape, dtype=token_logits.dtype, device=tokens.device) + if return_values: + return tokens, values + return tokens + def _multinomial_sample(self, token_logits, sampling_params, return_values, dim, rank_id): batch_size = token_logits.shape[0] top_k = sampling_params[:, 0] @@ -432,7 +450,14 @@ def _multinomial_sample(self, token_logits, sampling_params, return_values, dim, counts = self._multinomial(probs_soft_max, dim) return torch.gather(input=top_k_logits_indices, dim=dim, index=counts).flatten() - def forward(self, token_logits, sampling_params, return_values=False, rank_id=None): + def forward( + self, + token_logits, + sampling_params, + return_values=False, + rank_id=None, + disable_argmax_kernel_override=None, + ): """ forward to perform topk, topp, temperature and multinomial sampling. @@ -461,7 +486,12 @@ def forward(self, token_logits, sampling_params, return_values=False, rank_id=No token_logits, sampling_params, return_values, dim, rank_id ) else: - return self._argmax_sample(token_logits, return_values, dim) + return self._argmax_sample( + token_logits, + return_values, + dim, + disable_argmax_kernel_override=disable_argmax_kernel_override, + ) class DataParallelSampler(Sampler): @@ -564,9 +594,22 @@ def _top_k_masked(self, logits, top_k, dim, rank_id): sorted_logits = sorted_logits.masked_fill_(mask, self.IGNORED_LOGITS_VALUE) return sorted_logits, indices - def forward(self, token_logits, sampling_params, return_values=False, rank_id=None): + def forward( + self, + token_logits, + sampling_params, + return_values=False, + rank_id=None, + disable_argmax_kernel_override=None, + ): # Override forward to handle final gathering - results = super().forward(token_logits, sampling_params, return_values, rank_id) + results = super().forward( + token_logits, + sampling_params, + return_values, + rank_id, + disable_argmax_kernel_override=disable_argmax_kernel_override, + ) if return_values: top_k_logits_indices, probs_soft_max = results[0], results[1] if self.do_sample or self.dynamic or self.is_medusa: diff --git a/src/neuronx_distributed_inference/modules/kvcache/block_kv_cache_manager.py b/src/neuronx_distributed_inference/modules/kvcache/block_kv_cache_manager.py index 8c3539a4..a0de59cd 100644 --- a/src/neuronx_distributed_inference/modules/kvcache/block_kv_cache_manager.py +++ b/src/neuronx_distributed_inference/modules/kvcache/block_kv_cache_manager.py @@ -131,10 +131,6 @@ def _fetch_cache(self, idx: int, kvcache_buffer=None): def get_kv_by_layer_id(self, idx, active_block_table, kvcache_buffer=None, **kwargs): k_cache, v_cache = self._fetch_cache(idx, kvcache_buffer=kvcache_buffer) - if self.kv_quant_config: - k_cache = self._dequantize_cache(k_cache, idx, is_key=True) - v_cache = self._dequantize_cache(v_cache, idx, is_key=False) - if self.is_prefix_caching: key_state = self._get_block_cache_and_reshape_bhsd(k_cache, active_block_table) value_state = self._get_block_cache_and_reshape_bhsd(v_cache, active_block_table) @@ -145,8 +141,32 @@ def get_kv_by_layer_id(self, idx, active_block_table, kvcache_buffer=None, **kwa else: raise ValueError("Can't find a proper way to read block KV cache.") + if self.kv_quant_config: + key_state = self._dequantize_cache(key_state, idx, is_key=True) + value_state = self._dequantize_cache(value_state, idx, is_key=False) + return key_state, value_state + @staticmethod + def _safe_active_block_table(active_block_table: Tensor, num_blocks: int) -> Tensor: + """Map padded/invalid block ids to the reserved padding block.""" + pad_block_id = torch.full_like(active_block_table, num_blocks - 1) + valid_block = torch.logical_and( + active_block_table >= 0, + active_block_table < num_blocks, + ) + return torch.where(valid_block, active_block_table, pad_block_id) + + def get_raw_kv_by_layer_id(self, idx, kvcache_buffer=None, **kwargs): + """Return the block-layout KV cache without flattening through a block table.""" + k_cache, v_cache = self._fetch_cache(idx, kvcache_buffer=kvcache_buffer) + + if self.kv_quant_config: + k_cache = self._dequantize_cache(k_cache, idx, is_key=True) + v_cache = self._dequantize_cache(v_cache, idx, is_key=False) + + return k_cache, v_cache + def _get_block_cache_and_reshape_bhsd(self, cache: Tensor, active_block_table: Tensor): """ Reorder the cache based on the table indices from active_block_table, and return @@ -165,14 +185,27 @@ def _get_block_cache_and_reshape_bhsd(self, cache: Tensor, active_block_table: T batch_size, _ = active_block_table.shape if self.block_tiling: - _, _, num_block_tiles, num_heads_per_rank, head_dimension = cache.shape + num_blocks, _, num_block_tiles, num_heads_per_rank, head_dimension = cache.shape + active_block_table = self._safe_active_block_table( + active_block_table, + num_blocks, + ) cache_reshaped = cache.reshape(-1, num_block_tiles, num_heads_per_rank, head_dimension) index_array = active_block_table.reshape(-1) * self.block_tiling_factor - index_array = index_array.unsqueeze(-1) + torch.arange(self.block_tiling_factor) + index_array = index_array.unsqueeze(-1) + torch.arange( + self.block_tiling_factor, + device=active_block_table.device, + dtype=active_block_table.dtype, + ) selected_cache = cache_reshaped.index_select( dim=0, index=index_array.reshape(-1) ).reshape(batch_size, -1, num_heads_per_rank, head_dimension) else: + num_blocks = cache.shape[0] + active_block_table = self._safe_active_block_table( + active_block_table, + num_blocks, + ) selected_cache = cache.index_select( dim=0, index=active_block_table.reshape(-1) ).reshape(batch_size, -1, num_heads_per_rank, head_dimension) @@ -200,10 +233,18 @@ def _get_cache_for_chunked_prefill( # TKG usecase batch_size, _ = active_block_table.shape num_blocks, num_heads_per_rank, block_size, head_dimension = cache.shape + active_block_table = self._safe_active_block_table( + active_block_table, + num_blocks, + ) cache = cache.reshape(num_blocks * num_heads_per_rank, block_size * head_dimension) - indices = torch.arange(num_heads_per_rank).reshape(1, -1, 1) \ + indices = torch.arange( + num_heads_per_rank, + dtype=active_block_table.dtype, + device=active_block_table.device, + ).reshape(1, -1, 1) \ + active_block_table.reshape(batch_size, 1, -1) * num_heads_per_rank indices = indices.reshape(-1) @@ -309,11 +350,9 @@ def _update_cache_with_reshape(self, latest, cache, slot_mapping, padding_id=-1) else: pad_dest_index = torch.tensor((num_blocks - 1) * block_size, device=device, dtype=dtype) - slot_mapping = torch.where( - slot_mapping == padding_id, - pad_dest_index, - slot_mapping, - ) + max_slot_index = torch.tensor(cache.shape[0], device=device, dtype=dtype) + valid_slot = torch.logical_and(slot_mapping >= 0, slot_mapping < max_slot_index) + slot_mapping = torch.where(valid_slot, slot_mapping, pad_dest_index) slot_mapping = slot_mapping.expand( (batch_size * n_active_tokens, num_heads_per_rank * head_dim) ) @@ -347,11 +386,9 @@ def _update_cache_with_index_put( pad_dest_index = torch.tensor(num_blocks * block_size - 1, device=device, dtype=dtype) - slot_mapping = torch.where( - slot_mapping == padding_id, - pad_dest_index, - slot_mapping, - ) + max_slot_index = torch.tensor(num_blocks * block_size, device=device, dtype=dtype) + valid_slot = torch.logical_and(slot_mapping >= 0, slot_mapping < max_slot_index) + slot_mapping = torch.where(valid_slot, slot_mapping, pad_dest_index) block_id = slot_mapping // self.pa_block_size block_id = block_id.view(batch_size, 1, n_active_tokens) @@ -380,21 +417,34 @@ def generate_tokengen_slot_mapping( block_size: torch.Tensor, ): B = position_ids.shape[0] + if block_table.shape[0] == 0 or block_table.shape[1] == 0: + return torch.ones_like(slot_mapping) * -1 # Determine active sequences from slot mapping -1 pad active_mask = (slot_mapping >= 0) row_indices = torch.arange(B, dtype=position_ids.dtype, device=position_ids.device) + safe_row_indices = row_indices.clamp(min=0, max=block_table.shape[0] - 1) block_indices = (position_ids // block_size).squeeze(dim=1) - - block_number = block_table[row_indices, block_indices] + valid_block_index = torch.logical_and( + block_indices >= 0, + block_indices < block_table.shape[1], + ) + safe_block_indices = block_indices.clamp(min=0, max=block_table.shape[1] - 1) + + block_number = block_table[safe_row_indices, safe_block_indices] + valid_block_number = block_number >= 0 block_offset = (position_ids % block_size).squeeze(dim=1) cur_slots = block_size * block_number + block_offset cur_slots = cur_slots.unsqueeze(dim=1) - # Mask out inactive sequences + # Mask out inactive/padded rows before any invalid table entry can become a slot. inactive_slots = torch.ones_like(cur_slots) * -1 - final_slots = torch.where(active_mask, cur_slots, inactive_slots) + valid_generated_slot = torch.logical_and( + active_mask, + torch.logical_and(valid_block_index.unsqueeze(dim=1), valid_block_number.unsqueeze(dim=1)), + ) + final_slots = torch.where(valid_generated_slot, cur_slots, inactive_slots) return final_slots @@ -407,6 +457,8 @@ def generate_fusedspec_slot_mapping( ): B = position_ids.shape[0] speculation_length = slot_mapping.shape[1] + if block_table.shape[0] == 0 or block_table.shape[1] == 0: + return torch.ones_like(slot_mapping) * -1 # Determine active sequences from slot mapping -1 pad active_mask = ~torch.all(slot_mapping < 0, dim=1).unsqueeze(dim=1) @@ -417,15 +469,30 @@ def generate_fusedspec_slot_mapping( expanded_positions = position_ids + relative_speculative_positions row_indices = torch.arange(B, dtype=position_ids.dtype, device=position_ids.device).unsqueeze(dim=1) + safe_row_indices = row_indices.clamp(min=0, max=block_table.shape[0] - 1) expanded_row_indices = torch.tile(row_indices, (1, speculation_length)) + expanded_safe_row_indices = torch.tile(safe_row_indices, (1, speculation_length)) expanded_block_indices = (expanded_positions // block_size) - block_number = block_table[expanded_row_indices, expanded_block_indices] + valid_block_index = torch.logical_and( + expanded_block_indices >= 0, + expanded_block_indices < block_table.shape[1], + ) + expanded_safe_block_indices = expanded_block_indices.clamp( + min=0, + max=block_table.shape[1] - 1, + ) + block_number = block_table[expanded_safe_row_indices, expanded_safe_block_indices] + valid_block_number = block_number >= 0 block_offset = (expanded_positions % block_size) cur_slots = block_size * block_number + block_offset - # Mask out inactive sequences + # Mask out inactive/padded rows before any invalid table entry can become a slot. inactive_slots = torch.ones_like(cur_slots) * -1 - final_slots = torch.where(expanded_active_mask, cur_slots, inactive_slots) + valid_generated_slot = torch.logical_and( + expanded_active_mask, + torch.logical_and(valid_block_index, valid_block_number), + ) + final_slots = torch.where(valid_generated_slot, cur_slots, inactive_slots) return final_slots diff --git a/src/neuronx_distributed_inference/modules/kvcache/hybrid_prefix_cache.py b/src/neuronx_distributed_inference/modules/kvcache/hybrid_prefix_cache.py new file mode 100644 index 00000000..94ebff84 --- /dev/null +++ b/src/neuronx_distributed_inference/modules/kvcache/hybrid_prefix_cache.py @@ -0,0 +1,238 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Hybrid prefix-boundary checkpoint cache primitives. + +Attention KV can remain in the normal vLLM/NxDI block cache. GDN recurrent and +conv state should not be cached as ordinary per-block data; it is a checkpoint +for the cumulative prefix at a reusable boundary. +""" + +from __future__ import annotations + +from collections import OrderedDict +from dataclasses import dataclass +from typing import Any, Hashable, Mapping + + +@dataclass(frozen=True) +class HybridPrefixCheckpointKey: + cumulative_prefix_hash: Hashable + prefix_len: int + cache_salt: Hashable | None = None + model_revision: Hashable | None = None + layout_version: int = 1 + + +@dataclass +class HybridPrefixCheckpoint: + key: HybridPrefixCheckpointKey + recurrent_states: Mapping[int, Any] + conv_states: Mapping[int, Any] + ref_count: int = 0 + last_access_tick: int = 0 + + def has_all_required_state(self, required_gdn_layers: tuple[int, ...]) -> bool: + return all( + layer_id in self.recurrent_states and layer_id in self.conv_states + for layer_id in required_gdn_layers + ) + + +@dataclass(frozen=True) +class HybridPrefixReusePlan: + attention_hit_len: int + gdn_checkpoint_hit_len: int + restore_checkpoint_prefix_len: int + residual_replay_len: int + suffix_len: int + checkpoint_key: HybridPrefixCheckpointKey | None + + +class HybridPrefixCheckpointCache: + """LRU/refcount cache for cumulative-prefix GDN checkpoints.""" + + def __init__( + self, + *, + required_gdn_layers: list[int] | tuple[int, ...], + checkpoint_interval: int, + max_checkpoints: int | None = None, + layout_version: int = 1, + ): + if not required_gdn_layers: + raise ValueError("required_gdn_layers must not be empty") + self.required_gdn_layers = tuple(int(layer) for layer in required_gdn_layers) + self.checkpoint_interval = int(checkpoint_interval) + if self.checkpoint_interval <= 0: + raise ValueError( + f"checkpoint_interval must be positive, got {checkpoint_interval}" + ) + self.max_checkpoints = max_checkpoints + if self.max_checkpoints is not None and self.max_checkpoints <= 0: + raise ValueError(f"max_checkpoints must be positive, got {max_checkpoints}") + self.layout_version = int(layout_version) + self._checkpoints: OrderedDict[ + HybridPrefixCheckpointKey, HybridPrefixCheckpoint + ] = OrderedDict() + self._tick = 0 + + def __len__(self) -> int: + return len(self._checkpoints) + + def _next_tick(self) -> int: + self._tick += 1 + return self._tick + + def make_key( + self, + *, + cumulative_prefix_hash: Hashable, + prefix_len: int, + cache_salt: Hashable | None = None, + model_revision: Hashable | None = None, + layout_version: int | None = None, + ) -> HybridPrefixCheckpointKey: + prefix_len = int(prefix_len) + if prefix_len < 0: + raise ValueError(f"prefix_len must be non-negative, got {prefix_len}") + if prefix_len % self.checkpoint_interval != 0: + raise ValueError( + "prefix_len must align to checkpoint_interval " + f"{self.checkpoint_interval}, got {prefix_len}" + ) + return HybridPrefixCheckpointKey( + cumulative_prefix_hash=cumulative_prefix_hash, + prefix_len=prefix_len, + cache_salt=cache_salt, + model_revision=model_revision, + layout_version=self.layout_version + if layout_version is None + else int(layout_version), + ) + + def put_checkpoint( + self, + *, + cumulative_prefix_hash: Hashable, + prefix_len: int, + recurrent_states: Mapping[int, Any], + conv_states: Mapping[int, Any], + cache_salt: Hashable | None = None, + model_revision: Hashable | None = None, + layout_version: int | None = None, + ) -> HybridPrefixCheckpointKey: + key = self.make_key( + cumulative_prefix_hash=cumulative_prefix_hash, + prefix_len=prefix_len, + cache_salt=cache_salt, + model_revision=model_revision, + layout_version=layout_version, + ) + checkpoint = HybridPrefixCheckpoint( + key=key, + recurrent_states=dict(recurrent_states), + conv_states=dict(conv_states), + last_access_tick=self._next_tick(), + ) + if not checkpoint.has_all_required_state(self.required_gdn_layers): + raise ValueError( + "checkpoint must include recurrent and conv state for every " + f"required GDN layer: {self.required_gdn_layers}" + ) + self._checkpoints[key] = checkpoint + self._checkpoints.move_to_end(key) + self.evict_to_capacity() + return key + + def get_checkpoint( + self, + key: HybridPrefixCheckpointKey, + ) -> HybridPrefixCheckpoint | None: + checkpoint = self._checkpoints.get(key) + if checkpoint is None: + return None + checkpoint.last_access_tick = self._next_tick() + self._checkpoints.move_to_end(key) + return checkpoint + + def inc_ref(self, key: HybridPrefixCheckpointKey) -> int: + checkpoint = self.get_checkpoint(key) + if checkpoint is None: + raise KeyError(key) + checkpoint.ref_count += 1 + return checkpoint.ref_count + + def dec_ref(self, key: HybridPrefixCheckpointKey) -> int: + checkpoint = self.get_checkpoint(key) + if checkpoint is None: + raise KeyError(key) + checkpoint.ref_count = max(0, checkpoint.ref_count - 1) + return checkpoint.ref_count + + def evict_to_capacity(self) -> list[HybridPrefixCheckpointKey]: + if self.max_checkpoints is None: + return [] + evicted: list[HybridPrefixCheckpointKey] = [] + for key, checkpoint in list(self._checkpoints.items()): + if len(self._checkpoints) <= self.max_checkpoints: + break + if checkpoint.ref_count > 0: + continue + del self._checkpoints[key] + evicted.append(key) + return evicted + + def compute_reuse_plan( + self, + *, + cumulative_hashes_by_prefix_len: Mapping[int, Hashable], + attention_hit_len: int, + request_prefix_len: int, + cache_salt: Hashable | None = None, + model_revision: Hashable | None = None, + layout_version: int | None = None, + ) -> HybridPrefixReusePlan: + attention_hit_len = max(0, int(attention_hit_len)) + request_prefix_len = max(0, int(request_prefix_len)) + target_suffix_start = min(attention_hit_len, request_prefix_len) + + candidate_prefix_lens = sorted( + ( + int(prefix_len) + for prefix_len in cumulative_hashes_by_prefix_len + if int(prefix_len) <= target_suffix_start + and int(prefix_len) % self.checkpoint_interval == 0 + ), + reverse=True, + ) + for prefix_len in candidate_prefix_lens: + key = self.make_key( + cumulative_prefix_hash=cumulative_hashes_by_prefix_len[prefix_len], + prefix_len=prefix_len, + cache_salt=cache_salt, + model_revision=model_revision, + layout_version=layout_version, + ) + checkpoint = self.get_checkpoint(key) + if checkpoint is None: + continue + if not checkpoint.has_all_required_state(self.required_gdn_layers): + continue + return HybridPrefixReusePlan( + attention_hit_len=attention_hit_len, + gdn_checkpoint_hit_len=prefix_len, + restore_checkpoint_prefix_len=prefix_len, + residual_replay_len=target_suffix_start - prefix_len, + suffix_len=request_prefix_len - target_suffix_start, + checkpoint_key=key, + ) + + return HybridPrefixReusePlan( + attention_hit_len=attention_hit_len, + gdn_checkpoint_hit_len=0, + restore_checkpoint_prefix_len=0, + residual_replay_len=target_suffix_start, + suffix_len=request_prefix_len - target_suffix_start, + checkpoint_key=None, + ) diff --git a/src/neuronx_distributed_inference/modules/kvcache/utils.py b/src/neuronx_distributed_inference/modules/kvcache/utils.py index 54f72f76..65d5a427 100644 --- a/src/neuronx_distributed_inference/modules/kvcache/utils.py +++ b/src/neuronx_distributed_inference/modules/kvcache/utils.py @@ -549,3 +549,71 @@ def get_kv_shapes(max_len: int, bsz: int, num_kv_heads_per_rank: int, head_dim: max_len, ) return k_shape, v_shape + + +def floor_to_block_boundary(length: int, block_size: int) -> int: + if block_size <= 0: + raise ValueError(f"block_size must be positive, got {block_size}") + return max(0, int(length)) // int(block_size) * int(block_size) + + +def get_cumulative_prefix_hash_at_boundary( + cumulative_hashes_by_prefix_len, + prefix_len: int, + block_size: int, +): + boundary = floor_to_block_boundary(prefix_len, block_size) + if boundary not in cumulative_hashes_by_prefix_len: + raise KeyError(f"missing cumulative prefix hash for boundary {boundary}") + return cumulative_hashes_by_prefix_len[boundary] + + +def validate_active_block_table_for_prefix( + active_block_table: Tensor, + prefix_len: int, + block_size: int, +) -> bool: + boundary = floor_to_block_boundary(prefix_len, block_size) + expected_blocks = boundary // int(block_size) + if active_block_table.numel() < expected_blocks: + raise ValueError( + f"active_block_table has {active_block_table.numel()} blocks, " + f"expected at least {expected_blocks} for prefix {boundary}" + ) + return True + + +def make_hybrid_restore_inputs( + checkpoint_slot_ids: Tensor, + restore_prefix_lens: Tensor, + *, + block_size: int, +): + if checkpoint_slot_ids.shape != restore_prefix_lens.shape: + raise ValueError("checkpoint_slot_ids and restore_prefix_lens must match") + if block_size <= 0: + raise ValueError(f"block_size must be positive, got {block_size}") + restore_mask = restore_prefix_lens > 0 + if torch.any(restore_prefix_lens[restore_mask] % int(block_size) != 0): + raise ValueError("restore_prefix_lens must be block aligned") + return ( + checkpoint_slot_ids.to(torch.int32), + restore_mask.to(torch.int32), + restore_prefix_lens.to(torch.int32), + ) + + +def make_hybrid_commit_inputs( + checkpoint_slot_ids: Tensor, + commit_prefix_lens: Tensor, + *, + block_size: int, +): + if checkpoint_slot_ids.shape != commit_prefix_lens.shape: + raise ValueError("checkpoint_slot_ids and commit_prefix_lens must match") + if block_size <= 0: + raise ValueError(f"block_size must be positive, got {block_size}") + commit_mask = commit_prefix_lens > 0 + if torch.any(commit_prefix_lens[commit_mask] % int(block_size) != 0): + raise ValueError("commit_prefix_lens must be block aligned") + return checkpoint_slot_ids.to(torch.int32), commit_mask.to(torch.int32) diff --git a/src/neuronx_distributed_inference/modules/sliding_window/attention.py b/src/neuronx_distributed_inference/modules/sliding_window/attention.py index 3a9a8d2c..a82fd4e8 100644 --- a/src/neuronx_distributed_inference/modules/sliding_window/attention.py +++ b/src/neuronx_distributed_inference/modules/sliding_window/attention.py @@ -35,7 +35,7 @@ class FlashConfig: windowed_context_encoding: bool = False # if True, uses offset-ed mask for WCTE window = sliding window -@nki.jit(mode="trace") +@nki.jit def transpose_p_local(p_local_transposed, p_local, LARGE_TILE_SZ, use_dma_transpose=False): for i in nl.affine_range(LARGE_TILE_SZ // B_F_SIZE): # Temporarily disable use_dma_tranpose by default until we stablized it @@ -58,7 +58,7 @@ def transpose_p_local(p_local_transposed, p_local, LARGE_TILE_SZ, use_dma_transp ) -@nki.jit(mode="trace") +@nki.jit def _flash_attention_core( q_local_tile, k, @@ -207,7 +207,7 @@ def _flash_attention_core( l_buffer[:, 0] = nl.add(m_current, nisa.activation(nl.log, exp, bias=ps)) -@nki.jit(mode="trace") +@nki.jit def load_v_tile(v_hbm_tile, cur_v_tile, j, v_i, config): LARGE_TILE_SZ = config.seq_tile_size B_P_SIZE = 128 diff --git a/src/neuronx_distributed_inference/utils/constants.py b/src/neuronx_distributed_inference/utils/constants.py index eb2e647c..6c0912c8 100644 --- a/src/neuronx_distributed_inference/utils/constants.py +++ b/src/neuronx_distributed_inference/utils/constants.py @@ -66,3 +66,18 @@ "qwen3_vl": {"causal-lm": NeuronQwen3VLForCausalLM, "image-encoding": NeuronQwen3VLForImageEncoding}, } + +# QWEN36_CONTRIB_VLLM_REGISTER_BEGIN +# Registered by contrib/models/Qwen3.6-27B/vllm/install_qwen36_vllm.sh. +# Requires PYTHONPATH to include the Qwen3.6-27B contrib directory at runtime. +try: + from src.modeling_qwen35 import ( + NeuronQwen35ForCausalLM as _Qwen36ContribForCausalLM, + ) +except Exception: + _Qwen36ContribForCausalLM = None + +if _Qwen36ContribForCausalLM is not None: + MODEL_TYPES.setdefault("qwen3_5", {})["causal-lm"] = _Qwen36ContribForCausalLM + MODEL_TYPES.setdefault("qwen3_5_text", {})["causal-lm"] = _Qwen36ContribForCausalLM +# QWEN36_CONTRIB_VLLM_REGISTER_END diff --git a/src/neuronx_distributed_inference/utils/hf_adapter.py b/src/neuronx_distributed_inference/utils/hf_adapter.py index 6b81b5b4..c789409f 100644 --- a/src/neuronx_distributed_inference/utils/hf_adapter.py +++ b/src/neuronx_distributed_inference/utils/hf_adapter.py @@ -199,8 +199,26 @@ def _sample( # forward pass to get next token outputs = self(**model_inputs, return_dict=True) - if outputs.logits is not None: - next_token_logits = outputs.logits[:, -1, :].clone() + if outputs.logits is not None and ( + not self.on_device_sampling or output_scores or output_logits + ): + logits = outputs.logits + if isinstance(logits, (list, tuple)): + logits = next( + ( + tensor + for tensor in logits + if hasattr(tensor, "ndim") + and tensor.ndim >= 2 + and torch.is_floating_point(tensor) + ), + logits[0], + ) + next_token_logits = ( + logits[:, -1, :].clone() + if logits.ndim >= 3 + else logits.clone() + ) # pre-process distribution next_token_scores = logits_processor(input_ids, next_token_logits) @@ -228,6 +246,18 @@ def _sample( next_tokens = self.sampler(next_token_scores, sampling_params) else: next_tokens = outputs.tokens + if isinstance(next_tokens, (list, tuple)): + next_tokens = next( + ( + tensor + for tensor in next_tokens + if hasattr(tensor, "ndim") + and not torch.is_floating_point(tensor) + ), + next_tokens[0], + ) + if hasattr(next_tokens, "ndim") and next_tokens.ndim > 1: + next_tokens = next_tokens.reshape(next_tokens.shape[0], -1)[:, -1] # finished sentences should have their next token be a padding token if has_eos_stopping_criteria: @@ -311,10 +341,11 @@ def prepare_inputs_for_generation( "medusa_args": (accepted_indices, current_length, medusa_mask, scatter_index), "sampling_params": sampling_params, "input_capture_hook": input_capture_hook, - "tensor_capture_hook": tensor_capture_hook, "adapter_ids": adapter_ids } ) + if tensor_capture_hook is not None: + model_inputs["tensor_capture_hook"] = tensor_capture_hook tf_args = [] if self.neuron_config.tensor_replacement_config: diff --git a/test/unit/models/test_model_wrapper.py b/test/unit/models/test_model_wrapper.py index 9471b49b..ec09795d 100644 --- a/test/unit/models/test_model_wrapper.py +++ b/test/unit/models/test_model_wrapper.py @@ -347,6 +347,207 @@ def test_batch_bucketing_input_generation(): assert inputs[1][1].shape == (2, 128) # attention_mask shape +def test_prefix_caching_cte_input_generation_uses_valid_block_slots(): + """Warmup inputs for prefix CTE should follow vLLM-style block mapping.""" + neuron_config = NeuronConfig( + batch_size=1, + torch_dtype=torch.float32, + buckets=[[256, 512]], + bucket_n_active_tokens=True, + is_prefix_caching=True, + is_block_kv_layout=True, + pa_block_size=256, + pa_num_blocks=9, + ) + config = InferenceConfig(neuron_config=neuron_config) + config.pad_token_id = 0 + config.hidden_size = 128 + + wrapper = ModelWrapper(config, MockModel, tag=CONTEXT_ENCODING_MODEL_TAG) + generated = wrapper.input_generator()[0] + + position_ids = generated[2] + slot_mapping = generated[11] + active_block_table = generated[12] + num_queries = generated[13] + computed_context_lens = generated[14] + + assert torch.equal(position_ids[0, :3], torch.tensor([512, 513, 514], dtype=torch.int32)) + assert torch.equal(position_ids[0, -3:], torch.tensor([765, 766, 767], dtype=torch.int32)) + assert torch.equal(slot_mapping, position_ids) + assert torch.equal(active_block_table, torch.tensor([[0, 1]], dtype=torch.int32)) + assert torch.equal(num_queries, torch.tensor([[256]], dtype=torch.int32)) + assert torch.equal(computed_context_lens, torch.tensor([[512]], dtype=torch.int32)) + + +def _batched_prefix_cte_wrapper(): + neuron_config = NeuronConfig( + batch_size=2, + torch_dtype=torch.float32, + buckets=[ + [256, 0], + [256, 256], + [256, 512], + [512, 0], + [512, 256], + [512, 512], + ], + bucket_n_active_tokens=True, + is_prefix_caching=True, + is_block_kv_layout=True, + pa_block_size=256, + pa_num_blocks=16, + seq_len=512, + max_context_length=512, + max_length=512, + ) + config = InferenceConfig(neuron_config=neuron_config) + config.pad_token_id = 0 + config.hidden_size = 128 + return ModelWrapper(config, MockModel, tag=CONTEXT_ENCODING_MODEL_TAG) + + +def _batched_prefix_cte_args(active_len, *, query_lens, prefix_lens): + batch_size = len(query_lens) + empty = torch.empty(0) + return ( + torch.ones((batch_size, active_len), dtype=torch.int32), + torch.ones((batch_size, active_len), dtype=torch.int32), + torch.arange(active_len, dtype=torch.int32).unsqueeze(0).expand(batch_size, -1), + torch.arange(batch_size, dtype=torch.int32), + torch.ones((batch_size, 3), dtype=torch.float32), + empty, + torch.zeros((batch_size,), dtype=torch.int32), + empty, + empty, + empty, + empty, + torch.arange(batch_size * active_len, dtype=torch.int32).reshape( + batch_size, active_len + ), + torch.arange(batch_size * 3, dtype=torch.int32).reshape(batch_size, 3), + torch.tensor(query_lens, dtype=torch.int32).reshape(batch_size, 1), + torch.tensor(prefix_lens, dtype=torch.int32).reshape(batch_size, 1), + ) + + +def test_prefix_caching_batched_cte_pad_inputs_to_prefill_bucket(): + wrapper = _batched_prefix_cte_wrapper() + padded = wrapper.pad_inputs( + *_batched_prefix_cte_args(16, query_lens=[16, 16], prefix_lens=[0, 0]) + ) + + assert padded[0].shape == (2, 256) + assert padded[1].shape == (1,) + assert padded[2].shape == (2, 256) + assert padded[11].shape == (2, 256) + assert padded[12].shape == (1,) + assert torch.equal(padded[13], torch.tensor([[16], [16]], dtype=torch.int32)) + assert torch.equal(padded[14], torch.tensor([[0], [0]], dtype=torch.int32)) + + +def test_prefix_caching_batched_cte_uses_max_prefix_bucket_for_mixed_rows(): + wrapper = _batched_prefix_cte_wrapper() + padded = wrapper.pad_inputs( + *_batched_prefix_cte_args(511, query_lens=[511, 1], prefix_lens=[0, 512]) + ) + + assert padded[0].shape == (2, 512) + assert padded[1].shape == (2, 512) + assert padded[2].shape == (2, 512) + assert padded[11].shape == (2, 512) + assert padded[12].shape == (2, 2) + assert torch.equal(padded[1][0], torch.zeros(512, dtype=torch.int32)) + assert torch.equal(padded[1][1], torch.ones(512, dtype=torch.int32)) + assert torch.equal(padded[13], torch.tensor([[511], [1]], dtype=torch.int32)) + assert torch.equal(padded[14], torch.tensor([[0], [512]], dtype=torch.int32)) + + +def test_prefix_caching_pad_zeros_hybrid_apc_controls_for_dummy_rows(): + wrapper = _batched_prefix_cte_wrapper() + wrapper.config.use_hybrid_apc_manager = True + captured = {} + + def _capture_forward(*args): + captured["args"] = args + return torch.zeros((2, 1), dtype=torch.float32) + + wrapper._forward = _capture_forward + wrapper.is_neuron = lambda: True + empty = torch.empty(0) + rotary_position_ids = torch.zeros((3, 1, 16), dtype=torch.int32) + hybrid_extra_args = ( + empty, + empty, + empty, + empty, + empty, + empty, + rotary_position_ids, + empty, + empty, + torch.tensor([3], dtype=torch.int32), + torch.tensor([1], dtype=torch.int32), + torch.tensor([256], dtype=torch.int32), + torch.tensor([4], dtype=torch.int32), + torch.tensor([1], dtype=torch.int32), + ) + + wrapper._forward_with_pad( + *_batched_prefix_cte_args(16, query_lens=[16], prefix_lens=[0]), + *hybrid_extra_args, + ) + + padded_args = captured["args"] + assert torch.equal(padded_args[24], torch.tensor([3, 0], dtype=torch.int32)) + assert torch.equal(padded_args[25], torch.tensor([1, 0], dtype=torch.int32)) + assert torch.equal(padded_args[26], torch.tensor([256, 0], dtype=torch.int32)) + assert torch.equal(padded_args[27], torch.tensor([4, 0], dtype=torch.int32)) + assert torch.equal(padded_args[28], torch.tensor([1, 0], dtype=torch.int32)) + + +def test_prefix_caching_hybrid_apc_restore_active_uses_control_tail(): + wrapper = _batched_prefix_cte_wrapper() + wrapper.config.use_hybrid_apc_manager = True + captured = {} + + def _capture_forward(*args): + captured["args"] = args + return torch.zeros((2, 1), dtype=torch.float32) + + wrapper._forward = _capture_forward + wrapper.is_neuron = lambda: True + empty = torch.empty(0) + rotary_position_ids = torch.zeros((3, 1, 16), dtype=torch.int32) + hybrid_extra_args = ( + empty, + empty, + empty, + empty, + empty, + empty, + rotary_position_ids, + empty, + empty, + torch.tensor([99], dtype=torch.int32), + torch.tensor([3], dtype=torch.int32), + torch.tensor([1], dtype=torch.int32), + torch.tensor([256], dtype=torch.int32), + torch.tensor([4], dtype=torch.int32), + torch.tensor([1], dtype=torch.int32), + ) + + wrapper._forward_with_pad( + *_batched_prefix_cte_args(16, query_lens=[16], prefix_lens=[0]), + *hybrid_extra_args, + ) + + padded_args = captured["args"] + assert torch.equal(padded_args[-5], torch.tensor([3, 0], dtype=torch.int32)) + assert torch.equal(padded_args[-4], torch.tensor([1, 0], dtype=torch.int32)) + assert wrapper._hybrid_apc_restore_active(padded_args) + + def test_batch_bucketing_target_bucket_selection(): """Test get_target_bucket selects smallest bucket that fits.""" config = create_base_config() diff --git a/test/unit/models/test_prefix_caching_bucket_selection.py b/test/unit/models/test_prefix_caching_bucket_selection.py index cf8fc92d..82b559a5 100644 --- a/test/unit/models/test_prefix_caching_bucket_selection.py +++ b/test/unit/models/test_prefix_caching_bucket_selection.py @@ -84,6 +84,29 @@ def test_cte_no_spec(self, inp_args, prefill_bucket, prefix_bucket): assert computed_prefill_bucket == prefill_bucket assert computed_prefix_bucket == prefix_bucket + def test_cte_sparse_grid_selects_next_compiled_pair(self): + model_wrapper = self.setup_context_encoding() + model_wrapper.neuron_config.buckets = [ + [512, 0], + [512, 32768], + [1536, 0], + [1536, 65536], + [3072, 0], + [3072, 131072], + ] + inp_args = ( + [torch.tensor(0)] * 13 + + [torch.tensor([[512]])] + + [torch.tensor([[131072]])] + ) + + computed_prefill_bucket, computed_prefix_bucket = ( + model_wrapper.get_target_2d_bucket_for_prefix_caching(*inp_args) + ) + + assert computed_prefill_bucket == 3072 + assert computed_prefix_bucket == 131072 + @pytest.mark.parametrize( "inp_args, prefix_bucket", [ @@ -99,3 +122,553 @@ def test_tkg_no_spec(self, inp_args, prefix_bucket): computed_prefill_bucket, computed_prefix_bucket = model_wrapper.get_target_2d_bucket_for_prefix_caching(*inp_args) assert computed_prefill_bucket == 1 assert computed_prefix_bucket == prefix_bucket + + def test_tkg_bucket_selection_uses_decode_active_len_not_bad_num_queries(self): + model_wrapper = self.setup_token_generation() + inp_args = [ + torch.ones((1, 1), dtype=torch.int32), # input_ids + torch.ones((1, 15), dtype=torch.int32), # attention_mask + torch.tensor([[15]], dtype=torch.int32), # position_ids + torch.zeros((1,), dtype=torch.int32), # seq_ids + torch.ones((1, 3), dtype=torch.float32), # sampling_params + torch.empty(0), + torch.zeros((1,), dtype=torch.int32), # adapter_ids + torch.empty(0), + torch.empty(0), + torch.empty(0), + torch.empty(0), + torch.zeros((1, 1), dtype=torch.int32), # slot_mapping + torch.zeros((1, 4), dtype=torch.int32), # block_table + torch.tensor([[15]], dtype=torch.int32), # bad num_queries + torch.tensor([[15]], dtype=torch.int32), # computed_context_lens + ] + + computed_prefill_bucket, computed_prefix_bucket = ( + model_wrapper.get_target_2d_bucket_for_prefix_caching(*inp_args) + ) + + assert computed_prefill_bucket == 1 + assert computed_prefix_bucket == 16 + + def test_tkg_padding_rewrites_bad_num_queries_to_decode_active_len(self): + model_wrapper = self.setup_token_generation() + inp_args = [ + torch.ones((1, 1), dtype=torch.int32), # input_ids + torch.ones((1, 15), dtype=torch.int32), # attention_mask + torch.tensor([[15]], dtype=torch.int32), # position_ids + torch.zeros((1,), dtype=torch.int32), # seq_ids + torch.ones((1, 3), dtype=torch.float32), # sampling_params + torch.empty(0), + torch.zeros((1,), dtype=torch.int32), # adapter_ids + torch.empty(0), + torch.empty(0), + torch.empty(0), + torch.empty(0), + torch.zeros((1, 1), dtype=torch.int32), # slot_mapping + torch.zeros((1, 4), dtype=torch.int32), # block_table + torch.tensor([[15]], dtype=torch.int32), # bad num_queries + torch.tensor([[15]], dtype=torch.int32), # computed_context_lens + ] + + padded_args = model_wrapper._pad_prefix_caching_inputs(*inp_args) + + assert torch.equal(padded_args[13], torch.tensor([[1]], dtype=torch.int32)) + assert torch.equal(padded_args[14], torch.tensor([[15]], dtype=torch.int32)) + assert padded_args[0].shape[-1] == 1 + assert padded_args[1].shape[-1] == 16 + + def test_cte_hybrid_apc_restore_padding_keeps_suffix_and_prefix_bucket(self): + model_wrapper = self.setup_context_encoding() + model_wrapper.neuron_config.buckets = [ + [256, 0], + [256, 256], + [512, 0], + [512, 256], + ] + model_wrapper.neuron_config.pa_block_size = 256 + + suffix_len = 16 + restore_len = 256 + inp_args = [ + torch.arange(suffix_len, dtype=torch.int32).reshape(1, suffix_len), + torch.ones((1, suffix_len), dtype=torch.int32), + torch.arange( + restore_len, + restore_len + suffix_len, + dtype=torch.int32, + ).reshape(1, suffix_len), + torch.zeros((1,), dtype=torch.int32), + torch.ones((1, 3), dtype=torch.float32), + torch.empty(0), + torch.zeros((1,), dtype=torch.int32), + torch.empty(0), + torch.empty(0), + torch.empty(0), + torch.empty(0), + torch.arange(1792, 1792 + suffix_len, dtype=torch.int32).reshape( + 1, + suffix_len, + ), + torch.arange(8, dtype=torch.int32).reshape(1, 8), + torch.tensor([[suffix_len]], dtype=torch.int32), + torch.tensor([[restore_len]], dtype=torch.int32), + torch.empty(0), + torch.empty(0), + torch.empty(0), + torch.empty(0), + torch.empty(0), + torch.empty(0), + torch.empty(0), + torch.empty(0), + torch.empty(0), + torch.tensor([0], dtype=torch.int32), # restore slot + torch.tensor([1], dtype=torch.int32), # restore mask + torch.tensor([restore_len], dtype=torch.int32), + torch.tensor([0], dtype=torch.int32), # commit slot + torch.tensor([0], dtype=torch.int32), # commit mask + ] + + prefill_bucket, prefix_bucket = ( + model_wrapper.get_target_2d_bucket_for_prefix_caching(*inp_args) + ) + padded_args = model_wrapper._pad_prefix_caching_inputs(*inp_args) + + assert int(prefill_bucket) == 256 + assert int(prefix_bucket) == 256 + assert padded_args[0].shape == (1, 256) + assert padded_args[1].shape == (1, 256) + assert padded_args[2].shape == (1, 256) + assert torch.equal(padded_args[0][:, :suffix_len], inp_args[0]) + assert torch.equal(padded_args[1][:, :suffix_len], inp_args[1]) + assert torch.equal(padded_args[2][:, :suffix_len], inp_args[2]) + assert torch.equal(padded_args[11][:, :suffix_len], inp_args[11]) + assert torch.equal( + padded_args[11][:, suffix_len:], + torch.full((1, 256 - suffix_len), -1, dtype=torch.int32), + ) + assert padded_args[12].shape == (1, 1) + assert torch.equal(padded_args[12], torch.tensor([[0]], dtype=torch.int32)) + assert torch.equal(padded_args[13], torch.tensor([[suffix_len]], dtype=torch.int32)) + assert torch.equal(padded_args[14], torch.tensor([[restore_len]], dtype=torch.int32)) + + def test_cte_suffix_only_continuation_keeps_prefix_bucket(self): + model_wrapper = self.setup_context_encoding() + model_wrapper.neuron_config.buckets = [ + [256, 0], + [256, 256], + [512, 0], + [512, 256], + ] + model_wrapper.neuron_config.pa_block_size = 256 + + suffix_len = 256 + prefix_len = 256 + inp_args = [ + torch.arange(suffix_len, dtype=torch.int32).reshape(1, suffix_len), + torch.ones((1, suffix_len), dtype=torch.int32), + torch.arange( + prefix_len, + prefix_len + suffix_len, + dtype=torch.int32, + ).reshape(1, suffix_len), + torch.zeros((1,), dtype=torch.int32), + torch.ones((1, 3), dtype=torch.float32), + torch.empty(0), + torch.zeros((1,), dtype=torch.int32), + torch.empty(0), + torch.empty(0), + torch.empty(0), + torch.empty(0), + torch.arange(512, 512 + suffix_len, dtype=torch.int32).reshape( + 1, + suffix_len, + ), + torch.tensor([[0]], dtype=torch.int32), + torch.tensor([[suffix_len]], dtype=torch.int32), + torch.tensor([[prefix_len]], dtype=torch.int32), + ] + + prefill_bucket, prefix_bucket = ( + model_wrapper.get_target_2d_bucket_for_prefix_caching(*inp_args) + ) + padded_args = model_wrapper._pad_prefix_caching_inputs(*inp_args) + + assert int(prefill_bucket) == 256 + assert int(prefix_bucket) == 256 + assert padded_args[0].shape == (1, 256) + assert padded_args[1].shape == (1, 256) + assert padded_args[2].shape == (1, 256) + assert torch.equal(padded_args[0], inp_args[0]) + assert torch.equal(padded_args[1], torch.ones((1, 256), dtype=torch.int32)) + assert torch.equal(padded_args[2], inp_args[2]) + assert torch.equal(padded_args[11], inp_args[11]) + assert torch.equal(padded_args[12], torch.tensor([[0]], dtype=torch.int32)) + assert torch.equal( + padded_args[13], torch.tensor([[suffix_len]], dtype=torch.int32) + ) + assert torch.equal( + padded_args[14], torch.tensor([[prefix_len]], dtype=torch.int32) + ) + + def test_cte_suffix_only_partial_continuation_does_not_left_pad_slots(self): + model_wrapper = self.setup_context_encoding() + model_wrapper.neuron_config.buckets = [ + [256, 0], + [256, 256], + [256, 512], + [256, 1024], + ] + model_wrapper.neuron_config.pa_block_size = 256 + + suffix_len = 48 + prefix_len = 768 + inp_args = [ + torch.arange(suffix_len, dtype=torch.int32).reshape(1, suffix_len), + torch.ones((1, suffix_len), dtype=torch.int32), + torch.arange( + prefix_len, + prefix_len + suffix_len, + dtype=torch.int32, + ).reshape(1, suffix_len), + torch.zeros((1,), dtype=torch.int32), + torch.ones((1, 3), dtype=torch.float32), + torch.empty(0), + torch.zeros((1,), dtype=torch.int32), + torch.empty(0), + torch.empty(0), + torch.empty(0), + torch.empty(0), + torch.arange(1024, 1024 + suffix_len, dtype=torch.int32).reshape( + 1, + suffix_len, + ), + torch.tensor([[0, 1, 2]], dtype=torch.int32), + torch.tensor([[suffix_len]], dtype=torch.int32), + torch.tensor([[prefix_len]], dtype=torch.int32), + ] + + prefill_bucket, prefix_bucket = ( + model_wrapper.get_target_2d_bucket_for_prefix_caching(*inp_args) + ) + padded_args = model_wrapper._pad_prefix_caching_inputs(*inp_args) + + assert int(prefill_bucket) == 256 + assert int(prefix_bucket) == 1024 + assert torch.equal(padded_args[0][:, :suffix_len], inp_args[0]) + assert torch.equal(padded_args[2][:, :suffix_len], inp_args[2]) + assert torch.equal(padded_args[11][:, :suffix_len], inp_args[11]) + assert torch.equal( + padded_args[11][:, suffix_len:], + torch.full((1, 256 - suffix_len), -1, dtype=torch.int32), + ) + assert padded_args[1].shape == (1, 1024) + assert torch.equal( + padded_args[1][:, :prefix_len], + torch.ones((1, prefix_len), dtype=torch.int32), + ) + assert torch.equal( + padded_args[1][:, prefix_len:], + torch.zeros((1, 1024 - prefix_len), dtype=torch.int32), + ) + assert torch.equal( + padded_args[12], torch.tensor([[0, 1, 2, 0]], dtype=torch.int32) + ) + assert torch.equal( + padded_args[13], torch.tensor([[suffix_len]], dtype=torch.int32) + ) + assert torch.equal( + padded_args[14], torch.tensor([[prefix_len]], dtype=torch.int32) + ) + + def test_segmented_cte_padding_fills_active_block_table_from_slots(self): + model_wrapper = self.setup_context_encoding() + model_wrapper.neuron_config.buckets = [ + [256, 1024], + ] + model_wrapper.neuron_config.pa_block_size = 256 + model_wrapper.neuron_config.max_context_length = 1024 + model_wrapper.neuron_config.prefix_cte_attention_backend = "segmented_cte" + + suffix_len = 48 + prefix_len = 768 + inp_args = [ + torch.arange(suffix_len, dtype=torch.int32).reshape(1, suffix_len), + torch.ones((1, suffix_len), dtype=torch.int32), + torch.arange( + prefix_len, + prefix_len + suffix_len, + dtype=torch.int32, + ).reshape(1, suffix_len), + torch.zeros((1,), dtype=torch.int32), + torch.ones((1, 3), dtype=torch.float32), + torch.empty(0), + torch.zeros((1,), dtype=torch.int32), + torch.empty(0), + torch.empty(0), + torch.empty(0), + torch.empty(0), + torch.arange(1024, 1024 + suffix_len, dtype=torch.int32).reshape( + 1, + suffix_len, + ), + torch.tensor([[0, 1, 2, -1]], dtype=torch.int32), + torch.tensor([[suffix_len]], dtype=torch.int32), + torch.tensor([[prefix_len]], dtype=torch.int32), + ] + + padded_args = model_wrapper._pad_prefix_caching_inputs(*inp_args) + + assert torch.equal( + padded_args[12], + torch.tensor([[0, 1, 2, 4]], dtype=torch.int32), + ) + + def test_segmented_cte_padding_fills_short_suffix_block_after_prefix_hit(self): + model_wrapper = self.setup_context_encoding() + model_wrapper.neuron_config.buckets = [ + [512, 512], + ] + model_wrapper.neuron_config.pa_block_size = 256 + model_wrapper.neuron_config.max_context_length = 1024 + model_wrapper.neuron_config.prefix_cte_attention_backend = "segmented_cte" + + suffix_len = 14 + prefix_len = 512 + suffix_physical_block = 9 + inp_args = [ + torch.arange(suffix_len, dtype=torch.int32).reshape(1, suffix_len), + torch.ones((1, suffix_len), dtype=torch.int32), + torch.arange( + prefix_len, + prefix_len + suffix_len, + dtype=torch.int32, + ).reshape(1, suffix_len), + torch.zeros((1,), dtype=torch.int32), + torch.ones((1, 3), dtype=torch.float32), + torch.empty(0), + torch.zeros((1,), dtype=torch.int32), + torch.empty(0), + torch.empty(0), + torch.empty(0), + torch.empty(0), + torch.arange( + suffix_physical_block * 256, + suffix_physical_block * 256 + suffix_len, + dtype=torch.int32, + ).reshape(1, suffix_len), + torch.tensor([[0, 1, -1]], dtype=torch.int32), + torch.tensor([[suffix_len]], dtype=torch.int32), + torch.tensor([[prefix_len]], dtype=torch.int32), + ] + + padded_args = model_wrapper._pad_prefix_caching_inputs(*inp_args) + + assert torch.equal( + padded_args[12], + torch.tensor([[0, 1, suffix_physical_block, 0]], dtype=torch.int32), + ) + + def test_segmented_cte_cold_cte2048_keeps_active_block_table_when_batched(self): + model_wrapper = self.setup_context_encoding() + model_wrapper.neuron_config.buckets = [ + [2048, 0], + ] + model_wrapper.neuron_config.pa_block_size = 256 + model_wrapper.neuron_config.max_context_length = 32768 + model_wrapper.neuron_config.prefix_cte_attention_backend = "segmented_cte" + + active_len = 2048 + input_ids = torch.arange(active_len, dtype=torch.int32).repeat(2, 1) + attention_mask = torch.ones((2, active_len), dtype=torch.int32) + position_ids = torch.arange(active_len, dtype=torch.int32).repeat(2, 1) + slot_mapping = torch.stack( + ( + torch.arange(0, active_len, dtype=torch.int32), + torch.arange(4096, 4096 + active_len, dtype=torch.int32), + ) + ) + inp_args = [ + input_ids, + attention_mask, + position_ids, + torch.zeros((2,), dtype=torch.int32), + torch.ones((2, 3), dtype=torch.float32), + torch.empty(0), + torch.zeros((2,), dtype=torch.int32), + torch.empty(0), + torch.empty(0), + torch.empty(0), + torch.empty(0), + slot_mapping, + torch.full((2, 8), -1, dtype=torch.int32), + torch.full((2, 1), active_len, dtype=torch.int32), + torch.zeros((2, 1), dtype=torch.int32), + ] + + padded_args = model_wrapper._pad_prefix_caching_inputs(*inp_args) + + assert padded_args[12].shape == (2, 8) + assert torch.equal( + padded_args[12][0], + torch.arange(8, dtype=torch.int32), + ) + assert torch.equal( + padded_args[12][1], + torch.arange(16, 24, dtype=torch.int32), + ) + + def test_cte_batched_hybrid_apc_restore_padding_uses_full_attention_mask(self): + model_wrapper = self.setup_context_encoding() + model_wrapper.neuron_config.buckets = [ + [256, 0], + [256, 4096], + ] + model_wrapper.neuron_config.pa_block_size = 256 + + suffix_len = 16 + restore_len = 256 + attention_mask = torch.cat( + [ + torch.ones((1, suffix_len), dtype=torch.int32), + torch.cat( + [ + torch.ones((1, 12), dtype=torch.int32), + torch.zeros((1, suffix_len - 12), dtype=torch.int32), + ], + dim=1, + ), + ], + dim=0, + ) + inp_args = [ + torch.arange(2 * suffix_len, dtype=torch.int32).reshape(2, suffix_len), + attention_mask, + torch.arange( + restore_len, + restore_len + suffix_len, + dtype=torch.int32, + ).reshape(1, suffix_len).expand(2, -1), + torch.arange(2, dtype=torch.int32), + torch.ones((2, 3), dtype=torch.float32), + torch.empty(0), + torch.zeros((2,), dtype=torch.int32), + torch.empty(0), + torch.empty(0), + torch.empty(0), + torch.empty(0), + torch.arange(2 * suffix_len, dtype=torch.int32).reshape(2, suffix_len), + torch.tensor([[8], [9]], dtype=torch.int32), + torch.tensor([[suffix_len], [12]], dtype=torch.int32), + torch.tensor([[restore_len], [restore_len]], dtype=torch.int32), + torch.empty(0), + torch.empty(0), + torch.empty(0), + torch.empty(0), + torch.empty(0), + torch.empty(0), + torch.empty(0), + torch.empty(0), + torch.empty(0), + torch.tensor([0, 1], dtype=torch.int32), # restore slots + torch.tensor([1, 1], dtype=torch.int32), # restore mask + torch.tensor([restore_len, restore_len], dtype=torch.int32), + torch.tensor([0, 0], dtype=torch.int32), # commit slots + torch.tensor([0, 0], dtype=torch.int32), # commit mask + ] + + prefill_bucket, prefix_bucket = ( + model_wrapper.get_target_2d_bucket_for_prefix_caching(*inp_args) + ) + padded_args = model_wrapper._pad_prefix_caching_inputs(*inp_args) + + assert int(prefill_bucket) == 256 + assert int(prefix_bucket) == 4096 + assert padded_args[1].shape == (2, 4096) + assert torch.equal( + padded_args[1].sum(dim=1), + torch.tensor([restore_len + suffix_len, restore_len + 12]), + ) + assert torch.equal( + padded_args[1][0, : restore_len + suffix_len], + torch.ones((restore_len + suffix_len,), dtype=torch.int32), + ) + assert torch.equal( + padded_args[1][0, restore_len + suffix_len :], + torch.zeros((4096 - restore_len - suffix_len,), dtype=torch.int32), + ) + assert padded_args[12].shape == (2, 16) + + def test_cte_batched_hybrid_apc_restore_routes_mixed_warm_cold_to_compiled_shape(self): + model_wrapper = self.setup_context_encoding() + model_wrapper.neuron_config.buckets = [ + [256, 0], + [256, 256], + [256, 512], + [512, 0], + [512, 256], + [512, 512], + ] + model_wrapper.neuron_config.pa_block_size = 256 + + active_len = 272 + warm_suffix_len = 16 + restore_len = 256 + attention_mask = torch.zeros((2, active_len), dtype=torch.int32) + attention_mask[0, :warm_suffix_len] = 1 + attention_mask[1, :active_len] = 1 + inp_args = [ + torch.arange(2 * active_len, dtype=torch.int32).reshape(2, active_len), + attention_mask, + torch.arange(active_len, dtype=torch.int32).reshape(1, active_len).expand(2, -1), + torch.arange(2, dtype=torch.int32), + torch.ones((2, 3), dtype=torch.float32), + torch.empty(0), + torch.zeros((2,), dtype=torch.int32), + torch.empty(0), + torch.empty(0), + torch.empty(0), + torch.empty(0), + torch.arange(2 * active_len, dtype=torch.int32).reshape(2, active_len), + torch.tensor([[8, 9], [10, 11]], dtype=torch.int32), + torch.tensor([[warm_suffix_len], [active_len]], dtype=torch.int32), + torch.tensor([[restore_len], [0]], dtype=torch.int32), + torch.empty(0), + torch.empty(0), + torch.empty(0), + torch.empty(0), + torch.empty(0), + torch.empty(0), + torch.empty(0), + torch.empty(0), + torch.empty(0), + torch.tensor([0, 0], dtype=torch.int32), # restore slots + torch.tensor([1, 0], dtype=torch.int32), # restore mask + torch.tensor([restore_len, 0], dtype=torch.int32), + torch.tensor([1, 2], dtype=torch.int32), # commit slots + torch.tensor([1, 0], dtype=torch.int32), # commit mask + ] + + prefill_bucket, prefix_bucket = ( + model_wrapper.get_target_2d_bucket_for_prefix_caching(*inp_args) + ) + padded_args = model_wrapper._pad_prefix_caching_inputs(*inp_args) + + assert int(prefill_bucket) == 512 + assert int(prefix_bucket) == 512 + assert padded_args[0].shape == (2, 512) + assert padded_args[1].shape == (2, 512) + assert padded_args[2].shape == (2, 512) + assert padded_args[11].shape == (2, 512) + assert padded_args[12].shape == (2, 2) + assert torch.equal( + padded_args[12], + torch.tensor([[8, 0], [0, 0]], dtype=torch.int32), + ) + assert torch.equal( + padded_args[13], + torch.tensor([[warm_suffix_len], [active_len]], dtype=torch.int32), + ) + assert torch.equal( + padded_args[14], + torch.tensor([[restore_len], [0]], dtype=torch.int32), + ) diff --git a/test/unit/modules/attention/test_attention_base.py b/test/unit/modules/attention/test_attention_base.py index 9830c5b3..32569d07 100644 --- a/test/unit/modules/attention/test_attention_base.py +++ b/test/unit/modules/attention/test_attention_base.py @@ -772,6 +772,75 @@ def test_perform_prefix_prefill_sharded_flash_attn(mock_flash_fwd_call, attn_mod _check_prefix_prefill_flash_attn_kernel_call(mock_flash_fwd_kernel, attn_module, batch_size, seq_len, seq_len_prior) + +@pytest.mark.parametrize( + "attn_module", + [ + ({"num_key_value_heads": 4, "prefix_cte_attention_chunk_size": 2}), + ({"num_key_value_heads": 2, "prefix_cte_attention_chunk_size": 2}), + ], + indirect=["attn_module"], +) +@patch("neuronx_distributed_inference.modules.attention.attention_base._flash_fwd_call_nki") +def test_perform_prefix_prefill_chunked_prior_matches_native( + mock_flash_fwd_call, + attn_module, +): + q_len = 3 + seq_len_prior = 5 + batch_size = 1 + q = torch.rand((batch_size, attn_module.num_heads, q_len, attn_module.head_dim)) + k = torch.rand( + (batch_size, attn_module.num_key_value_heads, q_len, attn_module.head_dim) + ) + v = torch.rand( + (batch_size, attn_module.num_key_value_heads, q_len, attn_module.head_dim) + ) + k_prior = torch.rand( + (batch_size, attn_module.num_key_value_heads, seq_len_prior, attn_module.head_dim) + ) + v_prior = torch.rand( + (batch_size, attn_module.num_key_value_heads, seq_len_prior, attn_module.head_dim) + ) + prior_mask = torch.ones((batch_size, seq_len_prior)) + active_mask = _create_attn_mask(batch_size, q_len) + + attn_module.get_flash_attention_strategy = MagicMock( + return_value=FlashAttentionStrategy.SHARDED_KERNEL + ) + actual_output, actual_strategy = attn_module.perform_prefix_prefill( + q, + k, + v, + q_len, + batch_size, + prior_mask, + [k_prior, v_prior], + active_mask, + ) + + assert actual_strategy == FlashAttentionStrategy.NONE + mock_flash_fwd_call.assert_not_called() + + attn_module.neuron_config.prefix_cte_attention_chunk_size = None + attn_module.get_flash_attention_strategy = MagicMock( + return_value=FlashAttentionStrategy.NONE + ) + expected_output, expected_strategy = attn_module.perform_prefix_prefill( + q, + k, + v, + q_len, + batch_size, + prior_mask, + [k_prior, v_prior], + active_mask, + ) + + assert expected_strategy == FlashAttentionStrategy.NONE + torch.testing.assert_close(actual_output, expected_output, atol=1e-6, rtol=1e-5) + + @pytest.mark.parametrize( "attn_module, batch_size, flash_attention_strategy", # fmt: off diff --git a/test/unit/modules/attention/test_gqa.py b/test/unit/modules/attention/test_gqa.py index 39956383..3b814220 100644 --- a/test/unit/modules/attention/test_gqa.py +++ b/test/unit/modules/attention/test_gqa.py @@ -1,11 +1,81 @@ +import sys +import types from unittest.mock import MagicMock, Mock, call, patch import pytest import torch +_lora_pkg = types.ModuleType("neuronx_distributed_inference.modules.lora_serving") +_lora_pkg.__path__ = [] +_lora_module = types.ModuleType( + "neuronx_distributed_inference.modules.lora_serving.lora_module" +) +_lora_module.is_lora_module = lambda _module: False +sys.modules.setdefault( + "neuronx_distributed_inference.modules.lora_serving", + _lora_pkg, +) +sys.modules.setdefault( + "neuronx_distributed_inference.modules.lora_serving.lora_module", + _lora_module, +) + from neuronx_distributed_inference.modules.attention import gqa +def test_preshard_hook_preserves_qwen_qkv_gate_packed_weight(): + hidden_size = 16 + head_dim = 4 + num_attention_heads = 8 + num_key_value_heads = 4 + tp_degree = 4 + q_width = num_attention_heads * head_dim + kv_width = num_key_value_heads * head_dim + + qkv_proj = gqa.GroupQueryAttention_QKV( + hidden_size=hidden_size, + head_dim=head_dim, + num_attention_heads=num_attention_heads, + num_key_value_heads=num_key_value_heads, + tp_degree=tp_degree, + fused_qkv=True, + gather_output=False, + bias=False, + ) + qkv_proj.qwen_qkv_gate_packed = True + + q_weight = torch.arange(q_width * hidden_size, dtype=torch.float32).reshape( + q_width, hidden_size + ) + gate_weight = q_weight + 10_000 + k_weight = q_weight[:kv_width] + 20_000 + v_weight = q_weight[:kv_width] + 30_000 + packed_weight = torch.cat([q_weight, gate_weight, k_weight, v_weight], dim=0) + state_dict = {"layers.0.self_attn.Wqkv.weight": packed_weight.clone()} + + qkv_proj.preshard_hook(state_dict, "layers.0.self_attn.weight") + + heads_per_rank = num_attention_heads // tp_degree + q_heads = q_weight.reshape(num_attention_heads, head_dim, hidden_size) + gate_heads = gate_weight.reshape(num_attention_heads, head_dim, hidden_size) + q_gate_rank_blocks = [] + for rank in range(tp_degree): + start = rank * heads_per_rank + q_gate_rank_blocks.append(q_heads[start : start + heads_per_rank]) + q_gate_rank_blocks.append(gate_heads[start : start + heads_per_rank]) + expected_q_gate = torch.cat(q_gate_rank_blocks, dim=0).reshape( + 2 * q_width, + hidden_size, + ) + expected = torch.cat([expected_q_gate, k_weight, v_weight], dim=0) + + torch.testing.assert_close(state_dict["layers.0.self_attn.Wqkv.weight"], expected) + assert getattr(qkv_proj.Wqkv.weight, "fused_qkv") is True + assert getattr(qkv_proj.Wqkv.weight, "num_attention_heads") == num_attention_heads * 2 + assert getattr(qkv_proj.Wqkv.weight, "num_key_value_heads") == num_key_value_heads + assert getattr(qkv_proj.Wqkv.weight, "head_dim") == head_dim + + @pytest.mark.parametrize( "batch_size, seq_len, fuse_rope", # fmt: off @@ -64,6 +134,8 @@ def test_kernel_qkv_forward_rope_fusion(mock_qkv_kernel, batch_size, seq_len, fu qkv_proj.Wqkv.weight.shape = (hidden_size, fused_qkv_size) qkv_proj.Wqkv.weight.dtype = torch.float32 qkv_proj.Wqkv.bias = None + qkv_proj.Wqkv.scale = None + qkv_proj.Wqkv.input_scale = None # Mock _split_fused_qkv to return Q, K, V Q = torch.rand((batch_size, seq_len, num_attention_heads * head_dim // tp_degree)) @@ -93,6 +165,10 @@ def test_kernel_qkv_forward_rope_fusion(mock_qkv_kernel, batch_size, seq_len, fu # When rope fusion is disabled, cos_cache and sin_cache should NOT be passed assert kernel_kwargs["cos_cache"] is None assert kernel_kwargs["sin_cache"] is None + + assert kernel_kwargs["quantization_type"] == gqa.QuantizationType.NONE + assert kernel_kwargs["qkv_w_scale"] is None + assert kernel_kwargs["qkv_in_scale"] is None # Verify result is a tuple with Q, K, V, residual assert len(result) == 4 @@ -101,3 +177,213 @@ def test_kernel_qkv_forward_rope_fusion(mock_qkv_kernel, batch_size, seq_len, fu assert K.shape == (batch_size, seq_len, num_key_value_heads * head_dim // tp_degree) assert V.shape == (batch_size, seq_len, num_key_value_heads * head_dim // tp_degree) assert residual is None + + +@patch('neuronx_distributed_inference.modules.attention.gqa.qkv_kernel') +def test_kernel_qkv_forward_passes_fp8_weight_scale(mock_qkv_kernel): + hidden_size = 16 + head_dim = 4 + num_attention_heads = 8 + num_key_value_heads = 2 + tp_degree = 2 + batch_size = 1 + seq_len = 8 + fused_qkv_size = (num_attention_heads + 2 * num_key_value_heads) * head_dim // tp_degree + + hidden_states = torch.rand((batch_size, seq_len, hidden_size)) + QKV = torch.rand((batch_size, seq_len, fused_qkv_size)) + mock_kernel_call = MagicMock(return_value=QKV) + mock_qkv_kernel.__getitem__ = MagicMock(return_value=mock_kernel_call) + + qkv_proj = Mock(spec=gqa.GroupQueryAttention_QKV) + qkv_proj.num_attention_heads = num_attention_heads + qkv_proj.num_key_value_heads = num_key_value_heads + qkv_proj.tp_degree = tp_degree + qkv_proj.head_dim = head_dim + qkv_proj.fused_rmsnorm = False + qkv_proj.fused_rmsnorm_skip_gamma = False + qkv_proj.logical_nc_config = 1 + qkv_proj.bias = False + qkv_proj.qkv_kernel_nbsd_layout = False + qkv_proj.rms_norm_eps = 1e-6 + + qkv_proj.Wqkv = Mock() + qkv_proj.Wqkv.weight = Mock() + qkv_proj.Wqkv.weight.data = torch.rand((hidden_size, fused_qkv_size)) + qkv_scale = torch.rand((128, fused_qkv_size), dtype=torch.float32) + qkv_proj.Wqkv.scale = Mock() + qkv_proj.Wqkv.scale.data = qkv_scale + qkv_proj.Wqkv.input_scale = None + qkv_proj.Wqkv.bias = None + + Q = torch.rand((batch_size, seq_len, num_attention_heads * head_dim // tp_degree)) + K = torch.rand((batch_size, seq_len, num_key_value_heads * head_dim // tp_degree)) + V = torch.rand((batch_size, seq_len, num_key_value_heads * head_dim // tp_degree)) + qkv_proj._split_fused_qkv = Mock(return_value=(Q, K, V)) + + result = gqa.GroupQueryAttention_QKV._kernel_qkv_forward( + qkv_proj, hidden_states, None, None, None, None + ) + + kernel_kwargs = mock_kernel_call.call_args.kwargs + assert kernel_kwargs["quantization_type"] == gqa.QuantizationType.ROW + torch.testing.assert_close(kernel_kwargs["qkv_w_scale"], qkv_scale) + assert kernel_kwargs["qkv_in_scale"] is None + assert len(result) == 4 + + +@patch("neuronx_distributed_inference.modules.attention.gqa.reduce_from_tensor_model_parallel_region") +@patch("neuronx_distributed_inference.modules.attention.gqa.output_projection_cte") +def test_kernel_o_proj_uses_bnds_layout_without_quantization( + mock_output_projection_cte, + mock_reduce_from_tensor_model_parallel_region, +): + hidden_size = 16 + head_dim = 4 + num_attention_heads = 8 + num_key_value_heads = 2 + tp_degree = 2 + batch_size = 1 + seq_len = 8 + heads_per_core = num_attention_heads // tp_degree + nd = heads_per_core * head_dim + + attention_output = torch.rand((batch_size, seq_len, nd)) + kernel_out = torch.rand((batch_size, seq_len, hidden_size)) + mock_kernel_call = MagicMock(return_value=kernel_out) + mock_output_projection_cte.__getitem__ = MagicMock(return_value=mock_kernel_call) + mock_reduce_from_tensor_model_parallel_region.side_effect = lambda x, process_group=None: x + + o_proj = Mock(spec=gqa.GroupQueryAttention_O) + o_proj.num_attention_heads = num_attention_heads + o_proj.tp_degree = tp_degree + o_proj.head_dim = head_dim + o_proj.logical_nc_config = 1 + o_proj.bias = False + o_proj.quantized = False + o_proj.rpl_reduce_dtype = torch.float32 + o_proj.sequence_parallel_enabled = False + o_proj.tensor_model_parallel_group = None + o_proj.o_proj = Mock() + o_proj.o_proj.weight = Mock() + o_proj.o_proj.weight.shape = (nd, hidden_size) + o_proj.o_proj.weight.dtype = torch.float32 + o_proj.o_proj.weight.data = torch.rand((nd, hidden_size)) + o_proj.o_proj.bias = None + o_proj.o_proj.scale = None + + result = gqa.GroupQueryAttention_O._kernel_o_proj(o_proj, attention_output) + + mock_output_projection_cte.__getitem__.assert_called_once_with(o_proj.logical_nc_config) + kernel_kwargs = mock_kernel_call.call_args.kwargs + assert kernel_kwargs["attention"].shape == (batch_size, heads_per_core, head_dim, seq_len) + assert kernel_kwargs["quantization_type"] == gqa.QuantizationType.NONE + assert kernel_kwargs["weight_scales"] is None + torch.testing.assert_close(result, kernel_out.to(torch.float32)) + + +@patch("neuronx_distributed_inference.modules.attention.gqa.reduce_from_tensor_model_parallel_region") +@patch("neuronx_distributed_inference.modules.attention.gqa.output_projection_cte") +def test_kernel_o_proj_passes_fp8_row_weight_scales( + mock_output_projection_cte, + mock_reduce_from_tensor_model_parallel_region, +): + hidden_size = 16 + head_dim = 4 + num_attention_heads = 8 + num_key_value_heads = 2 + tp_degree = 2 + batch_size = 1 + seq_len = 8 + heads_per_core = num_attention_heads // tp_degree + nd = heads_per_core * head_dim + + attention_output = torch.rand((batch_size, seq_len, nd)) + kernel_out = torch.rand((batch_size, seq_len, hidden_size)) + mock_kernel_call = MagicMock(return_value=kernel_out) + mock_output_projection_cte.__getitem__ = MagicMock(return_value=mock_kernel_call) + mock_reduce_from_tensor_model_parallel_region.side_effect = lambda x, process_group=None: x + + o_proj = Mock(spec=gqa.GroupQueryAttention_O) + o_proj.num_attention_heads = num_attention_heads + o_proj.tp_degree = tp_degree + o_proj.head_dim = head_dim + o_proj.logical_nc_config = 1 + o_proj.bias = False + o_proj.quantized = True + o_proj.rpl_reduce_dtype = torch.float32 + o_proj.sequence_parallel_enabled = False + o_proj.tensor_model_parallel_group = None + o_proj.o_proj = Mock() + o_proj.o_proj.weight = Mock() + o_proj.o_proj.weight.shape = (nd, hidden_size) + o_proj.o_proj.weight.dtype = torch.float8_e4m3fn + o_proj.o_proj.weight.data = torch.rand((nd, hidden_size)).to(torch.float8_e4m3fn) + o_proj.o_proj.bias = None + weight_scales = torch.rand((128, hidden_size), dtype=torch.float32) + o_proj.o_proj.scale = Mock() + o_proj.o_proj.scale.data = weight_scales + + result = gqa.GroupQueryAttention_O._kernel_o_proj(o_proj, attention_output) + + mock_output_projection_cte.__getitem__.assert_called_once_with(o_proj.logical_nc_config) + kernel_kwargs = mock_kernel_call.call_args.kwargs + assert kernel_kwargs["attention"].shape == (batch_size, seq_len, heads_per_core, head_dim) + assert kernel_kwargs["quantization_type"] == gqa.QuantizationType.ROW + torch.testing.assert_close(kernel_kwargs["weight_scales"], weight_scales) + torch.testing.assert_close(result, kernel_out.to(torch.float32)) + + +@patch("neuronx_distributed_inference.modules.attention.gqa.reduce_from_tensor_model_parallel_region") +@patch("neuronx_distributed_inference.modules.attention.gqa.output_projection_cte") +def test_kernel_o_proj_folds_large_head_dim_for_fp8_row_weight_scales( + mock_output_projection_cte, + mock_reduce_from_tensor_model_parallel_region, +): + hidden_size = 16 + head_dim = 256 + num_attention_heads = 8 + tp_degree = 2 + batch_size = 1 + seq_len = 8 + heads_per_core = num_attention_heads // tp_degree + nd = heads_per_core * head_dim + + attention_output = torch.rand((batch_size, seq_len, nd)) + kernel_out = torch.rand((batch_size, seq_len, hidden_size)) + mock_kernel_call = MagicMock(return_value=kernel_out) + mock_output_projection_cte.__getitem__ = MagicMock(return_value=mock_kernel_call) + mock_reduce_from_tensor_model_parallel_region.side_effect = lambda x, process_group=None: x + + o_proj = Mock(spec=gqa.GroupQueryAttention_O) + o_proj.num_attention_heads = num_attention_heads + o_proj.tp_degree = tp_degree + o_proj.head_dim = head_dim + o_proj.logical_nc_config = 1 + o_proj.bias = False + o_proj.quantized = True + o_proj.rpl_reduce_dtype = torch.float32 + o_proj.sequence_parallel_enabled = False + o_proj.tensor_model_parallel_group = None + o_proj.o_proj = Mock() + o_proj.o_proj.weight = Mock() + o_proj.o_proj.weight.shape = (nd, hidden_size) + o_proj.o_proj.weight.dtype = torch.float8_e4m3fn + o_proj.o_proj.weight.data = torch.rand((nd, hidden_size)).to(torch.float8_e4m3fn) + o_proj.o_proj.bias = None + weight_scales = torch.rand((128, hidden_size), dtype=torch.float32) + o_proj.o_proj.scale = Mock() + o_proj.o_proj.scale.data = weight_scales + + result = gqa.GroupQueryAttention_O._kernel_o_proj(o_proj, attention_output) + + kernel_kwargs = mock_kernel_call.call_args.kwargs + assert kernel_kwargs["attention"].shape == ( + batch_size, + seq_len, + heads_per_core * 2, + head_dim // 2, + ) + assert kernel_kwargs["quantization_type"] == gqa.QuantizationType.ROW + torch.testing.assert_close(kernel_kwargs["weight_scales"], weight_scales) + torch.testing.assert_close(result, kernel_out.to(torch.float32)) diff --git a/test/unit/modules/generation/test_sampling.py b/test/unit/modules/generation/test_sampling.py index d7efcb66..24e2d515 100644 --- a/test/unit/modules/generation/test_sampling.py +++ b/test/unit/modules/generation/test_sampling.py @@ -331,6 +331,68 @@ def get_sampler(topk, num_beams, on_device=True): return Sampler(neuron_config, **sampler_kwargs) +def test_greedy_full_vocab_argmax_uses_torch_argmax(): + neuron_config = NeuronConfig( + on_device_sampling_config=OnDeviceSamplingConfig(do_sample=False) + ) + neuron_config.on_cpu = False + neuron_config.vocab_parallel = False + + with patch( + "neuronx_distributed_inference.modules.generation.sampling.parallel_state.get_tensor_model_parallel_group", + return_value=None, + ): + sampler = Sampler(neuron_config) + + logits = torch.tensor( + [[0.0, 4.0, 1.0], [3.0, -1.0, 2.0]], + dtype=torch.float32, + ) + with patch( + "neuronx_distributed_inference.modules.generation.sampling.nxd_argmax", + side_effect=AssertionError("distributed argmax should not be used"), + ): + tokens = sampler._argmax_sample(logits, return_values=False, dim=1) + + assert tokens.dtype == torch.int32 + assert tokens.tolist() == [1, 0] + + +def test_vocab_parallel_argmax_accepts_context_encoding_kernel_override(): + neuron_config = NeuronConfig( + on_device_sampling_config=OnDeviceSamplingConfig(do_sample=False), + disable_argmax_kernel=False, + ) + neuron_config.on_cpu = False + neuron_config.vocab_parallel = True + + with patch( + "neuronx_distributed_inference.modules.generation.sampling.parallel_state.get_tensor_model_parallel_group", + return_value="tp_group", + ): + sampler = Sampler(neuron_config) + + logits = torch.tensor( + [[0.0, 4.0, 1.0], [3.0, -1.0, 2.0]], + dtype=torch.float32, + ) + with patch( + "neuronx_distributed_inference.modules.generation.sampling.nxd_argmax", + return_value=torch.tensor([1, 0], dtype=torch.int32), + ) as mock_argmax: + tokens = sampler.forward( + logits, + torch.ones((2, 3), dtype=torch.float32), + disable_argmax_kernel_override=True, + ) + + assert tokens.dtype == torch.int32 + assert tokens.tolist() == [1, 0] + mock_argmax.assert_called_once() + assert mock_argmax.call_args.kwargs["process_group"] == "tp_group" + assert mock_argmax.call_args.kwargs["disable_argmax_kernel"] is True + + def run_sampler_accuracy_test(batch_size, topk, num_beams=1): torch.manual_seed(0) torch.distributed.init_process_group("xla", init_method="pjrt://") diff --git a/test/unit/modules/kvcache/test_block_kv_cache_manager.py b/test/unit/modules/kvcache/test_block_kv_cache_manager.py index 9bb9dacc..df8449c0 100644 --- a/test/unit/modules/kvcache/test_block_kv_cache_manager.py +++ b/test/unit/modules/kvcache/test_block_kv_cache_manager.py @@ -329,6 +329,47 @@ def test_block_kv_get_kv_by_layer_id_with_dequantization(): assert v_cache.dtype == torch.float32 +def test_block_kv_prefix_quantization_dequantizes_selected_blocks_only(): + kv_quant_config = KVQuantizationConfig( + quant_dtype=torch.bfloat16, + direct_cast=True, + ) + + kv_cache_mgr = _pa_prepare_cache_mgr( + tp_degree=1, + batch_size=1, + pa_num_blocks=16, + pa_block_size=128, + num_attention_heads=8, + num_kv_head=4, + hidden_size=32, + num_hidden_layers=2, + is_prefix_caching=True, + kv_quant_config=kv_quant_config, + ) + seen_shapes = [] + original_dequantize = kv_cache_mgr._dequantize_cache + + def record_dequantize(cache_tensor, layer_idx, is_key=True): + seen_shapes.append((is_key, tuple(cache_tensor.shape))) + return original_dequantize(cache_tensor, layer_idx, is_key=is_key) + + kv_cache_mgr._dequantize_cache = record_dequantize + + block_table = torch.tensor([[0, 2]], dtype=torch.int64) + k_cache, v_cache = kv_cache_mgr.get_kv_by_layer_id( + idx=0, + active_block_table=block_table, + ) + + assert k_cache.dtype == torch.float32 + assert v_cache.dtype == torch.float32 + assert seen_shapes == [ + (True, (1, 4, 256, 4)), + (False, (1, 4, 256, 4)), + ] + + def test_block_kv_direct_cast_update_integration(): """Test update_kv_by_layer_id with direct_cast quantization.""" kv_quant_config = KVQuantizationConfig( @@ -493,6 +534,89 @@ def test_prefix_caching_reading_kv_cache(): assert torch.equal(actual, expected) +def test_prefix_caching_reading_maps_invalid_block_ids_to_padding_block(): + kv_cache_mgr = _pa_prepare_cache_mgr( + tp_degree=1, + batch_size=1, + pa_num_blocks=4, + pa_block_size=128, + seq_len=128 * 128, + num_attention_heads=4, + num_kv_head=2, + hidden_size=8, + num_hidden_layers=1, + is_prefix_caching=True, + ) + assert not kv_cache_mgr.block_tiling + _pa_mock_kv_cache_in_mgr(kv_cache_mgr) + + pad_block_id = kv_cache_mgr.pa_num_blocks + kv_cache_mgr.past_key_values[0].data[pad_block_id].fill_(99) + safe_table = BlockKVCacheManager._safe_active_block_table( + torch.tensor([[0, -1, 2, 999]], dtype=torch.int64), + kv_cache_mgr.pa_num_blocks + BlockKVCacheManager._NUM_EXTRA_RESERVED_BLOCK, + ) + assert torch.equal( + safe_table, + torch.tensor([[0, pad_block_id, 2, pad_block_id]], dtype=torch.int64), + ) + + k_cache, _ = kv_cache_mgr.get_kv_by_layer_id( + idx=0, + active_block_table=torch.tensor([[0, -1, 2, 999]], dtype=torch.int64), + ) + assert torch.equal(k_cache[:, :, :128, :], torch.zeros_like(k_cache[:, :, :128, :])) + assert torch.equal( + k_cache[:, :, 128:256, :], + torch.full_like(k_cache[:, :, 128:256, :], 99), + ) + assert torch.equal( + k_cache[:, :, 256:384, :], + torch.full_like(k_cache[:, :, 256:384, :], 2), + ) + assert torch.equal( + k_cache[:, :, 384:512, :], + torch.full_like(k_cache[:, :, 384:512, :], 99), + ) + + +def test_prefix_caching_tiled_reading_maps_invalid_block_ids_to_padding_block(): + kv_cache_mgr = _pa_prepare_cache_mgr( + tp_degree=1, + batch_size=2, + pa_num_blocks=16, + pa_block_size=128, + seq_len=1024, + num_attention_heads=8, + num_kv_head=4, + hidden_size=16, + num_hidden_layers=1, + is_prefix_caching=True, + ) + assert kv_cache_mgr.block_tiling + _pa_mock_kv_cache_in_mgr(kv_cache_mgr) + + pad_block_id = kv_cache_mgr.pa_num_blocks + kv_cache_mgr.past_key_values[0].data[pad_block_id].fill_(77) + k_cache, _ = kv_cache_mgr.get_kv_by_layer_id( + idx=0, + active_block_table=torch.tensor([[0, -1, 2, 999]], dtype=torch.int64), + ) + assert torch.equal(k_cache[:, :, :128, :], torch.zeros_like(k_cache[:, :, :128, :])) + assert torch.equal( + k_cache[:, :, 128:256, :], + torch.full_like(k_cache[:, :, 128:256, :], 77), + ) + assert torch.equal( + k_cache[:, :, 256:384, :], + torch.full_like(k_cache[:, :, 256:384, :], 2), + ) + assert torch.equal( + k_cache[:, :, 384:512, :], + torch.full_like(k_cache[:, :, 384:512, :], 77), + ) + + def _pa_mock_kv_cache_in_mgr(kv_cache_mgr: BlockKVCacheManager): for layer_id in range(len(kv_cache_mgr.past_key_values)): for block_id in range(kv_cache_mgr.pa_num_blocks): @@ -860,6 +984,77 @@ def test_chunked_prefill_writing_kv_cache(): assert torch.equal(actual[block_id, :, block_offset, :], expected[0, :, seq_pos, :]) +def test_chunked_prefill_tkg_read_maps_invalid_block_ids_to_padding_block(): + kv_cache_mgr = _pa_prepare_cache_mgr( + tp_degree=1, + batch_size=1, + pa_num_blocks=4, + pa_block_size=2, + seq_len=128, + num_attention_heads=2, + num_kv_head=1, + hidden_size=2, + num_hidden_layers=1, + chunked_prefill_config=ChunkedPrefillConfig(tkg_model_enabled=True), + ) + with torch.no_grad(): + for block_id in range(4): + kv_cache_mgr.past_key_values[0][block_id].fill_(block_id) + + k_cache, _ = kv_cache_mgr.get_kv_by_layer_id( + idx=0, + active_block_table=torch.tensor([[0, -1, 2, 999]], dtype=torch.int64), + is_for_context_encoding=False, + ) + + assert torch.equal(k_cache[:, :, 0:2, :], torch.zeros_like(k_cache[:, :, 0:2, :])) + assert torch.equal( + k_cache[:, :, 2:4, :], + torch.full_like(k_cache[:, :, 2:4, :], 3), + ) + assert torch.equal( + k_cache[:, :, 4:6, :], + torch.full_like(k_cache[:, :, 4:6, :], 2), + ) + assert torch.equal( + k_cache[:, :, 6:8, :], + torch.full_like(k_cache[:, :, 6:8, :], 3), + ) + + +def test_chunked_prefill_update_maps_any_invalid_slot_to_padding_slot(): + kv_cache_mgr = _pa_prepare_cache_mgr( + tp_degree=1, + batch_size=1, + pa_num_blocks=4, + pa_block_size=2, + seq_len=128, + num_attention_heads=2, + num_kv_head=1, + hidden_size=2, + num_hidden_layers=1, + chunked_prefill_config=ChunkedPrefillConfig(tkg_model_enabled=True), + ) + latest = _pa_prepare_latest_kv_cache( + batch_size=1, + n_active_tokens=4, + head_dim=1, + num_kv_heads_per_rank=1, + num_hidden_layers=1, + ) + + updated_cache = kv_cache_mgr.update_cache( + is_for_context_encoding=False, + seq_ids=None, + position_ids=None, + new_key_values=latest, + seq_len=None, + scatter_index=torch.tensor([[0, -1, -2, 999]], dtype=torch.int64), + ) + + assert torch.equal(updated_cache[0][0, :, 0, :], latest[0][0][0, :, 0, :]) + + def _pa_prepare_latest_kv_cache( batch_size=1, n_active_tokens=32, @@ -983,6 +1178,50 @@ def test_generate_tokengen_slot_mapping_bs1(): assert torch.allclose(cpu_result, device_result) +def test_generate_tokengen_slot_mapping_masks_inactive_oob_block_index(): + block_size = torch.tensor(256, dtype=torch.int32) + position_ids = torch.tensor([[1024], [1023]], dtype=torch.int32) + target_slot_mapping = torch.tensor([[-1], [1023]], dtype=torch.int32) + block_table = torch.tensor( + [ + [0, 1, 2, 3], + [0, 1, 2, 3], + ], + dtype=torch.int32, + ) + + result = generate_tokengen_slot_mapping( + position_ids, + target_slot_mapping, + block_table, + block_size, + ) + + assert torch.equal(result, torch.tensor([[-1], [1023]], dtype=torch.int32)) + + +def test_generate_tokengen_slot_mapping_masks_active_invalid_table_entry(): + block_size = torch.tensor(256, dtype=torch.int32) + position_ids = torch.tensor([[1024], [1023]], dtype=torch.int32) + target_slot_mapping = torch.tensor([[1024], [1023]], dtype=torch.int32) + block_table = torch.tensor( + [ + [0, 1, 2, 3], + [0, 1, 2, 3], + ], + dtype=torch.int32, + ) + + result = generate_tokengen_slot_mapping( + position_ids, + target_slot_mapping, + block_table, + block_size, + ) + + assert torch.equal(result, torch.tensor([[-1], [1023]], dtype=torch.int32)) + + def test_generate_fusedspec_slot_mapping(): batch_size = 4 speculation_length = 5 @@ -1023,3 +1262,40 @@ def test_generate_fusedspec_slot_mapping(): assert cpu_result.shape == device_result.shape assert cpu_result.dtype == device_result.dtype assert torch.allclose(cpu_result, device_result) + + +def test_generate_fusedspec_slot_mapping_masks_inactive_oob_block_indices(): + block_size = torch.tensor(256, dtype=torch.int32) + position_ids = torch.tensor([[1024], [1022]], dtype=torch.int32) + target_slot_mapping = torch.tensor( + [ + [-1, -1, -1], + [1022, 1023, 1024], + ], + dtype=torch.int32, + ) + block_table = torch.tensor( + [ + [0, 1, 2, 3], + [0, 1, 2, 3], + ], + dtype=torch.int32, + ) + + result = generate_fusedspec_slot_mapping( + position_ids, + target_slot_mapping, + block_table, + block_size, + ) + + assert torch.equal( + result, + torch.tensor( + [ + [-1, -1, -1], + [1022, 1023, -1], + ], + dtype=torch.int32, + ), + ) diff --git a/test/unit/modules/kvcache/test_hybrid_prefix_cache.py b/test/unit/modules/kvcache/test_hybrid_prefix_cache.py new file mode 100644 index 00000000..495a2c07 --- /dev/null +++ b/test/unit/modules/kvcache/test_hybrid_prefix_cache.py @@ -0,0 +1,141 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +import unittest + +from neuronx_distributed_inference.modules.kvcache.hybrid_prefix_cache import ( + HybridPrefixCheckpointCache, +) + + +def _states(value): + return {0: f"l0-{value}", 1: f"l1-{value}", 2: f"l2-{value}"} + + +class TestHybridPrefixCheckpointCache(unittest.TestCase): + def test_reuses_deepest_cumulative_prefix_checkpoint(self): + cache = HybridPrefixCheckpointCache( + required_gdn_layers=[0, 1, 2], + checkpoint_interval=256, + ) + cache.put_checkpoint( + cumulative_prefix_hash="h256", + prefix_len=256, + recurrent_states=_states("r256"), + conv_states=_states("c256"), + ) + key512 = cache.put_checkpoint( + cumulative_prefix_hash="h512", + prefix_len=512, + recurrent_states=_states("r512"), + conv_states=_states("c512"), + ) + + plan = cache.compute_reuse_plan( + cumulative_hashes_by_prefix_len={ + 256: "h256", + 512: "h512", + 768: "h768", + 1024: "h1024", + }, + attention_hit_len=1024, + request_prefix_len=1200, + ) + + self.assertEqual(plan.checkpoint_key, key512) + self.assertEqual(plan.restore_checkpoint_prefix_len, 512) + self.assertEqual(plan.residual_replay_len, 512) + self.assertEqual(plan.suffix_len, 176) + + def test_missing_gdn_family_state_is_not_accepted(self): + cache = HybridPrefixCheckpointCache( + required_gdn_layers=[0, 1, 2], + checkpoint_interval=256, + ) + with self.assertRaisesRegex(ValueError, "every required GDN layer"): + cache.put_checkpoint( + cumulative_prefix_hash="h512", + prefix_len=512, + recurrent_states={0: "r0", 1: "r1", 2: "r2"}, + conv_states={0: "c0", 1: "c1"}, + ) + + def test_hash_salt_and_revision_are_part_of_identity(self): + cache = HybridPrefixCheckpointCache( + required_gdn_layers=[0, 1, 2], + checkpoint_interval=256, + ) + cache.put_checkpoint( + cumulative_prefix_hash="same-hash", + prefix_len=256, + recurrent_states=_states("r"), + conv_states=_states("c"), + cache_salt="tenant-a", + model_revision="rev-a", + ) + + miss = cache.compute_reuse_plan( + cumulative_hashes_by_prefix_len={256: "same-hash"}, + attention_hit_len=256, + request_prefix_len=300, + cache_salt="tenant-b", + model_revision="rev-a", + ) + hit = cache.compute_reuse_plan( + cumulative_hashes_by_prefix_len={256: "same-hash"}, + attention_hit_len=256, + request_prefix_len=300, + cache_salt="tenant-a", + model_revision="rev-a", + ) + + self.assertIsNone(miss.checkpoint_key) + self.assertIsNotNone(hit.checkpoint_key) + + def test_refcount_blocks_eviction(self): + cache = HybridPrefixCheckpointCache( + required_gdn_layers=[0, 1, 2], + checkpoint_interval=256, + max_checkpoints=2, + ) + key1 = cache.put_checkpoint( + cumulative_prefix_hash="h256", + prefix_len=256, + recurrent_states=_states("r256"), + conv_states=_states("c256"), + ) + key2 = cache.put_checkpoint( + cumulative_prefix_hash="h512", + prefix_len=512, + recurrent_states=_states("r512"), + conv_states=_states("c512"), + ) + cache.inc_ref(key1) + key3 = cache.put_checkpoint( + cumulative_prefix_hash="h768", + prefix_len=768, + recurrent_states=_states("r768"), + conv_states=_states("c768"), + ) + + self.assertIsNotNone(cache.get_checkpoint(key1)) + self.assertIsNone(cache.get_checkpoint(key2)) + self.assertIsNotNone(cache.get_checkpoint(key3)) + + def test_checkpoint_length_must_align_to_interval(self): + cache = HybridPrefixCheckpointCache( + required_gdn_layers=[0, 1, 2], + checkpoint_interval=256, + ) + + with self.assertRaisesRegex(ValueError, "checkpoint_interval"): + cache.put_checkpoint( + cumulative_prefix_hash="h300", + prefix_len=300, + recurrent_states=_states("r300"), + conv_states=_states("c300"), + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/unit/modules/test_async_execution.py b/test/unit/modules/test_async_execution.py index f3528131..8a62f78b 100644 --- a/test/unit/modules/test_async_execution.py +++ b/test/unit/modules/test_async_execution.py @@ -1,10 +1,49 @@ import unittest +from contextlib import redirect_stdout +from io import StringIO +from types import SimpleNamespace from typing import List from unittest.mock import patch import torch -from neuronx_distributed_inference.modules.async_execution import AsyncTensorWrapper +from neuronx_distributed_inference.modules.async_execution import ( + AsyncTensorWrapper, + _async_request_ids_signature, + _combine_vectorized_hybrid_apc_inputs, + _is_chunked_prefill_execution, + _is_context_encoding_execution, + _with_hybrid_apc_candidate_owner_metadata, + _with_hybrid_apc_owner_metadata, + cancel_hybrid_apc_request, + execute_model_prefix_caching, + finish_hybrid_apc_request, + prepare_disabled_hybrid_apc_model_inputs, + prepare_hybrid_apc_model_inputs, + prepare_hybrid_apc_request_for_execution, +) + + +class TestAsyncRequestIdsSignature(unittest.TestCase): + def test_signature_preserves_request_order(self): + model = SimpleNamespace(_qwen36_vllm_request_ids=["req-b", "req-a"]) + + self.assertEqual( + _async_request_ids_signature(model), + ("req-b", "req-a"), + ) + + def test_signature_accepts_single_request_id(self): + model = SimpleNamespace(_qwen36_vllm_request_ids="req-a") + + self.assertEqual(_async_request_ids_signature(model), ("req-a",)) + + def test_signature_accepts_tensor_request_ids(self): + model = SimpleNamespace( + _qwen36_vllm_request_ids=torch.tensor([1, 0], dtype=torch.int32) + ) + + self.assertEqual(_async_request_ids_signature(model), (1, 0)) class TestAsyncTensorWrapper(unittest.TestCase): @@ -279,3 +318,2309 @@ def test_early_exit(self, mock_is_ranked_io): ) assert res is None, f"Early Exit should return None, but found {res}" + + +class TestHybridAPCAsyncBridge(unittest.TestCase): + def test_bridge_is_empty_when_hybrid_apc_disabled(self): + base = SimpleNamespace(config=SimpleNamespace(use_hybrid_apc_manager=False)) + + args = prepare_hybrid_apc_model_inputs(base, {"seq_ids": torch.tensor([0])}) + + self.assertEqual(args, []) + + def test_bridge_builds_restore_and_commit_tensors(self): + base = SimpleNamespace(config=SimpleNamespace(use_hybrid_apc_manager=True)) + input_dict = { + "seq_ids": torch.tensor([3, 4], dtype=torch.int32), + "computed_context_lens": torch.tensor([[256], [0]], dtype=torch.int32), + "hybrid_restore_slot_ids": torch.tensor([7, 0], dtype=torch.int32), + "hybrid_restore_mask": torch.tensor([1, 0], dtype=torch.int32), + "hybrid_commit_slot_ids": torch.tensor([8, 9], dtype=torch.int32), + "hybrid_commit_mask": torch.tensor([1, 0], dtype=torch.int32), + } + + args = prepare_hybrid_apc_model_inputs(base, input_dict) + + self.assertEqual(len(args), 14) + self.assertTrue(torch.equal(args[9], torch.tensor([7, 0], dtype=torch.int32))) + self.assertTrue(torch.equal(args[10], torch.tensor([1, 0], dtype=torch.int32))) + self.assertTrue(torch.equal(args[11], torch.tensor([256, 0], dtype=torch.int32))) + self.assertTrue(torch.equal(args[12], torch.tensor([8, 9], dtype=torch.int32))) + self.assertTrue(torch.equal(args[13], torch.tensor([1, 0], dtype=torch.int32))) + + def test_bridge_does_not_infer_restore_mask_from_slot_presence(self): + base = SimpleNamespace(config=SimpleNamespace(use_hybrid_apc_manager=True)) + input_dict = { + "seq_ids": torch.tensor([3], dtype=torch.int32), + "computed_context_lens": torch.tensor([[256]], dtype=torch.int32), + "hybrid_restore_slot_ids": torch.tensor([7], dtype=torch.int32), + } + + args = prepare_hybrid_apc_model_inputs(base, input_dict) + + self.assertTrue(torch.equal(args[9], torch.tensor([7], dtype=torch.int32))) + self.assertTrue(torch.equal(args[10], torch.tensor([0], dtype=torch.int32))) + self.assertTrue(torch.equal(args[11], torch.tensor([256], dtype=torch.int32))) + + def test_bridge_debug_switches_zero_restore_and_commit_masks(self): + base = SimpleNamespace(config=SimpleNamespace(use_hybrid_apc_manager=True)) + input_dict = { + "seq_ids": torch.tensor([3, 4], dtype=torch.int32), + "computed_context_lens": torch.tensor([[256], [0]], dtype=torch.int32), + "hybrid_restore_slot_ids": torch.tensor([7, 0], dtype=torch.int32), + "hybrid_restore_mask": torch.tensor([1, 0], dtype=torch.int32), + "hybrid_commit_slot_ids": torch.tensor([8, 9], dtype=torch.int32), + "hybrid_commit_mask": torch.tensor([1, 1], dtype=torch.int32), + } + + with patch.dict( + "os.environ", + {"QWEN36_DISABLE_HYBRID_GDN_RESTORE_COMMIT": "1"}, + ): + args = prepare_hybrid_apc_model_inputs(base, input_dict) + + self.assertTrue(torch.equal(args[9], torch.tensor([7, 0], dtype=torch.int32))) + self.assertTrue(torch.equal(args[10], torch.tensor([0, 0], dtype=torch.int32))) + self.assertTrue(torch.equal(args[12], torch.tensor([8, 9], dtype=torch.int32))) + self.assertTrue(torch.equal(args[13], torch.tensor([0, 0], dtype=torch.int32))) + + def test_bridge_rejects_active_slot_out_of_range(self): + base = SimpleNamespace( + config=SimpleNamespace( + use_hybrid_apc_manager=True, + max_gdn_checkpoint_slots=2, + ) + ) + input_dict = { + "seq_ids": torch.tensor([0], dtype=torch.int32), + "hybrid_restore_slot_ids": torch.tensor([2], dtype=torch.int32), + "hybrid_restore_mask": torch.tensor([1], dtype=torch.int32), + } + + with self.assertRaisesRegex(ValueError, "outside \\[0, 2\\)"): + prepare_hybrid_apc_model_inputs(base, input_dict) + + def test_bridge_validates_active_slots_against_allocator(self): + base = SimpleNamespace( + config=SimpleNamespace( + use_hybrid_apc_manager=True, + max_gdn_checkpoint_slots=10, + ), + hybrid_apc_slot_allocator=SimpleNamespace( + committed_slots=(5,), + reserved_slots=(7,), + ), + ) + input_dict = { + "seq_ids": torch.tensor([0], dtype=torch.int32), + "computed_context_lens": torch.tensor([[128]], dtype=torch.int32), + "hybrid_restore_slot_ids": torch.tensor([5], dtype=torch.int32), + "hybrid_restore_mask": torch.tensor([1], dtype=torch.int32), + "hybrid_commit_slot_ids": torch.tensor([7], dtype=torch.int32), + "hybrid_commit_mask": torch.tensor([1], dtype=torch.int32), + } + + args = prepare_hybrid_apc_model_inputs(base, input_dict) + self.assertTrue(torch.equal(args[9], torch.tensor([5], dtype=torch.int32))) + self.assertTrue(torch.equal(args[12], torch.tensor([7], dtype=torch.int32))) + + input_dict["hybrid_commit_slot_ids"] = torch.tensor([6], dtype=torch.int32) + with self.assertRaisesRegex(ValueError, "not a reserved checkpoint slot"): + prepare_hybrid_apc_model_inputs(base, input_dict) + + def test_disabled_bridge_builds_inert_decode_args_without_validation(self): + base = SimpleNamespace( + config=SimpleNamespace( + use_hybrid_apc_manager=True, + max_gdn_checkpoint_slots=1, + ), + hybrid_apc_slot_allocator=SimpleNamespace( + committed_slots=(), + reserved_slots=(), + ), + ) + input_dict = { + "seq_ids": torch.tensor([3, 4], dtype=torch.int32), + "computed_context_lens": torch.tensor([[128], [256]], dtype=torch.int32), + "hybrid_restore_slot_ids": torch.tensor([99, 100], dtype=torch.int32), + "hybrid_restore_mask": torch.tensor([1, 1], dtype=torch.int32), + "hybrid_commit_slot_ids": torch.tensor([101, 102], dtype=torch.int32), + "hybrid_commit_mask": torch.tensor([1, 1], dtype=torch.int32), + } + + args = prepare_disabled_hybrid_apc_model_inputs(base, input_dict) + + self.assertEqual(len(args), 14) + for index in (9, 10, 11, 12, 13): + self.assertTrue( + torch.equal(args[index], torch.zeros((2,), dtype=torch.int32)) + ) + + def test_prefix_caching_execution_prepares_and_finishes_hybrid_apc(self): + base = SimpleNamespace( + config=SimpleNamespace(use_hybrid_apc_manager=True), + neuron_config=SimpleNamespace( + enable_fused_speculation=False, + enable_eagle_speculation=False, + ), + ) + bridge = _FakeHybridBridge() + model = _FakePrefixModel() + input_dict = _prefix_input_dict() + input_dict.update( + { + "hybrid_apc_bridge": bridge, + "request_id": "req-1", + "vllm_attention_hit_len": torch.tensor([2], dtype=torch.int32), + "cumulative_hashes_by_prefix_len": {2: "h2", 4: "h4"}, + "attention_block_refs": {4: (11, 12)}, + "actual_refs": (21, 22), + } + ) + + result, is_neuron = execute_model_prefix_caching(base, model, input_dict) + + self.assertEqual(result, "model-output") + self.assertFalse(is_neuron) + self.assertEqual(bridge.prepare_kwargs["request_id"], "req-1") + self.assertEqual(bridge.prepare_kwargs["attention_hit_len"], 2) + self.assertEqual( + bridge.prepare_kwargs["cumulative_hashes_by_prefix_len"], + {2: "h2", 4: "h4"}, + ) + self.assertEqual( + bridge.prepare_kwargs["attention_block_refs_by_prefix_len"], + {4: (11, 12)}, + ) + self.assertTrue( + torch.equal( + model.calls[0][0], + torch.tensor([[12, 13]], dtype=torch.int32), + ) + ) + self.assertIn("_hybrid_apc_prepared", input_dict) + + finish_hybrid_apc_request(input_dict) + + self.assertEqual(bridge.committed[0][0].request_id, "req-1") + self.assertEqual(bridge.committed[0][1], (21, 22)) + self.assertEqual(bridge.finished, ["req-1"]) + self.assertNotIn("_hybrid_apc_prepared", input_dict) + + def test_prefix_caching_execution_uses_wrapper_hybrid_apc_owner(self): + base = SimpleNamespace( + config=SimpleNamespace(use_hybrid_apc_manager=True), + neuron_config=SimpleNamespace( + enable_fused_speculation=False, + enable_eagle_speculation=False, + ), + ) + bridge = _FakeHybridBridge() + model = _FakePrefixModel() + model.config = SimpleNamespace(use_hybrid_apc_manager=True) + model.hybrid_apc_bridge = bridge + input_dict = _prefix_input_dict() + input_dict.update( + { + "request_id": "req-wrapper-owner", + "vllm_attention_hit_len": torch.tensor([2], dtype=torch.int32), + "actual_refs": (31, 32), + } + ) + + result, is_neuron = execute_model_prefix_caching(base, model, input_dict) + + self.assertEqual(result, "model-output") + self.assertFalse(is_neuron) + self.assertEqual(bridge.prepare_kwargs["request_id"], "req-wrapper-owner") + self.assertIn("_hybrid_apc_prepared", input_dict) + self.assertTrue( + torch.equal(model.calls[0][-1], torch.tensor([1], dtype=torch.int32)) + ) + + finish_hybrid_apc_request(input_dict) + + self.assertEqual(bridge.committed[0][0].request_id, "req-wrapper-owner") + self.assertEqual(bridge.committed[0][1], (31, 32)) + self.assertEqual(bridge.finished, ["req-wrapper-owner"]) + + def test_prefix_caching_execution_uses_wrapper_direct_runtime_flag(self): + base = SimpleNamespace( + config=SimpleNamespace(use_hybrid_apc_manager=False), + neuron_config=SimpleNamespace( + enable_fused_speculation=False, + enable_eagle_speculation=False, + ), + ) + bridge = _FakeHybridBridge() + model = _FakePrefixModel() + model.use_hybrid_apc_manager = True + model.hybrid_apc_bridge = bridge + input_dict = _prefix_input_dict() + input_dict.update( + { + "request_id": "req-wrapper-direct", + "vllm_attention_hit_len": torch.tensor([2], dtype=torch.int32), + } + ) + + result, is_neuron = execute_model_prefix_caching(base, model, input_dict) + + self.assertEqual(result, "model-output") + self.assertFalse(is_neuron) + self.assertEqual(bridge.prepare_kwargs["request_id"], "req-wrapper-direct") + self.assertIn("_hybrid_apc_prepared", input_dict) + + def test_prefix_caching_execution_finds_context_wrapper_bridge(self): + bridge = _FakeHybridBridge() + context_owner = SimpleNamespace( + config=SimpleNamespace(use_hybrid_apc_manager=True), + hybrid_apc_bridge=bridge, + ) + base = SimpleNamespace( + config=SimpleNamespace(use_hybrid_apc_manager=True), + neuron_config=SimpleNamespace( + enable_fused_speculation=False, + enable_eagle_speculation=False, + ), + context_encoding_model=context_owner, + ) + model = _FakePrefixModel() + input_dict = _prefix_input_dict() + input_dict.update( + { + "request_id": "req-context-owner", + "vllm_attention_hit_len": torch.tensor([2], dtype=torch.int32), + } + ) + + execute_model_prefix_caching(base, model, input_dict) + + self.assertEqual(bridge.prepare_kwargs["request_id"], "req-context-owner") + self.assertIn("_hybrid_apc_prepared", input_dict) + + def test_prefix_caching_execution_reuses_last_bridge_for_continuation(self): + bridge = _FakeHybridBridge() + base = SimpleNamespace( + config=SimpleNamespace(use_hybrid_apc_manager=True), + neuron_config=SimpleNamespace( + enable_fused_speculation=False, + enable_eagle_speculation=False, + ), + hybrid_apc_bridge=bridge, + ) + model = _FakePrefixModel() + first_input = _prefix_input_dict() + first_input.update( + { + "request_id": "req-first", + "vllm_attention_hit_len": torch.tensor([2], dtype=torch.int32), + } + ) + + execute_model_prefix_caching(base, model, first_input) + finish_hybrid_apc_request(first_input) + base.hybrid_apc_bridge = None + + second_input = _prefix_input_dict() + second_input.update( + { + "request_id": "req-second", + "vllm_attention_hit_len": torch.tensor([2], dtype=torch.int32), + } + ) + + execute_model_prefix_caching(base, model, second_input) + + self.assertEqual(bridge.prepare_kwargs["request_id"], "req-second") + self.assertIn("_hybrid_apc_prepared", second_input) + + def test_prefix_caching_execution_uses_wrapper_scheduler_records(self): + base = SimpleNamespace( + config=SimpleNamespace(use_hybrid_apc_manager=True), + neuron_config=SimpleNamespace( + enable_fused_speculation=False, + enable_eagle_speculation=False, + ), + ) + bridge = _FakeHybridBridge() + model = _FakePrefixModel() + model.config = SimpleNamespace(use_hybrid_apc_manager=True) + model.hybrid_apc_bridge = bridge + model._qwen36_vllm_request_ids = ("req-from-record",) + model._qwen36_vllm_hybrid_apc_request_records = ( + { + "request_id": "req-from-record", + "vllm_attention_hit_len": 2, + "request_prefix_len": 4, + "cumulative_hashes_by_prefix_len": {2: "h2", 4: "h4"}, + }, + ) + input_dict = _prefix_input_dict() + + result, is_neuron = execute_model_prefix_caching(base, model, input_dict) + + self.assertEqual(result, "model-output") + self.assertFalse(is_neuron) + self.assertEqual(bridge.prepare_kwargs["request_id"], "req-from-record") + self.assertEqual(bridge.prepare_kwargs["attention_hit_len"], 2) + self.assertEqual( + bridge.prepare_kwargs["cumulative_hashes_by_prefix_len"], + {2: "h2", 4: "h4"}, + ) + self.assertIn("_hybrid_apc_prepared", input_dict) + + def test_prefix_caching_execution_does_not_prepare_hybrid_apc_for_generation(self): + base = SimpleNamespace( + config=SimpleNamespace(use_hybrid_apc_manager=True), + neuron_config=SimpleNamespace( + enable_fused_speculation=False, + enable_eagle_speculation=False, + ), + ) + bridge = _FakeHybridBridge() + model = _FakePrefixModel(tag="token_generation_model") + input_dict = _prefix_input_dict() + input_dict.update( + { + "input_ids": torch.tensor([[13]], dtype=torch.int32), + "attention_mask": torch.ones((1, 5), dtype=torch.int32), + "position_ids": torch.tensor([[4]], dtype=torch.int32), + "slot_mapping": torch.tensor([[4]], dtype=torch.int32), + "full_context_lens": torch.tensor([[5]], dtype=torch.int32), + "computed_context_lens": torch.tensor([[4]], dtype=torch.int32), + "hybrid_apc_bridge": bridge, + "request_id": "req-1", + "vllm_attention_hit_len": torch.tensor([4], dtype=torch.int32), + "hybrid_restore_slot_ids": torch.tensor([5], dtype=torch.int32), + "hybrid_restore_mask": torch.tensor([1], dtype=torch.int32), + "hybrid_commit_slot_ids": torch.tensor([7], dtype=torch.int32), + "hybrid_commit_mask": torch.tensor([1], dtype=torch.int32), + } + ) + + with patch( + "neuronx_distributed_inference.modules.async_execution.prepare_hybrid_apc_model_inputs", + side_effect=AssertionError("decode should use inert Hybrid APC args"), + ): + result, is_neuron = execute_model_prefix_caching(base, model, input_dict) + + self.assertEqual(result, "model-output") + self.assertFalse(is_neuron) + self.assertEqual(bridge.prepare_calls, []) + self.assertTrue( + torch.equal(model.calls[0][-4], torch.tensor([0], dtype=torch.int32)) + ) + self.assertTrue( + torch.equal(model.calls[0][-1], torch.tensor([0], dtype=torch.int32)) + ) + + def test_prefix_caching_generation_rejects_invalid_token_before_neuron(self): + base = SimpleNamespace( + config=SimpleNamespace( + use_hybrid_apc_manager=True, + vocab_size=248320, + ), + neuron_config=SimpleNamespace( + enable_fused_speculation=False, + enable_eagle_speculation=False, + ), + _qwen36_vllm_request_ids=("req-invalid",), + ) + model = _FakePrefixModel(tag="token_generation_model") + input_dict = _prefix_input_dict() + input_dict.update( + { + "input_ids": torch.tensor([[1065353216]], dtype=torch.int32), + "attention_mask": torch.ones((1, 1280), dtype=torch.int32), + "position_ids": torch.tensor([[1024]], dtype=torch.int32), + "slot_mapping": torch.tensor([[2560]], dtype=torch.int32), + "block_table": torch.tensor([[6, 7, 8, 9, 10]], dtype=torch.int32), + "full_context_lens": torch.tensor([[1025]], dtype=torch.int32), + "computed_context_lens": torch.tensor([[1024]], dtype=torch.int32), + "num_queries": torch.tensor([[1]], dtype=torch.int32), + } + ) + + with self.assertRaisesRegex( + ValueError, + "Token generation input_ids contract violated before Neuron execution", + ) as cm: + execute_model_prefix_caching(base, model, input_dict) + + self.assertIn("0x3f800000", str(cm.exception)) + self.assertEqual(model.calls, []) + + def test_chunked_prefill_with_nonzero_positions_still_uses_context_execution(self): + base = SimpleNamespace( + neuron_config=SimpleNamespace( + enable_fused_speculation=False, + enable_eagle_speculation=False, + ), + _is_prefill=lambda position_ids: not bool(position_ids.min().item()), + ) + model = _FakePrefixModel(tag="token_generation_model") + input_dict = _prefix_input_dict() + input_dict.update( + { + "input_ids": torch.arange(256, dtype=torch.int32).reshape(1, 256), + "position_ids": torch.arange(256, 512, dtype=torch.int32).reshape(1, 256), + } + ) + + self.assertTrue(_is_context_encoding_execution(base, model, input_dict)) + self.assertTrue( + _is_chunked_prefill_execution( + base, + input_dict, + is_fused_speculation=False, + ) + ) + + def test_single_token_decode_remains_generation_execution(self): + base = SimpleNamespace( + neuron_config=SimpleNamespace( + enable_fused_speculation=False, + enable_eagle_speculation=False, + ), + _is_prefill=lambda position_ids: not bool(position_ids.min().item()), + ) + model = _FakePrefixModel(tag="token_generation_model") + input_dict = _prefix_input_dict() + input_dict.update( + { + "input_ids": torch.tensor([[13]], dtype=torch.int32), + "position_ids": torch.tensor([[512]], dtype=torch.int32), + } + ) + + self.assertFalse(_is_context_encoding_execution(base, model, input_dict)) + self.assertFalse( + _is_chunked_prefill_execution( + base, + input_dict, + is_fused_speculation=False, + ) + ) + + def test_single_token_cached_prefill_continuation_uses_context_execution(self): + base = SimpleNamespace( + neuron_config=SimpleNamespace( + enable_fused_speculation=False, + enable_eagle_speculation=False, + ), + _is_prefill=lambda position_ids: not bool(position_ids.min().item()), + ) + input_dict = _prefix_input_dict() + input_dict.update( + { + "input_ids": torch.empty((1, 0), dtype=torch.int32), + "position_ids": torch.empty((1, 0), dtype=torch.int32), + "hybrid_prefill_completion_state": torch.tensor([0], dtype=torch.int32), + "vllm_attention_hit_len": torch.tensor([2048], dtype=torch.int32), + "request_prefix_len": 2049, + "active_suffix_len": 1, + } + ) + + self.assertTrue( + _is_chunked_prefill_execution( + base, + input_dict, + is_fused_speculation=False, + ) + ) + + def test_owner_metadata_single_token_continuation_uses_context_execution(self): + base = SimpleNamespace( + neuron_config=SimpleNamespace( + enable_fused_speculation=False, + enable_eagle_speculation=False, + ), + _qwen36_vllm_request_ids=("req-2049",), + _qwen36_vllm_prefill_completion_state=torch.tensor( + [0], + dtype=torch.int32, + ), + _qwen36_vllm_hybrid_apc_metadata_by_request_id={ + "req-2049": { + "vllm_attention_hit_len": 2048, + "request_prefix_len": 2049, + "active_suffix_len": 1, + "full_input_ids": tuple(range(2049)), + }, + }, + ) + input_dict = _prefix_input_dict() + input_dict.update( + { + "input_ids": torch.empty((1, 0), dtype=torch.int32), + "position_ids": torch.empty((1, 0), dtype=torch.int32), + } + ) + + enriched = _with_hybrid_apc_owner_metadata(input_dict, base) + + self.assertIn("hybrid_request_records", enriched) + self.assertTrue( + _is_chunked_prefill_execution( + base, + enriched, + is_fused_speculation=False, + ) + ) + + def test_candidate_owner_metadata_uses_wrapper_records_for_prefill_probe(self): + base = SimpleNamespace() + wrapper = SimpleNamespace( + _qwen36_vllm_request_ids=("req-wrapper-2049",), + _qwen36_vllm_prefill_completion_state=torch.tensor( + [0], + dtype=torch.int32, + ), + _qwen36_vllm_hybrid_apc_metadata_by_request_id={ + "req-wrapper-2049": { + "vllm_attention_hit_len": 2048, + "request_prefix_len": 2049, + "active_suffix_len": 1, + }, + }, + ) + input_dict = _prefix_input_dict() + input_dict.update( + { + "input_ids": torch.empty((1, 0), dtype=torch.int32), + "position_ids": torch.empty((1, 0), dtype=torch.int32), + } + ) + + enriched = _with_hybrid_apc_candidate_owner_metadata(input_dict, base, wrapper) + + self.assertIn("hybrid_request_records", enriched) + self.assertTrue( + _is_chunked_prefill_execution( + SimpleNamespace( + neuron_config=SimpleNamespace( + enable_fused_speculation=False, + enable_eagle_speculation=False, + ) + ), + enriched, + is_fused_speculation=False, + ) + ) + + def test_completed_single_token_decode_stays_generation_execution(self): + base = SimpleNamespace( + neuron_config=SimpleNamespace( + enable_fused_speculation=False, + enable_eagle_speculation=False, + ), + _is_prefill=lambda position_ids: not bool(position_ids.min().item()), + ) + input_dict = _prefix_input_dict() + input_dict.update( + { + "input_ids": torch.tensor([[13]], dtype=torch.int32), + "position_ids": torch.tensor([[2048]], dtype=torch.int32), + "hybrid_prefill_completion_state": torch.tensor([1], dtype=torch.int32), + "vllm_attention_hit_len": torch.tensor([2048], dtype=torch.int32), + "request_prefix_len": 2048, + "active_suffix_len": 1, + } + ) + + self.assertFalse( + _is_chunked_prefill_execution( + base, + input_dict, + is_fused_speculation=False, + ) + ) + + def test_commit_debug_switch_cancels_instead_of_committing_metadata(self): + base = SimpleNamespace( + config=SimpleNamespace(use_hybrid_apc_manager=True), + neuron_config=SimpleNamespace( + enable_fused_speculation=False, + enable_eagle_speculation=False, + ), + ) + bridge = _FakeHybridBridge() + model = _FakePrefixModel() + input_dict = _prefix_input_dict() + input_dict.update( + { + "hybrid_apc_bridge": bridge, + "request_id": "req-no-commit", + "vllm_attention_hit_len": torch.tensor([2], dtype=torch.int32), + } + ) + + with patch.dict("os.environ", {"QWEN36_DISABLE_HYBRID_GDN_COMMIT": "1"}): + execute_model_prefix_caching(base, model, input_dict) + finish_hybrid_apc_request(input_dict) + + self.assertEqual(bridge.committed, []) + self.assertEqual(bridge.cancelled[0].request_id, "req-no-commit") + self.assertNotIn("_hybrid_apc_prepared", input_dict) + + def test_hybrid_apc_debug_trace_includes_restore_commit_evidence(self): + bridge = _FakeHybridBridge() + base = SimpleNamespace( + config=SimpleNamespace(use_hybrid_apc_manager=True), + hybrid_apc_bridge=bridge, + ) + input_dict = _prefix_input_dict() + input_dict.update( + { + "request_id": "req-debug", + "vllm_attention_hit_len": torch.tensor([2], dtype=torch.int32), + } + ) + + stdout = StringIO() + with patch.dict("os.environ", {"QWEN36_HYBRID_APC_DEBUG": "1"}): + with redirect_stdout(stdout): + prepared = prepare_hybrid_apc_request_for_execution(base, input_dict) + + log = stdout.getvalue() + self.assertIn("[hybrid_apc_debug] prepare", log) + self.assertIn("request_id='req-debug'", log) + self.assertIn("attention_hit_len=2", log) + self.assertIn("restore_len=2", log) + self.assertIn("commit_prefix_len=4", log) + self.assertIn("restore_slot=5", log) + self.assertIn("commit_slot=7", log) + self.assertIn("input_shape=(1, 4)", log) + self.assertIn("prepared_shape=(1, 2)", log) + self.assertIn("computed=tensor([[2]], dtype=torch.int32)", log) + self.assertIn("restore_mask=tensor([1], dtype=torch.int32)", log) + self.assertIn("commit_mask=tensor([1], dtype=torch.int32)", log) + self.assertTrue( + torch.equal( + prepared["input_ids"], + torch.tensor([[12, 13]], dtype=torch.int32), + ) + ) + + def test_prepare_debug_switches_zero_prepared_restore_commit_masks(self): + bridge = _FakeHybridBridge() + base = SimpleNamespace( + config=SimpleNamespace(use_hybrid_apc_manager=True), + hybrid_apc_bridge=bridge, + ) + input_dict = _prefix_input_dict() + input_dict.update( + { + "request_id": "req-debug-switches", + "vllm_attention_hit_len": torch.tensor([2], dtype=torch.int32), + } + ) + + with patch.dict( + "os.environ", + {"QWEN36_DISABLE_HYBRID_GDN_RESTORE_COMMIT": "1"}, + ): + prepared = prepare_hybrid_apc_request_for_execution(base, input_dict) + + self.assertTrue( + torch.equal( + prepared["hybrid_restore_slot_ids"], + torch.tensor([5], dtype=torch.int32), + ) + ) + self.assertTrue( + torch.equal( + prepared["hybrid_restore_mask"], + torch.tensor([0], dtype=torch.int32), + ) + ) + self.assertTrue( + torch.equal( + prepared["hybrid_commit_slot_ids"], + torch.tensor([7], dtype=torch.int32), + ) + ) + self.assertTrue( + torch.equal( + prepared["hybrid_commit_mask"], + torch.tensor([0], dtype=torch.int32), + ) + ) + + def test_prepare_debug_switch_zeroes_only_restore_mask(self): + bridge = _FakeHybridBridge() + base = SimpleNamespace( + config=SimpleNamespace(use_hybrid_apc_manager=True), + hybrid_apc_bridge=bridge, + ) + input_dict = _prefix_input_dict() + input_dict.update( + { + "request_id": "req-zero-restore-only", + "vllm_attention_hit_len": torch.tensor([2], dtype=torch.int32), + } + ) + + with patch.dict( + "os.environ", + {"QWEN36_ZERO_HYBRID_GDN_RESTORE_MASK": "1"}, + ): + prepared = prepare_hybrid_apc_request_for_execution(base, input_dict) + + self.assertTrue( + torch.equal( + prepared["hybrid_restore_mask"], + torch.tensor([0], dtype=torch.int32), + ) + ) + self.assertTrue( + torch.equal( + prepared["hybrid_commit_mask"], + torch.tensor([1], dtype=torch.int32), + ) + ) + + def test_prefix_caching_execution_cancels_hybrid_apc_on_model_failure(self): + base = SimpleNamespace( + config=SimpleNamespace(use_hybrid_apc_manager=True), + neuron_config=SimpleNamespace( + enable_fused_speculation=False, + enable_eagle_speculation=False, + ), + ) + bridge = _FakeHybridBridge() + model = _FakePrefixModel(should_fail=True) + input_dict = _prefix_input_dict() + input_dict.update( + { + "hybrid_apc_bridge": bridge, + "request_id": "req-1", + "vllm_attention_hit_len": 2, + } + ) + + with self.assertRaisesRegex(RuntimeError, "model failed"): + execute_model_prefix_caching(base, model, input_dict) + + self.assertEqual(bridge.cancelled[0].request_id, "req-1") + self.assertNotIn("_hybrid_apc_prepared", input_dict) + + def test_prefix_caching_execution_uses_model_registered_bridge_and_derived_hit(self): + bridge = _FakeHybridBridge() + base = SimpleNamespace( + config=SimpleNamespace(use_hybrid_apc_manager=True), + neuron_config=SimpleNamespace( + enable_fused_speculation=False, + enable_eagle_speculation=False, + ), + hybrid_apc_bridge=bridge, + ) + model = _FakePrefixModel() + input_dict = _prefix_input_dict() + + result, is_neuron = execute_model_prefix_caching(base, model, input_dict) + + self.assertEqual(result, "model-output") + self.assertFalse(is_neuron) + self.assertEqual(bridge.prepare_kwargs["request_id"], ("seq_id", 0)) + self.assertEqual(bridge.prepare_kwargs["attention_hit_len"], 0) + + def test_prefix_caching_execution_uses_full_prompt_tokens_for_suffix_request(self): + bridge = _FakeHybridBridge() + base = SimpleNamespace( + config=SimpleNamespace(use_hybrid_apc_manager=True), + neuron_config=SimpleNamespace( + enable_fused_speculation=False, + enable_eagle_speculation=False, + ), + hybrid_apc_bridge=bridge, + ) + model = _FakePrefixModel() + input_dict = _prefix_input_dict() + input_dict["input_ids"] = torch.tensor([[12, 13]], dtype=torch.int32) + input_dict["attention_mask"] = torch.ones((1, 2), dtype=torch.int32) + input_dict["position_ids"] = torch.tensor([[2, 3]], dtype=torch.int32) + input_dict["slot_mapping"] = torch.tensor([[2, 3]], dtype=torch.int32) + input_dict["computed_context_lens"] = torch.tensor([[2]], dtype=torch.int32) + input_dict["hybrid_full_input_ids"] = torch.tensor( + [[10, 11, 12, 13]], + dtype=torch.int32, + ) + + execute_model_prefix_caching(base, model, input_dict) + + self.assertTrue( + torch.equal( + bridge.prepare_kwargs["input_dict"]["input_ids"], + torch.tensor([[10, 11, 12, 13]], dtype=torch.int32), + ) + ) + self.assertEqual(bridge.prepare_kwargs["attention_hit_len"], 2) + + def test_prefix_caching_execution_skips_hybrid_apc_for_suffix_without_full_prompt(self): + bridge = _FakeHybridBridge() + base = SimpleNamespace( + config=SimpleNamespace(use_hybrid_apc_manager=True), + neuron_config=SimpleNamespace( + enable_fused_speculation=False, + enable_eagle_speculation=False, + ), + hybrid_apc_bridge=bridge, + ) + model = _FakePrefixModel() + input_dict = _prefix_input_dict() + input_dict["input_ids"] = torch.tensor([[12, 13]], dtype=torch.int32) + input_dict["attention_mask"] = torch.ones((1, 2), dtype=torch.int32) + input_dict["position_ids"] = torch.tensor([[2, 3]], dtype=torch.int32) + input_dict["slot_mapping"] = torch.tensor([[2, 3]], dtype=torch.int32) + input_dict["computed_context_lens"] = torch.tensor([[2]], dtype=torch.int32) + + execute_model_prefix_caching(base, model, input_dict) + + self.assertIsNone(bridge.prepare_kwargs) + self.assertTrue( + torch.equal( + model.calls[0][0], + torch.tensor([[12, 13]], dtype=torch.int32), + ) + ) + + def test_single_token_same_request_suffix_prepares_hybrid_apc(self): + bridge = _FakeHybridBridge() + base = SimpleNamespace( + config=SimpleNamespace(use_hybrid_apc_manager=True), + neuron_config=SimpleNamespace( + enable_fused_speculation=False, + enable_eagle_speculation=False, + ), + hybrid_apc_bridge=bridge, + ) + input_dict = _prefix_input_dict() + input_dict.update( + { + "input_ids": torch.tensor([[99]], dtype=torch.int32), + "attention_mask": torch.ones((1, 1), dtype=torch.int32), + "position_ids": torch.tensor([[2048]], dtype=torch.int32), + "slot_mapping": torch.tensor([[2304]], dtype=torch.int32), + "block_table": torch.arange(10, dtype=torch.int32).reshape(1, 10), + "computed_context_lens": torch.tensor([[2048]], dtype=torch.int32), + "full_context_lens": torch.tensor([[2049]], dtype=torch.int32), + "request_id": "req-2049", + "hybrid_cached_request_ids": ("req-2049",), + "hybrid_prefill_completion_state": torch.tensor([0], dtype=torch.int32), + "vllm_attention_hit_len": torch.tensor([2048], dtype=torch.int32), + "request_prefix_len": 2049, + "active_suffix_len": 1, + } + ) + + prepared = prepare_hybrid_apc_request_for_execution(base, input_dict) + + self.assertEqual(bridge.suffix_prepare_calls[0]["request_id"], "req-2049") + self.assertEqual(bridge.suffix_prepare_calls[0]["attention_hit_len"], 2048) + self.assertEqual(bridge.suffix_prepare_calls[0]["request_prefix_len"], 2049) + self.assertTrue( + torch.equal(prepared["input_ids"], torch.tensor([[99]], dtype=torch.int32)) + ) + self.assertTrue( + torch.equal( + prepared["computed_context_lens"], + torch.tensor([[2048]], dtype=torch.int32), + ) + ) + self.assertTrue( + torch.equal(prepared["num_queries"], torch.tensor([[1]], dtype=torch.int32)) + ) + self.assertTrue( + torch.equal(prepared["hybrid_restore_mask"], torch.tensor([0], dtype=torch.int32)) + ) + self.assertTrue( + torch.equal( + prepared["hybrid_restore_prefix_lens"], + torch.tensor([2048], dtype=torch.int32), + ) + ) + + def test_vectorized_no_hit_batch_skips_hybrid_apc_request_prep(self): + bridge = _FakeHybridBridge() + base = SimpleNamespace( + config=SimpleNamespace(use_hybrid_apc_manager=True), + hybrid_apc_bridge=bridge, + ) + input_dict = _prefix_input_dict() + input_dict.update( + { + "hybrid_request_id": ("req-a", "req-b"), + "full_context_lens": torch.tensor([4, 4], dtype=torch.int32), + "computed_context_lens": torch.tensor([0, 0], dtype=torch.int32), + } + ) + + prepared = prepare_hybrid_apc_request_for_execution(base, input_dict) + + self.assertIs(prepared, input_dict) + self.assertIsNone(bridge.prepare_kwargs) + self.assertNotIn("_hybrid_apc_prepared", input_dict) + + def test_vectorized_attention_hit_batch_prepares_each_row(self): + bridge = _FakeHybridBridge() + base = SimpleNamespace( + config=SimpleNamespace(use_hybrid_apc_manager=True), + hybrid_apc_bridge=bridge, + ) + input_dict = _prefix_input_dict() + input_dict.update( + { + "hybrid_request_id": ("req-a", "req-b"), + "vllm_attention_hit_len": (2, 2), + "request_prefix_len": (4, 4), + "full_context_lens": torch.tensor([4, 4], dtype=torch.int32), + "computed_context_lens": torch.tensor([2, 2], dtype=torch.int32), + } + ) + input_dict["input_ids"] = input_dict["input_ids"].repeat(2, 1) + input_dict["attention_mask"] = input_dict["attention_mask"].repeat(2, 1) + input_dict["position_ids"] = input_dict["position_ids"].repeat(2, 1) + input_dict["seq_ids"] = torch.tensor([0, 1], dtype=torch.int32) + input_dict["sampling_params"] = input_dict["sampling_params"].repeat(2, 1) + input_dict["adapter_ids"] = torch.tensor([0, 0], dtype=torch.int32) + input_dict["slot_mapping"] = input_dict["slot_mapping"].repeat(2, 1) + input_dict["block_table"] = input_dict["block_table"].repeat(2, 1) + + prepared = prepare_hybrid_apc_request_for_execution(base, input_dict) + + self.assertEqual( + [call["request_id"] for call in bridge.prepare_calls], + ["req-a", "req-b"], + ) + self.assertTrue( + torch.equal( + prepared["input_ids"], + torch.tensor([[12, 13], [12, 13]], dtype=torch.int32), + ) + ) + self.assertTrue( + torch.equal( + prepared["computed_context_lens"], + torch.tensor([[2], [2]], dtype=torch.int32), + ) + ) + self.assertTrue( + torch.equal( + prepared["num_queries"], + torch.tensor([[2], [2]], dtype=torch.int32), + ) + ) + self.assertTrue( + torch.equal( + prepared["hybrid_restore_mask"], + torch.tensor([1, 1], dtype=torch.int32), + ) + ) + self.assertEqual(len(input_dict["_hybrid_apc_prepared"]), 2) + + finish_hybrid_apc_request(input_dict) + + self.assertEqual([item[0].request_id for item in bridge.committed], ["req-a", "req-b"]) + self.assertEqual(bridge.finished, ["req-a", "req-b"]) + + def test_vectorized_strict_metadata_is_selected_per_row(self): + bridge = _FakeHybridBridge() + bridge.requires_external_metadata = True + base = SimpleNamespace( + config=SimpleNamespace( + use_hybrid_apc_manager=True, + hybrid_apc_require_vllm_metadata=True, + ), + hybrid_apc_bridge=bridge, + ) + input_dict = _prefix_input_dict() + input_dict.update( + { + "hybrid_request_id": ("req-a", "req-b"), + "full_context_lens": torch.tensor([4, 4], dtype=torch.int32), + "computed_context_lens": torch.tensor([2, 2], dtype=torch.int32), + "cumulative_hashes_by_prefix_len": ( + {2: b"a2", 4: b"a4"}, + {2: b"b2", 4: b"b4"}, + ), + "attention_block_refs_by_prefix_len": ( + {4: (11, 12)}, + {4: (21, 22)}, + ), + } + ) + input_dict["input_ids"] = input_dict["input_ids"].repeat(2, 1) + input_dict["attention_mask"] = input_dict["attention_mask"].repeat(2, 1) + input_dict["position_ids"] = input_dict["position_ids"].repeat(2, 1) + input_dict["seq_ids"] = torch.tensor([0, 1], dtype=torch.int32) + input_dict["sampling_params"] = input_dict["sampling_params"].repeat(2, 1) + input_dict["adapter_ids"] = torch.tensor([0, 0], dtype=torch.int32) + input_dict["slot_mapping"] = input_dict["slot_mapping"].repeat(2, 1) + input_dict["block_table"] = input_dict["block_table"].repeat(2, 1) + + prepare_hybrid_apc_request_for_execution(base, input_dict) + + self.assertEqual( + bridge.prepare_calls[0]["cumulative_hashes_by_prefix_len"], + {2: b"a2", 4: b"a4"}, + ) + self.assertEqual( + bridge.prepare_calls[1]["cumulative_hashes_by_prefix_len"], + {2: b"b2", 4: b"b4"}, + ) + self.assertEqual( + bridge.prepare_calls[0]["attention_block_refs_by_prefix_len"], + {4: (11, 12)}, + ) + self.assertEqual( + bridge.prepare_calls[1]["attention_block_refs_by_prefix_len"], + {4: (21, 22)}, + ) + + def test_vectorized_strict_metadata_keeps_missing_rows_aligned(self): + bridge = _FakeHybridBridge() + bridge.requires_external_metadata = True + base = SimpleNamespace( + config=SimpleNamespace( + use_hybrid_apc_manager=True, + hybrid_apc_require_vllm_metadata=True, + ), + hybrid_apc_bridge=bridge, + ) + input_dict = _prefix_input_dict() + input_dict.update( + { + "hybrid_request_id": ("cached-a", "new-a"), + "vllm_attention_hit_len": (None, 2), + "request_prefix_len": (None, 4), + "full_context_lens": torch.tensor([4, 4], dtype=torch.int32), + "computed_context_lens": torch.tensor([0, 2], dtype=torch.int32), + "cumulative_hashes_by_prefix_len": (None, {2: b"new-a-2"}), + "attention_block_refs_by_prefix_len": (None, {2: (21,)}), + } + ) + input_dict["input_ids"] = input_dict["input_ids"].repeat(2, 1) + input_dict["attention_mask"] = input_dict["attention_mask"].repeat(2, 1) + input_dict["position_ids"] = input_dict["position_ids"].repeat(2, 1) + input_dict["seq_ids"] = torch.tensor([0, 1], dtype=torch.int32) + input_dict["sampling_params"] = input_dict["sampling_params"].repeat(2, 1) + input_dict["adapter_ids"] = torch.tensor([0, 0], dtype=torch.int32) + input_dict["slot_mapping"] = input_dict["slot_mapping"].repeat(2, 1) + input_dict["block_table"] = input_dict["block_table"].repeat(2, 1) + + prepared = prepare_hybrid_apc_request_for_execution(base, input_dict) + + self.assertEqual( + [call["request_id"] for call in bridge.prepare_calls], + ["new-a"], + ) + self.assertEqual( + bridge.prepare_calls[0]["cumulative_hashes_by_prefix_len"], + {2: b"new-a-2"}, + ) + self.assertEqual( + bridge.prepare_calls[0]["attention_block_refs_by_prefix_len"], + {2: (21,)}, + ) + self.assertTrue( + torch.equal( + prepared["hybrid_restore_mask"], + torch.tensor([0, 1], dtype=torch.int32), + ) + ) + self.assertTrue( + torch.equal( + prepared["num_queries"], + torch.tensor([[4], [2]], dtype=torch.int32), + ) + ) + + def test_vectorized_mixed_hit_batch_pads_prepared_rows(self): + bridge = _FakeHybridBridge() + base = SimpleNamespace( + config=SimpleNamespace(use_hybrid_apc_manager=True), + hybrid_apc_bridge=bridge, + ) + input_dict = _prefix_input_dict() + input_dict.update( + { + "hybrid_request_id": ("req-a", "req-b"), + "full_context_lens": torch.tensor([[4], [4]], dtype=torch.int32), + "computed_context_lens": torch.tensor([[0], [2]], dtype=torch.int32), + } + ) + input_dict["input_ids"] = input_dict["input_ids"].repeat(2, 1) + input_dict["attention_mask"] = input_dict["attention_mask"].repeat(2, 1) + input_dict["position_ids"] = input_dict["position_ids"].repeat(2, 1) + input_dict["seq_ids"] = torch.tensor([0, 1], dtype=torch.int32) + input_dict["sampling_params"] = input_dict["sampling_params"].repeat(2, 1) + input_dict["adapter_ids"] = torch.tensor([0, 0], dtype=torch.int32) + input_dict["slot_mapping"] = input_dict["slot_mapping"].repeat(2, 1) + input_dict["block_table"] = input_dict["block_table"].repeat(2, 1) + + prepared = prepare_hybrid_apc_request_for_execution(base, input_dict) + + self.assertTrue( + torch.equal( + prepared["input_ids"], + torch.tensor([[10, 11, 12, 13], [12, 13, 0, 0]], dtype=torch.int32), + ) + ) + self.assertTrue( + torch.equal( + prepared["attention_mask"], + torch.tensor([[1, 1, 1, 1], [1, 1, 0, 0]], dtype=torch.int32), + ) + ) + self.assertTrue( + torch.equal( + prepared["slot_mapping"], + torch.tensor([[0, 1, 2, 3], [2, 3, -1, -1]], dtype=torch.int32), + ) + ) + self.assertTrue( + torch.equal( + prepared["num_queries"], + torch.tensor([[4], [2]], dtype=torch.int32), + ) + ) + self.assertTrue( + torch.equal( + prepared["hybrid_restore_mask"], + torch.tensor([0, 1], dtype=torch.int32), + ) + ) + + def test_vectorized_packed_suffix_batch_splits_by_query_lengths(self): + bridge = _FakeHybridBridge() + base = SimpleNamespace( + config=SimpleNamespace(use_hybrid_apc_manager=True), + hybrid_apc_bridge=bridge, + ) + input_dict = _prefix_input_dict() + input_dict.update( + { + "hybrid_request_id": ("req-a", "req-b"), + "input_ids": torch.tensor([[12, 13, 22, 23]], dtype=torch.int32), + "attention_mask": torch.ones((1, 4), dtype=torch.int32), + "position_ids": torch.tensor([[2, 3, 2, 3]], dtype=torch.int32), + "seq_ids": torch.tensor([0], dtype=torch.int32), + "adapter_ids": torch.tensor([0], dtype=torch.int32), + "slot_mapping": torch.tensor([[2, 3, 6, 7]], dtype=torch.int32), + "block_table": torch.tensor([[1, 2], [3, 4]], dtype=torch.int32), + "full_context_lens": torch.tensor([[4], [4]], dtype=torch.int32), + "computed_context_lens": torch.tensor([[2], [2]], dtype=torch.int32), + "num_queries": torch.tensor([[2], [2]], dtype=torch.int32), + "cumulative_hashes_by_prefix_len": ( + {2: "hash-a-2", 4: "hash-a-4"}, + {2: "hash-b-2", 4: "hash-b-4"}, + ), + "attention_block_refs_by_prefix_len": ( + {2: (1,), 4: (1, 2)}, + {2: (3,), 4: (3, 4)}, + ), + } + ) + + prepared = prepare_hybrid_apc_request_for_execution(base, input_dict) + + self.assertEqual( + [call["request_id"] for call in bridge.suffix_prepare_calls], + ["req-a", "req-b"], + ) + self.assertEqual( + bridge.suffix_prepare_calls[0]["cumulative_hashes_by_prefix_len"], + {2: "hash-a-2", 4: "hash-a-4"}, + ) + self.assertEqual( + bridge.suffix_prepare_calls[1]["attention_block_refs_by_prefix_len"], + {2: (3,), 4: (3, 4)}, + ) + self.assertTrue( + torch.equal( + prepared["input_ids"], + torch.tensor([[12, 13], [22, 23]], dtype=torch.int32), + ) + ) + self.assertTrue( + torch.equal( + prepared["position_ids"], + torch.tensor([[2, 3], [2, 3]], dtype=torch.int32), + ) + ) + self.assertTrue( + torch.equal( + prepared["slot_mapping"], + torch.tensor([[2, 3], [6, 7]], dtype=torch.int32), + ) + ) + self.assertTrue( + torch.equal( + prepared["seq_ids"], + torch.tensor([0, 1], dtype=torch.int32), + ) + ) + self.assertTrue( + torch.equal( + prepared["hybrid_restore_mask"], + torch.tensor([1, 1], dtype=torch.int32), + ) + ) + + def test_vectorized_request_records_override_collapsed_metadata(self): + bridge = _FakeHybridBridge() + bridge.requires_external_metadata = True + base = SimpleNamespace( + config=SimpleNamespace( + use_hybrid_apc_manager=True, + hybrid_apc_require_vllm_metadata=True, + ), + hybrid_apc_bridge=bridge, + ) + input_dict = _prefix_input_dict() + input_dict.update( + { + "hybrid_request_id": "cold-only", + "vllm_attention_hit_len": 0, + "input_ids": torch.tensor( + [[12, 13, 20, 21, 22, 23]], + dtype=torch.int32, + ), + "attention_mask": torch.ones((1, 6), dtype=torch.int32), + "position_ids": torch.tensor( + [[2, 3, 0, 1, 2, 3]], + dtype=torch.int32, + ), + "seq_ids": torch.tensor([0], dtype=torch.int32), + "adapter_ids": torch.tensor([0], dtype=torch.int32), + "slot_mapping": torch.tensor( + [[2, 3, 10, 11, 12, 13]], + dtype=torch.int32, + ), + "block_table": torch.tensor( + [[1, 2], [5, 6]], + dtype=torch.int32, + ), + "hybrid_request_records": ( + { + "request_id": "warm", + "vllm_attention_hit_len": 2, + "request_prefix_len": 4, + "active_suffix_len": 2, + "cumulative_hashes_by_prefix_len": { + 2: "warm-h2", + 4: "warm-h4", + }, + "attention_block_refs_by_prefix_len": { + 2: (1,), + 4: (1, 2), + }, + }, + { + "request_id": "cold", + "vllm_attention_hit_len": 0, + "request_prefix_len": 4, + "active_suffix_len": 4, + }, + ), + } + ) + + prepared = prepare_hybrid_apc_request_for_execution(base, input_dict) + + self.assertEqual( + [call["request_id"] for call in bridge.suffix_prepare_calls], + ["warm"], + ) + self.assertEqual(bridge.prepare_calls, []) + self.assertEqual( + bridge.suffix_prepare_calls[0]["cumulative_hashes_by_prefix_len"], + {2: "warm-h2", 4: "warm-h4"}, + ) + self.assertTrue( + torch.equal( + prepared["input_ids"], + torch.tensor([[12, 13, 0, 0], [20, 21, 22, 23]], dtype=torch.int32), + ) + ) + self.assertTrue( + torch.equal( + prepared["hybrid_restore_mask"], + torch.tensor([1, 0], dtype=torch.int32), + ) + ) + self.assertTrue( + torch.equal( + prepared["num_queries"], + torch.tensor([[2], [4]], dtype=torch.int32), + ) + ) + + def test_vectorized_cached_decode_row_does_not_require_prefix_restore(self): + bridge = _FakeHybridBridge() + original_prepare_request = bridge.prepare_request + + def prepare_request_with_vector_full_context_lens(**kwargs): + prepared = original_prepare_request(**kwargs) + prepared.input_dict["full_context_lens"] = prepared.input_dict[ + "full_context_lens" + ].reshape(-1) + return prepared + + bridge.prepare_request = prepare_request_with_vector_full_context_lens + base = SimpleNamespace( + config=SimpleNamespace(use_hybrid_apc_manager=True), + hybrid_apc_bridge=bridge, + ) + input_dict = _prefix_input_dict() + input_dict.update( + { + "hybrid_request_id": ("req-cached", "req-new"), + "hybrid_cached_request_ids": ("req-cached",), + "hybrid_prefill_completion_state": torch.tensor( + [True, False], + dtype=torch.bool, + ), + "input_ids": torch.tensor([[99, 20, 21, 22]], dtype=torch.int32), + "attention_mask": torch.ones((1, 4), dtype=torch.int32), + "position_ids": torch.tensor([[4, 0, 1, 2]], dtype=torch.int32), + "seq_ids": torch.tensor([0], dtype=torch.int32), + "adapter_ids": torch.tensor([0], dtype=torch.int32), + "slot_mapping": torch.tensor([[4, 5, 6, 7]], dtype=torch.int32), + "block_table": torch.tensor([[1, 2], [3, 4]], dtype=torch.int32), + "full_context_lens": torch.tensor([[5], [3]], dtype=torch.int32), + "computed_context_lens": torch.tensor([[4], [0]], dtype=torch.int32), + "num_queries": torch.tensor([[1], [3]], dtype=torch.int32), + } + ) + + prepared = prepare_hybrid_apc_request_for_execution(base, input_dict) + + self.assertEqual( + [call["request_id"] for call in bridge.prepare_calls], + ["req-new"], + ) + self.assertEqual(bridge.suffix_prepare_calls, []) + self.assertTrue( + torch.equal( + prepared["input_ids"], + torch.tensor([[99, 0, 0], [20, 21, 22]], dtype=torch.int32), + ) + ) + self.assertTrue( + torch.equal( + prepared["hybrid_restore_mask"], + torch.tensor([0, 0], dtype=torch.int32), + ) + ) + self.assertTrue( + torch.equal( + prepared["hybrid_commit_mask"], + torch.tensor([0, 1], dtype=torch.int32), + ) + ) + self.assertTrue( + torch.equal( + prepared["full_context_lens"], + torch.tensor([[5], [3]], dtype=torch.int32), + ) + ) + + def test_vectorized_cached_decode_row_pads_to_cte_bucket(self): + bridge = _FakeHybridBridge() + base = SimpleNamespace( + config=SimpleNamespace( + use_hybrid_apc_manager=True, + pad_token_id=0, + ), + neuron_config=SimpleNamespace( + context_encoding_buckets=[2, 4], + pa_block_size=2, + ), + hybrid_apc_bridge=bridge, + ) + input_dict = _prefix_input_dict() + input_dict.update( + { + "hybrid_request_id": ("req-cached", "req-new"), + "hybrid_cached_request_ids": ("req-cached",), + "hybrid_prefill_completion_state": torch.tensor( + [True, False], + dtype=torch.bool, + ), + "input_ids": torch.tensor([[99, 20, 21, 22]], dtype=torch.int32), + "attention_mask": torch.ones((1, 4), dtype=torch.int32), + "position_ids": torch.tensor([[4, 0, 1, 2]], dtype=torch.int32), + "seq_ids": torch.tensor([0], dtype=torch.int32), + "adapter_ids": torch.tensor([0], dtype=torch.int32), + "slot_mapping": torch.tensor([[8, 10, 11, 12]], dtype=torch.int32), + "block_table": torch.tensor( + [[1, 2, 3], [4, 5, 6]], + dtype=torch.int32, + ), + "full_context_lens": torch.tensor([[5], [3]], dtype=torch.int32), + "computed_context_lens": torch.tensor([[4], [0]], dtype=torch.int32), + "num_queries": torch.tensor([[1], [3]], dtype=torch.int32), + } + ) + + prepared = prepare_hybrid_apc_request_for_execution(base, input_dict) + + self.assertTrue( + torch.equal( + prepared["input_ids"], + torch.tensor( + [[99, 0, 0, 0], [20, 21, 22, 0]], + dtype=torch.int32, + ), + ) + ) + self.assertTrue( + torch.equal( + prepared["attention_mask"], + torch.tensor( + [[1, 1, 1, 1, 1, 0], [1, 1, 1, 0, 0, 0]], + dtype=torch.int32, + ), + ) + ) + self.assertTrue( + torch.equal( + prepared["slot_mapping"], + torch.tensor([[8, -1, -1, -1], [10, 11, 12, -1]], dtype=torch.int32), + ) + ) + self.assertTrue( + torch.equal( + prepared["block_table"], + torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=torch.int32), + ) + ) + + def test_vectorized_combiner_repairs_short_active_slot_mapping(self): + base = SimpleNamespace( + config=SimpleNamespace( + use_hybrid_apc_manager=True, + pad_token_id=0, + ), + neuron_config=SimpleNamespace( + context_encoding_buckets=[4], + pa_block_size=2, + ), + ) + row_cached_decode = { + "input_ids": torch.tensor([[99]], dtype=torch.int32), + "attention_mask": torch.ones((1, 1), dtype=torch.int32), + "position_ids": torch.tensor([[4]], dtype=torch.int32), + "seq_ids": torch.tensor([0], dtype=torch.int32), + "slot_mapping": torch.tensor([8], dtype=torch.int32), + "block_table": torch.tensor([[2, 3]], dtype=torch.int32), + "full_context_lens": torch.tensor([[5]], dtype=torch.int32), + "computed_context_lens": torch.tensor([[4]], dtype=torch.int32), + "num_queries": torch.tensor([[1]], dtype=torch.int32), + } + row_prefill = { + "input_ids": torch.tensor([[20, 21, 22]], dtype=torch.int32), + "attention_mask": torch.ones((1, 3), dtype=torch.int32), + "position_ids": torch.tensor([[0, 1, 2]], dtype=torch.int32), + "slot_mapping": torch.tensor([10], dtype=torch.int32), + "block_table": torch.tensor([[4, 5]], dtype=torch.int32), + "full_context_lens": torch.tensor([[3]], dtype=torch.int32), + "computed_context_lens": torch.tensor([[0]], dtype=torch.int32), + "num_queries": torch.tensor([[3]], dtype=torch.int32), + } + + combined = _combine_vectorized_hybrid_apc_inputs( + base, + dict(row_cached_decode), + [row_cached_decode, row_prefill], + ) + + self.assertTrue( + torch.equal( + combined["slot_mapping"], + torch.tensor( + [[8, -1, -1, -1], [8, 9, 10, -1]], + dtype=torch.int32, + ), + ) + ) + self.assertTrue( + torch.equal( + combined["seq_ids"], + torch.tensor([0, 1], dtype=torch.int32), + ) + ) + + def test_vectorized_combiner_preserves_restore_prefix_block_table(self): + base = SimpleNamespace( + config=SimpleNamespace( + use_hybrid_apc_manager=True, + pad_token_id=0, + ), + neuron_config=SimpleNamespace( + context_encoding_buckets=[4], + pa_block_size=2, + ), + ) + row_a = { + "input_ids": torch.tensor([[30, 31]], dtype=torch.int32), + "attention_mask": torch.ones((1, 2), dtype=torch.int32), + "position_ids": torch.tensor([[4, 5]], dtype=torch.int32), + "slot_mapping": torch.tensor([[20, 21]], dtype=torch.int32), + "block_table": torch.tensor([[7, 8]], dtype=torch.int32), + "computed_context_lens": torch.tensor([[4]], dtype=torch.int32), + "num_queries": torch.tensor([[2]], dtype=torch.int32), + "hybrid_restore_mask": torch.tensor([1], dtype=torch.int32), + "hybrid_restore_prefix_lens": torch.tensor([4], dtype=torch.int32), + "rotary_position_ids": torch.tensor( + [[[4, 5]], [[4, 5]], [[4, 5]]], + dtype=torch.int32, + ), + } + row_b = { + **row_a, + "input_ids": torch.tensor([[40, 41]], dtype=torch.int32), + "block_table": torch.tensor([[9, 10]], dtype=torch.int32), + "rotary_position_ids": torch.tensor( + [[[6, 7]], [[6, 7]], [[6, 7]]], + dtype=torch.int32, + ), + } + + combined = _combine_vectorized_hybrid_apc_inputs( + base, + dict(row_a), + [row_a, row_b], + ) + + self.assertTrue( + torch.equal( + combined["block_table"], + torch.tensor([[7, 8], [9, 10]], dtype=torch.int32), + ) + ) + self.assertTrue( + torch.equal( + combined["rotary_position_ids"], + torch.tensor( + [ + [[4, 5], [6, 7]], + [[4, 5], [6, 7]], + [[4, 5], [6, 7]], + ], + dtype=torch.int32, + ), + ) + ) + + def test_vectorized_combiner_repairs_active_window_slot_mapping(self): + base = SimpleNamespace( + config=SimpleNamespace( + use_hybrid_apc_manager=True, + pad_token_id=0, + ), + neuron_config=SimpleNamespace( + context_encoding_buckets=[4], + pa_block_size=2, + ), + ) + row_suffix = { + "input_ids": torch.tensor([[30, 31]], dtype=torch.int32), + "attention_mask": torch.ones((1, 2), dtype=torch.int32), + "position_ids": torch.tensor([[2, 3]], dtype=torch.int32), + "seq_ids": torch.tensor([0], dtype=torch.int32), + "slot_mapping": torch.full((1, 2), -1, dtype=torch.int32), + "block_table": torch.tensor([[4]], dtype=torch.int32), + "full_context_lens": torch.tensor([[4]], dtype=torch.int32), + "computed_context_lens": torch.tensor([[2]], dtype=torch.int32), + "num_queries": torch.tensor([[2]], dtype=torch.int32), + } + row_decode = { + "input_ids": torch.tensor([[99]], dtype=torch.int32), + "attention_mask": torch.ones((1, 1), dtype=torch.int32), + "position_ids": torch.tensor([[4]], dtype=torch.int32), + "slot_mapping": torch.full((1, 1), -1, dtype=torch.int32), + "block_table": torch.tensor([[5]], dtype=torch.int32), + "full_context_lens": torch.tensor([[5]], dtype=torch.int32), + "computed_context_lens": torch.tensor([[4]], dtype=torch.int32), + "num_queries": torch.tensor([[1]], dtype=torch.int32), + } + + combined = _combine_vectorized_hybrid_apc_inputs( + base, + dict(row_suffix), + [row_suffix, row_decode], + ) + + self.assertTrue( + torch.equal( + combined["slot_mapping"], + torch.tensor( + [[8, 9, -1, -1], [10, -1, -1, -1]], + dtype=torch.int32, + ), + ) + ) + + def test_cancel_hybrid_apc_request_is_noop_without_prepared_request(self): + input_dict = {} + + cancel_hybrid_apc_request(input_dict) + + self.assertEqual(input_dict, {}) + + def test_strict_hybrid_apc_requires_attached_bridge(self): + base = SimpleNamespace( + config=SimpleNamespace( + use_hybrid_apc_manager=True, + hybrid_apc_require_vllm_metadata=True, + ) + ) + + with self.assertRaisesRegex(ValueError, "requires a scheduler bridge"): + prepare_hybrid_apc_request_for_execution(base, _prefix_input_dict()) + + def test_strict_hybrid_apc_rejects_suffix_without_full_prompt(self): + bridge = _FakeHybridBridge() + bridge.requires_external_metadata = True + bridge.prepare_suffix_only_request = None + base = SimpleNamespace( + config=SimpleNamespace( + use_hybrid_apc_manager=True, + hybrid_apc_require_vllm_metadata=True, + ), + hybrid_apc_bridge=bridge, + ) + input_dict = _prefix_input_dict() + input_dict["input_ids"] = torch.tensor([[12, 13]], dtype=torch.int32) + input_dict["attention_mask"] = torch.ones((1, 2), dtype=torch.int32) + input_dict["position_ids"] = torch.tensor([[2, 3]], dtype=torch.int32) + input_dict["slot_mapping"] = torch.tensor([[2, 3]], dtype=torch.int32) + input_dict["computed_context_lens"] = torch.tensor([[2]], dtype=torch.int32) + input_dict["request_id"] = "req-strict" + input_dict["vllm_attention_hit_len"] = torch.tensor([2], dtype=torch.int32) + + with self.assertRaisesRegex(ValueError, "suffix-only input"): + prepare_hybrid_apc_request_for_execution(base, input_dict) + + def test_strict_hybrid_apc_suffix_chunk_uses_active_prefix_boundary(self): + bridge = _FakeHybridBridge() + bridge.requires_external_metadata = True + base = SimpleNamespace( + config=SimpleNamespace( + use_hybrid_apc_manager=True, + hybrid_apc_require_vllm_metadata=True, + ), + hybrid_apc_bridge=bridge, + ) + input_dict = _prefix_input_dict() + input_dict["input_ids"] = torch.tensor([[12, 13]], dtype=torch.int32) + input_dict["attention_mask"] = torch.ones((1, 2), dtype=torch.int32) + input_dict["position_ids"] = torch.tensor([[2, 3]], dtype=torch.int32) + input_dict["slot_mapping"] = torch.tensor([[2, 3]], dtype=torch.int32) + input_dict["computed_context_lens"] = torch.tensor([[2]], dtype=torch.int32) + input_dict["request_id"] = "req-strict" + input_dict["request_prefix_len"] = 6 + input_dict["vllm_attention_hit_len"] = torch.tensor([2], dtype=torch.int32) + input_dict["hybrid_active_suffix_len"] = torch.tensor([2], dtype=torch.int32) + + prepared = prepare_hybrid_apc_request_for_execution(base, input_dict) + + self.assertFalse(bridge.prepare_calls) + self.assertEqual(len(bridge.suffix_prepare_calls), 1) + self.assertEqual(bridge.suffix_prepare_calls[0]["request_prefix_len"], 4) + self.assertTrue( + torch.equal( + prepared["full_context_lens"], + torch.tensor([[4]], dtype=torch.int32), + ) + ) + self.assertTrue( + torch.equal( + prepared["num_queries"], + torch.tensor([[2]], dtype=torch.int32), + ) + ) + self.assertTrue( + torch.equal( + prepared["hybrid_restore_mask"], + torch.tensor([0], dtype=torch.int32), + ) + ) + self.assertTrue( + torch.equal( + prepared["hybrid_restore_prefix_lens"], + torch.tensor([2], dtype=torch.int32), + ) + ) + + def test_strict_hybrid_apc_suffix_chunk_without_checkpoint_uses_inert_controls( + self, + ): + bridge = _FakeHybridBridge() + bridge.requires_external_metadata = True + + def raise_unbacked_suffix_error(**kwargs): + bridge.suffix_prepare_calls.append(kwargs) + raise ValueError( + "suffix-only hybrid APC received an attention prefix hit " + "without scheduler-authorized GDN checkpoint metadata" + ) + + bridge.prepare_suffix_only_request = raise_unbacked_suffix_error + base = SimpleNamespace( + config=SimpleNamespace( + use_hybrid_apc_manager=True, + hybrid_apc_require_vllm_metadata=True, + ), + hybrid_apc_bridge=bridge, + ) + input_dict = _prefix_input_dict() + input_dict["input_ids"] = torch.tensor([[12, 13]], dtype=torch.int32) + input_dict["attention_mask"] = torch.ones((1, 2), dtype=torch.int32) + input_dict["position_ids"] = torch.tensor([[2, 3]], dtype=torch.int32) + input_dict["slot_mapping"] = torch.tensor([[2, 3]], dtype=torch.int32) + input_dict["computed_context_lens"] = torch.tensor([[2]], dtype=torch.int32) + input_dict["full_context_lens"] = torch.tensor([[6]], dtype=torch.int32) + input_dict["request_id"] = "req-strict" + input_dict["request_prefix_len"] = 6 + input_dict["vllm_attention_hit_len"] = torch.tensor([2], dtype=torch.int32) + input_dict["hybrid_active_suffix_len"] = torch.tensor([2], dtype=torch.int32) + + prepared = prepare_hybrid_apc_request_for_execution(base, input_dict) + + self.assertFalse(bridge.prepare_calls) + self.assertEqual(len(bridge.suffix_prepare_calls), 1) + self.assertEqual(bridge.suffix_prepare_calls[0]["request_prefix_len"], 4) + self.assertNotIn("_hybrid_apc_prepared", input_dict) + self.assertTrue(torch.equal(prepared["input_ids"], input_dict["input_ids"])) + self.assertTrue( + torch.equal( + prepared["computed_context_lens"], + torch.tensor([[2]], dtype=torch.int32), + ) + ) + self.assertTrue( + torch.equal( + prepared["full_context_lens"], + torch.tensor([[4]], dtype=torch.int32), + ) + ) + self.assertTrue( + torch.equal( + prepared["num_queries"], + torch.tensor([[2]], dtype=torch.int32), + ) + ) + self.assertTrue( + torch.equal( + prepared["hybrid_restore_mask"], + torch.tensor([0], dtype=torch.int32), + ) + ) + self.assertTrue( + torch.equal( + prepared["hybrid_commit_mask"], + torch.tensor([0], dtype=torch.int32), + ) + ) + + def test_strict_hybrid_apc_suffix_chunk_uses_scheduled_active_len(self): + bridge = _FakeHybridBridge() + bridge.requires_external_metadata = True + base = SimpleNamespace( + config=SimpleNamespace( + use_hybrid_apc_manager=True, + hybrid_apc_require_vllm_metadata=True, + ), + hybrid_apc_bridge=bridge, + ) + input_dict = _prefix_input_dict() + input_dict["input_ids"] = torch.tensor( + [[12, 13, 14, 15]], + dtype=torch.int32, + ) + input_dict["attention_mask"] = torch.ones((1, 4), dtype=torch.int32) + input_dict["position_ids"] = torch.tensor([[2, 3, 4, 5]], dtype=torch.int32) + input_dict["slot_mapping"] = torch.tensor([[2, 3, 4, 5]], dtype=torch.int32) + input_dict["computed_context_lens"] = torch.tensor([[2]], dtype=torch.int32) + input_dict["full_context_lens"] = torch.tensor([[6]], dtype=torch.int32) + input_dict["request_id"] = "req-strict" + input_dict["request_prefix_len"] = 6 + input_dict["vllm_attention_hit_len"] = torch.tensor([2], dtype=torch.int32) + input_dict["hybrid_active_suffix_len"] = torch.tensor([2], dtype=torch.int32) + + prepared = prepare_hybrid_apc_request_for_execution(base, input_dict) + + self.assertFalse(bridge.prepare_calls) + self.assertFalse(bridge.suffix_prepare_calls) + self.assertTrue(torch.equal(prepared["input_ids"], input_dict["input_ids"])) + self.assertTrue( + torch.equal( + prepared["attention_mask"], + torch.tensor([[1, 1, 0, 0]], dtype=torch.int32), + ) + ) + self.assertTrue( + torch.equal( + prepared["full_context_lens"], + torch.tensor([[4]], dtype=torch.int32), + ) + ) + self.assertTrue( + torch.equal( + prepared["num_queries"], + torch.tensor([[2]], dtype=torch.int32), + ) + ) + self.assertTrue( + torch.equal( + prepared["hybrid_restore_mask"], + torch.tensor([0], dtype=torch.int32), + ) + ) + + def test_strict_hybrid_apc_full_suffix_without_checkpoint_still_raises(self): + bridge = _FakeHybridBridge() + bridge.requires_external_metadata = True + + def raise_unbacked_suffix_error(**kwargs): + bridge.suffix_prepare_calls.append(kwargs) + raise ValueError( + "suffix-only hybrid APC received an attention prefix hit " + "without scheduler-authorized GDN checkpoint metadata" + ) + + bridge.prepare_suffix_only_request = raise_unbacked_suffix_error + base = SimpleNamespace( + config=SimpleNamespace( + use_hybrid_apc_manager=True, + hybrid_apc_require_vllm_metadata=True, + ), + hybrid_apc_bridge=bridge, + ) + input_dict = _prefix_input_dict() + input_dict["input_ids"] = torch.tensor([[12, 13]], dtype=torch.int32) + input_dict["attention_mask"] = torch.ones((1, 2), dtype=torch.int32) + input_dict["position_ids"] = torch.tensor([[2, 3]], dtype=torch.int32) + input_dict["slot_mapping"] = torch.tensor([[2, 3]], dtype=torch.int32) + input_dict["computed_context_lens"] = torch.tensor([[2]], dtype=torch.int32) + input_dict["request_id"] = "req-strict" + input_dict["request_prefix_len"] = 4 + input_dict["vllm_attention_hit_len"] = torch.tensor([2], dtype=torch.int32) + input_dict["hybrid_active_suffix_len"] = torch.tensor([2], dtype=torch.int32) + + with self.assertRaisesRegex(ValueError, "scheduler-authorized GDN"): + prepare_hybrid_apc_request_for_execution(base, input_dict) + + def test_strict_hybrid_apc_cached_prefill_suffix_without_checkpoint_is_inert( + self, + ): + bridge = _FakeHybridBridge() + bridge.requires_external_metadata = True + + def raise_unbacked_suffix_error(**kwargs): + bridge.suffix_prepare_calls.append(kwargs) + raise ValueError( + "suffix-only hybrid APC received an attention prefix hit " + "without scheduler-authorized GDN checkpoint metadata" + ) + + bridge.prepare_suffix_only_request = raise_unbacked_suffix_error + base = SimpleNamespace( + config=SimpleNamespace( + use_hybrid_apc_manager=True, + hybrid_apc_require_vllm_metadata=True, + ), + hybrid_apc_bridge=bridge, + ) + input_dict = _prefix_input_dict() + input_dict["input_ids"] = torch.tensor([[12, 13]], dtype=torch.int32) + input_dict["attention_mask"] = torch.ones((1, 2), dtype=torch.int32) + input_dict["position_ids"] = torch.tensor([[2, 3]], dtype=torch.int32) + input_dict["slot_mapping"] = torch.tensor([[2, 3]], dtype=torch.int32) + input_dict["computed_context_lens"] = torch.tensor([[2]], dtype=torch.int32) + input_dict["request_id"] = "req-cached-prefill" + input_dict["request_prefix_len"] = 4 + input_dict["vllm_attention_hit_len"] = torch.tensor([2], dtype=torch.int32) + input_dict["hybrid_active_suffix_len"] = torch.tensor([2], dtype=torch.int32) + input_dict["hybrid_cached_request_ids"] = ("req-cached-prefill",) + input_dict["hybrid_prefill_completion_state"] = torch.tensor( + [False], + dtype=torch.bool, + ) + + prepared = prepare_hybrid_apc_request_for_execution(base, input_dict) + + self.assertFalse(bridge.prepare_calls) + self.assertEqual(len(bridge.suffix_prepare_calls), 1) + self.assertTrue( + torch.equal( + prepared["full_context_lens"], + torch.tensor([[4]], dtype=torch.int32), + ) + ) + self.assertTrue( + torch.equal( + prepared["num_queries"], + torch.tensor([[2]], dtype=torch.int32), + ) + ) + self.assertTrue( + torch.equal( + prepared["hybrid_restore_mask"], + torch.tensor([0], dtype=torch.int32), + ) + ) + self.assertTrue( + torch.equal( + prepared["hybrid_commit_mask"], + torch.tensor([0], dtype=torch.int32), + ) + ) + + def test_strict_hybrid_apc_seq_id_prefill_suffix_without_checkpoint_is_inert( + self, + ): + bridge = _FakeHybridBridge() + bridge.requires_external_metadata = True + + def raise_unbacked_suffix_error(**kwargs): + bridge.suffix_prepare_calls.append(kwargs) + raise ValueError( + "suffix-only hybrid APC received an attention prefix hit " + "without scheduler-authorized GDN checkpoint metadata" + ) + + bridge.prepare_suffix_only_request = raise_unbacked_suffix_error + base = SimpleNamespace( + config=SimpleNamespace( + use_hybrid_apc_manager=True, + hybrid_apc_require_vllm_metadata=True, + ), + hybrid_apc_bridge=bridge, + ) + input_dict = _prefix_input_dict() + input_dict["input_ids"] = torch.tensor([[12, 13]], dtype=torch.int32) + input_dict["attention_mask"] = torch.ones((1, 2), dtype=torch.int32) + input_dict["position_ids"] = torch.tensor([[2, 3]], dtype=torch.int32) + input_dict["slot_mapping"] = torch.tensor([[2, 3]], dtype=torch.int32) + input_dict["computed_context_lens"] = torch.tensor([[2]], dtype=torch.int32) + input_dict["full_context_lens"] = torch.tensor([[4]], dtype=torch.int32) + input_dict["seq_ids"] = torch.tensor([0], dtype=torch.int32) + input_dict["vllm_attention_hit_len"] = torch.tensor([2], dtype=torch.int32) + + prepared = prepare_hybrid_apc_request_for_execution(base, input_dict) + + self.assertFalse(bridge.prepare_calls) + self.assertFalse(bridge.suffix_prepare_calls) + self.assertTrue( + torch.equal( + prepared["full_context_lens"], + torch.tensor([[4]], dtype=torch.int32), + ) + ) + self.assertTrue( + torch.equal( + prepared["num_queries"], + torch.tensor([[2]], dtype=torch.int32), + ) + ) + self.assertTrue( + torch.equal( + prepared["hybrid_restore_mask"], + torch.tensor([0], dtype=torch.int32), + ) + ) + self.assertTrue( + torch.equal( + prepared["hybrid_commit_mask"], + torch.tensor([0], dtype=torch.int32), + ) + ) + + def test_strict_hybrid_apc_suffix_chunk_other_bridge_error_still_raises(self): + bridge = _FakeHybridBridge() + bridge.requires_external_metadata = True + + def raise_other_suffix_error(**kwargs): + bridge.suffix_prepare_calls.append(kwargs) + raise ValueError("unrelated suffix bridge error") + + bridge.prepare_suffix_only_request = raise_other_suffix_error + base = SimpleNamespace( + config=SimpleNamespace( + use_hybrid_apc_manager=True, + hybrid_apc_require_vllm_metadata=True, + ), + hybrid_apc_bridge=bridge, + ) + input_dict = _prefix_input_dict() + input_dict["input_ids"] = torch.tensor([[12, 13]], dtype=torch.int32) + input_dict["attention_mask"] = torch.ones((1, 2), dtype=torch.int32) + input_dict["position_ids"] = torch.tensor([[2, 3]], dtype=torch.int32) + input_dict["slot_mapping"] = torch.tensor([[2, 3]], dtype=torch.int32) + input_dict["computed_context_lens"] = torch.tensor([[2]], dtype=torch.int32) + input_dict["request_id"] = "req-strict" + input_dict["request_prefix_len"] = 6 + input_dict["vllm_attention_hit_len"] = torch.tensor([2], dtype=torch.int32) + input_dict["hybrid_active_suffix_len"] = torch.tensor([2], dtype=torch.int32) + + with self.assertRaisesRegex(ValueError, "unrelated suffix bridge error"): + prepare_hybrid_apc_request_for_execution(base, input_dict) + + def test_strict_hybrid_apc_allows_zero_hit_partial_chunk(self): + bridge = _FakeHybridBridge() + bridge.requires_external_metadata = True + base = SimpleNamespace( + config=SimpleNamespace( + use_hybrid_apc_manager=True, + hybrid_apc_require_vllm_metadata=True, + ), + hybrid_apc_bridge=bridge, + ) + input_dict = _prefix_input_dict() + input_dict["input_ids"] = torch.tensor([[12, 13]], dtype=torch.int32) + input_dict["attention_mask"] = torch.ones((1, 2), dtype=torch.int32) + input_dict["position_ids"] = torch.tensor([[0, 1]], dtype=torch.int32) + input_dict["slot_mapping"] = torch.tensor([[0, 1]], dtype=torch.int32) + input_dict["full_context_lens"] = torch.tensor([[4]], dtype=torch.int32) + input_dict["computed_context_lens"] = torch.tensor([[0]], dtype=torch.int32) + input_dict["request_id"] = "req-strict" + input_dict["vllm_attention_hit_len"] = torch.tensor([0], dtype=torch.int32) + + prepared = prepare_hybrid_apc_request_for_execution(base, input_dict) + + self.assertFalse(bridge.prepare_calls) + self.assertFalse(bridge.suffix_prepare_calls) + self.assertTrue(torch.equal(prepared["input_ids"], input_dict["input_ids"])) + self.assertTrue( + torch.equal( + prepared["hybrid_restore_mask"], + torch.tensor([0], dtype=torch.int32), + ) + ) + self.assertTrue( + torch.equal( + prepared["hybrid_commit_mask"], + torch.tensor([0], dtype=torch.int32), + ) + ) + + def test_strict_hybrid_apc_commits_zero_hit_chunk_boundary_with_hash(self): + bridge = _FakeHybridBridge() + bridge.requires_external_metadata = True + base = SimpleNamespace( + config=SimpleNamespace( + use_hybrid_apc_manager=True, + hybrid_apc_require_vllm_metadata=True, + ), + hybrid_apc_bridge=bridge, + ) + input_dict = _prefix_input_dict() + input_dict["input_ids"] = torch.tensor([[12, 13]], dtype=torch.int32) + input_dict["attention_mask"] = torch.ones((1, 2), dtype=torch.int32) + input_dict["position_ids"] = torch.tensor([[0, 1]], dtype=torch.int32) + input_dict["slot_mapping"] = torch.tensor([[0, 1]], dtype=torch.int32) + input_dict["full_context_lens"] = torch.tensor([[4]], dtype=torch.int32) + input_dict["computed_context_lens"] = torch.tensor([[0]], dtype=torch.int32) + input_dict["request_id"] = "req-strict" + input_dict["vllm_attention_hit_len"] = torch.tensor([0], dtype=torch.int32) + input_dict["cumulative_hashes_by_prefix_len"] = {2: b"hash-2"} + input_dict["attention_block_refs_by_prefix_len"] = {2: [9]} + + prepared = prepare_hybrid_apc_request_for_execution(base, input_dict) + + self.assertEqual(len(bridge.prepare_calls), 1) + self.assertFalse(bridge.suffix_prepare_calls) + self.assertEqual(bridge.prepare_kwargs["request_prefix_len"], 2) + self.assertTrue( + torch.equal( + prepared["hybrid_commit_mask"], + torch.tensor([1], dtype=torch.int32), + ) + ) + + def test_strict_hybrid_apc_allows_zero_hit_without_hash_metadata(self): + bridge = _FakeHybridBridge() + bridge.requires_external_metadata = True + base = SimpleNamespace( + config=SimpleNamespace( + use_hybrid_apc_manager=True, + hybrid_apc_require_vllm_metadata=True, + ), + hybrid_apc_bridge=bridge, + ) + input_dict = _prefix_input_dict() + input_dict["request_id"] = "req-short-cold" + input_dict["request_prefix_len"] = 255 + input_dict["vllm_attention_hit_len"] = torch.tensor([0], dtype=torch.int32) + + prepared = prepare_hybrid_apc_request_for_execution(base, input_dict) + + self.assertFalse(bridge.prepare_calls) + self.assertTrue(torch.equal(prepared["input_ids"], input_dict["input_ids"])) + self.assertTrue( + torch.equal( + prepared["hybrid_restore_mask"], + torch.tensor([0], dtype=torch.int32), + ) + ) + self.assertTrue( + torch.equal( + prepared["hybrid_commit_mask"], + torch.tensor([0], dtype=torch.int32), + ) + ) + + def test_zero_hit_strict_hybrid_apc_rebuilds_active_attention_mask(self): + bridge = _FakeHybridBridge() + bridge.requires_external_metadata = True + base = SimpleNamespace( + config=SimpleNamespace( + use_hybrid_apc_manager=True, + hybrid_apc_require_vllm_metadata=True, + ), + hybrid_apc_bridge=bridge, + ) + input_dict = _prefix_input_dict() + input_dict["request_id"] = "req-cold" + input_dict["request_prefix_len"] = 4 + input_dict["vllm_attention_hit_len"] = torch.tensor([0], dtype=torch.int32) + input_dict["attention_mask"] = torch.zeros((1, 4), dtype=torch.int32) + + prepared = prepare_hybrid_apc_request_for_execution(base, input_dict) + + self.assertFalse(bridge.prepare_calls) + self.assertTrue( + torch.equal( + prepared["num_queries"], + torch.tensor([[4]], dtype=torch.int32), + ) + ) + self.assertTrue( + torch.equal( + prepared["attention_mask"], + torch.ones((1, 4), dtype=torch.int32), + ) + ) + + +def _prefix_input_dict(): + return { + "input_ids": torch.tensor([[10, 11, 12, 13]], dtype=torch.int32), + "attention_mask": torch.ones((1, 4), dtype=torch.int32), + "position_ids": torch.arange(4, dtype=torch.int32).unsqueeze(0), + "seq_ids": torch.tensor([0], dtype=torch.int32), + "sampling_params": torch.zeros((1, 1), dtype=torch.int32), + "adapter_ids": torch.zeros((1,), dtype=torch.int32), + "slot_mapping": torch.arange(4, dtype=torch.int32).unsqueeze(0), + "block_table": torch.tensor([[1, 2]], dtype=torch.int32), + "full_context_lens": torch.tensor([[4]], dtype=torch.int32), + "computed_context_lens": torch.tensor([[0]], dtype=torch.int32), + } + + +class _FakeHybridBridge: + def __init__(self): + self.prepare_kwargs = None + self.prepare_calls = [] + self.suffix_prepare_calls = [] + self.committed = [] + self.finished = [] + self.cancelled = [] + + def prepare_request(self, **kwargs): + self.prepare_kwargs = kwargs + self.prepare_calls.append(kwargs) + input_dict = dict(kwargs["input_dict"]) + restore_len = int(kwargs["attention_hit_len"]) + prompt_len = int(input_dict["input_ids"].shape[1]) + suffix_len = prompt_len - restore_len + restore_slot = 5 if restore_len > 0 else 0 + input_dict.update( + { + "input_ids": input_dict["input_ids"][:, restore_len:prompt_len], + "attention_mask": input_dict["attention_mask"][:, restore_len:prompt_len], + "position_ids": input_dict["position_ids"][:, restore_len:prompt_len], + "slot_mapping": input_dict["slot_mapping"][:, restore_len:prompt_len], + "computed_context_lens": torch.tensor([[restore_len]], dtype=torch.int32), + "full_context_lens": torch.tensor([[prompt_len]], dtype=torch.int32), + "num_queries": torch.tensor([[suffix_len]], dtype=torch.int32), + "hybrid_restore_slot_ids": torch.tensor([restore_slot], dtype=torch.int32), + "hybrid_restore_mask": torch.tensor([1 if restore_len > 0 else 0], dtype=torch.int32), + "hybrid_restore_prefix_lens": torch.tensor([restore_len], dtype=torch.int32), + "hybrid_commit_slot_ids": torch.tensor([7], dtype=torch.int32), + "hybrid_commit_mask": torch.tensor([1], dtype=torch.int32), + } + ) + return SimpleNamespace( + request_id=kwargs["request_id"], + input_dict=input_dict, + plan=SimpleNamespace( + restore_checkpoint_prefix_len=restore_len, + checkpoint_slot=restore_slot, + ), + commit_prefix_len=prompt_len, + commit_slot=7, + attention_block_refs=(11, 12), + ) + + def prepare_suffix_only_request(self, **kwargs): + self.suffix_prepare_calls.append(kwargs) + input_dict = dict(kwargs["input_dict"]) + restore_len = int(kwargs["attention_hit_len"]) + prompt_len = int(kwargs["request_prefix_len"]) + suffix_len = int(input_dict["input_ids"].shape[1]) + input_dict.update( + { + "computed_context_lens": torch.tensor([[restore_len]], dtype=torch.int32), + "full_context_lens": torch.tensor([[prompt_len]], dtype=torch.int32), + "num_queries": torch.tensor([[suffix_len]], dtype=torch.int32), + "hybrid_restore_slot_ids": torch.tensor([5], dtype=torch.int32), + "hybrid_restore_mask": torch.tensor([1], dtype=torch.int32), + "hybrid_restore_prefix_lens": torch.tensor([restore_len], dtype=torch.int32), + "hybrid_commit_slot_ids": torch.tensor([0], dtype=torch.int32), + "hybrid_commit_mask": torch.tensor([0], dtype=torch.int32), + } + ) + return SimpleNamespace( + request_id=kwargs["request_id"], + input_dict=input_dict, + plan=SimpleNamespace( + restore_checkpoint_prefix_len=restore_len, + checkpoint_slot=5, + ), + commit_prefix_len=prompt_len, + commit_slot=None, + attention_block_refs=(11, 12), + ) + + def commit_prefill(self, prepared, *, attention_block_refs=None): + self.committed.append((prepared, tuple(attention_block_refs))) + + def finish_request(self, request_id): + self.finished.append(request_id) + + def cancel_request(self, prepared): + self.cancelled.append(prepared) + + +class _FakePrefixModel: + def __init__(self, should_fail=False, tag="context_encoding_model"): + self.should_fail = should_fail + self.tag = tag + self.calls = [] + + def __call__(self, *args, **kwargs): + if self.should_fail: + raise RuntimeError("model failed") + self.calls.append(args) + return "model-output" + + def is_neuron(self): + return False diff --git a/test/unit/modules/test_autobucketing.py b/test/unit/modules/test_autobucketing.py index 0a7292d8..3fd395b1 100644 --- a/test/unit/modules/test_autobucketing.py +++ b/test/unit/modules/test_autobucketing.py @@ -118,6 +118,28 @@ def test_generate_buckets_for_cte(): assert result == [128, 256, 512, 1024] +def test_generate_buckets_for_cte_uses_sparse_prefix_pairs(): + n_config = NeuronConfig( + enable_bucketing=True, + is_prefix_caching=True, + max_context_length=65536, + max_length=65536, + context_encoding_buckets=[512, 1536], + prefix_buckets=[256, 65536], + context_encoding_bucket_pairs=[ + [512, 0], + [512, 256], + [1536, 0], + [1536, 65536], + ], + ) + config = InferenceConfig(neuron_config=n_config) + + result = autobucketing.generate_buckets_for_cte(config) + + assert result == [[512, 0], [512, 256], [1536, 0], [1536, 65536]] + + def test_generate_buckets_for_tkg(): # Test with bucketing disabled and no prefix caching n_config = NeuronConfig( diff --git a/test/unit/scripts/test_qwen36_openai_compat_server.py b/test/unit/scripts/test_qwen36_openai_compat_server.py new file mode 100644 index 00000000..235dc0a8 --- /dev/null +++ b/test/unit/scripts/test_qwen36_openai_compat_server.py @@ -0,0 +1,46 @@ +import importlib.util +from pathlib import Path + +import pytest + + +_REPO_ROOT = Path(__file__).resolve().parents[3] +_SCRIPT_PATH = ( + _REPO_ROOT + / "contrib" + / "models" + / "Qwen3.6-27B" + / "scripts" + / "openai_compat_server.py" +) +_SPEC = importlib.util.spec_from_file_location( + "qwen36_openai_compat_server", + _SCRIPT_PATH, +) +_SERVER = importlib.util.module_from_spec(_SPEC) +_SPEC.loader.exec_module(_SERVER) + + +def test_stop_string_is_treated_as_one_sequence(): + assert _SERVER._normalize_stop_sequences("END") == ["END"] + + +def test_stop_list_preserves_string_sequences(): + assert _SERVER._normalize_stop_sequences(["END", "DONE"]) == ["END", "DONE"] + + +def test_completion_prompt_preserves_token_id_prompt(): + assert _SERVER._completion_prompt([101, 202, 303]) == [101, 202, 303] + + +def test_completion_prompt_uses_first_token_id_prompt_for_batched_input(): + assert _SERVER._completion_prompt([[101, 202], [303, 404]]) == [101, 202] + + +def test_completion_prompt_uses_first_text_prompt_for_batched_input(): + assert _SERVER._completion_prompt(["first", "second"]) == "first" + + +def test_completion_prompt_rejects_mixed_token_id_prompt(): + with pytest.raises(ValueError, match="token-id prompt lists"): + _SERVER._completion_prompt([101, "bad"]) diff --git a/test/unit/scripts/test_qwen36_validation_gates.py b/test/unit/scripts/test_qwen36_validation_gates.py new file mode 100644 index 00000000..df530a0f --- /dev/null +++ b/test/unit/scripts/test_qwen36_validation_gates.py @@ -0,0 +1,53 @@ +import importlib.util +import urllib.error +from pathlib import Path + + +_REPO_ROOT = Path(__file__).resolve().parents[3] + + +def _load_script(name: str, relative_path: str): + spec = importlib.util.spec_from_file_location(name, _REPO_ROOT / relative_path) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module + + +_CHAT_APC = _load_script( + "qwen36_openai_chat_apc_validation", + "validation_scripts/qwen36_openai_chat_apc_validation.py", +) +_BOUNDARY_APC = _load_script( + "qwen36_openai_boundary_apc_probe", + "validation_scripts/qwen36_openai_boundary_apc_probe.py", +) + + +def test_chat_apc_gate_fails_without_exact_repeats(): + summary = { + "all_status_ok": True, + "warm_full_exact_text": False, + "partial_repeat_exact_text": True, + "multi_turn_repeat_exact_text": True, + "semantic_smoke_passed": True, + "warm_full_speedup_passed": True, + "partial_reference_speedup_passed": True, + } + + assert _CHAT_APC._apc_gate_failures(summary) == ["warm_full_exact_text"] + + +def test_chat_apc_speedup_gate_requires_threshold(): + assert _CHAT_APC._speedup_passes(2.0, 1.5) + assert not _CHAT_APC._speedup_passes(1.0, 1.5) + assert not _CHAT_APC._speedup_passes(None, 1.5) + assert _CHAT_APC._speedup_passes(None, 0.0) + + +def test_boundary_metric_snapshot_is_optional(monkeypatch): + def raise_url_error(*args, **kwargs): + raise urllib.error.URLError("metrics disabled") + + monkeypatch.setattr(_BOUNDARY_APC.urllib.request, "urlopen", raise_url_error) + + assert _BOUNDARY_APC._metric_snapshot("http://127.0.0.1:8000", 0.1) == {} diff --git a/validation_outputs/qwen36_nativechunk_baseline_20260609T000000Z/baseline_258k_chunk_timing.json b/validation_outputs/qwen36_nativechunk_baseline_20260609T000000Z/baseline_258k_chunk_timing.json new file mode 100644 index 00000000..86ab5dd0 --- /dev/null +++ b/validation_outputs/qwen36_nativechunk_baseline_20260609T000000Z/baseline_258k_chunk_timing.json @@ -0,0 +1,7 @@ +{ + "chunk_count": 0, + "run_count": 0, + "runs": [], + "runtime_log": "validation_outputs/qwen36_nativechunk_baseline_20260609T000000Z/qwen36_256k_nativechunk_crossguard_258k_runtime_20260609T000000Z.log", + "total_elapsed_s": 0 +} diff --git a/validation_outputs/qwen36_nativechunk_baseline_20260609T000000Z/baseline_summary.json b/validation_outputs/qwen36_nativechunk_baseline_20260609T000000Z/baseline_summary.json new file mode 100644 index 00000000..c1f60890 --- /dev/null +++ b/validation_outputs/qwen36_nativechunk_baseline_20260609T000000Z/baseline_summary.json @@ -0,0 +1,66 @@ +{ + "artifact": "/mnt/trainium_artifacts/qwen_artifacts/qwen36_256k_fp8_loadfix_lmheadbf16_gatesbf16_kvbf16_qkvnki_segmented_cte512_gdnseg512_cte2048_pfx256k_pa1025_slots64_20260608T195113Z_256k_loadfix_segcte2048_chatfix_hostsampling_kkt_hier_scan7", + "baseline_chunk_timing": null, + "baseline_prefill_tok_s": 622.9956983504886, + "boundary": { + "case_count": 1, + "case_failures": [], + "lengths": [ + 16374 + ], + "ok": true, + "path": "validation_outputs/qwen36_nativechunk_baseline_20260609T000000Z/qwen36_256k_nativechunk_crossguard_clean16k_probe_20260609T000000Z.jsonl", + "summary_ok": true, + "summary_present": false + }, + "chunk_timing": { + "chunk_count": 0, + "path": "validation_outputs/qwen36_nativechunk_baseline_20260609T000000Z/baseline_258k_chunk_timing.json", + "run_count": 0, + "runs": [], + "runtime_log": "validation_outputs/qwen36_nativechunk_baseline_20260609T000000Z/qwen36_256k_nativechunk_crossguard_258k_runtime_20260609T000000Z.log", + "total_elapsed_s": 0 + }, + "chunk_timing_comparison": null, + "coherent": true, + "log_scan": { + "empty": true, + "line_count": 0, + "path": "validation_outputs/qwen36_nativechunk_baseline_20260609T000000Z/log_scan_empty.txt", + "sample": [] + }, + "long_context": [ + { + "case_count": 1, + "case_failures": [], + "lengths": [ + 242864 + ], + "ok": true, + "path": "validation_outputs/qwen36_nativechunk_baseline_20260609T000000Z/qwen36_256k_nativechunk_crossguard_258k_probe_20260609T000000Z.jsonl", + "summary_ok": true, + "summary_present": false + } + ], + "material_speedup_threshold": 1.2, + "materially_faster": true, + "near_max": null, + "raw_16k": { + "completion_tokens": 16, + "elapsed_seconds": 6.83794949100411, + "path": "validation_outputs/qwen36_nativechunk_baseline_20260609T000000Z/qwen36_256k_nativechunk_crossguard_clean16k_probe_20260609T000000Z.jsonl", + "prefill_tok_s": 2394.577500395602, + "prompt_tokens": 16374, + "text": "Here's a thinking process:\n\n1. **Analyze User Input:**", + "time_basis": "ttft_seconds" + }, + "recommended_next_action": { + "action": "promote_after_repeat_validation", + "reason": "Candidate recovered the target prefill speed while staying coherent." + }, + "run_id": "baseline_nativechunk_20260609T000000Z", + "speed_class": "target_recovered", + "speedup_vs_baseline": 3.843650135523803, + "target_recovered": true, + "target_tok_s": 1200.0 +} diff --git a/validation_outputs/qwen36_nativechunk_baseline_20260609T000000Z/log_scan_empty.txt b/validation_outputs/qwen36_nativechunk_baseline_20260609T000000Z/log_scan_empty.txt new file mode 100644 index 00000000..e69de29b diff --git a/validation_outputs/qwen36_nativechunk_baseline_20260609T000000Z/qwen36_256k_nativechunk_crossguard_258k_probe_20260609T000000Z.jsonl b/validation_outputs/qwen36_nativechunk_baseline_20260609T000000Z/qwen36_256k_nativechunk_crossguard_258k_probe_20260609T000000Z.jsonl new file mode 100644 index 00000000..8a047774 --- /dev/null +++ b/validation_outputs/qwen36_nativechunk_baseline_20260609T000000Z/qwen36_256k_nativechunk_crossguard_258k_probe_20260609T000000Z.jsonl @@ -0,0 +1 @@ +{"bad_empty": false, "bad_repeated_bang": false, "content_chunk_count": 16, "content_preview": "Here's a thinking process:\n\n1. **Analyze User Input:**", "effective_target_tokens": 253904, "elapsed_seconds": 238.5765401269964, "error": null, "pass": true, "prompt_tokens": 253899, "prompt_tokens_estimated": true, "status": 200, "target_tokens": 258000, "thinking": true, "ttft_seconds": 235.98190240700205, "usage": {"completion_tokens": 16, "prompt_tokens": 242864, "total_tokens": 242880}} diff --git a/validation_outputs/qwen36_nativechunk_baseline_20260609T000000Z/qwen36_256k_nativechunk_crossguard_clean16k_probe_20260609T000000Z.jsonl b/validation_outputs/qwen36_nativechunk_baseline_20260609T000000Z/qwen36_256k_nativechunk_crossguard_clean16k_probe_20260609T000000Z.jsonl new file mode 100644 index 00000000..5a55baaa --- /dev/null +++ b/validation_outputs/qwen36_nativechunk_baseline_20260609T000000Z/qwen36_256k_nativechunk_crossguard_clean16k_probe_20260609T000000Z.jsonl @@ -0,0 +1 @@ +{"bad_empty": false, "bad_repeated_bang": false, "content_chunk_count": 16, "content_preview": "Here's a thinking process:\n\n1. **Analyze User Input:**", "elapsed_seconds": 8.87324038400402, "error": null, "pass": true, "prompt_tokens": 16374, "status": 200, "target_tokens": 16384, "thinking": true, "ttft_seconds": 6.83794949100411, "usage": {"completion_tokens": 16, "prompt_tokens": 16374, "total_tokens": 16390}} diff --git a/validation_scripts/neuron_memory_sampler.py b/validation_scripts/neuron_memory_sampler.py new file mode 100644 index 00000000..7a1cca61 --- /dev/null +++ b/validation_scripts/neuron_memory_sampler.py @@ -0,0 +1,188 @@ +#!/usr/bin/env python3 +"""Sample host RSS and Neuron device memory while a benchmark runs.""" + +from __future__ import annotations + +import argparse +import json +import os +import re +import signal +import subprocess +import time +from pathlib import Path +from typing import Any + + +def _read_int(path: Path) -> int | None: + try: + text = path.read_text(encoding="utf-8").strip() + except OSError: + return None + if not text: + return None + try: + return int(text) + except ValueError: + return None + + +def _sample_neuron_sysfs() -> dict[str, Any]: + root = Path("/sys/devices/virtual/neuron_device") + totals: dict[str, int] = {} + cores: dict[str, dict[str, int]] = {} + if not root.exists(): + return {"available": False, "totals_bytes": totals, "cores": cores} + for path in root.glob("neuron*/neuron_core*/stats/memory_usage/device_mem/**/*"): + if not path.is_file(): + continue + value = _read_int(path) + if value is None: + continue + parts = path.parts + try: + neuron = next(part for part in parts if part.startswith("neuron")) + core = next(part for part in parts if part.startswith("neuron_core")) + except StopIteration: + continue + category = path.parent.name if path.name == "bytes" else path.name + key = f"{neuron}/{core}" + cores.setdefault(key, {})[category] = value + totals[category] = totals.get(category, 0) + value + return {"available": True, "totals_bytes": totals, "cores": cores} + + +def _sample_processes(match: re.Pattern[str]) -> dict[str, Any]: + result = subprocess.run( + ["ps", "-eo", "pid,ppid,rss,comm,args"], + check=False, + text=True, + stdout=subprocess.PIPE, + stderr=subprocess.DEVNULL, + ) + processes = [] + total_rss_kb = 0 + own_pid = os.getpid() + for line in result.stdout.splitlines()[1:]: + fields = line.strip().split(None, 4) + if len(fields) < 5: + continue + pid, ppid, rss, command, args = fields + if int(pid) == own_pid: + continue + if not match.search(args) and not match.search(command): + continue + rss_kb = int(rss) + total_rss_kb += rss_kb + processes.append( + { + "pid": int(pid), + "ppid": int(ppid), + "rss_kb": rss_kb, + "command": command, + "args": args, + } + ) + return {"processes": processes, "total_rss_kb": total_rss_kb} + + +def _write_json(path: Path, payload: dict[str, Any]) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text(json.dumps(payload, indent=2, sort_keys=True) + "\n", encoding="utf-8") + + +def main() -> int: + parser = argparse.ArgumentParser() + parser.add_argument("--output-jsonl", type=Path, required=True) + parser.add_argument("--summary-json", type=Path, required=True) + parser.add_argument( + "--match", + default="VLLM::EngineCore|qwen36_.*bench|qwen36_.*sweep", + help="Regex matched against process command names and args.", + ) + parser.add_argument("--interval-seconds", type=float, default=2.0) + parser.add_argument("--duration-seconds", type=float, default=0.0) + parser.add_argument( + "--stop-when-no-match", + action="store_true", + help="Exit once at least one matching process has been seen and later none remain.", + ) + args = parser.parse_args() + + args.output_jsonl.parent.mkdir(parents=True, exist_ok=True) + match = re.compile(args.match) + start = time.time() + stop_requested = False + saw_process = False + samples = 0 + peak_host_rss_kb = 0 + peak_neuron_total_bytes = 0 + peak_neuron_by_category: dict[str, int] = {} + + def _request_stop(_signum: int, _frame: object) -> None: + nonlocal stop_requested + stop_requested = True + + signal.signal(signal.SIGTERM, _request_stop) + signal.signal(signal.SIGINT, _request_stop) + + with args.output_jsonl.open("a", encoding="utf-8") as handle: + while True: + process_sample = _sample_processes(match) + neuron_sample = _sample_neuron_sysfs() + now = time.time() + row = { + "timestamp_unix": now, + "elapsed_seconds": now - start, + "host": process_sample, + "neuron": neuron_sample, + } + handle.write(json.dumps(row, sort_keys=True) + "\n") + handle.flush() + samples += 1 + + if process_sample["processes"]: + saw_process = True + peak_host_rss_kb = max(peak_host_rss_kb, int(process_sample["total_rss_kb"])) + neuron_totals = neuron_sample.get("totals_bytes", {}) + current_neuron_total = sum(int(value) for value in neuron_totals.values()) + peak_neuron_total_bytes = max(peak_neuron_total_bytes, current_neuron_total) + for category, value in neuron_totals.items(): + peak_neuron_by_category[category] = max( + peak_neuron_by_category.get(category, 0), + int(value), + ) + + if stop_requested: + break + if args.duration_seconds and now - start >= args.duration_seconds: + break + if ( + args.stop_when_no_match + and saw_process + and not process_sample["processes"] + ): + break + time.sleep(args.interval_seconds) + + _write_json( + args.summary_json, + { + "samples": samples, + "duration_seconds": time.time() - start, + "peak_host_rss_kb": peak_host_rss_kb, + "peak_host_rss_gib": peak_host_rss_kb / 1024 / 1024, + "peak_neuron_total_bytes": peak_neuron_total_bytes, + "peak_neuron_total_gib": peak_neuron_total_bytes / 1024 / 1024 / 1024, + "peak_neuron_by_category_bytes": peak_neuron_by_category, + "peak_neuron_by_category_gib": { + key: value / 1024 / 1024 / 1024 + for key, value in peak_neuron_by_category.items() + }, + }, + ) + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/validation_scripts/qwen36_artifact_config_audit.py b/validation_scripts/qwen36_artifact_config_audit.py new file mode 100644 index 00000000..81b70781 --- /dev/null +++ b/validation_scripts/qwen36_artifact_config_audit.py @@ -0,0 +1,270 @@ +#!/usr/bin/env python3 +"""Audit Qwen3.6 Neuron artifact config for APC/prefill A/B experiments.""" + +from __future__ import annotations + +import argparse +import json +import math +from pathlib import Path +from typing import Any + + +def _load_config(path: Path) -> dict[str, Any]: + config_path = path + if path.is_dir(): + config_path = path / "neuron_config.json" + with config_path.open() as handle: + return json.load(handle) + + +def _first_config_value(config: dict[str, Any], *keys: str, default: Any = None) -> Any: + for key in keys: + if key in config: + return config[key] + override = config.get("override_neuron_config") + if isinstance(override, dict): + for key in keys: + if key in override: + return override[key] + nested = config.get("neuron_config") + if isinstance(nested, dict): + for key in keys: + if key in nested: + return nested[key] + nested_override = nested.get("override_neuron_config") + if isinstance(nested_override, dict): + for key in keys: + if key in nested_override: + return nested_override[key] + return default + + +def _bool_config(config: dict[str, Any], *keys: str) -> bool: + return bool(_first_config_value(config, *keys, default=False)) + + +def _compile_backend_from_log(path: Path | None) -> str | None: + if path is None or not path.exists(): + return None + backend = None + for line in path.read_text(errors="replace").splitlines(): + if line.startswith("DELTANET_CTE_BACKEND "): + backend = line.split(maxsplit=1)[1].strip() + elif " --deltanet-cte-backend " in line: + backend = line.rsplit(" --deltanet-cte-backend ", 1)[1].split()[0] + return backend + + +def _warning( + warnings: list[dict[str, Any]], + *, + code: str, + message: str, + value: Any = None, +): + warnings.append({"code": code, "message": message, "value": value}) + + +def audit( + *, + artifact: Path, + compile_log: Path | None, + recommended_block_size: int, + min_usable_headroom_blocks: int, + strict_hybrid_gate: bool, +) -> dict[str, Any]: + config = _load_config(artifact) + seq_len = int(_first_config_value(config, "seq_len", "max_length", default=0) or 0) + max_num_seqs = int(_first_config_value(config, "batch_size", default=1) or 1) + ctx_batch_size = int(_first_config_value(config, "ctx_batch_size", default=1) or 1) + block_size = int(_first_config_value(config, "pa_block_size", default=0) or 0) + pa_num_blocks = int(_first_config_value(config, "pa_num_blocks", default=0) or 0) + max_gdn_slots = int( + _first_config_value(config, "max_gdn_checkpoint_slots", default=0) or 0 + ) + cte_buckets = _first_config_value(config, "context_encoding_buckets", default=[]) + token_generation_buckets = _first_config_value( + config, + "token_generation_buckets", + default=[], + ) + prefix_buckets = _first_config_value(config, "prefix_buckets", default=[]) + tkg_batch_size = int(_first_config_value(config, "tkg_batch_size", default=1) or 1) + async_mode = _bool_config(config, "async_mode") + output_logits = _bool_config(config, "output_logits") + on_device_sampling_config = _first_config_value( + config, + "on_device_sampling_config", + default=None, + ) + min_blocks = ( + max(1, math.ceil(seq_len / block_size) * max_num_seqs) + if seq_len > 0 and block_size > 0 + else 0 + ) + usable_headroom_blocks = pa_num_blocks - min_blocks if pa_num_blocks else None + usable_headroom_blocks = ( + max(0, usable_headroom_blocks) + if usable_headroom_blocks is not None + else None + ) + required_full_prompt_boundaries = ( + math.ceil(seq_len / block_size) if seq_len > 0 and block_size > 0 else 0 + ) + compile_backend = _compile_backend_from_log(compile_log) + if compile_backend is None and "nki_chunked" in str(artifact): + compile_backend = "nki_chunked_from_artifact_name" + + warnings: list[dict[str, Any]] = [] + if block_size and recommended_block_size and block_size != recommended_block_size: + _warning( + warnings, + code="non_recommended_block_size", + message=( + "Artifact PA block size differs from the configured Neuron " + "performance recommendation." + ), + value={"pa_block_size": block_size, "recommended": recommended_block_size}, + ) + if ( + usable_headroom_blocks is not None + and usable_headroom_blocks < min_usable_headroom_blocks + ): + _warning( + warnings, + code="low_pa_headroom", + message=( + "PA block capacity has little usable residency headroom after " + "minimum sequence capacity." + ), + value={ + "pa_num_blocks": pa_num_blocks, + "min_blocks": min_blocks, + "usable_headroom_blocks": usable_headroom_blocks, + "minimum_expected": min_usable_headroom_blocks, + }, + ) + if strict_hybrid_gate and max_gdn_slots and required_full_prompt_boundaries > max_gdn_slots: + _warning( + warnings, + code="strict_gate_boundary_slots_exceed_gdn_slots", + message=( + "With the current disable-unbacked-prefix-reads gate, a full " + "prompt can require more backed prefix boundaries than the GDN " + "checkpoint slot budget can hold unless boundary chunk commits " + "or a less conservative gate are used." + ), + value={ + "required_full_prompt_boundaries": required_full_prompt_boundaries, + "max_gdn_checkpoint_slots": max_gdn_slots, + }, + ) + if compile_backend and "nki_chunked" in compile_backend: + _warning( + warnings, + code="nki_chunked_deltanet_cte", + message=( + "Compile log or artifact name indicates the nki_chunked DeltaNet " + "CTE backend; compare against a fused-control artifact." + ), + value=compile_backend, + ) + if ( + seq_len >= 32768 + and isinstance(token_generation_buckets, list) + and token_generation_buckets == [seq_len] + ): + _warning( + warnings, + code="single_full_length_tkg_bucket", + message=( + "Decode has only a full-length token-generation bucket. Short " + "generations will still use the largest TKG trace shape; compare " + "against an artifact compiled with smaller TKG buckets such as " + "8192,32768,seq_len." + ), + value=token_generation_buckets, + ) + if not async_mode: + _warning( + warnings, + code="sync_neuron_runtime_decode", + message=( + "Neuron async_mode is disabled. The previous fast decode control " + "path used async runtime execution for token generation." + ), + value=False, + ) + if tkg_batch_size <= 1: + _warning( + warnings, + code="single_sequence_tkg_batch", + message=( + "tkg_batch_size is 1, so decode cannot amortize per-token runner " + "overhead across concurrent sequences." + ), + value=tkg_batch_size, + ) + + summary = { + "artifact": str(artifact), + "compile_log": str(compile_log) if compile_log is not None else None, + "seq_len": seq_len, + "max_num_seqs": max_num_seqs, + "ctx_batch_size": ctx_batch_size, + "pa_block_size": block_size, + "pa_num_blocks": pa_num_blocks, + "pa_min_blocks": min_blocks, + "pa_usable_headroom_blocks": usable_headroom_blocks, + "max_gdn_checkpoint_slots": max_gdn_slots, + "required_full_prompt_boundaries": required_full_prompt_boundaries, + "context_encoding_buckets": cte_buckets, + "token_generation_buckets": token_generation_buckets, + "tkg_batch_size": tkg_batch_size, + "async_mode": async_mode, + "output_logits": output_logits, + "on_device_sampling": on_device_sampling_config is not None, + "prefix_buckets": prefix_buckets, + "is_prefix_caching": _bool_config(config, "is_prefix_caching"), + "use_hybrid_apc_manager": _bool_config(config, "use_hybrid_apc_manager"), + "use_qwen_hybrid_chunked_prefill": _bool_config( + config, + "use_qwen_hybrid_chunked_prefill", + ), + "deltanet_cte_backend": compile_backend, + "warnings": warnings, + } + summary["warning_count"] = len(warnings) + return summary + + +def main() -> int: + parser = argparse.ArgumentParser() + parser.add_argument("artifact", help="Artifact directory or neuron_config.json path") + parser.add_argument("--compile-log", type=Path, default=None) + parser.add_argument("--recommended-block-size", type=int, default=32) + parser.add_argument("--min-usable-headroom-blocks", type=int, default=8) + parser.add_argument( + "--no-strict-hybrid-gate", + action="store_true", + help="Do not warn when full-prompt boundary count exceeds GDN slot count.", + ) + parser.add_argument("--strict", action="store_true") + args = parser.parse_args() + + summary = audit( + artifact=Path(args.artifact).expanduser().resolve(), + compile_log=args.compile_log.expanduser().resolve() + if args.compile_log is not None + else None, + recommended_block_size=args.recommended_block_size, + min_usable_headroom_blocks=args.min_usable_headroom_blocks, + strict_hybrid_gate=not args.no_strict_hybrid_gate, + ) + print(json.dumps(summary, indent=2, sort_keys=True)) + return 1 if args.strict and summary["warnings"] else 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/validation_scripts/qwen36_bf16_length_sweep.py b/validation_scripts/qwen36_bf16_length_sweep.py new file mode 100644 index 00000000..2540ea77 --- /dev/null +++ b/validation_scripts/qwen36_bf16_length_sweep.py @@ -0,0 +1,120 @@ +#!/usr/bin/env python3 +"""Sweep no-prefix Qwen3.6 BF16 prompt lengths and print raw debug markers.""" + +from __future__ import annotations + +import argparse +import os +import sys +from pathlib import Path +from types import SimpleNamespace + + +def _parse_int_list(values: list[str]) -> list[int]: + out: list[int] = [] + for value in values: + out.extend(int(token) for token in value.replace(",", " ").split()) + return out + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--model-path", required=True) + parser.add_argument("--compiled-artifacts", required=True) + parser.add_argument("--repo-root", default=Path(__file__).resolve().parents[1]) + parser.add_argument("--repeats", nargs="+", default=["0,1,2,4,8,16,24,32,40,48,56,64"]) + parser.add_argument("--line", default="System: answer deterministically.\n") + parser.add_argument("--suffix", default="\nUser: What is 17 * 23?\nAssistant:") + parser.add_argument("--max-tokens", type=int, default=1) + parser.add_argument("--seq-len", type=int, default=2048) + parser.add_argument("--cte-buckets", nargs="+", default=["512"]) + parser.add_argument("--tensor-parallel-size", type=int, default=4) + parser.add_argument("--logical-nc-config", type=int, default=2) + parser.add_argument("--ctx-batch-size", type=int, default=1) + parser.add_argument("--skip-fp8-env", action="store_true") + return parser.parse_args() + + +def main() -> int: + args = parse_args() + repo_root = Path(args.repo_root).expanduser().resolve() + qwen_root = repo_root / "contrib" / "models" / "Qwen3.6-27B" + sys.path.insert(0, str(qwen_root / "vllm")) + sys.path.insert(0, str(qwen_root)) + + os.environ.setdefault("VLLM_NEURON_FRAMEWORK", "neuronx-distributed-inference") + os.environ.setdefault("VLLM_PLUGINS", "neuron") + os.environ["NEURON_COMPILED_ARTIFACTS"] = str( + Path(args.compiled_artifacts).expanduser().resolve() + ) + if not args.skip_fp8_env: + os.environ.setdefault("XLA_HANDLE_SPECIAL_SCALAR", "1") + os.environ.setdefault("UNSAFE_FP8FNCAST", "1") + + from transformers import AutoTokenizer # noqa: WPS433 + from hf_qwen35_config import register_qwen35_config # noqa: WPS433 + import run_offline_inference as runner # noqa: WPS433 + from vllm import LLM, SamplingParams # noqa: WPS433 + + register_qwen35_config() + cte_buckets = [str(bucket) for bucket in _parse_int_list(args.cte_buckets)] + runner_args = SimpleNamespace( + enable_hybrid_apc=False, + hybrid_cache_mode="all", + gdn_checkpoint_interval=256, + max_gdn_checkpoint_slots=8, + block_size=256, + enable_prefix_caching=False, + enable_vllm_chunked_prefill=False, + cte_bucket_profile="single", + cte_buckets=cte_buckets, + cte_bucket=int(cte_buckets[-1]), + seq_len=args.seq_len, + tensor_parallel_size=args.tensor_parallel_size, + max_num_seqs=1, + ctx_batch_size=args.ctx_batch_size, + logical_nc_config=args.logical_nc_config, + hybrid_gdn_recurrent_cache_dtype=None, + gdn_recurrent_cache_dtype="float32", + hybrid_gdn_conv_cache_dtype=None, + gdn_conv_cache_dtype="bfloat16", + kernel_q_tile_size=128, + kernel_kv_tile_size=1024, + text_only_cte=True, + compact_cte_attention_mask=True, + cold_zero_conv_fast_path=False, + hybrid_cache_prefix_boundary_only=True, + hybrid_cache_validate_exact=False, + hybrid_apc_require_vllm_metadata=False, + num_gpu_blocks_override=None, + ) + llm = LLM( + model=str(Path(args.model_path).expanduser().resolve()), + trust_remote_code=True, + dtype="bfloat16", + tensor_parallel_size=args.tensor_parallel_size, + max_num_seqs=1, + max_model_len=args.seq_len, + enable_prefix_caching=False, + enable_chunked_prefill=False, + additional_config=runner._override_config(runner_args), + ) + sampling = SamplingParams(temperature=0.0, top_k=1, max_tokens=args.max_tokens) + tokenizer = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True) + + for repeats in _parse_int_list(args.repeats): + prompt = args.line * repeats + args.suffix + token_count = len(tokenizer(prompt).input_ids) + print(f"SWEEP_CASE repeats={repeats} tokens={token_count}", flush=True) + outputs = llm.generate([prompt], sampling) + token_ids = list(outputs[0].outputs[0].token_ids) + print( + f"SWEEP_RESULT repeats={repeats} tokens={token_count} " + f"output_tokens={token_ids}", + flush=True, + ) + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/validation_scripts/qwen36_chat_completion_context_bench.py b/validation_scripts/qwen36_chat_completion_context_bench.py new file mode 100644 index 00000000..0ae195ba --- /dev/null +++ b/validation_scripts/qwen36_chat_completion_context_bench.py @@ -0,0 +1,562 @@ +#!/usr/bin/env python3 +"""Benchmark OpenAI-compatible chat completions across context lengths. + +The benchmark builds deterministic multi-turn chat histories, sends +``/v1/chat/completions`` requests with ``max_tokens=1``, and records wall/TTFT +latency. Streaming is used by default when the server supports it. +""" + +from __future__ import annotations + +import argparse +import json +import time +import urllib.error +import urllib.request +from concurrent.futures import ThreadPoolExecutor +from pathlib import Path +from typing import Any + + +def _parse_lengths(raw: str) -> list[int]: + lengths = [int(item) for item in raw.replace(",", " ").split()] + if not lengths: + raise ValueError("at least one context length is required") + return lengths + + +def _chat_token_count(tokenizer: Any, messages: list[dict[str, str]]) -> int: + try: + token_ids = tokenizer.apply_chat_template( + messages, + tokenize=True, + add_generation_prompt=True, + enable_thinking=False, + ) + except TypeError: + token_ids = tokenizer.apply_chat_template( + messages, + tokenize=True, + add_generation_prompt=True, + ) + if isinstance(token_ids, dict): + token_ids = token_ids.get("input_ids", token_ids) + elif hasattr(token_ids, "input_ids"): + token_ids = token_ids.input_ids + + if ( + isinstance(token_ids, list) + and token_ids + and isinstance(token_ids[0], list) + ): + return len(token_ids[0]) + return len(token_ids) + + +def _base_messages(turns: int, *, salt: str = "") -> list[dict[str, str]]: + messages = [ + { + "role": "system", + "content": ( + "You are a deterministic latency benchmark assistant. " + "Reply with one concise token. " + f"Benchmark salt: {salt}." + ), + } + ] + for idx in range(max(1, turns)): + messages.append( + { + "role": "user", + "content": f"Turn {idx}: remember benchmark key {idx}.", + } + ) + messages.append( + { + "role": "assistant", + "content": f"ack {idx}", + } + ) + messages.append({"role": "user", "content": "Return the next benchmark token."}) + return messages + + +def _make_messages( + tokenizer: Any, + target_tokens: int, + turns: int, + *, + salt: str = "", +) -> tuple[list[dict[str, str]], int]: + messages = _base_messages(turns, salt=salt) + filler_phrase = ( + " latency-prefix alpha beta gamma delta epsilon zeta eta theta iota kappa" + ) + + def set_repeats(repeats: int) -> None: + messages[-1]["content"] = ( + "Return the next benchmark token." + + (filler_phrase * max(0, repeats)) + ) + + set_repeats(0) + base_count = _chat_token_count(tokenizer, messages) + if base_count >= target_tokens: + return messages, base_count + + set_repeats(1) + one_repeat_count = _chat_token_count(tokenizer, messages) + filler_delta = max(1, one_repeat_count - base_count) + repeats = max(0, (target_tokens - base_count) // filler_delta) + + set_repeats(repeats) + prompt_tokens = _chat_token_count(tokenizer, messages) + while repeats > 0 and prompt_tokens > target_tokens: + overshoot = prompt_tokens - target_tokens + repeats = max(0, repeats - max(1, (overshoot // filler_delta) + 1)) + set_repeats(repeats) + prompt_tokens = _chat_token_count(tokenizer, messages) + + return messages, prompt_tokens + + +def _post_json(url: str, payload: dict[str, Any], timeout: float) -> tuple[int, dict[str, Any]]: + body = json.dumps(payload).encode("utf-8") + request = urllib.request.Request( + url, + data=body, + headers={"Content-Type": "application/json"}, + method="POST", + ) + try: + with urllib.request.urlopen(request, timeout=timeout) as response: + return response.status, json.loads(response.read().decode("utf-8")) + except urllib.error.HTTPError as exc: + try: + payload = json.loads(exc.read().decode("utf-8")) + except Exception: + payload = {"error": {"message": str(exc)}} + return exc.code, payload + + +def _completion_tokens_from_usage(usage: Any) -> int | None: + if not isinstance(usage, dict): + return None + completion_tokens = usage.get("completion_tokens") + if completion_tokens is None: + return None + try: + return int(completion_tokens) + except (TypeError, ValueError): + return None + + +def _completion_tokens_from_text(tokenizer: Any, text: str) -> int: + if not text: + return 0 + try: + return len(tokenizer.encode(text, add_special_tokens=False)) + except TypeError: + encoded = tokenizer(text, add_special_tokens=False) + return len(encoded.get("input_ids", [])) + + +def _token_latency_metrics( + *, + total_seconds: float, + ttft_seconds: float | None, + completion_tokens: int | None, + content_chunk_count: int | None, +) -> dict[str, Any]: + content_chunk_tpot_seconds = ( + (total_seconds - ttft_seconds) / (content_chunk_count - 1) + if ttft_seconds is not None + and content_chunk_count is not None + and content_chunk_count > 1 + else None + ) + completion_tokens_per_second = ( + completion_tokens / total_seconds + if completion_tokens is not None + and completion_tokens > 0 + and total_seconds > 0 + else None + ) + decode_elapsed_seconds = ( + total_seconds - ttft_seconds + if ttft_seconds is not None and total_seconds >= ttft_seconds + else None + ) + token_tpot_seconds = ( + decode_elapsed_seconds / (completion_tokens - 1) + if decode_elapsed_seconds is not None + and completion_tokens is not None + and completion_tokens > 1 + else None + ) + decode_tokens_per_second = ( + (completion_tokens - 1) / decode_elapsed_seconds + if decode_elapsed_seconds is not None + and decode_elapsed_seconds > 0 + and completion_tokens is not None + and completion_tokens > 1 + else None + ) + return { + "tpot_seconds": token_tpot_seconds, + "token_tpot_seconds": token_tpot_seconds, + "content_chunk_tpot_seconds": content_chunk_tpot_seconds, + "decode_elapsed_seconds": decode_elapsed_seconds, + "decode_tokens_per_second": decode_tokens_per_second, + "completion_tokens_per_second": completion_tokens_per_second, + } + + +def _response_text(response: dict[str, Any]) -> str: + choices = response.get("choices") if isinstance(response, dict) else None + if not isinstance(choices, list) or not choices: + return "" + first = choices[0] + if not isinstance(first, dict): + return "" + message = first.get("message") + if isinstance(message, dict): + content = message.get("content") + return "" if content is None else str(content) + text = first.get("text") + return "" if text is None else str(text) + + +def _stream_chat( + url: str, + payload: dict[str, Any], + timeout: float, +) -> tuple[ + int, + float | None, + float, + list[str], + int, + str, + dict[str, Any] | None, + dict[str, Any] | None, +]: + payload = dict(payload) + payload["stream"] = True + payload["stream_options"] = {"include_usage": True} + body = json.dumps(payload).encode("utf-8") + request = urllib.request.Request( + url, + data=body, + headers={"Content-Type": "application/json"}, + method="POST", + ) + start = time.perf_counter() + chunks: list[str] = [] + content_parts: list[str] = [] + usage_payload = None + first_content_seconds = None + content_chunk_count = 0 + try: + with urllib.request.urlopen(request, timeout=timeout) as response: + status = response.status + for raw_line in response: + line = raw_line.decode("utf-8", errors="replace").strip() + if not line or not line.startswith("data:"): + continue + data = line[len("data:") :].strip() + if data == "[DONE]": + break + try: + payload_chunk = json.loads(data) + usage = payload_chunk.get("usage") + if isinstance(usage, dict): + usage_payload = usage + choices = payload_chunk.get("choices") or [] + delta = (choices[0].get("delta") or {}) if choices else {} + content = delta.get("content") + except Exception: + content = None + if content: + content_parts.append(str(content)) + content_chunk_count += 1 + if first_content_seconds is None: + first_content_seconds = time.perf_counter() - start + chunks.append(data) + total_seconds = time.perf_counter() - start + return ( + status, + first_content_seconds, + total_seconds, + chunks, + content_chunk_count, + "".join(content_parts), + usage_payload, + None, + ) + except urllib.error.HTTPError as exc: + total_seconds = time.perf_counter() - start + try: + error_payload = json.loads(exc.read().decode("utf-8")) + except Exception: + error_payload = {"error": {"message": str(exc)}} + return ( + exc.code, + first_content_seconds, + total_seconds, + chunks, + content_chunk_count, + "".join(content_parts), + usage_payload, + error_payload, + ) + + +def _run_one( + *, + url: str, + model: str, + messages: list[dict[str, str]], + max_tokens: int, + timeout: float, + stream: bool, + ignore_eos: bool, + tokenizer: Any, +) -> dict[str, Any]: + payload = { + "model": model, + "messages": messages, + "max_tokens": max_tokens, + "temperature": 0, + "chat_template_kwargs": {"enable_thinking": False}, + } + if ignore_eos: + payload["ignore_eos"] = True + if stream: + ( + status, + first_chunk_seconds, + total_seconds, + chunks, + content_chunk_count, + content_text, + usage_payload, + error_payload, + ) = _stream_chat( + url, + payload, + timeout, + ) + completion_tokens = _completion_tokens_from_usage(usage_payload) + completion_token_source = "usage" + if completion_tokens is None: + completion_tokens = _completion_tokens_from_text(tokenizer, content_text) + completion_token_source = "tokenizer" + latency_metrics = _token_latency_metrics( + total_seconds=total_seconds, + ttft_seconds=first_chunk_seconds, + completion_tokens=completion_tokens, + content_chunk_count=content_chunk_count, + ) + if status < 400: + return { + "status": status, + "stream": True, + "ttft_seconds": first_chunk_seconds, + "total_seconds": total_seconds, + "chunk_count": len(chunks), + "content_chunk_count": content_chunk_count, + "completion_tokens": completion_tokens, + "completion_token_source": completion_token_source, + "content_text": content_text, + "usage": usage_payload, + **latency_metrics, + "error": None, + } + return { + "status": status, + "stream": True, + "ttft_seconds": first_chunk_seconds, + "total_seconds": total_seconds, + "chunk_count": len(chunks), + "content_chunk_count": content_chunk_count, + "completion_tokens": completion_tokens, + "completion_token_source": completion_token_source, + "content_text": content_text, + "usage": usage_payload, + **latency_metrics, + "error": error_payload, + } + + start = time.perf_counter() + status, response = _post_json(url, payload, timeout) + total_seconds = time.perf_counter() - start + usage_payload = response.get("usage") if isinstance(response, dict) else None + content_text = _response_text(response) if isinstance(response, dict) else "" + completion_tokens = _completion_tokens_from_usage(usage_payload) + completion_token_source = "usage" + if completion_tokens is None: + completion_tokens = _completion_tokens_from_text(tokenizer, content_text) + completion_token_source = "tokenizer" + return { + "status": status, + "stream": False, + "ttft_seconds": None, + "total_seconds": total_seconds, + "chunk_count": None, + "content_text": content_text, + "completion_tokens": completion_tokens, + "completion_token_source": completion_token_source, + "completion_tokens_per_second": ( + completion_tokens / total_seconds + if completion_tokens > 0 and total_seconds > 0 + else None + ), + "error": None if status < 400 else response, + "usage": usage_payload, + } + + +def _row_passed(row: dict[str, Any], *, max_tokens: int) -> bool: + if int(row["status"]) >= 400: + return False + if row.get("stream") and max_tokens > 0: + return int(row.get("content_chunk_count") or 0) > 0 + return True + + +def main() -> int: + parser = argparse.ArgumentParser() + parser.add_argument("--base-url", default="http://127.0.0.1:8000") + parser.add_argument("--model", default="Qwen3.6-27B") + parser.add_argument("--model-path", required=True) + parser.add_argument( + "--lengths", + default="1024,2048,4096,8192,16384,32768", + help="Comma or space separated target chat-template token lengths.", + ) + parser.add_argument("--turns", type=int, default=8) + parser.add_argument("--repeats", type=int, default=2) + parser.add_argument( + "--concurrency", + type=int, + default=1, + help="Number of concurrent requests per length/repeat group.", + ) + parser.add_argument("--max-tokens", type=int, default=1) + parser.add_argument("--timeout", type=float, default=900.0) + parser.add_argument("--no-stream", action="store_true") + parser.add_argument("--ignore-eos", action="store_true") + parser.add_argument( + "--unique-per-request", + action="store_true", + help="Add a unique system-message salt for each length/repeat.", + ) + parser.add_argument("--output-json", required=True) + args = parser.parse_args() + + from transformers import AutoTokenizer # noqa: WPS433 + + tokenizer = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True) + endpoint = args.base_url.rstrip("/") + "/v1/chat/completions" + results = [] + for target_tokens in _parse_lengths(args.lengths): + for repeat_idx in range(args.repeats): + requests = [] + for concurrency_idx in range(args.concurrency): + salt = ( + f"target={target_tokens};repeat={repeat_idx};" + f"concurrency={concurrency_idx};unique=1" + if args.unique_per_request + else "" + ) + messages, prompt_tokens = _make_messages( + tokenizer, + target_tokens=target_tokens, + turns=args.turns, + salt=salt, + ) + requests.append( + { + "messages": messages, + "prompt_tokens": prompt_tokens, + "concurrency_index": concurrency_idx, + } + ) + + group_start = time.perf_counter() + if args.concurrency == 1: + group_results = [ + _run_one( + url=endpoint, + model=args.model, + messages=requests[0]["messages"], + max_tokens=args.max_tokens, + timeout=args.timeout, + stream=not args.no_stream, + ignore_eos=args.ignore_eos, + tokenizer=tokenizer, + ) + ] + else: + with ThreadPoolExecutor(max_workers=args.concurrency) as executor: + futures = [ + executor.submit( + _run_one, + url=endpoint, + model=args.model, + messages=request["messages"], + max_tokens=args.max_tokens, + timeout=args.timeout, + stream=not args.no_stream, + ignore_eos=args.ignore_eos, + tokenizer=tokenizer, + ) + for request in requests + ] + group_results = [future.result() for future in futures] + group_wall_seconds = time.perf_counter() - group_start + group_prompt_tokens = sum(int(request["prompt_tokens"]) for request in requests) + group_effective_tps = ( + group_prompt_tokens / group_wall_seconds + if group_wall_seconds > 0 + and all(int(result["status"]) < 400 for result in group_results) + else None + ) + + for request, result in zip(requests, group_results): + row = { + "target_tokens": target_tokens, + "prompt_tokens": request["prompt_tokens"], + "repeat": repeat_idx, + "concurrency": args.concurrency, + "concurrency_index": request["concurrency_index"], + "group_wall_seconds": group_wall_seconds, + "group_prompt_tokens": group_prompt_tokens, + "group_effective_prompt_tokens_per_second": group_effective_tps, + **result, + } + print(json.dumps(row, sort_keys=True), flush=True) + results.append(row) + + output = { + "base_url": args.base_url, + "model": args.model, + "lengths": _parse_lengths(args.lengths), + "turns": args.turns, + "repeats": args.repeats, + "concurrency": args.concurrency, + "max_tokens": args.max_tokens, + "ignore_eos": args.ignore_eos, + "passed": all(_row_passed(row, max_tokens=args.max_tokens) for row in results), + "results": results, + } + output_path = Path(args.output_json) + output_path.parent.mkdir(parents=True, exist_ok=True) + with output_path.open("w") as f: + json.dump(output, f, indent=2, sort_keys=True) + return 0 if output["passed"] else 1 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/validation_scripts/qwen36_hf_first_mismatch_logits.py b/validation_scripts/qwen36_hf_first_mismatch_logits.py new file mode 100644 index 00000000..7787a248 --- /dev/null +++ b/validation_scripts/qwen36_hf_first_mismatch_logits.py @@ -0,0 +1,195 @@ +#!/usr/bin/env python3 +"""Inspect HF logits at Neuron/HF first mismatch positions for Qwen3.6.""" + +from __future__ import annotations + +import argparse +import json +import sys +import time +from pathlib import Path +from typing import Any + + +def _insert_hf_ref(path: Path | None) -> None: + if path is not None: + sys.path.insert(0, str(path.expanduser().resolve())) + + +def _rank_of(logits, token_id: int) -> int: + token_logit = logits[token_id] + return int((logits > token_logit).sum().item()) + 1 + + +def _token_entry(tokenizer: Any, token_id: int, logit: float, rank: int) -> dict[str, Any]: + return { + "token_id": int(token_id), + "rank": int(rank), + "logit": float(logit), + "text": tokenizer.decode( + [int(token_id)], + skip_special_tokens=False, + clean_up_tokenization_spaces=False, + ), + } + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser() + parser.add_argument("--model-path", type=Path, required=True) + parser.add_argument("--hf-ref-pkgs", type=Path) + parser.add_argument("--goldens-json", type=Path, required=True) + parser.add_argument("--neuron-json", type=Path, required=True) + parser.add_argument("--output-json", type=Path) + parser.add_argument("--top-k", type=int, default=10) + parser.add_argument("--limit", type=int) + parser.add_argument("--dtype", choices=("bfloat16", "float32"), default="bfloat16") + return parser.parse_args() + + +def main() -> int: + args = parse_args() + _insert_hf_ref(args.hf_ref_pkgs) + + import torch # noqa: WPS433 + from transformers import AutoModelForCausalLM, AutoTokenizer # noqa: WPS433 + + with args.goldens_json.expanduser().open(encoding="utf-8") as handle: + goldens = json.load(handle) + golden_by_index = {int(case["index"]): case for case in goldens["cases"]} + + with args.neuron_json.expanduser().open(encoding="utf-8") as handle: + neuron = json.load(handle) + mismatch_cases = [ + case + for case in neuron["cases"] + if case.get("first_mismatch") is not None + ] + if args.limit is not None: + mismatch_cases = mismatch_cases[: args.limit] + + dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float32 + load_start = time.perf_counter() + tokenizer = AutoTokenizer.from_pretrained( + args.model_path, + trust_remote_code=True, + ) + model = AutoModelForCausalLM.from_pretrained( + args.model_path, + trust_remote_code=True, + torch_dtype=dtype, + low_cpu_mem_usage=True, + ) + model.eval() + load_elapsed = time.perf_counter() - load_start + + rows = [] + with torch.no_grad(): + for neuron_case in mismatch_cases: + case_index = int(neuron_case["index"]) + golden_case = golden_by_index[case_index] + mismatch = neuron_case["first_mismatch"] + position = int(mismatch["position"]) + expected_token = int(mismatch["expected"]) + actual_token = int(mismatch["actual"]) + + prompt_ids = tokenizer.encode( + golden_case["prompt"], + add_special_tokens=False, + ) + prefix_ids = prompt_ids + [ + int(token) + for token in golden_case["hf_generated_tokens"][:position] + ] + input_ids = torch.tensor([prefix_ids], dtype=torch.long) + attention_mask = torch.ones_like(input_ids) + start = time.perf_counter() + outputs = model( + input_ids=input_ids, + attention_mask=attention_mask, + use_cache=False, + ) + elapsed = time.perf_counter() - start + logits = outputs.logits[0, -1].float() + top_values, top_indices = torch.topk(logits, k=args.top_k) + top_tokens = [ + _token_entry( + tokenizer, + int(token_id), + float(logit), + rank + 1, + ) + for rank, (logit, token_id) in enumerate(zip(top_values, top_indices)) + ] + + expected_rank = _rank_of(logits, expected_token) + actual_rank = _rank_of(logits, actual_token) + expected_logit = float(logits[expected_token].item()) + actual_logit = float(logits[actual_token].item()) + top1_logit = float(top_values[0].item()) + top2_logit = float(top_values[1].item()) if args.top_k > 1 else None + row = { + "index": case_index, + "position": position, + "prompt_tokens_reported": int(golden_case.get("prompt_tokens", 0)), + "prompt_tokens_encoded": len(prompt_ids), + "prefix_tokens_evaluated": len(prefix_ids), + "hf_expected": _token_entry( + tokenizer, + expected_token, + expected_logit, + expected_rank, + ), + "neuron_actual": _token_entry( + tokenizer, + actual_token, + actual_logit, + actual_rank, + ), + "top1_minus_expected": top1_logit - expected_logit, + "expected_minus_neuron": expected_logit - actual_logit, + "top1_minus_top2": ( + top1_logit - top2_logit if top2_logit is not None else None + ), + "forward_seconds": elapsed, + "top_tokens": top_tokens, + } + rows.append(row) + print( + json.dumps( + { + "index": row["index"], + "position": row["position"], + "expected_rank": expected_rank, + "actual_rank": actual_rank, + "expected_minus_neuron": row["expected_minus_neuron"], + "top1_minus_expected": row["top1_minus_expected"], + }, + sort_keys=True, + ), + flush=True, + ) + + report = { + "stage": "hf_first_mismatch_logits", + "model_path": str(args.model_path), + "goldens_json": str(args.goldens_json), + "neuron_json": str(args.neuron_json), + "hf_ref_pkgs": str(args.hf_ref_pkgs) if args.hf_ref_pkgs else None, + "dtype": args.dtype, + "top_k": args.top_k, + "hf_load_seconds": load_elapsed, + "num_mismatches_checked": len(rows), + "cases": rows, + } + print(json.dumps(report, indent=2, sort_keys=True), flush=True) + if args.output_json: + args.output_json.expanduser().write_text( + json.dumps(report, indent=2, sort_keys=True) + "\n", + encoding="utf-8", + ) + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/validation_scripts/qwen36_hf_neuron_greedy_match.py b/validation_scripts/qwen36_hf_neuron_greedy_match.py new file mode 100644 index 00000000..222dfb9f --- /dev/null +++ b/validation_scripts/qwen36_hf_neuron_greedy_match.py @@ -0,0 +1,212 @@ +#!/usr/bin/env python3 +"""Compare Qwen3.6 Neuron greedy tokens against saved HF goldens.""" + +from __future__ import annotations + +import argparse +import json +import sys +import time +from pathlib import Path +from typing import Any + + +def _load_bench(repo_root: Path): + sys.path.insert(0, str(repo_root / "validation_scripts")) + import qwen36_offline_decode_bench as bench # noqa: WPS433 + + return bench + + +def _add_runtime_args(parser: argparse.ArgumentParser) -> None: + parser.add_argument("--repo-root", type=Path, default=Path.cwd()) + parser.add_argument("--model-path", type=Path, required=True) + parser.add_argument("--compiled-artifacts", type=Path, required=True) + parser.add_argument("--goldens-json", type=Path, required=True) + parser.add_argument("--output-json", type=Path) + parser.add_argument("--max-tokens", type=int) + parser.add_argument("--limit", type=int) + parser.add_argument("--warmup", action="store_true") + parser.add_argument("--max-model-len", type=int) + parser.add_argument("--seq-len", type=int) + parser.add_argument("--cte-buckets") + parser.add_argument("--context-encoding-bucket-pairs") + parser.add_argument("--token-generation-buckets") + parser.add_argument("--token-generation-batches") + parser.add_argument("--async-mode", action="store_true") + parser.add_argument("--tensor-parallel-size", type=int, default=4) + parser.add_argument("--max-num-seqs", type=int, default=1) + parser.add_argument("--ctx-batch-size", type=int, default=1) + parser.add_argument("--logical-nc-config", type=int, default=2) + parser.add_argument("--block-size", type=int, default=256) + parser.add_argument("--pa-num-blocks", type=int) + parser.add_argument("--gdn-checkpoint-interval", type=int, default=256) + parser.add_argument("--max-gdn-checkpoint-slots", type=int, default=64) + parser.add_argument("--gdn-recurrent-cache-dtype", default="float32") + parser.add_argument("--gdn-conv-cache-dtype", default="bfloat16") + parser.add_argument("--kernel-q-tile-size", type=int, default=128) + parser.add_argument("--kernel-kv-tile-size", type=int, default=1024) + + +def _prepare_args(args: argparse.Namespace, bench: Any) -> None: + args.repo_root = args.repo_root.expanduser().resolve() + args.model_path = args.model_path.expanduser().resolve() + args.compiled_artifacts = args.compiled_artifacts.expanduser().resolve() + resolved = bench._resolve_config_defaults(args) + args.artifact_config = resolved["artifact_config"] + args.seq_len = resolved["seq_len"] + args.max_model_len = resolved["max_model_len"] + args.resolved_cte_buckets = resolved["cte_buckets"] + args.resolved_context_encoding_bucket_pairs = resolved[ + "context_encoding_bucket_pairs" + ] + args.resolved_token_generation_buckets = resolved["token_generation_buckets"] + args.resolved_token_generation_batches = resolved["token_generation_batches"] + args.pa_num_blocks = resolved["pa_num_blocks"] + args.prompt = "" + args.warmup_tokens = min(4, int(args.max_tokens or 4)) + + +def _position_match(expected: list[int], actual: list[int]) -> dict[str, Any]: + total = max(len(expected), len(actual)) + compared = min(len(expected), len(actual)) + matches = sum(1 for left, right in zip(expected, actual) if left == right) + first_mismatch = None + for index in range(total): + left = expected[index] if index < len(expected) else None + right = actual[index] if index < len(actual) else None + if left != right: + first_mismatch = { + "position": index, + "expected": left, + "actual": right, + } + break + return { + "expected_len": len(expected), + "actual_len": len(actual), + "compared_positions": compared, + "total_positions": total, + "matches": matches, + "match_rate": (matches / total) if total else 1.0, + "first_mismatch": first_mismatch, + } + + +def main() -> int: + parser = argparse.ArgumentParser() + _add_runtime_args(parser) + args = parser.parse_args() + args.repo_root = args.repo_root.expanduser().resolve() + bench = _load_bench(args.repo_root) + + with args.goldens_json.expanduser().open(encoding="utf-8") as handle: + goldens = json.load(handle) + cases = list(goldens["cases"]) + if args.limit is not None: + cases = cases[: args.limit] + if args.max_tokens is None: + args.max_tokens = int(goldens.get("max_new_tokens") or 16) + + _prepare_args(args, bench) + bench._ensure_paths(args.repo_root) + bench._ensure_runtime_env(args) + + llm = None + try: + llm, sampling, warmup_sampling = bench._build_llm(args) + if args.warmup and cases: + llm.generate([cases[0]["prompt"]], warmup_sampling) + + from transformers import AutoTokenizer # noqa: WPS433 + + tokenizer = AutoTokenizer.from_pretrained( + args.model_path, + trust_remote_code=True, + ) + rows = [] + start_all = time.perf_counter() + for case in cases: + expected = [ + int(item) + for item in case["hf_generated_tokens"][: args.max_tokens] + ] + start = time.perf_counter() + outputs = llm.generate([case["prompt"]], sampling) + elapsed = time.perf_counter() - start + actual = [int(item) for item in outputs[0].outputs[0].token_ids] + comparison = _position_match(expected, actual) + row = { + "index": int(case["index"]), + "prompt_tokens": int(case.get("prompt_tokens", 0)), + "elapsed_seconds": elapsed, + "tok_s": (len(actual) / elapsed) if elapsed > 0 else None, + "hf_generated_tokens": expected, + "neuron_generated_tokens": actual, + "hf_text": case.get("hf_text", ""), + "neuron_text": tokenizer.decode( + actual, + skip_special_tokens=False, + clean_up_tokenization_spaces=False, + ), + **comparison, + } + rows.append(row) + print( + json.dumps( + { + "index": row["index"], + "match_rate": row["match_rate"], + "matches": row["matches"], + "total_positions": row["total_positions"], + "tok_s": row["tok_s"], + "first_mismatch": row["first_mismatch"], + }, + sort_keys=True, + ), + flush=True, + ) + + elapsed_all = time.perf_counter() - start_all + total_matches = sum(int(row["matches"]) for row in rows) + total_positions = sum(int(row["total_positions"]) for row in rows) + report = { + "stage": "neuron_vs_hf_greedy_match", + "goldens_json": str(args.goldens_json.expanduser()), + "artifact": str(args.compiled_artifacts), + "model_path": str(args.model_path), + "max_tokens": args.max_tokens, + "num_cases": len(rows), + "overall_matches": total_matches, + "overall_positions": total_positions, + "overall_match_rate": ( + total_matches / total_positions if total_positions else 1.0 + ), + "elapsed_seconds": elapsed_all, + "avg_tok_s": ( + sum(float(row["tok_s"]) for row in rows if row["tok_s"]) + / max(1, sum(1 for row in rows if row["tok_s"])) + ), + "pa_num_blocks": args.pa_num_blocks, + "cte_buckets": args.resolved_cte_buckets, + "context_encoding_bucket_pairs": args.resolved_context_encoding_bucket_pairs, + "token_generation_buckets": args.resolved_token_generation_buckets, + "async_mode": args.async_mode, + "cases": rows, + } + print(json.dumps(report, indent=2, sort_keys=True), flush=True) + if args.output_json: + args.output_json.expanduser().write_text( + json.dumps(report, indent=2, sort_keys=True) + "\n", + encoding="utf-8", + ) + finally: + if llm is not None: + shutdown = getattr(llm, "shutdown", None) + if shutdown is not None: + shutdown() + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/validation_scripts/qwen36_hybrid_apc_context_sweep.py b/validation_scripts/qwen36_hybrid_apc_context_sweep.py new file mode 100644 index 00000000..2fd33ed7 --- /dev/null +++ b/validation_scripts/qwen36_hybrid_apc_context_sweep.py @@ -0,0 +1,466 @@ +#!/usr/bin/env python3 +"""Offline context-length sweep for Qwen3.6 Hybrid APC artifacts.""" + +from __future__ import annotations + +import argparse +import json +import time +from pathlib import Path +from types import SimpleNamespace +from typing import Any + +import qwen36_hybrid_apc_validation as hybrid_validation + + +def _parse_lengths(raw: str) -> list[int]: + lengths = [int(item) for item in raw.replace(",", " ").split() if item] + if not lengths: + raise ValueError("at least one length is required") + return lengths + + +def _artifact_neuron_config(compiled_artifacts: Path) -> dict[str, Any]: + config_path = compiled_artifacts / "neuron_config.json" + if not config_path.exists(): + return {} + with config_path.open(encoding="utf-8") as handle: + config = json.load(handle) + nested = config.get("neuron_config") + return nested if isinstance(nested, dict) else config + + +def _runtime_pa_override( + args: argparse.Namespace, + artifact_config: dict[str, Any], + *, + seq_len: int, + max_num_seqs: int, +) -> int: + """Return vLLM's user-intended block count, excluding its null block.""" + + block_size = int(args.block_size) + min_usable_blocks = ((seq_len + block_size - 1) // block_size) * max_num_seqs + if args.pa_num_blocks is not None: + return max(1, int(args.pa_num_blocks)) + + artifact_blocks = int(artifact_config.get("pa_num_blocks") or 0) + if artifact_blocks <= 0: + return max(1, min_usable_blocks) + + uses_block_kv = bool( + artifact_config.get("is_block_kv_layout") + or artifact_config.get("is_prefix_caching") + ) + if uses_block_kv and artifact_blocks > min_usable_blocks: + return artifact_blocks - 1 + return artifact_blocks + + +def _single_token_pool(tokenizer) -> list[int]: + return hybrid_validation._compact_single_token_ids(tokenizer) + + +def _role_token_ids(tokenizer, *, role_index: int, token_count: int) -> list[int]: + if token_count <= 0: + return [] + pool = _single_token_pool(tokenizer) + return [pool[(role_index + (position * 7)) % len(pool)] for position in range(token_count)] + + +def _prompt_parts_for_length( + tokenizer, + *, + target_tokens: int, + suffix_tokens: int, + prefix_role_index: int, + suffix_role_index: int, +) -> tuple[list[int], list[int]]: + if target_tokens <= suffix_tokens: + raise ValueError( + f"target length {target_tokens} must be larger than suffix length {suffix_tokens}" + ) + prefix = _role_token_ids( + tokenizer, + role_index=prefix_role_index, + token_count=target_tokens - suffix_tokens, + ) + suffix = _role_token_ids( + tokenizer, + role_index=suffix_role_index, + token_count=suffix_tokens, + ) + return prefix, suffix + + +def _prompt_for_length( + tokenizer, + *, + target_tokens: int, + suffix_tokens: int, + role_index: int, +) -> dict[str, list[int]]: + prefix, suffix = _prompt_parts_for_length( + tokenizer, + target_tokens=target_tokens, + suffix_tokens=suffix_tokens, + prefix_role_index=role_index, + suffix_role_index=role_index + 997, + ) + return {"prompt_token_ids": prefix + suffix} + + +def _partial_refill_prompts( + tokenizer, + *, + target_tokens: int, + suffix_tokens: int, + role_index: int, +) -> tuple[ + dict[str, list[int]], + dict[str, list[int]], + dict[str, list[int]], + int, +]: + shared_prefix, warmup_suffix = _prompt_parts_for_length( + tokenizer, + target_tokens=target_tokens, + suffix_tokens=suffix_tokens, + prefix_role_index=role_index, + suffix_role_index=role_index + 997, + ) + cold_prefix, measured_suffix = _prompt_parts_for_length( + tokenizer, + target_tokens=target_tokens, + suffix_tokens=suffix_tokens, + prefix_role_index=role_index + 2003, + suffix_role_index=role_index + 3001, + ) + warm_prompt = {"prompt_token_ids": shared_prefix + measured_suffix} + warmup_prompt = {"prompt_token_ids": shared_prefix + warmup_suffix} + cold_prompt = {"prompt_token_ids": cold_prefix + measured_suffix} + return cold_prompt, warmup_prompt, warm_prompt, len(shared_prefix) + + +def _generate(llm: Any, sampling: Any, prompt: dict[str, list[int]]) -> dict[str, Any]: + start = time.perf_counter() + outputs = llm.generate([prompt], sampling) + elapsed = time.perf_counter() - start + output = outputs[0].outputs[0] + tokens = [int(token_id) for token_id in output.token_ids] + return { + "elapsed_seconds": elapsed, + "generated_token_count": len(tokens), + "generated_tokens": tokens, + "generated_text": output.text, + } + + +def _effective_vocab_size(model_path: Path, tokenizer: Any) -> int: + sizes = [ + int(size) + for size in ( + getattr(tokenizer, "vocab_size", None), + len(tokenizer), + ) + if size + ] + try: + from transformers import AutoConfig # noqa: WPS433 + + config = AutoConfig.from_pretrained(str(model_path), trust_remote_code=True) + config_vocab_size = getattr(config, "vocab_size", None) + if config_vocab_size: + sizes.append(int(config_vocab_size)) + except Exception: + pass + if not sizes: + raise ValueError("could not determine model/tokenizer vocabulary size") + return max(sizes) + + +def _build_args(args: argparse.Namespace, artifact_config: dict[str, Any]) -> SimpleNamespace: + seq_len = int( + args.seq_len + or artifact_config.get("seq_len") + or artifact_config.get("max_context_length") + or artifact_config.get("max_length") + or 131072 + ) + cte_buckets = args.cte_buckets or ",".join( + str(item) for item in artifact_config.get("context_encoding_buckets", []) + ) + if not cte_buckets: + cte_buckets = "256,512" + token_generation_buckets = args.token_generation_buckets + if token_generation_buckets is None: + artifact_tkg_buckets = artifact_config.get("token_generation_buckets") or [] + if artifact_tkg_buckets: + token_generation_buckets = [ + ",".join(str(item) for item in artifact_tkg_buckets) + ] + token_generation_batches = args.token_generation_batches + if token_generation_batches is None: + artifact_tkg_batches = artifact_config.get("token_generation_batches") or [] + if artifact_tkg_batches: + token_generation_batches = [ + ",".join(str(item) for item in artifact_tkg_batches) + ] + context_encoding_bucket_pairs = args.context_encoding_bucket_pairs + if context_encoding_bucket_pairs is None: + artifact_pairs = artifact_config.get("context_encoding_bucket_pairs") or [] + if artifact_pairs: + context_encoding_bucket_pairs = [ + f"{int(active)}:{int(prefix)}" + for active, prefix in artifact_pairs + ] + async_mode = ( + bool(args.async_mode) + if args.async_mode is not None + else bool(artifact_config.get("async_mode", False)) + ) + ctx_batch_size = int( + args.ctx_batch_size + if args.ctx_batch_size is not None + else artifact_config.get("ctx_batch_size") or 1 + ) + max_num_seqs = int(args.max_num_seqs or 1) + pa_num_blocks = _runtime_pa_override( + args, + artifact_config, + seq_len=seq_len, + max_num_seqs=max_num_seqs, + ) + return SimpleNamespace( + model_path=str(args.model_path), + compiled_artifacts=str(args.compiled_artifacts), + skip_fp8_env=args.skip_fp8_env, + max_model_len=int(args.max_model_len or seq_len), + seq_len=seq_len, + cte_bucket=max(hybrid_validation._parse_bucket_values([cte_buckets])), + cte_buckets=[cte_buckets], + context_encoding_bucket_pairs=context_encoding_bucket_pairs, + cte_bucket_profile="single", + tensor_parallel_size=args.tensor_parallel_size, + max_num_seqs=max_num_seqs, + logical_nc_config=args.logical_nc_config, + ctx_batch_size=ctx_batch_size, + token_generation_buckets=token_generation_buckets, + token_generation_batches=token_generation_batches, + async_mode=async_mode, + block_size=args.block_size, + gdn_checkpoint_interval=args.gdn_checkpoint_interval, + max_gdn_checkpoint_slots=args.max_gdn_checkpoint_slots, + gdn_recurrent_cache_dtype=args.gdn_recurrent_cache_dtype, + gdn_conv_cache_dtype=args.gdn_conv_cache_dtype, + hybrid_apc_require_vllm_metadata=True, + hybrid_apc_reject_unbacked_attention_hits=True, + hybrid_apc_disable_unbacked_prefix_reads=False, + hybrid_apc_enable_backed_prefix_reads=True, + hybrid_apc_prefill_chunk_tokens=args.hybrid_apc_prefill_chunk_tokens, + hybrid_apc_max_backed_prefix_read_len=0, + enable_vllm_chunked_prefill=True, + kernel_q_tile_size=args.kernel_q_tile_size, + kernel_kv_tile_size=args.kernel_kv_tile_size, + num_gpu_blocks_override=pa_num_blocks, + gpu_memory_utilization=args.gpu_memory_utilization, + max_tokens=args.max_tokens, + dummy_token_ids=args.dummy_token_ids, + require_real_tokens=args.require_real_tokens, + ) + + +def main() -> int: + parser = argparse.ArgumentParser() + parser.add_argument("--model-path", type=Path, required=True) + parser.add_argument("--compiled-artifacts", type=Path, required=True) + parser.add_argument("--lengths", default="16384,32768,65536,131000") + parser.add_argument("--output-json", type=Path, required=True) + parser.add_argument("--max-tokens", type=int, default=1) + parser.add_argument("--suffix-tokens", type=int, default=16) + parser.add_argument( + "--warm-mode", + choices=("partial", "exact"), + default="partial", + help=( + "partial warms a shared prefix with one suffix, then measures the " + "same prefix with a different suffix; exact repeats the full prompt." + ), + ) + parser.add_argument("--seq-len", type=int) + parser.add_argument("--max-model-len", type=int) + parser.add_argument("--cte-buckets") + parser.add_argument("--context-encoding-bucket-pairs", nargs="+", default=None) + parser.add_argument("--pa-num-blocks", type=int) + parser.add_argument("--gpu-memory-utilization", type=float) + parser.add_argument("--tensor-parallel-size", type=int, default=4) + parser.add_argument("--max-num-seqs", type=int) + parser.add_argument("--logical-nc-config", type=int, default=2) + parser.add_argument("--ctx-batch-size", type=int) + parser.add_argument("--token-generation-buckets", nargs="+", default=None) + parser.add_argument("--token-generation-batches", nargs="+", default=None) + parser.add_argument("--async-mode", action="store_true", default=None) + parser.add_argument("--block-size", type=int, default=256) + parser.add_argument("--gdn-checkpoint-interval", type=int, default=256) + parser.add_argument("--max-gdn-checkpoint-slots", type=int, default=64) + parser.add_argument("--gdn-recurrent-cache-dtype", default="float32") + parser.add_argument("--gdn-conv-cache-dtype", default="bfloat16") + parser.add_argument("--hybrid-apc-prefill-chunk-tokens", type=int, default=0) + parser.add_argument("--kernel-q-tile-size", type=int, default=128) + parser.add_argument("--kernel-kv-tile-size", type=int, default=1024) + parser.add_argument("--skip-fp8-env", action="store_true") + parser.add_argument("--require-real-tokens", action="store_true") + parser.add_argument("--dummy-token-ids", nargs="+", type=int, default=[0]) + args = parser.parse_args() + + args.model_path = args.model_path.expanduser().resolve() + args.compiled_artifacts = args.compiled_artifacts.expanduser().resolve() + artifact_config = _artifact_neuron_config(args.compiled_artifacts) + runtime_args = _build_args(args, artifact_config) + + from transformers import AutoTokenizer # noqa: WPS433 + from vllm import SamplingParams # noqa: WPS433 + + tokenizer = AutoTokenizer.from_pretrained(str(args.model_path), trust_remote_code=True) + sampling = SamplingParams(temperature=0.0, top_k=1, max_tokens=args.max_tokens) + configured_dummy_ids = {int(token_id) for token_id in args.dummy_token_ids} + dummy_ids = configured_dummy_ids | hybrid_validation._effective_dummy_token_ids( + runtime_args, + tokenizer, + ) + vocab_size = _effective_vocab_size(args.model_path, tokenizer) + llm = None + rows: list[dict[str, Any]] = [] + try: + llm, _unused_sampling = hybrid_validation._build_llm( + runtime_args, + enable_hybrid_apc=True, + ) + for index, target_tokens in enumerate(_parse_lengths(args.lengths)): + if target_tokens + args.max_tokens > runtime_args.seq_len: + raise ValueError( + f"target_tokens + max_tokens exceeds seq_len: " + f"{target_tokens} + {args.max_tokens} > {runtime_args.seq_len}" + ) + role_index = index * 1009 + if args.warm_mode == "exact": + prompt = _prompt_for_length( + tokenizer, + target_tokens=target_tokens, + suffix_tokens=args.suffix_tokens, + role_index=role_index, + ) + cold = _generate(llm, sampling, prompt) + prefix_warmup = None + warm = _generate(llm, sampling, prompt) + actual_prompt_tokens = len(prompt["prompt_token_ids"]) + shared_prefix_tokens = actual_prompt_tokens + else: + ( + cold_prompt, + warmup_prompt, + warm_prompt, + shared_prefix_tokens, + ) = _partial_refill_prompts( + tokenizer, + target_tokens=target_tokens, + suffix_tokens=args.suffix_tokens, + role_index=role_index, + ) + cold = _generate(llm, sampling, cold_prompt) + prefix_warmup = _generate(llm, sampling, warmup_prompt) + warm = _generate(llm, sampling, warm_prompt) + actual_prompt_tokens = len(warm_prompt["prompt_token_ids"]) + generated_tokens = [ + token + for result in (cold, prefix_warmup, warm) + if result is not None + for token in result["generated_tokens"] + ] + non_dummy = [ + token + for token in generated_tokens + if token not in dummy_ids + ] + invalid_token_ids = [ + token + for token in generated_tokens + if token < 0 or token >= vocab_size + ] + unique_generated_tokens = sorted(set(generated_tokens)) + row = { + "target_prompt_tokens": target_tokens, + "actual_prompt_tokens": actual_prompt_tokens, + "warm_mode": args.warm_mode, + "shared_prefix_tokens": shared_prefix_tokens, + "suffix_tokens": args.suffix_tokens, + "max_tokens": args.max_tokens, + "cold": cold, + "prefix_warmup": prefix_warmup, + "warm": warm, + "repeat_exact": cold["generated_tokens"] == warm["generated_tokens"], + "real_tokens_passed": bool(non_dummy), + "non_dummy_generated_token_count": len(non_dummy), + "all_generated_tokens_dummy": bool(generated_tokens) + and all(token in dummy_ids for token in generated_tokens), + "unique_generated_token_count": len(unique_generated_tokens), + "unique_generated_tokens": unique_generated_tokens, + "configured_dummy_token_ids": sorted(configured_dummy_ids), + "effective_dummy_token_ids": sorted(dummy_ids), + "token_range_passed": not invalid_token_ids, + "invalid_token_ids": sorted(set(invalid_token_ids)), + "vocab_size": vocab_size, + "cold_effective_prompt_tokens_per_second": target_tokens + / cold["elapsed_seconds"] + if cold["elapsed_seconds"] > 0 + else None, + "warm_effective_prompt_tokens_per_second": target_tokens + / warm["elapsed_seconds"] + if warm["elapsed_seconds"] > 0 + else None, + } + print(json.dumps(row, sort_keys=True), flush=True) + rows.append(row) + finally: + if llm is not None: + hybrid_validation._shutdown_llm(llm) + + report = { + "artifact": str(args.compiled_artifacts), + "artifact_neuron_config": { + key: artifact_config.get(key) + for key in ( + "seq_len", + "max_context_length", + "context_encoding_buckets", + "prefix_buckets", + "token_generation_buckets", + "ctx_batch_size", + "tkg_batch_size", + "pa_block_size", + "pa_num_blocks", + "output_logits", + "on_device_sampling_config", + ) + }, + "lengths": _parse_lengths(args.lengths), + "warm_mode": args.warm_mode, + "rows": rows, + "passed": all( + (args.warm_mode != "exact" or row["repeat_exact"]) + and row["real_tokens_passed"] + and row["token_range_passed"] + for row in rows + ), + } + args.output_json.expanduser().parent.mkdir(parents=True, exist_ok=True) + args.output_json.expanduser().write_text( + json.dumps(report, indent=2, sort_keys=True) + "\n", + encoding="utf-8", + ) + print(json.dumps(report, indent=2, sort_keys=True), flush=True) + return 0 if report["passed"] else 1 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/validation_scripts/qwen36_hybrid_apc_validation.py b/validation_scripts/qwen36_hybrid_apc_validation.py new file mode 100644 index 00000000..dcbbc182 --- /dev/null +++ b/validation_scripts/qwen36_hybrid_apc_validation.py @@ -0,0 +1,1475 @@ +#!/usr/bin/env python3 +"""Trainium validation harness for Qwen3.6 hybrid APC. + +This script is intentionally separate from unit tests because it expects a +Neuron/vLLM runtime and compiled artifacts. It covers two gates: + +* token exactness for cold vs warm full-prefix and partial-prefix reuse; +* HBM planning for GDN checkpoint slot budgets. +""" + +from __future__ import annotations + +import argparse +import gc +import importlib.util +import json +import multiprocessing +import os +import queue +import sys +import time +import traceback +from pathlib import Path +from types import SimpleNamespace +from typing import Any + + +REPO_ROOT = Path(__file__).resolve().parents[1] +QWEN_ROOT = REPO_ROOT / "contrib" / "models" / "Qwen3.6-27B" +RUNNER_PATH = QWEN_ROOT / "vllm" / "run_offline_inference.py" +HYBRID_APC_PATH = QWEN_ROOT / "src" / "hybrid_apc.py" +FP8_ENV_DEFAULTS = { + "XLA_HANDLE_SPECIAL_SCALAR": "1", + "UNSAFE_FP8FNCAST": "1", +} +COMPACT_SINGLE_TOKEN_PIECES = [ + " one", + " two", + " three", + " four", + " five", + " six", + " seven", + " eight", + " nine", + " ten", + " alpha", + " beta", + " gamma", + " delta", + " token", +] + + +def _ensure_fp8_environment() -> None: + for name, value in FP8_ENV_DEFAULTS.items(): + os.environ.setdefault(name, value) + + +def _load_module(name: str, path: Path): + spec = importlib.util.spec_from_file_location(name, path) + module = importlib.util.module_from_spec(spec) + sys.modules[name] = module + spec.loader.exec_module(module) + return module + + +def _compiled_neuron_config(args) -> dict: + if not args.compiled_artifacts: + return {} + config_path = Path(args.compiled_artifacts).expanduser() / "neuron_config.json" + if not config_path.exists(): + return {} + with config_path.open(encoding="utf-8") as handle: + config = json.load(handle) + nested = config.get("neuron_config") + return nested if isinstance(nested, dict) else config + + +def _align_additional_config_to_compiled_artifact( + args, + additional_config: dict, +) -> dict: + """Keep runtime additional_config compatible with precompiled artifacts. + + The Qwen chunked-prefill runner uses CTE buckets for active prefill chunk + shapes, but vLLM-Neuron validates the top-level max_prompt_length against + the artifact's compiled max_context_length when loading precompiled NEFFs. + """ + + compiled_config = _compiled_neuron_config(args) + if not compiled_config: + return additional_config + compiled_max_prompt = int( + compiled_config.get("max_context_length") + or compiled_config.get("max_length") + or compiled_config.get("seq_len") + or 0 + ) + if compiled_max_prompt <= 0: + return additional_config + + aligned = dict(additional_config) + aligned["max_prompt_length"] = compiled_max_prompt + override = dict(aligned.get("override_neuron_config") or {}) + override["max_context_length"] = compiled_max_prompt + if ( + "context_encoding_bucket_pairs" not in override + and compiled_config.get("context_encoding_bucket_pairs") is not None + ): + override["context_encoding_bucket_pairs"] = compiled_config[ + "context_encoding_bucket_pairs" + ] + aligned["override_neuron_config"] = override + return aligned + + +def _validate_generation_batch_support(args) -> None: + if args.max_tokens <= 0 or args.max_num_seqs <= 1: + return + neuron_config = _compiled_neuron_config(args) + if not neuron_config: + return + tkg_batch_size = int( + neuron_config.get("tkg_batch_size") + or neuron_config.get("batch_size") + or neuron_config.get("max_batch_size") + or 1 + ) + if args.max_num_seqs > tkg_batch_size: + raise ValueError( + "batched generation requires a compiled artifact with " + f"tkg_batch_size >= --max-num-seqs; got tkg_batch_size={tkg_batch_size} " + f"and max_num_seqs={args.max_num_seqs}" + ) + ctx_batch_size = int( + neuron_config.get("ctx_batch_size") + or neuron_config.get("batch_size") + or neuron_config.get("max_batch_size") + or 1 + ) + if args.max_num_seqs > ctx_batch_size: + raise ValueError( + "batched generation requires a compiled artifact with " + f"ctx_batch_size >= --max-num-seqs for grouped prefill host logits; " + f"got ctx_batch_size={ctx_batch_size} and max_num_seqs={args.max_num_seqs}" + ) + + +def _parse_bucket_values(values) -> list[int]: + if isinstance(values, str): + values = [values] + buckets = [] + for value in values: + for part in str(value).split(","): + part = part.strip() + if part: + buckets.append(int(part)) + return sorted(set(buckets)) + + +def _next_bucket(token_count: int, buckets: list[int]) -> int: + for bucket in buckets: + if token_count <= bucket: + return bucket + raise ValueError( + f"prompt token length {token_count} exceeds compiled CTE buckets {buckets}" + ) + + +def _padding_token_id(tokenizer) -> int: + for token_id in (tokenizer.pad_token_id, tokenizer.eos_token_id): + if token_id is not None: + return int(token_id) + raise ValueError("tokenizer must define a pad_token_id or eos_token_id") + + +def _maybe_bucket_align_labeled_prompts(args, labeled_prompts): + if not getattr(args, "align_prompts_to_cte_buckets", False): + return labeled_prompts + + from transformers import AutoTokenizer # noqa: WPS433 + + tokenizer = AutoTokenizer.from_pretrained( + str(Path(args.model_path).expanduser().resolve()), + trust_remote_code=True, + ) + buckets = _parse_bucket_values(args.cte_buckets) + pad_token_id = _padding_token_id(tokenizer) + aligned = [] + for label, prompt in labeled_prompts: + if isinstance(prompt, dict): + prompt_token_ids = list(prompt.get("prompt_token_ids", [])) + else: + prompt_token_ids = tokenizer.encode(prompt, add_special_tokens=False) + bucket = _next_bucket(len(prompt_token_ids), buckets) + aligned.append( + ( + label, + { + "prompt_token_ids": prompt_token_ids + + [pad_token_id] * (bucket - len(prompt_token_ids)), + }, + ) + ) + return aligned + + +def _runner_args(args, *, enable_hybrid_apc: bool): + return SimpleNamespace( + cte_bucket=args.cte_bucket, + cte_buckets=args.cte_buckets, + context_encoding_bucket_pairs=getattr( + args, "context_encoding_bucket_pairs", None + ), + cte_bucket_profile=args.cte_bucket_profile, + seq_len=args.seq_len, + tensor_parallel_size=args.tensor_parallel_size, + max_num_seqs=args.max_num_seqs, + ctx_batch_size=args.ctx_batch_size, + logical_nc_config=args.logical_nc_config, + block_size=args.block_size, + num_gpu_blocks_override=args.num_gpu_blocks_override, + enable_prefix_caching=enable_hybrid_apc, + enable_hybrid_apc=enable_hybrid_apc, + enable_vllm_chunked_prefill=args.enable_vllm_chunked_prefill, + token_generation_buckets=getattr(args, "token_generation_buckets", None), + token_generation_batches=getattr(args, "token_generation_batches", None), + async_mode=getattr(args, "async_mode", False), + kernel_q_tile_size=args.kernel_q_tile_size, + kernel_kv_tile_size=args.kernel_kv_tile_size, + hybrid_gdn_recurrent_cache_dtype=None, + gdn_recurrent_cache_dtype=args.gdn_recurrent_cache_dtype, + hybrid_gdn_conv_cache_dtype=None, + gdn_conv_cache_dtype=args.gdn_conv_cache_dtype, + gdn_checkpoint_interval=args.gdn_checkpoint_interval, + max_gdn_checkpoint_slots=args.max_gdn_checkpoint_slots, + hybrid_cache_mode="all", + hybrid_cache_prefix_boundary_only=True, + hybrid_cache_validate_exact=True, + hybrid_apc_require_vllm_metadata=getattr( + args, "hybrid_apc_require_vllm_metadata", False + ), + hybrid_apc_reject_unbacked_attention_hits=getattr( + args, + "hybrid_apc_reject_unbacked_attention_hits", + True, + ), + hybrid_apc_disable_unbacked_prefix_reads=getattr( + args, + "hybrid_apc_disable_unbacked_prefix_reads", + False, + ), + hybrid_apc_enable_backed_prefix_reads=getattr( + args, + "hybrid_apc_enable_backed_prefix_reads", + False, + ), + hybrid_apc_prefill_chunk_tokens=getattr( + args, + "hybrid_apc_prefill_chunk_tokens", + 0, + ), + hybrid_apc_max_backed_prefix_read_len=getattr( + args, + "hybrid_apc_max_backed_prefix_read_len", + 0, + ), + text_only_cte=True, + compact_cte_attention_mask=True, + cold_zero_conv_fast_path=False, + ) + + +def _build_llm(args, *, enable_hybrid_apc: bool): + sys.path.insert(0, str(REPO_ROOT / "src")) + sys.path.insert(0, str(QWEN_ROOT / "vllm")) + sys.path.insert(0, str(QWEN_ROOT)) + os.environ["PYTHONPATH"] = ( + f"{REPO_ROOT / 'src'}:{QWEN_ROOT / 'vllm'}:{QWEN_ROOT}:" + f"{os.environ.get('PYTHONPATH', '')}" + ) + os.environ.setdefault("VLLM_NEURON_FRAMEWORK", "neuronx-distributed-inference") + os.environ.setdefault("VLLM_PLUGINS", "neuron") + if enable_hybrid_apc: + os.environ.setdefault("QWEN36_HYBRID_APC_INSTALL_PATCH", "1") + if args.enable_vllm_chunked_prefill: + os.environ["DISABLE_NEURON_CUSTOM_SCHEDULER"] = "1" + if args.hybrid_apc_disable_unbacked_prefix_reads: + os.environ["QWEN36_HYBRID_APC_DISABLE_UNBACKED_PREFIX_READS"] = "1" + if args.compiled_artifacts: + os.environ["NEURON_COMPILED_ARTIFACTS"] = str( + Path(args.compiled_artifacts).expanduser().resolve() + ) + if not args.skip_fp8_env: + _ensure_fp8_environment() + + runner = _load_module("qwen36_run_offline_inference_validation", RUNNER_PATH) + from hf_qwen35_config import register_qwen35_config # noqa: WPS433 + from qwen36_hybrid_apc_scheduler_patch import ( # noqa: WPS433 + install_import_hook as install_hybrid_apc_scheduler_patch, + ) + + register_qwen35_config() + install_hybrid_apc_scheduler_patch() + from vllm import LLM, SamplingParams # noqa: WPS433 + + runner_args = _runner_args(args, enable_hybrid_apc=enable_hybrid_apc) + additional_config = _align_additional_config_to_compiled_artifact( + args, + runner._override_config(runner_args), + ) + llm_kwargs = { + "model": str(Path(args.model_path).expanduser().resolve()), + "trust_remote_code": True, + "dtype": "bfloat16", + "tensor_parallel_size": args.tensor_parallel_size, + "max_num_seqs": args.max_num_seqs, + "max_model_len": args.max_model_len, + "enable_prefix_caching": enable_hybrid_apc, + "enable_chunked_prefill": args.enable_vllm_chunked_prefill, + "additional_config": additional_config, + # vLLM multiplies its default CPU swap space by tensor_parallel_size. + # Neuron validation runs with large TP counts and no CUDA swap path, so + # the default can exceed host RAM before the Neuron model is loaded. + "swap_space": 0, + } + gpu_memory_utilization = getattr(args, "gpu_memory_utilization", None) + if gpu_memory_utilization is not None: + llm_kwargs["gpu_memory_utilization"] = float(gpu_memory_utilization) + if enable_hybrid_apc or args.enable_vllm_chunked_prefill: + llm_kwargs["block_size"] = args.block_size + if enable_hybrid_apc: + llm_kwargs["mamba_cache_mode"] = "all" + recurrent_cache_dtype = str(args.gdn_recurrent_cache_dtype).lower() + if recurrent_cache_dtype in {"bfloat16", "bf16"}: + recurrent_cache_dtype = "auto" + llm_kwargs["mamba_ssm_cache_dtype"] = recurrent_cache_dtype + if args.enable_vllm_chunked_prefill: + llm_kwargs["max_num_batched_tokens"] = runner._max_num_batched_tokens( + runner_args, + runner._cte_buckets(runner_args), + ) + if ( + runner_args.enable_prefix_caching + or runner_args.enable_hybrid_apc + or runner_args.enable_vllm_chunked_prefill + ): + llm_kwargs["num_gpu_blocks_override"] = runner._pa_num_blocks(runner_args) + sampling = SamplingParams(temperature=0.0, top_k=1, max_tokens=args.max_tokens) + return LLM(**llm_kwargs), sampling + + +def _generate(llm, sampling, prompt: str): + start = time.perf_counter() + outputs = llm.generate([prompt], sampling) + elapsed = time.perf_counter() - start + token_ids = list(outputs[0].outputs[0].token_ids) + return {"tokens": token_ids, "elapsed_seconds": elapsed} + + +def _generate_many(llm, sampling, labeled_prompts): + start = time.perf_counter() + outputs = llm.generate([prompt for _label, prompt in labeled_prompts], sampling) + elapsed = time.perf_counter() - start + return { + label: { + "tokens": list(output.outputs[0].token_ids), + "elapsed_seconds": elapsed, + } + for (label, _prompt), output in zip(labeled_prompts, outputs) + } + + +def _shutdown_llm(llm) -> None: + if llm is None: + return + for target in ( + llm, + getattr(llm, "llm_engine", None), + getattr(getattr(llm, "llm_engine", None), "engine_core", None), + getattr(getattr(llm, "llm_engine", None), "engine_core_client", None), + ): + shutdown = getattr(target, "shutdown", None) + if shutdown is None: + continue + try: + shutdown() + except Exception: + pass + del llm + gc.collect() + + +def _generate_batch_worker(args_dict, enable_hybrid_apc: bool, labeled_prompts, result_queue): + llm = None + try: + args = argparse.Namespace(**args_dict) + llm, sampling = _build_llm(args, enable_hybrid_apc=enable_hybrid_apc) + labeled_prompts = _maybe_bucket_align_labeled_prompts(args, labeled_prompts) + results = {} + for label, prompt in labeled_prompts: + if os.environ.get("QWEN36_HYBRID_APC_DEBUG") == "1": + prompt_len = ( + len(prompt.get("prompt_token_ids", [])) + if isinstance(prompt, dict) + else len(prompt) + ) + print( + "[hybrid_apc_debug] generate " + f"label={label} enable_hybrid_apc={enable_hybrid_apc} " + f"prompt_len={prompt_len}", + flush=True, + ) + results[label] = _generate(llm, sampling, prompt) + result_queue.put({"ok": True, "results": results}) + except BaseException: + result_queue.put({"ok": False, "traceback": traceback.format_exc()}) + finally: + _shutdown_llm(llm) + + +def _generate_batch(args, *, enable_hybrid_apc: bool, labeled_prompts): + ctx = multiprocessing.get_context("spawn") + result_queue = ctx.Queue() + proc = ctx.Process( + target=_generate_batch_worker, + args=(vars(args), enable_hybrid_apc, labeled_prompts, result_queue), + ) + proc.start() + proc.join() + + try: + message = result_queue.get(timeout=1.0) + except queue.Empty as exc: + raise RuntimeError( + f"generation worker exited with code {proc.exitcode} without a report" + ) from exc + if not message["ok"]: + raise RuntimeError(message["traceback"]) + if proc.exitcode not in (0, None): + raise RuntimeError(f"generation worker exited with code {proc.exitcode}") + return message["results"] + + +def _generate_grouped_batch_worker( + args_dict, + enable_hybrid_apc: bool, + labeled_prompt_groups, + result_queue, +): + llm = None + try: + args = argparse.Namespace(**args_dict) + llm, sampling = _build_llm(args, enable_hybrid_apc=enable_hybrid_apc) + results = {} + for group in labeled_prompt_groups: + group = _maybe_bucket_align_labeled_prompts(args, group) + if os.environ.get("QWEN36_HYBRID_APC_DEBUG") == "1": + print( + "[hybrid_apc_debug] generate-group " + f"labels={[label for label, _prompt in group]} " + f"enable_hybrid_apc={enable_hybrid_apc}", + flush=True, + ) + if len(group) == 1: + label, prompt = group[0] + results[label] = _generate(llm, sampling, prompt) + else: + results.update(_generate_many(llm, sampling, group)) + result_queue.put({"ok": True, "results": results}) + except BaseException: + result_queue.put({"ok": False, "traceback": traceback.format_exc()}) + finally: + _shutdown_llm(llm) + + +def _generate_grouped_batch(args, *, enable_hybrid_apc: bool, labeled_prompt_groups): + ctx = multiprocessing.get_context("spawn") + result_queue = ctx.Queue() + proc = ctx.Process( + target=_generate_grouped_batch_worker, + args=(vars(args), enable_hybrid_apc, labeled_prompt_groups, result_queue), + ) + proc.start() + proc.join() + + try: + message = result_queue.get(timeout=1.0) + except queue.Empty as exc: + raise RuntimeError( + f"generation worker exited with code {proc.exitcode} without a report" + ) from exc + if not message["ok"]: + raise RuntimeError(message["traceback"]) + if proc.exitcode not in (0, None): + raise RuntimeError(f"generation worker exited with code {proc.exitcode}") + return message["results"] + + +def _token_check(label: str, result: dict, dummy_token_ids: set[int]) -> dict: + tokens = [int(token) for token in result.get("tokens", [])] + unique_token_ids = sorted(set(tokens)) + non_dummy_tokens = [token for token in tokens if token not in dummy_token_ids] + passed = bool(non_dummy_tokens) + check = { + "label": label, + "generated_token_count": len(tokens), + "unique_token_ids": unique_token_ids, + "dummy_token_ids": sorted(dummy_token_ids), + "non_dummy_token_count": len(non_dummy_tokens), + "passed": passed, + } + if not passed: + check["failure"] = "generated tokens are empty or all configured dummy tokens" + return check + + +def _real_token_checks(results_by_label: dict[str, dict], dummy_token_ids: set[int]) -> dict: + checks = { + label: _token_check(label, result, dummy_token_ids) + for label, result in sorted(results_by_label.items()) + } + return { + "passed": bool(checks) and all(check["passed"] for check in checks.values()), + "checks": checks, + } + + +def _effective_dummy_token_ids(args, tokenizer=None) -> set[int]: + configured = {int(token_id) for token_id in args.dummy_token_ids} + if tokenizer is None or configured != {0}: + return configured + special_ids = { + int(token_id) + for token_id in ( + getattr(tokenizer, "pad_token_id", None), + getattr(tokenizer, "eos_token_id", None), + ) + if token_id is not None + } + return special_ids or configured + + +def _prompt_token_count(prompt: Any) -> int: + if isinstance(prompt, dict): + return len(prompt.get("prompt_token_ids", [])) + return len(str(prompt)) + + +def _token_prompt(prefix_ids: list[int], suffix_ids: list[int] | None = None) -> dict: + return {"prompt_token_ids": list(prefix_ids) + list(suffix_ids or [])} + + +def _compact_boundary_lengths(args) -> list[int]: + if getattr(args, "compact_boundary_lens", None): + candidates = _parse_bucket_values(args.compact_boundary_lens) + else: + block_size = int(args.block_size) + candidates = [ + block_size - 1, + block_size, + block_size + 1, + (2 * block_size) - 1, + 2 * block_size, + (2 * block_size) + 1, + ] + max_suffix = max(1, int(getattr(args, "compact_suffix_tokens", 16))) + max_prompt_len = max(1, int(args.seq_len) - max_suffix - max(1, int(args.max_tokens))) + return [ + prefix_len + for prefix_len in sorted(set(candidates)) + if 0 < prefix_len <= max_prompt_len + ] + + +def _make_prefix_ids(tokenizer, *, label: str, target_len: int) -> list[int]: + seed = f"System {label}: answer deterministically.\n" + text = seed + ids = tokenizer.encode(text, add_special_tokens=False) + while len(ids) < target_len: + text += seed + ids = tokenizer.encode(text, add_special_tokens=False) + return list(ids[:target_len]) + + +def _make_suffix_ids(tokenizer, *, label: str, target_len: int) -> list[int]: + seed = f"\nUser: compact gate suffix {label}. Answer with one token.\nAssistant:" + ids = tokenizer.encode(seed, add_special_tokens=False) + if len(ids) >= target_len: + return list(ids[:target_len]) + pad_piece = tokenizer.encode(" detail", add_special_tokens=False) + if not pad_piece: + raise ValueError("tokenizer returned no tokens for compact suffix padding") + while len(ids) < target_len: + ids.extend(pad_piece) + return list(ids[:target_len]) + + +def _single_token_piece(tokenizer, start_index: int) -> str: + for offset in range(len(COMPACT_SINGLE_TOKEN_PIECES)): + piece = COMPACT_SINGLE_TOKEN_PIECES[ + (int(start_index) + offset) % len(COMPACT_SINGLE_TOKEN_PIECES) + ] + if len(tokenizer.encode(piece, add_special_tokens=False)) == 1: + return piece + raise ValueError("could not find a compact-gate single-token text piece") + + +def _compact_single_token_ids(tokenizer) -> list[int]: + special_ids = { + int(token_id) + for token_id in ( + getattr(tokenizer, "pad_token_id", None), + getattr(tokenizer, "eos_token_id", None), + ) + if token_id is not None + } + ids = [] + for piece in COMPACT_SINGLE_TOKEN_PIECES: + piece_ids = tokenizer.encode(piece, add_special_tokens=False) + if len(piece_ids) == 1 and int(piece_ids[0]) not in special_ids: + token_id = int(piece_ids[0]) + if token_id not in ids: + ids.append(token_id) + if len(ids) < 4: + next_id = max(ids or [0]) + 1 + while len(ids) < 4: + if next_id not in special_ids and next_id not in ids: + ids.append(next_id) + next_id += 1 + return ids + + +def _compact_role_prefix_ids( + tokenizer, + *, + role_index: int, + token_count: int, +) -> list[int]: + """Build a globally unique, tokenizer-stable prefix for one compact-gate role.""" + + token_count = int(token_count) + if token_count <= 0: + return [] + pool = _compact_single_token_ids(tokenizer) + base = len(pool) + header_len = min(token_count, 6) + header = [ + pool[(int(role_index) // (base**offset)) % base] + for offset in range(header_len) + ] + if token_count <= header_len: + return header[:token_count] + body = [ + pool[(int(role_index) + 3 + (position * 7)) % base] + for position in range(token_count - header_len) + ] + return header + body + + +def _repeat_single_token_piece(tokenizer, *, start_index: int, token_count: int) -> str: + piece = _single_token_piece(tokenizer, start_index) + text = piece * int(token_count) + actual = len(tokenizer.encode(text, add_special_tokens=False)) + if actual != int(token_count): + raise ValueError( + "compact-gate tokenizer-stable text construction failed: " + f"wanted {token_count} tokens but built {actual}" + ) + return text + + +def _compact_instruction_suffix_text( + tokenizer, + *, + label: str, + start_index: int, + token_count: int, +) -> str: + tails = ( + " one two three four five", + " one two three", + " one", + ) + tail = None + tail_len = 0 + for candidate in tails: + candidate_len = len(tokenizer.encode(candidate, add_special_tokens=False)) + if candidate_len <= int(token_count): + tail = candidate + tail_len = candidate_len + break + if tail is None: + raise ValueError(f"compact-gate suffix {label!r} cannot fit in {token_count} tokens") + filler_len = int(token_count) - tail_len + filler = ( + _repeat_single_token_piece( + tokenizer, + start_index=start_index, + token_count=filler_len, + ) + if filler_len > 0 + else "" + ) + text = filler + tail + actual = len(tokenizer.encode(text, add_special_tokens=False)) + if actual != int(token_count): + raise ValueError( + "compact-gate instruction suffix construction failed: " + f"wanted {token_count} tokens but built {actual} for {label}" + ) + return text + + +def _stable_text_prompt(tokenizer, prefix: str, suffix: str = "") -> str: + prefix_ids = tokenizer.encode(prefix, add_special_tokens=False) + full = prefix + suffix + full_ids = tokenizer.encode(full, add_special_tokens=False) + if full_ids[: len(prefix_ids)] != prefix_ids: + raise ValueError( + "compact-gate text prompt is not tokenizer-stable at the prefix/suffix " + "boundary" + ) + return full + + +def _compact_case_plan(args, tokenizer) -> dict: + suffix_tokens = int(getattr(args, "compact_suffix_tokens", 16)) + prefill_batch_budget = max(_parse_bucket_values(args.cte_buckets)) + largest_cte_bucket = prefill_batch_budget + cases = [] + for index, prefix_len in enumerate(_compact_boundary_lengths(args)): + prefix_a_ids = _compact_role_prefix_ids( + tokenizer, + role_index=(index * 3), + token_count=prefix_len, + ) + prefix_b_ids = _compact_role_prefix_ids( + tokenizer, + role_index=(index * 3) + 1, + token_count=prefix_len, + ) + cold_prefix_ids = _compact_role_prefix_ids( + tokenizer, + role_index=(index * 3) + 2, + token_count=prefix_len, + ) + suffix_a = _compact_instruction_suffix_text( + tokenizer, + label=f"{prefix_len}-a", + start_index=(index * 5) + 2, + token_count=suffix_tokens, + ) + suffix_b = _compact_instruction_suffix_text( + tokenizer, + label=f"{prefix_len}-b", + start_index=(index * 5) + 3, + token_count=suffix_tokens, + ) + cold_suffix = _compact_instruction_suffix_text( + tokenizer, + label=f"{prefix_len}-cold", + start_index=(index * 5) + 4, + token_count=suffix_tokens, + ) + suffix_a_ids = tokenizer.encode(suffix_a, add_special_tokens=False) + suffix_b_ids = tokenizer.encode(suffix_b, add_special_tokens=False) + cold_suffix_ids = tokenizer.encode(cold_suffix, add_special_tokens=False) + backed_checkpoint_hit = ( + prefix_len % int(args.gdn_checkpoint_interval) == 0 + and ( + int(getattr(args, "hybrid_apc_max_backed_prefix_read_len", 0) or 0) + <= 0 + or prefix_len <= int(args.hybrid_apc_max_backed_prefix_read_len) + ) + ) + speedup_required = backed_checkpoint_hit and prefix_len < largest_cte_bucket + speedup_skip_reason = None + if backed_checkpoint_hit and not speedup_required: + speedup_skip_reason = ( + "restore prefix reaches largest CTE bucket; current artifacts " + "cannot prove grouped warm speedup for this boundary" + ) + warm_partial_active_len = suffix_tokens if backed_checkpoint_hit else ( + prefix_len + suffix_tokens + ) + cold_mixed_active_len = prefix_len + suffix_tokens + cases.append( + { + "case": f"boundary_{prefix_len}", + "prefix_len": prefix_len, + "full_token_len": prefix_len, + "partial_token_len": prefix_len + suffix_tokens, + "full_a": _token_prompt(prefix_a_ids), + "full_b": _token_prompt(prefix_b_ids), + "partial_a": _token_prompt(prefix_a_ids, suffix_a_ids), + "partial_b": _token_prompt(prefix_b_ids, suffix_b_ids), + "mixed_cold": _token_prompt(cold_prefix_ids, cold_suffix_ids), + "full_grouped": (2 * prefix_len) <= prefill_batch_budget, + "partial_grouped": ( + 2 * warm_partial_active_len + ) + <= prefill_batch_budget, + "mixed_grouped": ( + warm_partial_active_len + cold_mixed_active_len + ) + <= prefill_batch_budget, + "backed_checkpoint_hit": backed_checkpoint_hit, + "speedup_required": speedup_required, + "speedup_skip_reason": speedup_skip_reason, + } + ) + return { + "boundary_lengths": [case["prefix_len"] for case in cases], + "cases": cases, + } + + +def _compact_exactness_check( + *, + name: str, + cold_label: str, + warm_label: str, + cold_results: dict[str, dict], + warm_results: dict[str, dict], +) -> dict: + cold_tokens = list(cold_results[cold_label]["tokens"]) + warm_tokens = list(warm_results[warm_label]["tokens"]) + return { + "name": name, + "cold_label": cold_label, + "warm_label": warm_label, + "passed": cold_tokens == warm_tokens, + "cold_tokens": cold_tokens, + "warm_tokens": warm_tokens, + } + + +def _compact_speedup_check( + *, + name: str, + cold_labels: list[str], + warm_label: str, + cold_results: dict[str, dict], + warm_results: dict[str, dict], + min_speedup: float, +) -> dict: + cold_serial = sum(float(cold_results[label]["elapsed_seconds"]) for label in cold_labels) + warm_elapsed = float(warm_results[warm_label]["elapsed_seconds"]) + speedup = cold_serial / warm_elapsed if warm_elapsed > 0 else float("inf") + return { + "name": name, + "cold_labels": list(cold_labels), + "warm_label": warm_label, + "cold_serial_seconds": cold_serial, + "warm_group_seconds": warm_elapsed, + "speedup": speedup, + "min_speedup": min_speedup, + "passed": speedup >= min_speedup, + } + + +def _write_report(args, report: dict) -> None: + if args.output_json: + args.output_json.expanduser().write_text( + json.dumps(report, indent=2, sort_keys=True) + "\n", + encoding="utf-8", + ) + print(json.dumps(report, indent=2, sort_keys=True)) + + +def run_exactness(args) -> int: + shared = args.shared_prefix + prompt_a = shared + args.suffix_a + prompt_b = shared + args.suffix_b + + # This validation uses the v3 vLLM APC artifact, which is compiled for + # prefix/block KV layout. Keep prefix metadata enabled even for cold + # references, and isolate each cold prompt in a fresh process so it cannot + # observe cache state from the other reference prompt. + cold_full = _generate_batch( + args, + enable_hybrid_apc=True, + labeled_prompts=[ + ("cold_full", prompt_a), + ], + )["cold_full"] + cold_partial = _generate_batch( + args, + enable_hybrid_apc=True, + labeled_prompts=[ + ("cold_partial", prompt_b), + ], + )["cold_partial"] + warm_results = _generate_batch( + args, + enable_hybrid_apc=True, + labeled_prompts=[ + ("warmup_full", prompt_a), + ("warm_full", prompt_a), + ("warmup_partial", prompt_a), + ("warm_partial", prompt_b), + ], + ) + + warmup_full = warm_results["warmup_full"] + warm_full = warm_results["warm_full"] + warm_partial = warm_results["warm_partial"] + all_results = { + "cold_full": cold_full, + "warmup_full": warmup_full, + "warm_full": warm_full, + "cold_partial": cold_partial, + "warm_partial": warm_partial, + } + real_token_checks = _real_token_checks( + all_results, + _effective_dummy_token_ids(args), + ) + + report = { + "full_prefix_exact": cold_full["tokens"] == warm_full["tokens"], + "partial_prefix_exact": cold_partial["tokens"] == warm_partial["tokens"], + "cold_full": cold_full, + "warmup_full": warmup_full, + "warm_full": warm_full, + "cold_partial": cold_partial, + "warm_partial": warm_partial, + "real_generated_tokens_required": args.require_real_tokens, + "real_generated_tokens_passed": real_token_checks["passed"], + "real_generated_token_checks": real_token_checks["checks"], + "negative_tests": { + "missing_gdn_state_fallback": "requires scheduler fault injection", + "zeroed_conv_state": "requires model debug hook", + }, + } + if args.output_json: + args.output_json.expanduser().write_text( + json.dumps(report, indent=2, sort_keys=True) + "\n", + encoding="utf-8", + ) + print(json.dumps(report, indent=2, sort_keys=True)) + passed = report["full_prefix_exact"] and report["partial_prefix_exact"] + if args.require_real_tokens: + passed = passed and real_token_checks["passed"] + return 0 if passed else 1 + + +def run_batched_exactness(args) -> int: + if not args.shared_prefix_2: + raise ValueError("--shared-prefix-2 is required for batched-exactness") + _validate_generation_batch_support(args) + + prompt_full_a = args.shared_prefix + args.suffix_a + prompt_partial_a = args.shared_prefix + args.suffix_b + prompt_full_b = args.shared_prefix_2 + args.suffix_c + prompt_partial_b = args.shared_prefix_2 + args.suffix_d + + cold_partial_a = _generate_batch( + args, + enable_hybrid_apc=True, + labeled_prompts=[ + ("cold_partial_a", prompt_partial_a), + ], + )["cold_partial_a"] + cold_partial_b = _generate_batch( + args, + enable_hybrid_apc=True, + labeled_prompts=[ + ("cold_partial_b", prompt_partial_b), + ], + )["cold_partial_b"] + warm_results = _generate_grouped_batch( + args, + enable_hybrid_apc=True, + labeled_prompt_groups=[ + [("warmup_full_a", prompt_full_a)], + [("warmup_full_b", prompt_full_b)], + [ + ("warm_partial_a", prompt_partial_a), + ("warm_partial_b", prompt_partial_b), + ], + ], + ) + + all_results = { + "cold_partial_a": cold_partial_a, + "cold_partial_b": cold_partial_b, + **warm_results, + } + real_token_checks = _real_token_checks( + all_results, + _effective_dummy_token_ids(args), + ) + report = { + "batched_partial_a_exact": ( + cold_partial_a["tokens"] == warm_results["warm_partial_a"]["tokens"] + ), + "batched_partial_b_exact": ( + cold_partial_b["tokens"] == warm_results["warm_partial_b"]["tokens"] + ), + "max_num_seqs": args.max_num_seqs, + "cold_partial_a": cold_partial_a, + "cold_partial_b": cold_partial_b, + "warmup_full_a": warm_results["warmup_full_a"], + "warmup_full_b": warm_results["warmup_full_b"], + "warm_partial_a": warm_results["warm_partial_a"], + "warm_partial_b": warm_results["warm_partial_b"], + "real_generated_tokens_required": args.require_real_tokens, + "real_generated_tokens_passed": real_token_checks["passed"], + "real_generated_token_checks": real_token_checks["checks"], + } + if args.output_json: + args.output_json.expanduser().write_text( + json.dumps(report, indent=2, sort_keys=True) + "\n", + encoding="utf-8", + ) + print(json.dumps(report, indent=2, sort_keys=True)) + passed = report["batched_partial_a_exact"] and report["batched_partial_b_exact"] + if args.require_real_tokens: + passed = passed and real_token_checks["passed"] + return 0 if passed else 1 + + +def run_compact_gate(args) -> int: + if not args.hybrid_apc_require_vllm_metadata: + raise ValueError("compact-gate requires --hybrid-apc-require-vllm-metadata") + if not args.hybrid_apc_disable_unbacked_prefix_reads: + raise ValueError( + "compact-gate requires --hybrid-apc-disable-unbacked-prefix-reads" + ) + if not args.hybrid_apc_enable_backed_prefix_reads: + raise ValueError( + "compact-gate requires --hybrid-apc-enable-backed-prefix-reads" + ) + if args.max_num_seqs < 2: + raise ValueError("compact-gate requires --max-num-seqs >= 2") + _validate_generation_batch_support(args) + + from transformers import AutoTokenizer # noqa: WPS433 + + tokenizer = AutoTokenizer.from_pretrained( + str(Path(args.model_path).expanduser().resolve()), + trust_remote_code=True, + ) + plan = _compact_case_plan(args, tokenizer) + cases = plan["cases"] + if not cases: + raise ValueError("compact-gate produced no cases") + + cold_full_prompts = [] + cold_partial_prompts = [] + for case in cases: + name = case["case"] + cold_full_prompts.extend( + [ + (f"cold_full_a__{name}", case["full_a"]), + (f"cold_full_b__{name}", case["full_b"]), + ] + ) + cold_partial_prompts.extend( + [ + (f"cold_partial_a__{name}", case["partial_a"]), + (f"cold_partial_b__{name}", case["partial_b"]), + (f"cold_mixed__{name}", case["mixed_cold"]), + ] + ) + + cold_results = {} + cold_results.update( + _generate_batch( + args, + enable_hybrid_apc=True, + labeled_prompts=cold_full_prompts, + ) + ) + cold_results.update( + _generate_batch( + args, + enable_hybrid_apc=True, + labeled_prompts=cold_partial_prompts, + ) + ) + + warm_groups = [] + for case in cases: + name = case["case"] + warm_groups.extend( + [ + [(f"warmup_full_a__{name}", case["full_a"])], + [(f"warmup_full_b__{name}", case["full_b"])], + ] + ) + if case["full_grouped"]: + warm_groups.append( + [ + (f"warm_full_a__{name}", case["full_a"]), + (f"warm_full_b__{name}", case["full_b"]), + ] + ) + else: + warm_groups.extend( + [ + [(f"warm_full_a__{name}", case["full_a"])], + [(f"warm_full_b__{name}", case["full_b"])], + ] + ) + if case["partial_grouped"]: + warm_groups.append( + [ + (f"warm_partial_a__{name}", case["partial_a"]), + (f"warm_partial_b__{name}", case["partial_b"]), + ] + ) + else: + warm_groups.extend( + [ + [(f"warm_partial_a__{name}", case["partial_a"])], + [(f"warm_partial_b__{name}", case["partial_b"])], + ] + ) + if case["mixed_grouped"]: + warm_groups.append( + [ + (f"mixed_warm_a__{name}", case["partial_a"]), + (f"mixed_cold__{name}", case["mixed_cold"]), + ] + ) + else: + warm_groups.extend( + [ + [(f"mixed_warm_a__{name}", case["partial_a"])], + [(f"mixed_cold__{name}", case["mixed_cold"])], + ] + ) + first_case = cases[0] + warm_groups.append( + [ + ( + f"eviction_probe_partial_a__{first_case['case']}", + first_case["partial_a"], + ) + ] + ) + warm_results = _generate_grouped_batch( + args, + enable_hybrid_apc=True, + labeled_prompt_groups=warm_groups, + ) + + exactness_checks = [] + speedup_checks = [] + speedup_skipped = [] + for case in cases: + name = case["case"] + exactness_checks.extend( + [ + _compact_exactness_check( + name=f"same_full_a__{name}", + cold_label=f"cold_full_a__{name}", + warm_label=f"warm_full_a__{name}", + cold_results=cold_results, + warm_results=warm_results, + ), + _compact_exactness_check( + name=f"same_full_b__{name}", + cold_label=f"cold_full_b__{name}", + warm_label=f"warm_full_b__{name}", + cold_results=cold_results, + warm_results=warm_results, + ), + _compact_exactness_check( + name=f"partial_a__{name}", + cold_label=f"cold_partial_a__{name}", + warm_label=f"warm_partial_a__{name}", + cold_results=cold_results, + warm_results=warm_results, + ), + _compact_exactness_check( + name=f"partial_b__{name}", + cold_label=f"cold_partial_b__{name}", + warm_label=f"warm_partial_b__{name}", + cold_results=cold_results, + warm_results=warm_results, + ), + _compact_exactness_check( + name=f"mixed_warm_a__{name}", + cold_label=f"cold_partial_a__{name}", + warm_label=f"mixed_warm_a__{name}", + cold_results=cold_results, + warm_results=warm_results, + ), + _compact_exactness_check( + name=f"mixed_cold__{name}", + cold_label=f"cold_mixed__{name}", + warm_label=f"mixed_cold__{name}", + cold_results=cold_results, + warm_results=warm_results, + ), + ] + ) + if case["speedup_required"] and case["partial_grouped"]: + speedup_checks.append( + _compact_speedup_check( + name=f"grouped_warm_partials__{name}", + cold_labels=[ + f"cold_partial_a__{name}", + f"cold_partial_b__{name}", + ], + warm_label=f"warm_partial_a__{name}", + cold_results=cold_results, + warm_results=warm_results, + min_speedup=float(args.compact_min_grouped_speedup), + ) + ) + elif case["backed_checkpoint_hit"] and case["partial_grouped"]: + speedup_skipped.append( + { + "name": f"grouped_warm_partials__{name}", + "prefix_len": case["prefix_len"], + "reason": case["speedup_skip_reason"] + or "speedup is not required for this boundary", + } + ) + + exactness_checks.append( + _compact_exactness_check( + name=f"eviction_probe_partial_a__{first_case['case']}", + cold_label=f"cold_partial_a__{first_case['case']}", + warm_label=f"eviction_probe_partial_a__{first_case['case']}", + cold_results=cold_results, + warm_results=warm_results, + ) + ) + + all_results = { + **{f"cold::{label}": result for label, result in cold_results.items()}, + **{f"warm::{label}": result for label, result in warm_results.items()}, + } + real_token_checks = _real_token_checks( + all_results, + _effective_dummy_token_ids(args, tokenizer), + ) + total_requests = len(cold_results) + len(warm_results) + grouped_partial_case_count = sum(1 for case in cases if case["partial_grouped"]) + grouped_mixed_case_count = sum(1 for case in cases if case["mixed_grouped"]) + acceptance = { + "request_count": total_requests, + "min_request_count": args.compact_min_requests, + "request_count_passed": total_requests >= args.compact_min_requests, + "exactness_passed": all(check["passed"] for check in exactness_checks), + "real_generated_tokens_passed": real_token_checks["passed"], + "grouped_partial_case_count": grouped_partial_case_count, + "grouped_partial_coverage_passed": grouped_partial_case_count > 0, + "grouped_mixed_case_count": grouped_mixed_case_count, + "grouped_mixed_coverage_passed": grouped_mixed_case_count > 0, + "speedup_checks_required": len(speedup_checks), + "speedup_passed": bool(speedup_checks) + and all(check["passed"] for check in speedup_checks), + "runtime_exception_free": True, + "eviction_probe_passed": exactness_checks[-1]["passed"], + } + acceptance["passed"] = all( + bool(acceptance[name]) + for name in ( + "request_count_passed", + "exactness_passed", + "real_generated_tokens_passed", + "grouped_partial_coverage_passed", + "grouped_mixed_coverage_passed", + "speedup_passed", + "runtime_exception_free", + "eviction_probe_passed", + ) + ) + report = { + "compact_gate_passed": acceptance["passed"], + "acceptance": acceptance, + "boundary_lengths": plan["boundary_lengths"], + "block_size": args.block_size, + "gdn_checkpoint_interval": args.gdn_checkpoint_interval, + "max_gdn_checkpoint_slots": args.max_gdn_checkpoint_slots, + "max_num_seqs": args.max_num_seqs, + "max_tokens": args.max_tokens, + "compact_suffix_tokens": args.compact_suffix_tokens, + "hybrid_apc_require_vllm_metadata": args.hybrid_apc_require_vllm_metadata, + "hybrid_apc_disable_unbacked_prefix_reads": ( + args.hybrid_apc_disable_unbacked_prefix_reads + ), + "hybrid_apc_enable_backed_prefix_reads": ( + args.hybrid_apc_enable_backed_prefix_reads + ), + "hybrid_apc_max_backed_prefix_read_len": ( + args.hybrid_apc_max_backed_prefix_read_len + ), + "exactness_checks": exactness_checks, + "speedup_checks": speedup_checks, + "speedup_skipped": speedup_skipped, + "real_generated_token_checks": real_token_checks["checks"], + } + _write_report(args, report) + return 0 if acceptance["passed"] else 1 + + +def run_hbm(args) -> int: + hybrid_apc = _load_module("qwen36_hybrid_apc_validation", HYBRID_APC_PATH) + rows = [] + for context_len in args.context_lens: + for interval in args.checkpoint_intervals: + estimate = hybrid_apc.estimate_qwen_hybrid_cache_bytes_per_rank( + max_context_len=context_len, + checkpoint_interval=interval, + recurrent_dtype=args.gdn_recurrent_cache_dtype, + conv_dtype=args.gdn_conv_cache_dtype, + ) + rows.append( + { + "context_len": context_len, + "checkpoint_interval": interval, + "num_gdn_checkpoints": estimate["num_gdn_checkpoints"], + "attention_kv_gib": estimate["attention_kv_bytes"] / 2**30, + "gdn_checkpoint_gib": estimate["gdn_checkpoint_bytes"] / 2**30, + "total_gib": estimate["total_bytes"] / 2**30, + "bytes_per_gdn_checkpoint": estimate["gdn_bytes_per_checkpoint"], + } + ) + print(json.dumps(rows, indent=2, sort_keys=True)) + return 0 + + +def parse_args(): + parser = argparse.ArgumentParser() + subparsers = parser.add_subparsers(dest="command", required=True) + + def add_common_exact_args(exact): + exact.add_argument("--model-path", required=True) + exact.add_argument("--compiled-artifacts") + exact.add_argument( + "--skip-fp8-env", + action="store_true", + help="Do not set FP8 runtime environment defaults for BF16 control artifacts.", + ) + exact.add_argument("--max-model-len", type=int, default=2048) + exact.add_argument("--seq-len", type=int, default=2048) + exact.add_argument("--cte-bucket", type=int, default=512) + exact.add_argument("--cte-buckets", nargs="+", default=["256,512"]) + exact.add_argument("--context-encoding-bucket-pairs", nargs="+", default=None) + exact.add_argument( + "--align-prompts-to-cte-buckets", + action="store_true", + help=( + "Tokenize prompts and pad token ids to the next compiled CTE bucket " + "before calling vLLM. This is useful for static Neuron artifacts " + "that reject non-bucket prompt shapes." + ), + ) + exact.add_argument("--cte-bucket-profile", default="single") + exact.add_argument("--tensor-parallel-size", type=int, default=4) + exact.add_argument("--max-num-seqs", type=int, default=1) + exact.add_argument("--logical-nc-config", type=int, default=2) + exact.add_argument("--ctx-batch-size", type=int, default=1) + exact.add_argument("--token-generation-buckets", nargs="+", default=None) + exact.add_argument("--token-generation-batches", nargs="+", default=None) + exact.add_argument("--async-mode", action="store_true") + exact.add_argument("--block-size", type=int, default=256) + exact.add_argument("--gdn-checkpoint-interval", type=int, default=256) + exact.add_argument("--max-gdn-checkpoint-slots", type=int, default=8) + exact.add_argument("--gdn-recurrent-cache-dtype", default="float32") + exact.add_argument("--gdn-conv-cache-dtype", default="bfloat16") + exact.add_argument("--hybrid-apc-require-vllm-metadata", action="store_true") + exact.add_argument( + "--hybrid-apc-reject-unbacked-attention-hits", + action=argparse.BooleanOptionalAction, + default=True, + ) + exact.add_argument( + "--hybrid-apc-disable-unbacked-prefix-reads", + action=argparse.BooleanOptionalAction, + default=False, + ) + exact.add_argument( + "--hybrid-apc-enable-backed-prefix-reads", + action=argparse.BooleanOptionalAction, + default=False, + ) + exact.add_argument("--hybrid-apc-max-backed-prefix-read-len", type=int, default=0) + exact.add_argument("--enable-vllm-chunked-prefill", action="store_true") + exact.add_argument("--kernel-q-tile-size", type=int, default=128) + exact.add_argument("--kernel-kv-tile-size", type=int, default=1024) + exact.add_argument("--num-gpu-blocks-override", type=int) + exact.add_argument("--max-tokens", type=int, default=32) + exact.add_argument( + "--shared-prefix", + default="System: answer deterministically.\n" * 64, + ) + exact.add_argument("--suffix-a", default="\nUser: What is 17 * 23?\nAssistant:") + exact.add_argument("--suffix-b", default="\nUser: What is 19 * 29?\nAssistant:") + exact.add_argument( + "--require-real-tokens", + action="store_true", + help=( + "Fail exactness if every generated token for any checked request is a " + "configured dummy token." + ), + ) + exact.add_argument( + "--dummy-token-ids", + nargs="+", + type=int, + default=[0], + help="Token ids treated as dummy generated output when --require-real-tokens is set.", + ) + exact.add_argument("--output-json", type=Path) + + exact = subparsers.add_parser("exactness") + add_common_exact_args(exact) + exact.set_defaults(func=run_exactness) + + batched = subparsers.add_parser("batched-exactness") + add_common_exact_args(batched) + batched.add_argument("--shared-prefix-2", required=True) + batched.add_argument("--suffix-c", default="") + batched.add_argument("--suffix-d", default="\nUser: What is 23 * 31?\nAssistant:") + batched.set_defaults(func=run_batched_exactness) + + compact = subparsers.add_parser("compact-gate") + add_common_exact_args(compact) + compact.add_argument( + "--compact-boundary-lens", + nargs="+", + help=( + "Prefix token lengths to test. Defaults to block_size +/- 1 and " + "2*block_size +/- 1." + ), + ) + compact.add_argument( + "--compact-suffix-tokens", + type=int, + default=16, + help="Tokenized suffix length appended to partial-prefix prompts.", + ) + compact.add_argument( + "--compact-min-requests", + type=int, + default=50, + help="Minimum generated request count required for the compact gate.", + ) + compact.add_argument( + "--compact-min-grouped-speedup", + type=float, + default=1.5, + help=( + "Minimum warm grouped throughput speedup required for checkpoint " + "boundary cases." + ), + ) + compact.set_defaults(func=run_compact_gate) + + hbm = subparsers.add_parser("hbm") + hbm.add_argument("--context-lens", nargs="+", type=int, default=[131072, 262144]) + hbm.add_argument("--checkpoint-intervals", nargs="+", type=int, default=[128, 256, 512]) + hbm.add_argument("--gdn-recurrent-cache-dtype", default="float32") + hbm.add_argument("--gdn-conv-cache-dtype", default="bfloat16") + hbm.set_defaults(func=run_hbm) + return parser.parse_args() + + +def main() -> int: + args = parse_args() + return args.func(args) + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/validation_scripts/qwen36_offline_decode_bench.py b/validation_scripts/qwen36_offline_decode_bench.py new file mode 100644 index 00000000..b3d8dd0d --- /dev/null +++ b/validation_scripts/qwen36_offline_decode_bench.py @@ -0,0 +1,446 @@ +#!/usr/bin/env python3 +"""Offline vLLM decode benchmark for Qwen3.6 Neuron artifacts. + +This intentionally bypasses the OpenAI HTTP server while keeping the same +vLLM/NxDI model runner and compiled artifact path. +""" + +from __future__ import annotations + +import argparse +import json +import os +import sys +import time +from pathlib import Path +from typing import Any + + +FP8_ENV_DEFAULTS = { + "XLA_HANDLE_SPECIAL_SCALAR": "1", + "UNSAFE_FP8FNCAST": "1", +} + + +def _parse_buckets(raw: str) -> list[int]: + return [int(item) for item in raw.replace(",", " ").split() if item] + + +def _parse_bucket_pairs(raw: str | None) -> list[list[int]] | None: + if not raw: + return None + pairs: set[tuple[int, int]] = set() + for token in raw.replace(",", " ").split(): + if ":" in token: + active, prefix = token.split(":", 1) + elif "x" in token: + active, prefix = token.split("x", 1) + else: + raise ValueError( + "context-encoding bucket pairs must use ACTIVE:PREFIX syntax, " + f"got {token!r}" + ) + active_tokens, prefix_tokens = int(active), int(prefix) + if active_tokens <= 0 or prefix_tokens < 0: + raise ValueError( + "context-encoding bucket pairs must use positive active tokens " + f"and non-negative prefix tokens, got {token!r}" + ) + pairs.add((active_tokens, prefix_tokens)) + return [[active, prefix] for active, prefix in sorted(pairs)] + + +def _validated_int_list( + values: list[int], + *, + name: str, + maximum: int | None = None, +) -> list[int]: + values = sorted(set(int(item) for item in values)) + if not values: + raise ValueError(f"{name} cannot be empty") + for value in values: + if value <= 0: + raise ValueError(f"{name} values must be positive, got {value}") + if maximum is not None and value > maximum: + raise ValueError(f"{name} value {value} exceeds {maximum}") + return values + + +def _artifact_neuron_config(compiled_artifacts: Path) -> dict[str, Any]: + config_path = compiled_artifacts / "neuron_config.json" + if not config_path.exists(): + return {} + with config_path.open(encoding="utf-8") as handle: + config = json.load(handle) + nested = config.get("neuron_config") + return nested if isinstance(nested, dict) else config + + +def _runtime_pa_override( + args: argparse.Namespace, + artifact_config: dict[str, Any], + *, + max_model_len: int, +) -> int: + """Return vLLM's user-intended block count, excluding its null block.""" + + block_size = int(args.block_size) + max_num_seqs = int(args.max_num_seqs or 1) + min_usable_blocks = ((max_model_len + block_size - 1) // block_size) * max_num_seqs + if args.pa_num_blocks is not None: + return max(1, int(args.pa_num_blocks)) + + artifact_blocks = int(artifact_config.get("pa_num_blocks") or 0) + if artifact_blocks <= 0: + return max(1, min_usable_blocks) + + uses_block_kv = bool( + artifact_config.get("is_block_kv_layout") + or artifact_config.get("is_prefix_caching") + ) + if uses_block_kv and artifact_blocks > min_usable_blocks: + return artifact_blocks - 1 + return artifact_blocks + + +def _resolve_config_defaults(args: argparse.Namespace) -> dict[str, Any]: + artifact_config = _artifact_neuron_config(args.compiled_artifacts) + seq_len = int( + args.seq_len + or artifact_config.get("seq_len") + or artifact_config.get("max_context_length") + or artifact_config.get("max_length") + or 131072 + ) + max_model_len = int(args.max_model_len or seq_len) + cte_buckets = ( + _parse_buckets(args.cte_buckets) + if args.cte_buckets + else [int(item) for item in artifact_config.get("context_encoding_buckets", [])] + ) + if not cte_buckets: + cte_buckets = [256, 512] + cte_buckets = _validated_int_list( + cte_buckets, + name="context encoding buckets", + maximum=seq_len, + ) + token_generation_buckets = ( + _parse_buckets(args.token_generation_buckets) + if args.token_generation_buckets + else [ + int(item) + for item in artifact_config.get("token_generation_buckets", []) + ] + ) + if not token_generation_buckets: + token_generation_buckets = [seq_len] + token_generation_buckets = _validated_int_list( + token_generation_buckets, + name="token generation buckets", + maximum=seq_len, + ) + token_generation_batches = ( + _parse_buckets(args.token_generation_batches) + if args.token_generation_batches + else None + ) + if token_generation_batches is not None: + token_generation_batches = _validated_int_list( + token_generation_batches, + name="token generation batches", + maximum=args.max_num_seqs, + ) + pa_num_blocks = _runtime_pa_override( + args, + artifact_config, + max_model_len=max_model_len, + ) + context_encoding_bucket_pairs = _parse_bucket_pairs( + args.context_encoding_bucket_pairs + ) + if context_encoding_bucket_pairs is None: + artifact_pairs = artifact_config.get("context_encoding_bucket_pairs") or [] + if artifact_pairs: + context_encoding_bucket_pairs = [ + [int(active), int(prefix)] + for active, prefix in artifact_pairs + ] + return { + "artifact_config": artifact_config, + "seq_len": seq_len, + "max_model_len": max_model_len, + "cte_buckets": cte_buckets, + "context_encoding_bucket_pairs": context_encoding_bucket_pairs, + "token_generation_buckets": token_generation_buckets, + "token_generation_batches": token_generation_batches, + "pa_num_blocks": pa_num_blocks, + } + + +def _ensure_paths(repo_root: Path) -> Path: + qwen_root = repo_root / "contrib" / "models" / "Qwen3.6-27B" + for path in (repo_root / "src", qwen_root / "vllm", qwen_root): + sys.path.insert(0, str(path)) + os.environ["PYTHONPATH"] = ( + f"{repo_root / 'src'}:{qwen_root / 'vllm'}:{qwen_root}:" + f"{os.environ.get('PYTHONPATH', '')}" + ) + return qwen_root + + +def _ensure_runtime_env(args: argparse.Namespace) -> None: + os.environ.setdefault("VLLM_NEURON_FRAMEWORK", "neuronx-distributed-inference") + os.environ.setdefault("VLLM_PLUGINS", "neuron") + os.environ.setdefault("QWEN36_HYBRID_APC_INSTALL_PATCH", "1") + os.environ.setdefault("DISABLE_NEURON_CUSTOM_SCHEDULER", "1") + for name, value in FP8_ENV_DEFAULTS.items(): + os.environ.setdefault(name, value) + os.environ["NEURON_COMPILED_ARTIFACTS"] = str(args.compiled_artifacts) + + +def _additional_config(args: argparse.Namespace) -> dict[str, Any]: + override_neuron_config = { + "tp_degree": args.tensor_parallel_size, + "batch_size": args.max_num_seqs, + "ctx_batch_size": args.ctx_batch_size, + "tkg_batch_size": args.max_num_seqs, + "seq_len": args.seq_len, + "max_length": args.seq_len, + "max_context_length": args.seq_len, + "context_encoding_buckets": args.resolved_cte_buckets, + "token_generation_buckets": args.resolved_token_generation_buckets, + "enable_bucketing": len(args.resolved_cte_buckets) > 1 + or len(args.resolved_token_generation_buckets) > 1, + "logical_nc_config": args.logical_nc_config, + "torch_dtype": "bfloat16", + "save_sharded_checkpoint": True, + "pa_block_size": args.block_size, + "pa_num_blocks": args.pa_num_blocks, + "gdn_checkpoint_interval": args.gdn_checkpoint_interval, + "max_gdn_checkpoint_slots": args.max_gdn_checkpoint_slots, + "gdn_recurrent_cache_dtype": args.gdn_recurrent_cache_dtype, + "gdn_conv_cache_dtype": args.gdn_conv_cache_dtype, + "hybrid_recurrent_cache_dtype": args.gdn_recurrent_cache_dtype, + "hybrid_conv_cache_dtype": args.gdn_conv_cache_dtype, + "hybrid_cache_mode": "all", + "is_block_kv_layout": True, + "is_prefix_caching": True, + "chunked_prefill_config": { + "max_num_seqs": args.max_num_seqs, + "tkg_model_enabled": True, + "kernel_q_tile_size": args.kernel_q_tile_size, + "kernel_kv_tile_size": args.kernel_kv_tile_size, + }, + } + if args.async_mode: + override_neuron_config["async_mode"] = True + if args.resolved_context_encoding_bucket_pairs is not None: + override_neuron_config["context_encoding_bucket_pairs"] = ( + args.resolved_context_encoding_bucket_pairs + ) + if args.resolved_token_generation_batches is not None: + override_neuron_config["token_generation_batches"] = ( + args.resolved_token_generation_batches + ) + + return { + "max_prompt_length": args.seq_len, + "use_hybrid_apc_manager": True, + "use_text_only_cte_inputs": True, + "use_compact_cte_attention_mask": True, + "use_cold_zero_conv_fast_path": False, + "gdn_checkpoint_interval": args.gdn_checkpoint_interval, + "max_gdn_checkpoint_slots": args.max_gdn_checkpoint_slots, + "gdn_recurrent_cache_dtype": args.gdn_recurrent_cache_dtype, + "gdn_conv_cache_dtype": args.gdn_conv_cache_dtype, + "hybrid_recurrent_cache_dtype": args.gdn_recurrent_cache_dtype, + "hybrid_conv_cache_dtype": args.gdn_conv_cache_dtype, + "hybrid_cache_mode": "all", + "hybrid_cache_prefix_boundary_only": True, + "hybrid_cache_block_boundary_only": True, + "hybrid_cache_validate_exact": False, + "hybrid_apc_require_vllm_metadata": True, + "hybrid_apc_allow_local_hash_fallback": False, + "hybrid_apc_require_attention_block_refs": True, + "hybrid_apc_disable_unbacked_prefix_reads": False, + "hybrid_apc_enable_backed_prefix_reads": True, + "use_qwen_hybrid_chunked_prefill": True, + "use_qwen_hybrid_chunked_prefill_nki": True, + "override_neuron_config": override_neuron_config, + } + + +def _build_llm(args: argparse.Namespace): + from hf_qwen35_config import register_qwen35_config # noqa: WPS433 + from qwen36_hybrid_apc_scheduler_patch import ( # noqa: WPS433 + install_import_hook as install_hybrid_apc_scheduler_patch, + ) + + register_qwen35_config() + install_hybrid_apc_scheduler_patch() + + from vllm import LLM, SamplingParams # noqa: WPS433 + + recurrent_cache_dtype = str(args.gdn_recurrent_cache_dtype).lower() + if recurrent_cache_dtype in {"bfloat16", "bf16"}: + recurrent_cache_dtype = "auto" + + llm = LLM( + model=str(args.model_path), + trust_remote_code=True, + dtype="bfloat16", + tensor_parallel_size=args.tensor_parallel_size, + max_model_len=args.max_model_len, + enable_prefix_caching=True, + enable_chunked_prefill=True, + additional_config=_additional_config(args), + block_size=args.block_size, + num_gpu_blocks_override=args.pa_num_blocks, + mamba_cache_mode="all", + mamba_ssm_cache_dtype=recurrent_cache_dtype, + max_num_batched_tokens=max(args.resolved_cte_buckets), + max_num_seqs=args.max_num_seqs, + ) + sampling = SamplingParams( + temperature=0.0, + top_k=1, + max_tokens=args.max_tokens, + ) + warmup_sampling = SamplingParams( + temperature=0.0, + top_k=1, + max_tokens=args.warmup_tokens, + ) + return llm, sampling, warmup_sampling + + +def _run_generate(llm: Any, prompt: str, sampling: Any) -> dict[str, Any]: + start = time.perf_counter() + outputs = llm.generate([prompt], sampling) + elapsed = time.perf_counter() - start + output = outputs[0].outputs[0] + token_ids = list(output.token_ids) + return { + "elapsed_s": elapsed, + "completion_tokens": len(token_ids), + "tok_s": (len(token_ids) / elapsed) if elapsed > 0 else None, + "text": output.text, + } + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser() + parser.add_argument("--repo-root", type=Path, default=Path.cwd()) + parser.add_argument("--model-path", type=Path, required=True) + parser.add_argument("--compiled-artifacts", type=Path, required=True) + parser.add_argument("--output-json", type=Path) + parser.add_argument("--prompt", default="Explain software benchmarking in two concise paragraphs.") + parser.add_argument("--max-tokens", type=int, default=64) + parser.add_argument("--warmup-tokens", type=int, default=8) + parser.add_argument("--repeats", type=int, default=3) + parser.add_argument("--max-model-len", type=int) + parser.add_argument("--seq-len", type=int) + parser.add_argument("--cte-buckets") + parser.add_argument("--context-encoding-bucket-pairs") + parser.add_argument("--token-generation-buckets") + parser.add_argument("--token-generation-batches") + parser.add_argument("--async-mode", action="store_true") + parser.add_argument("--tensor-parallel-size", type=int, default=4) + parser.add_argument("--max-num-seqs", type=int, default=1) + parser.add_argument("--ctx-batch-size", type=int, default=1) + parser.add_argument("--logical-nc-config", type=int, default=2) + parser.add_argument("--block-size", type=int, default=256) + parser.add_argument("--pa-num-blocks", type=int) + parser.add_argument("--gdn-checkpoint-interval", type=int, default=256) + parser.add_argument("--max-gdn-checkpoint-slots", type=int, default=64) + parser.add_argument("--gdn-recurrent-cache-dtype", default="float32") + parser.add_argument("--gdn-conv-cache-dtype", default="bfloat16") + parser.add_argument("--kernel-q-tile-size", type=int, default=128) + parser.add_argument("--kernel-kv-tile-size", type=int, default=1024) + return parser.parse_args() + + +def main() -> int: + args = parse_args() + args.repo_root = args.repo_root.expanduser().resolve() + args.model_path = args.model_path.expanduser().resolve() + args.compiled_artifacts = args.compiled_artifacts.expanduser().resolve() + resolved = _resolve_config_defaults(args) + args.artifact_config = resolved["artifact_config"] + args.seq_len = resolved["seq_len"] + args.max_model_len = resolved["max_model_len"] + args.resolved_cte_buckets = resolved["cte_buckets"] + args.resolved_context_encoding_bucket_pairs = resolved[ + "context_encoding_bucket_pairs" + ] + args.resolved_token_generation_buckets = resolved["token_generation_buckets"] + args.resolved_token_generation_batches = resolved["token_generation_batches"] + args.pa_num_blocks = resolved["pa_num_blocks"] + _ensure_paths(args.repo_root) + _ensure_runtime_env(args) + + llm = None + report: dict[str, Any] = { + "artifact": str(args.compiled_artifacts), + "model_path": str(args.model_path), + "prompt": args.prompt, + "max_tokens": args.max_tokens, + "max_num_seqs": args.max_num_seqs, + "pa_num_blocks": args.pa_num_blocks, + "cte_buckets": args.resolved_cte_buckets, + "token_generation_buckets": args.resolved_token_generation_buckets, + "context_encoding_bucket_pairs": args.resolved_context_encoding_bucket_pairs, + "token_generation_batches": args.resolved_token_generation_batches, + "async_mode": args.async_mode, + "max_model_len": args.max_model_len, + "seq_len": args.seq_len, + "artifact_neuron_config": { + key: args.artifact_config.get(key) + for key in ( + "seq_len", + "max_length", + "max_context_length", + "context_encoding_buckets", + "prefix_buckets", + "token_generation_buckets", + "tkg_batch_size", + "ctx_batch_size", + "pa_block_size", + "pa_num_blocks", + "output_logits", + "on_device_sampling_config", + ) + }, + } + try: + llm, sampling, warmup_sampling = _build_llm(args) + report["warmup"] = _run_generate(llm, args.prompt, warmup_sampling) + rows = [] + for index in range(args.repeats): + row = _run_generate(llm, args.prompt, sampling) + row["run"] = index + 1 + rows.append(row) + print(json.dumps(row, sort_keys=True), flush=True) + report["runs"] = rows + report["avg_tok_s"] = sum(float(row["tok_s"]) for row in rows) / len(rows) + report["avg_elapsed_s"] = ( + sum(float(row["elapsed_s"]) for row in rows) / len(rows) + ) + print(json.dumps(report, indent=2, sort_keys=True), flush=True) + if args.output_json: + args.output_json.expanduser().write_text( + json.dumps(report, indent=2, sort_keys=True) + "\n", + encoding="utf-8", + ) + finally: + if llm is not None: + shutdown = getattr(llm, "shutdown", None) + if shutdown is not None: + shutdown() + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/validation_scripts/qwen36_openai_boundary_apc_probe.py b/validation_scripts/qwen36_openai_boundary_apc_probe.py new file mode 100644 index 00000000..b4787a73 --- /dev/null +++ b/validation_scripts/qwen36_openai_boundary_apc_probe.py @@ -0,0 +1,266 @@ +#!/usr/bin/env python3 +"""Boundary-aligned APC probe for a Qwen3.6 OpenAI-compatible server. + +This probe intentionally uses raw ``/v1/completions`` token-id prompts so the +prompt length is exactly the checkpoint boundary under test. If running behind +``qwen36_chat_proxy.py``, start that proxy with ``--allow-completions`` or send +this probe directly to the vLLM server. +""" + +from __future__ import annotations + +import argparse +import json +import time +import urllib.error +import urllib.request +from pathlib import Path +from typing import Any + + +def _parse_lengths(raw: str) -> list[int]: + lengths = [int(item) for item in raw.replace(",", " ").split()] + if not lengths: + raise ValueError("at least one boundary length is required") + return lengths + + +def _load_json_from_url(url: str, timeout: float) -> tuple[int, dict[str, Any]]: + request = urllib.request.Request(url, method="GET") + try: + with urllib.request.urlopen(request, timeout=timeout) as response: + return response.status, json.loads(response.read().decode("utf-8")) + except urllib.error.HTTPError as exc: + try: + payload = json.loads(exc.read().decode("utf-8")) + except Exception: + payload = {"error": {"message": str(exc)}} + return exc.code, payload + + +def _detect_model(base_url: str, fallback: str, timeout: float) -> str: + status, payload = _load_json_from_url(base_url.rstrip("/") + "/v1/models", timeout) + if status < 400: + data = payload.get("data") if isinstance(payload, dict) else None + if isinstance(data, list) and data: + model_id = data[0].get("id") + if isinstance(model_id, str) and model_id: + return model_id + return fallback + + +def _metric_snapshot(base_url: str, timeout: float) -> dict[str, float]: + try: + data = urllib.request.urlopen( + base_url.rstrip("/") + "/metrics", + timeout=timeout, + ).read().decode("utf-8") + except (OSError, TimeoutError, urllib.error.HTTPError, urllib.error.URLError): + return {} + + wanted: dict[str, float] = {} + for line in data.splitlines(): + try: + value = float(line.rsplit(" ", 1)[1]) + except (IndexError, ValueError): + continue + if line.startswith("vllm:prefix_cache_queries_total"): + wanted["prefix_cache_queries_total"] = value + elif line.startswith("vllm:prefix_cache_hits_total"): + wanted["prefix_cache_hits_total"] = value + elif line.startswith("vllm:prompt_tokens_cached_total"): + wanted["prompt_tokens_cached_total"] = value + elif ( + line.startswith("vllm:prompt_tokens_by_source_total") + and 'source="local_compute"' in line + ): + wanted["local_compute"] = value + elif ( + line.startswith("vllm:prompt_tokens_by_source_total") + and 'source="local_cache_hit"' in line + ): + wanted["local_cache_hit"] = value + return wanted + + +def _exact_token_ids(tokenizer: Any, length: int, salt: str) -> list[int]: + filler = tokenizer.encode( + " boundary aligned hybrid apc checkpoint validation", + add_special_tokens=False, + ) + if not filler: + raise RuntimeError("tokenizer produced no filler tokens") + token_ids = tokenizer.encode(f"Boundary APC probe {salt}. ", add_special_tokens=False) + while len(token_ids) < length: + token_ids.extend(filler) + return token_ids[:length] + + +def _post_completion( + *, + endpoint: str, + model: str, + prompt_token_ids: list[int], + max_tokens: int, + timeout: float, +) -> dict[str, Any]: + payload = { + "model": model, + "prompt": prompt_token_ids, + "max_tokens": max_tokens, + "temperature": 0, + "stream": False, + } + request = urllib.request.Request( + endpoint, + data=json.dumps(payload).encode("utf-8"), + headers={"Content-Type": "application/json"}, + method="POST", + ) + start = time.perf_counter() + try: + with urllib.request.urlopen(request, timeout=timeout) as response: + elapsed = time.perf_counter() - start + body = json.loads(response.read().decode("utf-8")) + status = response.status + error = None + except urllib.error.HTTPError as exc: + elapsed = time.perf_counter() - start + status = exc.code + try: + body = json.loads(exc.read().decode("utf-8")) + except Exception: + body = {"error": {"message": str(exc)}} + error = body + + choices = body.get("choices") if isinstance(body, dict) else None + choice = choices[0] if isinstance(choices, list) and choices else {} + usage = body.get("usage") if isinstance(body, dict) else None + return { + "status": status, + "elapsed_seconds": elapsed, + "text": choice.get("text") if isinstance(choice, dict) else None, + "usage": usage, + "valid_openai_body": isinstance(choices, list) and bool(choices) and usage is not None, + "error": error, + } + + +def _metric_delta(before: dict[str, float], after: dict[str, float]) -> dict[str, float]: + return { + key: after.get(key, 0.0) - before.get(key, 0.0) + for key in sorted(set(before) | set(after)) + } + + +def main() -> int: + parser = argparse.ArgumentParser() + parser.add_argument("--base-url", default="http://127.0.0.1:8000") + parser.add_argument("--model", default="auto") + parser.add_argument("--model-path", required=True) + parser.add_argument("--lengths", default="256,512,1024,2048,4096") + parser.add_argument("--repeats", type=int, default=3) + parser.add_argument("--max-tokens", type=int, default=1) + parser.add_argument("--timeout", type=float, default=300.0) + parser.add_argument("--output-jsonl", required=True) + parser.add_argument( + "--require-prefix-cache-query", + action="store_true", + help="Return non-zero if repeated boundary prompts never query prefix cache.", + ) + parser.add_argument( + "--require-prefix-cache-hit", + action="store_true", + help="Return non-zero if repeated boundary prompts never hit prefix cache.", + ) + args = parser.parse_args() + + from transformers import AutoTokenizer # noqa: WPS433 + + base_url = args.base_url.rstrip("/") + model = ( + _detect_model(base_url, "Qwen3.6-27B", args.timeout) + if args.model == "auto" + else args.model + ) + endpoint = base_url + "/v1/completions" + tokenizer = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True) + output_path = Path(args.output_jsonl) + output_path.parent.mkdir(parents=True, exist_ok=True) + + rows: list[dict[str, Any]] = [] + before_all = _metric_snapshot(base_url, args.timeout) + with output_path.open("w", encoding="utf-8") as handle: + header = { + "phase": "before_metrics", + "base_url": base_url, + "model": model, + "metrics": before_all, + } + print(json.dumps(header, sort_keys=True), flush=True) + handle.write(json.dumps(header, sort_keys=True) + "\n") + + for length in _parse_lengths(args.lengths): + prompt_ids = _exact_token_ids(tokenizer, length, f"len-{length}") + for repeat in range(args.repeats): + metrics_before = _metric_snapshot(base_url, args.timeout) + result = _post_completion( + endpoint=endpoint, + model=model, + prompt_token_ids=prompt_ids, + max_tokens=args.max_tokens, + timeout=args.timeout, + ) + metrics_after = _metric_snapshot(base_url, args.timeout) + elapsed = float(result["elapsed_seconds"]) + row = { + "phase": "case", + "label": f"boundary_{length}_repeat_{repeat}", + "length": length, + "repeat": repeat, + "prompt_tokens": length, + **result, + "effective_prompt_tokens_per_second": ( + length / elapsed if elapsed > 0 and int(result["status"]) < 400 else None + ), + "metrics_before": metrics_before, + "metrics_after": metrics_after, + "metric_delta": _metric_delta(metrics_before, metrics_after), + } + rows.append(row) + print(json.dumps(row, sort_keys=True), flush=True) + handle.write(json.dumps(row, sort_keys=True) + "\n") + handle.flush() + + after_all = _metric_snapshot(base_url, args.timeout) + total_delta = _metric_delta(before_all, after_all) + repeated_rows = [row for row in rows if int(row["repeat"]) > 0] + summary = { + "phase": "summary", + "all_status_ok": all(int(row["status"]) < 400 for row in rows), + "all_valid_openai_body": all(bool(row["valid_openai_body"]) for row in rows), + "total_rows": len(rows), + "total_metric_delta": total_delta, + "repeated_prefix_cache_query_delta": sum( + row["metric_delta"].get("prefix_cache_queries_total", 0.0) + for row in repeated_rows + ), + "repeated_prefix_cache_hit_delta": sum( + row["metric_delta"].get("prefix_cache_hits_total", 0.0) + for row in repeated_rows + ), + "output_jsonl": str(output_path), + } + print(json.dumps(summary, sort_keys=True), flush=True) + handle.write(json.dumps(summary, sort_keys=True) + "\n") + + failed = not summary["all_status_ok"] or not summary["all_valid_openai_body"] + if args.require_prefix_cache_query and summary["repeated_prefix_cache_query_delta"] <= 0: + failed = True + if args.require_prefix_cache_hit and summary["repeated_prefix_cache_hit_delta"] <= 0: + failed = True + return 1 if failed else 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/validation_scripts/qwen36_openai_chat_apc_validation.py b/validation_scripts/qwen36_openai_chat_apc_validation.py new file mode 100644 index 00000000..594401dd --- /dev/null +++ b/validation_scripts/qwen36_openai_chat_apc_validation.py @@ -0,0 +1,641 @@ +#!/usr/bin/env python3 +"""OpenAI-compatible chat validation for Qwen3.6 Hybrid APC serving.""" + +from __future__ import annotations + +import argparse +import json +import statistics +import time +import urllib.error +import urllib.request +from concurrent.futures import ThreadPoolExecutor +from pathlib import Path +from typing import Any + + +def _parse_lengths(raw: str) -> list[int]: + lengths = [int(item) for item in raw.replace(",", " ").split()] + if not lengths: + raise ValueError("at least one length is required") + return lengths + + +def _chat_token_count(tokenizer: Any, messages: list[dict[str, str]]) -> int: + try: + token_ids = tokenizer.apply_chat_template( + messages, + tokenize=True, + add_generation_prompt=True, + enable_thinking=False, + ) + except TypeError: + token_ids = tokenizer.apply_chat_template( + messages, + tokenize=True, + add_generation_prompt=True, + ) + return len(token_ids) + + +def _build_messages( + *, + shared_key: str, + salt: str, + filler_repeats: int, + suffix: str, + turns: int, +) -> list[dict[str, str]]: + filler = ( + " hybrid apc prefix checkpoint recurrent delta state validation " + "attention blocks restore suffix deterministic " + ) + messages = [ + { + "role": "system", + "content": ( + "You are a deterministic latency validation assistant. " + f"shared-key={shared_key}; salt={salt}. " + "Return exactly one short answer token." + ), + } + ] + for idx in range(max(1, turns - 1)): + messages.append( + { + "role": "user", + "content": ( + f"Shared setup turn {idx}. " + "Remember the validation marker and answer tersely." + ), + } + ) + messages.append({"role": "assistant", "content": f"ack-{idx}"}) + messages.append( + { + "role": "user", + "content": ( + "Shared benchmark document begins. " + + (filler * max(0, filler_repeats)) + + f" Shared benchmark document ends. {suffix}" + ), + } + ) + return messages + + +def _make_messages( + tokenizer: Any, + target_tokens: int, + *, + shared_key: str, + salt: str, + suffix: str, + turns: int, +) -> tuple[list[dict[str, str]], int]: + def build(repeats: int) -> list[dict[str, str]]: + return _build_messages( + shared_key=shared_key, + salt=salt, + filler_repeats=repeats, + suffix=suffix, + turns=turns, + ) + + low = 0 + high = 1 + while _chat_token_count(tokenizer, build(high)) <= target_tokens: + low = high + high *= 2 + + while low + 1 < high: + mid = (low + high) // 2 + if _chat_token_count(tokenizer, build(mid)) <= target_tokens: + low = mid + else: + high = mid + + messages = build(low) + return messages, _chat_token_count(tokenizer, messages) + + +def _load_json_from_url(url: str, timeout: float) -> tuple[int, dict[str, Any]]: + request = urllib.request.Request(url, method="GET") + try: + with urllib.request.urlopen(request, timeout=timeout) as response: + return response.status, json.loads(response.read().decode("utf-8")) + except urllib.error.HTTPError as exc: + try: + payload = json.loads(exc.read().decode("utf-8")) + except Exception: + payload = {"error": {"message": str(exc)}} + return exc.code, payload + + +def _detect_model(base_url: str, fallback: str, timeout: float) -> str: + status, payload = _load_json_from_url(base_url.rstrip("/") + "/v1/models", timeout) + if status < 400: + data = payload.get("data") if isinstance(payload, dict) else None + if isinstance(data, list) and data: + model_id = data[0].get("id") + if isinstance(model_id, str) and model_id: + return model_id + return fallback + + +def _post_chat( + *, + endpoint: str, + model: str, + messages: list[dict[str, str]], + max_tokens: int, + timeout: float, + enable_thinking: bool = False, +) -> dict[str, Any]: + payload = { + "model": model, + "messages": messages, + "max_tokens": max_tokens, + "temperature": 0, + "stream": False, + "chat_template_kwargs": {"enable_thinking": enable_thinking}, + } + body = json.dumps(payload).encode("utf-8") + request = urllib.request.Request( + endpoint, + data=body, + headers={"Content-Type": "application/json"}, + method="POST", + ) + start = time.perf_counter() + try: + with urllib.request.urlopen(request, timeout=timeout) as response: + elapsed = time.perf_counter() - start + response_payload = json.loads(response.read().decode("utf-8")) + choice = (response_payload.get("choices") or [{}])[0] + message = choice.get("message") or {} + content = message.get("content") + return { + "status": response.status, + "elapsed_seconds": elapsed, + "content": "" if content is None else str(content), + "finish_reason": choice.get("finish_reason"), + "usage": response_payload.get("usage"), + "error": None, + } + except urllib.error.HTTPError as exc: + elapsed = time.perf_counter() - start + try: + error_payload = json.loads(exc.read().decode("utf-8")) + except Exception: + error_payload = {"error": {"message": str(exc)}} + return { + "status": exc.code, + "elapsed_seconds": elapsed, + "content": "", + "finish_reason": None, + "usage": None, + "error": error_payload, + } + + +def _run_case( + *, + endpoint: str, + model: str, + label: str, + messages: list[dict[str, str]], + prompt_tokens: int, + max_tokens: int, + timeout: float, + enable_thinking: bool = False, +) -> dict[str, Any]: + result = _post_chat( + endpoint=endpoint, + model=model, + messages=messages, + max_tokens=max_tokens, + timeout=timeout, + enable_thinking=enable_thinking, + ) + row = { + "label": label, + "prompt_tokens": prompt_tokens, + **result, + } + elapsed = float(row["elapsed_seconds"]) + row["effective_prompt_tokens_per_second"] = ( + prompt_tokens / elapsed if elapsed > 0 and int(row["status"]) < 400 else None + ) + print(json.dumps(row, sort_keys=True), flush=True) + return row + + +def _semantic_smoke_cases() -> list[dict[str, Any]]: + return [ + { + "label": "semantic_arithmetic", + "messages": [ + { + "role": "system", + "content": "You answer with only the requested value.", + }, + { + "role": "user", + "content": "What is 17 * 23? Answer with digits only.", + }, + ], + "contains": "391", + }, + { + "label": "semantic_marker_copy", + "messages": [ + { + "role": "system", + "content": "You answer with only the requested marker.", + }, + { + "role": "user", + "content": "Return exactly this marker: BASELINE_OK_27B", + }, + ], + "contains": "BASELINE_OK_27B", + }, + { + "label": "semantic_multi_turn_recall", + "messages": [ + { + "role": "system", + "content": "You answer with only the requested value.", + }, + { + "role": "user", + "content": "Remember validation code ZX-417.", + }, + {"role": "assistant", "content": "Remembered."}, + { + "role": "user", + "content": "What validation code did I ask you to remember?", + }, + ], + "contains": "ZX-417", + }, + ] + + +def _avg(values: list[float]) -> float | None: + return statistics.fmean(values) if values else None + + +def _successful_elapsed(rows: list[dict[str, Any]]) -> list[float]: + return [ + float(row["elapsed_seconds"]) + for row in rows + if int(row.get("status", 500)) < 400 + ] + + +def _speedup_passes(value: float | None, minimum: float) -> bool: + return minimum <= 0 or (value is not None and value >= minimum) + + +def _apc_gate_failures(summary: dict[str, Any]) -> list[str]: + checks = { + "all_status_ok": bool(summary["all_status_ok"]), + "warm_full_exact_text": bool(summary["warm_full_exact_text"]), + "partial_repeat_exact_text": bool(summary["partial_repeat_exact_text"]), + "multi_turn_repeat_exact_text": bool(summary["multi_turn_repeat_exact_text"]), + "semantic_smoke_passed": bool(summary["semantic_smoke_passed"]), + "warm_full_speedup_passed": bool(summary["warm_full_speedup_passed"]), + "partial_reference_speedup_passed": bool( + summary["partial_reference_speedup_passed"] + ), + } + return [name for name, passed in checks.items() if not passed] + + +def main() -> int: + parser = argparse.ArgumentParser() + parser.add_argument("--base-url", default="http://127.0.0.1:8000") + parser.add_argument("--model", default="auto") + parser.add_argument("--model-path", required=True) + parser.add_argument("--cold-lengths", default="256,512,1024,1536,1984") + parser.add_argument("--target-tokens", type=int, default=1984) + parser.add_argument("--turns", type=int, default=5) + parser.add_argument("--max-tokens", type=int, default=1) + parser.add_argument("--semantic-max-tokens", type=int, default=16) + parser.add_argument("--cold-repeats", type=int, default=2) + parser.add_argument("--warm-repeats", type=int, default=3) + parser.add_argument("--mixed-repeats", type=int, default=3) + parser.add_argument("--timeout", type=float, default=180.0) + parser.add_argument("--output-json", required=True) + parser.add_argument( + "--min-warm-full-speedup", + type=float, + default=1.5, + help=( + "Minimum warm-full repeat speedup over the initial request. " + "Set to 0 to disable this speed gate." + ), + ) + parser.add_argument( + "--min-partial-speedup", + type=float, + default=1.2, + help=( + "Minimum partial-prefix warm speedup over its cold reference. " + "Set to 0 to disable this speed gate." + ), + ) + args = parser.parse_args() + + from transformers import AutoTokenizer # noqa: WPS433 + + tokenizer = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True) + base_url = args.base_url.rstrip("/") + model = ( + _detect_model(base_url, "Qwen3.6-27B", args.timeout) + if args.model == "auto" + else args.model + ) + endpoint = base_url + "/v1/chat/completions" + stamp = int(time.time()) + + rows: list[dict[str, Any]] = [] + cold_rows: list[dict[str, Any]] = [] + for length in _parse_lengths(args.cold_lengths): + for repeat in range(args.cold_repeats): + messages, prompt_tokens = _make_messages( + tokenizer, + length, + shared_key=f"cold-{length}-{repeat}-{stamp}", + salt=f"cold-early-{length}-{repeat}-{stamp}", + suffix="Answer with the word cold.", + turns=args.turns, + ) + row = _run_case( + endpoint=endpoint, + model=model, + label=f"cold_len_{length}_repeat_{repeat}", + messages=messages, + prompt_tokens=prompt_tokens, + max_tokens=args.max_tokens, + timeout=args.timeout, + ) + rows.append(row) + cold_rows.append(row) + + warm_messages, warm_prompt_tokens = _make_messages( + tokenizer, + args.target_tokens, + shared_key=f"warm-full-{stamp}", + salt=f"warm-full-{stamp}", + suffix="Answer with the word warm.", + turns=args.turns, + ) + warm_full_rows = [ + _run_case( + endpoint=endpoint, + model=model, + label="warm_full_initial", + messages=warm_messages, + prompt_tokens=warm_prompt_tokens, + max_tokens=args.max_tokens, + timeout=args.timeout, + ) + ] + for repeat in range(args.warm_repeats): + warm_full_rows.append( + _run_case( + endpoint=endpoint, + model=model, + label=f"warm_full_repeat_{repeat}", + messages=warm_messages, + prompt_tokens=warm_prompt_tokens, + max_tokens=args.max_tokens, + timeout=args.timeout, + ) + ) + rows.extend(warm_full_rows) + + partial_key = f"partial-shared-{stamp}" + partial_warmup_messages, partial_warmup_tokens = _make_messages( + tokenizer, + args.target_tokens, + shared_key=partial_key, + salt=partial_key, + suffix="Suffix alpha. Answer with the word alpha.", + turns=args.turns, + ) + partial_target_messages, partial_target_tokens = _make_messages( + tokenizer, + args.target_tokens, + shared_key=partial_key, + salt=partial_key, + suffix="Suffix beta. Answer with the word beta.", + turns=args.turns, + ) + partial_cold_messages, partial_cold_tokens = _make_messages( + tokenizer, + args.target_tokens, + shared_key=f"partial-cold-reference-{stamp}", + salt=f"partial-cold-reference-{stamp}", + suffix="Suffix beta. Answer with the word beta.", + turns=args.turns, + ) + partial_rows = [ + _run_case( + endpoint=endpoint, + model=model, + label="partial_cold_reference", + messages=partial_cold_messages, + prompt_tokens=partial_cold_tokens, + max_tokens=args.max_tokens, + timeout=args.timeout, + ), + _run_case( + endpoint=endpoint, + model=model, + label="partial_warmup_alpha", + messages=partial_warmup_messages, + prompt_tokens=partial_warmup_tokens, + max_tokens=args.max_tokens, + timeout=args.timeout, + ), + ] + for repeat in range(args.warm_repeats): + partial_rows.append( + _run_case( + endpoint=endpoint, + model=model, + label=f"partial_warm_beta_repeat_{repeat}", + messages=partial_target_messages, + prompt_tokens=partial_target_tokens, + max_tokens=args.max_tokens, + timeout=args.timeout, + ) + ) + rows.extend(partial_rows) + + mixed_rows: list[dict[str, Any]] = [] + for repeat in range(args.mixed_repeats): + mixed_warm_messages, mixed_warm_tokens = _make_messages( + tokenizer, + args.target_tokens, + shared_key=partial_key, + salt=partial_key, + suffix=f"Suffix mixed warm {repeat}. Answer with the word beta.", + turns=args.turns, + ) + mixed_cold_messages, mixed_cold_tokens = _make_messages( + tokenizer, + args.target_tokens, + shared_key=f"mixed-cold-{repeat}-{stamp}", + salt=f"mixed-cold-{repeat}-{stamp}", + suffix=f"Suffix mixed cold {repeat}. Answer with the word cold.", + turns=args.turns, + ) + with ThreadPoolExecutor(max_workers=2) as executor: + futures = [ + executor.submit( + _run_case, + endpoint=endpoint, + model=model, + label=f"mixed_warm_repeat_{repeat}", + messages=mixed_warm_messages, + prompt_tokens=mixed_warm_tokens, + max_tokens=args.max_tokens, + timeout=args.timeout, + ), + executor.submit( + _run_case, + endpoint=endpoint, + model=model, + label=f"mixed_cold_repeat_{repeat}", + messages=mixed_cold_messages, + prompt_tokens=mixed_cold_tokens, + max_tokens=args.max_tokens, + timeout=args.timeout, + ), + ] + for future in futures: + mixed_rows.append(future.result()) + rows.extend(mixed_rows) + + multi_messages, multi_prompt_tokens = _make_messages( + tokenizer, + min(args.target_tokens, 1536), + shared_key=f"multi-turn-{stamp}", + salt=f"multi-turn-{stamp}", + suffix="Answer with the word multi.", + turns=max(args.turns, 8), + ) + multi_rows = [] + for repeat in range(args.warm_repeats): + multi_rows.append( + _run_case( + endpoint=endpoint, + model=model, + label=f"multi_turn_repeat_{repeat}", + messages=multi_messages, + prompt_tokens=multi_prompt_tokens, + max_tokens=args.max_tokens, + timeout=args.timeout, + ) + ) + rows.extend(multi_rows) + + semantic_rows = [] + for case in _semantic_smoke_cases(): + prompt_tokens = _chat_token_count(tokenizer, case["messages"]) + row = _run_case( + endpoint=endpoint, + model=model, + label=case["label"], + messages=case["messages"], + prompt_tokens=prompt_tokens, + max_tokens=args.semantic_max_tokens, + timeout=args.timeout, + ) + row["expected_contains"] = case["contains"] + row["semantic_passed"] = case["contains"] in row.get("content", "") + semantic_rows.append(row) + rows.extend(semantic_rows) + + warm_initial = warm_full_rows[0] + warm_repeats = warm_full_rows[1:] + partial_warm = [ + row for row in partial_rows if row["label"].startswith("partial_warm_beta") + ] + mixed_warm = [row for row in mixed_rows if row["label"].startswith("mixed_warm")] + mixed_cold = [row for row in mixed_rows if row["label"].startswith("mixed_cold")] + all_ok = all(int(row.get("status", 500)) < 400 for row in rows) + warm_full_exact = len({row["content"] for row in warm_full_rows}) == 1 + partial_repeat_exact = bool(partial_warm) and len( + {row["content"] for row in partial_warm} + ) == 1 + multi_repeat_exact = bool(multi_rows) and len( + {row["content"] for row in multi_rows} + ) == 1 + semantic_passed = all(bool(row.get("semantic_passed")) for row in semantic_rows) + + warm_initial_elapsed = float(warm_initial["elapsed_seconds"]) + warm_repeat_avg = _avg(_successful_elapsed(warm_repeats)) + partial_cold_elapsed = float(partial_rows[0]["elapsed_seconds"]) + partial_warm_avg = _avg(_successful_elapsed(partial_warm)) + warm_full_speedup = ( + warm_initial_elapsed / warm_repeat_avg + if warm_repeat_avg and warm_repeat_avg > 0 + else None + ) + partial_reference_speedup = ( + partial_cold_elapsed / partial_warm_avg + if partial_warm_avg and partial_warm_avg > 0 + else None + ) + summary = { + "all_status_ok": all_ok, + "base_url": base_url, + "model": model, + "cold_request_count": len(cold_rows), + "warm_full_exact_text": warm_full_exact, + "partial_repeat_exact_text": partial_repeat_exact, + "multi_turn_repeat_exact_text": multi_repeat_exact, + "semantic_smoke_passed": semantic_passed, + "warm_full_initial_seconds": warm_initial_elapsed, + "warm_full_repeat_avg_seconds": warm_repeat_avg, + "warm_full_speedup": warm_full_speedup, + "min_warm_full_speedup": args.min_warm_full_speedup, + "warm_full_speedup_passed": _speedup_passes( + warm_full_speedup, + args.min_warm_full_speedup, + ), + "partial_cold_reference_seconds": partial_cold_elapsed, + "partial_warm_beta_avg_seconds": partial_warm_avg, + "partial_reference_speedup": partial_reference_speedup, + "min_partial_speedup": args.min_partial_speedup, + "partial_reference_speedup_passed": _speedup_passes( + partial_reference_speedup, + args.min_partial_speedup, + ), + "mixed_warm_avg_seconds": _avg(_successful_elapsed(mixed_warm)), + "mixed_cold_avg_seconds": _avg(_successful_elapsed(mixed_cold)), + "multi_turn_avg_seconds": _avg(_successful_elapsed(multi_rows)), + } + failures = _apc_gate_failures(summary) + summary["apc_validation_passed"] = not failures + summary["apc_gate_failures"] = failures + output = { + "summary": summary, + "results": rows, + } + output_path = Path(args.output_json) + output_path.parent.mkdir(parents=True, exist_ok=True) + with output_path.open("w") as f: + json.dump(output, f, indent=2, sort_keys=True) + print(json.dumps({"summary": summary, "output_json": str(output_path)}, sort_keys=True)) + return 0 if summary["apc_validation_passed"] else 1 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/validation_scripts/qwen36_split_qkv_tkg_probe.py b/validation_scripts/qwen36_split_qkv_tkg_probe.py new file mode 100644 index 00000000..205c1ab6 --- /dev/null +++ b/validation_scripts/qwen36_split_qkv_tkg_probe.py @@ -0,0 +1,197 @@ +#!/usr/bin/env python3 +"""Minimal Qwen3.6 split-QKV TKG kernel probe. + +Runs one preprod qkv_tkg projection shape at a time so runtime OOBs can be +localized to Q, K, or V without compiling the full model. +""" + +from __future__ import annotations + +import argparse +import json +import os +from pathlib import Path + +os.environ.setdefault("XLA_HANDLE_SPECIAL_SCALAR", "1") +os.environ.setdefault("UNSAFE_FP8FNCAST", "1") + +import torch + + +def _parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser() + parser.add_argument("--projection", choices=("q", "k", "v"), required=True) + parser.add_argument("--weight-dtype", choices=("bf16", "fp8"), default="bf16") + parser.add_argument("--lnc", type=int, default=1) + parser.add_argument("--hidden-size", type=int, default=5120) + parser.add_argument("--head-dim", type=int, default=256) + parser.add_argument("--tp-degree", type=int, default=4) + parser.add_argument("--num-attention-heads", type=int, default=24) + parser.add_argument("--num-key-value-heads", type=int, default=4) + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--inspect-dir", default=None) + parser.add_argument("--cpu-backend-hlo", action="store_true") + return parser.parse_args() + + +def _local_heads(args: argparse.Namespace) -> int: + if args.projection == "q": + return args.num_attention_heads // args.tp_degree + return args.num_key_value_heads // args.tp_degree + + +def _make_weight( + hidden_size: int, + output_size: int, + dtype: torch.dtype, +) -> torch.Tensor: + weight = torch.randn(hidden_size, output_size, dtype=torch.bfloat16) * 0.01 + if dtype is torch.float8_e4m3fn: + return weight.to(torch.float8_e4m3fn) + return weight.to(dtype) + + +def main() -> int: + args = _parse_args() + if args.num_attention_heads % args.tp_degree != 0: + raise ValueError("num_attention_heads must be divisible by tp_degree") + if args.num_key_value_heads % args.tp_degree != 0: + raise ValueError("num_key_value_heads must be divisible by tp_degree") + + neuron_cc_flags = f"--target trn2 --lnc {args.lnc}" + if args.weight_dtype == "fp8": + neuron_cc_flags += ( + " --internal-hlo2tensorizer-options=' " + "--experimental-unsafe-fp8e4m3fn-as-fp8e4m3 --verify-hlo=true'" + ) + os.environ.setdefault("NEURON_CC_FLAGS", neuron_cc_flags) + os.environ.setdefault("NEURON_PLATFORM_TARGET_OVERRIDE", "trn2") + os.environ.setdefault("NEURON_RT_VISIBLE_CORES", "0-1" if args.lnc > 1 else "0") + os.environ.setdefault("NEURON_RT_ENABLE_DGE_NOTIFICATIONS", "1") + os.environ.setdefault("NEURON_FRAMEWORK_DEBUG", "1") + os.environ.setdefault("XLA_IR_DEBUG", "1") + os.environ.setdefault("XLA_HLO_DEBUG", "1") + if args.inspect_dir: + inspect_dir = Path(args.inspect_dir).expanduser().resolve() + inspect_dir.mkdir(parents=True, exist_ok=True) + os.environ["NEURON_RT_INSPECT_ENABLE"] = "1" + os.environ["NEURON_RT_INSPECT_DEVICE_PROFILE"] = "1" + os.environ["NEURON_RT_INSPECT_SYSTEM_PROFILE"] = "0" + os.environ["NEURON_RT_INSPECT_OUTPUT_DIR"] = str(inspect_dir) + + from neuronxcc.nki._pre_prod_kernels import ( # noqa: PLC0415 + NormType, + QKVOutputLayout, + QuantizationType, + ) + from neuronxcc.nki._pre_prod_kernels.qkv_tkg_impl import ( # noqa: PLC0415 + nki_qkv_projection_tkg_impl, + ) + + local_heads = _local_heads(args) + output_size = local_heads * args.head_dim + weight_dtype = ( + torch.float8_e4m3fn if args.weight_dtype == "fp8" else torch.bfloat16 + ) + quantization_type = ( + QuantizationType.ROW if args.weight_dtype == "fp8" else QuantizationType.NONE + ) + kernel = nki_qkv_projection_tkg_impl[args.lnc] + + def run_kernel(hidden: torch.Tensor, weight: torch.Tensor, scales: torch.Tensor): + return kernel( + hidden=hidden, + qkv_w=weight, + norm_w=None, + fused_add=False, + mlp_prev=None, + attn_prev=None, + d_head=args.head_dim, + output_layout=QKVOutputLayout.BSD, + eps=1e-6, + norm_type=NormType.NO_NORM, + qkvInSB=False, + qkv_bias=None, + norm_bias=None, + hidden_actual=args.hidden_size, + B=1, + S=1, + H=args.hidden_size, + num_q_heads=local_heads, + num_kv_heads=local_heads, + quantization_type=quantization_type, + qkv_w_scales=scales if args.weight_dtype == "fp8" else None, + qkv_in_scales=None, + ) + + torch.manual_seed(args.seed) + hidden_cpu = torch.randn(1, 1, args.hidden_size, dtype=torch.bfloat16) + weight_cpu = _make_weight(args.hidden_size, output_size, weight_dtype) + scale_cpu = torch.ones((128, output_size), dtype=torch.float32) + + metadata = { + "projection": args.projection, + "hidden_shape": list(hidden_cpu.shape), + "weight_shape": list(weight_cpu.shape), + "scale_shape": list(scale_cpu.shape), + "weight_dtype": str(weight_cpu.dtype), + "local_heads": local_heads, + "head_dim": args.head_dim, + "lnc": args.lnc, + "quantization_type": str(quantization_type), + "neuron_cc_flags": os.environ.get("NEURON_CC_FLAGS"), + "neuron_compile_cache_url": os.environ.get("NEURON_COMPILE_CACHE_URL"), + "visible_cores": os.environ.get("NEURON_RT_VISIBLE_CORES"), + } + print("PROBE_CONFIG", json.dumps(metadata, sort_keys=True), flush=True) + + if args.cpu_backend_hlo: + import torch_neuronx.xla_impl.trace as trace # noqa: PLC0415 + + artifacts = trace.generate_hlo( + run_kernel, + (hidden_cpu, weight_cpu, scale_cpu), + inline_weights_to_neff=False, + return_weights=False, + cpu_backend=True, + preserve_parameters=False, + ) + print("HLO_OK", type(artifacts), flush=True) + return 0 + + from torch_xla.core import xla_model as xm # noqa: PLC0415 + + device = xm.xla_device() + hidden = hidden_cpu.to(device) + weight = weight_cpu.to(device) + scales = scale_cpu.to(device) + output = run_kernel(hidden, weight, scales) + xm.mark_step() + output_cpu = output.detach().cpu() + print( + "OUTPUT", + tuple(output_cpu.shape), + output_cpu.dtype, + "finite", + bool(torch.isfinite(output_cpu.float()).all()), + "sum", + float(output_cpu.float().sum()), + flush=True, + ) + + if args.weight_dtype == "bf16": + ref = hidden_cpu.float().reshape(1, args.hidden_size) @ weight_cpu.float() + diff = (output_cpu.float().reshape_as(ref) - ref).abs() + print( + "BF16_REF", + "max_abs", + float(diff.max()), + "mean_abs", + float(diff.mean()), + flush=True, + ) + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/validation_scripts/qwen36_steady_cold_prefill_bench.py b/validation_scripts/qwen36_steady_cold_prefill_bench.py new file mode 100644 index 00000000..266a0285 --- /dev/null +++ b/validation_scripts/qwen36_steady_cold_prefill_bench.py @@ -0,0 +1,216 @@ +#!/usr/bin/env python3 +"""Steady-state cold-prefill benchmark for Qwen3.6 Hybrid APC artifacts.""" + +from __future__ import annotations + +import argparse +import json +from pathlib import Path +from typing import Any + +import qwen36_hybrid_apc_context_sweep as context_sweep +import qwen36_hybrid_apc_validation as hybrid_validation + + +def _row_for_prompt( + *, + llm: Any, + sampling: Any, + tokenizer: Any, + vocab_size: int, + dummy_ids: set[int], + target_tokens: int, + suffix_tokens: int, + role_index: int, + phase: str, + run: int, +) -> dict[str, Any]: + prompt = context_sweep._prompt_for_length( + tokenizer, + target_tokens=target_tokens, + suffix_tokens=suffix_tokens, + role_index=role_index, + ) + result = context_sweep._generate(llm, sampling, prompt) + generated_tokens = [int(token_id) for token_id in result["generated_tokens"]] + invalid_token_ids = [ + token_id for token_id in generated_tokens if token_id < 0 or token_id >= vocab_size + ] + non_dummy = [token_id for token_id in generated_tokens if token_id not in dummy_ids] + elapsed = float(result["elapsed_seconds"]) + row = { + "phase": phase, + "run": run, + "target_prompt_tokens": target_tokens, + "actual_prompt_tokens": len(prompt["prompt_token_ids"]), + "suffix_tokens": suffix_tokens, + "elapsed_seconds": elapsed, + "effective_prompt_tokens_per_second": ( + target_tokens / elapsed if elapsed > 0 else None + ), + "generated_text": result["generated_text"], + "generated_token_count": result["generated_token_count"], + "generated_tokens": generated_tokens, + "unique_generated_tokens": sorted(set(generated_tokens)), + "token_range_passed": not invalid_token_ids, + "invalid_token_ids": sorted(set(invalid_token_ids)), + "real_tokens_passed": bool(non_dummy), + "non_dummy_generated_token_count": len(non_dummy), + } + print(json.dumps(row, sort_keys=True), flush=True) + return row + + +def _summarize(rows: list[dict[str, Any]]) -> dict[str, Any]: + rates = [ + float(row["effective_prompt_tokens_per_second"]) + for row in rows + if row.get("effective_prompt_tokens_per_second") is not None + ] + if not rates: + return { + "count": 0, + "avg_effective_prompt_tokens_per_second": None, + "min_effective_prompt_tokens_per_second": None, + "max_effective_prompt_tokens_per_second": None, + } + return { + "count": len(rates), + "avg_effective_prompt_tokens_per_second": sum(rates) / len(rates), + "min_effective_prompt_tokens_per_second": min(rates), + "max_effective_prompt_tokens_per_second": max(rates), + } + + +def main() -> int: + parser = argparse.ArgumentParser() + parser.add_argument("--model-path", type=Path, required=True) + parser.add_argument("--compiled-artifacts", type=Path, required=True) + parser.add_argument("--output-json", type=Path, required=True) + parser.add_argument("--prompt-tokens", type=int, default=16384) + parser.add_argument("--suffix-tokens", type=int, default=16) + parser.add_argument("--warmup-runs", type=int, default=1) + parser.add_argument("--measured-runs", type=int, default=3) + parser.add_argument("--role-index-base", type=int, default=700000) + parser.add_argument("--max-tokens", type=int, default=1) + parser.add_argument("--seq-len", type=int) + parser.add_argument("--max-model-len", type=int) + parser.add_argument("--cte-buckets") + parser.add_argument("--context-encoding-bucket-pairs", nargs="+", default=None) + parser.add_argument("--pa-num-blocks", type=int) + parser.add_argument("--gpu-memory-utilization", type=float) + parser.add_argument("--tensor-parallel-size", type=int, default=4) + parser.add_argument("--max-num-seqs", type=int) + parser.add_argument("--logical-nc-config", type=int, default=2) + parser.add_argument("--ctx-batch-size", type=int) + parser.add_argument("--token-generation-buckets", nargs="+", default=None) + parser.add_argument("--token-generation-batches", nargs="+", default=None) + parser.add_argument("--async-mode", action="store_true", default=None) + parser.add_argument("--block-size", type=int, default=256) + parser.add_argument("--gdn-checkpoint-interval", type=int, default=256) + parser.add_argument("--max-gdn-checkpoint-slots", type=int, default=64) + parser.add_argument("--gdn-recurrent-cache-dtype", default="float32") + parser.add_argument("--gdn-conv-cache-dtype", default="bfloat16") + parser.add_argument("--hybrid-apc-prefill-chunk-tokens", type=int, default=0) + parser.add_argument("--kernel-q-tile-size", type=int, default=128) + parser.add_argument("--kernel-kv-tile-size", type=int, default=1024) + parser.add_argument("--skip-fp8-env", action="store_true") + parser.add_argument("--require-real-tokens", action="store_true") + parser.add_argument("--dummy-token-ids", nargs="+", type=int, default=[0]) + args = parser.parse_args() + + args.model_path = args.model_path.expanduser().resolve() + args.compiled_artifacts = args.compiled_artifacts.expanduser().resolve() + args.output_json = args.output_json.expanduser().resolve() + artifact_config = context_sweep._artifact_neuron_config(args.compiled_artifacts) + runtime_args = context_sweep._build_args(args, artifact_config) + + from transformers import AutoTokenizer # noqa: WPS433 + + tokenizer = AutoTokenizer.from_pretrained( + str(args.model_path), + trust_remote_code=True, + ) + vocab_size = context_sweep._effective_vocab_size(args.model_path, tokenizer) + configured_dummy_ids = {int(token_id) for token_id in args.dummy_token_ids} + dummy_ids = configured_dummy_ids | hybrid_validation._effective_dummy_token_ids( + runtime_args, + tokenizer, + ) + + if args.prompt_tokens + args.max_tokens > runtime_args.seq_len: + raise ValueError( + "prompt_tokens + max_tokens exceeds seq_len: " + f"{args.prompt_tokens} + {args.max_tokens} > {runtime_args.seq_len}" + ) + + rows: list[dict[str, Any]] = [] + llm = None + try: + llm, sampling = hybrid_validation._build_llm( + runtime_args, + enable_hybrid_apc=True, + ) + total_runs = args.warmup_runs + args.measured_runs + for run in range(total_runs): + phase = "warmup" if run < args.warmup_runs else "measured" + role_index = args.role_index_base + (run * 1009) + rows.append( + _row_for_prompt( + llm=llm, + sampling=sampling, + tokenizer=tokenizer, + vocab_size=vocab_size, + dummy_ids=dummy_ids, + target_tokens=args.prompt_tokens, + suffix_tokens=args.suffix_tokens, + role_index=role_index, + phase=phase, + run=run + 1, + ) + ) + finally: + if llm is not None: + hybrid_validation._shutdown_llm(llm) + + measured_rows = [row for row in rows if row["phase"] == "measured"] + correctness_passed = all( + row["token_range_passed"] + and (row["real_tokens_passed"] or not args.require_real_tokens) + for row in measured_rows + ) + report = { + "artifact": str(args.compiled_artifacts), + "artifact_neuron_config": { + key: artifact_config.get(key) + for key in ( + "seq_len", + "max_context_length", + "context_encoding_buckets", + "context_encoding_bucket_pairs", + "prefix_buckets", + "token_generation_buckets", + "output_logits", + "pa_num_blocks", + ) + }, + "prompt_tokens": args.prompt_tokens, + "warmup_runs": args.warmup_runs, + "measured_runs": args.measured_runs, + "configured_dummy_token_ids": sorted(configured_dummy_ids), + "effective_dummy_token_ids": sorted(dummy_ids), + "vocab_size": vocab_size, + "correctness_passed": correctness_passed, + "summary": _summarize(measured_rows), + "rows": rows, + } + args.output_json.parent.mkdir(parents=True, exist_ok=True) + with args.output_json.open("w", encoding="utf-8") as handle: + json.dump(report, handle, indent=2, sort_keys=True) + handle.write("\n") + print(json.dumps(report["summary"], sort_keys=True), flush=True) + return 0 if correctness_passed else 2 + + +if __name__ == "__main__": + raise SystemExit(main()) From e6ef24aa3b0b156e339fc2c59eb8e18259da2ad4 Mon Sep 17 00:00:00 2001 From: Deepankar Singh Date: Wed, 10 Jun 2026 22:27:48 +0530 Subject: [PATCH 2/3] Trim internal docs and port verified nativechunk evidence into README Co-Authored-By: Claude Fable 5 --- contrib/models/Qwen3.6-27B/README.md | 108 +- .../docs/CODEX_CONTINUOUS_BATCHING_PROMPT.md | 193 -- .../docs/FULL_FP8_ISSUES_AND_FIXES.md | 405 --- .../docs/HYBRID_APC_PRODUCTION_PLAN.md | 364 --- .../QWEN36_FP8_TIERFIX_VALIDATION_20260526.md | 2364 ----------------- .../docs/patches/mtp_batched_accept.patch | 154 -- .../docs/patches/mtp_batched_accept_README.md | 118 - 7 files changed, 37 insertions(+), 3669 deletions(-) delete mode 100644 contrib/models/Qwen3.6-27B/docs/CODEX_CONTINUOUS_BATCHING_PROMPT.md delete mode 100644 contrib/models/Qwen3.6-27B/docs/FULL_FP8_ISSUES_AND_FIXES.md delete mode 100644 contrib/models/Qwen3.6-27B/docs/HYBRID_APC_PRODUCTION_PLAN.md delete mode 100644 contrib/models/Qwen3.6-27B/docs/QWEN36_FP8_TIERFIX_VALIDATION_20260526.md delete mode 100644 contrib/models/Qwen3.6-27B/docs/patches/mtp_batched_accept.patch delete mode 100644 contrib/models/Qwen3.6-27B/docs/patches/mtp_batched_accept_README.md diff --git a/contrib/models/Qwen3.6-27B/README.md b/contrib/models/Qwen3.6-27B/README.md index a6327f07..f868a818 100644 --- a/contrib/models/Qwen3.6-27B/README.md +++ b/contrib/models/Qwen3.6-27B/README.md @@ -83,80 +83,46 @@ Unit tests are architecture-level and do not depend on weights. Coverage include | Capital of Japan | Tokyo | PASS | | sqrt(144) | 12 | PASS | -## Performance Benchmarks - -### Qwen3.6-27B on trn2.3xlarge (TP=4, LNC=2, SDK 2.29, BF16) - -**TTFT (Time To First Token)** - -| Input Length | P50 (ms) | P95 (ms) | -|-------------|----------|----------| -| 16 tokens | 305.3 | 305.6 | -| 64 tokens | 305.4 | 305.9 | -| 128 tokens | 306.6 | 306.8 | -| 256 tokens | 306.2 | 306.3 | - -**TPOT / Throughput** - -| Output Length | TPOT P50 (ms) | tok/s P50 | E2E P50 (ms) | -|--------------|---------------|-----------|---------------| -| 16 | 54.3 | 18.4 | 1,121 | -| 32 | 54.4 | 18.4 | 1,993 | -| 64 | 54.2 | 18.5 | 3,720 | -| 128 | 54.2 | 18.5 | 4,912 | - -### Comparison with Qwen3.5-27B - -| Metric | Qwen3.5-27B | Qwen3.6-27B | Delta | -|--------|------------|------------|-------| -| TPOT P50 | 53 ms | 54.2 ms | +2.3% | -| Throughput | 18.9 tok/s | 18.5 tok/s | -2.1% | -| TTFT (128 tok) | 576 ms | 306.6 ms | -47% * | - -\* TTFT improvement is due to compilation config differences (256-token bucket vs 128-token bucket), not model differences. Architectural performance is equivalent. - -### Long-Context vLLM Baseline - -A 128K FP8-MLP artifact was validated on trn2.3xlarge (TP=4, LNC=2, SDK 2.29) -with the vLLM Neuron plugin, Qwen chunked prefill, and native vLLM APC enabled. - -| Metric | Result | -|--------|--------| -| Max model length | 131,072 tokens | -| Context encoding bucket | 512 | -| Prefill throughput | 404-428 tok/s from 512 through 64K prompt tokens | -| Decode throughput | 26.3-26.6 tok/s | -| 64K quality | needle retrieval prompts returned all expected codes | -| State reset | repeated short-after-long validation passed after 32K and 64K requests | -| Peak Neuron device memory | ~53.25 GB decimal during the 64K eval | - -TTFT/TPOT details for the same 128K FP8/vLLM artifact: - -| Metric | Result | Notes | -|--------|--------|-------| -| Decode TPOT | ~37.6-38.0 ms/token | Derived from 26.3-26.6 tok/s decode | -| Cold 512-token TTFT | ~1.2-1.3s | Derived from measured prefill throughput plus one decode step | -| Cold 32K-token TTFT | ~76.6-81.1s | Derived from measured prefill throughput plus one decode step | -| Cold 64K-token TTFT | ~153-162s | Derived from measured prefill throughput plus one decode step | -| Warm APC latency, ~10.8K prompt | 1.36-2.38s | Exact-repeat, partial-prefix, and cross-prefix validation runs | -| Cold APC baseline, ~10.8K prompt | 25.17-26.68s | Same prompts with prefix cache disabled or cold | - -Native vLLM prefix caching/APC was also validated with exact greedy output -matches: - -| APC Scenario | Cold | Warm | Speedup | Result | -|--------------|------|------|---------|--------| -| Server exact-repeat, ~10.8K prompt tokens | 26.68s | 1.67s | 16.0x | exact text match | -| Offline exact-repeat | 26.19s | 2.38s | 11.0x | exact token-ID match | -| Offline partial-prefix reuse | 25.52s | 1.70s | 15.0x | exact token-ID match | -| Server cross-prefix reuse | 25.17s | 1.36s | 18.5x | exact text match | +## Current Validation Evidence + +The current publishable baseline is the coherent 256K native-chunk loadfix +lineage on `trn2.3xlarge` with TP=4, LNC=2, SDK 2.29, host sampling, BF16 KV, +BF16 LM head, BF16 gates, QKV NKI enabled, segmented CTE512, GDN segment 512, +and CTE bucket 2048. + +Validation evidence is included in +`validation_outputs/qwen36_nativechunk_baseline_20260609T000000Z/`. + +| Case | Prompt tokens | TTFT | Prefill tok/s | Source | +|------|--------------:|-----:|--------------:|--------| +| 16K native chunk | 16,374 | 6.8379 s | 2,394.6 | `qwen36_256k_nativechunk_crossguard_clean16k_probe_20260609T000000Z.jsonl` | +| Long context, usage-accounted | 242,864 | 235.9819 s | 1,029.2 | `qwen36_256k_nativechunk_crossguard_258k_probe_20260609T000000Z.jsonl` | +| Same long-context run, estimated target prompt | 253,899 estimated | 235.9819 s | 1,075.9 | tokenizer-derived fallback, not usage-accounted | + +`baseline_summary.json` marks the run coherent, target recovered, and +materially faster. `log_scan_empty.txt` contains no invalid-token, fallback, +NaN, NRT, or traceback markers. + +### Correctness Evidence + +- BF16 smoke tests: 7/7 text-only quality prompts passed with + `enable_thinking=False`. +- HF greedy comparison: 156/160 token positions matched HF greedy (97.5%), + and 9/10 prompts matched exactly for all 16 generated tokens. +- Strict Hybrid APC exactness passed for full-prefix, partial-prefix, and + real-token generation cases. +- Current native-chunk validation has `pass=true`, thinking enabled, coherent + first text, and an empty bad-marker log scan. ### Key Observations -- **BF16 TP=4 is HBM-limited:** The pure BF16 path is limited to short contexts on trn2.3xlarge. The validated 128K baseline uses MLP-only FP8 weights plus the hybrid cache manager. -- **DeltaNet enables efficient TKG:** Token generation uses O(1) per-token recurrence instead of O(n) KV cache attention for 48/64 layers. -- **vLLM APC is high leverage:** Repeated-prefix requests avoid replaying long chunked prefill and are the largest observed latency win for chat/RAG-style workloads. -- **Performance equivalent to Qwen3.5-27B:** The BF16 TPOT difference is within measurement noise. Expected since architectures are identical. +- **Qwen3.6 is hybrid, not transformer-only:** 48 of 64 layers use DeltaNet/GDN + recurrent state, so standard KV-only prefix caching is insufficient. +- **Safe Hybrid APC needs three-way agreement:** reusable prefix length is the + intersection of attention KV hits, GDN recurrent checkpoints, and GDN + convolution checkpoints. +- **Native Neuron chunking is required for the current coherent fast path:** + generic vLLM chunking is not the validated path for this model on Neuron. ## Usage diff --git a/contrib/models/Qwen3.6-27B/docs/CODEX_CONTINUOUS_BATCHING_PROMPT.md b/contrib/models/Qwen3.6-27B/docs/CODEX_CONTINUOUS_BATCHING_PROMPT.md deleted file mode 100644 index 7e758a85..00000000 --- a/contrib/models/Qwen3.6-27B/docs/CODEX_CONTINUOUS_BATCHING_PROMPT.md +++ /dev/null @@ -1,193 +0,0 @@ -# Codex Prompt — Enable Continuous Batching for Qwen3.6-27B - -## Context - -vLLM v1 already does continuous batching. The bottleneck is on the Neuron -side: the current MTP artifact (and baseline v3) was compiled with -`tkg_batch_size=1`, meaning the device can only execute one decode stream -per forward call regardless of how many vLLM tries to schedule. - -To enable real continuous batching: recompile with `tkg_batch_size > 1` so -the device-side decode graph processes multiple sequences in parallel. - -Current compile harness has it hardcoded: -- `contrib/models/Qwen3.6-27B/test/integration/qwen36_27b_compile_mtp.py:79-81` - - `batch_size=1, ctx_batch_size=1, tkg_batch_size=1` - -vLLM start script ALREADY wires `tkg_batch_size = MAX_NUM_SEQS` via -`--override-neuron-config`. This override only takes effect if the -underlying NEFF was compiled with the matching batch size. **No runtime -override of compile-time batch dimension is possible.** - -## Goal - -Compile and validate a continuous-batching artifact with `tkg_batch_size=8` -(and matching `batch_size=8`). Demonstrate aggregate throughput scaling -with `max-num-seqs=8` on real workloads. Document HBM peak. - -## Phase A: Compile harness update (target 0.5 day) - -A.1 Modify `contrib/models/Qwen3.6-27B/test/integration/qwen36_27b_compile_mtp.py` -to accept a CLI `--tkg-batch-size` argument (default 1 for backward compat). -Apply to NeuronConfig: -```python -batch_size = args.tkg_batch_size -ctx_batch_size = 1 # prefill stays single-stream per CTE call -tkg_batch_size = args.tkg_batch_size -``` - -A.2 Add `--max-num-seqs` and `--max-model-len` mirrors if not already present. - -A.3 Reduce `seq_len` for the first batched run to keep HBM in budget: -- batch=8, seq_len=16384 → KV cache ~8 GB, total HBM ~63 GB (fits) -- batch=8, seq_len=32768 → KV cache ~16 GB, total HBM ~71 GB (fits) -- batch=8, seq_len=65536 → KV cache ~34 GB, total HBM ~89 GB (tight) - -Start with batch=8, seq_len=32768. Validate. Push to 65536 only if HBM -budget allows. - -A.4 Document in the compile harness which compile flags need to match the -vLLM `--max-num-seqs` value at serve time. - -## Phase B: Compile + load (target 0.5 day) - -B.1 Compile artifact: -```bash -python contrib/models/Qwen3.6-27B/test/integration/qwen36_27b_compile_mtp.py \ - --model-path /opt/dlami/nvme/models/Qwen3.6-27B \ - --compiled-path /opt/dlami/nvme/qwen_artifacts/qwen36_27b_32k_fp8_batch8_run1 \ - --seq-len 32768 \ - --cte-bucket 512 \ - --tp-degree 4 \ - --logical-nc-config 2 \ - --tkg-batch-size 8 \ - --load-after-compile -``` - -Expected compile time: ~22 min. Slightly longer than batch=1 due to bigger -TKG graph. - -B.2 Load artifact on hardware. Verify load succeeds and HBM peak after -load is below 80 GB (leaves headroom for activations during inference). - -B.3 If NRT_RESOURCE (HBM blew up): drop seq_len to 16384 or batch to 4. -Report which tensor exceeded budget. - -## Phase C: Single-stream regression check (target 0.25 day) - -C.1 Bring up vLLM server with `--max-num-seqs 8` pointing to the new -artifact: -```bash -bash contrib/models/Qwen3.6-27B/vllm/start_vllm_server.sh \ - --compiled-artifacts /opt/dlami/nvme/qwen_artifacts/qwen36_27b_32k_fp8_batch8_run1 \ - --max-num-seqs 8 \ - --seq-len 32768 \ - --enable-chunked-prefill \ - --enable-prefix-caching -``` - -C.2 Single-stream smoke (one request at a time): -- Math: 17 × 23 should return 391 -- 762-token MGS prompt: coherent output - -C.3 Single-stream perf (one request at a time): -- 4K prompt, 128-token decode: measure prefill tok/s + decode tok/s -- Compare against baseline v3 (~418 prefill, ~27 decode) -- Expect: equivalent prefill, slightly lower decode (batch=8 graph has - some per-step overhead even at active batch=1). Acceptable: within 10%. - -If decode regresses more than 20% from baseline v3: there's a configuration -issue. Investigate before continuing. - -## Phase D: Aggregate throughput measurement (target 0.5 day) - -D.1 Use existing `validation_scripts/qwen36_27b_vllm_concurrency_eval.py` -or write a small async harness if needed. Test at: -- concurrency=1 (baseline) -- concurrency=2 -- concurrency=4 -- concurrency=8 (full batch) -- concurrency=16 (tests queueing behavior) - -D.2 Two prompt distributions: -- "Short": 1K prompts, 128 token decode (chat-like workload) -- "Medium": 8K prompts, 256 token decode (RAG-like workload) - -D.3 Capture for each (concurrency, prompt_len) point: -- Aggregate input tok/s -- Aggregate output tok/s -- Per-stream input tok/s -- Per-stream output tok/s -- P50 / P95 TTFT -- P50 / P95 inter-token latency (TPOT) -- HBM peak (neuron-monitor during run) - -D.4 Expected scaling pattern: -- concurrency=1: ~baseline single-stream -- concurrency=4: ~2-3× aggregate (sub-linear because per-stream slows) -- concurrency=8: ~3-5× aggregate at the batch ceiling -- concurrency=16: queued, aggregate same as 8 but TTFT spikes - -If concurrency=8 aggregate is NOT 3× the concurrency=1 number: batching -isn't actually happening at the device. Verify by checking -neuron-monitor: should see batch=8 graph activity, not batch=1. - -## Phase E: APC interaction (target 0.25 day) - -E.1 Add a shared prefix to the prompts (system message + variable user -turn). Repeat the concurrency=8 measurement. - -E.2 Expected: APC hit rate ≥ 50% across the 8 concurrent streams (they -share the system prompt). Aggregate prefill should jump significantly on -warm streams. - -E.3 If APC hit rate stays low when streams share a prefix: APC cache is -being evicted across concurrent streams. Investigate cache size limits. - -## Phase F: Documentation (target 0.25 day) - -F.1 Update `OPTIMIZATION_ARC.md` with the continuous-batching results: -- Add a "Continuous batching" row to the "What worked" table -- Update the "Hardware utilization" section with the aggregate numbers -- Update the "How this compares to NVIDIA" table with aggregate Trainium - numbers (the Millstone H100 page shows aggregate at 5 concurrent) - -F.2 Create `vllm/CONTINUOUS_BATCHING.md` with: -- Compile flags required -- vLLM serve flags required -- Measured throughput curve (concurrency 1-16) -- HBM budget table by (batch, seq_len) -- APC interaction notes - -## Hard constraints - -1. Do not modify baseline v3 artifact. Tag the new artifact as - `qwen36-27b-continuous-batching-v1` if all gates pass. -2. Commit + push after each phase. -3. Maximum compile attempts: 3. Each ~22 min. -4. If HBM exceeds 92 GB at batch=8: drop seq_len to 16384 and retry. -5. Do not enable MTP speculation in this artifact (defer to PR #4). Spec - decoding + continuous batching together is harder; tackle one at a time. - -## Expected outcomes - -| Outcome | Probability | Meaning | -|---|---|---| -| batch=8 compiles, loads, scales to 3-5× aggregate | 60% | Best case; ship as PR #2 (continuous batching baseline) | -| batch=8 compiles but scales less than 2× | 20% | Diagnose: probably KV cache contention or scheduler overhead | -| HBM blowup, must drop to batch=4 or seq_len=16K | 15% | Acceptable fallback; still 2-3× aggregate | -| Quality regression at batch>1 | 5% | Bug in hybrid cache at batch>1; investigate | - -Begin with Phase A. Report after Phase B. Do not chain phases. - -## Why this is high priority - -Currently single-stream decode is 27 tok/s. After continuous batching -with batch=8: -- Aggregate decode probably 150-200 tok/s (4-7× single) -- This is the metric that maps to "production serving capacity" -- Without it, you cannot answer "how many users can one instance serve?" -- Without it, the cost-per-token comparison vs H100 cannot be made - -This is the prerequisite for the multi-instance scaling discussion and -for any honest production-deployment claims. diff --git a/contrib/models/Qwen3.6-27B/docs/FULL_FP8_ISSUES_AND_FIXES.md b/contrib/models/Qwen3.6-27B/docs/FULL_FP8_ISSUES_AND_FIXES.md deleted file mode 100644 index ec7a8b4a..00000000 --- a/contrib/models/Qwen3.6-27B/docs/FULL_FP8_ISSUES_AND_FIXES.md +++ /dev/null @@ -1,405 +0,0 @@ -# Qwen3.6 27B Full FP8 — Issues Encountered and Fixes - -Consolidated catalog of every issue hit during the full-FP8 / 256K hybrid-APC -work and how each one was resolved. Branch: `codex/full-fp8-qwen36`. - -Source-of-truth detail (with exact log lines, PIDs, artifact paths) lives in: - -- [QWEN36_FP8_TIERFIX_VALIDATION_20260526.md](./QWEN36_FP8_TIERFIX_VALIDATION_20260526.md) — full chronological log -- [HYBRID_APC_PRODUCTION_PLAN.md](./HYBRID_APC_PRODUCTION_PLAN.md) — production bucket strategy -- [profile_artifacts/qwen36_256k_fp8_sparse_runtime_20260525/ERROR_LOG.md](../../../../profile_artifacts/qwen36_256k_fp8_sparse_runtime_20260525/ERROR_LOG.md) — runtime load failures -- [AGENTS.md](../../../../AGENTS.md) — error-logging contract and measurement rules - -This document is the index. Each entry: **what broke → why → what we changed → verification**. - ---- - -## Table of Contents - -1. [Quantization & Checkpoint Conversion](#1-quantization--checkpoint-conversion) -2. [Neuron Compiler Failures](#2-neuron-compiler-failures-neuronx-cc) -3. [vLLM / Hybrid APC / Scheduler](#3-vllm--hybrid-apc--scheduler) -4. [Runtime Load & Memory (NRT_RESOURCE / scratchpad / HBM)](#4-runtime-load--memory-nrt_resource--scratchpad--hbm) -5. [Custom NKI Kernel (`qwen_segcte256`)](#5-custom-nki-kernel-qwen_segcte256) -6. [Validation Harness & Measurement Bugs](#6-validation-harness--measurement-bugs) -7. [Tooling, Sync, Shell, SSH](#7-tooling-sync-shell-ssh) -8. [Lessons Codified in `AGENTS.md`](#8-lessons-codified-in-agentsmd) - ---- - -## 1. Quantization & Checkpoint Conversion - -### 1.1 MLP-only FP8 scope was insufficient for "full FP8" - -- **Symptom:** Original path only quantized MLP layers; attention, DeltaNet projections, and fused QKV stayed BF16. -- **Cause:** `_mlp_only_modules_to_not_convert` excluded entire `self_attn` and `linear_attn` modules; checkpoint rewrite only handled MLP scale tensors. -- **Fix:** Added `fp8_full` mode in [qwen36_27b_compile_fp8.py](../test/integration/qwen36_27b_compile_fp8.py); broadened module selector to all Linear matmuls (MLP + attention + DeltaNet `in_proj_*` / `out_proj`); kept embeddings, norms, rotary, `lm_head`, DeltaNet `conv1d`/`A_log`/`dt_bias` in BF16. -- **Verification:** Unit tests in [test_qwen36_compile_fp8_config.py](../test/unit/test_qwen36_compile_fp8_config.py). - -### 1.2 Scale tensors not transformed alongside weights - -- **Symptom:** Loading FP8 artifact failed because scale tensors didn't match the transformed weights (Q/gate split, fused QKV concat, DeltaNet QKV TP reorder). -- **Cause:** Checkpoint converter in [modeling_qwen35.py](../src/modeling_qwen35.py) only transformed `.weight`; FP8 needs the matching `.scale` to follow the same reorder/split/concat. -- **Fix:** Added scale-aware transforms in the converter for: q_proj weight/scale split → `output_gate_proj`, fused `Wqkv` weight+scale creation, and DeltaNet `in_proj_qkv.weight/scale` TP reorder using identical permutation. FP8 concat uses `view(torch.int8)` round-trip because PyTorch rejects direct `torch.float8_e4m3fn` concat. -- **Verification:** [test_weight_conversion.py](../test/unit/test_weight_conversion.py). - ---- - -## 2. Neuron Compiler Failures (`neuronx-cc`) - -### 2.1 `NCC_ITIN902 TensorInitialization` / `AffineIV doesn't appear in params or loopnest` - -- **Symptom:** Compiler internal error during NEFF tensorization on specific 2D prefix-caching bucket pairs: - - `cte=256, prefix=16384` - - `cte=1024, prefix=1024` - - `cte=2048, prefix=2048` - - `cte=4096, prefix=4096` -- **Cause:** Compiler bug in `neuronx-cc` lowering on power-of-two square shapes and the small-active/large-prefix corner. AWS log itself says "open a Neuron SDK issue." -- **Fix:** Avoid those exact shapes. Use safe CTE ladder `512, 768, 1536, 3072` (256-aligned, non-square) and limit prefix-bucket granularity. -- **Verification:** `cte512_768_1536_3072_pfx16k` compile completed with `COMPILE_DONE` and 0 `NCC_ITIN902`. - -### 2.2 Combined dense + long-prefix artifact — `[F137] neuronx-cc forcibly killed (-9)` - -- **Symptom:** Compiling all five long-prefix pairs (`3072:0,32k,64k,128k,256k`) in one run was killed by OOM on the largest buckets (`bk3`, `bk4`). -- **Cause:** Compile-host RAM pressure when multiple HLOs compile in parallel for very large shapes. -- **Fix:** Split tiers into separate compile runs. Implemented sequential orchestrator in [tmp_compile_qwen256k_fp8_full_prod_three_prefix_tiers_hostlogits.sh](../../../../tmp_compile_qwen256k_fp8_full_prod_three_prefix_tiers_hostlogits.sh). -- **Verification:** Three tiered artifacts (`pfx32k_64k`, `pfx128k`, `pfx256k`) all reached `COMPILE_DONE` on the same host once compiled sequentially. - -### 2.3 Bash script wrote artifact paths with spaces (`tkg32768 131072 262144`) - -- **Symptom:** Orchestrator created malformed paths because the launcher used `${ARR[*]}` instead of joining with `_`. -- **Cause:** Bash array word-splitting in label construction. -- **Fix:** Use explicit `IFS=_` join in the helper script. -- **Verification:** Relaunch produced underscore-only paths. - -### 2.4 `head_dim must be <= 128 (got 256)` — `NCC_INKI016` - -- **Symptom:** AWS Neuron 2.30 `attention_segmented_cte` kernel hard-asserts `head_dim <= 128`; Qwen3.6 uses `head_dim=256`. -- **Cause:** Kernel was not designed for 256-wide head dim. -- **Fix:** Wrote custom `qwen_segcte256` kernel that splits Q/K into two 128-wide D tiles and accumulates `Q_lo@K_lo + Q_hi@K_hi` into one PSUM before softmax. See [Section 5](#5-custom-nki-kernel-qwen_segcte256). -- **Verification:** Custom kernel BIR-compiled cleanly for production shape `q=(2,3072,256)`, `k/v=(1024,1,256,256)`, `prior_seg_size=32768`. - -### 2.5 `NCC_INLA001 Allocated memory out of bound (128x402724)` — SBUF scratch too large - -- **Symptom:** First version of the custom segmented CTE kernel compiled HLO but exceeded SBUF in the backend. -- **Cause:** Each Q group held its own K/V segment buffers + scratch live simultaneously; per-group block-dim allocation × 24 Q groups blew SBUF. -- **Fix:** Two-stage kernel rewrite in [fused_segmented_attention_256.py](../../../../src/neuronx_distributed_inference/modules/attention/nki_kernels/qwen_segcte256/fused_segmented_attention_256.py): - 1. Allocate one reusable Q-group window instead of `block_dim=[num_grps]`. - 2. Stream active CTE in 512-token chunks through the same online-softmax accumulator. - 3. Cap packed Q loads to 4 groups (instead of 8) for `head_dim=256`. -- **Verification:** Production-shape BIR scratch dropped from `402724` to `31360`, under the 32767 SBUF free-dim limit. Full compile produced `COMPILE_DONE`. - ---- - -## 3. vLLM / Hybrid APC / Scheduler - -### 3.1 64 GiB KV cache estimate on 96 GiB Trn2 (model wouldn't even start) - -- **Symptom:** vLLM rejected 256K context with `64.0 GiB KV cache needed, 39.12 GiB available`. -- **Cause:** vLLM-Neuron runner created `FullAttentionSpec` for all 64 layers. Qwen3.6 is hybrid — only 16 of 64 layers are full-attention; the other 48 are DeltaNet (no token-long KV). -- **Fix:** Patched `get_kv_cache_spec` in [qwen36_hybrid_apc_scheduler_patch.py](../vllm/qwen36_hybrid_apc_scheduler_patch.py:129) to report KV only for the 16 full-attention layers, with local KV heads per TP rank. -- **Verification:** Server log: `Using Qwen hybrid KV-cache spec for 16/64 attention layers`, `GPU KV cache size: 262,400 tokens`. - -### 3.2 Warm prefix-cache continuation crashed with "no `hybrid_full_input_ids`" - -- **Symptom:** Second request reusing a cached prefix died because runner received suffix-only tokens without the full prompt context the GDN path needs. -- **Cause:** Scheduler metadata didn't carry full `all_token_ids` through suffix-prefill requests; runner's strict guard rightly rejected. -- **Fix:** Scheduler patch now attaches `full_input_ids` only when `num_computed_tokens > 0` (cached continuation), not for the first cold chunk; async prep bridge unpacks it back to `hybrid_full_input_ids` and slices to active suffix length. -- **Verification:** [test_hybrid_apc_manager.py](../test/unit/test_hybrid_apc_manager.py) + working 8k→16k→18432 cold/warm exactness on TRN2. - -### 3.3 `request_prefix_len` polluted by generated tokens - -- **Symptom:** During decode, vLLM `request.num_tokens` grows past the original prompt length; that leaked into APC metadata and made cold vs warm runs schedule differently. -- **Cause:** Metadata used `request.num_tokens` instead of the original prompt length. -- **Fix:** Cap `request_prefix_len` to the prompt-only token count. -- **Verification:** Cold and warm 8k runs now schedule identically (`prompt_len=8192 restore_len=6144 suffix_len=2048`). - -### 3.4 Dummy token `0` (`!!!!`) leaked into output during chunked prefill - -- **Symptom:** Cold output started with two `0` tokens before real decoding; warm output started correctly. -- **Cause:** vLLM-Neuron host-logits sampling appended `sampled_token_ids` from incomplete chunked-prefill rows. Earlier mask attempt was on the wrong path (`_sample_on_device` instead of `_generate_model_runner_output` for `hostlogits` artifacts). -- **Fix:** Added `_generate_model_runner_output` wrapper that masks incomplete-prefill rows before vLLM appends sampled IDs to request state. Used scalar `.item()` to work on Neuron/XLA tensors. -- **Verification:** 8k cold and warm now both emit `[3817, 7840, 9197, 4590]`. Test coverage added. - -### 3.5 Chunked prefill above 16k prefix exceeded compiled prefix bucket - -- **Symptom:** Cold 32k prompt failed with `Prefill len 512 with prefix len 16896 exceeds compiled 2D buckets... largest prefix bucket 16384`. -- **Cause:** vLLM's chunked-prefill continuation presents (active_chunk, computed_prefix) shapes to NxDI's 2D bucket selector; with `pfx16k`, a 32k prompt eventually reaches prefix `16896`. -- **Fix:** Two approaches: - 1. Runtime cap on backed prefix reads (`QWEN36_HYBRID_APC_MAX_BACKED_PREFIX_READ_LEN`). - 2. Split production artifacts by prefix tier; route long contexts to the long artifact. -- **Verification:** Tiered split (`pfx16k` for short, `pfx32k_64k`, `pfx128k`, `pfx256k` for long) compiled and validated for context up to 128K. 256K required the custom kernel (see §5). - -### 3.6 Sparse 2D bucket support needed — runtime assumed full Cartesian grid - -- **Symptom:** Wanting sparse pairs like `cte=3072 × prefix=262144` only (without the failing `cte=512 × prefix=262144`) wasn't possible — runtime did `bucket_idx = prefill_index * len(prefix_buckets) + prefix_index`. -- **Cause:** NxDI runtime hard-assumed a rectangular bucket grid. -- **Fix:** Added `context_encoding_bucket_pairs` config + sparse-pair-aware runtime selection in [model_wrapper.py](../../../../src/neuronx_distributed_inference/models/model_wrapper.py:1126) and [autobucketing.py](../../../../src/neuronx_distributed_inference/modules/autobucketing.py:162). Wired through compile script and vLLM serving config. -- **Verification:** Unit tests in [test_autobucketing.py](../../../../test/unit/modules/test_autobucketing.py) and [test_prefix_caching_bucket_selection.py](../../../../test/unit/models/test_prefix_caching_bucket_selection.py). - -### 3.7 Async sample called before any `execute_model()` (V1 scheduler) - -- **Symptom:** With async scheduling, `sample_tokens()` was invoked once before any cached logits existed; Neuron runner raised. -- **Cause:** vLLM-Neuron runner had no "no output yet" guard like the GPU path. -- **Fix:** Added no-output guard in the runner wrapper. - -### 3.8 Contract mismatch: `expected 24/29 tensors, got 15` - -- **Symptom:** With prefix caching disabled at vLLM level but Hybrid APC enabled, the model wrapper got only 15 mandatory tensors while artifact expected 29. -- **Cause:** Compiled artifact's input contract is fixed at trace time. Runtime config flipping `is_prefix_caching` off without recompiling broke the contract. -- **Fix:** - 1. Server script preserves the compiled `is_prefix_caching` contract from `neuron_config.json` even when vLLM-level prefix caching is off. - 2. Qwen wrapper expands 15-tensor runtime input to 24/29-tensor traced form by padding with inert MRoPE/vision/tile tensors. -- **Verification:** Unit coverage in [test_qwen36_model_aliases.py](../test/unit/test_qwen36_model_aliases.py). - ---- - -## 4. Runtime Load & Memory (`NRT_RESOURCE` / scratchpad / HBM) - -### 4.1 Combined sparse artifact failed to load on `trn2.3xlarge` - -- **Symptom:** Artifact compiled fine, but TRN2 load failed with `Failed to allocate 1.000GB ... usage: shared scratchpad` at `_tp0_bk36` (long-prefix NEFF). -- **Cause:** Trn2.3xlarge has 96 GiB total but in four 24 GiB HBM banks under LNC=2. Per-bank usage hit `~22-24 GiB` (tensors + scratchpad) before runtime needed another aligned 1 GiB allocation. The "combined" artifact loaded **every** compiled CTE×prefix NEFF at once. -- **Fix:** Physical split into multiple artifacts; route requests to the smallest artifact that covers the prefix tier. -- **Verification:** `pfx32k_64k` and `pfx128k` loaded and ran end-to-end after the split. - -### 4.2 Runtime bucket-override JSON didn't reduce loaded NEFFs - -- **Symptom:** Setting `--context-encoding-bucket-pairs 512:0 512:512` at runtime still failed at `_tp0_bk36` load. -- **Cause:** Saved `model.pt` references all compiled workdir NEFFs; runtime overrides control routing, not which NEFFs get staged. -- **Fix:** Per §4.1 — split artifacts physically; runtime overrides alone are insufficient. - -### 4.3 `NEURON_SCRATCHPAD_PAGE_SIZE=2048` did not help - -- **Symptom:** Tried larger scratchpad page size to relieve alignment pressure; still failed with `Failed to allocate 2.000GB`. -- **Cause:** Total scratchpad footprint, not just alignment fragmentation. -- **Fix:** Abandon page-size-only mitigation for over-broad artifacts; compile narrower artifacts. - -### 4.4 Initial three-tier artifacts compiled but failed to load - -- **Symptom:** `pfx32k_64k`, `pfx128k`, `pfx256k` all compiled with `seq_len=262144`, `pa_num_blocks=1024`, `tkg=[32768,131072,262144]` — and all failed `NRT_RESOURCE` at load. -- **Cause:** "Tiered" by prefix only; every tier still paid the full 256K cache and 3 TKG buckets. -- **Fix:** True tier-specific budgets in [tmp_compile_qwen256k_fp8_full_prod_three_prefix_tiers_hostlogits.sh](../../../../tmp_compile_qwen256k_fp8_full_prod_three_prefix_tiers_hostlogits.sh): - - `pfx32k_64k`: `seq_len=65536`, `pa_num_blocks=256`, `tkg=[32768,65536]`, keep dense `3072:0`. - - `pfx128k`: `seq_len=131072`, `pa_num_blocks=512`, `tkg=[131072]`, omit dense `3072:0`. - - `pfx256k`: `seq_len=262144`, `pa_num_blocks=1024`, `tkg=[262144]`, omit dense `3072:0`. -- **Verification:** All three tierfix artifacts loaded; `pfx32k_64k` and `pfx128k` passed prefill + chat. - -### 4.5 Device profiling caused `NRT_RESOURCE` on `pfx256k` load - -- **Symptom:** First 256K runtime validation died because `NEURON_RT_INSPECT_DEVICE_PROFILE=1` reserved `2.348 GB HBM per NC`, pushing per-bank load over the edge. -- **Cause:** Device profiler adds non-trivial HBM tax. -- **Fix:** Run validation without `NEURON_RT_INSPECT_DEVICE_PROFILE`. Profile separately on smaller artifacts or with reduced sampling. - -### 4.6 Null block (vLLM adds 1 reserved block) — `pa_num_blocks=1024` was too small - -- **Symptom:** vLLM logs showed `num_gpu_blocks` becoming `1025` after the runtime adds a reserved null block, but compiled artifact only had 1024 physical blocks. -- **Cause:** Off-by-one between compile-time `pa_num_blocks` and runtime "user-usable + null" convention. -- **Fix:** Compile with `pa_num_blocks=1025`. Validation runners now treat compiled count as physical (includes null) and set `num_gpu_blocks_override` to `compiled - 1`. -- **Verification:** Compile config logged `pa_num_blocks=1025, pa_min_blocks=1024, pa_headroom_blocks=1`. Updated [qwen36_hybrid_apc_context_sweep.py](../../../../validation_scripts/qwen36_hybrid_apc_context_sweep.py) + [qwen36_offline_decode_bench.py](../../../../validation_scripts/qwen36_offline_decode_bench.py). - ---- - -## 5. Custom NKI Kernel (`qwen_segcte256`) - -Required because AWS Neuron 2.30 `attention_segmented_cte` rejects `head_dim > 128`. - -### 5.1 `dma_copy dst partition dimension 256 exceeds maximum 128` - -- **Symptom:** BIR compile failed when loading K cache: K SBUF tile shape `(256, 512)` violated the 128-partition rule. -- **Cause:** Tried to keep `head_dim=256` on the partition axis. -- **Fix:** Load each 256-token KV block as two 128-token halves: temp `(128, 128)`, transpose each, write into 128-token offset inside K tile. - -### 5.2 `unsupported expression` — list comprehensions - -- **Symptom:** `[(k_lo[i], k_hi[i]) for i in range(...)]` rejected by NKI specialization. -- **Cause:** NKI front-end doesn't accept Python list comprehensions inside kernel helpers. -- **Fix:** Build the list with explicit `for ... append`. - -### 5.3 `failed to resolve name 'x::0.shape'` - -- **Symptom:** After splitting K into `(lo, hi)` pair, old metadata lookup `k_sbuf[0].shape[1]` returned `.shape` from the pair tuple. -- **Fix:** Branch the metadata lookup to use `k_sbuf[0][0].shape[1]` on the split-K path. - -### 5.4 `dma_transpose dst.shape must match transposed src.shape` - -- **Symptom:** Q load pattern used `ac.d=256` as D extent while destination was 128. -- **Fix:** Use 128-wide D extent in source pattern: `[[ac.d, num_f], [1,1], [1,1], [1,128]]`. - -### 5.5 `reduce_one_batch` signature mismatch - -- **Symptom:** Compile failed with `batch_idx * sb_p * num_grps` where `batch_idx` was an object. -- **Cause:** Copied call signature didn't match installed Neuron 2.30 helper's argument order. -- **Fix:** Call with explicit keyword arguments matching the installed helper. - -### 5.6 `NCC_INLA001 Allocated memory out of bound (128x402724)` - -- See [§2.5](#25-ncc_inla001-allocated-memory-out-of-bound-128x402724--sbuf-scratch-too-large). The fix (group-window aliasing + active streaming + Q-pack cap) reduced production-shape SBUF scratch from `402724` to `31360`. - -### 5.7 `_exp_impl` partial-sum slot index out of range - -- **Symptom:** Active-streaming variant tried to index exp partial-sum slot 1 when each 512-token chunk only allocated slot 0. -- **Cause:** Active attention config still referenced full active KV length per chunk instead of per-chunk view. -- **Fix:** Specialize the active attention config per chunk with that chunk's `global KV end`, retain global K start via `kv_section_idx`. - -### 5.8 Runtime `scalar DGE out-of-bound access` at 256K prefill (PA-blocks) - -- **Symptom:** Compile passed, model loaded, KV initialized, then context-encoding NEFF crashed mid-execution with repeated scalar-DGE OOB. -- **First hypothesis tested:** vLLM adds a null block (`1025` physical), but artifact had `pa_num_blocks=1024`. -- **Fix attempted:** Recompiled with `pa_num_blocks=1025`. **Did not fix it** — runtime still hit DGE OOB on the new artifact. - -### 5.9 Runtime DGE OOB — root cause: final partial active chunk reads past block table - -- **Symptom:** Even with `pa1025`, the 261,888-token prefill failed in `context_encoding_model/_tp0_bk0` with scalar-DGE OOB. -- **Cause:** Active stream loop always processed 6 full sections per CTE bucket, even when the final real active chunk was only 768 tokens. At the end of the 256K prompt, the kernel read block-table offsets beyond the 1024-entry table. -- **Fix:** Pad the kernel's internal block table by the CTE active block count (1024 → 1036 entries). Padded active stream loads resolve to block 0 instead of reading past the table. -- **Verification:** Bound-fix artifact compiled (`COMPILE_DONE`), and the no-device-profile 256K runtime validation passed: - - Cold 261,888 prefill: `551.97s` - - Warm refill (16-token suffix on shared 261,872-token prefix): `10.76s` - - Cold throughput: `474.46 tok/s`; warm refill throughput: `24,342.68 tok/s` - - Real-token + token-range checks: passed - - Host RSS peak: `35.31 GiB`; Neuron active allocation peak: `~28 GiB`; high-water counter: `58 GiB` - -### 5.10 Block-table active-block-fill (necessary but not sufficient) - -- **Symptom:** Earlier hypothesis was that `block_table` had `-1` entries for the active suffix. -- **Fix attempted:** Fill active block ids from `slot_mapping // pa_block_size` before NKI dispatch. Aligned with AWS docs on `nisa.dma_copy` dynamic addressing. -- **Result:** Helped the 8K smoke test but did NOT fix the 256K case. The real bug was §5.9. - -### 5.11 Production envelope and fail-closed hardening - -- **Finding:** The bound-fix 256K artifact has strong validation evidence, but only for the exact serving envelope: 256K context, `pa_num_blocks=1025`, one `cte3072:pfx262144` bucket, `qwen_segcte256` segment size 512, batch/concurrency 1, backed prefix reads, non-KVP, and non-transposed K cache. -- **Risk:** Enabling Hybrid APC outside the blessed vLLM launcher could previously fall back to local prompt hashing or synthetic attention block refs. -- **Fix:** When `use_hybrid_apc_manager=True`, `Qwen35InferenceConfig` now defaults to requiring vLLM metadata and attention block refs, with local hash fallback disabled. Validation-only flows can still opt back into local fallback explicitly. -- **Risk:** The generic ModelWrapper used absolute Hybrid APC control positions (`args[25]`) for restore-active detection. -- **Fix:** Restore-active detection now reads from the final five Hybrid APC control args, so future pre-control extras do not silently misbucket CTE. -- **Risk:** `qwen_segcte256` still exposed KVP and transposed-K branches that were not validated for production and contained NKI 0.3-sensitive HBM output/intermediate patterns. -- **Fix:** `qwen_segcte256` now raises immediately for `kvp_offset`/KVP or `k_pre_transposed=True`. The validated production path remains the non-KVP, non-transposed K path used by `attention_base.py`. - ---- - -## 6. Validation Harness & Measurement Bugs - -### 6.1 TPOT measured from streamed content chunks, not tokens - -- **Symptom:** Reported TPOT was `~109 ms/chunk` at 16K context (with 16 generated tokens → only 8 streamed content chunks), masking real decode speed. -- **Fix:** [qwen36_chat_completion_context_bench.py](../../../../validation_scripts/qwen36_chat_completion_context_bench.py) now requests `stream_options: {"include_usage": true}` and computes `token_tpot_seconds` from `usage.completion_tokens`. Old chunk metric preserved as `content_chunk_tpot_seconds`. -- **Verification:** Corrected 16k pfx16k measurement: `~50-52 ms/token`, `~19-20 decode tok/s`. - -### 6.2 "Warm prefill" was actually full-prompt cache replay - -- **Symptom:** Sub-second warm runs were misinterpreted as refill speed. -- **Cause:** [qwen36_hybrid_apc_context_sweep.py](../../../../validation_scripts/qwen36_hybrid_apc_context_sweep.py) generated the exact same prompt twice — that's an exact cache hit, not a refill. -- **Fix:** Default warm mode now: warm shared prefix + suffix A, then measure shared prefix + suffix B. -- **Verification:** Corrected 16k partial refill: `0.91s` for a 16,368-token shared prefix → ~`18k tok/s` reuse rate. - -### 6.3 Sweep accepted dummy token `0` (`!!!!`) as "real" output - -- **Symptom:** Validator's `vocab_size=248044` check passed because token `0` was within range, masking the chunked-prefill output leak. -- **Fix:** Tighter validation: explicitly fail if all generated tokens equal the configured dummy id, regardless of vocab bounds. Then use `usage.completion_tokens` + tokenizer/AutoConfig vocab fallback for true range check. - -### 6.4 Hardcoded `seq_len=262144, pa_num_blocks=1024` for every tier - -- **Symptom:** Three-tier validation runner forced 256K cache shape on the 64K and 128K tiers, causing `NRT_RESOURCE`. -- **Fix:** [tmp_run_qwen256k_fp8_tierfix_validation.sh](../../../../tmp_run_qwen256k_fp8_tierfix_validation.sh) now uses per-tier `(seq_len, pa_num_blocks, tkg buckets)`. - -### 6.5 Chat wrapper passed `--pa-num-blocks` to server script (unknown arg) - -- **Symptom:** `start_vllm_server.sh` rejected `--pa-num-blocks`. -- **Fix:** Pass `--pa-num-blocks` only to offline benchmarks; server gets `--num-gpu-blocks-override` via the appropriate path. - -### 6.6 Memory sampler could hang the wrapper if benchmark never started - -- **Symptom:** With `--stop-when-no-match`, sampler waited forever if vLLM died during startup. -- **Fix:** Sampler ignores its own PID, handles SIGTERM/SIGINT to write summary JSON; wrapper explicitly stops sampler per phase rather than relying on regex disappearance. Sampler regex broadened to match vLLM server processes during startup. - -### 6.7 `start_vllm_server.sh` forced `ENABLE_PREFIX_CACHING=1` when `--enable-hybrid-apc` - -- **Symptom:** Couldn't test "Hybrid APC on, prefix-cache reads off" because flags were coupled. -- **Fix:** Split controls: `ENABLE_PREFIX_CACHING`, `ENABLE_HYBRID_APC`, `HYBRID_APC_DISABLE_UNBACKED_PREFIX_READS`, `HYBRID_APC_ENABLE_BACKED_PREFIX_READS`, and `QWEN36_HYBRID_APC_INSTALL_PATCH` are now independent. - -### 6.8 Validator's `vocab_size` check rejected legitimate model tokens - -- **Symptom:** Model emitted token `248068`, valid for the loaded model (`vocab_size=248320`) but above the tokenizer's base `vocab_size=248044`. -- **Fix:** Use `max(tokenizer.vocab_size, len(tokenizer), AutoConfig.vocab_size)` as the upper bound. - ---- - -## 7. Tooling, Sync, Shell, SSH - -### 7.1 Stale remote code (no `--context-encoding-bucket-pairs`) - -- **Symptom:** Remote compile script lacked sparse-pair CLI flag even though local repo had it. -- **Fix:** Sync the compile entrypoint along with runtime code; `bash -n` + `py_compile` checks before launching. - -### 7.2 `scp` multi-file → wrong directory - -- **Symptom:** Multi-file `scp` of mixed sources landed extra copies in the last directory. -- **Fix:** Use explicit per-file destinations or `rsync -R`. Cleaned the misplaced copies and removed them. - -### 7.3 Local zsh expanded `*` in remote command - -- **Symptom:** `ssh host "find /path/*"` failed locally because zsh tried to glob the path on the Mac. -- **Fix:** Quote remote command bodies; use single quotes around the SSH command argument. - -### 7.4 `rsync --info=stats2` rejected by macOS BSD rsync - -- **Fix:** Use portable `--stats`. - -### 7.5 SSH `Permission denied (publickey)` for EC2-to-EC2 transfers - -- **Symptom:** Source EC2 had no key for destination. -- **Fix:** Three options used at various times: - 1. SSH agent forwarding from local `trainium.pem`. - 2. Temporary ed25519 key created on source, authorized on destination, removed after transfer. - 3. `scp -3` relay through local (slow — avoid for large artifacts). - -### 7.6 `pkill -f` matched its own SSH command, killed the shell - -- **Symptom:** Cleanup SSH exited 255 with no output because broad `pgrep -f` pattern matched the SSH command line itself. -- **Fix:** Use explicit PIDs from prior status or narrower patterns; never use `pkill -f` patterns that could match the controlling shell. - -### 7.7 Remote `python` not on PATH - -- **Symptom:** Status/parsing commands failed with `python: command not found`. -- **Fix:** Use `python3` for remote helpers; activate Neuron venv for actual runtime work. - -### 7.8 Overlay venv missing PyTorch / `libneuronpjrt-path` - -- **Symptom:** Neuron 2.30 overlay venv had `nki 0.4` but no PyTorch; later, `torch_xla` failed to find the base venv's `libneuronpjrt-path` helper. -- **Fix:** Compile launcher adds base venv `site-packages` behind overlay, and base venv `bin` to `PATH`. - -### 7.9 Backgrounded shell ate `$ROOT` (`tee /run.log`) - -- **Symptom:** Wrapper backgrounded too broadly; nested var expansion broke; ended up writing to `/run.log` and python wasn't on PATH. -- **Fix:** Cleaner wrapper structure: start sampler separately as a tracked nohup PID; run benchmark in main subshell with explicit env activation. - -### 7.10 TRN2 SSH banner timeout / port unreachable during heavy compile - -- **Symptom:** SSH banner exchange timed out, then later TCP itself stopped. Local AWS CLI had stale credentials so couldn't inspect instance state. -- **Mitigation:** Use light periodic probes (not long live-tails) during heavy compiles; keep heartbeat automation as the resume signal. - ---- - -## 8. Lessons Codified in `AGENTS.md` - -Two operational rules added to the repo's [AGENTS.md](../../../../AGENTS.md): - -1. **Error-logging contract:** every error gets logged with what failed, exact error text, how we got there, hypothesis, fix, and verification — enough detail that another agent can reconstruct it. -2. **Measurement discipline:** - - TPOT must come from `usage.completion_tokens` (request `stream_options.include_usage`), not streamed content chunks. - - "Warm refill" requires a shared prefix + different suffix; identical prompts only measure exact cache hits. - - Record artifact, CTE buckets, prefix buckets, and whether backed prefix reads were enabled with every reported number. - ---- - -## Final State Summary - -| Tier | Artifact | Status | -|---|---|---| -| 16K | `cte512_768_1536_3072_pfx16k` | **Production-validated** — chat + multi-turn smoke pass | -| 64K | `pfx32k_64k_pa256` | **Loads + runs**, prefill + chat pass | -| 128K | `pfx128k_pa512` | **Loads + runs**, prefill + chat pass | -| 256K | `pfx256k_segcte512stream_qpack4_boundfix_pa1025` | **Validated only for the exact gated config** — cold `551.97s`, warm refill `10.76s`, real tokens validated | - -**Open work before general "production-ready":** repeat 256K runs (×3-5), full OpenAI server path test on `pfx256k`, multi-turn chat at long context, soak/load test, and fresh validation for any other bucket, KVP, transposed K cache, sliding-window, or multi-seq serving configuration. diff --git a/contrib/models/Qwen3.6-27B/docs/HYBRID_APC_PRODUCTION_PLAN.md b/contrib/models/Qwen3.6-27B/docs/HYBRID_APC_PRODUCTION_PLAN.md deleted file mode 100644 index 571b0266..00000000 --- a/contrib/models/Qwen3.6-27B/docs/HYBRID_APC_PRODUCTION_PLAN.md +++ /dev/null @@ -1,364 +0,0 @@ -# Qwen3.6 Hybrid APC Production Plan - -## Build Order - -```text -1. Production hybrid APC correctness -2. Dynamic CTE bucket serving -3. Block-size, bucket, and HBM tuning -4. GDN state dtype and memory optimization -5. Decode-side improvements -6. Kernel fusion and speculative decode -``` - -Do not start with FP8 recurrent cache, MTP, EAGLE, Medusa, flash decode, KV -tiling, or deeper GDN kernel fusion. Those add scheduler and rollback -complexity before the cache contract is correct. - -## Target Cache Object - -```text -HybridPrefixCheckpoint - cumulative_prefix_hash - token_ids_hash - cache_salt / tenant key - prefix_length_at_boundary - - attention: - per-attention-layer KV block refs - - gdn: - per-GDN-layer recurrent_state checkpoint - per-GDN-layer conv_state checkpoint - - metadata: - dtype - layout_version - model_revision - ref_count - last_access_time - valid_state_mask -``` - -The usable hit is the deepest cumulative-prefix boundary where all required -state exists: - -```text -usable_hit_len = - intersection( - attention_KV_full_block_hit, - all_GDN_recurrent_prefix_checkpoint_hits, - all_GDN_conv_prefix_checkpoint_hits - ) -``` - -If attention KV hits 16K but GDN state only hits 12K, suffix prefill must resume -from 12K. - -## Qwen3.6 GDN State - -At every reusable cumulative-prefix boundary, cache: - -```text -recurrent_state: [num_local_value_heads, key_dim, value_dim] -conv_state: [conv_dim, conv_kernel_size - 1] -``` - -Initial dtype policy: - -```text -attention KV: bfloat16 -GDN conv_state: bfloat16 -GDN recurrent_state: float32 -``` - -Conv state is small but correctness-critical. Recurrent state dominates GDN -cache memory and should remain FP32 until BF16 exactness is proven. - -## Restore Flow - -For prompt length `P` and hybrid hit length `H`: - -```text -cached prefix: tokens [0, H) -suffix prefill: tokens [H, P) -decode: tokens [P, ...) -``` - -Serving path: - -```text -1. vLLM hashes prompt blocks. -2. Hybrid APC computes usable H. -3. Restore attention block table for [0, H). -4. Restore GDN recurrent_state at H. -5. Restore GDN conv_state at H. -6. Send only suffix tokens [H, P) to Neuron CTE. -7. Position IDs start at H. -8. Attention suffix attends to cached KV plus new suffix KV. -9. GDN recurrence starts from restored recurrent_state. -10. GDN conv starts from restored conv_state. -11. Store new boundary checkpoints for newly completed blocks. -12. Decode uses final restored and updated state. -``` - -## Sprint Plan - -### Sprint 1: Correctness Foundation - -Build: - -```text -HybridAPCManager -GDN recurrent/conv prefix-boundary checkpoint cache -hybrid hit intersection -partial-prefix restore path -FP32 recurrent cache option -correctness tests -``` - -Success criteria: - -```text -warm full-prefix output == cold output -partial-prefix output == cold output -attention-only false hit cannot happen -concurrent requests do not leak state -``` - -Current v0 branch status: - -```text -implemented: - HybridAPCMetadataStore for cumulative-prefix checkpoint metadata - bounded model-side HybridGDNCheckpointCache tensor bank - model restore/commit slot inputs - use_hybrid_apc_manager initialization without the old guard - v0 launcher validation requiring checkpoint interval == block size - async prefix-caching bridge for scheduler-supplied restore/commit tensors - request finish/cancel lifecycle callbacks for checkpoint refcounts - Trainium exactness and HBM validation harness - -still required before production: - vLLM scheduler integration that computes cumulative-prefix hashes and slots - Trainium execution of cold/warm exactness harness on compiled artifacts - production cancellation/eviction callback wiring from vLLM events - long-context HBM sweep to choose checkpoint slot count and commit policy - larger production prefix buckets for 32K+ warm reuse -``` - -Production prefix-bucket plan: - -```text -Previous 256K FP8 artifact was correct only up to its compiled prefix bucket -coverage: - prefix_buckets = [256, 512, 1024, 2048, 4096, 8192, 16384] - -32K/64K/128K contexts can still run on the 256K artifact, but warm APC reuse -above 16K must replay the remainder. This is correct but slower. - -Production strategy is one sparse 2D CTE/prefix artifact, not two separate -models: - - dense fast path: - CTE buckets = [512, 768, 1536, 3072] - prefix buckets = [0, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768] - - long-prefix fallback: - [CTE 3072, prefix 65536] - [CTE 3072, prefix 131072] - [CTE 3072, prefix 262144] - -The dense fast path is for common short/normal cached prefixes and preserves -prefill speed by avoiding unnecessary padding to 3072. The sparse long-prefix -fallback enables 64K/128K/256K prefix reuse without compiling the full CTE x -prefix Cartesian grid that triggers Neuron compiler tensorization failures. -``` - -Implementation notes: - -```text -compile flag: - --context-encoding-bucket-pairs ACTIVE:PREFIX ... - -runtime behavior: - Prefix-caching CTE bucket selection now chooses the smallest actual compiled - [active_tokens, prefix_tokens] pair that can serve the request, instead of - assuming every CTE bucket exists for every prefix bucket. - -serving behavior: - vLLM override config forwards context_encoding_bucket_pairs so loaded - artifacts use the same sparse matrix they were compiled with. -``` - -## Fixed Bug Record: Neuron Tensorization Failure on Full 2D Prefix Grid - -```text -What failed: - 256K FP8 full Hybrid APC compile with pfx256k and multiple CTE buckets. - -How we got there: - Host: ubuntu@16.26.202.235 - Script: - tmp_compile_qwen256k_fp8_full_cte512_768_1536_3072_pfx256k_hostlogits.sh - Key args: - --seq-len 262144 - --max-context-length 262144 - --cte-buckets 512 768 1536 3072 - --prefix-buckets 256 512 1024 2048 4096 8192 16384 32768 65536 131072 262144 - --weight-dtype fp8_full - --enable-prefix-caching - --enable-hybrid-apc - --enable-vllm-chunked-prefill - -Exact error: - NCC_ITIN902 TensorInitialization error: - AffineIV doesn't appear in params or loopnest - -Failed generated buckets: - bk9 = [CTE 512, prefix 65536] - bk10 = [CTE 512, prefix 131072] - bk21 = [CTE 768, prefix 65536] - bk22 = [CTE 768, prefix 131072] - -Root cause hypothesis: - HLO generation succeeds, then neuronx-cc fails inside internal tensorization - for some small-active-token / large-prefix-token 2D prefix-cache shapes. This - is a Neuron compiler lowering bug, not disk pressure and not an invalid model - config. - -Fix: - Stop compiling the full Cartesian product. Add explicit sparse - context_encoding_bucket_pairs and route runtime selection over the actual - compiled pair list. - -Mitigation shape set: - Dense fast path only up to 32K prefix for all production CTE buckets: - [512/768/1536/3072] x [0..32768] - Long-prefix fallback only on largest CTE bucket: - [3072, 65536], [3072, 131072], [3072, 262144] - -Verification: - Unit/config tests passed: - 38 local contrib tests passed - 86 remote Neuron-env focused tests passed - Sparse high-prefix probe compile started with 7 CTE HLOs and no NCC_ITIN902 - observed at HLO generation time; final NEFF compile result must still be - checked before treating the sparse artifact as production-ready. -``` - -## Fixed Bug Record: Invalid Fast Warm Prefill - -This bug is useful to showcase because the first symptom looked like excellent -performance, but the warm path was not executing the same model semantics as -cold prefill. - -```text -Symptom: - Warm prefill appeared sub-second, but cold/warm generated token IDs diverged. - Cold also leaked placeholder token IDs: - cold = [0, 0, 3817, 7840] - warm = [3817, 7840, 9197, 4590] - -Root causes: - 1. vLLM attention prefix hits could exceed the deepest GDN checkpoint that - was actually available. - 2. Scheduler metadata used request token counts that could include generated - tokens instead of prompt-only tokens. - 3. Incomplete chunked-prefill rows in the host-logits path could append - placeholder sampled IDs as real generated tokens. - -Fix: - 1. Cap vLLM prefix-cache reads to the largest GDN-backed checkpoint. - 2. Build Hybrid APC metadata from prompt-only length/token IDs. - 3. Mask incomplete chunked-prefill sampled IDs to -1 before vLLM appends - them to request state. - -Evidence after fix: - 8K cold/warm exactness passed: - cold = [3817, 7840, 9197, 4590] - warm = [3817, 7840, 9197, 4590] - repeat_exact = true - - Warm prefill became slower than the invalid shortcut, but correct: - cold ~= 15.26s - warm ~= 4.95s -``` - -### Sprint 2: Dynamic CTE Buckets - -Build: - -```text -multi-bucket CTE artifact path -runtime suffix bucket selection -262K TP=4 [256] artifact -block_size 128/256 comparison -``` - -Success criteria: - -```text -short prompts retain 1.5x-2.3x latency gain -262K TP=4 [256] loads -TP=4 beats TP=8 unless TP=4 cannot load -``` - -### Sprint 3: Memory and HBM Tuning - -Build: - -```text -GDN recurrent state slot accounting -eviction/ref-count policy -FP32 vs BF16 recurrent experiment -attention KV memory report -hybrid cache memory dashboard -``` - -### Sprint 4: Decode Optimization - -Build: - -```text -lower-overhead GDN state gather/scatter -decode microbenchmarks -batch-slot reuse optimization -possibly fused recurrent step -``` - -## Test Matrix - -Correctness: - -```text -cold vs warm exact token IDs -partial-prefix exact match -non-block-aligned shared prefix floors to full block -attention hit with missing GDN state falls back -conv-state restore failure test by zeroing conv state -multi-hit chat simulation -mixed cold/warm continuous batching -long-context warm hit at 128K and 262K -``` - -Performance: - -```text -Context length: 256, 512, 2K, 8K, 32K, 128K, 262K -Block size: 64, 128, 256 -CTE buckets: [256], [512], [256,512], [256,512,1024] -TP: 4, and 8 only if HBM/load requires it -Cache mode: no APC, attention APC only, hybrid APC -GDN dtype: recurrent FP32, recurrent BF16 experiment -Workloads: single request, repeated system prompt, chat, long-doc QA -``` - -Immediate Trainium experiments: - -```text -262K TP=4, block_size=256, CTE buckets [256] -262K TP=4, block_size=128, CTE buckets [256] -128K TP=4, block_size=128, CTE buckets [256,512] -128K TP=4, block_size=256, CTE buckets [256,512] -``` diff --git a/contrib/models/Qwen3.6-27B/docs/QWEN36_FP8_TIERFIX_VALIDATION_20260526.md b/contrib/models/Qwen3.6-27B/docs/QWEN36_FP8_TIERFIX_VALIDATION_20260526.md deleted file mode 100644 index 6c61706f..00000000 --- a/contrib/models/Qwen3.6-27B/docs/QWEN36_FP8_TIERFIX_VALIDATION_20260526.md +++ /dev/null @@ -1,2364 +0,0 @@ -# Qwen3.6 27B FP8 Tierfix Validation - 2026-05-26 - -This note records the 2026-05-26 TRN2 validation of the three prefix-tier FP8 -artifacts and the current blocker for 256K prefix serving. - -Raw result JSON is stored at: - -```text -profile_artifacts/qwen36_fp8_tierfix_validation_20260526/summary_partial_with_pfx256_failure.json -``` - -Remote validation root: - -```text -/home/ubuntu/validation_logs/fp8_256k/tierfix_validation_20260526T152617Z -``` - -Test host: - -```text -ubuntu@16.50.61.215 -instance: trn2.3xlarge -logical-neuroncore-config: 2 -``` - -## Artifact Results - -### 32K/64K Prefix Tier - -Artifact: - -```text -/mnt/trainium_artifacts/qwen_artifacts/qwen36_27b_64k_fp8_full_lmheadbf16_hybrid_apc_prod_pfx32k_64k_nki_fusedstable_directsolve_hostlogits_b256_cte3072_pfx32k_64k_pa256_slots64_tkg32768_65536_async_20260526T132620Z_tierfix_pfx32k_64k -``` - -Compiled limits: - -```text -seq_len=65536 -max_context_length=65536 -pa_num_blocks=256 -pa_block_size=256 -prefix_buckets=[32768, 65536] -context_encoding_bucket_pairs=[[3072, 0], [3072, 32768], [3072, 65536]] -token_generation_buckets=[32768, 65536] -``` - -Prefill: - -| target tokens | cold prefill | cold TPS | warm refill | warm refill TPS | real tokens | -| --- | ---: | ---: | ---: | ---: | --- | -| 32768 | 60.596s | 540.77 | 5.480s | 5979.19 | pass | -| 65280 | 121.736s | 536.24 | 5.667s | 11519.06 | pass | - -Chat/decode: - -| target tokens | run | TTFT | TPOT | decode TPS | completion tokens | -| --- | --- | ---: | ---: | ---: | ---: | -| 32768 | cold | 59.913s | 79.67ms | 12.55 | 64 | -| 32768 | repeat | 5.549s | 78.36ms | 12.76 | 64 | -| 65280 | cold | 67.959s | 83.66ms | 11.95 | 64 | -| 65280 | repeat | 5.785s | 83.50ms | 11.98 | 64 | - -Runtime evidence: - -```text -vLLM reported GPU KV cache size: 65,792 tokens -max concurrency for 65,536 tokens: 1.00x -peak host RSS during chat: 34.04 GiB -``` - -### 128K Prefix Tier - -Artifact: - -```text -/mnt/trainium_artifacts/qwen_artifacts/qwen36_27b_128k_fp8_full_lmheadbf16_hybrid_apc_prod_pfx128k_nki_fusedstable_directsolve_hostlogits_b256_cte3072_pfx128k_pa512_slots64_tkg131072_async_20260526T132620Z_tierfix_pfx128k -``` - -Compiled limits: - -```text -seq_len=131072 -max_context_length=131072 -pa_num_blocks=512 -pa_block_size=256 -prefix_buckets=[131072] -context_encoding_bucket_pairs=[[3072, 131072]] -token_generation_buckets=[131072] -``` - -Prefill: - -| target tokens | cold prefill | cold TPS | warm refill | warm refill TPS | real tokens | -| --- | ---: | ---: | ---: | ---: | --- | -| 130816 | 298.355s | 438.46 | 6.972s | 18762.74 | pass | - -Chat/decode: - -| target tokens | run | TTFT | TPOT | decode TPS | completion tokens | -| --- | --- | ---: | ---: | ---: | ---: | -| 130816 | cold | 298.550s | 173.88ms | 5.75 | 64 | -| 130816 | repeat | 7.250s | 173.41ms | 5.77 | 64 | - -Runtime evidence: - -```text -vLLM reported GPU KV cache size: 131,328 tokens -max concurrency for 131,072 tokens: 1.00x -peak host RSS during chat: 33.75 GiB -``` - -### 256K Prefix Tier - -Artifact: - -```text -/mnt/trainium_artifacts/qwen_artifacts/qwen36_27b_256k_fp8_full_lmheadbf16_hybrid_apc_prod_pfx256k_nki_fusedstable_directsolve_hostlogits_b256_cte3072_pfx256k_pa1024_slots64_tkg262144_async_20260526T132620Z_tierfix_pfx256k -``` - -Compiled limits: - -```text -seq_len=262144 -max_context_length=262144 -pa_num_blocks=1024 -pa_block_size=256 -prefix_buckets=[262144] -context_encoding_bucket_pairs=[[3072, 262144]] -token_generation_buckets=[262144] -``` - -Result: failed during Neuron Runtime load before validation could run. - -Exact failure: - -```text -NRT_RESOURCE in nrt_load_util -Failed to allocate 1.000GB (alignment: 4.000MB, usage: shared scratchpad) -Failed to load NN: - .../context_encoding_model/_tp0_bk0/model.MODULE_c3eddc16a94d9c7dfe80+5c498585.neff -err: 4 -Failed to create logical core info for subgraph 0 to 1 -Failed to stage graph to NeuronCore -Failed to load collectives for model -``` - -TDRV memory table at failure: - -```text -per-HBM TOTAL: 22.056GB -Model Tensors: 12.052GB -Shared Scratchpad: 10.000GB -Failed next alloc: 1.000GB shared scratchpad -``` - -Retry with debug scratchpad placement disabled: - -```text -NEURON_RT_DBG_SCRATCHPAD_ON_SINGLE_CORE=0 -root=/home/ubuntu/validation_logs/fp8_256k/tierfix_validation_pfx256_dbg0_20260526T160225Z -``` - -Result: still failed with the same `NRT_RESOURCE` class. - -```text -per-HBM TOTAL: 22.056GB -Model Tensors: 12.052GB -Shared Scratchpad: 7.000GB on one logical core + 3.000GB on sibling -Failed next alloc: 1.000GB shared scratchpad -``` - -Probe with smaller runtime scratchpad page: - -```text -NEURON_RT_DBG_SCRATCHPAD_ON_SINGLE_CORE=0 -NEURON_SCRATCHPAD_PAGE_SIZE=512 -root=/home/ubuntu/validation_logs/fp8_256k/pfx256_pagesize512_probe_20260526T160430Z -``` - -Result: still failed. - -```text -NRT_RESOURCE in nrt_load_util -Failed to allocate 512.000MB (alignment: 4.000MB, usage: shared scratchpad) -per-HBM TOTAL: 23.056GB -Model Tensors: 12.052GB -Shared Scratchpad: 11.000GB -``` - -## Why 256K Prefix Fails - -The current 256K prefix artifact is not failing because GDN attention KV cache -needs a full-attention 256K cache. It fails earlier: Neuron Runtime cannot load -the 256K context-encoding NEFF because the compiled NEFF's model tensors plus -shared scratchpad exceed the usable HBM slice for that logical placement. - -AWS Neuron's device-memory documentation describes HBM usage categories such as -model tensors, shared scratchpad, non-shared scratchpad, DMA rings, and runtime -allocations. It also documents that scratchpad page size must be coordinated -between compile-time `NEURON_CC_FLAGS=--hbm-scratchpad-page-size=...` and -runtime `NEURON_SCRATCHPAD_PAGE_SIZE=...`; changing only runtime placement/page -size is not guaranteed to repair a NEFF whose compiled scratchpad layout is too -large. See: - -```text -https://awsdocs-neuron.readthedocs-hosted.com/en/latest/neuron-runtime/explore/device-memory.html -``` - -AWS Trainium2 documentation lists 96 GiB of device memory per Trainium2 chip, -but this validation shows the failing pfx256 CTE NEFF is constrained by the -per-HBM/logical-placement allocation shown in the TDRV table, not by aggregate -host RAM or the headline chip memory number. See: - -```text -https://awsdocs-neuron.readthedocs-hosted.com/en/latest/about-neuron/arch/neuron-hardware/trainium2.html -``` - -Current hypothesis: - -```text -The pfx256 context-encoding NEFF at [CTE 3072, prefix 262144] has too much -compiled tensor + shared scratchpad footprint for trn2.3xlarge LNC2 placement. -Runtime-only tweaks did not reduce that footprint enough. Fixing it requires a -new compile with lower pfx256 CTE scratchpad/tensor footprint, a different -tiling/page-size compile, or avoiding the pfx262144 CTE NEFF. -``` - -## Can We Use 128K Prefix and Infer 256K Context? - -Not with the current 128K artifact. - -The current 128K artifact is a 128K-total artifact: - -```text -max_context_length=131072 -seq_len=131072 -pa_num_blocks=512 -token_generation_buckets=[131072] -``` - -It cannot serve or decode a 256K context because the compiled position range, -KV capacity, and token-generation bucket stop at 131072. - -A separate 256K-total artifact with only a 128K prefix bucket is a valid next -mitigation to test: - -```text -seq_len=262144 -max_context_length=262144 -pa_num_blocks=1024 -prefix_buckets=[131072] -context_encoding_bucket_pairs=[[3072, 131072]] -token_generation_buckets=[262144] -``` - -That would be semantically valid for 256K context if it loads, but it changes -the caching behavior: - -```text -cached reusable prefix: up to 128K -remaining prompt suffix: replay/refill up to the requested context length -decode positions: up to 256K, because max_context_length and tkg are 256K -``` - -So 128K prefix is not a replacement for 256K context. It is a cache boundary -inside a 256K-capable artifact. It should be correct, but slower than true -pfx256 reuse for prompts where the reusable shared prefix is above 128K. - -Risk: - -```text -This still needs a 256K token-generation/KV-capable artifact. The pfx128 CTE -NEFF may avoid the pfx256 shared-scratchpad load failure, but the 256K decode -and PA footprint still must be compiled and load-tested before we call it -production-ready. -``` - -## Robust 256K Prefix Fix Under Test - -The robust fix is to keep the 256K prefix bucket but stop compiling the -prefix-attention CTE as one monolithic `[active_tokens, prefix_tokens]` score -tensor. - -Implementation: - -```text -src/neuronx_distributed_inference/models/config.py - NeuronConfig.prefix_cte_attention_chunk_size - -src/neuronx_distributed_inference/modules/attention/attention_base.py - NeuronAttentionBase.perform_prefix_prefill_chunked_prior() - -contrib/models/Qwen3.6-27B/test/integration/qwen36_27b_compile_fp8.py - --prefix-cte-attention-chunk-size -``` - -Behavior: - -```text -If prefix_cte_attention_chunk_size is set and prior_len exceeds it, prefix CTE -attention streams cached-prefix K/V in fixed chunks and combines the chunks with -online softmax. This avoids materializing the full [Q, prefix] score tensor. -The compiled bucket can still be [CTE 3072, prefix 262144]. -``` - -Why this is the robust path: - -```text -The failed pfx256 compile produced an 11GB page-aligned scratchpad requirement -for the pfx256 context-encoding NEFF. 32K/64K prefix-tier artifacts already -compiled and loaded. Streaming pfx256 as eight 32K chunks should bound live -attention-score memory near the proven smaller prefix shapes while preserving -correct full-256K prefix semantics. -``` - -Compile probe started: - -```text -host: ubuntu@16.51.90.254 -pid: 59247 -artifact: - /mnt/trainium_artifacts/qwen_artifacts/qwen36_27b_256k_fp8_full_lmheadbf16_hybrid_apc_prod_pfx256k_stream32k_nki_fusedstable_directsolve_hostlogits_b256_cte3072_pfx256k_stream32k_pa1024_slots64_tkg262144_async_20260526T164116Z_pfx256_stream32k -workdir: - /mnt/trainium_artifacts/qwen_artifacts/_nxd_model_workdir_256k_fp8_full_prod_pfx256k_stream32k_cte3072_pfx256k_stream32k_pa1024_tkg262144_20260526T164116Z_pfx256_stream32k -log: - /home/ubuntu/validation_logs/fp8_256k/qwen36_27b_256k_fp8_full_lmheadbf16_hybrid_apc_prod_pfx256k_stream32k_nki_fusedstable_directsolve_hostlogits_b256_cte3072_pfx256k_stream32k_pa1024_slots64_tkg262144_async_20260526T164116Z_pfx256_stream32k_compile.log -key args: - --seq-len 262144 - --max-context-length 262144 - --prefix-buckets 262144 - --context-encoding-bucket-pairs 3072:262144 - --token-generation-buckets 262144 - --pa-num-blocks 1024 - --prefix-cte-attention-chunk-size 32768 -``` - -Validation so far: - -```text -Local config test: - python3 -m pytest contrib/models/Qwen3.6-27B/test/unit/test_qwen36_compile_fp8_config.py \ - -k 'prefix_cte_attention_chunk_size or sparse_context_encoding_bucket_pairs_are_forwarded' - result: 2 passed - -Remote attention test: - NEURON_PLATFORM_TARGET_OVERRIDE=trn2 python -m pytest \ - test/unit/modules/attention/test_attention_base.py \ - -k 'prefix_prefill_chunked_prior or prefix_prefill_sharded_flash_attn or prefix_prefill_unsharded_flash_attn' - result: 6 passed - -Compile status at start: - HLO generation completed for context_encoding_model and token_generation_model. - neuronx-cc priority compilation started with no NCC_ITIN902/NRT_RESOURCE in - the main log at the time this note was updated. -``` - -## Error Log - -### Remote `rg` Missing - -```text -What failed: - Remote repo inspection command using `rg`. - -How it failed: - bash: line 1: rg: command not found - -How we got there: - Host ubuntu@16.50.61.215 did not have ripgrep installed. - -Hypothesis: - The TRN2 instance image lacks the local developer tooling installed on the - Mac workspace. - -Fix: - Switched remote inspection to `find`, `grep`, `sed`, and `python3`. - -Verification: - Remote config inspection completed and printed the 128K/256K neuron_config - limits recorded in this note. -``` - -### Launcher PID Redirection - -```text -What failed: - Initial background validation launcher. - -How it failed: - bash: line 1: ${PID}: ambiguous redirect - -How we got there: - A shell grouping/variable expansion issue while starting the long validation - command and writing the PID file. - -Hypothesis: - The PID variable was expanded in the wrong shell context. - -Fix: - Manually wrote the detected validation PID to: - /home/ubuntu/validation_logs/fp8_256k/tierfix_validation_20260526T152617Z/run.pid - -Verification: - The validation continued and produced the raw summary JSON stored in this - repo. -``` - -### Remote `python` Missing - -```text -What failed: - Remote JSON/config parsing helper invoked as `python`. - -How it failed: - bash: line 1: python: command not found - -How we got there: - The remote instance exposes Python as `python3`, not `python`. - -Hypothesis: - No `python` compatibility symlink on the remote image. - -Fix: - Reran the helper with `python3`. - -Verification: - Parsed artifact config fields successfully. -``` - -### Local Attention Unit Test Missing `torch_xla` - -```text -What failed: - Local focused attention unit test. - -Command: - python3 -m pytest test/unit/modules/attention/test_attention_base.py \ - -k 'prefix_prefill_chunked_prior or prefix_prefill_sharded_flash_attn or prefix_prefill_unsharded_flash_attn' - -How it failed: - ModuleNotFoundError: No module named 'torch_xla' - -How we got there: - The Mac workspace Python environment does not include torch_xla. - -Hypothesis: - Local environment is not the Neuron inference venv. - -Fix: - Synced the changed files to ubuntu@16.51.90.254 and reran in the Neuron venv. - -Verification: - Remote test passed with 6 selected tests after setting - NEURON_PLATFORM_TARGET_OVERRIDE=trn2. -``` - -### Remote Attention Unit Test Platform Override - -```text -What failed: - First remote focused attention unit test on ubuntu@16.51.90.254. - -How it failed: - RuntimeError: Unsupported Platform - r7i.24xlarge - If you want to compile on CPU, please supply a compiler target argument. - -How we got there: - The compile host is a CPU/cross-compile instance. Importing Neuron/NxD modules - without a platform override caused torch_neuronx to infer the host platform - instead of the target Trainium platform. - -Hypothesis: - Neuron unit tests that import NxD need NEURON_PLATFORM_TARGET_OVERRIDE when - running on non-Trainium compile hosts. - -Fix: - Reran with: - NEURON_PLATFORM_TARGET_OVERRIDE=trn2 - -Verification: - 6 selected attention prefix-prefill tests passed. -``` - -### 256K Prefix Runtime Load Failure - -```text -What failed: - 256K pfx256 artifact prefill/runtime load. - -How it failed: - NRT_RESOURCE in nrt_load_util: - Failed to allocate 1.000GB shared scratchpad - Failed to load context_encoding_model/_tp0_bk0/model...neff, err: 4 - -How we got there: - Artifact: - qwen36_27b_256k_fp8_full_lmheadbf16_hybrid_apc_prod_pfx256k_... - Inputs: - seq_len=262144 - pa_num_blocks=1024 - length=261888 - CTE/prefix pair=[3072, 262144] - token_generation_buckets=[262144] - -Hypothesis: - The pfx256 context-encoding NEFF compiled tensor + shared scratchpad footprint - exceeds the usable per-HBM allocation for the logical NeuronCore placement. - -Fix attempted: - Retried with: - NEURON_RT_DBG_SCRATCHPAD_ON_SINGLE_CORE=0 - Then probed: - NEURON_RT_DBG_SCRATCHPAD_ON_SINGLE_CORE=0 - NEURON_SCRATCHPAD_PAGE_SIZE=512 - -Verification: - Both retries still failed with `NRT_RESOURCE`, so the remaining blocker is a - compiled NEFF footprint issue, not just a runtime placement knob. -``` - -### Python-Level 256K Prefix Chunking Did Not Reduce NEFF Memory - -```text -What failed: - Robust pfx256 mitigation probe using Python-level prefix attention chunking. - -How it failed: - The compile itself completed, but the context-encoding NEFF memory footprint - did not improve: - COMPILE_DONE - context HBM: 24.101GB - total page-aligned scratchpad: 11.000000GB - -How we got there: - Host: - ubuntu@16.51.90.254 - Artifact: - /mnt/trainium_artifacts/qwen_artifacts/qwen36_27b_256k_fp8_full_lmheadbf16_hybrid_apc_prod_pfx256k_stream32k_nki_fusedstable_directsolve_hostlogits_b256_cte3072_pfx256k_stream32k_pa1024_slots64_tkg262144_async_20260526T164116Z_pfx256_stream32k - Workdir: - /mnt/trainium_artifacts/qwen_artifacts/_nxd_model_workdir_256k_fp8_full_prod_pfx256k_stream32k_cte3072_pfx256k_stream32k_pa1024_tkg262144_20260526T164116Z_pfx256_stream32k - Inputs: - --seq-len 262144 - --max-context-length 262144 - --cte-buckets 3072 - --prefix-buckets 262144 - --context-encoding-bucket-pairs 3072:262144 - --token-generation-buckets 262144 - --pa-num-blocks 1024 - --prefix-cte-attention-chunk-size 32768 - -Hypothesis: - XLA/Neuron static tracing still lowered the Python chunk loop into a graph - with the same large flat prefix attention footprint. This confirms that - chunking must happen inside an NKI kernel or via a newer segmented CTE kernel, - not in regular PyTorch graph code. - -Fix or mitigation applied: - Do not treat the pfx256_stream32k artifact as the production fix. Web/docs - research points to Neuron 2.30 NKI Library `Attention Segmented CTE` and - `KV-Parallel Segmented Prefill` as the next production-grade path because - they process block KV/prefix cache in configurable segments inside the kernel. - -Verification: - Pending. Need either: - 1. Upgrade/overlay a Neuron 2.30 NKI Library containing segmented CTE and - wire prefix CTE to that kernel; or - 2. Write a custom NKI segmented prefix attention kernel if the library - kernel is unavailable in our runtime. -``` - -### Neuron 2.30 Segmented CTE Overlay Inspection - -```text -What failed: - First SSH inspection command after creating the Neuron 2.30 segmented CTE - overlay on ubuntu@16.51.90.254. - -How it failed: - The command exited 1 because it used `set -o pipefail` with: - find "$NKILIB_DIR" -maxdepth 5 -type f | grep -E "attention.*(seg|prefill|cte).*\.py$|kv.*prefill.*\.py$" - The `find` maxdepth/pattern missed the files under - src/nkilib_src/nkilib/core/attention, so `grep` returned no matches. - -How we got there: - Host: - ubuntu@16.51.90.254 - Overlay venv: - /home/ubuntu/venvs/neuron_230_segmented_cte - Source checkout: - /home/ubuntu/nki-library-2.30 - Branch: - 2.30_release - Installed Python packages: - nki==0.4.0+25940409122.gd30719f9 - neuronx-cc==2.25.3371.0+f524f7f8 - -Root cause: - Inspection-command bug, not an overlay setup failure. - -Fix: - Reran inspection with the overlay activated and direct Python imports. - -Verification: - Confirmed: - IMPORT_OK nkilib.core.attention.attention_segmented_cte - IMPORT_OK nkilib.core.attention.kv_parallel_segmented_prefill - attention_segmented_cte signature accepts block KV cache, block_tables, - prior_tokens, block_size, and prior_seg_size. -``` - -### Local Segmented CTE Search/Syntax Checks - -```text -What failed: - Local source search for k-cache transposition references. - -How it failed: - Command exited 2: - rg: src/modeling_qwen35.py: No such file or directory (os error 2) - -How we got there: - I searched `src/modeling_qwen35.py`, but this repository stores the Qwen model - file at: - contrib/models/Qwen3.6-27B/src/modeling_qwen35.py - -Root cause: - Wrong local path in the search command. - -Fix: - Reran with `rg --files` and then searched existing paths under `src` and - `contrib/models/Qwen3.6-27B`. - -Verification: - Found the relevant `k_cache_transposed` references and confirmed block KV - cache disables transposed K cache. - -What failed: - First local Python syntax command: - python -m py_compile ... - -How it failed: - zsh:1: command not found: python - -How we got there: - The local Mac shell exposes `python3` but not `python`. - -Fix: - Reran: - python3 -m py_compile ... - -Verification: - Syntax compile passed for: - src/neuronx_distributed_inference/modules/attention/attention_base.py - src/neuronx_distributed_inference/models/config.py - src/neuronx_distributed_inference/models/model_wrapper.py - contrib/models/Qwen3.6-27B/test/integration/qwen36_27b_compile_fp8.py - contrib/models/Qwen3.6-27B/test/unit/test_qwen36_compile_fp8_config.py - Focused config tests passed: - 23 passed, 3 subtests passed -``` - -### Segmented CTE Overlay Wiring and Compile Launch - -```text -What failed: - Remote diff preview after syncing segmented CTE files to ubuntu@16.51.90.254. - -How it failed: - The command exited 141 because it ran `git diff ... | head -240` under - `set -o pipefail`; `head` closed the pipe and `git diff` received SIGPIPE. - -How we got there: - Files had already been installed into: - /home/ubuntu/inferentia-gdn-fused-noclamp-4340808 - The failing command was only a preview step after install. - -Root cause: - Shell preview mistake, not a sync failure. - -Fix: - Reran remote status/syntax checks without piping through `head`. - -Verification: - Remote `py_compile` passed for the synced files. - -What failed: - First remote focused tests in the Neuron 2.30 overlay venv. - -How it failed: - /home/ubuntu/venvs/neuron_230_segmented_cte/bin/python: - No module named pytest - -How we got there: - The overlay venv was intentionally minimal and only installed newer - nki/neuronx-cc. - -Root cause: - Missing test dependency in the overlay venv. - -Fix: - Ran focused unit tests in the base Neuron venv instead: - /opt/aws_neuronx_venv_pytorch_inference_vllm_0_16 - -Verification: - Remote tests passed: - 42 passed, 3 subtests passed - -What failed: - First overlay import check for the synced attention module. - -How it failed: - ModuleNotFoundError: No module named 'torch' - -How we got there: - The overlay venv had nki 0.4 / neuronx-cc 2.25 but did not inherit the base - Neuron venv's PyTorch/NxD packages. - -Root cause: - `python -m venv --system-site-packages` does not inherit packages installed - inside another venv. - -Fix: - Added a `.pth` file in the overlay site-packages pointing to: - /opt/aws_neuronx_venv_pytorch_inference_vllm_0_16/lib/python3.12/site-packages - -Verification: - Overlay then imported PyTorch, but exposed the next PATH issue below. - -What failed: - Overlay import after adding base site-packages. - -How it failed: - FileNotFoundError: - [Errno 2] No such file or directory: 'libneuronpjrt-path' - -How we got there: - torch_xla imported from the base venv site-packages, but the overlay - activation did not include the base venv `bin` directory on PATH. - -Root cause: - Base Neuron helper executables were not visible when using overlay Python. - -Fix: - Updated the compile launcher to export: - PATH="${NEURON_VENV}/bin:${BASE_NEURON_VENV}/bin:${PATH}" - -Verification: - Overlay import check passed: - TORCH 2.9.1+cu128 - NKI 0.4.0+25940409122.gd30719f9 - NEURONXCC 2.25.3371.0+f524f7f8 - SEGMENTED_KERNEL True - -What failed: - Potential segmented CTE compile-sample invalidity for - [active=3072, prefix=262144] with `pa_num_blocks=1024`. - -How it would fail: - The generated sample `slot_mapping` would write active KV at positions - 262144..265215, past the 256K cache capacity, before segmented CTE reads - active KV from block cache. - -How we got there: - Existing prefix CTE sampled `computed_context_lens=prefix_bucket`; this was - fine for flat `attention_cte` because active KV was passed separately, but - segmented CTE reads active KV from the updated block cache. - -Root cause: - The sample value for `computed_context_lens` was not constrained to - `max_context_length - active_bucket` for segmented CTE. - -Fix: - In `model_wrapper.py`, for context-encoding segmented CTE samples, keep the - bucket shape at 262144 but set the sample prior to: - min(prefix_bucket, max_context_length - n_active_tokens) - For the pfx256/cte3072 trace this is: - computed_context_lens=259072 - -Verification: - The segmented CTE compile got through both context HLOs and the - token-generation HLO without sample OOB or import errors. - -What was cleaned up: - Removed the two known-bad pfx256 probes before launching the new compile: - qwen36_27b_256k_..._prod_pfx256k_..._20260526T132620Z_tierfix_pfx256k - qwen36_27b_256k_..._prod_pfx256k_stream32k_..._20260526T164116Z_pfx256_stream32k - plus their `_nxd_model_workdir_*` directories. - -Why: - They were already proven not production-ready: - pfx256 tierfix hit runtime load NRT_RESOURCE. - pfx256_stream32k compiled but kept the same large HBM/scratch footprint. - -Verification: - Free disk on /mnt/trainium_artifacts increased from 35GB to 93GB. - -Current compile: - Host: - ubuntu@16.51.90.254 - PID: - 65885 - Artifact: - /mnt/trainium_artifacts/qwen_artifacts/qwen36_27b_256k_fp8_full_lmheadbf16_hybrid_apc_prod_pfx256k_segcte32k_nki_fusedstable_directsolve_hostlogits_b256_cte3072_pfx256k_segcte32k_pa1024_slots64_tkg262144_async_20260526T174252Z - Log: - /home/ubuntu/validation_logs/fp8_256k/qwen36_27b_256k_fp8_full_lmheadbf16_hybrid_apc_prod_pfx256k_segcte32k_nki_fusedstable_directsolve_hostlogits_b256_cte3072_pfx256k_segcte32k_pa1024_slots64_tkg262144_async_20260526T174252Z_compile.log - Key flags: - --context-encoding-bucket-pairs 3072:262144 - --prefix-cte-attention-backend segmented_cte - --prefix-cte-attention-segment-size 32768 - --pa-num-blocks 1024 - Status: - HLO generation completed for both context_encoding_model traces and the - token_generation_model trace. neuronx-cc compilation is running. -``` - -### Segmented CTE Compile Completed but pfx256 Footprint Still Has Flat Gather - -```text -What failed: - The pfx256 segmented CTE compile completed, but it did not eliminate the - high-footprint 256K-prefix context NEFF. - -How it failed: - Compile status: - COMPILE_DONE - Artifact: - /mnt/trainium_artifacts/qwen_artifacts/qwen36_27b_256k_fp8_full_lmheadbf16_hybrid_apc_prod_pfx256k_segcte32k_nki_fusedstable_directsolve_hostlogits_b256_cte3072_pfx256k_segcte32k_pa1024_slots64_tkg262144_async_20260526T174252Z - Context bucket summaries: - context_encoding_model/_tp0_bk0: - Total estimated HBM usage: 13.65GB - Total page-aligned scratchpad: 1.500000GB - context_encoding_model/_tp0_bk1: - Total estimated HBM usage: 24.10GB - Total page-aligned scratchpad: 11.000000GB - Token-generation summary: - token_generation_model/_tp0_bk0: - Total estimated HBM usage: 12.42GB - Total page-aligned scratchpad: 0.500000GB - -How we got there: - Host: - ubuntu@16.51.90.254 - Key compile args: - --context-encoding-bucket-pairs 3072:262144 - --prefix-cte-attention-backend segmented_cte - --prefix-cte-attention-segment-size 32768 - --pa-num-blocks 1024 - Overlay: - nki==0.4.0+25940409122.gd30719f9 - neuronx-cc==2.25.3371.0+f524f7f8 - -Evidence: - `neuron_config.json` in the artifact records: - prefix_cte_attention_backend=segmented_cte - prefix_cte_attention_segment_size=32768 - But `context_encoding_model/_tp0_bk1/log-neuron-cc.txt` still contains - large indirect loads from: - get_kv_by_layer_id/_get_block_cache_and_reshape_bhsd/aten.index_select - with cache-shaped tensors such as: - bfloat16 (1025, 65536) - That means the long-prefix trace still materialized the flattened block-cache - gather before/alongside the segmented CTE path. - -Root cause / hypothesis: - The integration still calls `kv_mgr.get_kv_by_layer_id(**kwargs)` before the - segmented CTE pre-update path. For prefix caching, that method gathers block - KV into flat BHSD prior tensors through `_get_block_cache_and_reshape_bhsd`. - Those flattened gathers remain in the HLO and dominate the 256K-prefix NEFF, - so using `attention_segmented_cte` later is not enough. - -Fix or next mitigation: - The robust fix is to add a true raw-block-cache prefix path for segmented - CTE: - 1. In context encoding when `prefix_cte_attention_backend=segmented_cte`, - do not call `kv_mgr.get_kv_by_layer_id` for prefix prior. - 2. Fetch raw per-layer block KV via `kv_mgr._fetch_cache(...)` or a clean - public wrapper. - 3. Pre-update active K/V into raw block KV. - 4. Call `attention_segmented_cte` with raw block KV, `active_block_table`, - and `computed_context_lens`. - 5. Return the updated raw block KV and skip the old flat prior path. - -Verification: - Not fixed yet. The completed artifact should not be treated as the pfx256 - production fix. It can be transferred only for confirmation, but based on the - compile footprint it is expected to have the same runtime-load risk as the - previous pfx256 artifact. -``` - -### Raw Block Segmented CTE Fix Applied - -```text -What failed: - Web/docs review found that the previous segmented CTE integration did not - match the official block-KV contract. The Qwen hybrid prefill path still - called `get_kv_by_layer_id`, which flattened prefix blocks before attention. - -How it failed: - The pfx256 segmented CTE artifact compiled, but `log-neuron-cc.txt` still - showed `_get_block_cache_and_reshape_bhsd/aten.index_select` in the pfx256 - context bucket and HBM reached 24.10GB per core with 11GB page-aligned - scratchpad. - -How we got there: - Branch: - codex/full-fp8-qwen36 - Backend: - prefix_cte_attention_backend=segmented_cte - Bucket: - context_encoding_bucket_pairs=3072:262144 - The base attention code had a segmented CTE call, but Qwen's hybrid path - pre-fetched `past_key_values` through `QwenHybridBlockKVCacheManager.get_cache` - and then used `perform_qwen_chunked_prefill` over flat selected prefix KV. - -Root cause / hypothesis: - Official Neuron docs say NxDI prefix caching uses block KV, but the default - prefix-caching flow gathers block KV into a flat layout before attention. - Neuron 2.30 adds `Attention Segmented CTE` and `KV-Parallel Segmented - Prefill` kernels specifically for block-based KV cache. Therefore the fix is - not another bucket shape; it is avoiding the flat gather entirely for the - segmented CTE path. - -Fix applied: - - Added `BlockKVCacheManager.get_raw_kv_by_layer_id()` to return block-layout - KV without `_get_block_cache_and_reshape_bhsd`. - - Changed `QwenHybridBlockKVCacheManager.get_cache()` so segmented context - prefix buckets return raw block KV for full-attention layers. - - Changed Qwen chunked prefill so `prefix_cte_attention_backend=segmented_cte` - pre-updates active K/V into raw block cache and calls - `attention_segmented_cte` with `active_block_table` and - `computed_context_lens`. - - Changed Qwen cache update to accept already-updated raw block KV and skip a - second block-cache update. - - Fixed the base attention segmented path so it no longer requires flat - `past_key_value` before dispatching to segmented CTE. - -Verification: - Pending local unit tests and a new pfx256 compile. Expected compile evidence - for success: - - no `_get_block_cache_and_reshape_bhsd/aten.index_select` in pfx256 - context HLO/logs; - - pfx256 context HBM below the per-core 24GB limit with materially smaller - scratchpad than the failed 24.10GB / 11GB artifact. -``` - -### Raw Block Segmented CTE Compile Blocked by head_dim=256 - -```text -What failed: - Fresh pfx256 raw-block segmented CTE compile failed during HLO generation. - -How it failed: - Host: - ubuntu@16.51.90.254 - PID: - 69740 - Log: - /home/ubuntu/validation_logs/fp8_256k/qwen36_27b_256k_fp8_full_lmheadbf16_hybrid_apc_prod_pfx256k_rawsegcte32k_nki_fusedstable_directsolve_hostlogits_b256_cte3072_pfx256k_rawsegcte32k_pa1024_slots64_tkg262144_async_20260526T183314Z_compile.log - Artifact target: - /mnt/trainium_artifacts/qwen_artifacts/qwen36_27b_256k_fp8_full_lmheadbf16_hybrid_apc_prod_pfx256k_rawsegcte32k_nki_fusedstable_directsolve_hostlogits_b256_cte3072_pfx256k_rawsegcte32k_pa1024_slots64_tkg262144_async_20260526T183314Z - Exact error: - AssertionError: error: failed to compile NKI kernel: - Collected 1 different diagnostics: - - [x2] error: assertion failed: [INTERNAL_ERROR] [NCC_INKI016] - Kernel validation exception: head_dim must be <= 128 (got 256). - Larger head_dim not yet supported. - Please check the validation - message and adjust kernel inputs accordingly - -How we got there: - Command launched `tmp_compile_qwen256k_fp8_full_prod_prefix_tier_hostlogits.sh` - with: - TIER_NAME=pfx256k_rawsegcte32k - PREFIX_BUCKETS_STR=262144 - PAIR_ARGS_STR=3072:262144 - CTE_BUCKETS_STR=3072 - TKG_BUCKETS_STR=262144 - PA_NUM_BLOCKS=1024 - PREFIX_CTE_ATTENTION_BACKEND=segmented_cte - PREFIX_CTE_ATTENTION_SEGMENT_SIZE=32768 - NEURON_VENV=/home/ubuntu/venvs/neuron_230_segmented_cte - NKI_LIBRARY_SRC=/home/ubuntu/nki-library-2.30/src/nkilib_src - -Root cause / hypothesis: - Proven root cause for this failure: Neuron 2.30 `attention_segmented_cte` - hard-validates `head_dim <= 128`, while Qwen3.6 27B has attention - `head_dim=256`. - Evidence: - /home/ubuntu/nki-library-2.30/src/nkilib_src/nkilib/core/attention/attention_segmented_cte.py - contains: - kernel_assert(head_dim <= 128, ...) - The bundled NKI model test config for `qwen3_235b` uses `d_head=128`, so - the official segmented CTE Qwen coverage does not cover this 27B head_dim. - -Fix or mitigation: - The raw-block segmented CTE integration is correct structurally, but the - official kernel cannot support this model without a head_dim=256 variant. - Viable next options are: - 1. Build a Qwen-specific head_dim=256 segmented CTE kernel that accumulates - QK over two 128-wide D tiles before softmax, then computes PV over the - full 256-wide V. This is the robust fix if we require a true 256K prefix - bucket. - 2. Use the existing production-safe tier strategy with <=128K prefix buckets - and route 256K-context requests through a lower prefix bucket, accepting - extra refill work. - 3. Open/escalate an AWS Neuron issue requesting head_dim=256 support in - `attention_segmented_cte`. - -Verification: - Compile did not complete. Do not retry this exact raw-block segmented CTE - compile until the head_dim=256 kernel limitation is addressed. -``` - -### Qwen head_dim=256 Segmented CTE Kernel Bring-Up - -```text -What failed: - The first Qwen-specific segmented CTE prototype was a direct copy of the - Neuron 2.30 segmented CTE kernel with only the top-level head_dim validator - relaxed. - -How it failed: - Host: - ubuntu@16.51.90.254 - Remote repo: - /home/ubuntu/inferentia-gdn-fused-noclamp-4340808 - Probe: - Offline NKI compile_to_bir with q=(2,256,256), - k/v_cache=(8,1,256,256), block_size=256, prior_seg_size=512, - tp_q=True, tp_out=False, target=trn2. - Exact errors hit and fixed: - 1. dma_copy dst partition dimension 256 exceeds maximum 128 - at attention_segmented_cte_256.py load_kv_cache. - Cause: copied kernel still loaded K as one (head_dim, K_TILE) tile. - Fix: split K into low/high (128, K_TILE) tiles. - 2. unsupported expression on list comprehensions creating K tile pairs. - Cause: NKI specialization rejected Python list comprehensions. - Fix: build the list with explicit for/append meta-programming. - 3. failed to resolve name 'x::0.shape' from k_sbuf[0].shape. - Cause: split K tile entries are Python pairs, not NKI tensors. - Fix: read K_TILE_SIZE from k_sbuf[0][0].shape for head_dim=256. - 4. dma_copy dst partition dimension 256 exceeds maximum 128 on the - temporary non-transposed K block load. - Cause: temp was shaped (block_size, 128), and block_size=256 became - the partition dimension. - Fix: load each K block in 128-token by 128-dim chunks. - 5. dma_copy src/dst element mismatch src=32768 dst=16384. - Cause: source access pattern still selected full D=256 for a 128-wide - destination. - Fix: use HBM source pattern [[head_dim, 128], [1, 128]] for each - 128-token by 128-dim K chunk. - 6. dma_transpose Q shape mismatch: source D=256, destination D=128. - Cause: split Q source pattern used full D as the transposed extent. - Fix: keep token stride at ac.d but set the transposed D count to 128: - [[ac.d, num_f], [1, 1], [1, 1], [1, 128]]. - 7. reduce_one_batch batch_idx typed as object. - Cause: copied call used an older helper signature and passed output - tensors where Neuron 2.30 expects batch_idx/grp_start/grp_end. - Fix: call reduce_one_batch with batch_idx=0, grp_start=0, - grp_end=n_grps, d=head_dim, num_grps=n_grps, sb_p=sb_p. - -Fix implemented: - Added a Qwen-specific NKI package: - src/neuronx_distributed_inference/modules/attention/nki_kernels/qwen_segcte256/ - - The kernel keeps V and output on legal free dimensions and only splits the - QK contraction: - logits = Q_lo @ K_lo + Q_hi @ K_hi - - This follows the documented NKI matmul rule that contraction dimensions - larger than 128 must be accumulated through multiple nc_matmul writes to the - same PSUM tile. - -Verification: - Local syntax: - python3 -m py_compile attention_base.py attention_segmented_cte_256.py - fused_segmented_attention_256.py - - Remote syntax: - PYCOMPILE_OK on ubuntu@16.51.90.254 under - /home/ubuntu/venvs/neuron_230_segmented_cte with Neuron 2.30 NKI source. - - Remote NKI BIR probe: - BIR_OK for q=(2,256,256), k/v=(8,1,256,256), prior_seg_size=512. - - Remote production-shape NKI BIR probe: - BIR_Q3072_OK for q=(2,3072,256), k/v=(1024,1,256,256), - block_size=256, prior_seg_size=32768, pa_num_blocks=1024. - Reported scratch: - sb_scratch_sizes=[402724] - psum_scratch_sizes=[15360] - -Remaining work: - This validates NKI front-end/BIR legality for the target bucket geometry. - Full model compile and runtime numerical validation are still required before - calling the pfx256 artifact production-ready. -``` - -### pfx256 segcte256d32k Full Compile Failed on SBUF Scratch Allocation - -```text -What failed: - Full Qwen3.6 27B FP8 pfx256k compile with the Qwen head_dim=256 segmented - CTE kernel failed during neuronx-cc compilation of context_encoding_model. - -How it failed: - Host: - ubuntu@16.51.90.254 - PID: - 76788 - Log: - /home/ubuntu/validation_logs/fp8_256k/qwen36_27b_256k_fp8_full_lmheadbf16_hybrid_apc_prod_pfx256k_segcte256d32k_nki_fusedstable_directsolve_hostlogits_b256_cte3072_pfx256k_pa1024_slots64_tkg262144_async_20260526T191006Z_compile.log - Workdir: - /mnt/trainium_artifacts/qwen_artifacts/_nxd_model_workdir_256k_fp8_full_prod_pfx256k_segcte256d32k_cte3072_pfx256k_pa1024_tkg262144_20260526T191006Z - Failing buckets: - context_encoding_model/_tp0_bk0 - context_encoding_model/_tp0_bk1 - Exact compiler error: - [INTERNAL_ERROR] [NCC_INLA001] Unhandled exception with message: - Allocated memory out of bound - {scratch_sb_for_inst__I-361}@SB<0,0>(128x402724) - #Internal DebugInfo: - Exit: - neuronx-cc returned non-zero exit status 70. - -How we got there: - Compile command used: - TIER_NAME=pfx256k_segcte256d32k - PREFIX_BUCKETS_STR=262144 - PAIR_ARGS_STR=3072:262144 - CTE_BUCKETS_STR=3072 - TKG_BUCKETS_STR=262144 - PA_NUM_BLOCKS=1024 - PREFIX_CTE_ATTENTION_BACKEND=segmented_cte - PREFIX_CTE_ATTENTION_SEGMENT_SIZE=32768 - NEURON_VENV=/home/ubuntu/venvs/neuron_230_segmented_cte - NKI_LIBRARY_SRC=/home/ubuntu/nki-library-2.30/src/nkilib_src - -Root cause / hypothesis: - Proven: - The custom head_dim=256 NKI kernel is BIR-legal for the target shape, but - the backend rejects the generated context CTE kernel because its live SBUF - scratch allocation is too large: 128x402724. - Best current hypothesis: - The first head_dim=256 kernel keeps too many per-segment K/V and per-Q-group - attention buffers live in SBUF. Splitting D into two 128-wide K/Q tiles fixed - the head_dim validator, but doubled K-side live storage and still inherited - the upstream segmented-CTE allocation style that materializes too much segment - state at once. - -Additional probes: - Offline BIR probes after failure showed scratch is still high even with lower - segment sizes: - q=3072, segment=8192 -> sb_scratch_sizes=[206116] - q=3072, segment=4096 -> sb_scratch_sizes=[116064] - q=3072, segment=2048 -> sb_scratch_sizes=[107840] - q=3072, segment=512 -> sb_scratch_sizes=[107808] - q=512, segment=512 -> sb_scratch_sizes=[54208] - These are BIR-legal but still likely too high for backend SBUF placement. - -Fix or mitigation: - Do not retry the same pfx256 segcte256d32k compile. - The next robust kernel fix is to reduce live SBUF, not only segment length: - - stream K/V tiles through the QK and PV loops instead of holding an entire - prior segment in SBUF; - - allocate MM1/MM2 scratch per Q group or a small group window instead of - block_dim=[num_grps] for all 3072 active tokens; - - keep only the running softmax stats/output persistent across segments. - This is a second-stage kernel rewrite. The current kernel fixed the head_dim - problem but is not production-ready for pfx256 because of SBUF pressure. - -Verification: - The compile failed. No artifact was produced. The heartbeat monitor was - stopped after recording this failure. -``` - -### pfx256 Kernel Rewrite: Active CTE Streaming + Q-Pack Cap - -```text -What failed: - The first qwen_segcte256 kernel fixed head_dim=256 front-end legality, but - full pfx256 compile failed because the context CTE kernel needed an illegal - live SBUF allocation: - {scratch_sb_for_inst__I-361}@SB<0,0>(128x402724) - -How we got there: - Host: - ubuntu@16.51.90.254 - Branch: - codex/full-fp8-qwen36 - Files: - src/neuronx_distributed_inference/modules/attention/nki_kernels/qwen_segcte256/fused_segmented_attention_256.py - src/neuronx_distributed_inference/modules/attention/nki_kernels/qwen_segcte256/attention_segmented_cte_256.py - Target compile shape: - CTE bucket 3072, prefix bucket 262144, PA blocks 1024, head_dim 256, - FP8 full model, hybrid APC, segmented CTE backend. - -Errors encountered while fixing: - 1. NKI front-end rejected list comprehensions inside the kernel. - Command: - Remote inline compile_kernel_to_nir probe with - q=(2,256,256), k/v=(8,1,256,256), prior_seg_size=512. - Exact pattern: - unsupported expression on list comprehensions for mm1_masked_row, - exp_sb_row, mm1_copy_row, mm1_affine_select_output_row, exp_tp_row, - and _repeat_ref. - Root cause: - NKI kernels do not accept those Python list-comprehension expressions. - Fix: - Replaced each list comprehension with an explicit for-loop and append. - Verification: - The same BIR probe passed: - BIR_SMALL_OK - sb_scratch_sizes=[30592, 30592] - psum_scratch_sizes=[9216, 9216] - - 2. First active-streaming BIR hit an out-of-bound tensor access. - Command: - Remote inline compile_kernel_to_nir probe with - q=(2,3072,256), k/v=(1024,1,256,256), block_tables=(1,1024), - prior_seg_size=512. - Exact error: - assertion failed: Out-of-bound access for tensor `unnamed` on dimension - 1: index 1 exceed dimension size of 1. - Called from fused_segmented_attention_256.py in _exp_impl(). - Root cause / hypothesis: - The active stream allocated exp/running partial-sum columns for one - 512-token chunk, but ac.seqlen_k_active_updated still described the full - 3072-token active range, so _exp_impl tried to address chunk index 1 in - a one-column buffer. - Fix: - Rebuild ac/atp per active stream chunk with - seqlen_k_active_updated=next_section_offset_active, while preserving the - global K position through SectionParams.kv_section_idx. - Verification: - The q=3072, segment=512 BIR probe advanced past _exp_impl and compiled. - - 3. A docs/inspection helper import failed while probing NKI internals. - Command: - Import nki.framework.torch_xla in - /home/ubuntu/venvs/neuron_230_segmented_cte. - Exact error: - FileNotFoundError: [Errno 2] No such file or directory: - 'libneuronpjrt-path' - Root cause / hypothesis: - Importing torch_xla through the overlay venv initialized torch_neuronx - without the base Neuron runtime path. - Fix: - Avoid that inspection path for BIR probes; import NkiTensor from - nki.language.tensor and shared_hbm from nki.language.buffers. - Verification: - BIR probes compiled with the direct NKI imports. - -Fix implemented: - The robust simple fix is not a larger prefix segment. It is a smaller live - working set: - - alias per-Q-group temporary SBUF buffers to one reusable group window; - - stream active CTE K/V through the same bounded K/V SBUF window used by - prior-prefix segments; - - keep only running max/sum/output persistent across active/prior segments; - - cap Q group packing to 4 groups for head_dim=256. - - The compile must use: - PREFIX_CTE_ATTENTION_BACKEND=segmented_cte - PREFIX_CTE_ATTENTION_SEGMENT_SIZE=512 - CTE_BUCKETS_STR=3072 - PAIR_ARGS_STR=3072:262144 - PREFIX_BUCKETS_STR=262144 - PA_NUM_BLOCKS=1024 - -Verification: - Local syntax: - python3 -m py_compile \ - src/neuronx_distributed_inference/modules/attention/nki_kernels/qwen_segcte256/fused_segmented_attention_256.py \ - src/neuronx_distributed_inference/modules/attention/nki_kernels/qwen_segcte256/attention_segmented_cte_256.py - - Remote syntax: - REMOTE_PYCOMPILE_OK on ubuntu@16.51.90.254 under - /home/ubuntu/venvs/neuron_230_segmented_cte. - - Remote BIR scratch results after the rewrite: - q=3072, prior_seg_size=512: - BIR_STREAMACTIVE_Q3072_SEG512_QPACK4_OK - sb_scratch_sizes=[31360, 31360] - psum_scratch_sizes=[9216, 9216] - - q=3072, prior_seg_size=1024: - sb_scratch_sizes=[35488, 35488] - - q=3072, prior_seg_size=2048: - sb_scratch_sizes=[43680, 43680] - - q=3072, prior_seg_size=4096: - sb_scratch_sizes=[60096, 60096] - -Conclusion: - For Trn2 head_dim=256 with the current NKI layout, prior_seg_size=512 is the - only verified segment size under the documented SBUF free-dimension limit - of 32767. The previous 32k segment path and the 1024+ segment probes remain - unsafe. Full model compile and runtime validation are still required before - marking the pfx256 artifact production-ready. -``` - -### pfx256 segcte512stream Full Compile Launched - -```text -What changed: - Launched the full model compile using the verified active-streaming kernel - shape instead of the failed 32k-segment kernel. - -Host: - ubuntu@16.51.90.254 - -Compile PID: - 84525 - -Artifact target: - /mnt/trainium_artifacts/qwen_artifacts/qwen36_27b_256k_fp8_full_lmheadbf16_hybrid_apc_prod_pfx256k_segcte512stream_qpack4_nki_fusedstable_directsolve_hostlogits_b256_cte3072_pfx256k_pa1024_slots64_tkg262144_async_20260526T195604Z - -Workdir: - /mnt/trainium_artifacts/qwen_artifacts/_nxd_model_workdir_256k_fp8_full_prod_pfx256k_segcte512stream_qpack4_cte3072_pfx256k_pa1024_tkg262144_20260526T195604Z - -Log: - /home/ubuntu/validation_logs/fp8_256k/qwen36_27b_256k_fp8_full_lmheadbf16_hybrid_apc_prod_pfx256k_segcte512stream_qpack4_nki_fusedstable_directsolve_hostlogits_b256_cte3072_pfx256k_pa1024_slots64_tkg262144_async_20260526T195604Z_compile.log - -PID file: - /home/ubuntu/validation_logs/fp8_256k/qwen36_27b_256k_fp8_full_lmheadbf16_hybrid_apc_prod_pfx256k_segcte512stream_qpack4_nki_fusedstable_directsolve_hostlogits_b256_cte3072_pfx256k_pa1024_slots64_tkg262144_async_20260526T195604Z_compile.pid - -Inputs / flags: - TIER_NAME=pfx256k_segcte512stream_qpack4 - PREFIX_BUCKETS_STR=262144 - PAIR_ARGS_STR=3072:262144 - CTE_BUCKETS_STR=3072 - TKG_BUCKETS_STR=262144 - SEQ_LEN=262144 - MAX_CONTEXT_LENGTH=262144 - PA_NUM_BLOCKS=1024 - OMIT_ZERO_PREFIX_PAIR=1 - PREFIX_CTE_ATTENTION_BACKEND=segmented_cte - PREFIX_CTE_ATTENTION_SEGMENT_SIZE=512 - NEURON_VENV=/home/ubuntu/venvs/neuron_230_segmented_cte - BASE_NEURON_VENV=/opt/aws_neuronx_venv_pytorch_inference_vllm_0_16 - NKI_LIBRARY_SRC=/home/ubuntu/nki-library-2.30/src/nkilib_src - -Initial status: - HLO generation completed for context_encoding_model and token_generation_model. - neuronx-cc started compiling token_generation_model. - -Profiling note: - NKI profiling cannot replace compile preflight because it requires an - executed NEFF/NTFF pair. The pre-compile gate for this issue is BIR scratch - sizing. After COMPILE_DONE, runtime validation should be run with: - NEURON_RT_INSPECT_ENABLE=1 - NEURON_RT_INSPECT_DEVICE_PROFILE=1 - NEURON_RT_INSPECT_OUTPUT_DIR=/inspect - Then identify the qwen_segcte256 NEFF, capture with profile-nth-exec=2, and - save neuron-profile/neuron-explorer summary-json beside the cold/warm - prefill, TTFT, TPOT, and memory results. -``` - -### Runtime Validation Attempt with Device Profiling Failed - -```text -What failed: - First pfx256 segcte512stream runtime validation on the TRN2 validation host - failed during context_encoding_model NEFF load. - -How it failed: - Host: - ubuntu@16.50.61.215 - Output root: - /home/ubuntu/validation_logs/fp8_256k/pfx256_segcte512_runtime_20260526T202502Z - Phase: - pfx256k_prefill - Status: - pfx256k_prefill.status = 1 - Exact runtime errors: - Failed to allocate 1.000GB (alignment: 4.000MB, usage: shared scratchpad) - on ND 0:NC 2 - device profiling is enabled, which requires 2.348GB of HBM per NC. Either - reduce the the model size to enable profiling, or disable profiling. - Failed to allocate virtual scratchpad! - Failed to load NN: - /mnt/trainium_artifacts/qwen_artifacts/_nxd_model_workdir_256k_fp8_full_prod_pfx256k_segcte512stream_qpack4_cte3072_pfx256k_pa1024_tkg262144_20260526T195604Z/context_encoding_model/_tp0_bk0/model.MODULE_dc595ea41a524c32e935+86f42f0e.neff - Failure: NRT_RESOURCE in nrt_load_util - RuntimeError: Engine core initialization failed. - -How we got there: - Artifact: - /mnt/trainium_artifacts/qwen_artifacts/qwen36_27b_256k_fp8_full_lmheadbf16_hybrid_apc_prod_pfx256k_segcte512stream_qpack4_nki_fusedstable_directsolve_hostlogits_b256_cte3072_pfx256k_pa1024_slots64_tkg262144_async_20260526T195604Z - Validation flags: - seq_len=262144 - pa_num_blocks=1024 - cte_buckets=3072 - token_generation_buckets=262144 - context_encoding_bucket_pairs=3072:262144 - max_tokens=1 - length=261888 - Profiling environment enabled: - NEURON_RT_INSPECT_ENABLE=1 - NEURON_RT_INSPECT_DEVICE_PROFILE=1 - NEURON_RT_INSPECT_OUTPUT_DIR=/home/ubuntu/validation_logs/fp8_256k/pfx256_segcte512_runtime_20260526T202502Z/inspect - -Memory evidence: - Runtime table for the failing HBM group showed: - Model tensors: 12.052GB - Shared scratchpad: 6.000GB - Profiler buffers: 4.758GB total, 2.379GB per NC - Total shown on HBM group: 22.814GB - -Root cause / hypothesis: - Proven: - Device profiling itself adds enough HBM pressure to prevent the 256K context - NEFF from loading. The error explicitly names profiler buffers and says to - disable profiling or reduce model size. - Not proven: - This does not prove the artifact fails without profiling. The profiler - overhead is the immediate blocker for this attempt. - -Fix / mitigation: - Rerun the same pfx256 validation without NEURON_RT_INSPECT_DEVICE_PROFILE. - Keep memory sampling enabled via neuron_memory_sampler. If runtime validation - passes, profile a smaller context/shorter profile variant or capture profiling - from a reduced-shape NEFF because full 256K device profiling does not fit. - -Verification: - Pending rerun without device profiling. -``` - -### Runtime Validation Without Device Profiling Failed with DGE OOB - -```text -What failed: - The no-profile pfx256 segcte512stream runtime validation failed during the - 261888-token context prefill execution after the artifact loaded. - -How it failed: - Host: - ubuntu@16.50.61.215 - Output root: - /home/ubuntu/validation_logs/fp8_256k/pfx256_segcte512_runtime_noprofile_20260526T202721Z - Wrapper PID: - 21303 - Context sweep PID: - 21309 - Phase: - pfx256k_prefill - Log: - /home/ubuntu/validation_logs/fp8_256k/pfx256_segcte512_runtime_noprofile_20260526T202721Z/pfx256k_prefill.log - Exact runtime errors: - TDRV:exec_process_custom_notification nd0:nc2:h_model.id1006: - Received notification generated at runtime: failed to run scatter/gather - (indirect memory copy via scalar DGE), due to out-of-bound access. - model name = - /mnt/trainium_artifacts/qwen_artifacts/_nxd_model_workdir_256k_fp8_full_prod_pfx256k_segcte512stream_qpack4_cte3072_pfx256k_pa1024_tkg262144_20260526T195604Z/context_encoding_model/_tp0_bk0/model.MODULE_dc595ea41a524c32e935+86f42f0e.neff. - - NMGR:kmgr_exec_worker_do_work Async request 88 failed for model - .../context_encoding_model/_tp0_bk0/model.MODULE_dc595ea41a524c32e935+86f42f0e.neff - on vnc 1 with status 1006 - - NMGR:kmgr_async_exec_default_exec_status_callback Exec id 88 for model - 10006 on worker 1 failed with fatal status 1006... aborting. - - /opt/workspace/KaenaRuntime/kmgr/kmgr_async_exec.cc:34: - void kmgr_async_exec_default_exec_status_callback(...): - Assertion `0' failed. - - ERROR Engine core proc EngineCore_DP0 died unexpectedly, shutting down client. - -How we got there: - Artifact: - /mnt/trainium_artifacts/qwen_artifacts/qwen36_27b_256k_fp8_full_lmheadbf16_hybrid_apc_prod_pfx256k_segcte512stream_qpack4_nki_fusedstable_directsolve_hostlogits_b256_cte3072_pfx256k_pa1024_slots64_tkg262144_async_20260526T195604Z - Runtime flags: - seq_len=262144 - max_model_len=262144 - pa_num_blocks=1024 - block_size=256 - gdn_checkpoint_interval=256 - max_gdn_checkpoint_slots=64 - cte_buckets=3072 - token_generation_buckets=262144 - context_encoding_bucket_pairs=3072:262144 - lengths=261888 - max_tokens=1 - suffix_tokens=16 - require_real_tokens=true - Runtime environment: - NEURON_RT_INSPECT_ENABLE=0 - NEURON_RT_INSPECT_DEVICE_PROFILE unset - Kernel/compile path: - PREFIX_CTE_ATTENTION_BACKEND=segmented_cte - PREFIX_CTE_ATTENTION_SEGMENT_SIZE=512 - TIER_NAME=pfx256k_segcte512stream_qpack4 - -Memory evidence: - The artifact loaded before execution. The memory sampler showed the runtime - had dropped to present-only bookkeeping after the fatal DGE error, not a - NRT_RESOURCE allocation failure: - latest host RSS: about 1.08 GiB for qwen36_hybrid_apc_context_sweep.py - neuron present: about 6.46GB total - latest total bytes: 0 - This separates this failure from the earlier device-profiling HBM failure. - -Root cause / hypothesis: - Proven: - The 256K pfx artifact compiles and loads without device profiling, but the - context_encoding_model NEFF issues an out-of-bound scalar DGE access during - execution. - Best current hypothesis: - The qwen_segcte256 segmented CTE kernel has a runtime address calculation - bug for the actual long-prefix path. The likely fault is in the mapping of - block-table, prior segment, active segment, or kv_section_idx offsets for - the 261888-token request. BIR scratch sizing and compile legality did not - catch it because the address goes out of range only with real runtime block - tables and long-prefix execution. - -Fix / mitigation applied: - Stopped the failed validation wrapper and context sweep on ubuntu@16.50.61.215 - to free Neuron resources: - kill -TERM 21309 / children through wrapper PID 21303, then kill stale - sampler PID 21308. - -Next mitigation: - Do not retry the same pfx256 segcte512stream artifact as production. Build a - targeted runtime/addressing debug path for qwen_segcte256: - 1. Reproduce with a smaller debug prefix artifact or a reduced long-prefix - request that still uses segmented_cte address math. - 2. Instrument or assert the block table index, prior segment start, active - stream start, kv_section_idx, and max addressed block before DGE loads. - 3. Patch the segmented CTE offset mapping, then rerun BIR preflight and a - no-profile runtime prefill before enabling any profiling. - -Verification: - Validation did not complete. No prefill, TTFT, TPOT, or chat metrics were - produced for this artifact. -``` - -### Null-Block PA Count Mismatch Hypothesis for DGE OOB - -```text -Additional evidence: - The failing pfx256 artifact was compiled with: - pa_num_blocks=1024 - block_size=256 - max_context_length=262144 - Runtime vLLM logs showed: - Adding 1 to num_gpu_blocks_override (1024 -> 1025) to account for null - block allocation - User provided pa_num_blocks (1024) matching original - --num-gpu-blocks-override intent. Incrementing pa_num_blocks to 1025 to - match the increment for a null block in vllm. - -Why this matters: - For vLLM, the user-intended usable block count for 256K at block size 256 is - 1024. vLLM adds one reserved null block, so the physical block-KV cache needs - 1025 blocks. The current artifact was compiled as pa1024, so a block-table - value of 1024 can be legal to vLLM but out of bounds for the compiled NEFF's - raw block-KV cache. That matches the observed scalar-DGE OOB in - context_encoding_model. - -Root cause / hypothesis update: - Best current hypothesis is now a PA physical-block sizing mismatch, not - scratch/HBM pressure. The qwen_segcte256 kernel may still need address tests, - but the first robust/simple fix to try is compiling the artifact with 1025 - physical PA blocks while running vLLM with 1024 usable blocks. - -Fix applied to validation scripts: - Updated validation_scripts/qwen36_hybrid_apc_context_sweep.py and - validation_scripts/qwen36_offline_decode_bench.py so artifact pa_num_blocks - is treated as physical block count. When the artifact uses block KV or prefix - caching and has more blocks than the minimum usable request, validation passes - artifact_pa_num_blocks - 1 as vLLM's num_gpu_blocks_override. - -Next mitigation: - Compile a replacement pfx256 segcte512stream artifact with: - PA_NUM_BLOCKS=1025 - PREFIX_CTE_ATTENTION_SEGMENT_SIZE=512 - CTE_BUCKETS_STR=3072 - PAIR_ARGS_STR=3072:262144 - PREFIX_BUCKETS_STR=262144 - Then validate it with user-usable pa override 1024 so vLLM adds the null - block back to 1025. -``` - -### PA1025 Relaunch Setup Errors and Correction - -```text -What failed: - First corrected relaunch attempt on ubuntu@16.50.61.215 used: - TIER_NAME=pfx256k_segcte512stream_qpack4_pafix - PA_NUM_BLOCKS=1025 - PREFIX_CTE_ATTENTION_BACKEND=segmented_cte - PREFIX_CTE_ATTENTION_SEGMENT_SIZE=512 - PAIR_ARGS_STR=3072:262144 - -How it failed: - The remote helper script was stale and hardcoded: - --pa-num-blocks 1024 - The resulting process PID 28426 was running a pa1024 compile even though the - environment requested PA_NUM_BLOCKS=1025. The log showed: - CONTEXT_TRACE_SHAPE ... "pa_num_blocks": 1024, "pa_min_blocks": 1024, - "pa_headroom_blocks": 0 - -How we got there: - The local helper had already been updated to include _pa${PA_NUM_BLOCKS} in - the artifact name and to pass --pa-num-blocks "${PA_NUM_BLOCKS}", but that - helper had not been synced to ubuntu@16.50.61.215. - -Fix / mitigation applied: - Stopped PID 28426 before useful compilation work continued, synced the local - helper to: - /home/ubuntu/inferentia-gdn-fused-noclamp-4340808/tmp_compile_qwen256k_fp8_full_prod_prefix_tier_hostlogits.sh - Verified the synced helper contains: - BASE=..._pa${PA_NUM_BLOCKS}_... - --pa-num-blocks "${PA_NUM_BLOCKS}" - -Verification: - Relaunch produced a _pa1025_ artifact name. -``` - -```text -What failed: - The next relaunch on ubuntu@16.50.61.215 failed before compilation started. - -Exact error: - qwen36_27b_compile_fp8.py: error: unrecognized arguments: - --omit-zero-prefix-pair - --prefix-cte-attention-backend segmented_cte - --prefix-cte-attention-segment-size 512 - -How we got there: - Remote repo branch was codex/full-fp8-qwen36 at 03e7e3a, but - contrib/models/Qwen3.6-27B/test/integration/qwen36_27b_compile_fp8.py was - stale relative to the local full-FP8 branch work. The runtime modules already - had segmented_cte support, but the compile entrypoint did not expose the - required CLI flags. - -Fix / mitigation applied: - Synced the local compile entrypoint to: - /home/ubuntu/inferentia-gdn-fused-noclamp-4340808/contrib/models/Qwen3.6-27B/test/integration/qwen36_27b_compile_fp8.py - Verified with: - python3 -m py_compile contrib/models/Qwen3.6-27B/test/integration/qwen36_27b_compile_fp8.py - grep for omit-zero-prefix, prefix-cte-attention, and segmented_cte. - -Verification: - Corrected relaunch started as PID 29224 with: - artifact: - /mnt/trainium_artifacts/qwen_artifacts/qwen36_27b_256k_fp8_full_lmheadbf16_hybrid_apc_prod_pfx256k_segcte512stream_qpack4_pafix_nki_fusedstable_directsolve_hostlogits_b256_cte3072_pfx256k_pa1025_slots64_tkg262144_async_20260526T205447Z - log: - /home/ubuntu/validation_logs/fp8_256k/qwen36_27b_256k_fp8_full_lmheadbf16_hybrid_apc_prod_pfx256k_segcte512stream_qpack4_pafix_nki_fusedstable_directsolve_hostlogits_b256_cte3072_pfx256k_pa1025_slots64_tkg262144_async_20260526T205447Z_compile.log - The compile log now shows: - CONTEXT_TRACE_SHAPE ... "pa_num_blocks": 1025, - "pa_min_blocks": 1024, - "pa_headroom_blocks": 1, - "prefix_cte_attention_backend": "segmented_cte", - "prefix_cte_attention_segment_size": 512 - and then enters HLO generation for context_encoding_model. -``` - -```text -What failed: - The PA1025 relaunch at 20260526T205447Z reached HLO tracing but failed in - Python before neuronx-cc compilation. - -Exact error: - AttributeError: 'QwenHybridBlockKVCacheManager' object has no attribute - 'get_raw_kv_by_layer_id'. Did you mean: 'get_kv_by_layer_id'? - -Evidence: - PID/log: - /home/ubuntu/validation_logs/fp8_256k/qwen36_27b_256k_fp8_full_lmheadbf16_hybrid_apc_prod_pfx256k_segcte512stream_qpack4_pafix_nki_fusedstable_directsolve_hostlogits_b256_cte3072_pfx256k_pa1025_slots64_tkg262144_async_20260526T205447Z_compile.pid - /home/ubuntu/validation_logs/fp8_256k/qwen36_27b_256k_fp8_full_lmheadbf16_hybrid_apc_prod_pfx256k_segcte512stream_qpack4_pafix_nki_fusedstable_directsolve_hostlogits_b256_cte3072_pfx256k_pa1025_slots64_tkg262144_async_20260526T205447Z_compile.log - Stack: - modeling_qwen35.py:get_cache -> self.get_raw_kv_by_layer_id(...) - torch.nn.Module.__getattr__ raised AttributeError. - -How we got there: - The Qwen model file and attention path expected the newer raw block-KV cache - accessor, but the remote - src/neuronx_distributed_inference/modules/kvcache/block_kv_cache_manager.py - had not been synced with the matching full-FP8 branch changes. - -Fix / mitigation applied: - Synced the matching local cache/runtime files to ubuntu@16.50.61.215: - src/neuronx_distributed_inference/modules/kvcache/block_kv_cache_manager.py - src/neuronx_distributed_inference/models/config.py - src/neuronx_distributed_inference/models/model_wrapper.py - src/neuronx_distributed_inference/modules/async_execution.py - src/neuronx_distributed_inference/modules/autobucketing.py - src/neuronx_distributed_inference/modules/attention/attention_base.py - src/neuronx_distributed_inference/modules/attention/nki_kernels/ - Verified with py_compile and confirmed: - def get_raw_kv_by_layer_id(self, idx, kvcache_buffer=None, **kwargs) - -Verification: - Relaunched as PID 30174 with artifact: - /mnt/trainium_artifacts/qwen_artifacts/qwen36_27b_256k_fp8_full_lmheadbf16_hybrid_apc_prod_pfx256k_segcte512stream_qpack4_pafix_nki_fusedstable_directsolve_hostlogits_b256_cte3072_pfx256k_pa1025_slots64_tkg262144_async_20260526T205813Z - Current log shows: - CONTEXT_TRACE_SHAPE ... "pa_num_blocks": 1025, - "pa_headroom_blocks": 1, - "prefix_cte_attention_backend": "segmented_cte", - "prefix_cte_attention_segment_size": 512 - Finished generating HLO for context_encoding_model - Started loading module token_generation_model -``` - -```text -Operator error: - During the remote sync fix, one multi-file scp command targeted the attention - directory for all source files. It created extra inert copies under: - src/neuronx_distributed_inference/modules/attention/config.py - src/neuronx_distributed_inference/modules/attention/model_wrapper.py - src/neuronx_distributed_inference/modules/attention/async_execution.py - src/neuronx_distributed_inference/modules/attention/autobucketing.py - -Impact / hypothesis: - These files are not imported by the current attention package path, but they - are remote workspace clutter and should be removed after explicit approval or - during the next cleanup pass. - -Fix / mitigation applied: - Re-copied each file to its correct destination. No compile path depends on - the accidental files. -``` - -### PA1025 pfx256 Runtime Validation Failed with DGE OOB - -```text -What failed: - No-device-profile runtime validation of the corrected PA1025 pfx256k - segmented CTE artifact on TRN2 ubuntu@16.50.61.215. - -Artifact: - /mnt/trainium_artifacts/qwen_artifacts/qwen36_27b_256k_fp8_full_lmheadbf16_hybrid_apc_prod_pfx256k_segcte512stream_qpack4_pafix_nki_fusedstable_directsolve_hostlogits_b256_cte3072_pfx256k_pa1025_slots64_tkg262144_async_20260526T205813Z - -Validation output root: - /home/ubuntu/validation_logs/fp8_256k/pfx256_pa1025_runtime_noprofile_20260527T034323Z - -Inputs / flags: - validation_scripts/qwen36_hybrid_apc_context_sweep.py - --lengths 261888 - --max-tokens 16 - --suffix-tokens 16 - --seq-len 262144 - --max-model-len 262144 - --cte-buckets 3072 - --context-encoding-bucket-pairs 3072:262144 - --token-generation-buckets 262144 - --async-mode - --block-size 256 - --gdn-checkpoint-interval 256 - --max-gdn-checkpoint-slots 64 - --gdn-recurrent-cache-dtype float32 - --gdn-conv-cache-dtype bfloat16 - --require-real-tokens - Device profiling was explicitly disabled: - unset NEURON_RT_INSPECT_ENABLE - unset NEURON_RT_INSPECT_DEVICE_PROFILE - unset NEURON_RT_INSPECT_OUTPUT_DIR - -Observed runtime context: - Engine loaded the compiled artifact successfully. - vLLM reported: - GPU KV cache size: 262,400 tokens - Maximum concurrency for 262,144 tokens per request: 1.00x - Prompt execution started for the 261888-token request. - -Exact error: - At 2026-05-27T03:53:20Z, the context_encoding_model NEFF repeatedly emitted: - TDRV:exec_process_custom_notification ... failed to run scatter/gather - (indirect memory copy via scalar DGE), due to out-of-bound access. - model name = - /mnt/trainium_artifacts/qwen_artifacts/_nxd_model_workdir_256k_fp8_full_prod_pfx256k_segcte512stream_qpack4_pafix_cte3072_pfx256k_pa1025_tkg262144_20260526T205813Z/context_encoding_model/_tp0_bk0/model.MODULE_30b568c5d3faaeced212+b0ee5af3.neff - The first repeated failures were on Neuron cores nc4/nc6 and then expanded - across other cores. - -Memory evidence: - This was not an NRT_RESOURCE/OOM failure. The memory sampler captured: - peak_host_rss_gib: 35.310752868652344 - peak_neuron_by_category_gib.total: 157.06698608398438 - peak_neuron_by_category_gib.present: 7.32489013671875 - Note: peak_neuron_total_gib in the sampler summary sums peak/present/total - categories and should not be used as a real HBM footprint. - -Root cause / hypothesis update: - The null-block PA mismatch was not the root cause for the 256K runtime - failure. Compiling with 1025 physical PA blocks fixed the compile/load shape - and provided physical capacity for the null block, but the actual long-prefix - segmented CTE path still generates an out-of-range scalar-DGE address at - runtime. The best current hypothesis is now a qwen_segcte256 address-mapping - bug for the pfx256k context_encoding bucket, likely in block-table indexing, - prior-segment offset, active-stream offset, or kv_section_idx mapping inside - the custom segmented CTE kernel. - -Fix / mitigation applied: - Stopped the failed validation run and sampler after the DGE OOB: - wrapper PID: 31812 - sampler PID: 31814 - context sweep PID: 31815 - EngineCore PID: 31872 - Verified those PIDs were no longer present afterward. - -Remaining blocker: - The PA1025 pfx256k artifact is not runtime-valid and is not production-ready. - Do not run OpenAI/server TTFT/TPOT validation on this artifact until the - segmented CTE 256K address calculation is fixed or replaced. - -Next mitigation: - Build a targeted qwen_segcte256 debug/fix path: - 1. Reproduce with a small diagnostic harness that exercises the same - segmented CTE addressing with controlled block_table values. - 2. Add bounds checks or debug-side assertions for physical block id, - kv_head/block offset, prior segment start, active segment start, and - kv_section_idx before the DGE loads. - 3. Patch the qwen_segcte256 NKI address mapping, then recompile the - pfx256k bucket and rerun no-device-profile runtime validation. -``` - -```text -Operator/status-check errors encountered during this validation: - -1. A status check command exited 127 because it used `python` in a remote - non-login shell where only the activated venv process had `python` on PATH. - The validation process itself was unaffected. Mitigation: subsequent status - parsing used `python3` or an activated venv. - -2. The first cleanup command used a broad pgrep pattern: - qwen36_hybrid_apc_context_sweep|VLLM::EngineCore|neuron_memory_sampler - and matched its own SSH-side shell command, causing the SSH cleanup command - to exit 255 before printing post-cleanup status. Mitigation: reran cleanup - with explicit known PIDs 31812, 31814, 31815, and 31872, then verified they - were no longer running. -``` - -### Segmented CTE Active Block-Table Fill Fix - -```text -What failed: - Follow-up investigation of the PA1025 pfx256k runtime DGE OOB found that the - segmented CTE kernel reads the active suffix K/V from the raw paged KV cache - through active_block_table. If active suffix logical block-table entries are - still unset, the NKI kernel can consume an invalid block id for scalar DGE. - -Evidence / how we got there: - Artifact: - /mnt/trainium_artifacts/qwen_artifacts/qwen36_27b_256k_fp8_full_lmheadbf16_hybrid_apc_prod_pfx256k_segcte512stream_qpack4_pafix_nki_fusedstable_directsolve_hostlogits_b256_cte3072_pfx256k_pa1025_slots64_tkg262144_async_20260526T205813Z - Compile trace shape: - context_encoding_bucket_pairs=[[3072,262144]] - pa_num_blocks=1025 - pa_min_blocks=1024 - pa_headroom_blocks=1 - prefix_cte_attention_backend=segmented_cte - prefix_cte_attention_segment_size=512 - Runtime loaded the artifact with pa_num_blocks=1025, then failed inside the - context_encoding_model NEFF with: - failed to run scatter/gather (indirect memory copy via scalar DGE), - due to out-of-bound access - There were no debug `pad-pre`, `pad-post`, or `qwen-cte-call` lines in the - failed validation log because QWEN36_HYBRID_APC_DEBUG was not enabled. - -Root cause / best current hypothesis: - BlockKVCacheManager writes the active suffix K/V into the raw block cache by - slot_mapping. The qwen_segcte256 path then reads active K/V from the raw block - cache by active_block_table. For segmented CTE, active_block_table must contain - physical block ids for logical active-suffix blocks as well as prefix blocks. - If those active entries remain -1 or otherwise unset, the NKI kernel casts the - block table to uint32 and can form a huge scalar-DGE HBM offset. That matches - the observed runtime-only scalar DGE OOB after successful load. - -Fix / mitigation applied locally: - Patched: - src/neuronx_distributed_inference/models/model_wrapper.py - Added segmented-CTE-only input preprocessing in `_pad_prefix_caching_inputs`: - - derive active logical block positions from computed_context_lens + token - index - - derive active physical block ids from slot_mapping // pa_block_size - - fill those active logical block-table entries before masking/padding - - include active tokens when sizing the segmented CTE block table - This leaves the non-segmented attention_cte path unchanged. - -Test added: - test/unit/models/test_prefix_caching_bucket_selection.py - test_segmented_cte_padding_fills_active_block_table_from_slots - The focused case starts with block_table [[0, 1, 2, -1]], prefix_len=768, - suffix_len=48, pa_block_size=256, and slot_mapping in physical block 4. The - expected padded block table is [[0, 1, 2, 4]]. - -Local verification: - Command: - python3 -m py_compile src/neuronx_distributed_inference/models/model_wrapper.py test/unit/models/test_prefix_caching_bucket_selection.py - Result: - pass - -Local test environment errors: - Command: - python3 -m pytest test/unit/models/test_prefix_caching_bucket_selection.py -q - Result: - exit 2 during collection - Exact error: - ModuleNotFoundError: No module named 'neuronx_distributed_inference' - Mitigation: - reran with PYTHONPATH=src - - Command: - PYTHONPATH=src python3 -m pytest test/unit/models/test_prefix_caching_bucket_selection.py -q - Result: - exit 2 during collection - Exact error: - ModuleNotFoundError: No module named 'neuronx_distributed' - Root cause / hypothesis: - The local Mac environment lacks the Neuron/NxD Python dependency needed for - this test module. This is an environment dependency issue, not a syntax - failure; py_compile passed locally. - -Next verification: - Sync the patch to TRN2 ubuntu@16.50.61.215, run py_compile and the focused - pytest in the Neuron venv, then rerun a no-device-profile pfx256 validation - with QWEN36_HYBRID_APC_DEBUG=1. If the shorter debug validation passes, rerun - the original 261888-token validation against the same compiled artifact. -``` - -### Active Block-Table Fill Validation Results - -```text -What passed: - Remote syntax/unit validation on TRN2 ubuntu@16.50.61.215 after syncing: - src/neuronx_distributed_inference/models/model_wrapper.py - test/unit/models/test_prefix_caching_bucket_selection.py - contrib/models/Qwen3.6-27B/docs/QWEN36_FP8_TIERFIX_VALIDATION_20260526.md - -Command: - cd /home/ubuntu/inferentia-gdn-fused-noclamp-4340808 - source /opt/aws_neuronx_venv_pytorch_inference_vllm_0_16/bin/activate - PYTHONPATH=src python -m py_compile \ - src/neuronx_distributed_inference/models/model_wrapper.py \ - test/unit/models/test_prefix_caching_bucket_selection.py - PYTHONPATH=src python -m pytest \ - test/unit/models/test_prefix_caching_bucket_selection.py -q - -Result: - 35 passed, 46 warnings in 5.33s - -What passed at runtime: - Short debug validation with the PA1025 pfx256k artifact and the local - active-block-table fill patch completed without DGE OOB. - -Artifact: - /mnt/trainium_artifacts/qwen_artifacts/qwen36_27b_256k_fp8_full_lmheadbf16_hybrid_apc_prod_pfx256k_segcte512stream_qpack4_pafix_nki_fusedstable_directsolve_hostlogits_b256_cte3072_pfx256k_pa1025_slots64_tkg262144_async_20260526T205813Z - -Output root: - /home/ubuntu/validation_logs/fp8_256k/pfx256_pa1025_activeblockfix_short_20260527T_test - -Inputs: - --lengths 8192 - --max-tokens 1 - --suffix-tokens 16 - --seq-len 262144 - --max-model-len 262144 - --cte-buckets 3072 - --context-encoding-bucket-pairs 3072:262144 - --token-generation-buckets 262144 - --async-mode - --block-size 256 - --gdn-checkpoint-interval 256 - --max-gdn-checkpoint-slots 64 - --gdn-recurrent-cache-dtype float32 - --gdn-conv-cache-dtype bfloat16 - --require-real-tokens - QWEN36_HYBRID_APC_DEBUG=1 - -Short-run evidence: - The debug trace showed the active/prefix block table now includes the active - physical block range. Examples: - prefix_len=6144, slot_mapping max=8447, block_table max=32 - prefix_len=6144, slot_mapping max=18687, block_table max=72 - The 8192-token run completed: - cold elapsed: 14.63105383799848s - warm elapsed: 4.818916980999347s - real_tokens_passed: true - -What still failed: - Full 261888-token validation with the same artifact and patch still failed - inside the context_encoding_model NEFF with scalar DGE OOB. - -Output root: - /home/ubuntu/validation_logs/fp8_256k/pfx256_pa1025_activeblockfix_full_20260527T0524Z - -Inputs: - --lengths 261888 - --max-tokens 16 - --suffix-tokens 16 - --seq-len 262144 - --max-model-len 262144 - --cte-buckets 3072 - --context-encoding-bucket-pairs 3072:262144 - --token-generation-buckets 262144 - --async-mode - --block-size 256 - --gdn-checkpoint-interval 256 - --max-gdn-checkpoint-slots 64 - --gdn-recurrent-cache-dtype float32 - --gdn-conv-cache-dtype bfloat16 - --require-real-tokens - Device profiling and QWEN36_HYBRID_APC_DEBUG were disabled for the full run. - -Exact error: - First repeated failures at run.log lines 2592+: - 2026-May-27 05:21:35.021738 ... ERROR TDRV:exec_process_custom_notification - nd0:nc6:h_model.id1005: Received notification generated at runtime: - failed to run scatter/gather (indirect memory copy via scalar DGE), - due to out-of-bound access. model name = - /mnt/trainium_artifacts/qwen_artifacts/_nxd_model_workdir_256k_fp8_full_prod_pfx256k_segcte512stream_qpack4_pafix_cte3072_pfx256k_pa1025_tkg262144_20260526T205813Z/context_encoding_model/_tp0_bk0/model.MODULE_30b568c5d3faaeced212+b0ee5af3.neff. - The same error appeared on nc4, nc5, nc6, nc7 and later nc0/nc1/nc2/nc3. - The runtime also reported: - TDRV:exec_request_process_errors [ND 0][NC 6] Out of bounds access on model ... - NMGR:dlr_exec_wait Execution completed with err: 1006. mode->h_nn=1008, lnc=2 - -Core dump evidence: - Neuron generated NRT_EXEC_OOB dumps: - /tmp/neuron-core-dump/dt-20260527-051233-cid-d99e36ea74c263ca - i-05d3f024966df11d5-nd0-nc4-pid-39738-tid-39861-lid-1 - i-05d3f024966df11d5-nd0-nc6-pid-39738-tid-39862-lid-2 - i-05d3f024966df11d5-nd0-nc2-pid-39738-tid-39863-lid-3 - -Memory evidence: - This was not a Neuron load OOM/NRT_RESOURCE failure. - Memory summary: - peak_host_rss_gib: 34.55003356933594 - peak_neuron_by_category_gib.present: 6.589611053466797 - peak_neuron_by_category_gib.total: 157.06698608398438 - As before, the sampler's peak_neuron_total_gib sums sysfs categories and is - not a single real HBM allocation. - -Root cause / hypothesis update: - The active-block-table fill is necessary and fixes a real input-prep hazard, - but it is not sufficient for the pfx256/261888 path. The remaining scalar DGE - OOB is likely inside qwen_segcte256 address generation for high prior segment - indices, for example: - - prior segment block-table offset when prefix_len approaches 256K - - the first/last partial-prior segment around a 512-token segment boundary - - segment index to block-table index arithmetic in the NKI kernel - - kv_section_idx or KV-head/block offset at high logical block ids - This is now confirmed as a kernel/addressing bug, not a PA1025 capacity issue - and not just missing active physical block ids. - -Fix / mitigation applied: - Stopped the failed full validation and sampler: - sampler PID: 39613 - wrapper bash PID: 39684 - context sweep PID: 39692 - EngineCore PID: 39738 - PID 39738 became a short-lived defunct EngineCore while neuron-dump wrote - NRT_EXEC_OOB dumps. No qwen36 sweep/sampler process remained afterward. - -Remaining blocker: - The PA1025 pfx256k segmented CTE artifact is still not runtime-valid for - 261888-token / 256K-context serving. It must not be called production-ready. - -Next mitigation: - Add high-prefix debug instrumentation or a CPU/NKI address simulator for - qwen_segcte256 and binary-search the failing prefix length with the pfx256 - artifact. The short 8K smoke is not enough; test lengths should bracket the - failure, e.g. 32768, 65536, 131072, 196608, 229376, and 261888, with debug - enabled only around the final failing CTE chunk. -``` - -```text -Operator errors during the active-block-table validation: - -1. The first full-run wrapper backgrounded too broad a shell command and lost - ROOT/PATH state. It printed: - tee: /run.log: Permission denied - bash: line 1: python: command not found - The validation did not start. An orphaned sampler PID 39303 was killed. - Mitigation: reran with explicit absolute output paths and separate sampler - launch. - -2. The first separate sampler launch quoted ROOT incorrectly inside nested - local/remote shell expansion. It printed: - mkdir: missing operand - bash: line 1: /sampler.pid: Permission denied - No validation ran from that command. Mitigation: relaunched sampler with - literal absolute paths. -``` - -### Root Cause Found: Final Partial Active Chunk Reads Past Block Table - -```text -What failed: - The PA1025 pfx256k artifact still emitted scalar DGE OOB at 261888 tokens even - after the Python active-block-table fill. The 8192-token smoke passed, which - meant the remaining bug was specific to high-prefix / end-of-context address - generation. - -Code path: - src/neuronx_distributed_inference/modules/attention/nki_kernels/qwen_segcte256/ - attention_segmented_cte_256.py - fused_segmented_attention_256.py - -Root cause: - In qwen_segcte256 active streaming, the compiled 3072-token CTE bucket is - split into six 512-token active stream sections: - active_stream_tokens = 512 - num_active_stream_sections = ceil(3072 / 512) = 6 - num_blocks_per_active_stream = 512 / 256 = 2 - - For the final real chunk of a 261888-token prompt: - prior_tokens = 261120 - real active_len = 768 - active_block_offset = prior_tokens // 256 = 1020 - - The compiled active-stream loop still loads all six bucket sections, so the - block-table offsets are: - 1020, 1022, 1024, 1026, 1028, 1030 - - A real pfx256 block_table has 1024 entries. The older internal padding only - padded to 1026 entries for the prior-segment one-past read: - padded_width = (1024 // 2 + 1) * 2 = 1026 - - Therefore active sections 4 and 5 can read block-table offsets 1028/1030, - outside the internally padded table. That exactly matches AWS Neuron's DGE - docs: scalar/vector DGE offsets must still resolve to valid tensor addresses, - otherwise runtime reports out-of-bound scatter/gather. - -Fix applied locally: - Patched `attention_segmented_cte_256.py` to pad the internal block table for - both hazards: - padded_width_for_prior = one extra prior segment - padded_width_for_active_stream = max_blocks_per_seq + seqlen_q // block_size - padded_width = rounded max of both - - For pfx256 cte3072 this pads from 1024 to 1036 entries, so out-of-range - compiled active-stream sections read zero block ids from the padded tail - instead of DGE-reading past the block table. Block id 0 is the existing null - block, so this matches the intended padding semantics. - -Why this aligns with docs: - The NKI/DGE docs allow dynamic/scalar-offset DMA patterns, but the program is - responsible for keeping the dynamic address inside the tensor. Padding the - source table before the scalar DGE access is the simple robust fix; relying on - masks after the DMA is too late because the OOB happens during the DMA - descriptor execution. - -Verification so far: - Local syntax: - python3 -m py_compile \ - src/neuronx_distributed_inference/modules/attention/nki_kernels/qwen_segcte256/attention_segmented_cte_256.py - result: pass - -Remaining work: - Sync to TRN2, run remote py_compile, recompile the pfx256 segmented CTE - artifact, then rerun the 261888-token no-device-profile validation. The old - PA1025 artifact cannot be fixed in place because this change is inside the - compiled NKI kernel. -``` - -### Bound-Fix PFX256 Runtime Validation Passed - -```text -Artifact: - /mnt/trainium_artifacts/qwen_artifacts/ - qwen36_27b_256k_fp8_full_lmheadbf16_hybrid_apc_prod_pfx256k_segcte512stream_qpack4_boundfix_nki_fusedstable_directsolve_hostlogits_b256_cte3072_pfx256k_pa1025_slots64_tkg262144_async_20260527T052822Z - -Validation root: - /home/ubuntu/validation_logs/fp8_256k/pfx256_boundfix_runtime_20260527T0552Z - -Inputs: - length: 261888 prompt tokens - max_tokens: 16 - seq_len/max_model_len: 262144 - cte/prefix pair: 3072:262144 - token_generation_bucket: 262144 - pa_num_blocks: 1025 - backend: segmented_cte - segment_size: 512 - profiling: disabled - -Result: - passed: true - real_tokens_passed: true - token_range_passed: true - non_dummy_generated_token_count: 48 - unique_generated_token_count: 41 - -Timings: - cold prefill+decode: 551.9684613459976s - prefix warmup prefill+decode: 551.5538088760004s - measured warm/refill+decode: 10.758389775000978s - cold effective prompt throughput: 474.46189110402327 tokens/s - warm/refill effective prompt throughput: 24342.67631839693 tokens/s - -Memory summary: - peak_host_rss_gib: 35.314510345458984 - peak_neuron_by_category_gib.present: 6.291294097900391 - peak_neuron_by_category_gib.total: 159.61318969726562 - -Notes: - The sysfs Neuron memory sampler aggregates categories and logical cores; the - `present` category is the most useful live resident counter from this sampler. - The larger `total` and `peak` aggregates are not single-device HBM usage. - -Monitor-side error encountered: - Command: - ssh ... 'python - < 262144). Running this sequence through the model will result in indexing errors - It had already completed 32768, 65536, and 131072 rows successfully. - The stuck child was manually terminated, so the suite recorded: - [2026-05-27T08:33:26+00:00] END server_context_bench rc=143 - How we got there: - validation_scripts/qwen36_chat_completion_context_bench.py was run with: - --lengths 32768,65536,131072,261888 --turns 8 --repeats 1 - The old prompt builder doubled filler repetitions until it exceeded the - target, which created a transient 426209-token chat-template probe for the - 261888-token target. - Root cause: - Validation harness bug, not a Neuron runtime/model failure. The prompt - builder used exponential overshoot probes that are too large near the - 262144-token model limit. - Fix: - Updated _make_messages in validation_scripts/qwen36_chat_completion_context_bench.py - to estimate filler repetitions from one-repeat token delta and correct - downward instead of doubling past the target. Synced the fixed script to - TRN2 and reran only server startup + server_context_bench. - Verification: - python3 -m py_compile validation_scripts/qwen36_chat_completion_context_bench.py - passed locally. - Corrected context bench passed: - 32768 target: prompt 32764, status 200, TTFT 57.0513s, completion 16 - 65536 target: prompt 65524, status 200, TTFT 66.4657s, completion 16 - 131072 target: prompt 131070, status 200, TTFT 132.5198s, completion 9 - 261888 target: prompt 261876, status 200, TTFT 319.2527s, completion 16 - -Memory summaries: - Primary server peak host RSS: 35.3254 GiB - Corrected server peak host RSS: 35.3525 GiB - Corrected server live Neuron `present` peak: 10.7526 GiB from sampler - -Monitor/tooling errors encountered: - write_stdin failed when attempting to interrupt old tail sessions: - stdin is closed for this session; rerun exec_command with tty=true to keep stdin open - This was a local monitoring-tool state issue. It did not affect remote - validation. The old remote suite had already exited and the corrected rerun - used a new tail session. - - A local sandboxed ps probe failed: - zsh:1: operation not permitted: ps - This was local sandboxing, not a repo or remote failure. Remote process - checks were done through ssh instead. - - The command used to stop the completed remote live tail returned ssh exit - code 255 with no stderr: - ssh ... 'pkill -f "tail -n 80 -F .*prod_readiness_boundfix_contextbench_capped_20260527T083606Z" || true' - Hypothesis: - pkill matched and terminated the remote tail/ssh session while the command - was still attached, so ssh reported disconnect as 255. - Verification: - The tail session then reported `Process exited with code 255`; validation - had already completed and the server had already shut down cleanly. -``` diff --git a/contrib/models/Qwen3.6-27B/docs/patches/mtp_batched_accept.patch b/contrib/models/Qwen3.6-27B/docs/patches/mtp_batched_accept.patch deleted file mode 100644 index a4ed7d11..00000000 --- a/contrib/models/Qwen3.6-27B/docs/patches/mtp_batched_accept.patch +++ /dev/null @@ -1,154 +0,0 @@ -From: Deepankar Singh -Subject: [PATCH] Qwen3.6-27B OpenAI server: consume all fused-spec accepted tokens per host loop - -The decode loop currently calls _token_scalar(out.tokens) which flattens the -output and returns only the first token. For non-spec inference this is -correct (out.tokens has shape (1,1)). For fused-spec MTP it discards all -accepted tokens after the first, causing the host to advance only one token -per Python iteration even when the device accepted multiple. - -Root cause of the observed 1.6x MTP gain vs the expected 2-2.5x: - - spec length = 2 in the artifact - - device returns N accepted tokens per forward (N in 1..2 typically) - - server keeps only tokens[0], requeues the rest implicitly by feeding - tokens[0] back as the next input - - effective speedup = 1 + (P_accept_2)*0.5 ~= 1.6x at high acceptance - -Fix: - - Add _accepted_tokens() that returns the prefix of in-vocab non-pad tokens - - Rewrite the decode loop as a while-loop that runs one forward per - iteration and commits ALL accepted tokens (up to max_tokens, stopping at - first EOS). - - Pre-allocate decode_ids / decode_position_ids / decode_attention_mask - (already done in the current code; preserved). - - Position-id update uses the position of the most recently committed - token: pos_value = prompt_tokens + len(new_ids) - 1. - -Expected result: decode 1.6x -> 2.0-2.4x on the same artifact, no recompile. - ---- - contrib/models/Qwen3.6-27B/scripts/openai_compat_server.py | 75 +++++++++++++-------- - 1 file changed, 47 insertions(+), 28 deletions(-) - -diff --git a/contrib/models/Qwen3.6-27B/scripts/openai_compat_server.py b/contrib/models/Qwen3.6-27B/scripts/openai_compat_server.py ---- a/contrib/models/Qwen3.6-27B/scripts/openai_compat_server.py -+++ b/contrib/models/Qwen3.6-27B/scripts/openai_compat_server.py -@@ -50,6 +50,7 @@ - return str(prompt) - - -+# Legacy: keep for prefill path which always returns a single token. - def _token_scalar(tokens: Any) -> int: - if hasattr(tokens, "detach"): - tokens = tokens.detach().cpu() -@@ -58,6 +59,36 @@ def _token_scalar(tokens: Any) -> int: - return int(tokens.reshape(-1)[0].item()) - - -+def _accepted_tokens( -+ tokens: Any, -+ vocab_size: int, -+ pad_id: Any = None, -+) -> List[int]: -+ """Return the prefix of legitimately-accepted tokens from a fused-spec output. -+ -+ For non-spec inference, ``out.tokens`` is shape (1, 1) and this returns a -+ 1-element list. For fused-spec MTP with speculation length K, ``out.tokens`` -+ is shape (1, K) where unused slots are padded with -1 or pad_token_id. -+ Scan left-to-right; the first out-of-vocab or pad value marks the boundary. -+ """ -+ if hasattr(tokens, "detach"): -+ tokens = tokens.detach().cpu() -+ if hasattr(tokens, "ndim") and tokens.ndim == 0: -+ v = int(tokens.item()) -+ if 0 <= v < vocab_size and (pad_id is None or v != pad_id): -+ return [v] -+ return [] -+ flat = tokens.reshape(-1).tolist() -+ accepted: List[int] = [] -+ for raw in flat: -+ v = int(raw) -+ if v < 0 or v >= vocab_size: -+ break -+ if pad_id is not None and v == pad_id: -+ break -+ accepted.append(v) -+ return accepted -+ -+ - class QwenOpenAIServer: - def __init__(self, args: argparse.Namespace): - self.args = args -@@ -187,8 +218,7 @@ class QwenOpenAIServer: - if first_token is None: - raise RuntimeError("prefill produced no token") - -- new_ids = [] -- current_token = first_token -+ new_ids: List[int] = [] - vocab_size = len(self.tokenizer) - raw_eos_id = self.tokenizer.eos_token_id - eos_ids = ( -@@ -196,6 +226,7 @@ class QwenOpenAIServer: - if isinstance(raw_eos_id, (list, tuple, set)) - else {raw_eos_id} - ) -+ pad_id = self.tokenizer.pad_token_id - decode_ids = torch.empty((1, 1), dtype=torch.int32) - decode_position_ids = torch.empty((1, 1), dtype=torch.int32) - decode_attention_mask = torch.ones( -@@ -203,21 +234,24 @@ class QwenOpenAIServer: - dtype=torch.int32, - ) - finish_reason = "length" -+ -+ # Bootstrap: commit the prefill token at position prompt_tokens. -+ if first_token < 0 or first_token >= vocab_size: -+ raise RuntimeError(f"prefill generated invalid token id: {first_token}") -+ new_ids.append(first_token) -+ if first_token in eos_ids: -+ finish_reason = "stop" -+ -+ # Decode loop: one forward per iteration, consume ALL accepted tokens. -+ # For non-spec, accepted is length 1. For fused-spec MTP, accepted may -+ # be length up to (speculation_length + 1). - with torch.no_grad(): -- for step in range(max_tokens): -- if current_token in eos_ids: -- finish_reason = "stop" -- break -- if current_token < 0 or current_token >= vocab_size: -- raise RuntimeError(f"model generated invalid token id: {current_token}") -- new_ids.append(current_token) -- if step == max_tokens - 1: -- break -- -- pos_value = prompt_tokens + step -- decode_ids[0, 0] = current_token -+ while len(new_ids) < max_tokens and finish_reason == "length": -+ last_token = new_ids[-1] -+ pos_value = prompt_tokens + len(new_ids) - 1 -+ decode_ids[0, 0] = last_token - decode_position_ids[0, 0] = pos_value - active_attention_mask = decode_attention_mask[:, : pos_value + 1] - out = self.model( - input_ids=decode_ids, - attention_mask=active_attention_mask, -@@ -226,7 +260,17 @@ class QwenOpenAIServer: - sampling_params=sampling_params, - return_dict=True, - ) -- current_token = _token_scalar(out.tokens) -+ accepted = _accepted_tokens(out.tokens, vocab_size, pad_id=pad_id) -+ if not accepted: -+ raise RuntimeError("model produced no accepted tokens in decode step") -+ for tok in accepted: -+ if len(new_ids) >= max_tokens: -+ break -+ if tok < 0 or tok >= vocab_size: -+ raise RuntimeError(f"model generated invalid token id: {tok}") -+ new_ids.append(tok) -+ if tok in eos_ids: -+ finish_reason = "stop" -+ break - elapsed = time.perf_counter() - t0 diff --git a/contrib/models/Qwen3.6-27B/docs/patches/mtp_batched_accept_README.md b/contrib/models/Qwen3.6-27B/docs/patches/mtp_batched_accept_README.md deleted file mode 100644 index d00b4d3c..00000000 --- a/contrib/models/Qwen3.6-27B/docs/patches/mtp_batched_accept_README.md +++ /dev/null @@ -1,118 +0,0 @@ -# MTP Batched Accept Fix - -## Problem - -The OpenAI-compatible server in `scripts/openai_compat_server.py` discards -fused-spec accepted tokens beyond the first. The decode loop calls -`_token_scalar(out.tokens)` which returns only `tokens[0]`, then feeds that -token back as the next input. Result: host advances 1 token per Python loop -iteration even when the device accepted multiple via MTP speculation. - -Observed effect on `qwen36_27b_128k_fp8_mtp_run2` artifact: -- Expected decode: 2.0-2.5x baseline (NVIDIA's published MTP gain for length=2) -- Actual decode: 1.6x baseline (44 tok/s vs 27 baseline) -- Gap is purely host-loop, not device compute - -## Fix - -Patch: `mtp_batched_accept.patch` - -Changes: -1. Add `_accepted_tokens(tokens, vocab_size, pad_id)` helper that scans the - fused-spec output tensor and returns the prefix of in-vocab non-pad tokens. -2. Rewrite the decode loop as a `while` loop with bootstrap + iterations: - - Bootstrap: commit `first_token` from prefill at position `prompt_tokens`. - - Each iteration: feed `new_ids[-1]` at position - `prompt_tokens + len(new_ids) - 1`, then commit all accepted tokens - returned by the device. - - Stop on EOS, max_tokens cap, or invalid token id. -3. No model recompile required. No NeuronConfig changes. - -## Apply - -From repo root on branch `codex/qwen36-mtp-vllm-apc`: - -```bash -git apply docs/patches/mtp_batched_accept.patch -# or, if line numbers shifted: -git apply --3way docs/patches/mtp_batched_accept.patch -``` - -Verify by inspection: - -```bash -grep -n "_accepted_tokens\|while len(new_ids)" \ - contrib/models/Qwen3.6-27B/scripts/openai_compat_server.py -``` - -Should show the helper definition near the top and the new while-loop in the -decode path. - -## Validation gates (in order) - -Run against the existing `qwen36_27b_128k_fp8_mtp_run2` artifact. - -### Gate 1: Smoke -Math prompt returns 391 with coherent text. No invalid token errors. Same -behavior as before the patch. - -```bash -curl -s -X POST http://localhost:8000/v1/chat/completions \ - -H 'Content-Type: application/json' \ - -d '{"model":"qwen3.6-27b-128k-fp8-mtp","messages":[ - {"role":"user","content":"What is 17 * 23?"}], - "max_tokens":32}' -``` - -Expect output containing `391`. - -### Gate 2: Greedy parity -Same 5 fixed prompts before and after patch. Greedy decode (top_k=1). -Token-by-token output should be **identical** between pre-patch and post-patch -because the patch only changes how the host loop consumes the device output, -not the math. - -If mismatch: bug in `_accepted_tokens` (likely missing pad sentinel or -off-by-one). Investigate before measuring perf. - -### Gate 3: Decode tok/s -Same benchmarks as the MTP results doc: -- 32-token prompt, 128-token completion: expect decode tok/s ~50-60 (vs 41.6) -- 28-token prompt, 256-token completion: expect decode tok/s ~55-65 (vs 44.3) -- 3959-token prompt, 128-token completion: expect decode tok/s ~55-65 (vs 45.2) - -If decode is unchanged from previous MTP measurements: spec is not actually -accepting multiple tokens per forward. Verify by logging -`len(accepted)` distribution during a 200-token generation; expect mean ≥ 1.5. - -### Gate 4: Long-context coherence -16K-token prompt, 256-token completion. Output should be coherent and not -contain any invalid tokens. Same quality as pre-patch. - -## Expected speedup - -| Workload | Before patch | After patch | Mechanism | -|---|---:|---:|---| -| 32-tok / 128-out decode | 41.6 tok/s | **~55-65 tok/s** | Consume 2 accepted per forward | -| 28-tok / 256-out decode | 44.3 tok/s | **~55-65 tok/s** | Sustained spec acceptance | -| 4K / 128-out decode | 45.2 tok/s | **~55-65 tok/s** | Same | - -Combined with baseline v3 (27 tok/s) → MTP after patch (~55-65) is **2.0-2.4x -total decode speedup**, matching NVIDIA's published number for spec length=2. - -## What this does NOT do - -- Does not change prefill speed (still ~420 tok/s flat across contexts) -- Does not change model quality (same math, same tokens, same logits) -- Does not change vLLM bridge (custom OpenAI server only) -- Does not change cache management -- Does not require artifact recompile - -## Followups after this lands - -1. Tag artifact + branch as `qwen36-27b-mtp-v2` with the new tok/s numbers -2. Apply the same batched-accept logic to the vLLM-Neuron decode path (once - the v1 MTP registry gap is fixed) -3. Investigate speculation length=3 (currently length=2 in the artifact) -4. Measure acceptance rate distribution; if mean < 1.5, MTP head quality is - the limit, not the host loop From 5619ee067e5e210315bab469556df696e9610181 Mon Sep 17 00:00:00 2001 From: Deepankar Singh Date: Wed, 10 Jun 2026 22:32:53 +0530 Subject: [PATCH 3/3] Remove stale Qwen benchmark references --- contrib/models/Qwen3.6-27B/README.md | 26 +- contrib/models/Qwen3.6-27B/vllm/README.md | 305 ++++------------------ 2 files changed, 63 insertions(+), 268 deletions(-) diff --git a/contrib/models/Qwen3.6-27B/README.md b/contrib/models/Qwen3.6-27B/README.md index f868a818..7558f93a 100644 --- a/contrib/models/Qwen3.6-27B/README.md +++ b/contrib/models/Qwen3.6-27B/README.md @@ -225,7 +225,10 @@ The DeltaNet forward path can be controlled via environment variables: ## Caveats -1. **BF16 HBM pressure at TP=4:** The pure BF16 model consumes nearly all HBM on trn2.3xlarge. Use the FP8/vLLM path for the validated 128K artifact, or a larger instance for additional batching/headroom. +1. **HBM pressure at TP=4:** The 27B text decoder is memory-constrained on + `trn2.3xlarge`. The current validated long-context path uses selective FP8 + weights while keeping sensitive KV, LM-head, gate, and GDN state paths in + BF16/FP32. Larger instances are recommended for batching/headroom. 2. **SDK 2.29+ required:** The NKI DeltaNet kernels require NKI 0.3.0 (SDK 2.29). No library modifications needed -- runs on stock SDK 2.29 DLAMI (`/opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/`). @@ -233,7 +236,9 @@ The DeltaNet forward path can be controlled via environment variables: 4. **Vision encoder runs on CPU:** The ViT cannot be placed on Neuron because HBM is fully consumed by the text decoder. This adds ~918ms latency per image. Future optimization: quantize text decoder to free HBM, or use larger instance. -5. **Compilation time:** The short-context BF16 path compiles in roughly 13 minutes. The validated 128K FP8/vLLM artifact takes longer because it includes long-context cache shapes and presharded checkpoints. +5. **Compilation time:** Long-context artifacts take substantially longer than + short smoke artifacts because they include long-context cache shapes, + segmented CTE graphs, and presharded checkpoints. 6. **+1 RMSNorm convention:** Qwen3.5/3.6 uses `output = norm(x) * (1 + weight)` for most RMSNorm layers, but DeltaNet internal norms use standard `output = norm(x) * weight`. The weight conversion handles this automatically, but custom weight loading must be aware of both conventions. @@ -245,21 +250,20 @@ The DeltaNet forward path can be controlled via environment variables: | seq_len | Path | Status | Notes | |---------|------|--------|-------| -| 128 | BF16 NxDI | **PASS** | BF16 baseline/quality checks | -| 256 | BF16 NxDI | **PASS** | BF16 benchmark bucket | -| 512 | BF16 NxDI | **PASS** | 4 DeltaNet chunks | -| 65,536 | FP8/vLLM | **PASS** | chunked prefill, quality, and state-reset validation | -| 131,072 | FP8/vLLM | **PASS** | compiled and served with 512-token CTE bucket | +| 16,384 | 256K native-chunk loadfix | **PASS** | 2,394.6 usage-accounted prompt tok/s | +| 242,864 usage-accounted | 256K native-chunk loadfix | **PASS** | 1,029.2 usage-accounted prompt tok/s | +| 253,899 estimated | 256K native-chunk loadfix | **PASS** | tokenizer-derived estimate for the same long-context run | -For production long-context serving on trn2.3xlarge, use the FP8/vLLM artifact -and 512-token context encoding bucket. Larger instances are recommended for -larger batches or additional serving headroom. +For production long-context serving on `trn2.3xlarge`, use the validated 256K +loadfix artifact contract: segmented CTE512, CTE bucket 2048, BF16 KV, BF16 LM +head, BF16 gates, FP32 GDN recurrent state, BF16 GDN conv state, and host +sampling. ## Compatibility Matrix | Instance | TP | LNC | Status | Notes | |----------|-----|-----|--------|-------| -| trn2.3xlarge | 4 | 2 | **PASS** | BF16 short-context and FP8 128K vLLM/APC validated | +| trn2.3xlarge | 4 | 2 | **PASS** | 256K native-chunk loadfix validation passed at 16K and long context | | trn2.12xlarge | 16 | 2 | Expected PASS | Untested, recommended for batching/headroom | ### SDK Configuration diff --git a/contrib/models/Qwen3.6-27B/vllm/README.md b/contrib/models/Qwen3.6-27B/vllm/README.md index 4c921efa..1ce3d9c1 100644 --- a/contrib/models/Qwen3.6-27B/vllm/README.md +++ b/contrib/models/Qwen3.6-27B/vllm/README.md @@ -4,8 +4,10 @@ This folder contains the first-pass vLLM integration helpers for the Qwen3.6-27B contrib model. The current goal is **vLLM serving through the Neuron/NxDI plugin** for the -validated Qwen3.6 artifact, including long prompts through vLLM's native -chunked-prefill scheduler. +validated coherent Qwen3.6 artifact. The validated fast long-context path uses +the compiled Neuron-native chunking contract captured in the artifact config; +the launcher should mirror that contract instead of relying on generic vLLM +chunk slicing. ## Which vLLM Neuron Package? @@ -136,71 +138,37 @@ contrib/models/Qwen3.6-27B/vllm/start_vllm_server.sh \ --port 8000 ``` -Precompiled artifact path: +Current 256K precompiled artifact path: ```bash contrib/models/Qwen3.6-27B/vllm/start_vllm_server.sh \ --model-path /opt/dlami/nvme/models/Qwen3.6-27B \ - --compiled-artifacts /opt/dlami/nvme/qwen_artifacts/qwen36_27b_128k_fp8_mlp_only_vllm_statereset_run1 \ - --max-model-len 131072 \ - --seq-len 131072 \ - --cte-buckets 128,256,512 \ - --port 8000 -``` - -Cold-prefill bucket waste is the first performance target. CTE buckets must stay -128-aligned because the fused DeltaNet CTE path operates in 128-token chunks. -Use one of the explicit profiles when compiling artifacts: - -```bash -# Short-prompt latency ---cte-bucket-profile short # [128,256,512,1024] - -# General production ---cte-bucket-profile general # [256,512,1024,2048] - -# Long-context artifact ---cte-bucket-profile long # [4096,8192,16384,32768] - -# 262K load experiment ---cte-bucket-profile 262k # [256] -``` - -`--cold-zero-conv-fast-path` is only for a cold-only CTE artifact whose suffix -prefill always starts at position 0. Leave it disabled for APC or partial-prefix -serving because restored GDN conv state must be consumed exactly. - -Long-prompt precompiled artifact path: - -```bash -contrib/models/Qwen3.6-27B/vllm/start_vllm_server.sh \ - --model-path /opt/dlami/nvme/models/Qwen3.6-27B \ - --compiled-artifacts /opt/dlami/nvme/qwen_artifacts/qwen36_27b_128k_fp8_mlp_only_vllm_statereset_run1 \ - --max-model-len 131072 \ - --seq-len 131072 \ - --cte-buckets 256,512 \ + --compiled-artifacts /mnt/trainium_artifacts/qwen_artifacts/qwen36_256k_fp8_loadfix_lmheadbf16_gatesbf16_kvbf16_qkvnki_segmented_cte512_gdnseg512_cte2048_pfx256k_pa1025_slots64_20260608T195113Z_256k_loadfix_segcte2048_chatfix_hostsampling_kkt_hier_scan7 \ + --max-model-len 262144 \ + --seq-len 262144 \ + --cte-bucket 2048 \ --block-size 256 \ - --enable-vllm-chunked-prefill \ + --num-gpu-blocks-override 1024 \ --port 8000 ``` -Native vLLM prefix-cache experiment: +Prefix-cache experiment: ```bash contrib/models/Qwen3.6-27B/vllm/start_vllm_server.sh \ --model-path /opt/dlami/nvme/models/Qwen3.6-27B \ - --compiled-artifacts /opt/dlami/nvme/qwen_artifacts/qwen36_27b_128k_fp8_mlp_only_vllm_statereset_run1 \ - --max-model-len 131072 \ - --seq-len 131072 \ - --cte-buckets 256,512 \ - --block-size 128 \ - --enable-vllm-chunked-prefill \ + --compiled-artifacts /mnt/trainium_artifacts/qwen_artifacts/qwen36_256k_fp8_loadfix_lmheadbf16_gatesbf16_kvbf16_qkvnki_segmented_cte512_gdnseg512_cte2048_pfx256k_pa1025_slots64_20260608T195113Z_256k_loadfix_segcte2048_chatfix_hostsampling_kkt_hier_scan7 \ + --max-model-len 262144 \ + --seq-len 262144 \ + --cte-bucket 2048 \ + --block-size 256 \ --enable-prefix-caching \ --gdn-checkpoint-interval 256 \ --hybrid-gdn-recurrent-cache-dtype float32 \ --hybrid-gdn-conv-cache-dtype bfloat16 \ --mamba-cache-mode all \ --mamba-ssm-cache-dtype float32 \ + --num-gpu-blocks-override 1024 \ --port 8000 ``` @@ -211,32 +179,17 @@ cumulative prefix hash. If native APC does not produce exact greedy matches and a clear warm-hit speedup, the next step is a hybrid APC path that restores those GDN checkpoints alongside attention KV. -For APC experiments, do not treat `256` as the only block size. It can be useful -for long-context amortization, but it is coarse for chat-style prefix reuse. -Run explicit sweeps at `64` and `128`; include `32` when hit granularity matters -enough to justify possible block-table/layout overhead. Keep the GDN checkpoint -interval separate from the attention block size. - -Immediate Trainium experiments: - -```text -262K TP=4, block_size=256, CTE buckets [256] -262K TP=4, block_size=128, CTE buckets [256] -128K TP=4, block_size=128, CTE buckets [256,512] -128K TP=4, block_size=256, CTE buckets [256,512] -``` - Production chat proxy: ```bash contrib/models/Qwen3.6-27B/vllm/start_vllm_server.sh \ --model-path /opt/dlami/nvme/models/Qwen3.6-27B \ - --compiled-artifacts /opt/dlami/nvme/qwen_artifacts/qwen36_27b_128k_fp8_mlp_only_vllm_statereset_run1 \ - --max-model-len 131072 \ - --seq-len 131072 \ - --cte-bucket 512 \ + --compiled-artifacts /mnt/trainium_artifacts/qwen_artifacts/qwen36_256k_fp8_loadfix_lmheadbf16_gatesbf16_kvbf16_qkvnki_segmented_cte512_gdnseg512_cte2048_pfx256k_pa1025_slots64_20260608T195113Z_256k_loadfix_segcte2048_chatfix_hostsampling_kkt_hier_scan7 \ + --max-model-len 262144 \ + --seq-len 262144 \ + --cte-bucket 2048 \ --block-size 256 \ - --enable-vllm-chunked-prefill \ + --num-gpu-blocks-override 1024 \ --port 8001 ``` @@ -265,12 +218,11 @@ Offline long-prompt smoke: ```bash python contrib/models/Qwen3.6-27B/vllm/run_offline_inference.py \ --model-path /opt/dlami/nvme/models/Qwen3.6-27B \ - --compiled-artifacts /opt/dlami/nvme/qwen_artifacts/qwen36_27b_128k_fp8_mlp_only_vllm_statereset_run1 \ - --max-model-len 131072 \ - --seq-len 131072 \ - --cte-bucket 512 \ + --compiled-artifacts /mnt/trainium_artifacts/qwen_artifacts/qwen36_256k_fp8_loadfix_lmheadbf16_gatesbf16_kvbf16_qkvnki_segmented_cte512_gdnseg512_cte2048_pfx256k_pa1025_slots64_20260608T195113Z_256k_loadfix_segcte2048_chatfix_hostsampling_kkt_hier_scan7 \ + --max-model-len 262144 \ + --seq-len 262144 \ + --cte-bucket 2048 \ --block-size 256 \ - --enable-vllm-chunked-prefill \ --chat \ --prompt "$(python - <<'PY' print('Summarize this document in one paragraph. ' + 'Neuron inference ' * 700) @@ -278,206 +230,45 @@ PY )" ``` -Offline token-exact prefix-cache validation: +Optional Hybrid APC validation should be artifact-specific. The acceptance gate +is strict: repeated greedy calls must produce identical output, warm-hit latency +should be materially lower than cold-fill latency, and GDN recurrent/conv state +must be proven exact alongside attention KV cache hits. Attention-only prefix +cache hits are not sufficient for this hybrid model. -```bash -python validation_scripts/qwen36_vllm_prefix_cache_offline.py \ - --repo-root /home/ubuntu/inferentia-gdn \ - --model-path /opt/dlami/nvme/models/Qwen3.6-27B \ - --compiled-artifacts /opt/dlami/nvme/qwen_artifacts/qwen36_27b_128k_fp8_mlp_only_vllm_statereset_run1 \ - --max-model-len 131072 \ - --seq-len 131072 \ - --cte-bucket 512 \ - --block-size 128 \ - --enable-vllm-chunked-prefill \ - --mamba-cache-mode all -``` +Current validation run on Trn2 with the 256K loadfix artifact: -Offline partial-prefix validation: - -```bash -python validation_scripts/qwen36_vllm_prefix_cache_partial_offline.py \ - --repo-root /home/ubuntu/inferentia-gdn \ - --model-path /opt/dlami/nvme/models/Qwen3.6-27B \ - --compiled-artifacts /opt/dlami/nvme/qwen_artifacts/qwen36_27b_128k_fp8_mlp_only_vllm_statereset_run1 \ - --max-model-len 131072 \ - --seq-len 131072 \ - --cte-bucket 512 \ - --block-size 128 \ - --enable-vllm-chunked-prefill \ - --mamba-cache-mode all -``` - -Server-side prefix-cache validation through the guarded proxy: - -```bash -python validation_scripts/qwen36_prefix_cache_validation.py \ - --base-url http://127.0.0.1:8000 \ - --model qwen3.6-27b-neuron-128k-fp8-mlp -``` - -The acceptance gate is strict: repeated greedy calls must produce identical -output, and warm-hit latency should be materially lower than cold-fill latency. -For hybrid Qwen3.6, prefix-cache validation is not complete until the GDN -recurrent/conv state behavior is proven, not just attention KV cache hits. - -Hybrid APC exactness and HBM harness: - -```bash -python validation_scripts/qwen36_hybrid_apc_validation.py exactness \ - --model-path /opt/dlami/nvme/models/Qwen3.6-27B \ - --compiled-artifacts /opt/dlami/nvme/qwen_artifacts/qwen36_hybrid_apc \ - --seq-len 2048 \ - --cte-buckets 256,512 \ - --block-size 256 \ - --gdn-checkpoint-interval 256 \ - --enable-vllm-chunked-prefill - -python validation_scripts/qwen36_hybrid_apc_validation.py hbm \ - --context-lens 131072 262144 \ - --checkpoint-intervals 128 256 512 -``` - -Native APC validation run on Trn2 with the FP8 128K artifact: - -- server exact-repeat, `~10.8K` prompt tokens: `26.68s` cold to `1.67s` warm, - `16.0x` speedup, exact greedy text match; -- offline exact-repeat, token IDs exposed: `26.19s` cold to `2.38s` warm, - `11.0x` speedup, exact greedy token-ID match; -- offline partial-prefix reuse, token IDs exposed: `25.52s` no-cache target to - `1.70s` APC target after a different shared-prefix warmup request, `15.0x` - speedup, exact greedy token-ID match. -- server hardening, exact repeat: `25.38s` cold to `1.55s` warm, `16.35x` - speedup, exact text match; -- server hardening, cross-prefix reuse after unrelated prefix: `25.17s` cold to - `1.36s` warm, exact text match; -- shared-prefix concurrency at 1/2/4 requests returned all requested markers - exactly; the artifact still queues because it is compiled for `max_num_seqs=1`. - -Validation run on Trn2 with the FP8 128K artifact: - -- state-reset artifact: `/opt/dlami/nvme/qwen_artifacts/qwen36_27b_128k_fp8_mlp_only_vllm_statereset_run1`; -- OpenAI-compatible `/v1/chat/completions` behind the proxy passes focused - quality checks without callers passing `chat_template_kwargs`; -- repeated short-after-long validation passes after 32K and 64K requests, - confirming DeltaNet recurrent/conv state is reset for new requests; -- 32K and 64K needle retrieval prompts return all expected codes; -- measured prefill is `404-428 tok/s` from 512 through 64K prompt tokens; -- measured decode is `26.3-26.6 tok/s`; -- peak Neuron device memory is about `53.25 GB` decimal for the 64K eval. +- 16K native-chunk run: `16,374` prompt tokens, `6.8379s` TTFT, + `2,394.6 tok/s` usage-accounted, `pass=true`, thinking enabled. +- Long-context native-chunk run: `242,864` `usage.prompt_tokens`, `235.9819s` + TTFT, `1,029.2 tok/s` usage-accounted, `pass=true`, thinking enabled. +- The same long-context run has tokenizer-estimated prompt length `253,899`, + which gives `1,075.9 tok/s`; keep that separate from usage-accounted + throughput. +- `log_scan_empty.txt` contains no invalid-token, fallback, NaN, NRT, or + traceback markers. Raw `/v1/completions` prompts are not chat-templated and can pollute the hybrid state if sent directly to the backend. Keep the backend private and expose the proxy on the public port for production calls. -4K BF16 Hybrid APC boundary/server probes: - -```bash -# Artifact/config audit before spending a Trn2 run. This flags oversized PA -# blocks, low block headroom, strict-gate boundary pressure, and nki_chunked CTE. -python validation_scripts/qwen36_artifact_config_audit.py \ - /mnt/trainium_artifacts/qwen_artifacts/qwen36_27b_4096_bf16_hybrid_apc_nki_chunked_prefix4096_ctx2_tkg2_r7i_20260520T082342Z \ - --compile-log /home/ubuntu/validation_logs/hybrid_apc_real_tokens/qwen36_4k_bf16_hybrid_apc_nki_chunked_prefix4096_20260520T082342Z_compile.log - -# Boundary-aligned APC proof. Run this directly against vLLM or a proxy started -# with --allow-completions because exact token-ID prompt lengths are required. -python validation_scripts/qwen36_openai_boundary_apc_probe.py \ - --base-url http://127.0.0.1:8000 \ - --model-path /home/ubuntu/models/Qwen3.6-27B \ - --lengths 256,512,1024,2048,4096 \ - --repeats 3 \ - --require-prefix-cache-query \ - --output-jsonl /home/ubuntu/validation_logs/hybrid_apc_real_tokens/boundary_apc_probe.jsonl - -# Cold prefill ctx-batch utilization check. Compare --concurrency 1 and 2 with -# --unique-per-request to avoid warm-cache reuse. -python validation_scripts/qwen36_chat_completion_context_bench.py \ - --base-url http://127.0.0.1:8000 \ - --model /home/ubuntu/models/Qwen3.6-27B \ - --model-path /home/ubuntu/models/Qwen3.6-27B \ - --lengths 4096 \ - --turns 8 \ - --repeats 3 \ - --concurrency 2 \ - --unique-per-request \ - --no-stream \ - --output-json /home/ubuntu/validation_logs/hybrid_apc_real_tokens/chat_4k_concurrency2.json -``` - -4K BF16 compile controls for the current investigation: - -```bash -# Single-request cold-prefill latency control: smaller PA blocks, usable block -# headroom, and fused DeltaNet CTE. Use a fresh compiled path and workdir. -python contrib/models/Qwen3.6-27B/test/integration/qwen36_27b_compile_fp8.py \ - --repo-root /home/ubuntu/inferentia-gdn-experimental \ - --model-path /home/ubuntu/models/Qwen3.6-27B \ - --compiled-path /mnt/trainium_artifacts/qwen_artifacts/qwen36_27b_4096_bf16_hybrid_apc_fused_block32_ctx1 \ - --base-compile-work-dir /mnt/trainium_artifacts/qwen_artifacts/_work_qwen36_4k_fused_block32_ctx1 \ - --weight-dtype bf16_control \ - --seq-len 4096 \ - --max-context-length 4096 \ - --cte-buckets 256,512,1024,2048,4096 \ - --prefix-buckets 4096 \ - --block-size 32 \ - --pa-headroom-blocks 64 \ - --tp-degree 4 \ - --logical-nc-config 2 \ - --max-num-seqs 1 \ - --ctx-batch-size 1 \ - --skip-warmup \ - --enable-prefix-caching \ - --enable-hybrid-apc \ - --enable-vllm-chunked-prefill \ - --deltanet-cte-backend fused \ - --gdn-checkpoint-interval 32 \ - --max-gdn-checkpoint-slots 160 \ - --hybrid-apc-require-vllm-metadata \ - --hybrid-apc-enable-backed-prefix-reads -``` - -The `block_size=32` control follows Neuron's prefix-cache performance guidance, -but it also increases the number of prefix boundaries the strict Hybrid APC gate -must prove. Without boundary chunk commits, a full 4096-token prompt has 128 -possible attention-hit boundaries at block size 32, so `max_gdn_checkpoint_slots` -must be sized accordingly or the safe gate will keep skipping APC reads. - ## Offline Smoke ```bash python contrib/models/Qwen3.6-27B/vllm/run_offline_inference.py \ --model-path /opt/dlami/nvme/models/Qwen3.6-27B \ - --compiled-artifacts /opt/dlami/nvme/qwen_artifacts/qwen36_27b_128k_fp8_mlp_only_vllm_statereset_run1 \ - --max-model-len 131072 \ - --seq-len 131072 \ - --cte-buckets 128,256,512 \ + --compiled-artifacts /mnt/trainium_artifacts/qwen_artifacts/qwen36_256k_fp8_loadfix_lmheadbf16_gatesbf16_kvbf16_qkvnki_segmented_cte512_gdnseg512_cte2048_pfx256k_pa1025_slots64_20260608T195113Z_256k_loadfix_segcte2048_chatfix_hostsampling_kkt_hier_scan7 \ + --max-model-len 262144 \ + --seq-len 262144 \ + --cte-buckets 2048 \ --chat \ --prompt "What is 17 * 23? Answer with the number only." ``` ## Next Milestone -For cold-prefill latency, fix bucket waste before speculative decode or cache -quantization. The serving entrypoints now support multi-bucket CTE artifacts, -text-only CTE inputs, compact CTE masks, context-batch profiles, and attention -tile overrides. - For warm-prefix production APC, the required contract remains a unified prefix-cache object whose attention KV, GDN recurrent state, and GDN conv state -are jointly addressable, evictable, restorable, and exact under continuous -batching. - -Recommended order: - -1. Dynamic CTE buckets: start with `[128,256,512]` for 2K short-prompt tests, - `[256,512]` for 128K, and `[256]` for the 262K TP=4 load experiment. -2. Fused GDN CTE path validation: qwen chunked-prefill should use fused - DeltaNet with restored initial state by default. -3. Text-only CTE and compact-mask validation: no full dummy vision reductions - and no dense 4D causal masks in normal text serving. -4. Hybrid APC exactness: cold vs warm greedy token IDs, partial-prefix reuse, - multi-hit chat history, continuous batching movement, and eviction pressure. -5. Attention block-size sweeps at `64` and `128`, with `32` included for - granularity-sensitive chat workloads. -6. FP8 KV/cache only after the BF16/FP32 baseline is exact. -7. MTP/spec decode after recurrent-state rollback semantics are explicit. +are jointly addressable, evictable, restorable, and exact. Speculation, FP8 +cache variants, resident row-IO experiments, and continuous-batching extensions +are intentionally outside this baseline contribution.