Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions examples/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,11 @@ if (SPBLAS_CPU_BACKEND)
add_example(simple_sptrsv)
add_example(spmm_csc)
add_example(matrix_opt_example)
if (ENABLE_ONEMKL_SYCL OR SPBLAS_REFERENCE_BACKEND)
# needs CPU + matrix_opt + operation_info_t to run
add_example(sptrsv_csr) # needs triangular_solve{_inspect} to run
add_example(spmm_csr) # needs multiply{_inspect} to run
endif()
endif()

# GPU examples
Expand Down
53 changes: 53 additions & 0 deletions examples/spmm_csr.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
#include <spblas/spblas.hpp>

#include <fmt/core.h>
#include <fmt/ranges.h>

int main(int argc, char** argv) {
using namespace spblas;
namespace md = spblas::__mdspan;

using T = float;

spblas::index_t m = 10;
spblas::index_t n = 10;
spblas::index_t k = 10;
spblas::index_t nnz_in = 20;

fmt::print("\n\t###########################################################"
"######################");
fmt::print("\n\t### Running Advanced SpMM Example:");
fmt::print("\n\t###");
fmt::print("\n\t###");
fmt::print("\n\t### with ");
fmt::print("\n\t### A, in CSR format, of size ({}, {}) with nnz = {}", m, k,
nnz_in);
fmt::print("\n\t### x, a dense matrix, of size ({}, {})", k, n);
fmt::print("\n\t### y, a dense vector, of size ({}, {})", m, n);
fmt::print("\n\t### using float and spblas::index_t (size = {} bytes)",
sizeof(spblas::index_t));
fmt::print("\n\t###########################################################"
"######################");
fmt::print("\n");

auto&& [values, rowptr, colind, shape, nnz] = generate_csr<T>(m, k, nnz_in);

csr_view<T> a(values, rowptr, colind, shape, nnz);
matrix_opt a_opt(a);

std::vector<T> x_values(k * n, 1);
std::vector<T> y_values(m * n, 0);

md::mdspan x(x_values.data(), k, n);
md::mdspan y(y_values.data(), m, n);

// Y = A * X
auto state = multiply_inspect(a_opt, x, y);
multiply(state, a_opt, x, y);

fmt::print("{}\n", spblas::__backend::values(y));

fmt::print("\tExample is completed!\n");

return 0;
}
63 changes: 63 additions & 0 deletions examples/sptrsv_csr.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
#include <spblas/spblas.hpp>

#include <fmt/core.h>
#include <fmt/ranges.h>

int main(int argc, char** argv) {
using namespace spblas;

using T = float;

spblas::index_t m = 100;
spblas::index_t nnz_in = 20;

fmt::print("\n\t###########################################################"
"######################");
fmt::print("\n\t### Running Full SpTRSV Example:");
fmt::print("\n\t###");
fmt::print("\n\t### solve for x: A * x = alpha * b");
fmt::print("\n\t###");
fmt::print("\n\t### with ");
fmt::print("\n\t### A, in CSR format, of size ({}, {}) with nnz = {}", m, m,
nnz_in);
fmt::print("\n\t### x, a dense vector, of size ({}, {})", m, 1);
fmt::print("\n\t### b, a dense vector, of size ({}, {})", m, 1);
fmt::print("\n\t### using float and spblas::index_t (size = {} bytes)",
sizeof(spblas::index_t));
fmt::print("\n\t###########################################################"
"######################");
fmt::print("\n");

auto&& [values, rowptr, colind, shape, nnz] =
generate_csr<T, spblas::index_t>(m, m, nnz_in);

// scale values of matrix to make the implicit unit diagonal matrix
// be diagonally dominant, so it is solveable
T scale_factor = 1e-3f;
std::transform(values.begin(), values.end(), values.begin(),
[scale_factor](T val) { return scale_factor * val; });

csr_view<T, spblas::index_t> a(values, rowptr, colind, shape, nnz);

matrix_opt a_opt(a);

// Scale every value of `a` by 5 in place.
// scale(5.f, a);

std::vector<T> x(m, 0);
std::vector<T> b(m, 1);

T alpha = 1.2f;
auto b_scaled = scaled(alpha, b);

// solve for x: lower(A) * x = alpha * b
triangular_solve_inspect(a_opt, spblas::upper_triangle_t{},
spblas::implicit_unit_diagonal_t{}, b_scaled, x);

triangular_solve(a_opt, spblas::upper_triangle_t{},
spblas::implicit_unit_diagonal_t{}, b_scaled, x);
Comment thread
yhmtsai marked this conversation as resolved.

fmt::print("\tExample is completed!\n");

return 0;
}
28 changes: 28 additions & 0 deletions include/spblas/algorithms/multiply.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,46 @@

namespace spblas {

// SpMV variants
template <matrix A, vector B, vector C>
operation_info_t multiply_inspect(A&& a, B&& b, C&& c);

template <matrix A, vector B, vector C>
void multiply_inspect(operation_info_t& info, A&& a, B&& b, C&& c);

template <matrix A, vector B, vector C>
void multiply(A&& a, B&& b, C&& c);

template <matrix A, vector B, vector C>
void multiply(operation_into_t& info, A&& a, B&& b, C&& c);

// SpMM variants
template <matrix A, matrix B, matrix C>
void multiply(A&& a, B&& b, C&& c);

template <matrix A, matrix B, matrix C>
void multiply(operation_info_t& info, A&& a, B&& b, C&& c);

// SpMM and SpGEMM multiply_inspect variants
template <matrix A, matrix B, matrix C>
operation_info_t multiply_inspect(A&& a, B&& b, C&& c);

template <matrix A, matrix B, matrix C>
void multiply_inspect(operation_info_t& info, A&& a, B&& b, C&& c);

// SpGEMM variants
template <typename ExecutionPolicy, matrix A, matrix B, matrix C>
operation_info_t multiply_compute(ExecutionPolicy&& policy, A&& a, B&& b,
C&& c);

template <typename ExecutionPolicy, matrix A, matrix B, matrix C>
void multiply_compute(ExecutionPolicy&& policy, operation_info_t& info, A&& a,
B&& b, C&& c);

template <typename ExecutionPolicy, matrix A, matrix B, matrix C>
void multiply_fill(ExecutionPolicy&& policy, operation_info_t& info, A&& a,
B&& b, C&& c);

template <matrix A, matrix B, matrix C>
operation_info_t multiply_compute(A&& a, B&& b, C&& c);

Expand Down
49 changes: 48 additions & 1 deletion include/spblas/algorithms/multiply_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,19 @@

namespace spblas {

// SpMV inspect
template <matrix A, vector B, vector C>
operation_info_t multiply_inspect(A&& a, B&& b, C&& c) {
log_trace("");
return operation_info_t{};
}

// SpMV inspect
template <matrix A, vector B, vector C>
operation_info_t multiply_inspect(operation_info_t& info, A&& a, B&& b, C&& c) {
log_trace("");
}

// C = AB
// SpMV
template <matrix A, vector B, vector C>
Expand All @@ -39,6 +52,15 @@ void multiply(A&& a, B&& b, C&& c) {
});
}

// C = AB
// SpMV with info input
template <matrix A, vector B, vector C>
requires(__backend::lookupable<B> && __backend::lookupable<C>)
void multiply(operation_info_t& info, A&& a, B&& b, C&& c) {
log_trace("");
multiply(std::forward<A>(a), std::forward<B>(b), std::forward<C>(c));
}

// C = AB
// SpMM
template <matrix A, matrix B, matrix C>
Expand All @@ -52,11 +74,14 @@ void multiply(A&& a, B&& b, C&& c) {
"multiply: matrix dimensions are incompatible.");
}

// initializes c to zero so we can use += everywhere
__backend::for_each(c, [](auto&& e) {
auto&& [_, v] = e;
v = 0;
});

// traverses elements of a and performs appropriate
// multiplication with B rows
__backend::for_each(a, [&](auto&& e) {
auto&& [idx, a_v] = e;
auto&& [i, k] = idx;
Expand All @@ -66,23 +91,44 @@ void multiply(A&& a, B&& b, C&& c) {
});
}

