Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,15 @@ and this project adheres to [Semantic Versioning][].
[keep a changelog]: https://keepachangelog.com/en/1.0.0/
[semantic versioning]: https://semver.org/spec/v2.0.0.html

## [0.1.1] - 2026-06-10

- Changed to sparse outputs
- Removed dead code

### Added

- Initial release with basic tool, preprocessing and plotting functions

## [0.0.1] - 2026-05-29

### Added
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ requires = [ "hatchling" ]

[project]
name = "cellpin"
version = "0.0.1"
version = "0.1.1"

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update the changelog

description = "Cellpin: A Student–Teacher Variational Model for Spatial Transcriptomics Imputation and Encoding"
readme = "README.md"
license = { file = "LICENSE" }
Expand Down
108 changes: 37 additions & 71 deletions src/cellpin/models/cellpin_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
-------------------
* KL annealing (linear warm-up over ``kl_warmup_epochs`` epochs).
* Per-stage configurable loss weights.
* ``get_cell_embedding``, ``embed_and_impute``, ``impute_to_anndata`` API.
* ``get_cell_embedding``, ``embed_and_impute`` API.
* ``fit()`` convenience wrapper: pretrain → train in one call.
"""

Expand Down Expand Up @@ -963,11 +963,16 @@ def _build_output_anndata(
counts: np.ndarray,
embeddings: np.ndarray,
obs_adata: ad.AnnData | None,
return_sparse: bool = True,
) -> ad.AnnData:
"""Build the base output AnnData shared by impute() and impute_to_anndata().
"""Build the base output AnnData for impute().

Sets var_names, X_cellpin embedding, and copies obs/obsm/layers from obs_adata.
Genes absent from obs_adata layers are filled with 0; var['is_measured'] marks
which genes were present in obs_adata.
"""
import scipy.sparse as sp

adata_out = ad.AnnData(X=counts)
adata_out.var_names = self.gene_names
adata_out.obsm["X_cellpin"] = embeddings
Expand All @@ -984,72 +989,26 @@ def _build_output_anndata(
out_gene_idx = {g: i for i, g in enumerate(adata_out.var_names)}
src_cols = [out_gene_idx[g] for g in obs_adata.var_names if g in out_gene_idx]
src_mask = [g in out_gene_idx for g in obs_adata.var_names]
n_missing = adata_out.n_vars - len(src_cols)
if n_missing > 0:
print(f" [impute] Filling {n_missing} gene(s) absent from obs_adata layers with 0")
for lyr_key, lyr_val in obs_adata.layers.items():
if hasattr(lyr_val, "toarray"):
lyr_val = lyr_val.toarray()
print(
f" [impute] Filling {adata_out.n_vars - len(src_cols)} "
"gene(s) absent from obs_adata layers with sentinel -2.0"
)
mat = np.full((adata_out.n_obs, adata_out.n_vars), -2.0, dtype=np.float32)
mat = np.zeros((adata_out.n_obs, adata_out.n_vars), dtype=np.float32)
mat[:, src_cols] = np.asarray(lyr_val, dtype=np.float32)[:, src_mask]
adata_out.layers[lyr_key] = mat

return adata_out

@torch.no_grad()
def impute_to_anndata(
self,
dataloader: torch.utils.data.DataLoader,
obs_adata: ad.AnnData | None = None,
use_mean: bool = True,
mc_impute: bool = False,
mc_samples: int = 50,
mask_fraction: float = 0.2,
table_key: str = "table",
) -> ad.AnnData:
"""Impute full-gene expression and return as AnnData.

Args:
dataloader: DataLoader to run inference on.
obs_adata: Optional AnnData (or :class:`spatialdata.SpatialData`) whose
``.obs`` is copied to the output. If SpatialData, the AnnData is read
from ``obs_adata.tables[table_key]`` and the result is returned as an
updated SpatialData object. Must have the same number of observations.
use_mean: Use posterior mean for the latent (deterministic).
Ignored when ``mc_impute=True``.
mc_impute: Use MC averaging over ``mc_samples`` stochastic forward
passes (recommended; ~+0.01 mean Pearson over deterministic).
mc_samples: Number of MC samples (default 50).
mask_fraction: Fraction of panel genes randomly zeroed per MC pass
(default 0.2).
table_key: Table name to read/write when ``obs_adata`` is a SpatialData
object (default ``"table"``).

Returns:
-------
:class:`anndata.AnnData` with ``X`` = imputed counts,
``obsm['X_cellpin']`` = embeddings, ``layers['imputed']``.
If ``obs_adata`` was a SpatialData object, returns the updated SpatialData
with the result stored in ``sdata.tables[table_key]``.
adata_out.layers[lyr_key] = sp.csr_matrix(mat) if return_sparse else mat

# Mark which genes were measured in obs_adata
measured = np.zeros(adata_out.n_vars, dtype=bool)
out_gene_set = set(adata_out.var_names)
for i, g in enumerate(adata_out.var_names):
measured[i] = g in set(obs_adata.var_names) and g in out_gene_set
adata_out.var["is_measured"] = measured
else:
# No obs_adata: all output genes are considered measured (imputed from sc ref)
adata_out.var["is_measured"] = np.ones(adata_out.n_vars, dtype=bool)

Raises:
------
ValueError: If ``obs_adata`` has the wrong number of cells.
"""
obs_adata, sdata = _resolve_sdata(obs_adata, table_key)
embeddings, imputed_arr, _ = self.embed_and_impute(
dataloader,
use_mean=use_mean,
mc_impute=mc_impute,
mc_samples=mc_samples,
mask_fraction=mask_fraction,
)
adata_out = self._build_output_anndata(imputed_arr, embeddings, obs_adata)
adata_out.layers["imputed"] = adata_out.X.copy()
if sdata is not None:
sdata.tables[table_key] = adata_out
return sdata
return adata_out

def fit(
Expand Down Expand Up @@ -1135,15 +1094,11 @@ def impute(
area_key: str | None = None,
nb_count_samples: int = 100,
return_int: bool = False,
return_sparse: bool = True,
table_key: str = "table",
) -> ad.AnnData:
"""Impute with MC averaging and optional count-space normalisation.

More complete than :meth:`impute_to_anndata`: adds MC dropout, integer
rounding, and area-normalised log-normalised layers. Use
:meth:`impute_to_anndata` when you only need embeddings + raw imputed
counts.

Args:
dataloader: DataLoader to run inference on.
obs_adata: Optional AnnData (or :class:`spatialdata.SpatialData`) whose
Expand All @@ -1168,6 +1123,9 @@ def impute(
``log1p(norm(E[X])) > E[log1p(norm(X))]``; sampling inside the
transform corrects this bias. More samples → lower variance.
return_int: If ``True``, round ``X`` to integer counts (``int32``).
return_sparse: If ``True`` (default), store ``X``, ``layers['imputed']``,
and ``layers['imputed_norm']`` as :class:`scipy.sparse.csr_matrix`.
Set to ``False`` to keep dense numpy arrays.
table_key: Table name to read/write when ``obs_adata`` is a SpatialData
object (default ``"table"``).

Expand All @@ -1176,6 +1134,8 @@ def impute(
:class:`anndata.AnnData` with ``X`` = imputed (float or int) counts,
``obsm['X_cellpin']`` = embeddings, ``layers['imputed']`` = copy of
``X``, and optionally ``layers['imputed_norm']``.
``var['is_measured']`` marks genes present in ``obs_adata`` (all ``True``
when ``obs_adata`` is ``None``).
If ``obs_adata`` was a SpatialData object, returns the updated SpatialData
with the result stored in ``sdata.tables[table_key]``.

Expand All @@ -1185,6 +1145,8 @@ def impute(
``area_key`` is specified but not found in ``adata.obs``, or if
any cell area is ≤ 0.
"""
import scipy.sparse as sp

obs_adata, sdata = _resolve_sdata(obs_adata, table_key)

embeddings, counts, log_library = self.embed_and_impute(
Expand All @@ -1200,8 +1162,11 @@ def impute(
if return_int:
counts = np.round(counts, decimals=0).astype(np.int32)

adata_out = self._build_output_anndata(counts, embeddings, obs_adata)
adata_out.layers["imputed"] = counts.copy()
adata_out = self._build_output_anndata(counts, embeddings, obs_adata, return_sparse=return_sparse)

imputed = sp.csr_matrix(counts) if return_sparse else counts.copy()
adata_out.X = imputed
adata_out.layers["imputed"] = imputed

if return_norm:
# MC estimate of E[log1p(norm(X))] where X ~ NB(mu, theta).
Expand Down Expand Up @@ -1240,7 +1205,8 @@ def impute(
normed = draw * (norm_target_sum / lib)
log1p_acc += np.log1p(normed)

adata_out.layers["imputed_norm"] = (log1p_acc / K).astype(np.float32)
norm_layer = (log1p_acc / K).astype(np.float32)
adata_out.layers["imputed_norm"] = sp.csr_matrix(norm_layer) if return_sparse else norm_layer

if sdata is not None:
sdata.tables[table_key] = adata_out
Expand Down
99 changes: 97 additions & 2 deletions tests/models/test_cellpin.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,16 +73,111 @@ def test_cellpin_get_cell_embedding(small_datasets):
assert embeddings.dtype == np.float32


def test_cellpin_impute_to_anndata(small_datasets):
def test_cellpin_impute(small_datasets):
"""Basic shape and key checks for impute()."""
import scipy.sparse as sp

sc_ds, st_ds = small_datasets
model = CellPin(sc_dataset=sc_ds, config=MINIMAL_CONFIG)
model.eval()

loader = DataLoader(st_ds, batch_size=4, shuffle=False)
adata_out = model.impute_to_anndata(loader)
adata_out = model.impute(loader, mc_samples=2)

assert adata_out.n_obs == len(st_ds)
assert adata_out.n_vars == sc_ds.X.shape[1]
assert "X_cellpin" in adata_out.obsm
assert adata_out.obsm["X_cellpin"].shape == (len(st_ds), MINIMAL_CONFIG["n_latent"])
assert "imputed" in adata_out.layers
assert "is_measured" in adata_out.var.columns


def test_impute_no_sentinel(small_datasets):
"""No -2 values anywhere in imputed output."""
sc_ds, st_ds = small_datasets
model = CellPin(sc_dataset=sc_ds, config=MINIMAL_CONFIG)
model.eval()

loader = DataLoader(st_ds, batch_size=4, shuffle=False)
adata_sc_loader = DataLoader(sc_ds, batch_size=4, shuffle=False)

# impute spatial with obs_adata that has fewer genes (triggers the fill path)
import anndata as ad
import numpy as np
rng = np.random.default_rng(0)
obs_adata = ad.AnnData(X=rng.integers(1, 10, size=(len(st_ds), 8)).astype(np.float32))
obs_adata.var_names = [f"gene{i}" for i in range(8)]
obs_adata.layers["counts"] = obs_adata.X.copy()

adata_out = model.impute(loader, obs_adata=obs_adata, mc_samples=2, return_sparse=False)

assert np.all(adata_out.X >= 0), "X contains negative values"
for lyr in adata_out.layers.values():
arr = lyr.toarray() if hasattr(lyr, "toarray") else lyr
assert np.all(arr >= 0), f"Layer contains negative values"


def test_is_measured_var_column(small_datasets):
"""is_measured correctly marks panel genes vs unmeasured genes."""
import anndata as ad
import numpy as np

sc_ds, st_ds = small_datasets
model = CellPin(sc_dataset=sc_ds, config=MINIMAL_CONFIG)
model.eval()

loader = DataLoader(st_ds, batch_size=4, shuffle=False)

# obs_adata has only 8 of the 20 genes
rng = np.random.default_rng(1)
obs_adata = ad.AnnData(X=rng.integers(1, 10, size=(len(st_ds), 8)).astype(np.float32))
obs_adata.var_names = [f"gene{i}" for i in range(8)]
obs_adata.layers["counts"] = obs_adata.X.copy()

adata_out = model.impute(loader, obs_adata=obs_adata, mc_samples=2)

assert "is_measured" in adata_out.var.columns
assert adata_out.var["is_measured"].dtype == bool
# genes 0-7 are in obs_adata, genes 8-19 are not
assert adata_out.var.loc["gene0", "is_measured"] is np.bool_(True)
assert adata_out.var.loc["gene7", "is_measured"] is np.bool_(True)
assert adata_out.var.loc["gene8", "is_measured"] is np.bool_(False)
assert adata_out.var.loc["gene19", "is_measured"] is np.bool_(False)
assert adata_out.var["is_measured"].sum() == 8


def test_impute_sparse_output(small_datasets):
"""X and layers are sparse by default; dense when return_sparse=False."""
import scipy.sparse as sp

sc_ds, st_ds = small_datasets
model = CellPin(sc_dataset=sc_ds, config=MINIMAL_CONFIG)
model.eval()

loader = DataLoader(st_ds, batch_size=4, shuffle=False)

adata_sparse = model.impute(loader, mc_samples=2, return_sparse=True)
assert sp.issparse(adata_sparse.X)
assert sp.issparse(adata_sparse.layers["imputed"])

adata_dense = model.impute(loader, mc_samples=2, return_sparse=False)
assert not sp.issparse(adata_dense.X)
assert not sp.issparse(adata_dense.layers["imputed"])


def test_impute_return_int_sparse(small_datasets):
"""return_int=True produces sparse int32 with no negative values."""
import scipy.sparse as sp

sc_ds, st_ds = small_datasets
model = CellPin(sc_dataset=sc_ds, config=MINIMAL_CONFIG)
model.eval()

loader = DataLoader(st_ds, batch_size=4, shuffle=False)
adata_out = model.impute(loader, mc_samples=2, return_int=True, return_sparse=True)

assert sp.issparse(adata_out.X)
assert adata_out.X.dtype == np.int32
assert sp.issparse(adata_out.layers["imputed"])
assert adata_out.layers["imputed"].dtype == np.int32
assert adata_out.X.min() >= 0
Loading