Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PaliGemma fix attention mask for finetunes #30918

Conversation

probicheaux
Copy link
Contributor

@probicheaux probicheaux commented May 20, 2024

What does this PR do?

Previously, the PaliGemma attention mask was a 4d attention mask of all 1s, which preempted any downstream causal mask creation by Gemma. PaliGemma has full attention over the image tokens and the prefix of the prompt, and causal attention over the suffix. This PR implements that logic, allowing for finetuning PaliGemma without label leakage.

(Additionally, I noticed that using the GemmaTokenizerFast led to incorrect tokenization of tokens like character by character, whereas using GemmaTokenizer did proper tokenization. Be sure to set use_fast=False when instantiating the PaliGemma processor. I'll make an issue for this observation
edit: ah, looks like that is a known issue?).

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@ArthurZucker @molbap @merveenoyan

@pcuenca pcuenca requested a review from molbap May 20, 2024 21:37
@ArthurZucker
Copy link
Collaborator

ArthurZucker commented May 21, 2024

Hey! Can you provide a reproducer for:

  1. the tokenizer issues: it's probably just about the added tokens and the conversion
  2. you are right, when you finetune, you want the non-prompt text to be causal vs now everything is ones.

Copy link
Contributor

@molbap molbap left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, main comment is that separation between prefix and suffix is brittle as-is - we want to merge this fix as soon as possible though!

src/transformers/models/paligemma/modeling_paligemma.py Outdated Show resolved Hide resolved
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)

mask = input_ids == self.config.prefix_suffix_separator_index
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not very robust unfortunately: one or several \n tokens might be present in the prefix. Could this depend on the labels passed instead?

src/transformers/models/paligemma/modeling_paligemma.py Outdated Show resolved Hide resolved
@probicheaux probicheaux force-pushed the paligemma-causal-attention-mask branch from 03cfe8e to 9a732f4 Compare May 22, 2024 02:17
@probicheaux
Copy link
Contributor Author

I accidentally ran style on the whole repo and now the tests are failing cause I had to force push away updating like every file (இ﹏இ`。)

I think all the style errors are in code not touched by this change

@probicheaux
Copy link
Contributor Author

Anyways OK I updated how we check where to split using a new index in the labels. Constructing the right labels array is now getting to be sort of tricky for new users, so now the PaliGemmaProcessor returns labels in BatchEncoding for easier finetuning. This is sort of necessary because the prefix and suffix might change length during tokenization and it'd be annoying to try to find the right index of that \n. Hope y'all are ok with this idea!!!

@probicheaux probicheaux requested a review from molbap May 22, 2024 02:24
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A lot of the logic you are re-implementing is natively supported in the tokenizer with the token_type_ids. In this case I think it would be wise to leverage create_token_type_ids_from_sequences checkout codegen for example.

Copy link
Contributor

@molbap molbap left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @probicheaux , left another review! I think your causal mask building is fine, we just need to remove training flag from the processor and pass the labels explicitly. Let me know if you have time to add the suggestions!


super().__init__(image_processor, tokenizer)

def __call__(
self,
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
suffix: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
suffix: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
labels: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None,

Comment on lines +154 to +156
suffix (`str`, `List[str]`, `List[List[str]]`):
The suffix or batch of suffixes to be encoded. Only necessary for finetuning. See https://github.com/google-research/big_vision/blob/main/big_vision/configs/proj/paligemma/README.md
for more information.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
suffix (`str`, `List[str]`, `List[List[str]]`):
The suffix or batch of suffixes to be encoded. Only necessary for finetuning. See https://github.com/google-research/big_vision/blob/main/big_vision/configs/proj/paligemma/README.md
for more information.
labels (`str`, `List[str]`, `List[List[str]]`):
The labels suffixes or batch of suffixes to be encoded. Only necessary for finetuning. See https://github.com/google-research/big_vision/blob/main/big_vision/configs/proj/paligemma/README.md
for more information.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall replacing suffix with labels where it's use, I think

else:
text_inputs = self.tokenizer(
inputs = self.tokenizer(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here I think the key is using the text_pair option in the tokenizer along with return_token_type_ids=True.
So

text = ["Is the witholded tax rate equal to 31%?"]
labels = ["Yes"]
 self.tokenizer(text=text, text_pair=labels, return_token_type_ids=True)

This will give the additional key to inputs of the form 'token_type_ids': [[0, 0, 0, 0, 0, 0, 0, 0, 1, 1]] and will have also correct shape for batched outputs.

Then you already have your mask from token_type_ids and don't need to specify the training bool flag below. The token for newline character has to be inserted before the first 1, and assigned the token type id 0 so that it's part of the prompt with full block attention.

Comment on lines +355 to +357
causal_mask = self.construct_causal_mask_with_block_attention(
attention_mask, labels, text_mask, inputs_embeds
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here, the token_type_ids can be passed from the forward since inputs are unpacked and will contain them. Then, you already have your mask to build the causal mask!

Comment on lines +301 to +303
mask = labels == self.config.prefix_suffix_separator_index
# Get the index of the first \n in each row
indices = mask.int().argmax(dim=1, keepdim=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Assuming you're getting the token_type_ids here, you just have to do something like that

Suggested change
mask = labels == self.config.prefix_suffix_separator_index
# Get the index of the first \n in each row
indices = mask.int().argmax(dim=1, keepdim=True)
indices = (token_type_ids == 1).int().argmax(dim=1)

@probicheaux
Copy link
Contributor Author

Merged as part of #30967

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants