Skip to content

🐛[BUG]: Sinusoidal positional embedding freq_bands formula does not produce octave doublings #1522

@szamanian

Description

@szamanian

Version

2.1.0a0 (main branch, HEAD)

On which installation method(s) does this occur?

Source

Describe the issue

The sinusoidal positional embedding in SongUNetPosEmbd._get_positional_embedding() and MultiDiffusionModel2D.__init__() computes frequency bands as:

freq_bands = 2.0 ** np.linspace(0.0, num_freq, num=num_freq)

This does not produce octave doublings (powers of two) for num_freq >= 2.

For example, with num_freq=3 this gives [1.0, 2.83, 8.0] instead of the expected [1.0, 2.0, 4.0]. The middle value 2.83 is 2^1.5, not an integer power of two.

The root cause is that np.linspace(0, num_freq, num_freq) produces evenly spaced values from 0 to num_freq inclusive endpoints, so the step size is num_freq / (num_freq - 1) rather than 1. The fix is to use np.arange(num_freq) which produces [0, 1, 2, ...].

This also creates an internal inconsistency: SinusoidalEncoding in physicsnemo/models/figconvnet/components/encodings.py already uses the correct formula 2.0 ** torch.arange(num_freq).

Affected files:

physicsnemo/models/diffusion_unets/song_unet.py (L1334-1336)
physicsnemo/models/multi_diffusion/models.py (L453-455)

Correct reference:

physicsnemo/models/figconvnet/components/encodings.py (L121-123) (uses 2.0 ** torch.arange(num_freq))

Minimum reproducible example

import numpy as np

num_freq = 3

# Current implementation (song_unet.py line ~1334, multi_diffusion/models.py line ~453)
freq_bands_current = 2.0 ** np.linspace(0.0, num_freq, num=num_freq)
print(f"Current:  {freq_bands_current}")   # [1.0, 2.83, 8.0]

# Expected (octave doublings, consistent with figconvnet/components/encodings.py)
freq_bands_fixed = 2.0 ** np.arange(num_freq)
print(f"Expected: {freq_bands_fixed}")      # [1.0, 2.0, 4.0]

Relevant log output

Current:  [1.         2.82842712 8.        ]
Expected: [1. 2. 4.]

Environment details

Environment location: Bare-metal
Installation method: source (git clone + uv sync)
OS: Linux
Python: 3.11.14

Metadata

Metadata

Assignees

Labels

? - Needs TriageNeed team to review and classifybugSomething isn't workingexternalIssues/PR filed by people outside the team

Type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions