diff --git a/src/dense_sparse.jl b/src/dense_sparse.jl index a14efa9..b600721 100644 --- a/src/dense_sparse.jl +++ b/src/dense_sparse.jl @@ -29,12 +29,12 @@ description(P::DenseSparseArray) = "" function _convert_idx(x::Vector,GU::GamsUniverse,d::Symbol) @assert (all(i∈GU[d] for i∈x)) "At least one element of $x is not in set $d" - return x + return [[i] for i∈x] end function _convert_idx(x::GamsSet,GU::GamsUniverse,d::Symbol) @assert x==GU[d] "The set\n\n$x\ndoes not match the domain set $d" - return [i for i∈x] + return [[i] for i∈x] end @@ -51,26 +51,39 @@ function Base.getindex(P::DenseSparseArray{T,N},idx::CartesianIndex{N}) where {T end +""" -#Assume input has no masks -function Base.getindex(P::DenseSparseArray{T,N},idx::Vararg{Any,N}) where {T,N} - return _getindex(P,idx...) +if v = [:a,:b,:c,:d,:e] and p = (2,2,1) + +we expect output of + +[[:a,:b],[:c,:d],[:e]] +""" +function partition(v,p) + ind = ((sum(p[i] for i∈1:n;init=1),sum(p[i] for i∈1:(n+1);init=1)-1) for n∈0:(length(p)-1)) + return [[v[i] for i∈a:b] for (a,b)∈ind] end +function dimension(x) + return 1 +end -function _getindex(P::DenseSparseArray{T,N},idx::Vararg{Any,N}) where {T,N} +function Base.getindex(P::DenseSparseArray{T,N},idx::Vararg{Any}) where {T,N} + domain_match = partition(domain(P),dimension.(idx)) #Used for masking + + @assert sum(length.(domain_match)) == dimension(P) "Not enough inputs, or something. Get a better error message" + GU = universe(P) - d = domain(P) + idx = map((x,d)->_convert_idx(x,GU,d...),idx,domain_match) - idx = map((x,d) -> _convert_idx(x,GU,d), idx, d) |> - x -> collect(Iterators.product(x...)) |> - x -> dropdims(x,dims=tuple(findall(size(x).==1)...)) + X = Tuple.(Iterators.flatten.(Iterators.product(idx...))) data_dict = data(P) - length(idx) == 1 ? get(data_dict,idx[1],zero(T)) : get.(Ref(data_dict),idx,zero(T)) + length(X) == 1 ? get(data_dict,X[1],0) : get.(Ref(data_dict),X,0) end + ################ ### setindex ### ################