diff --git a/gapic/ads-templates/%namespace/%name/%version/%sub/services/%service/_shared_macros.j2 b/gapic/ads-templates/%namespace/%name/%version/%sub/services/%service/_shared_macros.j2 index b055b9ca31..ced2529f18 100644 --- a/gapic/ads-templates/%namespace/%name/%version/%sub/services/%service/_shared_macros.j2 +++ b/gapic/ads-templates/%namespace/%name/%version/%sub/services/%service/_shared_macros.j2 @@ -171,16 +171,23 @@ def _get_http_options(): timeout, transcoded_request, body=None): - + uri = transcoded_request['uri'] method = transcoded_request['method'] headers = dict(metadata) headers['Content-Type'] = 'application/json' + # Build query string manually to avoid URL-encoding special characters like '$'. + # The `requests` library encodes '$' as '%24' when using the `params` argument, + # which causes API errors for parameters like '$alt'. See: + # https://github.com/googleapis/gapic-generator-python/issues/2514 + _query_params = rest_helpers.flatten_query_params(query_params, strict=True) + _request_url = "{host}{uri}".format(host=host, uri=uri) + if _query_params: + _request_url = "{}?{}".format(_request_url, urlencode(_query_params, safe="$")) response = {{ await_prefix }}getattr(session, method)( - "{host}{uri}".format(host=host, uri=uri), + _request_url, timeout=timeout, headers=headers, - params=rest_helpers.flatten_query_params(query_params, strict=True), {% if body_spec %} data=body, {% endif %} diff --git a/gapic/ads-templates/%namespace/%name/%version/%sub/services/%service/transports/rest.py.j2 b/gapic/ads-templates/%namespace/%name/%version/%sub/services/%service/transports/rest.py.j2 index a55ced7c08..f5f57b0fe9 100644 --- a/gapic/ads-templates/%namespace/%name/%version/%sub/services/%service/transports/rest.py.j2 +++ b/gapic/ads-templates/%namespace/%name/%version/%sub/services/%service/transports/rest.py.j2 @@ -33,11 +33,13 @@ from google.iam.v1 import policy_pb2 # type: ignore from google.cloud.location import locations_pb2 # type: ignore {% endif %} -from requests import __version__ as requests_version import dataclasses from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union +from urllib.parse import urlencode import warnings +from requests import __version__ as requests_version + {{ shared_macros.operations_mixin_imports(api, service, opts) }} from .rest_base import _Base{{ service.name }}RestTransport diff --git a/gapic/ads-templates/tests/unit/gapic/%name_%version/%sub/_test_mixins.py.j2 b/gapic/ads-templates/tests/unit/gapic/%name_%version/%sub/_test_mixins.py.j2 index d7f8bb7e68..abcfe570a4 100644 --- a/gapic/ads-templates/tests/unit/gapic/%name_%version/%sub/_test_mixins.py.j2 +++ b/gapic/ads-templates/tests/unit/gapic/%name_%version/%sub/_test_mixins.py.j2 @@ -997,3 +997,346 @@ def test_test_iam_permissions_from_dict(): {% endif %} {% endif %} + +{# REST transport tests for mixin URL query params encoding #} +{% if 'rest' in opts.transport %} + +{# Operations mixin REST URL encoding tests #} +{% if api.has_operations_mixin %} + +{% if "ListOperations" in api.mixin_api_methods %} +def test_list_operations_rest_url_query_params_encoding(): + # Verify that special characters like '$' are correctly preserved (not URL-encoded) + # when building the URL query string for mixin methods. + transport = transports.{{ service.rest_transport_name }}(credentials=ga_credentials.AnonymousCredentials()) + method_class = transport.list_operations.__class__ + get_response_fn = method_class._get_response + + mock_session = mock.Mock() + mock_response = mock.Mock() + mock_response.status_code = 200 + mock_session.get.return_value = mock_response + + with mock.patch.object(rest_helpers, 'flatten_query_params') as mock_flatten: + mock_flatten.return_value = [('$alt', 'json;enum-encoding=int'), ('foo', 'bar')] + + transcoded_request = { + 'uri': '/v1/test/operations', + 'method': 'get', + } + + get_response_fn( + host='https://example.com', + metadata=[], + query_params={}, + session=mock_session, + timeout=None, + transcoded_request=transcoded_request, + ) + + assert mock_session.get.called + call_url = mock_session.get.call_args.args[0] + assert '$alt=json' in call_url + assert '%24alt' not in call_url + assert 'foo=bar' in call_url +{% endif %} + +{% if "GetOperation" in api.mixin_api_methods %} +def test_get_operation_rest_url_query_params_encoding(): + transport = transports.{{ service.rest_transport_name }}(credentials=ga_credentials.AnonymousCredentials()) + method_class = transport.get_operation.__class__ + get_response_fn = method_class._get_response + + mock_session = mock.Mock() + mock_response = mock.Mock() + mock_response.status_code = 200 + mock_session.get.return_value = mock_response + + with mock.patch.object(rest_helpers, 'flatten_query_params') as mock_flatten: + mock_flatten.return_value = [('$alt', 'json;enum-encoding=int'), ('foo', 'bar')] + + transcoded_request = { + 'uri': '/v1/test/operations/op1', + 'method': 'get', + } + + get_response_fn( + host='https://example.com', + metadata=[], + query_params={}, + session=mock_session, + timeout=None, + transcoded_request=transcoded_request, + ) + + assert mock_session.get.called + call_url = mock_session.get.call_args.args[0] + assert '$alt=json' in call_url + assert '%24alt' not in call_url + assert 'foo=bar' in call_url +{% endif %} + +{% if "DeleteOperation" in api.mixin_api_methods %} +def test_delete_operation_rest_url_query_params_encoding(): + transport = transports.{{ service.rest_transport_name }}(credentials=ga_credentials.AnonymousCredentials()) + method_class = transport.delete_operation.__class__ + get_response_fn = method_class._get_response + + mock_session = mock.Mock() + mock_response = mock.Mock() + mock_response.status_code = 200 + mock_session.delete.return_value = mock_response + + with mock.patch.object(rest_helpers, 'flatten_query_params') as mock_flatten: + mock_flatten.return_value = [('$alt', 'json;enum-encoding=int'), ('foo', 'bar')] + + transcoded_request = { + 'uri': '/v1/test/operations/op1', + 'method': 'delete', + } + + get_response_fn( + host='https://example.com', + metadata=[], + query_params={}, + session=mock_session, + timeout=None, + transcoded_request=transcoded_request, + ) + + assert mock_session.delete.called + call_url = mock_session.delete.call_args.args[0] + assert '$alt=json' in call_url + assert '%24alt' not in call_url + assert 'foo=bar' in call_url +{% endif %} + +{% if "CancelOperation" in api.mixin_api_methods %} +def test_cancel_operation_rest_url_query_params_encoding(): + transport = transports.{{ service.rest_transport_name }}(credentials=ga_credentials.AnonymousCredentials()) + method_class = transport.cancel_operation.__class__ + get_response_fn = method_class._get_response + + mock_session = mock.Mock() + mock_response = mock.Mock() + mock_response.status_code = 200 + mock_session.post.return_value = mock_response + + with mock.patch.object(rest_helpers, 'flatten_query_params') as mock_flatten: + mock_flatten.return_value = [('$alt', 'json;enum-encoding=int'), ('foo', 'bar')] + + transcoded_request = { + 'uri': '/v1/test/operations/op1:cancel', + 'method': 'post', + 'body': {}, + } + + get_response_fn( + host='https://example.com', + metadata=[], + query_params={}, + session=mock_session, + timeout=None, + transcoded_request=transcoded_request, + body={}, + ) + + assert mock_session.post.called + call_url = mock_session.post.call_args.args[0] + assert '$alt=json' in call_url + assert '%24alt' not in call_url + assert 'foo=bar' in call_url +{% endif %} + +{% endif %} {# operations_mixin #} + +{# Location mixin REST URL encoding tests #} +{% if api.has_location_mixin %} + +{% if "ListLocations" in api.mixin_api_methods %} +def test_list_locations_rest_url_query_params_encoding(): + transport = transports.{{ service.rest_transport_name }}(credentials=ga_credentials.AnonymousCredentials()) + method_class = transport.list_locations.__class__ + get_response_fn = method_class._get_response + + mock_session = mock.Mock() + mock_response = mock.Mock() + mock_response.status_code = 200 + mock_session.get.return_value = mock_response + + with mock.patch.object(rest_helpers, 'flatten_query_params') as mock_flatten: + mock_flatten.return_value = [('$alt', 'json;enum-encoding=int'), ('foo', 'bar')] + + transcoded_request = { + 'uri': '/v1/projects/p1/locations', + 'method': 'get', + } + + get_response_fn( + host='https://example.com', + metadata=[], + query_params={}, + session=mock_session, + timeout=None, + transcoded_request=transcoded_request, + ) + + assert mock_session.get.called + call_url = mock_session.get.call_args.args[0] + assert '$alt=json' in call_url + assert '%24alt' not in call_url + assert 'foo=bar' in call_url +{% endif %} + +{% if "GetLocation" in api.mixin_api_methods %} +def test_get_location_rest_url_query_params_encoding(): + transport = transports.{{ service.rest_transport_name }}(credentials=ga_credentials.AnonymousCredentials()) + method_class = transport.get_location.__class__ + get_response_fn = method_class._get_response + + mock_session = mock.Mock() + mock_response = mock.Mock() + mock_response.status_code = 200 + mock_session.get.return_value = mock_response + + with mock.patch.object(rest_helpers, 'flatten_query_params') as mock_flatten: + mock_flatten.return_value = [('$alt', 'json;enum-encoding=int'), ('foo', 'bar')] + + transcoded_request = { + 'uri': '/v1/projects/p1/locations/l1', + 'method': 'get', + } + + get_response_fn( + host='https://example.com', + metadata=[], + query_params={}, + session=mock_session, + timeout=None, + transcoded_request=transcoded_request, + ) + + assert mock_session.get.called + call_url = mock_session.get.call_args.args[0] + assert '$alt=json' in call_url + assert '%24alt' not in call_url + assert 'foo=bar' in call_url +{% endif %} + +{% endif %} {# location_mixin #} + +{# IAM mixin REST URL encoding tests #} +{% if api.has_iam_mixin or opts.add_iam_methods %} + +{% if "SetIamPolicy" in api.mixin_api_methods or opts.add_iam_methods %} +def test_set_iam_policy_rest_url_query_params_encoding(): + transport = transports.{{ service.rest_transport_name }}(credentials=ga_credentials.AnonymousCredentials()) + method_class = transport.set_iam_policy.__class__ + get_response_fn = method_class._get_response + + mock_session = mock.Mock() + mock_response = mock.Mock() + mock_response.status_code = 200 + mock_session.post.return_value = mock_response + + with mock.patch.object(rest_helpers, 'flatten_query_params') as mock_flatten: + mock_flatten.return_value = [('$alt', 'json;enum-encoding=int'), ('foo', 'bar')] + + transcoded_request = { + 'uri': '/v1/resource:setIamPolicy', + 'method': 'post', + 'body': {}, + } + + get_response_fn( + host='https://example.com', + metadata=[], + query_params={}, + session=mock_session, + timeout=None, + transcoded_request=transcoded_request, + body={}, + ) + + assert mock_session.post.called + call_url = mock_session.post.call_args.args[0] + assert '$alt=json' in call_url + assert '%24alt' not in call_url + assert 'foo=bar' in call_url +{% endif %} + +{% if "GetIamPolicy" in api.mixin_api_methods or opts.add_iam_methods %} +def test_get_iam_policy_rest_url_query_params_encoding(): + transport = transports.{{ service.rest_transport_name }}(credentials=ga_credentials.AnonymousCredentials()) + method_class = transport.get_iam_policy.__class__ + get_response_fn = method_class._get_response + + mock_session = mock.Mock() + mock_response = mock.Mock() + mock_response.status_code = 200 + mock_session.get.return_value = mock_response + + with mock.patch.object(rest_helpers, 'flatten_query_params') as mock_flatten: + mock_flatten.return_value = [('$alt', 'json;enum-encoding=int'), ('foo', 'bar')] + + transcoded_request = { + 'uri': '/v1/resource:getIamPolicy', + 'method': 'get', + } + + get_response_fn( + host='https://example.com', + metadata=[], + query_params={}, + session=mock_session, + timeout=None, + transcoded_request=transcoded_request, + ) + + assert mock_session.get.called + call_url = mock_session.get.call_args.args[0] + assert '$alt=json' in call_url + assert '%24alt' not in call_url + assert 'foo=bar' in call_url +{% endif %} + +{% if "TestIamPermissions" in api.mixin_api_methods or opts.add_iam_methods %} +def test_test_iam_permissions_rest_url_query_params_encoding(): + transport = transports.{{ service.rest_transport_name }}(credentials=ga_credentials.AnonymousCredentials()) + method_class = transport.test_iam_permissions.__class__ + get_response_fn = method_class._get_response + + mock_session = mock.Mock() + mock_response = mock.Mock() + mock_response.status_code = 200 + mock_session.post.return_value = mock_response + + with mock.patch.object(rest_helpers, 'flatten_query_params') as mock_flatten: + mock_flatten.return_value = [('$alt', 'json;enum-encoding=int'), ('foo', 'bar')] + + transcoded_request = { + 'uri': '/v1/resource:testIamPermissions', + 'method': 'post', + 'body': {}, + } + + get_response_fn( + host='https://example.com', + metadata=[], + query_params={}, + session=mock_session, + timeout=None, + transcoded_request=transcoded_request, + body={}, + ) + + assert mock_session.post.called + call_url = mock_session.post.call_args.args[0] + assert '$alt=json' in call_url + assert '%24alt' not in call_url + assert 'foo=bar' in call_url +{% endif %} + +{% endif %} {# iam_mixin #} + +{% endif %} {# rest transport #} diff --git a/gapic/ads-templates/tests/unit/gapic/%name_%version/%sub/test_%service.py.j2 b/gapic/ads-templates/tests/unit/gapic/%name_%version/%sub/test_%service.py.j2 index ac385e285d..9569470c47 100644 --- a/gapic/ads-templates/tests/unit/gapic/%name_%version/%sub/test_%service.py.j2 +++ b/gapic/ads-templates/tests/unit/gapic/%name_%version/%sub/test_%service.py.j2 @@ -18,6 +18,8 @@ import grpc from grpc.experimental import aio {% if "rest" in opts.transport %} from collections.abc import Iterable +import urllib.parse + from google.protobuf import json_format import json {% endif %} @@ -45,6 +47,7 @@ from google.api_core import client_options from google.api_core import exceptions as core_exceptions from google.api_core import grpc_helpers from google.api_core import path_template +from google.api_core import rest_helpers from google.api_core import retry as retries {% if service.has_lro %} from google.api_core import future @@ -1057,6 +1060,53 @@ def test_{{ method_name }}_raw_page_lro(): {% for method in service.methods.values() if 'rest' in opts.transport %}{% with method_name = method.name|snake_case + "_unary" if method.operation_service else method.client_method_name|snake_case %}{% if method.http_options %} {# TODO(kbandes): remove this if condition when client streaming are supported. #} {% if not method.client_streaming %} + +def test_{{ method_name }}_rest_url_query_params_encoding(): + # Verify that special characters like '$' are correctly preserved (not URL-encoded) + # when building the URL query string. This tests the urlencode call with safe="$". + transport = transports.{{ service.rest_transport_name }}(credentials=ga_credentials.AnonymousCredentials) + method_class = transport.{{ method.transport_safe_name|snake_case }}.__class__ + # Get the _get_response static method from the method class + get_response_fn = method_class._get_response + + mock_session = mock.Mock() + mock_response = mock.Mock() + mock_response.status_code = 200 + mock_session.get.return_value = mock_response + mock_session.post.return_value = mock_response + mock_session.put.return_value = mock_response + mock_session.patch.return_value = mock_response + mock_session.delete.return_value = mock_response + + # Mock flatten_query_params to return query params that include '$' character + with mock.patch.object(rest_helpers, 'flatten_query_params') as mock_flatten: + mock_flatten.return_value = [('$alt', 'json;enum-encoding=int'), ('foo', 'bar')] + + transcoded_request = { + 'uri': '/v1/test', + 'method': '{{ method.http_options[0].method }}', + } + + get_response_fn( + host='https://example.com', + metadata=[], + query_params={}, + session=mock_session, + timeout=None, + transcoded_request=transcoded_request, + ) + + # Verify the session method was called with the URL containing query params + session_method = getattr(mock_session, '{{ method.http_options[0].method }}') + assert session_method.called + + # The URL should contain '$alt' (not '%24alt') because safe="$" is used + call_url = session_method.call_args.args[0] + assert '$alt=json' in call_url + assert '%24alt' not in call_url + assert 'foo=bar' in call_url + + @pytest.mark.parametrize("request_type", [ {{ method.input.ident }}, dict, @@ -1451,8 +1501,12 @@ def test_{{ method_name }}_rest_required_fields(request_type={{ method.input.ide ('$alt', 'json;enum-encoding=int') {% endif %} ] - actual_params = req.call_args.kwargs['params'] - assert expected_params == actual_params + # Verify query params are correctly included in the URL + # Session.request is called as request(method, url, ...), so url is args[1] + actual_url = req.call_args.args[1] + parsed_url = urllib.parse.urlparse(actual_url) + actual_params = urllib.parse.parse_qsl(parsed_url.query, keep_blank_values=True) + assert set(expected_params).issubset(set(actual_params)) def test_{{ method_name }}_rest_unset_required_fields(): @@ -1461,9 +1515,9 @@ def test_{{ method_name }}_rest_unset_required_fields(): unset_fields = transport.{{ method.transport_safe_name|snake_case }}._get_unset_required_fields({}) assert set(unset_fields) == (set(({% for param in method.query_params|sort %}"{{ param|camel_case }}", {% endfor %})) & set(({% for param in method.input.required_fields %}"{{param.name|camel_case}}", {% endfor %}))) - {% endif %}{# required_fields #} + {% if not method.client_streaming %} @pytest.mark.parametrize("null_interceptor", [True, False]) def test_{{ method_name }}_rest_interceptors(null_interceptor): diff --git a/gapic/templates/%namespace/%name_%version/%sub/services/%service/_shared_macros.j2 b/gapic/templates/%namespace/%name_%version/%sub/services/%service/_shared_macros.j2 index 6db274e82f..31c3718584 100644 --- a/gapic/templates/%namespace/%name_%version/%sub/services/%service/_shared_macros.j2 +++ b/gapic/templates/%namespace/%name_%version/%sub/services/%service/_shared_macros.j2 @@ -164,16 +164,23 @@ def _get_http_options(): timeout, transcoded_request, body=None): - + uri = transcoded_request['uri'] method = transcoded_request['method'] headers = dict(metadata) headers['Content-Type'] = 'application/json' + # Build query string manually to avoid URL-encoding special characters like '$'. + # The `requests` library encodes '$' as '%24' when using the `params` argument, + # which causes API errors for parameters like '$alt'. See: + # https://github.com/googleapis/gapic-generator-python/issues/2514 + _query_params = rest_helpers.flatten_query_params(query_params, strict=True) + _request_url = "{host}{uri}".format(host=host, uri=uri) + if _query_params: + _request_url = "{}?{}".format(_request_url, urlencode(_query_params, safe="$")) response = {{ await_prefix }}getattr(session, method)( - "{host}{uri}".format(host=host, uri=uri), + _request_url, timeout=timeout, headers=headers, - params=rest_helpers.flatten_query_params(query_params, strict=True), {% if body_spec %} data=body, {% endif %} diff --git a/gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/rest.py.j2 b/gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/rest.py.j2 index 95efafb389..2a55f8fd87 100644 --- a/gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/rest.py.j2 +++ b/gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/rest.py.j2 @@ -27,11 +27,13 @@ from google.iam.v1 import policy_pb2 # type: ignore from google.cloud.location import locations_pb2 # type: ignore {% endif %} -from requests import __version__ as requests_version import dataclasses from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union +from urllib.parse import urlencode import warnings +from requests import __version__ as requests_version + {{ shared_macros.operations_mixin_imports(api, service, opts) }} from .rest_base import _Base{{ service.name }}RestTransport diff --git a/gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/rest_asyncio.py.j2 b/gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/rest_asyncio.py.j2 index 1d6ec87374..2c758c2b8d 100644 --- a/gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/rest_asyncio.py.j2 +++ b/gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/rest_asyncio.py.j2 @@ -48,9 +48,10 @@ from google.iam.v1 import policy_pb2 # type: ignore from google.cloud.location import locations_pb2 # type: ignore {% endif %} -import json # type: ignore import dataclasses +import json # type: ignore from typing import Any, Dict, List, Callable, Tuple, Optional, Sequence, Union +from urllib.parse import urlencode {{ shared_macros.operations_mixin_imports(api, service, opts) }} diff --git a/gapic/templates/tests/unit/gapic/%name_%version/%sub/_test_mixins.py.j2 b/gapic/templates/tests/unit/gapic/%name_%version/%sub/_test_mixins.py.j2 index 169807a961..db320f22fe 100644 --- a/gapic/templates/tests/unit/gapic/%name_%version/%sub/_test_mixins.py.j2 +++ b/gapic/templates/tests/unit/gapic/%name_%version/%sub/_test_mixins.py.j2 @@ -1829,3 +1829,346 @@ async def test_test_iam_permissions_from_dict_async(): call.assert_called() {% endif %} {% endif %} + +{# REST transport tests for mixin URL query params encoding #} +{% if 'rest' in opts.transport %} + +{# Operations mixin REST URL encoding tests #} +{% if api.has_operations_mixin %} + +{% if "ListOperations" in api.mixin_api_methods %} +def test_list_operations_rest_url_query_params_encoding(): + # Verify that special characters like '$' are correctly preserved (not URL-encoded) + # when building the URL query string for mixin methods. + transport = transports.{{ service.rest_transport_name }}(credentials=ga_credentials.AnonymousCredentials()) + method_class = transport.list_operations.__class__ + get_response_fn = method_class._get_response + + mock_session = mock.Mock() + mock_response = mock.Mock() + mock_response.status_code = 200 + mock_session.get.return_value = mock_response + + with mock.patch.object(rest_helpers, 'flatten_query_params') as mock_flatten: + mock_flatten.return_value = [('$alt', 'json;enum-encoding=int'), ('foo', 'bar')] + + transcoded_request = { + 'uri': '/v1/test/operations', + 'method': 'get', + } + + get_response_fn( + host='https://example.com', + metadata=[], + query_params={}, + session=mock_session, + timeout=None, + transcoded_request=transcoded_request, + ) + + assert mock_session.get.called + call_url = mock_session.get.call_args.args[0] + assert '$alt=json' in call_url + assert '%24alt' not in call_url + assert 'foo=bar' in call_url +{% endif %} + +{% if "GetOperation" in api.mixin_api_methods %} +def test_get_operation_rest_url_query_params_encoding(): + transport = transports.{{ service.rest_transport_name }}(credentials=ga_credentials.AnonymousCredentials()) + method_class = transport.get_operation.__class__ + get_response_fn = method_class._get_response + + mock_session = mock.Mock() + mock_response = mock.Mock() + mock_response.status_code = 200 + mock_session.get.return_value = mock_response + + with mock.patch.object(rest_helpers, 'flatten_query_params') as mock_flatten: + mock_flatten.return_value = [('$alt', 'json;enum-encoding=int'), ('foo', 'bar')] + + transcoded_request = { + 'uri': '/v1/test/operations/op1', + 'method': 'get', + } + + get_response_fn( + host='https://example.com', + metadata=[], + query_params={}, + session=mock_session, + timeout=None, + transcoded_request=transcoded_request, + ) + + assert mock_session.get.called + call_url = mock_session.get.call_args.args[0] + assert '$alt=json' in call_url + assert '%24alt' not in call_url + assert 'foo=bar' in call_url +{% endif %} + +{% if "DeleteOperation" in api.mixin_api_methods %} +def test_delete_operation_rest_url_query_params_encoding(): + transport = transports.{{ service.rest_transport_name }}(credentials=ga_credentials.AnonymousCredentials()) + method_class = transport.delete_operation.__class__ + get_response_fn = method_class._get_response + + mock_session = mock.Mock() + mock_response = mock.Mock() + mock_response.status_code = 200 + mock_session.delete.return_value = mock_response + + with mock.patch.object(rest_helpers, 'flatten_query_params') as mock_flatten: + mock_flatten.return_value = [('$alt', 'json;enum-encoding=int'), ('foo', 'bar')] + + transcoded_request = { + 'uri': '/v1/test/operations/op1', + 'method': 'delete', + } + + get_response_fn( + host='https://example.com', + metadata=[], + query_params={}, + session=mock_session, + timeout=None, + transcoded_request=transcoded_request, + ) + + assert mock_session.delete.called + call_url = mock_session.delete.call_args.args[0] + assert '$alt=json' in call_url + assert '%24alt' not in call_url + assert 'foo=bar' in call_url +{% endif %} + +{% if "CancelOperation" in api.mixin_api_methods %} +def test_cancel_operation_rest_url_query_params_encoding(): + transport = transports.{{ service.rest_transport_name }}(credentials=ga_credentials.AnonymousCredentials()) + method_class = transport.cancel_operation.__class__ + get_response_fn = method_class._get_response + + mock_session = mock.Mock() + mock_response = mock.Mock() + mock_response.status_code = 200 + mock_session.post.return_value = mock_response + + with mock.patch.object(rest_helpers, 'flatten_query_params') as mock_flatten: + mock_flatten.return_value = [('$alt', 'json;enum-encoding=int'), ('foo', 'bar')] + + transcoded_request = { + 'uri': '/v1/test/operations/op1:cancel', + 'method': 'post', + 'body': {}, + } + + get_response_fn( + host='https://example.com', + metadata=[], + query_params={}, + session=mock_session, + timeout=None, + transcoded_request=transcoded_request, + body={}, + ) + + assert mock_session.post.called + call_url = mock_session.post.call_args.args[0] + assert '$alt=json' in call_url + assert '%24alt' not in call_url + assert 'foo=bar' in call_url +{% endif %} + +{% endif %} {# operations_mixin #} + +{# Location mixin REST URL encoding tests #} +{% if api.has_location_mixin %} + +{% if "ListLocations" in api.mixin_api_methods %} +def test_list_locations_rest_url_query_params_encoding(): + transport = transports.{{ service.rest_transport_name }}(credentials=ga_credentials.AnonymousCredentials()) + method_class = transport.list_locations.__class__ + get_response_fn = method_class._get_response + + mock_session = mock.Mock() + mock_response = mock.Mock() + mock_response.status_code = 200 + mock_session.get.return_value = mock_response + + with mock.patch.object(rest_helpers, 'flatten_query_params') as mock_flatten: + mock_flatten.return_value = [('$alt', 'json;enum-encoding=int'), ('foo', 'bar')] + + transcoded_request = { + 'uri': '/v1/projects/p1/locations', + 'method': 'get', + } + + get_response_fn( + host='https://example.com', + metadata=[], + query_params={}, + session=mock_session, + timeout=None, + transcoded_request=transcoded_request, + ) + + assert mock_session.get.called + call_url = mock_session.get.call_args.args[0] + assert '$alt=json' in call_url + assert '%24alt' not in call_url + assert 'foo=bar' in call_url +{% endif %} + +{% if "GetLocation" in api.mixin_api_methods %} +def test_get_location_rest_url_query_params_encoding(): + transport = transports.{{ service.rest_transport_name }}(credentials=ga_credentials.AnonymousCredentials()) + method_class = transport.get_location.__class__ + get_response_fn = method_class._get_response + + mock_session = mock.Mock() + mock_response = mock.Mock() + mock_response.status_code = 200 + mock_session.get.return_value = mock_response + + with mock.patch.object(rest_helpers, 'flatten_query_params') as mock_flatten: + mock_flatten.return_value = [('$alt', 'json;enum-encoding=int'), ('foo', 'bar')] + + transcoded_request = { + 'uri': '/v1/projects/p1/locations/l1', + 'method': 'get', + } + + get_response_fn( + host='https://example.com', + metadata=[], + query_params={}, + session=mock_session, + timeout=None, + transcoded_request=transcoded_request, + ) + + assert mock_session.get.called + call_url = mock_session.get.call_args.args[0] + assert '$alt=json' in call_url + assert '%24alt' not in call_url + assert 'foo=bar' in call_url +{% endif %} + +{% endif %} {# location_mixin #} + +{# IAM mixin REST URL encoding tests #} +{% if api.has_iam_mixin or opts.add_iam_methods %} + +{% if "SetIamPolicy" in api.mixin_api_methods or opts.add_iam_methods %} +def test_set_iam_policy_rest_url_query_params_encoding(): + transport = transports.{{ service.rest_transport_name }}(credentials=ga_credentials.AnonymousCredentials()) + method_class = transport.set_iam_policy.__class__ + get_response_fn = method_class._get_response + + mock_session = mock.Mock() + mock_response = mock.Mock() + mock_response.status_code = 200 + mock_session.post.return_value = mock_response + + with mock.patch.object(rest_helpers, 'flatten_query_params') as mock_flatten: + mock_flatten.return_value = [('$alt', 'json;enum-encoding=int'), ('foo', 'bar')] + + transcoded_request = { + 'uri': '/v1/resource:setIamPolicy', + 'method': 'post', + 'body': {}, + } + + get_response_fn( + host='https://example.com', + metadata=[], + query_params={}, + session=mock_session, + timeout=None, + transcoded_request=transcoded_request, + body={}, + ) + + assert mock_session.post.called + call_url = mock_session.post.call_args.args[0] + assert '$alt=json' in call_url + assert '%24alt' not in call_url + assert 'foo=bar' in call_url +{% endif %} + +{% if "GetIamPolicy" in api.mixin_api_methods or opts.add_iam_methods %} +def test_get_iam_policy_rest_url_query_params_encoding(): + transport = transports.{{ service.rest_transport_name }}(credentials=ga_credentials.AnonymousCredentials()) + method_class = transport.get_iam_policy.__class__ + get_response_fn = method_class._get_response + + mock_session = mock.Mock() + mock_response = mock.Mock() + mock_response.status_code = 200 + mock_session.get.return_value = mock_response + + with mock.patch.object(rest_helpers, 'flatten_query_params') as mock_flatten: + mock_flatten.return_value = [('$alt', 'json;enum-encoding=int'), ('foo', 'bar')] + + transcoded_request = { + 'uri': '/v1/resource:getIamPolicy', + 'method': 'get', + } + + get_response_fn( + host='https://example.com', + metadata=[], + query_params={}, + session=mock_session, + timeout=None, + transcoded_request=transcoded_request, + ) + + assert mock_session.get.called + call_url = mock_session.get.call_args.args[0] + assert '$alt=json' in call_url + assert '%24alt' not in call_url + assert 'foo=bar' in call_url +{% endif %} + +{% if "TestIamPermissions" in api.mixin_api_methods or opts.add_iam_methods %} +def test_test_iam_permissions_rest_url_query_params_encoding(): + transport = transports.{{ service.rest_transport_name }}(credentials=ga_credentials.AnonymousCredentials()) + method_class = transport.test_iam_permissions.__class__ + get_response_fn = method_class._get_response + + mock_session = mock.Mock() + mock_response = mock.Mock() + mock_response.status_code = 200 + mock_session.post.return_value = mock_response + + with mock.patch.object(rest_helpers, 'flatten_query_params') as mock_flatten: + mock_flatten.return_value = [('$alt', 'json;enum-encoding=int'), ('foo', 'bar')] + + transcoded_request = { + 'uri': '/v1/resource:testIamPermissions', + 'method': 'post', + 'body': {}, + } + + get_response_fn( + host='https://example.com', + metadata=[], + query_params={}, + session=mock_session, + timeout=None, + transcoded_request=transcoded_request, + body={}, + ) + + assert mock_session.post.called + call_url = mock_session.post.call_args.args[0] + assert '$alt=json' in call_url + assert '%24alt' not in call_url + assert 'foo=bar' in call_url +{% endif %} + +{% endif %} {# iam_mixin #} + +{% endif %} {# rest transport #} diff --git a/gapic/templates/tests/unit/gapic/%name_%version/%sub/test_%service.py.j2 b/gapic/templates/tests/unit/gapic/%name_%version/%sub/test_%service.py.j2 index c0e92cd9d6..73aeb74a50 100644 --- a/gapic/templates/tests/unit/gapic/%name_%version/%sub/test_%service.py.j2 +++ b/gapic/templates/tests/unit/gapic/%name_%version/%sub/test_%service.py.j2 @@ -21,6 +21,8 @@ import grpc from grpc.experimental import aio {% if "rest" in opts.transport %} from collections.abc import Iterable, AsyncIterable +import urllib.parse + from google.protobuf import json_format {% endif %} import json @@ -72,6 +74,7 @@ from google.api_core import exceptions as core_exceptions from google.api_core import grpc_helpers from google.api_core import grpc_helpers_async from google.api_core import path_template +from google.api_core import rest_helpers from google.api_core import retry as retries {% if service.has_lro or service.has_extended_lro %} from google.api_core import future diff --git a/gapic/templates/tests/unit/gapic/%name_%version/%sub/test_macros.j2 b/gapic/templates/tests/unit/gapic/%name_%version/%sub/test_macros.j2 index f15326d670..b738ff3176 100644 --- a/gapic/templates/tests/unit/gapic/%name_%version/%sub/test_macros.j2 +++ b/gapic/templates/tests/unit/gapic/%name_%version/%sub/test_macros.j2 @@ -1004,6 +1004,53 @@ def test_{{ method_name }}_raw_page_lro(): {% with method_name = method.client_method_name|snake_case + "_unary" if method.extended_lro and not full_extended_lro else method.client_method_name|snake_case, method_output = method.extended_lro.operation_type if method.extended_lro and not full_extended_lro else method.output %}{% if method.http_options %} {# TODO(kbandes): remove this if condition when lro and client streaming are supported. #} {% if not method.client_streaming %} + +def test_{{ method_name }}_rest_url_query_params_encoding(): + # Verify that special characters like '$' are correctly preserved (not URL-encoded) + # when building the URL query string. This tests the urlencode call with safe="$". + transport = transports.{{ service.rest_transport_name }}(credentials=ga_credentials.AnonymousCredentials) + method_class = transport.{{ method.transport_safe_name|snake_case }}.__class__ + # Get the _get_response static method from the method class + get_response_fn = method_class._get_response + + mock_session = mock.Mock() + mock_response = mock.Mock() + mock_response.status_code = 200 + mock_session.get.return_value = mock_response + mock_session.post.return_value = mock_response + mock_session.put.return_value = mock_response + mock_session.patch.return_value = mock_response + mock_session.delete.return_value = mock_response + + # Mock flatten_query_params to return query params that include '$' character + with mock.patch.object(rest_helpers, 'flatten_query_params') as mock_flatten: + mock_flatten.return_value = [('$alt', 'json;enum-encoding=int'), ('foo', 'bar')] + + transcoded_request = { + 'uri': '/v1/test', + 'method': '{{ method.http_options[0].method }}', + } + + get_response_fn( + host='https://example.com', + metadata=[], + query_params={}, + session=mock_session, + timeout=None, + transcoded_request=transcoded_request, + ) + + # Verify the session method was called with the URL containing query params + session_method = getattr(mock_session, '{{ method.http_options[0].method }}') + assert session_method.called + + # The URL should contain '$alt' (not '%24alt') because safe="$" is used + call_url = session_method.call_args.args[0] + assert '$alt=json' in call_url + assert '%24alt' not in call_url + assert 'foo=bar' in call_url + + def test_{{ method_name }}_rest_use_cached_wrapped_rpc(): # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, # instead of constructing them on each call @@ -1200,8 +1247,12 @@ def test_{{ method_name }}_rest_required_fields(request_type={{ method.input.ide ('$alt', 'json;enum-encoding=int') {% endif %} ] - actual_params = req.call_args.kwargs['params'] - assert expected_params == actual_params + # Verify query params are correctly included in the URL + # Session.request is called as request(method, url, ...), so url is args[1] + actual_url = req.call_args.args[1] + parsed_url = urllib.parse.urlparse(actual_url) + actual_params = urllib.parse.parse_qsl(parsed_url.query, keep_blank_values=True) + assert set(expected_params).issubset(set(actual_params)) def test_{{ method_name }}_rest_unset_required_fields():