Skip to content
9 changes: 7 additions & 2 deletions src/google/adk/tools/load_web_page.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@

_ALLOWED_URL_SCHEMES = frozenset({'http', 'https'})
_DEFAULT_PORT_BY_SCHEME = {'http': 80, 'https': 443}
# Default timeout in seconds for HTTP requests.
_DEFAULT_TIMEOUT_SECONDS = 10
_ResolvedAddress = ipaddress.IPv4Address | ipaddress.IPv6Address


Expand Down Expand Up @@ -230,6 +232,7 @@ def _fetch_direct_response(
url,
allow_redirects=False,
proxies={'http': None, 'https': None},
timeout=_DEFAULT_TIMEOUT_SECONDS,
)
except requests.RequestException as exc:
last_error = exc
Expand All @@ -253,7 +256,9 @@ def _fetch_response(url: str) -> requests.Response:
# localhost-style names can be rejected locally without breaking proxy use.
if parsed_ip_literal is not None and _is_blocked_address(parsed_ip_literal):
raise ValueError(f'Blocked host: {target.hostname}')
return requests.get(url, allow_redirects=False)
return requests.get(
url, allow_redirects=False, timeout=_DEFAULT_TIMEOUT_SECONDS
)

if parsed_ip_literal is not None:
if _is_blocked_address(parsed_ip_literal):
Expand Down Expand Up @@ -285,7 +290,7 @@ def load_web_page(url: str) -> str:

try:
response = _fetch_response(url)
except ValueError:
except (ValueError, requests.RequestException):
return _failed_to_fetch_message(url)

# Set allow_redirects=False to prevent SSRF attacks via redirection.
Expand Down
122 changes: 121 additions & 1 deletion tests/unittests/tools/test_load_web_page.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,9 @@ def test_load_web_page_uses_proxy_for_unresolved_public_hostnames(monkeypatch):

assert result == 'This page has enough words to keep.'
mock_get.assert_called_once_with(
'https://does-not-resolve.invalid', allow_redirects=False
'https://does-not-resolve.invalid',
allow_redirects=False,
timeout=load_web_page_module._DEFAULT_TIMEOUT_SECONDS,
)
mock_send.assert_not_called()

Expand Down Expand Up @@ -272,3 +274,121 @@ def _send(
'https://93.184.216.35',
]
mock_get.assert_not_called()


def test_load_web_page_passes_timeout_to_pinned_session(monkeypatch):
_clear_proxy_env(monkeypatch)
monkeypatch.setattr(
load_web_page_module.socket,
'getaddrinfo',
mock.Mock(
return_value=[(
socket.AF_INET,
socket.SOCK_STREAM,
socket.IPPROTO_TCP,
'',
('93.184.216.34', 0),
)]
),
)
monkeypatch.setattr(
'bs4.BeautifulSoup',
mock.Mock(
return_value=mock.Mock(
get_text=mock.Mock(
return_value='This page has enough words to keep.'
)
)
),
)
captured_timeouts: list[object] = []

def _send(
self,
request,
stream=False,
timeout=None,
verify=True,
cert=None,
proxies=None,
):
del self, request, stream, verify, cert, proxies
captured_timeouts.append(timeout)
return _create_response(
'<html><body><p>This page has enough words to keep.</p></body></html>'
)

monkeypatch.setattr(load_web_page_module.HTTPAdapter, 'send', _send)

load_web_page('https://example.com')

assert captured_timeouts == [load_web_page_module._DEFAULT_TIMEOUT_SECONDS]


def test_load_web_page_passes_timeout_to_proxied_get(monkeypatch):
monkeypatch.setenv('HTTPS_PROXY', 'http://proxy.example.test:8080')
monkeypatch.setenv('NO_PROXY', '')
monkeypatch.setattr(
load_web_page_module.socket,
'getaddrinfo',
mock.Mock(side_effect=AssertionError('unexpected local DNS lookup')),
)
monkeypatch.setattr(
'bs4.BeautifulSoup',
mock.Mock(
return_value=mock.Mock(
get_text=mock.Mock(
return_value='This page has enough words to keep.'
)
)
),
)
mock_get = mock.Mock(
return_value=_create_response(
'<html><body><p>This page has enough words to keep.</p></body></html>'
)
)
monkeypatch.setattr(load_web_page_module.requests, 'get', mock_get)

load_web_page('https://does-not-resolve.invalid')

mock_get.assert_called_once_with(
'https://does-not-resolve.invalid',
allow_redirects=False,
timeout=load_web_page_module._DEFAULT_TIMEOUT_SECONDS,
)


def test_load_web_page_returns_failure_on_timeout(monkeypatch):
_clear_proxy_env(monkeypatch)
monkeypatch.setattr(
load_web_page_module.socket,
'getaddrinfo',
mock.Mock(
return_value=[(
socket.AF_INET,
socket.SOCK_STREAM,
socket.IPPROTO_TCP,
'',
('93.184.216.34', 0),
)]
),
)

def _send(
self,
request,
stream=False,
timeout=None,
verify=True,
cert=None,
proxies=None,
):
del self, request, stream, timeout, verify, cert, proxies
raise requests.exceptions.Timeout('boom')

monkeypatch.setattr(load_web_page_module.HTTPAdapter, 'send', _send)

result = load_web_page('https://example.com')

assert result == 'Failed to fetch url: https://example.com'