diff --git a/ext/MatrixAlgebraKitGenericSchurExt.jl b/ext/MatrixAlgebraKitGenericSchurExt.jl index 4c54f3c0c..585278242 100644 --- a/ext/MatrixAlgebraKitGenericSchurExt.jl +++ b/ext/MatrixAlgebraKitGenericSchurExt.jl @@ -47,4 +47,9 @@ Base.@deprecate( schur_vals!(A, vals, QRIteration(; driver = GS(), alg.kwargs...)) ) +function MatrixAlgebraKit.default_exponential_algorithm(E::Type{T}; kwargs...) where {T <: StridedMatrix{<:Union{BigFloat, Complex{BigFloat}}}} + eig_alg = MatrixAlgebraKit.default_eig_algorithm(E; kwargs...) + return MatrixFunctionViaEig(eig_alg) +end + end diff --git a/src/MatrixAlgebraKit.jl b/src/MatrixAlgebraKit.jl index 22bf79e9c..946b92df1 100644 --- a/src/MatrixAlgebraKit.jl +++ b/src/MatrixAlgebraKit.jl @@ -30,6 +30,7 @@ export left_polar, right_polar export left_polar!, right_polar! export left_orth, right_orth, left_null, right_null export left_orth!, right_orth!, left_null!, right_null! +export exponential, exponential!, exponentialr, exponentialr! export Householder, Native_HouseholderQR, Native_HouseholderLQ export DivideAndConquer, SafeDivideAndConquer, QRIteration, Bisection, Jacobi, SVDViaPolar @@ -40,6 +41,7 @@ export LAPACK_HouseholderQR, LAPACK_HouseholderLQ, LAPACK_Simple, LAPACK_Expert, export GLA_HouseholderQR, GLA_QRIteration, GS_QRIteration export LQViaTransposedQR export PolarViaSVD, PolarNewton +export MatrixFunctionViaLA, MatrixFunctionViaEig, MatrixFunctionViaEigh export DefaultAlgorithm export DiagonalAlgorithm export NativeBlocked @@ -95,9 +97,12 @@ include("common/matrixproperties.jl") include("yalapack.jl") include("algorithms.jl") + include("interface/projections.jl") include("interface/decompositions.jl") include("interface/truncation.jl") +include("interface/matrixfunctions.jl") + include("interface/qr.jl") include("interface/lq.jl") include("interface/svd.jl") @@ -107,6 +112,7 @@ include("interface/gen_eig.jl") include("interface/schur.jl") include("interface/polar.jl") include("interface/orthnull.jl") +include("interface/exponential.jl") include("implementations/projections.jl") include("implementations/truncation.jl") @@ -119,6 +125,7 @@ include("implementations/gen_eig.jl") include("implementations/schur.jl") include("implementations/polar.jl") include("implementations/orthnull.jl") +include("implementations/exponential.jl") include("common/gauge.jl") # needs to be defined after the functions are diff --git a/src/common/view.jl b/src/common/view.jl index 8cd989ea0..61c8da8b8 100644 --- a/src/common/view.jl +++ b/src/common/view.jl @@ -20,6 +20,26 @@ See also [`diagview`](@ref). diagonal(v::AbstractVector) = Diagonal(v) +""" + map_diagonal!(f, dst, src...) + +Map the scalar function `f` over all elements of the diagonal of `src...`, returning +a diagonal result. + +See also [`map_diagonal!`](@ref). +""" +map_diagonal(f, src, srcs...) = diagonal(f.(diagview(src), map(diagview, srcs)...)) + +""" + map_diagonal!(f, dst, src...) + +Map the scalar function `f` over all elements of the diagonal of `src...`, +into the diagonal elements of destination `dst`. + +See also [`map_diagonal`](@ref). +""" +map_diagonal!(f, dst, src, srcs...) = (diagview(dst) .= f.(diagview(src), map(diagview, srcs)...); dst) + # triangularind function lowertriangularind(A::AbstractMatrix) Base.require_one_based_indexing(A) diff --git a/src/implementations/exponential.jl b/src/implementations/exponential.jl new file mode 100644 index 000000000..879d14fa2 --- /dev/null +++ b/src/implementations/exponential.jl @@ -0,0 +1,141 @@ +# Inputs +# ------ +function copy_input(::typeof(exponential), A::AbstractMatrix) + return copy!(similar(A, float(eltype(A))), A) +end + +copy_input(::typeof(exponential), A::Diagonal) = copy(A) + +function copy_input(::typeof(exponentialr), τ::Number, A::AbstractMatrix) + return τ, copy!(similar(A, float(eltype(A))), A) +end + +copy_input(::typeof(exponentialr), τ::Number, A::Diagonal) = τ, copy(A) + +function check_input(::typeof(exponential!), A::AbstractMatrix, expA::AbstractMatrix, alg::AbstractAlgorithm) + m, n = size(A) + m == n || throw(DimensionMismatch("square input matrix expected. Got ($m,$n)")) + @check_size(expA, (m, m)) + return @check_scalar(expA, A) +end + +function check_input(::typeof(exponential!), A::AbstractMatrix, expA::AbstractMatrix, alg::MatrixFunctionViaEigh) + if !ishermitian(A) + throw(DomainError(A, "Hermitian matrix was expected. Use `project_hermitian` to project onto the nearest hermitian matrix)")) + end + m, n = size(A) + m == n || throw(DimensionMismatch("square input matrix expected. Got ($m,$n)")) + @check_size(expA, (m, m)) + return @check_scalar(expA, A) +end + +function check_input(::typeof(exponential!), A::AbstractMatrix, expA::AbstractMatrix, ::DiagonalAlgorithm) + m, n = size(A) + @assert m == n && isdiag(A) + @assert expA isa Diagonal + @check_size(expA, (m, m)) + @check_scalar(expA, A) + return nothing +end + +function check_input(::typeof(exponentialr!), A::AbstractMatrix, expA::AbstractMatrix, alg::AbstractAlgorithm) + m, n = size(A) + m == n || throw(DimensionMismatch("square input matrix expected. Got ($m,$n)")) + return @check_size(expA, (m, m)) +end + +function check_input(::typeof(exponentialr!), A::AbstractMatrix, expA::AbstractMatrix, alg::MatrixFunctionViaEigh) + if !ishermitian(A) + throw(DomainError(A, "Hermitian matrix was expected. Use `project_hermitian` to project onto the nearest hermitian matrix)")) + end + m, n = size(A) + m == n || throw(DimensionMismatch("square input matrix expected. Got ($m,$n)")) + return @check_size(expA, (m, m)) +end + +function check_input(::typeof(exponentialr!), A::AbstractMatrix, expA::AbstractMatrix, ::DiagonalAlgorithm) + m, n = size(A) + @assert m == n && isdiag(A) + @assert expA isa Diagonal + return @check_size(expA, (m, m)) +end + +# Outputs +# ------- +initialize_output(::typeof(exponential!), A::AbstractMatrix, ::AbstractAlgorithm) = A +initialize_output(::typeof(exponentialr!), τ::T, A::AbstractMatrix, ::AbstractAlgorithm) where {T <: Real} = A +initialize_output(::typeof(exponentialr!), τ::Number, A::AbstractMatrix, ::AbstractAlgorithm) = + complex(A) + +# Implementation +# -------------- +function exponential!(A, expA, alg::MatrixFunctionViaLA) + check_input(exponential!, A, expA, alg) + return LinearAlgebra.exp!(A) +end + +function exponential!(A, expA, alg::MatrixFunctionViaEigh) + check_input(exponential!, A, expA, alg) + D, V = eigh_full!(A, alg.eigh_alg) + expD = map_diagonal!(x -> exp(x / 2), D, D) + VexpD = rmul!(V, expD) + return mul!(expA, VexpD, V') +end + +function exponential!(A::AbstractMatrix, expA::AbstractMatrix, alg::MatrixFunctionViaEig) + check_input(exponential!, A, expA, alg) + D, V = eig_full!(A, alg.eig_alg) + expD = map_diagonal!(exp, D, D) + iV = inv(V) + VexpD = rmul!(V, expD) + if eltype(A) <: Real + expA .= real.(VexpD * iV) + else + mul!(expA, VexpD, iV) + end + return expA +end + +function exponentialr!(τ::Number, A::AbstractMatrix, expA::AbstractMatrix, alg::MatrixFunctionViaLA) + check_input(exponentialr!, A, expA, alg) + expA .= A .* τ + return LinearAlgebra.exp!(expA) +end + +function exponentialr!(τ::Number, A::AbstractMatrix, expA::AbstractMatrix, alg::MatrixFunctionViaEigh) + check_input(exponentialr!, A, expA, alg) + D, V = eigh_full!(A, alg.eigh_alg) + expD = map_diagonal(x -> exp(x * τ), D) + VexpD = V * expD + if eltype(A) <: Real && eltype(τ) <: Real + return expA .= real.(VexpD * V') + else + return mul!(expA, VexpD, V') + end +end + +function exponentialr!(τ::Number, A, expA, alg::MatrixFunctionViaEig) + check_input(exponentialr!, A, expA, alg) + D, V = eig_full!(A, alg.eig_alg) + expD = map_diagonal!(x -> exp(x * τ), D, D) + iV = inv(V) + VexpD = rmul!(V, expD) + if eltype(A) <: Real && eltype(τ) <: Real + expA .= real.(VexpD * iV) + return expA + else + return mul!(expA, VexpD, iV) + end +end + +# Diagonal logic +# -------------- +function exponential!(A, expA, alg::DiagonalAlgorithm) + check_input(exponential!, A, expA, alg) + return map_diagonal!(exp, expA, A) +end + +function exponentialr!(τ::Number, A, expA, alg::DiagonalAlgorithm) + check_input(exponentialr!, A, expA, alg) + return map_diagonal!(x -> exp(x * τ), expA, A) +end diff --git a/src/interface/exponential.jl b/src/interface/exponential.jl new file mode 100644 index 000000000..01baa6269 --- /dev/null +++ b/src/interface/exponential.jl @@ -0,0 +1,60 @@ +# Exponential functions +# -------------- + +""" + exponential(A; kwargs...) -> expA + exponential(A, alg::AbstractAlgorithm) -> expA + exponential!(A, [expA]; kwargs...) -> expA + exponential!(A, [expA], alg::AbstractAlgorithm) -> expA + +Compute the exponential of the square matrix `A`, + +!!! note + The bang method `exponential!` optionally accepts the output structure and + possibly destroys the input matrix `A`. Always use the return value of the function + as it may not always be possible to use the provided `expA` as output. + +See also [`exponentialr(!)`](@ref exponentialr). +""" +@functiondef exponential + +""" + exponentialr(τ, A; kwargs...) -> expiτA + exponentialr(τ, A, alg::AbstractAlgorithm) -> expiτA + exponentialr!(τ, A, [expiτA]; kwargs...) -> expiτA + exponentialr!(τ, A, [expiτA], alg::AbstractAlgorithm) -> expiτA + +Compute the exponential of `i*τ*A`, where `i` is the imaginary unit, `τ` is a scalar, and `A` is a square matrix. +This allows the user to use the hermitian eigendecomposition when `A` is hermitian, even when `i*τ*A` is not. + +!!! note + The bang method `exponentialr!` optionally accepts the output structure and + possibly destroys the input matrix `A`. + Always use the return value of the function as it may not always be + possible to use the provided `expiτA` as output. + +See also [`exponential(!)`](@ref exponential). +""" +@functiondef n_args = 2 exponentialr + +# Algorithm selection +# ------------------- +default_exponential_algorithm(A; kwargs...) = default_exponential_algorithm(typeof(A); kwargs...) +function default_exponential_algorithm(T::Type; kwargs...) + return MatrixFunctionViaLA(; kwargs...) +end +function default_exponential_algorithm(::Type{T}; kwargs...) where {T <: Diagonal} + return DiagonalAlgorithm(; kwargs...) +end + +for f in (:exponential!,) + @eval function default_algorithm(::typeof($f), ::Type{A}; kwargs...) where {A} + return default_exponential_algorithm(A; kwargs...) + end +end + +for f in (:exponentialr!,) + @eval function default_algorithm(::typeof($f), ::Tuple{A, B}; kwargs...) where {A, B} + return default_exponential_algorithm(B; kwargs...) + end +end diff --git a/src/interface/matrixfunctions.jl b/src/interface/matrixfunctions.jl new file mode 100644 index 000000000..ce24652de --- /dev/null +++ b/src/interface/matrixfunctions.jl @@ -0,0 +1,39 @@ +# ================================ +# EXPONENTIAL ALGORITHMS +# ================================ +""" + MatrixFunctionViaLA() + +Algorithm type to denote finding the exponential of `A` via the implementation of `LinearAlgebra`. +""" +@algdef MatrixFunctionViaLA + +""" + MatrixFunctionViaEigh() + +Algorithm type to denote finding the exponential `A` by computing the hermitian eigendecomposition of `A`. +The `eigh_alg` specifies which hermitian eigendecomposition implementation to use. +""" +struct MatrixFunctionViaEigh{A <: AbstractAlgorithm} <: AbstractAlgorithm + eigh_alg::A +end +function Base.show(io::IO, alg::MatrixFunctionViaEigh) + print(io, "MatrixFunctionViaEigh(") + _show_alg(io, alg.eigh_alg) + return print(io, ")") +end + +""" + MatrixFunctionViaEig() + +Algorithm type to denote finding the exponential `A` by computing the eigendecomposition of `A`. +The `eig_alg` specifies which eigendecomposition implementation to use. +""" +struct MatrixFunctionViaEig{A <: AbstractAlgorithm} <: AbstractAlgorithm + eig_alg::A +end +function Base.show(io::IO, alg::MatrixFunctionViaEig) + print(io, "MatrixFunctionViaEig(") + _show_alg(io, alg.eig_alg) + return print(io, ")") +end diff --git a/src/matrixfunctions.jl b/src/matrixfunctions.jl deleted file mode 100644 index 8b1378917..000000000 --- a/src/matrixfunctions.jl +++ /dev/null @@ -1 +0,0 @@ - diff --git a/test/exponential.jl b/test/exponential.jl new file mode 100644 index 000000000..13b5fba0f --- /dev/null +++ b/test/exponential.jl @@ -0,0 +1,88 @@ +using MatrixAlgebraKit +using Test +using TestExtras +using StableRNGs +using MatrixAlgebraKit: diagview +using LinearAlgebra +using LinearAlgebra: exp + +BLASFloats = (Float32, Float64, ComplexF32, ComplexF64) +GenericFloats = (Float16, ComplexF16, BigFloat, Complex{BigFloat}) + +@testset "exponential! for T = $T" for T in BLASFloats + rng = StableRNG(123) + m = 54 + + A = LinearAlgebra.normalize!(randn(rng, T, m, m)) + Ac = copy(A) + expA = LinearAlgebra.exp(A) + + expA2 = @constinferred exponential(A) + @test expA ≈ expA2 + @test A == Ac + + algs = (MatrixFunctionViaLA(), MatrixFunctionViaEig(LAPACK_Simple())) + @testset "algorithm $alg" for alg in algs + expA2 = @constinferred exponential(A, alg) + @test expA ≈ expA2 + @test A == Ac + end + + @test_throws DomainError exponential(A; alg = MatrixFunctionViaEigh(LAPACK_QRIteration())) +end + +@testset "exponentialr! for T = $T" for T in BLASFloats + rng = StableRNG(123) + m = 54 + + A = randn(rng, T, m, m) + τ = randn(rng, T) + Ac = copy(A) + + Aτ = A * τ + expAτ = LinearAlgebra.exp(Aτ) + + expAτ2 = @constinferred exponentialr(τ, A) + @test expAτ ≈ expAτ2 + @test A == Ac + + algs = (MatrixFunctionViaLA(), MatrixFunctionViaEig(LAPACK_Simple())) + @testset "algorithm $alg" for alg in algs + expAτ2 = @constinferred exponentialr(τ, A, alg) + @test expAτ ≈ expAτ2 + @test A == Ac + end + + @test_throws DomainError exponentialr(τ, A; alg = MatrixFunctionViaEigh(LAPACK_QRIteration())) +end + +@testset "exponential! for Diagonal{$T}" for T in (BLASFloats..., GenericFloats...) + rng = StableRNG(123) + m = 54 + + A = Diagonal(randn(rng, T, m)) + τ = randn(rng, T) + Ac = copy(A) + + expA = LinearAlgebra.exp(A) + + expA2 = @constinferred exponential(A) + @test expA ≈ expA2 + @test A == Ac +end + +@testset "exponentialr! for Diagonal{$T}" for T in (BLASFloats..., GenericFloats...) + rng = StableRNG(123) + m = 1 + + A = Diagonal(randn(rng, T, m)) + τ = randn(rng, T) + Ac = copy(A) + + Aτ = A * τ + expAτ = LinearAlgebra.exp(Aτ) + + expAτ2 = @constinferred exponentialr(τ, A) + @test expAτ ≈ expAτ2 + @test A == Ac +end diff --git a/test/genericlinearalgebra/exponential.jl b/test/genericlinearalgebra/exponential.jl new file mode 100644 index 000000000..79ccef803 --- /dev/null +++ b/test/genericlinearalgebra/exponential.jl @@ -0,0 +1,47 @@ +using MatrixAlgebraKit +using Test +using TestExtras +using StableRNGs +using MatrixAlgebraKit: diagview +using LinearAlgebra + +GenericFloats = (BigFloat, Complex{BigFloat}) + +@testset "exponential! for T = $T" for T in GenericFloats + rng = StableRNG(123) + m = 54 + + A = randn(rng, T, m, m) + A = (A + A') / 2 + D, V = @constinferred eigh_full(A) + algs = (MatrixFunctionViaEigh(GLA_QRIteration()),) + @testset "algorithm $alg" for alg in algs + expA = @constinferred exponential!(copy(A); alg) + expA2 = @constinferred exponential(A; alg) + @test expA2 ≈ expA + + Dexp, Vexp = @constinferred eigh_full(expA) + @test diagview(Dexp) ≈ LinearAlgebra.exp.(diagview(D)) + end +end + +using GenericSchur +@testset "exponentialr! for T1 = $T1, T2 = $T2" for T1 in GenericFloats, T2 in GenericFloats + rng = StableRNG(123) + m = 54 + A = randn(rng, T1, m, m) + A = (A + A') / 2 + τ = randn(rng, T2) + + D, V = @constinferred eigh_full(A) + algs = (MatrixFunctionViaEigh(GLA_QRIteration()),) + @testset "algorithm $alg" for alg in algs + expτA = @constinferred exponentialr!(τ, copy(A); alg) + expτA2 = @constinferred exponentialr(τ, A; alg) + @test expτA2 ≈ expτA + + Dexp, Vexp = @constinferred eig_full(expτA) + + @test sort(diagview(Dexp); by = real) ≈ sort(LinearAlgebra.exp.(diagview(D) .* τ); by = real) + end +end diff --git a/test/genericschur/exponential.jl b/test/genericschur/exponential.jl new file mode 100644 index 000000000..ffe57a2a9 --- /dev/null +++ b/test/genericschur/exponential.jl @@ -0,0 +1,46 @@ +using MatrixAlgebraKit +using Test +using TestExtras +using StableRNGs +using MatrixAlgebraKit: diagview +using LinearAlgebra + +GenericFloats = (BigFloat, Complex{BigFloat}) + +@testset "exponential! for T = $T" for T in GenericFloats + rng = StableRNG(123) + m = 54 + + A = randn(rng, T, m, m) + D, V = @constinferred eig_full(A) + algs = (MatrixFunctionViaEig(GS_QRIteration()),) + expA_LA = @constinferred exponential(A) + @testset "algorithm $alg" for alg in algs + expA = @constinferred exponential!(copy(A)) + expA2 = @constinferred exponential(A; alg = alg) + @test expA ≈ expA_LA + @test expA2 ≈ expA + + Dexp, Vexp = @constinferred eig_full(expA) + @test sort(diagview(Dexp); by = imag) ≈ sort(LinearAlgebra.exp.(diagview(D)); by = imag) + end +end + +@testset "exponentialr! for T1 = $T1, T2 = $T2" for T1 in GenericFloats, T2 in GenericFloats + rng = StableRNG(123) + m = 54 + + A = randn(rng, T1, m, m) + τ = randn(rng, T2) + + D, V = @constinferred eig_full(A) + algs = (MatrixFunctionViaEig(GS_QRIteration()),) + @testset "algorithm $alg" for alg in algs + expτA = @constinferred exponentialr!(τ, copy(A)) + expτA2 = @constinferred exponentialr(τ, A; alg) + @test expτA2 ≈ expτA + + Dexp, Vexp = @constinferred eig_full(expτA) + @test sort(diagview(Dexp); by = x -> (imag(x), real(x))) ≈ sort(LinearAlgebra.exp.(diagview(D) .* τ); by = x -> (imag(x), real(x))) + end +end