Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,16 @@ def _read_model_profile_bytes(profile_path: Path) -> bytes:
return profile_bytes


def _contains_parent_path_segment(path: Path) -> bool:
"""Return True when a raw path contains a parent traversal segment."""
path_text = str(path)
normalized_path_text = path_text
for separator in {os.sep, os.altsep, "\\"}:
if separator and separator != "/":
normalized_path_text = normalized_path_text.replace(separator, "/")
return any(part == ".." for part in normalized_path_text.split("/"))


@dataclass(frozen=True)
class AudioSeparationConfig:
"""Resource and band-split settings for local stem separation."""
Expand Down Expand Up @@ -133,6 +143,8 @@ def separate(self, audio_path: str | Path) -> AudioSeparationResult:
def _resolve_audio_file(self, audio_path: str | Path) -> Path:
"""Normalize and validate the selected source path."""
candidate = Path(audio_path).expanduser()
if _contains_parent_path_segment(candidate):
raise ValueError("Path traversal attempt detected in selected audio path")
try:
path = candidate.resolve(strict=True)
except FileNotFoundError as error:
Expand Down Expand Up @@ -216,6 +228,8 @@ def _load_model_profile(self) -> dict[str, float]:

if self.config.model_profile_path:
profile_candidate = Path(self.config.model_profile_path).expanduser()
if _contains_parent_path_segment(profile_candidate):
raise ValueError("Path traversal attempt detected in selected model profile path")
try:
profile_path = profile_candidate.resolve(strict=True)
except FileNotFoundError as error:
Expand Down
81 changes: 81 additions & 0 deletions services/analysis-engine/tests/test_separation.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,36 @@ def test_audio_stem_separator_rejects_missing_audio_file(tmp_path) -> None:
separator.separate(tmp_path / "missing.wav")


def test_audio_stem_separator_rejects_parent_traversal_in_audio_file(tmp_path) -> None:
"""Ensure parent path segments are rejected before source path resolution."""
separator = AudioStemSeparator(AudioSeparationConfig(target_sample_rate=8_000))

with pytest.raises(ValueError, match="Path traversal attempt detected"):
separator.separate(tmp_path / "nested" / ".." / "rehearsal.wav")


def test_audio_stem_separator_rejects_altsep_parent_traversal_in_audio_file() -> None:
"""Ensure backslash traversal is rejected on non-Windows hosts."""
separator = AudioStemSeparator(AudioSeparationConfig(target_sample_rate=8_000))

with pytest.raises(ValueError, match="Path traversal attempt detected"):
separator.separate("safe\\..\\rehearsal.wav")
Comment thread
seonghobae marked this conversation as resolved.


@pytest.mark.parametrize(
"audio_path",
["safe/..\\rehearsal.wav", "safe\\../rehearsal.wav"],
)
def test_audio_stem_separator_rejects_mixed_separator_parent_traversal(
audio_path: str,
) -> None:
"""Ensure mixed-separator traversal is rejected before path resolution."""
separator = AudioStemSeparator(AudioSeparationConfig(target_sample_rate=8_000))

with pytest.raises(ValueError, match="Path traversal attempt detected"):
separator.separate(audio_path)


def test_audio_stem_separator_rejects_directory_source(tmp_path) -> None:
"""Ensure directories are not accepted as audio files."""
source_dir = tmp_path / "source-dir"
Expand Down Expand Up @@ -277,6 +307,15 @@ def fail_decode(*args, **kwargs):
assert str(tmp_path) not in str(error.value)


def test_audio_stem_separator_fit_length_zero() -> None:
"""Ensure zero-length targets stay bounded and return an empty stem."""
separator = AudioStemSeparator(AudioSeparationConfig(target_sample_rate=8_000))

fitted = separator._fit_length(np.ones(4, dtype=np.float32), 0)

assert fitted.shape == (0,)


def test_audio_stem_separator_uses_verified_local_model_profile(tmp_path) -> None:
"""Ensure local model profile overrides are applied only when checksum is verified."""
profile_path = tmp_path / "profile.json"
Expand Down Expand Up @@ -403,6 +442,48 @@ def test_audio_stem_separator_rejects_missing_local_model_profile(tmp_path) -> N
)


def test_audio_stem_separator_rejects_parent_traversal_in_model_profile(tmp_path) -> None:
"""Ensure parent path segments are rejected before profile path resolution."""
with pytest.raises(ValueError, match="Path traversal attempt detected"):
AudioStemSeparator(
AudioSeparationConfig(
target_sample_rate=8_000,
model_profile_path=str(tmp_path / "profiles" / ".." / "profile.json"),
model_profile_sha256="0" * 64,
)
)


def test_audio_stem_separator_rejects_altsep_parent_traversal_in_model_profile() -> None:
"""Ensure backslash traversal in profile paths is rejected on non-Windows hosts."""
with pytest.raises(ValueError, match="Path traversal attempt detected"):
AudioStemSeparator(
AudioSeparationConfig(
target_sample_rate=8_000,
model_profile_path="profiles\\..\\profile.json",
model_profile_sha256="0" * 64,
)
)


@pytest.mark.parametrize(
"model_profile_path",
["profiles/..\\profile.json", "profiles\\../profile.json"],
)
def test_audio_stem_separator_rejects_mixed_separator_parent_traversal_in_model_profile(
model_profile_path: str,
) -> None:
"""Ensure mixed-separator traversal is rejected in model profile paths."""
with pytest.raises(ValueError, match="Path traversal attempt detected"):
AudioStemSeparator(
AudioSeparationConfig(
target_sample_rate=8_000,
model_profile_path=model_profile_path,
model_profile_sha256="0" * 64,
)
)


def test_audio_stem_separator_rejects_non_numeric_band_profile(tmp_path) -> None:
"""Ensure profile numeric coercion failures use bounded verification errors."""
profile_path = tmp_path / "profile.json"
Expand Down