Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 120 additions & 0 deletions src/eaa/task_managers/tuning/bo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
import logging
from typing import Optional, Callable

import torch

from eaa.task_managers.base import BaseTaskManager
from eaa.tools.base import BaseTool
from eaa.tools.bo import BayesianOptimizationTool
from eaa.api.llm_config import LLMConfig

logger = logging.getLogger(__name__)


class BayesianOptimizationTaskManager(BaseTaskManager):

def __init__(
self,
llm_config: LLMConfig = None,
tools: list[BaseTool] = [],
bayesian_optimization_tool: BayesianOptimizationTool = None,
initial_points: Optional[torch.Tensor] = None,
n_initial_points: int = 20,
objective_function: Callable = None,
message_db_path: Optional[str] = None,
build: bool = True,
*args, **kwargs
) -> None:
"""Bayesian optimization task manager.

Parameters
----------
llm_config : LLMConfig, optional
The configuration for the LLM.
tools : list[BaseTool], optional
A list of tools for the agent. This should NOT include the
`BayesianOptimizationTool`.
bayesian_optimization_tool : BayesianOptimizationTool
The Bayesian optimization tool to use.
initial_points : torch.Tensor, optional
A (n_points, n_features) tensor giving the initial points where
the objective function should be evaluated to initialize the
Gaussian process model. If None, random initial points will be
generated.
n_initial_points : int, optional
The number of initial points to generate if `initial_points` is None.
objective_function : Callable
The objective function to be maximized. This function should take
a single argument, which is a (n_points, n_features) tensor of
points to evaluate the objective function at. It should return
a (n_points, n_objectives) tensor of objective function values.
message_db_path : Optional[str]
If provided, the entire chat history will be stored in
a SQLite database at the given path. This is essential
if you want to use the WebUI, which polls the database
for new messages.
build : bool, optional
Whether to build the internal state of the task manager.
"""
if bayesian_optimization_tool is None:
raise ValueError(
"Bayesian optimization tool should be explicitly passed to "
"`bayesian_optimization_tool`."
)
if objective_function is None:
raise ValueError("`objective_function` is required.")

self.bayesian_optimization_tool = bayesian_optimization_tool

for tool in tools:
if isinstance(tool, BayesianOptimizationTool):
raise ValueError(
"`BayesianOptimizationTool` should not be included in `tools`. "
"Instead, pass it to `bayesian_optimization_tool`."
)

self.objective_function = objective_function

self.initial_points = initial_points
self.n_initial_points = n_initial_points

super().__init__(
llm_config=llm_config,
tools=tools,
message_db_path=message_db_path,
build=build,
*args, **kwargs
)

def run(
self,
n_iterations: int = 50,
*args, **kwargs
) -> None:
"""Run Bayesian optimization. Upon the second or later call,
this function continues from the last iteration.

Parameters
----------
n_iterations : int, optional
The number of iterations to run.
"""
if len(self.bayesian_optimization_tool.xs_untransformed) == 0:
if self.initial_points is None:
xs_init = self.bayesian_optimization_tool.get_random_initial_points(n_points=self.n_initial_points)
else:
xs_init = self.initial_points
logger.info(f"Initial points (shape: {xs_init.shape}):\n{xs_init}")

for x in xs_init:
x = x[None, :]
y = self.objective_function(x)
self.bayesian_optimization_tool.update(x, y)
self.bayesian_optimization_tool.build()

for i in range(n_iterations):
candidates = self.bayesian_optimization_tool.suggest(n_suggestions=1)
logger.info(f"Candidate suggested: {candidates[0]}")
y = self.objective_function(candidates)
logger.info(f"Objective function value: {y.item()}")
self.bayesian_optimization_tool.update(candidates, y)
148 changes: 148 additions & 0 deletions src/eaa/task_managers/tuning/bo_mic_optics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
import logging
from typing import Optional

import torch
import numpy as np
from PIL import Image

from eaa.task_managers.tuning.bo import BayesianOptimizationTaskManager
from eaa.task_managers.imaging.feature_tracking import FeatureTrackingTaskManager
from eaa.tools.base import BaseTool
from eaa.tools.bo import BayesianOptimizationTool
from eaa.api.llm_config import LLMConfig

logger = logging.getLogger(__name__)


