-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathrun.py
More file actions
59 lines (45 loc) · 1.92 KB
/
run.py
File metadata and controls
59 lines (45 loc) · 1.92 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import mxnet as mx
import tarfile
import numpy as np
import time
from sagemaker.session import Session
from sagemaker.mxnet import MXNetModel
from mxnet.gluon.data.vision import transforms
mx.test_utils.download('https://s3.amazonaws.com/onnx-model-zoo/resnet/resnet50v2/resnet50v2.onnx')
with tarfile.open('onnx_model.tar.gz', mode='w:gz') as archive:
archive.add('resnet50v2.onnx')
model_data = Session().upload_data(path='onnx_model.tar.gz', key_prefix='model')
role = 'arn:aws:iam::841569659894:role/sagemaker-access-role'
mxnet_model = MXNetModel(model_data=model_data,
entry_point='resnet50.py',
role=role,
image='763104351884.dkr.ecr.us-west-2.amazonaws.com/mxnet-inference:1.4.1-gpu-py36-cu100-ubuntu16.04',
py_version='py3',
framework_version='1.4.1')
predictor = mxnet_model.deploy(initial_instance_count=1, instance_type='ml.p3.8xlarge')
img_path = mx.test_utils.download('https://s3.amazonaws.com/onnx-mxnet/examples/mallard_duck.jpg')
img = mx.image.imread(img_path)
def preprocess(img):
transform_fn = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
img = transform_fn(img)
img = img.expand_dims(axis=0)
return img
input_image = preprocess(img)
mx.test_utils.download('https://s3.amazonaws.com/onnx-model-zoo/synset.txt')
def do_pred():
start_time = time.time()
scores = predictor.predict(input_image.asnumpy())
end_time = time.time()
with open('synset.txt', 'r') as f:
labels = [l.rstrip() for l in f]
a = np.argsort(scores)[::-1]
for i in a[0:5]:
print('class=%s ; probability=%f' %(labels[i],scores[i]))
return end_time-start_time
costtime = do_pred()
print("this run cost {}s".format(costtime))