-
Notifications
You must be signed in to change notification settings - Fork 19.6k
Fix issue with concatenate, masking and symbolic inputs #21611
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
We were trying to grab use a symbolic input shape as a fixed broadcast shape. Instead we need to capture the input as a input node who's shape should be used to broadcast at execution time on real input tensors.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summary of Changes
Hello @mattdangerw, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request addresses an issue within Keras's Concatenate
layer concerning how masking interacts with symbolic input shapes. Previously, the system incorrectly used symbolic input shapes as fixed broadcast shapes. The fix ensures that the input is captured as an input node, allowing its shape to be used for broadcasting at execution time with real input tensors, thereby resolving an issue where masking was not correctly applied with symbolic inputs.
Highlights
- Masking with Symbolic Inputs: The
compute_mask
method inConcatenate
layers has been updated to correctly handle broadcasting of masks when the mask's dimensionality is less than the input's. The previousbroadcast_to
approach was replaced with a sequence of operations (expand_dims
,cast
,zeros_like
,cast
) to ensure the input is captured as a symbolic input in the operation graph, enabling proper broadcasting at execution time. - New Test Case for Symbolic Masking: A new test,
test_concatenate_with_mask_symbolic
, has been added tomerging_test.py
. This test specifically validates the fix by checking the behavior ofConcatenate
withMasking
when symbolic inputs are used, ensuring the mask is correctly propagated.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command>
or @gemini-code-assist <command>
. Below is a summary of the supported commands.
Feature | Command | Description |
---|---|---|
Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/
folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request effectively addresses an issue with mask computation in the Concatenate
layer when dealing with symbolic inputs. The approach of using zeros_like
to facilitate broadcasting in a graph-compatible way is a solid fix. The addition of a targeted test case for symbolic inputs is also a great way to prevent regressions. I have one suggestion to refactor the implementation for better readability and to use more idiomatic boolean operations.
mask_i = ops.expand_dims(mask_i, axis=-1) | ||
# Broadcast mask shape to match in a way where we capture the | ||
# input as a symbolic input in the op graph. | ||
mask_dtype = mask_i.dtype | ||
mask_i = ops.cast(mask_i, "int32") | ||
mask_i = mask_i + ops.zeros_like(input_i, "int32") | ||
mask_i = ops.cast(mask_i, dtype=mask_dtype) | ||
masks.append(mask_i) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The current implementation for broadcasting the mask works, but it's a bit verbose with the casting to int32
and back. A more direct and idiomatic way to achieve this for boolean tensors is to use ops.logical_or
with a tensor of zeros. This avoids the type casting and makes the intent clearer.
mask_i = ops.expand_dims(mask_i, axis=-1) | |
# Broadcast mask shape to match in a way where we capture the | |
# input as a symbolic input in the op graph. | |
mask_dtype = mask_i.dtype | |
mask_i = ops.cast(mask_i, "int32") | |
mask_i = mask_i + ops.zeros_like(input_i, "int32") | |
mask_i = ops.cast(mask_i, dtype=mask_dtype) | |
masks.append(mask_i) | |
mask_i = ops.expand_dims(mask_i, axis=-1) | |
# Broadcast mask to the same shape as the input by using | |
# `ops.logical_or` with a zero tensor of the target shape. This | |
# correctly handles symbolic tensors. | |
mask_i = ops.logical_or(mask_i, ops.zeros_like(input_i, dtype="bool")) | |
masks.append(mask_i) |
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## master #21611 +/- ##
==========================================
- Coverage 82.71% 82.45% -0.26%
==========================================
Files 568 572 +4
Lines 56897 57342 +445
Branches 8890 8970 +80
==========================================
+ Hits 47063 47283 +220
- Misses 7640 7760 +120
- Partials 2194 2299 +105
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
We were trying to grab use a symbolic input shape as a fixed broadcast shape. Instead we need to capture the input as a input node who's shape should be used to broadcast at execution time on real input tensors.
Fixes #21581