class MicroscopyOpticsTuningBOTaskManager(
BayesianOptimizationTaskManager,
FeatureTrackingTaskManager
):
def __init__(
self,
llm_config: LLMConfig = None,
image_acquisition_tool: BaseTool = None,
parameter_setting_tool: BaseTool = None,
bayesian_optimization_tool: BayesianOptimizationTool = None,
initial_points: Optional[torch.Tensor] = None,
n_initial_points: int = 20,
image_acquisition_kwargs: dict = {},
feature_tracking_kwargs: dict = {},
message_db_path: Optional[str] = None,
*args, **kwargs
):
"""The Bayesian optimization task manager for microscopy optics tuning.

Parameters
----------
llm_config : LLMConfig, optional
The configuration for the LLM.
bayesian_optimization_tool : BayesianOptimizationTool
The Bayesian optimization tool to use.
initial_points : torch.Tensor, optional
A (n_points, n_features) tensor giving the initial points where
the objective function should be evaluated to initialize the
Gaussian process model. If None, random initial points will be
generated.
n_initial_points : int, optional
The number of initial points to generate if `initial_points` is None.
image_acquisition_kwargs : dict, optional
The arguments of the image acquisition tool that should be used
when acquiring images for evaluating the objective function.
feature_tracking_kwargs : dict, optional
The arguments of the feature tracking task manager's `run` method.
message_db_path : Optional[str]
If provided, the entire chat history will be stored in
a SQLite database at the given path. This is essential
if you want to use the WebUI, which polls the database
for new messages.
build : bool, optional
Whether to build the internal state of the task manager.
"""
if bayesian_optimization_tool is None:
raise ValueError(
"Bayesian optimization tool should be explicitly passed to "
"`bayesian_optimization_tool`."
)
if image_acquisition_tool is None:
raise ValueError(
"Image acquisition tool should be explicitly passed to "
"`image_acquisition_tool`."
)
if parameter_setting_tool is None:
raise ValueError(
"Parameter setting tool should be explicitly passed to "
"`parameter_setting_tool`."
)

self.image_acquisition_tool = image_acquisition_tool
self.parameter_setting_tool = parameter_setting_tool
self.image_acquisition_kwargs = image_acquisition_kwargs
self.feature_tracking_kwargs = feature_tracking_kwargs

BayesianOptimizationTaskManager.__init__(
self,
llm_config=llm_config,
tools=[],
bayesian_optimization_tool=bayesian_optimization_tool,
initial_points=initial_points,
n_initial_points=n_initial_points,
objective_function=self.objective_function,
message_db_path=message_db_path,
build=False,
*args, **kwargs
)
FeatureTrackingTaskManager.__init__(
self,
llm_config=llm_config,
tools=[image_acquisition_tool],
build=True,
*args, **kwargs
)

def objective_function(self, x: torch.Tensor, *args, **kwargs):
"""Calculate the objective function value.

Parameters
----------
x : torch.Tensor
A (n_points, n_features) tensor of points to evaluate the
objective function at.

Returns
-------
torch.Tensor
A (n_points, 1) tensor of objective function values.
"""
if x.ndim != 2:
raise ValueError(
"`x` should be a 2D tensor of shape (n_points, n_features)."
)

objective_values = torch.zeros(x.shape[0], 1, device=x.device)

for i, x_i in enumerate(x):
# Acquire an image with the current parameters. It will be used
# as the reference image for feature tracking.
acquired_image_path = self.image_acquisition_tool.acquire_image(
**self.image_acquisition_kwargs
)

# Apply parameters.
self.parameter_setting_tool.set_parameters(x_i)

# Now the original feature will have drifted. Run feature tracking
# to bring it back.
self.run_feature_tracking(
**self.feature_tracking_kwargs
)

# Get a new image after feature tracking.
acquired_image_path = self.image_acquisition_tool.acquire_image(
**self.image_acquisition_kwargs
)
image = Image.open(acquired_image_path)
image = np.array(image)

objective_values[i, 0] = np.std(image)
return objective_values
37 changes: 37 additions & 0 deletions tests/test_bo.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import torch

from eaa.tools.bo import BayesianOptimizationTool
from eaa.task_managers.tuning.bo import BayesianOptimizationTaskManager

import test_utils as tutils

Expand Down Expand Up @@ -85,6 +86,41 @@ def objective_function(x: torch.Tensor) -> torch.Tensor:

final_suggestion = candidates[0]
assert torch.allclose(final_suggestion.float(), torch.tensor([1.0, 2.0]), rtol=0.1)

def test_bo_task_manager(self):
def objective_function(x: torch.Tensor) -> torch.Tensor:
# Expected input shape: (n_samples, n_features)
# Maximum: x = [1, 2]
y = torch.exp(-((x[:, 0] - 1) ** 2 + (x[:, 1] - 2) ** 2) / (2 * 10 ** 2))
return y[:, None]

tutils.set_seed(42)

bo_tool = BayesianOptimizationTool(
bounds=([-10, -10], [10, 10]),
acquisition_function_class=botorch.acquisition.LogExpectedImprovement,
acquisition_function_kwargs={
"best_f": -100
},
model_class=botorch.models.SingleTaskGP,
model_kwargs={
"covar_module": gpytorch.kernels.MaternKernel(
nu=2.5,
)
},
optimization_function=botorch.optim.optimize_acqf,
)

task_manager = BayesianOptimizationTaskManager(
llm_config=None,
bayesian_optimization_tool=bo_tool,
n_initial_points=20,
objective_function=objective_function,
)
task_manager.run(n_iterations=20)

final_suggestion = task_manager.bayesian_optimization_tool.xs_untransformed[-1]
assert torch.allclose(final_suggestion.float(), torch.tensor([1.0, 2.0]), rtol=0.1)


if __name__ == '__main__':
Expand All @@ -94,3 +130,4 @@ def objective_function(x: torch.Tensor) -> torch.Tensor:
tester = TestBayesianOptimization()
tester.setup_method(name="", generate_data=False, generate_gold=False, debug=True)
tester.test_bayesian_optimization()
tester.test_bo_task_manager()