Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,15 @@ class ConnectionBody(StrictBaseModel):
login: str | None = Field(default=None)
schema_: str | None = Field(None, alias="schema")
port: int | None = Field(default=None)

@field_validator("port")
@classmethod
def validate_port(cls, v: int | None) -> int | None:
"""Validate that port is within the valid TCP/UDP range (0-65535)."""
if v is not None and not (0 <= v <= 65535):
raise ValueError(f"Port must be between 0 and 65535, got {v}")
return v

password: str | None = Field(default=None)
extra: str | None = Field(default=None)
team_name: str | None = Field(max_length=50, default=None)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,11 @@ class ConnectionTestConnectionResponse(BaseModel):
schema_: str | None = Field(None, alias="schema")
port: int | None = None
extra: str | None = None

@field_validator("port")
@classmethod
def validate_port(cls, v: int | None) -> int | None:
"""Validate that port is within the valid TCP/UDP range (0-65535)."""
if v is not None and not (0 <= v <= 65535):
raise ValueError(f"Port must be between 0 and 65535, got {v}")
return v
5 changes: 5 additions & 0 deletions airflow-core/src/airflow/cli/commands/connection_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,11 @@ def connections_add(args):
)
args.conn_type = "generic"

# Validate port if provided
if args.conn_port is not None:
if not isinstance(args.conn_port, int) or not (0 <= args.conn_port <= 65535):
raise SystemExit(f"Port must be between 0 and 65535, got {args.conn_port}")

if has_uri or has_json:
invalid_args = []
if has_uri and not _valid_uri(args.conn_uri):
Expand Down
10 changes: 10 additions & 0 deletions airflow-core/src/airflow/models/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ def __init__(
self.login = login
self.password = password
self.schema = schema
self._validate_port(port, conn_id)
self.port = port
self.extra = extra

Expand All @@ -202,6 +203,13 @@ def __init__(
mask_secret(quote(self.password))
self.team_name = team_name

@staticmethod
def _validate_port(port: int | None, conn_id: str | None = None) -> None:
"""Validate that port is within the valid TCP/UDP range (0-65535)."""
if port is not None and not (0 <= port <= 65535):
conn_msg = f" for connection {conn_id!r}" if conn_id else ""
raise ValueError(f"Port must be between 0 and 65535{conn_msg}, got {port}")

@staticmethod
def _validate_extra(extra, conn_id) -> None:
"""Verify that ``extra`` is a JSON-encoded Python dict."""
Expand Down Expand Up @@ -257,6 +265,7 @@ def _parse_from_uri(self, uri: str):
self.login = unquote(uri_parts.username) if uri_parts.username else uri_parts.username
self.password = unquote(uri_parts.password) if uri_parts.password else uri_parts.password
self.port = uri_parts.port
self._validate_port(self.port, self.conn_id)
if uri_parts.query:
query = dict(parse_qsl(uri_parts.query, keep_blank_values=True))
if self.EXTRA_KEY in query:
Expand Down Expand Up @@ -591,6 +600,7 @@ def from_json(cls, value, conn_id=None) -> Connection:
kwargs["port"] = int(port)
except ValueError:
raise ValueError(f"Expected integer value for `port`, but got {port!r} instead.")
cls._validate_port(kwargs.get("port"), conn_id)
return Connection(conn_id=conn_id, **kwargs)

def as_json(self) -> str:
Expand Down
9 changes: 9 additions & 0 deletions airflow-core/src/airflow/models/connection_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,14 @@ class ConnectionTestRequest(Base, FernetFieldsMixin):
login: Mapped[str | None] = mapped_column(Text, nullable=True)
schema: Mapped[str | None] = mapped_column("schema", String(500), nullable=True)
port: Mapped[int | None] = mapped_column(Integer, nullable=True)

@staticmethod
def _validate_port(port: int | None, connection_id: str | None = None) -> None:
"""Validate that port is within the valid TCP/UDP range (0-65535)."""
if port is not None and not (0 <= port <= 65535):
conn_msg = f" for connection {connection_id!r}" if connection_id else ""
raise ValueError(f"Port must be between 0 and 65535{conn_msg}, got {port}")

commit_on_success: Mapped[bool] = mapped_column(
Boolean, nullable=False, default=False, server_default="0"
)
Expand Down Expand Up @@ -152,6 +160,7 @@ def __init__(
self.login = login
self.password = password
self.schema = schema
self._validate_port(port, connection_id)
self.port = port
self.extra = extra
self.commit_on_success = commit_on_success
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2062,3 +2062,30 @@ def test_post_should_fail_with_non_json_object_as_extra(
"method": "POST",
},
)


class TestConnectionBodyPortValidation:
"""Test port validation in ConnectionBody model."""

@pytest.mark.parametrize(
"port",
[0, 1, 80, 443, 3306, 5432, 8080, 65535],
)
def test_valid_ports(self, port):
"""Test that valid port numbers (0-65535) are accepted."""
body = ConnectionBody(connection_id="test", conn_type="test", port=port)
assert body.port == port

def test_none_port_allowed(self):
"""Test that None port is allowed (optional field)."""
body = ConnectionBody(connection_id="test", conn_type="test", port=None)
assert body.port is None

@pytest.mark.parametrize(
"port",
[-1, 65536, 99999, 99999999],
)
def test_invalid_ports(self, port):
"""Test that invalid port numbers are rejected."""
with pytest.raises(ValueError, match="Port must be between 0 and 65535"):
ConnectionBody(connection_id="test", conn_type="test", port=port)
51 changes: 51 additions & 0 deletions airflow-core/tests/unit/models/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,3 +540,54 @@ def test_get_conn_id_to_team_name_mapping(self, testing_team: Team, session: Ses
"test_conn2": None,
}
clear_db_connections()

def test_port_validation_valid_ports(self):
"""Test that valid port numbers (0-65535) are accepted."""
for port in [0, 1, 80, 443, 3306, 5432, 8080, 65535]:
conn = Connection(conn_id=f"test_{port}", conn_type="test", port=port)
assert conn.port == port

def test_port_validation_invalid_negative_port(self):
"""Test that negative port numbers are rejected."""
with pytest.raises(ValueError, match="Port must be between 0 and 65535"):
Connection(conn_id="test_neg", conn_type="test", port=-1)

def test_port_validation_invalid_port_too_large(self):
"""Test that port numbers > 65535 are rejected."""
with pytest.raises(ValueError, match="Port must be between 0 and 65535"):
Connection(conn_id="test_large", conn_type="test", port=65536)

def test_port_validation_invalid_port_very_large(self):
"""Test that very large port numbers are rejected."""
with pytest.raises(ValueError, match="Port must be between 0 and 65535"):
Connection(conn_id="test_very_large", conn_type="test", port=99999999)

def test_port_validation_none_allowed(self):
"""Test that None port is allowed (optional field)."""
conn = Connection(conn_id="test_none", conn_type="test", port=None)
assert conn.port is None

def test_port_validation_from_uri_valid_port(self):
"""Test that valid ports from URI are accepted."""
conn = Connection(uri="postgres://user:pass@host:5432/db", conn_id="test_uri")
assert conn.port == 5432

def test_port_validation_from_uri_invalid_port(self):
"""Test that invalid ports from URI are rejected."""
with pytest.raises(ValueError, match="Port must be between 0 and 65535"):
Connection(uri="postgres://user:pass@host:99999/db", conn_id="test_uri_invalid")

def test_port_validation_from_json_valid_port(self):
"""Test that valid ports from JSON are accepted."""
conn = Connection.from_json('{"conn_type": "postgres", "port": "5432"}', conn_id="test_json")
assert conn.port == 5432

def test_port_validation_from_json_invalid_port(self):
"""Test that invalid ports from JSON are rejected."""
with pytest.raises(ValueError, match="Port must be between 0 and 65535"):
Connection.from_json('{"conn_type": "postgres", "port": "99999"}', conn_id="test_json_invalid")

def test_port_validation_from_json_negative_port(self):
"""Test that negative ports from JSON are rejected."""
with pytest.raises(ValueError, match="Port must be between 0 and 65535"):
Connection.from_json('{"conn_type": "postgres", "port": "-1"}', conn_id="test_json_neg")
5 changes: 5 additions & 0 deletions task-sdk/src/airflow/sdk/definitions/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,11 @@ class Connection:

EXTRA_KEY = "__extra__"

def __attrs_post_init__(self) -> None:
"""Validate port after initialization."""
if self.port is not None and not (0 <= self.port <= 65535):
raise ValueError(f"Port must be between 0 and 65535, got {self.port}")

@overload
def __init__(self, *, conn_id: str, uri: str) -> None: ...

Expand Down
55 changes: 55 additions & 0 deletions task-sdk/tests/task_sdk/definitions/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,3 +421,58 @@ def test_from_uri_roundtrip(self):
original_extra = json.loads(conn_from_original.extra)
roundtrip_extra = json.loads(conn_from_roundtrip.extra)
assert original_extra == roundtrip_extra


class TestConnectionPortValidation:
"""Test port validation in Connection model."""

def test_port_validation_valid_ports(self):
"""Test that valid port numbers (0-65535) are accepted."""
for port in [0, 1, 80, 443, 3306, 5432, 8080, 65535]:
conn = Connection(conn_id=f"test_{port}", conn_type="test", port=port)
assert conn.port == port

def test_port_validation_none_allowed(self):
"""Test that None port is allowed (optional field)."""
conn = Connection(conn_id="test_none", conn_type="test", port=None)
assert conn.port is None

def test_port_validation_invalid_negative_port(self):
"""Test that negative port numbers are rejected."""
with pytest.raises(ValueError, match="Port must be between 0 and 65535"):
Connection(conn_id="test_neg", conn_type="test", port=-1)

def test_port_validation_invalid_port_too_large(self):
"""Test that port numbers > 65535 are rejected."""
with pytest.raises(ValueError, match="Port must be between 0 and 65535"):
Connection(conn_id="test_large", conn_type="test", port=65536)

def test_port_validation_invalid_port_very_large(self):
"""Test that very large port numbers are rejected."""
with pytest.raises(ValueError, match="Port must be between 0 and 65535"):
Connection(conn_id="test_very_large", conn_type="test", port=99999999)

def test_port_validation_from_uri_valid_port(self):
"""Test that valid ports from URI are accepted."""
conn = Connection.from_uri("postgres://user:***@host:5432/db", conn_id="test_uri")
assert conn.port == 5432

def test_port_validation_from_uri_invalid_port(self):
"""Test that invalid ports from URI are rejected."""
with pytest.raises(ValueError, match="Port must be between 0 and 65535"):
Connection.from_uri("postgres://user:***@host:99999/db", conn_id="test_uri_invalid")

def test_port_validation_from_json_valid_port(self):
"""Test that valid ports from JSON are accepted."""
conn = Connection.from_json('{"conn_type": "postgres", "port": "5432"}', conn_id="test_json")
assert conn.port == 5432

def test_port_validation_from_json_invalid_port(self):
"""Test that invalid ports from JSON are rejected."""
with pytest.raises(ValueError, match="Port must be between 0 and 65535"):
Connection.from_json('{"conn_type": "postgres", "port": "99999"}', conn_id="test_json_invalid")

def test_port_validation_from_json_negative_port(self):
"""Test that negative ports from JSON are rejected."""
with pytest.raises(ValueError, match="Port must be between 0 and 65535"):
Connection.from_json('{"conn_type": "postgres", "port": "-1"}', conn_id="test_json_neg")
Loading