diff --git a/src/dense_sparse.jl b/src/dense_sparse.jl index fc056dd..88a02c9 100644 --- a/src/dense_sparse.jl +++ b/src/dense_sparse.jl @@ -22,17 +22,17 @@ description(P::DenseSparseArray) = "" -@inline _convert_idx(x::Symbol,GU::GamsUniverse,d::Symbol) = (x==d || x∈GU[d].aliases || x∈GU[d]) ? (x == d || x∈GU[d].aliases ? [[i] for i∈GU[d]] : [[x]]) : throw(DomainError(x, "Symbol $x is neither a set nor an element of the set $d.")) +@inline _convert_idx(x::Symbol,GU::GamsUniverse,d::Symbol) = (x==d || x∈GU[d].aliases || x∈GU[d]) ? (x == d || x∈GU[d].aliases ? [i for i∈GU[d]] : [x]) : throw(DomainError(x, "Symbol $x is neither a set nor an element of the set $d.")) function _convert_idx(x::Vector,GU::GamsUniverse,d::Symbol) @assert (all(i∈GU[d] for i∈x)) "At least one element of $x is not in set $d" - return [[i] for i∈x] + return x end function _convert_idx(x::GamsSet,GU::GamsUniverse,d::Symbol) @assert x==GU[d] "The set\n\n$x\ndoes not match the domain set $d" - return [[i] for i∈x] + return [i for i∈x] end @@ -49,59 +49,45 @@ function Base.getindex(P::DenseSparseArray{T,N},idx::CartesianIndex{N}) where {T end -""" -if v = [:a,:b,:c,:d,:e] and p = (2,2,1) - -we expect output of - -[[:a,:b],[:c,:d],[:e]] -""" -function partition(v,p) - ind = ((sum(p[i] for i∈1:n;init=1),sum(p[i] for i∈1:(n+1);init=1)-1) for n∈0:(length(p)-1)) - return [[v[i] for i∈a:b] for (a,b)∈ind] -end - -function dimension(x) - return 1 +#Assume input has no masks +function Base.getindex(P::DenseSparseArray{T,N},idx::Vararg{Any,N}) where {T,N} + return _getindex(P,idx...) end -function Base.getindex(P::DenseSparseArray{T,N},idx::Vararg{Any}) where {T,N} - domain_match = partition(domain(P),dimension.(idx)) #Used for masking - @assert sum(length.(domain_match)) == dimension(P) "Not enough inputs, or something. Get a better error message" - +function _getindex(P::DenseSparseArray{T,N},idx::Vararg{Any,N}) where {T,N} GU = universe(P) - idx = map((x,d)->_convert_idx(x,GU,d...),idx,domain_match) + d = domain(P) - X = Tuple.(Iterators.flatten.(Iterators.product(idx...))) + idx = map((x,d) -> _convert_idx(x,GU,d), idx, d) |> + x -> collect(Iterators.product(x...)) |> + x -> dropdims(x,dims=tuple(findall(size(x).==1)...)) data_dict = data(P) - length(X) == 1 ? get(data_dict,X[1],0) : get.(Ref(data_dict),X,0) + length(idx) == 1 ? get(data_dict,idx[1],zero(T)) : get.(Ref(data_dict),idx,zero(T)) end - ################ ### setindex ### ################ function Base.setindex!(P::DenseSparseArray{T,N}, value, idx::Vararg{Any,N}) where {T,N} - domain_match = partition(domain(P),dimension.(idx)) - @assert sum(length.(domain_match)) == dimension(P) "Not enough inputs, or something. Get a better error message" + #domain_match = partition(domain(P),dimension.(idx)) + #@assert sum(length.(domain_match)) == dimension(P) "Not enough inputs, or something. Get a better error message" GU = universe(P) - idx = map((x,d)->_convert_idx(x,GU,d...),idx,domain_match) - - X = Tuple.(Iterators.flatten.(Iterators.product(idx...))) - - + d = domain(P) + idx = map((x,d) -> _convert_idx(x,GU,d), idx, d) + + idx = collect(Iterators.product(idx...)) - if length(X) == 1 - _setindex!(P,value,X[1]) + if length(idx) == 1 + _setindex!(P,value,idx[1]) else - _setindex!.(Ref(P), value, X) + _setindex!.(Ref(P), value, idx) end end diff --git a/src/parameter.jl b/src/parameter.jl index 7b5ff3b..351dad5 100644 --- a/src/parameter.jl +++ b/src/parameter.jl @@ -17,6 +17,61 @@ function _convert_idx(x::Mask,GU::GamsUniverse,d::Vararg{Symbol}) return keys(data(x)) end +""" + +if v = [:a,:b,:c,:d,:e] and p = (2,2,1) + +we expect output of + +[[:a,:b],[:c,:d],[:e]] +""" +function partition(v,p) + ind = ((sum(p[i] for i∈1:n;init=1),sum(p[i] for i∈1:(n+1);init=1)-1) for n∈0:(length(p)-1)) + return [[v[i] for i∈a:b] for (a,b)∈ind] +end + +function dimension(x) + return 1 +end + +function Base.getindex(P::Parameter{T,N},idx::Vararg{Any}) where {T,N} + if any(isa(x,Mask) for x∈idx) + _getindex_mask(P,idx...) + else + + GamsStructure._getindex(P,idx...) + end +end + +""" + Code from TupleTools.jl package +""" +flatten(x::Any) = (x,) +flatten(t::Tuple{}) = () +flatten(t::Tuple) = (flatten(t[1])..., flatten(Base.tail(t))...) +flatten(x, r...) = (flatten(x)..., flatten(r)...) + +function _getindex_mask(P::Parameter{T,N},idx...) where {T,N} + + GU = GamsStructure.universe(P) + d = GamsStructure.domain(P) + + domain_match = GamsStructure.partition(d,GamsStructure.dimension.(idx)) + + idx = map((x,d) -> GamsStructure._convert_idx(x,GU,d...), idx, domain_match) |> + x -> GamsStructure.collect(Iterators.product(x...)) |> + x -> GamsStructure.dropdims(x,dims=tuple(findall(size(x).==1)...)) |> + x -> flatten.(x) + + + #return idx + data_dict = GamsStructure.data(P) + length(idx) == 1 ? get(data_dict,idx[1],zero(T)) : get.(Ref(data_dict),idx,zero(T)) + +end + + + """ @create_parameters(GU,block)