diff --git a/lightx2v/server/api/server.py b/lightx2v/server/api/server.py index fd5071141..ff1ac0ba5 100644 --- a/lightx2v/server/api/server.py +++ b/lightx2v/server/api/server.py @@ -129,7 +129,11 @@ async def _process_single_task(self, task_info: Any): result = await generation_service.generate_with_stop_event(message, task_info.stop_event) if result: - task_manager.complete_task(task_id, result.save_result_path) + task_manager.complete_task( + task_id, + save_result_path=result.save_result_path or None, + result_png=getattr(result, "result_png", None), + ) logger.info(f"Task {task_id} completed successfully") else: if task_info.stop_event.is_set(): diff --git a/lightx2v/server/api/tasks/image.py b/lightx2v/server/api/tasks/image.py index d98f64225..2086eb9af 100644 --- a/lightx2v/server/api/tasks/image.py +++ b/lightx2v/server/api/tasks/image.py @@ -4,12 +4,12 @@ from pathlib import Path from fastapi import APIRouter, File, Form, HTTPException, Request, UploadFile +from fastapi.responses import Response from loguru import logger from ...schema import ImageTaskRequest, TaskResponse from ...task_manager import TaskStatus, task_manager from ..deps import get_services, validate_url_async -from .common import _stream_file_response router = APIRouter() @@ -20,9 +20,6 @@ def _write_file_sync(file_path: Path, content: bytes) -> None: async def _wait_task_and_stream_result(task_id: str, timeout_seconds: int, poll_interval_seconds: float): - services = get_services() - assert services.file_service is not None, "File service is not initialized" - start_time = time.monotonic() while True: task_status = task_manager.get_task_status(task_id) @@ -31,14 +28,14 @@ async def _wait_task_and_stream_result(task_id: str, timeout_seconds: int, poll_ status = task_status.get("status") if status == TaskStatus.COMPLETED.value: - save_result_path = task_status.get("save_result_path") - if not save_result_path: - raise HTTPException(status_code=500, detail=f"Task completed but no result path found: {task_id}") - - full_path = Path(save_result_path) - if not full_path.is_absolute(): - full_path = services.file_service.output_video_dir / save_result_path - return _stream_file_response(full_path) + result_png = task_manager.get_task_result_png(task_id) + if result_png: + return Response( + content=result_png, + media_type="image/png", + headers={"Content-Disposition": 'inline; filename="result.png"'}, + ) + raise HTTPException(status_code=500, detail=f"Task completed but no in-memory image found: {task_id}") if status == TaskStatus.FAILED.value: raise HTTPException(status_code=500, detail=task_status.get("error", "Task failed")) @@ -72,6 +69,7 @@ async def create_image_task(message: ImageTaskRequest): if not await validate_url_async(message.image_mask_path): raise HTTPException(status_code=400, detail=f"Image mask URL is not accessible: {message.image_mask_path}") + message.prefer_memory_result = False task_id = task_manager.create_task(message) message.task_id = task_id @@ -108,6 +106,7 @@ async def create_image_task_sync( if not await validate_url_async(message.image_mask_path): raise HTTPException(status_code=400, detail=f"Image mask URL is not accessible: {message.image_mask_path}") + message.prefer_memory_result = True task_id = task_manager.create_task(message) message.task_id = task_id @@ -184,6 +183,7 @@ async def save_file_async(file: UploadFile, target_dir: Path) -> str: ) try: + message.prefer_memory_result = False task_id = task_manager.create_task(message) message.task_id = task_id diff --git a/lightx2v/server/schema.py b/lightx2v/server/schema.py index 9acd32ae1..ba7f16fc0 100644 --- a/lightx2v/server/schema.py +++ b/lightx2v/server/schema.py @@ -40,6 +40,8 @@ class BaseTaskRequest(DisaggOverrideRequest): target_shape: list[int] = Field([], description="Return video or image shape") lora_name: Optional[str] = Field(None, description="LoRA filename to load from lora_dir, None to disable LoRA") lora_strength: float = Field(1.0, description="LoRA strength") + # Internal switch: sync API sets this True to return image from memory only. + prefer_memory_result: bool = Field(default=False, exclude=True) def __init__(self, **data): super().__init__(**data) @@ -83,6 +85,8 @@ class TaskResponse(BaseModel): task_id: str task_status: str save_result_path: str + # Filled after image generation in-process; never serialized in JSON responses. + result_png: Optional[bytes] = Field(default=None, exclude=True) class StopTaskResponse(BaseModel): diff --git a/lightx2v/server/services/generation/image.py b/lightx2v/server/services/generation/image.py index 110492db5..df22e8b97 100644 --- a/lightx2v/server/services/generation/image.py +++ b/lightx2v/server/services/generation/image.py @@ -43,6 +43,9 @@ async def generate_with_stop_event(self, message: Any, stop_event) -> Optional[A self._prepare_output_path(message.save_result_path, task_data) task_data["seed"] = message.seed + prefer_memory_result = bool(getattr(message, "prefer_memory_result", False)) + task_data.pop("prefer_memory_result", None) + task_data["return_result_tensor"] = prefer_memory_result result = await self.inference_service.submit_task_async(task_data) @@ -56,6 +59,17 @@ async def generate_with_stop_event(self, message: Any, stop_event) -> Optional[A actual_save_path = self.file_service.get_output_path(message.save_result_path) if not actual_save_path.suffix: actual_save_path = actual_save_path.with_suffix(self.get_output_extension()) + if prefer_memory_result: + result_png = result.get("result_png") + if not result_png: + raise RuntimeError("Image inference did not return in-memory PNG bytes (result_png)") + return TaskResponse( + task_id=message.task_id, + task_status="completed", + save_result_path="", + result_png=result_png, + ) + return TaskResponse( task_id=message.task_id, task_status="completed", diff --git a/lightx2v/server/services/inference/pipeline_image_encode.py b/lightx2v/server/services/inference/pipeline_image_encode.py new file mode 100644 index 000000000..25bf890cb --- /dev/null +++ b/lightx2v/server/services/inference/pipeline_image_encode.py @@ -0,0 +1,139 @@ +"""Normalize image runner outputs to PNG bytes (in-memory, no disk).""" + +from __future__ import annotations + +import base64 +import os +import time +from io import BytesIO +from typing import Any, Optional + +import torch +from PIL import Image +from loguru import logger + +try: + from torchvision.io import encode_png as tv_encode_png +except Exception: + tv_encode_png = None + + +def _get_png_compression_level() -> int: + raw = os.getenv("LIGHTX2V_SYNC_PNG_COMPRESSION", "6") + try: + level = int(raw) + except ValueError: + logger.warning(f"Invalid LIGHTX2V_SYNC_PNG_COMPRESSION={raw}, fallback to 6") + return 6 + if level < 0 or level > 9: + logger.warning(f"LIGHTX2V_SYNC_PNG_COMPRESSION={level} out of range [0,9], clamped") + level = max(0, min(9, level)) + return level + + +PNG_COMPRESSION_LEVEL = _get_png_compression_level() + + +def _pil_to_png_bytes(pil_image: Image.Image) -> bytes: + buf = BytesIO() + img = pil_image + if img.mode not in ("RGB", "RGBA"): + img = img.convert("RGB") + img.save(buf, format="PNG", compress_level=PNG_COMPRESSION_LEVEL) + return buf.getvalue() + + +def _pil_images_structure_to_png(images: Any) -> bytes: + first = images[0] + if isinstance(first, list): + pil_image = first[0] + else: + pil_image = first + if not hasattr(pil_image, "save"): + raise TypeError(f"Unexpected image element type: {type(pil_image)}") + return _pil_to_png_bytes(pil_image) + + +def _tensor_to_png_bytes(image_tensor: torch.Tensor) -> bytes: + total_start = time.perf_counter() + task_tag = f"shape={tuple(image_tensor.shape)},dtype={image_tensor.dtype},device={image_tensor.device}" + + cpu_start = time.perf_counter() + tensor = image_tensor.detach().cpu() + cpu_ms = (time.perf_counter() - cpu_start) * 1000 + + if tensor.ndim == 4: + tensor = tensor[0] + if tensor.ndim != 3: + raise TypeError(f"Unsupported tensor shape: {tuple(tensor.shape)}") + + prep_start = time.perf_counter() + # Normalize layout once: keep CHW for fast PNG encoding path. + if tensor.shape[0] in (1, 3, 4): + tensor_chw = tensor + elif tensor.shape[-1] in (1, 3, 4): + tensor_chw = tensor.permute(2, 0, 1) + else: + raise TypeError(f"Unsupported tensor channel layout: {tuple(tensor.shape)}") + + if tensor_chw.dtype.is_floating_point: + # Most runners output floats in [0, 1]. + if float(tensor_chw.max()) <= 1.0: + tensor_chw = (tensor_chw.clamp(0.0, 1.0) * 255.0).round() + else: + tensor_chw = tensor_chw.clamp(0.0, 255.0).round() + + tensor_chw = tensor_chw.to(torch.uint8) + prep_ms = (time.perf_counter() - prep_start) * 1000 + + # Fast path: encode PNG directly from CHW uint8 tensor. + if tv_encode_png is not None: + encode_start = time.perf_counter() + png_bytes = tv_encode_png(tensor_chw, compression_level=PNG_COMPRESSION_LEVEL).numpy().tobytes() + encode_ms = (time.perf_counter() - encode_start) * 1000 + total_ms = (time.perf_counter() - total_start) * 1000 + logger.info(f"Tensor->PNG(tv) cost total={total_ms:.2f}ms cpu_copy={cpu_ms:.2f}ms preprocess={prep_ms:.2f}ms encode={encode_ms:.2f}ms level={PNG_COMPRESSION_LEVEL} [{task_tag}]") + return png_bytes + + encode_start = time.perf_counter() + arr = tensor_chw.permute(1, 2, 0).numpy() + if arr.shape[-1] == 1: + arr = arr[:, :, 0] + png_bytes = _pil_to_png_bytes(Image.fromarray(arr)) + encode_ms = (time.perf_counter() - encode_start) * 1000 + total_ms = (time.perf_counter() - total_start) * 1000 + logger.info(f"Tensor->PNG(pil) cost total={total_ms:.2f}ms cpu_copy={cpu_ms:.2f}ms preprocess={prep_ms:.2f}ms encode={encode_ms:.2f}ms level={PNG_COMPRESSION_LEVEL} [{task_tag}]") + return png_bytes + + +def encode_pipeline_return_to_png_bytes(pipeline_return: Any) -> Optional[bytes]: + """Convert run_pipeline return value to a single PNG byte string, or None if not applicable.""" + if pipeline_return is None: + return None + try: + if isinstance(pipeline_return, tuple) and len(pipeline_return) > 0: + # e.g. BagelRunner returns (images, audio_or_none) + pipeline_return = pipeline_return[0] + if isinstance(pipeline_return, dict): + images = pipeline_return.get("images") + if images is None: + return None + if isinstance(images, torch.Tensor): + return _tensor_to_png_bytes(images) + return _pil_images_structure_to_png(images) + if isinstance(pipeline_return, list) and len(pipeline_return) > 0: + if isinstance(pipeline_return[0], torch.Tensor): + return _tensor_to_png_bytes(pipeline_return[0]) + return _pil_images_structure_to_png(pipeline_return) + if isinstance(pipeline_return, torch.Tensor): + return _tensor_to_png_bytes(pipeline_return) + if isinstance(pipeline_return, Image.Image): + return _pil_to_png_bytes(pipeline_return) + if isinstance(pipeline_return, str): + raw = base64.b64decode(pipeline_return) + img = Image.open(BytesIO(raw)).convert("RGB") + return _pil_to_png_bytes(img) + except Exception as e: + logger.exception(f"Failed to encode pipeline output to PNG: {e}") + return None + return None diff --git a/lightx2v/server/services/inference/worker.py b/lightx2v/server/services/inference/worker.py index 6a1fabf8a..238ffd9a5 100644 --- a/lightx2v/server/services/inference/worker.py +++ b/lightx2v/server/services/inference/worker.py @@ -1,5 +1,6 @@ import asyncio import os +import time from pathlib import Path from typing import Any, Dict @@ -11,6 +12,7 @@ from lightx2v.utils.set_config import set_config, set_parallel_config from ..distributed_utils import DistributedManager +from .pipeline_image_encode import encode_pipeline_return_to_png_bytes class TorchrunInferenceWorker: @@ -66,6 +68,7 @@ def init(self, args) -> bool: async def process_request(self, task_data: Dict[str, Any]) -> Dict[str, Any]: has_error = False error_msg = "" + pipeline_return = None try: if self.world_size > 1 and self.rank == 0: @@ -79,7 +82,7 @@ async def process_request(self, task_data: Dict[str, Any]) -> Dict[str, Any]: self.switch_lora(lora_name, lora_strength) task_data["task"] = self.runner.config["task"] - task_data["return_result_tensor"] = False + task_data["return_result_tensor"] = bool(task_data.get("return_result_tensor", False)) task_data["negative_prompt"] = task_data.get("negative_prompt", "") target_fps = task_data.pop("target_fps", None) @@ -93,7 +96,7 @@ async def process_request(self, task_data: Dict[str, Any]) -> Dict[str, Any]: update_input_info_from_dict(self.input_info, task_data) self.runner.set_config(task_data) - self.runner.run_pipeline(self.input_info) + pipeline_return = self.runner.run_pipeline(self.input_info) await asyncio.sleep(0) @@ -114,12 +117,20 @@ async def process_request(self, task_data: Dict[str, Any]) -> Dict[str, Any]: "message": f"Inference failed: {error_msg}", } else: - return { + out: Dict[str, Any] = { "task_id": task_data["task_id"], "status": "success", - "save_result_path": task_data["save_result_path"], + "save_result_path": task_data.get("save_result_path"), "message": "Inference completed", } + if task_data.get("return_result_tensor"): + encode_start = time.perf_counter() + png = encode_pipeline_return_to_png_bytes(pipeline_return) + encode_elapsed_ms = (time.perf_counter() - encode_start) * 1000 + logger.info(f"Task {task_data.get('task_id')} encode result_png cost {encode_elapsed_ms:.2f} ms") + if png: + out["result_png"] = png + return out else: return None diff --git a/lightx2v/server/task_manager.py b/lightx2v/server/task_manager.py index ed12b5f96..93b8a4e40 100644 --- a/lightx2v/server/task_manager.py +++ b/lightx2v/server/task_manager.py @@ -28,6 +28,7 @@ class TaskInfo: end_time: Optional[datetime] = None error: Optional[str] = None save_result_path: Optional[str] = None + result_png: Optional[bytes] = None stop_event: threading.Event = field(default_factory=threading.Event) thread: Optional[threading.Thread] = None @@ -81,7 +82,7 @@ def start_task(self, task_id: str) -> TaskInfo: return task - def complete_task(self, task_id: str, save_result_path: Optional[str] = None): + def complete_task(self, task_id: str, save_result_path: Optional[str] = None, result_png: Optional[bytes] = None): with self._lock: if task_id not in self._tasks: logger.warning(f"Task {task_id} not found for completion") @@ -90,8 +91,8 @@ def complete_task(self, task_id: str, save_result_path: Optional[str] = None): task = self._tasks[task_id] task.status = TaskStatus.COMPLETED task.end_time = datetime.now() - if save_result_path: - task.save_result_path = save_result_path + task.save_result_path = save_result_path + task.result_png = result_png self.completed_tasks += 1 self._emit_queue_metrics_unlocked() @@ -141,6 +142,13 @@ def get_task(self, task_id: str) -> Optional[TaskInfo]: with self._lock: return self._tasks.get(task_id) + def get_task_result_png(self, task_id: str) -> Optional[bytes]: + with self._lock: + task = self._tasks.get(task_id) + if not task: + return None + return task.result_png + def get_task_status(self, task_id: str) -> Optional[Dict[str, Any]]: task = self.get_task(task_id) if not task: