-
Notifications
You must be signed in to change notification settings - Fork 25.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
PaliGemma fix attention mask for finetunes #30918
PaliGemma fix attention mask for finetunes #30918
Conversation
Hey! Can you provide a reproducer for:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch, main comment is that separation between prefix and suffix is brittle as-is - we want to merge this fix as soon as possible though!
if sequence_length != 1: | ||
causal_mask = torch.triu(causal_mask, diagonal=1) | ||
|
||
mask = input_ids == self.config.prefix_suffix_separator_index |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not very robust unfortunately: one or several \n
tokens might be present in the prefix. Could this depend on the labels passed instead?
03cfe8e
to
9a732f4
Compare
I accidentally ran style on the whole repo and now the tests are failing cause I had to force push away updating like every file (இ﹏இ`。) I think all the style errors are in code not touched by this change |
Anyways OK I updated how we check where to split using a new index in the labels. Constructing the right labels array is now getting to be sort of tricky for new users, so now the PaliGemmaProcessor returns labels in BatchEncoding for easier finetuning. This is sort of necessary because the prefix and suffix might change length during tokenization and it'd be annoying to try to find the right index of that \n. Hope y'all are ok with this idea!!! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A lot of the logic you are re-implementing is natively supported in the tokenizer with the token_type_ids
. In this case I think it would be wise to leverage create_token_type_ids_from_sequences
checkout codegen for example.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @probicheaux , left another review! I think your causal mask building is fine, we just need to remove training
flag from the processor and pass the labels explicitly. Let me know if you have time to add the suggestions!
|
||
super().__init__(image_processor, tokenizer) | ||
|
||
def __call__( | ||
self, | ||
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, | ||
suffix: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
suffix: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, | |
labels: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None, |
suffix (`str`, `List[str]`, `List[List[str]]`): | ||
The suffix or batch of suffixes to be encoded. Only necessary for finetuning. See https://github.com/google-research/big_vision/blob/main/big_vision/configs/proj/paligemma/README.md | ||
for more information. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
suffix (`str`, `List[str]`, `List[List[str]]`): | |
The suffix or batch of suffixes to be encoded. Only necessary for finetuning. See https://github.com/google-research/big_vision/blob/main/big_vision/configs/proj/paligemma/README.md | |
for more information. | |
labels (`str`, `List[str]`, `List[List[str]]`): | |
The labels suffixes or batch of suffixes to be encoded. Only necessary for finetuning. See https://github.com/google-research/big_vision/blob/main/big_vision/configs/proj/paligemma/README.md | |
for more information. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall replacing suffix with labels where it's use, I think
else: | ||
text_inputs = self.tokenizer( | ||
inputs = self.tokenizer( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here I think the key is using the text_pair
option in the tokenizer along with return_token_type_ids=True
.
So
text = ["Is the witholded tax rate equal to 31%?"]
labels = ["Yes"]
self.tokenizer(text=text, text_pair=labels, return_token_type_ids=True)
This will give the additional key to inputs
of the form 'token_type_ids': [[0, 0, 0, 0, 0, 0, 0, 0, 1, 1]]
and will have also correct shape for batched outputs.
Then you already have your mask from token_type_ids
and don't need to specify the training
bool flag below. The token for newline character has to be inserted before the first 1, and assigned the token type id 0 so that it's part of the prompt with full block attention.
causal_mask = self.construct_causal_mask_with_block_attention( | ||
attention_mask, labels, text_mask, inputs_embeds | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here, the token_type_ids
can be passed from the forward since inputs
are unpacked and will contain them. Then, you already have your mask to build the causal mask!
mask = labels == self.config.prefix_suffix_separator_index | ||
# Get the index of the first \n in each row | ||
indices = mask.int().argmax(dim=1, keepdim=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Assuming you're getting the token_type_ids
here, you just have to do something like that
mask = labels == self.config.prefix_suffix_separator_index | |
# Get the index of the first \n in each row | |
indices = mask.int().argmax(dim=1, keepdim=True) | |
indices = (token_type_ids == 1).int().argmax(dim=1) |
Merged as part of #30967 |
What does this PR do?
Previously, the PaliGemma attention mask was a 4d attention mask of all 1s, which preempted any downstream causal mask creation by Gemma. PaliGemma has full attention over the image tokens and the prefix of the prompt, and causal attention over the suffix. This PR implements that logic, allowing for finetuning PaliGemma without label leakage.
(Additionally, I noticed that using the GemmaTokenizerFast led to incorrect tokenization of tokens like character by character, whereas using GemmaTokenizer did proper tokenization. Be sure to set
use_fast=False
when instantiating the PaliGemma processor. I'll make an issue for this observationedit: ah, looks like that is a known issue?).
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
@ArthurZucker @molbap @merveenoyan