diff --git a/.gitignore b/.gitignore
index 441facd..95e55f1 100644
--- a/.gitignore
+++ b/.gitignore
@@ -207,3 +207,9 @@ cython_debug/
marimo/_static/
marimo/_lsp/
__marimo__/
+
+
+#
+autobolt/
+*.db-*
+*.db
\ No newline at end of file
diff --git a/environment.yml b/environment.yml
index 8d41493..465a54d 100644
--- a/environment.yml
+++ b/environment.yml
@@ -17,4 +17,5 @@ dependencies:
- pandas
- "autobolt @ git+https://github.com/sriyanc2001/AutoBolt.git"
- black
- - pytest
\ No newline at end of file
+ - pytest
+ - sqlalchemy
\ No newline at end of file
diff --git a/src/autoboltagent/VLLMModelCustom.py b/src/autoboltagent/VLLMModelCustom.py
new file mode 100644
index 0000000..242f66a
--- /dev/null
+++ b/src/autoboltagent/VLLMModelCustom.py
@@ -0,0 +1,97 @@
+from smolagents import VLLMModel
+from smolagents.tools import Tool
+from smolagents.monitoring import TokenUsage
+from vllm.transformers_utils.tokenizer import get_tokenizer
+from smolagents.models import (
+ ChatMessage,
+ MessageRole,
+ remove_content_after_stop_sequences,
+
+)
+
+from typing import Any
+
+class VLLMModelCustom(VLLMModel):
+ def __init__(
+ self,
+ model_id,
+ model_kwargs: dict[str, Any] | None = None,
+ apply_chat_template_kwargs: dict[str, Any] | None = None,
+ sampling_params=None,
+ **kwargs,
+ ):
+ super().__init__(
+ model_id=model_id,
+ model_kwargs=model_kwargs,
+ apply_chat_template_kwargs=apply_chat_template_kwargs,
+ **kwargs
+ )
+ self.sampling_params = sampling_params
+
+ def generate(
+ self,
+ messages: list[ChatMessage | dict],
+ stop_sequences: list[str] | None = None,
+ response_format: dict[str, str] | None = None,
+ tools_to_call_from: list[Tool] | None = None,
+ **kwargs,
+ ) -> ChatMessage:
+ from vllm import SamplingParams # type: ignore
+ from vllm.sampling_params import StructuredOutputsParams # type: ignore
+
+ completion_kwargs = self._prepare_completion_kwargs(
+ messages=messages,
+ flatten_messages_as_text=(not self._is_vlm),
+ stop_sequences=stop_sequences,
+ tools_to_call_from=tools_to_call_from,
+ **kwargs,
+ )
+
+ prepared_stop_sequences = completion_kwargs.pop("stop", [])
+ messages = completion_kwargs.pop("messages")
+ tools = completion_kwargs.pop("tools", None)
+ completion_kwargs.pop("tool_choice", None)
+
+ if not self.sampling_params:
+ # Override the OpenAI schema for VLLM compatibility
+ structured_outputs = (
+ StructuredOutputsParams(json=response_format["json_schema"]["schema"]) if response_format else None
+ )
+
+
+ self.sampling_params = SamplingParams(
+ n=kwargs.get("n", 1),
+ temperature=kwargs.get("temperature", 0.0),
+ max_tokens=kwargs.get("max_tokens", 2048),
+ stop=prepared_stop_sequences,
+ structured_outputs=structured_outputs,
+ )
+
+
+ prompt = self.tokenizer.apply_chat_template(
+ messages,
+ tools=tools,
+ add_generation_prompt=True,
+ tokenize=False,
+ **self.apply_chat_template_kwargs,
+ )
+
+
+ out = self.model.generate(
+ prompt,
+ sampling_params=self.sampling_params,
+ **completion_kwargs,
+ )
+
+ output_text = out[0].outputs[0].text
+ if stop_sequences is not None and not self.supports_stop_parameter:
+ output_text = remove_content_after_stop_sequences(output_text, stop_sequences)
+ return ChatMessage(
+ role=MessageRole.ASSISTANT,
+ content=output_text,
+ raw={"out": output_text, "completion_kwargs": completion_kwargs},
+ token_usage=TokenUsage(
+ input_tokens=len(out[0].prompt_token_ids),
+ output_tokens=len(out[0].outputs[0].token_ids),
+ ),
+ )
\ No newline at end of file
diff --git a/src/autoboltagent/agents.py b/src/autoboltagent/agents.py
index 35a19fb..0fc6416 100644
--- a/src/autoboltagent/agents.py
+++ b/src/autoboltagent/agents.py
@@ -7,8 +7,10 @@
)
from .tools import AnalyticalTool, FiniteElementTool
+from .tools.logger import AgentLogger
-class GuessingAgent(smolagents.ToolCallingAgent):
+
+class GuessingAgent(smolagents.agents.ToolCallingAgent):
"""
An agent that makes guesses without using any tools.
@@ -16,7 +18,7 @@ class GuessingAgent(smolagents.ToolCallingAgent):
It is designed to provide initial estimates or solutions based on its knowledge and reasoning capabilities.
"""
- def __init__(self, model: smolagents.Model) -> None:
+ def __init__(self, model: smolagents.models.Model) -> None:
"""
Initializes a GuessingAgent that does not use any tools.
@@ -33,7 +35,7 @@ def __init__(self, model: smolagents.Model) -> None:
)
-class LowFidelityAgent(smolagents.ToolCallingAgent):
+class LowFidelityAgent(smolagents.agents.ToolCallingAgent):
"""
An agent that utilizes a low-fidelity analytical tool for bolted connection design.
@@ -41,24 +43,43 @@ class LowFidelityAgent(smolagents.ToolCallingAgent):
It is designed to provide solutions based on simplified models and assumptions, making it suitable for quick estimates and preliminary designs.
"""
- def __init__(self, model: smolagents.Model) -> None:
+ def __init__(self, model: smolagents.models.Model, agent_id: str, run_id: str, target_fos: float, agent_logger: AgentLogger|None = None, max_steps=20) -> None:
"""
Initializes a LowFidelityAgent that uses an analytical tool.
Args:
model: An instance of smolagents.Model to be used by the agent.
"""
+
+ self.agent_logger = agent_logger
+ self.agent_id = agent_id
+ self.run_id = run_id
+ self.target_fos = target_fos
+
+ callbacks = [self.log] if self.agent_logger else []
+
super().__init__(
name="LowFidelityAgent",
tools=[AnalyticalTool()],
add_base_tools=False,
model=model,
instructions=BASE_INSTRUCTIONS + TOOL_USING_INSTRUCTION,
+ step_callbacks = callbacks,
verbosity_level=2,
+ max_steps=max_steps
)
+ def log(self, step, agent):
+ if self.agent_logger and step.__class__.__name__ == "ActionStep":
+ self.agent_logger.log(
+ agent_id=self.agent_id,
+ run_id=self.run_id,
+ target_fos=self.target_fos,
+ action_step=step
+ )
+
-class HighFidelityAgent(smolagents.ToolCallingAgent):
+class HighFidelityAgent(smolagents.agents.ToolCallingAgent):
"""
An agent that utilizes a high-fidelity finite element analysis tool for bolted connection design.
@@ -66,7 +87,7 @@ class HighFidelityAgent(smolagents.ToolCallingAgent):
It is designed to provide accurate and reliable solutions based on comprehensive models, making it suitable for
"""
- def __init__(self, model: smolagents.Model) -> None:
+ def __init__(self, model: smolagents.models.Model) -> None:
"""
Initializes a HighFidelityAgent that uses a finite element tool.
@@ -83,7 +104,7 @@ def __init__(self, model: smolagents.Model) -> None:
)
-class DualFidelityAgent(smolagents.ToolCallingAgent):
+class DualFidelityAgent(smolagents.agents.ToolCallingAgent):
"""
An agent that utilizes both low-fidelity and high-fidelity tools for bolted connection design.
@@ -91,7 +112,7 @@ class DualFidelityAgent(smolagents.ToolCallingAgent):
It is designed to provide solutions that balance speed and accuracy by using the low-fidelity tool
"""
- def __init__(self, model: smolagents.Model) -> None:
+ def __init__(self, model: smolagents.models.Model) -> None:
"""
Initializes a DualFidelityAgent that uses both analytical and finite element tools.
diff --git a/src/autoboltagent/grammars.py b/src/autoboltagent/grammars.py
new file mode 100644
index 0000000..ce4dfcd
--- /dev/null
+++ b/src/autoboltagent/grammars.py
@@ -0,0 +1,74 @@
+# grammar that outputs low-fidelity tool call or final_answer
+low_fidelity_agent_grammar = r"""
+root ::= "\n" payload "\n"
+payload ::= fos | final
+
+fos ::= "{\"name\": \"analytical_fos_calculation\", \"arguments\": " fos_args "}}"
+fos_args ::= "{" ( num_field ", " dia_field ) | ( dia_field ", " num_field ) "}"
+num_field ::= "\"num_bolts\": " int
+dia_field ::= "\"bolt_diameter\": " number
+
+final ::= "{\"name\": \"final_answer\", \"arguments\": {\"answer\": " fos_args "}}}"
+
+int ::= digit+
+number ::= digit+ frac?
+frac ::= "." digit{1,4}
+
+digit ::= "0" | "1" | "2" | "3" | "4" | "5" | "6" | "7" | "8" | "9"
+"""
+
+# grammar that outputs high-fidelity tool call or final_answer
+high_fidelity_agent_grammar = r"""
+root ::= "\n" payload "\n"
+payload ::= fos | final
+
+fos ::= "{\"name\": \"fea_fos_calculation\", \"arguments\": " fos_args "}}"
+fos_args ::= "{" ( num_field ", " dia_field ) | ( dia_field ", " num_field ) "}"
+num_field ::= "\"num_bolts\": " int
+dia_field ::= "\"bolt_diameter\": " number
+
+final ::= "{\"name\": \"final_answer\", \"arguments\": {\"answer\": " fos_args "}}}"
+
+int ::= digit+
+number ::= digit+ frac?
+frac ::= "." digit{1,4}
+
+digit ::= "0" | "1" | "2" | "3" | "4" | "5" | "6" | "7" | "8" | "9"
+"""
+
+# grammar that outputs both types of tool call and final_answer
+dual_fidelity_agent_grammar = r"""
+root ::= "\n" payload "\n"
+payload ::= fos | final
+
+fos ::= "{\"name\": ("\"analytical_fos_calculation\"" | "\"fea_fos_calculation\""), \"arguments\": " fos_args "}}"
+fos_args ::= "{" ( num_field ", " dia_field ) | ( dia_field ", " num_field ) "}"
+num_field ::= "\"num_bolts\": " int
+dia_field ::= "\"bolt_diameter\": " number
+
+final ::= "{\"name\": \"final_answer\", \"arguments\": {\"answer\": " fos_args "}}}"
+
+int ::= digit+
+number ::= digit+ frac?
+frac ::= "." digit{1,4}
+
+digit ::= "0" | "1" | "2" | "3" | "4" | "5" | "6" | "7" | "8" | "9"
+"""
+
+low_fidelity_agent_grammar_original = r"""
+root ::= "\n" payload "\n"
+payload ::= fos | final
+
+fos ::= "{\"name\": \"analytical_fos_calculation\", \"arguments\": " fos_args "}}"
+fos_args ::= "{" ( num_field ", " dia_field ) | ( dia_field ", " num_field ) "}"
+num_field ::= "\"num_bolts\": " int
+dia_field ::= "\"bolt_diameter\": " number
+
+final ::= "{\"name\": \"final_answer\", \"arguments\": {\"answer\": " fos_args "}}"
+
+int ::= digit+
+number ::= digit+ frac?
+frac ::= "." digit{1,4}
+
+digit ::= "0" | "1" | "2" | "3" | "4" | "5" | "6" | "7" | "8" | "9"
+"""
\ No newline at end of file
diff --git a/src/autoboltagent/tools/high_fidelity_tool.py b/src/autoboltagent/tools/high_fidelity_tool.py
index 02a7bf9..c43b9a4 100644
--- a/src/autoboltagent/tools/high_fidelity_tool.py
+++ b/src/autoboltagent/tools/high_fidelity_tool.py
@@ -1,10 +1,11 @@
import autobolt
import smolagents
+from typing import Dict, Any, Union, cast
from .inputs import INPUTS
-class FiniteElementTool(smolagents.Tool):
+class FiniteElementTool(smolagents.tools.Tool):
"""
A tool that calculates the factor of safety for a bolted connection using finite element analysis.
@@ -15,7 +16,8 @@ class FiniteElementTool(smolagents.Tool):
name = "fea_fos_calculation"
description = "Calculates the factor of safety using finite element analysis."
- inputs = INPUTS
+ input_type = dict[str, dict[str, Union[str, type, bool]]]
+ inputs: input_type = cast(input_type,INPUTS)
output_type = "number"
@@ -65,3 +67,79 @@ def forward(
comparison = "within acceptable range"
return f"The factor of safety for the assembly is {fos:.2f} ({comparison})."
+
+
+class VerboseFiniteElementTool(smolagents.tools.Tool):
+ """
+ A tool that calculates the factor of safety for a bolted connection using finite element analysis.
+
+ This tool leverages the autobolt library to perform finite element calculations and determine the factor of safety
+ for a bolted connection based on the provided parameters.
+ """
+
+ name = "fea_fos_calculation"
+ description = "Calculates the factor of safety using finite element analysis."
+
+ input_type = dict[str, dict[str, Union[str, type, bool]]]
+ inputs = {
+ "num_bolts": {
+ "type": "number",
+ "description": "Number of bolts used in the joint",
+ },
+ "bolt_diameter": {
+ "type": "number",
+ "description": "Diameter of the bolt in mm",
+ }
+ }
+
+ output_type = "object"
+
+ def __init__(self, joint_configuration: Dict[str, Any], tolerance: float = 0.1):
+ super().__init__()
+ self.tolerance = tolerance
+ self.desired_safety_factor = joint_configuration["desired_safety_factor"]
+ self.load = joint_configuration["load"]
+ self.preload = joint_configuration["preload"]
+ self.bolt_yield_strength = joint_configuration["bolt_yield_strength"]
+ self.bolt_elastic_modulus = joint_configuration["bolt_elastic_modulus"]
+ self.plate_thickness = joint_configuration["plate_thickness"]
+ self.plate_elastic_modulus = joint_configuration["plate_elastic_modulus"]
+ self.plate_yield_strength = joint_configuration["plate_yield_strength"]
+ self.pitch = joint_configuration["pitch"]
+
+ def forward(
+ self,
+ num_bolts: int,
+ bolt_diameter: float,
+ ) -> dict:
+
+ # Define dimensions of the plate
+ plate_width = 0.1 # [m]
+ plate_length = 0.2 # [m]
+
+ # Compute traction
+ traction = -self.load / (self.plate_thickness / 1000 * plate_length)
+
+ fos = autobolt.calculate_fos(
+ plate_thickness_m=self.plate_thickness / 1000,
+ num_holes=num_bolts,
+ elastic_modulus=self.plate_elastic_modulus * 10**9,
+ yield_strength=self.plate_yield_strength * 10**6,
+ traction_values=[(0, traction, 0)],
+ hole_radius_m=bolt_diameter / 2 / 1000,
+ plate_length_m=plate_length,
+ plate_width_m=plate_width,
+ edge_margin_m=plate_length / (2 * num_bolts),
+ hole_spacing_m=plate_length / num_bolts,
+ hole_offset_from_bottom_m=0.020, # [m] vertical position of hole centers (Y from bottom edge)
+ plate_gap_mm=0.01, # [mm] gap between the two plates
+ poissons_ratio=0.3, # Poisson's ratio for steel
+ )
+
+ diff = fos - self.desired_safety_factor
+
+ return {
+ "ok": bool((abs(diff) < self.tolerance)),
+ "fos": float(fos),
+ "diff": float(diff),
+ }
\ No newline at end of file
diff --git a/src/autoboltagent/tools/logger.py b/src/autoboltagent/tools/logger.py
new file mode 100644
index 0000000..e1e7db7
--- /dev/null
+++ b/src/autoboltagent/tools/logger.py
@@ -0,0 +1,126 @@
+from sqlalchemy import create_engine
+
+from sqlalchemy.orm import declarative_base, Session, sessionmaker, Mapped, mapped_column
+from contextlib import contextmanager
+
+from pathlib import Path
+
+from datetime import datetime, timezone
+
+Base = declarative_base()
+
+class Iteration(Base):
+ __tablename__ = "iterations"
+
+ id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
+ agent_id: Mapped[str] = mapped_column()
+ run_id: Mapped[str] = mapped_column()
+ iteration_no: Mapped[int] = mapped_column()
+
+ start_time: Mapped[datetime] = mapped_column(nullable=True)
+ end_time: Mapped[datetime] = mapped_column(nullable=True)
+
+ status: Mapped[str] = mapped_column(nullable=True)
+ tool_call: Mapped[str] = mapped_column(nullable=True)
+ observations: Mapped[str] = mapped_column(nullable=True)
+ target_fos: Mapped[float] = mapped_column(nullable=True)
+
+ failure_reason: Mapped[str] = mapped_column(nullable=True)
+ llm_output: Mapped[str] = mapped_column(nullable=True)
+ error_message: Mapped[str] = mapped_column(nullable=True)
+
+class AgentLogger:
+ _instance = None
+
+ def connect_to_db(self, db_url: str):
+ self.db_url = db_url
+ self.engine = create_engine(db_url, future=True, pool_pre_ping=True)
+
+ try:
+ with self.engine.connect() as conn:
+ conn.exec_driver_sql("PRAGMA journal_mode=WAL;")
+ conn.exec_driver_sql("PRAGMA synchronous=NORMAL;")
+
+ Base.metadata.create_all(self.engine)
+
+ self.db_session = sessionmaker(bind=self.engine, expire_on_commit=False, future=True)
+ except Exception as e:
+ raise IOError("Failed to connect to DB, check if file in use", repr(e))
+
+ def __new__(cls, db_url: str):
+ if not cls._instance:
+ cls._instance = super().__new__(cls)
+ cls._instance.db_url = None
+ cls._instance.engine = None
+ cls._instance.db_session = None
+ cls._instance.connect_to_db(db_url)
+ return cls._instance
+
+ @classmethod
+ def reset(cls):
+ if not cls._instance:
+ return
+
+ inst = cls._instance
+
+ if inst.engine:
+ inst.engine.dispose()
+
+ if inst.db_url:
+ for suffix in ("", "-wal", "-shm"):
+ file_path = Path(inst.db_url.replace("sqlite:///", "") + suffix)
+ file_path.unlink(missing_ok=True)
+
+ cls._instance = None
+
+ def log(
+ self,
+ run_id,
+ agent_id,
+ target_fos,
+ action_step
+ ):
+
+ iteration_no = action_step.step_number
+ start_dt = datetime.fromtimestamp(action_step.timing.start_time, tz=timezone.utc)
+ end_dt = datetime.fromtimestamp(action_step.timing.end_time, tz=timezone.utc)
+
+ error = getattr(action_step, "error", None)
+ tool_calls = getattr(action_step, "tool_calls", None)
+ observations = getattr(action_step, "observations", None)
+ llm_message = getattr(action_step, "model_output_message", None)
+ llm_output = getattr(llm_message, "content", None)
+
+ print(tool_calls)
+ print(error)
+ print(observations)
+ print(llm_message)
+ print(action_step.token_usage)
+
+ try:
+
+ with self.db_session() as session:
+ session.add(
+ Iteration(
+ run_id=run_id,
+ agent_id=agent_id,
+ iteration_no=iteration_no,
+ start_time=start_dt,
+ end_time=end_dt,
+ tool_call=str(tool_calls[0] if tool_calls else None),
+ observations=observations,
+ target_fos=target_fos,
+ llm_output=llm_output,
+ error_message=error.message if (error and error.message) else None
+ )
+ )
+ session.flush()
+ session.commit()
+
+ except Exception as e:
+ print("\n\n")
+ print(e)
+ print("\n\n")
+
+
+
\ No newline at end of file
diff --git a/src/autoboltagent/tools/low_fidelity_tool.py b/src/autoboltagent/tools/low_fidelity_tool.py
index e250a0e..b5a62f9 100644
--- a/src/autoboltagent/tools/low_fidelity_tool.py
+++ b/src/autoboltagent/tools/low_fidelity_tool.py
@@ -1,5 +1,7 @@
import smolagents
+from typing import Dict, Any, Union, cast
+
from .fastener_toolkit import (
get_joint_constant,
get_tensile_stress_area,
@@ -8,7 +10,7 @@
from .inputs import INPUTS
-class AnalyticalTool(smolagents.Tool):
+class AnalyticalTool(smolagents.tools.Tool):
"""
A tool that calculates the factor of safety for a bolted connection using analytical expressions.
@@ -18,8 +20,9 @@ class AnalyticalTool(smolagents.Tool):
name = "analytical_fos_calculation"
description = "Calculates the factor of safety using analytical expressions."
-
- inputs = INPUTS
+
+ input_type = dict[str, dict[str, Union[str, type, bool]]]
+ inputs: input_type = cast(input_type,INPUTS)
output_type = "number"
@@ -81,3 +84,82 @@ def forward(
f"The factor of safety for bolts is {bolt_fos:.2f} ({bolt_comparison}) and "
f"the factor of safety for plates is {plate_fos:.2f} ({plate_comparison})."
)
+
+
+class VerboseAnalyticalTool(smolagents.tools.Tool):
+ """
+ A tool that calculates the factor of safety for a bolted connection using analytical expressions.
+
+ This tool uses established engineering formulas to compute the factor of safety for a bolted connection
+ based on the provided parameters.
+ """
+
+ name = "analytical_fos_calculation"
+ description = "Calculates the factor of safety using analytical expressions."
+
+ inputs = {
+ "num_bolts": {
+ "type": "number",
+ "description": "Number of bolts used in the joint",
+ },
+ "bolt_diameter": {
+ "type": "number",
+ "description": "Diameter of the bolt in mm",
+ }
+ }
+
+ output_type = "object"
+
+ def __init__(self, joint_configuration: Dict[str, Any], tolerance: float = 0.1):
+ super().__init__()
+ self.tolerance = tolerance
+ self.desired_safety_factor = joint_configuration["desired_safety_factor"]
+ self.load = joint_configuration["load"]
+ self.preload = joint_configuration["preload"]
+ self.bolt_yield_strength = joint_configuration["bolt_yield_strength"]
+ self.bolt_elastic_modulus = joint_configuration["bolt_elastic_modulus"]
+ self.plate_thickness = joint_configuration["plate_thickness"]
+ self.plate_elastic_modulus = joint_configuration["plate_elastic_modulus"]
+ self.plate_yield_strength = joint_configuration["plate_yield_strength"]
+ self.pitch = joint_configuration["pitch"]
+
+ def forward(
+ self,
+ num_bolts: int,
+ bolt_diameter: float,
+ ) -> dict:
+
+ load_per_bolt = self.load / num_bolts
+ preload_per_bolt = self.preload / num_bolts
+ tensile_area = get_tensile_stress_area(bolt_diameter, self.pitch)
+
+ c = get_joint_constant(
+ bolt_diameter,
+ self.plate_thickness * 2,
+ self.plate_elastic_modulus,
+ self.bolt_elastic_modulus,
+ )
+
+ bolt_fos = bolt_yield_safety_factor(
+ c=c,
+ load=load_per_bolt,
+ preload=preload_per_bolt,
+ a_ts=tensile_area,
+ b_ys=self.bolt_yield_strength,
+ )
+
+ bearing_area = bolt_diameter * self.plate_thickness * num_bolts
+ bearing_stress = self.load / bearing_area
+ allowable_bearing_stress = 1.5 * self.plate_yield_strength
+ plate_fos = allowable_bearing_stress / bearing_stress
+
+ bolt_diff = bolt_fos - self.desired_safety_factor
+ plate_diff = plate_fos - self.desired_safety_factor
+
+ return {
+ "ok": (abs(bolt_diff) < self.tolerance) and (abs(plate_diff) < self.tolerance),
+ "bolt_fos": bolt_fos,
+ "bolt_diff": bolt_diff,
+ "plate_fos": plate_fos,
+ "plate_diff": plate_diff
+ }
\ No newline at end of file
diff --git a/src/autoboltagent/verbose_agents.py b/src/autoboltagent/verbose_agents.py
new file mode 100644
index 0000000..d8c9286
--- /dev/null
+++ b/src/autoboltagent/verbose_agents.py
@@ -0,0 +1,49 @@
+import smolagents
+
+from .verbose_prompts import (
+ TOOL_USING_INSTRUCTION,
+ SIMPLIFIED_TOOL_USING_INSTRUCTION,
+ BASE_INSTRUCTIONS,
+ INPUT_FORMAT,
+ OUTPUT_FORMAT,
+ LOW_FIDELITY_TOOL_INSTRUCTION,
+ EXAMPLE_TASK_INSTRUCTIONS
+)
+from .tools.low_fidelity_tool import VerboseAnalyticalTool
+
+from .tools.logger import AgentLogger
+
+
+class VerboseLowFidelityAgent(smolagents.agents.ToolCallingAgent):
+ def __init__(self, model: smolagents.models.Model, joint_configuration: dict, agent_id: str, run_id: str, target_fos: float, agent_logger: AgentLogger|None = None) -> None:
+
+ self.agent_logger = agent_logger
+ self.agent_id = agent_id
+ self.run_id = run_id
+ self.target_fos = target_fos
+
+ callbacks = [self.log] if self.agent_logger else []
+ super().__init__(
+ name="VerboseLowFidelityAgent",
+ tools=[VerboseAnalyticalTool(joint_configuration)],
+ add_base_tools=False,
+ model=model,
+ instructions=(
+ BASE_INSTRUCTIONS +
+ INPUT_FORMAT +
+ OUTPUT_FORMAT +
+ SIMPLIFIED_TOOL_USING_INSTRUCTION +
+ LOW_FIDELITY_TOOL_INSTRUCTION
+ ),
+ step_callbacks=callbacks,
+ verbosity_level=2,
+ )
+
+ def log(self, step, agent):
+ if self.agent_logger and step.__class__.__name__ == "ActionStep":
+ self.agent_logger.log(
+ agent_id=self.agent_id,
+ run_id=self.run_id,
+ target_fos=self.target_fos,
+ action_step=step
+ )
\ No newline at end of file
diff --git a/src/autoboltagent/verbose_prompts.py b/src/autoboltagent/verbose_prompts.py
new file mode 100644
index 0000000..a663bdf
--- /dev/null
+++ b/src/autoboltagent/verbose_prompts.py
@@ -0,0 +1,331 @@
+# Base prompt for the agent
+BASE_INSTRUCTIONS = """
+# BASE INSTRUCTIONS
+
+You are a mechanical engineering expert specializing in the design of bolted connections.
+You will be given tasks that require you to determine the number and size of bolts to achieve a required factor of safety.
+Work iteratively to refine your solution.
+Before you complete the task, you must satisfy the following requirements:
+- The output of the analytical tool must have ok==True.
+- You must recommend both a bolt size (diameter) and the number of bolts.
+
+# HARD TERMINATION GATE (NON-NEGOTIABLE)
+
+You have exactly two allowed tool calls:
+1) analytical_fos_calculation(num_bolts, bolt_diameter)
+2) final_answer(answer)
+
+Rule A (no early final):
+- You MUST NOT call final_answer unless the most recent tool observation contains: ok == True.
+- If ok == False, calling final_answer is a failure.
+-
+
+Rule B (forced continuation):
+- If the most recent tool observation has ok == False, your next message MUST be a tool call to analytical_fos_calculation.
+- Do not explain, do not summarize, do not output any recommendation when ok == False.
+
+Rule C (final schema):
+- final_answer MUST be called exactly as:
+ {"name":"final_answer","arguments":{"answer":{"num_bolts":,"bolt_diameter":}}}
+- The field "answer" MUST be an object, NOT a string.
+"""
+# The factor of safety for both the bolt and the plate is within +/-0.1 of the target value.
+INPUT_FORMAT = """
+# INPUT FORMAT
+
+You will be given a json-like object called "joint_configuration" with some fields corresponding to a joint configuration. The fields are listed below:
+
+ - load: (number) The external load force, in Newtons (N)
+ - desired_safety_factor: (number) the desired FOS number
+ - bolt_yield_strength: (number) the yield strength of the bolt material, in MegaPascals (MPa)
+ - plate_yield_strength: (number) the yield strength of the plate material, in MegaPascals (MPa)
+ - preload: the force of preload per joint, in Newtons (N)
+ - pitch: (number) thread pitch in mm
+ - plate_thickness: (number) plate thickness in mm
+ - bolt_elastic_modulus: (number) elastic modulus of bolt, in GigaPascals (GPa)
+ - plate_elastic_modulus: (number) elastic modulus of plate material, in GigaPascals (GPa)
+
+Below is an example of a valid input:
+
+joint_configuration = {
+ "load": 60000,
+ "desired_safety_factor": 3.0,
+ "bolt_yield_strength": 940,
+ "plate_yield_strength": 250,
+ "preload": 150000,
+ "pitch": 1.5,
+ "plate_thickness": 10,
+ "bolt_elastic_modulus": 210,
+ "plate_elastic_modulus": 210
+}
+"""
+
+# Instructions for using tools
+TOOL_USING_INSTRUCTION = """
+# TOOL INSTRUCTIONS
+
+You will be given some tool(s) that use different methods to calculate the FOS of a joint configuration.
+
+They will be called with the following function signature:
+
+tool_call(
+ desired_safety_factor: float,
+ load: float,
+ preload: float,
+ num_bolts: int,
+ bolt_diameter: float,
+ bolt_yield_strength: float,
+ bolt_elastic_modulus: float,
+ plate_thickness: float,
+ plate_elastic_modulus: float,
+ plate_yield_strength: float,
+ pitch: float
+)
+
+These inputs include the specifications of the joint_configuration as well as num_bolts and bolt_diameter (in mm), which you will need to supply. You MUST include every parameter in the signature. Partial calls are strictly forbidden and will result in failure.
+The output of the tool will be a python dictionary with two fields: bolt_fos and plate_fos, which refer to the factor of safety for the bolt and plate respectively.
+
+When calling the tool, you must copy all joint_configuration fields exactly as provided. Do not change, round, “correct,” or infer any value (including pitch).
+"""
+
+SIMPLIFIED_TOOL_USING_INSTRUCTION = """
+# TOOL INSTRUCTIONS
+
+You will be given one or more tools that tell you if the factor of safety (FOS) for a joint configuration is within target.
+
+The tool is called using the following EXACT format:
+
+
+{"name":"","arguments":{"num_bolts":,"bolt_diameter":}}
+
+"""
+
+LOW_FIDELITY_TOOL_INSTRUCTION = """
+## analytical_fos_calculation
+
+This tool is the low-fidelity tool that uses analytical methods to calculate the FOS for the joint configuration and determine if it is within tolerance. It is computationally efficient but may be lacking in accuracy.
+"""
+
+HIGH_FIDELITY_TOOL_INSTRUCTION = """
+## fea_fos_calculation
+
+This tool is the high-fidelity tool that uses a computationally intensive but very accurate finite element analysis method to calculate the FOS.
+"""
+
+TOOL_OUTPUT_FORMAT = """
+# TOOL OUTPUT FORMAT
+
+The output of the FOS tool will be a json-like object with the following fields:
+
+- ok: (bool) this field is true if the bolt_fos and plate_fos values are within tolerance, and false otherwise. If ok is true, then the task is complete.
+- bolt_fos: (float) the calculated FOS value for the bolt
+- bolt_diff: (float) the signed difference between the calculated bolt FOS and the desired FOS.
+- plate_fos: (float) the calculated FOS value for the plate
+- plate_diff: (float) the signed difference between the calculated plate FOS and the desired FOS.
+
+Below is an example output from a FOS tool:
+
+{
+ 'ok': False,
+ 'bolt_fos': 0.00667551832522406,
+ 'bolt_diff': -2.993324481674776,
+ 'plate_fos': 0.5,
+ 'plate_diff': -2.5
+}
+"""
+
+SEARCH_RULES = """
+# SEARCH RULES (2-PHASE, PLATE THEN BOLT)
+
+Goal: make ok == True.
+
+Definitions (from tool output):
+- bolt_diff = bolt_fos - desired_safety_factor
+- plate_diff = plate_fos - desired_safety_factor
+- tolerance = 0.1
+
+Important structure:
+- plate_fos depends ONLY on (num_bolts * bolt_diameter), NOT preload or moduli.
+- Therefore, first satisfy plate_diff, then hold plate_fos ~ constant while tuning bolt_fos.
+
+PHASE 0 (plate target product):
+Compute the target capacity-product:
+ target_dn = (desired_safety_factor * load) / (1.5 * plate_yield_strength * plate_thickness)
+
+PHASE 1 (plate bracketing on capacity):
+- Keep bolt_diameter fixed initially.
+- Adjust num_bolts aggressively until plate_diff changes sign (bracket plate).
+ * if plate_diff < 0: increase num_bolts by +6 (cap 40)
+ * if plate_diff > 0: decrease num_bolts by -6 (floor 2)
+- Once bracketed, bisect num_bolts until abs(plate_diff) <= tolerance.
+
+PHASE 2 (bolt tuning with plate held near target):
+- Keep (num_bolts * bolt_diameter) approximately constant near target_dn.
+- If bolt_diff < 0 (bolt too weak), increase num_bolts and decrease bolt_diameter to keep num_bolts*bolt_diameter ~ constant.
+- If bolt_diff > 0 (bolt too strong), decrease num_bolts and increase bolt_diameter to keep product ~ constant.
+- After each move, re-check plate_diff; if abs(plate_diff) > tolerance, do ONE correction step to restore plate by nudging num_bolts (keeping diameter fixed).
+
+Constraints:
+- num_bolts ∈ [2,40], bolt_diameter ∈ [3.0,40.0]
+- Never reuse an exact (num_bolts, bolt_diameter) pair.
+"""
+
+SEARCH_RULES_1 = """
+# SEARCH RULES (BRACKET + BISECT, NO EXTRA STATE)
+
+Goal: make 'ok' == True in the tool output
+
+Definitions:
+- bolt_diff and plate_diff are provided by the tool.
+- controlling_diff = bolt_diff if abs(bolt_diff) >= abs(plate_diff) else plate_diff.
+
+Memory rule (use transcript only):
+- Treat each tool observation as a data point:
+ (num_bolts, bolt_diameter, bolt_diff, plate_diff).
+
+Bracketing rule:
+- Find the most recent previous data point in the transcript whose controlling_diff has the OPPOSITE SIGN
+ of the current controlling_diff.
+- If such a point exists, you have a bracket.
+
+Update rule:
+A) If you HAVE a bracket:
+ - Next guess MUST be the midpoint between the two bracket endpoints for EXACTLY ONE variable:
+ * If both points have the same bolt_diameter, bisect num_bolts:
+ num_bolts_next = round((n_low + n_high)/2)
+ bolt_diameter_next = current bolt_diameter
+ * Otherwise, bisect bolt_diameter:
+ bolt_diameter_next = round((d_low + d_high)/2, 2)
+ num_bolts_next = current num_bolts
+ - This guarantees zig-zag and shrinking steps.
+
+B) If you DO NOT have a bracket yet:
+ - Make a LARGE move to force a sign change in controlling_diff:
+ * If controlling_diff < 0, increase capacity: num_bolts += 6 (cap at 40).
+ * If controlling_diff > 0, decrease capacity: num_bolts -= 6 (floor at 2).
+ - Keep bolt_diameter fixed until the first bracket is found.
+
+Bounds:
+- num_bolts in [2, 40], bolt_diameter in [3.0, 40.0].
+"""
+
+OUTPUT_FORMAT = """
+# OUTPUT FORMAT (STRICT)
+
+You will output your reasoning and a tool call. Your reasoning must be at most 512 characters, and must be very specifically explain why you are chosen your numbers. Cite the FOS and diff numbers from the tool call.
+
+# FINAL ANSWER GATE (STRICT)
+
+Before you call ANY tool, you must check the most recent observation:
+
+- If there is no observation yet: call analytical_fos_calculation.
+- If the most recent observation has ok == False:
+ You MUST call analytical_fos_calculation next.
+ Calling final_answer here is INVALID and will be graded as FAILURE.
+
+- If the most recent observation has ok == True:
+ You MUST call final_answer next.
+
+If ok == False:
+Output ONLY your reasoning (512 chars MAX) and the tool call:
+
+reason:
+
+{"name":"analytical_fos_calculation","arguments":{"num_bolts": , "bolt_diameter": }}
+
+
+If ok == True:
+Output ONLY:
+
+reason:
+
+{"name":"final_answer", "arguments":{"answer":{"num_bolts": , "bolt_diameter": }}}
+
+"""
+
+FOS_CONTEXT = """
+
+"""
+
+# Instructions specific to dual-fidelity agent
+DUAL_FIDELITY_COORDINATION = """
+
+You should also note that you have access to a low-fidelity analytical tool and a high-fidelity finite element analysis tool.
+- Use the low-fidelity tool for quick initial estimates and to explore different design options.
+- Use the high-fidelity tool to validate and refine your designs.
+"""
+
+MINIMAL_PROMPT = """
+# BOLTED JOINT DESIGN TASK
+
+You must determine the number and size of bolts to achieve a target safety factor (FOS) for both the bolt and plate. Use tool calls iteratively. Do not guess final values until ok==True in the tool response.
+
+---
+
+# INPUT FORMAT
+
+You will receive a dictionary named `joint_configuration` containing:
+- load, desired_safety_factor, bolt_yield_strength, plate_yield_strength
+- preload, pitch, plate_thickness, bolt_elastic_modulus, plate_elastic_modulus
+
+These values are fixed and used in tool calls.
+
+---
+
+# TOOL USAGE
+
+Use only this tool format:
+```json
+{"name": "analytical_fos_calculation", "arguments": {"num_bolts": , "bolt_diameter": }}
+```
+
+The tool will return:
+```json
+{"bolt_fos": , "plate_fos": }
+```
+
+You must compute:
+```python
+bolt_diff = bolt_fos - desired_safety_factor
+plate_diff = plate_fos - desired_safety_factor
+controlling_diff = bolt_diff if abs(bolt_diff) >= abs(plate_diff) else plate_diff
+sign = -1 if controlling_diff < 0 else +1
+```
+
+---
+
+# SEARCH RULES
+
+- Keep `bolt_diameter` fixed throughout. Search only on `num_bolts ∈ [2, 40]`.
+- Never change `num_bolts` by ±1 or reuse a tried value.
+
+**Phase 1: Bracketing**
+- Start from any `num_bolts`.
+- If `sign = -1`: increase `num_bolts` by 12.
+- If `sign = +1`: decrease `num_bolts` by 12.
+- Continue until you’ve tried two values with opposite signs → this forms a bracket.
+
+**Phase 2: Bisection**
+- Once a bracket exists, bisect it:
+ `num_bolts = round((low + high)/2)`
+- If you get stuck (same sign 2× in a row and no improvement), increase jump to ±18 once.
+
+---
+
+# TERMINATION
+
+When both diffs are within ±0.1, output only:
+```json
+{"name": "final_answer", "arguments": {"answer": {"num_bolts": , "bolt_diameter": }}}
+```
+
+"""
+
+# Instructions for an example design task
+EXAMPLE_TASK_INSTRUCTIONS = """
+Given the following joint configuration:
+
+joint_configuration = {}
+
+Determine the optimal number of bolts and the major diameter of the bolts:
+"""
diff --git a/tests/test_agents.py b/tests/test_agents.py
index d002ba4..4fe66f3 100644
--- a/tests/test_agents.py
+++ b/tests/test_agents.py
@@ -10,7 +10,7 @@ def is_macos() -> bool:
return platform.system() == "Darwin"
-def get_testing_model() -> smolagents.Model:
+def get_testing_model() -> smolagents.models.Model:
if is_macos():
# Use a local model on macOS for faster testing
return smolagents.MLXModel(
@@ -18,7 +18,7 @@ def get_testing_model() -> smolagents.Model:
)
else:
# Use the smallest Instruct model available for fast CI feedback
- return smolagents.TransformersModel(
+ return smolagents.models.TransformersModel(
model_id="HuggingFaceTB/SmolLM-135M-Instruct",
max_new_tokens=200, # Keep generation short for speed
)
@@ -38,7 +38,15 @@ def test_guessing_agent():
def test_low_fidelity_agent():
# Create the LowFidelityAgent and run it
- response = autoboltagent.LowFidelityAgent(get_testing_model()).run(
+ agent = autoboltagent.LowFidelityAgent(
+ model=get_testing_model(),
+ agent_id="low fidelity agent",
+ run_id="test 1",
+ target_fos=3.0,
+ max_steps=5
+ )
+
+ response = agent.run(
autoboltagent.prompts.EXAMPLE_TASK_INSTRUCTIONS
)
diff --git a/tests/test_logger.py b/tests/test_logger.py
new file mode 100644
index 0000000..23179d2
--- /dev/null
+++ b/tests/test_logger.py
@@ -0,0 +1,103 @@
+from autoboltagent.tools.logger import AgentLogger
+from autoboltagent.tools.logger import Iteration
+from smolagents import ActionStep, Timing, ToolCall, ChatMessage, MessageRole, AgentError
+from sqlalchemy import create_engine, select
+from sqlalchemy.orm import Session
+from datetime import datetime, timezone
+import pytest
+
+db_url = "sqlite:///agent_logs_test.db"
+
+@pytest.fixture
+def logger():
+ """
+ Fixture to create a new AgentLogger with a fresh db per test
+ """
+ AgentLogger.reset()
+ logger = AgentLogger(db_url)
+ yield logger
+ AgentLogger.reset()
+
+def get_log_session(db_url):
+ """
+ Helper function to get db session from url
+ """
+ engine = create_engine(db_url)
+ return Session(engine)
+
+def test_logger_empty_when_fresh(logger):
+ """
+ Test to see if the logger db is empty when fresh
+ """
+ with get_log_session(db_url) as session:
+ iterations = session.query(Iteration).all()
+
+ assert len(iterations) == 0
+
+def test_logger_one_write_fields_persist(logger):
+ """
+ Test one log write and see if all fields persists in the db
+ """
+ start_time = datetime.now(timezone.utc).timestamp()
+ end_time = datetime.now(timezone.utc).timestamp()
+
+ step = ActionStep(
+ step_number=1,
+ timing=Timing(
+ start_time=start_time,
+ end_time=end_time
+ ),
+ observations = "observation observation",
+ tool_calls=[ToolCall(name="tool call", arguments={"asdf": 1}, id="1341fad")],
+ model_output_message=ChatMessage(role=MessageRole("assistant"), content="LLM output 1"),
+ error=None
+ )
+
+ logger.log(
+ run_id="run_1",
+ agent_id="agent_1",
+ target_fos=1,
+ action_step=step
+ )
+
+ with get_log_session(db_url) as session:
+ iteration = session.query(Iteration).one()
+
+ assert iteration.run_id == "run_1"
+ assert iteration.agent_id == "agent_1"
+ assert iteration.iteration_no == 1
+ assert type(iteration.start_time) == datetime
+ assert type(iteration.end_time) == datetime
+ assert iteration.target_fos == 1
+ assert iteration.llm_output == "LLM output 1"
+ assert iteration.error_message == None
+
+
+def test_logger_large_write(logger):
+ """
+ Test 50 log writes and see if db contains 50 rows
+ """
+ for i in range(50):
+ start_time = datetime.now(timezone.utc).timestamp()
+ end_time = datetime.now(timezone.utc).timestamp()
+ step = ActionStep(
+ step_number=1,
+ timing=Timing(
+ start_time=start_time,
+ end_time=end_time
+ ),
+ observations = "observation observation",
+ tool_calls=[ToolCall(name="tool call", arguments={"asdf": 1}, id="1341fad")],
+ model_output_message=ChatMessage(role=MessageRole("assistant"), content=f"LLM output {i}"),
+ error=None
+ )
+ logger.log(
+ run_id="run_1",
+ agent_id="agent_1",
+ target_fos=1,
+ action_step=step
+ )
+ with get_log_session(db_url) as session:
+ iterations = session.query(Iteration).all()
+
+ assert len(iterations) == 50
\ No newline at end of file
diff --git a/utilities/agent_testing.ipynb b/utilities/agent_testing.ipynb
new file mode 100644
index 0000000..3354bb6
--- /dev/null
+++ b/utilities/agent_testing.ipynb
@@ -0,0 +1,288 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "id": "1efd1f50",
+ "metadata": {},
+ "source": [
+ "# Imports"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "aa00e50c",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from autoboltagent import verbose_prompts, grammars\n",
+ "from autoboltagent import prompts\n",
+ "from autoboltagent.verbose_agents import VerboseLowFidelityAgent\n",
+ "from autoboltagent.agents import LowFidelityAgent\n",
+ "from autoboltagent.tools.logger import AgentLogger\n",
+ "from autoboltagent.VLLMModelCustom import VLLMModelCustom"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "85dcc950",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from vllm import SamplingParams\n",
+ "from vllm.sampling_params import StructuredOutputsParams\n",
+ "import smolagents"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "fb89c141",
+ "metadata": {},
+ "source": [
+ "# Setup"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "a87e3fec",
+ "metadata": {},
+ "source": [
+ "### set up logger and db"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "c84a2a1a",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "db_url = \"sqlite:///../src/agent_logs_grammar_prod.db\"\n",
+ "logger = AgentLogger(db_url)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "d00aa648",
+ "metadata": {},
+ "source": [
+ "### params"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "653c34d2",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "joint_configuration = {\n",
+ " \"load\": 60000,\n",
+ " \"desired_safety_factor\": 3.0,\n",
+ " \"bolt_yield_strength\": 940,\n",
+ " \"plate_yield_strength\": 250,\n",
+ " \"preload\": 150000,\n",
+ " \"pitch\": 1.5,\n",
+ " \"plate_thickness\": 10,\n",
+ " \"bolt_elastic_modulus\": 210,\n",
+ " \"plate_elastic_modulus\": 210\n",
+ " }\n",
+ "\n",
+ "input = \"\"\"{\n",
+ " \"load\": 60000,\n",
+ " \"desired_safety_factor\": 3.0,\n",
+ " \"bolt_yield_strength\": 940,\n",
+ " \"plate_yield_strength\": 250,\n",
+ " \"preload\": 150000,\n",
+ " \"pitch\": 1.5,\n",
+ " \"plate_thickness\": 10,\n",
+ " \"bolt_elastic_modulus\": 210,\n",
+ " \"plate_elastic_modulus\": 210\n",
+ " }\"\"\"\n",
+ "\n",
+ "grammar_sop = StructuredOutputsParams(\n",
+ " grammar=grammars.low_fidelity_agent_grammar_debug\n",
+ ")\n",
+ "\n",
+ "sampling_params = SamplingParams(\n",
+ " max_tokens=200,\n",
+ " temperature=0.0,\n",
+ " structured_outputs=grammar_sop\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "8f17629d",
+ "metadata": {},
+ "source": [
+ "# Models"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "e8e3c23e",
+ "metadata": {},
+ "source": [
+ "### Local"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "4143950e",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "model = smolagents.VLLMModel(\n",
+ " model_id=\"RedHatAI/Qwen2.5-3B-Instruct-quantized.w8a8\",\n",
+ " model_kwargs={\n",
+ " \"gpu_memory_utilization\": 0.85,\n",
+ " },\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "00cec7de",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Custom local VLLM model with grammar\n",
+ "\n",
+ "model = VLLMModelCustom(\n",
+ " model_id=\"RedHatAI/Qwen2.5-3B-Instruct-quantized.w8a8\",\n",
+ " apply_chat_template_kwargs=None,\n",
+ " model_kwargs={\n",
+ " \"gpu_memory_utilization\": 0.85,\n",
+ " },\n",
+ " sampling_params=sampling_params,\n",
+ " )"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "278e5eeb",
+ "metadata": {},
+ "source": [
+ "### Cloud models"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "ce9ba947",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "FIREWORKS_API_KEY = \"\""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "a14fa02a",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# smolagents InferenceClientModel\n",
+ "\n",
+ "model = smolagents.InferenceClientModel( # type: ignore\n",
+ " provider=\"fireworks-ai\",\n",
+ " model_id=\"openai/gpt-oss-20b\",\n",
+ " token=FIREWORKS_API_KEY,\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "83fce9b8",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "model = smolagents.OpenAIServerModel(\n",
+ " model_id=\"accounts/fireworks/models/gpt-oss-20b\",\n",
+ " api_base=\"https://api.fireworks.ai/inference/v1\", \n",
+ " api_key=FIREWORKS_API_KEY,\n",
+ " # response_format={\n",
+ " # \"type\": \"grammar\",\n",
+ " # \"grammar\": grammars.low_fidelity_agent_grammar_debug\n",
+ " # }\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "92627081",
+ "metadata": {},
+ "source": [
+ "# Run agent"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "05023b62",
+ "metadata": {},
+ "source": [
+ "### Single agent run"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "96b94fe8",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "agent = VerboseLowFidelityAgent(model, joint_configuration, \"verbose low fidelity agent\", f\"verbose prompts + minimized + reason 512 + gpt 11\", 3.0, logger, max_steps=100)\n",
+ "instruction = verbose_prompts.EXAMPLE_TASK_INSTRUCTIONS.format(input)\n",
+ "agent.run(instruction)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "e4207a92",
+ "metadata": {},
+ "source": [
+ "### Loop agent run"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "691f9b15",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "for i in range(0,25):\n",
+ " agent = VerboseLowFidelityAgent(model, joint_configuration, \"verbose low fidelity agent\", f\"verbose prompts + minimized + reason 512 + gpt {i}\", 3.0, logger, max_steps=100)\n",
+ " instruction = verbose_prompts.EXAMPLE_TASK_INSTRUCTIONS.format(input)\n",
+ " agent.run(instruction)\n"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "autoboltagent",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.19"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/utilities/tool_testing.ipynb b/utilities/tool_testing.ipynb
new file mode 100644
index 0000000..10a05a7
--- /dev/null
+++ b/utilities/tool_testing.ipynb
@@ -0,0 +1,208 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "id": "de3fb190",
+ "metadata": {},
+ "source": [
+ "# Imports"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "44771eae",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from autoboltagent.tools.low_fidelity_tool import VerboseAnalyticalTool\n",
+ "from autoboltagent.tools.high_fidelity_tool import VerboseFiniteElementTool"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "cffaf6f6",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from collections import deque\n",
+ "import numpy as np\n",
+ "from tqdm import tqdm"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "acf3d0aa",
+ "metadata": {},
+ "source": [
+ "# Manual Guessing"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "0f96da72",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "joint_config = {\n",
+ " \"load\": 60000,\n",
+ " \"desired_safety_factor\": 3.0,\n",
+ " \"bolt_yield_strength\": 940,\n",
+ " \"plate_yield_strength\": 250,\n",
+ " \"preload\": 150000,\n",
+ " \"pitch\": 1.5,\n",
+ " \"plate_thickness\": 10,\n",
+ " \"bolt_elastic_modulus\": 210,\n",
+ " \"plate_elastic_modulus\": 210\n",
+ "}"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "bf28dd68",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "analytical_tool = VerboseAnalyticalTool(joint_configuration=joint_config)\n",
+ "fea_tool = VerboseFiniteElementTool(joint_configuration=joint_config)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "6897da3a",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def cached_analytical_guess(past_guesses, num_bolts, bolt_diameter):\n",
+ " result = analytical_tool.forward(num_bolts=num_bolts, bolt_diameter=bolt_diameter)\n",
+ " past_guesses.appendleft((num_bolts, bolt_diameter, result))\n",
+ " \n",
+ " for row in past_guesses:\n",
+ " print(row)\n",
+ "\n",
+ "def cached_fea_guess(past_guesses, num_bolts, bolt_diameter):\n",
+ " result = fea_tool.forward(num_bolts=num_bolts, bolt_diameter=bolt_diameter)\n",
+ " past_guesses.appendleft((num_bolts, bolt_diameter, result))\n",
+ " \n",
+ " for row in past_guesses:\n",
+ " print(row)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "b8a88d0e",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "past_guesses = deque()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "cc066a9d",
+ "metadata": {},
+ "source": [
+ "### Usage\n",
+ "\n",
+ "Simply modify the num_bolts and bolt_diameter parameters; past guesses are saved and will be printed out as well. Note that the FEA tool generates a lot of output that can't be suppressed."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "87688dd8",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "cached_analytical_guess(past_guesses, num_bolts=2, bolt_diameter=17)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "d58a840c",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import logging\n",
+ "logging.basicConfig(level=logging.CRITICAL)\n",
+ "\n",
+ "cached_fea_guess(past_guesses, num_bolts=2, bolt_diameter=17)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "cd57334d",
+ "metadata": {},
+ "source": [
+ "# Brute Force Search"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "b19839ea",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def brute_force_search(\n",
+ " num_bolts_low, \n",
+ " num_bolts_high, \n",
+ " bolt_diameter_low, \n",
+ " bolt_diameter_high,\n",
+ " step_size=0.1\n",
+ "):\n",
+ " for num_bolts in range(num_bolts_low, num_bolts_high):\n",
+ " for bolt_diameter in np.arange(bolt_diameter_low, bolt_diameter_high, step_size):\n",
+ " result = analytical_tool.forward(num_bolts=num_bolts, bolt_diameter=float(bolt_diameter))\n",
+ "\n",
+ " if abs(result[\"bolt_diff\"]) < 0.25 and abs(result[\"plate_diff\"]) < 0.25:\n",
+ " print(num_bolts, bolt_diameter,result)\n",
+ " if result[\"ok\"]:\n",
+ " print(\"ANSWER \", num_bolts, bolt_diameter,result)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "4af7b215",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "brute_force_search(num_bolts_low=2, num_bolts_high=40, bolt_diameter_low=3.0, bolt_diameter_high=40.0)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "34f4e480",
+ "metadata": {},
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "autoboltagent",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.19"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}