Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 43 additions & 4 deletions charge/_tags.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,55 @@
def verifier(func):
from typing import Callable

def verifier(func: Callable) -> Callable:
"""Decorator to mark a method as a verifier.

Args:
func (Callable): The function to mark as a verifier.

Returns:
Callable: The function marked as a verifier.
"""
func.__verifier_tag = True
return func


def is_verifier(func):
def is_verifier(func: Callable) -> bool:
"""Check if a method is marked as a verifier.

Args:
func (Callable): The function to check.

Returns:
bool: True if the function is marked as a verifier, False otherwise.
"""
return hasattr(func, "__verifier_tag")


def hypothesis(func):
def hypothesis(func: Callable) -> Callable:
"""Decorator to mark a method as a hypothesis.

Args:
func (Callable): The function to mark as a hypothesis.

Returns:
Callable: The function marked as a hypothesis.


Side Effects:
Marks the function as a hypothesis.
`func` is modified in place to include the `__hypothesis_tag` attribute.
"""
func.__hypothesis_tag = True
return func


def is_hypothesis(func):
def is_hypothesis(func: Callable) -> bool:
"""Check if a method is marked as a hypothesis.

Args:
func (Callable): The function to check.

Returns:
bool: True if the function is marked as a hypothesis, False otherwise.
"""
return hasattr(func, "__hypothesis_tag")
7 changes: 7 additions & 0 deletions charge/_to_mcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,13 @@
def task_to_mcp(class_info, methods_list) -> str:
"""
Convert an Task class to an MCP server definition string.

Args:
class_info (dict): A dictionary containing information about the class.
methods_list (list): A list of methods to be converted to MCP server definition strings.

Returns:
str: A string representing the MCP server definition.
"""
return_str = ""
return_str += "from mcp.server.fastmcp import FastMCP\n"
Expand Down
8 changes: 5 additions & 3 deletions charge/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
import readline
import inspect
import asyncio
from typing import Any


def enable_cmd_history_and_shell_integration(history: str):
def enable_cmd_history_and_shell_integration(history: str) -> None:
"""Enable persistent command-line history and integrate with the interactive shell.
Attempts to load an existing readline history file, sets the in-memory history
length to 1000, and registers an atexit handler to persist history on exit.
Expand All @@ -28,12 +28,14 @@ def enable_cmd_history_and_shell_integration(history: str):
atexit.register(readline.write_history_file, history)


async def maybe_await_async(var, *args, **kwargs):
async def maybe_await_async(var: Any, *args: Any, **kwargs: Any) -> Any:
"""Utility function to handle both synchronous and asynchronous callables or values.

Args:
var: A value, callable, or awaitable.
*args: Positional arguments to pass if var is callable.
**kwargs: Keyword arguments to pass if var is callable.

Returns:
The result of the callable or awaitable, or the value itself.
"""
Expand Down
58 changes: 57 additions & 1 deletion charge/clients/Client.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
################################################################################
## Copyright 2025 Lawrence Livermore National Security, LLC. and Binghamton University.
## See the top-level LICENSE file for details.
##
## SPDX-License-Identifier: Apache-2.0
################################################################################
from typing import Type, Dict, Optional
from abc import ABC, abstractmethod
from charge.tasks.Task import Task
Expand All @@ -10,11 +16,25 @@
import argparse
import atexit
import readline
from charge.agents.Agent import Agent



class Client:
"""Base client class for orchestrating tasks and interacting with MCP servers.

Subclasses must implement configuration, execution, and interaction methods.
"""
def __init__(
self, task: Task, path: str = ".", max_retries: int = 3
):
"""Initialize the client with a task instance.

Args:
task: The Task object this client will manage.
path: Directory path for generated files.
max_retries: Maximum number of retry attempts for server communication.
"""
self.task = task
self.path = path
self.max_retries = max_retries
Expand All @@ -24,10 +44,15 @@ def __init__(
self._setup()

def reset(self):
"""Reset internal message and reasoning traces to start a fresh run."""
self.messages = []
self.reasoning_trace = []

def _setup(self):
"""Inspect the task class and collect verifier methods.

Populates ``self.verifier_methods`` with methods marked as verifiers.
"""
cls_info = inspect_class(self.task)
methods = inspect.getmembers(self.task, predicate=inspect.ismethod)
name = cls_info["name"]
Expand All @@ -46,6 +71,10 @@ def _setup(self):
self.verifier_methods = verifier_methods

def setup_mcp_servers(self):
"""Generate MCP server files for hypothesis and verifier methods.

Creates Python files containing MCP representations of task methods.
"""

class_info = inspect_class(self.task)
name = class_info["name"]
Expand Down Expand Up @@ -77,26 +106,53 @@ def setup_mcp_servers(self):

@abstractmethod
def configure(model: str, backend: str) -> (str, str, str, Dict[str, str]):
"""Configure the client with model and backend details.

Returns a tuple of (model, backend, additional_info, config_dict).
"""
raise NotImplementedError("Subclasses must implement this method.")

@abstractmethod
async def run(self):
"""Execute the full task workflow.

Subclasses should implement the orchestration logic here.
"""
raise NotImplementedError("Subclasses must implement this method.")

@abstractmethod
async def step(self, agent, task: str):
async def step(self, agent: Agent, task: str):
"""Perform a single step of the task using the given agent.

Args:
agent: The agent performing the step.
task: Description of the task step.
"""
raise NotImplementedError("Subclasses must implement this method.")

@abstractmethod
async def chat(self):
"""Interactively chat with the orchestrator.

Subclasses should handle chat I/O.
"""
raise NotImplementedError("Subclasses must implement this method.")

@abstractmethod
async def refine(self, feedback: str):
"""Refine the task based on feedback.

Args:
feedback: Feedback string to adjust the task execution.
"""
raise NotImplementedError("Subclasses must implement this method.")

@staticmethod
def add_std_parser_arguments(parser: argparse.ArgumentParser):
"""Utility method to add standard command‑line arguments for the client.

Populates an ``argparse.ArgumentParser`` with common options.
"""
parser.add_argument(
"--model",
type=str,
Expand Down
2 changes: 1 addition & 1 deletion charge/clients/autogen.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def __init__(
timeout: int = 60,
backend: Optional[str] = None,
model_kwargs: Optional[dict] = None,
**kwargs,
**kwargs: Any,
) -> None:
super().__init__(task=task, **kwargs)
self.max_retries = max_retries
Expand Down
7 changes: 7 additions & 0 deletions charge/experiments/Experiment.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
################################################################################
## Copyright 2025 Lawrence Livermore National Security, LLC. and Binghamton University.
## See the top-level LICENSE file for details.
##
## SPDX-License-Identifier: Apache-2.0
################################################################################

from abc import abstractmethod
from typing import Any, List, Union, Optional
from charge.tasks.Task import Task
Expand Down
13 changes: 10 additions & 3 deletions charge/inspector.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
import inspect
from typing import Type


def inspect_class(cls):

def inspect_class(cls: Type) -> dict:
"""Inspect a class and return its type, name, and file path.

Args:
cls: The class to inspect.

Returns:
dict: A dictionary containing the type, name, and file path of the class.
"""
type_ = type(cls)
module = inspect.getmodule(cls.__class__)
file = module.__file__ if module else "Unknown"
Expand Down
2 changes: 1 addition & 1 deletion charge/servers/SMARTS_reactions_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def verify_reaction_SMARTS(smarts: str) -> Tuple[bool, str]:
The bool indicates if the SMARTS is valid, and the str is an error message if it is not.

Args:
smiles (str): The input SMILES string.
smarts (str): The input SMARTS string.
Returns:
A tuple containing:
bool: True if the SMARTS is valid, False if it is invalid.
Expand Down
4 changes: 4 additions & 0 deletions charge/servers/SMILES_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def canonicalize_smiles(smiles: str) -> str:

Args:
smiles (str): The input SMILES string.

Returns:
str: The canonicalized SMILES string.
"""
Expand All @@ -51,6 +52,7 @@ def verify_smiles(smiles: str) -> bool:

Args:
smiles (str): The input SMILES string.

Returns:
bool: True if the SMILES is valid, False otherwise.
"""
Expand Down Expand Up @@ -81,6 +83,7 @@ def get_synthesizability(smiles: str) -> float:

Args:
smiles (str): The input SMILES string.

Returns:
float: The synthesizability score.
"""
Expand Down Expand Up @@ -110,6 +113,7 @@ def known_smiles(smiles: str) -> bool:

Args:
smiles (str): The input SMILES string.

Returns:
bool: True if the SMILES is known to this MCP server, False otherwise.
"""
Expand Down
7 changes: 7 additions & 0 deletions charge/servers/molecule_pricer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
################################################################################
## Copyright 2025 Lawrence Livermore National Security, LLC. and Binghamton University.
## See the top-level LICENSE file for details.
##
## SPDX-License-Identifier: Apache-2.0
################################################################################

from loguru import logger
try:
import chemprice
Expand Down
40 changes: 37 additions & 3 deletions charge/servers/server_utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,23 @@
################################################################################
## Copyright 2025 Lawrence Livermore National Security, LLC. and Binghamton University.
## See the top-level LICENSE file for details.
##
## SPDX-License-Identifier: Apache-2.0
################################################################################

import argparse
from mcp.server.fastmcp import FastMCP


def add_server_arguments(parser: argparse.ArgumentParser) -> None:
"""
Add standard server arguments to an argparse parser.

Args:
parser (argparse.ArgumentParser): The parser to add arguments to.

Returns:
None
"""
parser.add_argument(
"--port", type=int, default=8000, help="Port to run the server on"
Expand All @@ -24,11 +35,28 @@ def add_server_arguments(parser: argparse.ArgumentParser) -> None:
)


def update_mcp_network(mcp: FastMCP, host: str, port: str):
def update_mcp_network(mcp: FastMCP, host: str, port: str) -> None:
"""
Update the MCP network settings.

Args:
mcp (FastMCP): The MCP server to update.
host (str): The host to run the server on.
port (str): The port to run the server on.

Returns:
None
"""
mcp.settings.host = host
mcp.settings.port = port

def get_hostname():
def get_hostname() -> Tuple[str, str]:
"""
Get the hostname and IP address of the host.

Returns:
Tuple[str, str]: The hostname and IP address of the host.
"""
import socket
hostname = socket.gethostname()
try:
Expand All @@ -37,7 +65,13 @@ def get_hostname():
host = "127.0.0.1"
return hostname, host

def try_get_public_hostname():
def try_get_public_hostname() -> Tuple[str, str]:
"""
Try to get the public hostname and IP address of the host.

Returns:
Tuple[str, str]: The public hostname and IP address of the host.
"""
import socket
hostname = socket.gethostname()
try:
Expand Down
Loading