From 1e20372df5b963f00299ec78433a57d55be2f905 Mon Sep 17 00:00:00 2001 From: "selman.ozleyen" Date: Mon, 29 Jun 2026 10:33:42 +0200 Subject: [PATCH] fix the tests --- tests/trainer/test_trainer.py | 23 +++++++++-------------- 1 file changed, 9 insertions(+), 14 deletions(-) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index a04b6cb2..dabcab61 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -271,24 +271,19 @@ def test_validation_metrics_change_when_predict_cache_is_reused( rng=vf_rng, ) - trainer = CellFlowTrainer(solver=solver, predict_kwargs={"max_steps": 3, "throw": False}) - metrics_callback = Metrics(metrics=["e_distance"]) + trainer = CellFlowTrainer(solver=solver, predict_kwargs={"max_steps": 16, "throw": False}) - valid_source_data, valid_true_data, valid_pred_data = trainer._validation_step(valid_loader) - metric_before = metrics_callback.on_log_iteration(valid_source_data, valid_true_data, valid_pred_data, solver)[ - "val_e_distance_mean" - ] + _, _, valid_pred_before = trainer._validation_step(valid_loader) + preds_before = jax.tree.leaves(valid_pred_before) batch = dataloader.sample(None) - metric_diffs = [] for i in range(3): solver.step_fn(jax.random.PRNGKey(i), batch) - valid_source_data, valid_true_data, valid_pred_data = trainer._validation_step(valid_loader) - metric_after = metrics_callback.on_log_iteration( - valid_source_data, valid_true_data, valid_pred_data, solver - )["val_e_distance_mean"] - metric_diffs.append(abs(metric_after - metric_before)) - metric_before = metric_after + _, _, valid_pred_after = trainer._validation_step(valid_loader) + preds_after = jax.tree.leaves(valid_pred_after) + # The jitted predict fn must be reused (single cache entry), not rebuilt. assert len(solver._predict_fn_cache) == 1 - assert any(diff > 0 for diff in metric_diffs) + assert any( + not np.allclose(np.asarray(a), np.asarray(b)) for a, b in zip(preds_before, preds_after, strict=True) + )