-
Notifications
You must be signed in to change notification settings - Fork 5
added ai2thor enviroment with a method that proposes steps and uses C… #280
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,51 @@ | ||
| from __future__ import annotations | ||
|
|
||
| from align_system.algorithms.abstracts import ADMComponent | ||
| from align_system.utils import logging | ||
|
|
||
| log = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| class ArgmaxAlignmentADMComponent(ADMComponent): | ||
| """ | ||
| Alignment step that picks the choice with the highest predicted | ||
| KDMA score, averaged across all attributes and samples. | ||
|
|
||
| Replaces the ADEPT random effects model for domains (like AI2Thor) | ||
| where no calibrated statistical model exists. For single-attribute | ||
| pipelines this reduces to a plain argmax over the LLM's scores. | ||
| """ | ||
|
|
||
| def __init__(self, attributes=None): | ||
| self.attributes = attributes or {} | ||
|
|
||
| def run_returns(self): | ||
| return ("chosen_choice", "best_sample_idx", "alignment_info") | ||
|
|
||
| def run(self, attribute_prediction_scores, alignment_target=None): | ||
| """ | ||
| attribute_prediction_scores: dict[choice_str, dict[kdma, list[float]]] | ||
| """ | ||
| choice_totals: dict[str, float] = {} | ||
|
|
||
| for choice, attr_scores in attribute_prediction_scores.items(): | ||
| total = 0.0 | ||
| count = 0 | ||
| for kdma, scores in attr_scores.items(): | ||
| vals = scores if isinstance(scores, list) else [scores] | ||
| if vals: | ||
| total += sum(vals) / len(vals) | ||
| count += 1 | ||
| choice_totals[choice] = total / count if count else 0.0 | ||
|
|
||
| best_choice = max(choice_totals, key=choice_totals.get) | ||
|
|
||
| log.info(f"[ArgmaxAlignment] scores: {choice_totals}") | ||
| log.info(f"[ArgmaxAlignment] chosen: {best_choice}") | ||
|
|
||
| alignment_info = { | ||
| "source": type(self).__name__, | ||
| "choice_scores": choice_totals, | ||
| } | ||
|
|
||
| return best_choice, 0, alignment_info |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,135 @@ | ||
| from __future__ import annotations | ||
|
|
||
| import json | ||
|
|
||
| import ollama | ||
|
|
||
| from align_system.algorithms.abstracts import StructuredInferenceEngine | ||
| # from align_system.algorithms.planner_adm.llm_ollama import _extract_json_object, _repair_json | ||
| from align_system.utils import logging | ||
|
|
||
| log = logging.getLogger(__name__) | ||
|
|
||
| def _extract_json_object(s: str) -> str: | ||
| s = s.strip() | ||
| if s.startswith("```"): | ||
| parts = s.split("```") | ||
| if len(parts) >= 3: | ||
| s = parts[1].strip() | ||
| i = s.find("{") | ||
| j = s.rfind("}") | ||
| if i == -1 or j == -1 or j <= i: | ||
| raise ValueError("No JSON object found") | ||
| return s[i : j + 1] | ||
|
|
||
| def _loads_json(s: str) -> JSON: | ||
| return json.loads(_extract_json_object(s)) | ||
|
|
||
|
|
||
| def _repair_json(model: str, bad_text: str, schema_hint: str, num_ctx: int) -> JSON: | ||
| prompt = ( | ||
| "You output invalid JSON. Fix it.\n" | ||
| "Return ONLY valid JSON, no prose.\n" | ||
| f"Schema hint:\n{schema_hint}\n\n" | ||
| f"Bad output:\n{bad_text}\n" | ||
| ) | ||
| resp = ollama.generate(model=model, prompt=prompt, options={"temperature": 0.0, "num_ctx": num_ctx}) | ||
| return _loads_json(resp["response"]) | ||
|
|
||
|
|
||
|
|
||
| class OllamaInferenceEngine(StructuredInferenceEngine): | ||
| """ | ||
| StructuredInferenceEngine backed by a local Ollama model. | ||
|
|
||
| Unlike the Outlines engine, output is not grammar-constrained — the | ||
| schema is appended to the prompt as an instruction and the response | ||
| is parsed as JSON with a repair fallback. | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| model: str = "gpt-oss:20b", | ||
| temperature: float = 0.0, | ||
| num_ctx: int = 8192, | ||
| json_repair_attempts: int = 2, | ||
| ): | ||
| self.model = model | ||
| self.temperature = temperature | ||
| self.num_ctx = num_ctx | ||
| self.json_repair_attempts = json_repair_attempts | ||
|
|
||
| def dialog_to_prompt(self, dialog) -> str: | ||
| """ | ||
| Flatten a dialog list into a plain-text prompt for Ollama. | ||
|
|
||
| System messages are prepended as an unlabelled block so they | ||
| land at the top; user/assistant turns follow with role labels. | ||
| """ | ||
| system_parts = [] | ||
| turn_parts = [] | ||
|
|
||
| for elem in dialog: | ||
| role = elem.role if hasattr(elem, "role") else elem["role"] | ||
| content = elem.content if hasattr(elem, "content") else elem["content"] | ||
|
|
||
| if role == "system": | ||
| system_parts.append(content) | ||
| else: | ||
| turn_parts.append(f"[{role.upper()}]\n{content}") | ||
|
|
||
| parts = [] | ||
| if system_parts: | ||
| parts.append("\n\n".join(system_parts)) | ||
| parts.extend(turn_parts) | ||
| return "\n\n".join(parts) | ||
|
|
||
| def run_inference(self, prompts, schema: str, temperature: float = None) -> list[dict]: | ||
| """ | ||
| Run inference for each prompt string and return parsed JSON dicts. | ||
|
|
||
| `schema` is a JSON Schema string appended to each prompt as an | ||
| instruction. Malformed responses trigger up to | ||
| `json_repair_attempts` self-repair calls. | ||
| """ | ||
| schema_instruction = ( | ||
| "\n\nRespond with ONLY valid JSON that matches this schema " | ||
| "(no prose, no markdown fences):\n" + schema | ||
| ) | ||
| effective_temperature = self.temperature if temperature is None else temperature | ||
|
|
||
| results = [] | ||
| for prompt in prompts: | ||
| full_prompt = prompt + schema_instruction | ||
| resp = ollama.generate( | ||
| model=self.model, | ||
| prompt=full_prompt, | ||
| options={"temperature": effective_temperature, "num_ctx": self.num_ctx}, | ||
| ) | ||
| text = resp["response"] | ||
| log.debug(f"[OllamaInferenceEngine] raw response:\n{text}") | ||
|
|
||
| data = None | ||
| try: | ||
| data = json.loads(_extract_json_object(text)) | ||
| except Exception: | ||
| for _ in range(self.json_repair_attempts): | ||
| try: | ||
| data = _repair_json(self.model, text, schema, self.num_ctx) | ||
| break | ||
| except Exception: | ||
| data = None | ||
|
|
||
| if data is None: | ||
| log.warning("[OllamaInferenceEngine] JSON parse failed; returning empty dict") | ||
| data = {} | ||
|
|
||
| results.append(data) | ||
|
|
||
| return results | ||
|
|
||
| def cache_repr(self) -> str: | ||
| return ( | ||
| f"OllamaInferenceEngine(model={self.model}, " | ||
| f"temperature={self.temperature}, num_ctx={self.num_ctx})" | ||
| ) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -73,6 +73,11 @@ def __init__( | |
| # newer verion of outlines fixes this issue, but we are blocked with the vllm dependency | ||
| self.model.tokenizer.is_llama = True | ||
|
|
||
| # If generation_kwargs includes temperature, enable sampling in the model's | ||
| # generation_config so transformers doesn't warn that temperature is invalid. | ||
| if self.generation_kwargs.get("temperature", 0.0) > 0: | ||
| self.model.model.generation_config.do_sample = True | ||
|
|
||
| def dialog_to_prompt(self, dialog): | ||
| tokenizer = self.model.tokenizer.tokenizer | ||
|
|
||
|
|
@@ -127,40 +132,96 @@ def run_in_batches( | |
| outputs.extend(output) | ||
| return outputs | ||
|
|
||
| def run_inference(self, prompts, schema): | ||
| def _parse_json(self, text: str) -> dict: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not really excited about the idea of including this here. Have you been running into JSON validation etc. errors with the outlines inference engine here or? Just curious why this is even needed. |
||
| text = text.strip() | ||
| if text.startswith("```"): | ||
| parts = text.split("```") | ||
| if len(parts) >= 3: | ||
| text = parts[1].strip() | ||
| i = text.find("{") | ||
| j = text.rfind("}") | ||
| if i != -1 and j > i: | ||
| text = text[i:j + 1] | ||
| return json.loads(text) | ||
|
|
||
| def _prompt_based_inference(self, prompts, schema) -> list[dict]: | ||
| """Fallback: append schema as a prompt hint and parse free-text output.""" | ||
| schema_instruction = ( | ||
| "\n\nRespond with ONLY valid JSON matching this example " | ||
| "(no prose, no markdown fences):\n" + schema | ||
| ) | ||
| if isinstance(prompts, str): | ||
| prompts = [prompts] | ||
| generator = outlines.Generator(self.model) | ||
| results = [] | ||
| for prompt in prompts: | ||
| raw = generator( | ||
| prompt + schema_instruction, | ||
| max_new_tokens=self.max_generator_tokens, | ||
| **self.generation_kwargs, | ||
| ) | ||
| try: | ||
| results.append(self._parse_json(raw)) | ||
| except Exception: | ||
| results.append({}) | ||
| return results | ||
|
|
||
| def run_inference(self, prompts, schema, temperature: float = None): | ||
| json_schema = JsonSchema(schema, whitespace_pattern=r"[ ]?") | ||
| try: | ||
| generator = outlines.Generator(self.model, json_schema) | ||
| except (ValueError, Exception): | ||
| # schema is a JSON example/template rather than a proper JSON | ||
| # Schema — fall back to prompt-based generation + parsing | ||
| return self._prompt_based_inference(prompts, schema) | ||
|
|
||
| gen_kwargs = dict(self.generation_kwargs) | ||
| if temperature is not None: | ||
| gen_kwargs["temperature"] = temperature | ||
| gen_kwargs["do_sample"] = True | ||
|
|
||
| generator = outlines.Generator(self.model, json_schema) | ||
| if isinstance(prompts, str): | ||
| output = generator( | ||
| prompts, | ||
| max_new_tokens=self.max_generator_tokens, | ||
| **self.generation_kwargs, | ||
| **gen_kwargs, | ||
| ) | ||
| return json.loads(output) | ||
| try: | ||
| return json.loads(output) | ||
| except json.JSONDecodeError as e: | ||
| raise RuntimeError( | ||
| f"Failed to parse structured generation output as JSON " | ||
| f"(output may be truncated; consider increasing " | ||
| f"max_generator_tokens above {self.max_generator_tokens}). " | ||
| f"Raw output: {output!r}. Original error: {e}" | ||
| ) from e | ||
| elif isinstance(prompts, Iterable): | ||
| output = self.run_in_batches( | ||
| generator.batch, | ||
| prompts, | ||
| self.inference_batch_size, | ||
| self.max_generator_tokens, | ||
| **self.generation_kwargs, | ||
| **gen_kwargs, | ||
| ) | ||
| return [json.loads(r) for r in output] | ||
| try: | ||
| return [json.loads(r) for r in output] | ||
| except json.JSONDecodeError as e: | ||
| raise RuntimeError( | ||
| f"Failed to parse structured generation output as JSON " | ||
| f"(output may be truncated; consider increasing " | ||
| f"max_generator_tokens above {self.max_generator_tokens}). " | ||
| f"Raw output: {output!r}. Original error: {e}" | ||
| ) from e | ||
| else: | ||
| raise TypeError( | ||
| "Don't know how to run inference on provided `prompts` object" | ||
| ) | ||
|
|
||
| def run_inference_unstructured(self, prompts): | ||
| generator = outlines.generate.regex( | ||
| self.model, | ||
| r".*", # "allow anything" regex | ||
| **self.generation_kwargs, | ||
| ) | ||
| generator = outlines.Generator(self.model) | ||
|
|
||
| if isinstance(prompts, str): | ||
| return generator(prompts, self.max_generator_tokens) | ||
| return generator(prompts, max_new_tokens=self.max_generator_tokens, **self.generation_kwargs) | ||
| elif isinstance(prompts, Iterable): | ||
| return self.run_in_batches( | ||
| generator, prompts, self.inference_batch_size, self.max_generator_tokens | ||
|
|
||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it's fine to have the history tracking as you have it in here for now, but I'm more inclined to merge Yoni's approach on this: https://github.com/ITM-Kitware/align-system/pull/277/changes#diff-ea512e45fac46d4935ce85a4837bdbcd27b5891a09c1ac5f6aad038076f4d497 As it maintains the full working_output history. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It does look like Ollama supports structured output with a JSON schema: https://ollama.com/blog/structured-outputs. Seems like we should use that if possible