Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
238 changes: 47 additions & 191 deletions .github/scripts/check_sdk_api_breakage.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,17 +31,13 @@ class member must have been marked deprecated in the *previous* release using
from __future__ import annotations

import ast
import io
import json
import os
import subprocess
import sys
import tarfile
import tempfile
import tomllib
import urllib.request
from collections.abc import Iterable
from contextlib import contextmanager
from dataclasses import asdict, dataclass, field
from pathlib import Path

Expand Down Expand Up @@ -114,14 +110,6 @@ class FieldDefaultChange:
ACP_BASE_REF_ENV = "ACP_VERSION_CHECK_BASE_REF"


def _get_base_ref() -> str | None:
base_ref = os.environ.get(ACP_BASE_REF_ENV) or os.environ.get("GITHUB_BASE_REF")
if not base_ref:
return None
base_ref = base_ref.strip()
return base_ref or None


def read_version_from_pyproject(path: str) -> str:
"""Read the version string from a pyproject.toml file."""
with open(path, "rb") as f:
Expand Down Expand Up @@ -178,12 +166,8 @@ def _min_version_from_requirement(req_str: str) -> pkg_version.Version | None:
return max(lower_bounds)


def _git_ref_candidates(ref: str) -> tuple[str, ...]:
return tuple(dict.fromkeys((f"origin/{ref}", ref)))


def _git_show_file(ref: str, rel_path: str) -> str | None:
for candidate in _git_ref_candidates(ref):
for candidate in (f"origin/{ref}", ref):
result = subprocess.run(
["git", "show", f"{candidate}:{rel_path}"],
check=False,
Expand All @@ -195,29 +179,6 @@ def _git_show_file(ref: str, rel_path: str) -> str | None:
return None


def _git_archive_directory(
repo_root: str,
ref: str,
rel_path: str,
dest_root: str,
) -> bool:
for candidate in _git_ref_candidates(ref):
result = subprocess.run(
["git", "archive", "--format=tar", candidate, rel_path],
cwd=repo_root,
check=False,
capture_output=True,
)
if result.returncode != 0:
continue

with tarfile.open(fileobj=io.BytesIO(result.stdout)) as archive:
archive.extractall(dest_root, filter="data")
return True

return False


def _load_base_pyproject(base_ref: str) -> dict | None:
rel_path = "openhands-sdk/pyproject.toml"
content = _git_show_file(base_ref, rel_path)
Expand Down Expand Up @@ -245,7 +206,7 @@ def _check_acp_version_bump(repo_root: str) -> int:
)
return 0

base_ref = _get_base_ref()
base_ref = os.environ.get(ACP_BASE_REF_ENV) or os.environ.get("GITHUB_BASE_REF")
if not base_ref:
print(
"::warning title=ACP version::No base ref found; skipping ACP version check"
Expand Down Expand Up @@ -607,23 +568,19 @@ def _object_path(obj: object | None) -> str:
)


def _write_field_default_change_report(
changes: list[FieldDefaultChange],
*,
field_default_changes_since_base: list[FieldDefaultChange] | None = None,
) -> None:
def _write_field_default_change_report(changes: list[FieldDefaultChange]) -> None:
"""Write detected public Field default changes to a JSON report file."""
report_path = os.environ.get(FIELD_DEFAULT_CHANGE_REPORT_ENV, "").strip()
if not report_path:
return

report = {"field_default_changes": [asdict(change) for change in changes]}
if field_default_changes_since_base is not None:
report["field_default_changes_since_base"] = [
asdict(change) for change in field_default_changes_since_base
]

Path(report_path).write_text(json.dumps(report, indent=2) + "\n")
Path(report_path).write_text(
json.dumps(
{"field_default_changes": [asdict(change) for change in changes]},
indent=2,
)
+ "\n"
)


