Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
task_id: 10p_Sparse_MoE
description: Kernel task for 10p_Sparse_MoE
input_gen_code: |-
def get_inputs(dtype=jnp.bfloat16):
import jax
import jax.numpy as jnp
from functools import partial

CONFIG = {
'name': 'mixtral_8x7b_moe',
'model': 'Mixtral-8x7B',
'operator': 'sparse_moe',
'batch': 2,
'seq_len': 4096,
'emb_dim': 4096,
'mlp_dim': 14336,
'num_experts': 8,
'num_experts_per_tok': 2,
}
key = jax.random.key(42)
keys = jax.random.split(key, 5)
B, S, E, M = CONFIG['batch'], CONFIG['seq_len'], CONFIG['emb_dim'], CONFIG['mlp_dim']
N = CONFIG['num_experts']
x = jax.random.normal(keys[0], (B, S, E), dtype=dtype)
router = jax.random.normal(keys[1], (E, N), dtype=dtype) * 0.02
gate_k = jax.random.normal(keys[2], (N, E, M), dtype=dtype) * 0.02
up_k = jax.random.normal(keys[3], (N, E, M), dtype=dtype) * 0.02
down_k = jax.random.normal(keys[4], (N, M, E), dtype=dtype) * 0.02

dynamic_args = [x, router, gate_k, up_k, down_k]
static_args = [CONFIG['num_experts_per_tok']]

return dynamic_args, static_args
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Imports
import jax
import jax.numpy as jnp
from functools import partial

# Initialization
def get_inputs(dtype=jnp.bfloat16):
CONFIG = {
'name': 'mixtral_8x7b_moe',
'model': 'Mixtral-8x7B',
'operator': 'sparse_moe',
'batch': 2,
'seq_len': 4096,
'emb_dim': 4096,
'mlp_dim': 14336,
'num_experts': 8,
'num_experts_per_tok': 2,
}
key = jax.random.key(42)
keys = jax.random.split(key, 5)
B, S, E, M = CONFIG['batch'], CONFIG['seq_len'], CONFIG['emb_dim'], CONFIG['mlp_dim']
N = CONFIG['num_experts']
x = jax.random.normal(keys[0], (B, S, E), dtype=dtype)
router = jax.random.normal(keys[1], (E, N), dtype=dtype) * 0.02
gate_k = jax.random.normal(keys[2], (N, E, M), dtype=dtype) * 0.02
up_k = jax.random.normal(keys[3], (N, E, M), dtype=dtype) * 0.02
down_k = jax.random.normal(keys[4], (N, M, E), dtype=dtype) * 0.02

dynamic_args = [x, router, gate_k, up_k, down_k]
static_args = [CONFIG['num_experts_per_tok']]

return dynamic_args, static_args

# Computation
def computation(x, router_weights, expert_gate_kernels, expert_up_kernels, expert_down_kernels, num_experts_per_tok):
B, S, E = x.shape
N = router_weights.shape[-1]
K = num_experts_per_tok
logits = jnp.dot(x, router_weights)
top_k_logits, top_k_indices = jax.lax.top_k(logits, K)
router_probs = jax.nn.softmax(top_k_logits, axis=-1)
gate_out = jax.nn.silu(jnp.einsum('bse,nem->bsnm', x, expert_gate_kernels))
up_out = jnp.einsum('bse,nem->bsnm', x, expert_up_kernels)
hidden = gate_out * up_out
expert_outputs = jnp.einsum('bsnm,nme->bsne', hidden, expert_down_kernels)
one_hot = jax.nn.one_hot(top_k_indices, N)
weighted = one_hot * router_probs[..., None]
expert_weights = weighted.sum(axis=2)
output = jnp.einsum('bsne,bsn->bse', expert_outputs, expert_weights)
return output
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
task_id: 11p_Megablox_GMM
description: Kernel task for 11p_Megablox_GMM
input_gen_code: |-
def get_inputs(dtype=jnp.bfloat16):
import jax
import jax.numpy as jnp

CONFIG = {
'name': 'megablox_gmm_qwen3_235b',
'model': 'Qwen3-235B-A22B',
'operator': 'grouped_matmul',
'num_experts': 128,
'num_experts_per_tok': 8,
'emb_dim': 4096,
'moe_mlp_dim': 1536,
'seq_len': 4096,
}
key = jax.random.key(42)
k1, k2 = jax.random.split(key, 2)
G = CONFIG['num_experts']
top_k = CONFIG['num_experts_per_tok']
K = CONFIG['emb_dim']
N = CONFIG['moe_mlp_dim']
S = CONFIG['seq_len']
M = S * top_k
limit = 1 / (M * K)
lhs = jax.random.uniform(k1, (M, K), dtype=dtype, minval=-limit, maxval=limit)
lhs = lhs.astype(jnp.bfloat16).astype(dtype)
rhs = jax.random.uniform(k2, (G, K, N), dtype=dtype, minval=-limit, maxval=limit)
rhs = rhs.astype(jnp.bfloat16).astype(dtype)
max_expert_size = M // G
group_sizes = jnp.full((G,), max_expert_size, dtype=jnp.int32)

dynamic_args = [lhs, rhs, group_sizes]
static_args = [max_expert_size]

return dynamic_args, static_args
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# Imports
import jax
import jax.numpy as jnp

# Initialization
def get_inputs(dtype=jnp.bfloat16):
CONFIG = {
'name': 'megablox_gmm_qwen3_235b',
'model': 'Qwen3-235B-A22B',
'operator': 'grouped_matmul',
'num_experts': 128,
'num_experts_per_tok': 8,
'emb_dim': 4096,
'moe_mlp_dim': 1536,
'seq_len': 4096,
}
key = jax.random.key(42)
k1, k2 = jax.random.split(key, 2)
G = CONFIG['num_experts']
top_k = CONFIG['num_experts_per_tok']
K = CONFIG['emb_dim']
N = CONFIG['moe_mlp_dim']
S = CONFIG['seq_len']
M = S * top_k
limit = 1 / (M * K)
lhs = jax.random.uniform(k1, (M, K), dtype=dtype, minval=-limit, maxval=limit)
lhs = lhs.astype(jnp.bfloat16).astype(dtype)
rhs = jax.random.uniform(k2, (G, K, N), dtype=dtype, minval=-limit, maxval=limit)
rhs = rhs.astype(jnp.bfloat16).astype(dtype)
max_expert_size = M // G
group_sizes = jnp.full((G,), max_expert_size, dtype=jnp.int32)

dynamic_args = [lhs, rhs, group_sizes]
static_args = [max_expert_size]

return dynamic_args, static_args

# Computation
def computation(lhs, rhs, group_sizes, max_expert_size):
G = rhs.shape[0]
M, K = lhs.shape
N = rhs.shape[2]

group_ends = jnp.cumsum(group_sizes)
group_starts = jnp.concatenate(
[jnp.zeros(1, dtype=jnp.int32), group_ends[:-1]]
)

res_flat = jnp.zeros((M + max_expert_size, N), dtype=lhs.dtype)

def body_fun(carry_res_flat, i):
start = group_starts[i]
count = group_sizes[i]

expert_lhs = jax.lax.dynamic_slice(
lhs, (start, 0), (max_expert_size, K)
)
expert_rhs = rhs[i, :, :]

res = jax.lax.dot(
expert_lhs, expert_rhs, preferred_element_type=jnp.float32
)

mask = (
jax.lax.broadcasted_iota(jnp.int32, (max_expert_size, N), 0) < count
)
res_masked = jnp.where(mask, res, 0.0)

current_slice = jax.lax.dynamic_slice(
carry_res_flat, (start, 0), (max_expert_size, N)
)
updated_slice = current_slice + res_masked.astype(carry_res_flat.dtype)
carry_res_flat = jax.lax.dynamic_update_slice(
carry_res_flat, updated_slice, (start, 0)
)

return carry_res_flat, None

res_flat, _ = jax.lax.scan(body_fun, res_flat, jnp.arange(G))

return res_flat[:M, :]
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
task_id: 12p_RMSNorm
description: Kernel task for 12p_RMSNorm
input_gen_code: |-
def get_inputs(dtype=jnp.bfloat16):
import jax
import jax.numpy as jnp
from jax import lax

batch = 8
seq_len = 4096
emb_dim = 8192
epsilon = 1e-5

key = jax.random.key(42)
k1, k2 = jax.random.split(key, 2)
x = jax.random.normal(k1, (batch, seq_len, emb_dim), dtype=dtype)
scale = jax.random.normal(k2, (emb_dim,), dtype=dtype) * 0.1 + 1.0
return [x, scale], [epsilon]
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Imports
import jax
import jax.numpy as jnp
from jax import lax

# Initialization
def get_inputs(dtype=jnp.bfloat16):
batch = 8
seq_len = 4096
emb_dim = 8192
epsilon = 1e-5

key = jax.random.key(42)
k1, k2 = jax.random.split(key, 2)
x = jax.random.normal(k1, (batch, seq_len, emb_dim), dtype=dtype)
scale = jax.random.normal(k2, (emb_dim,), dtype=dtype) * 0.1 + 1.0
return [x, scale], [epsilon]

# Computation
def computation(x, scale, epsilon):
x_f32 = jnp.asarray(x, jnp.float32)
mean2 = jnp.mean(lax.square(x_f32), axis=-1, keepdims=True)
normed = x_f32 * lax.rsqrt(mean2 + epsilon)
normed = jnp.asarray(normed, x.dtype)
return normed * scale
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
task_id: 13p_Cross_Entropy
description: Kernel task for 13p_Cross_Entropy
input_gen_code: |-
def get_inputs():
import jax
import jax.numpy as jnp

CONFIG = {
'name': 'llama3_8b_cross_entropy',
'model': 'Llama-3.1-8B',
'operator': 'fused_cross_entropy',
'batch_tokens': 8192,
'hidden_dim': 4096,
'vocab_size': 128256,
}
dtype = jnp.bfloat16
key = jax.random.key(42)
k1, k2, k3 = jax.random.split(key, 3)
B, H, V = CONFIG['batch_tokens'], CONFIG['hidden_dim'], CONFIG['vocab_size']
hidden = jax.random.normal(k1, (B, H), dtype=dtype)
weight = jax.random.normal(k2, (H, V), dtype=dtype) * 0.02
labels = jax.random.randint(k3, (B,), 0, V)
dynamic_args = [hidden, weight, labels]
static_args = []
return dynamic_args, static_args
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Imports
import jax
import jax.numpy as jnp

# Initialization
def get_inputs():
CONFIG = {
'name': 'llama3_8b_cross_entropy',
'model': 'Llama-3.1-8B',
'operator': 'fused_cross_entropy',
'batch_tokens': 8192,
'hidden_dim': 4096,
'vocab_size': 128256,
}
dtype = jnp.bfloat16
key = jax.random.key(42)
k1, k2, k3 = jax.random.split(key, 3)
B, H, V = CONFIG['batch_tokens'], CONFIG['hidden_dim'], CONFIG['vocab_size']
hidden = jax.random.normal(k1, (B, H), dtype=dtype)
weight = jax.random.normal(k2, (H, V), dtype=dtype) * 0.02
labels = jax.random.randint(k3, (B,), 0, V)
dynamic_args = [hidden, weight, labels]
static_args = []
return dynamic_args, static_args

# Computation
def computation(hidden, weight, labels):
logits = jnp.dot(hidden, weight)
log_probs = jax.nn.log_softmax(logits, axis=-1)
one_hot = jax.nn.one_hot(labels, logits.shape[-1])
loss = -jnp.sum(one_hot * log_probs, axis=-1)
return jnp.mean(loss)
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
task_id: 14p_Ragged_Dot
description: Kernel task for 14p_Ragged_Dot
input_gen_code: |-
def get_inputs():
import jax
import jax.numpy as jnp

CONFIG = {
'name': 'mixtral_8x7b_ragged_dot',
'model': 'Mixtral-8x7B',
'operator': 'ragged_dot',
'num_groups': 8,
'M': 8192,
'K': 4096,
'N': 14336,
}
dtype = jnp.bfloat16
key = jax.random.key(42)
k1, k2 = jax.random.split(key, 2)
G, M, K, N = CONFIG['num_groups'], CONFIG['M'], CONFIG['K'], CONFIG['N']
x = jax.random.normal(k1, (G, M // G, K), dtype=dtype)
weights = jax.random.normal(k2, (G, K, N), dtype=dtype) * 0.02
dynamic_args = [x, weights]
static_args = []
return dynamic_args, static_args
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Imports
import jax
import jax.numpy as jnp

# Initialization
def get_inputs():
CONFIG = {
'name': 'mixtral_8x7b_ragged_dot',
'model': 'Mixtral-8x7B',
'operator': 'ragged_dot',
'num_groups': 8,
'M': 8192,
'K': 4096,
'N': 14336,
}
dtype = jnp.bfloat16
key = jax.random.key(42)
k1, k2 = jax.random.split(key, 2)
G, M, K, N = CONFIG['num_groups'], CONFIG['M'], CONFIG['K'], CONFIG['N']
x = jax.random.normal(k1, (G, M // G, K), dtype=dtype)
weights = jax.random.normal(k2, (G, K, N), dtype=dtype) * 0.02
dynamic_args = [x, weights]
static_args = []
return dynamic_args, static_args

# Computation
def computation(x, weights):
return jnp.einsum('gmk,gkn->gmn', x, weights)
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
task_id: 15p_RetNet_Retention
description: Kernel task for 15p_RetNet_Retention
input_gen_code: |-
def get_inputs():
import jax
import jax.numpy as jnp
from functools import partial

CONFIG = {
'name': 'retnet_6_7b_retention',
'model': 'RetNet-6.7B',
'operator': 'multi_scale_retention',
'batch': 4,
'seq_len': 4096,
'num_heads': 16,
'head_dim': 256,
'd_model': 4096,
}
dtype = jnp.bfloat16
key = jax.random.key(42)
keys = jax.random.split(key, 3)
B, S = CONFIG['batch'], CONFIG['seq_len']
H, D = CONFIG['num_heads'], CONFIG['head_dim']
query = jax.random.normal(keys[0], (B, H, S, D), dtype=dtype)
key_val = jax.random.normal(keys[1], (B, H, S, D), dtype=dtype)
value = jax.random.normal(keys[2], (B, H, S, D), dtype=dtype)
dynamic_args = [query, key_val, value]
static_args = []
return dynamic_args, static_args
Loading