Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions align_system/algorithms/abstracts.py

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addes a reset_history default NOOP for backwards compatibility and so that we do not have to check hasattr in the code.
In fact I would push to do this same pattern for output_conflict_resolver to avoid the hasattr check for that too.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I agree with you that doing the same thing with output_conflict_resolver would tidy some things up; however the way it's set up currently it's configurable at the hydra config level, but this change would only make it configurable at the code level (I think)? FWIW we don't use this pattern very often but here's one config where it's used: https://github.com/ITM-Kitware/align-system/blob/main/align_system/configs/adm/phase2_pipeline_fewshot_comparative_regression_swap_average.yaml#L53

@ygefen ygefen May 21, 2026

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool, so in the ABC it would look like:

def output_conflict_resolver(*args, **kwargs):
    return *args, **kwargs

i.e. a Callable type noop, and the hydra in your example would override it in the subclass with the injected callable.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, can you override members like that in the hydra config? I was under the impression that you can only set / pass initialization args with hydra instantiation. If the config I linked works without modification with this ABC update you're suggesting then I'm in favor of this change.

Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ def choose_action(self,
**kwargs) -> Union[Action, tuple[Action, dict]]:
pass

def reset_history(self):
pass

class StructuredInferenceEngine(ABC):
@abstractmethod
Expand All @@ -38,3 +40,6 @@ def run_returns(self) -> Union[str, Iterable[str]]:
returns expect from the `run` method
'''
pass

def reset_history(self):
pass
Empty file.
159 changes: 159 additions & 0 deletions align_system/algorithms/open_world/open_world_dialog.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
import json
from typing import Sequence
from rich.highlighter import JSONHighlighter
from swagger_client.models import KDMAValue
from enum import Enum
from align_system.utils import logging, call_with_coerced_args
from align_system.algorithms.abstracts import ADMComponent
from align_system.data_models.dialog import DialogElement, Dialog

log = logging.getLogger(__name__)
JSON_HIGHLIGHTER = JSONHighlighter()

class HistoryMode(Enum):
in_dialog = "IN_DIALOG"
in_prompt = "IN_PROMPT"
in_prompt_with_reasoning = "IN_PROMPT_WITH_REASONING"
off="OFF"

class BasicOpenWorldDialogADMComponent(ADMComponent):
'''
IMPORTANT: This ADM is not compatible with batch mode LLM calls.
'''
def __init__(self,
structured_inference_engine,
scenario_description_template,
prompt_template,
output_schema_template,
system_prompt_template,
history_mode
):
self.structured_inference_engine = structured_inference_engine
self.scenario_description_template = scenario_description_template
self.prompt_template = prompt_template
self.output_schema_template = output_schema_template
self.system_prompt_template = system_prompt_template
self.history_mode = HistoryMode(history_mode)

self.dialog: Dialog = []

def get_kdma(self, alignment_target):
kdma_values = alignment_target.kdma_values
if len(kdma_values) != 1:
raise RuntimeError("This ADM assumes a single KDMA target, aborting!")
kdma_value = kdma_values[0]
if isinstance(kdma_value, KDMAValue):
kdma_value = kdma_value.to_dict()

kdma = kdma_value['kdma']
value = kdma_value['value']
return kdma, value

def create_system_prompt(self, *, alignment_target, choices, scenario_description):
if self.system_prompt_template is not None:
kdma, value = self.get_kdma(alignment_target) if alignment_target else [None, None]
system_prompt = call_with_coerced_args(
self.system_prompt_template,
{'target_kdma': kdma,
'target_value': value,
"choices": choices,
"scenario_description": scenario_description
})
return str(system_prompt)

def create_user_prompt(self, scenario_state, choices):
scenario_description = call_with_coerced_args(
self.scenario_description_template,
{'scenario_state': scenario_state})

# import web_pdb; web_pdb.set_trace()
def first_if_list(x):
return x[0] if isinstance(x, list) else x
actions = {
HistoryMode.in_dialog: {},
HistoryMode.in_prompt: {'actions': [first_if_list(json.loads(x.content))["action_choice"] for x in self.dialog if x.role == 'assistant']},
HistoryMode.in_prompt_with_reasoning: {'actions': [f'{first_if_list(json.loads(x.content))["action_choice"]}. Reasoning: {first_if_list(json.loads(x.content))["detailed_reasoning"]}' for x in self.dialog if x.role == 'assistant']},
HistoryMode.off: {}
}[self.history_mode]

user_prompt = call_with_coerced_args(
self.prompt_template,
{'scenario_state': scenario_state,
'scenario_description': scenario_description,
'choices': choices,
**actions},
partial=False)

return str(user_prompt)

def run_returns(self):
return ('chosen_choice', 'justification')

def run_sanity_check(self, scenario_state,
choices,
alignment_target):

# Add System Prompt First Time Only
# if len(self.dialog) == 0:
self.dialog.clear()

system_prompt = self.create_system_prompt(alignment_target=alignment_target, choices=choices, scenario_description=scenario_state.unstructured)
if system_prompt:
self.dialog.append(DialogElement(role='system', content=system_prompt))

user_prompt = self.create_user_prompt(scenario_state, choices)
self.dialog.append(DialogElement(role="user", content=str(user_prompt)))

response = self.structured_inference_engine.run_inference(
prompts=self.structured_inference_engine.dialog_to_prompt(self.dialog),
schema=call_with_coerced_args(self.output_schema_template,{'choices': choices})
)

self.dialog.append(DialogElement(role="assistant", content=json.dumps(response)))

if isinstance(response, Sequence):
response = response[0]

chosen_choice = response['action_choice']
justification = response['detailed_reasoning']
return chosen_choice, justification, self.dialog


def run(self, scenario_state,
choices,
alignment_target):

# Add System Prompt First Time Only
if len(self.dialog) == 0:
system_prompt = self.create_system_prompt(alignment_target=alignment_target, choices=choices, scenario_description=scenario_state.unstructured)
if system_prompt:
self.dialog.append(DialogElement(role='system', content=system_prompt))

user_prompt = self.create_user_prompt(scenario_state, choices)
self.dialog.append(DialogElement(role="user", content=str(user_prompt)))

_system_and_latest_prompt = [self.dialog[0], self.dialog[-1]]
mode_adjusted_dialog = {
HistoryMode.in_dialog: self.dialog,
HistoryMode.in_prompt: _system_and_latest_prompt,
HistoryMode.in_prompt_with_reasoning: _system_and_latest_prompt,
HistoryMode.off: _system_and_latest_prompt
}[self.history_mode]

response = self.structured_inference_engine.run_inference(
prompts=self.structured_inference_engine.dialog_to_prompt(mode_adjusted_dialog),
schema=call_with_coerced_args(self.output_schema_template,{'choices': choices})
)

self.dialog.append(DialogElement(role="assistant", content=json.dumps(response)))

if isinstance(response, Sequence):
response = response[0]

chosen_choice = response['action_choice']
justification = response['detailed_reasoning']
return chosen_choice, justification

def reset_history(self):
super().reset_history()
self.dialog.clear()
13 changes: 10 additions & 3 deletions align_system/algorithms/outlines_inference_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def __init__(

def dialog_to_prompt(self, dialog):
tokenizer = self.model.tokenizer.tokenizer

# import web_pdb; web_pdb.set_trace()
try:
encoded_dialog = tokenizer.apply_chat_template(dialog)
except jinja2.exceptions.TemplateError:
Expand All @@ -92,7 +92,9 @@ def dialog_to_prompt(self, dialog):
dialog = [{"role": "user", "content": updated_content}, *rest]

encoded_dialog = tokenizer.apply_chat_template(dialog)

with open("prompt.txt", 'a') as f:
f.write(tokenizer.decode(encoded_dialog))
f.write("\n\n------------------------------------------------------------------\n\n")
return tokenizer.decode(encoded_dialog)

# Function borrowed from
Expand Down Expand Up @@ -128,6 +130,11 @@ def run_in_batches(
return outputs

def run_inference(self, prompts, schema):
# print("run inference")
# print(type(prompts))
# print(isinstance(prompts, Iterable))
# # print(json.dumps(json.loads(prompts), indent=2))
# print(prompts)
json_schema = JsonSchema(schema, whitespace_pattern=r"[ ]?")

generator = outlines.Generator(self.model, json_schema)
Expand Down Expand Up @@ -201,7 +208,7 @@ def dialog_to_prompt(self, dialog):
element.role = "input"
elif element.role == "assistant":
element.role = "output"
else:
elif element.role not in ("description", "input", "output"):
raise RuntimeError(f"{element.role} dialog element unrecognized.")

try:
Expand Down
19 changes: 15 additions & 4 deletions align_system/algorithms/pipeline_adm.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from collections import deque
from collections.abc import Iterable
from timeit import default_timer as timer

Expand All @@ -8,8 +9,9 @@


class PipelineADM(ActionBasedADM):
def __init__(self, steps: list[ADMComponent]):
def __init__(self, steps: list[ADMComponent], history_window=None):
self.steps = steps
self.history = deque(maxlen=history_window)
Comment thread
ygefen marked this conversation as resolved.

def choose_action(self,
scenario_state,
Expand All @@ -21,15 +23,18 @@ def choose_action(self,
'actions': available_actions,
'alignment_target': alignment_target,
**kwargs}


per_step_timing_stats = []

for i, step in enumerate(self.steps):
# import web_pdb; web_pdb.set_trace()
step_returns = step.run_returns()

start_time = timer()
# Run the step
run_output = call_with_coerced_args(step.run, working_output)
# Run the step, temporarily adding historical working outputs to working_output
working_output_with_history = {**{"history": list(self.history)}, **working_output}
Comment thread
ygefen marked this conversation as resolved.
run_output = call_with_coerced_args(step.run, working_output_with_history)
end_time = timer()

per_step_timing_stats.append(
Expand Down Expand Up @@ -73,5 +78,11 @@ def choose_action(self,

working_output.setdefault('choice_info', {})['per_step_timing_stats'] =\
per_step_timing_stats

self.history.append(working_output)
return working_output['chosen_action'], working_output

def reset_history(self):
self.history.clear()
for step in self.steps:
if hasattr(step, 'reset_history'):
step.reset_history()
31 changes: 31 additions & 0 deletions align_system/algorithms/state_feedback_adm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from typing import List, Dict, Any

from align_system.algorithms.abstracts import ActionBasedADM, ADMComponent

class InMemoryState():
def __init__(self):
self.state = []
def save(self, state: str) -> None:
self.state.append(state)
def retreive(self) -> List[Any]:
return self.state

class FullStateInMemorySaverADM(ADMComponent):
def __init__(self, state: InMemoryState):
self.state = state
def run(self, state):
self.state.save(state)
def run_returns(self):
return ""


class PassthroughStateRetriever(ADMComponent):
def __init__(self, state: InMemoryState):
self.state = state
def run(self) -> Dict[str, List[str]]:
return {"previous_state": self.state.retreive()}
def run_returns(self):
return ["previous_state"]

# hydra to compose saver and retrieve with the same memorystate
# hydra to create the pipeline
31 changes: 31 additions & 0 deletions align_system/configs/adm/open_world_dialog.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
name: open_world_dialog

defaults:
# - /inference_engine@structured_inference_engine: claude_haiku
# - /inference_engine@structured_inference_engine: outlines_structured_greedy
- /inference_engine@structured_inference_engine: outlines_structured_multinomial
- /adm_component/misc@step_definitions.format_choices: itm_format_choices
- /adm_component/history@step_definitions.dialog_builder: dialog_builder
- /adm_component/misc@step_definitions.ensure_chosen_action: ensure_chosen_action
- /adm_component/misc@step_definitions.populate_choice_info: populate_choice_info
- _self_




# structured_inference_engine:
# model_name: deepseek-ai/DeepSeek-R1-Distill-Qwen-7B
# structured_inference_engine:
# _target_: align_system.algorithms.outlines_inference_engine.SpectrumTunedInferenceEngine
# model_name: tsor13/spectrum-Qwen3-14B-v1
structured_inference_engine:
_target_: align_system.algorithms.outlines_inference_engine.SpectrumTunedInferenceEngine
model_name: tsor13/spectrum-Llama-3.1-8B-v1

instance:
_target_: align_system.algorithms.pipeline_adm.PipelineADM
steps:
- ${ref:adm.step_definitions.format_choices}
- ${ref:adm.step_definitions.dialog_builder}
- ${ref:adm.step_definitions.ensure_chosen_action}
- ${ref:adm.step_definitions.populate_choice_info}
16 changes: 16 additions & 0 deletions align_system/configs/adm_component/history/dialog_builder.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
_target_: align_system.algorithms.open_world.open_world_dialog.BasicOpenWorldDialogADMComponent

structured_inference_engine: ${ref:adm.structured_inference_engine}

scenario_description_template:
_target_: align_system.prompt_engineering.outlines_prompts.DefaultITMScenarioDescription
prompt_template:
_target_: align_system.prompt_engineering.outlines_prompts.DefaultITMPrompt
output_schema_template:
_target_: align_system.prompt_engineering.outlines_prompts.StructuredOutputChoiceSelectionSchema
system_prompt_template:
_target_: align_system.prompt_engineering.outlines_prompts.Phase2BaselinePrompt
# history_mode: "IN_DIALOG"
history_mode: "IN_PROMPT"
# history_mode: "IN_PROMPT_WITH_REASONING"
# history_mode: "OFF"
1 change: 1 addition & 0 deletions align_system/configs/driver/itm_phase2_openworld.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
_target_: align_system.drivers.itm_phase2_openworld.ITMPhase2OpenWorldDriver
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# @package _global_
# Baseline "marble through the system" experiment.
# Uses itm_phase1 driver (filtering disabled) to verify the open-world
# pipeline runs end-to-end before testing the full statefulness hypothesis.
defaults:
- override /adm: open_world_dialog
- override /driver: itm_phase1
- override /interface: input_output_file

interface:
input_output_filepath: scripts/statefulness_poc/custom_input_output.json
state_hydration_domain: open_world

align_to_target: false

driver:
apply_action_filtering: false
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# @package _global_
defaults:
- override /adm: open_world_dialog
- override /driver: itm_phase2_openworld
- override /interface: input_output_file

interface:
input_output_filepath: scripts/statefulness_poc/custom_input_output.json
state_hydration_domain: open_world

align_to_target: false

adm:
step_definitions:
dialog_builder:
system_prompt_template: null
output_schema_template:
_target_: align_system.prompt_engineering.outlines_prompts.StructuredOutputChoiceSelectionSchema
Loading