From d131c1b5b6037bba6717c945f1180c315eed620d Mon Sep 17 00:00:00 2001 From: lkdvos Date: Sun, 10 May 2026 18:16:35 -0400 Subject: [PATCH 1/4] make sure LQ pullback does not modify input --- src/pullbacks/lq.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pullbacks/lq.jl b/src/pullbacks/lq.jl index 251fafa86..146f6145d 100644 --- a/src/pullbacks/lq.jl +++ b/src/pullbacks/lq.jl @@ -13,8 +13,8 @@ function check_and_prepare_lq_cotangents( size(ΔQ) == size(Q) || throw(DimensionMismatch("ΔQ must have the same size as Q")) ΔQ₁ .= view(ΔQ, 1:p, 1:n) if p == minmn # full rank case, ΔQ₃ contains gauge-invariant information along Q₁ + ΔQ₃ = ΔQ[(minmn + 1):size(Q, 1), :] # extra columns in the case of qr_full Q₃ = view(Q, (minmn + 1):size(Q, 1), :) - ΔQ₃ = view(ΔQ, (minmn + 1):size(Q, 1), :) ΔQ₃Q₁ᴴ = ΔQ₃ * Q₁' mul!(ΔQ₃, ΔQ₃Q₁ᴴ, Q₁, -1, 1) Δgauge_Q = norm(ΔQ₃, Inf) From 4e19011facf794f62a1ada62df48bc7d67abdab5 Mon Sep 17 00:00:00 2001 From: lkdvos Date: Sun, 10 May 2026 18:16:50 -0400 Subject: [PATCH 2/4] remove unnecessary line --- src/pullbacks/qr.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/pullbacks/qr.jl b/src/pullbacks/qr.jl index 97bd45ae6..41c0f452b 100644 --- a/src/pullbacks/qr.jl +++ b/src/pullbacks/qr.jl @@ -15,7 +15,6 @@ function check_and_prepare_qr_cotangents( ΔQ₁ .= view(ΔQ, 1:m, 1:p) if p == minmn # full rank case, ΔQ₃ contains gauge-invariant information along Q₁ ΔQ₃ = ΔQ[:, (minmn + 1):size(Q, 2)] # extra columns in the case of qr_full - Q₁ = view(Q, :, 1:minmn) Q₃ = view(Q, :, (minmn + 1):size(Q, 2)) Q₁ᴴΔQ₃ = Q₁' * ΔQ₃ mul!(ΔQ₃, Q₁, Q₁ᴴΔQ₃, -1, 1) From 742074caa68b75cfde258966c3e18fa02eda9aac Mon Sep 17 00:00:00 2001 From: lkdvos Date: Mon, 11 May 2026 08:53:21 -0400 Subject: [PATCH 3/4] copy of view to make sure this is copied --- src/pullbacks/lq.jl | 2 +- src/pullbacks/qr.jl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/pullbacks/lq.jl b/src/pullbacks/lq.jl index 146f6145d..49fc305de 100644 --- a/src/pullbacks/lq.jl +++ b/src/pullbacks/lq.jl @@ -13,7 +13,7 @@ function check_and_prepare_lq_cotangents( size(ΔQ) == size(Q) || throw(DimensionMismatch("ΔQ must have the same size as Q")) ΔQ₁ .= view(ΔQ, 1:p, 1:n) if p == minmn # full rank case, ΔQ₃ contains gauge-invariant information along Q₁ - ΔQ₃ = ΔQ[(minmn + 1):size(Q, 1), :] # extra columns in the case of qr_full + ΔQ₃ = copy(view(ΔQ, (minmn + 1):size(Q, 1), :)) # extra columns in the case of qr_full Q₃ = view(Q, (minmn + 1):size(Q, 1), :) ΔQ₃Q₁ᴴ = ΔQ₃ * Q₁' mul!(ΔQ₃, ΔQ₃Q₁ᴴ, Q₁, -1, 1) diff --git a/src/pullbacks/qr.jl b/src/pullbacks/qr.jl index 41c0f452b..fb90704c1 100644 --- a/src/pullbacks/qr.jl +++ b/src/pullbacks/qr.jl @@ -14,7 +14,7 @@ function check_and_prepare_qr_cotangents( size(ΔQ) == size(Q) || throw(DimensionMismatch("ΔQ must have the same size as Q")) ΔQ₁ .= view(ΔQ, 1:m, 1:p) if p == minmn # full rank case, ΔQ₃ contains gauge-invariant information along Q₁ - ΔQ₃ = ΔQ[:, (minmn + 1):size(Q, 2)] # extra columns in the case of qr_full + ΔQ₃ = copy(view(ΔQ, :, (minmn + 1):size(Q, 2))) # extra columns in the case of qr_full Q₃ = view(Q, :, (minmn + 1):size(Q, 2)) Q₁ᴴΔQ₃ = Q₁' * ΔQ₃ mul!(ΔQ₃, Q₁, Q₁ᴴΔQ₃, -1, 1) From b25c6c90cdee8a03ea04312e1d94d19b8c9f5f68 Mon Sep 17 00:00:00 2001 From: lkdvos Date: Mon, 11 May 2026 09:00:12 -0400 Subject: [PATCH 4/4] add tests for not modifying output tangents in chainrules --- test/testsuite/chainrules.jl | 201 +++++++++++++++++++++++++++++++---- 1 file changed, 178 insertions(+), 23 deletions(-) diff --git a/test/testsuite/chainrules.jl b/test/testsuite/chainrules.jl index 25b498413..af5604881 100755 --- a/test/testsuite/chainrules.jl +++ b/test/testsuite/chainrules.jl @@ -63,50 +63,63 @@ function test_chainrules_qr( @testset "qr_compact" begin QR, ΔQR = ad_qr_compact_setup(A) ΔQ, ΔR = ΔQR + ΔQR_copy = deepcopy(ΔQR) test_rrule( cr_copy_qr_compact, A, alg ⊢ NoTangent(); output_tangent = ΔQR, atol = atol, rtol = rtol ) + @test isequal(ΔQR, ΔQR_copy) test_rrule( config, qr_compact, A; fkwargs = (; positive = true), output_tangent = ΔQR, atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false ) + @test isequal(ΔQR, ΔQR_copy) + ΔQ_copy = deepcopy(ΔQ) test_rrule( config, first ∘ qr_compact, A; fkwargs = (; positive = true), output_tangent = ΔQ, atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false ) + @test isequal(ΔQ, ΔQ_copy) + ΔR_copy = deepcopy(ΔR) test_rrule( config, last ∘ qr_compact, A; fkwargs = (; positive = true), output_tangent = ΔR, atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false ) + @test isequal(ΔR, ΔR_copy) end @testset "qr_null" begin N, ΔN = ad_qr_null_setup(A) + ΔN_copy = deepcopy(ΔN) test_rrule( cr_copy_qr_null, A, alg ⊢ NoTangent(); output_tangent = ΔN, atol = atol, rtol = rtol ) + @test isequal(ΔN, ΔN_copy) test_rrule( config, qr_null, A; fkwargs = (; positive = true), output_tangent = ΔN, atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false ) + @test isequal(ΔN, ΔN_copy) m, n = size(A) end @testset "qr_full" begin QR, ΔQR = ad_qr_full_setup(A) + ΔQR_copy = deepcopy(ΔQR) test_rrule( cr_copy_qr_full, A, alg ⊢ NoTangent(); output_tangent = ΔQR, atol = atol, rtol = rtol ) + @test isequal(ΔQR, ΔQR_copy) test_rrule( config, qr_full, A; fkwargs = (; positive = true), output_tangent = ΔQR, atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false ) + @test isequal(ΔQR, ΔQR_copy) m, n = size(A) end @testset "qr_compact - rank-deficient A" begin @@ -115,15 +128,18 @@ function test_chainrules_qr( Ard = instantiate_matrix(T, (m, r)) * instantiate_matrix(T, (r, n)) QR, ΔQR = ad_qr_compact_setup(Ard) ΔQ, ΔR = ΔQR + ΔQR_copy = deepcopy(ΔQR) test_rrule( cr_copy_qr_compact, Ard, alg ⊢ NoTangent(); output_tangent = ΔQR, atol = atol, rtol = rtol ) + @test isequal(ΔQR, ΔQR_copy) test_rrule( config, qr_compact, Ard; fkwargs = (; positive = true), output_tangent = ΔQR, atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false ) + @test isequal(ΔQR, ΔQR_copy) end end end @@ -142,64 +158,80 @@ function test_chainrules_lq( @testset "lq_compact" begin LQ, ΔLQ = ad_lq_compact_setup(A) ΔL, ΔQ = ΔLQ + ΔLQ_copy = deepcopy(ΔLQ) test_rrule( cr_copy_lq_compact, A, alg ⊢ NoTangent(); output_tangent = ΔLQ, atol = atol, rtol = rtol ) + @test isequal(ΔLQ, ΔLQ_copy) test_rrule( config, lq_compact, A; fkwargs = (; positive = true), output_tangent = ΔLQ, atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false ) + @test isequal(ΔLQ, ΔLQ_copy) + ΔL_copy = deepcopy(ΔL) test_rrule( config, first ∘ lq_compact, A; fkwargs = (; positive = true), output_tangent = ΔL, atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false ) + @test isequal(ΔL, ΔL_copy) + ΔQ_copy = deepcopy(ΔQ) test_rrule( config, last ∘ lq_compact, A; fkwargs = (; positive = true), output_tangent = ΔQ, atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false ) + @test isequal(ΔQ, ΔQ_copy) end @testset "lq_null" begin Nᴴ, ΔNᴴ = ad_lq_null_setup(A) + ΔNᴴ_copy = deepcopy(ΔNᴴ) test_rrule( cr_copy_lq_null, A, alg ⊢ NoTangent(); output_tangent = ΔNᴴ, atol = atol, rtol = rtol ) + @test isequal(ΔNᴴ, ΔNᴴ_copy) test_rrule( config, lq_null, A; fkwargs = (; positive = true), output_tangent = ΔNᴴ, atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false ) + @test isequal(ΔNᴴ, ΔNᴴ_copy) end @testset "lq_full" begin LQ, ΔLQ = ad_lq_full_setup(A) + ΔLQ_copy = deepcopy(ΔLQ) test_rrule( cr_copy_lq_full, A, alg ⊢ NoTangent(); output_tangent = ΔLQ, atol = atol, rtol = rtol ) + @test isequal(ΔLQ, ΔLQ_copy) test_rrule( config, lq_full, A; fkwargs = (; positive = true), output_tangent = ΔLQ, atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false ) + @test isequal(ΔLQ, ΔLQ_copy) end @testset "lq_compact - rank-deficient A" begin m, n = size(A) r = min(m, n) - 5 Ard = instantiate_matrix(T, (m, r)) * instantiate_matrix(T, (r, n)) LQ, ΔLQ = ad_lq_compact_setup(Ard) + ΔLQ_copy = deepcopy(ΔLQ) test_rrule( cr_copy_lq_compact, Ard, alg ⊢ NoTangent(); output_tangent = ΔLQ, atol = atol, rtol = rtol ) + @test isequal(ΔLQ, ΔLQ_copy) test_rrule( config, lq_compact, Ard; fkwargs = (; positive = true), output_tangent = ΔLQ, atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false ) + @test isequal(ΔLQ, ΔLQ_copy) end end end @@ -218,45 +250,60 @@ function test_chainrules_eig( @testset "eig_full" begin DV, ΔDV = ad_eig_full_setup(A) ΔD, ΔV = ΔDV + ΔDV_copy = deepcopy(ΔDV) test_rrule( cr_copy_eig_full, A, alg ⊢ NoTangent(); output_tangent = ΔDV, atol, rtol ) + @test isequal(ΔDV, ΔDV_copy) test_rrule( config, eig_full, A, alg ⊢ NoTangent(); output_tangent = ΔDV, atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false ) + @test isequal(ΔDV, ΔDV_copy) + ΔD_copy = deepcopy(ΔD) test_rrule( config, first ∘ eig_full, A, alg ⊢ NoTangent(); output_tangent = ΔD, atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false ) + @test isequal(ΔD, ΔD_copy) + ΔV_copy = deepcopy(ΔV) test_rrule( config, last ∘ eig_full, A, alg ⊢ NoTangent(); output_tangent = ΔV, atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false ) + @test isequal(ΔV, ΔV_copy) end @testset "eig_vals" begin D, ΔD = ad_eig_vals_setup(A) + ΔD_copy = deepcopy(ΔD) test_rrule( cr_copy_eig_vals, A, alg ⊢ NoTangent(); output_tangent = ΔD, atol, rtol ) + @test isequal(ΔD, ΔD_copy) test_rrule( config, eig_vals, A, alg ⊢ NoTangent(); output_tangent = ΔD, atol, rtol, rrule_f = rrule_via_ad, check_inferred = false ) + @test isequal(ΔD, ΔD_copy) end @testset "eig_trunc" begin for r in 1:4:m truncalg = TruncatedAlgorithm(alg, truncrank(r; by = abs)) DV, DVtrunc, ΔDV, ΔDVtrunc = ad_eig_trunc_setup(A, truncalg) + ot = (ΔDVtrunc..., zero(real(T))) + ot_copy = deepcopy(ot) test_rrule( cr_copy_eig_trunc, A, truncalg ⊢ NoTangent(); - output_tangent = (ΔDVtrunc..., zero(real(T))), + output_tangent = ot, atol = atol, rtol = rtol ) + @test isequal(ot, ot_copy) + ΔDVtrunc_copy = deepcopy(ΔDVtrunc) test_rrule( cr_copy_eig_trunc_no_error, A, truncalg ⊢ NoTangent(); output_tangent = ΔDVtrunc, atol = atol, rtol = rtol ) + @test isequal(ΔDVtrunc, ΔDVtrunc_copy) ind = MatrixAlgebraKit.findtruncated(diagview(DV[1]), truncalg.trunc) Ddiag = diagview(DV[1]) p = sortperm(Ddiag, by = abs, rev = true) @@ -268,15 +315,20 @@ function test_chainrules_eig( end truncalg = TruncatedAlgorithm(alg, truncrank(5; by = real)) DV, DVtrunc, ΔDV, ΔDVtrunc = ad_eig_trunc_setup(A, truncalg) + ot = (ΔDVtrunc..., zero(real(T))) + ot_copy = deepcopy(ot) test_rrule( cr_copy_eig_trunc, A, truncalg ⊢ NoTangent(); - output_tangent = (ΔDVtrunc..., zero(real(T))), + output_tangent = ot, atol = atol, rtol = rtol ) + @test isequal(ot, ot_copy) + ΔDVtrunc_copy = deepcopy(ΔDVtrunc) test_rrule( cr_copy_eig_trunc_no_error, A, truncalg ⊢ NoTangent(); output_tangent = ΔDVtrunc, atol = atol, rtol = rtol ) + @test isequal(ΔDVtrunc, ΔDVtrunc_copy) end end end @@ -296,32 +348,42 @@ function test_chainrules_eigh( @testset "eigh_full" begin DV, ΔDV = ad_eigh_full_setup(A) ΔD, ΔV = ΔDV + ΔDV_copy = deepcopy(ΔDV) test_rrule( cr_copy_eigh_full, A, alg ⊢ NoTangent(); output_tangent = ΔDV, atol, rtol ) + @test isequal(ΔDV, ΔDV_copy) # eigh_full does not include a projector onto the Hermitian part of the matrix test_rrule( config, eigh_full ∘ Matrix ∘ Hermitian, A; output_tangent = ΔDV, atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false ) + @test isequal(ΔDV, ΔDV_copy) + ΔD_copy = deepcopy(ΔD) test_rrule( config, first ∘ eigh_full ∘ Matrix ∘ Hermitian, A; output_tangent = ΔD, atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false ) + @test isequal(ΔD, ΔD_copy) + ΔV_copy = deepcopy(ΔV) test_rrule( config, last ∘ eigh_full ∘ Matrix ∘ Hermitian, A; output_tangent = ΔV, atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false ) + @test isequal(ΔV, ΔV_copy) end @testset "eigh_vals" begin D, ΔD = ad_eigh_vals_setup(A) + ΔD_copy = deepcopy(ΔD) test_rrule( cr_copy_eigh_vals, A, alg ⊢ NoTangent(); output_tangent = ΔD, atol, rtol ) + @test isequal(ΔD, ΔD_copy) test_rrule( config, eigh_vals ∘ Matrix ∘ Hermitian, A; output_tangent = ΔD, atol, rtol, rrule_f = rrule_via_ad, check_inferred = false ) + @test isequal(ΔD, ΔD_copy) end @testset "eigh_trunc" begin eigh_trunc2(A; kwargs...) = eigh_trunc(Matrix(Hermitian(A)); kwargs...) @@ -329,15 +391,20 @@ function test_chainrules_eigh( for r in 1:4:m truncalg = TruncatedAlgorithm(alg, truncrank(r; by = abs)) DV, DVtrunc, ΔDV, ΔDVtrunc = ad_eigh_trunc_setup(A, truncalg) + ot = (ΔDVtrunc..., zero(real(T))) + ot_copy = deepcopy(ot) test_rrule( cr_copy_eigh_trunc, A, truncalg ⊢ NoTangent(); - output_tangent = (ΔDVtrunc..., zero(real(T))), + output_tangent = ot, atol = atol, rtol = rtol ) + @test isequal(ot, ot_copy) + ΔDVtrunc_copy = deepcopy(ΔDVtrunc) test_rrule( cr_copy_eigh_trunc_no_error, A, truncalg ⊢ NoTangent(); output_tangent = ΔDVtrunc, atol = atol, rtol = rtol ) + @test isequal(ΔDVtrunc, ΔDVtrunc_copy) ind = MatrixAlgebraKit.findtruncated(diagview(DV[1]), truncalg.trunc) dA1 = MatrixAlgebraKit.eigh_pullback!(zero(A), A, DV, ΔDVtrunc, ind) dA2 = MatrixAlgebraKit.eigh_trunc_pullback!(zero(A), A, DVtrunc, ΔDVtrunc) @@ -346,32 +413,42 @@ function test_chainrules_eigh( ind = MatrixAlgebraKit.findtruncated(diagview(DV[1]), trunc) truncalg = TruncatedAlgorithm(alg, trunc) DV, DVtrunc, ΔDV, ΔDVtrunc = ad_eigh_trunc_setup(A, truncalg) + ot = (ΔDVtrunc..., zero(real(T))) + ot_copy = deepcopy(ot) test_rrule( config, eigh_trunc2, A; fkwargs = (; trunc = trunc), - output_tangent = (ΔDVtrunc..., zero(real(T))), + output_tangent = ot, atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false ) + @test isequal(ot, ot_copy) + ΔDVtrunc_copy = deepcopy(ΔDVtrunc) test_rrule( config, eigh_trunc_no_error2, A; fkwargs = (; trunc = trunc), output_tangent = ΔDVtrunc, atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false ) + @test isequal(ΔDVtrunc, ΔDVtrunc_copy) end D, ΔD = ad_eigh_vals_setup(A / 2) truncalg = TruncatedAlgorithm(alg, trunctol(; atol = maximum(abs, D) / 2)) DV, DVtrunc, ΔDV, ΔDVtrunc = ad_eigh_trunc_setup(A, truncalg) ind = MatrixAlgebraKit.findtruncated(diagview(DV[1]), truncalg.trunc) + ot = (ΔDVtrunc..., zero(real(T))) + ot_copy = deepcopy(ot) test_rrule( cr_copy_eigh_trunc, A, truncalg ⊢ NoTangent(); - output_tangent = (ΔDVtrunc..., zero(real(T))), + output_tangent = ot, atol = atol, rtol = rtol ) + @test isequal(ot, ot_copy) + ΔDVtrunc_copy = deepcopy(ΔDVtrunc) test_rrule( cr_copy_eigh_trunc_no_error, A, truncalg ⊢ NoTangent(); output_tangent = ΔDVtrunc, atol = atol, rtol = rtol ) + @test isequal(ΔDVtrunc, ΔDVtrunc_copy) dA1 = MatrixAlgebraKit.eigh_pullback!(zero(A), A, DV, ΔDVtrunc, ind) dA2 = MatrixAlgebraKit.eigh_trunc_pullback!(zero(A), A, DVtrunc, ΔDVtrunc) @test isapprox(dA1, dA2; atol = atol, rtol = rtol) @@ -379,18 +456,23 @@ function test_chainrules_eigh( truncalg = TruncatedAlgorithm(alg, trunc) DV, DVtrunc, ΔDV, ΔDVtrunc = ad_eigh_trunc_setup(A, truncalg) ind = MatrixAlgebraKit.findtruncated(diagview(DV[1]), truncalg.trunc) + ot = (ΔDVtrunc..., zero(real(T))) + ot_copy = deepcopy(ot) test_rrule( config, eigh_trunc2, A; fkwargs = (; trunc = trunc), - output_tangent = (ΔDVtrunc..., zero(real(T))), + output_tangent = ot, atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false ) + @test isequal(ot, ot_copy) + ΔDVtrunc_copy = deepcopy(ΔDVtrunc) test_rrule( config, eigh_trunc_no_error2, A; fkwargs = (; trunc = trunc), output_tangent = ΔDVtrunc, atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false ) + @test isequal(ΔDVtrunc, ΔDVtrunc_copy) end end end @@ -408,53 +490,67 @@ function test_chainrules_svd( alg = MatrixAlgebraKit.default_svd_algorithm(A) @testset "svd_compact" begin USV, ΔUSVᴴ = ad_svd_compact_setup(A) + ΔUSVᴴ_copy = deepcopy(ΔUSVᴴ) test_rrule( cr_copy_svd_compact, A, alg ⊢ NoTangent(); output_tangent = ΔUSVᴴ, atol = atol, rtol = rtol ) + @test isequal(ΔUSVᴴ, ΔUSVᴴ_copy) test_rrule( config, svd_compact, A, alg ⊢ NoTangent(); output_tangent = ΔUSVᴴ, atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false ) + @test isequal(ΔUSVᴴ, ΔUSVᴴ_copy) end @testset "svd_full" begin USV, ΔUSVᴴ = ad_svd_full_setup(A) + ΔUSVᴴ_copy = deepcopy(ΔUSVᴴ) test_rrule( cr_copy_svd_full, A, alg ⊢ NoTangent(); output_tangent = ΔUSVᴴ, atol = atol, rtol = rtol ) + @test isequal(ΔUSVᴴ, ΔUSVᴴ_copy) test_rrule( config, svd_full, A, alg ⊢ NoTangent(); output_tangent = ΔUSVᴴ, atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false ) + @test isequal(ΔUSVᴴ, ΔUSVᴴ_copy) end @testset "svd_vals" begin S, ΔS = ad_svd_vals_setup(A) + ΔS_copy = deepcopy(ΔS) test_rrule( cr_copy_svd_vals, A, alg ⊢ NoTangent(); output_tangent = ΔS, atol, rtol ) + @test isequal(ΔS, ΔS_copy) test_rrule( config, svd_vals, A, alg ⊢ NoTangent(); output_tangent = ΔS, atol, rtol, rrule_f = rrule_via_ad, check_inferred = false ) + @test isequal(ΔS, ΔS_copy) end @testset "svd_trunc" begin @testset for r in 1:4:minmn truncalg = TruncatedAlgorithm(alg, truncrank(r)) USVᴴ, _, ΔUSVᴴ, ΔUSVᴴtrunc = ad_svd_trunc_setup(A, truncalg) + ot = (ΔUSVᴴtrunc..., zero(real(T))) + ot_copy = deepcopy(ot) test_rrule( cr_copy_svd_trunc, A, truncalg ⊢ NoTangent(); - output_tangent = (ΔUSVᴴtrunc..., zero(real(T))), + output_tangent = ot, atol = atol, rtol = rtol ) + @test isequal(ot, ot_copy) + ΔUSVᴴtrunc_copy = deepcopy(ΔUSVᴴtrunc) test_rrule( cr_copy_svd_trunc_no_error, A, truncalg ⊢ NoTangent(); output_tangent = ΔUSVᴴtrunc, atol = atol, rtol = rtol ) + @test isequal(ΔUSVᴴtrunc, ΔUSVᴴtrunc_copy) U, S, Vᴴ = USVᴴ ind = MatrixAlgebraKit.findtruncated(diagview(S), truncalg.trunc) Strunc = Diagonal(diagview(S)[ind]) @@ -466,31 +562,41 @@ function test_chainrules_svd( @test isapprox(dA1, dA2; atol = atol, rtol = rtol) trunc = truncrank(r) ind = MatrixAlgebraKit.findtruncated(diagview(S), trunc) + ot = (ΔUSVᴴtrunc..., zero(real(T))) + ot_copy = deepcopy(ot) test_rrule( config, svd_trunc, A; fkwargs = (; trunc = trunc), - output_tangent = (ΔUSVᴴtrunc..., zero(real(T))), + output_tangent = ot, atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false ) + @test isequal(ot, ot_copy) + ΔUSVᴴtrunc_copy = deepcopy(ΔUSVᴴtrunc) test_rrule( config, svd_trunc_no_error, A; fkwargs = (; trunc = trunc), output_tangent = ΔUSVᴴtrunc, atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false ) + @test isequal(ΔUSVᴴtrunc, ΔUSVᴴtrunc_copy) end S, ΔS = ad_svd_vals_setup(A) truncalg = TruncatedAlgorithm(alg, trunctol(atol = S[1, 1] / 2)) USVᴴ, _, ΔUSVᴴ, ΔUSVᴴtrunc = ad_svd_trunc_setup(A, truncalg) + ot = (ΔUSVᴴtrunc..., zero(real(T))) + ot_copy = deepcopy(ot) test_rrule( cr_copy_svd_trunc, A, truncalg ⊢ NoTangent(); - output_tangent = (ΔUSVᴴtrunc..., zero(real(T))), + output_tangent = ot, atol = atol, rtol = rtol ) + @test isequal(ot, ot_copy) + ΔUSVᴴtrunc_copy = deepcopy(ΔUSVᴴtrunc) test_rrule( cr_copy_svd_trunc_no_error, A, truncalg ⊢ NoTangent(); output_tangent = ΔUSVᴴtrunc, atol = atol, rtol = rtol ) + @test isequal(ΔUSVᴴtrunc, ΔUSVᴴtrunc_copy) U, S, Vᴴ = USVᴴ ind = MatrixAlgebraKit.findtruncated(diagview(S), truncalg.trunc) Strunc = Diagonal(diagview(S)[ind]) @@ -501,18 +607,23 @@ function test_chainrules_svd( @test isapprox(dA1, dA2; atol = atol, rtol = rtol) trunc = trunctol(; atol = S[1, 1] / 2) ind = MatrixAlgebraKit.findtruncated(diagview(S), trunc) + ot = (ΔUSVᴴtrunc..., zero(real(T))) + ot_copy = deepcopy(ot) test_rrule( config, svd_trunc, A; fkwargs = (; trunc = trunc), - output_tangent = (ΔUSVᴴtrunc..., zero(real(T))), + output_tangent = ot, atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false ) + @test isequal(ot, ot_copy) + ΔUSVᴴtrunc_copy = deepcopy(ΔUSVᴴtrunc) test_rrule( config, svd_trunc_no_error, A; fkwargs = (; trunc = trunc), output_tangent = ΔUSVᴴtrunc, atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false ) + @test isequal(ΔUSVᴴtrunc, ΔUSVᴴtrunc_copy) end end end @@ -530,20 +641,36 @@ function test_chainrules_polar( alg = MatrixAlgebraKit.default_polar_algorithm(A) @testset "left_polar" begin if m >= n - test_rrule(cr_copy_left_polar, A, alg ⊢ NoTangent(); atol = atol, rtol = rtol) + _, ΔWP = ad_left_polar_setup(A) + ΔWP_copy = deepcopy(ΔWP) + test_rrule( + cr_copy_left_polar, A, alg ⊢ NoTangent(); + output_tangent = ΔWP, atol = atol, rtol = rtol + ) + @test isequal(ΔWP, ΔWP_copy) test_rrule( config, left_polar, A, alg ⊢ NoTangent(); + output_tangent = ΔWP, atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false ) + @test isequal(ΔWP, ΔWP_copy) end end @testset "right_polar" begin if m <= n - test_rrule(cr_copy_right_polar, A, alg ⊢ NoTangent(); atol = atol, rtol = rtol) + _, ΔPWᴴ = ad_right_polar_setup(A) + ΔPWᴴ_copy = deepcopy(ΔPWᴴ) + test_rrule( + cr_copy_right_polar, A, alg ⊢ NoTangent(); + output_tangent = ΔPWᴴ, atol = atol, rtol = rtol + ) + @test isequal(ΔPWᴴ, ΔPWᴴ_copy) test_rrule( config, right_polar, A, alg ⊢ NoTangent(); + output_tangent = ΔPWᴴ, atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false ) + @test isequal(ΔPWᴴ, ΔPWᴴ_copy) end end end @@ -561,43 +688,65 @@ function test_chainrules_orthnull( config = Zygote.ZygoteRuleConfig() N, ΔN = ad_left_null_setup(A) Nᴴ, ΔNᴴ = ad_right_null_setup(A) + _, ΔVC = ad_left_orth_setup(A) + ΔVC_copy = deepcopy(ΔVC) test_rrule( config, left_orth, A; + output_tangent = ΔVC, atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false ) + @test isequal(ΔVC, ΔVC_copy) test_rrule( config, left_orth, A; - fkwargs = (; alg = :qr), atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + fkwargs = (; alg = :qr), output_tangent = ΔVC, + atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false ) - m >= n && + @test isequal(ΔVC, ΔVC_copy) + if m >= n test_rrule( - config, left_orth, A; - fkwargs = (; alg = :polar), atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) + config, left_orth, A; + fkwargs = (; alg = :polar), output_tangent = ΔVC, + atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + @test isequal(ΔVC, ΔVC_copy) + end + ΔN_copy = deepcopy(ΔN) test_rrule( config, left_null, A; fkwargs = (; alg = :qr), output_tangent = ΔN, atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false ) + @test isequal(ΔN, ΔN_copy) + _, ΔCVᴴ = ad_right_orth_setup(A) + ΔCVᴴ_copy = deepcopy(ΔCVᴴ) test_rrule( config, right_orth, A; + output_tangent = ΔCVᴴ, atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false ) + @test isequal(ΔCVᴴ, ΔCVᴴ_copy) test_rrule( config, right_orth, A; fkwargs = (; alg = :lq), + output_tangent = ΔCVᴴ, atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false ) - m <= n && + @test isequal(ΔCVᴴ, ΔCVᴴ_copy) + if m <= n test_rrule( - config, right_orth, A; fkwargs = (; alg = :polar), - atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) + config, right_orth, A; fkwargs = (; alg = :polar), + output_tangent = ΔCVᴴ, + atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + @test isequal(ΔCVᴴ, ΔCVᴴ_copy) + end + ΔNᴴ_copy = deepcopy(ΔNᴴ) test_rrule( config, right_null, A; fkwargs = (; alg = :lq), output_tangent = ΔNᴴ, atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false ) + @test isequal(ΔNᴴ, ΔNᴴ_copy) end end @@ -613,11 +762,17 @@ function test_chainrules_projections( if m == n @testset "project_hermitian" begin alg = MatrixAlgebraKit.default_hermitian_algorithm(A) - test_rrule(project_hermitian, A, alg; atol, rtol) + Δ = randn!(similar(A)) + Δ_copy = deepcopy(Δ) + test_rrule(project_hermitian, A, alg; output_tangent = Δ, atol, rtol) + @test isequal(Δ, Δ_copy) end @testset "project_antihermitian" begin alg = MatrixAlgebraKit.default_hermitian_algorithm(A) - test_rrule(project_antihermitian, A, alg; atol, rtol) + Δ = randn!(similar(A)) + Δ_copy = deepcopy(Δ) + test_rrule(project_antihermitian, A, alg; output_tangent = Δ, atol, rtol) + @test isequal(Δ, Δ_copy) end end end