[WIP] Spectral-Grassmann OT#792
Conversation
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## master #792 +/- ##
==========================================
- Coverage 96.87% 96.86% -0.02%
==========================================
Files 113 115 +2
Lines 23062 23339 +277
==========================================
+ Hits 22342 22608 +266
- Misses 720 731 +11 🚀 New features to boost your workflow:
|
rflamary
left a comment
There was a problem hiding this comment.
Hello @osheasienna and @thibaut-germain this is a nice first step.
Here are below a few comments that we can discuss together
| return C | ||
|
|
||
|
|
||
| def metric( |
There was a problem hiding this comment.
| def metric( | |
| def sgot_metric( |
| return prod ** (q / 2) | ||
|
|
||
|
|
||
| def ot_plan(C, Ws=None, Wt=None, nx=None): |
There was a problem hiding this comment.
this function is not needed, this is two lines and the ormalization wrt ws and wt are not oK because it rcan retrun very weird things
| ### SPECTRAL-GRASSMANNIAN WASSERSTEIN METRIC ### | ||
| ##################################################################################################################################### | ||
| ##################################################################################################################################### | ||
| def cost( |
There was a problem hiding this comment.
| def cost( | |
| def sgot_cost_matrix( |
| imag_scale=1.0, | ||
| nx=None, | ||
| ): | ||
| """Compute the SGOT cost matrix between two spectral decompositions. |
There was a problem hiding this comment.
recall here the equation with eta and define with math teh different acceptable metrics
| raise ValueError(f"cost() expects Dt to be 1D (n,), got shape {Dt.shape}") | ||
| lam2 = Dt | ||
|
|
||
| lam1 = nx.astype(lam1, "complex128") |
There was a problem hiding this comment.
is that necessary? seems overkill to add a function to the backend for that . When and why does it fails?
| logits_s = rng.randn(r) | ||
| logits_t = rng.randn(r) | ||
|
|
||
| Ws = np.exp(logits_s) |
There was a problem hiding this comment.
simpler and return only positive values
| Ws = np.exp(logits_s) | |
| Ws = rng.rand(r) |
| """Create test_cost for each trial: sweep over HPs and run cost().""" | ||
| grassmann_types = ["geodesic", "chordal", "procrustes", "martin"] | ||
| n_trials = 10 | ||
| for _ in range(n_trials): |
| def test_hyperparameter_sweep(): | ||
| grassmann_types = ["geodesic", "chordal", "procrustes", "martin"] | ||
|
|
||
| for _ in range(10): |
| This new release adds support for sparse cost matrices and a new lazy EMD solver that computes distances on-the-fly from coordinates, reducing memory usage from O(n×m) to O(n+m). Both implementations are backend-agnostic and preserve gradient computation for automatic differentiation. | ||
|
|
||
| #### New features | ||
| - Add lazy EMD solver with on-the-fly distance computation from coordinates (PR #788) |
| ## Upcomming 0.9.7.post1 | ||
|
|
||
| #### New features | ||
| The next release will add cost functions between linear operators following [A Spectral-Grassmann Wasserstein metric for operator representations of dynamical systems](https://arxiv.org/pdf/2509.24920). |
There was a problem hiding this comment.
move this text to the new feature of 0.9.7.dev0 this is what we are working on. Also add a line in the Itemize with the PR number
rflamary
left a comment
There was a problem hiding this comment.
A few comments from talking together
| if grassman_metric == "procrustes": | ||
| return 2.0 * (1.0 - delta) | ||
| if grassman_metric == "martin": | ||
| return -nx.log(nx.clip(delta**2, eps, 1e300)) |
| C_grass = _grassmann_distance_squared(delta, grassman_metric=grassman_metric, nx=nx) | ||
|
|
||
| C2 = eta * C_lambda + (1.0 - eta) * C_grass | ||
| C = C2 ** (p / 2.0) |
There was a problem hiding this comment.
| C = C2 ** (p / 2.0) | |
| C = nx.real(C2) ** (p / 2.0) |
| q=1, | ||
| r=2, | ||
| grassman_metric="chordal", | ||
| real_scale=1.0, |
There was a problem hiding this comment.
lets call this eigen_scaling and set it to None by default
| nx=None, | ||
| ): | ||
| """Compute the SGOT metric between two spectral decompositions. | ||
|
|
There was a problem hiding this comment.
add equation that illustrate p q and r
| import numpy as np | ||
| import pytest | ||
|
|
||
| from ot.backend import get_backend |
There was a problem hiding this comment.
| from ot.backend import get_backend | |
| from ot.backend import get_backend, torch, jax |
| rng = np.random.RandomState(0) | ||
|
|
||
|
|
||
| def rand_complex(shape): |
There was a problem hiding this comment.
| def rand_complex(shape): | |
| def rand_complex(shape,rng): |
| return real + 1j * imag | ||
|
|
||
|
|
||
| def random_atoms(d=8, r=4): |
There was a problem hiding this comment.
| def random_atoms(d=8, r=4): | |
| def random_atoms(d=8, r=4,seed=42): |
|
|
||
|
|
||
| @pytest.mark.parametrize("backend_name", ["numpy", "torch", "jax"]) | ||
| def test_cost_backend_consistency(backend_name): |
There was a problem hiding this comment.
| def test_cost_backend_consistency(backend_name): | |
| def test_cost_backend_consistency(nx): |
| # --------------------------------------------------------------------- | ||
|
|
||
|
|
||
| def test_hyperparameter_sweep_cost(nx): |
There was a problem hiding this comment.
| def test_hyperparameter_sweep_cost(nx): | |
| def test_hyperparameter_sweep_cost(nx,grassmann_types,p,q,r,eta): |
| Ws = Ws / nx.sum(Ws) | ||
| Wt = Wt / nx.sum(Wt) | ||
|
|
||
| P = ot.emd2(Ws, Wt, nx.real(C)) |
There was a problem hiding this comment.
emd2 retruns directly obj no need to compute it again below
| else: | ||
| real_scale, imag_scale = eigen_scaling[0], eigen_scaling[1] | ||
|
|
||
| Dsn = nx.real(Ds) * real_scale + 1j * nx.imag(Ds) * imag_scale |
There was a problem hiding this comment.
| Dsn = nx.real(Ds) * real_scale + 1j * nx.imag(Ds) * imag_scale | |
| C_real = nx.real(Dsn)[:,None] - nx.real(Dtn)[None,:] | |
| C_real = C_real**2 | |
| C_imag = nx.imag(Dsn)[:,None] - nx.imag(Dtn)[None,:] | |
| C_imag = C_imag**2 | |
| prod = C_real + C_imag | |
| return prod ** (q / 2) |
| A_norm: array-like, shape (d, n) | ||
| Column-normalized array. | ||
| """ | ||
| nrm = nx.sqrt(nx.sum(A * nx.conj(A), axis=0, keepdims=True)) |
There was a problem hiding this comment.
| nrm = nx.sqrt(nx.sum(A * nx.conj(A), axis=0, keepdims=True)) | |
| nrm = nx.norm(A, axis=0, keepdims=True) |
You can replace it with the function nx.norm which manages the case of complex number
| return delta | ||
|
|
||
|
|
||
| def _grassmann_distance_squared(delta, grassman_metric="chordal", nx=None, eps=1e-300): |
There was a problem hiding this comment.
Epsilon is too small for the machine precision, you can set it to 1e-12 for instance.
| if nx is None: | ||
| nx = get_backend(delta) | ||
|
|
||
| delta = nx.clip(delta, 0.0, 1.0) |
There was a problem hiding this comment.
If delta is not in [0,1] it should raise an error, this is an issue in the computation of delta outside of this function.
| ### SPECTRAL-GRASSMANNIAN WASSERSTEIN METRIC ### | ||
| ##################################################################################################################################### | ||
| ##################################################################################################################################### | ||
| def sgot_cost_matrix( |
There was a problem hiding this comment.
You should add eps in the definition of the function as this parameters appears in downstream functions. Keep the same epsilon for all functions.
| # information-geometric interpretation in Germain et al. (2025). | ||
| delta2 = nx.maximum(delta**2, eps) | ||
| return -nx.log(delta2) | ||
| raise ValueError(f"Unknown grassman_metric: {grassman_metric}") |
There was a problem hiding this comment.
In this function the power q should also be a parameter:
for any distance you can set:
result = square_ditance(delta)
then
return nx.real(result)**(q/2)
Set by default q to the same value as for eigenvalue cost
| C_lambda = eigenvalue_cost_matrix(Ds, Dt, q=q, eigen_scaling=eigen_scaling, nx=nx) | ||
|
|
||
| delta = _delta_matrix_1d(Rs, Ls, Rt, Lt, nx=nx) | ||
| C_grass = _grassmann_distance_squared(delta, grassman_metric=grassman_metric, nx=nx) |
There was a problem hiding this comment.
the power parameter q should also affect the Grassmann cost
| C_grass = _grassmann_distance_squared(delta, grassman_metric=grassman_metric, nx=nx) | ||
|
|
||
| C2 = eta * C_lambda + (1.0 - eta) * C_grass | ||
| C = nx.real(C2) ** (p / 2.0) |
There was a problem hiding this comment.
you your cost function already return a real no need for nx.real here
| return C | ||
|
|
||
|
|
||
| def _validate_sgot_metric_inputs(Ds, Dt): |
There was a problem hiding this comment.
You can add verifications you wrote in line 272-290 in this function and also add verifications than source and target have the same shapes.
| ) | ||
|
|
||
|
|
||
| def sgot_metric( |
There was a problem hiding this comment.
You will need to add eps also in this function.
Types of changes
Adding sgot file in the ot folder.
Motivation and context / Related issue
Keep track of SGOT implementation in POT.
How has this been tested (if it applies)
Not tested yet.
PR checklist