Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Supporting Unitful LinearMaps #196

Draft
wants to merge 16 commits into
base: master
Choose a base branch
from
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Manifest.toml
78 changes: 54 additions & 24 deletions src/composition.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,25 @@ struct CompositeMap{T, As<:LinearMapTupleOrVector} <: LinearMap{T}
for n in 2:N
check_dim_mul(maps[n], maps[n-1])
end
for TA in Base.Iterators.map(eltype, maps)
promote_type(T, TA) == T ||
error("eltype $TA cannot be promoted to $T in CompositeMap constructor")
end
Tprod = eltype(prod(m -> oneunit(eltype(m)), maps)) # handles units
JeffFessler marked this conversation as resolved.
Show resolved Hide resolved
promote_type(T, Tprod) == T ||
error("eltype $Tprod and $T incompatible in CompositeMap constructor")
# for TA in Base.Iterators.map(eltype, maps) # todo: cut
# promote_type(T, TA) == T ||
# error("eltype $TA cannot be promoted to $T in CompositeMap constructor")
# end
new{T, As}(maps)
end
end

CompositeMap{T}(maps::As) where {T, As<:LinearMapTupleOrVector} = CompositeMap{T, As}(maps)

# constructor with eltype inferred from the product
function CompositeMap(maps::As) where {As <: LinearMapTupleOrVector}
T = eltype(prod(m -> oneunit(eltype(m)), maps)) # todo: can the compiler infer?
return CompositeMap{T, As}(maps)
end
JeffFessler marked this conversation as resolved.
Show resolved Hide resolved

Base.mapreduce(::typeof(identity), ::typeof(Base.mul_prod), maps::LinearMapTupleOrVector) =
CompositeMap{promote_type(map(eltype, maps)...)}(reverse(maps))
Base.mapreduce(::typeof(identity), ::typeof(Base.mul_prod), maps::AbstractVector{<:LinearMap{T}}) where {T} =
Expand Down Expand Up @@ -78,39 +88,51 @@ end

# scalar multiplication and division (non-commutative case)
function Base.:(*)(α::Number, A::LinearMap)
T = promote_type(typeof(α), eltype(A))
return CompositeMap{T}(_combine(A, UniformScalingMap(α, size(A, 1))))
# T = promote_type(typeof(α), eltype(A))
# T = eltype(oneunit(α) * oneunit(eltype(A)))
# return CompositeMap{T}(_combine(A, UniformScalingMap(α, size(A, 1))))
return CompositeMap(_combine(A, UniformScalingMap(α, size(A, 1))))
JeffFessler marked this conversation as resolved.
Show resolved Hide resolved
end
function Base.:(*)(α::Number, A::CompositeMap)
T = promote_type(typeof(α), eltype(A))
# T = promote_type(typeof(α), eltype(A))
# T = eltype(oneunit(α) * oneunit(eltype(A)))
Alast = last(A.maps)
if Alast isa UniformScalingMap
return CompositeMap{T}(_combine(_front(A.maps), α * Alast))
# return CompositeMap{T}(_combine(_front(A.maps), α * Alast))
return CompositeMap(_combine(_front(A.maps), α * Alast))
else
return CompositeMap{T}(_combine(A.maps, UniformScalingMap(α, size(A, 1))))
# return CompositeMap{T}(_combine(A.maps, UniformScalingMap(α, size(A, 1))))
return CompositeMap(_combine(A.maps, UniformScalingMap(α, size(A, 1))))
end
end
# needed for disambiguation
function Base.:(*)(α::RealOrComplex, A::CompositeMap{<:RealOrComplex})
T = Base.promote_op(*, typeof(α), eltype(A))
# T = Base.promote_op(*, typeof(α), eltype(A))
T = eltype(oneunit(α) * oneunit(eltype(A)))
return ScaledMap{T}(α, A)
end
function Base.:(*)(A::LinearMap, α::Number)
T = promote_type(typeof(α), eltype(A))
return CompositeMap{T}(_combine(UniformScalingMap(α, size(A, 2)), A))
# T = promote_type(typeof(α), eltype(A))
# T = eltype(oneunit(eltype(A)) * oneunit(α))
# return CompositeMap{T}(_combine(UniformScalingMap(α, size(A, 2)), A))
return CompositeMap(_combine(UniformScalingMap(α, size(A, 2)), A))
end
function Base.:(*)(A::CompositeMap, α::Number)
T = promote_type(typeof(α), eltype(A))
# T = promote_type(typeof(α), eltype(A))
# T = eltype(oneunit(eltype(A)) * oneunit(α))
Afirst = first(A.maps)
if Afirst isa UniformScalingMap
return CompositeMap{T}(_combine(Afirst * α, _tail(A.maps)))
# return CompositeMap{T}(_combine(Afirst * α, _tail(A.maps)))
return CompositeMap(_combine(Afirst * α, _tail(A.maps)))
else
return CompositeMap{T}(_combine(UniformScalingMap(α, size(A, 2)), A.maps))
# return CompositeMap{T}(_combine(UniformScalingMap(α, size(A, 2)), A.maps))
return CompositeMap(_combine(UniformScalingMap(α, size(A, 2)), A.maps))
end
end
# needed for disambiguation
function Base.:(*)(A::CompositeMap{<:RealOrComplex}, α::RealOrComplex)
T = Base.promote_op(*, typeof(α), eltype(A))
# T = Base.promote_op(*, typeof(α), eltype(A))
T = eltype(oneunit(eltype(A)) * oneunit(α))
return ScaledMap{T}(α, A)
end

Expand All @@ -135,20 +157,28 @@ julia> LinearMap(ones(Int, 3, 3)) * CS * I * rand(3, 3);
```
"""
function Base.:(*)(A₁::LinearMap, A₂::LinearMap)
T = promote_type(eltype(A₁), eltype(A₂))
return CompositeMap{T}(_combine(A₂, A₁))
# T = promote_type(eltype(A₁), eltype(A₂))
# T = eltype(prod(m -> oneunit(eltype(m)), (A₁, A₂)))
# return CompositeMap{T}(_combine(A₂, A₁))
return CompositeMap(_combine(A₂, A₁))
end
function Base.:(*)(A₁::LinearMap, A₂::CompositeMap)
T = promote_type(eltype(A₁), eltype(A₂))
return CompositeMap{T}(_combine(A₂.maps, A₁))
# T = promote_type(eltype(A₁), eltype(A₂))
# T = eltype(prod(m -> oneunit(eltype(m)), (A₁, A₂)))
# return CompositeMap{T}(_combine(A₂.maps, A₁))
return CompositeMap(_combine(A₂.maps, A₁))
end
function Base.:(*)(A₁::CompositeMap, A₂::LinearMap)
T = promote_type(eltype(A₁), eltype(A₂))
return CompositeMap{T}(_combine(A₂, A₁.maps))
# T = promote_type(eltype(A₁), eltype(A₂))
# T = eltype(prod(m -> oneunit(eltype(m)), (A₁, A₂)))
# return CompositeMap{T}(_combine(A₂, A₁.maps))
return CompositeMap(_combine(A₂, A₁.maps))
end
function Base.:(*)(A₁::CompositeMap, A₂::CompositeMap)
T = promote_type(eltype(A₁), eltype(A₂))
return CompositeMap{T}(_combine(A₂.maps, A₁.maps))
# T = promote_type(eltype(A₁), eltype(A₂))
# T = eltype(prod(m -> oneunit(eltype(m)), (A₁, A₂)))
# return CompositeMap{T}(_combine(A₂.maps, A₁.maps))
return CompositeMap(_combine(A₂.maps, A₁.maps))
end
# needed for disambiguation
Base.:(*)(A₁::ScaledMap, A₂::CompositeMap) = A₁.λ * (A₁.lmap * A₂)
Expand Down