Skip to content

Fix issue with concatenate, masking and symbolic inputs #21611

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from

Conversation

mattdangerw
Copy link
Member

We were trying to grab use a symbolic input shape as a fixed broadcast shape. Instead we need to capture the input as a input node who's shape should be used to broadcast at execution time on real input tensors.

Fixes #21581

We were trying to grab use a symbolic input shape as a fixed
broadcast shape. Instead we need to capture the input as a input
node who's shape should be used to broadcast at execution time on
real input tensors.
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @mattdangerw, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request addresses an issue within Keras's Concatenate layer concerning how masking interacts with symbolic input shapes. Previously, the system incorrectly used symbolic input shapes as fixed broadcast shapes. The fix ensures that the input is captured as an input node, allowing its shape to be used for broadcasting at execution time with real input tensors, thereby resolving an issue where masking was not correctly applied with symbolic inputs.

Highlights

  • Masking with Symbolic Inputs: The compute_mask method in Concatenate layers has been updated to correctly handle broadcasting of masks when the mask's dimensionality is less than the input's. The previous broadcast_to approach was replaced with a sequence of operations (expand_dims, cast, zeros_like, cast) to ensure the input is captured as a symbolic input in the operation graph, enabling proper broadcasting at execution time.
  • New Test Case for Symbolic Masking: A new test, test_concatenate_with_mask_symbolic, has been added to merging_test.py. This test specifically validates the fix by checking the behavior of Concatenate with Masking when symbolic inputs are used, ensuring the mask is correctly propagated.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request effectively addresses an issue with mask computation in the Concatenate layer when dealing with symbolic inputs. The approach of using zeros_like to facilitate broadcasting in a graph-compatible way is a solid fix. The addition of a targeted test case for symbolic inputs is also a great way to prevent regressions. I have one suggestion to refactor the implementation for better readability and to use more idiomatic boolean operations.

Comment on lines +149 to +156
mask_i = ops.expand_dims(mask_i, axis=-1)
# Broadcast mask shape to match in a way where we capture the
# input as a symbolic input in the op graph.
mask_dtype = mask_i.dtype
mask_i = ops.cast(mask_i, "int32")
mask_i = mask_i + ops.zeros_like(input_i, "int32")
mask_i = ops.cast(mask_i, dtype=mask_dtype)
masks.append(mask_i)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The current implementation for broadcasting the mask works, but it's a bit verbose with the casting to int32 and back. A more direct and idiomatic way to achieve this for boolean tensors is to use ops.logical_or with a tensor of zeros. This avoids the type casting and makes the intent clearer.

Suggested change
mask_i = ops.expand_dims(mask_i, axis=-1)
# Broadcast mask shape to match in a way where we capture the
# input as a symbolic input in the op graph.
mask_dtype = mask_i.dtype
mask_i = ops.cast(mask_i, "int32")
mask_i = mask_i + ops.zeros_like(input_i, "int32")
mask_i = ops.cast(mask_i, dtype=mask_dtype)
masks.append(mask_i)
mask_i = ops.expand_dims(mask_i, axis=-1)
# Broadcast mask to the same shape as the input by using
# `ops.logical_or` with a zero tensor of the target shape. This
# correctly handles symbolic tensors.
mask_i = ops.logical_or(mask_i, ops.zeros_like(input_i, dtype="bool"))
masks.append(mask_i)

@codecov-commenter
Copy link

codecov-commenter commented Aug 23, 2025

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 82.45%. Comparing base (ac5c97f) to head (9e84b7c).
⚠️ Report is 7 commits behind head on master.

Additional details and impacted files
@@            Coverage Diff             @@
##           master   #21611      +/-   ##
==========================================
- Coverage   82.71%   82.45%   -0.26%     
==========================================
  Files         568      572       +4     
  Lines       56897    57342     +445     
  Branches     8890     8970      +80     
==========================================
+ Hits        47063    47283     +220     
- Misses       7640     7760     +120     
- Partials     2194     2299     +105     
Flag Coverage Δ
keras 82.26% <100.00%> (-0.26%) ⬇️
keras-jax 63.58% <100.00%> (-0.08%) ⬇️
keras-numpy 57.86% <0.00%> (-0.40%) ⬇️
keras-openvino 34.34% <0.00%> (-0.22%) ⬇️
keras-tensorflow 64.22% <100.00%> (+<0.01%) ⬆️
keras-torch 63.79% <100.00%> (-0.02%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

BatchNormalization fails after Concatenation of masked Embeddings
3 participants