diff --git a/endpoint/exceptions.py b/endpoint/exceptions.py index f95139d7..0286b476 100644 --- a/endpoint/exceptions.py +++ b/endpoint/exceptions.py @@ -4,23 +4,32 @@ import json import werkzeug.exceptions -import werkzeug.wrappers +from odoo.http import Response -class RequestValidationError(werkzeug.exceptions.BadRequest): - """Bad request raised when the body fails JSON Schema validation. - Emits ``{"detail": [{"loc", "msg", "type"}, ...]}`` (FastAPI-style) - instead of the generic werkzeug HTML body. +class RequestValidationError(werkzeug.exceptions.HTTPException): + """Raised when the body fails JSON/XML schema validation. + + Emits ``{"detail": [{"loc", "msg", "type"}, ...]}`` (FastAPI-style) with a + 400 status instead of the generic werkzeug HTML body. + + ``code`` is left as ``None`` on purpose: this makes Odoo treat the + exception as a ready-made response and run the normal ``post_dispatch`` + hook on it, so route-managed headers (CORS, CSP, ...) are preserved. A real + ``code`` (e.g. via ``BadRequest``) would route through ``handle_error`` + instead and drop those headers. """ + code = None + def __init__(self, detail): super().__init__() self.detail = detail def get_response(self, environ=None, scope=None): - return werkzeug.wrappers.Response( + return Response( json.dumps({"detail": self.detail}), - status=self.code, + status=400, mimetype="application/json", ) diff --git a/endpoint/tests/test_endpoint.py b/endpoint/tests/test_endpoint.py index 6c41a80f..49030bec 100644 --- a/endpoint/tests/test_endpoint.py +++ b/endpoint/tests/test_endpoint.py @@ -140,6 +140,7 @@ def test_routing(self): "methods": ["GET"], "routes": ["/demo/one"], "type": "http", + "cors": "*", "csrf": False, "readonly": False, }, @@ -161,6 +162,7 @@ def test_routing(self): "methods": ["POST"], "routes": ["/new/one"], "type": "http", + "cors": "*", "csrf": False, "readonly": False, }, @@ -176,6 +178,7 @@ def test_routing(self): "methods": ["POST"], "routes": ["/foo/new/one"], "type": "http", + "cors": "*", "csrf": False, "readonly": False, }, diff --git a/endpoint/tests/test_endpoint_content_schema_validation.py b/endpoint/tests/test_endpoint_content_schema_validation.py index 711ba87b..c4abca8d 100644 --- a/endpoint/tests/test_endpoint_content_schema_validation.py +++ b/endpoint/tests/test_endpoint_content_schema_validation.py @@ -152,6 +152,21 @@ def test_json_invalid_body(self): self.assertEqual(payload["detail"][0]["loc"], ["body", "data"]) self.assertEqual(payload["detail"][0]["type"], "type") + @mute_logger("endpoint.endpoint") + def test_validation_error_preserves_cors_headers(self): + response = self.url_open( + "/demo/schema", + data=json.dumps({"data": "not-an-array"}), + headers={ + "Content-Type": "application/json", + "Origin": "https://editor.swagger.io", + }, + ) + self.assertEqual(response.status_code, 400) + # Route-managed CORS headers (default cors="*") must survive on the + # validation-error response, not only on a successful one. + self.assertEqual(response.headers.get("Access-Control-Allow-Origin"), "*") + @mute_logger("endpoint.endpoint") def test_json_malformed_body(self): response = self.url_open( diff --git a/endpoint/views/endpoint_view.xml b/endpoint/views/endpoint_view.xml index d51ddfba..cc50404e 100644 --- a/endpoint/views/endpoint_view.xml +++ b/endpoint/views/endpoint_view.xml @@ -67,6 +67,7 @@ required="request_method in ('POST', 'PUT')" invisible="request_method not in ('POST', 'PUT')" /> + diff --git a/endpoint_route_handler/models/endpoint_route_handler.py b/endpoint_route_handler/models/endpoint_route_handler.py index bba74a54..56e8c07b 100644 --- a/endpoint_route_handler/models/endpoint_route_handler.py +++ b/endpoint_route_handler/models/endpoint_route_handler.py @@ -43,6 +43,7 @@ class EndpointRouteHandler(models.AbstractModel): endpoint_hash = fields.Char( compute="_compute_endpoint_hash", help="Identify the route with its main params" ) + cors = fields.Char(help="Comma-separated list of allowed origins", default="*") csrf = fields.Boolean(default=False) readonly = fields.Boolean(default=False) @@ -248,6 +249,7 @@ def _get_routing_info(self): auth=self.auth_type, methods=[self.request_method], routes=[route], + cors=self.cors, csrf=self.csrf, readonly=self.readonly, )