Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions align_system/algorithms/argmax_alignment_adm_component.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from __future__ import annotations

from align_system.algorithms.abstracts import ADMComponent
from align_system.utils import logging

log = logging.getLogger(__name__)


class ArgmaxAlignmentADMComponent(ADMComponent):
"""
Alignment step that picks the choice with the highest predicted
KDMA score, averaged across all attributes and samples.

Replaces the ADEPT random effects model for domains (like AI2Thor)
where no calibrated statistical model exists. For single-attribute
pipelines this reduces to a plain argmax over the LLM's scores.
"""

def __init__(self, attributes=None):
self.attributes = attributes or {}

def run_returns(self):
return ("chosen_choice", "best_sample_idx", "alignment_info")

def run(self, attribute_prediction_scores, alignment_target=None):
"""
attribute_prediction_scores: dict[choice_str, dict[kdma, list[float]]]
"""
choice_totals: dict[str, float] = {}

for choice, attr_scores in attribute_prediction_scores.items():
total = 0.0
count = 0
for kdma, scores in attr_scores.items():
vals = scores if isinstance(scores, list) else [scores]
if vals:
total += sum(vals) / len(vals)
count += 1
choice_totals[choice] = total / count if count else 0.0

best_choice = max(choice_totals, key=choice_totals.get)

log.info(f"[ArgmaxAlignment] scores: {choice_totals}")
log.info(f"[ArgmaxAlignment] chosen: {best_choice}")

alignment_info = {
"source": type(self).__name__,
"choice_scores": choice_totals,
}

return best_choice, 0, alignment_info
6 changes: 3 additions & 3 deletions align_system/algorithms/misc_itm_adm_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ def run_returns(self):
return ('chosen_action')

def run(self,
choices,
actions,
choices=None,
chosen_choice=None,
chosen_action=None,
justification=None):
Expand Down Expand Up @@ -76,8 +76,8 @@ def run_returns(self):
return 'choice_info'

