Skip to content

Fix combine=True masking only last assistant turn in multi-turn finetuning#803

Open
Mr-Neutr0n wants to merge 1 commit intozai-org:mainfrom
Mr-Neutr0n:fix/combine-mode-multi-turn-labels
Open

Fix combine=True masking only last assistant turn in multi-turn finetuning#803
Mr-Neutr0n wants to merge 1 commit intozai-org:mainfrom
Mr-Neutr0n:fix/combine-mode-multi-turn-labels

Conversation

@Mr-Neutr0n
Copy link

Summary

When combine=True in process_batch(), the current label masking logic only finds the last occurrence of the assistant role token (151337) and unmasks tokens after it:

last_assistant_index = len(input_ids) - input_ids[::-1].index(151337) - 1
for j in range(last_assistant_index + 1, len(input_ids)):
    loss_masks[j] = True

In multi-turn conversations this means all earlier assistant responses are masked with -100 and never contribute to the training loss. The model effectively only learns from the final assistant reply, which wastes all intermediate assistant turns.

This is inconsistent with the non-combine branch, which correctly sets loss_mask_val = True for every assistant message.

Fix

Instead of searching backward for the last 151337, the fix iterates forward through the full token sequence and unmasks every assistant response segment — from each assistant role token (151337) through its corresponding end-of-turn token (151336). This matches the per-message masking behavior of the non-combine path.

for j in range(len(input_ids)):
    if input_ids[j] == 151337:  # assistant role token
        loss_masks[j] = True
        k = j + 1
        while k < len(input_ids):
            loss_masks[k] = True
            if input_ids[k] == 151336:  # end token
                break
            k += 1

Impact

Anyone using combine: true in their finetuning config with multi-turn conversation data was silently losing training signal from all assistant turns except the last one. This fix ensures all assistant responses contribute to the loss as intended.

…onversations

When combine=True in process_batch(), the label masking logic previously
found only the last occurrence of the assistant token (151337) and unmasked
tokens after it. This meant that in multi-turn conversations, all earlier
assistant responses were masked out with -100 and did not contribute to the
training loss. The model was effectively only learning from the final
assistant reply, wasting all intermediate assistant turns.

This fix iterates through the full token sequence and unmasks every
assistant response segment (from each 151337 marker through its
corresponding 151336 end token), matching the behavior of the non-combine
branch which correctly trains on all assistant turns.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant