Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@

logger = logging.getLogger(__name__)

_MAX_NOTE_LENGTH = 12

# Chromatic note order for comparison (octave-independent).
_NOTE_ORDER = [
"C",
Expand Down Expand Up @@ -54,16 +56,23 @@ def _parse_note(note: str) -> tuple[str, int]:
"""
if not note:
return ("C", 4)
import re
if len(note) > _MAX_NOTE_LENGTH:
return ("C", 4)
if note[0].upper() not in {"A", "B", "C", "D", "E", "F", "G"}:
return ("C", 4)

match = re.match(r"^([A-Ga-g](?:#|b|sharp|flat)?)(.*)$", note)
if not match:
return (note, 4)
name = note[0].upper()
octave_str = note[1:]
octave_str_lower = octave_str.lower()
for accidental_text, accidental in (("sharp", "#"), ("flat", "b"), ("#", "#"), ("b", "b")):
if octave_str_lower.startswith(accidental_text):
name += accidental
octave_str = octave_str[len(accidental_text) :]
break
Comment thread
seonghobae marked this conversation as resolved.

name, octave_str = match.groups()
if octave_str == "":
return (name, 4)
if octave_str == "-" or not re.match(r"^-?\d+$", octave_str):
if octave_str == "-" or not octave_str.removeprefix("-").isdigit():
return (name, 4)

return (name, int(octave_str))
Expand Down Expand Up @@ -121,7 +130,11 @@ def _ranges_overlap(low_a: str, high_a: str, low_b: str, high_b: str) -> bool:
midi_high_a = _note_to_midi(high_a)
midi_low_b = _note_to_midi(low_b)
midi_high_b = _note_to_midi(high_b)
return midi_low_a <= midi_high_b and midi_low_b <= midi_high_a
if midi_low_a > midi_high_a or midi_low_b > midi_high_b:
return False
overlap_low = max(midi_low_a, midi_low_b)
overlap_high = min(midi_high_a, midi_high_b)
return overlap_low <= overlap_high


def _overlap_severity(
Expand All @@ -142,14 +155,18 @@ def _overlap_severity(
midi_high_a = _note_to_midi(high_a)
midi_low_b = _note_to_midi(low_b)
midi_high_b = _note_to_midi(high_b)
if midi_low_a > midi_high_a or midi_low_b > midi_high_b:
return "low"

overlap_low = max(midi_low_a, midi_low_b)
overlap_high = min(midi_high_a, midi_high_b)
overlap_size = overlap_high - overlap_low
if overlap_size <= 0:
return "low"

range_a_size = midi_high_a - midi_low_a
range_b_size = midi_high_b - midi_low_b
min_range = min(range_a_size, range_b_size) if min(range_a_size, range_b_size) > 0 else 1
min_range = max(1, min(range_a_size, range_b_size))

ratio = overlap_size / min_range
if ratio > 0.5:
Expand All @@ -159,6 +176,15 @@ def _overlap_severity(
return "low"


def _safe_note_string(value: object) -> str:
"""Return a bounded note string or a safe default for untrusted range data."""
if not isinstance(value, str):
return "C4"
if len(value) > _MAX_NOTE_LENGTH:
return "C4"
return value


class RangeAnalyzer:
"""Analyzes pitch ranges and detects overlaps between roles."""

Expand Down Expand Up @@ -192,31 +218,36 @@ def analyze(
for role in section_roles:
role_range = role.get("range")
if isinstance(role_range, dict):
lowest_note = _safe_note_string(role_range.get("lowestNote", ""))
highest_note = _safe_note_string(role_range.get("highestNote", ""))
ranges.append(
{
"role_id": str(role.get("id", "")),
"role_name": str(role.get("name", "")),
"lowestNote": str(role_range.get("lowestNote", "")),
"highestNote": str(role_range.get("highestNote", "")),
"lowestNote": lowest_note,
"highestNote": highest_note,
}
)

ranges_with_midi = []
for r in ranges:
midi_low = _note_to_midi(r["lowestNote"])
midi_high = _note_to_midi(r["highestNote"])
if midi_low > midi_high:
continue
ranges_with_midi.append(
(
r,
_note_to_midi(r["lowestNote"]),
_note_to_midi(r["highestNote"]),
midi_low,
midi_high,
)
)

# Sort ranges by lowest note MIDI value for efficient overlap detection
ranges_with_midi.sort(key=lambda x: x[1])

# Detect overlaps between all pairs of ranges
for a_idx in range(len(ranges_with_midi)):
r_a, midi_low_a, midi_high_a = ranges_with_midi[a_idx]
for a_idx, (r_a, _midi_low_a, midi_high_a) in enumerate(ranges_with_midi):
for b_idx in range(a_idx + 1, len(ranges_with_midi)):
r_b, midi_low_b, midi_high_b = ranges_with_midi[b_idx]

Expand All @@ -225,25 +256,23 @@ def analyze(
if midi_low_b > midi_high_a:
break

# Check for overlap
if midi_low_a <= midi_high_b and midi_low_b <= midi_high_a:
severity = _overlap_severity(
r_a["lowestNote"],
r_a["highestNote"],
r_b["lowestNote"],
r_b["highestNote"],
)

overlaps.append(
{
"role_a": r_a["role_id"],
"role_b": r_b["role_id"],
"overlap_region": (
f"{r_a['role_name']} and {r_b['role_name']} overlap"
),
"severity": severity,
}
)
severity = _overlap_severity(
r_a["lowestNote"],
r_a["highestNote"],
r_b["lowestNote"],
r_b["highestNote"],
)

overlaps.append(
{
"role_a": r_a["role_id"],
"role_b": r_b["role_id"],
"overlap_region": (
f"{r_a['role_name']} and {r_b['role_name']} overlap"
),
"severity": severity,
}
)

summaries.append(
{
Expand Down
90 changes: 89 additions & 1 deletion services/analysis-engine/tests/test_ranges.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ def test_parse_note_basic() -> None:
assert _parse_note("C4") == ("C", 4)
assert _parse_note("G#3") == ("G#", 3)
assert _parse_note("Bb2") == ("Bb", 2)
assert _parse_note("csharp4") == ("C#", 4)
assert _parse_note("CSharp4") == ("C#", 4)
assert _parse_note("bflat2") == ("Bb", 2)
assert _parse_note("BFLAT2") == ("Bb", 2)
assert _parse_note("") == ("C", 4)


Expand All @@ -25,7 +29,7 @@ def test_parse_note_without_octave() -> None:

def test_parse_note_all_digits() -> None:
"""Test note parsing when input is all digits (edge case)."""
assert _parse_note("4") == ("4", 4)
assert _parse_note("4") == ("C", 4)


def test_parse_note_malformed_negative_octave_falls_back() -> None:
Expand All @@ -34,6 +38,11 @@ def test_parse_note_malformed_negative_octave_falls_back() -> None:
assert _parse_note("C#-") == ("C#", 4)


def test_parse_note_rejects_overlong_inputs() -> None:
"""Test overlong note strings are bounded before parsing."""
assert _parse_note("A" * 64) == ("C", 4)


def test_note_to_midi() -> None:
"""Test MIDI number conversion for note comparison."""
assert _note_to_midi("C4") == 60
Expand All @@ -51,6 +60,7 @@ def test_ranges_overlap_true() -> None:
def test_ranges_overlap_false() -> None:
"""Test non-overlapping ranges are correctly identified."""
assert _ranges_overlap("C2", "E2", "A4", "C5") is False
assert _ranges_overlap("C4", "C5", "C5", "C4") is False


def test_overlap_severity_high() -> None:
Expand All @@ -67,6 +77,18 @@ def test_overlap_severity_low() -> None:
assert result == "low"


def test_overlap_severity_low_for_inverted_range() -> None:
"""Test malformed inverted ranges fail closed to low severity."""
result = _overlap_severity("C5", "C4", "C4", "C5")
assert result == "low"


def test_overlap_severity_low_for_touching_boundary() -> None:
"""Test boundary-only overlap stays low severity."""
result = _overlap_severity("C4", "C5", "C5", "C6")
assert result == "low"


def test_overlap_severity_medium() -> None:
"""Test medium severity overlap detection."""
# C3-C5 = 24 semitones, A3-G6 = 34 semitones.
Expand Down Expand Up @@ -168,6 +190,72 @@ def test_range_analyzer_no_overlap() -> None:
assert result["sections"][0]["overlaps"] == []


def test_range_analyzer_does_not_overlap_inverted_ranges() -> None:
"""Test malformed inverted ranges do not create false-positive overlaps."""
analyzer = RangeAnalyzer()
sections = [{"id": "verse-1"}]
roles_by_section = {
"verse-1": [
{
"id": "normal",
"name": "Normal",
"range": {"lowestNote": "C4", "highestNote": "C5"},
},
{
"id": "inverted",
"name": "Inverted",
"range": {"lowestNote": "C5", "highestNote": "C4"},
},
]
}

result = analyzer.analyze(sections, roles_by_section)

assert result["sections"][0]["overlaps"] == []


def test_range_analyzer_bounds_overlong_note_strings() -> None:
"""Test overlong note strings are replaced before result serialization."""
analyzer = RangeAnalyzer()
sections = [{"id": "verse-1"}]
roles_by_section = {
"verse-1": [
{
"id": "bass",
"name": "Bass",
"range": {"lowestNote": "A" * 64, "highestNote": "B" * 64},
}
]
}

result = analyzer.analyze(sections, roles_by_section)
ranges = result["sections"][0]["ranges"]

assert ranges[0]["lowestNote"] == "C4"
assert ranges[0]["highestNote"] == "C4"


def test_range_analyzer_defaults_non_string_note_values() -> None:
"""Test non-string note values are replaced before result serialization."""
analyzer = RangeAnalyzer()
sections = [{"id": "verse-1"}]
roles_by_section = {
"verse-1": [
{
"id": "bass",
"name": "Bass",
"range": {"lowestNote": ["C4"], "highestNote": {"note": "G4"}},
}
]
}

result = analyzer.analyze(sections, roles_by_section)
ranges = result["sections"][0]["ranges"]

assert ranges[0]["lowestNote"] == "C4"
assert ranges[0]["highestNote"] == "C4"


def test_range_analyzer_invalid_section() -> None:
"""Test analyzer handles non-dict sections gracefully."""
analyzer = RangeAnalyzer()
Expand Down