Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
338 changes: 338 additions & 0 deletions include/spblas/views/matrix_view.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,338 @@
#pragma once
#include <spblas/views/view_base.hpp>
#include <type_traits>

namespace spblas {
namespace matrix_view {
namespace diag {
// always treat the diagonal value as unit
class implicit_unit {};

// always treat the diagonal value as zero
class implicit_zero {};

// use the matrix diagonal
class explicit_diag {};

template <typename d>
concept diag =
std::is_same_v<d, implicit_unit> || std::is_same_v<d, implicit_zero> ||
std::is_same_v<d, explicit_diag>;
} // namespace diag

namespace uplo {
// full matrix
class full {};

// take the lower triangular part of the matrix
class lower {};

// take the upper triangular part of the matrix
class upper {};

// take the diagonal part of the matrix
class diag {};

template <typename d>
concept uplo = std::is_same_v<d, full> || std::is_same_v<d, lower> ||
std::is_same_v<d, upper> || std::is_same_v<d, diag>;
} // namespace uplo

namespace __detail {

// help the diagonal transformation
// implicit_* will overwite the previous type
// explicit_diag keeps the type
template <typename T, typename U>
struct decide_diag {};

template <typename T>
struct decide_diag<T, diag::implicit_unit> {
using type = diag::implicit_unit;
};

template <typename T>
struct decide_diag<T, diag::implicit_zero> {
using type = diag::implicit_zero;
};

template <typename T>
struct decide_diag<T, diag::explicit_diag> {
using type = T;
};

} // namespace __detail

/**
* This is a view contain all possiblitity how kernel can intepret the matrix
* with specific order. This will not touch the matrix_opt itself, it will leave
* the operation to backend to decide what to do.
*
* @tparam Conjugate whether to conjugate the matrix
* @tparam Transpose whether to transpose the matrix
* @tparam Diagonal how to handle diagonal
* @tparam UpLo how to access the part of matrix
*/
template <typename matrix_opt, typename Conjugate = std::false_type,
typename Transpose = std::false_type,
Comment on lines +76 to +77
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was using none, conj, trans specific type in previous commit.
The current one use std::false_type for none and std::true_type for conj and trans.

diag::diag Diagonal = diag::explicit_diag,
uplo::uplo UpLo = uplo::full>
class general : public spblas::view_base {
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

other good name for that?
fixed_order, vendor, legacy, blas-like?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like something like "legacy_pattern" or "blas_pattern"

public:
general(matrix_opt&& t) : obj(t) {}

auto& base() {
return obj;
}

auto& base() const {
return obj;
}

private:
matrix_opt& obj;
};

template <typename matrix_opt>
auto conjugate(matrix_opt&& matrix) {
return general<matrix_opt, std::true_type>(matrix);
}

template <typename matrix_opt, typename Transpose, typename Diagonal,
typename UpLo>
auto conjugate(
general<matrix_opt, std::true_type, Transpose, Diagonal, UpLo>&& matrix) {
return general<matrix_opt, std::false_type, Transpose, Diagonal, UpLo>(
matrix.base());
}

template <typename Transpose, typename Diagonal, typename UpLo,
typename matrix_opt>
auto conjugate(
general<matrix_opt, std::false_type, Transpose, Diagonal, UpLo>&& matrix) {
return general<matrix_opt, std::true_type, Transpose, Diagonal, UpLo>(
matrix.base());
}

template <typename matrix_opt>
auto transpose(matrix_opt&& matrix) {
return general<matrix_opt, std::false_type, std::true_type>(matrix);
}

template <typename Conjugate, typename Diagonal, typename UpLo,
typename matrix_opt>
requires(!std::is_same_v<UpLo, uplo::diag>)
auto transpose(
general<matrix_opt, Conjugate, std::false_type, Diagonal, UpLo>&& matrix) {
return general<matrix_opt, Conjugate, std::true_type, Diagonal, UpLo>(
matrix.base());
}

template <typename Conjugate, typename Diagonal, typename UpLo,
typename matrix_opt>
requires(!std::is_same_v<UpLo, uplo::diag>)
auto transpose(
general<matrix_opt, Conjugate, std::true_type, Diagonal, UpLo>&& matrix) {
return general<matrix_opt, Conjugate, std::false_type, Diagonal, UpLo>(
matrix.base());
}

template <typename Conjugate, typename Diagonal, typename Transpose,
typename matrix_opt>
auto transpose(
general<matrix_opt, Conjugate, Transpose, Diagonal, uplo::diag>&& matrix) {
return general<matrix_opt, Conjugate, std::false_type, Diagonal, uplo::diag>(
matrix.base());
}

template <typename matrix_opt, typename TreatDiag = diag::explicit_diag>
auto diagonal(matrix_opt&& matrix, TreatDiag = {}) {
return general<matrix_opt, std::false_type, std::false_type, TreatDiag,
uplo::diag>(matrix);
}

template <typename Conjugate, typename Transpose, typename Diagonal,
typename UpLo, typename matrix_opt, typename TreatDiag>
auto diagonal(general<matrix_opt, Conjugate, Transpose,
typename __detail::decide_diag<Diagonal, TreatDiag>::type,
UpLo>&& matrix,
TreatDiag = {}) {
return general<matrix_opt, Conjugate, std::false_type,
typename __detail::decide_diag<Diagonal, TreatDiag>::type,
uplo::diag>(matrix.base());
}

template <typename matrix_opt, typename TreatDiag = diag::explicit_diag>
auto triangle(matrix_opt&& matrix, uplo::lower, TreatDiag = {}) {
return general<matrix_opt, std::false_type, std::false_type, TreatDiag,
uplo::lower>(matrix);
}

template <typename Conjugate, typename Diagonal, typename matrix_opt,
typename TreatDiag = diag::explicit_diag>
auto triangle(general<matrix_opt, Conjugate, std::false_type, Diagonal,
uplo::full>&& matrix,
uplo::lower, TreatDiag = {}) {
return general<matrix_opt, Conjugate, std::false_type,
typename __detail::decide_diag<Diagonal, TreatDiag>::type,
uplo::lower>(matrix.base());
}

template <typename Conjugate, typename Diagonal, typename matrix_opt,
typename TreatDiag = diag::explicit_diag>
auto triangle(general<matrix_opt, Conjugate, std::false_type, Diagonal,
uplo::lower>&& matrix,
uplo::lower, TreatDiag = {}) {
return general<matrix_opt, Conjugate, std::false_type,
typename __detail::decide_diag<Diagonal, TreatDiag>::type,
uplo::lower>(matrix.base());
}

template <typename Conjugate, typename Diagonal, typename matrix_opt,
typename TreatDiag = diag::explicit_diag>
auto triangle(general<matrix_opt, Conjugate, std::false_type, Diagonal,
uplo::upper>&& matrix,
uplo::lower, TreatDiag = {}) {
return general<matrix_opt, Conjugate, std::false_type,
typename __detail::decide_diag<Diagonal, TreatDiag>::type,
uplo::diag>(matrix.base());
}

template <typename Conjugate, typename Diagonal, typename matrix_opt,
typename TreatDiag = diag::explicit_diag>
auto triangle(general<matrix_opt, Conjugate, std::false_type, Diagonal,
uplo::diag>&& matrix,
uplo::lower, TreatDiag = {}) {
return general<matrix_opt, Conjugate, std::false_type,
typename __detail::decide_diag<Diagonal, TreatDiag>::type,
uplo::diag>(matrix.base());
}

template <typename Conjugate, typename Diagonal, typename matrix_opt,
typename TreatDiag = diag::explicit_diag>
auto triangle(general<matrix_opt, Conjugate, std::true_type, Diagonal,
uplo::full>&& matrix,
uplo::lower, TreatDiag = {}) {
return general<matrix_opt, Conjugate, std::true_type,
typename __detail::decide_diag<Diagonal, TreatDiag>::type,
uplo::upper>(matrix.base());
}

template <typename Conjugate, typename Diagonal, typename matrix_opt,
typename TreatDiag = diag::explicit_diag>
auto triangle(general<matrix_opt, Conjugate, std::true_type, Diagonal,
uplo::lower>&& matrix,
uplo::lower, TreatDiag = {}) {
return general<matrix_opt, Conjugate, std::false_type,
typename __detail::decide_diag<Diagonal, TreatDiag>::type,
uplo::diag>(matrix.base());
}

template <typename Conjugate, typename Diagonal, typename matrix_opt,
typename TreatDiag = diag::explicit_diag>
auto triangle(general<matrix_opt, Conjugate, std::true_type, Diagonal,
uplo::upper>&& matrix,
uplo::lower, TreatDiag = {}) {
return general<matrix_opt, Conjugate, std::true_type,
typename __detail::decide_diag<Diagonal, TreatDiag>::type,
uplo::upper>(matrix.base());
}

template <typename Conjugate, typename Diagonal, typename matrix_opt,
typename TreatDiag = diag::explicit_diag>
auto triangle(general<matrix_opt, Conjugate, std::true_type, Diagonal,
uplo::diag>&& matrix,
uplo::lower, TreatDiag = {}) {
return general<matrix_opt, Conjugate, std::false_type,
typename __detail::decide_diag<Diagonal, TreatDiag>::type,
uplo::diag>(matrix.base());
}

template <typename matrix_opt, typename TreatDiag = diag::explicit_diag>
auto triangle(matrix_opt&& matrix, uplo::upper, TreatDiag = {}) {
return general<matrix_opt, std::false_type, std::false_type, TreatDiag,
uplo::upper>(matrix);
}

template <typename Conjugate, typename Diagonal, typename matrix_opt,
typename TreatDiag = diag::explicit_diag>
auto triangle(general<matrix_opt, Conjugate, std::false_type, Diagonal,
uplo::full>&& matrix,
uplo::upper, TreatDiag = {}) {
return general<matrix_opt, Conjugate, std::false_type,
typename __detail::decide_diag<Diagonal, TreatDiag>::type,
uplo::upper>(matrix.base());
}

template <typename Conjugate, typename Diagonal, typename matrix_opt,
typename TreatDiag = diag::explicit_diag>
auto triangle(general<matrix_opt, Conjugate, std::false_type, Diagonal,
uplo::lower>&& matrix,
uplo::upper, TreatDiag = {}) {
return general<matrix_opt, Conjugate, std::false_type,
typename __detail::decide_diag<Diagonal, TreatDiag>::type,
uplo::diag>(matrix.base());
}

template <typename Conjugate, typename Diagonal, typename matrix_opt,
typename TreatDiag = diag::explicit_diag>
auto triangle(general<matrix_opt, Conjugate, std::false_type, Diagonal,
uplo::upper>&& matrix,
uplo::upper, TreatDiag = {}) {
return general<matrix_opt, Conjugate, std::false_type,
typename __detail::decide_diag<Diagonal, TreatDiag>::type,
uplo::upper>(matrix.base());
}

template <typename Conjugate, typename Diagonal, typename matrix_opt,
typename TreatDiag = diag::explicit_diag>
auto triangle(general<matrix_opt, Conjugate, std::false_type, Diagonal,
uplo::diag>&& matrix,
uplo::upper, TreatDiag = {}) {
return general<matrix_opt, Conjugate, std::false_type,
typename __detail::decide_diag<Diagonal, TreatDiag>::type,
uplo::diag>(matrix.base());
}

template <typename Conjugate, typename Diagonal, typename matrix_opt,
typename TreatDiag = diag::explicit_diag>
auto triangle(general<matrix_opt, Conjugate, std::true_type, Diagonal,
uplo::full>&& matrix,
uplo::upper, TreatDiag = {}) {
return general<matrix_opt, Conjugate, std::true_type,
typename __detail::decide_diag<Diagonal, TreatDiag>::type,
uplo::lower>(matrix.base());
}

template <typename Conjugate, typename Diagonal, typename matrix_opt,
typename TreatDiag = diag::explicit_diag>
auto triangle(general<matrix_opt, Conjugate, std::true_type, Diagonal,
uplo::lower>&& matrix,
uplo::upper, TreatDiag = {}) {
return general<matrix_opt, Conjugate, std::true_type,
typename __detail::decide_diag<Diagonal, TreatDiag>::type,
uplo::lower>(matrix.base());
}

template <typename Conjugate, typename Diagonal, typename matrix_opt,
typename TreatDiag = diag::explicit_diag>
auto triangle(general<matrix_opt, Conjugate, std::true_type, Diagonal,
uplo::upper>&& matrix,
uplo::upper, TreatDiag = {}) {
return general<matrix_opt, Conjugate, std::false_type,
typename __detail::decide_diag<Diagonal, TreatDiag>::type,
uplo::diag>(matrix.base());
}

template <typename Conjugate, typename Diagonal, typename matrix_opt,
typename TreatDiag = diag::explicit_diag>
auto triangle(general<matrix_opt, Conjugate, std::true_type, Diagonal,
uplo::diag>&& matrix,
uplo::upper, TreatDiag = {}) {
return general<matrix_opt, Conjugate, std::false_type,
typename __detail::decide_diag<Diagonal, TreatDiag>::type,
uplo::diag>(matrix.base());
}

} // namespace matrix_view
} // namespace spblas
3 changes: 2 additions & 1 deletion test/gtest/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ if (SPBLAS_CPU_BACKEND)
add_test.cpp
transpose_test.cpp
triangular_solve_test.cpp
mdspan_overlays.cpp)
mdspan_overlays.cpp
matrix_view_test.cpp)

if (ENABLE_ONEMKL_SYCL OR SPBLAS_REFERENCE_BACKEND)
list(APPEND TEST_SOURCES conjugate_test.cpp)
Expand Down
Loading
Loading