Skip to content

Conversation

han-ol
Copy link
Collaborator

@han-ol han-ol commented Jul 22, 2025

This PR seeks to address #541.

It looks to me like we need to tweak JAXApproximator.stateless_compute_metrics for this to work in jax as well.

The other backends are already covered with just the changes in the initial commit.

EDIT: You can find an example in https://github.com/bayesflow-org/bayesflow/blob/add-loss/examples/Custom_losses_with_add_loss.ipynb

Copy link

codecov bot commented Jul 22, 2025

Codecov Report

Attention: Patch coverage is 75.00000% with 3 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
...low/approximators/model_comparison_approximator.py 50.00% 3 Missing ⚠️
Files with missing lines Coverage Δ
bayesflow/approximators/continuous_approximator.py 91.45% <100.00%> (+0.22%) ⬆️
...low/approximators/model_comparison_approximator.py 83.90% <50.00%> (-1.30%) ⬇️

@han-ol han-ol requested review from LarsKue and vpratz and removed request for LarsKue July 22, 2025 12:17
@han-ol
Copy link
Collaborator Author

han-ol commented Jul 22, 2025

I added tests and a minimal example notebook.

Tests are passing on torch and tensorflow, but fail on jax.

@LarsKue since you are the architect of the stateless_compute_metrics, could you look into how we can make this work for jax?

The final section of the keras guide on custom training loops in jax proves that this can be rather straight forward, but I am unsure how to implement it in our case: https://keras.io/guides/writing_a_custom_training_loop_in_jax/

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant