Skip to content

Add distillation training recipe for CorrDiff#1533

Open
jialusui1102 wants to merge 10 commits intoNVIDIA:mainfrom
jialusui1102:distillation-training
Open

Add distillation training recipe for CorrDiff#1533
jialusui1102 wants to merge 10 commits intoNVIDIA:mainfrom
jialusui1102:distillation-training

Conversation

@jialusui1102
Copy link
Copy Markdown
Collaborator

PhysicsNeMo Pull Request

Description

Added distillation recipe and helper functions for CorrDiff distillation training with fastgen framework

Checklist

Dependencies

Review Process

All PRs are reviewed by the PhysicsNeMo team before merging.

Depending on which files are changed, GitHub may automatically assign a maintainer for review.

We are also testing AI-based code review tools (e.g., Greptile), which may add automated comments with a confidence score.
This score reflects the AI’s assessment of merge readiness and is not a qualitative judgment of your work, nor is
it an indication that the PR will be accepted / rejected.

AI-generated feedback should be reviewed critically for usefulness.
You are not required to respond to every AI comment, but they are intended to help both authors and reviewers.
Please react to Greptile comments with 👍 or 👎 to provide feedback on their accuracy.

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps bot commented Mar 25, 2026

Greptile Summary

This PR adds distillation training support for CorrDiff using the FastGen framework, introducing a new distill.py training script, a helpers/distill_helpers.py module with FastGen wrappers (SuperPatching2D, FastGenNet, CMModel/SCMModel/DMD2Model, DistillLoss), a few-step sampler for distillation inference in generate.py, and accompanying Hydra configs for three datasets (HRRR-mini, GEFS-HRRR, Taiwan/CWB).

Critical issues requiring fixes before merging:

  • Missing default config (distill.py line 129): @hydra.main uses config_name="config_distill_mini" which does not exist. The script will crash at startup for any user who doesn't explicitly pass --config-name. The closest matching config added in this PR is config_distill_hrrr_mini_diffusion.
  • Null-pointer crash on diffusion_model (distill.py line 455): diffusion_model.img_out_channels is accessed without a None-guard, even though the code explicitly sets diffusion_model = None when no checkpoint path is provided. A similar crash exists at torch.compile(diffusion_model) when use_torch_compile=True.
  • Validation loss scaling bug (distill.py lines 922–926): The loss returned by loss_fn is asserted to be a scalar (0-dim tensor) during training, but the validation loop applies .sum() / batch_size_per_gpu to that scalar. This divides an already-averaged value by batch_size_per_gpu, making the validation metric batch_size_per_gpu× smaller than training loss and incomparable in WandB/TensorBoard.
  • Wrong config key in error message (generate.py line 216): The fallback raise ValueError references cfg.sampling.type instead of cfg.sampler.type, which will throw an omegaconf.errors.ConfigAttributeError rather than the intended ValueError.

Minor issues:

  • torch._dynamo reset and verbose/logging configuration are set at module scope (lines 78–83 of distill.py), affecting all importers globally rather than being scoped to main().
  • Several unused imports (DistributedDataParallel, RegressionLoss, RegressionLossCE) and unreferenced variables (residual_loss, edm_precond_super_res) remain in distill.py.
  • The SCMModel docstring in distill_helpers.py is truncated (starts with "-time Consistency Model").

Important Files Changed

Filename Overview
examples/weather/corrdiff/distill.py New 1149-line distillation training script. Critical issues: (1) config_name="config_distill_mini" references a non-existent config file causing an immediate runtime failure; (2) diffusion_model.img_out_channels is accessed without a None-guard even though the code explicitly supports diffusion_model = None; (3) validation loss is incorrectly divided by batch_size_per_gpu making it incomparable to training loss; (4) torch._dynamo reset/config set at module scope affecting all importers.
examples/weather/corrdiff/helpers/distill_helpers.py New 898-line helper module implementing FastGen wrappers (SuperPatching2D, FastGenNet, CMModel, SCMModel, DMD2Model, DistillLoss) for CorrDiff distillation. Logic is generally sound; minor issue is a truncated docstring in SCMModel ("-time Consistency Model...").
examples/weather/corrdiff/generate.py Adds few-step sampler branch. The sampler integration is clean, but the pre-existing fallback error message incorrectly references cfg.sampling.type instead of cfg.sampler.type, which will raise a secondary OmegaConf error rather than the intended ValueError.
examples/weather/corrdiff/conf/base/distill/base_all.yaml Base distillation hyperparameter config; looks well-structured with sensible defaults for LR schedulers, optimizer, and IO settings.
examples/weather/corrdiff/conf/config_distill_gefs_hrrr_diffusion.yaml Top-level config for GEFS-HRRR distillation. Placeholder values (<path/to/checkpoint.mdlus>) for required checkpoint paths are clearly documented.
examples/weather/corrdiff/conf/config_distill_hrrr_mini_diffusion.yaml Top-level config for HRRR-mini distillation. Consistent structure with other top-level configs; no anomalies.
examples/weather/corrdiff/conf/config_distill_taiwan_diffusion.yaml Top-level config for Taiwan (CWB) distillation. Minor formatting nit: missing newline at end of file.
examples/weather/corrdiff/conf/base/generation/sampler/few-step.yaml New few-step sampler config. Missing newline at end of file; otherwise straightforward.

