diff --git a/decart/models.py b/decart/models.py index a63cbe9..7a3cbe7 100644 --- a/decart/models.py +++ b/decart/models.py @@ -113,6 +113,7 @@ class ImageToImageInput(DecartBaseModel): max_length=1000, ) data: FileInput + reference_image: Optional[FileInput] = None seed: Optional[int] = None resolution: Optional[str] = None enhance_prompt: Optional[bool] = None diff --git a/decart/process/request.py b/decart/process/request.py index 9f6b132..49ff7d2 100644 --- a/decart/process/request.py +++ b/decart/process/request.py @@ -89,7 +89,7 @@ async def send_request( for key, value in inputs.items(): if value is not None: - if key in ("data", "start", "end"): + if key in ("data", "start", "end", "reference_image"): content, content_type = await file_input_to_bytes(value, session) form_data.add_field(key, content, content_type=content_type) else: diff --git a/tests/test_process.py b/tests/test_process.py index 736bd7e..007c5b7 100644 --- a/tests/test_process.py +++ b/tests/test_process.py @@ -41,6 +41,38 @@ async def test_process_image_to_image() -> None: assert result == b"fake image data" +@pytest.mark.asyncio +async def test_process_image_to_image_with_reference_image() -> None: + """Test image-to-image with optional reference_image.""" + client = DecartClient(api_key="test-key") + + with patch("aiohttp.ClientSession") as mock_session_cls: + mock_response = MagicMock() + mock_response.ok = True + mock_response.read = AsyncMock(return_value=b"fake image data") + + mock_session = MagicMock() + mock_session.__aenter__ = AsyncMock(return_value=mock_session) + mock_session.__aexit__ = AsyncMock(return_value=None) + mock_session.post = MagicMock() + mock_session.post.return_value.__aenter__ = AsyncMock(return_value=mock_response) + mock_session.post.return_value.__aexit__ = AsyncMock(return_value=None) + + mock_session_cls.return_value = mock_session + + result = await client.process( + { + "model": models.image("lucy-pro-i2i"), + "prompt": "Add the object from the reference image", + "data": b"fake input image", + "reference_image": b"fake reference image", + "enhance_prompt": False, + } + ) + + assert result == b"fake image data" + + @pytest.mark.asyncio async def test_process_rejects_video_models() -> None: """Test that process() rejects video models with helpful error message."""