Description
During a review of the PyTorch Tabular codebase, I noticed that the ONNX export functionality is marked with a # TODO Need to test ONNX export in src/pytorch_tabular/tabular_model.py. Moreover, in tests/test_common.py, the "onnx" save type parametrization is commented out, indicating the export functionality is unverified and incomplete.
Exporting to ONNX is a highly critical feature for deploying models into production environments in C++ or other non-Python runtimes.
Scope and Missing Functionality
A robust ONNX export implementation requires addressing several technical gaps currently missing in the implementation:
- Model Scope: Phase 1 will fully validate export for stable models (
CategoryEmbeddingModel and AutoInt), while throwing a graceful NotImplementedError for architectures incompatible with standard ONNX opset 13 (like GATE, NODE, and TabNet).
- Missing Input Branches: The current code expects both
continuous and categorical tensors. It needs to dynamically construct the input schemas and dynamic_axes dictionaries if one of those feature ranges is absent.
- Model State Enforcement:
torch.onnx.export requires the model to be wrapped dynamically in an eval() state to prevent issues tracing Dropout and Batch Normalization.
- Metadata Attachment: Adding
pytorch_tabular_version and the PyTorch model class name into the ONNX file's metadata properties for future provenance tracking.
- Numerical Equivalence Verification: Re-enabling the
onnx test parametrization to assert strict numerical equivalence (np.testing.assert_allclose) between the ONNX runtime (onnxruntime) outputs and the PyTorch model outputs.
Expected Behavior
The "onnx" parametrization should be reintroduced in test_common.py. The test suite should verify that supported models export cleanly under Opset 13 using save_model_for_inference(kind="onnx"), and that loading the model using onnxruntime yields numerically equivalent predictions (within a specified tolerance) to the PyTorch native outputs for identical batches across varying feature combinations.
Proposal
I am proposing to:
- Formalize
onnx and onnxruntime as optional testing dependencies and use pytest.mark.skipif for test environments lacking them.
- Complete the programmatic construction of dummy inputs and dynamic axes inside
save_model_for_inference.
- Wrap the export operation in
eval()/train() state restorations.
- Implement rigorous tests verifying file graph validity (
onnx.checker.check_model) and sub-1e-4 tolerance runtime output equivalencies.
- Provide graceful
NotImplementedError exits for unsupported models.
Would you like me to raise a PR for this?
Description
During a review of the PyTorch Tabular codebase, I noticed that the ONNX export functionality is marked with a
# TODO Need to test ONNX exportinsrc/pytorch_tabular/tabular_model.py. Moreover, intests/test_common.py, the"onnx"save type parametrization is commented out, indicating the export functionality is unverified and incomplete.Exporting to ONNX is a highly critical feature for deploying models into production environments in C++ or other non-Python runtimes.
Scope and Missing Functionality
A robust ONNX export implementation requires addressing several technical gaps currently missing in the implementation:
CategoryEmbeddingModelandAutoInt), while throwing a gracefulNotImplementedErrorfor architectures incompatible with standard ONNX opset 13 (likeGATE,NODE, andTabNet).continuousandcategoricaltensors. It needs to dynamically construct the input schemas anddynamic_axesdictionaries if one of those feature ranges is absent.torch.onnx.exportrequires the model to be wrapped dynamically in aneval()state to prevent issues tracing Dropout and Batch Normalization.pytorch_tabular_versionand the PyTorch model class name into the ONNX file's metadata properties for future provenance tracking.onnxtest parametrization to assert strict numerical equivalence (np.testing.assert_allclose) between the ONNX runtime (onnxruntime) outputs and the PyTorch model outputs.Expected Behavior
The
"onnx"parametrization should be reintroduced intest_common.py. The test suite should verify that supported models export cleanly under Opset 13 usingsave_model_for_inference(kind="onnx"), and that loading the model usingonnxruntimeyields numerically equivalent predictions (within a specified tolerance) to the PyTorch native outputs for identical batches across varying feature combinations.Proposal
I am proposing to:
onnxandonnxruntimeas optional testing dependencies and usepytest.mark.skipiffor test environments lacking them.save_model_for_inference.eval()/train()state restorations.onnx.checker.check_model) and sub-1e-4 tolerance runtime output equivalencies.NotImplementedErrorexits for unsupported models.Would you like me to raise a PR for this?