Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion scripts/deepseek.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,7 +590,17 @@ def __init__(
self.meta = DeepSeekV3Meta(
config, max_tokens=max_tokens, dtype=torch.float16
)
self.tokenizer = transformers.AutoTokenizer.from_pretrained(model_dir_path)
# Try loading tokenizer, fall back to DeepSeek-V3 official tokenizer if local fails
try:
self.tokenizer = transformers.AutoTokenizer.from_pretrained(
model_dir_path, trust_remote_code=True
)
except Exception as e:
print(f"Warning: Failed to load tokenizer from {model_dir_path}: {e}")
print("Falling back to deepseek-ai/DeepSeek-V3 tokenizer...")
self.tokenizer = transformers.AutoTokenizer.from_pretrained(
"deepseek-ai/DeepSeek-V3", trust_remote_code=True
)
else:
raise ValueError("Unsupported model architecture")

Expand Down
178 changes: 149 additions & 29 deletions src/models/deepseek_v3/deepseek_v3.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,11 +138,39 @@ void inferDeviceBatch(const DeepSeekV3Meta &meta, DeepSeekV3DeviceResource &rsrc
max_seq_len = std::max(max_seq_len, size_t(seq_len));
max_total_len = std::max(max_total_len, size_t(total_len));
}
// Buffers for original (non-absorb) mode - keep for comparison
auto full_k_buf = Tensor::buffer(dt_logits, {max_total_len, nh * d_qk}, rsrc.memory_pool);
auto kv_b_buf = Tensor::buffer(dt_logits, {max_total_len, nh * (d_nope + d_v)}, rsrc.memory_pool);
auto attn_score_buf = Tensor::buffer(dt_logits, {nh, max_qk_size}, rsrc.memory_pool);
auto attn_val_buf = Tensor::buffer(dt_logits, {nh, max_seq_len, d_v}, rsrc.memory_pool);

// Check if all requests are decode (seq_len=1) for decode-specific optimization
bool is_all_decode = true;
for (uint32_t req = 0; req < nreq; req++) {
if (req_lens[req] != 1) {
is_all_decode = false;
break;
}
}

// ============== ABSORB MODE BUFFERS ==============
// wkv_b dequantized weight: [r_kv, nh * (d_nope + d_v)]
auto wkv_b_dequant = Tensor::buffer(dt_logits, {r_kv, nh * (d_nope + d_v)}, rsrc.memory_pool);

std::shared_ptr<Tensor> q_absorbed_buf, weighted_kv_buf;
if (is_all_decode) {
// DECODE OPTIMIZATION: Smaller buffers for seq_len=1
q_absorbed_buf = Tensor::buffer(dt_logits, {1, nh, r_kv}, rsrc.memory_pool);
weighted_kv_buf = Tensor::buffer(dt_logits, {1, nh, r_kv}, rsrc.memory_pool);
} else {
// Prefill: need full buffers
q_absorbed_buf = Tensor::buffer(dt_logits, {max_seq_len, nh, r_kv}, rsrc.memory_pool);
weighted_kv_buf = Tensor::buffer(dt_logits, {max_seq_len, nh, r_kv}, rsrc.memory_pool);
}

// Scale factor for attention
float attn_scale = 1.f / float(sqrt(d_qk));

// Compute
for (uint32_t layer = 0; layer < nlayer; layer++) {
// 1. Attention
Expand Down Expand Up @@ -173,6 +201,12 @@ void inferDeviceBatch(const DeepSeekV3Meta &meta, DeepSeekV3DeviceResource &rsrc
auto k_rot = kv_a_buf->slice(1, r_kv, d_rope)->view({ntok, 1, d_rope});
rope_v2(k_rot, k_rot, pos_ids_buf, weights->sin_table, weights->cos_table);

// Dequantize wkv_b once per layer
getInferenceContext().dequant(wkv_b_dequant,
weights->w_layers[layer].mla->kv_b_proj->w,
weights->w_layers[layer].mla->kv_b_proj->s,
weights->w_layers[layer].mla->kv_b_proj->z);

size_t token_offset = 0;
for (uint32_t req = 0; req < nreq; req++) {
auto past_len = req_pos[req];
Expand All @@ -184,37 +218,123 @@ void inferDeviceBatch(const DeepSeekV3Meta &meta, DeepSeekV3DeviceResource &rsrc
auto kv_pass_req = kv_a_req->slice(1, 0, r_kv);
auto k_rot_req = kv_a_req->slice(1, r_kv, d_rope);

// concat cache
// Update cache with new tokens
rearrange(caches[req]->kv_pass[idev][layer]->slice(0, past_len, seq_len), kv_pass_req);
rearrange(caches[req]->k_rot[idev][layer]->slice(0, past_len, seq_len), k_rot_req);
// kv_b_proj
auto kv_b_req = kv_b_buf->slice(0, 0, total_len);
dequant_linear(kv_b_req, caches[req]->kv_pass[idev][layer]->slice(0, 0, total_len),
weights->w_layers[layer].mla->kv_b_proj->w,
weights->w_layers[layer].mla->kv_b_proj->s,
weights->w_layers[layer].mla->kv_b_proj->z,
1.0, 0.0, nullptr, nullptr);
auto full_v_req = kv_b_req->slice(1, nh * d_nope, nh * d_v);
// concat k
auto full_k_req = full_k_buf->slice(0, 0, total_len);
auto full_k_pass_req = full_k_req->slice(1, 0, nh * d_nope);
auto full_k_rot_req = full_k_req->slice(1, nh * d_nope, nh * d_rope);
rearrange(full_k_pass_req, kv_b_req->slice(1, 0, nh * d_nope));
rearrange(full_k_rot_req->view({total_len, nh, d_rope}), k_rot_req->view_as({total_len, nh, d_rope}, {ptrdiff_t(d_rope), 0, 1})); // expand k_rot

// self attention
auto attn_score_req = attn_score_buf->slice(1, 0, seq_len * total_len)->view({nh, seq_len, total_len});
linear(attn_score_req,
q_req->view({seq_len, nh, d_qk})->permute({1, 0, 2}),
full_k_req->view({total_len, nh, d_qk})->permute({1, 2, 0}),
1.f / float(sqrt(d_qk)), 0.f, nullptr, nullptr);
// softmax
causalSoftmax(attn_score_req, attn_score_req);
// attn val
auto attn_val_req = attn_val_buf->slice(1, 0, seq_len)->view({nh, seq_len, d_v});
linear(attn_val_req, attn_score_req, full_v_req->view({total_len, nh, d_v})->permute({1, 0, 2}), 1.f, 0.f, nullptr, nullptr);
// rearrange attn val
rearrange(o_req, attn_val_req->permute({1, 0, 2}));

// ============== ABSORB MODE ATTENTION ==============
// Key insight: Instead of decompressing the entire cache every time,
// we absorb wkv_b into Q and compute attention on compressed cache.
//
// Original: scores = Q @ (kv_cache @ wkv_b).T
// Absorb: scores = (Q @ wkv_b) @ kv_cache.T
//
// This changes O(seq * total * hidden) to O(seq * hidden) for wkv_b application

auto kv_cache_req = caches[req]->kv_pass[idev][layer]->slice(0, 0, total_len); // [total_len, r_kv]
auto pe_cache_req = caches[req]->k_rot[idev][layer]->slice(0, 0, total_len); // [total_len, d_rope]

// Split q into nope and pe parts
auto q_req_view = q_req->view({seq_len, nh, d_qk});
auto q_nope_req = q_req_view->slice(2, 0, d_nope); // [seq_len, nh, d_nope]
auto q_pe_req = q_req_view->slice(2, d_nope, d_rope); // [seq_len, nh, d_rope]

// Split wkv_b into k and v parts
// wkv_b: [r_kv, nh * (d_nope + d_v)]
auto wkv_b_k = wkv_b_dequant->slice(1, 0, nh * d_nope); // [r_kv, nh * d_nope]
auto wkv_b_v = wkv_b_dequant->slice(1, nh * d_nope, nh * d_v); // [r_kv, nh * d_v]

auto attn_score_req = attn_score_buf->slice(1, 0, seq_len * total_len)->view({nh, seq_len, total_len});

if (seq_len == 1) {
// ============== DECODE OPTIMIZATION (seq_len=1) ==============
// For decode, we can use a more efficient computation order:
// Instead of: q_absorbed = q @ wkv_b_k, scores = q_absorbed @ kv_cache.T
// We compute: k_proj = wkv_b_k @ kv_cache.T, scores = q @ k_proj
// This can be more cache-friendly when total_len is large

// Step 1: Absorb wkv_b_k into q_nope (single token)
auto q_absorbed_req = q_absorbed_buf->slice(0, 0, 1); // [1, nh, r_kv]
linear(q_absorbed_req->permute({1, 0, 2}), // [nh, 1, r_kv]
q_nope_req->permute({1, 0, 2}), // [nh, 1, d_nope]
wkv_b_k->view({r_kv, nh, d_nope})->permute({1, 2, 0}), // [nh, d_nope, r_kv]
1.f, 0.f, nullptr, nullptr);

// Step 2: Compute attention scores with FUSED ADD
// First compute q_absorbed @ kv_cache.T into attn_score_req
linear(attn_score_req,
q_absorbed_req->permute({1, 0, 2}), // [nh, 1, r_kv]
kv_cache_req->permute({1, 0}), // [r_kv, total_len]
attn_scale, 0.f, nullptr, nullptr);

// Then compute q_pe @ pe_cache.T and ADD to attn_score_req (fused with beta=1)
linear(attn_score_req,
q_pe_req->permute({1, 0, 2}), // [nh, 1, d_rope]
pe_cache_req->permute({1, 0}), // [d_rope, total_len]
attn_scale, 1.f, nullptr, nullptr); // beta=1.0 fuses the add!

// Softmax
causalSoftmax(attn_score_req, attn_score_req);

// Step 3: Compute weighted KV cache
auto weighted_kv_req = weighted_kv_buf->slice(0, 0, 1); // [1, nh, r_kv]
linear(weighted_kv_req->permute({1, 0, 2}), // [nh, 1, r_kv]
attn_score_req, // [nh, 1, total_len]
kv_cache_req, // [total_len, r_kv]
1.f, 0.f, nullptr, nullptr);

// Step 4: Apply wkv_b_v to get final attention output
auto attn_val_req = attn_val_buf->slice(1, 0, 1)->view({nh, 1, d_v});
linear(attn_val_req,
weighted_kv_req->permute({1, 0, 2}), // [nh, 1, r_kv]
wkv_b_v->view({r_kv, nh, d_v})->permute({1, 0, 2}), // [nh, r_kv, d_v]
1.f, 0.f, nullptr, nullptr);

// Rearrange: [nh, 1, d_v] -> [1, nh, d_v]
rearrange(o_req, attn_val_req->permute({1, 0, 2}));

} else {
// ============== PREFILL MODE (seq_len > 1) ==============
// Step 1: Absorb wkv_b_k into q_nope
auto q_absorbed_req = q_absorbed_buf->slice(0, 0, seq_len); // [seq_len, nh, r_kv]
linear(q_absorbed_req->permute({1, 0, 2}), // [nh, seq_len, r_kv]
q_nope_req->permute({1, 0, 2}), // [nh, seq_len, d_nope]
wkv_b_k->view({r_kv, nh, d_nope})->permute({1, 2, 0}), // [nh, d_nope, r_kv]
1.f, 0.f, nullptr, nullptr);

// Step 2: Compute attention scores with FUSED ADD
// First compute q_absorbed @ kv_cache.T
linear(attn_score_req,
q_absorbed_req->permute({1, 0, 2}), // [nh, seq_len, r_kv]
kv_cache_req->permute({1, 0}), // [r_kv, total_len]
attn_scale, 0.f, nullptr, nullptr);

// Then compute q_pe @ pe_cache.T and ADD (fused with beta=1)
linear(attn_score_req,
q_pe_req->permute({1, 0, 2}), // [nh, seq_len, d_rope]
pe_cache_req->permute({1, 0}), // [d_rope, total_len]
attn_scale, 1.f, nullptr, nullptr); // beta=1.0 fuses the add!

// Softmax
causalSoftmax(attn_score_req, attn_score_req);

// Step 3: Compute weighted KV cache
auto weighted_kv_req = weighted_kv_buf->slice(0, 0, seq_len); // [seq_len, nh, r_kv]
linear(weighted_kv_req->permute({1, 0, 2}), // [nh, seq_len, r_kv]
attn_score_req, // [nh, seq_len, total_len]
kv_cache_req, // [total_len, r_kv]
1.f, 0.f, nullptr, nullptr);

// Step 4: Apply wkv_b_v to get final attention output
auto attn_val_req = attn_val_buf->slice(1, 0, seq_len)->view({nh, seq_len, d_v});
linear(attn_val_req,
weighted_kv_req->permute({1, 0, 2}), // [nh, seq_len, r_kv]
wkv_b_v->view({r_kv, nh, d_v})->permute({1, 0, 2}), // [nh, r_kv, d_v]
1.f, 0.f, nullptr, nullptr);

// Rearrange: [nh, seq_len, d_v] -> [seq_len, nh, d_v]
rearrange(o_req, attn_val_req->permute({1, 0, 2}));
}

token_offset += seq_len;
}
Expand Down
1 change: 1 addition & 0 deletions test/models/deepseek_mla/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# DeepSeek MLA (Multi-head Latent Attention) Test Module
Loading