def _member_deprecation_metadata(
Expand Down Expand Up @@ -671,8 +628,6 @@ def _collect_breakages_pairs(
title: str,
package: str,
field_default_changes: list[FieldDefaultChange] | None = None,
field_defaults_only: bool = False,
emit_diagnostics: bool = True,
) -> tuple[list[object], int]:
"""Find breaking changes between pairs of old/new API objects.

Expand All @@ -699,23 +654,20 @@ def _collect_breakages_pairs(
old_value = getattr(br, "old_value", None)
new_value = getattr(br, "new_value", None)
if _is_field_metadata_only_change(old_value, new_value):
if emit_diagnostics:
print(
f"::notice title={title}::Ignoring Field "
"metadata-only change (non-breaking): "
f"{obj.name if obj else 'unknown'}"
)
print(
f"::notice title={title}::Ignoring Field metadata-only "
f"change (non-breaking): {obj.name if obj else 'unknown'}"
)
continue
if _is_field_default_only_change(old_value, new_value):
object_path = _object_path(obj)
old_default = _field_default_repr(old_value) or "<unknown>"
new_default = _field_default_repr(new_value) or "<unknown>"
if emit_diagnostics:
print(
f"::warning title={title}::Public Field default "
"changed (release-note-required): "
f"{object_path} {old_default} -> {new_default}"
)
print(
f"::warning title={title}::Public Field default changed "
f"(release-note-required): {object_path} "
f"{old_default} -> {new_default}"
)
if field_default_changes is not None:
field_default_changes.append(
FieldDefaultChange(
Expand All @@ -727,9 +679,6 @@ def _collect_breakages_pairs(
)
continue

if field_defaults_only:
continue

print(br.explain(style=ExplanationStyle.GITHUB))
breakages.append(br)

Expand All @@ -753,8 +702,6 @@ def _collect_breakages_pairs(
print(f"::error title={title}::{error}")
removal_policy_errors += len(errors)
except AliasResolutionError as e:
if field_defaults_only:
continue
if isinstance(old, Alias) or isinstance(new, Alias):
old_target = old.target_path if isinstance(old, Alias) else None
new_target = new.target_path if isinstance(new, Alias) else None
Expand All @@ -780,8 +727,6 @@ def _collect_breakages_pairs(
f"unresolved alias: {e}"
)
except Exception as e:
if field_defaults_only:
raise RuntimeError("Failed to collect Field default changes") from e
print(f"::warning title={title}::Failed to compute breakages: {e}")

return breakages, removal_policy_errors
Expand Down Expand Up @@ -912,36 +857,6 @@ def _load_current(
return None


@contextmanager
def _load_from_git_ref(
griffe_module: object,
repo_root: str,
ref: str,
cfg: PackageConfig,
):
title = f"{cfg.distribution} API"
with tempfile.TemporaryDirectory() as tmpdir:
if not _git_archive_directory(repo_root, ref, cfg.source_dir, tmpdir):
print(
f"::warning title={title}::Failed to load {cfg.distribution} from "
f"git ref {ref}: unable to archive {cfg.source_dir}"
)
yield None
return

try:
yield griffe_module.load(
cfg.package,
search_paths=[os.path.join(tmpdir, cfg.source_dir)],
)
except Exception as e:
print(
f"::warning title={title}::Failed to load {cfg.distribution} from "
f"git ref {ref}: {e}"
)
yield None


def _load_prev_from_pypi(
griffe_module: object,
prev: str,
Expand All @@ -964,40 +879,6 @@ def _load_prev_from_pypi(
return None


def _collect_field_default_changes_since_ref(
griffe_module: object,
repo_root: str,
ref: str,
cfg: PackageConfig,
) -> list[FieldDefaultChange] | None:
new_root = _load_current(griffe_module, repo_root, cfg)
if not new_root:
return None

with _load_from_git_ref(griffe_module, repo_root, ref, cfg) as old_root:
if not old_root:
return None

changes: list[FieldDefaultChange] = []
try:
_compute_breakages(
old_root,
new_root,
cfg,
field_default_changes=changes,
field_defaults_only=True,
emit_diagnostics=False,
)
except Exception as e:
print(
f"::warning title={cfg.distribution} API::Failed to compare "
f"Field defaults against base ref {ref}: {e}"
)
return None

return changes


# Names of module-level data registries that declare deprecated public
# re-exports as ``{name: {"deprecated_in": ..., "removed_in": ...}}`` and are
# consumed by a module-level ``__getattr__``. The SDK uses this form for renamed
Expand Down Expand Up @@ -1201,8 +1082,6 @@ def _compute_breakages(
*,
current_version: str = "9999.0.0",
field_default_changes: list[FieldDefaultChange] | None = None,
field_defaults_only: bool = False,
emit_diagnostics: bool = True,
) -> tuple[int, int]:
"""Detect breaking changes between old and new package versions.

Expand Down Expand Up @@ -1237,47 +1116,44 @@ def _compute_breakages(
# evaluate) __all__, we can't compute meaningful breakages.
#
# In this situation, skip rather than failing the entire workflow.
if emit_diagnostics:
print(
f"::notice title={title}::Skipping breakage check; baseline release "
f"has no statically-evaluable {pkg}.__all__: {e}"
)
print(
f"::notice title={title}::Skipping breakage check; baseline release "
f"has no statically-evaluable {pkg}.__all__: {e}"
)
return 0, 0

if not field_defaults_only:
removed = sorted(old_exports - new_exports)

# Check deprecation runway policy (exports)
for name in removed:
total_breaks += 1 # every removal is a structural break
errors = _deprecation_schedule_errors(
feature=name,
metadata=(
deprecated.metadata.get(name, DeprecationMetadata())
if name in deprecated.top_level
else None
),
current_version=current_version,
removed = sorted(old_exports - new_exports)

# Check deprecation runway policy (exports)
for name in removed:
total_breaks += 1 # every removal is a structural break
errors = _deprecation_schedule_errors(
feature=name,
metadata=(
deprecated.metadata.get(name, DeprecationMetadata())
if name in deprecated.top_level
else None
),
current_version=current_version,
)
if not errors:
print(
f"::notice title={title}::Removed previously-deprecated symbol "
f"'{name}' from {pkg}.__all__ after its scheduled removal version"
)
if not errors:
print(
f"::notice title={title}::Removed previously-deprecated symbol "
f"'{name}' from {pkg}.__all__ after its scheduled removal version"
)
continue
continue

for error in errors:
print(f"::error title={title}::{error}")
removal_policy_errors += len(errors)
for error in errors:
print(f"::error title={title}::{error}")
removal_policy_errors += len(errors)

common = sorted(old_exports & new_exports)
pairs: list[tuple[object, object]] = []
for name in common:
try:
pairs.append((old_mod[name], new_mod[name]))
except Exception as e:
if emit_diagnostics:
print(f"::warning title={title}::Unable to resolve symbol {name}: {e}")
print(f"::warning title={title}::Unable to resolve symbol {name}: {e}")

breakages, member_policy_errors = _collect_breakages_pairs(
pairs,
Expand All @@ -1286,8 +1162,6 @@ def _compute_breakages(
title=title,
package=cfg.package,
field_default_changes=field_default_changes,
field_defaults_only=field_defaults_only,
emit_diagnostics=emit_diagnostics,
)
total_breaks += len(breakages)
removal_policy_errors += member_policy_errors
Expand Down Expand Up @@ -1357,8 +1231,6 @@ def main() -> int:
import griffe

field_default_changes: list[FieldDefaultChange] = []
field_default_changes_since_base: list[FieldDefaultChange] | None = []
base_ref = _get_base_ref()
for cfg in PACKAGES:
print(f"\n{'=' * 60}")
print(f"Checking {cfg.distribution} ({cfg.package})")
Expand All @@ -1369,24 +1241,8 @@ def main() -> int:
cfg,
field_default_changes=field_default_changes,
)
if base_ref and field_default_changes_since_base is not None:
changes_since_base = _collect_field_default_changes_since_ref(
griffe,
repo_root,
base_ref,
cfg,
)
if changes_since_base is None:
field_default_changes_since_base = None
else:
field_default_changes_since_base.extend(changes_since_base)

_write_field_default_change_report(
field_default_changes,
field_default_changes_since_base=(
field_default_changes_since_base if base_ref else None
),
)
_write_field_default_change_report(field_default_changes)
return rc


Expand Down
Loading
Loading