diff --git a/test/mooncake/eig.jl b/test/mooncake/eig.jl index a0e60694..b2bbded4 100644 --- a/test/mooncake/eig.jl +++ b/test/mooncake/eig.jl @@ -18,4 +18,14 @@ for T in (BLASFloats..., GenericFloats...) AT = Diagonal{T, Vector{T}} TestSuite.test_mooncake_eig(AT, (m, m); atol = m * m * TestSuite.precision(T), rtol = m * m * TestSuite.precision(T)) end + if CUDA.functional() && T ∈ BLASFloats + TestSuite.test_mooncake_eig(CuMatrix{T}, (m, m); atol = m * m * TestSuite.precision(T), rtol = m * m * TestSuite.precision(T)) + AT = Diagonal{T, CuVector{T}} + TestSuite.test_mooncake_eig(AT, (m, m); atol = m * m * TestSuite.precision(T), rtol = m * m * TestSuite.precision(T)) + end + if AMDGPU.functional() && T ∈ BLASFloats + TestSuite.test_mooncake_eig(ROCMatrix{T}, (m, m); atol = m * m * TestSuite.precision(T), rtol = m * m * TestSuite.precision(T)) + AT = Diagonal{T, ROCVector{T}} + TestSuite.test_mooncake_eig(AT, (m, m); atol = m * m * TestSuite.precision(T), rtol = m * m * TestSuite.precision(T)) + end end diff --git a/test/mooncake/eigh.jl b/test/mooncake/eigh.jl index e39f6831..dce1eee3 100644 --- a/test/mooncake/eigh.jl +++ b/test/mooncake/eigh.jl @@ -18,4 +18,14 @@ for T in (BLASFloats..., GenericFloats...) AT = Diagonal{T, Vector{T}} TestSuite.test_mooncake_eigh(AT, m; atol = m * m * TestSuite.precision(T), rtol = m * m * TestSuite.precision(T)) end + if CUDA.functional() && T ∈ BLASFloats + TestSuite.test_mooncake_eigh(CuMatrix{T}, (m, m); atol = m * m * TestSuite.precision(T), rtol = m * m * TestSuite.precision(T)) + AT = Diagonal{T, CuVector{T}} + TestSuite.test_mooncake_eigh(AT, (m, m); atol = m * m * TestSuite.precision(T), rtol = m * m * TestSuite.precision(T)) + end + if AMDGPU.functional() && T ∈ BLASFloats + TestSuite.test_mooncake_eigh(ROCMatrix{T}, (m, m); atol = m * m * TestSuite.precision(T), rtol = m * m * TestSuite.precision(T)) + AT = Diagonal{T, ROCVector{T}} + TestSuite.test_mooncake_eigh(AT, (m, m); atol = m * m * TestSuite.precision(T), rtol = m * m * TestSuite.precision(T)) + end end diff --git a/test/mooncake/lq.jl b/test/mooncake/lq.jl index 42d0fdb6..266aa98c 100644 --- a/test/mooncake/lq.jl +++ b/test/mooncake/lq.jl @@ -20,4 +20,18 @@ for T in (BLASFloats..., GenericFloats...), n in (17, m, 23) TestSuite.test_mooncake_lq(AT, (m, m); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) end end + if CUDA.functional() && T ∈ BLASFloats + TestSuite.test_mooncake_lq(CuMatrix{T}, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) + if m == n + AT = Diagonal{T, CuVector{T}} + TestSuite.test_mooncake_lq(AT, (m, m); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) + end + end + if AMDGPU.functional() && T ∈ BLASFloats + TestSuite.test_mooncake_lq(ROCMatrix{T}, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) + if m == n + AT = Diagonal{T, ROCVector{T}} + TestSuite.test_mooncake_lq(AT, (m, m); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) + end + end end diff --git a/test/mooncake/svd.jl b/test/mooncake/svd.jl index f096fdb8..c2aaefb2 100644 --- a/test/mooncake/svd.jl +++ b/test/mooncake/svd.jl @@ -20,4 +20,18 @@ for T in (BLASFloats..., GenericFloats...), n in (17, m, 23) TestSuite.test_mooncake_svd(AT, m; atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) end end + if CUDA.functional() && T ∈ BLASFloats + TestSuite.test_mooncake_svd(CuMatrix{T}, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) + if m == n + AT = Diagonal{T, CuVector{T}} + TestSuite.test_mooncake_svd(AT, m; atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) + end + end + if AMDGPU.functional() && T ∈ BLASFloats + TestSuite.test_mooncake_svd(ROCMatrix{T}, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) + if m == n + AT = Diagonal{T, ROCVector{T}} + TestSuite.test_mooncake_svd(AT, m; atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) + end + end end