diff --git a/R/predict.R b/R/predict.R index da200eb..3582495 100644 --- a/R/predict.R +++ b/R/predict.R @@ -134,6 +134,8 @@ project_poisson_nmf <- function (X, F, numiter, ...) { rownames(L0) <- rownames(X) colnames(L0) <- colnames(F) fit <- init_poisson_nmf(X,F = F,L = L0) - fit <- fit_poisson_nmf(X,fit0 = fit,numiter = numiter,...) + + # Fix the factors and update the loadings. + fit <- fit_poisson_nmf(X,fit0 = fit,numiter = numiter, update.factors = NULL, ...) return(fit) } diff --git a/tests/testthat/test_predict.R b/tests/testthat/test_predict.R new file mode 100644 index 0000000..57dd217 --- /dev/null +++ b/tests/testthat/test_predict.R @@ -0,0 +1,20 @@ +context("predict") + +# Test that project_poisson_nmf returns F with the same structure as the input F +test_that("project_poisson_nmf returns F with the same structure as the input F", { + + set.seed(1) + dat <- simulate_multinom_gene_data(175,1200,k = 3) + train <- dat$X[1:100,] + test <- dat$X[101:175,] + + fit <- init_poisson_nmf(train,F = dat$F,init.method = "random") + fit <- fit_poisson_nmf(train,fit0 = fit,verbose = "none") + fit <- poisson2multinom(fit) + F_pois <- multinom2poisson(fit)$F + + out <- project_poisson_nmf(test,F_pois,numiter = 20,verbose = "none", update.factors = NULL) + + norm_col <- function(M) sweep(M,2,colSums(M),"/") + expect_equal(norm_col(out$F),norm_col(F_pois),tolerance = 1e-10) +})