-
Notifications
You must be signed in to change notification settings - Fork 5
Pipeline Statefulness for Open World Challenges #277
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
ygefen
wants to merge
3
commits into
main
Choose a base branch
from
ygefen/feat/statefulness
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
3 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
159 changes: 159 additions & 0 deletions
159
align_system/algorithms/open_world/open_world_dialog.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,159 @@ | ||
| import json | ||
| from typing import Sequence | ||
| from rich.highlighter import JSONHighlighter | ||
| from swagger_client.models import KDMAValue | ||
| from enum import Enum | ||
| from align_system.utils import logging, call_with_coerced_args | ||
| from align_system.algorithms.abstracts import ADMComponent | ||
| from align_system.data_models.dialog import DialogElement, Dialog | ||
|
|
||
| log = logging.getLogger(__name__) | ||
| JSON_HIGHLIGHTER = JSONHighlighter() | ||
|
|
||
| class HistoryMode(Enum): | ||
| in_dialog = "IN_DIALOG" | ||
| in_prompt = "IN_PROMPT" | ||
| in_prompt_with_reasoning = "IN_PROMPT_WITH_REASONING" | ||
| off="OFF" | ||
|
|
||
| class BasicOpenWorldDialogADMComponent(ADMComponent): | ||
| ''' | ||
| IMPORTANT: This ADM is not compatible with batch mode LLM calls. | ||
| ''' | ||
| def __init__(self, | ||
| structured_inference_engine, | ||
| scenario_description_template, | ||
| prompt_template, | ||
| output_schema_template, | ||
| system_prompt_template, | ||
| history_mode | ||
| ): | ||
| self.structured_inference_engine = structured_inference_engine | ||
| self.scenario_description_template = scenario_description_template | ||
| self.prompt_template = prompt_template | ||
| self.output_schema_template = output_schema_template | ||
| self.system_prompt_template = system_prompt_template | ||
| self.history_mode = HistoryMode(history_mode) | ||
|
|
||
| self.dialog: Dialog = [] | ||
|
|
||
| def get_kdma(self, alignment_target): | ||
| kdma_values = alignment_target.kdma_values | ||
| if len(kdma_values) != 1: | ||
| raise RuntimeError("This ADM assumes a single KDMA target, aborting!") | ||
| kdma_value = kdma_values[0] | ||
| if isinstance(kdma_value, KDMAValue): | ||
| kdma_value = kdma_value.to_dict() | ||
|
|
||
| kdma = kdma_value['kdma'] | ||
| value = kdma_value['value'] | ||
| return kdma, value | ||
|
|
||
| def create_system_prompt(self, *, alignment_target, choices, scenario_description): | ||
| if self.system_prompt_template is not None: | ||
| kdma, value = self.get_kdma(alignment_target) if alignment_target else [None, None] | ||
| system_prompt = call_with_coerced_args( | ||
| self.system_prompt_template, | ||
| {'target_kdma': kdma, | ||
| 'target_value': value, | ||
| "choices": choices, | ||
| "scenario_description": scenario_description | ||
| }) | ||
| return str(system_prompt) | ||
|
|
||
| def create_user_prompt(self, scenario_state, choices): | ||
| scenario_description = call_with_coerced_args( | ||
| self.scenario_description_template, | ||
| {'scenario_state': scenario_state}) | ||
|
|
||
| # import web_pdb; web_pdb.set_trace() | ||
| def first_if_list(x): | ||
| return x[0] if isinstance(x, list) else x | ||
| actions = { | ||
| HistoryMode.in_dialog: {}, | ||
| HistoryMode.in_prompt: {'actions': [first_if_list(json.loads(x.content))["action_choice"] for x in self.dialog if x.role == 'assistant']}, | ||
| HistoryMode.in_prompt_with_reasoning: {'actions': [f'{first_if_list(json.loads(x.content))["action_choice"]}. Reasoning: {first_if_list(json.loads(x.content))["detailed_reasoning"]}' for x in self.dialog if x.role == 'assistant']}, | ||
| HistoryMode.off: {} | ||
| }[self.history_mode] | ||
|
|
||
| user_prompt = call_with_coerced_args( | ||
| self.prompt_template, | ||
| {'scenario_state': scenario_state, | ||
| 'scenario_description': scenario_description, | ||
| 'choices': choices, | ||
| **actions}, | ||
| partial=False) | ||
|
|
||
| return str(user_prompt) | ||
|
|
||
| def run_returns(self): | ||
| return ('chosen_choice', 'justification') | ||
|
|
||
| def run_sanity_check(self, scenario_state, | ||
| choices, | ||
| alignment_target): | ||
|
|
||
| # Add System Prompt First Time Only | ||
| # if len(self.dialog) == 0: | ||
| self.dialog.clear() | ||
|
|
||
| system_prompt = self.create_system_prompt(alignment_target=alignment_target, choices=choices, scenario_description=scenario_state.unstructured) | ||
| if system_prompt: | ||
| self.dialog.append(DialogElement(role='system', content=system_prompt)) | ||
|
|
||
| user_prompt = self.create_user_prompt(scenario_state, choices) | ||
| self.dialog.append(DialogElement(role="user", content=str(user_prompt))) | ||
|
|
||
| response = self.structured_inference_engine.run_inference( | ||
| prompts=self.structured_inference_engine.dialog_to_prompt(self.dialog), | ||
| schema=call_with_coerced_args(self.output_schema_template,{'choices': choices}) | ||
| ) | ||
|
|
||
| self.dialog.append(DialogElement(role="assistant", content=json.dumps(response))) | ||
|
|
||
| if isinstance(response, Sequence): | ||
| response = response[0] | ||
|
|
||
| chosen_choice = response['action_choice'] | ||
| justification = response['detailed_reasoning'] | ||
| return chosen_choice, justification, self.dialog | ||
|
|
||
|
|
||
| def run(self, scenario_state, | ||
| choices, | ||
| alignment_target): | ||
|
|
||
| # Add System Prompt First Time Only | ||
| if len(self.dialog) == 0: | ||
| system_prompt = self.create_system_prompt(alignment_target=alignment_target, choices=choices, scenario_description=scenario_state.unstructured) | ||
| if system_prompt: | ||
| self.dialog.append(DialogElement(role='system', content=system_prompt)) | ||
|
|
||
| user_prompt = self.create_user_prompt(scenario_state, choices) | ||
| self.dialog.append(DialogElement(role="user", content=str(user_prompt))) | ||
|
|
||
| _system_and_latest_prompt = [self.dialog[0], self.dialog[-1]] | ||
| mode_adjusted_dialog = { | ||
| HistoryMode.in_dialog: self.dialog, | ||
| HistoryMode.in_prompt: _system_and_latest_prompt, | ||
| HistoryMode.in_prompt_with_reasoning: _system_and_latest_prompt, | ||
| HistoryMode.off: _system_and_latest_prompt | ||
| }[self.history_mode] | ||
|
|
||
| response = self.structured_inference_engine.run_inference( | ||
| prompts=self.structured_inference_engine.dialog_to_prompt(mode_adjusted_dialog), | ||
| schema=call_with_coerced_args(self.output_schema_template,{'choices': choices}) | ||
| ) | ||
|
|
||
| self.dialog.append(DialogElement(role="assistant", content=json.dumps(response))) | ||
|
|
||
| if isinstance(response, Sequence): | ||
| response = response[0] | ||
|
|
||
| chosen_choice = response['action_choice'] | ||
| justification = response['detailed_reasoning'] | ||
| return chosen_choice, justification | ||
|
|
||
| def reset_history(self): | ||
| super().reset_history() | ||
| self.dialog.clear() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,31 @@ | ||
| from typing import List, Dict, Any | ||
|
|
||
| from align_system.algorithms.abstracts import ActionBasedADM, ADMComponent | ||
|
|
||
| class InMemoryState(): | ||
| def __init__(self): | ||
| self.state = [] | ||
| def save(self, state: str) -> None: | ||
| self.state.append(state) | ||
| def retreive(self) -> List[Any]: | ||
| return self.state | ||
|
|
||
| class FullStateInMemorySaverADM(ADMComponent): | ||
| def __init__(self, state: InMemoryState): | ||
| self.state = state | ||
| def run(self, state): | ||
| self.state.save(state) | ||
| def run_returns(self): | ||
| return "" | ||
|
|
||
|
|
||
| class PassthroughStateRetriever(ADMComponent): | ||
| def __init__(self, state: InMemoryState): | ||
| self.state = state | ||
| def run(self) -> Dict[str, List[str]]: | ||
| return {"previous_state": self.state.retreive()} | ||
| def run_returns(self): | ||
| return ["previous_state"] | ||
|
|
||
| # hydra to compose saver and retrieve with the same memorystate | ||
| # hydra to create the pipeline |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,31 @@ | ||
| name: open_world_dialog | ||
|
|
||
| defaults: | ||
| # - /inference_engine@structured_inference_engine: claude_haiku | ||
| # - /inference_engine@structured_inference_engine: outlines_structured_greedy | ||
| - /inference_engine@structured_inference_engine: outlines_structured_multinomial | ||
| - /adm_component/misc@step_definitions.format_choices: itm_format_choices | ||
| - /adm_component/history@step_definitions.dialog_builder: dialog_builder | ||
| - /adm_component/misc@step_definitions.ensure_chosen_action: ensure_chosen_action | ||
| - /adm_component/misc@step_definitions.populate_choice_info: populate_choice_info | ||
| - _self_ | ||
|
|
||
|
|
||
|
|
||
|
|
||
| # structured_inference_engine: | ||
| # model_name: deepseek-ai/DeepSeek-R1-Distill-Qwen-7B | ||
| # structured_inference_engine: | ||
| # _target_: align_system.algorithms.outlines_inference_engine.SpectrumTunedInferenceEngine | ||
| # model_name: tsor13/spectrum-Qwen3-14B-v1 | ||
| structured_inference_engine: | ||
| _target_: align_system.algorithms.outlines_inference_engine.SpectrumTunedInferenceEngine | ||
| model_name: tsor13/spectrum-Llama-3.1-8B-v1 | ||
|
|
||
| instance: | ||
| _target_: align_system.algorithms.pipeline_adm.PipelineADM | ||
| steps: | ||
| - ${ref:adm.step_definitions.format_choices} | ||
| - ${ref:adm.step_definitions.dialog_builder} | ||
| - ${ref:adm.step_definitions.ensure_chosen_action} | ||
| - ${ref:adm.step_definitions.populate_choice_info} |
16 changes: 16 additions & 0 deletions
16
align_system/configs/adm_component/history/dialog_builder.yaml
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,16 @@ | ||
| _target_: align_system.algorithms.open_world.open_world_dialog.BasicOpenWorldDialogADMComponent | ||
|
|
||
| structured_inference_engine: ${ref:adm.structured_inference_engine} | ||
|
|
||
| scenario_description_template: | ||
| _target_: align_system.prompt_engineering.outlines_prompts.DefaultITMScenarioDescription | ||
| prompt_template: | ||
| _target_: align_system.prompt_engineering.outlines_prompts.DefaultITMPrompt | ||
| output_schema_template: | ||
| _target_: align_system.prompt_engineering.outlines_prompts.StructuredOutputChoiceSelectionSchema | ||
| system_prompt_template: | ||
| _target_: align_system.prompt_engineering.outlines_prompts.Phase2BaselinePrompt | ||
| # history_mode: "IN_DIALOG" | ||
| history_mode: "IN_PROMPT" | ||
| # history_mode: "IN_PROMPT_WITH_REASONING" | ||
| # history_mode: "OFF" |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| _target_: align_system.drivers.itm_phase2_openworld.ITMPhase2OpenWorldDriver |
17 changes: 17 additions & 0 deletions
17
align_system/configs/experiment/statefulness_poc/baseline_open_world_dialog.yaml
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,17 @@ | ||
| # @package _global_ | ||
| # Baseline "marble through the system" experiment. | ||
| # Uses itm_phase1 driver (filtering disabled) to verify the open-world | ||
| # pipeline runs end-to-end before testing the full statefulness hypothesis. | ||
| defaults: | ||
| - override /adm: open_world_dialog | ||
| - override /driver: itm_phase1 | ||
| - override /interface: input_output_file | ||
|
|
||
| interface: | ||
| input_output_filepath: scripts/statefulness_poc/custom_input_output.json | ||
| state_hydration_domain: open_world | ||
|
|
||
| align_to_target: false | ||
|
|
||
| driver: | ||
| apply_action_filtering: false |
18 changes: 18 additions & 0 deletions
18
align_system/configs/experiment/statefulness_poc/open_world_dialog.yaml
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,18 @@ | ||
| # @package _global_ | ||
| defaults: | ||
| - override /adm: open_world_dialog | ||
| - override /driver: itm_phase2_openworld | ||
| - override /interface: input_output_file | ||
|
|
||
| interface: | ||
| input_output_filepath: scripts/statefulness_poc/custom_input_output.json | ||
| state_hydration_domain: open_world | ||
|
|
||
| align_to_target: false | ||
|
|
||
| adm: | ||
| step_definitions: | ||
| dialog_builder: | ||
| system_prompt_template: null | ||
| output_schema_template: | ||
| _target_: align_system.prompt_engineering.outlines_prompts.StructuredOutputChoiceSelectionSchema |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Addes a
reset_historydefault NOOP for backwards compatibility and so that we do not have to checkhasattrin the code.In fact I would push to do this same pattern for
output_conflict_resolverto avoid the hasattr check for that too.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I agree with you that doing the same thing with
output_conflict_resolverwould tidy some things up; however the way it's set up currently it's configurable at the hydra config level, but this change would only make it configurable at the code level (I think)? FWIW we don't use this pattern very often but here's one config where it's used: https://github.com/ITM-Kitware/align-system/blob/main/align_system/configs/adm/phase2_pipeline_fewshot_comparative_regression_swap_average.yaml#L53Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool, so in the ABC it would look like:
i.e. a Callable type noop, and the hydra in your example would override it in the subclass with the injected callable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, can you override members like that in the hydra config? I was under the impression that you can only set / pass initialization args with hydra instantiation. If the config I linked works without modification with this ABC update you're suggesting then I'm in favor of this change.