Reviews (1): Last reviewed commit: "update CHANEGLOG.md" | Re-trigger Greptile

Comment thread examples/weather/corrdiff/distill.py Outdated


# Train the CorrDiff model using the configurations in "conf/config_training.yaml"
@hydra.main(version_base="1.2", config_path="conf", config_name="config_distill_mini")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P0 Non-existent default config name

The config_name="config_distill_mini" does not correspond to any file in conf/. The three configs added in this PR are config_distill_hrrr_mini_diffusion.yaml, config_distill_gefs_hrrr_diffusion.yaml, and config_distill_taiwan_diffusion.yaml. Running distill.py without an explicit --config-name override will fail immediately with a Hydra MissingConfigException.

The closest match to the apparent intent is config_distill_hrrr_mini_diffusion:

Suggested change
@hydra.main(version_base="1.2", config_path="conf", config_name="config_distill_mini")
@hydra.main(version_base="1.2", config_path="conf", config_name="config_distill_hrrr_mini_diffusion")

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated

{
"precision": PRECISION_MAP[fp_optimizations],
"precision_infer": PRECISION_MAP[str(input_dtype)],
"input_shape": (diffusion_model.img_out_channels, *super_patch_shape),
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P0 AttributeError when diffusion_model is None

diffusion_model is explicitly set to None (around line 403) when no diffusion_checkpoint_path is provided or the key is absent. The attribute access diffusion_model.img_out_channels on line 455 will raise AttributeError: 'NoneType' object has no attribute 'img_out_channels' in that case.

A similar crash exists a few lines earlier (line ~434): torch.compile(diffusion_model) when use_torch_compile=True and diffusion_model is None.

A guard is needed before this block, e.g.:

if diffusion_model is None:
    raise ValueError(
        "A diffusion checkpoint must be provided for distillation training. "
        "Set cfg.distill.io.diffusion_checkpoint_path."
    )

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added

Comment on lines +922 to +926
loss_valid = (
(loss_valid.sum() / batch_size_per_gpu)
.cpu()
.item()
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Incorrect validation loss scaling

Training explicitly asserts loss.ndim == 0 (the loss returned by loss_fn is already a scalar). However, the validation path applies .sum() / batch_size_per_gpu to the same scalar output. loss_valid.sum() on a 0-dim tensor is a no-op, but the subsequent division by batch_size_per_gpu shrinks the value by that factor, making the logged validation loss batch_size_per_gpu times smaller than the comparable training loss. This makes the two metrics incomparable in TensorBoard/WandB.

Suggested change
loss_valid = (
(loss_valid.sum() / batch_size_per_gpu)
.cpu()
.item()
)
loss_valid = (
loss_valid.cpu()
.item()
)

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

skip since that's the original code

Comment thread examples/weather/corrdiff/generate.py Outdated
)

else:
raise ValueError(f"Unknown sampling method {cfg.sampling.type}")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Wrong config key in error message

Every other branch in this if/elif chain uses cfg.sampler.type, but the fallback error message references cfg.sampling.type, which is a different (and likely non-existent) key. Accessing cfg.sampling.type will raise an omegaconf.errors.ConfigAttributeError instead of surfacing the intended ValueError.

Suggested change
raise ValueError(f"Unknown sampling method {cfg.sampling.type}")
raise ValueError(f"Unknown sampling method {cfg.sampler.type}")

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated

Comment on lines +78 to +83
torch._dynamo.reset()
# Increase the cache size limit
torch._dynamo.config.cache_size_limit = 264 # Set to a higher value
torch._dynamo.config.verbose = True # Enable verbose logging
torch._dynamo.config.suppress_errors = False # Forces the error to show all details
torch._logging.set_logs(recompiles=True, graph_breaks=True)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Global-scope torch._dynamo configuration with side effects

These five statements run at module import time (not inside main()), so they apply to any process that imports distill.py:

  • torch._dynamo.config.verbose = True emits verbose graph-break logs for all dynamo usage.
  • torch._logging.set_logs(recompiles=True, graph_breaks=True) turns on extra logging globally.
  • torch._dynamo.reset() discards any previously compiled caches.

This is especially problematic for multi-process training where workers import the module, and for test suites that import helper functions. These should be moved inside main() and ideally made conditional on a config flag (e.g., cfg.distill.perf.profile_mode).

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Keep it there for torch.compile



class SCMModel(SCMBaseModel):
"""-time Consistency Model with TrigFlow for CorrDiff distillation.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Truncated docstring for SCMModel

The class docstring starts with -time Consistency Model, which looks like the beginning of the model name was accidentally deleted (probably "Straight-time" or "Stochastic-time").

Suggested change
"""-time Consistency Model with TrigFlow for CorrDiff distillation.
"""Straight-time Consistency Model with TrigFlow for CorrDiff distillation.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated

from hydra.core.hydra_config import HydraConfig
from omegaconf import DictConfig, OmegaConf
import torch
from torch.nn.parallel import DistributedDataParallel
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Unused imports

Several symbols are imported but never referenced in distill.py:

  • DistributedDataParallel (line 27) — DDPWrapper from fastgen is used instead.
  • RegressionLoss, RegressionLossCE (line 41) — only ResidualLoss is referenced.
  • residual_loss and edm_precond_super_res (lines 246/266) are assigned but never consumed downstream.

Removing these will reduce confusion about which loss/preconditioner classes are actually active during distillation.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant