Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions demos/image-classification/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,14 @@ <h1 class="title">
<input type="radio" name="backend" id="webnn_npu" />WebNN NPU
</label>
</div>
<div class="btn-group-toggle dtypes hide" data-toggle="buttons" id="dtypeBtns">
<label class="btn" id="label_dtype_fp16">
<input type="radio" name="dtype" id="dtype_fp16" />FP16
</label>
<label class="btn" id="label_dtype_fp32">
<input type="radio" name="dtype" id="dtype_fp32" />FP32
</label>
</div>
<div class="btn-group-toggle models" data-toggle="buttons" id="modelBtns">
<label class="btn" id="label_mobilenet-v2">
<input type="radio" name="model" id="mobilenet-v2" />MobileNet V2
Expand Down
229 changes: 213 additions & 16 deletions demos/image-classification/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,91 @@ if (useRemoteModels) {
log("[Transformer.js] env.allowRemoteModels: " + transformers.env.allowRemoteModels);
log("[Transformer.js] env.allowLocalModels: " + transformers.env.allowLocalModels);

const FP16_MODEL_PATHS = {
"mobilenet-v2": "webnn/mobilenet-v2",
"resnet-50": "xenova/resnet-50",
"efficientnet-lite4": "webnn/efficientnet-lite4",
};

/**
* FP32 model IDs. Local AMD repos match xenova (`config.json`, preprocessor, `onnx/`). On the Hub,
* the same repos may use a `webnn/` folder; remote loads rewrite JSON URLs and use `subfolder:
* "webnn/onnx"` for weights.
*/
const FP32_MODEL_PATHS = {
"mobilenet-v2": "amd/MobileNetV2",
"resnet-50": "amd/resnet50",
};

/** AMD Hub repos that store `config.json` / preprocessor under `webnn/` and ONNX under `webnn/onnx/`. */
const AMD_WEBNN_HUB_LAYOUT_MODEL_IDS = ["amd/resnet50", "amd/MobileNetV2"];

const isAmdWebnnHubLayoutModel = modelPath => AMD_WEBNN_HUB_LAYOUT_MODEL_IDS.includes(modelPath);

function isRemoteHubArtifactUrl(urlString) {
if (typeof urlString !== "string" || !/^https?:\/\//i.test(urlString) || !urlString.includes("/resolve/")) {
return false;
}
return urlString.includes("/amd/resnet50/") || urlString.includes("/amd/MobileNetV2/");
}

/**
* Remote Hub only: map `.../resolve/<rev>/config.json` → `.../resolve/<rev>/webnn/config.json`.
* Local file URLs (no `/resolve/`) are unchanged.
*/
function rewriteAmdWebnnHubJsonAssetUrl(urlString) {
if (!isRemoteHubArtifactUrl(urlString)) {
return urlString;
}
if (urlString.includes("/webnn/config.json") || urlString.includes("/webnn/preprocessor_config.json")) {
return urlString;
}
for (const id of AMD_WEBNN_HUB_LAYOUT_MODEL_IDS) {
const escaped = id.replace(/\//g, "\\/");
const re = new RegExp(
`^(https?://[^/]+/${escaped}/resolve/[^/]+/)(config\\.json|preprocessor_config\\.json)(\\?.*)?$`,
"i",
);
const match = urlString.match(re);
if (match) {
return `${match[1]}webnn/${match[2]}${match[3] ?? ""}`;
}
}
return urlString;
}

let amdWebnnHubJsonFetchInstalled = false;

function ensureAmdWebnnHubJsonFetch(env) {
if (amdWebnnHubJsonFetchInstalled) {
return;
}
const inner = env.fetch.bind(env);
env.fetch = async (input, init) => {
if (typeof input === "string") {
return inner(rewriteAmdWebnnHubJsonAssetUrl(input), init);
}
if (typeof Request !== "undefined" && input instanceof Request) {
const next = rewriteAmdWebnnHubJsonAssetUrl(input.url);
if (next !== input.url) {
return inner(new Request(next, input), init);
}
}
return inner(input, init);
};
amdWebnnHubJsonFetchInstalled = true;
}

const resolveModelPath = (id, dtype) => {
if (dtype === "fp32") {
if (id === "efficientnet-lite4") {
return FP32_MODEL_PATHS["resnet-50"];
}
return FP32_MODEL_PATHS[id] ?? FP32_MODEL_PATHS["resnet-50"];
}
return FP16_MODEL_PATHS[id] ?? FP16_MODEL_PATHS["resnet-50"];
};

let provider = "webnn";
let deviceType = "gpu";
let dataType = "fp16";
Expand All @@ -52,6 +137,7 @@ let runs = 1;
let range, rangeValue, runSpan;
let backendLabels, modelLabels;
let label_webgpu, label_webnn_gpu, label_webnn_npu, label_mobilenetV2, label_resnet50, label_efficientnetLite4;
let dtypeLabels, label_dtype_fp16, label_dtype_fp32;
let uploadImage, label_uploadImage;
let imageUrl, image;
let classify;
Expand All @@ -66,6 +152,55 @@ let dataTypeSpan;
let modelIdSpan;
let latency, latencyDiv, indicator;
let title, device, badge;
let dtypeBtnsRow;

const isWebnnNpuFromQuery = () =>
getQueryValue("provider")?.toLowerCase() === "webnn" && getQueryValue("devicetype")?.toLowerCase() === "npu";

const syncDtypeRowVisibility = () => {
if (!dtypeBtnsRow) {
return;
}
if (isWebnnNpuFromQuery()) {
dtypeBtnsRow.classList.remove("hide");
} else {
dtypeBtnsRow.classList.add("hide");
}
};

const syncEfficientnetFp32Visibility = () => {
if (!label_efficientnetLite4) {
return;
}
const hideEfficientnet = isWebnnNpuFromQuery() && getQueryValue("dtype")?.toLowerCase() === "fp32";
if (hideEfficientnet) {
label_efficientnetLite4.classList.add("hide");
} else {
label_efficientnetLite4.classList.remove("hide");
}
};

/**
* FP32 AMD ONNX exports often name the image tensor `input` while the image-classification pipeline
* passes `pixel_values`. Apply to every repo listed in `FP32_MODEL_PATHS`.
*/
const patchAmdClassifierPixelValuesInput = (classifier, modelPath) => {
if (!Object.values(FP32_MODEL_PATHS).includes(modelPath)) {
return;
}
const model = classifier?.model;
if (!model) {
return;
}
const _call = model._call.bind(model);
model._call = async model_inputs => {
let mi = model_inputs;
if (mi?.pixel_values != null && mi.input == null) {
mi = { ...mi, input: mi.pixel_values };
}
return _call(mi);
};
};

const main = async () => {
fullResult.setAttribute("class", "none");
Expand All @@ -77,22 +212,25 @@ const main = async () => {

if (getQueryValue("model")) {
modelId = getQueryValue("model");
switch (modelId) {
case "mobilenet-v2":
modelPath = "webnn/mobilenet-v2";
break;
case "resnet-50":
modelPath = "xenova/resnet-50";
break;
case "efficientnet-lite4":
modelPath = "webnn/efficientnet-lite4";
break;
default:
modelPath = "xenova/resnet-50";
break;
if (!["mobilenet-v2", "resnet-50", "efficientnet-lite4"].includes(modelId)) {
modelId = "resnet-50";
}
}

const dtypeParam = getQueryValue("dtype");
const urlDtype = dtypeParam?.toLowerCase() === "fp32" ? "fp32" : "fp16";
dataType = isWebnnNpuFromQuery() ? urlDtype : "fp16";

if (isWebnnNpuFromQuery() && dataType === "fp32" && modelId === "efficientnet-lite4") {
modelId = "resnet-50";
}

modelPath = resolveModelPath(modelId, dataType);

if (isAmdWebnnHubLayoutModel(modelPath)) {
ensureAmdWebnnHubJsonFetch(transformers.env);
}

await remapHuggingFaceDomainIfNeeded(transformers.env);

let device = "webnn-gpu";
Expand Down Expand Up @@ -120,6 +258,10 @@ const main = async () => {
options.session_options.freeDimensionOverrides = dimensionOverrides[modelId];
}

if (dataType === "fp32" && Object.values(FP32_MODEL_PATHS).includes(modelPath)) {
options.subfolder = useRemoteModels && isAmdWebnnHubLayoutModel(modelPath) ? "webnn/onnx" : "onnx";
}

modelIdSpan.innerHTML = dataType;
dataTypeSpan.innerHTML = modelPath;

Expand All @@ -129,6 +271,7 @@ const main = async () => {
WebNNPerf.configure({ model: modelId, device: deviceType, provider });

const classifier = await transformers.pipeline("image-classification", modelPath, options);
patchAmdClassifierPixelValuesInput(classifier, modelPath);

let [err, output] = await asyncErrorHandling(classifier(imageUrl, { topk: 3 }));

Expand Down Expand Up @@ -240,6 +383,17 @@ const checkWebNN = async () => {
}
};

const initDtypeSelector = () => {
dtypeLabels.forEach(label => {
label.setAttribute("class", "btn");
});
if (dataType === "fp32") {
label_dtype_fp32.setAttribute("class", "btn active");
} else {
label_dtype_fp16.setAttribute("class", "btn active");
}
};

const initModelSelector = () => {
provider = getQueryValue("provider").toLowerCase();
deviceType = getQueryValue("devicetype").toLowerCase();
Expand Down Expand Up @@ -288,6 +442,7 @@ const controls = async () => {

let backendBtns = $("#backendBtns");
let modelBtns = $("#modelBtns");
let dtypeBtns = $("#dtypeBtns");

const updateBackend = e => {
backendLabels.forEach(label => {
Expand Down Expand Up @@ -319,7 +474,7 @@ const controls = async () => {
currentUrl = window.location.href;
updatedUrl = updateQueryStringParameter(currentUrl, "devicetype", "npu");
window.history.pushState({}, "", updatedUrl);
provider = "webgpu";
provider = "webnn";
deviceType = "npu";
}

Expand Down Expand Up @@ -350,8 +505,30 @@ const controls = async () => {
updateUi();
};

const updateDtype = e => {
dtypeLabels.forEach(label => {
label.setAttribute("class", "btn");
});
e.target.parentNode.setAttribute("class", "btn active");
const id = e.target.id.trim();
let currentUrl = window.location.href;
let updatedUrl;
if (id === "dtype_fp16") {
dataType = "fp16";
updatedUrl = updateQueryStringParameter(currentUrl, "dtype", "fp16");
} else if (id === "dtype_fp32") {
dataType = "fp32";
updatedUrl = updateQueryStringParameter(currentUrl, "dtype", "fp32");
}
if (updatedUrl) {
window.history.pushState({}, "", updatedUrl);
}
updateUi();
};

backendBtns.addEventListener("change", updateBackend, false);
modelBtns.addEventListener("change", updateModel, false);
dtypeBtns.addEventListener("change", updateDtype, false);
};

const badgeUpdate = () => {
Expand Down Expand Up @@ -410,9 +587,25 @@ const updateUi = async () => {
modelId = getQueryValue("model");
}

if (getQueryValue("dtype")) {
const d = getQueryValue("dtype").toLowerCase();
dataType = d === "fp32" ? "fp32" : "fp16";
} else {
dataType = "fp16";
}

if (isWebnnNpuFromQuery() && dataType === "fp32" && modelId === "efficientnet-lite4") {
modelId = "resnet-50";
window.history.replaceState({}, "", updateQueryStringParameter(window.location.href, "model", "resnet-50"));
}

initModelSelector();
badgeUpdate();
log(`[Config] Demo config updated · ${modelId} · ${provider} · ${deviceType}`);

initDtypeSelector();
syncDtypeRowVisibility();
syncEfficientnetFp32Visibility();
log(`[Config] Demo config updated · ${modelId} · ${provider} · ${deviceType} · ${dataType}`);
await checkWebNN();
console.log(provider);
console.log(deviceType);
Expand All @@ -432,7 +625,7 @@ const changeImage = async () => {
const ui = async () => {
imageUrl = "./static/tiger.jpg";
if (!(getQueryValue("provider") && getQueryValue("model") && getQueryValue("devicetype") && getQueryValue("run"))) {
let url = "?provider=webnn&devicetype=gpu&model=resnet-50&run=5";
let url = "?provider=webnn&devicetype=gpu&model=resnet-50&run=5&dtype=fp16";
location.replace(url);
}

Expand All @@ -450,6 +643,10 @@ const ui = async () => {
label_mobilenetV2 = $("#label_mobilenet-v2");
label_resnet50 = $("#label_resnet-50");
label_efficientnetLite4 = $("#label_efficientnet-lite4");
dtypeLabels = $$(".dtypes label");
label_dtype_fp16 = $("#label_dtype_fp16");
label_dtype_fp32 = $("#label_dtype_fp32");
dtypeBtnsRow = $("#dtypeBtns");
image = $("#image");
uploadImage = $("#upload-image");
label_uploadImage = $("#label_upload-image");
Expand Down
Loading