Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions mattergen/common/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def from_cache_path(
Load a dataset from a specified cache path.

Args:
name: Name of the reference dataset.
cache_path: Path to the cache directory containing the dataset.
transforms: List of transforms to apply to **each datapoint** when loading, e.g., to make the lattice matrices symmetric.
properties: List of properties to condition on.
dataset_transforms: List of transforms to apply to the **whole dataset**, e.g., to filter out certain entries.
Expand Down Expand Up @@ -318,9 +318,8 @@ def from_num_atoms_distribution(

Args:
num_atoms_distribution: A dictionary with the number of atoms as keys and the probability of that number of atoms as values.
num_samples: The number of samples to generate.
transforms: List of transforms to apply to **each datapoint** when loading, e.g., to make the lattice matrices symmetric.
properties: List of properties to condition on.
dataset_transforms: List of transforms to apply to the **whole dataset**, e.g., to filter out certain entries.

Returns:
The dataset.
Expand Down
9 changes: 7 additions & 2 deletions mattergen/denoiser.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,10 +188,15 @@ def __init__(

Args:
gemnet: a GNN module
hidden_dim (int, optional): Number of hidden dimensions in the GemNet. Defaults to 128.
hidden_dim (int, optional): Number of hidden dimensions in the GemNet. Defaults to 512.
denoise_atom_types (bool, optional): Whether to denoise the atom types. Defaults to False.
atom_type_diffusion (str, optional): Which type of atom type diffusion to use. Defaults to "mask".
condition_on (Optional[List[str]], optional): Which aspects of the data to condition on. Strings must be in ["property", "chemical_system"]. If None (default), condition on ["chemical_system"].
property_embeddings (torch.nn.ModuleDict, optional): Property embeddings used in a
conditioned model trained from scratch.
property_embeddings_adapt (torch.nn.ModuleDict, optional): Property embeddings used in
a conditioned model, fine-tuned from a base model. Unused here.
Comment on lines +196 to +197
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To me it seemed that property_embeddings_adapt is unused here and only used in the child class GemNetTAdapter, which handles the fine tuning, so I added the hint "unused here". It could probably be removed, but then one might also want to remove it from some yaml files and I wasn't too sure about it, so it just stuck to the hint instead of trying to remove it. Possibly the consistency to the other classes might be desireable.

element_mask_func (Callable, optional): Function to mask out unsupported or undesired
elements that shall not be contained in the generated structures.
"""
super(GemNetTDenoiser, self).__init__()

Expand Down