Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ authors = ["Jutho Haegeman <jutho.haegeman@ugent.be>, Lukas Devos, Katharine Hya
[deps]
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"

[weakdeps]
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
Expand Down
1 change: 1 addition & 0 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ makedocs(;
"user_interface/decompositions.md",
"user_interface/algorithms.md",
"user_interface/truncations.md",
"user_interface/sketching.md",
"user_interface/properties.md",
"user_interface/matrix_functions.md",
],
Expand Down
4 changes: 4 additions & 0 deletions docs/src/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,14 @@ When releasing a new version, move the "Unreleased" changes to a new version sec

### Added

- Randomized sketching API: new `SketchedAlgorithm` wrapper, `SketchingStrategy` supertype with the `GaussianSketching` strategy, and standalone `left_sketch` / `right_sketch` (plus bang variants) entry points for low-rank factorizations. `svd_trunc` / `svd_trunc!` / `svd_trunc_no_error` accept a new `sketch =` keyword that mirrors the existing `trunc =` pattern. The CUDA extension routes `SketchedAlgorithm(; driver = CUSOLVER())` to cuSOLVER's fused `gesvdr` kernel ([#225](https://github.com/QuantumKitHub/MatrixAlgebraKit.jl/pull/225)). See [Sketching](@ref) for details.

### Changed

### Deprecated

- `CUSOLVER_Randomized` is deprecated in favour of `SketchedAlgorithm(; driver = CUSOLVER())` or the new `sketch =` keyword on `svd_trunc`. Calls using `CUSOLVER_Randomized` now emit a deprecation warning at algorithm selection and are forwarded to the new pipeline ([#225](https://github.com/QuantumKitHub/MatrixAlgebraKit.jl/pull/225)).

### Removed

### Fixed
Expand Down
1 change: 1 addition & 0 deletions docs/src/dev_interface.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,5 @@ MatrixAlgebraKit.default_algorithm
MatrixAlgebraKit.select_algorithm
MatrixAlgebraKit.findtruncated
MatrixAlgebraKit.findtruncated_svd
MatrixAlgebraKit.select_sketching
```
17 changes: 17 additions & 0 deletions docs/src/user_interface/algorithms.md
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,23 @@ They all accept an optional `driver` keyword to select the computational backend

For full docstring details on each algorithm type, see the corresponding section in [Decompositions](@ref).

### Algorithm Wrappers

In addition to the terminal algorithms above, MatrixAlgebraKit provides algorithm *wrappers* that compose a terminal algorithm with additional behaviour rather than computing a factorization themselves.
Their `alg` field is one of the algorithms in the table above; the wrapper adds either a post-truncation step or a sketching step (or both).

| Wrapper | Applicable decompositions | Key keyword arguments |
|:--------|:--------------------------|:----------------------|
| [`TruncatedAlgorithm`](@ref) | `svd_trunc`, `eigh_trunc`, `eig_trunc` | `alg`, `trunc` |
| [`SketchedAlgorithm`](@ref) | `svd_trunc` | `sketch`, `trunc`, `alg`, `driver` |

```@docs; canonical=false
MatrixAlgebraKit.TruncatedAlgorithm
MatrixAlgebraKit.SketchedAlgorithm
```

The corresponding truncated-decomposition functions ([`svd_trunc`](@ref), [`eigh_trunc`](@ref), [`eig_trunc`](@ref)) also accept `trunc =` and (for SVD) `sketch =` keyword arguments that construct the appropriate wrapper automatically; see [Truncations](@ref) and [Sketching](@ref) for the full interface.

## [Driver Selection](@id sec_driverselection)

!!! note "Expert use case"
Expand Down
2 changes: 2 additions & 0 deletions docs/src/user_interface/decompositions.md
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,8 @@ It is also possible to compute the singular values only, using the [`svd_vals`](
This then returns a vector of the values on the diagonal of `Σ`.

Finally, we also support computing a partial or truncated SVD, using the [`svd_trunc`](@ref) function.
To control the behavior of the truncation, we refer to [Truncations](@ref) for more information.
Furthermore, for large matrices with rapidly-decaying spectra, the truncated SVD can additionally be computed via randomized sketching; see [Sketching](@ref).

```@docs; canonical=false
svd_full
Expand Down
152 changes: 152 additions & 0 deletions docs/src/user_interface/sketching.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
```@meta
CurrentModule = MatrixAlgebraKit
CollapsedDocStrings = true
```

# Sketching

*Sketching* methods project a large matrix onto a low-dimensional subspace before any expensive dense factorization is performed.
The result is an approximate low-rank decomposition that can be substantially cheaper than the corresponding full decomposition, depending on the quality of the sketch and the spectrum of the original matrix.

This page first describes how to compute a sketch as a standalone operation, then how to plug a sketch into a partial decomposition such as a truncated SVD.

## Standalone Sketches

The basic sketching primitives are [`left_sketch`](@ref) and [`right_sketch`](@ref) (with their bang counterparts).
They return an isometric range/co-range factor together with a corresponding core matrix, independent of any subsequent decomposition.

[`left_sketch`](@ref) computes an isometric matrix `Q` whose column span approximates the range of `A`, together with the core factor `B = Q' * A`:

```jldoctest sketching; output=false
using MatrixAlgebraKit
using MatrixAlgebraKit: diagview
using LinearAlgebra: norm
using Random: MersenneTwister

# A rank-3 matrix
A = randn(MersenneTwister(0), 8, 3) * randn(MersenneTwister(1), 3, 6);

Q, B = left_sketch(A; howmany = 3, rng = MersenneTwister(42));
isisometric(Q) && A ≈ Q * B

# output
true
```

[`right_sketch`](@ref) is the dual operation, returning a right-isometric matrix `Pᴴ` (orthonormal rows) and core factor `B = A * Pᴴ'`:

```jldoctest sketching; output=false
B, Pᴴ = right_sketch(A; howmany = 3, rng = MersenneTwister(42));
isisometric(Pᴴ') && A ≈ B * Pᴴ

# output
true
```

These functions follow the same conventions as other decompositions: `left_sketch!` / `right_sketch!` may destroy the input, and the bang forms optionally accept pre-allocated output tuples.

```@docs; canonical=false
left_sketch
left_sketch!
right_sketch
right_sketch!
```

## Available Sketching Strategies

The keyword arguments accepted by `left_sketch` / `right_sketch` are forwarded to the default sketching strategy for the input type (currently [`GaussianSketching`](@ref) for all `AbstractMatrix`).
For full control, construct a strategy directly and pass it as the second positional argument:

```jldoctest sketching; output=false
Q, B = left_sketch(A, GaussianSketching(3; numiter = 4, rng = MersenneTwister(42)));
A ≈ Q * B

# output
true
```

```@docs; canonical=false
GaussianSketching
SketchingStrategy
```

The `numiter` keyword controls the number of power iterations.
The first iteration is the initial sketch; additional iterations apply `A * A'` to improve accuracy on slowly-decaying spectra at the cost of two extra matrix multiplications per iteration.
The default `numiter = 2` is a conservative choice; values of 4–5 often improve accuracy significantly when the singular values decay slowly.

!!! note "Additional strategies"
[`GaussianSketching`](@ref) is currently the only built-in [`SketchingStrategy`](@ref).
Additional strategies (for example, structured or subsampled sketches) may be added in the future; the interface is deliberately written against the abstract [`SketchingStrategy`](@ref) supertype so that new strategies plug in without changes to the downstream decomposition code.

## Sketched Partial Decompositions

A sketch can be combined with a small dense decomposition of the core factor to obtain an approximate partial decomposition of the original matrix.
At present, this is supported for the truncated SVD via [`svd_trunc`](@ref) / [`svd_trunc!`](@ref) / [`svd_trunc_no_error`](@ref).

!!! note "Not yet supported"
Sketched variants of [`eigh_trunc`](@ref) and [`eig_trunc`](@ref) are natural extensions of the same machinery but are not implemented yet.
The [`SketchedAlgorithm`](@ref) wrapper and the `sketch =` keyword are currently only accepted by the truncated SVD functions.

There are two equivalent ways to request a sketched truncated SVD, paralleling the two-form syntax used for [Truncations](@ref).

### 1. Using the `sketch` keyword with a `NamedTuple`

The simplest form is to pass a `NamedTuple` of sketch parameters together with the desired truncation:

```jldoctest sketching; output=false
U, S, Vᴴ, ϵ = svd_trunc(A;
sketch = (; howmany = 3, rng = MersenneTwister(42)),
trunc = truncrank(3),
);
size(diagview(S), 1) == 3 && A ≈ U * S * Vᴴ

# output
true
```

The `NamedTuple` keywords are forwarded to the default sketching strategy for the input type, exactly as for `left_sketch` above.

### 2. Using an explicit `SketchedAlgorithm`

For full control, construct a [`SketchedAlgorithm`](@ref) value directly and pass it as the `alg` argument:

```jldoctest sketching; output=false
alg = SketchedAlgorithm(;
sketch = GaussianSketching(3; rng = MersenneTwister(42)),
trunc = truncrank(3),
);
U, S, Vᴴ, ϵ = svd_trunc(A, alg);
A ≈ U * S * Vᴴ

# output
true
```

When an `alg::SketchedAlgorithm` is supplied, the `sketch` and `trunc` keywords cannot also be specified at the call site; doing so raises `ArgumentError`.
All configuration must instead live inside the algorithm constructor.

### The `SketchedAlgorithm` Wrapper

```@docs; canonical=false
SketchedAlgorithm
```

`SketchedAlgorithm` differs from [`TruncatedAlgorithm`](@ref) in that it is *self-truncating*: the sketch step itself produces a small dense problem of size `sketch.howmany`, and any further `trunc` is applied to the result of the inner factorization rather than to a full dense decomposition.

The `driver` field selects the backend implementing the sketched pipeline:

- `Native()` (the default for CPU array types) runs the generic *sketch-then-decompose* pipeline using the standard [`left_sketch!`](@ref) / [`right_sketch!`](@ref) building blocks followed by the inner `alg`.
- `CUSOLVER()` (the default for CUDA array types, with the appropriate extension loaded) dispatches to cuSOLVER's fused `gesvdr` kernel, which performs the sketch and the small SVD in a single device call.

To force a particular driver, set it explicitly:

```julia
using MatrixAlgebraKit: CUSOLVER # driver types are not exported by default

alg = SketchedAlgorithm(;
sketch = GaussianSketching(k; numiter = 4),
trunc = truncrank(k),
driver = CUSOLVER(),
)
U, S, Vᴴ, ϵ = svd_trunc(A_cuda, alg)
```
30 changes: 26 additions & 4 deletions ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@ using MatrixAlgebraKit
using MatrixAlgebraKit: @algdef, Algorithm, check_input
using MatrixAlgebraKit: one!, zero!, uppertriangular!, lowertriangular!
using MatrixAlgebraKit: diagview, sign_safe
using MatrixAlgebraKit: CUSOLVER, LQViaTransposedQR, TruncationByValue, AbstractAlgorithm
using MatrixAlgebraKit: CUSOLVER, LQViaTransposedQR, TruncationByValue, TruncationByOrder, AbstractAlgorithm
using MatrixAlgebraKit: GaussianSketching, SketchingStrategy, SketchedAlgorithm
using MatrixAlgebraKit: default_qr_algorithm, default_lq_algorithm, default_svd_algorithm, default_eig_algorithm, default_eigh_algorithm
import MatrixAlgebraKit: geqrf!, ungqr!, unmqr!, gesvd!, gesvdp!, gesvdr!, gesvdj!
import MatrixAlgebraKit: heevj!, heevd!, geev!
import MatrixAlgebraKit: _gpu_Xgesvdr!, _sylvester, svd_rank
import MatrixAlgebraKit: _sylvester, svd_rank
using CUDA, CUDA.cuBLAS
using CUDA: i32
using LinearAlgebra
Expand All @@ -17,6 +18,7 @@ using LinearAlgebra: BlasFloat
include("yacusolver.jl")

MatrixAlgebraKit.default_driver(::Type{TA}) where {TA <: StridedCuVecOrMat{<:BlasFloat}} = CUSOLVER()
MatrixAlgebraKit.default_driver(::Type{<:SketchedAlgorithm}, ::Type{TA}) where {TA <: StridedCuVecOrMat{<:BlasFloat}} = CUSOLVER()

function MatrixAlgebraKit.default_svd_algorithm(::Type{T}; kwargs...) where {T <: StridedCuVecOrMat{<:BlasFloat}}
return QRIteration(; kwargs...)
Expand Down Expand Up @@ -50,8 +52,28 @@ end
gesvdp!(::CUSOLVER, A::StridedCuMatrix, S::StridedCuVector, U::StridedCuMatrix, Vᴴ::StridedCuMatrix; kwargs...) =
YACUSOLVER.gesvdp!(A, S, U, Vᴴ; kwargs...)

_gpu_Xgesvdr!(A::StridedCuMatrix, S::StridedCuVector, U::StridedCuMatrix, Vᴴ::StridedCuMatrix; kwargs...) =
YACUSOLVER.gesvdr!(A, S, U, Vᴴ; kwargs...)
# Sketched SVD via cuSOLVER's gesvdr kernel.
# The full m×m / n×n shapes of U / Vᴴ allow YACUSOLVER.gesvdr! to reuse them as cuSOLVER workspace.
# `alg` is accepted but unused: cuSOLVER's gesvdr fuses the inner SVD itself.
function gesvdr!(
::CUSOLVER, A::StridedCuMatrix, S, U::StridedCuMatrix, Vᴴ::StridedCuMatrix;
sketch::GaussianSketching, trunc::TruncationByOrder, alg::AbstractAlgorithm = DefaultAlgorithm()
)
isempty(A) && return U, S, Vᴴ
m, n = size(A)
sketch_amount = min(sketch.howmany, m, n)
k = min(trunc.howmany, m, n)
p = max(sketch_amount - k, 0)
numiter = sketch.numiter

V = Vᴴ # gesvdr returns V, but this has to be the same size so we will use this as workspace

YACUSOLVER.gesvdr!(A, diagview(S), U, V; k, p, numiter)

# Truncate requires Vᴴ, so we adjoint here
USVᴴtrunc, _ = MatrixAlgebraKit.truncate(MatrixAlgebraKit.svd_trunc!, (U, S, V'), trunc)
return USVᴴtrunc
end

geev!(::CUSOLVER, A::StridedCuMatrix, Dd::StridedCuVector, V::StridedCuMatrix) =
YACUSOLVER.Xgeev!(A, Dd, V)
Expand Down
58 changes: 31 additions & 27 deletions ext/MatrixAlgebraKitCUDAExt/yacusolver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -266,41 +266,53 @@ for (bname, fname, elty, relty) in
end
end

# Wrapper for randomized SVD
# Wrapper for randomized SVD.
# Caller must supply full-size buffers: U is (m, m) and Vᴴ is (n, n); both are reused
# directly as cuSOLVER's workspace, and Vᴴ is converted in place from V to Vᴴ on the
# leading k rows after cuSOLVER returns.
# !!! Warning: this function takes in/returns V instead of Vᴴ
function gesvdr!(
A::StridedCuMatrix{T},
S::StridedCuVector = similar(A, real(T), min(size(A)...)),
U::StridedCuMatrix{T} = similar(A, T, size(A, 1), min(size(A)...)),
Vᴴ::StridedCuMatrix{T} = similar(A, T, min(size(A)...), size(A, 2));
U::StridedCuMatrix{T} = similar(A, T, size(A, 1), size(A, 1)),
V::StridedCuMatrix{T} = similar(A, T, size(A, 2), size(A, 2));
k::Int = length(S),
p::Int = min(size(A)...) - k - 1,
niters::Int = 1
numiter::Int = 1,
) where {T <: BlasFloat}
chkstride1(A, U, S, Vᴴ)
chkstride1(A, U, S, V)
m, n = size(A)
minmn = min(m, n)
jobu = length(U) == 0 ? 'N' : 'S'
jobv = length(Vᴴ) == 0 ? 'N' : 'S'
R = eltype(S)
k < minmn || throw(DimensionMismatch("length of S ($k) must be less than the smaller dimension of A ($minmn)"))
k + p < minmn || throw(DimensionMismatch("length of S ($k) plus oversampling ($p) must be less than the smaller dimension of A ($minmn)"))
R == real(T) ||
throw(ArgumentError("S does not have the matching real `eltype` of A"))

Ṽ = similar(Vᴴ, (n, n))
Ũ = (size(U) == (m, m)) ? U : similar(U, (m, m))
length(S) == minmn ||
throw(DimensionMismatch("length of S ($(length(S))) must equal min(size(A)) = $minmn"))
size(U) == (m, m) ||
throw(DimensionMismatch("U must have shape (m, m) = ($m, $m); got $(size(U))"))
size(V) == (n, n) ||
throw(DimensionMismatch("V must have shape (n, n) = ($n, $n); got $(size(V))"))
k < minmn ||
throw(DimensionMismatch("rank k ($k) must be less than min(size(A)) = $minmn"))
k + p < minmn ||
throw(DimensionMismatch("k + p ($(k + p)) must be less than min(size(A)) = $minmn"))

isempty(A) && return S, U, V

jobu = 'S'
jobv = 'S'
lda = max(1, stride(A, 2))
ldu = max(1, stride(, 2))
ldv = max(1, stride(, 2))
ldu = max(1, stride(U, 2))
ldv = max(1, stride(V, 2))
params = cuSOLVER.CuSolverParameters()
dh = cuSOLVER.dense_handle()

function bufferSize()
out_cpu = Ref{Csize_t}(0)
out_gpu = Ref{Csize_t}(0)
cuSOLVER.cusolverDnXgesvdr_bufferSize(
dh, params, jobu, jobv, m, n, k, p, niters,
T, A, lda, R, S, T, , ldu, T, , ldv,
dh, params, jobu, jobv, m, n, k, p, numiter,
T, A, lda, R, S, T, U, ldu, T, V, ldv,
T, out_gpu, out_cpu
)

Expand All @@ -311,8 +323,8 @@ function gesvdr!(
bufferSize()...
) do buffer_gpu, buffer_cpu
return cuSOLVER.cusolverDnXgesvdr(
dh, params, jobu, jobv, m, n, k, p, niters,
T, A, lda, R, S, T, , ldu, T, , ldv,
dh, params, jobu, jobv, m, n, k, p, numiter,
T, A, lda, R, S, T, U, ldu, T, V, ldv,
T, buffer_gpu, sizeof(buffer_gpu),
buffer_cpu, sizeof(buffer_cpu),
dh.info
Expand All @@ -321,16 +333,8 @@ function gesvdr!(

flag = @allowscalar dh.info[1]
cuSOLVER.chklapackerror(BlasInt(flag))
if Ũ !== U && length(U) > 0
U .= view(Ũ, 1:m, 1:size(U, 2))
end
if length(Vᴴ) > 0
Vᴴ .= view(Ṽ', 1:size(Vᴴ, 1), 1:n)
end
Ũ !== U && CUDA.unsafe_free!(Ũ)
CUDA.unsafe_free!(Ṽ)

return S, U, Vᴴ
return S, U, V
end

# Wrapper for general eigensolver
Expand Down
Loading
Loading