diff --git a/test/enzyme/eig.jl b/test/enzyme/eig.jl index 949129ea..7d985836 100644 --- a/test/enzyme/eig.jl +++ b/test/enzyme/eig.jl @@ -18,4 +18,7 @@ for T in (BLASFloats..., GenericFloats...) AT = Diagonal{T, Vector{T}} TestSuite.test_enzyme_eig(AT, (m, m); atol = m * m * TestSuite.precision(T), rtol = m * m * TestSuite.precision(T)) end + if CUDA.functional() + TestSuite.test_enzyme_eig(CuMatrix{T}, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) + end end diff --git a/test/enzyme/eigh.jl b/test/enzyme/eigh.jl index 64c796fc..024b2269 100644 --- a/test/enzyme/eigh.jl +++ b/test/enzyme/eigh.jl @@ -18,4 +18,7 @@ for T in (BLASFloats..., GenericFloats...) AT = Diagonal{T, Vector{T}} TestSuite.test_enzyme_eigh(AT, (m, m); atol = m * m * TestSuite.precision(T), rtol = m * m * TestSuite.precision(T)) end + if CUDA.functional() + TestSuite.test_enzyme_eigh(CuMatrix{T}, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) + end end diff --git a/test/enzyme/lq.jl b/test/enzyme/lq.jl index 7c747529..c10aadcb 100644 --- a/test/enzyme/lq.jl +++ b/test/enzyme/lq.jl @@ -18,4 +18,7 @@ for T in (BLASFloats..., GenericFloats...), n in (17, m, 23) AT = Diagonal{T, Vector{T}} m == n && TestSuite.test_enzyme_lq(AT, (m, m); atol = m * m * TestSuite.precision(T), rtol = m * m * TestSuite.precision(T)) end + if CUDA.functional() + TestSuite.test_enzyme_lq(CuMatrix{T}, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) + end end diff --git a/test/enzyme/orthnull.jl b/test/enzyme/orthnull.jl index 086873d3..aa8ac967 100644 --- a/test/enzyme/orthnull.jl +++ b/test/enzyme/orthnull.jl @@ -18,4 +18,7 @@ for T in (BLASFloats..., GenericFloats...), n in (17, m, 23) AT = Diagonal{T, Vector{T}} m == n && TestSuite.test_enzyme_orthnull(AT, (m, m); atol = m * m * TestSuite.precision(T), rtol = m * m * TestSuite.precision(T)) end + if CUDA.functional() + TestSuite.test_enzyme_orthnull(CuMatrix{T}, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) + end end diff --git a/test/enzyme/polar.jl b/test/enzyme/polar.jl index 183086ad..1843ecd5 100644 --- a/test/enzyme/polar.jl +++ b/test/enzyme/polar.jl @@ -18,4 +18,7 @@ for T in (BLASFloats..., GenericFloats...), n in (17, m, 23) AT = Diagonal{T, Vector{T}} #m == n && TestSuite.test_enzyme_polar(AT, (m, m); atol = m * m * TestSuite.precision(T), rtol = m * m * TestSuite.precision(T)) end + if CUDA.functional() + TestSuite.test_enzyme_polar(CuMatrix{T}, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) + end end diff --git a/test/enzyme/projections.jl b/test/enzyme/projections.jl index 52b222a5..cf3d9070 100644 --- a/test/enzyme/projections.jl +++ b/test/enzyme/projections.jl @@ -18,4 +18,8 @@ for T in (BLASFloats..., GenericFloats...) TestSuite.test_enzyme_projections(T, (m, m); atol, rtol) TestSuite.test_enzyme_projections(Diagonal{T, Vector{T}}, (m, m); atol, rtol) end + if CUDA.functional() + TestSuite.test_enzyme_projections(CuMatrix{T}, (m, n); atol, rtol) + TestSuite.test_enzyme_projections(Diagonal{T, CuVector{T}}, (m, m); atol, rtol) + end end diff --git a/test/enzyme/qr.jl b/test/enzyme/qr.jl index 2d8b9e7e..3d3116a1 100644 --- a/test/enzyme/qr.jl +++ b/test/enzyme/qr.jl @@ -18,4 +18,7 @@ for T in (BLASFloats..., GenericFloats...), n in (17, m, 23) AT = Diagonal{T, Vector{T}} m == n && TestSuite.test_enzyme_qr(AT, (m, m); atol = m * m * TestSuite.precision(T), rtol = m * m * TestSuite.precision(T)) end + if CUDA.functional() + TestSuite.test_enzyme_qr(CuMatrix{T}, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) + end end diff --git a/test/enzyme/svd.jl b/test/enzyme/svd.jl index e4aaa7aa..29e2963e 100644 --- a/test/enzyme/svd.jl +++ b/test/enzyme/svd.jl @@ -18,4 +18,7 @@ for T in (BLASFloats..., GenericFloats...), n in (17, m, 23) AT = Diagonal{T, Vector{T}} m == n && TestSuite.test_enzyme_svd(AT, (m, m); atol = m * m * TestSuite.precision(T), rtol = m * m * TestSuite.precision(T)) end + if CUDA.functional() + TestSuite.test_enzyme_svd(CuMatrix{T}, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) + end end