Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
mitchphillipson committed May 9, 2024
2 parents abc1d0b + feecad5 commit eb0caec
Showing 1 changed file with 24 additions and 11 deletions.
35 changes: 24 additions & 11 deletions src/dense_sparse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,12 @@ description(P::DenseSparseArray) = ""

function _convert_idx(x::Vector,GU::GamsUniverse,d::Symbol)
@assert (all(iGU[d] for ix)) "At least one element of $x is not in set $d"
return x
return [[i] for ix]
end

function _convert_idx(x::GamsSet,GU::GamsUniverse,d::Symbol)
@assert x==GU[d] "The set\n\n$x\ndoes not match the domain set $d"
return [i for ix]
return [[i] for ix]
end


Expand All @@ -51,26 +51,39 @@ function Base.getindex(P::DenseSparseArray{T,N},idx::CartesianIndex{N}) where {T
end


"""
#Assume input has no masks
function Base.getindex(P::DenseSparseArray{T,N},idx::Vararg{Any,N}) where {T,N}
return _getindex(P,idx...)
if v = [:a,:b,:c,:d,:e] and p = (2,2,1)
we expect output of
[[:a,:b],[:c,:d],[:e]]
"""
function partition(v,p)
ind = ((sum(p[i] for i1:n;init=1),sum(p[i] for i1:(n+1);init=1)-1) for n0:(length(p)-1))
return [[v[i] for ia:b] for (a,b)ind]
end

function dimension(x)
return 1
end

function _getindex(P::DenseSparseArray{T,N},idx::Vararg{Any,N}) where {T,N}
function Base.getindex(P::DenseSparseArray{T,N},idx::Vararg{Any}) where {T,N}
domain_match = partition(domain(P),dimension.(idx)) #Used for masking

@assert sum(length.(domain_match)) == dimension(P) "Not enough inputs, or something. Get a better error message"

GU = universe(P)
d = domain(P)
idx = map((x,d)->_convert_idx(x,GU,d...),idx,domain_match)

idx = map((x,d) -> _convert_idx(x,GU,d), idx, d) |>
x -> collect(Iterators.product(x...)) |>
x -> dropdims(x,dims=tuple(findall(size(x).==1)...))
X = Tuple.(Iterators.flatten.(Iterators.product(idx...)))

data_dict = data(P)
length(idx) == 1 ? get(data_dict,idx[1],zero(T)) : get.(Ref(data_dict),idx,zero(T))
length(X) == 1 ? get(data_dict,X[1],0) : get.(Ref(data_dict),X,0)

end


################
### setindex ###
################
Expand Down

0 comments on commit eb0caec

Please sign in to comment.