-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathresnet50.py
More file actions
37 lines (33 loc) · 1.42 KB
/
resnet50.py
File metadata and controls
37 lines (33 loc) · 1.42 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
from __future__ import print_function
import json
import logging
import mxnet as mx
import mxnet.contrib.onnx as onnx_mxnet
import numpy as np
logging.basicConfig(level=logging.DEBUG)
def model_fn(model_dir):
"""
Load the onnx model. Called once when hosting service starts.
:param: model_dir The directory where model files are stored.
:return: a Module object
"""
sym, arg_params, aux_params = onnx_mxnet.import_model('%s/resnet50v2.onnx' % model_dir)
# create module
mod = mx.mod.Module(symbol=sym, data_names=['data'], label_names=None)
mod.bind(for_training=False, data_shapes=[('data', [1, 3, 224, 224])], label_shapes=mod._label_shapes)
mod.set_params(arg_params=arg_params, aux_params=aux_params, allow_missing=True, allow_extra=True)
return mod
def transform_fn(mod, data, input_content_type, output_content_type):
"""
Transform a request using the loaded model. Called once per inference request.
:param mod: The resnet152 model.
:param data: The request payload.
:param input_content_type: The request content type.
:param output_content_type: The (desired) response content type.
:return: response payload and content type.
"""
input_data = json.loads(data)
mod.predict(mx.nd.array(input_data))
scores = mx.ndarray.softmax(mod.get_outputs()[0]).asnumpy()
scores = np.squeeze(scores)
return json.dumps(scores.tolist()), output_content_type