// C = AB
// SpMM with info
template <matrix A, matrix B, matrix C>
requires(__backend::lookupable<B> && __backend::lookupable<C>)
void multiply(operation_info_t& info, A&& a, B&& b, C&& c) {
log_trace("");
multiply(std::forward<A>(a), std::forward<B>(b), std::forward<C>(c));
}

// C = AB
// SpMM or SpGEMM multiply_inspect variants end up here
template <matrix A, matrix B, matrix C>
operation_info_t multiply_inspect(A&& a, B&& b, C&& c) {
log_trace("");
return operation_info_t{};
}

// C = AB
// SpMM or SpGEMM multiply_inspect variants end up here
template <matrix A, matrix B, matrix C>
void multiply_inspect(operation_info_t& info, A&& a, B&& b, C&& c){};
void multiply_inspect(operation_info_t& info, A&& a, B&& b, C&& c) {
log_trace("");
};

// C = AB
// SpGEMM compute stage with CSR output
template <matrix A, matrix B, matrix C>
requires(__backend::row_iterable<A> && __backend::row_iterable<B> &&
__detail::is_csr_view_v<C>)
void multiply_compute(operation_info_t& info, A&& a, B&& b, C&& c) {
log_trace("");
auto new_info = multiply_compute(std::forward<A>(a), std::forward<B>(b),
std::forward<C>(c));
info.update_impl_(new_info.result_shape(), new_info.result_nnz());
}

// C = AB
// SpGEMM compute stage with CSC output
template <matrix A, matrix B, matrix C>
requires(__backend::column_iterable<A> && __backend::column_iterable<B> &&
__detail::is_csc_view_v<C>)
Expand All @@ -93,6 +139,7 @@ void multiply_compute(operation_info_t& info, A&& a, B&& b, C&& c) {
}

