Skip to content

Commit

Permalink
hang [h/v]cat lower
Browse files Browse the repository at this point in the history
  • Loading branch information
dkarrasch committed Jun 28, 2024
1 parent 3c9b0ad commit 223f060
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 34 deletions.
3 changes: 2 additions & 1 deletion src/LinearMaps.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,9 @@ const LinearMapVector = AbstractVector{<:LinearMap}
const LinearMapTupleOrVector = Union{LinearMapTuple,LinearMapVector}

Base.eltype(::LinearMap{T}) where {T} = T
Base.eltype(::Type{L}) where {T,L<:LinearMap{T}} = T
Base.eltype(::Type{<:LinearMap{T}}) where {T} = T
Base.eltypeof(x::LinearMap) = eltype(x)
Base.eltypeof(J::UniformScaling) = eltype(J) # fix upstream
Base.promote_eltypeof(v1::Union{AbstractVecOrMatOrQ{T},LinearMap{T}}, vs::Union{AbstractVecOrMatOrQ{T},LinearMap{T}}...) where {T} = T

# conversion to LinearMap
Expand Down
39 changes: 20 additions & 19 deletions src/blockmap.jl
Original file line number Diff line number Diff line change
Expand Up @@ -83,18 +83,7 @@ julia> L * ones(Int, 6)
"""
Base.hcat

Base.hcat(As::T...) where {T<:LinearMap} = Base._cat_t(Val(2), eltype(T), As...)
function Base._cat_t(::Val{2}, ::Type{T}, As::Union{LinearMap, UniformScaling, AbstractArray, AbstractQ}...) where {T}
nbc = length(As)

# find first non-UniformScaling to detect number of rows
j = findfirst(A -> !isa(A, UniformScaling), As)
# this should not happen, function should only be called with at least one LinearMap
@assert !isnothing(j)
@inbounds nrows = size(As[j], 1)::Int

return BlockMap{T}(promote_to_lmaps(ntuple(_ -> nrows, Val(nbc)), 1, 1, As...), (nbc,))
end
Base.hcat(As::T...) where {T<:LinearMap} = Base._cat(Val(2), As...)

############
# vcat
Expand Down Expand Up @@ -123,18 +112,29 @@ julia> L * ones(Int, 3)
"""
Base.vcat

Base.vcat(As::T...) where {T<:LinearMap} = Base._cat_t(Val(1), eltype(T), As...)
function Base._cat_t(::Val{1}, ::Type{T}, As::Union{LinearMap, UniformScaling, AbstractArray, AbstractQ}...) where {T}
nbr = length(As)
Base.vcat(As::T...) where {T<:LinearMap} = Base._cat(Val(1), As...)

function Base._cat(dims, As::Union{LinearMap, UniformScaling, AbstractArray, AbstractQ}...)
T = promote_type(map(eltype, As)...)
nb = length(As)

# find first non-UniformScaling to detect number of rows
j = findfirst(A -> !isa(A, UniformScaling), As)
# this should not happen, function should only be called with at least one LinearMap
@assert !isnothing(j)
@inbounds ncols = size(As[j], 2)::Int

rows = ntuple(_ -> 1, Val(nbr))
return BlockMap{T}(promote_to_lmaps(ntuple(_ -> ncols, Val(nbr)), 1, 2, As...), rows)
if dims isa Val{2}
@inbounds nrows = size(As[j], 1)::Int
return BlockMap{T}(promote_to_lmaps(ntuple(_ -> nrows, Val(nb)), 1, 1, As...), (nb,))
elseif dims isa Val{1}
@inbounds ncols = size(As[j], 2)::Int

rows = ntuple(_ -> 1, Val(nb))
return BlockMap{T}(promote_to_lmaps(ntuple(_ -> ncols, Val(nb)), 1, 2, As...), rows)
elseif dims isa Dims{2}
Base._cat_t(dims, T, As...)
else
throw(ArgumentError("unhandled dims argument"))
end
end

############
Expand Down Expand Up @@ -225,6 +225,7 @@ end

promote_to_lmaps_(n::Int, dim, A::AbstractVecOrMat) = (check_dim(A, dim, n); LinearMap(A))
promote_to_lmaps_(n::Int, dim, J::UniformScaling) = UniformScalingMap(J.λ, n)
promote_to_lmaps_(n::Int, dim, Q::AbstractQ) = (check_dim(Q, dim, n); LinearMap(Q))
promote_to_lmaps_(n::Int, dim, A::LinearMap) = (check_dim(A, dim, n); A)
promote_to_lmaps(n, k, dim) = ()
promote_to_lmaps(n, k, dim, A) = (promote_to_lmaps_(n[k], dim, A),)
Expand Down
4 changes: 1 addition & 3 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Octonions = "d00ba074-1e29-4f5e-9fd4-d67071d6a14d"
Expand All @@ -15,11 +14,10 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[compat]
Aqua = "0.8"
BlockArrays = "0.16"
BlockArrays = "0.16, 1"
ChainRulesCore = "1"
ChainRulesTestUtils = "1.9"
Documenter = "1"
InteractiveUtils = "1.6"
IterativeSolvers = "0.9"
LinearAlgebra = "1.6"
Octonions = "0.1, 0.2"
Expand Down
22 changes: 11 additions & 11 deletions test/blockmap.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using Test, LinearMaps, LinearAlgebra, SparseArrays, InteractiveUtils
using Test, LinearMaps, LinearAlgebra, SparseArrays
using LinearMaps: FiveArg

@testset "block maps" begin
Expand Down Expand Up @@ -29,10 +29,9 @@ using LinearMaps: FiveArg
A = [A11 A12 A11]
@test Matrix(L) == A == mul!(zero(A), L, 1, true, false)
A = [I I I A11 A11 A11 a]
@test (@which [A11 A11 A11]).module != LinearMaps
@test (@which [I I I A11 A11 A11]).module != LinearMaps
@test (@which hcat(I, I, I)).module != LinearMaps
@test (@which hcat(I, I, I, LinearMap(A11), A11, A11)).module == LinearMaps
@test [A11 A11 A11] isa AbstractArray
@test [I I I A11 A11 A11 qr(A11).Q a] isa AbstractArray
@test [I I I A11 A11 A11 qr(A11).Q LinearMap(a)] isa LinearMap
maps = @inferred LinearMaps.promote_to_lmaps(ntuple(i->m, 7), 1, 1, I, I, I, LinearMap(A11), A11, A11, a)
@inferred LinearMaps.rowcolranges(maps, (7,))
L = @inferred hcat(I, I, I, LinearMap(A11), A11, A11, a)
Expand Down Expand Up @@ -68,15 +67,15 @@ using LinearMaps: FiveArg
Lv = LinearMaps.BlockMap{elty}([LinearMap(A11), LinearMap(A21)], (1,1))
@test Lv.maps isa Vector
@test L == Lv
@test (@which [A11; A21]).module != LinearMaps
@test [A11; A21] isa AbstractArray
A = [A11; A21]
x = rand(elty, n)
@test size(L) == size(A)
@test Matrix(L) == A == mul!(copy(A), L, 1, true, false)
@test Matrix(Lv) == A == mul!(copy(A), Lv, 1, true, false)
@test L * x Lv * x A * x
A = [I; I; I; A11; A11; A11; reduce(hcat, fill(v, n))]
@test (@which [I; I; I; A11; A11; A11; v v v v v v v v v v]).module != LinearMaps
@test [I; I; I; A11; A11; A11; v v v] isa AbstractArray
L = @inferred vcat(I, I, I, LinearMap(A11), LinearMap(A11), LinearMap(A11), reduce(hcat, fill(v, n)))
@test L == [I; I; I; LinearMap(A11); LinearMap(A11); LinearMap(A11); reduce(hcat, fill(v, n))]
@test L isa LinearMaps.BlockMap{elty}
Expand All @@ -97,7 +96,7 @@ using LinearMaps: FiveArg
A21 = rand(elty, m2, m1)
A22 = ones(elty, m2, m2)
A = [A11 A12; A21 A22]
@test (@which [A11 A12; A21 A22]).module != LinearMaps
@test [A11 A12; A21 A22] isa AbstractArray
@inferred hvcat((2,2), LinearMap(A11), LinearMap(A12), LinearMap(A21), LinearMap(A22))
L = [LinearMap(A11) LinearMap(A12); LinearMap(A21) LinearMap(A22)]
@test L.maps isa Tuple
Expand All @@ -116,7 +115,7 @@ using LinearMaps: FiveArg
end
@test convert(AbstractMatrix, L) == A
A = [I A12; A21 I]
@test (@which [I A12; A21 I]).module != LinearMaps
@test [I A12; A21 I] isa AbstractArray
@inferred hvcat((2,2), I, LinearMap(A12), LinearMap(A21), I)
L = @inferred hvcat((2,2), I, LinearMap(A12), LinearMap(A21), I)
@test L isa LinearMaps.BlockMap{elty}
Expand Down Expand Up @@ -222,8 +221,9 @@ using LinearMaps: FiveArg
end
# Md = diag(M1, M2, M3, M2, M1) # unsupported so use sparse:
Md = Matrix(blockdiag(sparse.((M1, M2, M3, M2, M1))...))
@test (@which blockdiag(sparse.((M1, M2, M3, M2, M1))...)).module != LinearMaps
@test (@which cat(M1, M2, M3, M2, M1; dims=(1,2))).module != LinearMaps
@test blockdiag(sparse.((M1, M2, M3, M2, M1))...) isa AbstractArray
@test cat(M1, M2, M3, M2, M1; dims=(1,2)) isa AbstractArray
@test cat(M2, M2, qr(M3).Q, M3[:,1]; dims=(1,2)) isa AbstractArray
x = randn(elty, size(Md, 2))
Bd = @inferred blockdiag(L1, L2, L3, L2, L1)
@test Bd.maps isa Tuple
Expand Down

0 comments on commit 223f060

Please sign in to comment.