Skip to content

Commit

Permalink
Fix and test for issue 100 (#103)
Browse files Browse the repository at this point in the history
* Fix and test for issue 100

* Formatter

* apply code suggestions

* Remove GPUArraysCore and bump version

* Migrate to type domain
  • Loading branch information
lkdvos authored Nov 13, 2024
1 parent 1b175b7 commit 7e91607
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 11 deletions.
4 changes: 1 addition & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
name = "KrylovKit"
uuid = "0b1a1467-8014-51b9-945f-bf0ae24f4b77"
authors = ["Jutho Haegeman"]
version = "0.8.2"
version = "0.8.3"

[deps]
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
PackageExtensionCompat = "65ce6f38-6b18-4e1d-a461-8949797d7930"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
Expand All @@ -22,7 +21,6 @@ Aqua = "0.6, 0.7, 0.8"
ChainRulesCore = "1"
ChainRulesTestUtils = "1"
FiniteDifferences = "0.12"
GPUArraysCore = "0.1"
LinearAlgebra = "1"
PackageExtensionCompat = "1"
Printf = "1"
Expand Down
1 change: 0 additions & 1 deletion src/KrylovKit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ using VectorInterface: add!!
using LinearAlgebra
using Printf
using Random
using GPUArraysCore
using PackageExtensionCompat
const IndexRange = AbstractRange{Int}

Expand Down
22 changes: 15 additions & 7 deletions src/orthonormal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,17 @@ LinearAlgebra.mul!(y, b::OrthonormalBasis, x::AbstractVector) = unproject!!(y, b

const BLOCKSIZE = 4096

# helper function to determine if a multithreaded approach should be used
# this uses functionality beyond VectorInterface, but can be faster
_use_multithreaded_array_kernel(y) = _use_multithreaded_array_kernel(typeof(y))
_use_multithreaded_array_kernel(::Type) = false
function _use_multithreaded_array_kernel(::Type{<:Array{T}}) where {T<:Number}
return isbitstype(T) && get_num_threads() > 1
end
function _use_multithreaded_array_kernel(::Type{<:OrthonormalBasis{T}}) where {T}
return _use_multithreaded_array_kernel(T)
end

"""
project!!(y::AbstractVector, b::OrthonormalBasis, x,
[α::Number = 1, β::Number = 0, r = Base.OneTo(length(b))])
Expand Down Expand Up @@ -127,8 +138,7 @@ function unproject!!(y,
α::Number=true,
β::Number=false,
r=Base.OneTo(length(b)))
if y isa AbstractArray && !(y isa AbstractGPUArray) && IndexStyle(y) isa IndexLinear &&
get_num_threads() > 1
if _use_multithreaded_array_kernel(y)
return unproject_linear_multithreaded!(y, b, x, α, β, r)
end
# general case: using only vector operations, i.e. axpy! (similar to BLAS level 1)
Expand Down Expand Up @@ -157,7 +167,7 @@ function unproject_linear_multithreaded!(y::AbstractArray,
length(b[rj]) == m || throw(DimensionMismatch())
end
if n == 0
return β == 1 ? y : β == 0 ? fill!(y, 0) : rmul!(y, β)
return β == 1 ? y : β == 0 ? zerovector!(y) : scale!(y, β)
end
let m = m, n = n, y = y, x = x, b = b, blocksize = prevpow(2, div(BLOCKSIZE, n))
@sync for II in splitrange(1:blocksize:m, get_num_threads())
Expand Down Expand Up @@ -213,8 +223,7 @@ It is the user's responsibility to make sure that the result is still an orthono
α::Number=true,
β::Number=true,
r=Base.OneTo(length(b)))
if y isa AbstractArray && !(y isa AbstractGPUArray) && IndexStyle(y) isa IndexLinear &&
Threads.nthreads() > 1
if _use_multithreaded_array_kernel(y)
return rank1update_linear_multithreaded!(b, y, x, α, β, r)
end
# general case: using only vector operations, i.e. axpy! (similar to BLAS level 1)
Expand Down Expand Up @@ -294,8 +303,7 @@ and are stored in `b`, so the old basis vectors are thrown away. Note that, by d
the subspace spanned by these basis vectors is exactly the same.
"""
function basistransform!(b::OrthonormalBasis{T}, U::AbstractMatrix) where {T} # U should be unitary or isometric
if T <: AbstractArray && !(T <: AbstractGPUArray) && IndexStyle(T) isa IndexLinear &&
get_num_threads() > 1
if _use_multithreaded_array_kernel(b)
return basistransform_linear_multithreaded!(b, U)
end
m, n = size(U)
Expand Down
19 changes: 19 additions & 0 deletions test/issues.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# https://github.com/Jutho/KrylovKit.jl/issues/100
@testset "Issue #100" begin
N = 32 # needs to be large enough to trigger shrinking
A = rand(N, N)
A += A'
v₀ = [rand(N ÷ 2), rand(N ÷ 2)]

vals, vecs, = eigsolve(v₀, 4, :LM; ishermitian=true) do v
v′ = vcat(v...)
y = A * v′
return [y[1:(N ÷ 2)], y[(N ÷ 2 + 1):end]]
end

vals2, vecs2, = eigsolve(A, 4, :LM; ishermitian=true)
@test vals vals2
for (v, v′) in zip(vecs, vecs2)
@test abs(inner(vcat(v...), v′)) 1
end
end
4 changes: 4 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ include("ad/svdsolve.jl")
t = time() - t
println("Tests finished in $t seconds")

# Issues
# ------
include("issues.jl")

module AquaTests
using KrylovKit
using Aqua
Expand Down

2 comments on commit 7e91607

@lkdvos
Copy link
Collaborator Author

@lkdvos lkdvos commented on 7e91607 Nov 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/119360

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.8.3 -m "<description of version>" 7e91607104024ead221ddfd6d444288606b88997
git push origin v0.8.3

Please sign in to comment.