// C = AB
// SpGEMM fill stage with CSR or CSC output
template <matrix A, matrix B, matrix C>
void multiply_fill(operation_info_t info, A&& a, B&& b, C&& c) {
log_trace("");
Expand Down
13 changes: 8 additions & 5 deletions include/spblas/algorithms/triangular_solve.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,16 @@
#include <spblas/concepts.hpp>
#include <spblas/detail/operation_info_t.hpp>

template <class ExecutionPolicy, in - matrix InMat, class Triangle,
class DiagonalStorage, in - vector InVec, out - vector OutVec>
void triangular_matrix_vector_solve(ExecutionPolicy&& exec, InMat A, Triangle t,
DiagonalStorage d, InVec b, OutVec x);

namespace spblas {

template <matrix A, class Triangle, class DiagonalStorage, vector B, vector X>
void triangular_solve_inspect(operation_info_t& info, A&& a, Triangle uplo,
DiagonalStorage diag, B&& b, X&& x);

template <matrix A, class Triangle, class DiagonalStorage, vector B, vector X>
operation_info_t triangular_solve_inspect(A&& a, Triangle uplo,
DiagonalStorage diag, B&& b, X&& x);

template <matrix A, class Triangle, class DiagonalStorage, vector B, vector X>
void triangular_solve(A&& a, Triangle uplo, DiagonalStorage diag, B&& b, X&& x);

Expand Down
44 changes: 44 additions & 0 deletions include/spblas/algorithms/triangular_solve_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,41 @@

namespace spblas {

// X = inv(A) B
// SpTRSV inspect stage
template <matrix A, class Triangle, class DiagonalStorage, vector B, vector X>
requires(__backend::row_iterable<A> && __backend::lookupable<B> &&
__backend::lookupable<X>)
operation_info_t triangular_solve_inspect(A&& a, Triangle t, DiagonalStorage d,
B&& b, X&& x) {
log_trace("");
static_assert(std::is_same_v<Triangle, upper_triangle_t> ||
std::is_same_v<Triangle, lower_triangle_t>);
assert(__backend::shape(a)[0] == __backend::shape(a)[1]);

return operation_info_t{};
}

// X = inv(A) B
// SpTRSV inspect stage
template <matrix A, class Triangle, class DiagonalStorage, vector B, vector X>
requires(__backend::row_iterable<A> && __backend::lookupable<B> &&
__backend::lookupable<X>)
void triangular_solve_inspect(operation_info_t& info, A&& a, Triangle t,
DiagonalStorage d, B&& b, X&& x) {
log_trace("");
static_assert(std::is_same_v<Triangle, upper_triangle_t> ||
std::is_same_v<Triangle, lower_triangle_t>);
assert(__backend::shape(a)[0] == __backend::shape(a)[1]);
}

// X = inv(A) B
// SpTRSV solve stage
template <matrix A, class Triangle, class DiagonalStorage, vector B, vector X>
requires(__backend::row_iterable<A> && __backend::lookupable<B> &&
__backend::lookupable<X>)
void triangular_solve(A&& a, Triangle t, DiagonalStorage d, B&& b, X&& x) {
log_trace("");
static_assert(std::is_same_v<Triangle, upper_triangle_t> ||
std::is_same_v<Triangle, lower_triangle_t>);
assert(__backend::shape(a)[0] == __backend::shape(a)[1]);
Expand Down Expand Up @@ -62,4 +93,17 @@ void triangular_solve(A&& a, Triangle t, DiagonalStorage d, B&& b, X&& x) {
}
}

// X = inv(A) B
// SpTRSV solve stage with info
template <matrix A, class Triangle, class DiagonalStorage, vector B, vector X>
requires(__backend::row_iterable<A> && __backend::lookupable<B> &&
__backend::lookupable<X>)
void triangular_solve(operation_info_t& info, A&& a, Triangle t,
DiagonalStorage d, B&& b, X&& x) {
log_trace("");
triangular_solve(std::forward<A>(a), std::forward<Triangle>(t),
std::forward<DiagonalStorage>(d), std::forward<B>(b),
std::forward<X>(x));
}

} // namespace spblas
20 changes: 16 additions & 4 deletions include/spblas/vendor/onemkl_sycl/detail/create_matrix_handle.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,14 @@ oneapi::mkl::sparse::matrix_handle_t create_matrix_handle(sycl::queue& q,
oneapi::mkl::sparse::init_matrix_handle(&handle);

oneapi::mkl::sparse::set_csr_data(
q, handle, m.shape()[0], m.shape()[1], oneapi::mkl::index_base::zero,
m.rowptr().data(), m.colind().data(), m.values().data())
q, handle, m.shape()[0], m.shape()[1],
#if defined(__INTEL_MKL__) && \
((__INTEL_MKL__ == 2025) && (__INTEL_MKL_MINOR__ == 3) || \
(__INTEL_MKL__ > 2025))
m.size(), // nnz added in 2025.3, and without deprecated
#endif
oneapi::mkl::index_base::zero, m.rowptr().data(), m.colind().data(),
m.values().data())
.wait();

return handle;
Expand All @@ -33,8 +39,14 @@ oneapi::mkl::sparse::matrix_handle_t create_matrix_handle(sycl::queue& q,
oneapi::mkl::sparse::init_matrix_handle(&handle);

oneapi::mkl::sparse::set_csr_data(
q, handle, m.shape()[1], m.shape()[0], oneapi::mkl::index_base::zero,
m.colptr().data(), m.rowind().data(), m.values().data())
q, handle, m.shape()[1], m.shape()[0],
#if defined(__INTEL_MKL__) && \
((__INTEL_MKL__ == 2025) && (__INTEL_MKL_MINOR__ == 3) || \
(__INTEL_MKL__ > 2025))
m.size(), // nnz added in 2025.3, and without deprecated
#endif
Comment on lines +43 to +47
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MKL SYCL api change in 2025.3 adding nnz to csr/csc/bsr and adding coo format which already had nnz in definition.

oneapi::mkl::index_base::zero, m.colptr().data(), m.rowind().data(),
m.values().data())
.wait();

return handle;
Expand Down
Loading
Loading