-
Notifications
You must be signed in to change notification settings - Fork 16
Open
Description
Let's generalize the easy_gcg code to optimize prompts on a dataset of (x, y) pairs, where each x is the question and y is the answer.
We want to solve u := argmax_u E [P(y | u + x)] where the expectation is taken over the dataset (x, y) ~ D.
We can start by simply aggregating gradients for the swaps in GCG over multiple elements of the batch (
Magic_Words/magic_words/easy_gcg.py
Line 178 in 32840cd
| def stochastic_easy_gcg_qa_ids(question_ids: list[torch.Tensor], |
All that remains is to create an efficient batch_compute_score_dataset() function to compute the scores of each potential new prompt w.r.t. the dataset (
Magic_Words/magic_words/easy_gcg.py
Line 263 in 32840cd
| alt_scores = batch_compute_score_dataset(alt_prompt_ids, |
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels