diff --git a/include/LinearAlgebra/Solvers/csr_lu_solver.h b/include/LinearAlgebra/Solvers/csr_lu_solver.h index 89a1e59a..5f32507a 100644 --- a/include/LinearAlgebra/Solvers/csr_lu_solver.h +++ b/include/LinearAlgebra/Solvers/csr_lu_solver.h @@ -59,8 +59,7 @@ class SparseLUSolver * * @param b Right-hand side vector (modified in place to contain the solution). */ - void solveInPlace(Vector b) const; - void solveInPlace(T* b) const; + void solveInPlace(Vector& b) const; private: // LU decomposition data structures @@ -76,7 +75,7 @@ class SparseLUSolver // Core methods void factorize(const SparseMatrixCSR& A); - void solveInPlacePermuted(T* b) const; + void solveInPlacePermuted(Vector& b) const; // Reordering and permutation utilities std::vector computeRCM(const SparseMatrixCSR& A) const; @@ -129,17 +128,7 @@ SparseLUSolver::SparseLUSolver(const SparseMatrixCSR& A, T tolerance_abs, * @param b - Right-hand side vector (overwritten with solution) */ template -void SparseLUSolver::solveInPlace(Vector b) const -{ - solveInPlace(b.data()); -} - -/** - * Solves Ax = b for raw pointer - * @param b - Right-hand side vector (overwritten with solution) - */ -template -void SparseLUSolver::solveInPlace(T* b) const +void SparseLUSolver::solveInPlace(Vector& b) const { assert(factorized_); const int n = perm.size(); @@ -153,7 +142,7 @@ void SparseLUSolver::solveInPlace(T* b) const } // Solve permuted system - solveInPlacePermuted(b_perm.data()); + solveInPlacePermuted(b_perm); // Unpermute solution: x = P^T * x_perm for (int i = 0; i < n; i++) { @@ -166,7 +155,7 @@ void SparseLUSolver::solveInPlace(T* b) const * @param b - Permuted right-hand side vector (overwritten with solution) */ template -void SparseLUSolver::solveInPlacePermuted(T* b) const +void SparseLUSolver::solveInPlacePermuted(Vector& b) const { const int n = L_row_ptr.size() - 1;