def run(self,
choices,
actions,
choices=None,
alignment_target=None,
attribute_prediction_scores=None,
attribute_relevance=None,
Expand Down Expand Up @@ -105,7 +105,7 @@ def run(self,

true_kdma_values = {}
true_relevance = {}
for choice, action in zip(choices, actions):
for choice, action in zip(choices or [], actions):
if action.kdma_association is not None:
true_kdma_values[choice] = action.kdma_association
for kdma in target_kdmas:
Expand Down
135 changes: 135 additions & 0 deletions align_system/algorithms/ollama_inference_engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
from __future__ import annotations

import json

import ollama

from align_system.algorithms.abstracts import StructuredInferenceEngine
# from align_system.algorithms.planner_adm.llm_ollama import _extract_json_object, _repair_json
from align_system.utils import logging

log = logging.getLogger(__name__)

def _extract_json_object(s: str) -> str:
s = s.strip()
if s.startswith("```"):
parts = s.split("```")
if len(parts) >= 3:
s = parts[1].strip()
i = s.find("{")
j = s.rfind("}")
if i == -1 or j == -1 or j <= i:
raise ValueError("No JSON object found")
return s[i : j + 1]

def _loads_json(s: str) -> JSON:
return json.loads(_extract_json_object(s))


def _repair_json(model: str, bad_text: str, schema_hint: str, num_ctx: int) -> JSON:
prompt = (
"You output invalid JSON. Fix it.\n"
"Return ONLY valid JSON, no prose.\n"
f"Schema hint:\n{schema_hint}\n\n"
f"Bad output:\n{bad_text}\n"
)
resp = ollama.generate(model=model, prompt=prompt, options={"temperature": 0.0, "num_ctx": num_ctx})
return _loads_json(resp["response"])



class OllamaInferenceEngine(StructuredInferenceEngine):
"""
StructuredInferenceEngine backed by a local Ollama model.

Unlike the Outlines engine, output is not grammar-constrained — the
schema is appended to the prompt as an instruction and the response
is parsed as JSON with a repair fallback.
"""
Comment on lines +45 to +48

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It does look like Ollama supports structured output with a JSON schema: https://ollama.com/blog/structured-outputs. Seems like we should use that if possible


def __init__(
self,
model: str = "gpt-oss:20b",
temperature: float = 0.0,
num_ctx: int = 8192,
json_repair_attempts: int = 2,
):
self.model = model
self.temperature = temperature
self.num_ctx = num_ctx
self.json_repair_attempts = json_repair_attempts

def dialog_to_prompt(self, dialog) -> str:
"""
Flatten a dialog list into a plain-text prompt for Ollama.

System messages are prepended as an unlabelled block so they
land at the top; user/assistant turns follow with role labels.
"""
system_parts = []
turn_parts = []

for elem in dialog:
role = elem.role if hasattr(elem, "role") else elem["role"]
content = elem.content if hasattr(elem, "content") else elem["content"]

if role == "system":
system_parts.append(content)
else:
turn_parts.append(f"[{role.upper()}]\n{content}")

parts = []
if system_parts:
parts.append("\n\n".join(system_parts))
parts.extend(turn_parts)
return "\n\n".join(parts)

def run_inference(self, prompts, schema: str, temperature: float = None) -> list[dict]:
"""
Run inference for each prompt string and return parsed JSON dicts.

`schema` is a JSON Schema string appended to each prompt as an
instruction. Malformed responses trigger up to
`json_repair_attempts` self-repair calls.
"""
schema_instruction = (
"\n\nRespond with ONLY valid JSON that matches this schema "
"(no prose, no markdown fences):\n" + schema
)
effective_temperature = self.temperature if temperature is None else temperature

results = []
for prompt in prompts:
full_prompt = prompt + schema_instruction
resp = ollama.generate(
model=self.model,
prompt=full_prompt,
options={"temperature": effective_temperature, "num_ctx": self.num_ctx},
)
text = resp["response"]
log.debug(f"[OllamaInferenceEngine] raw response:\n{text}")

data = None
try:
data = json.loads(_extract_json_object(text))
except Exception:
for _ in range(self.json_repair_attempts):
try:
data = _repair_json(self.model, text, schema, self.num_ctx)
break
except Exception:
data = None

if data is None:
log.warning("[OllamaInferenceEngine] JSON parse failed; returning empty dict")
data = {}

results.append(data)

return results

def cache_repr(self) -> str:
return (
f"OllamaInferenceEngine(model={self.model}, "
f"temperature={self.temperature}, num_ctx={self.num_ctx})"
)
85 changes: 73 additions & 12 deletions align_system/algorithms/outlines_inference_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,11 @@ def __init__(
# newer verion of outlines fixes this issue, but we are blocked with the vllm dependency
self.model.tokenizer.is_llama = True

# If generation_kwargs includes temperature, enable sampling in the model's
# generation_config so transformers doesn't warn that temperature is invalid.
if self.generation_kwargs.get("temperature", 0.0) > 0:
self.model.model.generation_config.do_sample = True

def dialog_to_prompt(self, dialog):
tokenizer = self.model.tokenizer.tokenizer

Expand Down Expand Up @@ -127,40 +132,96 @@ def run_in_batches(
outputs.extend(output)
return outputs

def run_inference(self, prompts, schema):
def _parse_json(self, text: str) -> dict:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not really excited about the idea of including this here. Have you been running into JSON validation etc. errors with the outlines inference engine here or? Just curious why this is even needed.

text = text.strip()
if text.startswith("```"):
parts = text.split("```")
if len(parts) >= 3:
text = parts[1].strip()
i = text.find("{")
j = text.rfind("}")
if i != -1 and j > i:
text = text[i:j + 1]
return json.loads(text)

def _prompt_based_inference(self, prompts, schema) -> list[dict]:
"""Fallback: append schema as a prompt hint and parse free-text output."""
schema_instruction = (
"\n\nRespond with ONLY valid JSON matching this example "
"(no prose, no markdown fences):\n" + schema
)
if isinstance(prompts, str):
prompts = [prompts]
generator = outlines.Generator(self.model)
results = []
for prompt in prompts:
raw = generator(
prompt + schema_instruction,
max_new_tokens=self.max_generator_tokens,
**self.generation_kwargs,
)
try:
results.append(self._parse_json(raw))
except Exception:
results.append({})
return results

def run_inference(self, prompts, schema, temperature: float = None):
json_schema = JsonSchema(schema, whitespace_pattern=r"[ ]?")
try:
generator = outlines.Generator(self.model, json_schema)
except (ValueError, Exception):
# schema is a JSON example/template rather than a proper JSON
# Schema — fall back to prompt-based generation + parsing
return self._prompt_based_inference(prompts, schema)

gen_kwargs = dict(self.generation_kwargs)
if temperature is not None:
gen_kwargs["temperature"] = temperature
gen_kwargs["do_sample"] = True

generator = outlines.Generator(self.model, json_schema)
if isinstance(prompts, str):
output = generator(
prompts,
max_new_tokens=self.max_generator_tokens,
**self.generation_kwargs,
**gen_kwargs,
)
return json.loads(output)
try:
return json.loads(output)
except json.JSONDecodeError as e:
raise RuntimeError(
f"Failed to parse structured generation output as JSON "
f"(output may be truncated; consider increasing "
f"max_generator_tokens above {self.max_generator_tokens}). "
f"Raw output: {output!r}. Original error: {e}"
) from e
elif isinstance(prompts, Iterable):
output = self.run_in_batches(
generator.batch,
prompts,
self.inference_batch_size,
self.max_generator_tokens,
**self.generation_kwargs,
**gen_kwargs,
)
return [json.loads(r) for r in output]
try:
return [json.loads(r) for r in output]
except json.JSONDecodeError as e:
raise RuntimeError(
f"Failed to parse structured generation output as JSON "
f"(output may be truncated; consider increasing "
f"max_generator_tokens above {self.max_generator_tokens}). "
f"Raw output: {output!r}. Original error: {e}"
) from e
else:
raise TypeError(
"Don't know how to run inference on provided `prompts` object"
)

def run_inference_unstructured(self, prompts):
generator = outlines.generate.regex(
self.model,
r".*", # "allow anything" regex
**self.generation_kwargs,
)
generator = outlines.Generator(self.model)

if isinstance(prompts, str):
return generator(prompts, self.max_generator_tokens)
return generator(prompts, max_new_tokens=self.max_generator_tokens, **self.generation_kwargs)
elif isinstance(prompts, Iterable):
return self.run_in_batches(
generator, prompts, self.inference_batch_size, self.max_generator_tokens
Expand Down
12 changes: 12 additions & 0 deletions align_system/algorithms/pipeline_adm.py

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's fine to have the history tracking as you have it in here for now, but I'm more inclined to merge Yoni's approach on this: https://github.com/ITM-Kitware/align-system/pull/277/changes#diff-ea512e45fac46d4935ce85a4837bdbcd27b5891a09c1ac5f6aad038076f4d497

As it maintains the full working_output history.

Original file line number Diff line number Diff line change
Expand Up @@ -75,3 +75,15 @@ def choose_action(self,
per_step_timing_stats

return working_output['chosen_action'], working_output

def reset_history(self) -> None:
"""Delegate to any pipeline step that maintains its own history."""
for step in self.steps:
if hasattr(step, 'reset_history'):
step.reset_history()

def update_history(self, chosen_action) -> None:
"""Delegate to any pipeline step that maintains its own history."""
for step in self.steps:
if hasattr(step, 'update_history'):
step.update_history(chosen_action)
Loading