Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion lightx2v/server/api/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,11 @@ async def _process_single_task(self, task_info: Any):
result = await generation_service.generate_with_stop_event(message, task_info.stop_event)

if result:
task_manager.complete_task(task_id, result.save_result_path)
task_manager.complete_task(
task_id,
save_result_path=result.save_result_path or None,
result_png=getattr(result, "result_png", None),
)
logger.info(f"Task {task_id} completed successfully")
else:
if task_info.stop_event.is_set():
Expand Down
24 changes: 12 additions & 12 deletions lightx2v/server/api/tasks/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@
from pathlib import Path

from fastapi import APIRouter, File, Form, HTTPException, Request, UploadFile
from fastapi.responses import Response
from loguru import logger

from ...schema import ImageTaskRequest, TaskResponse
from ...task_manager import TaskStatus, task_manager
from ..deps import get_services, validate_url_async
from .common import _stream_file_response

router = APIRouter()

Expand All @@ -20,9 +20,6 @@ def _write_file_sync(file_path: Path, content: bytes) -> None:


async def _wait_task_and_stream_result(task_id: str, timeout_seconds: int, poll_interval_seconds: float):
services = get_services()
assert services.file_service is not None, "File service is not initialized"

start_time = time.monotonic()
while True:
task_status = task_manager.get_task_status(task_id)
Expand All @@ -31,14 +28,14 @@ async def _wait_task_and_stream_result(task_id: str, timeout_seconds: int, poll_

status = task_status.get("status")
if status == TaskStatus.COMPLETED.value:
save_result_path = task_status.get("save_result_path")
if not save_result_path:
raise HTTPException(status_code=500, detail=f"Task completed but no result path found: {task_id}")

full_path = Path(save_result_path)
if not full_path.is_absolute():
full_path = services.file_service.output_video_dir / save_result_path
return _stream_file_response(full_path)
result_png = task_manager.get_task_result_png(task_id)
if result_png:
return Response(
content=result_png,
media_type="image/png",
headers={"Content-Disposition": 'inline; filename="result.png"'},
)
raise HTTPException(status_code=500, detail=f"Task completed but no in-memory image found: {task_id}")

if status == TaskStatus.FAILED.value:
raise HTTPException(status_code=500, detail=task_status.get("error", "Task failed"))
Expand Down Expand Up @@ -72,6 +69,7 @@ async def create_image_task(message: ImageTaskRequest):
if not await validate_url_async(message.image_mask_path):
raise HTTPException(status_code=400, detail=f"Image mask URL is not accessible: {message.image_mask_path}")

message.prefer_memory_result = False
task_id = task_manager.create_task(message)
message.task_id = task_id

Expand Down Expand Up @@ -108,6 +106,7 @@ async def create_image_task_sync(
if not await validate_url_async(message.image_mask_path):
raise HTTPException(status_code=400, detail=f"Image mask URL is not accessible: {message.image_mask_path}")

message.prefer_memory_result = True
task_id = task_manager.create_task(message)
message.task_id = task_id

Expand Down Expand Up @@ -184,6 +183,7 @@ async def save_file_async(file: UploadFile, target_dir: Path) -> str:
)

try:
message.prefer_memory_result = False
task_id = task_manager.create_task(message)
message.task_id = task_id

Expand Down
4 changes: 4 additions & 0 deletions lightx2v/server/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ class BaseTaskRequest(DisaggOverrideRequest):
target_shape: list[int] = Field([], description="Return video or image shape")
lora_name: Optional[str] = Field(None, description="LoRA filename to load from lora_dir, None to disable LoRA")
lora_strength: float = Field(1.0, description="LoRA strength")
# Internal switch: sync API sets this True to return image from memory only.
prefer_memory_result: bool = Field(default=False, exclude=True)

def __init__(self, **data):
super().__init__(**data)
Expand Down Expand Up @@ -83,6 +85,8 @@ class TaskResponse(BaseModel):
task_id: str
task_status: str
save_result_path: str
# Filled after image generation in-process; never serialized in JSON responses.
result_png: Optional[bytes] = Field(default=None, exclude=True)


class StopTaskResponse(BaseModel):
Expand Down
14 changes: 14 additions & 0 deletions lightx2v/server/services/generation/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ async def generate_with_stop_event(self, message: Any, stop_event) -> Optional[A

self._prepare_output_path(message.save_result_path, task_data)
task_data["seed"] = message.seed
prefer_memory_result = bool(getattr(message, "prefer_memory_result", False))
task_data.pop("prefer_memory_result", None)
task_data["return_result_tensor"] = prefer_memory_result

result = await self.inference_service.submit_task_async(task_data)

Expand All @@ -56,6 +59,17 @@ async def generate_with_stop_event(self, message: Any, stop_event) -> Optional[A
actual_save_path = self.file_service.get_output_path(message.save_result_path)
if not actual_save_path.suffix:
actual_save_path = actual_save_path.with_suffix(self.get_output_extension())
if prefer_memory_result:
result_png = result.get("result_png")
if not result_png:
raise RuntimeError("Image inference did not return in-memory PNG bytes (result_png)")
return TaskResponse(
task_id=message.task_id,
task_status="completed",
save_result_path="",
result_png=result_png,
)

return TaskResponse(
task_id=message.task_id,
task_status="completed",
Expand Down
139 changes: 139 additions & 0 deletions lightx2v/server/services/inference/pipeline_image_encode.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
"""Normalize image runner outputs to PNG bytes (in-memory, no disk)."""

from __future__ import annotations

import base64
import os
import time
from io import BytesIO
from typing import Any, Optional

import torch
from PIL import Image
from loguru import logger

try:
from torchvision.io import encode_png as tv_encode_png
except Exception:
tv_encode_png = None


def _get_png_compression_level() -> int:
raw = os.getenv("LIGHTX2V_SYNC_PNG_COMPRESSION", "6")
try:
level = int(raw)
except ValueError:
logger.warning(f"Invalid LIGHTX2V_SYNC_PNG_COMPRESSION={raw}, fallback to 6")
return 6
if level < 0 or level > 9:
logger.warning(f"LIGHTX2V_SYNC_PNG_COMPRESSION={level} out of range [0,9], clamped")
level = max(0, min(9, level))
return level


PNG_COMPRESSION_LEVEL = _get_png_compression_level()


def _pil_to_png_bytes(pil_image: Image.Image) -> bytes:
buf = BytesIO()
img = pil_image
if img.mode not in ("RGB", "RGBA"):
img = img.convert("RGB")
img.save(buf, format="PNG", compress_level=PNG_COMPRESSION_LEVEL)
return buf.getvalue()


def _pil_images_structure_to_png(images: Any) -> bytes:
first = images[0]
if isinstance(first, list):
pil_image = first[0]
else:
pil_image = first
if not hasattr(pil_image, "save"):
raise TypeError(f"Unexpected image element type: {type(pil_image)}")
return _pil_to_png_bytes(pil_image)


def _tensor_to_png_bytes(image_tensor: torch.Tensor) -> bytes:
total_start = time.perf_counter()
task_tag = f"shape={tuple(image_tensor.shape)},dtype={image_tensor.dtype},device={image_tensor.device}"

cpu_start = time.perf_counter()
tensor = image_tensor.detach().cpu()
cpu_ms = (time.perf_counter() - cpu_start) * 1000

if tensor.ndim == 4:
tensor = tensor[0]
if tensor.ndim != 3:
raise TypeError(f"Unsupported tensor shape: {tuple(tensor.shape)}")

prep_start = time.perf_counter()
# Normalize layout once: keep CHW for fast PNG encoding path.
if tensor.shape[0] in (1, 3, 4):
tensor_chw = tensor
elif tensor.shape[-1] in (1, 3, 4):
tensor_chw = tensor.permute(2, 0, 1)
else:
raise TypeError(f"Unsupported tensor channel layout: {tuple(tensor.shape)}")

if tensor_chw.dtype.is_floating_point:
# Most runners output floats in [0, 1].
if float(tensor_chw.max()) <= 1.0:
tensor_chw = (tensor_chw.clamp(0.0, 1.0) * 255.0).round()
else:
tensor_chw = tensor_chw.clamp(0.0, 255.0).round()

tensor_chw = tensor_chw.to(torch.uint8)
prep_ms = (time.perf_counter() - prep_start) * 1000

# Fast path: encode PNG directly from CHW uint8 tensor.
if tv_encode_png is not None:
encode_start = time.perf_counter()
png_bytes = tv_encode_png(tensor_chw, compression_level=PNG_COMPRESSION_LEVEL).numpy().tobytes()
encode_ms = (time.perf_counter() - encode_start) * 1000
total_ms = (time.perf_counter() - total_start) * 1000
logger.info(f"Tensor->PNG(tv) cost total={total_ms:.2f}ms cpu_copy={cpu_ms:.2f}ms preprocess={prep_ms:.2f}ms encode={encode_ms:.2f}ms level={PNG_COMPRESSION_LEVEL} [{task_tag}]")
return png_bytes

encode_start = time.perf_counter()
arr = tensor_chw.permute(1, 2, 0).numpy()
if arr.shape[-1] == 1:
arr = arr[:, :, 0]
png_bytes = _pil_to_png_bytes(Image.fromarray(arr))
encode_ms = (time.perf_counter() - encode_start) * 1000
total_ms = (time.perf_counter() - total_start) * 1000
logger.info(f"Tensor->PNG(pil) cost total={total_ms:.2f}ms cpu_copy={cpu_ms:.2f}ms preprocess={prep_ms:.2f}ms encode={encode_ms:.2f}ms level={PNG_COMPRESSION_LEVEL} [{task_tag}]")
return png_bytes


def encode_pipeline_return_to_png_bytes(pipeline_return: Any) -> Optional[bytes]:
"""Convert run_pipeline return value to a single PNG byte string, or None if not applicable."""
if pipeline_return is None:
return None
try:
if isinstance(pipeline_return, tuple) and len(pipeline_return) > 0:
# e.g. BagelRunner returns (images, audio_or_none)
pipeline_return = pipeline_return[0]
if isinstance(pipeline_return, dict):
images = pipeline_return.get("images")
if images is None:
return None
if isinstance(images, torch.Tensor):
return _tensor_to_png_bytes(images)
return _pil_images_structure_to_png(images)
if isinstance(pipeline_return, list) and len(pipeline_return) > 0:
if isinstance(pipeline_return[0], torch.Tensor):
return _tensor_to_png_bytes(pipeline_return[0])
return _pil_images_structure_to_png(pipeline_return)
Comment on lines +121 to +127
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic for handling dictionary returns from the pipeline assumes that images is either a torch.Tensor or a non-empty list of PIL images. If images is an empty list or a single PIL.Image object, it will raise an IndexError or TypeError. The suggested change adds robustness for these cases.

Suggested change
if isinstance(images, torch.Tensor):
return _tensor_to_png_bytes(images)
return _pil_images_structure_to_png(images)
if isinstance(pipeline_return, list) and len(pipeline_return) > 0:
if isinstance(pipeline_return[0], torch.Tensor):
return _tensor_to_png_bytes(pipeline_return[0])
return _pil_images_structure_to_png(pipeline_return)
if isinstance(pipeline_return, dict):
images = pipeline_return.get("images")
if images is None:
return None
if isinstance(images, torch.Tensor):
return _tensor_to_png_bytes(images)
if isinstance(images, (list, tuple)):
return _pil_images_structure_to_png(images) if images else None
if isinstance(images, Image.Image):
return _pil_to_png_bytes(images)
return None

if isinstance(pipeline_return, torch.Tensor):
return _tensor_to_png_bytes(pipeline_return)
if isinstance(pipeline_return, Image.Image):
return _pil_to_png_bytes(pipeline_return)
if isinstance(pipeline_return, str):
raw = base64.b64decode(pipeline_return)
img = Image.open(BytesIO(raw)).convert("RGB")
return _pil_to_png_bytes(img)
except Exception as e:
logger.exception(f"Failed to encode pipeline output to PNG: {e}")
return None
return None
19 changes: 15 additions & 4 deletions lightx2v/server/services/inference/worker.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import os
import time
from pathlib import Path
from typing import Any, Dict

Expand All @@ -11,6 +12,7 @@
from lightx2v.utils.set_config import set_config, set_parallel_config

from ..distributed_utils import DistributedManager
from .pipeline_image_encode import encode_pipeline_return_to_png_bytes


class TorchrunInferenceWorker:
Expand Down Expand Up @@ -66,6 +68,7 @@ def init(self, args) -> bool:
async def process_request(self, task_data: Dict[str, Any]) -> Dict[str, Any]:
has_error = False
error_msg = ""
pipeline_return = None

try:
if self.world_size > 1 and self.rank == 0:
Expand All @@ -79,7 +82,7 @@ async def process_request(self, task_data: Dict[str, Any]) -> Dict[str, Any]:
self.switch_lora(lora_name, lora_strength)

task_data["task"] = self.runner.config["task"]
task_data["return_result_tensor"] = False
task_data["return_result_tensor"] = bool(task_data.get("return_result_tensor", False))
task_data["negative_prompt"] = task_data.get("negative_prompt", "")

target_fps = task_data.pop("target_fps", None)
Expand All @@ -93,7 +96,7 @@ async def process_request(self, task_data: Dict[str, Any]) -> Dict[str, Any]:
update_input_info_from_dict(self.input_info, task_data)

self.runner.set_config(task_data)
self.runner.run_pipeline(self.input_info)
pipeline_return = self.runner.run_pipeline(self.input_info)

await asyncio.sleep(0)

Expand All @@ -114,12 +117,20 @@ async def process_request(self, task_data: Dict[str, Any]) -> Dict[str, Any]:
"message": f"Inference failed: {error_msg}",
}
else:
return {
out: Dict[str, Any] = {
"task_id": task_data["task_id"],
"status": "success",
"save_result_path": task_data["save_result_path"],
"save_result_path": task_data.get("save_result_path"),
"message": "Inference completed",
}
if task_data.get("return_result_tensor"):
encode_start = time.perf_counter()
png = encode_pipeline_return_to_png_bytes(pipeline_return)
encode_elapsed_ms = (time.perf_counter() - encode_start) * 1000
logger.info(f"Task {task_data.get('task_id')} encode result_png cost {encode_elapsed_ms:.2f} ms")
if png:
out["result_png"] = png
return out
else:
return None

Expand Down
14 changes: 11 additions & 3 deletions lightx2v/server/task_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class TaskInfo:
end_time: Optional[datetime] = None
error: Optional[str] = None
save_result_path: Optional[str] = None
result_png: Optional[bytes] = None
stop_event: threading.Event = field(default_factory=threading.Event)
thread: Optional[threading.Thread] = None

Expand Down Expand Up @@ -81,7 +82,7 @@ def start_task(self, task_id: str) -> TaskInfo:

return task

def complete_task(self, task_id: str, save_result_path: Optional[str] = None):
def complete_task(self, task_id: str, save_result_path: Optional[str] = None, result_png: Optional[bytes] = None):
with self._lock:
if task_id not in self._tasks:
logger.warning(f"Task {task_id} not found for completion")
Expand All @@ -90,8 +91,8 @@ def complete_task(self, task_id: str, save_result_path: Optional[str] = None):
task = self._tasks[task_id]
task.status = TaskStatus.COMPLETED
task.end_time = datetime.now()
if save_result_path:
task.save_result_path = save_result_path
task.save_result_path = save_result_path
task.result_png = result_png

self.completed_tasks += 1
self._emit_queue_metrics_unlocked()
Expand Down Expand Up @@ -141,6 +142,13 @@ def get_task(self, task_id: str) -> Optional[TaskInfo]:
with self._lock:
return self._tasks.get(task_id)

def get_task_result_png(self, task_id: str) -> Optional[bytes]:
with self._lock:
task = self._tasks.get(task_id)
if not task:
return None
return task.result_png
Comment on lines +145 to +150
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The result_png field stores image bytes in memory. Since TaskManager retains up to 1000 tasks by default, frequent use of the sync API could lead to significant memory consumption and potential Out-Of-Memory (OOM) errors. It is recommended to clear the result_png data once it has been retrieved by the client.

Suggested change
def get_task_result_png(self, task_id: str) -> Optional[bytes]:
with self._lock:
task = self._tasks.get(task_id)
if not task:
return None
return task.result_png
def get_task_result_png(self, task_id: str) -> Optional[bytes]:
with self._lock:
task = self._tasks.get(task_id)
if not task:
return None
res = task.result_png
task.result_png = None # Clear memory after retrieval
return res


def get_task_status(self, task_id: str) -> Optional[Dict[str, Any]]:
task = self.get_task(task_id)
if not task:
Expand Down
Loading