diff --git a/CHANGELOG.md b/CHANGELOG.md index d68f4d7..f3dac64 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,15 @@ and this project adheres to [Semantic Versioning][]. [keep a changelog]: https://keepachangelog.com/en/1.0.0/ [semantic versioning]: https://semver.org/spec/v2.0.0.html +## [0.1.1] - 2026-06-10 + +- Changed to sparse outputs +- Removed dead code + +### Added + +- Initial release with basic tool, preprocessing and plotting functions + ## [0.0.1] - 2026-05-29 ### Added diff --git a/pyproject.toml b/pyproject.toml index e3b523b..9a0fe52 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ requires = [ "hatchling" ] [project] name = "cellpin" -version = "0.0.1" +version = "0.1.1" description = "Cellpin: A Student–Teacher Variational Model for Spatial Transcriptomics Imputation and Encoding" readme = "README.md" license = { file = "LICENSE" } diff --git a/src/cellpin/models/cellpin_model.py b/src/cellpin/models/cellpin_model.py index 1d32827..a716a7a 100644 --- a/src/cellpin/models/cellpin_model.py +++ b/src/cellpin/models/cellpin_model.py @@ -15,7 +15,7 @@ ------------------- * KL annealing (linear warm-up over ``kl_warmup_epochs`` epochs). * Per-stage configurable loss weights. -* ``get_cell_embedding``, ``embed_and_impute``, ``impute_to_anndata`` API. +* ``get_cell_embedding``, ``embed_and_impute`` API. * ``fit()`` convenience wrapper: pretrain → train in one call. """ @@ -963,11 +963,16 @@ def _build_output_anndata( counts: np.ndarray, embeddings: np.ndarray, obs_adata: ad.AnnData | None, + return_sparse: bool = True, ) -> ad.AnnData: - """Build the base output AnnData shared by impute() and impute_to_anndata(). + """Build the base output AnnData for impute(). Sets var_names, X_cellpin embedding, and copies obs/obsm/layers from obs_adata. + Genes absent from obs_adata layers are filled with 0; var['is_measured'] marks + which genes were present in obs_adata. """ + import scipy.sparse as sp + adata_out = ad.AnnData(X=counts) adata_out.var_names = self.gene_names adata_out.obsm["X_cellpin"] = embeddings @@ -984,72 +989,26 @@ def _build_output_anndata( out_gene_idx = {g: i for i, g in enumerate(adata_out.var_names)} src_cols = [out_gene_idx[g] for g in obs_adata.var_names if g in out_gene_idx] src_mask = [g in out_gene_idx for g in obs_adata.var_names] + n_missing = adata_out.n_vars - len(src_cols) + if n_missing > 0: + print(f" [impute] Filling {n_missing} gene(s) absent from obs_adata layers with 0") for lyr_key, lyr_val in obs_adata.layers.items(): if hasattr(lyr_val, "toarray"): lyr_val = lyr_val.toarray() - print( - f" [impute] Filling {adata_out.n_vars - len(src_cols)} " - "gene(s) absent from obs_adata layers with sentinel -2.0" - ) - mat = np.full((adata_out.n_obs, adata_out.n_vars), -2.0, dtype=np.float32) + mat = np.zeros((adata_out.n_obs, adata_out.n_vars), dtype=np.float32) mat[:, src_cols] = np.asarray(lyr_val, dtype=np.float32)[:, src_mask] - adata_out.layers[lyr_key] = mat - - return adata_out - - @torch.no_grad() - def impute_to_anndata( - self, - dataloader: torch.utils.data.DataLoader, - obs_adata: ad.AnnData | None = None, - use_mean: bool = True, - mc_impute: bool = False, - mc_samples: int = 50, - mask_fraction: float = 0.2, - table_key: str = "table", - ) -> ad.AnnData: - """Impute full-gene expression and return as AnnData. - - Args: - dataloader: DataLoader to run inference on. - obs_adata: Optional AnnData (or :class:`spatialdata.SpatialData`) whose - ``.obs`` is copied to the output. If SpatialData, the AnnData is read - from ``obs_adata.tables[table_key]`` and the result is returned as an - updated SpatialData object. Must have the same number of observations. - use_mean: Use posterior mean for the latent (deterministic). - Ignored when ``mc_impute=True``. - mc_impute: Use MC averaging over ``mc_samples`` stochastic forward - passes (recommended; ~+0.01 mean Pearson over deterministic). - mc_samples: Number of MC samples (default 50). - mask_fraction: Fraction of panel genes randomly zeroed per MC pass - (default 0.2). - table_key: Table name to read/write when ``obs_adata`` is a SpatialData - object (default ``"table"``). - - Returns: - ------- - :class:`anndata.AnnData` with ``X`` = imputed counts, - ``obsm['X_cellpin']`` = embeddings, ``layers['imputed']``. - If ``obs_adata`` was a SpatialData object, returns the updated SpatialData - with the result stored in ``sdata.tables[table_key]``. + adata_out.layers[lyr_key] = sp.csr_matrix(mat) if return_sparse else mat + + # Mark which genes were measured in obs_adata + measured = np.zeros(adata_out.n_vars, dtype=bool) + out_gene_set = set(adata_out.var_names) + for i, g in enumerate(adata_out.var_names): + measured[i] = g in set(obs_adata.var_names) and g in out_gene_set + adata_out.var["is_measured"] = measured + else: + # No obs_adata: all output genes are considered measured (imputed from sc ref) + adata_out.var["is_measured"] = np.ones(adata_out.n_vars, dtype=bool) - Raises: - ------ - ValueError: If ``obs_adata`` has the wrong number of cells. - """ - obs_adata, sdata = _resolve_sdata(obs_adata, table_key) - embeddings, imputed_arr, _ = self.embed_and_impute( - dataloader, - use_mean=use_mean, - mc_impute=mc_impute, - mc_samples=mc_samples, - mask_fraction=mask_fraction, - ) - adata_out = self._build_output_anndata(imputed_arr, embeddings, obs_adata) - adata_out.layers["imputed"] = adata_out.X.copy() - if sdata is not None: - sdata.tables[table_key] = adata_out - return sdata return adata_out def fit( @@ -1135,15 +1094,11 @@ def impute( area_key: str | None = None, nb_count_samples: int = 100, return_int: bool = False, + return_sparse: bool = True, table_key: str = "table", ) -> ad.AnnData: """Impute with MC averaging and optional count-space normalisation. - More complete than :meth:`impute_to_anndata`: adds MC dropout, integer - rounding, and area-normalised log-normalised layers. Use - :meth:`impute_to_anndata` when you only need embeddings + raw imputed - counts. - Args: dataloader: DataLoader to run inference on. obs_adata: Optional AnnData (or :class:`spatialdata.SpatialData`) whose @@ -1168,6 +1123,9 @@ def impute( ``log1p(norm(E[X])) > E[log1p(norm(X))]``; sampling inside the transform corrects this bias. More samples → lower variance. return_int: If ``True``, round ``X`` to integer counts (``int32``). + return_sparse: If ``True`` (default), store ``X``, ``layers['imputed']``, + and ``layers['imputed_norm']`` as :class:`scipy.sparse.csr_matrix`. + Set to ``False`` to keep dense numpy arrays. table_key: Table name to read/write when ``obs_adata`` is a SpatialData object (default ``"table"``). @@ -1176,6 +1134,8 @@ def impute( :class:`anndata.AnnData` with ``X`` = imputed (float or int) counts, ``obsm['X_cellpin']`` = embeddings, ``layers['imputed']`` = copy of ``X``, and optionally ``layers['imputed_norm']``. + ``var['is_measured']`` marks genes present in ``obs_adata`` (all ``True`` + when ``obs_adata`` is ``None``). If ``obs_adata`` was a SpatialData object, returns the updated SpatialData with the result stored in ``sdata.tables[table_key]``. @@ -1185,6 +1145,8 @@ def impute( ``area_key`` is specified but not found in ``adata.obs``, or if any cell area is ≤ 0. """ + import scipy.sparse as sp + obs_adata, sdata = _resolve_sdata(obs_adata, table_key) embeddings, counts, log_library = self.embed_and_impute( @@ -1200,8 +1162,11 @@ def impute( if return_int: counts = np.round(counts, decimals=0).astype(np.int32) - adata_out = self._build_output_anndata(counts, embeddings, obs_adata) - adata_out.layers["imputed"] = counts.copy() + adata_out = self._build_output_anndata(counts, embeddings, obs_adata, return_sparse=return_sparse) + + imputed = sp.csr_matrix(counts) if return_sparse else counts.copy() + adata_out.X = imputed + adata_out.layers["imputed"] = imputed if return_norm: # MC estimate of E[log1p(norm(X))] where X ~ NB(mu, theta). @@ -1240,7 +1205,8 @@ def impute( normed = draw * (norm_target_sum / lib) log1p_acc += np.log1p(normed) - adata_out.layers["imputed_norm"] = (log1p_acc / K).astype(np.float32) + norm_layer = (log1p_acc / K).astype(np.float32) + adata_out.layers["imputed_norm"] = sp.csr_matrix(norm_layer) if return_sparse else norm_layer if sdata is not None: sdata.tables[table_key] = adata_out diff --git a/tests/models/test_cellpin.py b/tests/models/test_cellpin.py index 5cb89d7..69c1972 100644 --- a/tests/models/test_cellpin.py +++ b/tests/models/test_cellpin.py @@ -73,16 +73,111 @@ def test_cellpin_get_cell_embedding(small_datasets): assert embeddings.dtype == np.float32 -def test_cellpin_impute_to_anndata(small_datasets): +def test_cellpin_impute(small_datasets): + """Basic shape and key checks for impute().""" + import scipy.sparse as sp + sc_ds, st_ds = small_datasets model = CellPin(sc_dataset=sc_ds, config=MINIMAL_CONFIG) model.eval() loader = DataLoader(st_ds, batch_size=4, shuffle=False) - adata_out = model.impute_to_anndata(loader) + adata_out = model.impute(loader, mc_samples=2) assert adata_out.n_obs == len(st_ds) assert adata_out.n_vars == sc_ds.X.shape[1] assert "X_cellpin" in adata_out.obsm assert adata_out.obsm["X_cellpin"].shape == (len(st_ds), MINIMAL_CONFIG["n_latent"]) assert "imputed" in adata_out.layers + assert "is_measured" in adata_out.var.columns + + +def test_impute_no_sentinel(small_datasets): + """No -2 values anywhere in imputed output.""" + sc_ds, st_ds = small_datasets + model = CellPin(sc_dataset=sc_ds, config=MINIMAL_CONFIG) + model.eval() + + loader = DataLoader(st_ds, batch_size=4, shuffle=False) + adata_sc_loader = DataLoader(sc_ds, batch_size=4, shuffle=False) + + # impute spatial with obs_adata that has fewer genes (triggers the fill path) + import anndata as ad + import numpy as np + rng = np.random.default_rng(0) + obs_adata = ad.AnnData(X=rng.integers(1, 10, size=(len(st_ds), 8)).astype(np.float32)) + obs_adata.var_names = [f"gene{i}" for i in range(8)] + obs_adata.layers["counts"] = obs_adata.X.copy() + + adata_out = model.impute(loader, obs_adata=obs_adata, mc_samples=2, return_sparse=False) + + assert np.all(adata_out.X >= 0), "X contains negative values" + for lyr in adata_out.layers.values(): + arr = lyr.toarray() if hasattr(lyr, "toarray") else lyr + assert np.all(arr >= 0), f"Layer contains negative values" + + +def test_is_measured_var_column(small_datasets): + """is_measured correctly marks panel genes vs unmeasured genes.""" + import anndata as ad + import numpy as np + + sc_ds, st_ds = small_datasets + model = CellPin(sc_dataset=sc_ds, config=MINIMAL_CONFIG) + model.eval() + + loader = DataLoader(st_ds, batch_size=4, shuffle=False) + + # obs_adata has only 8 of the 20 genes + rng = np.random.default_rng(1) + obs_adata = ad.AnnData(X=rng.integers(1, 10, size=(len(st_ds), 8)).astype(np.float32)) + obs_adata.var_names = [f"gene{i}" for i in range(8)] + obs_adata.layers["counts"] = obs_adata.X.copy() + + adata_out = model.impute(loader, obs_adata=obs_adata, mc_samples=2) + + assert "is_measured" in adata_out.var.columns + assert adata_out.var["is_measured"].dtype == bool + # genes 0-7 are in obs_adata, genes 8-19 are not + assert adata_out.var.loc["gene0", "is_measured"] is np.bool_(True) + assert adata_out.var.loc["gene7", "is_measured"] is np.bool_(True) + assert adata_out.var.loc["gene8", "is_measured"] is np.bool_(False) + assert adata_out.var.loc["gene19", "is_measured"] is np.bool_(False) + assert adata_out.var["is_measured"].sum() == 8 + + +def test_impute_sparse_output(small_datasets): + """X and layers are sparse by default; dense when return_sparse=False.""" + import scipy.sparse as sp + + sc_ds, st_ds = small_datasets + model = CellPin(sc_dataset=sc_ds, config=MINIMAL_CONFIG) + model.eval() + + loader = DataLoader(st_ds, batch_size=4, shuffle=False) + + adata_sparse = model.impute(loader, mc_samples=2, return_sparse=True) + assert sp.issparse(adata_sparse.X) + assert sp.issparse(adata_sparse.layers["imputed"]) + + adata_dense = model.impute(loader, mc_samples=2, return_sparse=False) + assert not sp.issparse(adata_dense.X) + assert not sp.issparse(adata_dense.layers["imputed"]) + + +def test_impute_return_int_sparse(small_datasets): + """return_int=True produces sparse int32 with no negative values.""" + import scipy.sparse as sp + + sc_ds, st_ds = small_datasets + model = CellPin(sc_dataset=sc_ds, config=MINIMAL_CONFIG) + model.eval() + + loader = DataLoader(st_ds, batch_size=4, shuffle=False) + adata_out = model.impute(loader, mc_samples=2, return_int=True, return_sparse=True) + + assert sp.issparse(adata_out.X) + assert adata_out.X.dtype == np.int32 + assert sp.issparse(adata_out.layers["imputed"]) + assert adata_out.layers["imputed"].dtype == np.int32 + assert adata_out.X.min() >= 0