diff --git a/src/eaa/image_proc.py b/src/eaa/image_proc.py index 54c0ea1..9073401 100644 --- a/src/eaa/image_proc.py +++ b/src/eaa/image_proc.py @@ -25,4 +25,42 @@ def stitch_images(images: list[np.ndarray], gap: int = 0) -> np.ndarray: buffer[:img.shape[0], x : x + img.shape[1]] = img x += img.shape[1] + gap return buffer - \ No newline at end of file + + +def windowed_phase_cross_correlation( + moving: np.ndarray, + ref: np.ndarray, +) -> np.ndarray: + """Phase correlation with windowing. + + Parameters + ---------- + moving : np.ndarray + A 2D image. + ref : np.ndarray + A 2D image. + + Returns + ------- + np.ndarray + The shift of the moving image with respect to the reference image. + """ + assert np.all(np.array(moving.shape) == np.array(ref.shape)), ( + "The shapes of the moving and reference images must be the same." + ) + win_y = np.hanning(moving.shape[0]) + win_x = np.hanning(moving.shape[1]) + win = np.outer(win_y, win_x) + + f_moving = np.fft.fft2(moving * win) + f_ref = np.fft.fft2(ref * win) + + f_corr = f_moving * f_ref.conj() + f_corr = f_corr / np.abs(f_corr) + + map = np.fft.ifft2(f_corr).real + shift = np.array(np.unravel_index(np.argmax(map), map.shape)) + for i in range(2): + if shift[i] > map.shape[i] / 2: + shift[i] -= map.shape[i] + return shift diff --git a/src/eaa/maths.py b/src/eaa/maths.py index 7141bec..764b603 100644 --- a/src/eaa/maths.py +++ b/src/eaa/maths.py @@ -29,7 +29,7 @@ def gaussian_1d(x: np.ndarray, a: float, mu: float, sigma: float, c: float = 0) def fit_gaussian_1d( x: np.ndarray, y: np.ndarray, - y_threshold: float = 0.3, + y_threshold: float = 0, ) -> tuple[float, float, float]: """Fit a 1D Gaussian to the data after subtracting a linear background. diff --git a/src/eaa/task_managers/base.py b/src/eaa/task_managers/base.py index 95faa34..987e99a 100644 --- a/src/eaa/task_managers/base.py +++ b/src/eaa/task_managers/base.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, Callable import sqlite3 import logging import time @@ -327,6 +327,8 @@ def run_feedback_loop( n_first_images_to_keep: Optional[int] = None, n_past_images_to_keep: Optional[int] = None, allow_non_image_tool_responses: bool = True, + hook_functions: Optional[dict[str, Callable]] = None, + *args, **kwargs ) -> None: """Run an agent-involving feedback loop. @@ -364,7 +366,20 @@ def run_feedback_loop( allow_non_image_tool_responses : bool, optional If False, the agent will be asked to redo the tool call if it returns anything that is not an image path. + hook_functions : dict[str, Callable], optional + A dictionary of hook functions to call at certain points in the loop. + The keys specify the points where the hook functions are called, and + the values are the callables. Allowed keys are: + - `image_path_tool_response`: + args: {"img_path": str} + return: {"response": Dict[str, Any], "outgoing": Dict[str, Any]} + Executed when the tool response is an image path, after the tool + response is added to the context but before the image is loaded and + sent to the agent. When this function is given, it **replaces** the + `agent.receive` call so be sure to send the image to the agent in + the hook if this is intended. """ + hook_functions = hook_functions or {} round = 0 image_path = None response, outgoing = self.agent.receive( @@ -404,12 +419,15 @@ def run_feedback_loop( if tool_response_type == ToolReturnType.IMAGE_PATH: image_path = tool_response["content"] - response, outgoing = self.agent.receive( - message_with_acquired_image, - image_path=image_path, - context=self.context, - return_outgoing_message=True - ) + if "image_path_tool_response" in hook_functions: + response, outgoing = hook_functions["image_path_tool_response"](image_path) + else: + response, outgoing = self.agent.receive( + message_with_acquired_image, + image_path=image_path, + context=self.context, + return_outgoing_message=True + ) elif tool_response_type == ToolReturnType.EXCEPTION: response, outgoing = self.agent.receive( "The tool returned an exception. Please fix the exception and try again.", diff --git a/src/eaa/task_managers/imaging/param_tuning.py b/src/eaa/task_managers/imaging/param_tuning.py deleted file mode 100644 index bb84997..0000000 --- a/src/eaa/task_managers/imaging/param_tuning.py +++ /dev/null @@ -1,238 +0,0 @@ -from typing import Optional -from textwrap import dedent -import logging - -from eaa.tools.imaging.acquisition import AcquireImage -from eaa.tools.imaging.param_tuning import SetParameters -from eaa.task_managers.imaging.base import ImagingBaseTaskManager -from eaa.tools.base import ToolReturnType -from eaa.agents.base import print_message -from eaa.api.llm_config import LLMConfig - -logger = logging.getLogger(__name__) - - -class ParameterTuningTaskManager(ImagingBaseTaskManager): - - def __init__( - self, - llm_config: LLMConfig = None, - param_setting_tool: SetParameters = None, - acquisition_tool: AcquireImage = None, - initial_parameters: dict[str, float] = None, - parameter_ranges: list[tuple[float, ...], tuple[float, ...]] = None, - message_db_path: Optional[str] = None, - build: bool = True, - *args, **kwargs - ) -> None: - """An agent that searches for the best setup parameters - for an imaging system. - - Parameters - ---------- - llm_config : LLMConfig - The configuration for the LLM. - param_setting_tool : SetParameters - The tool to use to set the parameters. - acquisition_tool : SimulatedAcquireImage, optional - The tool to use to acquire images. This tool will - not be called by AI; it is executed automatically - following each parameter adjustment. - initial_parameters : dict[str, float], optional - The initial parameters given as a dictionary of - parameter names and values. - parameter_ranges : list[tuple[float, ...], tuple[float, ...]] - The ranges of the parameters. It should be given as a list of - 2 tuples, where the first tuple gives the lower bounds and the - second tuple gives the upper bounds. The order of the parameters - should match the order of the initial parameters. - message_db_path : Optional[str] - If provided, the entire chat history will be stored in - a SQLite database at the given path. This is essential - if you want to use the WebUI, which polls the database - for new messages. - """ - if "tools" in kwargs.keys(): - raise ValueError( - "`tools` should not be provided to `ParameterTuningTaskManager`. Instead, " - "provide the `param_setting_tool` and `acquisition_tool`." - ) - - self.param_setting_tool = param_setting_tool - self.acquisition_tool = acquisition_tool - self.initial_parameters = initial_parameters - self.parameter_names = list(initial_parameters.keys()) - self.parameter_ranges = parameter_ranges - - super().__init__( - llm_config=llm_config, - tools=[param_setting_tool], - message_db_path=message_db_path, - build=build, - *args, **kwargs - ) - - def set_initial_parameters(self, initial_params: dict[str, float]): - self.initial_parameters = initial_params - - def prerun_check(self, *args, **kwargs) -> bool: - if self.initial_parameters is None: - raise ValueError("initial_parameters must be provided.") - return super().prerun_check(*args, **kwargs) - - def run( - self, - acquisition_tool_kwargs: dict = {}, - n_past_images_to_keep: int = 3, - max_iters: int = 10, - initial_prompt: Optional[str] = None, - additional_prompt: Optional[str] = None, - ) -> None: - """Run the parameter tuning task. - - Parameters - ---------- - acquisition_tool_kwargs : dict - The arguments for the acquisition tool. These arguments will be - used to acquire images for evaluation. - n_past_images_to_keep : int, optional - The number of most recent images to keep in the context. Having past - images in the context allows to agent to "remember" images it - has seen before; however, it also increases the context size - and inference cost. - max_iters : int, optional - The maximum number of iterations to run. - initial_prompt : str, optional - If provided, this prompt will override the default initial prompt. - additional_prompt : str, optional - If provided, this prompt will be added to the initial prompt (either - the default one or the one provided by `initial_prompt`). - """ - self.prerun_check() - - initial_parameter_values = list(self.initial_parameters.values()) - self.param_setting_tool.set_parameters(initial_parameter_values) - last_img_path = self.acquisition_tool.acquire_image(**acquisition_tool_kwargs) - - bounds_str = "" - for i, param in enumerate(self.parameter_names): - bounds_str += f"{param}: {self.parameter_ranges[0][i]} to {self.parameter_ranges[1][i]}\n" - - if initial_prompt is None: - initial_prompt = dedent( - f"""\ - You are tuning the parameters of a microscope to attain the best - image sharpness. The parameters are {list(self.parameter_names)}, - and their current values are {initial_parameter_values}. An image acquired - with the current parameters is shown below. - - - - Here are the tunable ranges of the parameters: - {bounds_str} - - You can change the parameters using your parameter setting tool. - An image acquired with the new parameters will be given to you - after each parameter change. Here are some detailed instructions: - - - Tune parameters one by one. Start with the first parameter, tweak it - to attain the sharpest possible image, then move on to the next parameter. - Do not change more than one parameter at a time. - - The sharpness of the image is convex with regards to the parameters. There - is only one optimal point; assume there is no local maximum. As such, if - you find the image comes more blurry when changing a parameter in a direction, - you should consider changing it the other way; if you find the image comes - sharper when changing a parameter in a direction, you are on the right track. - - For each parameter, first get a coarse estimate of the optimal value, then - fine-tune it. To get a coarse estimate, look for a peak in the sharpness. In - other words, find a parameter value that gives a sharper image than the value - immediately before and after it. For example, if the image becomes sharper when - you increase the parameter from 4 to 5, but becomes blurrier when you increase - it from 5 to 6, then the optimal value is around 5. - - Choose the step size for changing parameters wisely. For each parameter, start - with a large step size, and decrease it as you get closer to the optimal point. - - Only call the parameter setting tool one at a time. Do not call it multiple times - in one response. - - When you finish or when you need human input, add "TERMINATE" to your response.\ - """ - ) - if additional_prompt is not None: - initial_prompt += "\nAdditional instructions:\n" + additional_prompt - - round = 0 - response, outgoing = self.agent.receive( - initial_prompt, - context=self.context, - image_path=last_img_path, - return_outgoing_message=True - ) - self.update_message_history(outgoing, update_context=True, update_full_history=True) - self.update_message_history(response, update_context=True, update_full_history=True) - while round < max_iters: - if response["content"] is not None and "TERMINATE" in response["content"]: - message = self.get_user_input( - "Termination condition triggered. What to do next? Type \"exit\" to exit. " - ) - if message.lower() == "exit": - return - else: - response, outgoing = self.agent.receive( - message, - context=self.context, - image_path=None, - return_outgoing_message=True - ) - self.update_message_history(outgoing, update_context=True, update_full_history=True) - self.update_message_history(response, update_context=True, update_full_history=True) - continue - - tool_responses, tool_response_types = self.agent.handle_tool_call(response, return_tool_return_types=True) - if len(tool_responses) == 1: - tool_response = tool_responses[0] - tool_response_type = tool_response_types[0] - # Just save the tool response, but don't send yet. We will send it - # together with the image later. - print_message(tool_response) - self.update_message_history(tool_response, update_context=True, update_full_history=True) - - if tool_response_type == ToolReturnType.EXCEPTION: - response, outgoing = self.agent.receive( - "The tool returned an exception. Please fix the exception and try again.", - image_path=None, - context=self.context, - return_outgoing_message=True - ) - else: - # Acquire an image with the new parameters. - last_img_path = self.acquisition_tool.acquire_image(**acquisition_tool_kwargs) - response, outgoing = self.agent.receive( - "An image acquired with the new parameters is shown below.", - image_path=last_img_path, - context=self.context, - return_outgoing_message=True - ) - self.purge_context_images(keep_first_n=1, keep_last_n=n_past_images_to_keep - 1) - self.update_message_history(outgoing, update_context=True, update_full_history=True) - self.update_message_history(response, update_context=True, update_full_history=True) - elif len(tool_responses) > 1: - response, outgoing = self.agent.receive( - "There are more than one tool calls in your response. " - "Make sure you only make one call at a time. Please redo " - "your tool calls.", - image_path=None, - context=self.context, - return_outgoing_message=True - ) - self.update_message_history(outgoing, update_context=True, update_full_history=True) - self.update_message_history(response, update_context=True, update_full_history=True) - else: - response, outgoing = self.agent.receive( - "There is no tool call in the response. Make sure you call the tool correctly.", - image_path=None, - context=self.context, - return_outgoing_message=True - ) - self.update_message_history(outgoing, update_context=True, update_full_history=True) - self.update_message_history(response, update_context=True, update_full_history=True) - round += 1 diff --git a/src/eaa/task_managers/tuning/base.py b/src/eaa/task_managers/tuning/base.py new file mode 100644 index 0000000..a9706d6 --- /dev/null +++ b/src/eaa/task_managers/tuning/base.py @@ -0,0 +1,76 @@ +from typing import Optional +import logging + +from eaa.tools.imaging.param_tuning import SetParameters +from eaa.task_managers.base import BaseTaskManager +from eaa.tools.base import BaseTool +from eaa.api.llm_config import LLMConfig + +logger = logging.getLogger(__name__) + + +class BaseParameterTuningTaskManager(BaseTaskManager): + + def __init__( + self, + llm_config: LLMConfig = None, + param_setting_tool: SetParameters = None, + tools: list[BaseTool] = (), + initial_parameters: dict[str, float] = None, + parameter_ranges: list[tuple[float, ...], tuple[float, ...]] = None, + message_db_path: Optional[str] = None, + build: bool = True, + *args, **kwargs + ) -> None: + """An agent that searches for the best setup parameters + for an imaging system. + + Parameters + ---------- + llm_config : LLMConfig + The configuration for the LLM. + param_setting_tool : SetParameters + The tool to use to set the parameters. + initial_parameters : dict[str, float], optional + The initial parameters given as a dictionary of + parameter names and values. + parameter_ranges : list[tuple[float, ...], tuple[float, ...]] + The ranges of the parameters. It should be given as a list of + 2 tuples, where the first tuple gives the lower bounds and the + second tuple gives the upper bounds. The order of the parameters + should match the order of the initial parameters. + tools : list[BaseTool], optional + Other tools provided to the agent. + message_db_path : Optional[str] + If provided, the entire chat history will be stored in + a SQLite database at the given path. This is essential + if you want to use the WebUI, which polls the database + for new messages. + """ + self.param_setting_tool: SetParameters = param_setting_tool + self.initial_parameters: dict[str, float] = initial_parameters + self.parameter_names = list(initial_parameters.keys()) + self.parameter_ranges = parameter_ranges + + super().__init__( + llm_config=llm_config, + tools=[param_setting_tool, *tools], + message_db_path=message_db_path, + build=build, + *args, **kwargs + ) + + def build(self, *args, **kwargs): + super().build(*args, **kwargs) + self.initialize_parameter_setting_tool() + + def initialize_parameter_setting_tool(self): + self.param_setting_tool.set_parameters(list(self.initial_parameters.values())) + + def prerun_check(self, *args, **kwargs) -> bool: + if self.initial_parameters is None: + raise ValueError("initial_parameters must be provided.") + return super().prerun_check(*args, **kwargs) + + def run(self, *args, **kwargs) -> None: + raise NotImplementedError \ No newline at end of file diff --git a/src/eaa/task_managers/tuning/focusing.py b/src/eaa/task_managers/tuning/focusing.py new file mode 100644 index 0000000..0f509f1 --- /dev/null +++ b/src/eaa/task_managers/tuning/focusing.py @@ -0,0 +1,506 @@ +from typing import Optional +from textwrap import dedent +import logging + +from eaa.tools.imaging.acquisition import AcquireImage +from eaa.tools.imaging.param_tuning import SetParameters +from eaa.task_managers.tuning.base import BaseParameterTuningTaskManager +from eaa.tools.base import ToolReturnType, BaseTool +from eaa.agents.base import print_message +from eaa.api.llm_config import LLMConfig +from eaa.image_proc import windowed_phase_cross_correlation + +logger = logging.getLogger(__name__) + + +class ScanningMicroscopeFocusingTaskManager(BaseParameterTuningTaskManager): + + def __init__( + self, + llm_config: LLMConfig = None, + param_setting_tool: SetParameters = None, + acquisition_tool: AcquireImage = None, + tools: list[BaseTool] = (), + initial_parameters: dict[str, float] = None, + parameter_ranges: list[tuple[float, ...], tuple[float, ...]] = None, + message_db_path: Optional[str] = None, + build: bool = True, + *args, **kwargs + ): + """A task manager for focusing a scanning microscope. + + The task manager assumes that the user has a test pattern that has + thin lines that can be used to evaluate the focus. It expects a + 2D image acquisition tool, a line scan tool, and a parameter setting + tool. The workflow is as follows: + + 1. The user provides a reference image that highlights the thin + feature that should be used to evaluate the focus through line + scan, or describe it verbally. + 2. The agent runs a line scan across the feature and obtain its + line profile and the FWHM of its Gaussian fit. + 3. The agent uses the parameter setting tool to adjust the parameters + controlling the focus. + 4. The agent runs a 2D image scan around the area to acquire a new image, + which may have drifted due to the focus adjustment. + 5. The agent runs a new line scan across the same feature used previously + and compare the FWHM of the Gaussian fit. + 6. The agent repeats the process until the FWHM of the Gaussian fit is + minimized. + + Parameters + ---------- + llm_config : LLMConfig, optional + The LLM configuration to use. + param_setting_tool : SetParameters + The tool to use to set the parameters. + acquisition_tool : AcquireImage + The BaseTool object used to acquire data. It should contain a 2D + image acquisition tool and a line scan tool. + tools : list[BaseTool], optional + Other tools provided to the agent. + initial_parameters : dict[str, float], optional + The initial parameters given as a dictionary of + parameter names and values. + parameter_ranges : list[tuple[float, ...], tuple[float, ...]] + The ranges of the parameters. It should be given as a list of + 2 tuples, where the first tuple gives the lower bounds and the + second tuple gives the upper bounds. The order of the parameters + should match the order of the initial parameters. + message_db_path : Optional[str], optional + If provided, the entire chat history will be stored in + a SQLite database at the given path. This is essential + if you want to use the WebUI, which polls the database + for new messages. + build : bool, optional + Whether to build the internal state of the task manager. + """ + self.acquisition_tool = acquisition_tool + + self.last_acquisition_count_registered = -1 + + super().__init__( + llm_config=llm_config, + param_setting_tool=param_setting_tool, + tools=[acquisition_tool, *tools], + initial_parameters=initial_parameters, + parameter_ranges=parameter_ranges, + message_db_path=message_db_path, + build=build, + *args, **kwargs + ) + + def run_registration_and_send_image(self, image_path: str) -> None: + """Register the new image with the previous one and + send the offset and the new image to the agent. + + This routine assumes `self.image_km1` and `self.image_k` of + `self.acquisition_tool` are already set. + """ + image_k = self.acquisition_tool.image_k + image_km1 = self.acquisition_tool.image_km1 + + if ( + image_km1 is None + or self.acquisition_tool.counter == self.last_acquisition_count_registered + ): + response, outgoing = self.agent.receive( + "Here is the new image.", + image_path=image_path, + context=self.context, + return_outgoing_message=True + ) + else: + # Run registration. + image_k = image_k if image_k.ndim == 2 else image_k.mean(-1) + image_km1 = image_km1 if image_km1.ndim == 2 else image_km1.mean(-1) + shift = windowed_phase_cross_correlation(image_k, image_km1) + shift = shift * self.acquisition_tool.psize_k + + response, outgoing = self.agent.receive( + f"Here is the new image. Phase correlation has found the offset between " + f"the new image and the previous one to be {shift.tolist()} (y, x). Use " + f"this offset to adjust the line scan positions by **adding** it to both " + f"the x and y coordinates of the start and end points of the previous line scan.", + image_path=image_path, + context=self.context, + return_outgoing_message=True + ) + self.last_acquisition_count_registered = self.acquisition_tool.counter + return response, outgoing + + def run( + self, + reference_image_path: str, + reference_feature_description: Optional[str] = None, + suggested_2d_scan_kwargs: dict = None, + suggested_parameter_step_size: Optional[float] = None, + line_scan_step_size: float = None, + initial_prompt: Optional[str] = None, + max_iters: int = 20, + n_past_images_to_keep: Optional[int] = None, + additional_prompt: Optional[str] = None, + *args, **kwargs + ): + """Run the focusing task. + + Parameters + ---------- + reference_image_path : Optional[str] + The path to the reference image, which should show a 2D scan + of the ROI with the desired line scan path indicated by a + marker. `reference_feature_description` will be ignored if + this argument is provided. + reference_feature_description : Optional[str] + The description of the feature across which line scans should + be done. Ignored if `reference_image_path` is provided. + suggested_2d_scan_kwargs : dict + The suggested kwargs for the 2D scan. The argument should match + the arguments of the 2D image acquisition tool. + suggested_parameter_step_size : float + The suggested step size for the parameter adjustment. + line_scan_step_size : float + The step size for the line scan. + initial_prompt : Optional[str] + If provided, this prompt will override the default initial prompt. + max_iters : int, optional + The maximum number of iterations to run. + n_past_images_to_keep : int, optional + The number of past images to keep in the context. If None, all images + will be kept. + additional_prompt : Optional[str] + If provided, this prompt will be added to the initial prompt. + """ + if reference_image_path is None and reference_feature_description is None: + raise ValueError( + "Either `reference_image_path` or `reference_feature_description` must be provided." + ) + + if initial_prompt is None: + feat_text_description = "" + if reference_feature_description is not None: + feat_text_description = f"Also, here is the description of the feature: {reference_feature_description}. " + param_step_size_prompt = "" + if suggested_parameter_step_size is not None: + param_step_size_prompt = dedent( + f"""\ + - The suggested step size for adjusting the parameter is + {suggested_parameter_step_size}. You can adjust the step size + to a smaller value if you want to fine-tune the parameter. + """ + ) + + initial_prompt = dedent( + f"""\ + You will adjust the focus of a scanning microscope by adjusting + the parameters of its optics. The focusing quality can be evalutated + by performing a line scan across a thin feature and observe the FWHM + of its Gaussian fit. The smaller the FWHM, the sharper the image. + But each time you adjust the focus, the image may drift due to + the change of the optics. You will need to perform a 2D scan + prior to the line scan to locate the feature that is line-scanned. + + + You will see a reference 2D scan image in this message. + This image is acquired in the region of interest that + contains the thin feature to be line-scanned. The line scan path + across that feature is indicated by a marker. {feat_text_description} + + Follow the procedure below to focus the microscope: + + 1. First, perform a 2D scan of the region of interest using the + "acquire_image" tool and the following arguments: + {suggested_2d_scan_kwargs}. + The image should look similar to the reference image. + Determine the coordinates of the line scan path across the feature, + and use the "scan_line" tool to perform a line scan across the feature. + 2. The line scan tool will return a plot along the scan line. You should + see a peak in the plot. A Gaussian fit will be included in the plot + and the FWHM of the Gaussian fit will be shown. + 3. Adjust the optics parameters using the parameter setting tool. + The initial parameter values are {self.initial_parameters}. + 4. Acquire an image of the region using the image acquisition tool. + Here are the suggested arguments: {suggested_2d_scan_kwargs}. The + image acquired may have drifted compared to the last one you saw, + but you should still see the line-scanned feature there. If not, + try adjusting the image acquisition tool's parameters to locate that + feature. Along with this image, you will also be given the offset of + this image compared to the previous image found through phase correlation. + Use this offset to adjust the line scan positions. Note that the offset + is just a suggestion. If the new image does not appear to have any overlap + with the previous one, the offset won't be reliable. In that case, try + adjusting the image acquisition tool's parameters to move the field of view + closer to the previous image. + 5. Once you find the line-scanned feature, perform a new line scan across + it again. Due to the drift, the start/end points' coordinates may need to + be changed. Read the coordinates from the axis ticks. + 6. You will be presented with the new line scan plot and the FWHM of the + Gaussian fit. + 7. Compare the new FWHM with the last one. If it is smaller, you are on the + right track. Keep adjusting the parameters to the same direction. Otherwise, + adjust the parameters in the opposite direction. + 8. Repeat the process from step 4. + 9. When you find the FWHM is minimized, you are done. Add "TERMINATE" to + your response to hand over control back to the user. + + Other notes: + + - Your line scan should cross only one line feature, and you should see + **exactly one peak** in the line scan plot. If there isn't one, or if there + are multiple peaks, or if the Gaussian fit looks bad, check your arguments + to the line scan tool and run it again. Make sure your line scan strictly + follow the marker in the reference image. + - The line scan plot should show a complete peak. If the peak is incomplete, + adjust the line scan tool's arguments to make it complete. + - The minimal point of the FWHM is indicated by an inflection of the trend + of the FWHM with regards to the optics parameters. For example, if the FWHM + is 3 with a parameter value of 10, then 1 with a parameter value of 11, then + 3 with a parameter value of 12, this means the optimal parameter value is around + 11. + {param_step_size_prompt} + - When calling a tool, explain what you are doing. + - When making a tool call, only call one tool at a time. Do not call multiple + tools in one response. + + When you finish or when you need human input, add "TERMINATE" to your response.\ + """ + ) + if additional_prompt is not None: + initial_prompt += "\nAdditional instructions:\n" + additional_prompt + + # Always keep the first (reference) image. + self.run_feedback_loop( + initial_prompt=initial_prompt, + initial_image_path=reference_image_path, + store_all_images_in_context=True, + allow_non_image_tool_responses=True, + n_first_images_to_keep=1, + n_past_images_to_keep=n_past_images_to_keep, + max_rounds=max_iters, + hook_functions={ + "image_path_tool_response": self.run_registration_and_send_image + }, + *args, **kwargs + ) + + +class ParameterTuningTaskManager(BaseParameterTuningTaskManager): + + def __init__( + self, + llm_config: LLMConfig = None, + param_setting_tool: SetParameters = None, + acquisition_tool: AcquireImage = None, + initial_parameters: dict[str, float] = None, + parameter_ranges: list[tuple[float, ...], tuple[float, ...]] = None, + message_db_path: Optional[str] = None, + build: bool = True, + *args, **kwargs + ) -> None: + """An agent that searches for the best setup parameters + for an imaging system. + + Parameters + ---------- + llm_config : LLMConfig + The configuration for the LLM. + param_setting_tool : SetParameters + The tool to use to set the parameters. + acquisition_tool : SimulatedAcquireImage, optional + The tool to use to acquire images. This tool will + not be called by AI; it is executed automatically + following each parameter adjustment. + initial_parameters : dict[str, float], optional + The initial parameters given as a dictionary of + parameter names and values. + parameter_ranges : list[tuple[float, ...], tuple[float, ...]] + The ranges of the parameters. It should be given as a list of + 2 tuples, where the first tuple gives the lower bounds and the + second tuple gives the upper bounds. The order of the parameters + should match the order of the initial parameters. + message_db_path : Optional[str] + If provided, the entire chat history will be stored in + a SQLite database at the given path. This is essential + if you want to use the WebUI, which polls the database + for new messages. + """ + if "tools" in kwargs.keys(): + raise ValueError( + "`tools` should not be provided to `ParameterTuningTaskManager`. Instead, " + "provide the `param_setting_tool` and `acquisition_tool`." + ) + + self.acquisition_tool = acquisition_tool + + super().__init__( + llm_config=llm_config, + param_setting_tool=param_setting_tool, + tools=[param_setting_tool], + initial_parameters=initial_parameters, + parameter_ranges=parameter_ranges, + message_db_path=message_db_path, + build=build, + *args, **kwargs + ) + + def prerun_check(self, *args, **kwargs) -> bool: + if self.initial_parameters is None: + raise ValueError("initial_parameters must be provided.") + return super().prerun_check(*args, **kwargs) + + def run( + self, + acquisition_tool_kwargs: dict = {}, + n_past_images_to_keep: int = 3, + max_iters: int = 10, + initial_prompt: Optional[str] = None, + additional_prompt: Optional[str] = None, + ) -> None: + """Run the parameter tuning task. + + Parameters + ---------- + acquisition_tool_kwargs : dict + The arguments for the acquisition tool. These arguments will be + used to acquire images for evaluation. + n_past_images_to_keep : int, optional + The number of most recent images to keep in the context. Having past + images in the context allows to agent to "remember" images it + has seen before; however, it also increases the context size + and inference cost. + max_iters : int, optional + The maximum number of iterations to run. + initial_prompt : str, optional + If provided, this prompt will override the default initial prompt. + additional_prompt : str, optional + If provided, this prompt will be added to the initial prompt (either + the default one or the one provided by `initial_prompt`). + """ + self.prerun_check() + + initial_parameter_values = list(self.initial_parameters.values()) + self.param_setting_tool.set_parameters(initial_parameter_values) + last_img_path = self.acquisition_tool.acquire_image(**acquisition_tool_kwargs) + + bounds_str = "" + for i, param in enumerate(self.parameter_names): + bounds_str += f"{param}: {self.parameter_ranges[0][i]} to {self.parameter_ranges[1][i]}\n" + + if initial_prompt is None: + initial_prompt = dedent( + f"""\ + You are tuning the parameters of a microscope to attain the best + image sharpness. The parameters are {list(self.parameter_names)}, + and their current values are {initial_parameter_values}. An image acquired + with the current parameters is shown below. + + + + Here are the tunable ranges of the parameters: + {bounds_str} + + You can change the parameters using your parameter setting tool. + An image acquired with the new parameters will be given to you + after each parameter change. Here are some detailed instructions: + + - Tune parameters one by one. Start with the first parameter, tweak it + to attain the sharpest possible image, then move on to the next parameter. + Do not change more than one parameter at a time. + - The sharpness of the image is convex with regards to the parameters. There + is only one optimal point; assume there is no local maximum. As such, if + you find the image comes more blurry when changing a parameter in a direction, + you should consider changing it the other way; if you find the image comes + sharper when changing a parameter in a direction, you are on the right track. + - For each parameter, first get a coarse estimate of the optimal value, then + fine-tune it. To get a coarse estimate, look for a peak in the sharpness. In + other words, find a parameter value that gives a sharper image than the value + immediately before and after it. For example, if the image becomes sharper when + you increase the parameter from 4 to 5, but becomes blurrier when you increase + it from 5 to 6, then the optimal value is around 5. + - Choose the step size for changing parameters wisely. For each parameter, start + with a large step size, and decrease it as you get closer to the optimal point. + - Only call the parameter setting tool one at a time. Do not call it multiple times + in one response. + + When you finish or when you need human input, add "TERMINATE" to your response.\ + """ + ) + if additional_prompt is not None: + initial_prompt += "\nAdditional instructions:\n" + additional_prompt + + round = 0 + response, outgoing = self.agent.receive( + initial_prompt, + context=self.context, + image_path=last_img_path, + return_outgoing_message=True + ) + self.update_message_history(outgoing, update_context=True, update_full_history=True) + self.update_message_history(response, update_context=True, update_full_history=True) + while round < max_iters: + if response["content"] is not None and "TERMINATE" in response["content"]: + message = self.get_user_input( + "Termination condition triggered. What to do next? Type \"exit\" to exit. " + ) + if message.lower() == "exit": + return + else: + response, outgoing = self.agent.receive( + message, + context=self.context, + image_path=None, + return_outgoing_message=True + ) + self.update_message_history(outgoing, update_context=True, update_full_history=True) + self.update_message_history(response, update_context=True, update_full_history=True) + continue + + tool_responses, tool_response_types = self.agent.handle_tool_call(response, return_tool_return_types=True) + if len(tool_responses) == 1: + tool_response = tool_responses[0] + tool_response_type = tool_response_types[0] + # Just save the tool response, but don't send yet. We will send it + # together with the image later. + print_message(tool_response) + self.update_message_history(tool_response, update_context=True, update_full_history=True) + + if tool_response_type == ToolReturnType.EXCEPTION: + response, outgoing = self.agent.receive( + "The tool returned an exception. Please fix the exception and try again.", + image_path=None, + context=self.context, + return_outgoing_message=True + ) + else: + # Acquire an image with the new parameters. + last_img_path = self.acquisition_tool.acquire_image(**acquisition_tool_kwargs) + response, outgoing = self.agent.receive( + "An image acquired with the new parameters is shown below.", + image_path=last_img_path, + context=self.context, + return_outgoing_message=True + ) + self.purge_context_images(keep_first_n=1, keep_last_n=n_past_images_to_keep - 1) + self.update_message_history(outgoing, update_context=True, update_full_history=True) + self.update_message_history(response, update_context=True, update_full_history=True) + elif len(tool_responses) > 1: + response, outgoing = self.agent.receive( + "There are more than one tool calls in your response. " + "Make sure you only make one call at a time. Please redo " + "your tool calls.", + image_path=None, + context=self.context, + return_outgoing_message=True + ) + self.update_message_history(outgoing, update_context=True, update_full_history=True) + self.update_message_history(response, update_context=True, update_full_history=True) + else: + response, outgoing = self.agent.receive( + "There is no tool call in the response. Make sure you call the tool correctly.", + image_path=None, + context=self.context, + return_outgoing_message=True + ) + self.update_message_history(outgoing, update_context=True, update_full_history=True) + self.update_message_history(response, update_context=True, update_full_history=True) + round += 1 diff --git a/src/eaa/tools/base.py b/src/eaa/tools/base.py index f887578..03ab5b7 100644 --- a/src/eaa/tools/base.py +++ b/src/eaa/tools/base.py @@ -49,32 +49,64 @@ def convert_image_to_base64(self, image: np.ndarray) -> str: img_base64 = base64.b64encode(buf.getvalue()).decode('utf-8') return img_base64 - def save_image_to_temp_dir( + def plot_2d_image( self, image: np.ndarray, - filename: Optional[str] = None, - add_timestamp: bool = False, add_axis_ticks: bool = False, x_ticks: Optional[List[float]] = None, y_ticks: Optional[List[float]] = None, - ) -> str: + add_grid_lines: bool = False, + invert_yaxis: bool = False, + ) -> plt.Figure: """Save an image to the temporary directory. Parameters ---------- image : np.ndarray The image to save. - filename : str, optional - The filename to save the image as. If not provided, the image is - saved as "image.png". - add_timestamp : bool, optional - If True, the timestamp is added to the filename. add_axis_ticks : bool, optional If True, axis ticks are added to the image to indicate positions. x_ticks : List[float], optional The x-axis ticks to add to the image. Required when `add_axis_ticks` is True. y_ticks : List[float], optional The y-axis ticks to add to the image. Required when `add_axis_ticks` is True. + add_grid_lines : bool, optional + If True, grid lines are added to the image. + invert_yaxis : bool, optional + If True, the y-axis is inverted. + """ + fig, ax = plt.subplots(1, 1) + if add_axis_ticks: + ax.imshow(image, cmap='gray') + ax.set_xticks(np.linspace(0, len(x_ticks) - 1, 5, dtype=int)) + ax.set_yticks(np.linspace(0, len(y_ticks) - 1, 5, dtype=int)) + ax.set_xticklabels([np.round(x_ticks[i], 2) for i in ax.get_xticks()]) + ax.set_yticklabels([np.round(y_ticks[i], 2) for i in ax.get_yticks()]) + ax.grid(add_grid_lines) + if invert_yaxis: + ax.invert_yaxis() + ax.set_xlabel("x") + ax.set_ylabel("y") + plt.tight_layout() + return fig + + def save_image_to_temp_dir( + self, + fig: plt.Figure, + filename: Optional[str] = None, + add_timestamp: bool = False + ) -> str: + """Save a figure to the temporary directory. + + Parameters + ---------- + fig : plt.Figure + The figure to save. + filename : str, optional + The filename to save the image as. If not provided, the image is + saved as "image.png". + add_timestamp : bool, optional + If True, the timestamp is added to the filename. """ if not os.path.exists(".tmp"): os.makedirs(".tmp") @@ -87,21 +119,8 @@ def save_image_to_temp_dir( parts = os.path.splitext(filename) filename = parts[0] + "_" + eaa.util.get_timestamp() + parts[1] path = os.path.join(".tmp", filename) - if add_axis_ticks: - fig, ax = plt.subplots(1, 1) - ax.imshow(image, cmap='gray') - ax.set_xticks(np.linspace(0, len(x_ticks) - 1, 5, dtype=int)) - ax.set_yticks(np.linspace(0, len(y_ticks) - 1, 5, dtype=int)) - ax.set_xticklabels([np.round(x_ticks[i], 2) for i in ax.get_xticks()]) - ax.set_yticklabels([np.round(y_ticks[i], 2) for i in ax.get_yticks()]) - ax.grid(True) - ax.set_xlabel("x") - ax.set_ylabel("y") - plt.tight_layout() - fig.savefig(path, bbox_inches="tight", pad_inches=0) - plt.close(fig) - else: - plt.imsave(path, image, cmap='gray') + fig.savefig(path, bbox_inches="tight", pad_inches=0) + plt.close(fig) return path def create_image_message(self, image: np.ndarray, text: str) -> str: diff --git a/src/eaa/tools/imaging/acquisition.py b/src/eaa/tools/imaging/acquisition.py index 2ef16a2..c7843ab 100644 --- a/src/eaa/tools/imaging/acquisition.py +++ b/src/eaa/tools/imaging/acquisition.py @@ -16,6 +16,17 @@ logger = logging.getLogger(__name__) +def post_image_acquisition(func): + """A decorator to be used to decorate the `acquire_image` method. + The decorated function will be called after the image is acquired. + """ + def wrapper(self, *args, **kwargs): + ret = func(self, *args, **kwargs) + self.counter += 1 + return ret + return wrapper + + class AcquireImage(BaseTool): name: str = "acquire_image" @@ -35,6 +46,19 @@ def __init__(self, show_image_in_real_time: bool = False, *args, **kwargs): } ] + # Buffered images: + # image_0 - the first image + # image_km1 - the previous image + # image_k - the current image + self.image_0: np.ndarray = None + self.image_km1: np.ndarray = None + self.image_k: np.ndarray = None + self.psize_0 = None + self.psize_km1 = None + self.psize_k = None + + self.counter = 0 + def update_real_time_view(self, image: np.ndarray): if self.rt_fig is None: self.rt_fig, ax = plt.subplots(1, 1, squeeze=True) @@ -44,7 +68,26 @@ def update_real_time_view(self, image: np.ndarray): ax.imshow(image) plt.draw() plt.pause(0.001) # Small pause to allow GUI to update + + def update_image_buffers(self, new_image: np.ndarray, psize: float = 1): + """Update the image buffers. + + Parameters + ---------- + new_image : np.ndarray + The new image. + psize : float, optional + The pixel size (or scan step) of the new image. + """ + if self.counter == 0: + self.image_0 = new_image + self.psize_0 = psize + self.image_km1 = self.image_k + self.psize_km1 = self.psize_k + self.image_k = new_image + self.psize_k = psize + @post_image_acquisition def acquire_image(self, *args, **kwargs): raise NotImplementedError @@ -58,6 +101,11 @@ def __init__( whole_image: np.ndarray, return_message: bool = True, add_axis_ticks: bool = False, + add_grid_lines: bool = False, + invert_yaxis: bool = False, + line_scan_gaussian_fit_y_threshold: float = 0, + add_line_scan_candidates_to_image: bool = False, + plot_image_in_log_scale: bool = False, *args, **kwargs ): """The simulated acquisition tool. @@ -75,14 +123,33 @@ def __init__( add_axis_ticks : bool, optional If True, the tool adds axis ticks to the acquired image that indicate the positions. + add_grid_lines : bool, optional + If True, the tool adds grid lines to the image. + invert_yaxis : bool, optional + If True, the tool inverts the y-axis of the acquired image. + line_scan_gaussian_fit_y_threshold : float, optional + The threshold for the Gaussian fit of the line scan. Only points whose + y values are above y_min + y_threshold * (y_max - y_min) are considered + for fitting. To disable point selection, set y_threshold to 0. + add_line_scan_candidates_to_image : bool, optional + If True, the tool adds line scan candidates to the image. + plot_image_in_log_scale : bool, optional + If True, 2D images are plotted in log scale. """ self.whole_image = whole_image self.interpolator = None self.blur = None self.offset = np.array([0, 0]) + self.line_scan_gaussian_fit_y_threshold = line_scan_gaussian_fit_y_threshold self.return_message = return_message self.add_axis_ticks = add_axis_ticks + self.add_grid_lines = add_grid_lines + self.invert_yaxis = invert_yaxis + self.add_line_scan_candidates_to_image = add_line_scan_candidates_to_image + self.plot_image_in_log_scale = plot_image_in_log_scale + + self.line_scan_candidates: Dict[int, list[int]] = {} super().__init__(*args, **kwargs) @@ -132,7 +199,72 @@ def set_offset(self, offset: np.ndarray): of (y, x) coordinates. """ self.offset = offset + + def add_line_scan_candidates( + self, + fig: plt.Figure, + length: float = 30, + gap: float = 5, + spacing: float = 30, + horizontal: bool = True, + ): + """Add markers indicating line scan paths that can be chosen from + to a figure. + + Parameters + ---------- + fig : plt.Figure + The figure to add the markers to. + ny, nx : int + The number of markers to add in the y and x directions. + length : float + The length of the markers. + gap : float + The gap between the ends of the markers. + spacing : float + The parallel spacing between the markers. + horizontal : bool, optional + If True, the markers are added horizontally. If False, the markers + are added vertically. + """ + self.line_scan_candidates = {} + + ax = fig.get_axes()[0] + xlim = ax.get_xlim() + ylim = ax.get_ylim() + ylim_sorted = sorted(ylim) + if horizontal: + start_xs = np.arange(xlim[0], xlim[1], length + gap) + end_xs = start_xs + length + start_ys = np.arange(ylim_sorted[0], ylim_sorted[1], spacing) + end_ys = start_ys + else: + start_ys = np.arange(ylim_sorted[0], ylim_sorted[1], length + gap) + end_ys = start_ys + length + start_xs = np.arange(xlim[0], xlim[1], spacing) + end_xs = start_xs + start_xs_all, start_ys_all = np.meshgrid(start_xs, start_ys, indexing="ij") + end_xs_all, end_ys_all = np.meshgrid(end_xs, end_ys, indexing="ij") + start_xs_all = start_xs_all.flatten() + start_ys_all = start_ys_all.flatten() + end_xs_all = end_xs_all.flatten() + end_ys_all = end_ys_all.flatten() + for i in range(len(start_xs_all)): + ax.plot([start_xs_all[i], end_xs_all[i]], [start_ys_all[i], end_ys_all[i]], color="red") + ax.text( + (start_xs_all[i] + end_xs_all[i]) / 2, + (start_ys_all[i] + end_ys_all[i]) / 2, + f"{i}", + color="red", + horizontalalignment="center", + verticalalignment="bottom" + ) + self.line_scan_candidates[i] = [start_xs_all[i], start_ys_all[i], end_xs_all[i], end_ys_all[i]] + ax.set_xlim(xlim) + ax.set_ylim(ylim) + return fig + @post_image_acquisition def acquire_image( self, loc_y: float, @@ -160,29 +292,35 @@ def acquire_image( loc = [loc_y, loc_x] size = [size_y, size_x] logger.info(f"Acquiring image of size {size} at location {loc}.") - y = np.arange(loc[0] + self.offset[0], loc[0] + size[0] + self.offset[0]) - x = np.arange(loc[1] + self.offset[1], loc[1] + size[1] + self.offset[1]) - arr = self.interpolator(y, x).reshape(size) + y = np.arange(loc[0], loc[0] + size[0]) + x = np.arange(loc[1], loc[1] + size[1]) + arr = self.interpolator(y + self.offset[0], x + self.offset[1]).reshape(size) if self.blur is not None and self.blur > 0: arr = ndi.gaussian_filter(arr, self.blur) if self.show_image_in_real_time: self.update_real_time_view(arr) + + self.update_image_buffers(arr, psize=1) + if self.return_message: filename = f"image_{loc_y}_{loc_x}_{size_y}_{size_x}_{eaa.util.get_timestamp()}.png" - self.save_image_to_temp_dir( - arr, - filename, + fig = self.plot_2d_image( + arr if not self.plot_image_in_log_scale else np.log10(arr + 1), add_axis_ticks=self.add_axis_ticks, x_ticks=x, y_ticks=y, + add_grid_lines=self.add_grid_lines, + invert_yaxis=self.invert_yaxis ) + if self.add_line_scan_candidates_to_image: + fig = self.add_line_scan_candidates(fig) + self.save_image_to_temp_dir(fig, filename, add_timestamp=False) return f".tmp/{filename}" else: return arr - def scan_line( self, start_x: float, @@ -214,6 +352,7 @@ def scan_line( d_tot = np.linalg.norm(pt_end - pt_start) ds = np.arange(0, d_tot, scan_step) pts = pt_start + ds[:, None] * (pt_end - pt_start) / d_tot + pts = pts + self.offset arr = self.line_interpolator(pts).reshape(-1) @@ -221,7 +360,9 @@ def scan_line( arr = ndi.gaussian_filter(arr, self.blur) # Fit a Gaussian to the line scan - a, mu, sigma, c = eaa.maths.fit_gaussian_1d(ds, arr) + a, mu, sigma, c = eaa.maths.fit_gaussian_1d( + ds, arr, y_threshold=self.line_scan_gaussian_fit_y_threshold + ) val_gauss = eaa.maths.gaussian_1d(ds, a, mu, sigma, c) fwhm = 2.35 * sigma @@ -251,3 +392,29 @@ def scan_line( fig.savefig(fname) plt.close(fig) return fname + + def scan_line_by_choice( + self, + choice: int, + scan_step: float = 1.0, + ) -> Annotated[str, "The path to the plot of the line scan."]: + """Conduct a line scan along a chosen path. To use this tool, + you must call the tool "acquire_image" first, examine the image + with the candidates, and then call this tool with the index of the + candidate you want to use. + + Parameters + ---------- + choice : int + The index of the line scan candidate to use. You should have + seen an image with the line scan candidates. + scan_step : float + The step size of the line scan. + + Returns + ------- + str + The path of the plot of the line scan saved in hard drive. + """ + start_x, start_y, end_x, end_y = self.line_scan_candidates[choice] + return self.scan_line(start_x, start_y, end_x, end_y, scan_step=scan_step) diff --git a/src/eaa/tools/imaging/param_tuning.py b/src/eaa/tools/imaging/param_tuning.py index f79733b..4675829 100644 --- a/src/eaa/tools/imaging/param_tuning.py +++ b/src/eaa/tools/imaging/param_tuning.py @@ -150,9 +150,16 @@ def __init__( blur_factor : float The factor determining the amount of blurring of the acquisition tool due to deviation from the true parameters. + The amount of blurring is determined as + ``sum(abs(delta_params / range)) * blur_factor``, where ``delta_params`` + is the difference between the true parameters and the parameters to set. drift_factor : float The factor determining the amount of drift of the acquisition tool - due to deviation from the true parameters. + due to deviation from the true parameters. The amount of drift is + determined as + ``mean(delta_params / range) * drift_factor * z``, + where ``z`` is a random variable from a uniform distribution between + 0 and 1. """ super().__init__( parameter_names=parameter_names, @@ -193,7 +200,7 @@ def set_parameters( # Set drift. if self.len_parameter_history > 0 and self.drift_factor > 0: mean_delta = ((self.get_parameter_at_iteration(-1) - parameters) / scalers).mean() - drift = np.random.rand(2) * mean_delta * self.drift_factor + drift = np.ones(2) * mean_delta * self.drift_factor self.acquisition_tool.set_offset(drift) # Update parameter history.