Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions aif_gen/cli/commands/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,25 @@
default=0.99,
help='Temperature for sampling from the model.',
)
@click.option(
'--n-candidates',
type=click.IntRange(min=2, max=64, clamp=True),
default=None,
help=(
'Enable Sample-N → Score → Select pipeline (RLCD + HelpSteer2 + West-of-N) '
'with this many candidate responses per prompt. If unset (and no `generation` '
'block is present in the config), uses the legacy joint-generation pipeline.'
),
)
@click.option(
'--target-margin',
type=click.FloatRange(min=0.0, max=1.0, clamp=True),
default=None,
help=(
'Difficulty target for the sampled pipeline (0 = hardest, smallest score gap; '
'1 = easiest, largest score gap). Defaults to 0.6 when --n-candidates is set.'
),
)
def generate(
data_config_name: pathlib.Path,
model: str,
Expand All @@ -91,6 +110,8 @@ def generate(
hf_repo_id: Optional[str],
include_preference_axes: bool,
temperature: float,
n_candidates: Optional[int],
target_margin: Optional[float],
) -> None:
r"""Generate a new ContinualAlignmentDataset.

Expand All @@ -105,6 +126,21 @@ def generate(
data_config = yaml.safe_load(data_config_name.read_text())
logging.debug(f'Configuration: {data_config}')

# Resolve generation pipeline config: CLI flags override the YAML `generation:` block.
yaml_generation = (
data_config.get('generation') if isinstance(data_config, dict) else None
)
generation_config: Optional[dict] = None
if n_candidates is not None or yaml_generation is not None:
generation_config = dict(yaml_generation) if yaml_generation else {}
if n_candidates is not None:
generation_config['n_candidates'] = n_candidates
if target_margin is not None:
generation_config['target_margin'] = target_margin
generation_config.setdefault('n_candidates', 6)
generation_config.setdefault('target_margin', 0.6)
logging.info(f'Sampled-pipeline generation_config: {generation_config}')

output_file.parent.mkdir(parents=True, exist_ok=True)

if not dry_run:
Expand All @@ -131,6 +167,7 @@ def generate(
dry_run,
include_preference_axes=include_preference_axes,
temperature=temperature,
generation_config=generation_config,
)
dataset = asyncio.get_event_loop().run_until_complete(future)
if dataset is not None:
Expand Down
Loading
Loading