From 98c0658a9e0206496a0d3b48301e04000c4f69c4 Mon Sep 17 00:00:00 2001 From: roi-meir Date: Tue, 14 Apr 2026 17:25:05 +0300 Subject: [PATCH 1/2] fix: freeze F during prediction in project_poisson_nmf Without `update.factors = NULL`, fit_poisson_nmf used its default seq(1, ncol(X)) and refitted every row of F on the new (test) data, silently corrupting the training-estimated factors. Adding `update.factors = NULL` reduces to integer(0), which causes the main update loop to skip all factor updates so that only L is learned for the new data points. --- R/predict.R | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/R/predict.R b/R/predict.R index da200eb..3582495 100644 --- a/R/predict.R +++ b/R/predict.R @@ -134,6 +134,8 @@ project_poisson_nmf <- function (X, F, numiter, ...) { rownames(L0) <- rownames(X) colnames(L0) <- colnames(F) fit <- init_poisson_nmf(X,F = F,L = L0) - fit <- fit_poisson_nmf(X,fit0 = fit,numiter = numiter,...) + + # Fix the factors and update the loadings. + fit <- fit_poisson_nmf(X,fit0 = fit,numiter = numiter, update.factors = NULL, ...) return(fit) } From dbad59bf482cdc696513c04eaddc83e870de392b Mon Sep 17 00:00:00 2001 From: roi-meir Date: Tue, 14 Apr 2026 17:43:36 +0300 Subject: [PATCH 2/2] Add tests/testthat/test_predict.R The test verifies project_poisson_nmf returns F with the same column-normalised structure as the input (catches the update.factors = NULL fix) --- tests/testthat/test_predict.R | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) create mode 100644 tests/testthat/test_predict.R diff --git a/tests/testthat/test_predict.R b/tests/testthat/test_predict.R new file mode 100644 index 0000000..57dd217 --- /dev/null +++ b/tests/testthat/test_predict.R @@ -0,0 +1,20 @@ +context("predict") + +# Test that project_poisson_nmf returns F with the same structure as the input F +test_that("project_poisson_nmf returns F with the same structure as the input F", { + + set.seed(1) + dat <- simulate_multinom_gene_data(175,1200,k = 3) + train <- dat$X[1:100,] + test <- dat$X[101:175,] + + fit <- init_poisson_nmf(train,F = dat$F,init.method = "random") + fit <- fit_poisson_nmf(train,fit0 = fit,verbose = "none") + fit <- poisson2multinom(fit) + F_pois <- multinom2poisson(fit)$F + + out <- project_poisson_nmf(test,F_pois,numiter = 20,verbose = "none", update.factors = NULL) + + norm_col <- function(M) sweep(M,2,colSums(M),"/") + expect_equal(norm_col(out$F),norm_col(F_pois),tolerance = 1e-10) +})