Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 5 additions & 16 deletions include/LinearAlgebra/Solvers/csr_lu_solver.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,7 @@ class SparseLUSolver
*
* @param b Right-hand side vector (modified in place to contain the solution).
*/
void solveInPlace(Vector<T> b) const;
void solveInPlace(T* b) const;
void solveInPlace(Vector<T>& b) const;

private:
// LU decomposition data structures
Expand All @@ -76,7 +75,7 @@ class SparseLUSolver

// Core methods
void factorize(const SparseMatrixCSR<T>& A);
void solveInPlacePermuted(T* b) const;
void solveInPlacePermuted(Vector<T>& b) const;

// Reordering and permutation utilities
std::vector<int> computeRCM(const SparseMatrixCSR<T>& A) const;
Expand Down Expand Up @@ -129,17 +128,7 @@ SparseLUSolver<T>::SparseLUSolver(const SparseMatrixCSR<T>& A, T tolerance_abs,
* @param b - Right-hand side vector (overwritten with solution)
*/
template <typename T>
void SparseLUSolver<T>::solveInPlace(Vector<T> b) const
{
solveInPlace(b.data());
}

/**
* Solves Ax = b for raw pointer
* @param b - Right-hand side vector (overwritten with solution)
*/
template <typename T>
void SparseLUSolver<T>::solveInPlace(T* b) const
void SparseLUSolver<T>::solveInPlace(Vector<T>& b) const
{
assert(factorized_);
const int n = perm.size();
Expand All @@ -153,7 +142,7 @@ void SparseLUSolver<T>::solveInPlace(T* b) const
}

// Solve permuted system
solveInPlacePermuted(b_perm.data());
solveInPlacePermuted(b_perm);

// Unpermute solution: x = P^T * x_perm
for (int i = 0; i < n; i++) {
Expand All @@ -166,7 +155,7 @@ void SparseLUSolver<T>::solveInPlace(T* b) const
* @param b - Permuted right-hand side vector (overwritten with solution)
*/
template <typename T>
void SparseLUSolver<T>::solveInPlacePermuted(T* b) const
void SparseLUSolver<T>::solveInPlacePermuted(Vector<T>& b) const
{
const int n = L_row_ptr.size() - 1;

Expand Down
Loading