diff --git a/scripts/t_vaccine/harden.py b/scripts/t_vaccine/harden.py index a66e11c5..6f64c748 100644 --- a/scripts/t_vaccine/harden.py +++ b/scripts/t_vaccine/harden.py @@ -13,7 +13,8 @@ rho=3, S=8 layers, K=200 steps, N_h=200 harmful examples Expected runtime: ~1 hour on A100, ~23.5GB GPU memory. -Expected output: aligned LoRA adapter (adapter_model.safetensors). +Expected output: full merged checkpoint (LoRA adapter is merged into the base +model before saving), loadable directly via AutoModelForCausalLM / vLLM. """ import argparse @@ -44,6 +45,7 @@ def main(): parser.add_argument("--probability-steps", type=int, default=200, help="Probability recalc interval K (paper: 200)") parser.add_argument("--prompt-data-size", type=int, default=200, help="Harmful dataset size N_h (paper: 200)") parser.add_argument("--num-epochs", type=int, default=20, help="Training epochs (paper: 20)") + parser.add_argument("--learning-rate", type=float, default=1e-3, help="Learning rate (paper: 1e-3, tuned on Llama-2-7B)") # Dataset paths: The original T-Vaccine paper uses # beavertails_with_refusals_train.json from rosati2024immunization, generated @@ -82,6 +84,7 @@ def main(): print(f"Probability recalc interval (K): {args.probability_steps}") print(f"Harmful dataset size (N_h): {args.prompt_data_size}") print(f"Epochs: {args.num_epochs}") + print(f"Learning rate: {args.learning_rate}") print("=" * 80) config = TVaccineConfig( @@ -98,7 +101,7 @@ def main(): save_strategy="steps", save_steps=100000, save_total_limit=0, - learning_rate=1e-3, + learning_rate=args.learning_rate, weight_decay=0.1, warmup_ratio=0.1, lr_scheduler_type="cosine", diff --git a/src/tamperbench/whitebox/defenses/sdd/sdd.py b/src/tamperbench/whitebox/defenses/sdd/sdd.py index 0537230f..6766e177 100644 --- a/src/tamperbench/whitebox/defenses/sdd/sdd.py +++ b/src/tamperbench/whitebox/defenses/sdd/sdd.py @@ -137,6 +137,18 @@ def _load_tokenizer(model_name: str) -> PreTrainedTokenizer: if tokenizer.pad_token is None: tokenizer.add_special_tokens(special_tokens_dict={"pad_token": DEFAULT_PAD_TOKEN}) + # Base models (e.g. Meta-Llama-3-8B) lack a chat template, which SFTTrainer + # requires for conversational-format data. Set a minimal template so + # apply_chat_template works for both base and instruct models. + if not getattr(tokenizer, "chat_template", None): + tokenizer.chat_template = ( + "{% for message in messages %}" + "{% if message['role'] == 'user' %}### Instruction:\n{{ message['content'] }}\n\n" + "{% elif message['role'] == 'assistant' %}### Response:\n{{ message['content'] }}" + "{% endif %}{% endfor %}" + "{% if add_generation_prompt %}### Response:\n{% endif %}" + ) + return tokenizer diff --git a/src/tamperbench/whitebox/defenses/t_vaccine/t_vaccine_trainer.py b/src/tamperbench/whitebox/defenses/t_vaccine/t_vaccine_trainer.py index 9f848947..33238f06 100644 --- a/src/tamperbench/whitebox/defenses/t_vaccine/t_vaccine_trainer.py +++ b/src/tamperbench/whitebox/defenses/t_vaccine/t_vaccine_trainer.py @@ -33,8 +33,8 @@ def get_leaf_modules_with_grad(module): module_list = [] for name, module in module.named_modules(): if 'LlamaAttention' in str(type(module)) or 'OPTAttention' in str(type(module)) or 'Qwen2Attention' in str( - type(module)) or 'Gemma2Attention' in str(type(module)) or 'GemmaAttention' in str( - type(module)) or 'MistralAttention' in str(type(module)): + type(module)) or 'Qwen3Attention' in str(type(module)) or 'Gemma2Attention' in str( + type(module)) or 'GemmaAttention' in str(type(module)) or 'MistralAttention' in str(type(module)): module_list += [module] return module_list diff --git a/src/tamperbench/whitebox/defenses/t_vaccine/train.py b/src/tamperbench/whitebox/defenses/t_vaccine/train.py index 6f1f7c96..bf32760a 100644 --- a/src/tamperbench/whitebox/defenses/t_vaccine/train.py +++ b/src/tamperbench/whitebox/defenses/t_vaccine/train.py @@ -690,15 +690,15 @@ def _train_main( if training_args.bf16: model = model.to(torch.bfloat16) + # Only inject pad_token if the tokenizer lacks one — that's actually + # required for padded batched SFT. Do NOT inject bos/eos/unk: modern + # tokenizers (Qwen3, etc.) legitimately don't define BOS, and forcing + # DEFAULT_BOS_TOKEN ("") onto them collides with an unrelated vocab + # entry, writes a bogus bos_token_id into the saved config, and + # produces empty generations at inference. special_tokens_dict = dict() if tokenizer.pad_token is None: special_tokens_dict["pad_token"] = DEFAULT_PAD_TOKEN - if tokenizer.eos_token is None: - special_tokens_dict["eos_token"] = DEFAULT_EOS_TOKEN - if tokenizer.bos_token is None: - special_tokens_dict["bos_token"] = DEFAULT_BOS_TOKEN - if tokenizer.unk_token is None: - special_tokens_dict["unk_token"] = DEFAULT_UNK_TOKEN smart_tokenizer_and_embedding_resize( special_tokens_dict=special_tokens_dict,