-
Notifications
You must be signed in to change notification settings - Fork 1.4k
[tmva][sofie] Fix the allocation of intermediate tensors in the memory pool #19730
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[tmva][sofie] Fix the allocation of intermediate tensors in the memory pool #19730
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for fixing this, some comments:
Test Results 19 files 19 suites 3d 11h 32m 14s ⏱️ For more details on these failures, see this check. Results for commit 5e1d1e2. ♻️ This comment has been updated with latest results. |
59f9ea2
to
096e3ad
Compare
When broadcasting from scalar tensor the tensor size is 0 and shape.front() is undefined. Add then the check on size before calling shape.front()
Add counters in RModel to monitor allocations of Constant, Weights, initermidiates and other types of tensor sizes allocated at code generation
By adding the conv temporary tensors in input lists they will be flushed afterwards and their memory can be reused by next operator
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Thanks for fixing this!
Fix an issue of merging free chunks in the list of available_stack memory This will make easier to re-use more efficiently the memory In addition order the optput tensor by decreasing sizes Add also debug of the current chunk allocated and avaialable during the process
096e3ad
to
5e1d1e2
Compare
This PR fixes an issue in merging the free chunks of memory which are used in the memory pool for the
allocation of the intermediate tensors.
This PR provides a significant (x2) improvement in total memory usage for the intermediate tensors