Skip to content

Concrete Dropout for GeoTransolver model UQ #1548

Merged
mnabian merged 14 commits intoNVIDIA:mainfrom
mnabian:geotransolver-uq
Apr 14, 2026
Merged

Concrete Dropout for GeoTransolver model UQ #1548
mnabian merged 14 commits intoNVIDIA:mainfrom
mnabian:geotransolver-uq

Conversation

@mnabian
Copy link
Copy Markdown
Collaborator

@mnabian mnabian commented Apr 3, 2026

PhysicsNeMo Pull Request

Description

Adds concrete dropout-based uncertainty quantification (UQ) to GeoTransolver. Concrete dropout (Gal, Hron & Kendall, 2017) makes the dropout probability a learnable parameter per layer using the concrete (Gumbel-softmax) relaxation, enabling calibrated MC-Dropout inference without manual tuning of dropout rates.

All changes are backward compatible — concrete dropout is disabled by default and existing behavior is unchanged.

Changes

Model (physicsnemo/experimental/models/geotransolver/)

New file: concrete_dropout.py

  • ConcreteDropout module: wraps a layer with a learnable dropout probability using the concrete relaxation. During training, applies soft binary masks via z = sigmoid((log(u) - log(1-u) + log(p) - log(1-p)) / temperature) with temperature=0.1. During eval, passes input unchanged.
  • collect_concrete_dropout_losses(model): gathers entropy regularization losses from all ConcreteDropout modules. The loss p*log(p) + (1-p)*log(1-p) prevents dropout rates from collapsing to 0 or 1.
  • get_concrete_dropout_rates(model): extracts learned per-layer dropout probabilities for monitoring.

Modified: gale.py

  • GALE: replaces inherited out_dropout with ConcreteDropout when enabled (attention output projection).
  • GALE_block: adds ConcreteDropout after the attention residual connection (attn_dropout) and after the FFN residual connection (ffn_dropout).

Modified: context_projector.py

  • ContextProjector: adds ConcreteDropout on output slice tokens (output_dropout).
  • Parameters propagated through MultiScaleFeatureExtractor and GlobalContextBuilder.

Modified: geotransolver.py

  • Constructor accepts concrete_dropout: bool = False, dropout_reg: float = 1e-3, weight_reg: float = 1e-6.
  • New methods: concrete_dropout_reg_loss(), concrete_dropout_rates(), enable_mc_dropout().

Modified: __init__.py

  • Exports ConcreteDropout, collect_concrete_dropout_losses, get_concrete_dropout_rates.

Training recipe (examples/.../transformer_models/src/)

Modified: train.py

  • Adds concrete dropout regularization loss to the training loop, gated by lambda_reg > 0 (default 0, no-op).
  • Logs learned per-layer dropout rates to TensorBoard at each epoch end.

Modified: conf/model/geotransolver.yaml

  • concrete_dropout: false (default off)
  • dropout_reg: 1.0e-3 (entropy regularization coefficient, ~1/(2N) for N=400 samples)
  • weight_reg: 1.0e-6 (weight regularization coefficient, unused in loss since AdamW handles L2)

Modified: conf/training/base.yaml

  • lambda_reg: 0.0 (multiplier for regularization loss, default disabled)

Inference recipe (examples/.../transformer_models/src/)

Modified: inference_on_zarr.py

  • Adds mc_dropout_inference_loop(): runs N stochastic forward passes, returns mean predictions, std predictions, all per-sample predictions, averaged loss/metrics, and targets.
  • MC-Dropout model setup: model.eval() with ConcreteDropout layers kept in train mode. Gated by +mc_dropout_samples=N (default 0, disabled).
  • When enabled, logs per-field uncertainty stats and computes force coefficient uncertainty (Cd/Cl std across MC samples).

Modified: inference_on_vtk.py

  • Runs a deterministic forward pass (ConcreteDropout disabled) followed by an MC-Dropout pass (if enabled).
  • Writes per-point fields to VTP/VTU output files:
    • PredictedPressure, PredictedWallShearStress (deterministic, always written)
    • MCMeanPressure, MCMeanWallShearStress (MC mean, only when MC enabled)
    • MCStdPressure, MCStdWallShearStress (MC std, only when MC enabled)
    • Equivalent fields for volume mode (Velocity, Pressure, Nut)

Usage

Training with concrete dropout:

