diff --git a/include/spblas/views/matrix_view.hpp b/include/spblas/views/matrix_view.hpp new file mode 100644 index 0000000..cca047b --- /dev/null +++ b/include/spblas/views/matrix_view.hpp @@ -0,0 +1,338 @@ +#pragma once +#include +#include + +namespace spblas { +namespace matrix_view { +namespace diag { +// always treat the diagonal value as unit +class implicit_unit {}; + +// always treat the diagonal value as zero +class implicit_zero {}; + +// use the matrix diagonal +class explicit_diag {}; + +template +concept diag = + std::is_same_v || std::is_same_v || + std::is_same_v; +} // namespace diag + +namespace uplo { +// full matrix +class full {}; + +// take the lower triangular part of the matrix +class lower {}; + +// take the upper triangular part of the matrix +class upper {}; + +// take the diagonal part of the matrix +class diag {}; + +template +concept uplo = std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v; +} // namespace uplo + +namespace __detail { + +// help the diagonal transformation +// implicit_* will overwite the previous type +// explicit_diag keeps the type +template +struct decide_diag {}; + +template +struct decide_diag { + using type = diag::implicit_unit; +}; + +template +struct decide_diag { + using type = diag::implicit_zero; +}; + +template +struct decide_diag { + using type = T; +}; + +} // namespace __detail + +/** + * This is a view contain all possiblitity how kernel can intepret the matrix + * with specific order. This will not touch the matrix_opt itself, it will leave + * the operation to backend to decide what to do. + * + * @tparam Conjugate whether to conjugate the matrix + * @tparam Transpose whether to transpose the matrix + * @tparam Diagonal how to handle diagonal + * @tparam UpLo how to access the part of matrix + */ +template +class general : public spblas::view_base { +public: + general(matrix_opt&& t) : obj(t) {} + + auto& base() { + return obj; + } + + auto& base() const { + return obj; + } + +private: + matrix_opt& obj; +}; + +template +auto conjugate(matrix_opt&& matrix) { + return general(matrix); +} + +template +auto conjugate( + general&& matrix) { + return general( + matrix.base()); +} + +template +auto conjugate( + general&& matrix) { + return general( + matrix.base()); +} + +template +auto transpose(matrix_opt&& matrix) { + return general(matrix); +} + +template + requires(!std::is_same_v) +auto transpose( + general&& matrix) { + return general( + matrix.base()); +} + +template + requires(!std::is_same_v) +auto transpose( + general&& matrix) { + return general( + matrix.base()); +} + +template +auto transpose( + general&& matrix) { + return general( + matrix.base()); +} + +template +auto diagonal(matrix_opt&& matrix, TreatDiag = {}) { + return general(matrix); +} + +template +auto diagonal(general::type, + UpLo>&& matrix, + TreatDiag = {}) { + return general::type, + uplo::diag>(matrix.base()); +} + +template +auto triangle(matrix_opt&& matrix, uplo::lower, TreatDiag = {}) { + return general(matrix); +} + +template +auto triangle(general&& matrix, + uplo::lower, TreatDiag = {}) { + return general::type, + uplo::lower>(matrix.base()); +} + +template +auto triangle(general&& matrix, + uplo::lower, TreatDiag = {}) { + return general::type, + uplo::lower>(matrix.base()); +} + +template +auto triangle(general&& matrix, + uplo::lower, TreatDiag = {}) { + return general::type, + uplo::diag>(matrix.base()); +} + +template +auto triangle(general&& matrix, + uplo::lower, TreatDiag = {}) { + return general::type, + uplo::diag>(matrix.base()); +} + +template +auto triangle(general&& matrix, + uplo::lower, TreatDiag = {}) { + return general::type, + uplo::upper>(matrix.base()); +} + +template +auto triangle(general&& matrix, + uplo::lower, TreatDiag = {}) { + return general::type, + uplo::diag>(matrix.base()); +} + +template +auto triangle(general&& matrix, + uplo::lower, TreatDiag = {}) { + return general::type, + uplo::upper>(matrix.base()); +} + +template +auto triangle(general&& matrix, + uplo::lower, TreatDiag = {}) { + return general::type, + uplo::diag>(matrix.base()); +} + +template +auto triangle(matrix_opt&& matrix, uplo::upper, TreatDiag = {}) { + return general(matrix); +} + +template +auto triangle(general&& matrix, + uplo::upper, TreatDiag = {}) { + return general::type, + uplo::upper>(matrix.base()); +} + +template +auto triangle(general&& matrix, + uplo::upper, TreatDiag = {}) { + return general::type, + uplo::diag>(matrix.base()); +} + +template +auto triangle(general&& matrix, + uplo::upper, TreatDiag = {}) { + return general::type, + uplo::upper>(matrix.base()); +} + +template +auto triangle(general&& matrix, + uplo::upper, TreatDiag = {}) { + return general::type, + uplo::diag>(matrix.base()); +} + +template +auto triangle(general&& matrix, + uplo::upper, TreatDiag = {}) { + return general::type, + uplo::lower>(matrix.base()); +} + +template +auto triangle(general&& matrix, + uplo::upper, TreatDiag = {}) { + return general::type, + uplo::lower>(matrix.base()); +} + +template +auto triangle(general&& matrix, + uplo::upper, TreatDiag = {}) { + return general::type, + uplo::diag>(matrix.base()); +} + +template +auto triangle(general&& matrix, + uplo::upper, TreatDiag = {}) { + return general::type, + uplo::diag>(matrix.base()); +} + +} // namespace matrix_view +} // namespace spblas diff --git a/test/gtest/CMakeLists.txt b/test/gtest/CMakeLists.txt index b458a17..79ec433 100644 --- a/test/gtest/CMakeLists.txt +++ b/test/gtest/CMakeLists.txt @@ -12,7 +12,8 @@ if (SPBLAS_CPU_BACKEND) add_test.cpp transpose_test.cpp triangular_solve_test.cpp - mdspan_overlays.cpp) + mdspan_overlays.cpp + matrix_view_test.cpp) if (ENABLE_ONEMKL_SYCL OR SPBLAS_REFERENCE_BACKEND) list(APPEND TEST_SOURCES conjugate_test.cpp) diff --git a/test/gtest/matrix_view_test.cpp b/test/gtest/matrix_view_test.cpp new file mode 100644 index 0000000..81cc994 --- /dev/null +++ b/test/gtest/matrix_view_test.cpp @@ -0,0 +1,230 @@ +#include +#include +#include + +namespace { + +class temp {}; + +using ::testing::StaticAssertTypeEq; +// only for the testing +using namespace spblas::matrix_view; + +TEST(Tag, Conjugate) { + temp t; + + StaticAssertTypeEq>(); + StaticAssertTypeEq>(); + EXPECT_EQ(&(conjugate(t).base()), &t); + EXPECT_EQ(&(conjugate(conjugate(t)).base()), &t); +} + +TEST(Tag, Tranpose) { + temp t; + + StaticAssertTypeEq>(); + StaticAssertTypeEq>(); + EXPECT_EQ(&(transpose(t).base()), &t); + EXPECT_EQ(&(transpose(transpose(t)).base()), &t); +} + +TEST(Tag, Diagonal) { + temp t; + + StaticAssertTypeEq>(); + EXPECT_EQ(&(diagonal(t).base()), &t); +} + +TEST(Tag, Lower) { + temp t; + + StaticAssertTypeEq>(); + StaticAssertTypeEq>(); + StaticAssertTypeEq>(); + StaticAssertTypeEq>(); + StaticAssertTypeEq>(); + StaticAssertTypeEq>(); + StaticAssertTypeEq>(); + StaticAssertTypeEq>(); + // Unit or Zero Diag will overwrite the old one + StaticAssertTypeEq>(); + StaticAssertTypeEq>(); + + EXPECT_EQ(&(triangle(t, uplo::lower()).base()), &t); + EXPECT_EQ(&(triangle(t, uplo::lower(), diag::implicit_zero()).base()), &t); + EXPECT_EQ(&(triangle(t, uplo::lower(), diag::implicit_unit()).base()), &t); + EXPECT_EQ(&(triangle(triangle(t, uplo::lower(), diag::implicit_zero()), + uplo::lower(), diag::implicit_unit()) + .base()), + &t); + EXPECT_EQ(&(triangle(triangle(t, uplo::lower(), diag::implicit_unit()), + uplo::lower(), diag::implicit_zero()) + .base()), + &t); +} + +TEST(Tag, Upper) { + temp t; + + StaticAssertTypeEq>(); + StaticAssertTypeEq>(); + StaticAssertTypeEq>(); + StaticAssertTypeEq>(); + StaticAssertTypeEq>(); + StaticAssertTypeEq>(); + StaticAssertTypeEq>(); + StaticAssertTypeEq>(); + // Unit or Zero Diag will overwrite the old one + StaticAssertTypeEq>(); + StaticAssertTypeEq>(); + EXPECT_EQ(&(triangle(t, uplo::upper()).base()), &t); + EXPECT_EQ(&(triangle(t, uplo::upper(), diag::implicit_zero()).base()), &t); + EXPECT_EQ(&(triangle(t, uplo::upper(), diag::implicit_unit()).base()), &t); + EXPECT_EQ(&(triangle(triangle(t, uplo::upper(), diag::implicit_zero()), + uplo::upper(), diag::implicit_unit()) + .base()), + &t); + EXPECT_EQ(&(triangle(triangle(t, uplo::upper(), diag::implicit_unit()), + uplo::upper(), diag::implicit_zero()) + .base()), + &t); +} + +TEST(Tag, MixUpperAndLower) { + temp t; + + StaticAssertTypeEq>(); + StaticAssertTypeEq>(); +} + +TEST(Tag, GetTransposeOfUpper) { + temp t; + + StaticAssertTypeEq>(); + // Lower(M^T) = Upper(M)^T + StaticAssertTypeEq>(); +} + +TEST(Tag, GetTransposeOfLower) { + temp t; + + StaticAssertTypeEq>(); + // Upper(M^T) = Lower(M)^T + StaticAssertTypeEq>(); +} + +TEST(Tag, LongChain) { + temp t; + + StaticAssertTypeEq>(); + StaticAssertTypeEq>(); + StaticAssertTypeEq< + decltype(transpose(conjugate( + triangle(transpose(triangle(t, uplo::lower(), diag::implicit_zero())), + uplo::lower(), diag::implicit_unit())))), + general>(); +} + +} // namespace