diff --git a/aif_gen/cli/commands/generate.py b/aif_gen/cli/commands/generate.py index a516cab8..23423c50 100644 --- a/aif_gen/cli/commands/generate.py +++ b/aif_gen/cli/commands/generate.py @@ -79,6 +79,25 @@ default=0.99, help='Temperature for sampling from the model.', ) +@click.option( + '--n-candidates', + type=click.IntRange(min=2, max=64, clamp=True), + default=None, + help=( + 'Enable Sample-N → Score → Select pipeline (RLCD + HelpSteer2 + West-of-N) ' + 'with this many candidate responses per prompt. If unset (and no `generation` ' + 'block is present in the config), uses the legacy joint-generation pipeline.' + ), +) +@click.option( + '--target-margin', + type=click.FloatRange(min=0.0, max=1.0, clamp=True), + default=None, + help=( + 'Difficulty target for the sampled pipeline (0 = hardest, smallest score gap; ' + '1 = easiest, largest score gap). Defaults to 0.6 when --n-candidates is set.' + ), +) def generate( data_config_name: pathlib.Path, model: str, @@ -91,6 +110,8 @@ def generate( hf_repo_id: Optional[str], include_preference_axes: bool, temperature: float, + n_candidates: Optional[int], + target_margin: Optional[float], ) -> None: r"""Generate a new ContinualAlignmentDataset. @@ -105,6 +126,21 @@ def generate( data_config = yaml.safe_load(data_config_name.read_text()) logging.debug(f'Configuration: {data_config}') + # Resolve generation pipeline config: CLI flags override the YAML `generation:` block. + yaml_generation = ( + data_config.get('generation') if isinstance(data_config, dict) else None + ) + generation_config: Optional[dict] = None + if n_candidates is not None or yaml_generation is not None: + generation_config = dict(yaml_generation) if yaml_generation else {} + if n_candidates is not None: + generation_config['n_candidates'] = n_candidates + if target_margin is not None: + generation_config['target_margin'] = target_margin + generation_config.setdefault('n_candidates', 6) + generation_config.setdefault('target_margin', 0.6) + logging.info(f'Sampled-pipeline generation_config: {generation_config}') + output_file.parent.mkdir(parents=True, exist_ok=True) if not dry_run: @@ -131,6 +167,7 @@ def generate( dry_run, include_preference_axes=include_preference_axes, temperature=temperature, + generation_config=generation_config, ) dataset = asyncio.get_event_loop().run_until_complete(future) if dataset is not None: diff --git a/aif_gen/generate/engine.py b/aif_gen/generate/engine.py index beab95c2..9df6cedc 100644 --- a/aif_gen/generate/engine.py +++ b/aif_gen/generate/engine.py @@ -18,6 +18,7 @@ ) from aif_gen.generate.caching import AsyncElasticsearchCache from aif_gen.generate.mappers import PromptMapper, ResponseMapper +from aif_gen.generate.mappers.pair_selector import ScoredCandidate, select_pair from aif_gen.task.alignment_task import AlignmentTask @@ -31,6 +32,7 @@ async def generate_continual_dataset( dry_run: bool = False, include_preference_axes: bool = False, temperature: float = 1.0, + generation_config: Optional[Dict[str, Any]] = None, ) -> Optional[ContinualAlignmentDataset]: r"""Generate a ContinualAlignmentDataset dataset given the AlignmentTask, and model. @@ -44,6 +46,13 @@ async def generate_continual_dataset( dry_run (bool): If True, ignore the config and generate a dummy sample to ensure the model is setup correctly. include_preference_axes (bool): If True, include the preference axes in the prompt for response mapper. temperature (float): Temperature for the model. + generation_config (Optional[Dict[str, Any]]): Optional dict enabling the + Sample-N → Score → Select pipeline (RLCD + HelpSteer2 + West-of-N). + When provided, overrides the legacy joint-generation path. Recognized + keys: ``n_candidates`` (int), ``target_margin`` (float in [0, 1]), + ``min_margin`` (float), ``max_margin`` (Optional[float]), + ``length_ratio_band`` ([float, float]), ``rubric_weights`` (dict), + ``persona_schedule`` (Optional[List[(str, float)]]). Returns: Optional[ContinualAlignmentDataset]: The synthetically generated dataset. @@ -51,25 +60,50 @@ async def generate_continual_dataset( prompt_mapper = PromptMapper() response_mapper = ResponseMapper() task_specs = data_config['task_specs'] + use_sampled_pipeline = generation_config is not None + if use_sampled_pipeline: + assert generation_config is not None # for type narrowing + logging.info( + f'Using Sample-N → Score → Select generation pipeline: {generation_config}' + ) if dry_run: logging.info(f'Doing dry-run data generation on a single sample...') mock_task = AlignmentTask.from_dict(task_specs[0]['alignment_task']) - coro = _generate_sample( - mock_task, - client, - model_name, - prompt_mapper, - response_mapper, - async_semaphore, - max_tokens_prompt_response, - max_tokens_chosen_rejected_response, - dataset_idx=-1, - prompt_idx=-1, - cache=None, - include_preference_axes=include_preference_axes, - temperature=temperature, - ) + if use_sampled_pipeline: + assert generation_config is not None + coro = _generate_sample_sampled( + mock_task, + client, + model_name, + prompt_mapper, + response_mapper, + async_semaphore, + max_tokens_prompt_response, + max_tokens_chosen_rejected_response, + dataset_idx=-1, + prompt_idx=-1, + cache=None, + judge_cache=None, + generation_config=generation_config, + temperature=temperature, + ) + else: + coro = _generate_sample( + mock_task, + client, + model_name, + prompt_mapper, + response_mapper, + async_semaphore, + max_tokens_prompt_response, + max_tokens_chosen_rejected_response, + dataset_idx=-1, + prompt_idx=-1, + cache=None, + include_preference_axes=include_preference_axes, + temperature=temperature, + ) try: _ = await coro except BaseException as e: @@ -81,6 +115,11 @@ async def generate_continual_dataset( cache = await AsyncElasticsearchCache.maybe_from_env_var( index_name=f'CACHE_DATA_GENERATION_{model_name}' ) + judge_cache = None + if use_sampled_pipeline: + judge_cache = await AsyncElasticsearchCache.maybe_from_env_var( + index_name=f'CACHE_DATA_GENERATION_RUBRIC_{model_name}' + ) futures, tasks, dataset_sizes = [], [], [] for dataset_idx, task_spec in enumerate(task_specs): task = AlignmentTask.from_dict(task_spec['alignment_task']) @@ -90,21 +129,40 @@ async def generate_continual_dataset( tasks.append(task) dataset_sizes.append(dataset_size) for _sample_idx in range(dataset_size): - coro = _generate_sample( - task, - client, - model_name, - prompt_mapper, - response_mapper, - async_semaphore, - max_tokens_prompt_response, - max_tokens_chosen_rejected_response, - dataset_idx=dataset_idx, - prompt_idx=_sample_idx, - cache=cache, - include_preference_axes=include_preference_axes, - temperature=temperature, - ) + if use_sampled_pipeline: + assert generation_config is not None + coro = _generate_sample_sampled( + task, + client, + model_name, + prompt_mapper, + response_mapper, + async_semaphore, + max_tokens_prompt_response, + max_tokens_chosen_rejected_response, + dataset_idx=dataset_idx, + prompt_idx=_sample_idx, + cache=cache, + judge_cache=judge_cache, + generation_config=generation_config, + temperature=temperature, + ) + else: + coro = _generate_sample( + task, + client, + model_name, + prompt_mapper, + response_mapper, + async_semaphore, + max_tokens_prompt_response, + max_tokens_chosen_rejected_response, + dataset_idx=dataset_idx, + prompt_idx=_sample_idx, + cache=cache, + include_preference_axes=include_preference_axes, + temperature=temperature, + ) futures.append(asyncio.create_task(coro)) try: @@ -125,8 +183,9 @@ async def generate_continual_dataset( ) continual_dataset.append(AlignmentDataset(tasks[i], samples[i])) - # If preference axes included, use judge to pick chosen/rejected responses - if include_preference_axes: + # If preference axes included, use judge to pick chosen/rejected responses. + # Skipped under the sampled pipeline, which already orders the pair by rubric score. + if include_preference_axes and not use_sampled_pipeline: from aif_gen.validation.llm_judge import ( _get_judge_prompt, _get_score, @@ -221,6 +280,8 @@ async def generate_continual_dataset( finally: if cache is not None: await cache.close() + if judge_cache is not None: + await judge_cache.close() @lru_cache(maxsize=None) @@ -387,3 +448,210 @@ async def _generate_sample( except pydantic.ValidationError as e: logging.error(f'Failed to bind structured output json schema: {e}') return None + + +# --------------------------------------------------------------------------- +# Sample-N → Score → Select pipeline (RLCD + HelpSteer2 + West-of-N) +# --------------------------------------------------------------------------- + + +@backoff.on_exception( + backoff.expo, + (openai.RateLimitError, openai.InternalServerError, openai.APITimeoutError), + max_tries=_get_tries(), +) +async def _generate_sample_sampled( + task: AlignmentTask, + client: openai.AsyncOpenAI, + model_name: str, + prompt_mapper: PromptMapper, + response_mapper: ResponseMapper, + async_semaphore: asyncio.Semaphore, + max_tokens_prompt_response: int, + max_tokens_chosen_rejected_response: int, + dataset_idx: int, + prompt_idx: int, + generation_config: Optional[Dict[str, Any]], + cache: 'AsyncElasticsearchCache | None' = None, + judge_cache: 'AsyncElasticsearchCache | None' = None, + temperature: float = 1.0, +) -> Optional[Tuple[AlignmentDatasetSample, int]]: + r"""Generate one sample via Sample-N → Score → Select. + + 1. Generate the task prompt (same as legacy). + 2. Sample N candidate responses with varied (persona, temperature). + 3. Score each candidate with the rubric judge (HelpSteer2-style). + 4. Select a (chosen, rejected) pair whose score gap matches `target_margin`. + """ + # Local import to avoid circular import at module load time. + from aif_gen.validation.llm_judge import ( + _get_rubric_judge_prompt, + _get_rubric_score, + ) + + if generation_config is None: + generation_config = {} + n_candidates = int(generation_config.get('n_candidates', 6)) + target_margin = float(generation_config.get('target_margin', 0.6)) + min_margin = float(generation_config.get('min_margin', 0.3)) + raw_max_margin = generation_config.get('max_margin', None) + max_margin = float(raw_max_margin) if raw_max_margin is not None else None + length_band_raw = generation_config.get('length_ratio_band', [0.5, 2.0]) + length_ratio_band = (float(length_band_raw[0]), float(length_band_raw[1])) + rubric_weights = generation_config.get('rubric_weights', None) + persona_schedule = generation_config.get('persona_schedule', None) + if persona_schedule is None: + persona_schedule = response_mapper.default_persona_schedule( + n_candidates, base_temperature=temperature + ) + else: + persona_schedule = [(str(p), float(t)) for p, t in persona_schedule] + if len(persona_schedule) != n_candidates: + raise ValueError( + f'persona_schedule length ({len(persona_schedule)}) must equal ' + f'n_candidates ({n_candidates})' + ) + + try: + # ---- 1. Generate the task prompt (cached, same as legacy) ---- + meta_prompt = prompt_mapper.generate_prompt(task) + meta_prompt_nonce = f'{prompt_idx}' + + async with async_semaphore: + output = None + if cache is not None: + output = await cache.get(meta_prompt, nonce=meta_prompt_nonce) + if output is None: + response = await client.chat.completions.create( + model=model_name, + messages=[{'role': 'user', 'content': meta_prompt}], + max_tokens=max_tokens_prompt_response, + response_format={ + 'type': 'json_schema', + 'json_schema': { + 'name': 'PromptProposal', + 'schema': _PromptProposal.model_json_schema(), + 'strict': True, + }, + }, + temperature=temperature, + ) + output = response.choices[0].message.content + + if output is None: + raise ValueError(f'Received None response to meta prompt: {meta_prompt}') + prompt = _PromptProposal.model_validate_json(output).prompt + if cache is not None: + await cache.set(query=meta_prompt, value=output, nonce=meta_prompt_nonce) + + # ---- 2. Sample N candidate responses in parallel ---- + async def _one_candidate( + persona: str, t: float, cand_idx: int + ) -> Optional[str]: + cand_prompt = response_mapper.generate_candidate_prompt( + task, prompt, persona + ) + cand_nonce = f'{prompt_idx}:cand:{cand_idx}:{persona}:{t:.3f}' + async with async_semaphore: + cand_output = None + if cache is not None: + cand_output = await cache.get(cand_prompt, nonce=cand_nonce) + if cand_output is None: + resp = await client.chat.completions.create( + model=model_name, + messages=[{'role': 'user', 'content': cand_prompt}], + max_tokens=max_tokens_chosen_rejected_response, + response_format={ + 'type': 'json_schema', + 'json_schema': { + 'name': 'CandidateResponse', + 'schema': _Response.model_json_schema(), + 'strict': True, + }, + }, + temperature=t, + ) + cand_output = resp.choices[0].message.content + if cand_output is None: + return None + try: + parsed = _Response.model_validate_json(cand_output) + except pydantic.ValidationError as e: + logging.warning(f'Candidate {cand_idx} schema parse failed: {e}') + return None + if cache is not None: + await cache.set(query=cand_prompt, value=cand_output, nonce=cand_nonce) + return parsed.response + + cand_coros = [ + _one_candidate(persona, t, i) + for i, (persona, t) in enumerate(persona_schedule) + ] + cand_texts = await asyncio.gather(*cand_coros, return_exceptions=False) + + # ---- 3. Score each successful candidate with the rubric judge ---- + async def _score_one(text: Optional[str]) -> Optional[float]: + if text is None: + return None + judge_prompt = _get_rubric_judge_prompt( + prompt=prompt, + response=text, + preference=task.preference, + objective=task.objective, + ) + return await _get_rubric_score( + judge_prompt, + client, + model_name, + async_semaphore, + max_tokens_judge_response=64, + cache=judge_cache, + weights=rubric_weights, + ) + + score_coros = [_score_one(t) for t in cand_texts] + scores = await asyncio.gather(*score_coros, return_exceptions=False) + + scored: List[ScoredCandidate] = [] + for (persona, _t), text, score in zip(persona_schedule, cand_texts, scores): + if text is None or score is None: + continue + scored.append(ScoredCandidate(text=text, score=score, persona=persona)) + + if len(scored) < 2: + logging.warning( + f'Only {len(scored)}/{n_candidates} usable candidates for prompt_idx={prompt_idx}; dropping sample.' + ) + return None + + # ---- 4. Select pair by margin ---- + pair = select_pair( + scored, + target_margin=target_margin, + min_margin=min_margin, + max_margin=max_margin, + length_ratio_band=length_ratio_band, + ) + if pair is None: + logging.info( + f'No admissible pair for prompt_idx={prompt_idx} ' + f'(target_margin={target_margin}, min_margin={min_margin}); dropping sample.' + ) + return None + + chosen_cand, rejected_cand = pair + sample = AlignmentDatasetSample( + prompt=prompt, + chosen=chosen_cand.text, + rejected=rejected_cand.text, + ) + logging.debug( + f'Sampled pipeline: prompt_idx={prompt_idx} ' + f'chosen_score={chosen_cand.score:.2f} ({chosen_cand.persona}) ' + f'rejected_score={rejected_cand.score:.2f} ({rejected_cand.persona})' + ) + return sample, dataset_idx + + except pydantic.ValidationError as e: + logging.error(f'Failed to bind structured output json schema: {e}') + return None diff --git a/aif_gen/generate/mappers/__init__.py b/aif_gen/generate/mappers/__init__.py index aee858af..2b58ccc6 100644 --- a/aif_gen/generate/mappers/__init__.py +++ b/aif_gen/generate/mappers/__init__.py @@ -1,3 +1,10 @@ from aif_gen.generate.mappers.base import PromptMapperBase, ResponseMapperBase +from aif_gen.generate.mappers.pair_selector import ScoredCandidate, select_pair from aif_gen.generate.mappers.prompt_mapper import PromptMapper -from aif_gen.generate.mappers.response_mapper import ResponseMapper +from aif_gen.generate.mappers.response_mapper import ( + PERSONA_ALIGNED, + PERSONA_ANTI_ALIGNED, + PERSONA_NEUTRAL, + VALID_PERSONAS, + ResponseMapper, +) diff --git a/aif_gen/generate/mappers/pair_selector.py b/aif_gen/generate/mappers/pair_selector.py new file mode 100644 index 00000000..5da0219f --- /dev/null +++ b/aif_gen/generate/mappers/pair_selector.py @@ -0,0 +1,98 @@ +"""Pair selection over scored response candidates. + +Implements the West-of-N style margin-controlled pair selection +(arXiv:2401.12086) generalized so that the user can target any margin in +[0, 1] rather than always picking max-vs-min, plus a length-ratio filter to +mitigate the well-known length-hacking failure mode in RLHF reward models +(Singhal et al., arXiv:2310.03716). +""" + +from dataclasses import dataclass +from typing import List, Optional, Tuple + + +@dataclass +class ScoredCandidate: + r"""A single response candidate paired with its rubric score and origin metadata. + + Args: + text (str): The candidate response text. + score (float): The aggregated rubric score (typically in [1, 5]). + persona (str): Generation persona used for this candidate + (e.g., 'aligned', 'anti_aligned', 'neutral'). Used as a tie-breaker + to encourage semantic, not stylistic, contrast in the chosen pair. + """ + + text: str + score: float + persona: str + + +def select_pair( + candidates: List[ScoredCandidate], + target_margin: float, + min_margin: float = 0.0, + max_margin: Optional[float] = None, + length_ratio_band: Tuple[float, float] = (0.5, 2.0), +) -> Optional[Tuple[ScoredCandidate, ScoredCandidate]]: + r"""Select a (chosen, rejected) pair from N scored candidates by score margin. + + The score gap of the selected pair targets: + target_gap = (s_max - s_min) * target_margin + where ``target_margin=1.0`` gives the easiest pair (best vs worst, the + classic West-of-N criterion) and ``target_margin → 0`` gives the hardest + pair (smallest admissible score gap). Pairs outside ``[min_margin, max_margin]`` + on the raw score scale, or outside ``length_ratio_band`` on the + chosen/rejected length ratio, are rejected. + + Args: + candidates: Pool of scored response candidates. Must contain at least 2. + target_margin: Difficulty knob in [0, 1]. 0 = hardest, 1 = easiest. + min_margin: Minimum admissible raw score gap. Pairs below are dropped + as ambiguous. + max_margin: Maximum admissible raw score gap. None disables the cap. + length_ratio_band: (lo, hi) bounds on len(chosen)/len(rejected). + + Returns: + (chosen, rejected) tuple, or None if no admissible pair exists. + """ + if not 0.0 <= target_margin <= 1.0: + raise ValueError(f'target_margin must be in [0, 1], got {target_margin}') + if len(candidates) < 2: + return None + + sorted_cands = sorted(candidates, key=lambda c: c.score) + s_min, s_max = sorted_cands[0].score, sorted_cands[-1].score + target_gap = (s_max - s_min) * target_margin + + lo, hi = length_ratio_band + + best: Optional[Tuple[ScoredCandidate, ScoredCandidate]] = None + best_key: Optional[Tuple[float, int]] = None + + n = len(sorted_cands) + for i in range(n): + for j in range(i + 1, n): + low, high = sorted_cands[i], sorted_cands[j] + gap = high.score - low.score + if gap < min_margin: + continue + if max_margin is not None and gap > max_margin: + continue + len_low = max(len(low.text), 1) + len_high = max(len(high.text), 1) + ratio = len_high / len_low + if not (lo <= ratio <= hi): + continue + + # Primary key: distance from target gap. + # Secondary key: prefer pairs from *different* personas (encourages + # semantic, not stylistic, contrast). 0 = different, 1 = same. + distance = abs(gap - target_gap) + same_persona = int(low.persona == high.persona) + key = (distance, same_persona) + if best_key is None or key < best_key: + best_key = key + best = (high, low) # (chosen=higher score, rejected=lower) + + return best diff --git a/aif_gen/generate/mappers/response_mapper.py b/aif_gen/generate/mappers/response_mapper.py index c5923520..51b52072 100644 --- a/aif_gen/generate/mappers/response_mapper.py +++ b/aif_gen/generate/mappers/response_mapper.py @@ -1,11 +1,21 @@ import random from textwrap import dedent -from typing import Optional, Tuple +from typing import List, Optional, Tuple from aif_gen.task import AlignmentTask from .base import ResponseMapperBase +# Personas used by the candidate-sampling pipeline (RLCD-style contrastive +# prompting, arXiv:2307.12950). The single-model substitute for multi-model +# fan-out: an `aligned` prompt encourages the preference, an `anti_aligned` +# prompt encourages violating it (while staying on-topic), and a `neutral` +# prompt omits the preference instruction entirely. +PERSONA_ALIGNED = 'aligned' +PERSONA_ANTI_ALIGNED = 'anti_aligned' +PERSONA_NEUTRAL = 'neutral' +VALID_PERSONAS = (PERSONA_ALIGNED, PERSONA_ANTI_ALIGNED, PERSONA_NEUTRAL) + class ResponseMapper(ResponseMapperBase): r"""Generate a prompt that, when given to a language model, produces a winning and losing response to the task_prompt. @@ -93,3 +103,93 @@ def _preference_axes_scale( desc += f'On a scale of {min_score} to {max_score} where {min_score} is {axis[0]} and {max_score} is {axis[1]}, your response should be: {scores[i]}\n' desc += 'Please ensure your responses aligns with the provided scores.' return desc + + # ------------------------------------------------------------------ + # New "Sample-N → Score → Select" pipeline (RLCD + West-of-N + HelpSteer2) + # ------------------------------------------------------------------ + def generate_candidate_prompt( + self, + task: AlignmentTask, + task_prompt: str, + persona: str, + ) -> str: + r"""Generate a single-response prompt conditioned on a persona. + + This is the RLCD-style contrastive prompting (arXiv:2307.12950): instead + of asking one model in one call to produce both chosen and rejected, + we sample N independent responses with opposing instructions. The + difference in the *prompts* is what produces meaningfully differentiated + outputs from a single base model. + + Args: + task: AlignmentTask containing objective, preference, domain. + task_prompt: The prompt the response should answer. + persona: One of 'aligned', 'anti_aligned', 'neutral'. + + Returns: + Prompt string for a single response generation call. + """ + if persona not in VALID_PERSONAS: + raise ValueError( + f'persona must be one of {VALID_PERSONAS}, got {persona!r}' + ) + + if persona == PERSONA_ALIGNED: + preference_clause = ( + f"You MUST strictly follow this preference in every aspect of your response: '{task.preference}'.\n" + 'Make the preference clearly evident in style, tone, and content.\n' + ) + elif persona == PERSONA_ANTI_ALIGNED: + preference_clause = ( + f"You MUST deliberately violate this preference while still answering the prompt and staying on-topic: '{task.preference}'.\n" + 'Do not adopt the preferred style/tone/content; produce something that clearly does not satisfy the preference.\n' + ) + else: # neutral + preference_clause = 'Answer the prompt naturally without considering any particular stylistic preference.\n' + + prompt = f"""\ + Generate a single response to the following prompt: '{task_prompt}'. + {preference_clause}You don't need to start your response by saying "here is the response" nor to give any meta-explanation. Just provide the response. + """ + if self.suffix_context: + prompt += self.suffix_context + return dedent(prompt) + + @staticmethod + def default_persona_schedule( + n_candidates: int, base_temperature: float = 1.0 + ) -> List[Tuple[str, float]]: + r"""Default (persona, temperature) schedule for N candidate samples. + + Mixes aligned / anti_aligned / neutral with low and high temperatures + so the pool spans a wide rubric-score range. Deterministic given N. + + Args: + n_candidates: Number of candidates per prompt. + base_temperature: Center temperature; temperatures are jittered + around this value. + + Returns: + List of (persona, temperature) of length n_candidates. + """ + if n_candidates < 2: + raise ValueError(f'n_candidates must be >= 2, got {n_candidates}') + + t_lo = max(0.1, base_temperature - 0.3) + t_hi = min(2.0, base_temperature + 0.2) + rotation = [ + (PERSONA_ALIGNED, base_temperature), + (PERSONA_ANTI_ALIGNED, t_hi), + (PERSONA_NEUTRAL, base_temperature), + (PERSONA_ALIGNED, t_lo), + (PERSONA_ANTI_ALIGNED, base_temperature), + (PERSONA_NEUTRAL, t_hi), + ] + # Ensure we always have at least one aligned and one anti_aligned for + # well-defined contrast even at very small N. + schedule = [rotation[i % len(rotation)] for i in range(n_candidates)] + if not any(p == PERSONA_ALIGNED for p, _ in schedule): + schedule[0] = (PERSONA_ALIGNED, base_temperature) + if not any(p == PERSONA_ANTI_ALIGNED for p, _ in schedule): + schedule[-1] = (PERSONA_ANTI_ALIGNED, t_hi) + return schedule diff --git a/aif_gen/validation/llm_judge.py b/aif_gen/validation/llm_judge.py index 9097ae7b..bfef24eb 100644 --- a/aif_gen/validation/llm_judge.py +++ b/aif_gen/validation/llm_judge.py @@ -231,6 +231,114 @@ def _get_coherence_prompt(response: str) -> str: ) +# --------------------------------------------------------------------------- +# Rubric judge for the Sample-N → Score → Select pipeline (HelpSteer2-style, +# arXiv:2406.08673). Returns a per-axis 1–5 score for a single response, and an +# aggregated scalar in [1, 5] used by the pair selector. +# --------------------------------------------------------------------------- + +DEFAULT_RUBRIC_WEIGHTS: Dict[str, float] = { + 'preference_adherence': 0.6, + 'objective_fidelity': 0.25, + 'coherence': 0.15, +} + + +class _RubricResponse(pydantic.BaseModel, extra='forbid'): + preference_adherence: int + objective_fidelity: int + coherence: int + + +def _get_rubric_judge_prompt( + prompt: str, response: str, preference: str, objective: str +) -> str: + return ( + 'You are an expert evaluator. Score the following response on three axes, ' + 'each on an INTEGER scale from 1 to 5 (1 = very poor, 5 = excellent).\n\n' + f'OBJECTIVE: {objective}\n' + f'PREFERENCE: {preference}\n' + f'PROMPT: {prompt}\n' + f'RESPONSE: {response}\n\n' + 'Score the response on:\n' + ' - preference_adherence: How well does the response follow the PREFERENCE? ' + '(1 = ignores or contradicts the preference; 5 = perfectly embodies it.)\n' + ' - objective_fidelity: How well does the response stay on-topic to the OBJECTIVE? ' + '(1 = off-topic; 5 = perfectly addresses the objective.)\n' + ' - coherence: Is the response fluent, internally consistent, and grammatical? ' + '(1 = incoherent; 5 = perfectly coherent.)\n' + 'Respond ONLY as JSON with integer fields ' + '`preference_adherence`, `objective_fidelity`, `coherence`.' + ) + + +def _aggregate_rubric( + pa: int, of: int, c: int, weights: Optional[Dict[str, float]] = None +) -> float: + r"""Aggregate per-axis 1–5 scores into a single scalar in [1, 5].""" + w = weights or DEFAULT_RUBRIC_WEIGHTS + return ( + w['preference_adherence'] * pa + + w['objective_fidelity'] * of + + w['coherence'] * c + ) + + +@backoff.on_exception( + backoff.expo, + (openai.RateLimitError, openai.InternalServerError, openai.APITimeoutError), + max_tries=_get_tries(), +) +async def _get_rubric_score( + prompt: str, + client: openai.AsyncOpenAI, + model_name: str, + async_semaphore: asyncio.Semaphore, + max_tokens_judge_response: int = 64, + cache: Optional[AsyncElasticsearchCache] = None, + weights: Optional[Dict[str, float]] = None, +) -> Optional[float]: + r"""Call the judge model with a rubric prompt; return aggregated [1, 5] score, or None on failure.""" + try: + async with async_semaphore: + model_response: Optional[str] = None + if cache is not None: + model_response = await cache.get(prompt) + + if model_response is None: + response = await client.chat.completions.create( + model=model_name, + temperature=0, + messages=[{'role': 'user', 'content': prompt}], + max_tokens=max_tokens_judge_response, + response_format={ + 'type': 'json_schema', + 'json_schema': { + 'name': 'RubricResponse', + 'schema': _RubricResponse.model_json_schema(), + 'strict': True, + }, + }, + ) + model_response = response.choices[0].message.content + if model_response is None: + raise ValueError(f'Received None response to prompt: {prompt}') + + parsed = _RubricResponse.model_validate_json(model_response) + # Clamp to the documented 1–5 range to guard against off-spec outputs. + pa = max(1, min(5, parsed.preference_adherence)) + of = max(1, min(5, parsed.objective_fidelity)) + c = max(1, min(5, parsed.coherence)) + + if cache is not None: + await cache.set(query=prompt, value=model_response) + + return _aggregate_rubric(pa, of, c, weights) + except pydantic.ValidationError as e: + logging.error(f'Failed to bind rubric output json schema: {e}') + return None + + def _compute_statistics(results: Dict[str, List[float]]) -> Dict[str, float]: statistics: Dict[str, float] = {} for metric, values in results.items(): diff --git a/test/test_generate/test_pair_selector.py b/test/test_generate/test_pair_selector.py new file mode 100644 index 00000000..ef1c3048 --- /dev/null +++ b/test/test_generate/test_pair_selector.py @@ -0,0 +1,89 @@ +"""Tests for the West-of-N style margin-controlled pair selector.""" + +import pytest + +from aif_gen.generate.mappers.pair_selector import ScoredCandidate, select_pair + + +def _mk(text: str, score: float, persona: str = 'aligned') -> ScoredCandidate: + return ScoredCandidate(text=text, score=score, persona=persona) + + +def test_select_pair_returns_none_for_too_few_candidates() -> None: + assert select_pair([], target_margin=0.5) is None + assert select_pair([_mk('a', 1.0)], target_margin=0.5) is None + + +def test_select_pair_target_margin_one_picks_widest_gap() -> None: + cands = [ + _mk('a' * 10, 1.0, 'aligned'), + _mk('b' * 10, 2.5, 'neutral'), + _mk('c' * 10, 5.0, 'anti_aligned'), + ] + chosen, rejected = select_pair(cands, target_margin=1.0) # type: ignore[misc] + assert chosen.score == 5.0 + assert rejected.score == 1.0 + + +def test_select_pair_target_margin_zero_picks_smallest_gap() -> None: + cands = [ + _mk('a' * 10, 1.0), + _mk('b' * 10, 1.2), + _mk('c' * 10, 5.0), + ] + chosen, rejected = select_pair(cands, target_margin=0.0) # type: ignore[misc] + assert chosen.score - rejected.score == pytest.approx(0.2) + + +def test_select_pair_monotone_in_target_margin() -> None: + cands = [_mk('x' * 10, float(s), 'aligned') for s in range(1, 6)] + gaps = [] + for tm in (0.0, 0.25, 0.5, 0.75, 1.0): + pair = select_pair(cands, target_margin=tm) + assert pair is not None + gaps.append(pair[0].score - pair[1].score) + # The achieved gap must be (weakly) non-decreasing in target_margin. + for a, b in zip(gaps, gaps[1:]): + assert a <= b + + +def test_select_pair_orders_chosen_above_rejected() -> None: + cands = [_mk('a' * 10, 4.5), _mk('b' * 10, 1.5)] + chosen, rejected = select_pair(cands, target_margin=1.0) # type: ignore[misc] + assert chosen.score > rejected.score + + +def test_select_pair_min_margin_filters_too_close_pairs() -> None: + cands = [_mk('a' * 10, 3.0), _mk('b' * 10, 3.05)] + assert select_pair(cands, target_margin=0.0, min_margin=0.5) is None + + +def test_select_pair_length_ratio_band_filters_lopsided_pairs() -> None: + short = _mk('x', 1.0) # length 1 + long = _mk('x' * 100, 5.0) # length 100, ratio = 100 > 2.0 + assert ( + select_pair([short, long], target_margin=1.0, length_ratio_band=(0.5, 2.0)) + is None + ) + + +def test_select_pair_prefers_different_personas_on_tie() -> None: + # Two pairs with identical gap; selector should prefer the one with + # different personas as a semantic contrast tie-breaker. + cands = [ + _mk('a' * 10, 1.0, 'aligned'), + _mk('b' * 10, 4.0, 'aligned'), # same-persona pair: gap 3 + _mk('c' * 10, 2.0, 'anti_aligned'), + _mk('d' * 10, 5.0, 'anti_aligned'), # same-persona pair: gap 3 + _mk('e' * 10, 2.5, 'neutral'), # different-persona pair with one above: gap 2.5 + ] + chosen, rejected = select_pair(cands, target_margin=1.0) # type: ignore[misc] + # max gap is 5.0 - 1.0 = 4.0 between aligned and anti_aligned (different personas) + assert chosen.persona != rejected.persona + + +def test_select_pair_invalid_target_margin_raises() -> None: + with pytest.raises(ValueError): + select_pair([_mk('a' * 10, 1.0), _mk('b' * 10, 2.0)], target_margin=1.5) + with pytest.raises(ValueError): + select_pair([_mk('a' * 10, 1.0), _mk('b' * 10, 2.0)], target_margin=-0.1) diff --git a/test/test_generate/test_response_mapper.py b/test/test_generate/test_response_mapper.py index 64a2c5c3..d53f8cd7 100644 --- a/test/test_generate/test_response_mapper.py +++ b/test/test_generate/test_response_mapper.py @@ -1,4 +1,10 @@ -from aif_gen.generate.mappers import ResponseMapper +from aif_gen.generate.mappers import ( + PERSONA_ALIGNED, + PERSONA_ANTI_ALIGNED, + PERSONA_NEUTRAL, + VALID_PERSONAS, + ResponseMapper, +) from aif_gen.task import AlignmentTask, Domain, DomainComponent @@ -49,3 +55,59 @@ def test_generate_no_preference_response(suffix_context): # suffix_context still shows up if present if suffix_context: assert suffix_context in prompt + + +def _mk_task() -> AlignmentTask: + domain = Domain(components=[DomainComponent(name='Health', seed_words=['x'])]) + return AlignmentTask( + domain=domain, + objective='mock', + preference='Generate responses that are vividly descriptive and engaging.', + ) + + +def test_generate_candidate_prompt_aligned_includes_preference(): + mapper = ResponseMapper() + task = _mk_task() + p = mapper.generate_candidate_prompt(task, 'tell a story', PERSONA_ALIGNED) + assert task.preference in p + assert 'follow' in p.lower() + + +def test_generate_candidate_prompt_anti_aligned_includes_violation_clause(): + mapper = ResponseMapper() + task = _mk_task() + p = mapper.generate_candidate_prompt(task, 'tell a story', PERSONA_ANTI_ALIGNED) + assert task.preference in p + assert 'violate' in p.lower() or 'deliberately' in p.lower() + + +def test_generate_candidate_prompt_neutral_makes_no_preference_claim(): + mapper = ResponseMapper() + task = _mk_task() + p = mapper.generate_candidate_prompt(task, 'tell a story', PERSONA_NEUTRAL) + # Neutral prompt must NOT instruct following the preference, and must NOT + # cite the task preference text. + assert task.preference not in p + assert 'naturally' in p.lower() or 'without considering' in p.lower() + + +def test_generate_candidate_prompt_rejects_unknown_persona(): + import pytest + + mapper = ResponseMapper() + task = _mk_task() + with pytest.raises(ValueError): + mapper.generate_candidate_prompt(task, 'tell a story', 'bogus') + + +def test_default_persona_schedule_length_and_coverage(): + mapper = ResponseMapper() + schedule = mapper.default_persona_schedule(6, base_temperature=1.0) + assert len(schedule) == 6 + personas = {p for p, _ in schedule} + assert PERSONA_ALIGNED in personas + assert PERSONA_ANTI_ALIGNED in personas + for p, t in schedule: + assert p in VALID_PERSONAS + assert 0.0 <= t <= 2.0