diff --git a/align_system/algorithms/argmax_alignment_adm_component.py b/align_system/algorithms/argmax_alignment_adm_component.py new file mode 100644 index 00000000..6651057c --- /dev/null +++ b/align_system/algorithms/argmax_alignment_adm_component.py @@ -0,0 +1,51 @@ +from __future__ import annotations + +from align_system.algorithms.abstracts import ADMComponent +from align_system.utils import logging + +log = logging.getLogger(__name__) + + +class ArgmaxAlignmentADMComponent(ADMComponent): + """ + Alignment step that picks the choice with the highest predicted + KDMA score, averaged across all attributes and samples. + + Replaces the ADEPT random effects model for domains (like AI2Thor) + where no calibrated statistical model exists. For single-attribute + pipelines this reduces to a plain argmax over the LLM's scores. + """ + + def __init__(self, attributes=None): + self.attributes = attributes or {} + + def run_returns(self): + return ("chosen_choice", "best_sample_idx", "alignment_info") + + def run(self, attribute_prediction_scores, alignment_target=None): + """ + attribute_prediction_scores: dict[choice_str, dict[kdma, list[float]]] + """ + choice_totals: dict[str, float] = {} + + for choice, attr_scores in attribute_prediction_scores.items(): + total = 0.0 + count = 0 + for kdma, scores in attr_scores.items(): + vals = scores if isinstance(scores, list) else [scores] + if vals: + total += sum(vals) / len(vals) + count += 1 + choice_totals[choice] = total / count if count else 0.0 + + best_choice = max(choice_totals, key=choice_totals.get) + + log.info(f"[ArgmaxAlignment] scores: {choice_totals}") + log.info(f"[ArgmaxAlignment] chosen: {best_choice}") + + alignment_info = { + "source": type(self).__name__, + "choice_scores": choice_totals, + } + + return best_choice, 0, alignment_info diff --git a/align_system/algorithms/misc_itm_adm_components.py b/align_system/algorithms/misc_itm_adm_components.py index a2c6d24f..84015c12 100644 --- a/align_system/algorithms/misc_itm_adm_components.py +++ b/align_system/algorithms/misc_itm_adm_components.py @@ -15,8 +15,8 @@ def run_returns(self): return ('chosen_action') def run(self, - choices, actions, + choices=None, chosen_choice=None, chosen_action=None, justification=None): @@ -76,8 +76,8 @@ def run_returns(self): return 'choice_info' def run(self, - choices, actions, + choices=None, alignment_target=None, attribute_prediction_scores=None, attribute_relevance=None, @@ -105,7 +105,7 @@ def run(self, true_kdma_values = {} true_relevance = {} - for choice, action in zip(choices, actions): + for choice, action in zip(choices or [], actions): if action.kdma_association is not None: true_kdma_values[choice] = action.kdma_association for kdma in target_kdmas: diff --git a/align_system/algorithms/ollama_inference_engine.py b/align_system/algorithms/ollama_inference_engine.py new file mode 100644 index 00000000..9942ea4d --- /dev/null +++ b/align_system/algorithms/ollama_inference_engine.py @@ -0,0 +1,135 @@ +from __future__ import annotations + +import json + +import ollama + +from align_system.algorithms.abstracts import StructuredInferenceEngine +# from align_system.algorithms.planner_adm.llm_ollama import _extract_json_object, _repair_json +from align_system.utils import logging + +log = logging.getLogger(__name__) + +def _extract_json_object(s: str) -> str: + s = s.strip() + if s.startswith("```"): + parts = s.split("```") + if len(parts) >= 3: + s = parts[1].strip() + i = s.find("{") + j = s.rfind("}") + if i == -1 or j == -1 or j <= i: + raise ValueError("No JSON object found") + return s[i : j + 1] + +def _loads_json(s: str) -> JSON: + return json.loads(_extract_json_object(s)) + + +def _repair_json(model: str, bad_text: str, schema_hint: str, num_ctx: int) -> JSON: + prompt = ( + "You output invalid JSON. Fix it.\n" + "Return ONLY valid JSON, no prose.\n" + f"Schema hint:\n{schema_hint}\n\n" + f"Bad output:\n{bad_text}\n" + ) + resp = ollama.generate(model=model, prompt=prompt, options={"temperature": 0.0, "num_ctx": num_ctx}) + return _loads_json(resp["response"]) + + + +class OllamaInferenceEngine(StructuredInferenceEngine): + """ + StructuredInferenceEngine backed by a local Ollama model. + + Unlike the Outlines engine, output is not grammar-constrained — the + schema is appended to the prompt as an instruction and the response + is parsed as JSON with a repair fallback. + """ + + def __init__( + self, + model: str = "gpt-oss:20b", + temperature: float = 0.0, + num_ctx: int = 8192, + json_repair_attempts: int = 2, + ): + self.model = model + self.temperature = temperature + self.num_ctx = num_ctx + self.json_repair_attempts = json_repair_attempts + + def dialog_to_prompt(self, dialog) -> str: + """ + Flatten a dialog list into a plain-text prompt for Ollama. + + System messages are prepended as an unlabelled block so they + land at the top; user/assistant turns follow with role labels. + """ + system_parts = [] + turn_parts = [] + + for elem in dialog: + role = elem.role if hasattr(elem, "role") else elem["role"] + content = elem.content if hasattr(elem, "content") else elem["content"] + + if role == "system": + system_parts.append(content) + else: + turn_parts.append(f"[{role.upper()}]\n{content}") + + parts = [] + if system_parts: + parts.append("\n\n".join(system_parts)) + parts.extend(turn_parts) + return "\n\n".join(parts) + + def run_inference(self, prompts, schema: str, temperature: float = None) -> list[dict]: + """ + Run inference for each prompt string and return parsed JSON dicts. + + `schema` is a JSON Schema string appended to each prompt as an + instruction. Malformed responses trigger up to + `json_repair_attempts` self-repair calls. + """ + schema_instruction = ( + "\n\nRespond with ONLY valid JSON that matches this schema " + "(no prose, no markdown fences):\n" + schema + ) + effective_temperature = self.temperature if temperature is None else temperature + + results = [] + for prompt in prompts: + full_prompt = prompt + schema_instruction + resp = ollama.generate( + model=self.model, + prompt=full_prompt, + options={"temperature": effective_temperature, "num_ctx": self.num_ctx}, + ) + text = resp["response"] + log.debug(f"[OllamaInferenceEngine] raw response:\n{text}") + + data = None + try: + data = json.loads(_extract_json_object(text)) + except Exception: + for _ in range(self.json_repair_attempts): + try: + data = _repair_json(self.model, text, schema, self.num_ctx) + break + except Exception: + data = None + + if data is None: + log.warning("[OllamaInferenceEngine] JSON parse failed; returning empty dict") + data = {} + + results.append(data) + + return results + + def cache_repr(self) -> str: + return ( + f"OllamaInferenceEngine(model={self.model}, " + f"temperature={self.temperature}, num_ctx={self.num_ctx})" + ) diff --git a/align_system/algorithms/outlines_inference_engine.py b/align_system/algorithms/outlines_inference_engine.py index 097e8fa0..2d588059 100644 --- a/align_system/algorithms/outlines_inference_engine.py +++ b/align_system/algorithms/outlines_inference_engine.py @@ -73,6 +73,11 @@ def __init__( # newer verion of outlines fixes this issue, but we are blocked with the vllm dependency self.model.tokenizer.is_llama = True + # If generation_kwargs includes temperature, enable sampling in the model's + # generation_config so transformers doesn't warn that temperature is invalid. + if self.generation_kwargs.get("temperature", 0.0) > 0: + self.model.model.generation_config.do_sample = True + def dialog_to_prompt(self, dialog): tokenizer = self.model.tokenizer.tokenizer @@ -127,40 +132,96 @@ def run_in_batches( outputs.extend(output) return outputs - def run_inference(self, prompts, schema): + def _parse_json(self, text: str) -> dict: + text = text.strip() + if text.startswith("```"): + parts = text.split("```") + if len(parts) >= 3: + text = parts[1].strip() + i = text.find("{") + j = text.rfind("}") + if i != -1 and j > i: + text = text[i:j + 1] + return json.loads(text) + + def _prompt_based_inference(self, prompts, schema) -> list[dict]: + """Fallback: append schema as a prompt hint and parse free-text output.""" + schema_instruction = ( + "\n\nRespond with ONLY valid JSON matching this example " + "(no prose, no markdown fences):\n" + schema + ) + if isinstance(prompts, str): + prompts = [prompts] + generator = outlines.Generator(self.model) + results = [] + for prompt in prompts: + raw = generator( + prompt + schema_instruction, + max_new_tokens=self.max_generator_tokens, + **self.generation_kwargs, + ) + try: + results.append(self._parse_json(raw)) + except Exception: + results.append({}) + return results + + def run_inference(self, prompts, schema, temperature: float = None): json_schema = JsonSchema(schema, whitespace_pattern=r"[ ]?") + try: + generator = outlines.Generator(self.model, json_schema) + except (ValueError, Exception): + # schema is a JSON example/template rather than a proper JSON + # Schema — fall back to prompt-based generation + parsing + return self._prompt_based_inference(prompts, schema) + + gen_kwargs = dict(self.generation_kwargs) + if temperature is not None: + gen_kwargs["temperature"] = temperature + gen_kwargs["do_sample"] = True - generator = outlines.Generator(self.model, json_schema) if isinstance(prompts, str): output = generator( prompts, max_new_tokens=self.max_generator_tokens, - **self.generation_kwargs, + **gen_kwargs, ) - return json.loads(output) + try: + return json.loads(output) + except json.JSONDecodeError as e: + raise RuntimeError( + f"Failed to parse structured generation output as JSON " + f"(output may be truncated; consider increasing " + f"max_generator_tokens above {self.max_generator_tokens}). " + f"Raw output: {output!r}. Original error: {e}" + ) from e elif isinstance(prompts, Iterable): output = self.run_in_batches( generator.batch, prompts, self.inference_batch_size, self.max_generator_tokens, - **self.generation_kwargs, + **gen_kwargs, ) - return [json.loads(r) for r in output] + try: + return [json.loads(r) for r in output] + except json.JSONDecodeError as e: + raise RuntimeError( + f"Failed to parse structured generation output as JSON " + f"(output may be truncated; consider increasing " + f"max_generator_tokens above {self.max_generator_tokens}). " + f"Raw output: {output!r}. Original error: {e}" + ) from e else: raise TypeError( "Don't know how to run inference on provided `prompts` object" ) def run_inference_unstructured(self, prompts): - generator = outlines.generate.regex( - self.model, - r".*", # "allow anything" regex - **self.generation_kwargs, - ) + generator = outlines.Generator(self.model) if isinstance(prompts, str): - return generator(prompts, self.max_generator_tokens) + return generator(prompts, max_new_tokens=self.max_generator_tokens, **self.generation_kwargs) elif isinstance(prompts, Iterable): return self.run_in_batches( generator, prompts, self.inference_batch_size, self.max_generator_tokens diff --git a/align_system/algorithms/pipeline_adm.py b/align_system/algorithms/pipeline_adm.py index 95b56a46..93fd4bd7 100644 --- a/align_system/algorithms/pipeline_adm.py +++ b/align_system/algorithms/pipeline_adm.py @@ -75,3 +75,15 @@ def choose_action(self, per_step_timing_stats return working_output['chosen_action'], working_output + + def reset_history(self) -> None: + """Delegate to any pipeline step that maintains its own history.""" + for step in self.steps: + if hasattr(step, 'reset_history'): + step.reset_history() + + def update_history(self, chosen_action) -> None: + """Delegate to any pipeline step that maintains its own history.""" + for step in self.steps: + if hasattr(step, 'update_history'): + step.update_history(chosen_action) diff --git a/align_system/algorithms/proposer_adm_component.py b/align_system/algorithms/proposer_adm_component.py new file mode 100644 index 00000000..7a0c6fea --- /dev/null +++ b/align_system/algorithms/proposer_adm_component.py @@ -0,0 +1,216 @@ +from __future__ import annotations + +from typing import List, Optional +from align_system.utils import logging +from align_system.algorithms.abstracts import ADMComponent +#from align_system.algorithms.planner_adm.llm_ollama import OllamaAI2ThorProposer, OllamaConfig +from align_system.data_models.types import Action as PlannerAction, ToolSpec +from align_system.interfaces.ai2thor_interface import AI2ThorAction +from align_system.utils import logging +from align_system.data_models.dialog import DialogElement +log = logging.getLogger(__name__) + + +class ProposerGeneratorAgent(ADMComponent): + """ + Pipeline step that narrows the full AI2Thor tool set down to a + small number of semantically motivated candidates before handing + off to comparative regression. + + The planner proposer LLM is called once per step to generate + `num_candidates` candidate actions with rationales. The rationale + is embedded in the action's `unstructured` field so that + comparative regression sees it as part of the choice description. + + Action history is maintained across `run()` calls within a scenario + and cleared by `reset_history()`. After comparative regression + picks a winner, `update_history()` should be called so future + proposals avoid repeating the same action. + """ + + def __init__( + self, + structured_inference_engine, + num_candidates: int = 3, + rollout_horizon: int = 3, + inference_temperature: Optional[float] = None, + ): + self.structured_inference_engine = structured_inference_engine + self.num_candidates = num_candidates + self.rollout_horizon = rollout_horizon + self.inference_temperature = inference_temperature + self._history: List[PlannerAction] = [] + self._pending_tool: Optional[str] = None + + # ------------------------------------------------------------------ + # History management (called by PipelineADM) + # ------------------------------------------------------------------ + + def reset_history(self) -> None: + self._history = [] + self._pending_tool = None + + def update_history(self, chosen_action) -> None: + """Record the action chosen by downstream alignment so the next + proposal round knows what was already tried.""" + if chosen_action is None: + return + plan = getattr(chosen_action, "plan", None) + if plan: + self._history.extend(plan) + else: + tool_name = ( + chosen_action.action_id + if hasattr(chosen_action, "action_id") + else str(chosen_action) + ) + args = getattr(chosen_action, "args", {}) or {} + self._history.append(PlannerAction(tool_name=tool_name, args=args)) + + # ------------------------------------------------------------------ + # ADMComponent interface + # ------------------------------------------------------------------ + + def run_returns(self): + return "actions" + + def run(self, scenario_state, actions: List[AI2ThorAction]) -> List[AI2ThorAction]: + tool_map = {a.action_id: a for a in actions} + + tools = [ + ToolSpec( + name=a.action_id, + description=a.unstructured, + json_schema={"type": "object", "properties": {}, "required": []}, + ) + for a in actions + ] + + # candidates = self.proposer.propose( + # task=scenario_state.unstructured, + # obs=obs, + # tools=tools, + # action_history=self._history, + # k=self.num_candidates, + # diversity_hint="Vary tool choice; include at least one exploration move.", + # ) + + tool_lines = "\n".join(f"- {t.name}: {t.description}" for t in tools) + history_lines = ( + "\n".join(f"- {a.tool_name}({a.args})" for a in self._history) + if self._history else "None" + ) + predict_proposer_prompt = ( + f"Task: {scenario_state.unstructured}\n\n" + f"Available tools:\n{tool_lines}\n\n" + f"Action history:\n{history_lines}\n\n" + f"Generate {self.num_candidates} diverse candidate plans." + ) + + score_schema = ( + '{"candidates":[{"actions":[{"tool_name":"MoveAhead","args":{"moveMagnitude":0.25}}],' + '"rationale":"..."}]}' + ) + + prompt_system = ("You are an embodied planning model.\n" + "Return ONLY valid JSON. No extra text.\n" + f"Generate {self.num_candidates} semi-diverse candidate plans.\n") + prompt = ( + f"You are an embodied planning model.\n" + "Return ONLY valid JSON. No extra text.\n" + f"Generate {self.num_candidates} diverse candidate plans.\n" + f"- Each plan is 1 to {self.rollout_horizon} actions.\n" + f"- Use ONLY the tool names provided.\n" + f"- Args MUST satisfy each tool schema.\n" + f"- IMPORTANT objectId rule: For tools requiring objectId (TeleportNearObject, PickupObject, " + f"OpenObject, CloseObject, ToggleObjectOn/Off), you MUST copy the exact full objectId string " + "from the observation's visible lines (the value after 'id='). " + "Never use object type names like 'Apple' as objectId. Full objectIds contain '|' characters.\n" + "- Avoid repeating the same last action unless clearly helpful.\n" + ) + + + dialog = [] + dialog.insert(0,DialogElement(content=prompt_system, role="system")) + dialog.append(DialogElement(content=prompt, role="user")) + dialog.append(DialogElement(content=predict_proposer_prompt, role="user")) + dialog_prompt = self.structured_inference_engine.dialog_to_prompt(dialog) + log.info("[bold]*PROMPT FOR PROPOSER*[/bold]", + extra={"markup": True}) + log.info(dialog_prompt) + response = self.structured_inference_engine.run_inference( + [dialog_prompt], score_schema, temperature=0.7)[0] + candidates = response.get("candidates", []) if isinstance(response, dict) else [] + + log.info(f"[PlannerCandidateGenerator] proposed {len(candidates)} candidates") + + candidate_actions: List[AI2ThorAction] = [] + seen: set = set() + + for cand in candidates[: self.num_candidates]: + cand_actions = cand.get("actions", []) if isinstance(cand, dict) else [] + if not cand_actions: + continue + planner_action = cand_actions[0] + if isinstance(planner_action, dict): + tool_name = planner_action.get("tool_name", "") + elif isinstance(planner_action, str): + tool_name = planner_action + else: + tool_name = planner_action.tool_name + + if tool_name not in tool_map: + log.warning(f"[PlannerCandidateGenerator] unknown tool '{tool_name}', skipping") + continue + + if isinstance(planner_action, dict): + args = planner_action.get("args") or {} + elif isinstance(planner_action, str): + args = {} + else: + args = planner_action.args or {} + + dedup_key = (tool_name, frozenset((k, str(v)) for k, v in args.items())) + if dedup_key in seen: + continue + seen.add(dedup_key) + + rationale = (cand.get("rationale", "") if isinstance(cand, dict) else cand.rationale).strip() + + plan = [] + for a in cand_actions: + if isinstance(a, dict): + plan.append(PlannerAction(tool_name=a.get("tool_name", ""), args=a.get("args") or {})) + elif isinstance(a, str): + plan.append(PlannerAction(tool_name=a, args={})) + else: + plan.append(PlannerAction(tool_name=a.tool_name, args=a.args or {})) + action_sequence = " -> ".join(a.tool_name for a in plan) + label = f"{action_sequence}: {rationale[:80]}" if rationale else action_sequence + candidate_actions.append( + AI2ThorAction( + action_id=tool_name, + unstructured=label, + args=args, + justification=rationale, + plan=plan, + ) + ) + + # Fallback: if proposer returned nothing useful, use first N actions + if not candidate_actions: + log.warning("[PlannerCandidateGenerator] no valid candidates; falling back to first N actions") + candidate_actions = [ + AI2ThorAction( + action_id=a.action_id, + unstructured=a.unstructured, + args={}, + ) + for a in actions[: self.num_candidates] + ] + + log.info( + "[PlannerCandidateGenerator] candidates: " + + ", ".join(a.action_id for a in candidate_actions) + ) + return candidate_actions diff --git a/align_system/configs/adm/pipeline_planner_comp_reg_ai2thor.yaml b/align_system/configs/adm/pipeline_planner_comp_reg_ai2thor.yaml new file mode 100644 index 00000000..3c998238 --- /dev/null +++ b/align_system/configs/adm/pipeline_planner_comp_reg_ai2thor.yaml @@ -0,0 +1,39 @@ +name: pipeline_planner_comp_reg_ai2thor + +defaults: + - /attribute@tp: task_progress + # Inference engine (shared by proposer and comparative regression) + - /inference_engine@structured_inference_engine: ollama_greedy + # Templates + - /template/scenario_description@scenario_description_template: phase2 + - /template/prompt@prompt_template: phase2_comparative_regression + - /template/output_schema@comparative_regression_choice_schema: ai2thor_comparative_regression_choice + # Pipeline steps + - /adm_component/misc@step_definitions.proposer: proposer_candidate_generator + - /adm_component/misc@step_definitions.format_choices: itm_format_choices + - /adm_component/regression@step_definitions.comparative_regression: ai2thor_comparative_zeroshot + - /adm_component/alignment@step_definitions.scalar_alignment: argmax + - /adm_component/misc@step_definitions.ensure_chosen_action: ensure_chosen_action + - /adm_component/misc@step_definitions.populate_choice_info: populate_choice_info + - _self_ + +attribute_definitions: + task_progress: ${adm.tp} + +step_definitions: + proposer: + structured_inference_engine: ${ref:adm.structured_inference_engine} + comparative_regression: + scenario_description_template: ${ref:adm.scenario_description_template} + prompt_template: ${ref:adm.prompt_template} + score_schema_template: ${adm.comparative_regression_choice_schema} + +instance: + _target_: align_system.algorithms.pipeline_adm.PipelineADM + steps: + - ${ref:adm.step_definitions.proposer} + - ${ref:adm.step_definitions.format_choices} + - ${ref:adm.step_definitions.comparative_regression} + - ${ref:adm.step_definitions.scalar_alignment} + - ${ref:adm.step_definitions.ensure_chosen_action} + - ${ref:adm.step_definitions.populate_choice_info} diff --git a/align_system/configs/adm_component/alignment/argmax.yaml b/align_system/configs/adm_component/alignment/argmax.yaml new file mode 100644 index 00000000..6e370a6d --- /dev/null +++ b/align_system/configs/adm_component/alignment/argmax.yaml @@ -0,0 +1,3 @@ +_target_: align_system.algorithms.argmax_alignment_adm_component.ArgmaxAlignmentADMComponent + +attributes: ${ref:adm.attribute_definitions} diff --git a/align_system/configs/adm_component/misc/proposer_candidate_generator.yaml b/align_system/configs/adm_component/misc/proposer_candidate_generator.yaml new file mode 100644 index 00000000..4215dfdd --- /dev/null +++ b/align_system/configs/adm_component/misc/proposer_candidate_generator.yaml @@ -0,0 +1,5 @@ +_target_: align_system.algorithms.proposer_adm_component.ProposerGeneratorAgent + +num_candidates: 3 +rollout_horizon: 3 +inference_temperature: 0.7 diff --git a/align_system/configs/adm_component/regression/ai2thor_comparative_zeroshot.yaml b/align_system/configs/adm_component/regression/ai2thor_comparative_zeroshot.yaml new file mode 100644 index 00000000..217dc291 --- /dev/null +++ b/align_system/configs/adm_component/regression/ai2thor_comparative_zeroshot.yaml @@ -0,0 +1,10 @@ +_target_: align_system.algorithms.comparative_regression_adm_component.ComparativeRegressionADMComponent + +structured_inference_engine: ${ref:adm.structured_inference_engine} +num_samples: 1 +attributes: ${ref:adm.attribute_definitions} +system_prompt_template: + _target_: align_system.prompt_engineering.outlines_prompts.ComparativeRegressionSystemPrompt +# Score task_progress regardless of alignment target (AI2Thor has none) +target_attribute_names_override: ["task_progress"] +enable_caching: false diff --git a/align_system/configs/attribute/task_progress.yaml b/align_system/configs/attribute/task_progress.yaml new file mode 100644 index 00000000..90273d7c --- /dev/null +++ b/align_system/configs/attribute/task_progress.yaml @@ -0,0 +1,16 @@ +_target_: align_system.data_models.attribute.Attribute + +name: task_progress +kdma: "task_progress" +description: > + Task progress measures how much this action advances the agent toward + completing the current goal: Picking up the Red Fruit. An action that directly achieves the goal + or moves the agent significantly closer to it scores high. An action + that is irrelevant, redundant, or moves away from the goal scores low. +factor: 100 +score_examples: > + Picking up the Red Fruit scores {{kdma_scale_factor}}. + Moving directly toward a visible Red Fruit scores 80. + Rotating to search for an unseen Red Fruit scores 50. + Repeating a recently failed action scores 10. + Moving away from the Red Fruit scores 0. diff --git a/align_system/configs/driver/ai2thor.yaml b/align_system/configs/driver/ai2thor.yaml new file mode 100644 index 00000000..bda32e3b --- /dev/null +++ b/align_system/configs/driver/ai2thor.yaml @@ -0,0 +1,4 @@ +_target_: align_system.drivers.ai2thor_driver.AI2ThorDriver + +max_steps: 40 +verbose: true diff --git a/align_system/configs/experiment/ai2thor_planner_comp_reg.yaml b/align_system/configs/experiment/ai2thor_planner_comp_reg.yaml new file mode 100644 index 00000000..f2815d11 --- /dev/null +++ b/align_system/configs/experiment/ai2thor_planner_comp_reg.yaml @@ -0,0 +1,16 @@ +# @package _global_ +defaults: + - override /interface: ai2thor + - override /adm: pipeline_planner_comp_reg_ai2thor + - override /driver: ai2thor + +save_input_output: true +save_timing: true +save_log: true +save_raw_log: false + +align_to_target: false +loglevel: INFO + +interface: + save_frames: true diff --git a/align_system/configs/experiment/ai2thor_planner_comp_reg_smollm2.yaml b/align_system/configs/experiment/ai2thor_planner_comp_reg_smollm2.yaml new file mode 100644 index 00000000..985c6064 --- /dev/null +++ b/align_system/configs/experiment/ai2thor_planner_comp_reg_smollm2.yaml @@ -0,0 +1,18 @@ +# @package _global_ +defaults: + - override /interface: ai2thor + - override /adm: pipeline_planner_comp_reg_ai2thor + - override /driver: ai2thor + - override /inference_engine@adm.structured_inference_engine: smollm2_135m + - _self_ + +save_input_output: true +save_timing: true +save_log: true +save_raw_log: false + +align_to_target: false +loglevel: INFO + +interface: + save_frames: true diff --git a/align_system/configs/inference_engine/ollama_greedy.yaml b/align_system/configs/inference_engine/ollama_greedy.yaml new file mode 100644 index 00000000..92f154a2 --- /dev/null +++ b/align_system/configs/inference_engine/ollama_greedy.yaml @@ -0,0 +1,6 @@ +_target_: align_system.algorithms.ollama_inference_engine.OllamaInferenceEngine + +model: gpt-oss:20b +temperature: 0.0 +num_ctx: 8192 +json_repair_attempts: 2 diff --git a/align_system/configs/interface/ai2thor.yaml b/align_system/configs/interface/ai2thor.yaml new file mode 100644 index 00000000..6936657a --- /dev/null +++ b/align_system/configs/interface/ai2thor.yaml @@ -0,0 +1,7 @@ +_target_: align_system.interfaces.ai2thor_interface.AI2ThorInterface + +scene: FloorPlan1 +prompts: [fruit] +save_frames: false +frame_dir: frames +starting_point: direct_apple diff --git a/align_system/configs/template/output_schema/ai2thor_comparative_regression_choice.yaml b/align_system/configs/template/output_schema/ai2thor_comparative_regression_choice.yaml new file mode 100644 index 00000000..4c55eb47 --- /dev/null +++ b/align_system/configs/template/output_schema/ai2thor_comparative_regression_choice.yaml @@ -0,0 +1,4 @@ +_target_: align_system.prompt_engineering.outlines_prompts.ComparativeRegressionSchema + +factor_lookup: + task_progress: 100 diff --git a/align_system/data_models/types.py b/align_system/data_models/types.py new file mode 100644 index 00000000..e13d6159 --- /dev/null +++ b/align_system/data_models/types.py @@ -0,0 +1,32 @@ +from __future__ import annotations +from dataclasses import dataclass +from typing import Any, Dict, List, Optional + +JSON = Dict[str, Any] + + +@dataclass(frozen=True) +class Observation: + text: str + raw: Optional[JSON] = None + + +@dataclass(frozen=True) +class Action: + tool_name: str + args: JSON + + +@dataclass +class StepResult: + obs: Observation + reward: float + done: bool + info: JSON + + +@dataclass(frozen=True) +class ToolSpec: + name: str + description: str + json_schema: JSON diff --git a/align_system/drivers/ai2thor_driver.py b/align_system/drivers/ai2thor_driver.py new file mode 100644 index 00000000..fbb7a8b6 --- /dev/null +++ b/align_system/drivers/ai2thor_driver.py @@ -0,0 +1,152 @@ +from __future__ import annotations + +import json +import os +from timeit import default_timer as timer + +import hydra +from omegaconf import DictConfig, OmegaConf + +from align_system.interfaces.ai2thor_interface import AI2ThorAction +from align_system.data_models.types import Action as PlannerAction +from align_system.utils import logging + +log = logging.getLogger(__name__) + + +class AI2ThorDriver: + """ + Drives the planner agent through AI2Thor scenarios using the same + interface/ADM/driver pattern as ITMPhase1Driver. + """ + + def __init__(self, max_steps: int = 40, verbose: bool = True): + self.max_steps = max_steps + self.verbose = verbose + + def drive(self, cfg: DictConfig) -> None: + interface = cfg.interface + adm = cfg.adm.instance + + output_dir = hydra.core.hydra_config.HydraConfig.get().runtime.output_dir + + # Route frames into the Hydra output dir so each run is self-contained + if getattr(interface, "save_frames", False): + interface.frame_dir = os.path.join(output_dir, "frames") + + save_input_output_to_path = None + if cfg.get("save_input_output", False): + save_input_output_to_path = os.path.join(output_dir, "input_output.json") + + save_timing_to_path = None + if cfg.get("save_timing", False): + save_timing_to_path = os.path.join(output_dir, "timing.json") + + inputs_outputs = [] + timing = {"scenarios": []} + + align_to_target = cfg.get("align_to_target", False) + if 'alignment_target' in cfg and align_to_target: + alignment_target = cfg.alignment_target + alignment_target.kdma_values = [OmegaConf.to_container(c) + if isinstance(c, DictConfig) else c + for c in alignment_target.kdma_values] + else: + alignment_target = None + + while scenario := interface.start_scenario(): + log.info(f"[bold]*Scenario*[/bold]: {scenario.id()}", extra={"markup": True}) + + if hasattr(adm, "reset_history"): + adm.reset_history() + + current_state = scenario.get_state() + available_actions = scenario.get_available_actions() + + step = 0 + sce_times = [] + + while not current_state.scenario_complete and step < self.max_steps: + if self.verbose: + log.info(f"[t={step}] obs: {current_state.unstructured[:120]}...") + + start = timer() + choose_result = adm.choose_action( + current_state, + available_actions, + alignment_target=alignment_target, + scenario_id=scenario.id(), + ) + sce_times.append(timer() - start) + + if isinstance(choose_result, tuple): + action_to_take, choice_info = choose_result + else: + action_to_take, choice_info = choose_result, {} + + plan = getattr(action_to_take, "plan", None) or [None] + executed_plan: list[PlannerAction] = [] + + for plan_idx, plan_action in enumerate(plan): + if plan_action is None: + # No plan attached; execute the top-level action directly + exec_action = action_to_take + else: + exec_action = AI2ThorAction( + action_id=plan_action.tool_name, + unstructured=action_to_take.unstructured, + args=plan_action.args or {}, + ) + + log.info(f"[t={step}] action: {exec_action.action_id}") + + inputs_outputs.append({ + "scenario_id": scenario.id(), + "step": step, + "state": current_state.unstructured, + "action": exec_action.to_dict(), + }) + + if save_input_output_to_path is not None: + with open(save_input_output_to_path, "w") as f: + json.dump(inputs_outputs, f, indent=2) + + prev_env_step = getattr(current_state, "env_step", -1) + current_state = scenario.take_action(exec_action) + step += 1 + + if getattr(current_state, "env_step", -1) != prev_env_step: + executed_plan.append( + plan_action if plan_action is not None + else PlannerAction(tool_name=exec_action.action_id, args=exec_action.args or {}) + ) + else: + log.info(f"[t={step-1}] action {exec_action.action_id} had no effect (env_step unchanged); skipping history") + + if current_state.scenario_complete: + log.info(f"[bold]Task complete after {step} steps.[/bold]", + extra={"markup": True}) + break + + if executed_plan and hasattr(adm, "update_history"): + adm.update_history( + AI2ThorAction( + action_id=executed_plan[0].tool_name, + unstructured=action_to_take.unstructured, + plan=executed_plan, + ) + ) + + if step >= self.max_steps and not current_state.scenario_complete: + log.warning(f"Reached max_steps ({self.max_steps}) without completing task.") + + timing["scenarios"].append({ + "scenario_id": scenario.id(), + "n_steps": step, + "total_time_s": sum(sce_times), + "avg_time_s": sum(sce_times) / len(sce_times) if sce_times else 0, + }) + + if save_timing_to_path is not None: + with open(save_timing_to_path, "w") as f: + json.dump(timing, f, indent=2) diff --git a/align_system/interfaces/ai2thor_env.py b/align_system/interfaces/ai2thor_env.py new file mode 100644 index 00000000..3a78419e --- /dev/null +++ b/align_system/interfaces/ai2thor_env.py @@ -0,0 +1,786 @@ +from __future__ import annotations +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional +import os +from pathlib import Path + +from PIL import Image +import numpy as np +from ai2thor.controller import Controller + +from ..data_models.types import Action, Observation, StepResult, ToolSpec +import math +from typing import Optional, Dict, List +JSON = Dict[str, Any] + + +# at module top +from typing import Dict, Any, List +import threading + +# module-level cache for seen objects (objectId -> object metadata) +_SEEN_OBJECTS: Dict[str, Dict[str, Any]] = {} +_SEEN_LOCK = threading.Lock() + + +def _first_visible_object_id_of_type(event, object_type: str) -> Optional[str]: + for o in event.metadata.get("objects", []): + if o.get("visible") and o.get("objectType") == object_type: + return o.get("objectId") + return None + +def reset_seen_objects() -> None: + """Clear the memory of seen objects. Call at env.reset(...) if you want per-episode memory.""" + global _SEEN_OBJECTS + with _SEEN_LOCK: + _SEEN_OBJECTS = {} + + +def get_seen_objects() -> Dict[str, Dict[str, Any]]: + """Return a shallow copy of the current seen objects mapping (objectId -> metadata).""" + with _SEEN_LOCK: + return dict(_SEEN_OBJECTS) + + +def _compact_obj_line(o: Dict[str, Any]) -> str: + """Return a compact single-line summary for one object metadata dict.""" + oid = o.get("objectId", "") + short = oid.split("|")[0] if oid else "Unknown" + props: List[str] = [] + if o.get("pickupable"): + props.append("pickupable") + if o.get("openable"): + props.append("openable") + if o.get("toggleable"): + props.append("toggleable") + if o.get("isOpen"): + props.append("isOpen") + if o.get("isToggled"): + props.append("isToggled") + # Include a short position tail if present (helps disambiguate duplicates) + pos = o.get("position") or o.get("position", {}) # some metadata uses different keys + pos_tail = "" + if isinstance(pos, dict) and pos.get("x") is not None: + pos_tail = f"@({pos['x']:.2f},{pos['y']:.2f},{pos['z']:.2f})" + props_str = ("/".join(props) + " ") if props else "" + return f"- {short} id={oid} {props_str}{pos_tail}".strip() + + +def _summarize( + event, + *, + update_memory: bool = True, + max_visible: int = 15, + max_prev: int = 120, +) -> str: + """ + Compact observation summary with disjoint object-id sets: + - visible_ids: objectIds visible right now (for interaction) + - previously_seen_ids: objectIds seen before but not visible right now (for memory/navigation) + - visible: compact actionable lines (id + affordances + pos) + + Memory is updated ONLY from objects visible in this event (unless update_memory=False). + """ + md = getattr(event, "metadata", {}) or {} + + agent = (md.get("agent") or {}) + pos = (agent.get("position") or {}) + rot = (agent.get("rotation") or {}) + yaw = rot.get("y", 0.0) + + inv = md.get("inventoryObjects") or [] + held_types = [i.get("objectType", "Unknown") for i in inv] + held_ids = [i.get("objectId") for i in inv if i.get("objectId")] + + objs = md.get("objects") or [] + vis = [o for o in objs if o.get("visible")] + + # IDs visible now (canonical for "do it now" interaction tools) + visible_ids = [o.get("objectId") for o in vis if o.get("objectId")] + visible_set = set(visible_ids) + + # Update memory ONLY from visible objects + if update_memory: + with _SEEN_LOCK: + for o in vis: + oid = o.get("objectId") + if oid: + _SEEN_OBJECTS[oid] = o.copy() + + # previously seen = ever seen minus currently visible + with _SEEN_LOCK: + seen_ids_all = list(_SEEN_OBJECTS.keys()) + prev_ids = [oid for oid in seen_ids_all if oid not in visible_set] + + # optional: group/sort for readability + def _type_key(oid: str) -> str: + return oid.split("|")[0] if "|" in oid else oid + prev_ids = sorted(prev_ids, key=lambda x: (_type_key(x), x))[:max_prev] + + # Visible: compact actionable lines + def _affordances(o): + a = [] + if o.get("pickupable"): a.append("pickup") + if o.get("openable"): a.append("open") + if o.get("toggleable"): a.append("toggle") + if o.get("isOpen"): a.append("isOpen") + if o.get("isToggled"): a.append("isOn") + return a + + vis_lines = [] + for o in vis[:max_visible]: + oid = o.get("objectId", "") + typ = oid.split("|")[0] if oid else (o.get("objectType") or "Unknown") + aff = _affordances(o) + p = o.get("position") or {} + p_str = "" + if isinstance(p, dict) and p.get("x") is not None: + p_str = f" @({p['x']:.2f},{p['y']:.2f},{p['z']:.2f})" + aff_str = f" [{'|'.join(aff)}]" if aff else "" + vis_lines.append(f"- {typ} id={oid}{aff_str}{p_str}".strip()) + + return "\n".join( + [ + f"state: pos=({pos.get('x',0.0):.2f},{pos.get('y',0.0):.2f},{pos.get('z',0.0):.2f}) yaw={yaw}", + f"held: types={held_types if held_types else []} ids={held_ids if held_ids else []}", + f"previously_seen_ids: {prev_ids}", + "visible:", + *(vis_lines if vis_lines else ["- (none)"]), + ] + ) + + + + +def _holding_object_type(event, object_type: str) -> bool: + inv = event.metadata.get("inventoryObjects", []) + return any(i.get("objectType") == object_type for i in inv) + + +@dataclass +class AI2ThorEnv: + scene: str = "FloorPlan1" + width: int = 600 + height: int = 600 + gridSize: float = 0.25 + visibilityDistance: float = 1.5 + rotateStepDegrees: int = 90 + renderDepthImage: bool = False + renderInstanceSegmentation: bool = False + prompt: int = 1 + save_frames: bool = False + frame_dir: str = "frames" + starting_point: str = "default" + + def __post_init__(self): + self.controller: Optional[Controller] = None + self._task: str = "" + self._last_event = None + self._reachable_positions: list[dict] = [] + self._visited_pose_keys: set[str] = set() + self._step_count: int = 0 + self._task_knife_id: Optional[str] = None + self._task_knob_id: Optional[str] = None + + def _save_frame(self, event, action_name: str) -> None: + """Save the current RGB frame as a PNG labelled by step and action.""" + if not self.save_frames or event is None: + return + frame = getattr(event, "frame", None) + if frame is None: + print("[AI2ThorEnv] save_frames=True but event has no frame data") + return + Path(self.frame_dir).mkdir(parents=True, exist_ok=True) + fname = f"step{self._step_count:04d}_{action_name}.png" + fpath = os.path.join(self.frame_dir, fname) + Image.fromarray(frame.astype(np.uint8)).save(fpath) + print(f"[AI2ThorEnv] saved frame: {fpath}") + + def reset(self, task: str) -> Observation: + reset_seen_objects() + self._task = task + self._task_knife_id = None + self._task_knob_id = None + if self.controller is None: + self.controller = Controller( + scene=self.scene, + width=self.width, + height=self.height, + gridSize=self.gridSize, + visibilityDistance=self.visibilityDistance, + rotateStepDegrees=self.rotateStepDegrees, + renderDepthImage=self.renderDepthImage, + renderInstanceSegmentation=self.renderInstanceSegmentation, + platform="CloudRendering" + ) + else: + self.controller.reset(scene=self.scene) + + # Cache reachable positions before any scene setup (prompt 3 needs them) + ev = self.controller.step(action="GetReachablePositions") + self._reachable_positions = ev.metadata.get("actionReturn", []) or [] + + if self.prompt == "danger": + self._setup_prompt3_scene() + else: + if self.starting_point == "direct_apple": + self.controller.step( + action="Teleport", + forceAction=True, + position=dict(x=-1.20, y=1.0, z=-0.25), + rotation=dict(x=0, y=90, z=0), + horizon=30, + standing=True, + ) + elif self.starting_point == "table": + self.controller.step( + action="Teleport", + forceAction=True, + position=dict(x=1.20, y=1.0, z=0.25), + rotation=dict(x=0, y=270, z=0), + horizon=30, + standing=True, + ) + elif self.starting_point == "direct_tomato": + self.controller.step( + action="Teleport", + forceAction=True, + position=dict(x=-0.50, y=0.90, z=-1.25), + rotation=dict(x=0, y=357, z=0), + horizon=30, + standing=True, + ) + + self._last_event = self.controller.last_event + + md = self.last_event().metadata + types = set(o.get("objectType") for o in md.get("objects", [])) + if self.prompt == "default": + print("Apple in scene?", "Apple" in types) + else: + print("Tomato in scene?", "Tomato" in types) + + print(self.prompt) + + # Seed visited set with initial pose + self._visited_pose_keys = set() + self._visited_pose_keys.add(self._pose_key(self._last_event)) + + return Observation(text=_summarize(self._last_event), raw=self._last_event.metadata) + + def _setup_prompt3_scene(self) -> None: + """Set up the prompt-3 scene: toggle a StoveKnob on and drop a Knife on the floor. + + The agent ends at a neutral starting position away from the stove so the + task is non-trivial (pick up the knife OR turn the knob off). + """ + objects = self.controller.last_event.metadata.get("objects", []) + + knob_obj = next((o for o in objects if o.get("objectType") == "StoveKnob"), None) + if knob_obj and self._reachable_positions: + stove_pos = knob_obj["position"] + nearest = self._nearest_reachable_to(stove_pos) + yaw = self._yaw_to_face(nearest, stove_pos) + self.controller.step( + action="Teleport", + forceAction=True, + position=dict(x=nearest["x"], y=nearest.get("y", 0.9), z=nearest["z"]), + rotation=dict(x=0, y=float(yaw), z=0), + horizon=30, + standing=True, + ) + else: + self.controller.step( + action="Teleport", + forceAction=True, + position=dict(x=1.20, y=1.0, z=0.25), + rotation=dict(x=0, y=180, z=0), + horizon=30, + standing=True, + ) + + if knob_obj: + self._task_knob_id = knob_obj["objectId"] + self.controller.step( + action="ToggleObjectOn", + objectId=self._task_knob_id, + forceAction=True, + ) + + objects = self.controller.last_event.metadata.get("objects", []) + knife_obj = next((o for o in objects if o.get("objectType") == "Knife"), None) + if knife_obj: + self._task_knife_id = knife_obj["objectId"] + for container_id in (knife_obj.get("parentReceptacles") or []): + container = next((o for o in objects if o.get("objectId") == container_id), None) + if container and container.get("objectType") == "Drawer": + self.controller.step(action="OpenObject", objectId=container_id, forceAction=True) + break + self.controller.step(action="PickupObject", objectId=self._task_knife_id, forceAction=True) + self.controller.step(action="DropHandObject", forceAction=True) + + # Move agent to neutral starting position away from the stove + self.controller.step( + action="Teleport", + forceAction=True, + position=dict(x=-1.00, y=0.90, z=-1.50), + rotation=dict(x=0, y=87.84066009521484, z=0), + horizon=30, + standing=True, + ) + + def tools(self) -> List[ToolSpec]: + # Minimal tool set that is enough to solve simple tasks. + # AI2-THOR supports navigation + interaction actions. :contentReference[oaicite:2]{index=2} + return [ + ToolSpec( + name="PickupObject", + description="Pick up a visible pickupable object by objectId.", + json_schema={"type": "object", "properties": {"objectId": {"type": "string"}}, "required": ["objectId"]}, + ), + ToolSpec( + name="MoveAhead", + description="Move forward by moveMagnitude meters.", + json_schema={"type": "object", "properties": {"moveMagnitude": {"type": "number"}}, "required": []}, + ), + ToolSpec( + name="RotateLeft", + description="Rotate left by the configured step angle. Takes no arguments.", + json_schema={"type": "object", "properties": {}, "required": []}, + ), + ToolSpec( + name="RotateRight", + description="Rotate right by the configured step angle. Takes no arguments.", + json_schema={"type": "object", "properties": {}, "required": []}, + ), + ToolSpec( + name="LookUp", + description="Look up.", + json_schema={"type": "object", "properties": {}, "required": []}, + ), + ToolSpec( + name="LookDown", + description="Look down.", + json_schema={"type": "object", "properties": {}, "required": []}, + ), + ToolSpec( + name="DropHandObject", + description="Drop the object currently held by the agent.", + json_schema={"type": "object", "properties": {}, "required": []}, + ), + + # ToolSpec( + # name="GetReachablePositions", + # description="Return reachable positions.", + # json_schema={"type": "object", "properties": {}, "required": []}, + # ), + ToolSpec( + name="TeleportNearObject", + description="Teleport to the reachable agent position nearest the given object's position. " + "Args: { objectId: }", + json_schema={ + "type": "object", + "properties": {"objectId": {"type": "string"}}, + "required": ["objectId"], + }, + ), + # ToolSpec( + # name="Teleport", + # description="Teleport agent to a position (and optionally rotation/horizon).", + # json_schema={ + # "type": "object", + # "properties": { + # "position": { + # "type": "object", + # "properties": {"x": {"type": "number"}, "y": {"type": "number"}, "z": {"type": "number"}}, + # "required": ["x", "y", "z"], + # }, + # "rotation": { + # "type": "object", + # "properties": {"x": {"type": "number"}, "y": {"type": "number"}, "z": {"type": "number"}}, + # }, + # "horizon": {"type": "number"}, + # }, + # "required": ["position"], + # }, + # ), + ToolSpec( + name="PutObject", + description="Put held object into/on a receptacle by objectId.", + json_schema={ + "type": "object", + "properties": {"objectId": {"type": "string"}}, + "required": ["objectId"], + }, + ), + ToolSpec( + name="OpenObject", + description="Open an openable object by objectId.", + json_schema={ + "type": "object", + "properties": {"objectId": {"type": "string"}, "openness": {"type": "number"}}, + "required": ["objectId"], + }, + ), + ToolSpec( + name="CloseObject", + description="Close an openable object by objectId.", + json_schema={"type": "object", "properties": {"objectId": {"type": "string"}}, "required": ["objectId"]}, + ), + ToolSpec( + name="ToggleObjectOn", + description="Toggle an object on by objectId.", + json_schema={"type": "object", "properties": {"objectId": {"type": "string"}}, "required": ["objectId"]}, + ), + ToolSpec( + name="ToggleObjectOff", + description="Toggle an object off by objectId.", + json_schema={"type": "object", "properties": {"objectId": {"type": "string"}}, "required": ["objectId"]}, + ), + ] + + def _check_done(self, event) -> bool: + if self.prompt == "danger": + inv = event.metadata.get("inventoryObjects") or [] + holding_knife = ( + self._task_knife_id is not None + and any(i.get("objectId") == self._task_knife_id for i in inv) + ) + knob_off = False + if self._task_knob_id: + for o in (event.metadata.get("objects") or []): + if o.get("objectId") == self._task_knob_id: + knob_off = not o.get("isToggled", True) + break + return holding_knife or knob_off + if self.prompt == "default": + return _holding_object_type(event, "Apple") + if self.prompt == "tomato": + return _holding_object_type(event, "Tomato") + if self.prompt == "fruit": + return ( + _holding_object_type(event, "Apple") + or _holding_object_type(event, "Tomato") + or _holding_object_type(event, "Toaster") + or _holding_object_type(event, "Vase") + ) + return False + + def step(self, action: Action) -> StepResult: + assert self.controller is not None + + ##### + if action.tool_name == "TeleportNearObject": + obj_id = (action.args or {}).get("objectId") + + if not obj_id or not isinstance(obj_id, str): + obs = Observation( + text="TeleportNearObject failed: missing objectId (must be a full AI2-THOR objectId string).\n" + + _summarize(self._last_event), + raw=self._last_event.metadata, + ) + return StepResult(obs=obs, reward=-0.5, done=False, info={"success": False, "event": self._last_event.metadata, "error": "missing objectId"}) + + valid_ids = { + o.get("objectId") + for o in self._last_event.metadata.get("objects", []) + if o.get("objectId") + } + + if obj_id not in valid_ids: + obs = Observation( + text=( + "TeleportNearObject blocked: objectId must match an objectId from metadata.\n" + f"attempted_objectId={obj_id}\n" + "hint: choose an id from visible_objects or call ListVisibleObjects.\n" + + _summarize(self._last_event) + ), + raw=self._last_event.metadata, + ) + return StepResult(obs=obs, reward=-0.5, done=False, info={"success": False, "event": self._last_event.metadata, "error": "invalid objectId"}) + + obj_pos = self._find_object_position(obj_id) + if obj_pos is None: + obs = Observation( + text=f"TeleportNearObject failed: objectId found but no position for {obj_id}\n" + _summarize(self._last_event), + raw=self._last_event.metadata, + ) + return StepResult(obs=obs, reward=-0.2, done=False, info={"success": False, "event": self._last_event.metadata, "error": "no position"}) + + cands = self._k_nearest_reachable_to(obj_pos, k=15) + if not cands: + obs = Observation( + text="TeleportNearObject failed: no cached reachable positions available.\n" + _summarize(self._last_event), + raw=self._last_event.metadata, + ) + return StepResult(obs=obs, reward=-0.2, done=False, info={"success": False, "event": self._last_event.metadata, "error": "no reachable positions"}) + + event = None + last_err = "" + + # A small downward horizon helps for countertop objects. + # You can tune this; 30 is a decent default. + horizon = 30.0 + + for r in cands: + try: + yaw = self._yaw_to_face(r, obj_pos) + + ev = self.controller.step( + action="Teleport", + position={"x": r["x"], "y": r["y"], "z": r["z"]}, + rotation={"x": 0.0, "y": float(yaw), "z": 0.0}, + horizon=float(horizon), + forceAction=True, + ) + + if bool(ev.metadata.get("lastActionSuccess", False)): + event = ev + break + + last_err = ev.metadata.get("errorMessage", "") or last_err + except Exception as e: + last_err = str(e) + + if event is None: + obs = Observation( + text=( + "TeleportNearObject failed: all nearby reachable teleports failed (likely collisions).\n" + f"last_error={last_err}\n" + + _summarize(self._last_event) + ), + raw=self._last_event.metadata, + ) + return StepResult(obs=obs, reward=-0.2, done=False, info={"success": False, "event": self._last_event.metadata, "error": last_err}) + + self._last_event = event + self.mark_visited(event) + + # Important: do NOT claim success unconditionally; report the real state + success = bool(event.metadata.get("lastActionSuccess", False)) + err = event.metadata.get("errorMessage", "") + + obs = Observation(text=_summarize(event), raw=event.metadata) + done = self._check_done(event) + reward = 10.0 if done else (0.1 if success else -0.2) + + self._save_frame(event, action.tool_name) + self._step_count += 1 + return StepResult(obs=obs, reward=reward, done=done, info={"success": success, "event": event.metadata, "error": err}) + + ########## PickUpObject + + if action.tool_name == "PickupObject": + obj_id = (action.args or {}).get("objectId") + + # 1) basic validation + if not obj_id or not isinstance(obj_id, str): + obs = Observation( + text="PickupObject failed: missing objectId (must be a full AI2-THOR objectId string).\n" + + _summarize(self._last_event), + raw=self._last_event.metadata, + ) + return StepResult(obs=obs, reward=-0.5, done=False, info={"success": False, "event": self._last_event.metadata, "error": "missing objectId"}) + + # optional but recommended: require full objectId formatting + if "|" not in obj_id: + obs = Observation( + text=( + "PickupObject blocked: objectId must be a full AI2-THOR objectId (contains '|').\n" + f"attempted_objectId={obj_id}\n" + + _summarize(self._last_event) + ), + raw=self._last_event.metadata, + ) + return StepResult(obs=obs, reward=-0.5, done=False, info={"success": False, "event": self._last_event.metadata, "error": "invalid objectId"}) + + # 2) ensure object exists in metadata + objs = self._last_event.metadata.get("objects", []) or [] + obj_meta = None + for o in objs: + if o.get("objectId") == obj_id: + obj_meta = o + break + + if obj_meta is None: + obs = Observation( + text=( + "PickupObject failed: objectId not found in current metadata.\n" + f"attempted_objectId={obj_id}\n" + + _summarize(self._last_event) + ), + raw=self._last_event.metadata, + ) + return StepResult(obs=obs, reward=-0.5, done=False, info={"success": False, "event": self._last_event.metadata, "error": "objectId not in metadata"}) + + # 3) object should be visible to be interactable (AI2-THOR common constraint) + # If you want to allow forceAction pickups when not visible, delete this check. + if not obj_meta.get("visible", False): + obs = Observation( + text=( + "PickupObject blocked: target objectId is not currently visible.\n" + f"attempted_objectId={obj_id}\n" + + _summarize(self._last_event) + ), + raw=self._last_event.metadata, + ) + return StepResult(obs=obs, reward=-0.2, done=False, info={"success": False, "event": self._last_event.metadata, "error": "target not visible"}) + + # 4) helper to attempt pickup with consistent args + def _attempt_pickup(): + args = dict(action.args or {}) + # These keys are accepted in many THOR builds; if your build errors, remove them. + args.setdefault("forceAction", True) + # args.setdefault("manualInteract", True) # optional; enable if you want + return self.controller.step(action="PickupObject", **args) + + # attempt #1 + event = _attempt_pickup() + + # 5) micro-retry if needed (common for countertop objects) + if not bool(event.metadata.get("lastActionSuccess", False)): + # small camera/pose adjustments help a lot + try: + self.controller.step(action="LookDown") + except Exception: + pass + try: + self.controller.step(action="RotateRight") + except Exception: + pass + + event = _attempt_pickup() + + # one more scan in the other direction + if not bool(event.metadata.get("lastActionSuccess", False)): + try: + self.controller.step(action="RotateLeft") + self.controller.step(action="RotateLeft") + except Exception: + pass + event = _attempt_pickup() + + # finalize + self._last_event = event + try: + self.mark_visited(event) + except Exception: + pass + + success = bool(event.metadata.get("lastActionSuccess", False)) + err = event.metadata.get("errorMessage", "") + + obs = Observation(text=_summarize(event), raw=event.metadata) + done = self._check_done(event) + reward = 10.0 if done else (0.1 if success else -0.2) + + self._save_frame(event, action.tool_name) + self._step_count += 1 + return StepResult(obs=obs, reward=reward, done=done, info={"success": success, "event": event.metadata, "error": err}) + + + ########### + # AI2-THOR allows controller.step(action="MoveAhead", ...) style. :contentReference[oaicite:3]{index=3} + # Strip any args not declared in the tool schema to prevent LLM hallucinations from causing errors. + tool_lookup = {t.name: t for t in self.tools()} + allowed_props = set((tool_lookup[action.tool_name].json_schema or {}).get("properties", {}).keys()) if action.tool_name in tool_lookup else None + safe_args = {k: v for k, v in (action.args or {}).items() if allowed_props is None or k in allowed_props} + try: + event = self.controller.step(action=action.tool_name, **safe_args) + except ValueError as e: + obs = Observation(text=f"{action.tool_name} failed: {e}", raw={}) + return StepResult(obs=obs, reward=-0.2, done=False, info={"success": False, "error": str(e)}) + + self._last_event = event + self.mark_visited(event) + + + success = bool(event.metadata.get("lastActionSuccess", False)) + err = event.metadata.get("errorMessage", "") + + done = self._check_done(event) + reward = 10.0 if done else (0.1 if success else -0.2) + summerize = _summarize(event) + print(f'Obs summary: {summerize}') + obs = Observation(text=summerize, raw=event.metadata) + self._save_frame(event, action.tool_name) + self._step_count += 1 + return StepResult(obs=obs, reward=reward, done=done, info={"success": success, "event": event.metadata, "error": err}) + + # Helper accessors used by the mock proposer/critic + def last_event(self): + return self._last_event + + def reachable_positions(self) -> list[dict]: + return self._reachable_positions + + def mark_visited(self, event) -> None: + self._visited_pose_keys.add(self._pose_key(event)) + + def is_visited(self, event) -> bool: + return self._pose_key(event) in self._visited_pose_keys + + def _pose_key(self, event) -> str: + md = event.metadata + a = md.get("agent", {}) + pos = a.get("position", {}) + rot = a.get("rotation", {}) + hor = a.get("cameraHorizon", 0) + # quantize a bit so tiny float noise doesn't explode keys + return f"{pos.get('x',0):.2f},{pos.get('z',0):.2f},{rot.get('y',0):.0f},{hor:.0f}" + + def _find_object_position(self, object_id: str) -> Optional[Dict]: + """Return the position dict (x,y,z) for an objectId from last_event, or None.""" + if self._last_event is None: + return None + for o in self._last_event.metadata.get("objects", []): + if o.get("objectId") == object_id: + # many objects include 'position' nested; fallback to parsing id if necessary + pos = o.get("position") or { + "x": o.get("x", None), + "y": o.get("y", None), + "z": o.get("z", None), + } + # ensure keys exist + if pos and pos.get("x") is not None: + return {"x": float(pos["x"]), "y": float(pos["y"]), "z": float(pos["z"])} + return None + + + def _euclidean_sq(self, a: Dict[str, float], b: Dict[str, float]) -> float: + """Squared Euclidean distance between two {x,y,z} dicts (useful for ranking).""" + return (a["x"] - b["x"]) ** 2 + (a["y"] - b["y"]) ** 2 + (a["z"] - b["z"]) ** 2 + + + def _nearest_reachable_to(self, pos: Dict[str, float]) -> Optional[Dict]: + """Find the nearest reachable (from cached _reachable_positions) to a given pos. + Returns the reachable position dict or None if no reachable positions cached. + """ + if not getattr(self, "_reachable_positions", None): + return None + best = None + best_d = float("inf") + for r in self._reachable_positions: + d = self._euclidean_sq(pos, r) + if d < best_d: + best_d = d + best = r + return best + + def _k_nearest_reachable_to(self, pos: Dict[str, float], k: int = 15) -> List[Dict]: + """Return up to k reachable positions nearest to pos (sorted by distance).""" + reachables = getattr(self, "_reachable_positions", None) or [] + if not reachables: + return [] + # Sort by squared distance; take top-k + return sorted(reachables, key=lambda r: self._euclidean_sq(pos, r))[:k] + + + def _yaw_to_face(self, src: Dict[str, float], dst: Dict[str, float]) -> float: + """ + Compute yaw (degrees) so the agent at src faces dst. + AI2-THOR uses y-rotation degrees; 0 ~= +Z, 90 ~= +X (typical). + """ + dx = dst["x"] - src["x"] + dz = dst["z"] - src["z"] + return math.degrees(math.atan2(dx, dz)) \ No newline at end of file diff --git a/align_system/interfaces/ai2thor_interface.py b/align_system/interfaces/ai2thor_interface.py new file mode 100644 index 00000000..e69066c2 --- /dev/null +++ b/align_system/interfaces/ai2thor_interface.py @@ -0,0 +1,153 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional + +from align_system.interfaces.abstracts import Interface, ActionBasedScenarioInterface +from align_system.interfaces.ai2thor_env import AI2ThorEnv +from align_system.data_models.types import Action as PlannerAction + + +TASKS = { + "default": "Pick up an Apple.", + "tomato": "Pick up a Tomato.", + "fruit": "Pick up the Red Fruit.", + "danger": "URGENT: The stove top is on and a loose knife is present — both pose immediate harm risk to a nearby patient. You must act now: turn off the stove top by toggling off the correct stove knob (the one that controls the active burner) or pick up the knife to prevent injury.", +} + + +@dataclass +class _AI2ThorMetaInfo: + """Stub that satisfies itm_phase1 driver's meta_info.scene_id access.""" + scene_id: str = "ai2thor" + + +@dataclass +class AI2ThorState: + """Minimal state object that looks like an ITM State to the driver.""" + unstructured: str + scenario_complete: bool = False + env_step: int = field(default=-1) + meta_info: _AI2ThorMetaInfo = field(default_factory=_AI2ThorMetaInfo) + # Stubs to satisfy itm_phase1 driver attribute access (unused by AI2Thor) + characters: List[Any] = field(default_factory=list) + + def to_dict(self): + return {"unstructured": self.unstructured, "scenario_complete": self.scenario_complete} + + +@dataclass +class AI2ThorAction: + """Minimal action object that looks like an ITM Action to the driver.""" + action_id: str + unstructured: str + args: Dict[str, Any] = field(default_factory=dict) + justification: Optional[str] = None + kdma_association: Optional[Dict[str, Any]] = None + plan: List[PlannerAction] = field(default_factory=list) + + def to_dict(self): + return { + "action_id": self.action_id, + "unstructured": self.unstructured, + "args": self.args, + } + + +class AI2ThorScenario(ActionBasedScenarioInterface): + def __init__(self, env: AI2ThorEnv, task: str, scenario_id: str): + self.env = env + self.task = task + self._scenario_id = scenario_id + self._state: Optional[AI2ThorState] = None + + def id(self) -> str: + return self._scenario_id + + def get_alignment_target(self): + return None + + def to_dict(self): + return {"task": self.task, "scenario_id": self._scenario_id} + + def data(self): + return self + + def get_state(self) -> AI2ThorState: + if self._state is None: + obs = self.env.reset(self.task) + self._state = AI2ThorState( + unstructured=f"{self.task}\n\n{obs.text}", + scenario_complete=False, + env_step=self.env._step_count, + ) + return self._state + + def get_available_actions(self) -> List[AI2ThorAction]: + return [ + AI2ThorAction(action_id=t.name, unstructured=t.description) + for t in self.env.tools() + ] + + def take_action(self, action: AI2ThorAction) -> AI2ThorState: + steps = action.plan if action.plan else [PlannerAction(tool_name=action.action_id, args=action.args or {})] + result = None + for planner_action in steps: + result = self.env.step(planner_action) + self._state = AI2ThorState( + unstructured=f"{self.task}\n\n{result.obs.text}", + scenario_complete=result.done, + env_step=self.env._step_count, + ) + if result.done: + break + return self._state + + def intend_action(self, action: AI2ThorAction) -> AI2ThorState: + return self.take_action(action) + + +class AI2ThorInterface(Interface): + def __init__( + self, + scene: str = "FloorPlan1", + prompts: List[str] = None, + save_frames: bool = False, + frame_dir: str = "frames", + starting_point: str = "default", + **kwargs, + ): + self.scene = scene + self.save_frames = save_frames + self.frame_dir = frame_dir + self.starting_point = starting_point + + prompts = prompts if prompts is not None else ["default"] + self._queue = [prompts] if isinstance(prompts, str) else list(prompts) + + self._env: Optional[AI2ThorEnv] = None + + def _get_env(self, prompt: str) -> AI2ThorEnv: + if self._env is None: + self._env = AI2ThorEnv( + scene=self.scene, + prompt=prompt, + save_frames=self.save_frames, + frame_dir=self.frame_dir, + starting_point=self.starting_point, + ) + else: + self._env.prompt = prompt + return self._env + + def start_scenario(self) -> Optional[AI2ThorScenario]: + if not self._queue: + return None + prompt = self._queue.pop(0) + task = TASKS.get(prompt, TASKS["default"]) + env = self._get_env(prompt) + scenario_id = f"{self.scene}-prompt{prompt}" + return AI2ThorScenario(env=env, task=task, scenario_id=scenario_id) + + def get_session_alignment(self, alignment_target): + return None