From e7a5a8068dfd234e1da5bf50a72a644fe1318810 Mon Sep 17 00:00:00 2001 From: Luis Date: Fri, 11 Apr 2025 10:18:41 +0200 Subject: [PATCH 1/2] Update docstring of GemNetTDenoiser init to match arguments --- mattergen/denoiser.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/mattergen/denoiser.py b/mattergen/denoiser.py index bb01970b..dac1fa46 100644 --- a/mattergen/denoiser.py +++ b/mattergen/denoiser.py @@ -188,10 +188,15 @@ def __init__( Args: gemnet: a GNN module - hidden_dim (int, optional): Number of hidden dimensions in the GemNet. Defaults to 128. + hidden_dim (int, optional): Number of hidden dimensions in the GemNet. Defaults to 512. denoise_atom_types (bool, optional): Whether to denoise the atom types. Defaults to False. atom_type_diffusion (str, optional): Which type of atom type diffusion to use. Defaults to "mask". - condition_on (Optional[List[str]], optional): Which aspects of the data to condition on. Strings must be in ["property", "chemical_system"]. If None (default), condition on ["chemical_system"]. + property_embeddings (torch.nn.ModuleDict, optional): Property embeddings used in a + conditioned model trained from scratch. + property_embeddings_adapt (torch.nn.ModuleDict, optional): Property embeddings used in + a conditioned model, fine-tuned from a base model. Unused here. + element_mask_func (Callable, optional): Function to mask out unsupported or undesired + elements that shall not be contained in the generated structures. """ super(GemNetTDenoiser, self).__init__() From 655a68e949861e3a41d33ef054967a8f36822505 Mon Sep 17 00:00:00 2001 From: Luis Date: Fri, 11 Apr 2025 10:18:59 +0200 Subject: [PATCH 2/2] Update docstrings of dataset methods --- mattergen/common/data/dataset.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/mattergen/common/data/dataset.py b/mattergen/common/data/dataset.py index 812c4f21..e6ade463 100644 --- a/mattergen/common/data/dataset.py +++ b/mattergen/common/data/dataset.py @@ -101,7 +101,7 @@ def from_cache_path( Load a dataset from a specified cache path. Args: - name: Name of the reference dataset. + cache_path: Path to the cache directory containing the dataset. transforms: List of transforms to apply to **each datapoint** when loading, e.g., to make the lattice matrices symmetric. properties: List of properties to condition on. dataset_transforms: List of transforms to apply to the **whole dataset**, e.g., to filter out certain entries. @@ -318,9 +318,8 @@ def from_num_atoms_distribution( Args: num_atoms_distribution: A dictionary with the number of atoms as keys and the probability of that number of atoms as values. + num_samples: The number of samples to generate. transforms: List of transforms to apply to **each datapoint** when loading, e.g., to make the lattice matrices symmetric. - properties: List of properties to condition on. - dataset_transforms: List of transforms to apply to the **whole dataset**, e.g., to filter out certain entries. Returns: The dataset.