Summary
The experimental Apple Metal backend (backends/apple/metal, AOTI-based) can't lower common CNNs (MobileNetV3, YOLO). Export fails with "missing fallback kernels" for ops the backend has no c-shim for. Sharing findings + suggestions — not a complete analysis, just flagging that standard CNNs can't currently target Metal.
Environment
- ExecuTorch 1.3.1 (pip), torch 2.12.0,
torchao 0.17.0 built from git with TORCHAO_BUILD_EXPERIMENTAL_MPS=1, macOS arm64 (Apple M2), Python 3.11.
Minimal repro
import torch, torch.nn as nn
from torch.export import export
from executorch.exir import to_edge_transform_and_lower
from executorch.backends.apple.metal.metal_partitioner import MetalPartitioner
from executorch.backends.apple.metal.metal_backend import MetalBackend
class M(nn.Module):
def __init__(self):
super().__init__()
self.fc = nn.Linear(4, 4)
def forward(self, x):
a, b = torch.split(x, 4, dim=1) # -> aten.split_copy.Tensor
return self.fc(a + b)
ep = export(M().eval(), (torch.randn(1, 8),))
to_edge_transform_and_lower(
ep,
partitioner=[MetalPartitioner([MetalBackend.generate_method_name_compile_spec("forward")])],
).to_executorch()
W torch/_inductor/ir.py: aten.split_copy.Tensor is missing a c-shim implementation, using proxy executor as fallback
RuntimeError: Method forward missing fallback kernels (1 total):
- aten::split_copy.Tensor
Please add them to the AOTI backend.
Real models (exported via MetalPartitioner):
- MobileNetV3-small → missing
aoti_torch_mps_addmm_out (the classifier Linear).
- YOLO11n → missing
aten::split_copy.Tensor.
Root cause
MetalBackend.get_supported_fallback_kernels() only lists:
aoti_torch_mps_bmm_out, aoti_torch_mps_convolution, aoti_torch_mps_mm_out, _scaled_dot_product_attention_math_for_mps::call, torchao::_linear_fp_act_4bit_weight.
When AOTInductor emits a fallback for any other op (split_copy, slice_copy, addmm, …) there's no c-shim, and get_decomposition_table() is empty → export fails.
What I tried (why decomposition alone isn't enough)
- Adding
aten.addmm.default to get_decomposition_table() looks viable — torch has a stock decomp (addmm → mm + add) and mm_out is supported.
split_copy has no stock decomposition. A custom split_copy → slice_copy decomposition just moves the failure to aten::slice_copy.Tensor, which is also missing a c-shim. The supported set has no view/copy primitive to bottom out on.
Suggestions
- Add AOTI MPS c-shims for the common view/copy ops AOTInductor emits as fallbacks (
slice_copy, split_copy, addmm_out, …); or
- Provide decompositions to already-supported ops (works for
addmm → mm; the slice/split family needs a supported target); or
- Have AOTInductor codegen these for the MPS device rather than emitting proxy-executor fallbacks.
Happy to help test fixes.
Summary
The experimental Apple Metal backend (
backends/apple/metal, AOTI-based) can't lower common CNNs (MobileNetV3, YOLO). Export fails with "missing fallback kernels" for ops the backend has no c-shim for. Sharing findings + suggestions — not a complete analysis, just flagging that standard CNNs can't currently target Metal.Environment
torchao 0.17.0built from git withTORCHAO_BUILD_EXPERIMENTAL_MPS=1, macOS arm64 (Apple M2), Python 3.11.Minimal repro
Real models (exported via
MetalPartitioner):aoti_torch_mps_addmm_out(the classifierLinear).aten::split_copy.Tensor.Root cause
MetalBackend.get_supported_fallback_kernels()only lists:aoti_torch_mps_bmm_out,aoti_torch_mps_convolution,aoti_torch_mps_mm_out,_scaled_dot_product_attention_math_for_mps::call,torchao::_linear_fp_act_4bit_weight.When AOTInductor emits a fallback for any other op (
split_copy,slice_copy,addmm, …) there's no c-shim, andget_decomposition_table()is empty → export fails.What I tried (why decomposition alone isn't enough)
aten.addmm.defaulttoget_decomposition_table()looks viable — torch has a stock decomp (addmm → mm + add) andmm_outis supported.split_copyhas no stock decomposition. A customsplit_copy → slice_copydecomposition just moves the failure toaten::slice_copy.Tensor, which is also missing a c-shim. The supported set has no view/copy primitive to bottom out on.Suggestions
slice_copy,split_copy,addmm_out, …); oraddmm → mm; the slice/split family needs a supported target); orHappy to help test fixes.