From 79883ef2eb403d795ca5a998457069cb3ddae164 Mon Sep 17 00:00:00 2001 From: mitchphillipson Date: Tue, 23 Jan 2024 07:23:49 -0600 Subject: [PATCH 1/3] Improve performance getting indices --- src/dense_sparse.jl | 40 +++++++++++++++++++--------------------- 1 file changed, 19 insertions(+), 21 deletions(-) diff --git a/src/dense_sparse.jl b/src/dense_sparse.jl index b0cf33f..33300d1 100644 --- a/src/dense_sparse.jl +++ b/src/dense_sparse.jl @@ -22,17 +22,17 @@ description(P::DenseSparseArray) = "" -@inline _convert_idx(x::Symbol,GU::GamsUniverse,d::Symbol) = (x==d || x∈GU[d]) ? (x == d ? [[i] for i∈GU[d]] : [[x]]) : throw(DomainError(x, "Symbol $x is neither a set nor an element of the set $d.")) +@inline _convert_idx(x::Symbol,GU::GamsUniverse,d::Symbol) = (x==d || x∈GU[d].aliases || x∈GU[d]) ? (x == d ? [i for i∈GU[d]] : [x]) : throw(DomainError(x, "Symbol $x is neither a set nor an element of the set $d.")) function _convert_idx(x::Vector,GU::GamsUniverse,d::Symbol) @assert (all(i∈GU[d] for i∈x)) "At least one element of $x is not in set $d" - return [[i] for i∈x] + return x end function _convert_idx(x::GamsSet,GU::GamsUniverse,d::Symbol) @assert x==GU[d] "The set\n\n$x\ndoes not match the domain set $d" - return [[i] for i∈x] + return [i for i∈x] end @@ -66,19 +66,18 @@ function dimension(x) return 1 end -function Base.getindex(P::DenseSparseArray{T,N},idx::Vararg{Any}) where {T,N} - domain_match = partition(domain(P),dimension.(idx)) #Used for masking - - @assert sum(length.(domain_match)) == dimension(P) "Not enough inputs, or something. Get a better error message" +#Assume input has no masks +function Base.getindex(P::DenseSparseArray{T,N},idx::Vararg{Any,N}) where {T,N} GU = universe(P) - idx = map((x,d)->_convert_idx(x,GU,d...),idx,domain_match) + d = domain(P) + idx = map((x,d) -> _convert_idx(x,GU,d), idx, d) - X = Tuple.(Iterators.flatten.(Iterators.product(idx...))) + idx = collect(Iterators.product(idx...)) + idx = dropdims(idx,dims=tuple(findall(size(idx).==1)...)) data_dict = data(P) - length(X) == 1 ? get(data_dict,X[1],0) : get.(Ref(data_dict),X,0) - + length(idx) == 1 ? get(data_dict,idx[1],zero(T)) : get.(Ref(data_dict),idx,zero(T)) end @@ -88,20 +87,19 @@ end function Base.setindex!(P::DenseSparseArray{T,N}, value, idx::Vararg{Any,N}) where {T,N} - domain_match = partition(domain(P),dimension.(idx)) - @assert sum(length.(domain_match)) == dimension(P) "Not enough inputs, or something. Get a better error message" + #domain_match = partition(domain(P),dimension.(idx)) + #@assert sum(length.(domain_match)) == dimension(P) "Not enough inputs, or something. Get a better error message" GU = universe(P) - idx = map((x,d)->_convert_idx(x,GU,d...),idx,domain_match) - - X = Tuple.(Iterators.flatten.(Iterators.product(idx...))) - - + d = domain(P) + idx = map((x,d) -> _convert_idx(x,GU,d), idx, d) + + idx = collect(Iterators.product(idx...)) - if length(X) == 1 - _setindex!(P,value,X[1]) + if length(idx) == 1 + _setindex!(P,value,idx[1]) else - _setindex!.(Ref(P), value, X) + _setindex!.(Ref(P), value, idx) end end From 26ebf8a0aa59cfb2b7d1e9fe64eac892573a81dd Mon Sep 17 00:00:00 2001 From: mitchphillipson Date: Tue, 23 Jan 2024 07:41:41 -0600 Subject: [PATCH 2/3] Alias fix --- src/dense_sparse.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/dense_sparse.jl b/src/dense_sparse.jl index 33300d1..2b3140c 100644 --- a/src/dense_sparse.jl +++ b/src/dense_sparse.jl @@ -22,7 +22,7 @@ description(P::DenseSparseArray) = "" -@inline _convert_idx(x::Symbol,GU::GamsUniverse,d::Symbol) = (x==d || x∈GU[d].aliases || x∈GU[d]) ? (x == d ? [i for i∈GU[d]] : [x]) : throw(DomainError(x, "Symbol $x is neither a set nor an element of the set $d.")) +@inline _convert_idx(x::Symbol,GU::GamsUniverse,d::Symbol) = (x==d || x∈GU[d].aliases || x∈GU[d]) ? (x == d || x∈GU[d].aliases ? [i for i∈GU[d]] : [x]) : throw(DomainError(x, "Symbol $x is neither a set nor an element of the set $d.")) function _convert_idx(x::Vector,GU::GamsUniverse,d::Symbol) From 6bf547bc76bb166306a7f24912d3662edaaa63a6 Mon Sep 17 00:00:00 2001 From: mitchphillipson Date: Tue, 23 Jan 2024 08:53:59 -0600 Subject: [PATCH 3/3] Fixing masking indexing --- src/dense_sparse.jl | 28 +++++++---------------- src/parameter.jl | 55 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 63 insertions(+), 20 deletions(-) diff --git a/src/dense_sparse.jl b/src/dense_sparse.jl index 2b3140c..88a02c9 100644 --- a/src/dense_sparse.jl +++ b/src/dense_sparse.jl @@ -49,37 +49,25 @@ function Base.getindex(P::DenseSparseArray{T,N},idx::CartesianIndex{N}) where {T end -""" -if v = [:a,:b,:c,:d,:e] and p = (2,2,1) - -we expect output of - -[[:a,:b],[:c,:d],[:e]] -""" -function partition(v,p) - ind = ((sum(p[i] for i∈1:n;init=1),sum(p[i] for i∈1:(n+1);init=1)-1) for n∈0:(length(p)-1)) - return [[v[i] for i∈a:b] for (a,b)∈ind] +#Assume input has no masks +function Base.getindex(P::DenseSparseArray{T,N},idx::Vararg{Any,N}) where {T,N} + return _getindex(P,idx...) end -function dimension(x) - return 1 -end -#Assume input has no masks -function Base.getindex(P::DenseSparseArray{T,N},idx::Vararg{Any,N}) where {T,N} - +function _getindex(P::DenseSparseArray{T,N},idx::Vararg{Any,N}) where {T,N} GU = universe(P) d = domain(P) - idx = map((x,d) -> _convert_idx(x,GU,d), idx, d) - idx = collect(Iterators.product(idx...)) - idx = dropdims(idx,dims=tuple(findall(size(idx).==1)...)) + idx = map((x,d) -> _convert_idx(x,GU,d), idx, d) |> + x -> collect(Iterators.product(x...)) |> + x -> dropdims(x,dims=tuple(findall(size(x).==1)...)) data_dict = data(P) length(idx) == 1 ? get(data_dict,idx[1],zero(T)) : get.(Ref(data_dict),idx,zero(T)) -end +end ################ ### setindex ### diff --git a/src/parameter.jl b/src/parameter.jl index 7b5ff3b..351dad5 100644 --- a/src/parameter.jl +++ b/src/parameter.jl @@ -17,6 +17,61 @@ function _convert_idx(x::Mask,GU::GamsUniverse,d::Vararg{Symbol}) return keys(data(x)) end +""" + +if v = [:a,:b,:c,:d,:e] and p = (2,2,1) + +we expect output of + +[[:a,:b],[:c,:d],[:e]] +""" +function partition(v,p) + ind = ((sum(p[i] for i∈1:n;init=1),sum(p[i] for i∈1:(n+1);init=1)-1) for n∈0:(length(p)-1)) + return [[v[i] for i∈a:b] for (a,b)∈ind] +end + +function dimension(x) + return 1 +end + +function Base.getindex(P::Parameter{T,N},idx::Vararg{Any}) where {T,N} + if any(isa(x,Mask) for x∈idx) + _getindex_mask(P,idx...) + else + + GamsStructure._getindex(P,idx...) + end +end + +""" + Code from TupleTools.jl package +""" +flatten(x::Any) = (x,) +flatten(t::Tuple{}) = () +flatten(t::Tuple) = (flatten(t[1])..., flatten(Base.tail(t))...) +flatten(x, r...) = (flatten(x)..., flatten(r)...) + +function _getindex_mask(P::Parameter{T,N},idx...) where {T,N} + + GU = GamsStructure.universe(P) + d = GamsStructure.domain(P) + + domain_match = GamsStructure.partition(d,GamsStructure.dimension.(idx)) + + idx = map((x,d) -> GamsStructure._convert_idx(x,GU,d...), idx, domain_match) |> + x -> GamsStructure.collect(Iterators.product(x...)) |> + x -> GamsStructure.dropdims(x,dims=tuple(findall(size(x).==1)...)) |> + x -> flatten.(x) + + + #return idx + data_dict = GamsStructure.data(P) + length(idx) == 1 ? get(data_dict,idx[1],zero(T)) : get.(Ref(data_dict),idx,zero(T)) + +end + + + """ @create_parameters(GU,block)