python src/train.py --config-name geotransolver_surface \
    model.concrete_dropout=true \
    training.lambda_reg=0.1

Inference with MC-Dropout UQ (zarr dataset):

python src/inference_on_zarr.py --config-name geotransolver_surface \
    +mc_dropout_samples=20

Inference with MC-Dropout UQ (raw VTK files):

python src/inference_on_vtk.py --config-name geotransolver_surface \
    +vtk_inference.input_dir=/path/to/runs \
    +vtk_inference.output_dir=/path/to/output \
    +mc_dropout_samples=20

Test

This has been tested on the DrivAerML dataset. Two interesting observations:

  1. The accuracy of a GeoTransolver model trained with concrete dropout is on par with the accuracy of a baseline model without dropout.
Screenshot 2026-04-01 at 9 48 30 AM
  1. There is high correlation between the magnitude of the standard deviation of MC dropout and the magnitude of the prediction error, suggesting uncertainty bounds are meaningful.
Screenshot 2026-04-02 at 3 28 38 PM Screenshot 2026-04-02 at 3 37 45 PM

Checklist

Dependencies

None

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps bot commented Apr 3, 2026

Greptile Summary

This PR adds Concrete Dropout-based uncertainty quantification to GeoTransolver by introducing a learnable per-layer dropout probability (via the binary-concrete relaxation) that enables calibrated MC-Dropout inference without manual hyperparameter tuning. All changes are backward-compatible — the feature is opt-in via concrete_dropout: false default. The core model integration (gale.py, context_projector.py, geotransolver.py) is clean and well-structured, and the training recipe changes are minimal and correct.

There are a few issues worth addressing before merge:

  • weight_reg is a silent no-op (concrete_dropout.py lines 160–183): the parameter is accepted, stored, and described in the regularization_loss() docstring as an active weight-penalty term, but is never used in the loss computation. This creates a misleading interface for users who tune this value expecting it to affect training.
  • Return type annotation mismatch (inference_on_zarr.py line 348): mc_dropout_inference_loop is annotated as returning a 4-tuple but actually returns 6 values, breaking static type checking.
  • Train/eval mode toggled per-run inside a compiled model (inference_on_vtk.py lines 699–750): when cfg.compile=True and MC-Dropout is active, each run in the loop flips ConcreteDropout submodules between eval and train. torch.compile treats the training flag as a compile-time constant, so this forces a full graph retrace on every run. inference_on_zarr.py handles this correctly (mode set once before the loop); inference_on_vtk.py should follow the same pattern.
  • Unused import (concrete_dropout.py line 34): import torch.nn.functional as F is never referenced.
  • print instead of logger (inference_on_zarr.py line 396): one progress message inside mc_dropout_inference_loop uses a bare print() rather than the logger, bypassing rank-filtering and log-level control.

Important Files Changed

Filename Overview
physicsnemo/experimental/models/geotransolver/concrete_dropout.py New file implementing ConcreteDropout; the weight_reg parameter and in_features are accepted, stored, and documented as active components of the regularization loss but are never used in regularization_loss(), creating a misleading API. Unused torch.nn.functional as F import should be removed.
physicsnemo/experimental/models/geotransolver/gale.py ConcreteDropout is cleanly integrated into GALE and GALE_block; out_dropout override and attn/ffn dropout additions are backward-compatible and logically correct.
physicsnemo/experimental/models/geotransolver/context_projector.py concrete_dropout parameters are propagated correctly through MultiScaleFeatureExtractor and GlobalContextBuilder; ContextProjector output_dropout integration is clean.
physicsnemo/experimental/models/geotransolver/geotransolver.py New constructor params and convenience methods (concrete_dropout_reg_loss, concrete_dropout_rates, enable_mc_dropout) are well-implemented and backward compatible.
examples/cfd/external_aerodynamics/transformer_models/src/inference_on_zarr.py mc_dropout_inference_loop return type annotation claims 4 elements but returns 6; bare print() bypasses logging infrastructure. MC-dropout mode setup order (compile then set mode) is correct here.
examples/cfd/external_aerodynamics/transformer_models/src/inference_on_vtk.py Per-run toggling of ConcreteDropout train/eval state inside the compiled model's inference loop will force torch.compile to retrace the graph on every run, causing severe performance regression when both compile and mc_dropout_samples are enabled.
examples/cfd/external_aerodynamics/transformer_models/src/train.py Regularization loss integration and TensorBoard logging of dropout rates are clean and correctly gated by lambda_reg and concrete_dropout flags.
physicsnemo/experimental/models/geotransolver/init.py Correctly exports the three new public symbols from concrete_dropout.py.

Comments Outside Diff (2)

  1. examples/cfd/external_aerodynamics/transformer_models/src/inference_on_zarr.py, line 348-349 (link)

    P1 Return type annotation does not match actual return value

    The function signature declares a 4-element return tuple:

    ) -> tuple[torch.Tensor, torch.Tensor, float, dict]:

    but the function actually returns 6 values:

    return mean_predictions, std_predictions, stacked, mean_loss, mean_metrics, targets

    This mismatch will confuse static analysis tools and anyone writing callers against this type hint. The annotation should match the actual return:

  2. examples/cfd/external_aerodynamics/transformer_models/src/inference_on_zarr.py, line 396 (link)

    P2 print used instead of logger

    The rest of both inference_on_zarr.py and inference_on_vtk.py consistently use logger.info(...). This bare print() call will bypass any logging configuration (e.g., distributed rank filtering, log-level control):

Reviews (1): Last reviewed commit: "concrete dropout for geotransolver" | Re-trigger Greptile

Comment thread physicsnemo/experimental/models/geotransolver/concrete_dropout.py Outdated
Comment thread physicsnemo/experimental/models/geotransolver/concrete_dropout.py Outdated
@mnabian mnabian self-assigned this Apr 3, 2026
Copy link
Copy Markdown
Collaborator

@coreyjadams coreyjadams left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @mnabian - I think ConcreteDropout is a great idea to add to physicsnemo in general, actually! We have already GumbelSoftmax which, numerically, is closely related: https://github.com/NVIDIA/physicsnemo/blob/main/physicsnemo/nn/module/gumbel_softmax.py

If I understand it correctly, ConcreteDropout is a two-category optimization of that same mathematical concept. Could we add this implementation in a similar way, in physicsnemo.nn, and add some tests? I think it is just fine to add it to GeoTransolver like this.

I don't think we need to merge the implementations, btw: the sigmoid version here is computationally more efficient than softmax in 2 categories.

@mnabian mnabian requested a review from loliverhennigh as a code owner April 9, 2026 23:02
@mnabian
Copy link
Copy Markdown
Collaborator Author

mnabian commented Apr 10, 2026

/blossom-ci

@mnabian
Copy link
Copy Markdown
Collaborator Author

mnabian commented Apr 10, 2026

Hi @mnabian - I think ConcreteDropout is a great idea to add to physicsnemo in general, actually! We have already GumbelSoftmax which, numerically, is closely related: https://github.com/NVIDIA/physicsnemo/blob/main/physicsnemo/nn/module/gumbel_softmax.py

If I understand it correctly, ConcreteDropout is a two-category optimization of that same mathematical concept. Could we add this implementation in a similar way, in physicsnemo.nn, and add some tests? I think it is just fine to add it to GeoTransolver like this.

I don't think we need to merge the implementations, btw: the sigmoid version here is computationally more efficient than softmax in 2 categories.

Thanks for reviewing the PR. I have moved the concrete_dropout to nn/module, and added tests.

@mnabian mnabian requested a review from coreyjadams April 10, 2026 00:53
@mnabian
Copy link
Copy Markdown
Collaborator Author

mnabian commented Apr 10, 2026

/blossom-ci

Comment thread physicsnemo/nn/module/concrete_dropout.py
Copy link
Copy Markdown
Collaborator

@loliverhennigh loliverhennigh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Comment thread physicsnemo/experimental/models/geotransolver/concrete_dropout.py Outdated
Comment thread physicsnemo/experimental/models/geotransolver/context_projector.py Outdated
Copy link
Copy Markdown
Collaborator

@coreyjadams coreyjadams left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

THanks @mnabian looks good!

@mnabian
Copy link
Copy Markdown
Collaborator Author

mnabian commented Apr 13, 2026

/blossom-ci

@mnabian mnabian enabled auto-merge April 13, 2026 22:27
@mnabian
Copy link
Copy Markdown
Collaborator Author

mnabian commented Apr 13, 2026

/blossom-ci

@mnabian mnabian added this pull request to the merge queue Apr 14, 2026
Merged via the queue into NVIDIA:main with commit 3bba0e1 Apr 14, 2026
4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants