Skip to content

Implement Comprehensive ONNX Export Testing and Validation (save_model_for_inference) #654

@Tomgrinds777

Description

@Tomgrinds777

Description

During a review of the PyTorch Tabular codebase, I noticed that the ONNX export functionality is marked with a # TODO Need to test ONNX export in src/pytorch_tabular/tabular_model.py. Moreover, in tests/test_common.py, the "onnx" save type parametrization is commented out, indicating the export functionality is unverified and incomplete.

Exporting to ONNX is a highly critical feature for deploying models into production environments in C++ or other non-Python runtimes.

Scope and Missing Functionality

A robust ONNX export implementation requires addressing several technical gaps currently missing in the implementation:

  1. Model Scope: Phase 1 will fully validate export for stable models (CategoryEmbeddingModel and AutoInt), while throwing a graceful NotImplementedError for architectures incompatible with standard ONNX opset 13 (like GATE, NODE, and TabNet).
  2. Missing Input Branches: The current code expects both continuous and categorical tensors. It needs to dynamically construct the input schemas and dynamic_axes dictionaries if one of those feature ranges is absent.
  3. Model State Enforcement: torch.onnx.export requires the model to be wrapped dynamically in an eval() state to prevent issues tracing Dropout and Batch Normalization.
  4. Metadata Attachment: Adding pytorch_tabular_version and the PyTorch model class name into the ONNX file's metadata properties for future provenance tracking.
  5. Numerical Equivalence Verification: Re-enabling the onnx test parametrization to assert strict numerical equivalence (np.testing.assert_allclose) between the ONNX runtime (onnxruntime) outputs and the PyTorch model outputs.

Expected Behavior

The "onnx" parametrization should be reintroduced in test_common.py. The test suite should verify that supported models export cleanly under Opset 13 using save_model_for_inference(kind="onnx"), and that loading the model using onnxruntime yields numerically equivalent predictions (within a specified tolerance) to the PyTorch native outputs for identical batches across varying feature combinations.

Proposal

I am proposing to:

  1. Formalize onnx and onnxruntime as optional testing dependencies and use pytest.mark.skipif for test environments lacking them.
  2. Complete the programmatic construction of dummy inputs and dynamic axes inside save_model_for_inference.
  3. Wrap the export operation in eval()/train() state restorations.
  4. Implement rigorous tests verifying file graph validity (onnx.checker.check_model) and sub-1e-4 tolerance runtime output equivalencies.
  5. Provide graceful NotImplementedError exits for unsupported models.

Would you like me to raise a PR for this?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions