diff --git a/pyproject.toml b/pyproject.toml index 21eea5c..19a08a6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,6 +16,7 @@ dependencies = [ "botorch", "openai", "tqdm", + "gradio", ] dynamic = ["version"] diff --git a/requirements.txt b/requirements.txt index 6751967..e34c54a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,84 +1,174 @@ # This file was autogenerated by uv via the following command: -# uv pip compile pyproject.toml --output-file requirements.txt +# uv pip compile requirements.txt --output-file requirements.txt +aiofiles==24.1.0 + # via gradio annotated-types==0.7.0 - # via pydantic + # via + # -r requirements.txt + # pydantic anyio==4.9.0 # via + # -r requirements.txt + # gradio # httpx # openai + # starlette botorch==0.14.0 - # via eaa (pyproject.toml) + # via -r requirements.txt certifi==2025.4.26 # via + # -r requirements.txt # httpcore # httpx + # requests +charset-normalizer==3.4.2 + # via requests +click==8.2.1 + # via + # typer + # uvicorn contourpy==1.3.2 - # via matplotlib + # via + # -r requirements.txt + # matplotlib cycler==0.12.1 - # via matplotlib + # via + # -r requirements.txt + # matplotlib distro==1.9.0 - # via openai + # via + # -r requirements.txt + # openai +fastapi==0.115.12 + # via gradio +ffmpy==0.6.0 + # via gradio filelock==3.18.0 - # via torch + # via + # -r requirements.txt + # huggingface-hub + # torch fonttools==4.57.0 - # via matplotlib + # via + # -r requirements.txt + # matplotlib fsspec==2025.5.1 - # via torch + # via + # -r requirements.txt + # gradio-client + # huggingface-hub + # torch gpytorch==1.14 - # via botorch + # via + # -r requirements.txt + # botorch +gradio==5.34.0 + # via -r requirements.txt +gradio-client==1.10.3 + # via gradio +groovy==0.1.2 + # via gradio h11==0.16.0 - # via httpcore + # via + # -r requirements.txt + # httpcore + # uvicorn +hf-xet==1.1.3 + # via huggingface-hub httpcore==1.0.9 - # via httpx + # via + # -r requirements.txt + # httpx httpx==0.28.1 - # via openai + # via + # -r requirements.txt + # gradio + # gradio-client + # openai + # safehttpx +huggingface-hub==0.33.0 + # via + # gradio + # gradio-client idna==3.10 # via + # -r requirements.txt # anyio # httpx + # requests imageio==2.37.0 - # via scikit-image + # via + # -r requirements.txt + # scikit-image jaxtyping==0.3.2 # via + # -r requirements.txt # gpytorch # linear-operator jinja2==3.1.6 - # via torch + # via + # -r requirements.txt + # gradio + # torch jiter==0.10.0 - # via openai + # via + # -r requirements.txt + # openai joblib==1.5.1 - # via scikit-learn + # via + # -r requirements.txt + # scikit-learn kiwisolver==1.4.8 - # via matplotlib + # via + # -r requirements.txt + # matplotlib lazy-loader==0.4 - # via scikit-image + # via + # -r requirements.txt + # scikit-image linear-operator==0.6 # via + # -r requirements.txt # botorch # gpytorch +markdown-it-py==3.0.0 + # via rich markupsafe==3.0.2 - # via jinja2 + # via + # -r requirements.txt + # gradio + # jinja2 matplotlib==3.10.1 - # via eaa (pyproject.toml) + # via -r requirements.txt +mdurl==0.1.2 + # via markdown-it-py mpmath==1.3.0 # via + # -r requirements.txt # gpytorch # linear-operator # sympy multipledispatch==1.0.0 - # via botorch + # via + # -r requirements.txt + # botorch mypy-extensions==1.1.0 - # via typing-inspect + # via + # -r requirements.txt + # typing-inspect networkx==3.4.2 # via + # -r requirements.txt # scikit-image # torch numpy==2.2.5 # via - # eaa (pyproject.toml) + # -r requirements.txt # contourpy + # gradio # imageio # matplotlib + # pandas # pyro-ppl # scikit-image # scikit-learn @@ -86,123 +176,245 @@ numpy==2.2.5 # tifffile nvidia-cublas-cu12==12.6.4.1 # via + # -r requirements.txt # nvidia-cudnn-cu12 # nvidia-cusolver-cu12 # torch nvidia-cuda-cupti-cu12==12.6.80 - # via torch + # via + # -r requirements.txt + # torch nvidia-cuda-nvrtc-cu12==12.6.77 - # via torch + # via + # -r requirements.txt + # torch nvidia-cuda-runtime-cu12==12.6.77 - # via torch + # via + # -r requirements.txt + # torch nvidia-cudnn-cu12==9.5.1.17 - # via torch + # via + # -r requirements.txt + # torch nvidia-cufft-cu12==11.3.0.4 - # via torch + # via + # -r requirements.txt + # torch nvidia-cufile-cu12==1.11.1.6 - # via torch + # via + # -r requirements.txt + # torch nvidia-curand-cu12==10.3.7.77 - # via torch + # via + # -r requirements.txt + # torch nvidia-cusolver-cu12==11.7.1.2 - # via torch + # via + # -r requirements.txt + # torch nvidia-cusparse-cu12==12.5.4.2 # via + # -r requirements.txt # nvidia-cusolver-cu12 # torch nvidia-cusparselt-cu12==0.6.3 - # via torch + # via + # -r requirements.txt + # torch nvidia-nccl-cu12==2.26.2 - # via torch + # via + # -r requirements.txt + # torch nvidia-nvjitlink-cu12==12.6.85 # via + # -r requirements.txt # nvidia-cufft-cu12 # nvidia-cusolver-cu12 # nvidia-cusparse-cu12 # torch nvidia-nvtx-cu12==12.6.77 - # via torch + # via + # -r requirements.txt + # torch openai==1.83.0 - # via eaa (pyproject.toml) + # via -r requirements.txt opt-einsum==3.4.0 - # via pyro-ppl + # via + # -r requirements.txt + # pyro-ppl +orjson==3.10.18 + # via gradio packaging==25.0 # via + # -r requirements.txt + # gradio + # gradio-client + # huggingface-hub # lazy-loader # matplotlib # scikit-image +pandas==2.3.0 + # via gradio pillow==11.2.1 # via + # -r requirements.txt + # gradio # imageio # matplotlib # scikit-image pydantic==2.11.3 - # via openai + # via + # -r requirements.txt + # fastapi + # gradio + # openai pydantic-core==2.33.1 - # via pydantic + # via + # -r requirements.txt + # pydantic +pydub==0.25.1 + # via gradio +pygments==2.19.1 + # via rich pyparsing==3.2.3 - # via matplotlib + # via + # -r requirements.txt + # matplotlib pyre-extensions==0.0.32 - # via botorch + # via + # -r requirements.txt + # botorch pyro-api==0.1.2 - # via pyro-ppl + # via + # -r requirements.txt + # pyro-ppl pyro-ppl==1.9.1 - # via botorch + # via + # -r requirements.txt + # botorch python-dateutil==2.9.0.post0 - # via matplotlib + # via + # -r requirements.txt + # matplotlib + # pandas +python-multipart==0.0.20 + # via gradio +pytz==2025.2 + # via pandas +pyyaml==6.0.2 + # via + # gradio + # huggingface-hub +requests==2.32.4 + # via huggingface-hub +rich==14.0.0 + # via typer +ruff==0.11.13 + # via gradio +safehttpx==0.1.6 + # via gradio scikit-image==0.25.2 - # via eaa (pyproject.toml) + # via -r requirements.txt scikit-learn==1.6.1 - # via gpytorch + # via + # -r requirements.txt + # gpytorch scipy==1.15.2 # via - # eaa (pyproject.toml) + # -r requirements.txt # botorch # gpytorch # linear-operator # scikit-image # scikit-learn +semantic-version==2.10.0 + # via gradio setuptools==80.9.0 - # via triton + # via + # -r requirements.txt + # triton +shellingham==1.5.4 + # via typer six==1.17.0 - # via python-dateutil + # via + # -r requirements.txt + # python-dateutil sniffio==1.3.1 # via + # -r requirements.txt # anyio # openai +starlette==0.46.2 + # via + # fastapi + # gradio sympy==1.14.0 - # via torch + # via + # -r requirements.txt + # torch threadpoolctl==3.6.0 # via + # -r requirements.txt # botorch # scikit-learn tifffile==2025.3.30 - # via scikit-image + # via + # -r requirements.txt + # scikit-image +tomlkit==0.13.3 + # via gradio torch==2.7.0 # via + # -r requirements.txt # botorch # linear-operator # pyro-ppl tqdm==4.67.1 # via - # eaa (pyproject.toml) + # -r requirements.txt + # huggingface-hub # openai # pyro-ppl triton==3.3.0 - # via torch + # via + # -r requirements.txt + # torch +typer==0.16.0 + # via gradio typing-extensions==4.13.2 # via + # -r requirements.txt # anyio # botorch + # fastapi + # gradio + # gradio-client + # huggingface-hub # openai # pydantic # pydantic-core # pyre-extensions # torch + # typer # typing-inspect # typing-inspection typing-inspect==0.9.0 - # via pyre-extensions + # via + # -r requirements.txt + # pyre-extensions typing-inspection==0.4.0 - # via pydantic + # via + # -r requirements.txt + # pydantic +tzdata==2025.2 + # via pandas +urllib3==2.4.0 + # via requests +uvicorn==0.34.3 + # via gradio wadler-lindig==0.1.6 - # via jaxtyping + # via + # -r requirements.txt + # jaxtyping +websockets==15.0.1 + # via gradio-client diff --git a/src/eaa/agents/openai.py b/src/eaa/agents/openai.py index 9a46f1e..1e1f1b6 100644 --- a/src/eaa/agents/openai.py +++ b/src/eaa/agents/openai.py @@ -653,7 +653,11 @@ def resolve_json_type(py_type): } -def print_message(message: Dict[str, Any], response_requested: Optional[bool] = None) -> None: +def print_message( + message: Dict[str, Any], + response_requested: Optional[bool] = None, + return_string: bool = False +) -> None: """Print the message. Parameters @@ -662,6 +666,8 @@ def print_message(message: Dict[str, Any], response_requested: Optional[bool] = The message to be printed. response_requested : bool, optional Whether a response is requested for the message. + return_string : bool, optional + If True, the message is returned as a string instead of printed. """ color_dict = { "user": "\033[94m", @@ -692,4 +698,7 @@ def print_message(message: Dict[str, Any], response_requested: Optional[bool] = text += f"Arguments: {tool_call['function']['arguments']}\n" text += "\n ========================================= \n" - print(f"{color}{text}\033[0m") + if return_string: + return text + else: + print(f"{color}{text}\033[0m") diff --git a/src/eaa/gui/chat.py b/src/eaa/gui/chat.py new file mode 100644 index 0000000..cdaba35 --- /dev/null +++ b/src/eaa/gui/chat.py @@ -0,0 +1,103 @@ +"""WebUI based on Gradio. + +To use the WebUI, import `launch_gui` from `eaa.gui.chat` and call it after +creating the task manager. + +Example: +```python +from eaa.gui.chat import launch_gui + +task_manager = TaskManager(...) +launch_gui(task_manager) +``` +""" + +import threading +from typing import Optional +import re + +import gradio as gr + +from eaa.task_managers.base import BaseTaskManager +from eaa.util import decode_image_base64 +from eaa.agents.openai import print_message + + +class ChatUI: + def __init__(self, task_manager: BaseTaskManager): + self.task_manager = task_manager + self.chatbot: Optional[gr.Chatbot] = None + self.blocks: Optional[gr.Blocks] = None + self._setup_ui() + + def _setup_ui(self): + with gr.Blocks() as self.blocks: + self.chatbot = gr.Chatbot(type="messages") + + # Create a function to update the chat + def update_chat(): + context = self.task_manager.full_history + context_processed = [] + for i in range(len(context)): + if isinstance(context[i]["content"], list) and "type" in context[i]["content"][0]: + for item in context[i]["content"]: + if item["type"] == "image_url": + img_base64 = item["image_url"]["url"] + base64_data = re.sub('^data:image/.+;base64,', '', img_base64) + pil_image = decode_image_base64(base64_data, return_type="pil") + if pil_image.mode != "RGB": + pil_image = pil_image.convert("RGB") + gradio_im = gr.Image(pil_image) + context_processed.append({"content": gradio_im, "role": context[i]["role"]}) + elif item["type"] == "text": + context_processed.append({"content": item["text"], "role": context[i]["role"]}) + elif ( + "tool_calls" in context[i] + and isinstance(context[i]["tool_calls"], list) + and len(context[i]["tool_calls"]) > 0 + ): + tool_call_message = print_message(context[i], return_string=True) + context_processed.append({"content": tool_call_message, "role": context[i]["role"]}) + elif context[i]["role"] == "tool": + tool_response_message = print_message(context[i], return_string=True) + context_processed.append({"content": tool_response_message, "role": "user"}) + else: + if context[i]["content"] is None: + context[i]["content"] = "" + context_processed.append(context[i]) + return context_processed + + # Set up periodic updates + self.blocks.load(update_chat, None, self.chatbot, stream_every=1) + + def launch(self, **kwargs): + """Launch the UI in a non-blocking way""" + def run_server(): + self.blocks.launch(**kwargs) + + # Start the server in a separate thread + server_thread = threading.Thread(target=run_server) + server_thread.daemon = True # Make thread daemon so it exits when main program exits + server_thread.start() + + return server_thread + + +def launch_gui(task_manager: BaseTaskManager, **kwargs) -> threading.Thread: + """ + Launch the GUI in a non-blocking way. + + Parameters + ---------- + task_manager : BaseTaskManager + The task manager instance to use + **kwargs : dict + Additional arguments to pass to gr.Blocks.launch() + + Returns + ------- + threading.Thread + The thread running the Gradio server + """ + ui = ChatUI(task_manager) + return ui.launch(**kwargs) diff --git a/src/eaa/task_managers/base.py b/src/eaa/task_managers/base.py index e8ef998..0d89ca3 100644 --- a/src/eaa/task_managers/base.py +++ b/src/eaa/task_managers/base.py @@ -1,3 +1,5 @@ +from typing import Any, Dict + from eaa.tools.base import BaseTool from eaa.comms import get_api_key from eaa.agents.openai import OpenAIAgent @@ -16,6 +18,7 @@ def __init__( *args, **kwargs ): self.context = [] + self.full_history = [] self.model = model_name self.model_base_url = model_base_url self.access_token = access_token @@ -99,6 +102,17 @@ def get_llm_config(self, *args, **kwargs): def prerun_check(self, *args, **kwargs) -> bool: return True + + def update_message_history( + self, + message: Dict[str, Any], + update_context: bool = True, + update_full_history: bool = True + ) -> None: + if update_context: + self.context.append(message) + if update_full_history: + self.full_history.append(message) def run(self, *args, **kwargs) -> None: self.prerun_check() @@ -110,6 +124,6 @@ def run_conversation(self, *args, **kwargs) -> None: if message.lower() == "exit": break response, outgoing_message = self.agent.receive(message, return_outgoing_message=True) - self.context.append(outgoing_message) - self.context.append(response) + self.update_message_history(outgoing_message, update_context=True, update_full_history=True) + self.update_message_history(response, update_context=True, update_full_history=True) message = input("Enter a message: ") \ No newline at end of file diff --git a/src/eaa/task_managers/imaging/base.py b/src/eaa/task_managers/imaging/base.py index 798154e..e517f70 100644 --- a/src/eaa/task_managers/imaging/base.py +++ b/src/eaa/task_managers/imaging/base.py @@ -109,8 +109,8 @@ def run_imaging_feedback_loop( image_path=initial_image_path, return_outgoing_message=True ) - self.context.append(outgoing) - self.context.append(response) + self.update_message_history(outgoing, update_context=True, update_full_history=True) + self.update_message_history(response, update_context=True, update_full_history=True) while round < max_rounds: if response["content"] is not None and "TERMINATE" in response["content"]: message = input( @@ -125,8 +125,8 @@ def run_imaging_feedback_loop( image_path=None, return_outgoing_message=True ) - self.context.append(outgoing) - self.context.append(response) + self.update_message_history(outgoing, update_context=True, update_full_history=True) + self.update_message_history(response, update_context=True, update_full_history=True) continue tool_responses, tool_response_types = self.agent.handle_tool_call(response, return_tool_return_types=True) @@ -136,7 +136,7 @@ def run_imaging_feedback_loop( # Just save the tool response, but don't send yet. We will send it # together with the image later. print_message(tool_response) - self.context.append(tool_response) + self.update_message_history(tool_response, update_context=True, update_full_history=True) if not tool_response_type == ToolReturnType.IMAGE_PATH: raise ValueError( @@ -150,9 +150,8 @@ def run_imaging_feedback_loop( context=self.context, return_outgoing_message=True ) - if store_all_images_in_context: - self.context.append(outgoing) - self.context.append(response) + self.update_message_history(outgoing, update_context=store_all_images_in_context, update_full_history=True) + self.update_message_history(response, update_context=True, update_full_history=True) elif len(tool_responses) > 1: response, outgoing = self.agent.receive( "There are more than one tool calls in your response. " @@ -162,8 +161,8 @@ def run_imaging_feedback_loop( context=self.context, return_outgoing_message=True ) - self.context.append(outgoing) - self.context.append(response) + self.update_message_history(outgoing, update_context=True, update_full_history=True) + self.update_message_history(response, update_context=True, update_full_history=True) else: response, outgoing = self.agent.receive( "There is no tool call in the response. Make sure you call the tool correctly.", @@ -171,7 +170,7 @@ def run_imaging_feedback_loop( context=self.context, return_outgoing_message=True ) - self.context.append(outgoing) - self.context.append(response) + self.update_message_history(outgoing, update_context=True, update_full_history=True) + self.update_message_history(response, update_context=True, update_full_history=True) round += 1 diff --git a/src/eaa/util.py b/src/eaa/util.py index 089279e..bd97551 100644 --- a/src/eaa/util.py +++ b/src/eaa/util.py @@ -1,8 +1,9 @@ +from typing import Tuple, Literal import datetime import base64 import io import re -from typing import Tuple +from io import BytesIO from PIL import Image import numpy as np @@ -96,6 +97,33 @@ def encode_image_base64( return base64_data +def decode_image_base64( + base64_data: str, + return_type: Literal["numpy", "pil"] = "numpy" +) -> np.ndarray | Image.Image: + """Decode a base64-encoded image to a NumPy array or PIL image. + + Parameters + ---------- + base64_data : str + The base64-encoded image data. + return_type : Literal["numpy", "pil"], optional + The type of the returned image. + + Returns + ------- + np.ndarray | Image.Image + The decoded image. + """ + pil_image = Image.open(BytesIO(base64.b64decode(base64_data))) + if return_type == "numpy": + return np.array(pil_image) + elif return_type == "pil": + return pil_image + else: + raise ValueError(f"Invalid return type: {return_type}") + + def get_image_path_from_text( text: str, return_text_without_image_tag: bool = False