Concrete Dropout for GeoTransolver model UQ #1548
Conversation
Greptile SummaryThis PR adds Concrete Dropout-based uncertainty quantification to GeoTransolver by introducing a learnable per-layer dropout probability (via the binary-concrete relaxation) that enables calibrated MC-Dropout inference without manual hyperparameter tuning. All changes are backward-compatible — the feature is opt-in via There are a few issues worth addressing before merge:
Important Files Changed
|
coreyjadams
left a comment
There was a problem hiding this comment.
Hi @mnabian - I think ConcreteDropout is a great idea to add to physicsnemo in general, actually! We have already GumbelSoftmax which, numerically, is closely related: https://github.com/NVIDIA/physicsnemo/blob/main/physicsnemo/nn/module/gumbel_softmax.py
If I understand it correctly, ConcreteDropout is a two-category optimization of that same mathematical concept. Could we add this implementation in a similar way, in physicsnemo.nn, and add some tests? I think it is just fine to add it to GeoTransolver like this.
I don't think we need to merge the implementations, btw: the sigmoid version here is computationally more efficient than softmax in 2 categories.
|
/blossom-ci |
Thanks for reviewing the PR. I have moved the concrete_dropout to |
|
/blossom-ci |
coreyjadams
left a comment
There was a problem hiding this comment.
THanks @mnabian looks good!
|
/blossom-ci |
|
/blossom-ci |
PhysicsNeMo Pull Request
Description
Adds concrete dropout-based uncertainty quantification (UQ) to GeoTransolver. Concrete dropout (Gal, Hron & Kendall, 2017) makes the dropout probability a learnable parameter per layer using the concrete (Gumbel-softmax) relaxation, enabling calibrated MC-Dropout inference without manual tuning of dropout rates.
All changes are backward compatible — concrete dropout is disabled by default and existing behavior is unchanged.
Changes
Model (
physicsnemo/experimental/models/geotransolver/)New file:
concrete_dropout.pyConcreteDropoutmodule: wraps a layer with a learnable dropout probability using the concrete relaxation. During training, applies soft binary masks viaz = sigmoid((log(u) - log(1-u) + log(p) - log(1-p)) / temperature)with temperature=0.1. During eval, passes input unchanged.collect_concrete_dropout_losses(model): gathers entropy regularization losses from allConcreteDropoutmodules. The lossp*log(p) + (1-p)*log(1-p)prevents dropout rates from collapsing to 0 or 1.get_concrete_dropout_rates(model): extracts learned per-layer dropout probabilities for monitoring.Modified:
gale.pyGALE: replaces inheritedout_dropoutwithConcreteDropoutwhen enabled (attention output projection).GALE_block: addsConcreteDropoutafter the attention residual connection (attn_dropout) and after the FFN residual connection (ffn_dropout).Modified:
context_projector.pyContextProjector: addsConcreteDropouton output slice tokens (output_dropout).MultiScaleFeatureExtractorandGlobalContextBuilder.Modified:
geotransolver.pyconcrete_dropout: bool = False,dropout_reg: float = 1e-3,weight_reg: float = 1e-6.concrete_dropout_reg_loss(),concrete_dropout_rates(),enable_mc_dropout().Modified:
__init__.pyConcreteDropout,collect_concrete_dropout_losses,get_concrete_dropout_rates.Training recipe (
examples/.../transformer_models/src/)Modified:
train.pylambda_reg > 0(default 0, no-op).Modified:
conf/model/geotransolver.yamlconcrete_dropout: false(default off)dropout_reg: 1.0e-3(entropy regularization coefficient, ~1/(2N) for N=400 samples)weight_reg: 1.0e-6(weight regularization coefficient, unused in loss since AdamW handles L2)Modified:
conf/training/base.yamllambda_reg: 0.0(multiplier for regularization loss, default disabled)Inference recipe (
examples/.../transformer_models/src/)Modified:
inference_on_zarr.pymc_dropout_inference_loop(): runs N stochastic forward passes, returns mean predictions, std predictions, all per-sample predictions, averaged loss/metrics, and targets.model.eval()withConcreteDropoutlayers kept in train mode. Gated by+mc_dropout_samples=N(default 0, disabled).Modified:
inference_on_vtk.pyPredictedPressure,PredictedWallShearStress(deterministic, always written)MCMeanPressure,MCMeanWallShearStress(MC mean, only when MC enabled)MCStdPressure,MCStdWallShearStress(MC std, only when MC enabled)Usage
Training with concrete dropout:
python src/train.py --config-name geotransolver_surface \ model.concrete_dropout=true \ training.lambda_reg=0.1Inference with MC-Dropout UQ (zarr dataset):
python src/inference_on_zarr.py --config-name geotransolver_surface \ +mc_dropout_samples=20Inference with MC-Dropout UQ (raw VTK files):
python src/inference_on_vtk.py --config-name geotransolver_surface \ +vtk_inference.input_dir=/path/to/runs \ +vtk_inference.output_dir=/path/to/output \ +mc_dropout_samples=20Test
This has been tested on the DrivAerML dataset. Two interesting observations:
Checklist
Dependencies
None