From 08d603c9b384bdc3fcd35203635655cd2d64b5bb Mon Sep 17 00:00:00 2001 From: shangkunwang Date: Wed, 27 May 2026 19:46:44 +0000 Subject: [PATCH] feat: add the adapted dataset from JaxBench --- .../10p_Sparse_MoE/kernel_task.yaml | 33 ++++++ .../10p_Sparse_MoE/reference.py | 50 +++++++++ .../11p_Megablox_GMM/kernel_task.yaml | 37 ++++++ .../11p_Megablox_GMM/reference.py | 81 +++++++++++++ .../12p_RMSNorm/kernel_task.yaml | 18 +++ .../12p_RMSNorm/reference.py | 25 +++++ .../13p_Cross_Entropy/kernel_task.yaml | 25 +++++ .../13p_Cross_Entropy/reference.py | 32 ++++++ .../14p_Ragged_Dot/kernel_task.yaml | 25 +++++ .../14p_Ragged_Dot/reference.py | 28 +++++ .../15p_RetNet_Retention/kernel_task.yaml | 29 +++++ .../15p_RetNet_Retention/reference.py | 46 ++++++++ .../16p_Mamba2_SSD/kernel_task.yaml | 29 +++++ .../16p_Mamba2_SSD/reference.py | 47 ++++++++ .../kernel_task.yaml | 30 +++++ .../17p_Triangle_Multiplication/reference.py | 46 ++++++++ .../18k_Conv2D_ReLU_BiasAdd/kernel_task.yaml | 23 ++++ .../18k_Conv2D_ReLU_BiasAdd/reference.py | 38 +++++++ .../kernel_task.yaml | 21 ++++ .../reference.py | 28 +++++ .../1p_Flash_Attention/kernel_task.yaml | 20 ++++ .../1p_Flash_Attention/reference.py | 30 +++++ .../kernel_task.yaml | 19 ++++ .../20k_Gemm_Multiply_LeakyReLU/reference.py | 27 +++++ .../kernel_task.yaml | 23 ++++ .../21k_Gemm_Divide_Sum_Scaling/reference.py | 34 ++++++ .../kernel_task.yaml | 25 +++++ .../reference.py | 45 ++++++++ .../kernel_task.yaml | 17 +++ .../reference.py | 26 +++++ .../kernel_task.yaml | 23 ++++ .../reference.py | 34 ++++++ .../kernel_task.yaml | 21 ++++ .../25k_Conv3d_GroupNorm_Mean/reference.py | 44 ++++++++ .../kernel_task.yaml | 20 ++++ .../reference.py | 33 ++++++ .../27k_Matmul_Mish_Mish/kernel_task.yaml | 17 +++ .../27k_Matmul_Mish_Mish/reference.py | 23 ++++ .../kernel_task.yaml | 31 +++++ .../reference.py | 68 +++++++++++ .../kernel_task.yaml | 20 ++++ .../reference.py | 33 ++++++ .../2p_GQA_Attention/kernel_task.yaml | 29 +++++ .../2p_GQA_Attention/reference.py | 46 ++++++++ .../kernel_task.yaml | 17 +++ .../reference.py | 25 +++++ .../kernel_task.yaml | 20 ++++ .../31k_Gemm_BatchNorm_GELU_ReLU/reference.py | 30 +++++ .../kernel_task.yaml | 20 ++++ .../32k_Gemm_Sigmoid_LogSumExp/reference.py | 27 +++++ .../33k_Conv3d_Mish_Tanh/kernel_task.yaml | 21 ++++ .../33k_Conv3d_Mish_Tanh/reference.py | 36 ++++++ .../kernel_task.yaml | 26 +++++ .../reference.py | 49 ++++++++ .../kernel_task.yaml | 33 ++++++ .../reference.py | 40 +++++++ .../36k_Matmul_Sigmoid_Sum/kernel_task.yaml | 20 ++++ .../36k_Matmul_Sigmoid_Sum/reference.py | 26 +++++ .../37k_Matmul_Swish_Scaling/kernel_task.yaml | 19 ++++ .../37k_Matmul_Swish_Scaling/reference.py | 25 +++++ .../kernel_task.yaml | 21 ++++ .../38k_Matmul_Dropout_Softmax/reference.py | 26 +++++ .../kernel_task.yaml | 16 +++ .../reference.py | 30 +++++ .../3p_MLA_Attention/kernel_task.yaml | 41 +++++++ .../3p_MLA_Attention/reference.py | 85 ++++++++++++++ .../kernel_task.yaml | 18 +++ .../reference.py | 35 ++++++ .../41k_Gemm_Add_ReLU/kernel_task.yaml | 19 ++++ .../41k_Gemm_Add_ReLU/reference.py | 25 +++++ .../kernel_task.yaml | 22 ++++ .../42k_Gemm_Max_Subtract_GELU/reference.py | 29 +++++ .../kernel_task.yaml | 25 +++++ .../reference.py | 34 ++++++ .../44k_Matmul_Divide_GELU/kernel_task.yaml | 21 ++++ .../44k_Matmul_Divide_GELU/reference.py | 27 +++++ .../kernel_task.yaml | 24 ++++ .../reference.py | 39 +++++++ .../kernel_task.yaml | 23 ++++ .../reference.py | 48 ++++++++ .../kernel_task.yaml | 20 ++++ .../reference.py | 29 +++++ .../kernel_task.yaml | 22 ++++ .../reference.py | 33 ++++++ .../kernel_task.yaml | 22 ++++ .../reference.py | 39 +++++++ .../4p_Sparse_Attention/kernel_task.yaml | 30 +++++ .../4p_Sparse_Attention/reference.py | 46 ++++++++ .../50k_Matmul_GELU_Softmax/kernel_task.yaml | 20 ++++ .../50k_Matmul_GELU_Softmax/reference.py | 26 +++++ .../5p_Flex_Attention/kernel_task.yaml | 30 +++++ .../5p_Flex_Attention/reference.py | 43 +++++++ .../6p_Paged_Attention/kernel_task.yaml | 44 ++++++++ .../6p_Paged_Attention/reference.py | 104 +++++++++++++++++ .../kernel_task.yaml | 45 ++++++++ .../7p_Ragged_Paged_Attention/reference.py | 106 ++++++++++++++++++ .../8p_GEMM/kernel_task.yaml | 20 ++++ .../8p_GEMM/reference.py | 23 ++++ .../9p_SwiGLU_MLP/kernel_task.yaml | 29 +++++ .../9p_SwiGLU_MLP/reference.py | 34 ++++++ 100 files changed, 3206 insertions(+) create mode 100644 MaxKernel/evaluation/jaxbench_adapted_dataset/10p_Sparse_MoE/kernel_task.yaml create mode 100644 MaxKernel/evaluation/jaxbench_adapted_dataset/10p_Sparse_MoE/reference.py create mode 100644 MaxKernel/evaluation/jaxbench_adapted_dataset/11p_Megablox_GMM/kernel_task.yaml create mode 100644 MaxKernel/evaluation/jaxbench_adapted_dataset/11p_Megablox_GMM/reference.py create mode 100644 MaxKernel/evaluation/jaxbench_adapted_dataset/12p_RMSNorm/kernel_task.yaml create mode 100644 MaxKernel/evaluation/jaxbench_adapted_dataset/12p_RMSNorm/reference.py create mode 100644 MaxKernel/evaluation/jaxbench_adapted_dataset/13p_Cross_Entropy/kernel_task.yaml create mode 100644 MaxKernel/evaluation/jaxbench_adapted_dataset/13p_Cross_Entropy/reference.py create mode 100644 MaxKernel/evaluation/jaxbench_adapted_dataset/14p_Ragged_Dot/kernel_task.yaml create mode 100644 MaxKernel/evaluation/jaxbench_adapted_dataset/14p_Ragged_Dot/reference.py create mode 100644 MaxKernel/evaluation/jaxbench_adapted_dataset/15p_RetNet_Retention/kernel_task.yaml create mode 100644 MaxKernel/evaluation/jaxbench_adapted_dataset/15p_RetNet_Retention/reference.py create mode 100644 MaxKernel/evaluation/jaxbench_adapted_dataset/16p_Mamba2_SSD/kernel_task.yaml create mode 100644 MaxKernel/evaluation/jaxbench_adapted_dataset/16p_Mamba2_SSD/reference.py create mode 100644 MaxKernel/evaluation/jaxbench_adapted_dataset/17p_Triangle_Multiplication/kernel_task.yaml create mode 100644 MaxKernel/evaluation/jaxbench_adapted_dataset/17p_Triangle_Multiplication/reference.py create mode 100644 MaxKernel/evaluation/jaxbench_adapted_dataset/18k_Conv2D_ReLU_BiasAdd/kernel_task.yaml create mode 100644 MaxKernel/evaluation/jaxbench_adapted_dataset/18k_Conv2D_ReLU_BiasAdd/reference.py create mode 100644 MaxKernel/evaluation/jaxbench_adapted_dataset/19k_Matmul_Subtract_Multiply_ReLU/kernel_task.yaml create mode 100644 MaxKernel/evaluation/jaxbench_adapted_dataset/19k_Matmul_Subtract_Multiply_ReLU/reference.py create mode 100644 MaxKernel/evaluation/jaxbench_adapted_dataset/1p_Flash_Attention/kernel_task.yaml create mode 100644 MaxKernel/evaluation/jaxbench_adapted_dataset/1p_Flash_Attention/reference.py create mode 100644 MaxKernel/evaluation/jaxbench_adapted_dataset/20k_Gemm_Multiply_LeakyReLU/kernel_task.yaml create mode 100644 MaxKernel/evaluation/jaxbench_adapted_dataset/20k_Gemm_Multiply_LeakyReLU/reference.py create mode 100644 MaxKernel/evaluation/jaxbench_adapted_dataset/21k_Gemm_Divide_Sum_Scaling/kernel_task.yaml create mode 100644 MaxKernel/evaluation/jaxbench_adapted_dataset/21k_Gemm_Divide_Sum_Scaling/reference.py create mode 100644 MaxKernel/evaluation/jaxbench_adapted_dataset/22k_Conv2d_InstanceNorm_Divide/kernel_task.yaml create mode 100644 MaxKernel/evaluation/jaxbench_adapted_dataset/22k_Conv2d_InstanceNorm_Divide/reference.py create mode 100644 MaxKernel/evaluation/jaxbench_adapted_dataset/23k_Matmul_Sum_Max_AvgPool_LogSumExp_LogSumExp/kernel_task.yaml create mode 100644 MaxKernel/evaluation/jaxbench_adapted_dataset/23k_Matmul_Sum_Max_AvgPool_LogSumExp_LogSumExp/reference.py create mode 100644 MaxKernel/evaluation/jaxbench_adapted_dataset/24k_Matmul_Scale_ResidualAdd_Clamp_LogSumExp_Mish/kernel_task.yaml create mode 100644 MaxKernel/evaluation/jaxbench_adapted_dataset/24k_Matmul_Scale_ResidualAdd_Clamp_LogSumExp_Mish/reference.py create mode 100644 MaxKernel/evaluation/jaxbench_adapted_dataset/25k_Conv3d_GroupNorm_Mean/kernel_task.yaml create mode 100644 MaxKernel/evaluation/jaxbench_adapted_dataset/25k_Conv3d_GroupNorm_Mean/reference.py create mode 100644 MaxKernel/evaluation/jaxbench_adapted_dataset/26k_BMM_InstanceNorm_Sum_ResidualAdd_Multiply/kernel_task.yaml create mode 100644 MaxKernel/evaluation/jaxbench_adapted_dataset/26k_BMM_InstanceNorm_Sum_ResidualAdd_Multiply/reference.py create mode 100644 MaxKernel/evaluation/jaxbench_adapted_dataset/27k_Matmul_Mish_Mish/kernel_task.yaml create mode 100644 MaxKernel/evaluation/jaxbench_adapted_dataset/27k_Matmul_Mish_Mish/reference.py create mode 100644 MaxKernel/evaluation/jaxbench_adapted_dataset/28k_ConvTranspose3d_LayerNorm_GELU_Scaling/kernel_task.yaml create mode 100644 MaxKernel/evaluation/jaxbench_adapted_dataset/28k_ConvTranspose3d_LayerNorm_GELU_Scaling/reference.py create mode 100644 MaxKernel/evaluation/jaxbench_adapted_dataset/29k_Matmul_Swish_Sum_GroupNorm/kernel_task.yaml create mode 100644 MaxKernel/evaluation/jaxbench_adapted_dataset/29k_Matmul_Swish_Sum_GroupNorm/reference.py create mode 100644 MaxKernel/evaluation/jaxbench_adapted_dataset/2p_GQA_Attention/kernel_task.yaml create mode 100644 MaxKernel/evaluation/jaxbench_adapted_dataset/2p_GQA_Attention/reference.py create mode 100644 MaxKernel/evaluation/jaxbench_adapted_dataset/30k_Matmul_Scaling_ResidualAdd/kernel_task.yaml create mode 100644 MaxKernel/evaluation/jaxbench_adapted_dataset/30k_Matmul_Scaling_ResidualAdd/reference.py create mode 100644 MaxKernel/evaluation/jaxbench_adapted_dataset/31k_Gemm_BatchNorm_GELU_ReLU/kernel_task.yaml create mode 100644 MaxKernel/evaluation/jaxbench_adapted_dataset/31k_Gemm_BatchNorm_GELU_ReLU/reference.py create mode 100644 MaxKernel/evaluation/jaxbench_adapted_dataset/32k_Gemm_Sigmoid_LogSumExp/kernel_task.yaml create mode 100644 MaxKernel/evaluation/jaxbench_adapted_dataset/32k_Gemm_Sigmoid_LogSumExp/reference.py create mode 100644 MaxKernel/evaluation/jaxbench_adapted_dataset/33k_Conv3d_Mish_Tanh/kernel_task.yaml create mode 100644 MaxKernel/evaluation/jaxbench_adapted_dataset/33k_Conv3d_Mish_Tanh/reference.py create mode 100644 MaxKernel/evaluation/jaxbench_adapted_dataset/34k_Conv2d_Activation_BatchNorm/kernel_task.yaml create mode 100644 MaxKernel/evaluation/jaxbench_adapted_dataset/34k_Conv2d_Activation_BatchNorm/reference.py create mode 100644 MaxKernel/evaluation/jaxbench_adapted_dataset/35k_Gemm_Scaling_Hardtanh_GELU/kernel_task.yaml create mode 100644 MaxKernel/evaluation/jaxbench_adapted_dataset/35k_Gemm_Scaling_Hardtanh_GELU/reference.py create mode 100644 MaxKernel/evaluation/jaxbench_adapted_dataset/36k_Matmul_Sigmoid_Sum/kernel_task.yaml create mode 100644 MaxKernel/evaluation/jaxbench_adapted_dataset/36k_Matmul_Sigmoid_Sum/reference.py create mode 100644 MaxKernel/evaluation/jaxbench_adapted_dataset/37k_Matmul_Swish_Scaling/kernel_task.yaml create mode 100644 MaxKernel/evaluation/jaxbench_adapted_dataset/37k_Matmul_Swish_Scaling/reference.py create mode 100644 MaxKernel/evaluation/jaxbench_adapted_dataset/38k_Matmul_Dropout_Softmax/kernel_task.yaml create mode 100644 MaxKernel/evaluation/jaxbench_adapted_dataset/38k_Matmul_Dropout_Softmax/reference.py create mode 100644 MaxKernel/evaluation/jaxbench_adapted_dataset/39k_Conv2d_GELU_GlobalAvgPool/kernel_task.yaml create mode 100644 MaxKernel/evaluation/jaxbench_adapted_dataset/39k_Conv2d_GELU_GlobalAvgPool/reference.py create mode 100644 MaxKernel/evaluation/jaxbench_adapted_dataset/3p_MLA_Attention/kernel_task.yaml create mode 100644 MaxKernel/evaluation/jaxbench_adapted_dataset/3p_MLA_Attention/reference.py create mode 100644 MaxKernel/evaluation/jaxbench_adapted_dataset/40k_Gemm_GroupNorm_Min_BiasAdd/kernel_task.yaml create mode 100644 MaxKernel/evaluation/jaxbench_adapted_dataset/40k_Gemm_GroupNorm_Min_BiasAdd/reference.py create mode 100644 MaxKernel/evaluation/jaxbench_adapted_dataset/41k_Gemm_Add_ReLU/kernel_task.yaml create mode 100644 MaxKernel/evaluation/jaxbench_adapted_dataset/41k_Gemm_Add_ReLU/reference.py create mode 100644 MaxKernel/evaluation/jaxbench_adapted_dataset/42k_Gemm_Max_Subtract_GELU/kernel_task.yaml create mode 100644 MaxKernel/evaluation/jaxbench_adapted_dataset/42k_Gemm_Max_Subtract_GELU/reference.py create mode 100644 MaxKernel/evaluation/jaxbench_adapted_dataset/43k_Gemm_BatchNorm_Scaling_Softmax/kernel_task.yaml create mode 100644 MaxKernel/evaluation/jaxbench_adapted_dataset/43k_Gemm_BatchNorm_Scaling_Softmax/reference.py create mode 100644 MaxKernel/evaluation/jaxbench_adapted_dataset/44k_Matmul_Divide_GELU/kernel_task.yaml create mode 100644 MaxKernel/evaluation/jaxbench_adapted_dataset/44k_Matmul_Divide_GELU/reference.py create mode 100644 MaxKernel/evaluation/jaxbench_adapted_dataset/45k_Gemm_GroupNorm_Swish_Multiply_Swish/kernel_task.yaml create mode 100644 MaxKernel/evaluation/jaxbench_adapted_dataset/45k_Gemm_GroupNorm_Swish_Multiply_Swish/reference.py create mode 100644 MaxKernel/evaluation/jaxbench_adapted_dataset/46k_Conv2d_GroupNorm_Tanh_HardSwish_ResidualAdd_LogSumExp/kernel_task.yaml create mode 100644 MaxKernel/evaluation/jaxbench_adapted_dataset/46k_Conv2d_GroupNorm_Tanh_HardSwish_ResidualAdd_LogSumExp/reference.py create mode 100644 MaxKernel/evaluation/jaxbench_adapted_dataset/47k_Matmul_Add_Swish_Tanh_GELU_Hardtanh/kernel_task.yaml create mode 100644 MaxKernel/evaluation/jaxbench_adapted_dataset/47k_Matmul_Add_Swish_Tanh_GELU_Hardtanh/reference.py create mode 100644 MaxKernel/evaluation/jaxbench_adapted_dataset/48k_Matmul_BatchNorm_BiasAdd_Divide_Swish/kernel_task.yaml create mode 100644 MaxKernel/evaluation/jaxbench_adapted_dataset/48k_Matmul_BatchNorm_BiasAdd_Divide_Swish/reference.py create mode 100644 MaxKernel/evaluation/jaxbench_adapted_dataset/49k_Matmul_AvgPool_GELU_Scale_Max/kernel_task.yaml create mode 100644 MaxKernel/evaluation/jaxbench_adapted_dataset/49k_Matmul_AvgPool_GELU_Scale_Max/reference.py create mode 100644 MaxKernel/evaluation/jaxbench_adapted_dataset/4p_Sparse_Attention/kernel_task.yaml create mode 100644 MaxKernel/evaluation/jaxbench_adapted_dataset/4p_Sparse_Attention/reference.py create mode 100644 MaxKernel/evaluation/jaxbench_adapted_dataset/50k_Matmul_GELU_Softmax/kernel_task.yaml create mode 100644 MaxKernel/evaluation/jaxbench_adapted_dataset/50k_Matmul_GELU_Softmax/reference.py create mode 100644 MaxKernel/evaluation/jaxbench_adapted_dataset/5p_Flex_Attention/kernel_task.yaml create mode 100644 MaxKernel/evaluation/jaxbench_adapted_dataset/5p_Flex_Attention/reference.py create mode 100644 MaxKernel/evaluation/jaxbench_adapted_dataset/6p_Paged_Attention/kernel_task.yaml create mode 100644 MaxKernel/evaluation/jaxbench_adapted_dataset/6p_Paged_Attention/reference.py create mode 100644 MaxKernel/evaluation/jaxbench_adapted_dataset/7p_Ragged_Paged_Attention/kernel_task.yaml create mode 100644 MaxKernel/evaluation/jaxbench_adapted_dataset/7p_Ragged_Paged_Attention/reference.py create mode 100644 MaxKernel/evaluation/jaxbench_adapted_dataset/8p_GEMM/kernel_task.yaml create mode 100644 MaxKernel/evaluation/jaxbench_adapted_dataset/8p_GEMM/reference.py create mode 100644 MaxKernel/evaluation/jaxbench_adapted_dataset/9p_SwiGLU_MLP/kernel_task.yaml create mode 100644 MaxKernel/evaluation/jaxbench_adapted_dataset/9p_SwiGLU_MLP/reference.py diff --git a/MaxKernel/evaluation/jaxbench_adapted_dataset/10p_Sparse_MoE/kernel_task.yaml b/MaxKernel/evaluation/jaxbench_adapted_dataset/10p_Sparse_MoE/kernel_task.yaml new file mode 100644 index 0000000..82a9938 --- /dev/null +++ b/MaxKernel/evaluation/jaxbench_adapted_dataset/10p_Sparse_MoE/kernel_task.yaml @@ -0,0 +1,33 @@ +task_id: 10p_Sparse_MoE +description: Kernel task for 10p_Sparse_MoE +input_gen_code: |- + def get_inputs(dtype=jnp.bfloat16): + import jax + import jax.numpy as jnp + from functools import partial + + CONFIG = { + 'name': 'mixtral_8x7b_moe', + 'model': 'Mixtral-8x7B', + 'operator': 'sparse_moe', + 'batch': 2, + 'seq_len': 4096, + 'emb_dim': 4096, + 'mlp_dim': 14336, + 'num_experts': 8, + 'num_experts_per_tok': 2, + } + key = jax.random.key(42) + keys = jax.random.split(key, 5) + B, S, E, M = CONFIG['batch'], CONFIG['seq_len'], CONFIG['emb_dim'], CONFIG['mlp_dim'] + N = CONFIG['num_experts'] + x = jax.random.normal(keys[0], (B, S, E), dtype=dtype) + router = jax.random.normal(keys[1], (E, N), dtype=dtype) * 0.02 + gate_k = jax.random.normal(keys[2], (N, E, M), dtype=dtype) * 0.02 + up_k = jax.random.normal(keys[3], (N, E, M), dtype=dtype) * 0.02 + down_k = jax.random.normal(keys[4], (N, M, E), dtype=dtype) * 0.02 + + dynamic_args = [x, router, gate_k, up_k, down_k] + static_args = [CONFIG['num_experts_per_tok']] + + return dynamic_args, static_args diff --git a/MaxKernel/evaluation/jaxbench_adapted_dataset/10p_Sparse_MoE/reference.py b/MaxKernel/evaluation/jaxbench_adapted_dataset/10p_Sparse_MoE/reference.py new file mode 100644 index 0000000..1d35fdb --- /dev/null +++ b/MaxKernel/evaluation/jaxbench_adapted_dataset/10p_Sparse_MoE/reference.py @@ -0,0 +1,50 @@ +# Imports +import jax +import jax.numpy as jnp +from functools import partial + +# Initialization +def get_inputs(dtype=jnp.bfloat16): + CONFIG = { + 'name': 'mixtral_8x7b_moe', + 'model': 'Mixtral-8x7B', + 'operator': 'sparse_moe', + 'batch': 2, + 'seq_len': 4096, + 'emb_dim': 4096, + 'mlp_dim': 14336, + 'num_experts': 8, + 'num_experts_per_tok': 2, + } + key = jax.random.key(42) + keys = jax.random.split(key, 5) + B, S, E, M = CONFIG['batch'], CONFIG['seq_len'], CONFIG['emb_dim'], CONFIG['mlp_dim'] + N = CONFIG['num_experts'] + x = jax.random.normal(keys[0], (B, S, E), dtype=dtype) + router = jax.random.normal(keys[1], (E, N), dtype=dtype) * 0.02 + gate_k = jax.random.normal(keys[2], (N, E, M), dtype=dtype) * 0.02 + up_k = jax.random.normal(keys[3], (N, E, M), dtype=dtype) * 0.02 + down_k = jax.random.normal(keys[4], (N, M, E), dtype=dtype) * 0.02 + + dynamic_args = [x, router, gate_k, up_k, down_k] + static_args = [CONFIG['num_experts_per_tok']] + + return dynamic_args, static_args + +# Computation +def computation(x, router_weights, expert_gate_kernels, expert_up_kernels, expert_down_kernels, num_experts_per_tok): + B, S, E = x.shape + N = router_weights.shape[-1] + K = num_experts_per_tok + logits = jnp.dot(x, router_weights) + top_k_logits, top_k_indices = jax.lax.top_k(logits, K) + router_probs = jax.nn.softmax(top_k_logits, axis=-1) + gate_out = jax.nn.silu(jnp.einsum('bse,nem->bsnm', x, expert_gate_kernels)) + up_out = jnp.einsum('bse,nem->bsnm', x, expert_up_kernels) + hidden = gate_out * up_out + expert_outputs = jnp.einsum('bsnm,nme->bsne', hidden, expert_down_kernels) + one_hot = jax.nn.one_hot(top_k_indices, N) + weighted = one_hot * router_probs[..., None] + expert_weights = weighted.sum(axis=2) + output = jnp.einsum('bsne,bsn->bse', expert_outputs, expert_weights) + return output \ No newline at end of file diff --git a/MaxKernel/evaluation/jaxbench_adapted_dataset/11p_Megablox_GMM/kernel_task.yaml b/MaxKernel/evaluation/jaxbench_adapted_dataset/11p_Megablox_GMM/kernel_task.yaml new file mode 100644 index 0000000..81ef5f6 --- /dev/null +++ b/MaxKernel/evaluation/jaxbench_adapted_dataset/11p_Megablox_GMM/kernel_task.yaml @@ -0,0 +1,37 @@ +task_id: 11p_Megablox_GMM +description: Kernel task for 11p_Megablox_GMM +input_gen_code: |- + def get_inputs(dtype=jnp.bfloat16): + import jax + import jax.numpy as jnp + + CONFIG = { + 'name': 'megablox_gmm_qwen3_235b', + 'model': 'Qwen3-235B-A22B', + 'operator': 'grouped_matmul', + 'num_experts': 128, + 'num_experts_per_tok': 8, + 'emb_dim': 4096, + 'moe_mlp_dim': 1536, + 'seq_len': 4096, + } + key = jax.random.key(42) + k1, k2 = jax.random.split(key, 2) + G = CONFIG['num_experts'] + top_k = CONFIG['num_experts_per_tok'] + K = CONFIG['emb_dim'] + N = CONFIG['moe_mlp_dim'] + S = CONFIG['seq_len'] + M = S * top_k + limit = 1 / (M * K) + lhs = jax.random.uniform(k1, (M, K), dtype=dtype, minval=-limit, maxval=limit) + lhs = lhs.astype(jnp.bfloat16).astype(dtype) + rhs = jax.random.uniform(k2, (G, K, N), dtype=dtype, minval=-limit, maxval=limit) + rhs = rhs.astype(jnp.bfloat16).astype(dtype) + max_expert_size = M // G + group_sizes = jnp.full((G,), max_expert_size, dtype=jnp.int32) + + dynamic_args = [lhs, rhs, group_sizes] + static_args = [max_expert_size] + + return dynamic_args, static_args diff --git a/MaxKernel/evaluation/jaxbench_adapted_dataset/11p_Megablox_GMM/reference.py b/MaxKernel/evaluation/jaxbench_adapted_dataset/11p_Megablox_GMM/reference.py new file mode 100644 index 0000000..47455b3 --- /dev/null +++ b/MaxKernel/evaluation/jaxbench_adapted_dataset/11p_Megablox_GMM/reference.py @@ -0,0 +1,81 @@ +# Imports +import jax +import jax.numpy as jnp + +# Initialization +def get_inputs(dtype=jnp.bfloat16): + CONFIG = { + 'name': 'megablox_gmm_qwen3_235b', + 'model': 'Qwen3-235B-A22B', + 'operator': 'grouped_matmul', + 'num_experts': 128, + 'num_experts_per_tok': 8, + 'emb_dim': 4096, + 'moe_mlp_dim': 1536, + 'seq_len': 4096, + } + key = jax.random.key(42) + k1, k2 = jax.random.split(key, 2) + G = CONFIG['num_experts'] + top_k = CONFIG['num_experts_per_tok'] + K = CONFIG['emb_dim'] + N = CONFIG['moe_mlp_dim'] + S = CONFIG['seq_len'] + M = S * top_k + limit = 1 / (M * K) + lhs = jax.random.uniform(k1, (M, K), dtype=dtype, minval=-limit, maxval=limit) + lhs = lhs.astype(jnp.bfloat16).astype(dtype) + rhs = jax.random.uniform(k2, (G, K, N), dtype=dtype, minval=-limit, maxval=limit) + rhs = rhs.astype(jnp.bfloat16).astype(dtype) + max_expert_size = M // G + group_sizes = jnp.full((G,), max_expert_size, dtype=jnp.int32) + + dynamic_args = [lhs, rhs, group_sizes] + static_args = [max_expert_size] + + return dynamic_args, static_args + +# Computation +def computation(lhs, rhs, group_sizes, max_expert_size): + G = rhs.shape[0] + M, K = lhs.shape + N = rhs.shape[2] + + group_ends = jnp.cumsum(group_sizes) + group_starts = jnp.concatenate( + [jnp.zeros(1, dtype=jnp.int32), group_ends[:-1]] + ) + + res_flat = jnp.zeros((M + max_expert_size, N), dtype=lhs.dtype) + + def body_fun(carry_res_flat, i): + start = group_starts[i] + count = group_sizes[i] + + expert_lhs = jax.lax.dynamic_slice( + lhs, (start, 0), (max_expert_size, K) + ) + expert_rhs = rhs[i, :, :] + + res = jax.lax.dot( + expert_lhs, expert_rhs, preferred_element_type=jnp.float32 + ) + + mask = ( + jax.lax.broadcasted_iota(jnp.int32, (max_expert_size, N), 0) < count + ) + res_masked = jnp.where(mask, res, 0.0) + + current_slice = jax.lax.dynamic_slice( + carry_res_flat, (start, 0), (max_expert_size, N) + ) + updated_slice = current_slice + res_masked.astype(carry_res_flat.dtype) + carry_res_flat = jax.lax.dynamic_update_slice( + carry_res_flat, updated_slice, (start, 0) + ) + + return carry_res_flat, None + + res_flat, _ = jax.lax.scan(body_fun, res_flat, jnp.arange(G)) + + return res_flat[:M, :] \ No newline at end of file diff --git a/MaxKernel/evaluation/jaxbench_adapted_dataset/12p_RMSNorm/kernel_task.yaml b/MaxKernel/evaluation/jaxbench_adapted_dataset/12p_RMSNorm/kernel_task.yaml new file mode 100644 index 0000000..d32524b --- /dev/null +++ b/MaxKernel/evaluation/jaxbench_adapted_dataset/12p_RMSNorm/kernel_task.yaml @@ -0,0 +1,18 @@ +task_id: 12p_RMSNorm +description: Kernel task for 12p_RMSNorm +input_gen_code: |- + def get_inputs(dtype=jnp.bfloat16): + import jax + import jax.numpy as jnp + from jax import lax + + batch = 8 + seq_len = 4096 + emb_dim = 8192 + epsilon = 1e-5 + + key = jax.random.key(42) + k1, k2 = jax.random.split(key, 2) + x = jax.random.normal(k1, (batch, seq_len, emb_dim), dtype=dtype) + scale = jax.random.normal(k2, (emb_dim,), dtype=dtype) * 0.1 + 1.0 + return [x, scale], [epsilon] diff --git a/MaxKernel/evaluation/jaxbench_adapted_dataset/12p_RMSNorm/reference.py b/MaxKernel/evaluation/jaxbench_adapted_dataset/12p_RMSNorm/reference.py new file mode 100644 index 0000000..e03707c --- /dev/null +++ b/MaxKernel/evaluation/jaxbench_adapted_dataset/12p_RMSNorm/reference.py @@ -0,0 +1,25 @@ +# Imports +import jax +import jax.numpy as jnp +from jax import lax + +# Initialization +def get_inputs(dtype=jnp.bfloat16): + batch = 8 + seq_len = 4096 + emb_dim = 8192 + epsilon = 1e-5 + + key = jax.random.key(42) + k1, k2 = jax.random.split(key, 2) + x = jax.random.normal(k1, (batch, seq_len, emb_dim), dtype=dtype) + scale = jax.random.normal(k2, (emb_dim,), dtype=dtype) * 0.1 + 1.0 + return [x, scale], [epsilon] + +# Computation +def computation(x, scale, epsilon): + x_f32 = jnp.asarray(x, jnp.float32) + mean2 = jnp.mean(lax.square(x_f32), axis=-1, keepdims=True) + normed = x_f32 * lax.rsqrt(mean2 + epsilon) + normed = jnp.asarray(normed, x.dtype) + return normed * scale \ No newline at end of file diff --git a/MaxKernel/evaluation/jaxbench_adapted_dataset/13p_Cross_Entropy/kernel_task.yaml b/MaxKernel/evaluation/jaxbench_adapted_dataset/13p_Cross_Entropy/kernel_task.yaml new file mode 100644 index 0000000..e6678cd --- /dev/null +++ b/MaxKernel/evaluation/jaxbench_adapted_dataset/13p_Cross_Entropy/kernel_task.yaml @@ -0,0 +1,25 @@ +task_id: 13p_Cross_Entropy +description: Kernel task for 13p_Cross_Entropy +input_gen_code: |- + def get_inputs(): + import jax + import jax.numpy as jnp + + CONFIG = { + 'name': 'llama3_8b_cross_entropy', + 'model': 'Llama-3.1-8B', + 'operator': 'fused_cross_entropy', + 'batch_tokens': 8192, + 'hidden_dim': 4096, + 'vocab_size': 128256, + } + dtype = jnp.bfloat16 + key = jax.random.key(42) + k1, k2, k3 = jax.random.split(key, 3) + B, H, V = CONFIG['batch_tokens'], CONFIG['hidden_dim'], CONFIG['vocab_size'] + hidden = jax.random.normal(k1, (B, H), dtype=dtype) + weight = jax.random.normal(k2, (H, V), dtype=dtype) * 0.02 + labels = jax.random.randint(k3, (B,), 0, V) + dynamic_args = [hidden, weight, labels] + static_args = [] + return dynamic_args, static_args diff --git a/MaxKernel/evaluation/jaxbench_adapted_dataset/13p_Cross_Entropy/reference.py b/MaxKernel/evaluation/jaxbench_adapted_dataset/13p_Cross_Entropy/reference.py new file mode 100644 index 0000000..ec17179 --- /dev/null +++ b/MaxKernel/evaluation/jaxbench_adapted_dataset/13p_Cross_Entropy/reference.py @@ -0,0 +1,32 @@ +# Imports +import jax +import jax.numpy as jnp + +# Initialization +def get_inputs(): + CONFIG = { + 'name': 'llama3_8b_cross_entropy', + 'model': 'Llama-3.1-8B', + 'operator': 'fused_cross_entropy', + 'batch_tokens': 8192, + 'hidden_dim': 4096, + 'vocab_size': 128256, + } + dtype = jnp.bfloat16 + key = jax.random.key(42) + k1, k2, k3 = jax.random.split(key, 3) + B, H, V = CONFIG['batch_tokens'], CONFIG['hidden_dim'], CONFIG['vocab_size'] + hidden = jax.random.normal(k1, (B, H), dtype=dtype) + weight = jax.random.normal(k2, (H, V), dtype=dtype) * 0.02 + labels = jax.random.randint(k3, (B,), 0, V) + dynamic_args = [hidden, weight, labels] + static_args = [] + return dynamic_args, static_args + +# Computation +def computation(hidden, weight, labels): + logits = jnp.dot(hidden, weight) + log_probs = jax.nn.log_softmax(logits, axis=-1) + one_hot = jax.nn.one_hot(labels, logits.shape[-1]) + loss = -jnp.sum(one_hot * log_probs, axis=-1) + return jnp.mean(loss) \ No newline at end of file diff --git a/MaxKernel/evaluation/jaxbench_adapted_dataset/14p_Ragged_Dot/kernel_task.yaml b/MaxKernel/evaluation/jaxbench_adapted_dataset/14p_Ragged_Dot/kernel_task.yaml new file mode 100644 index 0000000..769a5fd --- /dev/null +++ b/MaxKernel/evaluation/jaxbench_adapted_dataset/14p_Ragged_Dot/kernel_task.yaml @@ -0,0 +1,25 @@ +task_id: 14p_Ragged_Dot +description: Kernel task for 14p_Ragged_Dot +input_gen_code: |- + def get_inputs(): + import jax + import jax.numpy as jnp + + CONFIG = { + 'name': 'mixtral_8x7b_ragged_dot', + 'model': 'Mixtral-8x7B', + 'operator': 'ragged_dot', + 'num_groups': 8, + 'M': 8192, + 'K': 4096, + 'N': 14336, + } + dtype = jnp.bfloat16 + key = jax.random.key(42) + k1, k2 = jax.random.split(key, 2) + G, M, K, N = CONFIG['num_groups'], CONFIG['M'], CONFIG['K'], CONFIG['N'] + x = jax.random.normal(k1, (G, M // G, K), dtype=dtype) + weights = jax.random.normal(k2, (G, K, N), dtype=dtype) * 0.02 + dynamic_args = [x, weights] + static_args = [] + return dynamic_args, static_args diff --git a/MaxKernel/evaluation/jaxbench_adapted_dataset/14p_Ragged_Dot/reference.py b/MaxKernel/evaluation/jaxbench_adapted_dataset/14p_Ragged_Dot/reference.py new file mode 100644 index 0000000..c5f9a01 --- /dev/null +++ b/MaxKernel/evaluation/jaxbench_adapted_dataset/14p_Ragged_Dot/reference.py @@ -0,0 +1,28 @@ +# Imports +import jax +import jax.numpy as jnp + +# Initialization +def get_inputs(): + CONFIG = { + 'name': 'mixtral_8x7b_ragged_dot', + 'model': 'Mixtral-8x7B', + 'operator': 'ragged_dot', + 'num_groups': 8, + 'M': 8192, + 'K': 4096, + 'N': 14336, + } + dtype = jnp.bfloat16 + key = jax.random.key(42) + k1, k2 = jax.random.split(key, 2) + G, M, K, N = CONFIG['num_groups'], CONFIG['M'], CONFIG['K'], CONFIG['N'] + x = jax.random.normal(k1, (G, M // G, K), dtype=dtype) + weights = jax.random.normal(k2, (G, K, N), dtype=dtype) * 0.02 + dynamic_args = [x, weights] + static_args = [] + return dynamic_args, static_args + +# Computation +def computation(x, weights): + return jnp.einsum('gmk,gkn->gmn', x, weights) \ No newline at end of file diff --git a/MaxKernel/evaluation/jaxbench_adapted_dataset/15p_RetNet_Retention/kernel_task.yaml b/MaxKernel/evaluation/jaxbench_adapted_dataset/15p_RetNet_Retention/kernel_task.yaml new file mode 100644 index 0000000..38aad6f --- /dev/null +++ b/MaxKernel/evaluation/jaxbench_adapted_dataset/15p_RetNet_Retention/kernel_task.yaml @@ -0,0 +1,29 @@ +task_id: 15p_RetNet_Retention +description: Kernel task for 15p_RetNet_Retention +input_gen_code: |- + def get_inputs(): + import jax + import jax.numpy as jnp + from functools import partial + + CONFIG = { + 'name': 'retnet_6_7b_retention', + 'model': 'RetNet-6.7B', + 'operator': 'multi_scale_retention', + 'batch': 4, + 'seq_len': 4096, + 'num_heads': 16, + 'head_dim': 256, + 'd_model': 4096, + } + dtype = jnp.bfloat16 + key = jax.random.key(42) + keys = jax.random.split(key, 3) + B, S = CONFIG['batch'], CONFIG['seq_len'] + H, D = CONFIG['num_heads'], CONFIG['head_dim'] + query = jax.random.normal(keys[0], (B, H, S, D), dtype=dtype) + key_val = jax.random.normal(keys[1], (B, H, S, D), dtype=dtype) + value = jax.random.normal(keys[2], (B, H, S, D), dtype=dtype) + dynamic_args = [query, key_val, value] + static_args = [] + return dynamic_args, static_args diff --git a/MaxKernel/evaluation/jaxbench_adapted_dataset/15p_RetNet_Retention/reference.py b/MaxKernel/evaluation/jaxbench_adapted_dataset/15p_RetNet_Retention/reference.py new file mode 100644 index 0000000..c8058eb --- /dev/null +++ b/MaxKernel/evaluation/jaxbench_adapted_dataset/15p_RetNet_Retention/reference.py @@ -0,0 +1,46 @@ +# Imports +import jax +import jax.numpy as jnp +from functools import partial + +# Initialization +def get_inputs(): + CONFIG = { + 'name': 'retnet_6_7b_retention', + 'model': 'RetNet-6.7B', + 'operator': 'multi_scale_retention', + 'batch': 4, + 'seq_len': 4096, + 'num_heads': 16, + 'head_dim': 256, + 'd_model': 4096, + } + dtype = jnp.bfloat16 + key = jax.random.key(42) + keys = jax.random.split(key, 3) + B, S = CONFIG['batch'], CONFIG['seq_len'] + H, D = CONFIG['num_heads'], CONFIG['head_dim'] + query = jax.random.normal(keys[0], (B, H, S, D), dtype=dtype) + key_val = jax.random.normal(keys[1], (B, H, S, D), dtype=dtype) + value = jax.random.normal(keys[2], (B, H, S, D), dtype=dtype) + dynamic_args = [query, key_val, value] + static_args = [] + return dynamic_args, static_args + +# Computation +def computation(query, key, value): + B, H, S, D = query.shape + gammas = 1.0 - jnp.exp2(-5.0 - jnp.arange(H, dtype=jnp.float32)) + positions = jnp.arange(S, dtype=jnp.float32) + distance = positions[:, None] - positions[None, :] + causal_mask = (distance >= 0).astype(jnp.float32) + log_gamma = jnp.log(gammas) + decay = jnp.exp(log_gamma[:, None, None] * jnp.maximum(distance, 0.0)[None, :, :]) + decay = decay * causal_mask[None, :, :] + qk = jnp.einsum('bhsd,bhtd->bhst', query.astype(jnp.float32), key.astype(jnp.float32)) + qk = qk * decay[None, :, :, :] + retention_sum = jnp.sum(jnp.abs(qk), axis=-1, keepdims=True) + retention_sum = jnp.maximum(retention_sum, 1.0) + qk = qk / retention_sum + output = jnp.einsum('bhst,bhtd->bhsd', qk.astype(query.dtype), value) + return output \ No newline at end of file diff --git a/MaxKernel/evaluation/jaxbench_adapted_dataset/16p_Mamba2_SSD/kernel_task.yaml b/MaxKernel/evaluation/jaxbench_adapted_dataset/16p_Mamba2_SSD/kernel_task.yaml new file mode 100644 index 0000000..2058c10 --- /dev/null +++ b/MaxKernel/evaluation/jaxbench_adapted_dataset/16p_Mamba2_SSD/kernel_task.yaml @@ -0,0 +1,29 @@ +task_id: 16p_Mamba2_SSD +description: Kernel task for 16p_Mamba2_SSD +input_gen_code: |- + def get_inputs(dtype=jnp.bfloat16): + import jax + import jax.numpy as jnp + + CONFIG = { + 'name': 'mamba2_2_7b_ssd', + 'model': 'Mamba-2-2.7B', + 'operator': 'state_space_duality', + 'batch': 4, + 'seq_len': 4096, + 'num_heads': 64, + 'head_dim': 64, + 'd_state': 128, + 'd_model': 2560, + } + rng = jax.random.key(42) + keys = jax.random.split(rng, 5) + B, S = CONFIG['batch'], CONFIG['seq_len'] + H, D = CONFIG['num_heads'], CONFIG['head_dim'] + query = jax.random.normal(keys[0], (B, H, S, D), dtype=dtype) + key = jax.random.normal(keys[1], (B, H, S, D), dtype=dtype) + value = jax.random.normal(keys[2], (B, H, S, D), dtype=dtype) + A_log = jax.random.normal(keys[3], (B, H, S), dtype=jnp.float32) * 0.5 - 4.0 + dynamic_args = [query, key, value, A_log] + static_args = [] + return dynamic_args, static_args diff --git a/MaxKernel/evaluation/jaxbench_adapted_dataset/16p_Mamba2_SSD/reference.py b/MaxKernel/evaluation/jaxbench_adapted_dataset/16p_Mamba2_SSD/reference.py new file mode 100644 index 0000000..cb1c90a --- /dev/null +++ b/MaxKernel/evaluation/jaxbench_adapted_dataset/16p_Mamba2_SSD/reference.py @@ -0,0 +1,47 @@ +# Imports +import jax +import jax.numpy as jnp + +# Initialization +def get_inputs(dtype=jnp.bfloat16): + CONFIG = { + 'name': 'mamba2_2_7b_ssd', + 'model': 'Mamba-2-2.7B', + 'operator': 'state_space_duality', + 'batch': 4, + 'seq_len': 4096, + 'num_heads': 64, + 'head_dim': 64, + 'd_state': 128, + 'd_model': 2560, + } + rng = jax.random.key(42) + keys = jax.random.split(rng, 5) + B, S = CONFIG['batch'], CONFIG['seq_len'] + H, D = CONFIG['num_heads'], CONFIG['head_dim'] + query = jax.random.normal(keys[0], (B, H, S, D), dtype=dtype) + key = jax.random.normal(keys[1], (B, H, S, D), dtype=dtype) + value = jax.random.normal(keys[2], (B, H, S, D), dtype=dtype) + A_log = jax.random.normal(keys[3], (B, H, S), dtype=jnp.float32) * 0.5 - 4.0 + dynamic_args = [query, key, value, A_log] + static_args = [] + return dynamic_args, static_args + +# Computation +def computation(query, key, value, A_log): + B, H, S, D = query.shape + a = jax.nn.sigmoid(A_log.astype(jnp.float32)) + log_a = jnp.log(a + 1e-8) + log_a_cumsum = jnp.cumsum(log_a, axis=-1) + diff = log_a_cumsum[:, :, :, None] - log_a_cumsum[:, :, None, :] + causal = jnp.tril(jnp.ones((S, S), dtype=jnp.bool_)) + L = jnp.exp(jnp.where(causal[None, None, :, :], diff, -1e30)) + scores = jnp.einsum('bhsd,bhtd->bhst', + query.astype(jnp.float32), + key.astype(jnp.float32)) + scores = scores * L + scores_sum = jnp.sum(scores, axis=-1, keepdims=True) + scores_sum = jnp.where(jnp.abs(scores_sum) < 1e-6, 1.0, scores_sum) + scores = scores / jnp.maximum(jnp.abs(scores_sum), 1.0) + output = jnp.einsum('bhst,bhtd->bhsd', scores.astype(query.dtype), value) + return output \ No newline at end of file diff --git a/MaxKernel/evaluation/jaxbench_adapted_dataset/17p_Triangle_Multiplication/kernel_task.yaml b/MaxKernel/evaluation/jaxbench_adapted_dataset/17p_Triangle_Multiplication/kernel_task.yaml new file mode 100644 index 0000000..c9312f3 --- /dev/null +++ b/MaxKernel/evaluation/jaxbench_adapted_dataset/17p_Triangle_Multiplication/kernel_task.yaml @@ -0,0 +1,30 @@ +task_id: 17p_Triangle_Multiplication +description: Kernel task for 17p_Triangle_Multiplication +input_gen_code: |- + def get_inputs(dtype=jnp.bfloat16): + import jax + import jax.numpy as jnp + + CONFIG = { + 'name': 'alphafold_768_triangle_mult', + 'model': 'AlphaFold2', + 'operator': 'triangle_mult_outgoing', + 'N': 1536, + 'C': 128, + 'direction': 'outgoing', + } + key = jax.random.key(42) + keys = jax.random.split(key, 9) + N, C = CONFIG['N'], CONFIG['C'] + pair_act = jax.random.normal(keys[0], (N, N, C), dtype=dtype) + mask = jnp.ones((N, N, 1), dtype=dtype) + left_proj = jax.random.normal(keys[1], (C, C), dtype=dtype) * 0.02 + right_proj = jax.random.normal(keys[2], (C, C), dtype=dtype) * 0.02 + left_gate = jax.random.normal(keys[3], (C, C), dtype=dtype) * 0.02 + right_gate = jax.random.normal(keys[4], (C, C), dtype=dtype) * 0.02 + center_scale = jax.random.normal(keys[5], (C,), dtype=dtype) * 0.1 + 1.0 + out_proj = jax.random.normal(keys[6], (C, C), dtype=dtype) * 0.02 + out_gate = jax.random.normal(keys[7], (C, C), dtype=dtype) * 0.02 + dynamic_args = [pair_act, mask, left_proj, right_proj, left_gate, right_gate, center_scale, out_proj, out_gate] + static_args = [] + return dynamic_args, static_args diff --git a/MaxKernel/evaluation/jaxbench_adapted_dataset/17p_Triangle_Multiplication/reference.py b/MaxKernel/evaluation/jaxbench_adapted_dataset/17p_Triangle_Multiplication/reference.py new file mode 100644 index 0000000..16d6054 --- /dev/null +++ b/MaxKernel/evaluation/jaxbench_adapted_dataset/17p_Triangle_Multiplication/reference.py @@ -0,0 +1,46 @@ +# Imports +import jax +import jax.numpy as jnp + +# Initialization +def get_inputs(dtype=jnp.bfloat16): + CONFIG = { + 'name': 'alphafold_768_triangle_mult', + 'model': 'AlphaFold2', + 'operator': 'triangle_mult_outgoing', + 'N': 1536, + 'C': 128, + 'direction': 'outgoing', + } + key = jax.random.key(42) + keys = jax.random.split(key, 9) + N, C = CONFIG['N'], CONFIG['C'] + pair_act = jax.random.normal(keys[0], (N, N, C), dtype=dtype) + mask = jnp.ones((N, N, 1), dtype=dtype) + left_proj = jax.random.normal(keys[1], (C, C), dtype=dtype) * 0.02 + right_proj = jax.random.normal(keys[2], (C, C), dtype=dtype) * 0.02 + left_gate = jax.random.normal(keys[3], (C, C), dtype=dtype) * 0.02 + right_gate = jax.random.normal(keys[4], (C, C), dtype=dtype) * 0.02 + center_scale = jax.random.normal(keys[5], (C,), dtype=dtype) * 0.1 + 1.0 + out_proj = jax.random.normal(keys[6], (C, C), dtype=dtype) * 0.02 + out_gate = jax.random.normal(keys[7], (C, C), dtype=dtype) * 0.02 + dynamic_args = [pair_act, mask, left_proj, right_proj, left_gate, right_gate, center_scale, out_proj, out_gate] + static_args = [] + return dynamic_args, static_args + +# Computation +def computation(pair_act, mask, left_proj_w, right_proj_w, left_gate_w, right_gate_w, center_scale, out_proj_w, out_gate_w): + act = pair_act * mask + left_proj = jnp.dot(act, left_proj_w) + right_proj = jnp.dot(act, right_proj_w) + left_gate = jax.nn.sigmoid(jnp.dot(act, left_gate_w)) + right_gate = jax.nn.sigmoid(jnp.dot(act, right_gate_w)) + left_proj = left_proj * left_gate + right_proj = right_proj * right_gate + result = jnp.einsum('ikc,jkc->ijc', left_proj, right_proj) + eps = 1e-6 + rms = jnp.sqrt(jnp.mean(result * result, axis=-1, keepdims=True) + eps) + result = result / rms * center_scale + output = jnp.dot(result, out_proj_w) + gate = jax.nn.sigmoid(jnp.dot(pair_act, out_gate_w)) + return output * gate \ No newline at end of file diff --git a/MaxKernel/evaluation/jaxbench_adapted_dataset/18k_Conv2D_ReLU_BiasAdd/kernel_task.yaml b/MaxKernel/evaluation/jaxbench_adapted_dataset/18k_Conv2D_ReLU_BiasAdd/kernel_task.yaml new file mode 100644 index 0000000..d64cdc4 --- /dev/null +++ b/MaxKernel/evaluation/jaxbench_adapted_dataset/18k_Conv2D_ReLU_BiasAdd/kernel_task.yaml @@ -0,0 +1,23 @@ +task_id: 18k_Conv2D_ReLU_BiasAdd +description: Kernel task for 18k_Conv2D_ReLU_BiasAdd +input_gen_code: |- + def get_inputs(dtype=jnp.float32): + import jax + import jax.numpy as jnp + + batch_size = 128 + in_channels = 64 + out_channels = 128 + kernel_size = 3 + height = width = 128 + + key = jax.random.key(0) + k1, k2 = jax.random.split(key) + x = jax.random.uniform(k1, (batch_size, in_channels, height, width), dtype=dtype) + weight = jnp.zeros((out_channels, in_channels, kernel_size, kernel_size), dtype=dtype) + conv_bias = jnp.zeros(out_channels, dtype=dtype) + bias = jnp.zeros((out_channels, 1, 1), dtype=dtype) + + dynamic_args = [x, weight, conv_bias, bias] + static_args = [] + return dynamic_args, static_args diff --git a/MaxKernel/evaluation/jaxbench_adapted_dataset/18k_Conv2D_ReLU_BiasAdd/reference.py b/MaxKernel/evaluation/jaxbench_adapted_dataset/18k_Conv2D_ReLU_BiasAdd/reference.py new file mode 100644 index 0000000..ff66c21 --- /dev/null +++ b/MaxKernel/evaluation/jaxbench_adapted_dataset/18k_Conv2D_ReLU_BiasAdd/reference.py @@ -0,0 +1,38 @@ +# Imports +import jax +import jax.numpy as jnp + +# Initialization +def get_inputs(dtype=jnp.float32): + batch_size = 128 + in_channels = 64 + out_channels = 128 + kernel_size = 3 + height = width = 128 + + key = jax.random.key(0) + k1, k2 = jax.random.split(key) + x = jax.random.uniform(k1, (batch_size, in_channels, height, width), dtype=dtype) + weight = jnp.zeros((out_channels, in_channels, kernel_size, kernel_size), dtype=dtype) + conv_bias = jnp.zeros(out_channels, dtype=dtype) + bias = jnp.zeros((out_channels, 1, 1), dtype=dtype) + + dynamic_args = [x, weight, conv_bias, bias] + static_args = [] + return dynamic_args, static_args + +# Computation +def computation(x, weight, conv_bias, bias): + x = jnp.transpose(x, (0, 2, 3, 1)) + kernel = jnp.transpose(weight, (2, 3, 1, 0)) + x = jax.lax.conv_general_dilated( + x, kernel, + window_strides=(1, 1), + padding='VALID', + dimension_numbers=('NHWC', 'HWIO', 'NHWC') + ) + x = x + conv_bias.reshape(1, 1, 1, -1) + x = jax.nn.relu(x) + x = jnp.transpose(x, (0, 3, 1, 2)) + x = x + bias + return x \ No newline at end of file diff --git a/MaxKernel/evaluation/jaxbench_adapted_dataset/19k_Matmul_Subtract_Multiply_ReLU/kernel_task.yaml b/MaxKernel/evaluation/jaxbench_adapted_dataset/19k_Matmul_Subtract_Multiply_ReLU/kernel_task.yaml new file mode 100644 index 0000000..67fd686 --- /dev/null +++ b/MaxKernel/evaluation/jaxbench_adapted_dataset/19k_Matmul_Subtract_Multiply_ReLU/kernel_task.yaml @@ -0,0 +1,21 @@ +task_id: 19k_Matmul_Subtract_Multiply_ReLU +description: Kernel task for 19k_Matmul_Subtract_Multiply_ReLU +input_gen_code: |- + def get_inputs(dtype=jnp.float32): + import jax + import jax.numpy as jnp + + batch_size = 4096 + in_features = 8192 + out_features = 8192 + subtract_value = 2.0 + multiply_value = 1.5 + + key = jax.random.key(0) + x = jax.random.uniform(key, (batch_size, in_features), dtype=dtype) + weight = jnp.zeros((in_features, out_features), dtype=dtype) + bias = jnp.zeros(out_features, dtype=dtype) + + dynamic_args = [x, weight, bias] + static_args = [subtract_value, multiply_value] + return dynamic_args, static_args diff --git a/MaxKernel/evaluation/jaxbench_adapted_dataset/19k_Matmul_Subtract_Multiply_ReLU/reference.py b/MaxKernel/evaluation/jaxbench_adapted_dataset/19k_Matmul_Subtract_Multiply_ReLU/reference.py new file mode 100644 index 0000000..9bddeff --- /dev/null +++ b/MaxKernel/evaluation/jaxbench_adapted_dataset/19k_Matmul_Subtract_Multiply_ReLU/reference.py @@ -0,0 +1,28 @@ +# Imports +import jax +import jax.numpy as jnp + +# Initialization +def get_inputs(dtype=jnp.float32): + batch_size = 4096 + in_features = 8192 + out_features = 8192 + subtract_value = 2.0 + multiply_value = 1.5 + + key = jax.random.key(0) + x = jax.random.uniform(key, (batch_size, in_features), dtype=dtype) + weight = jnp.zeros((in_features, out_features), dtype=dtype) + bias = jnp.zeros(out_features, dtype=dtype) + + dynamic_args = [x, weight, bias] + static_args = [subtract_value, multiply_value] + return dynamic_args, static_args + +# Computation +def computation(x, weight, bias, subtract_value, multiply_value): + x = jnp.matmul(x, weight) + bias + x = x - subtract_value + x = x * multiply_value + x = jax.nn.relu(x) + return x \ No newline at end of file diff --git a/MaxKernel/evaluation/jaxbench_adapted_dataset/1p_Flash_Attention/kernel_task.yaml b/MaxKernel/evaluation/jaxbench_adapted_dataset/1p_Flash_Attention/kernel_task.yaml new file mode 100644 index 0000000..2cde484 --- /dev/null +++ b/MaxKernel/evaluation/jaxbench_adapted_dataset/1p_Flash_Attention/kernel_task.yaml @@ -0,0 +1,20 @@ +task_id: 1p_Flash_Attention +description: Kernel task for 1p_Flash_Attention +input_gen_code: |- + def get_inputs(): + import jax + import jax.numpy as jnp + + dtype = jnp.bfloat16 + key = jax.random.key(42) + k1, k2, k3 = jax.random.split(key, 3) + B = 4 + S = 4096 + H = 64 + D = 128 + query = jax.random.normal(k1, (B, H, S, D), dtype=dtype) + key_t = jax.random.normal(k2, (B, H, S, D), dtype=dtype) + value = jax.random.normal(k3, (B, H, S, D), dtype=dtype) + dynamic_args = [query, key_t, value] + static_args = [] + return dynamic_args, static_args diff --git a/MaxKernel/evaluation/jaxbench_adapted_dataset/1p_Flash_Attention/reference.py b/MaxKernel/evaluation/jaxbench_adapted_dataset/1p_Flash_Attention/reference.py new file mode 100644 index 0000000..21eba34 --- /dev/null +++ b/MaxKernel/evaluation/jaxbench_adapted_dataset/1p_Flash_Attention/reference.py @@ -0,0 +1,30 @@ +# Imports +import jax +import jax.numpy as jnp + +# Initialization +def get_inputs(): + dtype = jnp.bfloat16 + key = jax.random.key(42) + k1, k2, k3 = jax.random.split(key, 3) + B = 4 + S = 4096 + H = 64 + D = 128 + query = jax.random.normal(k1, (B, H, S, D), dtype=dtype) + key_t = jax.random.normal(k2, (B, H, S, D), dtype=dtype) + value = jax.random.normal(k3, (B, H, S, D), dtype=dtype) + dynamic_args = [query, key_t, value] + static_args = [] + return dynamic_args, static_args + +# Computation +def computation(query, key, value): + B, H, S, D = query.shape + scale = D ** -0.5 + attn = jnp.einsum('bhqd,bhkd->bhqk', query, key) * scale + mask = jnp.tril(jnp.ones((S, S))) + attn = jnp.where(mask, attn, -1e9) + attn = jax.nn.softmax(attn, axis=-1) + output = jnp.einsum('bhqk,bhkd->bhqd', attn, value) + return output \ No newline at end of file diff --git a/MaxKernel/evaluation/jaxbench_adapted_dataset/20k_Gemm_Multiply_LeakyReLU/kernel_task.yaml b/MaxKernel/evaluation/jaxbench_adapted_dataset/20k_Gemm_Multiply_LeakyReLU/kernel_task.yaml new file mode 100644 index 0000000..7472a8e --- /dev/null +++ b/MaxKernel/evaluation/jaxbench_adapted_dataset/20k_Gemm_Multiply_LeakyReLU/kernel_task.yaml @@ -0,0 +1,19 @@ +task_id: 20k_Gemm_Multiply_LeakyReLU +description: Kernel task for 20k_Gemm_Multiply_LeakyReLU +input_gen_code: |- + def get_inputs(dtype=jnp.float32): + import jax + import jax.numpy as jnp + + batch_size = 4096 + in_features = 8192 + out_features = 8192 + + key = jax.random.key(0) + x = jax.random.uniform(key, (batch_size, in_features), dtype=dtype) + weight = jnp.zeros((in_features, out_features), dtype=dtype) + bias = jnp.zeros(out_features, dtype=dtype) + + dynamic_args = [x, weight, bias] + static_args = [] + return dynamic_args, static_args diff --git a/MaxKernel/evaluation/jaxbench_adapted_dataset/20k_Gemm_Multiply_LeakyReLU/reference.py b/MaxKernel/evaluation/jaxbench_adapted_dataset/20k_Gemm_Multiply_LeakyReLU/reference.py new file mode 100644 index 0000000..ad6749d --- /dev/null +++ b/MaxKernel/evaluation/jaxbench_adapted_dataset/20k_Gemm_Multiply_LeakyReLU/reference.py @@ -0,0 +1,27 @@ +# Imports +import jax +import jax.numpy as jnp + +# Initialization +def get_inputs(dtype=jnp.float32): + batch_size = 4096 + in_features = 8192 + out_features = 8192 + + key = jax.random.key(0) + x = jax.random.uniform(key, (batch_size, in_features), dtype=dtype) + weight = jnp.zeros((in_features, out_features), dtype=dtype) + bias = jnp.zeros(out_features, dtype=dtype) + + dynamic_args = [x, weight, bias] + static_args = [] + return dynamic_args, static_args + +# Computation +def computation(x, weight, bias): + multiplier = 2.0 + negative_slope = 0.1 + x = jnp.matmul(x, weight) + bias + x = x * multiplier + x = jnp.where(x >= 0, x, x * negative_slope) + return x \ No newline at end of file diff --git a/MaxKernel/evaluation/jaxbench_adapted_dataset/21k_Gemm_Divide_Sum_Scaling/kernel_task.yaml b/MaxKernel/evaluation/jaxbench_adapted_dataset/21k_Gemm_Divide_Sum_Scaling/kernel_task.yaml new file mode 100644 index 0000000..e106ba8 --- /dev/null +++ b/MaxKernel/evaluation/jaxbench_adapted_dataset/21k_Gemm_Divide_Sum_Scaling/kernel_task.yaml @@ -0,0 +1,23 @@ +task_id: 21k_Gemm_Divide_Sum_Scaling +description: Kernel task for 21k_Gemm_Divide_Sum_Scaling +input_gen_code: |- + def get_inputs(): + import jax + import jax.numpy as jnp + import jax.lax as lax + + dtype = jnp.float32 + batch_size = 4096 + input_size = 8192 + hidden_size = 8192 + scaling_factor = 1.5 + + key = jax.random.key(0) + k1, k2 = jax.random.split(key) + x = jax.random.uniform(k1, (batch_size, input_size), dtype=dtype) + weight = jax.random.normal(k2, (input_size, hidden_size), dtype=dtype) + + dynamic_args = [x, weight] + static_args = [scaling_factor] + + return dynamic_args, static_args diff --git a/MaxKernel/evaluation/jaxbench_adapted_dataset/21k_Gemm_Divide_Sum_Scaling/reference.py b/MaxKernel/evaluation/jaxbench_adapted_dataset/21k_Gemm_Divide_Sum_Scaling/reference.py new file mode 100644 index 0000000..24655e3 --- /dev/null +++ b/MaxKernel/evaluation/jaxbench_adapted_dataset/21k_Gemm_Divide_Sum_Scaling/reference.py @@ -0,0 +1,34 @@ +# Imports +import jax +import jax.numpy as jnp +import jax.lax as lax + +# Initialization +def get_inputs(): + dtype = jnp.float32 + batch_size = 4096 + input_size = 8192 + hidden_size = 8192 + scaling_factor = 1.5 + + key = jax.random.key(0) + k1, k2 = jax.random.split(key) + x = jax.random.uniform(k1, (batch_size, input_size), dtype=dtype) + weight = jax.random.normal(k2, (input_size, hidden_size), dtype=dtype) + + dynamic_args = [x, weight] + static_args = [scaling_factor] + + return dynamic_args, static_args + +# Computation +def computation(x, weight, scaling_factor): + x = lax.dot_general( + x, weight.T, + dimension_numbers=(((1,), (0,)), ((), ())), + precision=lax.Precision.HIGHEST + ) + x = x / 2.0 + x = jnp.sum(x, axis=1, keepdims=True) + x = x * scaling_factor + return x \ No newline at end of file diff --git a/MaxKernel/evaluation/jaxbench_adapted_dataset/22k_Conv2d_InstanceNorm_Divide/kernel_task.yaml b/MaxKernel/evaluation/jaxbench_adapted_dataset/22k_Conv2d_InstanceNorm_Divide/kernel_task.yaml new file mode 100644 index 0000000..118f049 --- /dev/null +++ b/MaxKernel/evaluation/jaxbench_adapted_dataset/22k_Conv2d_InstanceNorm_Divide/kernel_task.yaml @@ -0,0 +1,25 @@ +task_id: 22k_Conv2d_InstanceNorm_Divide +description: Kernel task for 22k_Conv2d_InstanceNorm_Divide +input_gen_code: |- + def get_inputs(): + import jax + import jax.numpy as jnp + + batch_size = 128 + in_channels = 64 + out_channels = 128 + kernel_size = 3 + divide_by_value = 2.0 + + dtype = jnp.float32 + key = jax.random.key(0) + height = width = 128 + x = jax.random.uniform(key, (batch_size, in_channels, height, width), dtype=dtype) + weight = jnp.zeros((out_channels, in_channels, kernel_size, kernel_size), dtype=dtype) + conv_bias = jnp.zeros(out_channels, dtype=dtype) + in_weight = jnp.ones(out_channels, dtype=dtype) + in_bias = jnp.zeros(out_channels, dtype=dtype) + + dynamic_args = [x, weight, conv_bias, in_weight, in_bias] + static_args = [divide_by_value] + return dynamic_args, static_args diff --git a/MaxKernel/evaluation/jaxbench_adapted_dataset/22k_Conv2d_InstanceNorm_Divide/reference.py b/MaxKernel/evaluation/jaxbench_adapted_dataset/22k_Conv2d_InstanceNorm_Divide/reference.py new file mode 100644 index 0000000..63ed040 --- /dev/null +++ b/MaxKernel/evaluation/jaxbench_adapted_dataset/22k_Conv2d_InstanceNorm_Divide/reference.py @@ -0,0 +1,45 @@ +# Imports +import jax +import jax.numpy as jnp + +# Initialization +def get_inputs(): + batch_size = 128 + in_channels = 64 + out_channels = 128 + kernel_size = 3 + divide_by_value = 2.0 + + dtype = jnp.float32 + key = jax.random.key(0) + height = width = 128 + x = jax.random.uniform(key, (batch_size, in_channels, height, width), dtype=dtype) + weight = jnp.zeros((out_channels, in_channels, kernel_size, kernel_size), dtype=dtype) + conv_bias = jnp.zeros(out_channels, dtype=dtype) + in_weight = jnp.ones(out_channels, dtype=dtype) + in_bias = jnp.zeros(out_channels, dtype=dtype) + + dynamic_args = [x, weight, conv_bias, in_weight, in_bias] + static_args = [divide_by_value] + return dynamic_args, static_args + +# Computation +def computation(x, weight, conv_bias, in_weight, in_bias, divide_by_value): + x_nhwc = jnp.transpose(x, (0, 2, 3, 1)) + kernel = jnp.transpose(weight, (2, 3, 1, 0)) + x = jax.lax.conv_general_dilated( + x_nhwc, kernel, + window_strides=(1, 1), + padding='VALID', + dimension_numbers=('NHWC', 'HWIO', 'NHWC') + ) + x = x + conv_bias.reshape(1, 1, 1, -1) + x = jnp.transpose(x, (0, 3, 1, 2)) + + mean = jnp.mean(x, axis=(2, 3), keepdims=True) + var = jnp.var(x, axis=(2, 3), keepdims=True) + x = (x - mean) / jnp.sqrt(var + 1e-5) + x = x * in_weight.reshape(1, -1, 1, 1) + in_bias.reshape(1, -1, 1, 1) + + x = x / divide_by_value + return x \ No newline at end of file diff --git a/MaxKernel/evaluation/jaxbench_adapted_dataset/23k_Matmul_Sum_Max_AvgPool_LogSumExp_LogSumExp/kernel_task.yaml b/MaxKernel/evaluation/jaxbench_adapted_dataset/23k_Matmul_Sum_Max_AvgPool_LogSumExp_LogSumExp/kernel_task.yaml new file mode 100644 index 0000000..ba0118b --- /dev/null +++ b/MaxKernel/evaluation/jaxbench_adapted_dataset/23k_Matmul_Sum_Max_AvgPool_LogSumExp_LogSumExp/kernel_task.yaml @@ -0,0 +1,17 @@ +task_id: 23k_Matmul_Sum_Max_AvgPool_LogSumExp_LogSumExp +description: Kernel task for 23k_Matmul_Sum_Max_AvgPool_LogSumExp_LogSumExp +input_gen_code: |- + def get_inputs(dtype=jnp.float32): + import jax + import jax.numpy as jnp + + batch_size = 4096 + in_features = 8192 + out_features = 8192 + key = jax.random.key(0) + x = jax.random.uniform(key, (batch_size, in_features), dtype=dtype) + weight = jnp.zeros((in_features, out_features), dtype=dtype) + bias = jnp.zeros(out_features, dtype=dtype) + dynamic_args = [x, weight, bias] + static_args = [] + return dynamic_args, static_args diff --git a/MaxKernel/evaluation/jaxbench_adapted_dataset/23k_Matmul_Sum_Max_AvgPool_LogSumExp_LogSumExp/reference.py b/MaxKernel/evaluation/jaxbench_adapted_dataset/23k_Matmul_Sum_Max_AvgPool_LogSumExp_LogSumExp/reference.py new file mode 100644 index 0000000..e5dd6f7 --- /dev/null +++ b/MaxKernel/evaluation/jaxbench_adapted_dataset/23k_Matmul_Sum_Max_AvgPool_LogSumExp_LogSumExp/reference.py @@ -0,0 +1,26 @@ +# Imports +import jax +import jax.numpy as jnp + +# Initialization +def get_inputs(dtype=jnp.float32): + batch_size = 4096 + in_features = 8192 + out_features = 8192 + key = jax.random.key(0) + x = jax.random.uniform(key, (batch_size, in_features), dtype=dtype) + weight = jnp.zeros((in_features, out_features), dtype=dtype) + bias = jnp.zeros(out_features, dtype=dtype) + dynamic_args = [x, weight, bias] + static_args = [] + return dynamic_args, static_args + +# Computation +def computation(x, weight, bias): + x = jnp.matmul(x, weight.T) + bias + x = jnp.sum(x, axis=1, keepdims=True) + x = jnp.max(x, axis=1, keepdims=True) + x = jnp.mean(x, axis=1, keepdims=True) + x = jax.scipy.special.logsumexp(x, axis=1, keepdims=True) + x = jax.scipy.special.logsumexp(x, axis=1, keepdims=True) + return x \ No newline at end of file diff --git a/MaxKernel/evaluation/jaxbench_adapted_dataset/24k_Matmul_Scale_ResidualAdd_Clamp_LogSumExp_Mish/kernel_task.yaml b/MaxKernel/evaluation/jaxbench_adapted_dataset/24k_Matmul_Scale_ResidualAdd_Clamp_LogSumExp_Mish/kernel_task.yaml new file mode 100644 index 0000000..c5bbd6d --- /dev/null +++ b/MaxKernel/evaluation/jaxbench_adapted_dataset/24k_Matmul_Scale_ResidualAdd_Clamp_LogSumExp_Mish/kernel_task.yaml @@ -0,0 +1,23 @@ +task_id: 24k_Matmul_Scale_ResidualAdd_Clamp_LogSumExp_Mish +description: Kernel task for 24k_Matmul_Scale_ResidualAdd_Clamp_LogSumExp_Mish +input_gen_code: |- + def get_inputs(): + import jax + import jax.numpy as jnp + + batch_size = 4096 + input_size = 8192 + hidden_size = 8192 + scale_factor = 2.0 + clamp_min = -10.0 + clamp_max = 10.0 + dtype = jnp.float32 + + key = jax.random.key(0) + x = jax.random.uniform(key, (batch_size, input_size), dtype=dtype) + weight = jnp.zeros((input_size, hidden_size), dtype=dtype) + bias = jnp.zeros(hidden_size, dtype=dtype) + + dynamic_args = [x, weight, bias] + static_args = [scale_factor, clamp_min, clamp_max] + return dynamic_args, static_args diff --git a/MaxKernel/evaluation/jaxbench_adapted_dataset/24k_Matmul_Scale_ResidualAdd_Clamp_LogSumExp_Mish/reference.py b/MaxKernel/evaluation/jaxbench_adapted_dataset/24k_Matmul_Scale_ResidualAdd_Clamp_LogSumExp_Mish/reference.py new file mode 100644 index 0000000..2931b42 --- /dev/null +++ b/MaxKernel/evaluation/jaxbench_adapted_dataset/24k_Matmul_Scale_ResidualAdd_Clamp_LogSumExp_Mish/reference.py @@ -0,0 +1,34 @@ +# Imports +import jax +import jax.numpy as jnp + +# Initialization +def get_inputs(): + batch_size = 4096 + input_size = 8192 + hidden_size = 8192 + scale_factor = 2.0 + clamp_min = -10.0 + clamp_max = 10.0 + dtype = jnp.float32 + + key = jax.random.key(0) + x = jax.random.uniform(key, (batch_size, input_size), dtype=dtype) + weight = jnp.zeros((input_size, hidden_size), dtype=dtype) + bias = jnp.zeros(hidden_size, dtype=dtype) + + dynamic_args = [x, weight, bias] + static_args = [scale_factor, clamp_min, clamp_max] + return dynamic_args, static_args + +# Computation +def computation(x, weight, bias, scale_factor, clamp_min, clamp_max): + x = jnp.matmul(x, weight.T) + bias + x = x * scale_factor + x = x + x + x = jnp.clip(x, clamp_min, clamp_max) + x = jax.scipy.special.logsumexp(x, axis=1, keepdims=True) + softplus_x = jnp.logaddexp(x, 0.0) + mish_x = x * jnp.tanh(softplus_x) + x = x * mish_x + return x \ No newline at end of file diff --git a/MaxKernel/evaluation/jaxbench_adapted_dataset/25k_Conv3d_GroupNorm_Mean/kernel_task.yaml b/MaxKernel/evaluation/jaxbench_adapted_dataset/25k_Conv3d_GroupNorm_Mean/kernel_task.yaml new file mode 100644 index 0000000..79d3f37 --- /dev/null +++ b/MaxKernel/evaluation/jaxbench_adapted_dataset/25k_Conv3d_GroupNorm_Mean/kernel_task.yaml @@ -0,0 +1,21 @@ +task_id: 25k_Conv3d_GroupNorm_Mean +description: Kernel task for 25k_Conv3d_GroupNorm_Mean +input_gen_code: |- + def get_inputs(dtype=jnp.float32): + import jax + import jax.numpy as jnp + + key = jax.random.key(0) + batch_size = 128 + in_channels = 3 + out_channels = 24 + kernel_size = 3 + D, H, W = 24, 32, 32 + x = jax.random.uniform(key, (batch_size, in_channels, D, H, W), dtype=dtype) + weight = jnp.zeros((out_channels, in_channels, kernel_size, kernel_size, kernel_size), dtype=dtype) + conv_bias = jnp.zeros(out_channels, dtype=dtype) + gamma = jnp.ones(out_channels, dtype=dtype) + beta = jnp.zeros(out_channels, dtype=dtype) + dynamic_args = [x, weight, conv_bias, gamma, beta] + static_args = [] + return dynamic_args, static_args diff --git a/MaxKernel/evaluation/jaxbench_adapted_dataset/25k_Conv3d_GroupNorm_Mean/reference.py b/MaxKernel/evaluation/jaxbench_adapted_dataset/25k_Conv3d_GroupNorm_Mean/reference.py new file mode 100644 index 0000000..057d346 --- /dev/null +++ b/MaxKernel/evaluation/jaxbench_adapted_dataset/25k_Conv3d_GroupNorm_Mean/reference.py @@ -0,0 +1,44 @@ +# Imports +import jax +import jax.numpy as jnp + +# Initialization +def get_inputs(dtype=jnp.float32): + key = jax.random.key(0) + batch_size = 128 + in_channels = 3 + out_channels = 24 + kernel_size = 3 + D, H, W = 24, 32, 32 + x = jax.random.uniform(key, (batch_size, in_channels, D, H, W), dtype=dtype) + weight = jnp.zeros((out_channels, in_channels, kernel_size, kernel_size, kernel_size), dtype=dtype) + conv_bias = jnp.zeros(out_channels, dtype=dtype) + gamma = jnp.ones(out_channels, dtype=dtype) + beta = jnp.zeros(out_channels, dtype=dtype) + dynamic_args = [x, weight, conv_bias, gamma, beta] + static_args = [] + return dynamic_args, static_args + +# Computation +def computation(x, weight, conv_bias, gamma, beta): + num_groups = 8 + x = jnp.transpose(x, (0, 2, 3, 4, 1)) + kernel = jnp.transpose(weight, (2, 3, 4, 1, 0)) + x = jax.lax.conv_general_dilated( + x, kernel, + window_strides=(1, 1, 1), + padding='VALID', + dimension_numbers=('NDHWC', 'DHWIO', 'NDHWC') + ) + x = x + conv_bias.reshape(1, 1, 1, 1, -1) + x = jnp.transpose(x, (0, 4, 1, 2, 3)) + N, C, D, H, W = x.shape + G = num_groups + x = x.reshape(N, G, C // G, D, H, W) + mean = jnp.mean(x, axis=(2, 3, 4, 5), keepdims=True) + var = jnp.var(x, axis=(2, 3, 4, 5), keepdims=True) + x = (x - mean) / jnp.sqrt(var + 1e-5) + x = x.reshape(N, C, D, H, W) + x = x * gamma.reshape(1, -1, 1, 1, 1) + beta.reshape(1, -1, 1, 1, 1) + x = jnp.mean(x, axis=(1, 2, 3, 4)) + return x \ No newline at end of file diff --git a/MaxKernel/evaluation/jaxbench_adapted_dataset/26k_BMM_InstanceNorm_Sum_ResidualAdd_Multiply/kernel_task.yaml b/MaxKernel/evaluation/jaxbench_adapted_dataset/26k_BMM_InstanceNorm_Sum_ResidualAdd_Multiply/kernel_task.yaml new file mode 100644 index 0000000..1b19a95 --- /dev/null +++ b/MaxKernel/evaluation/jaxbench_adapted_dataset/26k_BMM_InstanceNorm_Sum_ResidualAdd_Multiply/kernel_task.yaml @@ -0,0 +1,20 @@ +task_id: 26k_BMM_InstanceNorm_Sum_ResidualAdd_Multiply +description: Kernel task for 26k_BMM_InstanceNorm_Sum_ResidualAdd_Multiply +input_gen_code: |- + def get_inputs(dtype=jnp.float32): + import jax + import jax.numpy as jnp + + key = jax.random.key(0) + k1, k2 = jax.random.split(key) + batch_size, in_features, out_features = 4096, 8192, 8192 + x = jax.random.uniform(k1, (batch_size, in_features), dtype=dtype) + y = jax.random.uniform(k2, (batch_size, out_features), dtype=dtype) + bmm_weight = jnp.zeros((out_features, in_features), dtype=dtype) + bmm_bias = jnp.zeros(out_features, dtype=dtype) + in_weight = jnp.ones(out_features, dtype=dtype) + in_bias = jnp.zeros(out_features, dtype=dtype) + + dynamic_args = [x, y, bmm_weight, bmm_bias, in_weight, in_bias] + static_args = [] + return dynamic_args, static_args diff --git a/MaxKernel/evaluation/jaxbench_adapted_dataset/26k_BMM_InstanceNorm_Sum_ResidualAdd_Multiply/reference.py b/MaxKernel/evaluation/jaxbench_adapted_dataset/26k_BMM_InstanceNorm_Sum_ResidualAdd_Multiply/reference.py new file mode 100644 index 0000000..ac3e66d --- /dev/null +++ b/MaxKernel/evaluation/jaxbench_adapted_dataset/26k_BMM_InstanceNorm_Sum_ResidualAdd_Multiply/reference.py @@ -0,0 +1,33 @@ +# Imports +import jax +import jax.numpy as jnp + +# Initialization +def get_inputs(dtype=jnp.float32): + key = jax.random.key(0) + k1, k2 = jax.random.split(key) + batch_size, in_features, out_features = 4096, 8192, 8192 + x = jax.random.uniform(k1, (batch_size, in_features), dtype=dtype) + y = jax.random.uniform(k2, (batch_size, out_features), dtype=dtype) + bmm_weight = jnp.zeros((out_features, in_features), dtype=dtype) + bmm_bias = jnp.zeros(out_features, dtype=dtype) + in_weight = jnp.ones(out_features, dtype=dtype) + in_bias = jnp.zeros(out_features, dtype=dtype) + + dynamic_args = [x, y, bmm_weight, bmm_bias, in_weight, in_bias] + static_args = [] + return dynamic_args, static_args + +# Computation +def computation(x, y, bmm_weight, bmm_bias, in_weight, in_bias): + eps = 1e-5 + x = x @ bmm_weight.T + bmm_bias + x = jnp.expand_dims(jnp.expand_dims(x, 2), 3) + mean = jnp.mean(x, axis=(2, 3), keepdims=True) + var = jnp.var(x, axis=(2, 3), keepdims=True) + x = (x - mean) / jnp.sqrt(var + eps) + x = x * jnp.reshape(in_weight, (1, -1, 1, 1)) + jnp.reshape(in_bias, (1, -1, 1, 1)) + x = jnp.squeeze(jnp.squeeze(x, axis=3), axis=2) + x = x + y + x = x * y + return x \ No newline at end of file diff --git a/MaxKernel/evaluation/jaxbench_adapted_dataset/27k_Matmul_Mish_Mish/kernel_task.yaml b/MaxKernel/evaluation/jaxbench_adapted_dataset/27k_Matmul_Mish_Mish/kernel_task.yaml new file mode 100644 index 0000000..f5c3877 --- /dev/null +++ b/MaxKernel/evaluation/jaxbench_adapted_dataset/27k_Matmul_Mish_Mish/kernel_task.yaml @@ -0,0 +1,17 @@ +task_id: 27k_Matmul_Mish_Mish +description: Kernel task for 27k_Matmul_Mish_Mish +input_gen_code: |- + def get_inputs(dtype=jnp.float32): + import jax + import jax.numpy as jnp + + batch_size = 4096 + in_features = 8192 + out_features = 8192 + key = jax.random.key(0) + x = jax.random.uniform(key, (batch_size, in_features), dtype=dtype) + weight = jnp.zeros((in_features, out_features), dtype=dtype) + bias = jnp.zeros(out_features, dtype=dtype) + dynamic_args = [x, weight, bias] + static_args = [] + return dynamic_args, static_args diff --git a/MaxKernel/evaluation/jaxbench_adapted_dataset/27k_Matmul_Mish_Mish/reference.py b/MaxKernel/evaluation/jaxbench_adapted_dataset/27k_Matmul_Mish_Mish/reference.py new file mode 100644 index 0000000..5e9c8c0 --- /dev/null +++ b/MaxKernel/evaluation/jaxbench_adapted_dataset/27k_Matmul_Mish_Mish/reference.py @@ -0,0 +1,23 @@ +# Imports +import jax +import jax.numpy as jnp + +# Initialization +def get_inputs(dtype=jnp.float32): + batch_size = 4096 + in_features = 8192 + out_features = 8192 + key = jax.random.key(0) + x = jax.random.uniform(key, (batch_size, in_features), dtype=dtype) + weight = jnp.zeros((in_features, out_features), dtype=dtype) + bias = jnp.zeros(out_features, dtype=dtype) + dynamic_args = [x, weight, bias] + static_args = [] + return dynamic_args, static_args + +# Computation +def computation(x, weight, bias): + x = x @ weight + bias + x = x * jnp.tanh(jax.nn.softplus(x)) + x = x * jnp.tanh(jax.nn.softplus(x)) + return x \ No newline at end of file diff --git a/MaxKernel/evaluation/jaxbench_adapted_dataset/28k_ConvTranspose3d_LayerNorm_GELU_Scaling/kernel_task.yaml b/MaxKernel/evaluation/jaxbench_adapted_dataset/28k_ConvTranspose3d_LayerNorm_GELU_Scaling/kernel_task.yaml new file mode 100644 index 0000000..73ccc95 --- /dev/null +++ b/MaxKernel/evaluation/jaxbench_adapted_dataset/28k_ConvTranspose3d_LayerNorm_GELU_Scaling/kernel_task.yaml @@ -0,0 +1,31 @@ +task_id: 28k_ConvTranspose3d_LayerNorm_GELU_Scaling +description: Kernel task for 28k_ConvTranspose3d_LayerNorm_GELU_Scaling +input_gen_code: |- + def get_inputs(dtype=jnp.float32): + import jax + import jax.numpy as jnp + import jax.lax as lax + + key = jax.random.key(0) + k1, k2 = jax.random.split(key) + batch_size = 32 + in_channels = 32 + out_channels = 64 + kernel_size_val = 4 + D, H, W = 16, 32, 32 + + x = jax.random.uniform(k1, (batch_size, in_channels, D, H, W), dtype=dtype) + conv_weight = jax.random.normal(k2, (in_channels, out_channels, kernel_size_val, kernel_size_val, kernel_size_val), dtype=dtype) + conv_bias = jnp.zeros(out_channels, dtype=dtype) + ln_weight = jnp.ones(out_channels, dtype=dtype) + ln_bias = jnp.zeros(out_channels, dtype=dtype) + + stride_val = 2 + padding_val = 1 + eps_val = 1e-5 + scaling_factor_val = 1.0 + + dynamic_args = [x, conv_weight, conv_bias, ln_weight, ln_bias] + static_args = [stride_val, padding_val, kernel_size_val, eps_val, scaling_factor_val] + + return dynamic_args, static_args diff --git a/MaxKernel/evaluation/jaxbench_adapted_dataset/28k_ConvTranspose3d_LayerNorm_GELU_Scaling/reference.py b/MaxKernel/evaluation/jaxbench_adapted_dataset/28k_ConvTranspose3d_LayerNorm_GELU_Scaling/reference.py new file mode 100644 index 0000000..0b6ca41 --- /dev/null +++ b/MaxKernel/evaluation/jaxbench_adapted_dataset/28k_ConvTranspose3d_LayerNorm_GELU_Scaling/reference.py @@ -0,0 +1,68 @@ +# Imports +import jax +import jax.numpy as jnp +import jax.lax as lax + +# Initialization +def get_inputs(dtype=jnp.float32): + key = jax.random.key(0) + k1, k2 = jax.random.split(key) + batch_size = 32 + in_channels = 32 + out_channels = 64 + kernel_size_val = 4 + D, H, W = 16, 32, 32 + + x = jax.random.uniform(k1, (batch_size, in_channels, D, H, W), dtype=dtype) + conv_weight = jax.random.normal(k2, (in_channels, out_channels, kernel_size_val, kernel_size_val, kernel_size_val), dtype=dtype) + conv_bias = jnp.zeros(out_channels, dtype=dtype) + ln_weight = jnp.ones(out_channels, dtype=dtype) + ln_bias = jnp.zeros(out_channels, dtype=dtype) + + stride_val = 2 + padding_val = 1 + eps_val = 1e-5 + scaling_factor_val = 1.0 + + dynamic_args = [x, conv_weight, conv_bias, ln_weight, ln_bias] + static_args = [stride_val, padding_val, kernel_size_val, eps_val, scaling_factor_val] + + return dynamic_args, static_args + +# Computation +def computation(x, conv_weight, conv_bias, ln_weight, ln_bias, stride, padding, kernel_size, eps, scaling_factor): + x = jnp.transpose(x, (0, 2, 3, 4, 1)) + kernel = jnp.transpose(conv_weight, (2, 3, 4, 1, 0)) + kernel = jnp.flip(kernel, axis=(0, 1, 2)) + + batch_size, d_in, h_in, w_in, channels = x.shape + k = kernel_size + + d_dilated = d_in + (d_in - 1) * (stride - 1) + h_dilated = h_in + (h_in - 1) * (stride - 1) + w_dilated = w_in + (w_in - 1) * (stride - 1) + x_dilated = jnp.zeros((batch_size, d_dilated, h_dilated, w_dilated, channels), dtype=x.dtype) + x_dilated = x_dilated.at[:, ::stride, ::stride, ::stride, :].set(x) + x = x_dilated + + pad = k - 1 - padding + jax_padding = ((pad, pad), (pad, pad), (pad, pad)) + + x = lax.conv_general_dilated( + x, kernel, + window_strides=(1, 1, 1), + padding=jax_padding, + dimension_numbers=('NDHWC', 'DHWOI', 'NDHWC') + ) + x = x + conv_bias.reshape(1, 1, 1, 1, -1) + + x = jnp.transpose(x, (0, 4, 1, 2, 3)) + + mean = jnp.mean(x, axis=-1, keepdims=True) + var = jnp.mean((x - mean) ** 2, axis=-1, keepdims=True) + x = (x - mean) / jnp.sqrt(var + eps) + x = x * ln_weight + ln_bias + + x = jax.nn.gelu(x) + x = x * scaling_factor + return x \ No newline at end of file diff --git a/MaxKernel/evaluation/jaxbench_adapted_dataset/29k_Matmul_Swish_Sum_GroupNorm/kernel_task.yaml b/MaxKernel/evaluation/jaxbench_adapted_dataset/29k_Matmul_Swish_Sum_GroupNorm/kernel_task.yaml new file mode 100644 index 0000000..72bc7da --- /dev/null +++ b/MaxKernel/evaluation/jaxbench_adapted_dataset/29k_Matmul_Swish_Sum_GroupNorm/kernel_task.yaml @@ -0,0 +1,20 @@ +task_id: 29k_Matmul_Swish_Sum_GroupNorm +description: Kernel task for 29k_Matmul_Swish_Sum_GroupNorm +input_gen_code: |- + def get_inputs(dtype=jnp.float32): + import jax + import jax.numpy as jnp + + key = jax.random.key(0) + batch_size = 8192 + in_features = 4096 + out_features = 4096 + num_groups = 64 + x = jax.random.uniform(key, (batch_size, in_features), dtype=dtype) + weight = jnp.zeros((in_features, out_features), dtype=dtype) + bias = jnp.zeros(out_features, dtype=dtype) + gn_weight = jnp.ones(out_features, dtype=dtype) + gn_bias = jnp.zeros(out_features, dtype=dtype) + dynamic_args = [x, weight, bias, gn_weight, gn_bias] + static_args = [num_groups, out_features] + return dynamic_args, static_args diff --git a/MaxKernel/evaluation/jaxbench_adapted_dataset/29k_Matmul_Swish_Sum_GroupNorm/reference.py b/MaxKernel/evaluation/jaxbench_adapted_dataset/29k_Matmul_Swish_Sum_GroupNorm/reference.py new file mode 100644 index 0000000..d8f5e3b --- /dev/null +++ b/MaxKernel/evaluation/jaxbench_adapted_dataset/29k_Matmul_Swish_Sum_GroupNorm/reference.py @@ -0,0 +1,33 @@ +# Imports +import jax +import jax.numpy as jnp + +# Initialization +def get_inputs(dtype=jnp.float32): + key = jax.random.key(0) + batch_size = 8192 + in_features = 4096 + out_features = 4096 + num_groups = 64 + x = jax.random.uniform(key, (batch_size, in_features), dtype=dtype) + weight = jnp.zeros((in_features, out_features), dtype=dtype) + bias = jnp.zeros(out_features, dtype=dtype) + gn_weight = jnp.ones(out_features, dtype=dtype) + gn_bias = jnp.zeros(out_features, dtype=dtype) + dynamic_args = [x, weight, bias, gn_weight, gn_bias] + static_args = [num_groups, out_features] + return dynamic_args, static_args + +# Computation +def computation(x, weight, bias, gn_weight, gn_bias, num_groups, out_features): + x = jnp.matmul(x, weight) + x = jax.nn.sigmoid(x) * x + x = x + bias + group_size = out_features // num_groups + x = x.reshape(-1, num_groups, group_size) + mean = jnp.mean(x, axis=-1, keepdims=True) + var = jnp.var(x, axis=-1, keepdims=True) + x = (x - mean) / jnp.sqrt(var + 1e-5) + x = x.reshape(-1, out_features) + x = x * gn_weight + gn_bias + return x \ No newline at end of file diff --git a/MaxKernel/evaluation/jaxbench_adapted_dataset/2p_GQA_Attention/kernel_task.yaml b/MaxKernel/evaluation/jaxbench_adapted_dataset/2p_GQA_Attention/kernel_task.yaml new file mode 100644 index 0000000..1a77447 --- /dev/null +++ b/MaxKernel/evaluation/jaxbench_adapted_dataset/2p_GQA_Attention/kernel_task.yaml @@ -0,0 +1,29 @@ +task_id: 2p_GQA_Attention +description: Kernel task for 2p_GQA_Attention +input_gen_code: |- + def get_inputs(dtype=jnp.bfloat16): + import jax + import jax.numpy as jnp + from functools import partial + + CONFIG = { + 'name': 'llama3_405b_gqa', + 'model': 'Llama-3.1-405B', + 'operator': 'gqa_attention', + 'batch': 4, + 'seq_len': 4096, + 'num_query_heads': 128, + 'num_kv_heads': 8, + 'head_dim': 128, + 'emb_dim': 16384, + } + key = jax.random.key(42) + k1, k2, k3 = jax.random.split(key, 3) + B, S = CONFIG['batch'], CONFIG['seq_len'] + Hq, Hkv, D = CONFIG['num_query_heads'], CONFIG['num_kv_heads'], CONFIG['head_dim'] + query = jax.random.normal(k1, (B, S, Hq, D), dtype=dtype) + key_t = jax.random.normal(k2, (B, S, Hkv, D), dtype=dtype) + value = jax.random.normal(k3, (B, S, Hkv, D), dtype=dtype) + dynamic_args = [query, key_t, value] + static_args = [] + return dynamic_args, static_args diff --git a/MaxKernel/evaluation/jaxbench_adapted_dataset/2p_GQA_Attention/reference.py b/MaxKernel/evaluation/jaxbench_adapted_dataset/2p_GQA_Attention/reference.py new file mode 100644 index 0000000..2722104 --- /dev/null +++ b/MaxKernel/evaluation/jaxbench_adapted_dataset/2p_GQA_Attention/reference.py @@ -0,0 +1,46 @@ +# Imports +import jax +import jax.numpy as jnp +from functools import partial + +# Initialization +def get_inputs(dtype=jnp.bfloat16): + CONFIG = { + 'name': 'llama3_405b_gqa', + 'model': 'Llama-3.1-405B', + 'operator': 'gqa_attention', + 'batch': 4, + 'seq_len': 4096, + 'num_query_heads': 128, + 'num_kv_heads': 8, + 'head_dim': 128, + 'emb_dim': 16384, + } + key = jax.random.key(42) + k1, k2, k3 = jax.random.split(key, 3) + B, S = CONFIG['batch'], CONFIG['seq_len'] + Hq, Hkv, D = CONFIG['num_query_heads'], CONFIG['num_kv_heads'], CONFIG['head_dim'] + query = jax.random.normal(k1, (B, S, Hq, D), dtype=dtype) + key_t = jax.random.normal(k2, (B, S, Hkv, D), dtype=dtype) + value = jax.random.normal(k3, (B, S, Hkv, D), dtype=dtype) + dynamic_args = [query, key_t, value] + static_args = [] + return dynamic_args, static_args + +# Computation +def computation(query, key, value): + B, S, Hq, D = query.shape + Hkv = key.shape[2] + G = Hq // Hkv + key = jnp.repeat(key[:, :, :, None, :], G, axis=3).reshape(B, S, Hq, D) + value = jnp.repeat(value[:, :, :, None, :], G, axis=3).reshape(B, S, Hq, D) + q = query.transpose(0, 2, 1, 3) + k = key.transpose(0, 2, 1, 3) + v = value.transpose(0, 2, 1, 3) + scale = D ** -0.5 + attn = jnp.einsum('bhqd,bhkd->bhqk', q, k) * scale + mask = jnp.tril(jnp.ones((S, S))) + attn = jnp.where(mask, attn, -1e9) + attn = jax.nn.softmax(attn, axis=-1) + out = jnp.einsum('bhqk,bhkd->bhqd', attn, v) + return out.transpose(0, 2, 1, 3) \ No newline at end of file diff --git a/MaxKernel/evaluation/jaxbench_adapted_dataset/30k_Matmul_Scaling_ResidualAdd/kernel_task.yaml b/MaxKernel/evaluation/jaxbench_adapted_dataset/30k_Matmul_Scaling_ResidualAdd/kernel_task.yaml new file mode 100644 index 0000000..9f00a82 --- /dev/null +++ b/MaxKernel/evaluation/jaxbench_adapted_dataset/30k_Matmul_Scaling_ResidualAdd/kernel_task.yaml @@ -0,0 +1,17 @@ +task_id: 30k_Matmul_Scaling_ResidualAdd +description: Kernel task for 30k_Matmul_Scaling_ResidualAdd +input_gen_code: |- + def get_inputs(dtype=jnp.float32): + import jax + import jax.numpy as jnp + + batch_size = 16384 + in_features = 4096 + out_features = 4096 + key = jax.random.key(0) + x = jax.random.uniform(key, (batch_size, in_features), dtype=dtype) + weight = jnp.zeros((in_features, out_features), dtype=dtype) + bias = jnp.zeros(out_features, dtype=dtype) + dynamic_args = [x, weight, bias] + static_args = [] + return dynamic_args, static_args diff --git a/MaxKernel/evaluation/jaxbench_adapted_dataset/30k_Matmul_Scaling_ResidualAdd/reference.py b/MaxKernel/evaluation/jaxbench_adapted_dataset/30k_Matmul_Scaling_ResidualAdd/reference.py new file mode 100644 index 0000000..9ffd75d --- /dev/null +++ b/MaxKernel/evaluation/jaxbench_adapted_dataset/30k_Matmul_Scaling_ResidualAdd/reference.py @@ -0,0 +1,25 @@ +# Imports +import jax +import jax.numpy as jnp + +# Initialization +def get_inputs(dtype=jnp.float32): + batch_size = 16384 + in_features = 4096 + out_features = 4096 + key = jax.random.key(0) + x = jax.random.uniform(key, (batch_size, in_features), dtype=dtype) + weight = jnp.zeros((in_features, out_features), dtype=dtype) + bias = jnp.zeros(out_features, dtype=dtype) + dynamic_args = [x, weight, bias] + static_args = [] + return dynamic_args, static_args + +# Computation +def computation(x, weight, bias): + scaling_factor = 0.5 + x = jnp.matmul(x, weight) + bias + original_x = x + x = x * scaling_factor + x = x + original_x + return x \ No newline at end of file diff --git a/MaxKernel/evaluation/jaxbench_adapted_dataset/31k_Gemm_BatchNorm_GELU_ReLU/kernel_task.yaml b/MaxKernel/evaluation/jaxbench_adapted_dataset/31k_Gemm_BatchNorm_GELU_ReLU/kernel_task.yaml new file mode 100644 index 0000000..5f040ba --- /dev/null +++ b/MaxKernel/evaluation/jaxbench_adapted_dataset/31k_Gemm_BatchNorm_GELU_ReLU/kernel_task.yaml @@ -0,0 +1,20 @@ +task_id: 31k_Gemm_BatchNorm_GELU_ReLU +description: Kernel task for 31k_Gemm_BatchNorm_GELU_ReLU +input_gen_code: |- + def get_inputs(dtype=jnp.float32): + import jax + import jax.numpy as jnp + + batch_size = 16384 + in_features = 8192 + out_features = 8192 + key = jax.random.key(0) + k1, k2, k3 = jax.random.split(key, 3) + x = jax.random.uniform(k1, (batch_size, in_features), dtype=dtype) + gemm_weight = jax.random.normal(k2, (out_features, in_features), dtype=dtype) + gemm_bias = jax.random.normal(k3, (out_features,), dtype=dtype) + bn_weight = jnp.ones(out_features, dtype=dtype) + bn_bias = jnp.zeros(out_features, dtype=dtype) + dynamic_args = [x, gemm_weight, gemm_bias, bn_weight, bn_bias] + static_args = [] + return dynamic_args, static_args diff --git a/MaxKernel/evaluation/jaxbench_adapted_dataset/31k_Gemm_BatchNorm_GELU_ReLU/reference.py b/MaxKernel/evaluation/jaxbench_adapted_dataset/31k_Gemm_BatchNorm_GELU_ReLU/reference.py new file mode 100644 index 0000000..c585b6f --- /dev/null +++ b/MaxKernel/evaluation/jaxbench_adapted_dataset/31k_Gemm_BatchNorm_GELU_ReLU/reference.py @@ -0,0 +1,30 @@ +# Imports +import jax +import jax.numpy as jnp + +# Initialization +def get_inputs(dtype=jnp.float32): + batch_size = 16384 + in_features = 8192 + out_features = 8192 + key = jax.random.key(0) + k1, k2, k3 = jax.random.split(key, 3) + x = jax.random.uniform(k1, (batch_size, in_features), dtype=dtype) + gemm_weight = jax.random.normal(k2, (out_features, in_features), dtype=dtype) + gemm_bias = jax.random.normal(k3, (out_features,), dtype=dtype) + bn_weight = jnp.ones(out_features, dtype=dtype) + bn_bias = jnp.zeros(out_features, dtype=dtype) + dynamic_args = [x, gemm_weight, gemm_bias, bn_weight, bn_bias] + static_args = [] + return dynamic_args, static_args + +# Computation +def computation(x, gemm_weight, gemm_bias, bn_weight, bn_bias): + eps = 1e-5 + x = jnp.matmul(x, gemm_weight.T) + gemm_bias + mean = jnp.mean(x, axis=0, keepdims=True) + var = jnp.mean((x - mean) ** 2, axis=0, keepdims=True) + x = (x - mean) / jnp.sqrt(var + eps) * bn_weight + bn_bias + x = jax.nn.gelu(x) + x = jax.nn.relu(x) + return x \ No newline at end of file diff --git a/MaxKernel/evaluation/jaxbench_adapted_dataset/32k_Gemm_Sigmoid_LogSumExp/kernel_task.yaml b/MaxKernel/evaluation/jaxbench_adapted_dataset/32k_Gemm_Sigmoid_LogSumExp/kernel_task.yaml new file mode 100644 index 0000000..0a4a64e --- /dev/null +++ b/MaxKernel/evaluation/jaxbench_adapted_dataset/32k_Gemm_Sigmoid_LogSumExp/kernel_task.yaml @@ -0,0 +1,20 @@ +task_id: 32k_Gemm_Sigmoid_LogSumExp +description: Kernel task for 32k_Gemm_Sigmoid_LogSumExp +input_gen_code: |- + def get_inputs(dtype=jnp.float32): + import jax + import jax.numpy as jnp + + batch_size = 16384 + input_size = 2048 + hidden_size = 4096 + output_size = 1024 + key = jax.random.key(0) + x = jax.random.uniform(key, (batch_size, input_size), dtype=dtype) + w1 = jnp.zeros((hidden_size, input_size), dtype=dtype) + b1 = jnp.zeros(hidden_size, dtype=dtype) + w2 = jnp.zeros((output_size, hidden_size), dtype=dtype) + b2 = jnp.zeros(output_size, dtype=dtype) + dynamic_args = [x, w1, b1, w2, b2] + static_args = [] + return dynamic_args, static_args diff --git a/MaxKernel/evaluation/jaxbench_adapted_dataset/32k_Gemm_Sigmoid_LogSumExp/reference.py b/MaxKernel/evaluation/jaxbench_adapted_dataset/32k_Gemm_Sigmoid_LogSumExp/reference.py new file mode 100644 index 0000000..9507242 --- /dev/null +++ b/MaxKernel/evaluation/jaxbench_adapted_dataset/32k_Gemm_Sigmoid_LogSumExp/reference.py @@ -0,0 +1,27 @@ +# Imports +import jax +import jax.numpy as jnp + +# Initialization +def get_inputs(dtype=jnp.float32): + batch_size = 16384 + input_size = 2048 + hidden_size = 4096 + output_size = 1024 + key = jax.random.key(0) + x = jax.random.uniform(key, (batch_size, input_size), dtype=dtype) + w1 = jnp.zeros((hidden_size, input_size), dtype=dtype) + b1 = jnp.zeros(hidden_size, dtype=dtype) + w2 = jnp.zeros((output_size, hidden_size), dtype=dtype) + b2 = jnp.zeros(output_size, dtype=dtype) + dynamic_args = [x, w1, b1, w2, b2] + static_args = [] + return dynamic_args, static_args + +# Computation +def computation(x, w1, b1, w2, b2): + x = jnp.matmul(x, w1.T) + b1 + x = jax.nn.sigmoid(x) + x = jnp.matmul(x, w2.T) + b2 + x = jax.scipy.special.logsumexp(x, axis=1) + return x \ No newline at end of file diff --git a/MaxKernel/evaluation/jaxbench_adapted_dataset/33k_Conv3d_Mish_Tanh/kernel_task.yaml b/MaxKernel/evaluation/jaxbench_adapted_dataset/33k_Conv3d_Mish_Tanh/kernel_task.yaml new file mode 100644 index 0000000..7551edd --- /dev/null +++ b/MaxKernel/evaluation/jaxbench_adapted_dataset/33k_Conv3d_Mish_Tanh/kernel_task.yaml @@ -0,0 +1,21 @@ +task_id: 33k_Conv3d_Mish_Tanh +description: Kernel task for 33k_Conv3d_Mish_Tanh +input_gen_code: |- + def get_inputs(dtype=jnp.float32): + import jax + import jax.numpy as jnp + + batch_size = 16 + in_channels = 32 + out_channels = 64 + kernel_size = 3 + D, H, W = 32, 64, 64 + + key = jax.random.key(0) + x = jax.random.uniform(key, (batch_size, in_channels, D, H, W), dtype=dtype) + weight = jnp.zeros((out_channels, in_channels, kernel_size, kernel_size, kernel_size), dtype=dtype) + bias = jnp.zeros(out_channels, dtype=dtype) + + dynamic_args = [x, weight, bias] + static_args = [] + return dynamic_args, static_args diff --git a/MaxKernel/evaluation/jaxbench_adapted_dataset/33k_Conv3d_Mish_Tanh/reference.py b/MaxKernel/evaluation/jaxbench_adapted_dataset/33k_Conv3d_Mish_Tanh/reference.py new file mode 100644 index 0000000..97ce080 --- /dev/null +++ b/MaxKernel/evaluation/jaxbench_adapted_dataset/33k_Conv3d_Mish_Tanh/reference.py @@ -0,0 +1,36 @@ +# Imports +import jax +import jax.numpy as jnp + +# Initialization +def get_inputs(dtype=jnp.float32): + batch_size = 16 + in_channels = 32 + out_channels = 64 + kernel_size = 3 + D, H, W = 32, 64, 64 + + key = jax.random.key(0) + x = jax.random.uniform(key, (batch_size, in_channels, D, H, W), dtype=dtype) + weight = jnp.zeros((out_channels, in_channels, kernel_size, kernel_size, kernel_size), dtype=dtype) + bias = jnp.zeros(out_channels, dtype=dtype) + + dynamic_args = [x, weight, bias] + static_args = [] + return dynamic_args, static_args + +# Computation +def computation(x, weight, bias): + x = jnp.transpose(x, (0, 2, 3, 4, 1)) + kernel = jnp.transpose(weight, (2, 3, 4, 1, 0)) + x = jax.lax.conv_general_dilated( + x, kernel, + window_strides=(1, 1, 1), + padding=((0, 0), (0, 0), (0, 0)), + dimension_numbers=('NDHWC', 'DHWIO', 'NDHWC') + ) + x = x + bias.reshape(1, 1, 1, 1, -1) + x = x * jnp.tanh(jnp.log(1 + jnp.exp(x))) + x = jnp.tanh(x) + x = jnp.transpose(x, (0, 4, 1, 2, 3)) + return x \ No newline at end of file diff --git a/MaxKernel/evaluation/jaxbench_adapted_dataset/34k_Conv2d_Activation_BatchNorm/kernel_task.yaml b/MaxKernel/evaluation/jaxbench_adapted_dataset/34k_Conv2d_Activation_BatchNorm/kernel_task.yaml new file mode 100644 index 0000000..8407d45 --- /dev/null +++ b/MaxKernel/evaluation/jaxbench_adapted_dataset/34k_Conv2d_Activation_BatchNorm/kernel_task.yaml @@ -0,0 +1,26 @@ +task_id: 34k_Conv2d_Activation_BatchNorm +description: Kernel task for 34k_Conv2d_Activation_BatchNorm +input_gen_code: |- + def get_inputs(dtype=jnp.float32): + import jax + import jax.numpy as jnp + import jax.lax as lax + + batch_size = 64 + in_channels = 64 + out_channels = 128 + kernel_size = 3 + height = 128 + width = 128 + + key = jax.random.key(0) + k1, k2, k3 = jax.random.split(key, 3) + x = jax.random.uniform(k1, (batch_size, in_channels, height, width), dtype=dtype) + conv_weight = jax.random.normal(k2, (out_channels, in_channels, kernel_size, kernel_size), dtype=dtype) + conv_bias = jax.random.normal(k3, (out_channels,), dtype=dtype) + bn_weight = jnp.ones(out_channels, dtype=dtype) + bn_bias = jnp.zeros(out_channels, dtype=dtype) + + dynamic_args = [x, conv_weight, conv_bias, bn_weight, bn_bias] + static_args = [] + return dynamic_args, static_args diff --git a/MaxKernel/evaluation/jaxbench_adapted_dataset/34k_Conv2d_Activation_BatchNorm/reference.py b/MaxKernel/evaluation/jaxbench_adapted_dataset/34k_Conv2d_Activation_BatchNorm/reference.py new file mode 100644 index 0000000..e82df68 --- /dev/null +++ b/MaxKernel/evaluation/jaxbench_adapted_dataset/34k_Conv2d_Activation_BatchNorm/reference.py @@ -0,0 +1,49 @@ +# Imports +import jax +import jax.numpy as jnp +import jax.lax as lax + +# Initialization +def get_inputs(dtype=jnp.float32): + batch_size = 64 + in_channels = 64 + out_channels = 128 + kernel_size = 3 + height = 128 + width = 128 + + key = jax.random.key(0) + k1, k2, k3 = jax.random.split(key, 3) + x = jax.random.uniform(k1, (batch_size, in_channels, height, width), dtype=dtype) + conv_weight = jax.random.normal(k2, (out_channels, in_channels, kernel_size, kernel_size), dtype=dtype) + conv_bias = jax.random.normal(k3, (out_channels,), dtype=dtype) + bn_weight = jnp.ones(out_channels, dtype=dtype) + bn_bias = jnp.zeros(out_channels, dtype=dtype) + + dynamic_args = [x, conv_weight, conv_bias, bn_weight, bn_bias] + static_args = [] + return dynamic_args, static_args + +# Computation +def computation(x, conv_weight, conv_bias, bn_weight, bn_bias): + eps = 1e-5 + x = jnp.transpose(x, (0, 2, 3, 1)) + weight = jnp.transpose(conv_weight, (2, 3, 1, 0)) + x = lax.conv_general_dilated( + x, weight, + window_strides=(1, 1), + padding='VALID', + dimension_numbers=('NHWC', 'HWIO', 'NHWC') + ) + x = x + conv_bias.reshape(1, 1, 1, -1) + x = jnp.transpose(x, (0, 3, 1, 2)) + + softplus_x = jax.nn.softplus(x) + x = jnp.multiply(jnp.tanh(softplus_x), x) + + mean = jnp.mean(x, axis=(0, 2, 3), keepdims=True) + var = jnp.mean((x - mean) ** 2, axis=(0, 2, 3), keepdims=True) + w = bn_weight.reshape(1, -1, 1, 1) + b = bn_bias.reshape(1, -1, 1, 1) + x = (x - mean) / jnp.sqrt(var + eps) * w + b + return x \ No newline at end of file diff --git a/MaxKernel/evaluation/jaxbench_adapted_dataset/35k_Gemm_Scaling_Hardtanh_GELU/kernel_task.yaml b/MaxKernel/evaluation/jaxbench_adapted_dataset/35k_Gemm_Scaling_Hardtanh_GELU/kernel_task.yaml new file mode 100644 index 0000000..559c7c1 --- /dev/null +++ b/MaxKernel/evaluation/jaxbench_adapted_dataset/35k_Gemm_Scaling_Hardtanh_GELU/kernel_task.yaml @@ -0,0 +1,33 @@ +task_id: 35k_Gemm_Scaling_Hardtanh_GELU +description: Kernel task for 35k_Gemm_Scaling_Hardtanh_GELU +input_gen_code: |- + def get_inputs(dtype=jnp.float32): + import jax + import jax.numpy as jnp + + CONFIG = { + 'name': '53_Gemm_Scaling_Hardtanh_GELU', + 'batch_size': 4096, + 'in_features': 8192, + 'out_features': 8192, + 'scaling_factor': 0.5, + 'hardtanh_min': -2, + 'hardtanh_max': 2, + } + + batch_size = CONFIG['batch_size'] + in_features = CONFIG['in_features'] + out_features = CONFIG['out_features'] + scaling_factor = CONFIG['scaling_factor'] + hardtanh_min = CONFIG['hardtanh_min'] + hardtanh_max = CONFIG['hardtanh_max'] + + key = jax.random.key(0) + x = jax.random.uniform(key, (batch_size, in_features), dtype=dtype) + weight = jnp.zeros((in_features, out_features), dtype=dtype) + bias = jnp.zeros(out_features, dtype=dtype) + + dynamic_args = [x, weight, bias] + static_args = [scaling_factor, hardtanh_min, hardtanh_max] + + return dynamic_args, static_args diff --git a/MaxKernel/evaluation/jaxbench_adapted_dataset/35k_Gemm_Scaling_Hardtanh_GELU/reference.py b/MaxKernel/evaluation/jaxbench_adapted_dataset/35k_Gemm_Scaling_Hardtanh_GELU/reference.py new file mode 100644 index 0000000..5cb7e83 --- /dev/null +++ b/MaxKernel/evaluation/jaxbench_adapted_dataset/35k_Gemm_Scaling_Hardtanh_GELU/reference.py @@ -0,0 +1,40 @@ +# Imports +import jax +import jax.numpy as jnp + +# Initialization +def get_inputs(dtype=jnp.float32): + CONFIG = { + 'name': '53_Gemm_Scaling_Hardtanh_GELU', + 'batch_size': 4096, + 'in_features': 8192, + 'out_features': 8192, + 'scaling_factor': 0.5, + 'hardtanh_min': -2, + 'hardtanh_max': 2, + } + + batch_size = CONFIG['batch_size'] + in_features = CONFIG['in_features'] + out_features = CONFIG['out_features'] + scaling_factor = CONFIG['scaling_factor'] + hardtanh_min = CONFIG['hardtanh_min'] + hardtanh_max = CONFIG['hardtanh_max'] + + key = jax.random.key(0) + x = jax.random.uniform(key, (batch_size, in_features), dtype=dtype) + weight = jnp.zeros((in_features, out_features), dtype=dtype) + bias = jnp.zeros(out_features, dtype=dtype) + + dynamic_args = [x, weight, bias] + static_args = [scaling_factor, hardtanh_min, hardtanh_max] + + return dynamic_args, static_args + +# Computation +def computation(x, weight, bias, scaling_factor, hardtanh_min, hardtanh_max): + x = jnp.matmul(x, weight) + bias + x = x * scaling_factor + x = jnp.clip(x, hardtanh_min, hardtanh_max) + x = x * 0.5 * (1.0 + jnp.tanh(jnp.sqrt(2.0 / jnp.pi) * (x + 0.044715 * x**3))) + return x \ No newline at end of file diff --git a/MaxKernel/evaluation/jaxbench_adapted_dataset/36k_Matmul_Sigmoid_Sum/kernel_task.yaml b/MaxKernel/evaluation/jaxbench_adapted_dataset/36k_Matmul_Sigmoid_Sum/kernel_task.yaml new file mode 100644 index 0000000..1838bd6 --- /dev/null +++ b/MaxKernel/evaluation/jaxbench_adapted_dataset/36k_Matmul_Sigmoid_Sum/kernel_task.yaml @@ -0,0 +1,20 @@ +task_id: 36k_Matmul_Sigmoid_Sum +description: Kernel task for 36k_Matmul_Sigmoid_Sum +input_gen_code: |- + def get_inputs(): + import jax + import jax.numpy as jnp + + batch_size = 4096 + input_size = 8192 + hidden_size = 8192 + dtype = jnp.float32 + + key = jax.random.key(0) + x = jax.random.uniform(key, (batch_size, input_size), dtype=dtype) + weight = jnp.zeros((input_size, hidden_size), dtype=dtype) + bias = jnp.zeros(hidden_size, dtype=dtype) + + dynamic_args = [x, weight, bias] + static_args = [] + return dynamic_args, static_args diff --git a/MaxKernel/evaluation/jaxbench_adapted_dataset/36k_Matmul_Sigmoid_Sum/reference.py b/MaxKernel/evaluation/jaxbench_adapted_dataset/36k_Matmul_Sigmoid_Sum/reference.py new file mode 100644 index 0000000..763e3e8 --- /dev/null +++ b/MaxKernel/evaluation/jaxbench_adapted_dataset/36k_Matmul_Sigmoid_Sum/reference.py @@ -0,0 +1,26 @@ +# Imports +import jax +import jax.numpy as jnp + +# Initialization +def get_inputs(): + batch_size = 4096 + input_size = 8192 + hidden_size = 8192 + dtype = jnp.float32 + + key = jax.random.key(0) + x = jax.random.uniform(key, (batch_size, input_size), dtype=dtype) + weight = jnp.zeros((input_size, hidden_size), dtype=dtype) + bias = jnp.zeros(hidden_size, dtype=dtype) + + dynamic_args = [x, weight, bias] + static_args = [] + return dynamic_args, static_args + +# Computation +def computation(x, weight, bias): + x = jnp.matmul(x, weight) + bias + x = jax.nn.sigmoid(x) + x = jnp.sum(x, axis=1, keepdims=True) + return x \ No newline at end of file diff --git a/MaxKernel/evaluation/jaxbench_adapted_dataset/37k_Matmul_Swish_Scaling/kernel_task.yaml b/MaxKernel/evaluation/jaxbench_adapted_dataset/37k_Matmul_Swish_Scaling/kernel_task.yaml new file mode 100644 index 0000000..7f719fe --- /dev/null +++ b/MaxKernel/evaluation/jaxbench_adapted_dataset/37k_Matmul_Swish_Scaling/kernel_task.yaml @@ -0,0 +1,19 @@ +task_id: 37k_Matmul_Swish_Scaling +description: Kernel task for 37k_Matmul_Swish_Scaling +input_gen_code: |- + def get_inputs(dtype=jnp.float32): + import jax + import jax.numpy as jnp + + batch_size = 4096 + in_features = 8192 + out_features = 8192 + + key = jax.random.key(0) + x = jax.random.uniform(key, (batch_size, in_features), dtype=dtype) + weight = jnp.zeros((in_features, out_features), dtype=dtype) + bias = jnp.zeros(out_features, dtype=dtype) + + dynamic_args = [x, weight, bias] + static_args = [] + return dynamic_args, static_args diff --git a/MaxKernel/evaluation/jaxbench_adapted_dataset/37k_Matmul_Swish_Scaling/reference.py b/MaxKernel/evaluation/jaxbench_adapted_dataset/37k_Matmul_Swish_Scaling/reference.py new file mode 100644 index 0000000..b330f03 --- /dev/null +++ b/MaxKernel/evaluation/jaxbench_adapted_dataset/37k_Matmul_Swish_Scaling/reference.py @@ -0,0 +1,25 @@ +# Imports +import jax +import jax.numpy as jnp + +# Initialization +def get_inputs(dtype=jnp.float32): + batch_size = 4096 + in_features = 8192 + out_features = 8192 + + key = jax.random.key(0) + x = jax.random.uniform(key, (batch_size, in_features), dtype=dtype) + weight = jnp.zeros((in_features, out_features), dtype=dtype) + bias = jnp.zeros(out_features, dtype=dtype) + + dynamic_args = [x, weight, bias] + static_args = [] + return dynamic_args, static_args + +# Computation +def computation(x, weight, bias): + x = jnp.matmul(x, weight) + bias + x = x * jax.nn.sigmoid(x) + x = x * 2.0 + return x \ No newline at end of file diff --git a/MaxKernel/evaluation/jaxbench_adapted_dataset/38k_Matmul_Dropout_Softmax/kernel_task.yaml b/MaxKernel/evaluation/jaxbench_adapted_dataset/38k_Matmul_Dropout_Softmax/kernel_task.yaml new file mode 100644 index 0000000..f3c9fd4 --- /dev/null +++ b/MaxKernel/evaluation/jaxbench_adapted_dataset/38k_Matmul_Dropout_Softmax/kernel_task.yaml @@ -0,0 +1,21 @@ +task_id: 38k_Matmul_Dropout_Softmax +description: Kernel task for 38k_Matmul_Dropout_Softmax +input_gen_code: |- + def get_inputs(): + import jax + import jax.numpy as jnp + + batch_size = 4096 + in_features = 8192 + out_features = 8192 + dtype = jnp.float32 + + key = jax.random.key(0) + x = jax.random.uniform(key, (batch_size, in_features), dtype=dtype) + weight = jnp.zeros((out_features, in_features), dtype=dtype) + bias = jnp.zeros(out_features, dtype=dtype) + + dynamic_args = [x, weight, bias] + static_args = [] + + return dynamic_args, static_args diff --git a/MaxKernel/evaluation/jaxbench_adapted_dataset/38k_Matmul_Dropout_Softmax/reference.py b/MaxKernel/evaluation/jaxbench_adapted_dataset/38k_Matmul_Dropout_Softmax/reference.py new file mode 100644 index 0000000..db8b916 --- /dev/null +++ b/MaxKernel/evaluation/jaxbench_adapted_dataset/38k_Matmul_Dropout_Softmax/reference.py @@ -0,0 +1,26 @@ +# Imports +import jax +import jax.numpy as jnp + +# Initialization +def get_inputs(): + batch_size = 4096 + in_features = 8192 + out_features = 8192 + dtype = jnp.float32 + + key = jax.random.key(0) + x = jax.random.uniform(key, (batch_size, in_features), dtype=dtype) + weight = jnp.zeros((out_features, in_features), dtype=dtype) + bias = jnp.zeros(out_features, dtype=dtype) + + dynamic_args = [x, weight, bias] + static_args = [] + + return dynamic_args, static_args + +# Computation +def computation(x, weight, bias): + x = x @ weight.T + bias + x = jax.nn.softmax(x, axis=1) + return x \ No newline at end of file diff --git a/MaxKernel/evaluation/jaxbench_adapted_dataset/39k_Conv2d_GELU_GlobalAvgPool/kernel_task.yaml b/MaxKernel/evaluation/jaxbench_adapted_dataset/39k_Conv2d_GELU_GlobalAvgPool/kernel_task.yaml new file mode 100644 index 0000000..ac872b8 --- /dev/null +++ b/MaxKernel/evaluation/jaxbench_adapted_dataset/39k_Conv2d_GELU_GlobalAvgPool/kernel_task.yaml @@ -0,0 +1,16 @@ +task_id: 39k_Conv2d_GELU_GlobalAvgPool +description: Kernel task for 39k_Conv2d_GELU_GlobalAvgPool +input_gen_code: |- + def get_inputs(): + import jax + import jax.numpy as jnp + + key = jax.random.key(0) + batch_size, in_channels, out_channels, kernel_size = 128, 8, 64, 3 + height, width = 256, 256 + x = jax.random.uniform(key, (batch_size, in_channels, height, width), dtype=jnp.float32) + weight = jnp.zeros((out_channels, in_channels, kernel_size, kernel_size), dtype=jnp.float32) + bias = jnp.zeros(out_channels, dtype=jnp.float32) + dynamic_args = [x, weight, bias] + static_args = [] + return dynamic_args, static_args diff --git a/MaxKernel/evaluation/jaxbench_adapted_dataset/39k_Conv2d_GELU_GlobalAvgPool/reference.py b/MaxKernel/evaluation/jaxbench_adapted_dataset/39k_Conv2d_GELU_GlobalAvgPool/reference.py new file mode 100644 index 0000000..0a6f2e6 --- /dev/null +++ b/MaxKernel/evaluation/jaxbench_adapted_dataset/39k_Conv2d_GELU_GlobalAvgPool/reference.py @@ -0,0 +1,30 @@ +# Imports +import jax +import jax.numpy as jnp + +# Initialization +def get_inputs(): + key = jax.random.key(0) + batch_size, in_channels, out_channels, kernel_size = 128, 8, 64, 3 + height, width = 256, 256 + x = jax.random.uniform(key, (batch_size, in_channels, height, width), dtype=jnp.float32) + weight = jnp.zeros((out_channels, in_channels, kernel_size, kernel_size), dtype=jnp.float32) + bias = jnp.zeros(out_channels, dtype=jnp.float32) + dynamic_args = [x, weight, bias] + static_args = [] + return dynamic_args, static_args + +# Computation +def computation(x, weight, bias): + x = jnp.transpose(x, (0, 2, 3, 1)) + kernel = jnp.transpose(weight, (2, 3, 1, 0)) + x = jax.lax.conv_general_dilated( + x, kernel, + window_strides=(1, 1), + padding='VALID', + dimension_numbers=('NHWC', 'HWIO', 'NHWC') + ) + x = x + bias.reshape(1, 1, 1, -1) + x = jax.nn.gelu(x) + x = jnp.mean(x, axis=(1, 2)) + return x \ No newline at end of file diff --git a/MaxKernel/evaluation/jaxbench_adapted_dataset/3p_MLA_Attention/kernel_task.yaml b/MaxKernel/evaluation/jaxbench_adapted_dataset/3p_MLA_Attention/kernel_task.yaml new file mode 100644 index 0000000..b852c3d --- /dev/null +++ b/MaxKernel/evaluation/jaxbench_adapted_dataset/3p_MLA_Attention/kernel_task.yaml @@ -0,0 +1,41 @@ +task_id: 3p_MLA_Attention +description: Kernel task for 3p_MLA_Attention +input_gen_code: |- + def get_inputs(dtype=jnp.bfloat16): + import jax + import jax.numpy as jnp + from functools import partial + + CONFIG = { + 'name': 'deepseek_v3_mla', + 'model': 'DeepSeek-V3-671B', + 'operator': 'mla_attention', + 'batch': 4, + 'seq_len': 2048, + 'emb_dim': 7168, + 'num_heads': 128, + 'q_lora_rank': 1536, + 'kv_lora_rank': 512, + 'qk_nope_head_dim': 128, + 'qk_rope_head_dim': 64, + 'v_head_dim': 128, + 'rope_theta': 10000, + } + key = jax.random.key(42) + keys = jax.random.split(key, 8) + C = CONFIG + B, S, E = C['batch'], C['seq_len'], C['emb_dim'] + H = C['num_heads'] + ql, kvl = C['q_lora_rank'], C['kv_lora_rank'] + nope, rope, vd = C['qk_nope_head_dim'], C['qk_rope_head_dim'], C['v_head_dim'] + x = jax.random.normal(keys[0], (B, S, E), dtype=dtype) + q_down = jax.random.normal(keys[1], (E, ql), dtype=dtype) * 0.02 + q_up = jax.random.normal(keys[2], (ql, H * (nope + rope)), dtype=dtype) * 0.02 + kv_down = jax.random.normal(keys[3], (E, kvl + rope), dtype=dtype) * 0.02 + k_up = jax.random.normal(keys[4], (kvl, H * nope), dtype=dtype) * 0.02 + v_up = jax.random.normal(keys[5], (kvl, H * vd), dtype=dtype) * 0.02 + o_proj = jax.random.normal(keys[6], (H * vd, E), dtype=dtype) * 0.02 + + dynamic_args = [x, q_down, q_up, kv_down, k_up, v_up, o_proj] + static_args = [H, nope, rope, vd, kvl, C['rope_theta']] + return dynamic_args, static_args diff --git a/MaxKernel/evaluation/jaxbench_adapted_dataset/3p_MLA_Attention/reference.py b/MaxKernel/evaluation/jaxbench_adapted_dataset/3p_MLA_Attention/reference.py new file mode 100644 index 0000000..0c40c1e --- /dev/null +++ b/MaxKernel/evaluation/jaxbench_adapted_dataset/3p_MLA_Attention/reference.py @@ -0,0 +1,85 @@ +# Imports +import jax +import jax.numpy as jnp +from functools import partial + +# Initialization +def get_inputs(dtype=jnp.bfloat16): + CONFIG = { + 'name': 'deepseek_v3_mla', + 'model': 'DeepSeek-V3-671B', + 'operator': 'mla_attention', + 'batch': 4, + 'seq_len': 2048, + 'emb_dim': 7168, + 'num_heads': 128, + 'q_lora_rank': 1536, + 'kv_lora_rank': 512, + 'qk_nope_head_dim': 128, + 'qk_rope_head_dim': 64, + 'v_head_dim': 128, + 'rope_theta': 10000, + } + key = jax.random.key(42) + keys = jax.random.split(key, 8) + C = CONFIG + B, S, E = C['batch'], C['seq_len'], C['emb_dim'] + H = C['num_heads'] + ql, kvl = C['q_lora_rank'], C['kv_lora_rank'] + nope, rope, vd = C['qk_nope_head_dim'], C['qk_rope_head_dim'], C['v_head_dim'] + x = jax.random.normal(keys[0], (B, S, E), dtype=dtype) + q_down = jax.random.normal(keys[1], (E, ql), dtype=dtype) * 0.02 + q_up = jax.random.normal(keys[2], (ql, H * (nope + rope)), dtype=dtype) * 0.02 + kv_down = jax.random.normal(keys[3], (E, kvl + rope), dtype=dtype) * 0.02 + k_up = jax.random.normal(keys[4], (kvl, H * nope), dtype=dtype) * 0.02 + v_up = jax.random.normal(keys[5], (kvl, H * vd), dtype=dtype) * 0.02 + o_proj = jax.random.normal(keys[6], (H * vd, E), dtype=dtype) * 0.02 + + dynamic_args = [x, q_down, q_up, kv_down, k_up, v_up, o_proj] + static_args = [H, nope, rope, vd, kvl, C['rope_theta']] + return dynamic_args, static_args + +# Computation +def _compute_rope(head_dim, seq_len, theta, dtype): + freqs = 1.0 / (theta ** (jnp.arange(0, head_dim, 2, dtype=jnp.float32) / head_dim)) + pos = jnp.arange(seq_len, dtype=jnp.float32) + angles = jnp.outer(pos, freqs) + return jnp.cos(angles).astype(dtype), jnp.sin(angles).astype(dtype) + +def _apply_rope(x, cos, sin): + x1, x2 = x[..., ::2], x[..., 1::2] + cos = cos[None, :, None, :] + sin = sin[None, :, None, :] + rotated = jnp.stack([x1 * cos - x2 * sin, x1 * sin + x2 * cos], axis=-1) + return rotated.reshape(x.shape) + +def computation(x, q_down_proj, q_up_proj, kv_down_proj, k_up_proj, v_up_proj, o_proj, + num_heads, qk_nope_head_dim, qk_rope_head_dim, v_head_dim, kv_lora_rank, rope_theta): + B, S, E = x.shape + H = num_heads + nope = qk_nope_head_dim + rope = qk_rope_head_dim + vd = v_head_dim + kvl = kv_lora_rank + q = jnp.dot(jnp.dot(x, q_down_proj), q_up_proj) + q = q.reshape(B, S, H, nope + rope) + q_nope, q_rope = q[..., :nope], q[..., nope:] + kv = jnp.dot(x, kv_down_proj) + k_latent, k_rope_raw = kv[..., :kvl], kv[..., kvl:] + k_nope = jnp.dot(k_latent, k_up_proj).reshape(B, S, H, nope) + cos, sin = _compute_rope(rope, S, rope_theta, x.dtype) + k_rope = jnp.broadcast_to(k_rope_raw[:, :, None, :], (B, S, H, rope)) + q_rope = _apply_rope(q_rope, cos, sin) + k_rope = _apply_rope(k_rope, cos, sin) + v = jnp.dot(k_latent, v_up_proj).reshape(B, S, H, vd) + q_full = jnp.concatenate([q_nope, q_rope], axis=-1).transpose(0, 2, 1, 3) + k_full = jnp.concatenate([k_nope, k_rope], axis=-1).transpose(0, 2, 1, 3) + v = v.transpose(0, 2, 1, 3) + hd = nope + rope + attn = jnp.einsum('bhqd,bhkd->bhqk', q_full, k_full) * (hd ** -0.5) + mask = jnp.tril(jnp.ones((S, S))) + attn = jnp.where(mask, attn, -1e9) + attn = jax.nn.softmax(attn, axis=-1) + out = jnp.einsum('bhqk,bhkd->bhqd', attn, v) + out = out.transpose(0, 2, 1, 3).reshape(B, S, H * vd) + return jnp.dot(out, o_proj) \ No newline at end of file diff --git a/MaxKernel/evaluation/jaxbench_adapted_dataset/40k_Gemm_GroupNorm_Min_BiasAdd/kernel_task.yaml b/MaxKernel/evaluation/jaxbench_adapted_dataset/40k_Gemm_GroupNorm_Min_BiasAdd/kernel_task.yaml new file mode 100644 index 0000000..660778c --- /dev/null +++ b/MaxKernel/evaluation/jaxbench_adapted_dataset/40k_Gemm_GroupNorm_Min_BiasAdd/kernel_task.yaml @@ -0,0 +1,18 @@ +task_id: 40k_Gemm_GroupNorm_Min_BiasAdd +description: Kernel task for 40k_Gemm_GroupNorm_Min_BiasAdd +input_gen_code: |- + def get_inputs(dtype=jnp.float32): + import jax + import jax.numpy as jnp + + key = jax.random.key(0) + batch_size, in_features, out_features, num_groups = 4096, 8192, 8192, 512 + x = jax.random.uniform(key, (batch_size, in_features), dtype=dtype) + weight = jnp.zeros((out_features, in_features), dtype=dtype) + linear_bias = jnp.zeros(out_features, dtype=dtype) + gn_weight = jnp.ones(out_features, dtype=dtype) + gn_bias = jnp.zeros(out_features, dtype=dtype) + bias = jnp.zeros((1, out_features, 1, 1), dtype=dtype) + dynamic_args = [x, weight, linear_bias, gn_weight, gn_bias, bias] + static_args = [] + return dynamic_args, static_args diff --git a/MaxKernel/evaluation/jaxbench_adapted_dataset/40k_Gemm_GroupNorm_Min_BiasAdd/reference.py b/MaxKernel/evaluation/jaxbench_adapted_dataset/40k_Gemm_GroupNorm_Min_BiasAdd/reference.py new file mode 100644 index 0000000..d8b49f7 --- /dev/null +++ b/MaxKernel/evaluation/jaxbench_adapted_dataset/40k_Gemm_GroupNorm_Min_BiasAdd/reference.py @@ -0,0 +1,35 @@ +# Imports +import jax +import jax.numpy as jnp + +# Initialization +def get_inputs(dtype=jnp.float32): + key = jax.random.key(0) + batch_size, in_features, out_features, num_groups = 4096, 8192, 8192, 512 + x = jax.random.uniform(key, (batch_size, in_features), dtype=dtype) + weight = jnp.zeros((out_features, in_features), dtype=dtype) + linear_bias = jnp.zeros(out_features, dtype=dtype) + gn_weight = jnp.ones(out_features, dtype=dtype) + gn_bias = jnp.zeros(out_features, dtype=dtype) + bias = jnp.zeros((1, out_features, 1, 1), dtype=dtype) + dynamic_args = [x, weight, linear_bias, gn_weight, gn_bias, bias] + static_args = [] + return dynamic_args, static_args + +# Computation +def computation(x, weight, linear_bias, gn_weight, gn_bias, bias): + num_groups = 512 + eps = 1e-5 + x = jnp.matmul(x, weight.T) + linear_bias + N, C = x.shape + G = num_groups + x = x.reshape(N, G, C // G) + mean = jnp.mean(x, axis=2, keepdims=True) + var = jnp.var(x, axis=2, keepdims=True) + x = (x - mean) / jnp.sqrt(var + eps) + x = x.reshape(N, C) + x = x * gn_weight + gn_bias + x = jnp.min(x, axis=1, keepdims=True) + x = x.reshape(1, 1, N, 1) + x = x + bias + return x \ No newline at end of file diff --git a/MaxKernel/evaluation/jaxbench_adapted_dataset/41k_Gemm_Add_ReLU/kernel_task.yaml b/MaxKernel/evaluation/jaxbench_adapted_dataset/41k_Gemm_Add_ReLU/kernel_task.yaml new file mode 100644 index 0000000..416b18b --- /dev/null +++ b/MaxKernel/evaluation/jaxbench_adapted_dataset/41k_Gemm_Add_ReLU/kernel_task.yaml @@ -0,0 +1,19 @@ +task_id: 41k_Gemm_Add_ReLU +description: Kernel task for 41k_Gemm_Add_ReLU +input_gen_code: |- + def get_inputs(dtype=jnp.float32): + import jax + import jax.numpy as jnp + + batch_size = 4096 + in_features = 8192 + out_features = 8192 + + key = jax.random.key(0) + x = jax.random.uniform(key, (batch_size, in_features), dtype=dtype) + weight = jnp.zeros((in_features, out_features), dtype=dtype) + bias = jnp.zeros(out_features, dtype=dtype) + + dynamic_args = [x, weight, bias] + static_args = [] + return dynamic_args, static_args diff --git a/MaxKernel/evaluation/jaxbench_adapted_dataset/41k_Gemm_Add_ReLU/reference.py b/MaxKernel/evaluation/jaxbench_adapted_dataset/41k_Gemm_Add_ReLU/reference.py new file mode 100644 index 0000000..3d0e8fe --- /dev/null +++ b/MaxKernel/evaluation/jaxbench_adapted_dataset/41k_Gemm_Add_ReLU/reference.py @@ -0,0 +1,25 @@ +# Imports +import jax +import jax.numpy as jnp + +# Initialization +def get_inputs(dtype=jnp.float32): + batch_size = 4096 + in_features = 8192 + out_features = 8192 + + key = jax.random.key(0) + x = jax.random.uniform(key, (batch_size, in_features), dtype=dtype) + weight = jnp.zeros((in_features, out_features), dtype=dtype) + bias = jnp.zeros(out_features, dtype=dtype) + + dynamic_args = [x, weight, bias] + static_args = [] + return dynamic_args, static_args + +# Computation +def computation(x, weight, bias): + x = jnp.matmul(x, weight) + x = x + bias + x = jax.nn.relu(x) + return x \ No newline at end of file diff --git a/MaxKernel/evaluation/jaxbench_adapted_dataset/42k_Gemm_Max_Subtract_GELU/kernel_task.yaml b/MaxKernel/evaluation/jaxbench_adapted_dataset/42k_Gemm_Max_Subtract_GELU/kernel_task.yaml new file mode 100644 index 0000000..78e977d --- /dev/null +++ b/MaxKernel/evaluation/jaxbench_adapted_dataset/42k_Gemm_Max_Subtract_GELU/kernel_task.yaml @@ -0,0 +1,22 @@ +task_id: 42k_Gemm_Max_Subtract_GELU +description: Kernel task for 42k_Gemm_Max_Subtract_GELU +input_gen_code: |- + def get_inputs(): + import jax + import jax.numpy as jnp + + config = { + 'name': '80_Gemm_Max_Subtract_GELU', + 'batch_size': 4096, + 'in_features': 8192, + 'out_features': 8192, + 'max_dim': 1, + } + dtype = jnp.float32 + key = jax.random.key(0) + x = jax.random.uniform(key, (config['batch_size'], config['in_features']), dtype=dtype) + weight = jnp.zeros((config['in_features'], config['out_features']), dtype=dtype) + bias = jnp.zeros(config['out_features'], dtype=dtype) + dynamic_args = [x, weight, bias] + static_args = [] + return dynamic_args, static_args diff --git a/MaxKernel/evaluation/jaxbench_adapted_dataset/42k_Gemm_Max_Subtract_GELU/reference.py b/MaxKernel/evaluation/jaxbench_adapted_dataset/42k_Gemm_Max_Subtract_GELU/reference.py new file mode 100644 index 0000000..6b2266d --- /dev/null +++ b/MaxKernel/evaluation/jaxbench_adapted_dataset/42k_Gemm_Max_Subtract_GELU/reference.py @@ -0,0 +1,29 @@ +# Imports +import jax +import jax.numpy as jnp + +# Initialization +def get_inputs(): + config = { + 'name': '80_Gemm_Max_Subtract_GELU', + 'batch_size': 4096, + 'in_features': 8192, + 'out_features': 8192, + 'max_dim': 1, + } + dtype = jnp.float32 + key = jax.random.key(0) + x = jax.random.uniform(key, (config['batch_size'], config['in_features']), dtype=dtype) + weight = jnp.zeros((config['in_features'], config['out_features']), dtype=dtype) + bias = jnp.zeros(config['out_features'], dtype=dtype) + dynamic_args = [x, weight, bias] + static_args = [] + return dynamic_args, static_args + +# Computation +def computation(x, weight, bias): + x = jnp.matmul(x, weight) + bias + x = jnp.max(x, axis=1, keepdims=True) + x = x - jnp.mean(x, axis=1, keepdims=True) + x = jax.nn.gelu(x) + return x \ No newline at end of file diff --git a/MaxKernel/evaluation/jaxbench_adapted_dataset/43k_Gemm_BatchNorm_Scaling_Softmax/kernel_task.yaml b/MaxKernel/evaluation/jaxbench_adapted_dataset/43k_Gemm_BatchNorm_Scaling_Softmax/kernel_task.yaml new file mode 100644 index 0000000..f2155c2 --- /dev/null +++ b/MaxKernel/evaluation/jaxbench_adapted_dataset/43k_Gemm_BatchNorm_Scaling_Softmax/kernel_task.yaml @@ -0,0 +1,25 @@ +task_id: 43k_Gemm_BatchNorm_Scaling_Softmax +description: Kernel task for 43k_Gemm_BatchNorm_Scaling_Softmax +input_gen_code: |- + def get_inputs(): + import jax + import jax.numpy as jnp + + batch_size = 4096 + in_features = 8192 + out_features = 8192 + dtype = jnp.float32 + + key = jax.random.key(0) + x = jax.random.uniform(key, (batch_size, in_features), dtype=dtype) + weight = jnp.zeros((in_features, out_features), dtype=dtype) + bias = jnp.zeros(out_features, dtype=dtype) + bn_scale = jnp.ones(out_features, dtype=dtype) + bn_bias = jnp.zeros(out_features, dtype=dtype) + bn_mean = jnp.zeros(out_features, dtype=dtype) + bn_var = jnp.ones(out_features, dtype=dtype) + scale = jnp.ones((1,), dtype=dtype) + + dynamic_args = [x, weight, bias, bn_scale, bn_bias, bn_mean, bn_var, scale] + static_args = [] + return dynamic_args, static_args diff --git a/MaxKernel/evaluation/jaxbench_adapted_dataset/43k_Gemm_BatchNorm_Scaling_Softmax/reference.py b/MaxKernel/evaluation/jaxbench_adapted_dataset/43k_Gemm_BatchNorm_Scaling_Softmax/reference.py new file mode 100644 index 0000000..1ccb4e1 --- /dev/null +++ b/MaxKernel/evaluation/jaxbench_adapted_dataset/43k_Gemm_BatchNorm_Scaling_Softmax/reference.py @@ -0,0 +1,34 @@ +# Imports +import jax +import jax.numpy as jnp + +# Initialization +def get_inputs(): + batch_size = 4096 + in_features = 8192 + out_features = 8192 + dtype = jnp.float32 + + key = jax.random.key(0) + x = jax.random.uniform(key, (batch_size, in_features), dtype=dtype) + weight = jnp.zeros((in_features, out_features), dtype=dtype) + bias = jnp.zeros(out_features, dtype=dtype) + bn_scale = jnp.ones(out_features, dtype=dtype) + bn_bias = jnp.zeros(out_features, dtype=dtype) + bn_mean = jnp.zeros(out_features, dtype=dtype) + bn_var = jnp.ones(out_features, dtype=dtype) + scale = jnp.ones((1,), dtype=dtype) + + dynamic_args = [x, weight, bias, bn_scale, bn_bias, bn_mean, bn_var, scale] + static_args = [] + return dynamic_args, static_args + +# Computation +def computation(x, weight, bias, bn_scale, bn_bias, bn_mean, bn_var, scale): + bn_eps = 1e-5 + x = jnp.matmul(x, weight) + bias + x_normalized = (x - bn_mean) / jnp.sqrt(bn_var + bn_eps) + x = bn_scale * x_normalized + bn_bias + x = scale * x + x = jax.nn.softmax(x, axis=1) + return x \ No newline at end of file diff --git a/MaxKernel/evaluation/jaxbench_adapted_dataset/44k_Matmul_Divide_GELU/kernel_task.yaml b/MaxKernel/evaluation/jaxbench_adapted_dataset/44k_Matmul_Divide_GELU/kernel_task.yaml new file mode 100644 index 0000000..082dc92 --- /dev/null +++ b/MaxKernel/evaluation/jaxbench_adapted_dataset/44k_Matmul_Divide_GELU/kernel_task.yaml @@ -0,0 +1,21 @@ +task_id: 44k_Matmul_Divide_GELU +description: Kernel task for 44k_Matmul_Divide_GELU +input_gen_code: |- + def get_inputs(): + import jax + import jax.numpy as jnp + + batch_size = 4096 + input_size = 8192 + output_size = 8192 + divisor = 10.0 + dtype = jnp.float32 + + key = jax.random.key(0) + x = jax.random.uniform(key, (batch_size, input_size), dtype=dtype) + weight = jnp.zeros((input_size, output_size), dtype=dtype) + bias = jnp.zeros(output_size, dtype=dtype) + + dynamic_args = [x, weight, bias] + static_args = [divisor] + return dynamic_args, static_args diff --git a/MaxKernel/evaluation/jaxbench_adapted_dataset/44k_Matmul_Divide_GELU/reference.py b/MaxKernel/evaluation/jaxbench_adapted_dataset/44k_Matmul_Divide_GELU/reference.py new file mode 100644 index 0000000..345665d --- /dev/null +++ b/MaxKernel/evaluation/jaxbench_adapted_dataset/44k_Matmul_Divide_GELU/reference.py @@ -0,0 +1,27 @@ +# Imports +import jax +import jax.numpy as jnp + +# Initialization +def get_inputs(): + batch_size = 4096 + input_size = 8192 + output_size = 8192 + divisor = 10.0 + dtype = jnp.float32 + + key = jax.random.key(0) + x = jax.random.uniform(key, (batch_size, input_size), dtype=dtype) + weight = jnp.zeros((input_size, output_size), dtype=dtype) + bias = jnp.zeros(output_size, dtype=dtype) + + dynamic_args = [x, weight, bias] + static_args = [divisor] + return dynamic_args, static_args + +# Computation +def computation(x, weight, bias, divisor): + x = x @ weight + bias + x = x / divisor + x = jax.nn.gelu(x) + return x \ No newline at end of file diff --git a/MaxKernel/evaluation/jaxbench_adapted_dataset/45k_Gemm_GroupNorm_Swish_Multiply_Swish/kernel_task.yaml b/MaxKernel/evaluation/jaxbench_adapted_dataset/45k_Gemm_GroupNorm_Swish_Multiply_Swish/kernel_task.yaml new file mode 100644 index 0000000..aa58cc7 --- /dev/null +++ b/MaxKernel/evaluation/jaxbench_adapted_dataset/45k_Gemm_GroupNorm_Swish_Multiply_Swish/kernel_task.yaml @@ -0,0 +1,24 @@ +task_id: 45k_Gemm_GroupNorm_Swish_Multiply_Swish +description: Kernel task for 45k_Gemm_GroupNorm_Swish_Multiply_Swish +input_gen_code: |- + def get_inputs(dtype=jnp.float32): + import jax + import jax.numpy as jnp + + batch_size = 4096 + in_features = 8192 + out_features = 8192 + num_groups = 256 + + key = jax.random.key(0) + x = jax.random.uniform(key, (batch_size, in_features), dtype=dtype) + gemm_weight = jnp.zeros((out_features, in_features), dtype=dtype) + gemm_bias = jnp.zeros(out_features, dtype=dtype) + gn_weight = jnp.ones(out_features, dtype=dtype) + gn_bias = jnp.zeros(out_features, dtype=dtype) + multiply_weight = jnp.zeros(out_features, dtype=dtype) + + dynamic_args = [x, gemm_weight, gemm_bias, gn_weight, gn_bias, multiply_weight] + static_args = [num_groups, out_features] + + return dynamic_args, static_args diff --git a/MaxKernel/evaluation/jaxbench_adapted_dataset/45k_Gemm_GroupNorm_Swish_Multiply_Swish/reference.py b/MaxKernel/evaluation/jaxbench_adapted_dataset/45k_Gemm_GroupNorm_Swish_Multiply_Swish/reference.py new file mode 100644 index 0000000..d7c4da3 --- /dev/null +++ b/MaxKernel/evaluation/jaxbench_adapted_dataset/45k_Gemm_GroupNorm_Swish_Multiply_Swish/reference.py @@ -0,0 +1,39 @@ +# Imports +import jax +import jax.numpy as jnp + +# Initialization +def get_inputs(dtype=jnp.float32): + batch_size = 4096 + in_features = 8192 + out_features = 8192 + num_groups = 256 + + key = jax.random.key(0) + x = jax.random.uniform(key, (batch_size, in_features), dtype=dtype) + gemm_weight = jnp.zeros((out_features, in_features), dtype=dtype) + gemm_bias = jnp.zeros(out_features, dtype=dtype) + gn_weight = jnp.ones(out_features, dtype=dtype) + gn_bias = jnp.zeros(out_features, dtype=dtype) + multiply_weight = jnp.zeros(out_features, dtype=dtype) + + dynamic_args = [x, gemm_weight, gemm_bias, gn_weight, gn_bias, multiply_weight] + static_args = [num_groups, out_features] + + return dynamic_args, static_args + +# Computation +def computation(x, gemm_weight, gemm_bias, gn_weight, gn_bias, multiply_weight, num_groups, out_features): + x = jnp.matmul(x, gemm_weight.T) + gemm_bias + batch_size = x.shape[0] + group_size = out_features // num_groups + x_grouped = x.reshape(batch_size, num_groups, group_size) + mean = jnp.mean(x_grouped, axis=-1, keepdims=True) + var = jnp.var(x_grouped, axis=-1, keepdims=True) + x_normalized = (x_grouped - mean) / jnp.sqrt(var + 1e-5) + x = x_normalized.reshape(batch_size, out_features) + x = x * gn_weight + gn_bias + x = x * jax.nn.sigmoid(x) + x = x * multiply_weight + x = x * jax.nn.sigmoid(x) + return x \ No newline at end of file diff --git a/MaxKernel/evaluation/jaxbench_adapted_dataset/46k_Conv2d_GroupNorm_Tanh_HardSwish_ResidualAdd_LogSumExp/kernel_task.yaml b/MaxKernel/evaluation/jaxbench_adapted_dataset/46k_Conv2d_GroupNorm_Tanh_HardSwish_ResidualAdd_LogSumExp/kernel_task.yaml new file mode 100644 index 0000000..8274d54 --- /dev/null +++ b/MaxKernel/evaluation/jaxbench_adapted_dataset/46k_Conv2d_GroupNorm_Tanh_HardSwish_ResidualAdd_LogSumExp/kernel_task.yaml @@ -0,0 +1,23 @@ +task_id: 46k_Conv2d_GroupNorm_Tanh_HardSwish_ResidualAdd_LogSumExp +description: Kernel task for 46k_Conv2d_GroupNorm_Tanh_HardSwish_ResidualAdd_LogSumExp +input_gen_code: |- + def get_inputs(dtype=jnp.float32): + import jax + import jax.numpy as jnp + import jax.scipy.special + + key = jax.random.key(0) + batch_size = 128 + in_channels = 8 + out_channels = 64 + kernel_size = 3 + height = 128 + width = 128 + x = jax.random.uniform(key, (batch_size, in_channels, height, width), dtype=dtype) + conv_weight = jnp.zeros((out_channels, in_channels, kernel_size, kernel_size), dtype=dtype) + conv_bias = jnp.zeros(out_channels, dtype=dtype) + gn_weight = jnp.ones(out_channels, dtype=dtype) + gn_bias = jnp.zeros(out_channels, dtype=dtype) + dynamic_args = [x, conv_weight, conv_bias, gn_weight, gn_bias] + static_args = [] + return dynamic_args, static_args diff --git a/MaxKernel/evaluation/jaxbench_adapted_dataset/46k_Conv2d_GroupNorm_Tanh_HardSwish_ResidualAdd_LogSumExp/reference.py b/MaxKernel/evaluation/jaxbench_adapted_dataset/46k_Conv2d_GroupNorm_Tanh_HardSwish_ResidualAdd_LogSumExp/reference.py new file mode 100644 index 0000000..963865d --- /dev/null +++ b/MaxKernel/evaluation/jaxbench_adapted_dataset/46k_Conv2d_GroupNorm_Tanh_HardSwish_ResidualAdd_LogSumExp/reference.py @@ -0,0 +1,48 @@ +# Imports +import jax +import jax.numpy as jnp +import jax.scipy.special + +# Initialization +def get_inputs(dtype=jnp.float32): + key = jax.random.key(0) + batch_size = 128 + in_channels = 8 + out_channels = 64 + kernel_size = 3 + height = 128 + width = 128 + x = jax.random.uniform(key, (batch_size, in_channels, height, width), dtype=dtype) + conv_weight = jnp.zeros((out_channels, in_channels, kernel_size, kernel_size), dtype=dtype) + conv_bias = jnp.zeros(out_channels, dtype=dtype) + gn_weight = jnp.ones(out_channels, dtype=dtype) + gn_bias = jnp.zeros(out_channels, dtype=dtype) + dynamic_args = [x, conv_weight, conv_bias, gn_weight, gn_bias] + static_args = [] + return dynamic_args, static_args + +# Computation +def computation(x, conv_weight, conv_bias, gn_weight, gn_bias): + groups = 16 + eps = 1e-5 + x_nhwc = jnp.transpose(x, (0, 2, 3, 1)) + kernel = jnp.transpose(conv_weight, (2, 3, 1, 0)) + x_conv = jax.lax.conv_general_dilated( + x_nhwc, kernel, + window_strides=(1, 1), + padding='VALID', + dimension_numbers=('NHWC', 'HWIO', 'NHWC')) + x_conv = x_conv + conv_bias.reshape(1, 1, 1, -1) + x_conv = jnp.transpose(x_conv, (0, 3, 1, 2)) + N, C, H, W = x_conv.shape + x = x_conv.reshape(N, groups, C // groups, H, W) + mean = jnp.mean(x, axis=(2, 3, 4), keepdims=True) + var = jnp.var(x, axis=(2, 3, 4), keepdims=True) + x = (x - mean) / jnp.sqrt(var + eps) + x = x.reshape(N, C, H, W) + x_norm = x * gn_weight.reshape(1, -1, 1, 1) + gn_bias.reshape(1, -1, 1, 1) + x_tanh = jnp.tanh(x_norm) + x_hard_swish = x_tanh * jnp.minimum(jnp.maximum(x_tanh + 3, 0), 6) / 6 + x_res = x_conv + x_hard_swish + x_logsumexp = jax.scipy.special.logsumexp(x_res, axis=1, keepdims=True) + return x_logsumexp \ No newline at end of file diff --git a/MaxKernel/evaluation/jaxbench_adapted_dataset/47k_Matmul_Add_Swish_Tanh_GELU_Hardtanh/kernel_task.yaml b/MaxKernel/evaluation/jaxbench_adapted_dataset/47k_Matmul_Add_Swish_Tanh_GELU_Hardtanh/kernel_task.yaml new file mode 100644 index 0000000..7428ff6 --- /dev/null +++ b/MaxKernel/evaluation/jaxbench_adapted_dataset/47k_Matmul_Add_Swish_Tanh_GELU_Hardtanh/kernel_task.yaml @@ -0,0 +1,20 @@ +task_id: 47k_Matmul_Add_Swish_Tanh_GELU_Hardtanh +description: Kernel task for 47k_Matmul_Add_Swish_Tanh_GELU_Hardtanh +input_gen_code: |- + def get_inputs(dtype=jnp.float32): + import jax + import jax.numpy as jnp + + batch_size = 4096 + in_features = 8192 + out_features = 8192 + + key = jax.random.key(0) + x = jax.random.uniform(key, (batch_size, in_features), dtype=dtype) + weight = jnp.zeros((in_features, out_features), dtype=dtype) + bias = jnp.zeros(out_features, dtype=dtype) + add_value = jnp.zeros(out_features, dtype=dtype) + + dynamic_args = [x, weight, bias, add_value] + static_args = [] + return dynamic_args, static_args diff --git a/MaxKernel/evaluation/jaxbench_adapted_dataset/47k_Matmul_Add_Swish_Tanh_GELU_Hardtanh/reference.py b/MaxKernel/evaluation/jaxbench_adapted_dataset/47k_Matmul_Add_Swish_Tanh_GELU_Hardtanh/reference.py new file mode 100644 index 0000000..118d3ee --- /dev/null +++ b/MaxKernel/evaluation/jaxbench_adapted_dataset/47k_Matmul_Add_Swish_Tanh_GELU_Hardtanh/reference.py @@ -0,0 +1,29 @@ +# Imports +import jax +import jax.numpy as jnp + +# Initialization +def get_inputs(dtype=jnp.float32): + batch_size = 4096 + in_features = 8192 + out_features = 8192 + + key = jax.random.key(0) + x = jax.random.uniform(key, (batch_size, in_features), dtype=dtype) + weight = jnp.zeros((in_features, out_features), dtype=dtype) + bias = jnp.zeros(out_features, dtype=dtype) + add_value = jnp.zeros(out_features, dtype=dtype) + + dynamic_args = [x, weight, bias, add_value] + static_args = [] + return dynamic_args, static_args + +# Computation +def computation(x, weight, bias, add_value): + x = x @ weight + bias + x = x + add_value + x = jax.nn.swish(x) + x = jnp.tanh(x) + x = jax.nn.gelu(x) + x = jnp.clip(x, -1.0, 1.0) + return x \ No newline at end of file diff --git a/MaxKernel/evaluation/jaxbench_adapted_dataset/48k_Matmul_BatchNorm_BiasAdd_Divide_Swish/kernel_task.yaml b/MaxKernel/evaluation/jaxbench_adapted_dataset/48k_Matmul_BatchNorm_BiasAdd_Divide_Swish/kernel_task.yaml new file mode 100644 index 0000000..c47f5f9 --- /dev/null +++ b/MaxKernel/evaluation/jaxbench_adapted_dataset/48k_Matmul_BatchNorm_BiasAdd_Divide_Swish/kernel_task.yaml @@ -0,0 +1,22 @@ +task_id: 48k_Matmul_BatchNorm_BiasAdd_Divide_Swish +description: Kernel task for 48k_Matmul_BatchNorm_BiasAdd_Divide_Swish +input_gen_code: |- + def get_inputs(dtype=jnp.float32): + import jax + import jax.numpy as jnp + + batch_size = 4096 + in_features = 8192 + out_features = 8192 + key = jax.random.key(0) + x = jax.random.uniform(key, (batch_size, in_features), dtype=dtype) + weight = jnp.zeros((in_features, out_features), dtype=dtype) + linear_bias = jnp.zeros(out_features, dtype=dtype) + bn_scale = jnp.ones(out_features, dtype=dtype) + bn_bias = jnp.zeros(out_features, dtype=dtype) + bn_mean = jnp.zeros(out_features, dtype=dtype) + bn_var = jnp.ones(out_features, dtype=dtype) + bias = jnp.zeros((1,), dtype=dtype) + dynamic_args = [x, weight, linear_bias, bn_scale, bn_bias, bn_mean, bn_var, bias] + static_args = [] + return dynamic_args, static_args diff --git a/MaxKernel/evaluation/jaxbench_adapted_dataset/48k_Matmul_BatchNorm_BiasAdd_Divide_Swish/reference.py b/MaxKernel/evaluation/jaxbench_adapted_dataset/48k_Matmul_BatchNorm_BiasAdd_Divide_Swish/reference.py new file mode 100644 index 0000000..84e0ad5 --- /dev/null +++ b/MaxKernel/evaluation/jaxbench_adapted_dataset/48k_Matmul_BatchNorm_BiasAdd_Divide_Swish/reference.py @@ -0,0 +1,33 @@ +# Imports +import jax +import jax.numpy as jnp + +# Initialization +def get_inputs(dtype=jnp.float32): + batch_size = 4096 + in_features = 8192 + out_features = 8192 + key = jax.random.key(0) + x = jax.random.uniform(key, (batch_size, in_features), dtype=dtype) + weight = jnp.zeros((in_features, out_features), dtype=dtype) + linear_bias = jnp.zeros(out_features, dtype=dtype) + bn_scale = jnp.ones(out_features, dtype=dtype) + bn_bias = jnp.zeros(out_features, dtype=dtype) + bn_mean = jnp.zeros(out_features, dtype=dtype) + bn_var = jnp.ones(out_features, dtype=dtype) + bias = jnp.zeros((1,), dtype=dtype) + dynamic_args = [x, weight, linear_bias, bn_scale, bn_bias, bn_mean, bn_var, bias] + static_args = [] + return dynamic_args, static_args + +# Computation +def computation(x, weight, linear_bias, bn_scale, bn_bias, bn_mean, bn_var, bias): + bn_eps = 1e-5 + divide_value = 1.0 + x = jnp.matmul(x, weight) + linear_bias + x_normalized = (x - bn_mean) / jnp.sqrt(bn_var + bn_eps) + x = bn_scale * x_normalized + bn_bias + x = x + bias + x = x / divide_value + x = x * jax.nn.sigmoid(x) + return x \ No newline at end of file diff --git a/MaxKernel/evaluation/jaxbench_adapted_dataset/49k_Matmul_AvgPool_GELU_Scale_Max/kernel_task.yaml b/MaxKernel/evaluation/jaxbench_adapted_dataset/49k_Matmul_AvgPool_GELU_Scale_Max/kernel_task.yaml new file mode 100644 index 0000000..2a3b09f --- /dev/null +++ b/MaxKernel/evaluation/jaxbench_adapted_dataset/49k_Matmul_AvgPool_GELU_Scale_Max/kernel_task.yaml @@ -0,0 +1,22 @@ +task_id: 49k_Matmul_AvgPool_GELU_Scale_Max +description: Kernel task for 49k_Matmul_AvgPool_GELU_Scale_Max +input_gen_code: |- + def get_inputs(): + import jax + import jax.numpy as jnp + + batch_size = 4096 + in_features = 8192 + out_features = 8192 + pool_kernel_size = 16 + scale_factor = 2.0 + dtype = jnp.float32 + + key = jax.random.key(0) + x = jax.random.uniform(key, (batch_size, in_features), dtype=dtype) + weight = jnp.zeros((in_features, out_features), dtype=dtype) + bias = jnp.zeros(out_features, dtype=dtype) + + dynamic_args = [x, weight, bias] + static_args = [pool_kernel_size, scale_factor] + return dynamic_args, static_args diff --git a/MaxKernel/evaluation/jaxbench_adapted_dataset/49k_Matmul_AvgPool_GELU_Scale_Max/reference.py b/MaxKernel/evaluation/jaxbench_adapted_dataset/49k_Matmul_AvgPool_GELU_Scale_Max/reference.py new file mode 100644 index 0000000..b84141f --- /dev/null +++ b/MaxKernel/evaluation/jaxbench_adapted_dataset/49k_Matmul_AvgPool_GELU_Scale_Max/reference.py @@ -0,0 +1,39 @@ +# Imports +import jax +import jax.numpy as jnp + +# Initialization +def get_inputs(): + batch_size = 4096 + in_features = 8192 + out_features = 8192 + pool_kernel_size = 16 + scale_factor = 2.0 + dtype = jnp.float32 + + key = jax.random.key(0) + x = jax.random.uniform(key, (batch_size, in_features), dtype=dtype) + weight = jnp.zeros((in_features, out_features), dtype=dtype) + bias = jnp.zeros(out_features, dtype=dtype) + + dynamic_args = [x, weight, bias] + static_args = [pool_kernel_size, scale_factor] + return dynamic_args, static_args + +# Computation +def computation(x, weight, bias, pool_kernel_size, scale_factor): + x = jnp.matmul(x, weight) + bias + x = jnp.expand_dims(x, axis=1) + x = jax.lax.reduce_window( + x, + init_value=0.0, + computation=jax.lax.add, + window_dimensions=(1, 1, pool_kernel_size), + window_strides=(1, 1, pool_kernel_size), + padding='VALID' + ) / pool_kernel_size + x = jnp.squeeze(x, axis=1) + x = jax.nn.gelu(x) + x = x * scale_factor + x = jnp.max(x, axis=1) + return x \ No newline at end of file diff --git a/MaxKernel/evaluation/jaxbench_adapted_dataset/4p_Sparse_Attention/kernel_task.yaml b/MaxKernel/evaluation/jaxbench_adapted_dataset/4p_Sparse_Attention/kernel_task.yaml new file mode 100644 index 0000000..3bb03ef --- /dev/null +++ b/MaxKernel/evaluation/jaxbench_adapted_dataset/4p_Sparse_Attention/kernel_task.yaml @@ -0,0 +1,30 @@ +task_id: 4p_Sparse_Attention +description: Kernel task for 4p_Sparse_Attention +input_gen_code: |- + def get_inputs(dtype=jnp.bfloat16): + import jax + import jax.numpy as jnp + + CONFIG = { + 'name': 'llama3_70b_sparse_attention', + 'model': 'Llama-3.1-70B', + 'operator': 'sparse_attention', + 'batch': 4, + 'seq_len': 4096, + 'num_query_heads': 64, + 'num_kv_heads': 8, + 'head_dim': 128, + } + key = jax.random.key(42) + k1, k2, k3 = jax.random.split(key, 3) + S = CONFIG['seq_len'] + H_q = CONFIG['num_query_heads'] + H_kv = CONFIG['num_kv_heads'] + D = CONFIG['head_dim'] + q = jax.random.normal(k1, (H_q, S, D), dtype=dtype) * (D ** -0.5) + k = jax.random.normal(k2, (H_kv, S, D), dtype=dtype) * 0.02 + v = jax.random.normal(k3, (H_kv, S, D), dtype=dtype) * 0.02 + + dynamic_args = [q, k, v] + static_args = [S, H_q, H_kv, D] + return dynamic_args, static_args diff --git a/MaxKernel/evaluation/jaxbench_adapted_dataset/4p_Sparse_Attention/reference.py b/MaxKernel/evaluation/jaxbench_adapted_dataset/4p_Sparse_Attention/reference.py new file mode 100644 index 0000000..a45f288 --- /dev/null +++ b/MaxKernel/evaluation/jaxbench_adapted_dataset/4p_Sparse_Attention/reference.py @@ -0,0 +1,46 @@ +# Imports +import jax +import jax.numpy as jnp + +# Initialization +def get_inputs(dtype=jnp.bfloat16): + CONFIG = { + 'name': 'llama3_70b_sparse_attention', + 'model': 'Llama-3.1-70B', + 'operator': 'sparse_attention', + 'batch': 4, + 'seq_len': 4096, + 'num_query_heads': 64, + 'num_kv_heads': 8, + 'head_dim': 128, + } + key = jax.random.key(42) + k1, k2, k3 = jax.random.split(key, 3) + S = CONFIG['seq_len'] + H_q = CONFIG['num_query_heads'] + H_kv = CONFIG['num_kv_heads'] + D = CONFIG['head_dim'] + q = jax.random.normal(k1, (H_q, S, D), dtype=dtype) * (D ** -0.5) + k = jax.random.normal(k2, (H_kv, S, D), dtype=dtype) * 0.02 + v = jax.random.normal(k3, (H_kv, S, D), dtype=dtype) * 0.02 + + dynamic_args = [q, k, v] + static_args = [S, H_q, H_kv, D] + return dynamic_args, static_args + +# Computation +def computation(q, k, v, S, H_q, H_kv, D): + num_q_per_kv = H_q // H_kv + + k = jnp.repeat(k, num_q_per_kv, axis=0) + v = jnp.repeat(v, num_q_per_kv, axis=0) + + attn = jnp.einsum('hqd,hkd->hqk', q, k) + + causal = jnp.tril(jnp.ones((S, S), dtype=jnp.bool_)) + attn = jnp.where(causal[None, :, :], attn, -1e30) + + attn = jax.nn.softmax(attn, axis=-1) + + out = jnp.einsum('hqk,hkd->hqd', attn, v) + return out \ No newline at end of file diff --git a/MaxKernel/evaluation/jaxbench_adapted_dataset/50k_Matmul_GELU_Softmax/kernel_task.yaml b/MaxKernel/evaluation/jaxbench_adapted_dataset/50k_Matmul_GELU_Softmax/kernel_task.yaml new file mode 100644 index 0000000..87ed5dd --- /dev/null +++ b/MaxKernel/evaluation/jaxbench_adapted_dataset/50k_Matmul_GELU_Softmax/kernel_task.yaml @@ -0,0 +1,20 @@ +task_id: 50k_Matmul_GELU_Softmax +description: Kernel task for 50k_Matmul_GELU_Softmax +input_gen_code: |- + def get_inputs(): + import jax + import jax.numpy as jnp + + batch_size = 4096 + in_features = 8192 + out_features = 8192 + dtype = jnp.float32 + + key = jax.random.key(0) + x = jax.random.uniform(key, (batch_size, in_features), dtype=dtype) + weight = jnp.zeros((in_features, out_features), dtype=dtype) + bias = jnp.zeros(out_features, dtype=dtype) + + dynamic_args = [x, weight, bias] + static_args = [] + return dynamic_args, static_args diff --git a/MaxKernel/evaluation/jaxbench_adapted_dataset/50k_Matmul_GELU_Softmax/reference.py b/MaxKernel/evaluation/jaxbench_adapted_dataset/50k_Matmul_GELU_Softmax/reference.py new file mode 100644 index 0000000..e424247 --- /dev/null +++ b/MaxKernel/evaluation/jaxbench_adapted_dataset/50k_Matmul_GELU_Softmax/reference.py @@ -0,0 +1,26 @@ +# Imports +import jax +import jax.numpy as jnp + +# Initialization +def get_inputs(): + batch_size = 4096 + in_features = 8192 + out_features = 8192 + dtype = jnp.float32 + + key = jax.random.key(0) + x = jax.random.uniform(key, (batch_size, in_features), dtype=dtype) + weight = jnp.zeros((in_features, out_features), dtype=dtype) + bias = jnp.zeros(out_features, dtype=dtype) + + dynamic_args = [x, weight, bias] + static_args = [] + return dynamic_args, static_args + +# Computation +def computation(x, weight, bias): + x = jnp.matmul(x, weight) + bias + x = jax.nn.gelu(x) + x = jax.nn.softmax(x, axis=1) + return x \ No newline at end of file diff --git a/MaxKernel/evaluation/jaxbench_adapted_dataset/5p_Flex_Attention/kernel_task.yaml b/MaxKernel/evaluation/jaxbench_adapted_dataset/5p_Flex_Attention/kernel_task.yaml new file mode 100644 index 0000000..a5cfcbd --- /dev/null +++ b/MaxKernel/evaluation/jaxbench_adapted_dataset/5p_Flex_Attention/kernel_task.yaml @@ -0,0 +1,30 @@ +task_id: 5p_Flex_Attention +description: Kernel task for 5p_Flex_Attention +input_gen_code: |- + def get_inputs(dtype=jnp.bfloat16): + import jax + import jax.numpy as jnp + + CONFIG = { + 'name': 'llama3_70b_flex_attention', + 'model': 'Llama-3.1-70B', + 'operator': 'flex_attention', + 'batch': 4, + 'seq_len': 4096, + 'num_heads': 64, + 'head_dim': 128, + } + key = jax.random.key(42) + k1, k2, k3, k4 = jax.random.split(key, 4) + B = CONFIG['batch'] + S = CONFIG['seq_len'] + H = CONFIG['num_heads'] + D = CONFIG['head_dim'] + q = jax.random.normal(k1, (B, H, S, D), dtype=dtype) + k = jax.random.normal(k2, (B, H, S, D), dtype=dtype) * 0.02 + v = jax.random.normal(k3, (B, H, S, D), dtype=dtype) * 0.02 + rel_pos_bias = jax.random.normal(k4, (H, S, S), dtype=dtype) * 0.01 + + dynamic_args = [q, k, v, rel_pos_bias] + static_args = [D, S] + return dynamic_args, static_args diff --git a/MaxKernel/evaluation/jaxbench_adapted_dataset/5p_Flex_Attention/reference.py b/MaxKernel/evaluation/jaxbench_adapted_dataset/5p_Flex_Attention/reference.py new file mode 100644 index 0000000..b40ddd3 --- /dev/null +++ b/MaxKernel/evaluation/jaxbench_adapted_dataset/5p_Flex_Attention/reference.py @@ -0,0 +1,43 @@ +# Imports +import jax +import jax.numpy as jnp + +# Initialization +def get_inputs(dtype=jnp.bfloat16): + CONFIG = { + 'name': 'llama3_70b_flex_attention', + 'model': 'Llama-3.1-70B', + 'operator': 'flex_attention', + 'batch': 4, + 'seq_len': 4096, + 'num_heads': 64, + 'head_dim': 128, + } + key = jax.random.key(42) + k1, k2, k3, k4 = jax.random.split(key, 4) + B = CONFIG['batch'] + S = CONFIG['seq_len'] + H = CONFIG['num_heads'] + D = CONFIG['head_dim'] + q = jax.random.normal(k1, (B, H, S, D), dtype=dtype) + k = jax.random.normal(k2, (B, H, S, D), dtype=dtype) * 0.02 + v = jax.random.normal(k3, (B, H, S, D), dtype=dtype) * 0.02 + rel_pos_bias = jax.random.normal(k4, (H, S, S), dtype=dtype) * 0.01 + + dynamic_args = [q, k, v, rel_pos_bias] + static_args = [D, S] + return dynamic_args, static_args + +# Computation +def computation(q, k, v, rel_pos_bias, head_dim, seq_len): + D = head_dim + S = seq_len + sm_scale = D ** -0.5 + + attn = jnp.einsum('bhqd,bhkd->bhqk', q, k) * sm_scale + attn = attn + rel_pos_bias[None, :, :, :] + causal = jnp.tril(jnp.ones((S, S), dtype=jnp.bool_)) + attn = jnp.where(causal[None, None, :, :], attn, -1e30) + attn = jax.nn.softmax(attn, axis=-1) + out = jnp.einsum('bhqk,bhkd->bhqd', attn, v) + return out \ No newline at end of file diff --git a/MaxKernel/evaluation/jaxbench_adapted_dataset/6p_Paged_Attention/kernel_task.yaml b/MaxKernel/evaluation/jaxbench_adapted_dataset/6p_Paged_Attention/kernel_task.yaml new file mode 100644 index 0000000..d287114 --- /dev/null +++ b/MaxKernel/evaluation/jaxbench_adapted_dataset/6p_Paged_Attention/kernel_task.yaml @@ -0,0 +1,44 @@ +task_id: 6p_Paged_Attention +description: Kernel task for 6p_Paged_Attention +input_gen_code: |- + def get_inputs(dtype=jnp.bfloat16): + import jax + import jax.numpy as jnp + + CONFIG = { + 'name': 'llama3_70b_paged_attention', + 'model': 'Llama-3.1-70B', + 'operator': 'paged_attention', + 'num_seqs': 64, + 'max_seq_len': 4096, + 'num_query_heads': 64, + 'num_kv_heads': 8, + 'head_dim': 128, + 'page_size': 16, + 'pages_per_seq': 256, + } + + key = jax.random.key(42) + keys = jax.random.split(key, 5) + num_seqs = CONFIG['num_seqs'] + num_q_heads = CONFIG['num_query_heads'] + num_kv_heads = CONFIG['num_kv_heads'] + head_dim = CONFIG['head_dim'] + page_size = CONFIG['page_size'] + pages_per_seq = CONFIG['pages_per_seq'] + total_pages = num_seqs * pages_per_seq + max_seq_len_derived = pages_per_seq * page_size + + max_num_tokens = num_seqs + queries = jax.random.normal(keys[0], (max_num_tokens, num_q_heads, head_dim), dtype=dtype) + k_pages = jax.random.normal(keys[1], (total_pages, page_size, num_kv_heads, head_dim), dtype=dtype) * 0.02 + v_pages = jax.random.normal(keys[2], (total_pages, page_size, num_kv_heads, head_dim), dtype=dtype) * 0.02 + + kv_lens = jnp.full((num_seqs,), max_seq_len_derived, dtype=jnp.int32) + page_indices = jnp.arange(total_pages, dtype=jnp.int32).reshape(num_seqs, pages_per_seq) + cu_q_lens = jnp.arange(num_seqs + 1, dtype=jnp.int32) + + dynamic_args = [queries, k_pages, v_pages, kv_lens, page_indices, cu_q_lens] + static_args = [num_seqs, num_q_heads, num_kv_heads, head_dim, max_seq_len_derived] + + return dynamic_args, static_args diff --git a/MaxKernel/evaluation/jaxbench_adapted_dataset/6p_Paged_Attention/reference.py b/MaxKernel/evaluation/jaxbench_adapted_dataset/6p_Paged_Attention/reference.py new file mode 100644 index 0000000..455e363 --- /dev/null +++ b/MaxKernel/evaluation/jaxbench_adapted_dataset/6p_Paged_Attention/reference.py @@ -0,0 +1,104 @@ +# Imports +import jax +import jax.numpy as jnp + + +# Initialization +def get_inputs(dtype=jnp.bfloat16): + CONFIG = { + "name": "llama3_70b_paged_attention", + "model": "Llama-3.1-70B", + "operator": "paged_attention", + "num_seqs": 64, + "max_seq_len": 4096, + "num_query_heads": 64, + "num_kv_heads": 8, + "head_dim": 128, + "page_size": 16, + "pages_per_seq": 256, + } + + key = jax.random.key(42) + keys = jax.random.split(key, 5) + num_seqs = CONFIG["num_seqs"] + num_q_heads = CONFIG["num_query_heads"] + num_kv_heads = CONFIG["num_kv_heads"] + head_dim = CONFIG["head_dim"] + page_size = CONFIG["page_size"] + pages_per_seq = CONFIG["pages_per_seq"] + total_pages = num_seqs * pages_per_seq + max_seq_len_derived = pages_per_seq * page_size + + max_num_tokens = num_seqs + queries = jax.random.normal( + keys[0], (max_num_tokens, num_q_heads, head_dim), dtype=dtype + ) + k_pages = ( + jax.random.normal( + keys[1], (total_pages, page_size, num_kv_heads, head_dim), dtype=dtype + ) + * 0.02 + ) + v_pages = ( + jax.random.normal( + keys[2], (total_pages, page_size, num_kv_heads, head_dim), dtype=dtype + ) + * 0.02 + ) + + kv_lens = jnp.full((num_seqs,), max_seq_len_derived, dtype=jnp.int32) + page_indices = jnp.arange(total_pages, dtype=jnp.int32).reshape( + num_seqs, pages_per_seq + ) + cu_q_lens = jnp.arange(num_seqs + 1, dtype=jnp.int32) + + dynamic_args = [queries, k_pages, v_pages, kv_lens, page_indices, cu_q_lens] + static_args = [ + num_seqs, + num_q_heads, + num_kv_heads, + head_dim, + max_seq_len_derived, + ] + + return dynamic_args, static_args + + +# Computation +def computation( + queries, + k_pages, + v_pages, + kv_lens, + page_indices, + cu_q_lens, + num_seqs, + num_q_heads, + num_kv_heads, + head_dim, + max_seq_len, +): + num_q_per_kv = num_q_heads // num_kv_heads + sm_scale = head_dim**-0.5 + + def attend_one_seq(seq_idx): + q_start = cu_q_lens[seq_idx] + q_end = cu_q_lens[seq_idx + 1] + q = jax.lax.dynamic_slice( + queries, (q_start, 0, 0), (1, num_q_heads, head_dim) + ) + seq_pages = page_indices[seq_idx] + k = k_pages[seq_pages].reshape(max_seq_len, num_kv_heads, head_dim) + v = v_pages[seq_pages].reshape(max_seq_len, num_kv_heads, head_dim) + k = jnp.repeat(k, num_q_per_kv, axis=1) + v = jnp.repeat(v, num_q_per_kv, axis=1) + attn = jnp.einsum("qhd,khd->hqk", q, k) * sm_scale + kv_len = kv_lens[seq_idx] + mask = jnp.arange(max_seq_len) < kv_len + attn = jnp.where(mask[None, None, :], attn, -1e30) + attn = jax.nn.softmax(attn, axis=-1) + out = jnp.einsum("hqk,khd->qhd", attn, v) + return out.squeeze(0) + + outputs = jax.vmap(attend_one_seq)(jnp.arange(num_seqs)) + return outputs diff --git a/MaxKernel/evaluation/jaxbench_adapted_dataset/7p_Ragged_Paged_Attention/kernel_task.yaml b/MaxKernel/evaluation/jaxbench_adapted_dataset/7p_Ragged_Paged_Attention/kernel_task.yaml new file mode 100644 index 0000000..c073fe5 --- /dev/null +++ b/MaxKernel/evaluation/jaxbench_adapted_dataset/7p_Ragged_Paged_Attention/kernel_task.yaml @@ -0,0 +1,45 @@ +task_id: 7p_Ragged_Paged_Attention +description: Kernel task for 7p_Ragged_Paged_Attention +input_gen_code: |- + def get_inputs(): + CONFIG = { + 'name': 'ragged_paged_attention_llama70b', + 'model': 'Llama-3.1-70B', + 'operator': 'ragged_paged_attention', + 'max_num_batched_tokens': 4096, + 'max_num_seqs': 64, + 'num_q_heads': 64, + 'num_kv_heads': 8, + 'head_dim': 128, + 'page_size': 16, + 'pages_per_seq': 256, + } + + dtype = jnp.bfloat16 + key = jax.random.key(42) + k1, k2 = jax.random.split(key, 2) + max_num_batched_tokens = CONFIG['max_num_batched_tokens'] + max_num_seqs = CONFIG['max_num_seqs'] + H_q = CONFIG['num_q_heads'] + H_kv = CONFIG['num_kv_heads'] + head_dim = CONFIG['head_dim'] + page_size = CONFIG['page_size'] + pages_per_seq = CONFIG['pages_per_seq'] + num_combined_kv_heads = 2 * H_kv + total_num_pages = max_num_seqs * pages_per_seq + q = jax.random.normal(k1, (max_num_batched_tokens, H_q, head_dim), dtype=dtype) + kv_pages = jax.random.normal( + k2, (total_num_pages, page_size, num_combined_kv_heads, head_dim), dtype=dtype + ) + tokens_per_seq = max_num_batched_tokens // max_num_seqs + kv_len_per_seq = pages_per_seq * page_size + kv_lens = jnp.full((max_num_seqs,), kv_len_per_seq, dtype=jnp.int32) + page_indices = jnp.arange(total_num_pages, dtype=jnp.int32).reshape( + max_num_seqs, pages_per_seq + ) + cu_q_lens = jnp.arange(max_num_seqs + 1, dtype=jnp.int32) * tokens_per_seq + num_seqs = jnp.array([max_num_seqs], dtype=jnp.int32) + + dynamic_args = [q, kv_pages, kv_lens, page_indices, cu_q_lens, num_seqs] + static_args = [head_dim, max_num_seqs, max_num_batched_tokens] + return dynamic_args, static_args diff --git a/MaxKernel/evaluation/jaxbench_adapted_dataset/7p_Ragged_Paged_Attention/reference.py b/MaxKernel/evaluation/jaxbench_adapted_dataset/7p_Ragged_Paged_Attention/reference.py new file mode 100644 index 0000000..f8e96e4 --- /dev/null +++ b/MaxKernel/evaluation/jaxbench_adapted_dataset/7p_Ragged_Paged_Attention/reference.py @@ -0,0 +1,106 @@ +# Imports +import math + +import jax +import jax.numpy as jnp + +# Initialization +def get_inputs(): + CONFIG = { + 'name': 'ragged_paged_attention_llama70b', + 'model': 'Llama-3.1-70B', + 'operator': 'ragged_paged_attention', + 'max_num_batched_tokens': 4096, + 'max_num_seqs': 64, + 'num_q_heads': 64, + 'num_kv_heads': 8, + 'head_dim': 128, + 'page_size': 16, + 'pages_per_seq': 256, + } + + dtype = jnp.bfloat16 + key = jax.random.key(42) + k1, k2 = jax.random.split(key, 2) + max_tokens = CONFIG['max_num_batched_tokens'] + max_seqs = CONFIG['max_num_seqs'] + H_q = CONFIG['num_q_heads'] + H_kv = CONFIG['num_kv_heads'] + D = CONFIG['head_dim'] + page_size = CONFIG['page_size'] + pages_per_seq = CONFIG['pages_per_seq'] + head_dim = CONFIG['head_dim'] + max_num_seqs = CONFIG['max_num_seqs'] + max_num_batched_tokens = CONFIG['max_num_batched_tokens'] + num_combined_kv_heads = 2 * H_kv + total_num_pages = max_seqs * pages_per_seq + q = jax.random.normal(k1, (max_tokens, H_q, D), dtype=dtype) + kv_pages = jax.random.normal( + k2, (total_num_pages, page_size, num_combined_kv_heads, D), dtype=dtype + ) + tokens_per_seq = max_tokens // max_seqs + kv_len_per_seq = pages_per_seq * page_size + kv_lens = jnp.full((max_seqs,), kv_len_per_seq, dtype=jnp.int32) + page_indices = jnp.arange(total_num_pages, dtype=jnp.int32).reshape( + max_seqs, pages_per_seq + ) + cu_q_lens = jnp.arange(max_seqs + 1, dtype=jnp.int32) * tokens_per_seq + num_seqs = jnp.array([max_seqs], dtype=jnp.int32) + + dynamic_args = [q, kv_pages, kv_lens, page_indices, cu_q_lens, num_seqs] + static_args = [head_dim, max_num_seqs, max_num_batched_tokens] + return dynamic_args, static_args + +# Computation +def computation(queries, kv_pages, kv_lens, page_indices, cu_q_lens, num_seqs, + head_dim, max_num_seqs, max_num_batched_tokens): + DEFAULT_MASK_VALUE = -0.7 * float(jnp.finfo(jnp.dtype("float32")).max) + + sm_scale = 1.0 / math.sqrt(head_dim) + mask_value = DEFAULT_MASK_VALUE + _, _, num_combined_kv_heads, head_dim = kv_pages.shape + num_kv_heads = num_combined_kv_heads // 2 + num_q_heads = queries.shape[1] + num_query_per_kv = num_q_heads // num_kv_heads + + max_seqs = max_num_seqs + tokens_per_seq = max_num_batched_tokens // max_seqs + + outputs = [] + for i in range(max_seqs): + q_start = cu_q_lens[i] + kv_len = kv_lens[i] + indices = page_indices[i] + + q = jax.lax.dynamic_slice( + queries, (q_start, 0, 0), (tokens_per_seq, num_q_heads, head_dim) + ) + + k = kv_pages[indices, :, 0::2, :].reshape(-1, num_kv_heads, head_dim) + v = kv_pages[indices, :, 1::2, :].reshape(-1, num_kv_heads, head_dim) + + k = jnp.repeat(k, num_query_per_kv, axis=1) + v = jnp.repeat(v, num_query_per_kv, axis=1) + + attn = jnp.einsum( + "qhd,khd->hqk", q, k, preferred_element_type=jnp.float32 + ) + attn *= sm_scale + + q_span = (kv_len - tokens_per_seq) + jax.lax.broadcasted_iota( + jnp.int32, attn.shape, 1 + ) + kv_span = jax.lax.broadcasted_iota(jnp.int32, attn.shape, 2) + + mask = (q_span < kv_span) | (kv_span >= kv_len) + attn = jnp.where(mask, mask_value, attn) + + attn = jax.nn.softmax(attn, axis=-1).astype(v.dtype) + out = jnp.einsum("hqk,khd->qhd", attn, v).astype(queries.dtype) + + is_valid = i < num_seqs[0] + out = jnp.where(is_valid, out, 0.0) + + outputs.append(out) + + return jnp.concatenate(outputs, axis=0) \ No newline at end of file diff --git a/MaxKernel/evaluation/jaxbench_adapted_dataset/8p_GEMM/kernel_task.yaml b/MaxKernel/evaluation/jaxbench_adapted_dataset/8p_GEMM/kernel_task.yaml new file mode 100644 index 0000000..5dcb8ed --- /dev/null +++ b/MaxKernel/evaluation/jaxbench_adapted_dataset/8p_GEMM/kernel_task.yaml @@ -0,0 +1,20 @@ +task_id: 8p_GEMM +description: Kernel task for 8p_GEMM +input_gen_code: |- + def get_inputs(dtype=jnp.bfloat16): + import jax + import jax.numpy as jnp + import time + import numpy as np + import json + + M = 8192 + K = 8192 + N = 28672 + key = jax.random.key(42) + k1, k2 = jax.random.split(key, 2) + A = jax.random.normal(k1, (M, K), dtype=dtype) + B = jax.random.normal(k2, (K, N), dtype=dtype) * 0.02 + dynamic_args = [A, B] + static_args = [] + return dynamic_args, static_args diff --git a/MaxKernel/evaluation/jaxbench_adapted_dataset/8p_GEMM/reference.py b/MaxKernel/evaluation/jaxbench_adapted_dataset/8p_GEMM/reference.py new file mode 100644 index 0000000..d3cd90f --- /dev/null +++ b/MaxKernel/evaluation/jaxbench_adapted_dataset/8p_GEMM/reference.py @@ -0,0 +1,23 @@ +# Imports +import jax +import jax.numpy as jnp +import time +import numpy as np +import json + +# Initialization +def get_inputs(dtype=jnp.bfloat16): + M = 8192 + K = 8192 + N = 28672 + key = jax.random.key(42) + k1, k2 = jax.random.split(key, 2) + A = jax.random.normal(k1, (M, K), dtype=dtype) + B = jax.random.normal(k2, (K, N), dtype=dtype) * 0.02 + dynamic_args = [A, B] + static_args = [] + return dynamic_args, static_args + +# Computation +def computation(A, B): + return jnp.dot(A, B) \ No newline at end of file diff --git a/MaxKernel/evaluation/jaxbench_adapted_dataset/9p_SwiGLU_MLP/kernel_task.yaml b/MaxKernel/evaluation/jaxbench_adapted_dataset/9p_SwiGLU_MLP/kernel_task.yaml new file mode 100644 index 0000000..0484292 --- /dev/null +++ b/MaxKernel/evaluation/jaxbench_adapted_dataset/9p_SwiGLU_MLP/kernel_task.yaml @@ -0,0 +1,29 @@ +task_id: 9p_SwiGLU_MLP +description: Kernel task for 9p_SwiGLU_MLP +input_gen_code: |- + def get_inputs(): + import jax + import jax.numpy as jnp + from functools import partial + + CONFIG = { + 'name': 'llama3_70b_swiglu', + 'model': 'Llama-3.1-70B', + 'operator': 'swiglu_mlp', + 'batch': 2, + 'seq_len': 4096, + 'emb_dim': 8192, + 'mlp_dim': 28672, + } + dtype = jnp.bfloat16 + key = jax.random.key(42) + k1, k2, k3, k4 = jax.random.split(key, 4) + B, S, E, M = CONFIG['batch'], CONFIG['seq_len'], CONFIG['emb_dim'], CONFIG['mlp_dim'] + x = jax.random.normal(k1, (B, S, E), dtype=dtype) + gate = jax.random.normal(k2, (E, M), dtype=dtype) * 0.02 + up = jax.random.normal(k3, (E, M), dtype=dtype) * 0.02 + down = jax.random.normal(k4, (M, E), dtype=dtype) * 0.02 + + dynamic_args = [x, gate, up, down] + static_args = [] + return dynamic_args, static_args diff --git a/MaxKernel/evaluation/jaxbench_adapted_dataset/9p_SwiGLU_MLP/reference.py b/MaxKernel/evaluation/jaxbench_adapted_dataset/9p_SwiGLU_MLP/reference.py new file mode 100644 index 0000000..7253c66 --- /dev/null +++ b/MaxKernel/evaluation/jaxbench_adapted_dataset/9p_SwiGLU_MLP/reference.py @@ -0,0 +1,34 @@ +# Imports +import jax +import jax.numpy as jnp +from functools import partial + +# Initialization +def get_inputs(): + CONFIG = { + 'name': 'llama3_70b_swiglu', + 'model': 'Llama-3.1-70B', + 'operator': 'swiglu_mlp', + 'batch': 2, + 'seq_len': 4096, + 'emb_dim': 8192, + 'mlp_dim': 28672, + } + dtype = jnp.bfloat16 + key = jax.random.key(42) + k1, k2, k3, k4 = jax.random.split(key, 4) + B, S, E, M = CONFIG['batch'], CONFIG['seq_len'], CONFIG['emb_dim'], CONFIG['mlp_dim'] + x = jax.random.normal(k1, (B, S, E), dtype=dtype) + gate = jax.random.normal(k2, (E, M), dtype=dtype) * 0.02 + up = jax.random.normal(k3, (E, M), dtype=dtype) * 0.02 + down = jax.random.normal(k4, (M, E), dtype=dtype) * 0.02 + + dynamic_args = [x, gate, up, down] + static_args = [] + return dynamic_args, static_args + +# Computation +def computation(x, gate_kernel, up_kernel, down_kernel): + gate = jax.nn.silu(jnp.dot(x, gate_kernel)) + up = jnp.dot(x, up_kernel) + return jnp.dot(gate * up, down_kernel) \ No newline at end of file