-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathmain.py
More file actions
111 lines (81 loc) · 3.85 KB
/
main.py
File metadata and controls
111 lines (81 loc) · 3.85 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import logging
import warnings
import torch
from common.utils import parse_arguments, plot_benchmark_results
from src.image_processor import ImageProcessor
from src.model import ModelLoader
from src.onnx_inference import ONNXInference
from src.ov_inference import OVInference
from src.pytorch_inference import PyTorchInference
from src.tensorrt_inference import TensorRTInference
warnings.filterwarnings("ignore", category=UserWarning, module="torchvision.io.image")
logging.basicConfig(filename="inference.log", level=logging.INFO)
CUDA_AVAILABLE = False
if torch.cuda.is_available():
try:
import torch_tensorrt # noqa: F401
CUDA_AVAILABLE = True
except ImportError:
print("torch-tensorrt not installed. Running in CPU mode only.")
def _run_onnx_inference(args, model_loader, img_batch) -> dict[str, tuple[float, float]]:
onnx_inference = ONNXInference(model_loader, args.onnx_path, debug_mode=args.DEBUG)
benchmark_result = onnx_inference.benchmark(img_batch)
onnx_inference.predict(img_batch)
return {"ONNX (CPU)": benchmark_result}
def _run_openvino_inference(args, model_loader, img_batch) -> dict[str, tuple[float, float]]:
ov_inference = OVInference(model_loader, args.ov_path, debug_mode=args.DEBUG)
benchmark_result = ov_inference.benchmark(img_batch)
ov_inference.predict(img_batch)
return {"OpenVINO (CPU)": benchmark_result}
def _run_pytorch_cpu_inference(args, model_loader, img_batch) -> dict[str, tuple[float, float]]:
pytorch_cpu_inference = PyTorchInference(model_loader, device="cpu", debug_mode=args.DEBUG)
benchmark_result = pytorch_cpu_inference.benchmark(img_batch)
pytorch_cpu_inference.predict(img_batch)
return {"PyTorch (CPU)": benchmark_result}
def _run_pytorch_cuda_inference(
args, model_loader, device, img_batch
) -> dict[str, tuple[float, float]]:
print("Running CUDA inference...")
pytorch_cuda_inference = PyTorchInference(model_loader, device=device, debug_mode=args.DEBUG)
benchmark_result = pytorch_cuda_inference.benchmark(img_batch)
pytorch_cuda_inference.predict(img_batch)
return {"PyTorch (CUDA)": benchmark_result}
def _run_tensorrt_inference(
args, model_loader, device, img_batch
) -> dict[str, tuple[float, float]]:
results = {}
precisions = [torch.float16, torch.float32]
for precision in precisions:
tensorrt_inference = TensorRTInference(
model_loader, device=device, precision=precision, debug_mode=args.DEBUG
)
benchmark_result = tensorrt_inference.benchmark(img_batch)
tensorrt_inference.predict(img_batch)
results[f"TRT_{precision}"] = benchmark_result
return results
def main():
args = parse_arguments()
if args.DEBUG:
print("Debug mode enabled")
benchmark_results = {}
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_loader = ModelLoader(device=device)
img_processor = ImageProcessor(img_path=args.image_path, device=device)
img_batch = img_processor.process_image()
if args.mode in ["onnx", "all"]:
benchmark_results.update(_run_onnx_inference(args, model_loader, img_batch))
if args.mode in ["ov", "all"]:
benchmark_results.update(_run_openvino_inference(args, model_loader, img_batch))
if args.mode in ["cpu", "all"]:
benchmark_results.update(_run_pytorch_cpu_inference(args, model_loader, img_batch))
if torch.cuda.is_available():
if args.mode in ["cuda", "all"]:
benchmark_results.update(
_run_pytorch_cuda_inference(args, model_loader, device, img_batch)
)
if args.mode in ["tensorrt", "all"]:
benchmark_results.update(_run_tensorrt_inference(args, model_loader, device, img_batch))
if args.mode == "all":
plot_benchmark_results(benchmark_results)
if __name__ == "__main__":
main()